File size: 8,964 Bytes
e93c659
 
 
 
 
 
 
 
 
 
fdf5616
306ab4d
 
bc09cb1
 
 
 
 
 
 
 
 
fdf5616
bc09cb1
 
 
 
 
e93c659
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306ab4d
 
 
 
 
 
e93c659
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
import os
import tiger
import pandas as pd
import streamlit as st
from pathlib import Path

ENTRY_METHODS = dict(
    manual='Manual entry of single transcript',
    fasta="Fasta file upload (supports multiple transcripts if they have unique ID's)"
)
CRISPR_MODELS = ['Cas9', 'Cas12', 'Cas13d']
selected_model = st.selectbox('Select CRISPR model:', CRISPR_MODELS, key='selected_model')

def load_model(model_name):
    if model_name == 'Cas9':
        # Placeholder for Cas9 model loading
        # TODO: Implement Cas9 model loading logic
        raise NotImplementedError("Cas9 model loading not implemented yet.")
    elif model_name == 'Cas12':
        # Placeholder for Cas12 model loading
        # TODO: Implement Cas12 model loading logic
        raise NotImplementedError("Cas12 model loading not implemented yet.")
    elif model_name == 'Cas13d':
        # Assuming tiger module is for Cas13
        return tiger.load_model()  # Assuming there's a load_model function in tiger.py
    else:
        raise ValueError(f"Unknown model: {model_name}")



@st.cache_data
def convert_df(df):
    # IMPORTANT: Cache the conversion to prevent computation on every rerun
    return df.to_csv().encode('utf-8')


def mode_change_callback():
    if st.session_state.mode in {tiger.RUN_MODES['all'], tiger.RUN_MODES['titration']}:  # TODO: support titration
        st.session_state.check_off_targets = False
        st.session_state.disable_off_target_checkbox = True
    else:
        st.session_state.disable_off_target_checkbox = False


def progress_update(update_text, percent_complete):
    with progress.container():
        st.write(update_text)
        st.progress(percent_complete / 100)


def initiate_run():


    # Placeholder for dynamic module import based on selected_model
    # model_module = get_model_module(selected_model)
    # You will need to implement get_model_module function to import the correct module (cas9, cas12, cas13)

    # ... rest of the initiate_run function ...
    # initialize state variables
    st.session_state.transcripts = None
    st.session_state.input_error = None
    st.session_state.on_target = None
    st.session_state.titration = None
    st.session_state.off_target = None

    # initialize transcript DataFrame
    transcripts = pd.DataFrame(columns=[tiger.ID_COL, tiger.SEQ_COL])

    # manual entry
    if st.session_state.entry_method == ENTRY_METHODS['manual']:
        transcripts = pd.DataFrame({
            tiger.ID_COL: ['ManualEntry'],
            tiger.SEQ_COL: [st.session_state.manual_entry]
        }).set_index(tiger.ID_COL)

    # fasta file upload
    elif st.session_state.entry_method == ENTRY_METHODS['fasta']:
        if st.session_state.fasta_entry is not None:
            fasta_path = st.session_state.fasta_entry.name
            with open(fasta_path, 'w') as f:
                f.write(st.session_state.fasta_entry.getvalue().decode('utf-8'))
            transcripts = tiger.load_transcripts([fasta_path], enforce_unique_ids=False)
            os.remove(fasta_path)

    # convert to upper case as used by tokenizer
    transcripts[tiger.SEQ_COL] = transcripts[tiger.SEQ_COL].apply(lambda s: s.upper().replace('U', 'T'))

    # ensure all transcripts have unique identifiers
    if transcripts.index.has_duplicates:
        st.session_state.input_error = "Duplicate transcript ID's detected in fasta file"

    # ensure all transcripts only contain nucleotides A, C, G, T, and wildcard N
    elif not all(transcripts[tiger.SEQ_COL].apply(lambda s: set(s).issubset(tiger.NUCLEOTIDE_TOKENS.keys()))):
        st.session_state.input_error = 'Transcript(s) must only contain upper or lower case A, C, G, and Ts or Us'

    # ensure all transcripts satisfy length requirements
    elif any(transcripts[tiger.SEQ_COL].apply(lambda s: len(s) < tiger.TARGET_LEN)):
        st.session_state.input_error = 'Transcript(s) must be at least {:d} bases.'.format(tiger.TARGET_LEN)

    # run model if we have any transcripts
    elif len(transcripts) > 0:
        st.session_state.transcripts = transcripts


