matanninio commited on
Commit
fda141d
·
1 Parent(s): 49831fb

refactoring to make code more elegant and cleanups

Browse files
app.py CHANGED
@@ -1,57 +1,47 @@
1
  import gradio as gr
2
 
3
- from mammal_demo.demo_framework import MammalObjectBroker, MammalTask
 
 
 
4
  from mammal_demo.dti_task import DtiTask
5
  from mammal_demo.ppi_task import PpiTask
6
- from mammal_demo.tcr_task import TcrTask
7
  from mammal_demo.ps_task import PsTask
 
8
 
9
- all_tasks: dict[str, MammalTask] = dict()
10
- all_models: dict[str, MammalObjectBroker] = dict()
11
-
12
 
13
  # first create the required tasks
14
  # Note that the tasks need access to the models, as the model to use depends on the state of the widget
15
  # we pass the all_models dict and update it when we actualy have the models.
16
- ppi_task = PpiTask(model_dict=all_models)
17
- all_tasks[ppi_task.name] = ppi_task
18
-
19
- tdi_task = DtiTask(model_dict=all_models)
20
- all_tasks[tdi_task.name] = tdi_task
21
-
22
- tcr_task = TcrTask(model_dict=all_models)
23
- all_tasks[tcr_task.name] = tcr_task
24
-
25
-
26
- ps_task = PsTask(model_dict=all_models)
27
- all_tasks[ps_task.name] = ps_task
28
 
 
 
 
 
29
 
30
  # create the model holders. hold the model and the tokenizer, lazy download
31
  # note that the list of relevent tasks needs to be stated.
32
- ppi_model = MammalObjectBroker(
33
- model_path="ibm/biomed.omics.bl.sm.ma-ted-458m", task_list=[ppi_task.name,tcr_task.name]
34
- )
35
- all_models[ppi_model.name] = ppi_model
36
-
37
- tdi_model = MammalObjectBroker(
38
  model_path="ibm/biomed.omics.bl.sm.ma-ted-458m.dti_bindingdb_pkd",
39
- task_list=[tdi_task.name],
40
  )
41
- all_models[tdi_model.name] = tdi_model
42
-
43
- tcr_model = MammalObjectBroker(
44
- model_path= "ibm/biomed.omics.bl.sm.ma-ted-458m.tcr_epitope_bind",
45
- task_list=[tcr_task.name]
46
  )
47
- all_models[tcr_model.name] = tcr_model
48
-
49
- ps_model = MammalObjectBroker(
50
  model_path="ibm/biomed.omics.bl.sm.ma-ted-458m.protein_solubility",
51
- task_list=[ps_task.name]
52
  )
53
- all_models[ps_model.name] = ps_model
54
-
 
 
 
 
 
55
 
56
  def create_application():
57
  def task_change(value):
@@ -62,13 +52,18 @@ def create_application():
62
  if value in model.tasks
63
  ]
64
  if choices:
65
- return (gr.update(choices=choices, value=choices[0], visible=True), *visibility)
 
 
 
66
  else:
67
  return (gr.skip, *visibility)
68
  # return model_name_dropdown
69
 
70
  with gr.Blocks() as application:
71
- task_dropdown = gr.Dropdown(choices=["Select task"] + list(all_tasks.keys()), label="Mammal Task")
 
 
72
  task_dropdown.interactive = True
