matanninio commited on
Commit
022cccc
·
1 Parent(s): 4c8737b

version working with two demos and possibly multi-models

Browse files
Files changed (1) hide show
  1. new_app.py +181 -49
new_app.py CHANGED
@@ -43,21 +43,21 @@ class MammalTask(ABC):
43
  self.description = None
44
  self._demo = None
45
 
46
- @abstractmethod
47
- def generate_prompt(self, **kwargs) -> str:
48
- """Formatting prompt to match pre-training syntax
49
 
50
- Args:
51
- prot1 (_type_): _description_
52
- prot2 (_type_): _description_
53
 
54
- Raises:
55
- No: _description_
56
- """
57
- raise NotImplementedError()
58
 
59
  @abstractmethod
60
- def crate_sample_dict(self, prompt: str, **kwargs) -> dict:
61
  """Formatting prompt to match pre-training syntax
62
 
63
  Args:
@@ -72,19 +72,25 @@ class MammalTask(ABC):
72
  def run_model(self, sample_dict, model:Mammal):
73
  raise NotImplementedError()
74
 
75
- @abstractmethod
76
- def create_demo(self, model_name_dropdown):
77
  """create an gradio demo group
78
 
79
- Returns:
80
- _type_: _description_
 
 
 
 
 
81
  """
82
  raise NotImplementedError()
83
 
 
84
 
85
- def demo(self,model_name_dropdown=None):
86
  if self._demo is None:
87
- self._demo = self.create_demo(model_name_dropdown=model_name_dropdown)
 
88
  return self._demo
89
 
90
  @abstractmethod
@@ -103,7 +109,7 @@ all_models= dict()
103
 
104
  class PpiTask(MammalTask):
105
  def __init__(self):
106
- super().__init__(name="PPI")
107
  self.description = "Protein-Protein Interaction (PPI)"
108
  self.examples = {
109
  "protein_calmodulin": "MADQLTEEQIAEFKEAFSLFDKDGDGTITTKELGTVMRSLGQNPTEAELQDMISELDQDGFIDKEDLHDGDGKISFEEFLNLVNKEMTADVDGDGQVNYEEFVTMMTSK",
@@ -138,17 +144,18 @@ class PpiTask(MammalTask):
138
  Returns:
139
  str: prompt
140
  """
141
- prompt = "<@TOKENIZER-TYPE=AA><BINDING_AFFINITY_CLASS><SENTINEL_ID_0>"\
142
  "<MOLECULAR_ENTITY><MOLECULAR_ENTITY_GENERAL_PROTEIN>"\
143
- f"<SEQUENCE_NATURAL_START>{prot1}<SEQUENCE_NATURAL_END>"\
144
  "<MOLECULAR_ENTITY><MOLECULAR_ENTITY_GENERAL_PROTEIN>"\
145
- f"<SEQUENCE_NATURAL_START>{prot2}<SEQUENCE_NATURAL_END><EOS>"
146
  return prompt
147
 
148
 
149
- def crate_sample_dict(self,prompt: str, model_holder:MammalObjectBroker):
150
  # Create and load sample
151
  sample_dict = dict()
 
152
  sample_dict[ENCODER_INPUTS_STR] = prompt
153
 
154
  # Tokenize
@@ -176,7 +183,7 @@ class PpiTask(MammalTask):
176
  )
177
  return batch_dict
178
 
179
- def decode_output(self,batch_dict, model_holder):
180
 
181
  # Get output
182
  generated_output = model_holder.tokenizer_op._tokenizer.decode(batch_dict[CLS_PRED][0])
@@ -187,14 +194,17 @@ class PpiTask(MammalTask):
187
 
188
  def create_and_run_prompt(self,model_name,protein1, protein2):
189
  model_holder = all_models[model_name]
190
- prompt = self.generate_prompt(protein1, protein2)
191
- sample_dict = self.crate_sample_dict(prompt=prompt, model_holder=model_holder)
 
 
 
192
  batch_dict = self.run_model(sample_dict=sample_dict, model=model_holder.model)
