AlienChen commited on
Commit
40897d7
·
verified ·
1 Parent(s): d17476e

Update moppit.py

Browse files
Files changed (1) hide show
  1. moppit.py +71 -31
moppit.py CHANGED
@@ -1,9 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from transformers import AutoTokenizer
3
  from pathlib import Path
4
  import inspect
5
 
6
- from models.peptide_classifiers import *
 
7
  from utils.parsing import parse_guidance_args
8
  args = parse_guidance_args()
9
 
@@ -17,9 +46,9 @@ device = 'cuda:0'
17
 
18
  length = args.length
19
  target = args.target_protein
 
20
  if args.motifs:
21
  motifs = parse_motifs(args.motifs).to(device)
22
- print(motifs)
23
 
24
  tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
25
  target_sequence = tokenizer(target, return_tensors='pt').to(device)
@@ -29,29 +58,30 @@ solver = load_solver('./ckpt/peptide/cnn_epoch200_lr0.0001_embed512_hidden256_lo
29
 
30
  score_models = []
31
  if 'Hemolysis' in args.objectives:
32
- hemolysis_model = HemolysisModel(device=device)
33
  score_models.append(hemolysis_model)
34
  if 'Non-Fouling' in args.objectives:
35
- nonfouling_model = NonfoulingModel(device=device)
36
  score_models.append(nonfouling_model)
37
  if 'Solubility' in args.objectives:
38
- solubility_model = SolubilityModel(device=device)
39
  score_models.append(solubility_model)
 
 
 
40
  if 'Half-Life' in args.objectives:
41
- halflife_model = HalfLifeModel(device=device)
42
  score_models.append(halflife_model)
43
  if 'Affinity' in args.objectives:
44
- affinity_predictor = load_affinity_predictor(device)
45
- affinity_model = AffinityModel(affinity_predictor, target_sequence, device)
46
  score_models.append(affinity_model)
47
-
48
- if 'Specificity' in args.objectives:
49
- motif_penalty = True
50
- else:
51
- motif_penalty = False
52
  if 'Motif' in args.objectives or 'Specificity' in args.objectives:
53
  bindevaluator = load_bindevaluator('./classifier_ckpt/finetuned_BindEvaluator.ckpt', device)
54
- motif_model = MotifModel(bindevaluator, target_sequence['input_ids'], motifs, penalty=motif_penalty)
 
 
 
 
55
  score_models.append(motif_model)
56
 
57
  objective_line = "Binder," + str(args.objectives)[1:-1].replace(' ', '').replace("'", "") + '\n'
@@ -68,36 +98,46 @@ else:
68
  f.write(objective_line)
69
 
70
  for i in range(args.n_batches):
71
- if source_distribution == "uniform":
72
- x_init = torch.randint(low=4, high=vocab_size, size=(n_samples, length), device=device) # CHANGE!
73
- elif source_distribution == "mask":
74
- x_init = (torch.zeros(size=(n_samples, length), device=device) + 3).long()
75
  else:
76
- raise NotImplementedError
 
 
 
 
 
77
 
78
- zeros = torch.zeros((n_samples, 1), dtype=x_init.dtype, device=x_init.device)
79
- twos = torch.full((n_samples, 1), 2, dtype=x_init.dtype, device=x_init.device)
80
- x_init = torch.cat([zeros, x_init, twos], dim=1)
 
 
 
 
 
 
 
81
 
82
  x_1 = solver.multi_guidance_sample(args=args, x_init=x_init,
83
  step_size=step_size,
84
  verbose=True,
85
  time_grid=torch.tensor([0.0, 1.0-1e-3]),
86
  score_models=score_models,
87
- num_objectives=len(score_models) + int(motif_penalty),
88
- weights=args.weights)
89
-
90
- samples = x_1.tolist()
91
- samples = [tokenizer.decode(seq).replace(' ', '')[5:-5] for seq in samples]
92
- print(samples)
93
 
94
  scores = []
 
95
  for i, s in enumerate(score_models):
96
  sig = inspect.signature(s.forward) if hasattr(s, 'forward') else inspect.signature(s)
97
  if 't' in sig.parameters:
98
- candidate_scores = s(x_1, 1)
99
  else:
100
- candidate_scores = s(x_1)
101
 
102
  if args.objectives[i] == 'Affinity':
103
  candidate_scores = 10 * candidate_scores
@@ -110,7 +150,7 @@ for i in range(args.n_batches):
110
  print(scores)
111
 
112
  with open(args.output_file, 'a') as f:
113
- f.write(samples[0])
114
  for score in scores:
115
  f.write(f",{score}")
116
  f.write('\n')
 
1
+ import os
2
+ import warnings
3
+ import logging
4
+
5
+ import os
6
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
7
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
8
+
9
+ warnings.filterwarnings("ignore")
10
+ warnings.filterwarnings("ignore", category=UserWarning)
11
+ warnings.filterwarnings("ignore", category=FutureWarning)
12
+
13
+ from sklearn.exceptions import InconsistentVersionWarning
14
+ warnings.filterwarnings("ignore", category=InconsistentVersionWarning)
15
+
16
+ logging.getLogger().setLevel(logging.ERROR)
17
+ logging.getLogger("lightning").setLevel(logging.ERROR)
18
+ logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
19
+ logging.getLogger("transformers").setLevel(logging.ERROR)
20
+ logging.getLogger("absl").setLevel(logging.ERROR)
21
+
22
+ from transformers import logging as hf_logging
23
+ hf_logging.set_verbosity_error()
24
+ hf_logging.disable_progress_bar()
25
+
26
+ logging.getLogger("lightning.fabric.utilities.seed").setLevel(logging.ERROR)
27
+ logging.getLogger("pytorch_lightning.utilities.seed").setLevel(logging.ERROR)
28
+
29
  import torch
30
  from transformers import AutoTokenizer
31
  from pathlib import Path
32
  import inspect
33
 
34
+ # from models.peptide_classifiers import *
35
+ from models.peptiverse_classifiers import *
36
  from utils.parsing import parse_guidance_args
37
  args = parse_guidance_args()
38
 
 
46
 
47
  length = args.length
48
  target = args.target_protein
49
+
50
  if args.motifs:
51
  motifs = parse_motifs(args.motifs).to(device)
 
52
 
53
  tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
54
  target_sequence = tokenizer(target, return_tensors='pt').to(device)
 
58
 
59
  score_models = []
60
  if 'Hemolysis' in args.objectives:
61
+ hemolysis_model = HemolysisWT()
62
  score_models.append(hemolysis_model)
63
  if 'Non-Fouling' in args.objectives:
64
+ nonfouling_model = NonfoulingWT()
65
  score_models.append(nonfouling_model)
66
  if 'Solubility' in args.objectives:
67
+ solubility_model = Solubility()
68
  score_models.append(solubility_model)
69
+ if 'Permeability' in args.objectives:
70
+ permeability_model = PermeabilityWT()
71
+ score_models.append(permeability_model)
72
  if 'Half-Life' in args.objectives:
73
+ halflife_model = HalfLifeWT()
74
  score_models.append(halflife_model)
75
  if 'Affinity' in args.objectives:
76
+ affinity_model = AffinityWT(target)
 
77
  score_models.append(affinity_model)
 
 
 
 
 
78
  if 'Motif' in args.objectives or 'Specificity' in args.objectives:
79
  bindevaluator = load_bindevaluator('./classifier_ckpt/finetuned_BindEvaluator.ckpt', device)
80
+ if 'Specificity' in args.objectives:
81
+ args.specificity = True
82
+ else:
83
+ args.specificity = False
84
+ motif_model = MotifModelWT(bindevaluator, target_sequence['input_ids'], motifs, tokenizer, device, penalty=args.specificity)
85
  score_models.append(motif_model)
86
 
87
  objective_line = "Binder," + str(args.objectives)[1:-1].replace(' ', '').replace("'", "") + '\n'
 
98
  f.write(objective_line)
99
 
100
  for i in range(args.n_batches):
101
+ if args.starting_sequence:
102
+ x_init = tokenizer(args.starting_sequence, return_tensors='pt')['input_ids'].to(device)
 
 
103
  else:
104
+ if source_distribution == "uniform":
105
+ x_init = torch.randint(low=4, high=vocab_size, size=(n_samples, length), device=device) # CHANGE!
106
+ elif source_distribution == "mask":
107
+ x_init = (torch.zeros(size=(n_samples, length), device=device) + 3).long()
108
+ else:
109
+ raise NotImplementedError
110
 
111
+ zeros = torch.zeros((n_samples, 1), dtype=x_init.dtype, device=x_init.device)
112
+ twos = torch.full((n_samples, 1), 2, dtype=x_init.dtype, device=x_init.device)
113
+ x_init = torch.cat([zeros, x_init, twos], dim=1)
114
+
115
+ if args.fixed_positions is not None:
116
+ fixed_positions = parse_motifs(args.fixed_positions).tolist()
117
+ else:
118
+ fixed_positions = []
119
+
120
+ invalid_tokens = torch.tensor([0, 1, 2, 3], device=device)
121
 
122
  x_1 = solver.multi_guidance_sample(args=args, x_init=x_init,
123
  step_size=step_size,
124
  verbose=True,
125
  time_grid=torch.tensor([0.0, 1.0-1e-3]),
126
  score_models=score_models,
127
+ num_objectives=len(score_models) + int(args.specificity),
128
+ weights=args.weights,
129
+ tokenizer=tokenizer,
130
+ fixed_positions=fixed_positions,
131
+ invalid_tokens=invalid_tokens)
 
132
 
133
  scores = []
134
+ input_seqs = [tokenizer.batch_decode(x_1)[0].replace(' ', '')[5:-5]]
135
  for i, s in enumerate(score_models):
136
  sig = inspect.signature(s.forward) if hasattr(s, 'forward') else inspect.signature(s)
137
  if 't' in sig.parameters:
138
+ candidate_scores = s(input_seqs, 1)
139
  else:
140
+ candidate_scores = s(input_seqs)
141
 
142
  if args.objectives[i] == 'Affinity':
143
  candidate_scores = 10 * candidate_scores
 
150
  print(scores)
151
 
152
  with open(args.output_file, 'a') as f:
153
+ f.write(input_seqs[0])
154
  for score in scores:
155
  f.write(f",{score}")
156
  f.write('\n')