73
  model_name_dropdown = gr.Dropdown(
74
  choices=[
@@ -85,7 +80,10 @@ def create_application():
85
  task_change,
86
  inputs=[task_dropdown],
87
  outputs=[model_name_dropdown]
88
- + [all_tasks[task].demo(model_name_widgit=model_name_dropdown) for task in all_tasks],
 
 
 
89
  )
90
 
91
  # def set_demo_vis(main_text):
 
1
  import gradio as gr
2
 
3
+ from mammal_demo.demo_framework import (
4
+ ModelRegistry,
5
+ TaskRegistry,
6
+ )
7
  from mammal_demo.dti_task import DtiTask
8
  from mammal_demo.ppi_task import PpiTask
 
9
  from mammal_demo.ps_task import PsTask
10
+ from mammal_demo.tcr_task import TcrTask
11
 
12
+ all_tasks = TaskRegistry()
13
+ all_models = ModelRegistry()
 
14
 
15
  # first create the required tasks
16
  # Note that the tasks need access to the models, as the model to use depends on the state of the widget
17
  # we pass the all_models dict and update it when we actualy have the models.
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ ppi_task = all_tasks.register_task(PpiTask(model_dict=all_models))
20
+ tdi_task = all_tasks.register_task(DtiTask(model_dict=all_models))
21
+ tcr_task = all_tasks.register_task(TcrTask(model_dict=all_models))
22
+ ps_task = all_tasks.register_task(PsTask(model_dict=all_models))
23
 
24
  # create the model holders. hold the model and the tokenizer, lazy download
25
  # note that the list of relevent tasks needs to be stated.
26
+ all_models.register_model(
 
 
 
 
 
27
  model_path="ibm/biomed.omics.bl.sm.ma-ted-458m.dti_bindingdb_pkd",
28
+ task_list=[tdi_task],
29
  )
30
+ all_models.register_model(
31
+ model_path="ibm/biomed.omics.bl.sm.ma-ted-458m.tcr_epitope_bind",
32
+ task_list=[tcr_task],
 
 
33
  )
34
+ all_models.register_model(
 
 
35
  model_path="ibm/biomed.omics.bl.sm.ma-ted-458m.protein_solubility",
36
+ task_list=[ps_task],
37
  )
38
+ all_models.register_model(
39
+ model_path="ibm/biomed.omics.bl.sm.ma-ted-458m",
40
+ task_list=[ppi_task, tcr_task],
41
+ )
42
+ all_models.register_model("https://huggingface.co/ibm/biomed.omics.bl.sm.ma-ted-458m.moleculenet_clintox_tox")
43
+ all_models.register_model("https://huggingface.co/ibm/biomed.omics.bl.sm.ma-ted-458m.moleculenet_clintox_fda")
44
+ all_models.register_model("https://huggingface.co/ibm/biomed.omics.bl.sm.ma-ted-458m.moleculenet_bbbp")
45
 
46
  def create_application():
47
  def task_change(value):
 
52
  if value in model.tasks
53
  ]
54
  if choices:
55
+ return (
56
+ gr.update(choices=choices, value=choices[0], visible=True),
57
+ *visibility,
58
+ )
59
  else:
60
  return (gr.skip, *visibility)
61
  # return model_name_dropdown
62
 
63
  with gr.Blocks() as application:
64
+ task_dropdown = gr.Dropdown(
65
+ choices=["Select task"] + list(all_tasks.keys()), label="Mammal Task"
66
+ )
67
  task_dropdown.interactive = True