193
  res = prompt, *self.decode_output(batch_dict,model_holder=model_holder)
194
  return res
195
 
196
 
197
- def create_demo(self,model_name_dropdown):
198
 
199
  # """
200
  # ### Using the model from
@@ -219,7 +229,7 @@ class PpiTask(MammalTask):
219
  value=self.examples["protein_calcineurin"],
220
  )
221
  with gr.Row():
222
- run_mammal = gr.Button(
223
  "Run Mammal prompt for Protein-Protein Interaction", variant="primary"
224
  )
225
  with gr.Row():
@@ -229,63 +239,185 @@ class PpiTask(MammalTask):
229
  decoded = gr.Textbox(label="Mammal output")
230
  run_mammal.click(
231
  fn=self.create_and_run_prompt,
232
- inputs=[model_name_dropdown, prot1, prot2],
233
  outputs=[prompt_box, decoded, gr.Number(label="PPI score")],
234
  )
235
  with gr.Row():
236
  gr.Markdown(
237
  "```<SENTINEL_ID_0>``` contains the binding affinity class, which is ```<1>``` for interacting and ```<0>``` for non-interacting"
238
  )
239
- demo.visible = True
240
  return demo
241
 
242
  ppi_task = PpiTask()
243
  all_tasks[ppi_task.name]=ppi_task
244
 
245
- ppi_model = MammalObjectBroker(model_path="ibm/biomed.omics.bl.sm.ma-ted-458m", task_list=["PPI"])
246
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  all_models[ppi_model.name]=ppi_model
248
- # tdi_model = MammalTrainedModel(model_path="ibm/biomed.omics.bl.sm.ma-ted-458m.dti_bindingdb_pkd") TODO: ## task list still empty
249
- # all_models.append(tdi_model)
 
250
 
251
 
252
  def create_application():
253
  def task_change(value):
 
 
254
  choices=[model_name for model_name, model in all_models.items() if value in model.tasks]
255
  if choices:
256
- return gr.update(choices=choices, value=choices[0])
257
  else:
258
- return
259
  # return model_name_dropdown
260
 
261
 
262
- with gr.Blocks() as demo:
263
  task_dropdown = gr.Dropdown(choices=["select demo"] + list(all_tasks.keys()))
264
  task_dropdown.interactive = True
265
  model_name_dropdown = gr.Dropdown(choices=[model_name for model_name, model in all_models.items() if task_dropdown.value in model.tasks], interactive=True)
266
- task_dropdown.change(task_change,inputs=[task_dropdown],outputs=[model_name_dropdown])
267
 
268
 
269
 
270
 
271
 
272
- ppi_demo = all_tasks["PPI"].demo(model_name_dropdown = model_name_dropdown)
273
- ppi_demo.visible = True
274
- # dtb_demo = create_tdb_demo()
 
 
 
 
 
 
 
 
 
275
 
276
- def set_ppi_vis(main_text):
277
- main_text=main_text
278
- print(f"main text is {main_text}")
279
- return gr.Group(visible=True)
280
- #return gr.Group(visible=(main_text == "PPI"))
281
- # , gr.Group( visible=(main_text == "DTI") )
282
 
283
- task_dropdown.change(
284
- set_ppi_vis, inputs=task_dropdown, outputs=[ppi_demo]
285
- )
286
- return demo
287
 
288
  full_demo=None
 
289
  def main():
290
  global full_demo
291
  full_demo = create_application()
 
43
  self.description = None
44
  self._demo = None
45
 
46
+ # @abstractmethod
47
+ # def _generate_prompt(self, **kwargs) -> str:
48
+ # """Formatting prompt to match pre-training syntax
49
 
50
+ # Args:
51
+ # prot1 (_type_): _description_
52
+ # prot2 (_type_): _description_
53
 
54
+ # Raises:
55
+ # No: _description_
56
+ # """
57
+ # raise NotImplementedError()
58
 
59
  @abstractmethod
