matanninio commited on
Commit
b93c8a7
·
1 Parent(s): e3cb71b

cleanup and normalization of tasks

Browse files
mammal_demo/demo_framework.py CHANGED
@@ -11,6 +11,8 @@ class MammalObjectBroker:
11
  model_path: str,
12
  name: str | None = None,
13
  task_list: list[str] | None = None,
 
 
14
  ) -> None:
15
  self.model_path = model_path
16
  if name is None:
@@ -22,12 +24,14 @@ class MammalObjectBroker:
22
  self.tasks = task_list
23
  self._model: Mammal | None = None
24
  self._tokenizer_op = None
 
 
25
 
26
  @property
27
  def model(self) -> Mammal:
28
  if self._model is None:
29
  self._model = Mammal.from_pretrained(self.model_path)
30
- self._model.eval()
31
  return self._model
32
 
33
  @property
@@ -36,6 +40,11 @@ class MammalObjectBroker:
36
  self._tokenizer_op = ModularTokenizerOp.from_pretrained(self.model_path)
37
  return self._tokenizer_op
38
 
 
 
 
 
 
39
 
40
  class MammalTask(ABC):
41
  def __init__(self, name: str, model_dict: dict[str, MammalObjectBroker]) -> None:
@@ -44,19 +53,6 @@ class MammalTask(ABC):
44
  self._demo = None
45
  self.model_dict = model_dict
46
 
47
- # @abstractmethod
48
- # def _generate_prompt(self, **kwargs) -> str:
49
- # """Formatting prompt to match pre-training syntax
50
-
51
- # Args:
52
- # prot1 (_type_): _description_
53
- # prot2 (_type_): _description_
54
-
55
- # Raises:
56
- # No: _description_
57
- # """
58
- # raise NotImplementedError()
59
-
60
  @abstractmethod
61
  def crate_sample_dict(
62
  self, sample_inputs: dict, model_holder: MammalObjectBroker
@@ -97,10 +93,39 @@ class MammalTask(ABC):
97
  def decode_output(self, batch_dict, model: Mammal) -> list:
98
  raise NotImplementedError()
99
 
100
- # self._setup()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
- # def _setup(self):
103
- # pass
 
 
 
104
 
105
 
106
  class TaskRegistry(dict[str, MammalTask]):
@@ -114,7 +139,9 @@ class TaskRegistry(dict[str, MammalTask]):
114
  class ModelRegistry(dict[str, MammalObjectBroker]):
115
  """just a dictionary with a register models"""
116
 
117
- def register_model(self, model_path, task_list=None, name=None):
 
 
118
  """register a model and return the name of the model
119
  Args:
120
  model_path (_type_): _description_
@@ -124,7 +151,10 @@ class ModelRegistry(dict[str, MammalObjectBroker]):
124
  str: model name
125
  """
126
  model_holder = MammalObjectBroker(
127
- model_path=model_path, task_list=task_list, name=name
 
 
 
128
  )
129
  self[model_holder.name] = model_holder
130
  return model_holder.name
 
11
  model_path: str,
12
  name: str | None = None,
13
  task_list: list[str] | None = None,
14
+ *,
15
+ force_preload=False,
16
  ) -> None:
17
  self.model_path = model_path
18
  if name is None:
 
24
  self.tasks = task_list
25
  self._model: Mammal | None = None
26
  self._tokenizer_op = None
27
+ if force_preload:
28
+ self.force_preload()
29
 
30
  @property
31
  def model(self) -> Mammal:
32
  if self._model is None:
33
  self._model = Mammal.from_pretrained(self.model_path)
34
+ self._model.eval()
35
  return self._model
36
 
37
  @property
 
40
  self._tokenizer_op = ModularTokenizerOp.from_pretrained(self.model_path)
41
  return self._tokenizer_op
42
 
43
+ def force_preload(self):
44
+ """pre-load the model and tokenizer (in this order)"""
45
+ _ = self.model
46
+ _ = self.tokenizer_op
47
+
48
 
49
  class MammalTask(ABC):
50
  def __init__(self, name: str, model_dict: dict[str, MammalObjectBroker]) -> None:
 
53
  self._demo = None