68
  model_name_dropdown = gr.Dropdown(
69
  choices=[
 
80
  task_change,
81
  inputs=[task_dropdown],
82
  outputs=[model_name_dropdown]
83
+ + [
84
+ all_tasks[task].demo(model_name_widgit=model_name_dropdown)
85
+ for task in all_tasks
86
+ ],
87
  )
88
 
89
  # def set_demo_vis(main_text):
mammal_demo/demo_framework.py CHANGED
@@ -90,15 +90,41 @@ class MammalTask(ABC):
90
 
91
  def demo(self, model_name_widgit: gr.component = None):
92
  if self._demo is None:
93
- model_name_widget: gr.component
94
  self._demo = self.create_demo(model_name_widget=model_name_widgit)
95
  return self._demo
96
 
97
  @abstractmethod
98
- def decode_output(self, batch_dict, model: Mammal):
99
  raise NotImplementedError()
100
 
101
  # self._setup()
102
 
103
  # def _setup(self):
104
  # pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  def demo(self, model_name_widgit: gr.component = None):
92
  if self._demo is None:
 
93
  self._demo = self.create_demo(model_name_widget=model_name_widgit)
94
  return self._demo
95
 
96
  @abstractmethod
97
+ def decode_output(self, batch_dict, model: Mammal) -> list:
98
  raise NotImplementedError()
99
 
100
  # self._setup()
101
 
102
  # def _setup(self):
103
  # pass
104
+
105
+
106
+ class TaskRegistry(dict[str, MammalTask]):
107
+ """just a dictionary with a register method"""
108
+
109
+ def register_task(self, task: MammalTask):
110
+ self[task.name] = task
111
+ return task.name
112
+
113
+
114
+ class ModelRegistry(dict[str, MammalObjectBroker]):
115
+ """just a dictionary with a register models"""
116
+
117
+ def register_model(self, model_path, task_list=None, name=None):
118
+ """register a model and return the name of the model
119
+ Args:
120
+ model_path (_type_): _description_
121
+ name (optional str): explicit name for the model
122
+
123
+ Returns:
124
+ str: model name
125
+ """
126
+ model_holder = MammalObjectBroker(
127
+ model_path=model_path, task_list=task_list, name=name
128
+ )
129
+ self[model_holder.name] = model_holder
130
+ return model_holder.name
mammal_demo/ps_task.py CHANGED
@@ -1,10 +1,9 @@
1
  import gradio as gr
2
- import torch
3
  from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
4
  from mammal.examples.protein_solubility.task import ProteinSolubilityTask
5
  from mammal.keys import (
6
- ENCODER_INPUTS_STR,
7
  CLS_PRED,
 
8
  SCORES,
9
  )
10
  from mammal.model import Mammal
@@ -25,8 +24,6 @@ class PsTask(MammalTask):
25
  Given the protein sequance, estimate if it's soluble or insoluble.
26
  """
27
 
28
-
29
-
30
  def crate_sample_dict(self, sample_inputs: dict, model_holder: MammalObjectBroker):
31
  """convert sample_inputs to sample_dict including creating a proper prompt
32
 
@@ -36,12 +33,12 @@ Given the protein sequance, estimate if it's soluble or insoluble.
36
  Returns:
37
  dict: sample_dict for feeding into model
38
  """
39
- sample_dict = dict(sample_inputs) # shallow copy
40
  sample_dict = ProteinSolubilityTask.data_preprocessing(
41
- sample_dict=sample_dict,
42
- protein_sequence_key="protein_seq",
43
- tokenizer_op=model_holder.tokenizer_op,
44
- device=model_holder.model.device,
45
  )
46
 
47
  return sample_dict
@@ -56,8 +53,7 @@ Given the protein sequance, estimate if it's soluble or insoluble.
56
  )
57
  return batch_dict
58
 
59
- def decode_output(self, batch_dict, tokenizer_op: ModularTokenizerOp)-> dict:
60
-
61
  """
62
  Extract predicted class and scores
63
  """
@@ -71,11 +67,9 @@ Given the protein sequance, estimate if it's soluble or insoluble.
71
  ans_dict["pred"],
72
  ans_dict["not_normalized_scores"].item(),
73
  ans_dict["normalized_scores"].item(),
74
- ]
75
  return ans
76
 
77
-
78
-
79
  def create_and_run_prompt(self, model_name, protein_seq):
80
  model_holder = self.model_dict[model_name]
81
  inputs = {
@@ -86,14 +80,13 @@ Given the protein sequance, estimate if it's soluble or insoluble.
86
  )
87
  prompt = sample_dict[ENCODER_INPUTS_STR]
88
  batch_dict = self.run_model(sample_dict=sample_dict, model=model_holder.model)
89
- res = prompt, *self.decode_output(batch_dict, tokenizer_op=model_holder.tokenizer_op)
 
 
90
  return res
91
 
92
-
93
-
94
  def create_demo(self, model_name_widget):
95
 
96
-
97
  with gr.Group() as demo:
98
  gr.Markdown(self.markup_text)
99
  with gr.Row():
@@ -121,7 +114,13 @@ Given the protein sequance, estimate if it's soluble or insoluble.
121
  run_mammal.click(
122
  fn=self.create_and_run_prompt,
123
  inputs=[model_name_widget, protein_textbox],
124
- outputs=[prompt_box, decoded, predicted_class,non_norm_score,norm_score],
 
 
 
 
 
 
125
  )
126
  demo.visible = False
127
  return demo
 
1
  import gradio as gr
 
2
  from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
3
  from mammal.examples.protein_solubility.task import ProteinSolubilityTask
4
  from mammal.keys import (
 
5
  CLS_PRED,
6
+ ENCODER_INPUTS_STR,
7
  SCORES,
8
  )
9
  from mammal.model import Mammal
 
24
  Given the protein sequance, estimate if it's soluble or insoluble.
25
  """
26
 
 
 
27
  def crate_sample_dict(self, sample_inputs: dict, model_holder: MammalObjectBroker):
28
  """convert sample_inputs to sample_dict including creating a proper prompt
29
 
 
33
  Returns:
34
  dict: sample_dict for feeding into model
35
  """
36
+ sample_dict = dict(sample_inputs) # shallow copy
37
  sample_dict = ProteinSolubilityTask.data_preprocessing(
38
+ sample_dict=sample_dict,
39
+ protein_sequence_key="protein_seq",
40
+ tokenizer_op=model_holder.tokenizer_op,
41
+ device=model_holder.model.device,
42
  )
43
 
44
  return sample_dict
 
53
  )
