matanninio commited on
Commit
4fb0503
1 Parent(s): f98cc68

save snapshot

Browse files
Files changed (1) hide show
  1. mammal_demo/tcr_task.py +196 -0
mammal_demo/tcr_task.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
4
+ from mammal.examples.dti_bindingdb_kd.task import DtiBindingdbKdTask
5
+ from mammal.keys import (
6
+ ENCODER_INPUTS_STR,
7
+ ENCODER_INPUTS_TOKENS,
8
+ ENCODER_INPUTS_ATTENTION_MASK,
9
+ CLS_PRED,
10
+ SCORES,
11
+ )
12
+ from mammal.model import Mammal
13
+
14
+ from mammal_demo.demo_framework import MammalObjectBroker, MammalTask
15
+
16
+
17
+ class TcrTask(MammalTask):
18
+ def __init__(self, model_dict):
19
+ super().__init__(name="T-cell receptors-peptide binding specificity", model_dict=model_dict)
20
+ self.description = "T-cell receptors-peptide binding specificity (TCR)"
21
+ self.examples = {
22
+ "tcr_beta_seq": "NAGVTQTPKFQVLKTGQSMTLQCAQDMNHEYMSWYRQDPGMGLRLIHYSVGAGITDQGEVPNGYNVSRSTTEDFPLRLLSAAPSQTSVYFCASSYSWDRVLEQYFGPGTRLTVT",
23
+ "epitope_seq": "LLQTGIHVRVSQPSL",
24
+ }
25
+ self.markup_text = """
26
+ # Mammal based T-cell receptors-peptide binding specificity demonstration
27
+
28
+ Given the TCR beta sequance and the epitope sequacne, estimate the binding specificity.
29
+ """
30
+
31
+
32
+
33
+
34
+ def create_prompt(self,tcr_beta_seq, epitope_seq):
35
+ prompt = (
36
+ "<@TOKENIZER-TYPE=AA><BINDING_AFFINITY_CLASS><SENTINEL_ID_0>"+
37
+ f"<@TOKENIZER-TYPE=AA><MOLECULAR_ENTITY><MOLECULAR_ENTITY_TCR_BETA_VDJ><SEQUENCE_NATURAL_START>{tcr_beta_seq}<SEQUENCE_NATURAL_END>"+
38
+ f"<@TOKENIZER-TYPE=AA><MOLECULAR_ENTITY><MOLECULAR_ENTITY_EPITOPE><SEQUENCE_NATURAL_START>{epitope_seq}<SEQUENCE_NATURAL_END><EOS>"
39
+ )
40
+
41
+ return prompt
42
+
43
+
44
+
45
+
46
+ def crate_sample_dict(self, sample_inputs: dict, model_holder: MammalObjectBroker):
47
+ """convert sample_inputs to sample_dict including creating a proper prompt
48
+
49
+ Args:
50
+ sample_inputs (dict): dictionary containing the inputs to the model
51
+ model_holder (MammalObjectBroker): model holder
52
+ Returns:
53
+ dict: sample_dict for feeding into model
54
+ """
55
+ sample_dict= dict()
56
+ sample_dict[ENCODER_INPUTS_STR] = self.create_prompt(*sample_inputs)
57
+ tokenizer_op = model_holder.tokenizer_op
58
+ model = model_holder.model
59
+ tokenizer_op(
60
+ sample_dict=sample_dict,
61
+ key_in=ENCODER_INPUTS_STR,
62
+ key_out_tokens_ids=ENCODER_INPUTS_TOKENS,
63
+ key_out_attention_mask=ENCODER_INPUTS_ATTENTION_MASK,
64
+ )
65
+ sample_dict[ENCODER_INPUTS_TOKENS] = torch.tensor(
66
+ sample_dict[ENCODER_INPUTS_TOKENS], device=model.device
67
+ )
68
+ sample_dict[ENCODER_INPUTS_ATTENTION_MASK] = torch.tensor(
69
+ sample_dict[ENCODER_INPUTS_ATTENTION_MASK], device=model.device
70
+ )
71
+
72
+ return sample_dict
73
+
74
+ def run_model(self, sample_dict, model: Mammal):
75
+ # Generate Prediction
76
+ batch_dict = model.generate(
77
+ [sample_dict],
78
+ output_scores=True,
79
+ return_dict_in_generate=True,
80
+ max_new_tokens=5,
81
+ )
82
+ return batch_dict
83
+
84
+ @staticmethod
85
+ def positive_token_id(tokenizer_op: ModularTokenizerOp):
86
+ """token for positive binding
87
+
88
+ Args:
89
+ model (MammalTrainedModel): model holding tokenizer
90
+
91
+ Returns:
92
+ int: id of positive binding token
93
+ """
94
+ return tokenizer_op.get_token_id("<1>")
95
+
96
+ @staticmethod
97
+ def negative_token_id(tokenizer_op: ModularTokenizerOp):
98
+ """token for negative binding
99
+
100
+ Args:
101
+ model (MammalTrainedModel): model holding tokenizer
102
+
103
+ Returns:
104
+ int: id of negative binding token
105
+ """
106
+ return tokenizer_op.get_token_id("<0>")
107
+
108
+ def decode_output(self, batch_dict, tokenizer_op: ModularTokenizerOp)-> dict:
109
+
110
+ """
111
+ Extract predicted class and scores
112
+ """
113
+
114
+ # positive_token_id = self.positive_token_id(tokenizer_op)
115
+ # negative_token_id = self.negative_token_id(tokenizer_op)
116
+
117
+ negative_token_id = tokenizer_op.get_token_id("<0>")
118
+ positive_token_id = tokenizer_op.get_token_id("<1>")
119
+
120
+ label_id_to_int = {
121
+ negative_token_id: 0,
122
+ positive_token_id: 1,
123
+ }
124
+ classification_position = 1
125
+
126
+ decoder_output=batch_dict[CLS_PRED][0]
127
+ decoder_output_scores=batch_dict[SCORES][0]
128
+
129
+
130
+ if decoder_output_scores is not None:
131
+ scores = decoder_output_scores[classification_position,positive_token_id]
132
+ else:
133
+ scores=[None]
134
+
135
+ ans = dict(
136
+ pred=label_id_to_int.get(int(decoder_output[classification_position]), -1),
137
+ score=scores.item(),
138
+ )
139
+ return ans
140
+
141
+
142
+
143
+ def create_and_run_prompt(self, model_name, tcr_beta_seq, epitope_seq):
144
+ model_holder = self.model_dict[model_name]
145
+ inputs = {
146
+ "tcr_beta_seq": tcr_beta_seq,
147
+ "epitope_seq": epitope_seq,
148
+ }
149
+ sample_dict = self.crate_sample_dict(
150
+ sample_inputs=inputs, model_holder=model_holder
151
+ )
152
+ prompt = sample_dict[ENCODER_INPUTS_STR]
153
+ batch_dict = self.run_model(sample_dict=sample_dict, model=model_holder.model)
154
+ res = prompt, *self.decode_output(batch_dict, tokenizer_op=model_holder.tokenizer_op)
155
+ return res
156
+
157
+
158
+
159
+ def create_demo(self, model_name_widget):
160
+
161
+
162
+ with gr.Group() as demo:
163
+ gr.Markdown(self.markup_text)
164
+ with gr.Row():
165
+ tcr_textbox = gr.Textbox(
166
+ label="T-cell receptor beta sequence",
167
+ # info="standard",
168
+ interactive=True,
169
+ lines=3,
170
+ value=self.examples["tcr_beta_seq"],
171
+ )
172
+ epitope_textbox = gr.Textbox(
173
+ label="Epitope sequace",
174
+ # info="standard",
175
+ interactive=True,
176
+ lines=3,
177
+ value=self.examples["epitope_seq"],
178
+ )
179
+ with gr.Row():
180
+ run_mammal = gr.Button(
181
+ "Run Mammal prompt for TCL-Epitope Interaction",
182
+ variant="primary",
183
+ )
184
+ with gr.Row():
185
+ prompt_box = gr.Textbox(label="Mammal prompt", lines=5)
186
+
187
+ with gr.Row():
188
+ decoded = gr.Textbox(label="Mammal prediction")
189
+ binding_score = gr.Number(label="Binding score")
190
+ run_mammal.click(
191
+ fn=self.create_and_run_prompt,
192
+ inputs=[model_name_widget, tcr_textbox, epitope_textbox],
193
+ outputs=[prompt_box, decoded, binding_score],
194
+ )
195
+ demo.visible = False
196
+ return demo