54
  self.model_dict = model_dict
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  @abstractmethod
57
  def crate_sample_dict(
58
  self, sample_inputs: dict, model_holder: MammalObjectBroker
 
93
  def decode_output(self, batch_dict, model: Mammal) -> list:
94
  raise NotImplementedError()
95
 
96
+ # classification helpers
97
+ @staticmethod
98
+ def positive_token_id(tokenizer_op: ModularTokenizerOp) -> int:
99
+ """token for positive binding
100
+
101
+ Args:
102
+ model (MammalTrainedModel): model holding tokenizer
103
+
104
+ Returns:
105
+ int: id of positive binding token
106
+ """
107
+ return tokenizer_op.get_token_id("<1>")
108
+
109
+ @staticmethod
110
+ def negative_token_id(tokenizer_op: ModularTokenizerOp) -> int:
111
+ """token for negative binding
112
+
113
+ Args:
114
+ model (MammalTrainedModel): model holding tokenizer
115
+
116
+ Returns:
117
+ int: id of negative binding token
118
+ """
119
+ return tokenizer_op.get_token_id("<0>")
120
+
121
+ @staticmethod
122
+ def get_label_from_token(tokenizer_op: ModularTokenizerOp, token_id):
123
 
124
+ label_mapping = {
125
+ MammalTask.negative_token_id(tokenizer_op): "negative",
126
+ MammalTask.positive_token_id(tokenizer_op): "positive",
127
+ }
128
+ return label_mapping.get(token_id, token_id)
129
 
130
 
131
  class TaskRegistry(dict[str, MammalTask]):
 
139
  class ModelRegistry(dict[str, MammalObjectBroker]):
140
  """just a dictionary with a register models"""
141
 
142
+ def register_model(
143
+ self, model_path, task_list=None, name=None, *, force_preload=False
144
+ ):
145
  """register a model and return the name of the model
146
  Args:
147
  model_path (_type_): _description_
 
151
  str: model name
152
  """
153
  model_holder = MammalObjectBroker(
154
+ model_path=model_path,
155
+ task_list=task_list,
156
+ name=name,
157
+ force_preload=force_preload,
158
  )
159
  self[model_holder.name] = model_holder
160
  return model_holder.name
mammal_demo/ppi_task.py CHANGED
@@ -1,10 +1,12 @@
1
  import gradio as gr
2
  import torch
 
3
  from mammal.keys import (
4
  CLS_PRED,
5
  ENCODER_INPUTS_ATTENTION_MASK,
6
  ENCODER_INPUTS_STR,
7
  ENCODER_INPUTS_TOKENS,
 
8
  )
9
  from mammal.model import Mammal
10
 
@@ -24,24 +26,12 @@ class PpiTask(MammalTask):
24
 
25
  Given two protein sequences, estimate if the proteins interact or not."""
26
 
27
- @staticmethod
28
- def positive_token_id(model_holder: MammalObjectBroker):
29
- """token for positive binding
30
-
31
- Args:
32
- model (MammalTrainedModel): model holding tokenizer
33
-
34
- Returns:
35
- int: id of positive binding token
36
- """
37
- return model_holder.tokenizer_op.get_token_id("<1>")
38
-
39
- def generate_prompt(self, prot1, prot2):
40
  """Formatting prompt to match pre-training syntax
41
 
42
  Args:
43
- prot1 (str): sequance of protein number 1
44
- prot2 (str): sequance of protein number 2
45
 
46
  Returns:
47
  str: prompt
@@ -49,9 +39,9 @@ class PpiTask(MammalTask):
49
  prompt = (
50
  "<@TOKENIZER-TYPE=AA><BINDING_AFFINITY_CLASS><SENTINEL_ID_0>"
51
  + "<MOLECULAR_ENTITY><MOLECULAR_ENTITY_GENERAL_PROTEIN>"
52
- + f"<SEQUENCE_NATURAL_START>{prot1}<SEQUENCE_NATURAL_END>"
53
  + "<MOLECULAR_ENTITY><MOLECULAR_ENTITY_GENERAL_PROTEIN>"
54
- + f"<SEQUENCE_NATURAL_START>{prot2}<SEQUENCE_NATURAL_END><EOS>"
55
  )
56
  return prompt
57
 
@@ -74,6 +64,7 @@ class PpiTask(MammalTask):
74
  sample_dict[ENCODER_INPUTS_ATTENTION_MASK] = torch.tensor(
75
  sample_dict[ENCODER_INPUTS_ATTENTION_MASK]
76
  )
 
77
  return sample_dict
78
 
79
  def run_model(self, sample_dict, model: Mammal):
@@ -86,27 +77,26 @@ class PpiTask(MammalTask):
86
  )