54
  return batch_dict
55
 
56
+ def decode_output(self, batch_dict, tokenizer_op: ModularTokenizerOp) -> list:
 
57
  """
58
  Extract predicted class and scores
59
  """
 
67
  ans_dict["pred"],
68
  ans_dict["not_normalized_scores"].item(),
69
  ans_dict["normalized_scores"].item(),
70
+ ]
71
  return ans
72
 
 
 
73
  def create_and_run_prompt(self, model_name, protein_seq):
74
  model_holder = self.model_dict[model_name]
75
  inputs = {
 
80
  )
81
  prompt = sample_dict[ENCODER_INPUTS_STR]
82
  batch_dict = self.run_model(sample_dict=sample_dict, model=model_holder.model)
83
+ res = prompt, *self.decode_output(
84
+ batch_dict, tokenizer_op=model_holder.tokenizer_op
85
+ )
86
  return res
87
 
 
 
88
  def create_demo(self, model_name_widget):
89
 
 
90
  with gr.Group() as demo:
91
  gr.Markdown(self.markup_text)
92
  with gr.Row():
 
114
  run_mammal.click(
115
  fn=self.create_and_run_prompt,
116
  inputs=[model_name_widget, protein_textbox],
117
+ outputs=[
118
+ prompt_box,
119
+ decoded,
120
+ predicted_class,
121
+ non_norm_score,
122
+ norm_score,
123
+ ],
124
  )
125
  demo.visible = False
126
  return demo
mammal_demo/tcr_task.py CHANGED
@@ -1,12 +1,11 @@
1
  import gradio as gr
2
  import torch
3
  from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
4
- from mammal.examples.dti_bindingdb_kd.task import DtiBindingdbKdTask
5
  from mammal.keys import (
 
 
6
  ENCODER_INPUTS_STR,
7
  ENCODER_INPUTS_TOKENS,
8
- ENCODER_INPUTS_ATTENTION_MASK,
9
- CLS_PRED,
10
  SCORES,
11
  )
12
  from mammal.model import Mammal