60
+ def crate_sample_dict(self,sample_inputs: dict, model_holder:MammalObjectBroker) -> dict:
61
  """Formatting prompt to match pre-training syntax
62
 
63
  Args:
 
72
  def run_model(self, sample_dict, model:Mammal):
73
  raise NotImplementedError()
74
 
75
+ def create_demo(self, model_name_widget: gr.component) -> gr.Group:
 
76
  """create an gradio demo group
77
 
78
+ Args:
79
+ model_name_widgit (gr.Component): widget holding the model name to use. This is needed to create
80
+ gradio actions with the current model name as an input
81
+
82
+
83
+ Raises:
84
+ NotImplementedError: _description_
85
  """
86
  raise NotImplementedError()
87
 
88
+
89
 
90
+ def demo(self,model_name_widgit:gr.component=None):
91
  if self._demo is None:
92
+ model_name_widget:gr.component
93
+ self._demo = self.create_demo(model_name_widget=model_name_widgit)
94
  return self._demo
95
 
96
  @abstractmethod
 
109
 
110
  class PpiTask(MammalTask):
111
  def __init__(self):
112
+ super().__init__(name="Protein-Protein Interaction")
113
  self.description = "Protein-Protein Interaction (PPI)"
114
  self.examples = {
115
  "protein_calmodulin": "MADQLTEEQIAEFKEAFSLFDKDGDGTITTKELGTVMRSLGQNPTEAELQDMISELDQDGFIDKEDLHDGDGKISFEEFLNLVNKEMTADVDGDGQVNYEEFVTMMTSK",
 
144
  Returns:
145
  str: prompt
146
  """
147
+ prompt = f"<@TOKENIZER-TYPE=AA><BINDING_AFFINITY_CLASS><SENTINEL_ID_0>"\
148
  "<MOLECULAR_ENTITY><MOLECULAR_ENTITY_GENERAL_PROTEIN>"\
149
+ "<SEQUENCE_NATURAL_START>{prot1}<SEQUENCE_NATURAL_END>"\
150
  "<MOLECULAR_ENTITY><MOLECULAR_ENTITY_GENERAL_PROTEIN>"\
151
+ "<SEQUENCE_NATURAL_START>{prot2}<SEQUENCE_NATURAL_END><EOS>"
152
  return prompt
153
 
154
 
155
+ def crate_sample_dict(self,sample_inputs: dict, model_holder:MammalObjectBroker):
156
  # Create and load sample
157
  sample_dict = dict()
158
+ prompt = self.generate_prompt(*sample_inputs)
159
  sample_dict[ENCODER_INPUTS_STR] = prompt
160
 
161
  # Tokenize
 
183
  )
184
  return batch_dict
185
 
186
+ def decode_output(self,batch_dict, model_holder:MammalObjectBroker):
187
 
188
  # Get output
189
  generated_output = model_holder.tokenizer_op._tokenizer.decode(batch_dict[CLS_PRED][0])
 
194
 
195
  def create_and_run_prompt(self,model_name,protein1, protein2):
196
  model_holder = all_models[model_name]
197
+ sample_inputs = {"prot1":protein1,
198
+ "prot2":protein2
199
+ }
200
+ sample_dict = self.crate_sample_dict(sample_inputs=sample_inputs, model_holder=model_holder)
201
+ prompt = sample_dict[ENCODER_INPUTS_STR]
202
  batch_dict = self.run_model(sample_dict=sample_dict, model=model_holder.model)
203
  res = prompt, *self.decode_output(batch_dict,model_holder=model_holder)
204
  return res
205
 
206
 
207
+ def create_demo(self,model_name_widget:gr.component):
208
 
209
  # """
210
  # ### Using the model from
 
229
  value=self.examples["protein_calcineurin"],
230
  )
231
  with gr.Row():
232
+ run_mammal: gr.Button = gr.Button(
233
  "Run Mammal prompt for Protein-Protein Interaction", variant="primary"
234
  )
235
  with gr.Row():
 
239
  decoded = gr.Textbox(label="Mammal output")