87
  return batch_dict
88
 
89
- def decode_output(self, batch_dict, model_holder: MammalObjectBroker):
90
 
91
  # Get output
92
- generated_output = model_holder.tokenizer_op._tokenizer.decode(
93
- batch_dict[CLS_PRED][0]
94
- )
95
- score = batch_dict["model.out.scores"][0][1][
96
- self.positive_token_id(model_holder)
97
- ].item()
98
 
99
- return generated_output, score
 
100
 
101
- def create_and_run_prompt(self, model_name, protein1, protein2):
102
  model_holder = self.model_dict[model_name]
103
- sample_inputs = {"prot1": protein1, "prot2": protein2}
104
  sample_dict = self.crate_sample_dict(
105
  sample_inputs=sample_inputs, model_holder=model_holder
106
  )
107
  prompt = sample_dict[ENCODER_INPUTS_STR]
108
  batch_dict = self.run_model(sample_dict=sample_dict, model=model_holder.model)
109
- res = prompt, *self.decode_output(batch_dict, model_holder=model_holder)
 
 
110
  return res
111
 
112
  def create_demo(self, model_name_widget: gr.component):
@@ -119,14 +109,14 @@ class PpiTask(MammalTask):
119
  with gr.Group() as demo:
120
  gr.Markdown(self.markup_text)
121
  with gr.Row():
122
- prot1 = gr.Textbox(
123
  label="Protein 1 sequence",
124
  # info="standard",
125
  interactive=True,
126
  lines=3,
127
  value=self.examples["protein_calmodulin"],
128
  )
129
- prot2 = gr.Textbox(
130
  label="Protein 2 sequence",
131
  # info="standard",
132
  interactive=True,
@@ -145,7 +135,7 @@ class PpiTask(MammalTask):
145
  score_box = gr.Number(label="PPI score")
146
  run_mammal.click(
147
  fn=self.create_and_run_prompt,
148
- inputs=[model_name_widget, prot1, prot2],
149
  outputs=[prompt_box, decoded, score_box],
150
  )
151
  with gr.Row():
 
1
  import gradio as gr
2
  import torch
3
+ from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
4
  from mammal.keys import (
5
  CLS_PRED,
6
  ENCODER_INPUTS_ATTENTION_MASK,
7
  ENCODER_INPUTS_STR,
8
  ENCODER_INPUTS_TOKENS,
9
+ SCORES,
10
  )
11
  from mammal.model import Mammal
12
 
 
26
 
27
  Given two protein sequences, estimate if the proteins interact or not."""
28
 
29
+ def generate_prompt(self, protein_seq_1, protein_seq_2):
 
 
 
 
 
 
 
 
 
 
 
 
30
  """Formatting prompt to match pre-training syntax
31
 
32
  Args:
33
+ protein_seq_1 (str): sequance of protein number 1
34
+ protein_seq_2 (str): sequance of protein number 2
35
 
36
  Returns:
37
  str: prompt
 
39
  prompt = (
40
  "<@TOKENIZER-TYPE=AA><BINDING_AFFINITY_CLASS><SENTINEL_ID_0>"
41
  + "<MOLECULAR_ENTITY><MOLECULAR_ENTITY_GENERAL_PROTEIN>"
42
+ + f"<SEQUENCE_NATURAL_START>{protein_seq_1}<SEQUENCE_NATURAL_END>"
43
  + "<MOLECULAR_ENTITY><MOLECULAR_ENTITY_GENERAL_PROTEIN>"
44
+ + f"<SEQUENCE_NATURAL_START>{protein_seq_2}<SEQUENCE_NATURAL_END><EOS>"
45
  )
46
  return prompt
47
 
 
64
  sample_dict[ENCODER_INPUTS_ATTENTION_MASK] = torch.tensor(
65
  sample_dict[ENCODER_INPUTS_ATTENTION_MASK]
66
  )
67
+
68
  return sample_dict
69
 
70
  def run_model(self, sample_dict, model: Mammal):
 
77
  )