@@ -16,10 +15,12 @@ from mammal_demo.demo_framework import MammalObjectBroker, MammalTask
16
 
17
  class TcrTask(MammalTask):
18
  def __init__(self, model_dict):
19
- super().__init__(name="T-cell receptors-peptide binding specificity", model_dict=model_dict)
 
 
20
  self.description = "T-cell receptors-peptide binding specificity (TCR)"
21
  self.examples = {
22
- "tcr_beta_seq": "NAGVTQTPKFQVLKTGQSMTLQCAQDMNHEYMSWYRQDPGMGLRLIHYSVGAGITDQGEVPNGYNVSRSTTEDFPLRLLSAAPSQTSVYFCASSYSWDRVLEQYFGPGTRLTVT",
23
  "epitope_seq": "LLQTGIHVRVSQPSL",
24
  }
25
  self.markup_text = """
@@ -28,20 +29,14 @@ class TcrTask(MammalTask):
28
  Given the TCR beta sequance and the epitope sequacne, estimate the binding specificity.
29
  """
30
 
31
-
32
-
33
-
34
- def create_prompt(self,tcr_beta_seq, epitope_seq):
35
  prompt = (
36
- "<@TOKENIZER-TYPE=AA><BINDING_AFFINITY_CLASS><SENTINEL_ID_0>"+
37
- f"<@TOKENIZER-TYPE=AA><MOLECULAR_ENTITY><MOLECULAR_ENTITY_TCR_BETA_VDJ><SEQUENCE_NATURAL_START>{tcr_beta_seq}<SEQUENCE_NATURAL_END>"+
38
- f"<@TOKENIZER-TYPE=AA><MOLECULAR_ENTITY><MOLECULAR_ENTITY_EPITOPE><SEQUENCE_NATURAL_START>{epitope_seq}<SEQUENCE_NATURAL_END><EOS>"
39
  )
40
-
41
- return prompt
42
-
43
-
44
 
 
45
 
46
  def crate_sample_dict(self, sample_inputs: dict, model_holder: MammalObjectBroker):
47
  """convert sample_inputs to sample_dict including creating a proper prompt
@@ -52,15 +47,15 @@ Given the TCR beta sequance and the epitope sequacne, estimate the binding speci
52
  Returns:
53
  dict: sample_dict for feeding into model
54
  """
55
- sample_dict= dict()
56
  sample_dict[ENCODER_INPUTS_STR] = self.create_prompt(**sample_inputs)
57
  tokenizer_op = model_holder.tokenizer_op
58
  model = model_holder.model
59
  tokenizer_op(
60
- sample_dict=sample_dict,
61
- key_in=ENCODER_INPUTS_STR,
62
- key_out_tokens_ids=ENCODER_INPUTS_TOKENS,
63
- key_out_attention_mask=ENCODER_INPUTS_ATTENTION_MASK,
64
  )