240
  run_mammal.click(
241
  fn=self.create_and_run_prompt,
242
+ inputs=[model_name_widget, prot1, prot2],
243
  outputs=[prompt_box, decoded, gr.Number(label="PPI score")],
244
  )
245
  with gr.Row():
246
  gr.Markdown(
247
  "```<SENTINEL_ID_0>``` contains the binding affinity class, which is ```<1>``` for interacting and ```<0>``` for non-interacting"
248
  )
249
+ demo.visible = False
250
  return demo
251
 
252
  ppi_task = PpiTask()
253
  all_tasks[ppi_task.name]=ppi_task
254
 
 
255
 
256
+ class DtiTask(MammalTask):
257
+ def __init__(self):
258
+ super().__init__(name="Drug-Target Binding Affinity")
259
+ self.description = "Drug-Target Binding Affinity (tdi)"
260
+ self.examples = {
261
+ "target_seq": "NLMKRCTRGFRKLGKCTTLEEEKCKTLYPRGQCTCSDSKMNTHSCDCKSC",
262
+ "drug_seq":"CC(=O)NCCC1=CNc2c1cc(OC)cc2"
263
+ }
264
+ self.markup_text = """
265
+ # Mammal based Target-Drug binding affinity demonstration
266
+
267
+ Given a protein sequence and a drug (in SMILES), estimate the binding affinity.
268
+ """
269
+
270
+ def crate_sample_dict(self, sample_inputs:dict, model_holder:MammalObjectBroker):
271
+ """convert sample_inputs to sample_dict including creating a proper prompt
272
+
273
+ Args:
274
+ sample_inputs (dict): dictionary containing the inputs to the model
275
+ model_holder (MammalObjectBroker): model holder
276
+ Returns:
277
+ dict: sample_dict for feeding into model
278
+ """
279
+ sample_dict = dict(sample_inputs)
280
+ sample_dict = DtiBindingdbKdTask.data_preprocessing(
281
+ sample_dict=sample_dict,
282
+ tokenizer_op=model_holder.tokenizer_op,
283
+ target_sequence_key="target_seq",
284
+ drug_sequence_key="drug_seq",
285
+ norm_y_mean=None,
286
+ norm_y_std=None,
287
+ device=model_holder.model.device,
288
+ )
289
+ return sample_dict
290
+
291
+
292
+ def run_model(self, sample_dict, model: Mammal):
293
+ # Generate Prediction
294
+ batch_dict = model.generate(
295
+ [sample_dict],
296
+ output_scores=True,
297
+ return_dict_in_generate=True,
298
+ max_new_tokens=5,
299
+ )
300
+ return batch_dict
301
+
302
+ def decode_output(self,batch_dict, model_holder):
303
+
304
+ # Get output
305
+ generated_output = model_holder.tokenizer_op._tokenizer.decode(batch_dict[CLS_PRED][0])
306
+ score = batch_dict["model.out.scores"][0][1][self.positive_token_id(model_holder)].item()
307
+
308
+ return generated_output, score
309
+
310
+
311
+ def create_and_run_prompt(self,model_name,target_seq, drug_seq):
312
+ model_holder = all_models[model_name]
313
+ inputs = {
314
+ "target_seq": target_seq,
315
+ "drug_seq": drug_seq,
316
+ }
317
+ sample_dict = self.crate_sample_dict(sample_inputs=inputs, model_holder=model_holder)
318
+ prompt=sample_dict[ENCODER_INPUTS_STR]
319
+ batch_dict = self.run_model(sample_dict=sample_dict, model=model_holder.model)
320
+ res = prompt, *self.decode_output(batch_dict,model_holder=model_holder)
321
+ return res
322
+
323
+
324
+ def create_demo(self,model_name_widget):
325
+
326
+ # """
327
+ # ### Using the model from
328
+
329
+ # ```{model} ```
330
+ # """
331
+ with gr.Group() as demo:
332
+ gr.Markdown(self.markup_text)
333
+ with gr.Row():
334
+ target_textbox = gr.Textbox(
335
+ label="target sequence",
336
+ # info="standard",
337
+ interactive=True,
338
+ lines=3,
339
+ value=self.examples["target_seq"],
340
+ )
341
+ drug_textbox = gr.Textbox(
342
+ label="Drug sequance (in SMILES)",
343
+ # info="standard",
344
+ interactive=True,
345
+ lines=3,
346
+ value=self.examples["drug_seq"],
347
+ )
348
+ with gr.Row():
349
+ run_mammal = gr.Button(
350
+ "Run Mammal prompt for Protein-Protein Interaction", variant="primary"
351
+ )
352
+ with gr.Row():
353
+ prompt_box = gr.Textbox(label="Mammal prompt", lines=5)
354
+
355
+ with gr.Row():
356
+ decoded = gr.Textbox(label="Mammal output")
357
+ run_mammal.click(
358
+ fn=self.create_and_run_prompt,
359
+ inputs=[model_name_widget, target_textbox, drug_textbox],
360
+ outputs=[prompt_box, decoded, gr.Number(label="PPI score")],
361
+ )
362
+ with gr.Row():
363
+ gr.Markdown(
364
+ "```<SENTINEL_ID_0>``` contains the binding affinity class, which is ```<1>``` for interacting and ```<0>``` for non-interacting"
365
+ )
366
+ demo.visible = False
367
+ return demo
368
+
369
+ tdi_task = DtiTask()
370
+ all_tasks[tdi_task.name]=tdi_task
371
+
372
+ ppi_model = MammalObjectBroker(model_path="ibm/biomed.omics.bl.sm.ma-ted-458m", task_list=[ppi_task.name])
373
  all_models[ppi_model.name]=ppi_model
