File size: 7,748 Bytes
ac117b5
 
 
71382c0
ac117b5
71382c0
 
 
 
 
 
 
ac117b5
71382c0
 
 
ac117b5
 
71382c0
ac117b5
71382c0
 
ac117b5
71382c0
 
 
 
 
 
 
 
 
 
 
 
ac117b5
 
 
 
 
 
71382c0
ac117b5
 
 
71382c0
cec3465
ac117b5
 
cec3465
ac117b5
 
71382c0
ac117b5
 
 
 
 
71382c0
 
 
 
 
 
ac117b5
 
71382c0
ac117b5
 
 
 
71382c0
ac117b5
 
71382c0
 
 
 
 
 
 
 
 
ac117b5
 
71382c0
 
ac117b5
32cc43e
 
 
 
71382c0
ac117b5
71382c0
ac117b5
71382c0
ac117b5
 
 
 
 
 
71382c0
ac117b5
 
 
 
 
 
71382c0
ac117b5
 
 
71382c0
 
 
ac117b5
71382c0
 
ac117b5
 
 
cec3465
71382c0
 
ac117b5
32cc43e
71382c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac117b5
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
import gradio as gr
import torch
from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
from mammal.examples.dti_bindingdb_kd.task import DtiBindingdbKdTask
from mammal.keys import *
from mammal.model import Mammal

model_paths = dict()

# Protein protein interaction:
ppi = "Protein-Protein Interaction (PPI)"
model_paths[ppi] = "ibm/biomed.omics.bl.sm.ma-ted-458m"

#
dti = "Drug-Target Binding Affinity"
model_paths[dti] = "ibm/biomed.omics.bl.sm.ma-ted-458m.dti_bindingdb_pkd"


# load models (should probably be lazy)

models = dict()
tokenizer_op = dict()


for task, model_path in model_paths.items():
    if task not in models:
        models[task] = Mammal.from_pretrained(model_path)
        models[task].eval()
        # Load Tokenizer
        tokenizer_op[task] = ModularTokenizerOp.from_pretrained(model_path)


### PPI:
# token for positive binding
positive_token_id = tokenizer_op[ppi].get_token_id("<1>")

# Default input proteins
protein_calmodulin = "MADQLTEEQIAEFKEAFSLFDKDGDGTITTKELGTVMRSLGQNPTEAELQDMISELDQDGFIDKEDLHDGDGKISFEEFLNLVNKEMTADVDGDGQVNYEEFVTMMTSK"
protein_calcineurin = "MSSKLLLAGLDIERVLAEKNFYKEWDTWIIEAMNVGDEEVDRIKEFKEDEIFEEAKTLGTAEMQEYKKQKLEEAIEGAFDIFDKDGNGYISAAELRHVMTNLGEKLTDEEVDEMIRQMWDQNGDWDRIKELKFGEIKKLSAKDTRGTIFIKVFENLGTGVDSEYEDVSKYMLKHQ"


def format_prompt_ppi(prot1, prot2):
    # Formatting prompt to match pre-training syntax
    return f"<@TOKENIZER-TYPE=AA><BINDING_AFFINITY_CLASS><SENTINEL_ID_0><MOLECULAR_ENTITY><MOLECULAR_ENTITY_GENERAL_PROTEIN><SEQUENCE_NATURAL_START>{prot1}<SEQUENCE_NATURAL_END><MOLECULAR_ENTITY><MOLECULAR_ENTITY_GENERAL_PROTEIN><SEQUENCE_NATURAL_START>{prot2}<SEQUENCE_NATURAL_END><EOS>"


def run_prompt(prompt):
    # Create and load sample
    sample_dict = dict()
    sample_dict[ENCODER_INPUTS_STR] = prompt

    # Tokenize
    sample_dict = tokenizer_op[ppi](
        sample_dict=sample_dict,
        key_in=ENCODER_INPUTS_STR,
        key_out_tokens_ids=ENCODER_INPUTS_TOKENS,
        key_out_attention_mask=ENCODER_INPUTS_ATTENTION_MASK,
    )
    sample_dict[ENCODER_INPUTS_TOKENS] = torch.tensor(
        sample_dict[ENCODER_INPUTS_TOKENS]
    )
    sample_dict[ENCODER_INPUTS_ATTENTION_MASK] = torch.tensor(
        sample_dict[ENCODER_INPUTS_ATTENTION_MASK]
    )

    # Generate Prediction
    batch_dict = models[ppi].generate(
        [sample_dict],
        output_scores=True,
        return_dict_in_generate=True,
        max_new_tokens=5,
    )

    # Get output
    generated_output = tokenizer_op[ppi]._tokenizer.decode(batch_dict[CLS_PRED][0])
    score = batch_dict["model.out.scores"][0][1][positive_token_id].item()

    return generated_output, score


def create_and_run_prompt(protein1, protein2):
    prompt = format_prompt_ppi(protein1, protein2)
    res = prompt, *run_prompt(prompt=prompt)
    return res