if __name__ == '__main__':

    # app initialization
    if 'mode' not in st.session_state:
        st.session_state.mode = tiger.RUN_MODES['all']
        st.session_state.disable_off_target_checkbox = True
    if 'entry_method' not in st.session_state:
        st.session_state.entry_method = ENTRY_METHODS['manual']
    if 'transcripts' not in st.session_state:
        st.session_state.transcripts = None
    if 'input_error' not in st.session_state:
        st.session_state.input_error = None
    if 'on_target' not in st.session_state:
        st.session_state.on_target = None
    if 'titration' not in st.session_state:
        st.session_state.titration = None
    if 'off_target' not in st.session_state:
        st.session_state.off_target = None

    # title and documentation
    st.markdown(Path('tiger.md').read_text(), unsafe_allow_html=True)
    st.divider()

    # mode selection
    col1, col2 = st.columns([0.65, 0.35])
    with col1:
        st.radio(
            label='What do you want to predict?',
            options=tuple(tiger.RUN_MODES.values()),
            key='mode',
            on_change=mode_change_callback,
            disabled=st.session_state.transcripts is not None,
        )
    with col2:
        st.checkbox(
            label='Find off-target effects (slow)',
            key='check_off_targets',
            disabled=st.session_state.disable_off_target_checkbox or st.session_state.transcripts is not None
        )

    # transcript entry
    st.selectbox(
        label='How would you like to provide transcript(s) of interest?',
        options=ENTRY_METHODS.values(),
        key='entry_method',
        disabled=st.session_state.transcripts is not None
    )
    if st.session_state.entry_method == ENTRY_METHODS['manual']:
        st.text_input(
            label='Enter a target transcript:',
            key='manual_entry',
            placeholder='Upper or lower case',
            disabled=st.session_state.transcripts is not None
        )
    elif st.session_state.entry_method == ENTRY_METHODS['fasta']:
        st.file_uploader(
            label='Upload a fasta file:',
            key='fasta_entry',
            disabled=st.session_state.transcripts is not None
        )

    # let's go!
    st.button(label='Get predictions!', on_click=initiate_run, disabled=st.session_state.transcripts is not None)
    progress = st.empty()

    # input error
    error = st.empty()
    if st.session_state.input_error is not None:
        error.error(st.session_state.input_error, icon="🚨")
    else:
        error.empty()

    # on-target results
    on_target_results = st.empty()
    if st.session_state.on_target is not None:
        with on_target_results.container():
            st.write('On-target predictions:', st.session_state.on_target)
            st.download_button(
                label='Download on-target predictions',
                data=convert_df(st.session_state.on_target),
                file_name='on_target.csv',
                mime='text/csv'
            )
    else:
        on_target_results.empty()

    # titration results
    titration_results = st.empty()
    if st.session_state.titration is not None:
        with titration_results.container():
            st.write('Titration predictions:', st.session_state.titration)
            st.download_button(
                label='Download titration predictions',
                data=convert_df(st.session_state.titration),
                file_name='titration.csv',
                mime='text/csv'
            )
    else:
        titration_results.empty()

    # off-target results
    off_target_results = st.empty()
    if st.session_state.off_target is not None:
        with off_target_results.container():
            if len(st.session_state.off_target) > 0:
                st.write('Off-target predictions:', st.session_state.off_target)
                st.download_button(
                    label='Download off-target predictions',
                    data=convert_df(st.session_state.off_target),
                    file_name='off_target.csv',
                    mime='text/csv'
                )
            else:
                st.write('We did not find any off-target effects!')
    else:
        off_target_results.empty()

    # keep trying to run model until we clear inputs (streamlit UI changes can induce race-condition reruns)
    if st.session_state.transcripts is not None:
        st.session_state.on_target, st.session_state.titration, st.session_state.off_target = tiger.tiger_exhibit(
            transcripts=st.session_state.transcripts,
            mode={v: k for k, v in tiger.RUN_MODES.items()}[st.session_state.mode],
            check_off_targets=st.session_state.check_off_targets,
            status_update_fn=progress_update
        )
        st.session_state.transcripts = None
        st.experimental_rerun()