NiniCat commited on
Commit
83a2e73
1 Parent(s): 4d05090

add enzyme buttons

Browse files
Files changed (1) hide show
  1. app.py +199 -205
app.py CHANGED
@@ -15,16 +15,16 @@ selected_model = st.selectbox('Select CRISPR model:', CRISPR_MODELS, key='select
15
  # Check if the selected model is Cas9
16
  if selected_model == 'Cas9':
17
  # Display buttons for the Cas9 model
18
- if st.button('SPCas9_U6'):
19
  # Placeholder for action when SPCas9_U6 is clicked
20
  pass
21
- if st.button('SPCas9_t7'):
22
  # Placeholder for action when SPCas9_t7 is clicked
23
  pass
24
- if st.button('eSPCas9'):
25
  # Placeholder for action when eSPCas9 is clicked
26
  pass
27
- if st.button('SPCas9_HF1'):
28
  # Placeholder for action when SPCas9_HF1 is clicked
29
  pass
30
  elif selected_model == 'Cas12':
@@ -32,210 +32,204 @@ elif selected_model == 'Cas12':
32
  # TODO: Implement Cas12 model loading logic
33
  raise NotImplementedError("Cas12 model loading not implemented yet.")
34
  elif selected_model == 'Cas13d':
35
- # Assuming tiger module is for Cas13
36
- tiger.load_model() # Assuming there's a load_model function in tiger.py
37
- else:
38
- raise ValueError(f"Unknown model: {model_name}")
39
-
40
-
41
- @st.cache_data
42
- def convert_df(df):
43
- # IMPORTANT: Cache the conversion to prevent computation on every rerun
44
- return df.to_csv().encode('utf-8')
45
-
46
-
47
- def mode_change_callback():
48
- if st.session_state.mode in {tiger.RUN_MODES['all'], tiger.RUN_MODES['titration']}: # TODO: support titration
49
- st.session_state.check_off_targets = False
50
- st.session_state.disable_off_target_checkbox = True
51
- else:
52
- st.session_state.disable_off_target_checkbox = False
53
-
54
-
55
- def progress_update(update_text, percent_complete):
56
- with progress.container():
57
- st.write(update_text)
58
- st.progress(percent_complete / 100)
59
-
60
-
61
- def initiate_run():
62
-
63
-
64
- # Placeholder for dynamic module import based on selected_model
65
- # model_module = get_model_module(selected_model)
66
- # You will need to implement get_model_module function to import the correct module (cas9, cas12, cas13)
67
-
68
- # ... rest of the initiate_run function ...
69
- # initialize state variables
70
- st.session_state.transcripts = None
71
- st.session_state.input_error = None
72
- st.session_state.on_target = None
73
- st.session_state.titration = None
74
- st.session_state.off_target = None
75
-
76
- # initialize transcript DataFrame
77
- transcripts = pd.DataFrame(columns=[tiger.ID_COL, tiger.SEQ_COL])
78
-
79
- # manual entry
80
- if st.session_state.entry_method == ENTRY_METHODS['manual']:
81
- transcripts = pd.DataFrame({
82
- tiger.ID_COL: ['ManualEntry'],
83
- tiger.SEQ_COL: [st.session_state.manual_entry]
84
- }).set_index(tiger.ID_COL)
85
-
86
- # fasta file upload
87
- elif st.session_state.entry_method == ENTRY_METHODS['fasta']:
88
- if st.session_state.fasta_entry is not None:
89
- fasta_path = st.session_state.fasta_entry.name
90
- with open(fasta_path, 'w') as f:
91
- f.write(st.session_state.fasta_entry.getvalue().decode('utf-8'))
92
- transcripts = tiger.load_transcripts([fasta_path], enforce_unique_ids=False)
93
- os.remove(fasta_path)
94
-
95
- # convert to upper case as used by tokenizer
96
- transcripts[tiger.SEQ_COL] = transcripts[tiger.SEQ_COL].apply(lambda s: s.upper().replace('U', 'T'))
97
-
98
- # ensure all transcripts have unique identifiers
99
- if transcripts.index.has_duplicates:
100
- st.session_state.input_error = "Duplicate transcript ID's detected in fasta file"
101
-
102
- # ensure all transcripts only contain nucleotides A, C, G, T, and wildcard N
103
- elif not all(transcripts[tiger.SEQ_COL].apply(lambda s: set(s).issubset(tiger.NUCLEOTIDE_TOKENS.keys()))):
104
- st.session_state.input_error = 'Transcript(s) must only contain upper or lower case A, C, G, and Ts or Us'
105
-
106
- # ensure all transcripts satisfy length requirements
107
- elif any(transcripts[tiger.SEQ_COL].apply(lambda s: len(s) < tiger.TARGET_LEN)):
108
- st.session_state.input_error = 'Transcript(s) must be at least {:d} bases.'.format(tiger.TARGET_LEN)
109
-
110
- # run model if we have any transcripts
111
- elif len(transcripts) > 0:
112
- st.session_state.transcripts = transcripts
113
-
114
-
115
- if __name__ == '__main__':
116
-
117
- # app initialization
118
- if 'mode' not in st.session_state:
119
- st.session_state.mode = tiger.RUN_MODES['all']
120
- st.session_state.disable_off_target_checkbox = True
121
- if 'entry_method' not in st.session_state:
122
- st.session_state.entry_method = ENTRY_METHODS['manual']
123
- if 'transcripts' not in st.session_state:
124
- st.session_state.transcripts = None
125
- if 'input_error' not in st.session_state:
126
- st.session_state.input_error = None
127
- if 'on_target' not in st.session_state:
128
- st.session_state.on_target = None
129
- if 'titration' not in st.session_state:
130
- st.session_state.titration = None
131
- if 'off_target' not in st.session_state:
132
- st.session_state.off_target = None
133
-
134
- # title and documentation
135
- st.markdown(Path('tiger.md').read_text(), unsafe_allow_html=True)
136
- st.divider()
137
-
138
- # mode selection
139
- col1, col2 = st.columns([0.65, 0.35])
140
- with col1:
141
- st.radio(
142
- label='What do you want to predict?',
143
- options=tuple(tiger.RUN_MODES.values()),
144
- key='mode',
145
- on_change=mode_change_callback,
146
- disabled=st.session_state.transcripts is not None,
147
- )
148
- with col2:
149
- st.checkbox(
150
- label='Find off-target effects (slow)',
151
- key='check_off_targets',
152
- disabled=st.session_state.disable_off_target_checkbox or st.session_state.transcripts is not None
153
  )
 
 
 
 
154
 
155
- # transcript entry
156
- st.selectbox(
157
- label='How would you like to provide transcript(s) of interest?',
158
- options=ENTRY_METHODS.values(),
159
- key='entry_method',
160
- disabled=st.session_state.transcripts is not None
161
- )
162
- if st.session_state.entry_method == ENTRY_METHODS['manual']:
163
- st.text_input(
164
- label='Enter a target transcript:',
165
- key='manual_entry',
166
- placeholder='Upper or lower case',
167
- disabled=st.session_state.transcripts is not None
168
- )
169
- elif st.session_state.entry_method == ENTRY_METHODS['fasta']:
170
- st.file_uploader(
171
- label='Upload a fasta file:',
172
- key='fasta_entry',
173
- disabled=st.session_state.transcripts is not None
174
- )
175
 
176
- # let's go!
177
- st.button(label='Get predictions!', on_click=initiate_run, disabled=st.session_state.transcripts is not None)
178
- progress = st.empty()
179
-
180
- # input error
181
- error = st.empty()
182
- if st.session_state.input_error is not None:
183
- error.error(st.session_state.input_error, icon="🚨")
184
- else:
185
- error.empty()
186
-
187
- # on-target results
188
- on_target_results = st.empty()
189
- if st.session_state.on_target is not None:
190
- with on_target_results.container():
191
- st.write('On-target predictions:', st.session_state.on_target)
192
- st.download_button(
193
- label='Download on-target predictions',
194
- data=convert_df(st.session_state.on_target),
195
- file_name='on_target.csv',
196
- mime='text/csv'
197
- )
198
- else:
199
- on_target_results.empty()
200
-
201
- # titration results
202
- titration_results = st.empty()
203
- if st.session_state.titration is not None:
204
- with titration_results.container():
205
- st.write('Titration predictions:', st.session_state.titration)
206
- st.download_button(
207
- label='Download titration predictions',
208
- data=convert_df(st.session_state.titration),
209
- file_name='titration.csv',
210
- mime='text/csv'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  )
212
- else:
213
- titration_results.empty()
214
-
215
- # off-target results
216
- off_target_results = st.empty()
217
- if st.session_state.off_target is not None:
218
- with off_target_results.container():
219
- if len(st.session_state.off_target) > 0:
220
- st.write('Off-target predictions:', st.session_state.off_target)
221
- st.download_button(
222
- label='Download off-target predictions',
223
- data=convert_df(st.session_state.off_target),
224
- file_name='off_target.csv',
225
- mime='text/csv'
226
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  else:
228
- st.write('We did not find any off-target effects!')
229
- else:
230
- off_target_results.empty()
231
-
232
- # keep trying to run model until we clear inputs (streamlit UI changes can induce race-condition reruns)
233
- if st.session_state.transcripts is not None:
234
- st.session_state.on_target, st.session_state.titration, st.session_state.off_target = tiger.tiger_exhibit(
235
- transcripts=st.session_state.transcripts,
236
- mode={v: k for k, v in tiger.RUN_MODES.items()}[st.session_state.mode],
237
- check_off_targets=st.session_state.check_off_targets,
238
- status_update_fn=progress_update
239
- )
240
- st.session_state.transcripts = None
241
- st.experimental_rerun()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  # Check if the selected model is Cas9
16
  if selected_model == 'Cas9':
17
  # Display buttons for the Cas9 model
18
+ if st.checkbox('SPCas9_U6'):
19
  # Placeholder for action when SPCas9_U6 is clicked
20
  pass
21
+ if st.checkbox('SPCas9_t7'):
22
  # Placeholder for action when SPCas9_t7 is clicked
23
  pass
24
+ if st.checkbox('eSPCas9'):
25
  # Placeholder for action when eSPCas9 is clicked
26
  pass
27
+ if st.checkbox('SPCas9_HF1'):
28
  # Placeholder for action when SPCas9_HF1 is clicked
29
  pass
30
  elif selected_model == 'Cas12':
 
32
  # TODO: Implement Cas12 model loading logic
33
  raise NotImplementedError("Cas12 model loading not implemented yet.")
34
  elif selected_model == 'Cas13d':
35
+ ENTRY_METHODS = dict(
36
+ manual='Manual entry of single transcript',
37
+ fasta="Fasta file upload (supports multiple transcripts if they have unique ID's)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  )
39
+ @st.cache_data
40
+ def convert_df(df):
41
+ # IMPORTANT: Cache the conversion to prevent computation on every rerun
42
+ return df.to_csv().encode('utf-8')
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
+ def mode_change_callback():
46
+ if st.session_state.mode in {tiger.RUN_MODES['all'], tiger.RUN_MODES['titration']}: # TODO: support titration
47
+ st.session_state.check_off_targets = False
48
+ st.session_state.disable_off_target_checkbox = True
49
+ else:
50
+ st.session_state.disable_off_target_checkbox = False
51
+
52
+
53
+ def progress_update(update_text, percent_complete):
54
+ with progress.container():
55
+ st.write(update_text)
56
+ st.progress(percent_complete / 100)
57
+
58
+
59
+ def initiate_run():
60
+
61
+ # initialize state variables
62
+ st.session_state.transcripts = None
63
+ st.session_state.input_error = None
64
+ st.session_state.on_target = None
65
+ st.session_state.titration = None
66
+ st.session_state.off_target = None
67
+
68
+ # initialize transcript DataFrame
69
+ transcripts = pd.DataFrame(columns=[tiger.ID_COL, tiger.SEQ_COL])
70
+
71
+ # manual entry
72
+ if st.session_state.entry_method == ENTRY_METHODS['manual']:
73
+ transcripts = pd.DataFrame({
74
+ tiger.ID_COL: ['ManualEntry'],
75
+ tiger.SEQ_COL: [st.session_state.manual_entry]
76
+ }).set_index(tiger.ID_COL)
77
+
78
+ # fasta file upload
79
+ elif st.session_state.entry_method == ENTRY_METHODS['fasta']:
80
+ if st.session_state.fasta_entry is not None:
81
+ fasta_path = st.session_state.fasta_entry.name
82
+ with open(fasta_path, 'w') as f:
83
+ f.write(st.session_state.fasta_entry.getvalue().decode('utf-8'))
84
+ transcripts = tiger.load_transcripts([fasta_path], enforce_unique_ids=False)
85
+ os.remove(fasta_path)
86
+
87
+ # convert to upper case as used by tokenizer
88
+ transcripts[tiger.SEQ_COL] = transcripts[tiger.SEQ_COL].apply(lambda s: s.upper().replace('U', 'T'))
89
+
90
+ # ensure all transcripts have unique identifiers
91
+ if transcripts.index.has_duplicates:
92
+ st.session_state.input_error = "Duplicate transcript ID's detected in fasta file"
93
+
94
+ # ensure all transcripts only contain nucleotides A, C, G, T, and wildcard N
95
+ elif not all(transcripts[tiger.SEQ_COL].apply(lambda s: set(s).issubset(tiger.NUCLEOTIDE_TOKENS.keys()))):
96
+ st.session_state.input_error = 'Transcript(s) must only contain upper or lower case A, C, G, and Ts or Us'
97
+
98
+ # ensure all transcripts satisfy length requirements
99
+ elif any(transcripts[tiger.SEQ_COL].apply(lambda s: len(s) < tiger.TARGET_LEN)):
100
+ st.session_state.input_error = 'Transcript(s) must be at least {:d} bases.'.format(tiger.TARGET_LEN)
101
+
102
+ # run model if we have any transcripts
103
+ elif len(transcripts) > 0:
104
+ st.session_state.transcripts = transcripts
105
+
106
+
107
+ if __name__ == '__main__':
108
+
109
+ # app initialization
110
+ if 'mode' not in st.session_state:
111
+ st.session_state.mode = tiger.RUN_MODES['all']
112
+ st.session_state.disable_off_target_checkbox = True
113
+ if 'entry_method' not in st.session_state:
114
+ st.session_state.entry_method = ENTRY_METHODS['manual']
115
+ if 'transcripts' not in st.session_state:
116
+ st.session_state.transcripts = None
117
+ if 'input_error' not in st.session_state:
118
+ st.session_state.input_error = None
119
+ if 'on_target' not in st.session_state:
120
+ st.session_state.on_target = None
121
+ if 'titration' not in st.session_state:
122
+ st.session_state.titration = None
123
+ if 'off_target' not in st.session_state:
124
+ st.session_state.off_target = None
125
+
126
+ # title and documentation
127
+ st.markdown(Path('tiger.md').read_text(), unsafe_allow_html=True)
128
+ st.divider()
129
+
130
+ # mode selection
131
+ col1, col2 = st.columns([0.65, 0.35])
132
+ with col1:
133
+ st.radio(
134
+ label='What do you want to predict?',
135
+ options=tuple(tiger.RUN_MODES.values()),
136
+ key='mode',
137
+ on_change=mode_change_callback,
138
+ disabled=st.session_state.transcripts is not None,
139
+ )
140
+ with col2:
141
+ st.checkbox(
142
+ label='Find off-target effects (slow)',
143
+ key='check_off_targets',
144
+ disabled=st.session_state.disable_off_target_checkbox or st.session_state.transcripts is not None
145
+ )
146
+
147
+ # transcript entry
148
+ st.selectbox(
149
+ label='How would you like to provide transcript(s) of interest?',
150
+ options=ENTRY_METHODS.values(),
151
+ key='entry_method',
152
+ disabled=st.session_state.transcripts is not None
153
  )
154
+ if st.session_state.entry_method == ENTRY_METHODS['manual']:
155
+ st.text_input(
156
+ label='Enter a target transcript:',
157
+ key='manual_entry',
158
+ placeholder='Upper or lower case',
159
+ disabled=st.session_state.transcripts is not None
 
 
 
 
 
 
 
 
160
  )
161
+ elif st.session_state.entry_method == ENTRY_METHODS['fasta']:
162
+ st.file_uploader(
163
+ label='Upload a fasta file:',
164
+ key='fasta_entry',
165
+ disabled=st.session_state.transcripts is not None
166
+ )
167
+
168
+ # let's go!
169
+ st.button(label='Get predictions!', on_click=initiate_run, disabled=st.session_state.transcripts is not None)
170
+ progress = st.empty()
171
+
172
+ # input error
173
+ error = st.empty()
174
+ if st.session_state.input_error is not None:
175
+ error.error(st.session_state.input_error, icon="🚨")
176
  else:
177
+ error.empty()
178
+
179
+ # on-target results
180
+ on_target_results = st.empty()
181
+ if st.session_state.on_target is not None:
182
+ with on_target_results.container():
183
+ st.write('On-target predictions:', st.session_state.on_target)
184
+ st.download_button(
185
+ label='Download on-target predictions',
186
+ data=convert_df(st.session_state.on_target),
187
+ file_name='on_target.csv',
188
+ mime='text/csv'
189
+ )
190
+ else:
191
+ on_target_results.empty()
192
+
193
+ # titration results
194
+ titration_results = st.empty()
195
+ if st.session_state.titration is not None:
196
+ with titration_results.container():
197
+ st.write('Titration predictions:', st.session_state.titration)
198
+ st.download_button(
199
+ label='Download titration predictions',
200
+ data=convert_df(st.session_state.titration),
201
+ file_name='titration.csv',
202
+ mime='text/csv'
203
+ )
204
+ else:
205
+ titration_results.empty()
206
+
207
+ # off-target results
208
+ off_target_results = st.empty()
209
+ if st.session_state.off_target is not None:
210
+ with off_target_results.container():
211
+ if len(st.session_state.off_target) > 0:
212
+ st.write('Off-target predictions:', st.session_state.off_target)
213
+ st.download_button(
214
+ label='Download off-target predictions',
215
+ data=convert_df(st.session_state.off_target),
216
+ file_name='off_target.csv',
217
+ mime='text/csv'
218
+ )
219
+ else:
220
+ st.write('We did not find any off-target effects!')
221
+ else:
222
+ off_target_results.empty()
223
+
224
+ # keep trying to run model until we clear inputs (streamlit UI changes can induce race-condition reruns)
225
+ if st.session_state.transcripts is not None:
226
+ st.session_state.on_target, st.session_state.titration, st.session_state.off_target = tiger.tiger_exhibit(
227
+ transcripts=st.session_state.transcripts,
228
+ mode={v: k for k, v in tiger.RUN_MODES.items()}[st.session_state.mode],
229
+ check_off_targets=st.session_state.check_off_targets,
230
+ status_update_fn=progress_update
231
+ )
232
+ st.session_state.transcripts = None
233
+ st.experimental_rerun()
234
+ else:
235
+ raise ValueError(f"Unknown model: {model_name}")