def create_ppi_demo():
    markup_text = f"""
# Mammal based Protein-Protein Interaction (PPI) demonstration

Given two protein sequences, estimate if the proteins interact or not.

### Using the model from

 ```{model_paths[ppi]} ```
"""
    with gr.Group() as ppi_demo:
        gr.Markdown(markup_text)
        with gr.Row():
            prot1 = gr.Textbox(
                label="Protein 1 sequence",
                # info="standard",
                interactive=True,
                lines=3,
                value=protein_calmodulin,
            )
            prot2 = gr.Textbox(
                label="Protein 2 sequence",
                # info="standard",
                interactive=True,
                lines=3,
                value=protein_calcineurin,
            )
        with gr.Row():
            run_mammal = gr.Button(
                "Run Mammal prompt for Protein-Protein Interaction", variant="primary"
            )
        with gr.Row():
            prompt_box = gr.Textbox(label="Mammal prompt", lines=5)

        with gr.Row():
            decoded = gr.Textbox(label="Mammal output")
            run_mammal.click(
                fn=create_and_run_prompt,
                inputs=[prot1, prot2],
                outputs=[prompt_box, decoded, gr.Number(label="PPI score")],
            )
        with gr.Row():
            gr.Markdown(
                "```<SENTINEL_ID_0>``` contains the binding affinity class, which is ```<1>``` for interacting and ```<0>``` for non-interacting"
            )
        ppi_demo.visible = False
        return ppi_demo


### DTI:
# input
target_seq = "NLMKRCTRGFRKLGKCTTLEEEKCKTLYPRGQCTCSDSKMNTHSCDCKSC"
drug_seq = "CC(=O)NCCC1=CNc2c1cc(OC)cc2"


# token for positive binding
positive_token_id = tokenizer_op[dti].get_token_id("<1>")


def format_prompt_dti(prot, drug):
    sample_dict = {"target_seq": target_seq, "drug_seq": drug_seq}
    sample_dict = DtiBindingdbKdTask.data_preprocessing(
        sample_dict=sample_dict,
        tokenizer_op=tokenizer_op[dti],
        target_sequence_key="target_seq",
        drug_sequence_key="drug_seq",
        norm_y_mean=None,
        norm_y_std=None,
        device=models[dti].device,
    )
    return sample_dict


def create_and_run_prompt_dtb(prot, drug):
    sample_dict = format_prompt_dti(prot, drug)
    # Post-process the model's output
    # batch_dict = model_dti.forward_encoder_only([sample_dict])
    batch_dict = models[dti].forward_encoder_only([sample_dict])
    batch_dict = DtiBindingdbKdTask.process_model_output(
        batch_dict,
        scalars_preds_processed_key="model.out.dti_bindingdb_kd",
        norm_y_mean=5.79384684128215,
        norm_y_std=1.33808027428196,
    )
    ans = [
        "model.out.dti_bindingdb_kd",
        float(batch_dict["model.out.dti_bindingdb_kd"][0]),
    ]
    res = sample_dict["data.query.encoder_input"], *ans
    return res


def create_tdb_demo():
    markup_text = f"""
# Mammal based Target-Drug binding affinity demonstration

Given a protein sequence and a drug (in SMILES), estimate the binding affinity.

### Using the model from

 ```{model_paths[dti]} ```
"""
    with gr.Group() as tdb_demo:
        gr.Markdown(markup_text)
        with gr.Row():
            prot = gr.Textbox(
                label="Protein sequence",
                # info="standard",
                interactive=True,
                lines=3,
                value=target_seq,
            )
            drug = gr.Textbox(
                label="drug sequence (SMILES)",
                # info="standard",
                interactive=True,
                lines=3,
                value=drug_seq,
            )
        with gr.Row():
            run_mammal = gr.Button(
                "Run Mammal prompt for Target Drug Affinity", variant="primary"
            )
        with gr.Row():
            prompt_box = gr.Textbox(label="Mammal prompt", lines=5)

        with gr.Row():
            decoded = gr.Textbox(label="Mammal output")
            run_mammal.click(
                fn=create_and_run_prompt_dtb,
                inputs=[prot, drug],
                outputs=[prompt_box, decoded, gr.Number(label="DTI score")],
            )
        tdb_demo.visible = False
        return tdb_demo


def create_application():

    with gr.Blocks() as demo:
        main_dropdown = gr.Dropdown(choices=["select demo", ppi, dti])
        main_dropdown.interactive = True
        ppi_demo = create_ppi_demo()
        dtb_demo = create_tdb_demo()

        def set_ppi_vis(main_text):
            return gr.Group(visible=main_text == ppi), gr.Group(
                visible=main_text == dti
            )

        main_dropdown.change(
            set_ppi_vis, inputs=main_dropdown, outputs=[ppi_demo, dtb_demo]
        )
        return demo


def main():
    demo = create_application()
    demo.launch(show_error=True, share=True)


if __name__ == "__main__":
    main()