Spaces:
Sleeping
Sleeping
supercat666
commited on
Commit
•
242350b
1
Parent(s):
2b3514d
fix
Browse files
app.py
CHANGED
@@ -144,13 +144,20 @@ if selected_model == 'Cas9':
|
|
144 |
# Prediction button
|
145 |
predict_button = st.button('Predict on-target')
|
146 |
|
|
|
|
|
|
|
|
|
|
|
147 |
# Process predictions
|
148 |
if predict_button and gene_symbol:
|
149 |
with st.spinner('Predicting... Please wait'):
|
150 |
-
predictions, gene_sequence = cas9on.process_gene(gene_symbol, cas9on_path)
|
151 |
sorted_predictions = sorted(predictions, key=lambda x: x[-1], reverse=True)[:10]
|
152 |
st.session_state['on_target_results'] = sorted_predictions
|
153 |
st.session_state['gene_sequence'] = gene_sequence # Save gene sequence in session state
|
|
|
|
|
154 |
|
155 |
# Notify the user once the process is completed successfully.
|
156 |
st.success('Prediction completed!')
|
@@ -162,44 +169,64 @@ if selected_model == 'Cas9':
|
|
162 |
df = pd.DataFrame(st.session_state['on_target_results'],
|
163 |
columns=["Gene ID", "Start Pos", "End Pos", "Strand", "Target", "gRNA", "Prediction"])
|
164 |
st.dataframe(df)
|
165 |
-
# Now create a Plotly plot with the sorted_predictions
|
|
|
166 |
fig = go.Figure()
|
167 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
# Initialize the y position for the positive and negative strands
|
169 |
positive_strand_y = 0.1
|
170 |
negative_strand_y = -0.1
|
171 |
-
|
172 |
-
# Use an offset to spread gRNA sequences vertically
|
173 |
-
offset = 0.05
|
174 |
|
175 |
# Iterate over the sorted predictions to create the plot
|
176 |
-
for i, prediction in enumerate(
|
177 |
-
# Extract data for plotting and convert start and end to integers
|
178 |
chrom, start, end, strand, target, gRNA, pred_score = prediction
|
179 |
start, end = int(start), int(end)
|
180 |
midpoint = (start + end) / 2
|
181 |
|
182 |
-
|
183 |
-
if strand == '1':
|
184 |
y_value = positive_strand_y
|
185 |
arrow_symbol = 'triangle-right'
|
186 |
-
# Increment the y-value for the next positive strand gRNA
|
187 |
positive_strand_y += offset
|
188 |
-
else:
|
189 |
y_value = negative_strand_y
|
190 |
arrow_symbol = 'triangle-left'
|
191 |
-
# Decrement the y-value for the next negative strand gRNA
|
192 |
negative_strand_y -= offset
|
193 |
|
194 |
fig.add_trace(go.Scatter(
|
195 |
x=[midpoint],
|
196 |
-
y=[y_value],
|
197 |
mode='markers+text',
|
198 |
marker=dict(symbol=arrow_symbol, size=10),
|
199 |
name=f"gRNA: {gRNA}",
|
200 |
-
text=f"Rank: {i}",
|
201 |
hoverinfo='text',
|
202 |
-
hovertext=f"Rank: {i}<br>Chromosome: {chrom}<br>Target Sequence: {target}<br>gRNA: {gRNA}<br>Start: {start}<br>End: {end}<br>Strand: {'+' if strand == 1 else '-'}<br>Prediction Score: {pred_score:.4f}",
|
203 |
))
|
204 |
|
205 |
# Update the layout of the plot
|
@@ -208,14 +235,12 @@ if selected_model == 'Cas9':
|
|
208 |
xaxis_title='Genomic Position',
|
209 |
yaxis=dict(
|
210 |
title='Strand',
|
211 |
-
showgrid=True,
|
212 |
-
zeroline=
|
213 |
-
|
214 |
-
|
215 |
-
tickvals=[positive_strand_y, negative_strand_y],
|
216 |
-
ticktext=['+ Strand', '- Strand']
|
217 |
),
|
218 |
-
showlegend=
|
219 |
)
|
220 |
|
221 |
# Display the plot
|
|
|
144 |
# Prediction button
|
145 |
predict_button = st.button('Predict on-target')
|
146 |
|
147 |
+
if 'exons' not in st.session_state:
|
148 |
+
st.session_state['exons'] = []
|
149 |
+
if 'cds' not in st.session_state:
|
150 |
+
st.session_state['cds'] = []
|
151 |
+
|
152 |
# Process predictions
|
153 |
if predict_button and gene_symbol:
|
154 |
with st.spinner('Predicting... Please wait'):
|
155 |
+
predictions, gene_sequence, exons, cds = cas9on.process_gene(gene_symbol, cas9on_path)
|
156 |
sorted_predictions = sorted(predictions, key=lambda x: x[-1], reverse=True)[:10]
|
157 |
st.session_state['on_target_results'] = sorted_predictions
|
158 |
st.session_state['gene_sequence'] = gene_sequence # Save gene sequence in session state
|
159 |
+
st.session_state['exons'] = exons # Store exon data
|
160 |
+
st.session_state['cds'] = cds # Store CDS data
|
161 |
|
162 |
# Notify the user once the process is completed successfully.
|
163 |
st.success('Prediction completed!')
|
|
|
169 |
df = pd.DataFrame(st.session_state['on_target_results'],
|
170 |
columns=["Gene ID", "Start Pos", "End Pos", "Strand", "Target", "gRNA", "Prediction"])
|
171 |
st.dataframe(df)
|
172 |
+
# Now create a Plotly plot with the sorted_predictions# Initialize Plotly figure
|
173 |
+
# Initialize Plotly figure
|
174 |
fig = go.Figure()
|
175 |
|
176 |
+
# Plot Exons as horizontal lines or rectangles
|
177 |
+
exon_y = 0.2 # Adjust this as needed
|
178 |
+
for exon in st.session_state['exons']:
|
179 |
+
exon_start, exon_end = int(exon['start']), int(exon['end'])
|
180 |
+
fig.add_trace(go.Scatter(
|
181 |
+
x=[exon_start, exon_end],
|
182 |
+
y=[exon_y, exon_y],
|
183 |
+
mode='lines',
|
184 |
+
line=dict(color='purple', width=10), # Adjust styling as needed
|
185 |
+
name='Exon'
|
186 |
+
))
|
187 |
+
|
188 |
+
# Plot CDS as horizontal lines or rectangles
|
189 |
+
cds_y = 0.3 # Adjust this as needed
|
190 |
+
for cds in st.session_state['cds']:
|
191 |
+
cds_start, cds_end = int(cds['start']), int(cds['end'])
|
192 |
+
fig.add_trace(go.Scatter(
|
193 |
+
x=[cds_start, cds_end],
|
194 |
+
y=[cds_y, cds_y],
|
195 |
+
mode='lines',
|
196 |
+
line=dict(color='blue', width=10), # Adjust styling as needed
|
197 |
+
name='CDS'
|
198 |
+
))
|
199 |
+
|
200 |
+
# Plot gRNAs using triangles to indicate direction
|
201 |
# Initialize the y position for the positive and negative strands
|
202 |
positive_strand_y = 0.1
|
203 |
negative_strand_y = -0.1
|
204 |
+
offset = 0.05 # Use an offset to spread gRNA sequences vertically
|
|
|
|
|
205 |
|
206 |
# Iterate over the sorted predictions to create the plot
|
207 |
+
for i, prediction in enumerate(st.session_state['on_target_results'], start=1):
|
|
|
208 |
chrom, start, end, strand, target, gRNA, pred_score = prediction
|
209 |
start, end = int(start), int(end)
|
210 |
midpoint = (start + end) / 2
|
211 |
|
212 |
+
if strand == '1': # Positive strand
|
|
|
213 |
y_value = positive_strand_y
|
214 |
arrow_symbol = 'triangle-right'
|
|
|
215 |
positive_strand_y += offset
|
216 |
+
else: # Negative strand
|
217 |
y_value = negative_strand_y
|
218 |
arrow_symbol = 'triangle-left'
|
|
|
219 |
negative_strand_y -= offset
|
220 |
|
221 |
fig.add_trace(go.Scatter(
|
222 |
x=[midpoint],
|
223 |
+
y=[y_value],
|
224 |
mode='markers+text',
|
225 |
marker=dict(symbol=arrow_symbol, size=10),
|
226 |
name=f"gRNA: {gRNA}",
|
227 |
+
text=f"Rank: {i}",
|
228 |
hoverinfo='text',
|
229 |
+
hovertext=f"Rank: {i}<br>Chromosome: {chrom}<br>Target Sequence: {target}<br>gRNA: {gRNA}<br>Start: {start}<br>End: {end}<br>Strand: {'+' if strand == '1' else '-'}<br>Prediction Score: {pred_score:.4f}",
|
230 |
))
|
231 |
|
232 |
# Update the layout of the plot
|
|
|
235 |
xaxis_title='Genomic Position',
|
236 |
yaxis=dict(
|
237 |
title='Strand',
|
238 |
+
showgrid=True,
|
239 |
+
zeroline=False,
|
240 |
+
tickvals=[positive_strand_y, negative_strand_y, exon_y, cds_y],
|
241 |
+
ticktext=['+ Strand gRNAs', '- Strand gRNAs', 'Exons', 'CDS']
|
|
|
|
|
242 |
),
|
243 |
+
showlegend=True
|
244 |
)
|
245 |
|
246 |
# Display the plot
|
cas9on.py
CHANGED
@@ -104,6 +104,7 @@ def find_crispr_targets(sequence, chr, start, strand, pam="NGG", target_length=2
|
|
104 |
|
105 |
return targets
|
106 |
|
|
|
107 |
def process_gene(gene_symbol, model_path):
|
108 |
transcripts = fetch_ensembl_transcripts(gene_symbol)
|
109 |
all_data = []
|
@@ -118,14 +119,41 @@ def process_gene(gene_symbol, model_path):
|
|
118 |
# Fetch the sequence here and concatenate if multiple transcripts
|
119 |
gene_sequence += fetch_ensembl_sequence(transcript_id) or ''
|
120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
if gene_sequence:
|
122 |
gRNA_sites = find_crispr_targets(gene_sequence, chr, start, strand)
|
123 |
if gRNA_sites:
|
124 |
formatted_data = format_prediction_output(gRNA_sites, model_path)
|
125 |
all_data.extend(formatted_data)
|
126 |
|
127 |
-
# Return
|
128 |
-
return all_data, gene_sequence
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
|
130 |
def create_genbank_features(formatted_data):
|
131 |
features = []
|
|
|
104 |
|
105 |
return targets
|
106 |
|
107 |
+
|
108 |
def process_gene(gene_symbol, model_path):
|
109 |
transcripts = fetch_ensembl_transcripts(gene_symbol)
|
110 |
all_data = []
|
|
|
119 |
# Fetch the sequence here and concatenate if multiple transcripts
|
120 |
gene_sequence += fetch_ensembl_sequence(transcript_id) or ''
|
121 |
|
122 |
+
# Fetch exon and CDS information
|
123 |
+
exons = fetch_ensembl_exons(transcript_id)
|
124 |
+
cds_list = fetch_ensembl_cds(transcript_id)
|
125 |
+
|
126 |
+
# You might want to do something specific with exons and CDS information here
|
127 |
+
# For example, store them, print them, or include them in your analysis
|
128 |
+
|
129 |
if gene_sequence:
|
130 |
gRNA_sites = find_crispr_targets(gene_sequence, chr, start, strand)
|
131 |
if gRNA_sites:
|
132 |
formatted_data = format_prediction_output(gRNA_sites, model_path)
|
133 |
all_data.extend(formatted_data)
|
134 |
|
135 |
+
# Return the data, fetched sequence, and possibly exon/CDS data
|
136 |
+
return all_data, gene_sequence, exons, cds_list
|
137 |
+
|
138 |
+
def fetch_ensembl_exons(transcript_id):
|
139 |
+
"""Fetch exon information for a given transcript from Ensembl."""
|
140 |
+
url = f"https://rest.ensembl.org/overlap/id/{transcript_id}?feature=exon;content-type=application/json"
|
141 |
+
response = requests.get(url)
|
142 |
+
if response.status_code == 200:
|
143 |
+
return response.json() # Returns a list of exons for the transcript
|
144 |
+
else:
|
145 |
+
print(f"Error fetching exon data from Ensembl for transcript {transcript_id}: {response.text}")
|
146 |
+
return None
|
147 |
+
|
148 |
+
def fetch_ensembl_cds(transcript_id):
|
149 |
+
"""Fetch coding sequence (CDS) information for a given transcript from Ensembl."""
|
150 |
+
url = f"https://rest.ensembl.org/overlap/id/{transcript_id}?feature=cds;content-type=application/json"
|
151 |
+
response = requests.get(url)
|
152 |
+
if response.status_code == 200:
|
153 |
+
return response.json() # Returns a list of CDS regions for the transcript
|
154 |
+
else:
|
155 |
+
print(f"Error fetching CDS data from Ensembl for transcript {transcript_id}: {response.text}")
|
156 |
+
return None
|
157 |
|
158 |
def create_genbank_features(formatted_data):
|
159 |
features = []
|