Upload 6 files
Browse files- utils/__pycache__/parsing.cpython-310.pyc +0 -0
- utils/__pycache__/parsing.cpython-39.pyc +0 -0
- utils/dataloader.py +41 -0
- utils/dataset.py +17 -0
- utils/flow_utils.py +497 -0
- utils/parsing.py +30 -0
utils/__pycache__/parsing.cpython-310.pyc
ADDED
|
Binary file (1.06 kB). View file
|
|
|
utils/__pycache__/parsing.cpython-39.pyc
ADDED
|
Binary file (1.26 kB). View file
|
|
|
utils/dataloader.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from functools import partial
|
| 3 |
+
from torch.utils.data import DataLoader
|
| 4 |
+
from torch import nn
|
| 5 |
+
|
| 6 |
+
def collate_fn(batch):
|
| 7 |
+
input_ids = torch.tensor(batch[0]['input_ids'])
|
| 8 |
+
attention_mask = torch.tensor(batch[0]['attention_mask'])
|
| 9 |
+
return {
|
| 10 |
+
'input_ids': input_ids,
|
| 11 |
+
'attention_mask': attention_mask
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
class CustomDataModule(nn.Module):
|
| 15 |
+
def __init__(self, train_dataset, val_dataset, test_dataset, collate_fn=collate_fn):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.train_dataset = train_dataset
|
| 18 |
+
self.val_dataset = val_dataset
|
| 19 |
+
self.test_dataset = test_dataset
|
| 20 |
+
self.collate_fn = collate_fn
|
| 21 |
+
|
| 22 |
+
def train_dataloader(self):
|
| 23 |
+
return DataLoader(self.train_dataset,
|
| 24 |
+
collate_fn=partial(self.collate_fn),
|
| 25 |
+
num_workers=8,
|
| 26 |
+
pin_memory=True,
|
| 27 |
+
shuffle=True)
|
| 28 |
+
|
| 29 |
+
def val_dataloader(self):
|
| 30 |
+
return DataLoader(self.val_dataset,
|
| 31 |
+
collate_fn=partial(self.collate_fn),
|
| 32 |
+
num_workers=8,
|
| 33 |
+
pin_memory=True,
|
| 34 |
+
shuffle=False)
|
| 35 |
+
|
| 36 |
+
def test_dataloader(self):
|
| 37 |
+
return DataLoader(self.test_dataset,
|
| 38 |
+
collate_fn=partial(self.collate_fn),
|
| 39 |
+
num_workers=8,
|
| 40 |
+
pin_memory=True,
|
| 41 |
+
shuffle=False)
|
utils/dataset.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import pickle
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
class EnhancerDataset(torch.utils.data.Dataset):
|
| 6 |
+
def __init__(self, mel_enhancer=True, split='train'):
|
| 7 |
+
all_data = pickle.load(open(f'./dataset/enhancer_data/Deep{"MEL2" if mel_enhancer else "FlyBrain"}_data.pkl', 'rb'))
|
| 8 |
+
self.seqs = torch.argmax(torch.from_numpy(copy.deepcopy(all_data[f'{split}_data'])), dim=-1)
|
| 9 |
+
self.clss = torch.argmax(torch.from_numpy(copy.deepcopy(all_data[f'y_{split}'])), dim=-1)
|
| 10 |
+
self.num_cls = all_data[f'y_{split}'].shape[-1]
|
| 11 |
+
self.alphabet_size = 4
|
| 12 |
+
|
| 13 |
+
def __len__(self):
|
| 14 |
+
return len(self.seqs)
|
| 15 |
+
|
| 16 |
+
def __getitem__(self, idx):
|
| 17 |
+
return self.seqs[idx], self.clss[idx]
|
utils/flow_utils.py
ADDED
|
@@ -0,0 +1,497 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import math
|
| 3 |
+
import pickle
|
| 4 |
+
|
| 5 |
+
import scipy
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
from scipy.linalg import sqrtm
|
| 11 |
+
import re
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def upgrade_state_dict(state_dict, prefixes=["encoder.sentence_encoder.", "encoder."]):
|
| 18 |
+
"""Removes prefixes 'model.encoder.sentence_encoder.' and 'model.encoder.'."""
|
| 19 |
+
pattern = re.compile("^" + "|".join(prefixes))
|
| 20 |
+
state_dict = {pattern.sub("", name): param for name, param in state_dict.items()}
|
| 21 |
+
return state_dict
|
| 22 |
+
|
| 23 |
+
def map_t_to_alpha(t, alpha_scale):
|
| 24 |
+
"""
|
| 25 |
+
Maps t in [0,1) to the range of alphas using the inverse CDF of an exponential distribution.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
t (torch.Tensor): A tensor of values in [0,1).
|
| 29 |
+
alpha_scale (float): The scaling factor used in the original alpha calculation.
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
torch.Tensor: The corresponding alpha values.
|
| 33 |
+
"""
|
| 34 |
+
if torch.any(t >= 1) or torch.any(t < 0):
|
| 35 |
+
raise ValueError("t must be in the range [0,1).")
|
| 36 |
+
|
| 37 |
+
return 1 + (-torch.log(1 - t)) * alpha_scale
|
| 38 |
+
|
| 39 |
+
# return torch.clamp(1 + (-torch.log(1 - t)) * alpha_scale, torch.tensor(8).to(t.device))
|
| 40 |
+
|
| 41 |
+
def load_flybrain_designed_seqs(path):
|
| 42 |
+
order = {'A': 0, 'C':1, 'G':2, 'T':3}
|
| 43 |
+
f = open(path, "rb")
|
| 44 |
+
data = pickle.load(f)
|
| 45 |
+
arrays = []
|
| 46 |
+
for seq in data['seq']:
|
| 47 |
+
arrays.append([order[char] for char in seq])
|
| 48 |
+
return torch.tensor(arrays, dtype=torch.long)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def update_ema(current_dict, prev_ema, gamma = 0.9):
|
| 52 |
+
ema = copy.deepcopy(prev_ema)
|
| 53 |
+
current_dict = copy.deepcopy(current_dict)
|
| 54 |
+
for key, current_value in current_dict.items():
|
| 55 |
+
ema_key = 'ema_' + key
|
| 56 |
+
if not np.isnan(current_value):
|
| 57 |
+
if ema_key in prev_ema:
|
| 58 |
+
ema[ema_key] = (1 - gamma) * current_value + gamma * prev_ema[ema_key]
|
| 59 |
+
else:
|
| 60 |
+
ema[ema_key] = current_value
|
| 61 |
+
return ema
|
| 62 |
+
|
| 63 |
+
def min_max_str(x):
|
| 64 |
+
return f'min {x.min()} max {x.max()}'
|
| 65 |
+
|
| 66 |
+
def get_wasserstein_dist(embeds1, embeds2):
|
| 67 |
+
if np.isnan(embeds2).any() or np.isnan(embeds1).any() or len(embeds1) == 0 or len(embeds2) == 0:
|
| 68 |
+
return float('nan')
|
| 69 |
+
mu1, sigma1 = embeds1.mean(axis=0), np.cov(embeds1, rowvar=False)
|
| 70 |
+
mu2, sigma2 = embeds2.mean(axis=0), np.cov(embeds2, rowvar=False)
|
| 71 |
+
ssdiff = np.sum((mu1 - mu2) ** 2.0)
|
| 72 |
+
covmean = sqrtm(sigma1.dot(sigma2))
|
| 73 |
+
if np.iscomplexobj(covmean):
|
| 74 |
+
covmean = covmean.real
|
| 75 |
+
dist = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
|
| 76 |
+
return dist
|
| 77 |
+
|
| 78 |
+
def simplex_proj(seq):
|
| 79 |
+
"""Algorithm from https://arxiv.org/abs/1309.1541 Weiran Wang, Miguel Á. Carreira-Perpiñán"""
|
| 80 |
+
Y = seq.reshape(-1, seq.shape[-1])
|
| 81 |
+
N, K = Y.shape
|
| 82 |
+
X, _ = torch.sort(Y, dim=-1, descending=True)
|
| 83 |
+
X_cumsum = torch.cumsum(X, dim=-1) - 1
|
| 84 |
+
div_seq = torch.arange(1, K + 1, dtype=Y.dtype, device=Y.device)
|
| 85 |
+
Xtmp = X_cumsum / div_seq.unsqueeze(0)
|
| 86 |
+
|
| 87 |
+
greater_than_Xtmp = (X > Xtmp).sum(dim=1, keepdim=True)
|
| 88 |
+
row_indices = torch.arange(N, dtype=torch.long, device=Y.device).unsqueeze(1)
|
| 89 |
+
selected_Xtmp = Xtmp[row_indices, greater_than_Xtmp - 1]
|
| 90 |
+
|
| 91 |
+
X = torch.max(Y - selected_Xtmp, torch.zeros_like(Y))
|
| 92 |
+
return X.view(seq.shape)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def batch_project_simplex(v):
|
| 97 |
+
u, _ = torch.sort(v, dim=1, descending=True)
|
| 98 |
+
cssv = u.cumsum(dim=1)
|
| 99 |
+
k = torch.arange(1, v.shape[1] + 1, device=v.device)
|
| 100 |
+
rho = ((u * k) > (cssv - 1)).int().cumsum(dim=1).argmax(dim=1)
|
| 101 |
+
theta = (cssv[torch.arange(v.shape[0]), rho] - 1) / (rho + 1).float()
|
| 102 |
+
w = torch.maximum(v - theta.unsqueeze(1), torch.tensor(0.0, device=v.device))
|
| 103 |
+
return w
|
| 104 |
+
|
| 105 |
+
if __name__ == "__main__":
|
| 106 |
+
a = torch.softmax(torch.rand((5,4)), dim=-1)
|
| 107 |
+
b = torch.rand((5,4)) - 1
|
| 108 |
+
ab = torch.cat([a,b])
|
| 109 |
+
ab_proj1 = batch_project_simplex(ab)
|
| 110 |
+
ab_proj2 = simplex_proj(ab)
|
| 111 |
+
print('ab_proj1 - ab_proj2',ab_proj1 - ab_proj2)
|
| 112 |
+
print('ab_proj1 - ab', ab_proj1 - ab)
|
| 113 |
+
print('ab_proj2.sum(-1)', ab_proj2.sum(-1))
|
| 114 |
+
print('ab_proj2', ab_proj2)
|
| 115 |
+
|
| 116 |
+
def sample_cond_prob_path(args, seq, alphabet_size):
|
| 117 |
+
B, L = seq.shape
|
| 118 |
+
seq_one_hot = torch.nn.functional.one_hot(seq, num_classes=alphabet_size)
|
| 119 |
+
if args.mode == 'dirichlet':
|
| 120 |
+
alphas = torch.from_numpy(1 + scipy.stats.expon().rvs(size=B) * args.alpha_scale).to(seq.device).float()
|
| 121 |
+
if args.fix_alpha:
|
| 122 |
+
alphas = torch.ones(B, device=seq.device) * args.fix_alpha
|
| 123 |
+
alphas_ = torch.ones(B, L, alphabet_size, device=seq.device)
|
| 124 |
+
alphas_ = alphas_ + seq_one_hot * (alphas[:,None,None] - 1)
|
| 125 |
+
xt = torch.distributions.Dirichlet(alphas_).sample()
|
| 126 |
+
elif args.mode == 'distill':
|
| 127 |
+
alphas = torch.zeros(B, device=seq.device)
|
| 128 |
+
xt = torch.distributions.Dirichlet(torch.ones(B, L, alphabet_size, device=seq.device)).sample()
|
| 129 |
+
elif args.mode == 'riemannian':
|
| 130 |
+
t = torch.rand(B, device=seq.device)
|
| 131 |
+
dirichlet = torch.distributions.Dirichlet(torch.ones(alphabet_size, device=seq.device))
|
| 132 |
+
x0 = dirichlet.sample((B,L))
|
| 133 |
+
x1 = seq_one_hot
|
| 134 |
+
xt = t[:,None,None] * x1 + (1 - t[:,None,None]) * x0
|
| 135 |
+
alphas = t
|
| 136 |
+
elif args.mode == 'ardm' or args.mode == 'lrar':
|
| 137 |
+
mask_prob = torch.rand(1, device=seq.device)
|
| 138 |
+
mask = torch.rand(seq.shape, device=seq.device) < mask_prob
|
| 139 |
+
if args.mode == 'lrar': mask = ~(torch.arange(L, device=seq.device) < (1-mask_prob) * L)
|
| 140 |
+
xt = torch.where(mask, alphabet_size, seq) # mask token index
|
| 141 |
+
xt = torch.nn.functional.one_hot(xt, num_classes=alphabet_size + 1).float() # plus one to include index for mask token
|
| 142 |
+
alphas = mask_prob.expand(B)
|
| 143 |
+
return xt, alphas
|
| 144 |
+
|
| 145 |
+
def expand_simplex(xt, alphas, prior_pseudocount):
|
| 146 |
+
prior_weights = (prior_pseudocount / (alphas + prior_pseudocount - 1))[:, None, None]
|
| 147 |
+
return torch.cat([xt * (1 - prior_weights), xt * prior_weights], -1), prior_weights
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class DirichletConditionalFlow:
|
| 151 |
+
def __init__(self, K=20, alpha_min=1, alpha_max=100, alpha_spacing=0.01):
|
| 152 |
+
self.alphas = np.arange(alpha_min, alpha_max + alpha_spacing, alpha_spacing)
|
| 153 |
+
self.beta_cdfs = []
|
| 154 |
+
self.bs = np.linspace(0, 1, 1000)
|
| 155 |
+
for alph in self.alphas:
|
| 156 |
+
self.beta_cdfs.append(scipy.special.betainc(alph, K-1, self.bs))
|
| 157 |
+
self.beta_cdfs = np.array(self.beta_cdfs)
|
| 158 |
+
self.beta_cdfs_derivative = np.diff(self.beta_cdfs, axis=0) / alpha_spacing
|
| 159 |
+
self.K = K
|
| 160 |
+
|
| 161 |
+
def c_factor(self, bs, alpha):
|
| 162 |
+
out1 = scipy.special.beta(alpha, self.K - 1)
|
| 163 |
+
out2 = np.where(bs < 1, out1 / ((1 - bs) ** (self.K - 1)), 0)
|
| 164 |
+
out = np.where((bs ** (alpha - 1)) > 0, out2 / (bs ** (alpha - 1)), 0)
|
| 165 |
+
I_func = self.beta_cdfs_derivative[np.argmin(np.abs(alpha - self.alphas))]
|
| 166 |
+
interp = -np.interp(bs, self.bs, I_func)
|
| 167 |
+
final = interp * out
|
| 168 |
+
return final
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class GaussianSmearing(torch.nn.Module):
|
| 172 |
+
# used to embed the edge distances
|
| 173 |
+
def __init__(self, start=0.0, stop=5.0, embedding_dim=50):
|
| 174 |
+
super().__init__()
|
| 175 |
+
offset = torch.linspace(start, stop, embedding_dim)
|
| 176 |
+
self.coeff = -0.5 / (offset[1] - offset[0]).item() ** 2
|
| 177 |
+
self.register_buffer("offset", offset)
|
| 178 |
+
self.embedding_dim = embedding_dim
|
| 179 |
+
|
| 180 |
+
def forward(self, signal):
|
| 181 |
+
shape = signal.shape
|
| 182 |
+
signal = signal.view(-1, 1) - self.offset.view(1, -1) + 1E-6
|
| 183 |
+
encoded = torch.exp(self.coeff * torch.pow(signal, 2))
|
| 184 |
+
return encoded.view(*shape, self.embedding_dim)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class MonotonicFunction(torch.nn.Module):
|
| 188 |
+
def __init__(self, init_max, num_bins):
|
| 189 |
+
super().__init__()
|
| 190 |
+
self.w = torch.nn.Parameter(torch.ones(num_bins) * np.log(init_max) - np.log(num_bins))
|
| 191 |
+
self.num_bins = num_bins
|
| 192 |
+
|
| 193 |
+
def forward(self, t):
|
| 194 |
+
widths = torch.exp(self.w)
|
| 195 |
+
right = torch.cumsum(widths, 0)
|
| 196 |
+
left = right - widths
|
| 197 |
+
|
| 198 |
+
bin_idx = (t * self.num_bins).long()
|
| 199 |
+
frac_part = t - bin_idx * (1 / self.num_bins)
|
| 200 |
+
|
| 201 |
+
return left[bin_idx] + (frac_part * self.num_bins) * (right[bin_idx] - left[bin_idx])
|
| 202 |
+
|
| 203 |
+
def invert(self, f):
|
| 204 |
+
widths = torch.exp(self.w)
|
| 205 |
+
left = torch.cumsum(widths, 0) - widths
|
| 206 |
+
bin_idx = (f.unsqueeze(-1) > left).sum(-1) - 1
|
| 207 |
+
frac_part = f - left[bin_idx]
|
| 208 |
+
return bin_idx / self.num_bins + frac_part / widths[bin_idx] / self.num_bins
|
| 209 |
+
|
| 210 |
+
def derivative(self, t):
|
| 211 |
+
widths = torch.exp(self.w)
|
| 212 |
+
right = torch.cumsum(widths, 0)
|
| 213 |
+
left = right - widths
|
| 214 |
+
bin_idx = (t * self.num_bins).long()
|
| 215 |
+
return (right[bin_idx] - left[bin_idx]) * self.num_bins
|
| 216 |
+
|
| 217 |
+
class SinusoidalEmbedding(nn.Module):
|
| 218 |
+
""" from https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py """
|
| 219 |
+
def __init__(self, embedding_dim, embedding_scale, max_positions=10000):
|
| 220 |
+
super().__init__()
|
| 221 |
+
self.embedding_dim = embedding_dim
|
| 222 |
+
self.max_positions = max_positions
|
| 223 |
+
self.embedding_scale = embedding_scale
|
| 224 |
+
|
| 225 |
+
def forward(self, signal):
|
| 226 |
+
shape = signal.shape
|
| 227 |
+
signal = signal.view(-1) * self.embedding_scale
|
| 228 |
+
half_dim = self.embedding_dim // 2
|
| 229 |
+
emb = math.log(self.max_positions) / (half_dim - 1)
|
| 230 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=signal.device) * -emb)
|
| 231 |
+
emb = signal.float()[:, None] * emb[None, :]
|
| 232 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
| 233 |
+
if self.embedding_dim % 2 == 1: # zero pad
|
| 234 |
+
emb = F.pad(emb, (0, 1), mode='constant')
|
| 235 |
+
assert emb.shape == (signal.shape[0], self.embedding_dim)
|
| 236 |
+
return emb.view(*shape, self.embedding_dim )
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class GaussianFourierProjection(nn.Module):
|
| 240 |
+
"""Gaussian Fourier embeddings for noise levels.
|
| 241 |
+
from https://github.com/yang-song/score_sde_pytorch/blob/1618ddea340f3e4a2ed7852a0694a809775cf8d0/models/layerspp.py#L32
|
| 242 |
+
"""
|
| 243 |
+
|
| 244 |
+
def __init__(self, embedding_dim=256, scale=1.0):
|
| 245 |
+
super().__init__()
|
| 246 |
+
self.W = nn.Parameter(torch.randn(embedding_dim//2) * scale, requires_grad=False)
|
| 247 |
+
self.embedding_dim = embedding_dim
|
| 248 |
+
|
| 249 |
+
def forward(self, signal):
|
| 250 |
+
shape = signal.shape
|
| 251 |
+
signal = signal.view(-1)
|
| 252 |
+
signal_proj = signal[:, None] * self.W[None, :] * 2 * np.pi
|
| 253 |
+
emb = torch.cat([torch.sin(signal_proj), torch.cos(signal_proj)], dim=-1)
|
| 254 |
+
return emb.view(*shape, self.embedding_dim )
|
| 255 |
+
|
| 256 |
+
def get_signal_mapping(embedding_type, embedding_dim, embedding_scale=10000):
|
| 257 |
+
if embedding_type == 'sinusoidal':
|
| 258 |
+
emb_func = SinusoidalEmbedding(embedding_dim=embedding_dim, embedding_scale=embedding_scale)
|
| 259 |
+
elif embedding_type == 'fourier':
|
| 260 |
+
emb_func = GaussianFourierProjection(embedding_dim=embedding_dim, scale=embedding_scale)
|
| 261 |
+
elif embedding_type == 'gaussian':
|
| 262 |
+
emb_func = GaussianSmearing(0.0, 1, embedding_dim)
|
| 263 |
+
else:
|
| 264 |
+
raise NotImplemented
|
| 265 |
+
return emb_func
|
| 266 |
+
|
| 267 |
+
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
| 268 |
+
"""
|
| 269 |
+
Create a beta schedule that discretizes the given alpha_t_bar function,
|
| 270 |
+
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
| 271 |
+
|
| 272 |
+
:param num_diffusion_timesteps: the number of betas to produce.
|
| 273 |
+
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
| 274 |
+
produces the cumulative product of (1-beta) up to that
|
| 275 |
+
part of the diffusion process.
|
| 276 |
+
:param max_beta: the maximum beta to use; use values lower than 1 to
|
| 277 |
+
prevent singularities.
|
| 278 |
+
"""
|
| 279 |
+
betas = []
|
| 280 |
+
for i in range(num_diffusion_timesteps):
|
| 281 |
+
t1 = i / num_diffusion_timesteps
|
| 282 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
| 283 |
+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
| 284 |
+
return np.array(betas)
|
| 285 |
+
|
| 286 |
+
def get_beta_schedule(num_steps):
|
| 287 |
+
|
| 288 |
+
return betas_for_alpha_bar(
|
| 289 |
+
num_steps,
|
| 290 |
+
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
class GaussianDiffusionSchedule:
|
| 295 |
+
"""
|
| 296 |
+
Utilities for training and sampling diffusion models.
|
| 297 |
+
|
| 298 |
+
Ported directly from here, and then adapted over time to further experimentation.
|
| 299 |
+
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
|
| 300 |
+
|
| 301 |
+
:param betas: a 1-D numpy array of betas for each diffusion timestep,
|
| 302 |
+
starting at T and going to 1.
|
| 303 |
+
:param model_mean_type: a ModelMeanType determining what the model outputs.
|
| 304 |
+
:param model_var_type: a ModelVarType determining how variance is output.
|
| 305 |
+
:param loss_type: a LossType determining the loss function to use.
|
| 306 |
+
:param rescale_timesteps: if True, pass floating point timesteps into the
|
| 307 |
+
model so that they are always scaled like in the
|
| 308 |
+
original paper (0 to 1000).
|
| 309 |
+
"""
|
| 310 |
+
|
| 311 |
+
def __init__(
|
| 312 |
+
self,
|
| 313 |
+
timesteps,
|
| 314 |
+
noise_scale=1.0,
|
| 315 |
+
):
|
| 316 |
+
betas = get_beta_schedule(timesteps)
|
| 317 |
+
|
| 318 |
+
# Use float64 for accuracy.
|
| 319 |
+
betas = np.array(betas, dtype=np.float64)
|
| 320 |
+
self.betas = betas
|
| 321 |
+
assert len(betas.shape) == 1, "betas must be 1-D"
|
| 322 |
+
assert (betas > 0).all() and (betas <= 1).all()
|
| 323 |
+
|
| 324 |
+
self.timesteps = int(betas.shape[0])
|
| 325 |
+
self.noise_scale = noise_scale
|
| 326 |
+
|
| 327 |
+
alphas = 1.0 - betas
|
| 328 |
+
self.alphas_cumprod = np.cumprod(alphas, axis=0)
|
| 329 |
+
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
|
| 330 |
+
self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
|
| 331 |
+
assert self.alphas_cumprod_prev.shape == (self.timesteps,)
|
| 332 |
+
|
| 333 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
| 334 |
+
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
|
| 335 |
+
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
|
| 336 |
+
self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
|
| 337 |
+
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
|
| 338 |
+
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
|
| 339 |
+
|
| 340 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
| 341 |
+
self.posterior_variance = (
|
| 342 |
+
betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
| 343 |
+
)
|
| 344 |
+
# log calculation clipped because the posterior variance is 0 at the
|
| 345 |
+
# beginning of the diffusion chain.
|
| 346 |
+
self.posterior_log_variance_clipped = np.log(
|
| 347 |
+
np.append(self.posterior_variance[1], self.posterior_variance[1:])
|
| 348 |
+
)
|
| 349 |
+
self.posterior_mean_coef1 = (
|
| 350 |
+
betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
| 351 |
+
)
|
| 352 |
+
self.posterior_mean_coef2 = (
|
| 353 |
+
(1.0 - self.alphas_cumprod_prev)
|
| 354 |
+
* np.sqrt(alphas)
|
| 355 |
+
/ (1.0 - self.alphas_cumprod)
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
def q_sample(self, x_start, t, noise=None):
|
| 359 |
+
"""
|
| 360 |
+
Diffuse the data for a given number of diffusion steps.
|
| 361 |
+
|
| 362 |
+
In other words, sample from q(x_t | x_0).
|
| 363 |
+
|
| 364 |
+
:param x_start: the initial data batch.
|
| 365 |
+
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
| 366 |
+
:param noise: if specified, the split-out normal noise.
|
| 367 |
+
:return: A noisy version of x_start.
|
| 368 |
+
"""
|
| 369 |
+
if noise is None:
|
| 370 |
+
noise = self.noise_scale * torch.randn_like(x_start)
|
| 371 |
+
# add scaling here
|
| 372 |
+
assert noise.shape == x_start.shape
|
| 373 |
+
return (
|
| 374 |
+
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
| 375 |
+
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
|
| 376 |
+
* noise
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
def q_posterior_mean_variance(self, x_start, x_t, t):
|
| 380 |
+
"""
|
| 381 |
+
Compute the mean and variance of the diffusion posterior:
|
| 382 |
+
|
| 383 |
+
q(x_{t-1} | x_t, x_0)
|
| 384 |
+
|
| 385 |
+
"""
|
| 386 |
+
assert x_start.shape == x_t.shape
|
| 387 |
+
posterior_mean = (
|
| 388 |
+
_extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
|
| 389 |
+
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
| 390 |
+
)
|
| 391 |
+
posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
|
| 392 |
+
posterior_log_variance_clipped = _extract_into_tensor(
|
| 393 |
+
self.posterior_log_variance_clipped, t, x_t.shape
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
posterior_variance = (self.noise_scale ** 2) * posterior_variance
|
| 397 |
+
posterior_log_variance_clipped = 2 * np.log(self.noise_scale) + posterior_log_variance_clipped
|
| 398 |
+
|
| 399 |
+
assert (
|
| 400 |
+
posterior_mean.shape[0]
|
| 401 |
+
== posterior_variance.shape[0]
|
| 402 |
+
== posterior_log_variance_clipped.shape[0]
|
| 403 |
+
== x_start.shape[0]
|
| 404 |
+
)
|
| 405 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
| 409 |
+
"""
|
| 410 |
+
Extract values from a 1-D numpy array for a batch of indices.
|
| 411 |
+
|
| 412 |
+
:param arr: the 1-D numpy array.
|
| 413 |
+
:param timesteps: a tensor of indices into the array to extract.
|
| 414 |
+
:param broadcast_shape: a larger shape of K dimensions with the batch
|
| 415 |
+
dimension equal to the length of timesteps.
|
| 416 |
+
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
|
| 417 |
+
"""
|
| 418 |
+
res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
|
| 419 |
+
while len(res.shape) < len(broadcast_shape):
|
| 420 |
+
res = res[..., None]
|
| 421 |
+
return res.expand(broadcast_shape)
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
def space_timesteps(num_timesteps, section_counts):
|
| 425 |
+
"""
|
| 426 |
+
Create a list of timesteps to use from an original diffusion process,
|
| 427 |
+
given the number of timesteps we want to take from equally-sized portions
|
| 428 |
+
of the original process.
|
| 429 |
+
|
| 430 |
+
For example, if there's 300 timesteps and the section counts are [10,15,20]
|
| 431 |
+
then the first 100 timesteps are strided to be 10 timesteps, the second 100
|
| 432 |
+
are strided to be 15 timesteps, and the final 100 are strided to be 20.
|
| 433 |
+
|
| 434 |
+
If the stride is a string starting with "ddim", then the fixed striding
|
| 435 |
+
from the DDIM paper is used, and only one section is allowed.
|
| 436 |
+
|
| 437 |
+
:param num_timesteps: the number of diffusion steps in the original
|
| 438 |
+
process to divide up.
|
| 439 |
+
:param section_counts: either a list of numbers, or a string containing
|
| 440 |
+
comma-separated numbers, indicating the step count
|
| 441 |
+
per section. As a special case, use "ddimN" where N
|
| 442 |
+
is a number of steps to use the striding from the
|
| 443 |
+
DDIM paper.
|
| 444 |
+
:return: a set of diffusion steps from the original process to use.
|
| 445 |
+
"""
|
| 446 |
+
if isinstance(section_counts, str):
|
| 447 |
+
if section_counts.startswith("ddim"):
|
| 448 |
+
desired_count = int(section_counts[len("ddim"):])
|
| 449 |
+
for i in range(1, num_timesteps):
|
| 450 |
+
if len(range(0, num_timesteps, i)) == desired_count:
|
| 451 |
+
return set(range(0, num_timesteps, i))
|
| 452 |
+
raise ValueError(
|
| 453 |
+
f"cannot create exactly {num_timesteps} steps with an integer stride"
|
| 454 |
+
)
|
| 455 |
+
section_counts = [int(x) for x in section_counts.split(",")]
|
| 456 |
+
size_per = num_timesteps // len(section_counts)
|
| 457 |
+
extra = num_timesteps % len(section_counts)
|
| 458 |
+
start_idx = 0
|
| 459 |
+
all_steps = []
|
| 460 |
+
for i, section_count in enumerate(section_counts):
|
| 461 |
+
size = size_per + (1 if i < extra else 0)
|
| 462 |
+
if size < section_count:
|
| 463 |
+
raise ValueError(
|
| 464 |
+
f"cannot divide section of {size} steps into {section_count}"
|
| 465 |
+
)
|
| 466 |
+
if section_count <= 1:
|
| 467 |
+
frac_stride = 1
|
| 468 |
+
else:
|
| 469 |
+
frac_stride = (size - 1) / (section_count - 1)
|
| 470 |
+
cur_idx = 0.0
|
| 471 |
+
taken_steps = []
|
| 472 |
+
for _ in range(section_count):
|
| 473 |
+
taken_steps.append(start_idx + round(cur_idx))
|
| 474 |
+
cur_idx += frac_stride
|
| 475 |
+
all_steps += taken_steps
|
| 476 |
+
start_idx += size
|
| 477 |
+
return set(all_steps)
|
| 478 |
+
|
| 479 |
+
def timestep_embedding(timesteps, dim, max_period=10000):
|
| 480 |
+
"""
|
| 481 |
+
Create sinusoidal timestep embeddings.
|
| 482 |
+
|
| 483 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
| 484 |
+
These may be fractional.
|
| 485 |
+
:param dim: the dimension of the output.
|
| 486 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 487 |
+
:return: an [N x dim] Tensor of positional embeddings.
|
| 488 |
+
"""
|
| 489 |
+
half = dim // 2
|
| 490 |
+
freqs = torch.exp(
|
| 491 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
| 492 |
+
).to(device=timesteps.device)
|
| 493 |
+
args = timesteps[:, None].float() * freqs[None]
|
| 494 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 495 |
+
if dim % 2:
|
| 496 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 497 |
+
return embedding
|
utils/parsing.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from argparse import ArgumentParser
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
def parse_guidance_args():
|
| 5 |
+
parser = ArgumentParser()
|
| 6 |
+
|
| 7 |
+
parser.add_argument("--num_div", type=int, default=64)
|
| 8 |
+
parser.add_argument("--lambda_", type=float, default=1.0)
|
| 9 |
+
parser.add_argument("--beta", type=float, default=1.0)
|
| 10 |
+
parser.add_argument("--alpha_r", type=float, default=0.5)
|
| 11 |
+
parser.add_argument("--eta", type=float, default=1.0)
|
| 12 |
+
parser.add_argument("--Phi_init", type=float, default=math.radians(45.0))
|
| 13 |
+
parser.add_argument("--Phi_min", type=float, default=math.radians(15.0))
|
| 14 |
+
parser.add_argument("--Phi_max", type=float, default=math.radians(75.0))
|
| 15 |
+
parser.add_argument("--tau", type=float, default=0.3)
|
| 16 |
+
parser.add_argument("--T", type=int, default=100)
|
| 17 |
+
parser.add_argument("--length", type=int, default=12)
|
| 18 |
+
parser.add_argument("--is_peptide", type=bool, default=True)
|
| 19 |
+
parser.add_argument("--n_samples", type=int, default=5)
|
| 20 |
+
parser.add_argument("--n_batches", type=int, default=2)
|
| 21 |
+
parser.add_argument("--target_protein", type=str, default="AAAAA")
|
| 22 |
+
parser.add_argument("--target_enhancer_class", type=int, default=0)
|
| 23 |
+
parser.add_argument("--target_DNA_shape", type=str, default='HelT')
|
| 24 |
+
parser.add_argument("--motifs", type=str, required=False)
|
| 25 |
+
parser.add_argument("--weights", type=float, nargs='+', required=False)
|
| 26 |
+
parser.add_argument("--output_file", type=str, default='moo_outputs.txt')
|
| 27 |
+
parser.add_argument("--motif_penalty", action='store_true')
|
| 28 |
+
|
| 29 |
+
args = parser.parse_args()
|
| 30 |
+
return args
|