78
  return batch_dict
79
 
80
+ def decode_output(self, batch_dict, tokenizer_op: ModularTokenizerOp) -> list:
81
 
82
  # Get output
83
+ generated_output = tokenizer_op._tokenizer.decode(batch_dict[CLS_PRED][0])
84
+ score = batch_dict[SCORES][0][1][self.positive_token_id(tokenizer_op)].item()
 
 
 
 
85
 
86
+ ans = [generated_output, score]
87
+ return ans
88
 
89
+ def create_and_run_prompt(self, model_name, protein_seq_1, protein_seq_2):
90
  model_holder = self.model_dict[model_name]
91
+ sample_inputs = {"protein_seq_1": protein_seq_1, "protein_seq_2": protein_seq_2}
92
  sample_dict = self.crate_sample_dict(
93
  sample_inputs=sample_inputs, model_holder=model_holder
94
  )
95
  prompt = sample_dict[ENCODER_INPUTS_STR]
96
  batch_dict = self.run_model(sample_dict=sample_dict, model=model_holder.model)
97
+ res = prompt, *self.decode_output(
98
+ batch_dict, tokenizer_op=model_holder.tokenizer_op
99
+ )
100
  return res
101
 
102
  def create_demo(self, model_name_widget: gr.component):
 
109
  with gr.Group() as demo:
110
  gr.Markdown(self.markup_text)
111
  with gr.Row():
112
+ protein_seq_1 = gr.Textbox(
113
  label="Protein 1 sequence",
114
  # info="standard",
115
  interactive=True,
116
  lines=3,
117
  value=self.examples["protein_calmodulin"],
118
  )
