matanninio commited on
Commit
292a922
1 Parent(s): 4fb0503

tcr task works

Browse files
Files changed (1) hide show
  1. mammal_demo/tcr_task.py +9 -7
mammal_demo/tcr_task.py CHANGED
@@ -53,7 +53,7 @@ Given the TCR beta sequance and the epitope sequacne, estimate the binding speci
53
  dict: sample_dict for feeding into model
54
  """
55
  sample_dict= dict()
56
- sample_dict[ENCODER_INPUTS_STR] = self.create_prompt(*sample_inputs)
57
  tokenizer_op = model_holder.tokenizer_op
58
  model = model_holder.model
59
  tokenizer_op(
@@ -132,10 +132,11 @@ Given the TCR beta sequance and the epitope sequacne, estimate the binding speci
132
  else:
133
  scores=[None]
134
 
135
- ans = dict(
136
- pred=label_id_to_int.get(int(decoder_output[classification_position]), -1),
137
- score=scores.item(),
138
- )
 
139
  return ans
140
 
141
 
@@ -185,12 +186,13 @@ Given the TCR beta sequance and the epitope sequacne, estimate the binding speci
185
  prompt_box = gr.Textbox(label="Mammal prompt", lines=5)
186
 
187
  with gr.Row():
188
- decoded = gr.Textbox(label="Mammal prediction")
 
189
  binding_score = gr.Number(label="Binding score")
190
  run_mammal.click(
191
  fn=self.create_and_run_prompt,
192
  inputs=[model_name_widget, tcr_textbox, epitope_textbox],
193
- outputs=[prompt_box, decoded, binding_score],
194
  )
195
  demo.visible = False
196
  return demo
 
53
  dict: sample_dict for feeding into model
54
  """
55
  sample_dict= dict()
56
+ sample_dict[ENCODER_INPUTS_STR] = self.create_prompt(**sample_inputs)
57
  tokenizer_op = model_holder.tokenizer_op
58
  model = model_holder.model
59
  tokenizer_op(
 
132
  else:
133
  scores=[None]
134
 
135
+ ans = [
136
+ tokenizer_op._tokenizer.decode(batch_dict[CLS_PRED][0]),
137
+ label_id_to_int.get(int(decoder_output[classification_position]), -1),
138
+ scores.item(),
139
+ ]
140
  return ans
141
 
142
 
 
186
  prompt_box = gr.Textbox(label="Mammal prompt", lines=5)
187
 
188
  with gr.Row():
189
+ decoded = gr.Textbox(label="Mammal output")
190
+ predicted_class = gr.Textbox(label="Mammal prediction")
191
  binding_score = gr.Number(label="Binding score")
192
  run_mammal.click(
193
  fn=self.create_and_run_prompt,
194
  inputs=[model_name_widget, tcr_textbox, epitope_textbox],
195
+ outputs=[prompt_box, decoded, predicted_class,binding_score],
196
  )
197
  demo.visible = False
198
  return demo