Spaces:
Sleeping
Sleeping
File size: 6,818 Bytes
4fb0503 fda141d 4fb0503 fda141d 4fb0503 fda141d 4fb0503 fda141d 4fb0503 fda141d 4fb0503 fda141d 4fb0503 fda141d 292a922 4fb0503 fda141d 4fb0503 fda141d 4fb0503 fda141d 4fb0503 fda141d 4fb0503 fda141d 4fb0503 fda141d 4fb0503 fda141d 4fb0503 fda141d 4fb0503 292a922 4fb0503 fda141d 4fb0503 292a922 4fb0503 fda141d 4fb0503 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
import gradio as gr
import torch
from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
from mammal.keys import (
CLS_PRED,
ENCODER_INPUTS_ATTENTION_MASK,
ENCODER_INPUTS_STR,
ENCODER_INPUTS_TOKENS,
SCORES,
)
from mammal.model import Mammal
from mammal_demo.demo_framework import MammalObjectBroker, MammalTask
class TcrTask(MammalTask):
def __init__(self, model_dict):
super().__init__(
name="T-cell receptors-peptide binding specificity", model_dict=model_dict
)
self.description = "T-cell receptors-peptide binding specificity (TCR)"
self.examples = {
"tcr_beta_seq": "NAGVTQTPKFQVLKTGQSMTLQCAQDMNHEYMSWYRQDPGMGLRLIHYSVGAGITDQGEVPNGYNVSRSTTEDFPLRLLSAAPSQTSVYFCASSYSWDRVLEQYFGPGTRLTVT",
"epitope_seq": "LLQTGIHVRVSQPSL",
}
self.markup_text = """
# Mammal based T-cell receptors-peptide binding specificity demonstration
Given the TCR beta sequance and the epitope sequacne, estimate the binding specificity.
"""
def create_prompt(self, tcr_beta_seq, epitope_seq):
prompt = (
"<@TOKENIZER-TYPE=AA><BINDING_AFFINITY_CLASS><SENTINEL_ID_0>"
+ f"<@TOKENIZER-TYPE=AA><MOLECULAR_ENTITY><MOLECULAR_ENTITY_TCR_BETA_VDJ><SEQUENCE_NATURAL_START>{tcr_beta_seq}<SEQUENCE_NATURAL_END>"
+ f"<@TOKENIZER-TYPE=AA><MOLECULAR_ENTITY><MOLECULAR_ENTITY_EPITOPE><SEQUENCE_NATURAL_START>{epitope_seq}<SEQUENCE_NATURAL_END><EOS>"
)
return prompt
def crate_sample_dict(self, sample_inputs: dict, model_holder: MammalObjectBroker):
"""convert sample_inputs to sample_dict including creating a proper prompt
Args:
sample_inputs (dict): dictionary containing the inputs to the model
model_holder (MammalObjectBroker): model holder
Returns:
dict: sample_dict for feeding into model
"""
sample_dict = dict()
sample_dict[ENCODER_INPUTS_STR] = self.create_prompt(**sample_inputs)
tokenizer_op = model_holder.tokenizer_op
model = model_holder.model
tokenizer_op(
sample_dict=sample_dict,
key_in=ENCODER_INPUTS_STR,
key_out_tokens_ids=ENCODER_INPUTS_TOKENS,
key_out_attention_mask=ENCODER_INPUTS_ATTENTION_MASK,
)
sample_dict[ENCODER_INPUTS_TOKENS] = torch.tensor(
sample_dict[ENCODER_INPUTS_TOKENS], device=model.device
)
sample_dict[ENCODER_INPUTS_ATTENTION_MASK] = torch.tensor(
sample_dict[ENCODER_INPUTS_ATTENTION_MASK], device=model.device
)
return sample_dict
def run_model(self, sample_dict, model: Mammal):
# Generate Prediction
batch_dict = model.generate(
[sample_dict],
output_scores=True,
return_dict_in_generate=True,
max_new_tokens=5,
)
return batch_dict
@staticmethod
def positive_token_id(tokenizer_op: ModularTokenizerOp):
"""token for positive binding
Args:
model (MammalTrainedModel): model holding tokenizer
Returns:
int: id of positive binding token
"""
return tokenizer_op.get_token_id("<1>")
@staticmethod
def negative_token_id(tokenizer_op: ModularTokenizerOp):
"""token for negative binding
Args:
model (MammalTrainedModel): model holding tokenizer
Returns:
int: id of negative binding token
"""
return tokenizer_op.get_token_id("<0>")
def decode_output(self, batch_dict, tokenizer_op: ModularTokenizerOp) -> list:
"""
Extract predicted class and scores
"""
# positive_token_id = self.positive_token_id(tokenizer_op)
# negative_token_id = self.negative_token_id(tokenizer_op)
negative_token_id = tokenizer_op.get_token_id("<0>")
positive_token_id = tokenizer_op.get_token_id("<1>")
label_id_to_int = {
negative_token_id: 0,
positive_token_id: 1,
}
classification_position = 1
decoder_output = batch_dict[CLS_PRED][0]
decoder_output_scores = batch_dict[SCORES][0]
if decoder_output_scores is not None:
scores = decoder_output_scores[classification_position, positive_token_id]
else:
scores = [None]
ans = [
tokenizer_op._tokenizer.decode(batch_dict[CLS_PRED][0]),
label_id_to_int.get(int(decoder_output[classification_position]), -1),
scores.item(),
]
return ans
def create_and_run_prompt(self, model_name, tcr_beta_seq, epitope_seq):
model_holder = self.model_dict[model_name]
inputs = {
"tcr_beta_seq": tcr_beta_seq,
"epitope_seq": epitope_seq,
}
sample_dict = self.crate_sample_dict(
sample_inputs=inputs, model_holder=model_holder
)
prompt = sample_dict[ENCODER_INPUTS_STR]
batch_dict = self.run_model(sample_dict=sample_dict, model=model_holder.model)
res = prompt, *self.decode_output(
batch_dict, tokenizer_op=model_holder.tokenizer_op
)
return res
def create_demo(self, model_name_widget):
with gr.Group() as demo:
gr.Markdown(self.markup_text)
with gr.Row():
tcr_textbox = gr.Textbox(
label="T-cell receptor beta sequence",
# info="standard",
interactive=True,
lines=3,
value=self.examples["tcr_beta_seq"],
)
epitope_textbox = gr.Textbox(
label="Epitope sequace",
# info="standard",
interactive=True,
lines=3,
value=self.examples["epitope_seq"],
)
with gr.Row():
run_mammal = gr.Button(
"Run Mammal prompt for TCL-Epitope Interaction",
variant="primary",
)
with gr.Row():
prompt_box = gr.Textbox(label="Mammal prompt", lines=5)
with gr.Row():
decoded = gr.Textbox(label="Mammal output")
predicted_class = gr.Textbox(label="Mammal prediction")
binding_score = gr.Number(label="Binding score")
run_mammal.click(
fn=self.create_and_run_prompt,
inputs=[model_name_widget, tcr_textbox, epitope_textbox],
outputs=[prompt_box, decoded, predicted_class, binding_score],
)
demo.visible = False
return demo
|