65
  sample_dict[ENCODER_INPUTS_TOKENS] = torch.tensor(
66
  sample_dict[ENCODER_INPUTS_TOKENS], device=model.device
@@ -92,7 +87,7 @@ Given the TCR beta sequance and the epitope sequacne, estimate the binding speci
92
  int: id of positive binding token
93
  """
94
  return tokenizer_op.get_token_id("<1>")
95
-
96
  @staticmethod
97
  def negative_token_id(tokenizer_op: ModularTokenizerOp):
98
  """token for negative binding
@@ -105,15 +100,14 @@ Given the TCR beta sequance and the epitope sequacne, estimate the binding speci
105
  """
106
  return tokenizer_op.get_token_id("<0>")
107
 
108
- def decode_output(self, batch_dict, tokenizer_op: ModularTokenizerOp)-> dict:
109
-
110
  """
111
  Extract predicted class and scores
112
  """
113
-
114
  # positive_token_id = self.positive_token_id(tokenizer_op)
115
  # negative_token_id = self.negative_token_id(tokenizer_op)
116
-
117
  negative_token_id = tokenizer_op.get_token_id("<0>")
118
  positive_token_id = tokenizer_op.get_token_id("<1>")
119
 
@@ -123,14 +117,13 @@ Given the TCR beta sequance and the epitope sequacne, estimate the binding speci
123
  }
124
  classification_position = 1
125
 
126
- decoder_output=batch_dict[CLS_PRED][0]
127
- decoder_output_scores=batch_dict[SCORES][0]
128
-
129
 
130
  if decoder_output_scores is not None:
131
- scores = decoder_output_scores[classification_position,positive_token_id]
132
  else:
133
- scores=[None]
134
 
135
  ans = [
136
  tokenizer_op._tokenizer.decode(batch_dict[CLS_PRED][0]),
@@ -139,8 +132,6 @@ Given the TCR beta sequance and the epitope sequacne, estimate the binding speci
139
  ]
140
  return ans
141
 
142
-
143
-
144
  def create_and_run_prompt(self, model_name, tcr_beta_seq, epitope_seq):
145
  model_holder = self.model_dict[model_name]
146
  inputs = {
@@ -152,14 +143,13 @@ Given the TCR beta sequance and the epitope sequacne, estimate the binding speci
152
  )
153
  prompt = sample_dict[ENCODER_INPUTS_STR]
154
  batch_dict = self.run_model(sample_dict=sample_dict, model=model_holder.model)
155
- res = prompt, *self.decode_output(batch_dict, tokenizer_op=model_holder.tokenizer_op)
 
 
156
  return res
157
 
158
-
159
-
160
  def create_demo(self, model_name_widget):
161
 
162
-
163
  with gr.Group() as demo:
164
  gr.Markdown(self.markup_text)
165
  with gr.Row():
@@ -192,7 +182,7 @@ Given the TCR beta sequance and the epitope sequacne, estimate the binding speci
192
  run_mammal.click(
193
  fn=self.create_and_run_prompt,
194
  inputs=[model_name_widget, tcr_textbox, epitope_textbox],
195
- outputs=[prompt_box, decoded, predicted_class,binding_score],
196
  )
197
  demo.visible = False
198
  return demo
 
1
  import gradio as gr
2
  import torch
3
  from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
 
4
  from mammal.keys import (
5
+ CLS_PRED,
6
+ ENCODER_INPUTS_ATTENTION_MASK,
7
  ENCODER_INPUTS_STR,
8
  ENCODER_INPUTS_TOKENS,
 
 
9
  SCORES,
10
  )
11
  from mammal.model import Mammal
 
15
 
16
  class TcrTask(MammalTask):
17
  def __init__(self, model_dict):
18
+ super().__init__(
19
+ name="T-cell receptors-peptide binding specificity", model_dict=model_dict
20
+ )
21
  self.description = "T-cell receptors-peptide binding specificity (TCR)"
22
  self.examples = {
23
+ "tcr_beta_seq": "NAGVTQTPKFQVLKTGQSMTLQCAQDMNHEYMSWYRQDPGMGLRLIHYSVGAGITDQGEVPNGYNVSRSTTEDFPLRLLSAAPSQTSVYFCASSYSWDRVLEQYFGPGTRLTVT",
24
  "epitope_seq": "LLQTGIHVRVSQPSL",
25
  }
26
  self.markup_text = """
 
29
  Given the TCR beta sequance and the epitope sequacne, estimate the binding specificity.
30
  """
31
 
32
+ def create_prompt(self, tcr_beta_seq, epitope_seq):
 
 
 
33
  prompt = (
34
+ "<@TOKENIZER-TYPE=AA><BINDING_AFFINITY_CLASS><SENTINEL_ID_0>"
35
+ + f"<@TOKENIZER-TYPE=AA><MOLECULAR_ENTITY><MOLECULAR_ENTITY_TCR_BETA_VDJ><SEQUENCE_NATURAL_START>{tcr_beta_seq}<SEQUENCE_NATURAL_END>"
36
+ + f"<@TOKENIZER-TYPE=AA><MOLECULAR_ENTITY><MOLECULAR_ENTITY_EPITOPE><SEQUENCE_NATURAL_START>{epitope_seq}<SEQUENCE_NATURAL_END><EOS>"
37
  )
 
 
 
 
38
 
39
+ return prompt
40
 
41
  def crate_sample_dict(self, sample_inputs: dict, model_holder: MammalObjectBroker):
42
  """convert sample_inputs to sample_dict including creating a proper prompt
 
47
  Returns:
48
  dict: sample_dict for feeding into model
49
  """
50
+ sample_dict = dict()
51
  sample_dict[ENCODER_INPUTS_STR] = self.create_prompt(**sample_inputs)
52
  tokenizer_op = model_holder.tokenizer_op
53
  model = model_holder.model
54
  tokenizer_op(
55
+ sample_dict=sample_dict,
56
+ key_in=ENCODER_INPUTS_STR,
57
+ key_out_tokens_ids=ENCODER_INPUTS_TOKENS,
58
+ key_out_attention_mask=ENCODER_INPUTS_ATTENTION_MASK,
59
  )
60
  sample_dict[ENCODER_INPUTS_TOKENS] = torch.tensor(
61
  sample_dict[ENCODER_INPUTS_TOKENS], device=model.device
 
87
  int: id of positive binding token
88
  """
89
  return tokenizer_op.get_token_id("<1>")
90
+
91
  @staticmethod
92
  def negative_token_id(tokenizer_op: ModularTokenizerOp):
93
  """token for negative binding
 
100
  """
101
  return tokenizer_op.get_token_id("<0>")
102
 
103
+ def decode_output(self, batch_dict, tokenizer_op: ModularTokenizerOp) -> list:
 
104
  """
105
  Extract predicted class and scores
106
  """
107
+
108
  # positive_token_id = self.positive_token_id(tokenizer_op)
109
  # negative_token_id = self.negative_token_id(tokenizer_op)
110
+
111
  negative_token_id = tokenizer_op.get_token_id("<0>")
112
  positive_token_id = tokenizer_op.get_token_id("<1>")
113
 
 
117
  }
