File size: 3,292 Bytes
ac117b5
71382c0
19dfa7a
 
 
f98cc68
71382c0
19dfa7a
 
71382c0
f98cc68
 
 
 
19dfa7a
 
 
 
 
 
f98cc68
 
 
 
 
 
19dfa7a
f98cc68
19dfa7a
 
71382c0
19dfa7a
 
f98cc68
19dfa7a
 
71382c0
f98cc68
 
 
 
 
71382c0
19dfa7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71382c0
19dfa7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71382c0
ac117b5
 
19dfa7a
 
 
ac117b5
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import gradio as gr

from mammal_demo.demo_framework import MammalObjectBroker, MammalTask
from mammal_demo.dti_task import DtiTask
from mammal_demo.ppi_task import PpiTask
from mammal_demo.tcr_task import TcrTask

all_tasks: dict[str, MammalTask] = dict()
all_models: dict[str, MammalObjectBroker] = dict()


# first create the required tasks
# Note that the tasks need access to the models, as the model to use depends on the state of the widget
# we pass the all_models dict and update it when we actualy have the models.
ppi_task = PpiTask(model_dict=all_models)
all_tasks[ppi_task.name] = ppi_task

tdi_task = DtiTask(model_dict=all_models)
all_tasks[tdi_task.name] = tdi_task

tcr_task = TcrTask(model_dict=all_models)
all_tasks[tcr_task.name] = tcr_task


# create the model holders. hold the model and the tokenizer, lazy download
# note that the list of relevent tasks needs to be stated.
ppi_model = MammalObjectBroker(
    model_path="ibm/biomed.omics.bl.sm.ma-ted-458m", task_list=[ppi_task.name,tcr_task.name]
)
all_models[ppi_model.name] = ppi_model

tdi_model = MammalObjectBroker(
    model_path="ibm/biomed.omics.bl.sm.ma-ted-458m.dti_bindingdb_pkd",
task_list=[tdi_task.name],
)
all_models[tdi_model.name] = tdi_model

tcr_model = MammalObjectBroker(
    model_path= "ibm/biomed.omics.bl.sm.ma-ted-458m.tcr_epitope_bind",
    task_list=[tcr_task.name]
)
all_models[tcr_model.name] = tcr_model

def create_application():
    def task_change(value):
        visibility = [gr.update(visible=(task == value)) for task in all_tasks.keys()]
        choices = [
            model_name
            for model_name, model in all_models.items()
            if value in model.tasks
        ]
        if choices:
            return (gr.update(choices=choices, value=choices[0], visible=True), *visibility)
        else:
            return (gr.skip, *visibility)
        # return model_name_dropdown

    with gr.Blocks() as application:
        task_dropdown = gr.Dropdown(choices=["select demo"] + list(all_tasks.keys()), label="Mammal Task")
        task_dropdown.interactive = True
        model_name_dropdown = gr.Dropdown(
            choices=[
                model_name
                for model_name, model in all_models.items()
                if task_dropdown.value in model.tasks
            ],
            interactive=True,
            label="Matching Mammal models",
            visible=False,
        )

        task_dropdown.change(
            task_change,
            inputs=[task_dropdown],
            outputs=[model_name_dropdown]
            + [all_tasks[task].demo(model_name_widgit=model_name_dropdown) for task in all_tasks],
        )

        # def set_demo_vis(main_text):
        #     main_text=main_text
        #     print(f"main text is {main_text}")
        #     return gr.Group(visible=True)
        #     #return gr.Group(visible=(main_text == "PPI"))
        # # , gr.Group(                visible=(main_text == "DTI")            )

        # task_dropdown.change(
        # set_ppi_vis, inputs=task_dropdown, outputs=[ppi_demo]
        # )
        return application


full_demo = None


def main():
    global full_demo
    full_demo = create_application()
    full_demo.launch(show_error=True, share=False)


if __name__ == "__main__":
    main()