Update moppit.py
Browse files
moppit.py
CHANGED
|
@@ -22,7 +22,7 @@ if args.motifs:
|
|
| 22 |
print(motifs)
|
| 23 |
|
| 24 |
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
| 25 |
-
target_sequence = tokenizer(target, return_tensors='pt')
|
| 26 |
|
| 27 |
# Load Models
|
| 28 |
solver = load_solver('./ckpt/peptide/cnn_epoch200_lr0.0001_embed512_hidden256_loss3.1051.ckpt', vocab_size, device)
|
|
@@ -35,14 +35,14 @@ if 'Non-Fouling' in args.objectives:
|
|
| 35 |
nonfouling_model = NonfoulingModel(device=device)
|
| 36 |
score_models.append(nonfouling_model)
|
| 37 |
if 'Solubility' in args.objectives:
|
| 38 |
-
solubility_model =
|
| 39 |
score_models.append(solubility_model)
|
| 40 |
if 'Half-Life' in args.objectives:
|
| 41 |
halflife_model = HalfLifeModel(device=device)
|
| 42 |
score_models.append(halflife_model)
|
| 43 |
if 'Affinity' in args.objectives:
|
| 44 |
-
affinity_predictor = load_affinity_predictor('./classifier_ckpt/
|
| 45 |
-
affinity_model = AffinityModel(affinity_predictor, target_sequence)
|
| 46 |
score_models.append(affinity_model)
|
| 47 |
if 'Motif' in args.objectives or 'Specificity' in args.objectives:
|
| 48 |
bindevaluator = load_bindevaluator('./classifier_ckpt/finetuned_BindEvaluator.ckpt', device)
|
|
@@ -50,7 +50,7 @@ if 'Motif' in args.objectives or 'Specificity' in args.objectives:
|
|
| 50 |
motif_penalty = True
|
| 51 |
else:
|
| 52 |
motif_penalty = False
|
| 53 |
-
motif_model = MotifModel(bindevaluator, target_sequence, motifs, penalty=motif_penalty)
|
| 54 |
score_models.append(motif_model)
|
| 55 |
|
| 56 |
objective_line = "Binder," + str(args.objectives)[1:-1].replace(' ', '').replace("'", "") + '\n'
|
|
|
|
| 22 |
print(motifs)
|
| 23 |
|
| 24 |
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
| 25 |
+
target_sequence = tokenizer(target, return_tensors='pt').to(device)
|
| 26 |
|
| 27 |
# Load Models
|
| 28 |
solver = load_solver('./ckpt/peptide/cnn_epoch200_lr0.0001_embed512_hidden256_loss3.1051.ckpt', vocab_size, device)
|
|
|
|
| 35 |
nonfouling_model = NonfoulingModel(device=device)
|
| 36 |
score_models.append(nonfouling_model)
|
| 37 |
if 'Solubility' in args.objectives:
|
| 38 |
+
solubility_model = SolubilityModel(device=device)
|
| 39 |
score_models.append(solubility_model)
|
| 40 |
if 'Half-Life' in args.objectives:
|
| 41 |
halflife_model = HalfLifeModel(device=device)
|
| 42 |
score_models.append(halflife_model)
|
| 43 |
if 'Affinity' in args.objectives:
|
| 44 |
+
affinity_predictor = load_affinity_predictor('../classifier_ckpt/wt_affinity.pt', device)
|
| 45 |
+
affinity_model = AffinityModel(affinity_predictor, target_sequence, device)
|
| 46 |
score_models.append(affinity_model)
|
| 47 |
if 'Motif' in args.objectives or 'Specificity' in args.objectives:
|
| 48 |
bindevaluator = load_bindevaluator('./classifier_ckpt/finetuned_BindEvaluator.ckpt', device)
|
|
|
|
| 50 |
motif_penalty = True
|
| 51 |
else:
|
| 52 |
motif_penalty = False
|
| 53 |
+
motif_model = MotifModel(bindevaluator, target_sequence['input_ids'], motifs, penalty=motif_penalty)
|
| 54 |
score_models.append(motif_model)
|
| 55 |
|
| 56 |
objective_line = "Binder," + str(args.objectives)[1:-1].replace(' ', '').replace("'", "") + '\n'
|