118
  classification_position = 1
119
 
120
+ decoder_output = batch_dict[CLS_PRED][0]
121
+ decoder_output_scores = batch_dict[SCORES][0]
 
122
 
123
  if decoder_output_scores is not None:
124
+ scores = decoder_output_scores[classification_position, positive_token_id]
125
  else:
126
+ scores = [None]
127
 
128
  ans = [
129
  tokenizer_op._tokenizer.decode(batch_dict[CLS_PRED][0]),
 
132
  ]
133
  return ans
134
 
 
 
135
  def create_and_run_prompt(self, model_name, tcr_beta_seq, epitope_seq):
136
  model_holder = self.model_dict[model_name]
137
  inputs = {
 
143
  )
144
  prompt = sample_dict[ENCODER_INPUTS_STR]
145
  batch_dict = self.run_model(sample_dict=sample_dict, model=model_holder.model)
146
+ res = prompt, *self.decode_output(
147
+ batch_dict, tokenizer_op=model_holder.tokenizer_op
148
+ )
149
  return res
150
 
 
 
151
  def create_demo(self, model_name_widget):
152
 
 
153
  with gr.Group() as demo:
154
  gr.Markdown(self.markup_text)
155
  with gr.Row():
 
182
  run_mammal.click(
183
  fn=self.create_and_run_prompt,
184
  inputs=[model_name_widget, tcr_textbox, epitope_textbox],
185
+ outputs=[prompt_box, decoded, predicted_class, binding_score],
186
  )
187
  demo.visible = False
188
  return demo