AlienChen commited on
Commit
5d5918d
·
verified ·
1 Parent(s): f71f3fa

Delete moppit.py

Browse files
Files changed (1) hide show
  1. moppit.py +0 -90
moppit.py DELETED
@@ -1,90 +0,0 @@
1
- import yaml
2
- from tqdm import tqdm
3
- import torch
4
- from torch import nn
5
- from transformers import AutoTokenizer
6
-
7
- from models.peptide_classifiers import *
8
-
9
- from utils.parsing import parse_guidance_args
10
- args = parse_guidance_args()
11
-
12
- import pdb
13
- import random
14
- import inspect
15
-
16
- # MOO hyper-parameters
17
- step_size = 1 / 100
18
- n_samples = 1
19
- length = args.length
20
- target = args.target_protein
21
- motifs = args.motifs # args.motifs
22
- vocab_size = 24
23
- source_distribution = "uniform"
24
- device = 'cuda:0'
25
-
26
- tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
27
- target_sequence = tokenizer(target, return_tensors='pt')['input_ids'].to(device)
28
- motifs = parse_motifs(motifs).to(device)
29
- print(motifs)
30
-
31
- # Load Models
32
- solver = load_solver('./ckpt/peptide/cnn_epoch200_lr0.0001_embed512_hidden256_loss3.1051.ckpt', vocab_size, device)
33
-
34
- bindevaluator = load_bindevaluator('./classifier_ckpt/finetuned_BindEvaluator.ckpt', device)
35
- motif_model = MotifModel(bindevaluator, target_sequence, motifs, penalty=True)
36
-
37
- affinity_predictor = load_affinity_predictor('./classifier_ckpt/binding_affinity_unpooled.pt', device)
38
- affinity_model = AffinityModel(affinity_predictor, target_sequence)
39
-
40
- score_models = [motif_model, affinity_model]
41
-
42
- for i in range(args.n_batches):
43
- if source_distribution == "uniform":
44
- x_init = torch.randint(low=4, high=vocab_size, size=(n_samples, length), device=device)
45
- elif source_distribution == "mask":
46
- x_init = (torch.zeros(size=(n_samples, length), device=device) + 3).long()
47
- else:
48
- raise NotImplementedError
49
-
50
- zeros = torch.zeros((n_samples, 1), dtype=x_init.dtype, device=x_init.device)
51
- twos = torch.full((n_samples, 1), 2, dtype=x_init.dtype, device=x_init.device)
52
- x_init = torch.cat([zeros, x_init, twos], dim=1)
53
-
54
- x_1 = solver.multi_guidance_sample(args=args, x_init=x_init,
55
- step_size=step_size,
56
- verbose=True,
57
- time_grid=torch.tensor([0.0, 1.0-1e-3]),
58
- score_models=score_models,
59
- num_objectives=3,
60
- weights=args.weights)
61
-
62
- samples = x_1.tolist()
63
- samples = [tokenizer.decode(seq).replace(' ', '')[5:-5] for seq in samples]
64
- print(samples)
65
-
66
- scores = []
67
- for i, s in enumerate(score_models):
68
- sig = inspect.signature(s.forward) if hasattr(s, 'forward') else inspect.signature(s)
69
- if 't' in sig.parameters:
70
- candidate_scores = s(x_1, 1)
71
- else:
72
- candidate_scores = s(x_1)
73
-
74
- if isinstance(candidate_scores, tuple):
75
- for score in candidate_scores:
76
- scores.append(score.item())
77
- else:
78
- scores.append(candidate_scores.item())
79
- print(scores)
80
-
81
- with open(args.output_file, 'a') as f:
82
- f.write(samples[0])
83
- for score in scores:
84
- f.write(f",{score}")
85
- f.write('\n')
86
- # samples = x_1.tolist()
87
- # sample = [tokenizer.decode(seq).replace(' ', '')[5:-5] for seq in samples][0]
88
- # with open(f"/vast/home/c/chentong/MOG-DFM/samples/{name}.csv", "a") as f:
89
- # f.write(sample + ',' + str(score_list_0[-1]) + ',' + str(score_list_1[-1]) + '\n')
90
-