119
+ protein_seq_2 = gr.Textbox(
120
  label="Protein 2 sequence",
121
  # info="standard",
122
  interactive=True,
 
135
  score_box = gr.Number(label="PPI score")
136
  run_mammal.click(
137
  fn=self.create_and_run_prompt,
138
+ inputs=[model_name_widget, protein_seq_1, protein_seq_2],
139
  outputs=[prompt_box, decoded, score_box],
140
  )
141
  with gr.Row():
mammal_demo/ps_task.py CHANGED
@@ -10,6 +10,9 @@ from mammal.model import Mammal
10
 
11
  from mammal_demo.demo_framework import MammalObjectBroker, MammalTask
12
 
 
 
 
13
 
14
  class PsTask(MammalTask):
15
  def __init__(self, model_dict):
@@ -34,7 +37,7 @@ Given the protein sequence, estimate if it's water-soluble.
34
  dict: sample_dict for feeding into model
35
  """
36
  sample_dict = dict(sample_inputs) # shallow copy
37
- sample_dict = ProteinSolubilityTask.data_preprocessing(
38
  sample_dict=sample_dict,
39
  protein_sequence_key="protein_seq",
40
  tokenizer_op=model_holder.tokenizer_op,
@@ -57,7 +60,7 @@ Given the protein sequence, estimate if it's water-soluble.
57
  """
58
  Extract predicted class and scores
59
  """
60
- ans_dict = ProteinSolubilityTask.process_model_output(
61
  tokenizer_op=tokenizer_op,
62
  decoder_output=batch_dict[CLS_PRED][0],
63
  decoder_output_scores=batch_dict[SCORES][0],
@@ -72,11 +75,11 @@ Given the protein sequence, estimate if it's water-soluble.
72
 
73
  def create_and_run_prompt(self, model_name, protein_seq):
74
  model_holder = self.model_dict[model_name]
75
- inputs = {
76
  "protein_seq": protein_seq,
77
  }
78
  sample_dict = self.crate_sample_dict(
79
- sample_inputs=inputs, model_holder=model_holder
80
  )
81
  prompt = sample_dict[ENCODER_INPUTS_STR]
82
  batch_dict = self.run_model(sample_dict=sample_dict, model=model_holder.model)
 
10
 
11
  from mammal_demo.demo_framework import MammalObjectBroker, MammalTask
12
 
13
+ data_preprocessing = ProteinSolubilityTask.data_preprocessing
14
+ process_model_output = ProteinSolubilityTask.process_model_output
15
+
16
 
17
  class PsTask(MammalTask):
18
  def __init__(self, model_dict):
 
37
  dict: sample_dict for feeding into model
38
  """
39
  sample_dict = dict(sample_inputs) # shallow copy
40
+ sample_dict = data_preprocessing(
41
  sample_dict=sample_dict,
42
  protein_sequence_key="protein_seq",
43
  tokenizer_op=model_holder.tokenizer_op,
 
60
  """
61
  Extract predicted class and scores
62
  """
63
+ ans_dict = process_model_output(
64
  tokenizer_op=tokenizer_op,
65
  decoder_output=batch_dict[CLS_PRED][0],
66
  decoder_output_scores=batch_dict[SCORES][0],
 
75
 
76
  def create_and_run_prompt(self, model_name, protein_seq):
77
  model_holder = self.model_dict[model_name]
78
+ sample_inputs = {
79
  "protein_seq": protein_seq,
80
  }
81
  sample_dict = self.crate_sample_dict(
82
+ sample_inputs=sample_inputs, model_holder=model_holder
83
  )
84
  prompt = sample_dict[ENCODER_INPUTS_STR]
85
  batch_dict = self.run_model(sample_dict=sample_dict, model=model_holder.model)
mammal_demo/tcr_task.py CHANGED
@@ -29,7 +29,7 @@ class TcrTask(MammalTask):
29
  Given a TCR beta chain and epitope amino acid sequences, estimate the binding affinity score.
30
  """
31
 
32
- def create_prompt(self, tcr_beta_seq, epitope_seq):
33
  prompt = (
34
  "<@TOKENIZER-TYPE=AA><BINDING_AFFINITY_CLASS><SENTINEL_ID_0>"
35
  + f"<@TOKENIZER-TYPE=AA><MOLECULAR_ENTITY><MOLECULAR_ENTITY_TCR_BETA_VDJ><SEQUENCE_NATURAL_START>{tcr_beta_seq}<SEQUENCE_NATURAL_END>"
@@ -48,20 +48,21 @@ Given a TCR beta chain and epitope amino acid sequences, estimate the binding af
48
  dict: sample_dict for feeding into model
49
  """
50
  sample_dict = dict()
51
- sample_dict[ENCODER_INPUTS_STR] = self.create_prompt(**sample_inputs)
52
- tokenizer_op = model_holder.tokenizer_op
53
- model = model_holder.model
54
- tokenizer_op(
 
55
  sample_dict=sample_dict,
56
  key_in=ENCODER_INPUTS_STR,
57
  key_out_tokens_ids=ENCODER_INPUTS_TOKENS,
58
  key_out_attention_mask=ENCODER_INPUTS_ATTENTION_MASK,
59
  )
60
  sample_dict[ENCODER_INPUTS_TOKENS] = torch.tensor(
61
- sample_dict[ENCODER_INPUTS_TOKENS], device=model.device
62
  )
63
  sample_dict[ENCODER_INPUTS_ATTENTION_MASK] = torch.tensor(
64
- sample_dict[ENCODER_INPUTS_ATTENTION_MASK], device=model.device
65
  )
66
 
67
  return sample_dict
@@ -76,47 +77,26 @@ Given a TCR beta chain and epitope amino acid sequences, estimate the binding af
76
  )
77
  return batch_dict
78
 
79
- @staticmethod
80
- def positive_token_id(tokenizer_op: ModularTokenizerOp):
81
- """token for positive binding
82
-
83
- Args:
84
- model (MammalTrainedModel): model holding tokenizer
85
-
86
- Returns:
87
- int: id of positive binding token
88
- """
89
- return tokenizer_op.get_token_id("<1>")
90
-
91
- @staticmethod
92
- def negative_token_id(tokenizer_op: ModularTokenizerOp):
93
- """token for negative binding
94
-
95
- Args:
96
- model (MammalTrainedModel): model holding tokenizer
97
-
98
- Returns:
99
- int: id of negative binding token
100
- """
101
- return tokenizer_op.get_token_id("<0>")
102
-
103
  def decode_output(self, batch_dict, tokenizer_op: ModularTokenizerOp) -> list:
104
  """
105
  Extract predicted class and scores
106
  """
107
 
108
- # positive_token_id = self.positive_token_id(tokenizer_op)
109
- # negative_token_id = self.negative_token_id(tokenizer_op)
110
 
111
- negative_token_id = tokenizer_op.get_token_id("<0>")
112
- positive_token_id = tokenizer_op.get_token_id("<1>")
113
 
114
  label_id_to_int = {
115
- negative_token_id: 0,
116
- positive_token_id: 1,
117
  }
118
  classification_position = 1
119
 
 
 
 
120
  decoder_output = batch_dict[CLS_PRED][0]
121
  decoder_output_scores = batch_dict[SCORES][0]
122
 
@@ -126,7 +106,7 @@ Given a TCR beta chain and epitope amino acid sequences, estimate the binding af
126
  scores = [None]
127
 
128
  ans = [
129
- tokenizer_op._tokenizer.decode(batch_dict[CLS_PRED][0]),
130
  label_id_to_int.get(int(decoder_output[classification_position]), -1),
131
  scores.item(),
132
  ]
 
29
  Given a TCR beta chain and epitope amino acid sequences, estimate the binding affinity score.
30
  """
31
 
32
+ def generate_prompt(self, tcr_beta_seq, epitope_seq):
33
  prompt = (
34
  "<@TOKENIZER-TYPE=AA><BINDING_AFFINITY_CLASS><SENTINEL_ID_0>"
35
  + f"<@TOKENIZER-TYPE=AA><MOLECULAR_ENTITY><MOLECULAR_ENTITY_TCR_BETA_VDJ><SEQUENCE_NATURAL_START>{tcr_beta_seq}<SEQUENCE_NATURAL_END>"
 
48
  dict: sample_dict for feeding into model
49
  """
50
  sample_dict = dict()
51
+ prompt = self.generate_prompt(**sample_inputs)
52
+ sample_dict[ENCODER_INPUTS_STR] = prompt
53
+
54
+ # Tokenize
55
+ sample_dict = model_holder.tokenizer_op(
56
  sample_dict=sample_dict,
57
  key_in=ENCODER_INPUTS_STR,
58
  key_out_tokens_ids=ENCODER_INPUTS_TOKENS,
59
  key_out_attention_mask=ENCODER_INPUTS_ATTENTION_MASK,
60
  )
61
  sample_dict[ENCODER_INPUTS_TOKENS] = torch.tensor(
62
+ sample_dict[ENCODER_INPUTS_TOKENS]
63
  )
64
  sample_dict[ENCODER_INPUTS_ATTENTION_MASK] = torch.tensor(
65
+ sample_dict[ENCODER_INPUTS_ATTENTION_MASK]
66
  )
67
 
68
  return sample_dict
 
77
  )
78
  return batch_dict
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  def decode_output(self, batch_dict, tokenizer_op: ModularTokenizerOp) -> list:
81
  """
82
  Extract predicted class and scores
83
  """
84
 
85
+ positive_token_id = self.positive_token_id(tokenizer_op)
86
+ negative_token_id = self.negative_token_id(tokenizer_op)
87
 
88
+ # negative_token_id = tokenizer_op.get_token_id("<0>")
89
+ # positive_token_id = tokenizer_op.get_token_id("<1>")
90
 
91
  label_id_to_int = {
92
+ negative_token_id: "negative",
93
+ positive_token_id: "positive",
94
  }
95
  classification_position = 1
96
 
97
+ # Get output
98
+ generated_output = tokenizer_op._tokenizer.decode(batch_dict[CLS_PRED][0])
99
+
100
  decoder_output = batch_dict[CLS_PRED][0]
101
  decoder_output_scores = batch_dict[SCORES][0]
102
 
 
106
  scores = [None]
107
 
108
  ans = [
109
+ generated_output,
110
  label_id_to_int.get(int(decoder_output[classification_position]), -1),
111
  scores.item(),
112
  ]