Delete moppit.py
Browse files
moppit.py
DELETED
|
@@ -1,90 +0,0 @@
|
|
| 1 |
-
import yaml
|
| 2 |
-
from tqdm import tqdm
|
| 3 |
-
import torch
|
| 4 |
-
from torch import nn
|
| 5 |
-
from transformers import AutoTokenizer
|
| 6 |
-
|
| 7 |
-
from models.peptide_classifiers import *
|
| 8 |
-
|
| 9 |
-
from utils.parsing import parse_guidance_args
|
| 10 |
-
args = parse_guidance_args()
|
| 11 |
-
|
| 12 |
-
import pdb
|
| 13 |
-
import random
|
| 14 |
-
import inspect
|
| 15 |
-
|
| 16 |
-
# MOO hyper-parameters
|
| 17 |
-
step_size = 1 / 100
|
| 18 |
-
n_samples = 1
|
| 19 |
-
length = args.length
|
| 20 |
-
target = args.target_protein
|
| 21 |
-
motifs = args.motifs # args.motifs
|
| 22 |
-
vocab_size = 24
|
| 23 |
-
source_distribution = "uniform"
|
| 24 |
-
device = 'cuda:0'
|
| 25 |
-
|
| 26 |
-
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
| 27 |
-
target_sequence = tokenizer(target, return_tensors='pt')['input_ids'].to(device)
|
| 28 |
-
motifs = parse_motifs(motifs).to(device)
|
| 29 |
-
print(motifs)
|
| 30 |
-
|
| 31 |
-
# Load Models
|
| 32 |
-
solver = load_solver('./ckpt/peptide/cnn_epoch200_lr0.0001_embed512_hidden256_loss3.1051.ckpt', vocab_size, device)
|
| 33 |
-
|
| 34 |
-
bindevaluator = load_bindevaluator('./classifier_ckpt/finetuned_BindEvaluator.ckpt', device)
|
| 35 |
-
motif_model = MotifModel(bindevaluator, target_sequence, motifs, penalty=True)
|
| 36 |
-
|
| 37 |
-
affinity_predictor = load_affinity_predictor('./classifier_ckpt/binding_affinity_unpooled.pt', device)
|
| 38 |
-
affinity_model = AffinityModel(affinity_predictor, target_sequence)
|
| 39 |
-
|
| 40 |
-
score_models = [motif_model, affinity_model]
|
| 41 |
-
|
| 42 |
-
for i in range(args.n_batches):
|
| 43 |
-
if source_distribution == "uniform":
|
| 44 |
-
x_init = torch.randint(low=4, high=vocab_size, size=(n_samples, length), device=device)
|
| 45 |
-
elif source_distribution == "mask":
|
| 46 |
-
x_init = (torch.zeros(size=(n_samples, length), device=device) + 3).long()
|
| 47 |
-
else:
|
| 48 |
-
raise NotImplementedError
|
| 49 |
-
|
| 50 |
-
zeros = torch.zeros((n_samples, 1), dtype=x_init.dtype, device=x_init.device)
|
| 51 |
-
twos = torch.full((n_samples, 1), 2, dtype=x_init.dtype, device=x_init.device)
|
| 52 |
-
x_init = torch.cat([zeros, x_init, twos], dim=1)
|
| 53 |
-
|
| 54 |
-
x_1 = solver.multi_guidance_sample(args=args, x_init=x_init,
|
| 55 |
-
step_size=step_size,
|
| 56 |
-
verbose=True,
|
| 57 |
-
time_grid=torch.tensor([0.0, 1.0-1e-3]),
|
| 58 |
-
score_models=score_models,
|
| 59 |
-
num_objectives=3,
|
| 60 |
-
weights=args.weights)
|
| 61 |
-
|
| 62 |
-
samples = x_1.tolist()
|
| 63 |
-
samples = [tokenizer.decode(seq).replace(' ', '')[5:-5] for seq in samples]
|
| 64 |
-
print(samples)
|
| 65 |
-
|
| 66 |
-
scores = []
|
| 67 |
-
for i, s in enumerate(score_models):
|
| 68 |
-
sig = inspect.signature(s.forward) if hasattr(s, 'forward') else inspect.signature(s)
|
| 69 |
-
if 't' in sig.parameters:
|
| 70 |
-
candidate_scores = s(x_1, 1)
|
| 71 |
-
else:
|
| 72 |
-
candidate_scores = s(x_1)
|
| 73 |
-
|
| 74 |
-
if isinstance(candidate_scores, tuple):
|
| 75 |
-
for score in candidate_scores:
|
| 76 |
-
scores.append(score.item())
|
| 77 |
-
else:
|
| 78 |
-
scores.append(candidate_scores.item())
|
| 79 |
-
print(scores)
|
| 80 |
-
|
| 81 |
-
with open(args.output_file, 'a') as f:
|
| 82 |
-
f.write(samples[0])
|
| 83 |
-
for score in scores:
|
| 84 |
-
f.write(f",{score}")
|
| 85 |
-
f.write('\n')
|
| 86 |
-
# samples = x_1.tolist()
|
| 87 |
-
# sample = [tokenizer.decode(seq).replace(' ', '')[5:-5] for seq in samples][0]
|
| 88 |
-
# with open(f"/vast/home/c/chentong/MOG-DFM/samples/{name}.csv", "a") as f:
|
| 89 |
-
# f.write(sample + ',' + str(score_list_0[-1]) + ',' + str(score_list_1[-1]) + '\n')
|
| 90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|