wangjin2000 commited on
Commit
b1c82a3
·
verified ·
1 Parent(s): 1ebd41b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -122,13 +122,16 @@ def predict_bind(base_model_path,PEFT_model_path,input_seq):
122
  predictions = torch.argmax(logits, dim=2)
123
 
124
  binding_site=[]
 
125
  # Print the predicted labels for each token
126
  for token, prediction in zip(tokens, predictions[0].numpy()):
127
  if token not in ['<pad>', '<cls>', '<eos>']:
128
- print((token, id2label[prediction]))
129
- if prediction == 1:
130
- print((token, id2label[prediction]))
131
- binding_site.append([token, id2label[prediction]])
 
 
132
  return binding_site
133
 
134
  # fine-tuning function
 
122
  predictions = torch.argmax(logits, dim=2)
123
 
124
  binding_site=[]
125
+ pos = 0
126
  # Print the predicted labels for each token
127
  for token, prediction in zip(tokens, predictions[0].numpy()):
128
  if token not in ['<pad>', '<cls>', '<eos>']:
129
+ pos++
130
+ print((pos, token, id2label[prediction]))
131
+ if prediction == 1:
132
+ print((pos, token, id2label[prediction]))
133
+ binding_site.append([pos, token, id2label[prediction]])
134
+
135
  return binding_site
136
 
137
  # fine-tuning function