matanninio commited on
Commit
0c8cec9
·
1 Parent(s): 994bf05

added the protein solubility demo

Browse files
Files changed (2) hide show
  1. app.py +14 -1
  2. mammal_demo/ps_task.py +127 -0
app.py CHANGED
@@ -4,6 +4,7 @@ from mammal_demo.demo_framework import MammalObjectBroker, MammalTask
4
  from mammal_demo.dti_task import DtiTask
5
  from mammal_demo.ppi_task import PpiTask
6
  from mammal_demo.tcr_task import TcrTask
 
7
 
8
  all_tasks: dict[str, MammalTask] = dict()
9
  all_models: dict[str, MammalObjectBroker] = dict()
@@ -22,6 +23,10 @@ tcr_task = TcrTask(model_dict=all_models)
22
  all_tasks[tcr_task.name] = tcr_task
23
 
24
 
 
 
 
 
25
  # create the model holders. hold the model and the tokenizer, lazy download
26
  # note that the list of relevent tasks needs to be stated.
27
  ppi_model = MammalObjectBroker(
@@ -41,6 +46,13 @@ tcr_model = MammalObjectBroker(
41
  )
42
  all_models[tcr_model.name] = tcr_model
43
 
 
 
 
 
 
 
 
44
  def create_application():
45
  def task_change(value):
46
  visibility = [gr.update(visible=(task == value)) for task in all_tasks.keys()]
@@ -95,7 +107,8 @@ full_demo = None
95
  def main():
96
  global full_demo
97
  full_demo = create_application()
98
- full_demo.launch(show_error=True, share=True)
 
99
 
100
 
101
  if __name__ == "__main__":
 
4
  from mammal_demo.dti_task import DtiTask
5
  from mammal_demo.ppi_task import PpiTask
6
  from mammal_demo.tcr_task import TcrTask
7
+ from mammal_demo.ps_task import PsTask
8
 
9
  all_tasks: dict[str, MammalTask] = dict()
10
  all_models: dict[str, MammalObjectBroker] = dict()
 
23
  all_tasks[tcr_task.name] = tcr_task
24
 
25
 
26
+ ps_task = PsTask(model_dict=all_models)
27
+ all_tasks[ps_task.name] = ps_task
28
+
29
+
30
  # create the model holders. hold the model and the tokenizer, lazy download
31
  # note that the list of relevent tasks needs to be stated.
32
  ppi_model = MammalObjectBroker(
 
46
  )
47
  all_models[tcr_model.name] = tcr_model
48
 
49
+ ps_model = MammalObjectBroker(
50
+ model_path="ibm/biomed.omics.bl.sm.ma-ted-458m.protein_solubility",
51
+ task_list=[ps_task.name]
52
+ )
53
+ all_models[ps_model.name] = ps_model
54
+
55
+
56
  def create_application():
57
  def task_change(value):
58
  visibility = [gr.update(visible=(task == value)) for task in all_tasks.keys()]
 
107
  def main():
108
  global full_demo
109
  full_demo = create_application()
110
+ full_demo.launch(show_error=True, share=False)
111
+ # full_demo.launch(show_error=True, share=True)
112
 
113
 
114
  if __name__ == "__main__":
mammal_demo/ps_task.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
4
+ from mammal.examples.protein_solubility.task import ProteinSolubilityTask
5
+ from mammal.keys import (
6
+ ENCODER_INPUTS_STR,
7
+ CLS_PRED,
8
+ SCORES,
9
+ )
10
+ from mammal.model import Mammal
11
+
12
+ from mammal_demo.demo_framework import MammalObjectBroker, MammalTask
13
+
14
+
15
+ class PsTask(MammalTask):
16
+ def __init__(self, model_dict):
17
+ super().__init__(name="Protein Solubility", model_dict=model_dict)
18
+ self.description = "Protein Solubility (PS)"
19
+ self.examples = {
20
+ "protein_seq": "LLQTGIHVRVSQPSL",
21
+ }
22
+ self.markup_text = """
23
+ # Mammal based TODO: T-cell receptors-peptide binding specificity demonstration
24
+
25
+ Given the TCR beta sequance and the epitope sequacne, estimate the binding specificity.
26
+ """
27
+
28
+
29
+
30
+ def crate_sample_dict(self, sample_inputs: dict, model_holder: MammalObjectBroker):
31
+ """convert sample_inputs to sample_dict including creating a proper prompt
32
+
33
+ Args:
34
+ sample_inputs (dict): dictionary containing the inputs to the model
35
+ model_holder (MammalObjectBroker): model holder
36
+ Returns:
37
+ dict: sample_dict for feeding into model
38
+ """
39
+ sample_dict = dict(sample_inputs) # shallow copy
40
+ sample_dict = ProteinSolubilityTask.data_preprocessing(
41
+ sample_dict=sample_dict,
42
+ protein_sequence_key="protein_seq",
43
+ tokenizer_op=model_holder.tokenizer_op,
44
+ device=model_holder.model.device,
45
+ )
46
+
47
+ return sample_dict
48
+
49
+ def run_model(self, sample_dict, model: Mammal):
50
+ # Generate Prediction
51
+ batch_dict = model.generate(
52
+ [sample_dict],
53
+ output_scores=True,
54
+ return_dict_in_generate=True,
55
+ max_new_tokens=5,
56
+ )
57
+ return batch_dict
58
+
59
+ def decode_output(self, batch_dict, tokenizer_op: ModularTokenizerOp)-> dict:
60
+
61
+ """
62
+ Extract predicted class and scores
63
+ """
64
+ ans_dict = ProteinSolubilityTask.process_model_output(
65
+ tokenizer_op=tokenizer_op,
66
+ decoder_output=batch_dict[CLS_PRED][0],
67
+ decoder_output_scores=batch_dict[SCORES][0],
68
+ )
69
+ ans = [
70
+ tokenizer_op._tokenizer.decode(batch_dict[CLS_PRED][0]),
71
+ ans_dict["pred"],
72
+ ans_dict["not_normalized_scores"].item(),
73
+ ans_dict["normalized_scores"].item(),
74
+ ]
75
+ return ans
76
+
77
+
78
+
79
+ def create_and_run_prompt(self, model_name, protein_seq):
80
+ model_holder = self.model_dict[model_name]
81
+ inputs = {
82
+ "protein_seq": protein_seq,
83
+ }
84
+ sample_dict = self.crate_sample_dict(
85
+ sample_inputs=inputs, model_holder=model_holder
86
+ )
87
+ prompt = sample_dict[ENCODER_INPUTS_STR]
88
+ batch_dict = self.run_model(sample_dict=sample_dict, model=model_holder.model)
89
+ res = prompt, *self.decode_output(batch_dict, tokenizer_op=model_holder.tokenizer_op)
90
+ return res
91
+
92
+
93
+
94
+ def create_demo(self, model_name_widget):
95
+
96
+
97
+ with gr.Group() as demo:
98
+ gr.Markdown(self.markup_text)
99
+ with gr.Row():
100
+ protein_textbox = gr.Textbox(
101
+ label="Protein sequance",
102
+ # info="standard",
103
+ interactive=True,
104
+ lines=3,
105
+ value=self.examples["protein_seq"],
106
+ )
107
+ with gr.Row():
108
+ run_mammal = gr.Button(
109
+ "Run Mammal prompt for TCL-Epitope Interaction",
110
+ variant="primary",
111
+ )
112
+ with gr.Row():
113
+ prompt_box = gr.Textbox(label="Mammal prompt", lines=5)
114
+
115
+ with gr.Row():
116
+ decoded = gr.Textbox(label="Mammal output")
117
+ predicted_class = gr.Textbox(label="Mammal prediction")
118
+ with gr.Column():
119
+ non_norm_score = gr.Number(label="Non normelized score")
120
+ norm_score = gr.Number(label="Normelized score")
121
+ run_mammal.click(
122
+ fn=self.create_and_run_prompt,
123
+ inputs=[model_name_widget, protein_textbox],
124
+ outputs=[prompt_box, decoded, predicted_class,non_norm_score,norm_score],
125
+ )
126
+ demo.visible = False
127
+ return demo