Update models/peptiverse_classifiers.py
Browse files
models/peptiverse_classifiers.py
CHANGED
|
@@ -230,7 +230,7 @@ class MotifModelWT(nn.Module):
|
|
| 230 |
return self.bindevaluator.scoring(x, self.target_sequence, self.motifs, self.penalty)
|
| 231 |
|
| 232 |
def load_bindevaluator(checkpoint_path, device):
|
| 233 |
-
bindevaluator = BindEvaluatorWT.load_from_checkpoint(checkpoint_path, n_layers=8, d_model=128, d_hidden=128, n_head=8, d_k=64, d_v=128, d_inner=64).to(device)
|
| 234 |
bindevaluator.eval()
|
| 235 |
for param in bindevaluator.parameters():
|
| 236 |
param.requires_grad = False
|
|
|
|
| 230 |
return self.bindevaluator.scoring(x, self.target_sequence, self.motifs, self.penalty)
|
| 231 |
|
| 232 |
def load_bindevaluator(checkpoint_path, device):
|
| 233 |
+
bindevaluator = BindEvaluatorWT.load_from_checkpoint(checkpoint_path, weights_only=False, n_layers=8, d_model=128, d_hidden=128, n_head=8, d_k=64, d_v=128, d_inner=64).to(device)
|
| 234 |
bindevaluator.eval()
|
| 235 |
for param in bindevaluator.parameters():
|
| 236 |
param.requires_grad = False
|