AlienChen commited on
Commit
6409d51
·
verified ·
1 Parent(s): 30fd543

Update moo.py

Browse files
Files changed (1) hide show
  1. moo.py +20 -4
moo.py CHANGED
@@ -1,11 +1,9 @@
1
- import yaml
2
- from tqdm import tqdm
3
  import torch
4
- from torch import nn
5
  from transformers import AutoTokenizer
 
 
6
 
7
  from models.peptide_classifiers import *
8
-
9
  from utils.parsing import parse_guidance_args
10
  args = parse_guidance_args()
11
 
@@ -55,6 +53,19 @@ if 'Motif' in args.objectives or 'Specificity' in args.objectives:
55
  motif_model = MotifModel(bindevaluator, target_sequence, motifs, penalty=args.motif_penalty)
56
  score_models.append(motif_model)
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  for i in range(args.n_batches):
59
  if source_distribution == "uniform":
60
  x_init = torch.randint(low=4, high=vocab_size, size=(n_samples, length), device=device) # CHANGE!
@@ -87,6 +98,11 @@ for i in range(args.n_batches):
87
  else:
88
  candidate_scores = s(x_1)
89
 
 
 
 
 
 
90
  if isinstance(candidate_scores, tuple):
91
  for score in candidate_scores:
92
  scores.append(score.item())
 
 
 
1
  import torch
 
2
  from transformers import AutoTokenizer
3
+ from pathlib import Path
4
+ import inspect
5
 
6
  from models.peptide_classifiers import *
 
7
  from utils.parsing import parse_guidance_args
8
  args = parse_guidance_args()
9
 
 
53
  motif_model = MotifModel(bindevaluator, target_sequence, motifs, penalty=args.motif_penalty)
54
  score_models.append(motif_model)
55
 
56
+ objective_line = str(args.objectives)[1:-1].replace(' ', '').replace("'", "") + '\n'
57
+
58
+ if Path(args.output_file).exists():
59
+ with open(args.output_file, 'r') as f:
60
+ lines = f.readlines()
61
+
62
+ if lines[0] != objective_line:
63
+ with open(args.output_file, 'w') as f:
64
+ f.write(objective_line)
65
+ else:
66
+ with open(args.output_file, 'w') as f:
67
+ f.write(objective_line)
68
+
69
  for i in range(args.n_batches):
70
  if source_distribution == "uniform":
71
  x_init = torch.randint(low=4, high=vocab_size, size=(n_samples, length), device=device) # CHANGE!
 
98
  else:
99
  candidate_scores = s(x_1)
100
 
101
+ if args.objectives[i] == 'Half-Life':
102
+ candidate_scores = 10 ** (candidate_scores * 2)
103
+ if args.objectives[i] == 'Affinity':
104
+ candidate_scores = 10 * candidate_scores
105
+
106
  if isinstance(candidate_scores, tuple):
107
  for score in candidate_scores:
108
  scores.append(score.item())