374
+
375
+ tdi_model = MammalObjectBroker(model_path="ibm/biomed.omics.bl.sm.ma-ted-458m.dti_bindingdb_pkd", task_list=[tdi_task.name])
376
+ all_models[tdi_model.name]=tdi_model
377
 
378
 
379
  def create_application():
380
  def task_change(value):
381
+ visibility = [gr.update(visible=(task==value)) for task in all_tasks.keys()]
382
+ # all_tasks[task].demo().visible =
383
  choices=[model_name for model_name, model in all_models.items() if value in model.tasks]
384
  if choices:
385
+ return (gr.update(choices=choices, value=choices[0]),*visibility)
386
  else:
387
+ return (gr.skip,*visibility)
388
  # return model_name_dropdown
389
 
390
 
391
+ with gr.Blocks() as application:
392
  task_dropdown = gr.Dropdown(choices=["select demo"] + list(all_tasks.keys()))
393
  task_dropdown.interactive = True
394
  model_name_dropdown = gr.Dropdown(choices=[model_name for model_name, model in all_models.items() if task_dropdown.value in model.tasks], interactive=True)
 
395
 
396
 
397
 
398
 
399
 
400
+ ppi_demo = all_tasks[ppi_task.name].demo(model_name_widgit = model_name_dropdown)
401
+ # ppi_demo.visible = True
402
+ dtb_demo = all_tasks[tdi_task.name].demo(model_name_widgit = model_name_dropdown)
403
+
404
+ task_dropdown.change(task_change,inputs=[task_dropdown],outputs=[model_name_dropdown]+[all_tasks[task].demo() for task in all_tasks])
405
+
406
+ # def set_demo_vis(main_text):
407
+ # main_text=main_text
408
+ # print(f"main text is {main_text}")
409
+ # return gr.Group(visible=True)
410
+ # #return gr.Group(visible=(main_text == "PPI"))
411
+ # # , gr.Group( visible=(main_text == "DTI") )
412
 
 
 
 
 
 
 
413
 
414
+ # task_dropdown.change(
415
+ # set_ppi_vis, inputs=task_dropdown, outputs=[ppi_demo]
416
+ # )
417
+ return application
418
 
419
  full_demo=None
420
+
421
  def main():
422
  global full_demo
423
  full_demo = create_application()