AlienChen commited on
Commit
b24eac9
·
verified ·
1 Parent(s): 9fb6785

Upload 6 files

Browse files
utils/__pycache__/parsing.cpython-310.pyc ADDED
Binary file (1.06 kB). View file
 
utils/__pycache__/parsing.cpython-39.pyc ADDED
Binary file (1.26 kB). View file
 
utils/dataloader.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from functools import partial
3
+ from torch.utils.data import DataLoader
4
+ from torch import nn
5
+
6
+ def collate_fn(batch):
7
+ input_ids = torch.tensor(batch[0]['input_ids'])
8
+ attention_mask = torch.tensor(batch[0]['attention_mask'])
9
+ return {
10
+ 'input_ids': input_ids,
11
+ 'attention_mask': attention_mask
12
+ }
13
+
14
+ class CustomDataModule(nn.Module):
15
+ def __init__(self, train_dataset, val_dataset, test_dataset, collate_fn=collate_fn):
16
+ super().__init__()
17
+ self.train_dataset = train_dataset
18
+ self.val_dataset = val_dataset
19
+ self.test_dataset = test_dataset
20
+ self.collate_fn = collate_fn
21
+
22
+ def train_dataloader(self):
23
+ return DataLoader(self.train_dataset,
24
+ collate_fn=partial(self.collate_fn),
25
+ num_workers=8,
26
+ pin_memory=True,
27
+ shuffle=True)
28
+
29
+ def val_dataloader(self):
30
+ return DataLoader(self.val_dataset,
31
+ collate_fn=partial(self.collate_fn),
32
+ num_workers=8,
33
+ pin_memory=True,
34
+ shuffle=False)
35
+
36
+ def test_dataloader(self):
37
+ return DataLoader(self.test_dataset,
38
+ collate_fn=partial(self.collate_fn),
39
+ num_workers=8,
40
+ pin_memory=True,
41
+ shuffle=False)
utils/dataset.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import pickle
3
+ import torch
4
+
5
+ class EnhancerDataset(torch.utils.data.Dataset):
6
+ def __init__(self, mel_enhancer=True, split='train'):
7
+ all_data = pickle.load(open(f'./dataset/enhancer_data/Deep{"MEL2" if mel_enhancer else "FlyBrain"}_data.pkl', 'rb'))
8
+ self.seqs = torch.argmax(torch.from_numpy(copy.deepcopy(all_data[f'{split}_data'])), dim=-1)
9
+ self.clss = torch.argmax(torch.from_numpy(copy.deepcopy(all_data[f'y_{split}'])), dim=-1)
10
+ self.num_cls = all_data[f'y_{split}'].shape[-1]
11
+ self.alphabet_size = 4
12
+
13
+ def __len__(self):
14
+ return len(self.seqs)
15
+
16
+ def __getitem__(self, idx):
17
+ return self.seqs[idx], self.clss[idx]
utils/flow_utils.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ import pickle
4
+
5
+ import scipy
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ from scipy.linalg import sqrtm
11
+ import re
12
+
13
+
14
+
15
+
16
+
17
+ def upgrade_state_dict(state_dict, prefixes=["encoder.sentence_encoder.", "encoder."]):
18
+ """Removes prefixes 'model.encoder.sentence_encoder.' and 'model.encoder.'."""
19
+ pattern = re.compile("^" + "|".join(prefixes))
20
+ state_dict = {pattern.sub("", name): param for name, param in state_dict.items()}
21
+ return state_dict
22
+
23
+ def map_t_to_alpha(t, alpha_scale):
24
+ """
25
+ Maps t in [0,1) to the range of alphas using the inverse CDF of an exponential distribution.
26
+
27
+ Args:
28
+ t (torch.Tensor): A tensor of values in [0,1).
29
+ alpha_scale (float): The scaling factor used in the original alpha calculation.
30
+
31
+ Returns:
32
+ torch.Tensor: The corresponding alpha values.
33
+ """
34
+ if torch.any(t >= 1) or torch.any(t < 0):
35
+ raise ValueError("t must be in the range [0,1).")
36
+
37
+ return 1 + (-torch.log(1 - t)) * alpha_scale
38
+
39
+ # return torch.clamp(1 + (-torch.log(1 - t)) * alpha_scale, torch.tensor(8).to(t.device))
40
+
41
+ def load_flybrain_designed_seqs(path):
42
+ order = {'A': 0, 'C':1, 'G':2, 'T':3}
43
+ f = open(path, "rb")
44
+ data = pickle.load(f)
45
+ arrays = []
46
+ for seq in data['seq']:
47
+ arrays.append([order[char] for char in seq])
48
+ return torch.tensor(arrays, dtype=torch.long)
49
+
50
+
51
+ def update_ema(current_dict, prev_ema, gamma = 0.9):
52
+ ema = copy.deepcopy(prev_ema)
53
+ current_dict = copy.deepcopy(current_dict)
54
+ for key, current_value in current_dict.items():
55
+ ema_key = 'ema_' + key
56
+ if not np.isnan(current_value):
57
+ if ema_key in prev_ema:
58
+ ema[ema_key] = (1 - gamma) * current_value + gamma * prev_ema[ema_key]
59
+ else:
60
+ ema[ema_key] = current_value
61
+ return ema
62
+
63
+ def min_max_str(x):
64
+ return f'min {x.min()} max {x.max()}'
65
+
66
+ def get_wasserstein_dist(embeds1, embeds2):
67
+ if np.isnan(embeds2).any() or np.isnan(embeds1).any() or len(embeds1) == 0 or len(embeds2) == 0:
68
+ return float('nan')
69
+ mu1, sigma1 = embeds1.mean(axis=0), np.cov(embeds1, rowvar=False)
70
+ mu2, sigma2 = embeds2.mean(axis=0), np.cov(embeds2, rowvar=False)
71
+ ssdiff = np.sum((mu1 - mu2) ** 2.0)
72
+ covmean = sqrtm(sigma1.dot(sigma2))
73
+ if np.iscomplexobj(covmean):
74
+ covmean = covmean.real
75
+ dist = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
76
+ return dist
77
+
78
+ def simplex_proj(seq):
79
+ """Algorithm from https://arxiv.org/abs/1309.1541 Weiran Wang, Miguel Á. Carreira-Perpiñán"""
80
+ Y = seq.reshape(-1, seq.shape[-1])
81
+ N, K = Y.shape
82
+ X, _ = torch.sort(Y, dim=-1, descending=True)
83
+ X_cumsum = torch.cumsum(X, dim=-1) - 1
84
+ div_seq = torch.arange(1, K + 1, dtype=Y.dtype, device=Y.device)
85
+ Xtmp = X_cumsum / div_seq.unsqueeze(0)
86
+
87
+ greater_than_Xtmp = (X > Xtmp).sum(dim=1, keepdim=True)
88
+ row_indices = torch.arange(N, dtype=torch.long, device=Y.device).unsqueeze(1)
89
+ selected_Xtmp = Xtmp[row_indices, greater_than_Xtmp - 1]
90
+
91
+ X = torch.max(Y - selected_Xtmp, torch.zeros_like(Y))
92
+ return X.view(seq.shape)
93
+
94
+
95
+
96
+ def batch_project_simplex(v):
97
+ u, _ = torch.sort(v, dim=1, descending=True)
98
+ cssv = u.cumsum(dim=1)
99
+ k = torch.arange(1, v.shape[1] + 1, device=v.device)
100
+ rho = ((u * k) > (cssv - 1)).int().cumsum(dim=1).argmax(dim=1)
101
+ theta = (cssv[torch.arange(v.shape[0]), rho] - 1) / (rho + 1).float()
102
+ w = torch.maximum(v - theta.unsqueeze(1), torch.tensor(0.0, device=v.device))
103
+ return w
104
+
105
+ if __name__ == "__main__":
106
+ a = torch.softmax(torch.rand((5,4)), dim=-1)
107
+ b = torch.rand((5,4)) - 1
108
+ ab = torch.cat([a,b])
109
+ ab_proj1 = batch_project_simplex(ab)
110
+ ab_proj2 = simplex_proj(ab)
111
+ print('ab_proj1 - ab_proj2',ab_proj1 - ab_proj2)
112
+ print('ab_proj1 - ab', ab_proj1 - ab)
113
+ print('ab_proj2.sum(-1)', ab_proj2.sum(-1))
114
+ print('ab_proj2', ab_proj2)
115
+
116
+ def sample_cond_prob_path(args, seq, alphabet_size):
117
+ B, L = seq.shape
118
+ seq_one_hot = torch.nn.functional.one_hot(seq, num_classes=alphabet_size)
119
+ if args.mode == 'dirichlet':
120
+ alphas = torch.from_numpy(1 + scipy.stats.expon().rvs(size=B) * args.alpha_scale).to(seq.device).float()
121
+ if args.fix_alpha:
122
+ alphas = torch.ones(B, device=seq.device) * args.fix_alpha
123
+ alphas_ = torch.ones(B, L, alphabet_size, device=seq.device)
124
+ alphas_ = alphas_ + seq_one_hot * (alphas[:,None,None] - 1)
125
+ xt = torch.distributions.Dirichlet(alphas_).sample()
126
+ elif args.mode == 'distill':
127
+ alphas = torch.zeros(B, device=seq.device)
128
+ xt = torch.distributions.Dirichlet(torch.ones(B, L, alphabet_size, device=seq.device)).sample()
129
+ elif args.mode == 'riemannian':
130
+ t = torch.rand(B, device=seq.device)
131
+ dirichlet = torch.distributions.Dirichlet(torch.ones(alphabet_size, device=seq.device))
132
+ x0 = dirichlet.sample((B,L))
133
+ x1 = seq_one_hot
134
+ xt = t[:,None,None] * x1 + (1 - t[:,None,None]) * x0
135
+ alphas = t
136
+ elif args.mode == 'ardm' or args.mode == 'lrar':
137
+ mask_prob = torch.rand(1, device=seq.device)
138
+ mask = torch.rand(seq.shape, device=seq.device) < mask_prob
139
+ if args.mode == 'lrar': mask = ~(torch.arange(L, device=seq.device) < (1-mask_prob) * L)
140
+ xt = torch.where(mask, alphabet_size, seq) # mask token index
141
+ xt = torch.nn.functional.one_hot(xt, num_classes=alphabet_size + 1).float() # plus one to include index for mask token
142
+ alphas = mask_prob.expand(B)
143
+ return xt, alphas
144
+
145
+ def expand_simplex(xt, alphas, prior_pseudocount):
146
+ prior_weights = (prior_pseudocount / (alphas + prior_pseudocount - 1))[:, None, None]
147
+ return torch.cat([xt * (1 - prior_weights), xt * prior_weights], -1), prior_weights
148
+
149
+
150
+ class DirichletConditionalFlow:
151
+ def __init__(self, K=20, alpha_min=1, alpha_max=100, alpha_spacing=0.01):
152
+ self.alphas = np.arange(alpha_min, alpha_max + alpha_spacing, alpha_spacing)
153
+ self.beta_cdfs = []
154
+ self.bs = np.linspace(0, 1, 1000)
155
+ for alph in self.alphas:
156
+ self.beta_cdfs.append(scipy.special.betainc(alph, K-1, self.bs))
157
+ self.beta_cdfs = np.array(self.beta_cdfs)
158
+ self.beta_cdfs_derivative = np.diff(self.beta_cdfs, axis=0) / alpha_spacing
159
+ self.K = K
160
+
161
+ def c_factor(self, bs, alpha):
162
+ out1 = scipy.special.beta(alpha, self.K - 1)
163
+ out2 = np.where(bs < 1, out1 / ((1 - bs) ** (self.K - 1)), 0)
164
+ out = np.where((bs ** (alpha - 1)) > 0, out2 / (bs ** (alpha - 1)), 0)
165
+ I_func = self.beta_cdfs_derivative[np.argmin(np.abs(alpha - self.alphas))]
166
+ interp = -np.interp(bs, self.bs, I_func)
167
+ final = interp * out
168
+ return final
169
+
170
+
171
+ class GaussianSmearing(torch.nn.Module):
172
+ # used to embed the edge distances
173
+ def __init__(self, start=0.0, stop=5.0, embedding_dim=50):
174
+ super().__init__()
175
+ offset = torch.linspace(start, stop, embedding_dim)
176
+ self.coeff = -0.5 / (offset[1] - offset[0]).item() ** 2
177
+ self.register_buffer("offset", offset)
178
+ self.embedding_dim = embedding_dim
179
+
180
+ def forward(self, signal):
181
+ shape = signal.shape
182
+ signal = signal.view(-1, 1) - self.offset.view(1, -1) + 1E-6
183
+ encoded = torch.exp(self.coeff * torch.pow(signal, 2))
184
+ return encoded.view(*shape, self.embedding_dim)
185
+
186
+
187
+ class MonotonicFunction(torch.nn.Module):
188
+ def __init__(self, init_max, num_bins):
189
+ super().__init__()
190
+ self.w = torch.nn.Parameter(torch.ones(num_bins) * np.log(init_max) - np.log(num_bins))
191
+ self.num_bins = num_bins
192
+
193
+ def forward(self, t):
194
+ widths = torch.exp(self.w)
195
+ right = torch.cumsum(widths, 0)
196
+ left = right - widths
197
+
198
+ bin_idx = (t * self.num_bins).long()
199
+ frac_part = t - bin_idx * (1 / self.num_bins)
200
+
201
+ return left[bin_idx] + (frac_part * self.num_bins) * (right[bin_idx] - left[bin_idx])
202
+
203
+ def invert(self, f):
204
+ widths = torch.exp(self.w)
205
+ left = torch.cumsum(widths, 0) - widths
206
+ bin_idx = (f.unsqueeze(-1) > left).sum(-1) - 1
207
+ frac_part = f - left[bin_idx]
208
+ return bin_idx / self.num_bins + frac_part / widths[bin_idx] / self.num_bins
209
+
210
+ def derivative(self, t):
211
+ widths = torch.exp(self.w)
212
+ right = torch.cumsum(widths, 0)
213
+ left = right - widths
214
+ bin_idx = (t * self.num_bins).long()
215
+ return (right[bin_idx] - left[bin_idx]) * self.num_bins
216
+
217
+ class SinusoidalEmbedding(nn.Module):
218
+ """ from https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py """
219
+ def __init__(self, embedding_dim, embedding_scale, max_positions=10000):
220
+ super().__init__()
221
+ self.embedding_dim = embedding_dim
222
+ self.max_positions = max_positions
223
+ self.embedding_scale = embedding_scale
224
+
225
+ def forward(self, signal):
226
+ shape = signal.shape
227
+ signal = signal.view(-1) * self.embedding_scale
228
+ half_dim = self.embedding_dim // 2
229
+ emb = math.log(self.max_positions) / (half_dim - 1)
230
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=signal.device) * -emb)
231
+ emb = signal.float()[:, None] * emb[None, :]
232
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
233
+ if self.embedding_dim % 2 == 1: # zero pad
234
+ emb = F.pad(emb, (0, 1), mode='constant')
235
+ assert emb.shape == (signal.shape[0], self.embedding_dim)
236
+ return emb.view(*shape, self.embedding_dim )
237
+
238
+
239
+ class GaussianFourierProjection(nn.Module):
240
+ """Gaussian Fourier embeddings for noise levels.
241
+ from https://github.com/yang-song/score_sde_pytorch/blob/1618ddea340f3e4a2ed7852a0694a809775cf8d0/models/layerspp.py#L32
242
+ """
243
+
244
+ def __init__(self, embedding_dim=256, scale=1.0):
245
+ super().__init__()
246
+ self.W = nn.Parameter(torch.randn(embedding_dim//2) * scale, requires_grad=False)
247
+ self.embedding_dim = embedding_dim
248
+
249
+ def forward(self, signal):
250
+ shape = signal.shape
251
+ signal = signal.view(-1)
252
+ signal_proj = signal[:, None] * self.W[None, :] * 2 * np.pi
253
+ emb = torch.cat([torch.sin(signal_proj), torch.cos(signal_proj)], dim=-1)
254
+ return emb.view(*shape, self.embedding_dim )
255
+
256
+ def get_signal_mapping(embedding_type, embedding_dim, embedding_scale=10000):
257
+ if embedding_type == 'sinusoidal':
258
+ emb_func = SinusoidalEmbedding(embedding_dim=embedding_dim, embedding_scale=embedding_scale)
259
+ elif embedding_type == 'fourier':
260
+ emb_func = GaussianFourierProjection(embedding_dim=embedding_dim, scale=embedding_scale)
261
+ elif embedding_type == 'gaussian':
262
+ emb_func = GaussianSmearing(0.0, 1, embedding_dim)
263
+ else:
264
+ raise NotImplemented
265
+ return emb_func
266
+
267
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
268
+ """
269
+ Create a beta schedule that discretizes the given alpha_t_bar function,
270
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
271
+
272
+ :param num_diffusion_timesteps: the number of betas to produce.
273
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
274
+ produces the cumulative product of (1-beta) up to that
275
+ part of the diffusion process.
276
+ :param max_beta: the maximum beta to use; use values lower than 1 to
277
+ prevent singularities.
278
+ """
279
+ betas = []
280
+ for i in range(num_diffusion_timesteps):
281
+ t1 = i / num_diffusion_timesteps
282
+ t2 = (i + 1) / num_diffusion_timesteps
283
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
284
+ return np.array(betas)
285
+
286
+ def get_beta_schedule(num_steps):
287
+
288
+ return betas_for_alpha_bar(
289
+ num_steps,
290
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
291
+ )
292
+
293
+
294
+ class GaussianDiffusionSchedule:
295
+ """
296
+ Utilities for training and sampling diffusion models.
297
+
298
+ Ported directly from here, and then adapted over time to further experimentation.
299
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
300
+
301
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
302
+ starting at T and going to 1.
303
+ :param model_mean_type: a ModelMeanType determining what the model outputs.
304
+ :param model_var_type: a ModelVarType determining how variance is output.
305
+ :param loss_type: a LossType determining the loss function to use.
306
+ :param rescale_timesteps: if True, pass floating point timesteps into the
307
+ model so that they are always scaled like in the
308
+ original paper (0 to 1000).
309
+ """
310
+
311
+ def __init__(
312
+ self,
313
+ timesteps,
314
+ noise_scale=1.0,
315
+ ):
316
+ betas = get_beta_schedule(timesteps)
317
+
318
+ # Use float64 for accuracy.
319
+ betas = np.array(betas, dtype=np.float64)
320
+ self.betas = betas
321
+ assert len(betas.shape) == 1, "betas must be 1-D"
322
+ assert (betas > 0).all() and (betas <= 1).all()
323
+
324
+ self.timesteps = int(betas.shape[0])
325
+ self.noise_scale = noise_scale
326
+
327
+ alphas = 1.0 - betas
328
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
329
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
330
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
331
+ assert self.alphas_cumprod_prev.shape == (self.timesteps,)
332
+
333
+ # calculations for diffusion q(x_t | x_{t-1}) and others
334
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
335
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
336
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
337
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
338
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
339
+
340
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
341
+ self.posterior_variance = (
342
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
343
+ )
344
+ # log calculation clipped because the posterior variance is 0 at the
345
+ # beginning of the diffusion chain.
346
+ self.posterior_log_variance_clipped = np.log(
347
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
348
+ )
349
+ self.posterior_mean_coef1 = (
350
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
351
+ )
352
+ self.posterior_mean_coef2 = (
353
+ (1.0 - self.alphas_cumprod_prev)
354
+ * np.sqrt(alphas)
355
+ / (1.0 - self.alphas_cumprod)
356
+ )
357
+
358
+ def q_sample(self, x_start, t, noise=None):
359
+ """
360
+ Diffuse the data for a given number of diffusion steps.
361
+
362
+ In other words, sample from q(x_t | x_0).
363
+
364
+ :param x_start: the initial data batch.
365
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
366
+ :param noise: if specified, the split-out normal noise.
367
+ :return: A noisy version of x_start.
368
+ """
369
+ if noise is None:
370
+ noise = self.noise_scale * torch.randn_like(x_start)
371
+ # add scaling here
372
+ assert noise.shape == x_start.shape
373
+ return (
374
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
375
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
376
+ * noise
377
+ )
378
+
379
+ def q_posterior_mean_variance(self, x_start, x_t, t):
380
+ """
381
+ Compute the mean and variance of the diffusion posterior:
382
+
383
+ q(x_{t-1} | x_t, x_0)
384
+
385
+ """
386
+ assert x_start.shape == x_t.shape
387
+ posterior_mean = (
388
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
389
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
390
+ )
391
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
392
+ posterior_log_variance_clipped = _extract_into_tensor(
393
+ self.posterior_log_variance_clipped, t, x_t.shape
394
+ )
395
+
396
+ posterior_variance = (self.noise_scale ** 2) * posterior_variance
397
+ posterior_log_variance_clipped = 2 * np.log(self.noise_scale) + posterior_log_variance_clipped
398
+
399
+ assert (
400
+ posterior_mean.shape[0]
401
+ == posterior_variance.shape[0]
402
+ == posterior_log_variance_clipped.shape[0]
403
+ == x_start.shape[0]
404
+ )
405
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
406
+
407
+
408
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
409
+ """
410
+ Extract values from a 1-D numpy array for a batch of indices.
411
+
412
+ :param arr: the 1-D numpy array.
413
+ :param timesteps: a tensor of indices into the array to extract.
414
+ :param broadcast_shape: a larger shape of K dimensions with the batch
415
+ dimension equal to the length of timesteps.
416
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
417
+ """
418
+ res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
419
+ while len(res.shape) < len(broadcast_shape):
420
+ res = res[..., None]
421
+ return res.expand(broadcast_shape)
422
+
423
+
424
+ def space_timesteps(num_timesteps, section_counts):
425
+ """
426
+ Create a list of timesteps to use from an original diffusion process,
427
+ given the number of timesteps we want to take from equally-sized portions
428
+ of the original process.
429
+
430
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
431
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
432
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
433
+
434
+ If the stride is a string starting with "ddim", then the fixed striding
435
+ from the DDIM paper is used, and only one section is allowed.
436
+
437
+ :param num_timesteps: the number of diffusion steps in the original
438
+ process to divide up.
439
+ :param section_counts: either a list of numbers, or a string containing
440
+ comma-separated numbers, indicating the step count
441
+ per section. As a special case, use "ddimN" where N
442
+ is a number of steps to use the striding from the
443
+ DDIM paper.
444
+ :return: a set of diffusion steps from the original process to use.
445
+ """
446
+ if isinstance(section_counts, str):
447
+ if section_counts.startswith("ddim"):
448
+ desired_count = int(section_counts[len("ddim"):])
449
+ for i in range(1, num_timesteps):
450
+ if len(range(0, num_timesteps, i)) == desired_count:
451
+ return set(range(0, num_timesteps, i))
452
+ raise ValueError(
453
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
454
+ )
455
+ section_counts = [int(x) for x in section_counts.split(",")]
456
+ size_per = num_timesteps // len(section_counts)
457
+ extra = num_timesteps % len(section_counts)
458
+ start_idx = 0
459
+ all_steps = []
460
+ for i, section_count in enumerate(section_counts):
461
+ size = size_per + (1 if i < extra else 0)
462
+ if size < section_count:
463
+ raise ValueError(
464
+ f"cannot divide section of {size} steps into {section_count}"
465
+ )
466
+ if section_count <= 1:
467
+ frac_stride = 1
468
+ else:
469
+ frac_stride = (size - 1) / (section_count - 1)
470
+ cur_idx = 0.0
471
+ taken_steps = []
472
+ for _ in range(section_count):
473
+ taken_steps.append(start_idx + round(cur_idx))
474
+ cur_idx += frac_stride
475
+ all_steps += taken_steps
476
+ start_idx += size
477
+ return set(all_steps)
478
+
479
+ def timestep_embedding(timesteps, dim, max_period=10000):
480
+ """
481
+ Create sinusoidal timestep embeddings.
482
+
483
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
484
+ These may be fractional.
485
+ :param dim: the dimension of the output.
486
+ :param max_period: controls the minimum frequency of the embeddings.
487
+ :return: an [N x dim] Tensor of positional embeddings.
488
+ """
489
+ half = dim // 2
490
+ freqs = torch.exp(
491
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
492
+ ).to(device=timesteps.device)
493
+ args = timesteps[:, None].float() * freqs[None]
494
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
495
+ if dim % 2:
496
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
497
+ return embedding
utils/parsing.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentParser
2
+ import math
3
+
4
+ def parse_guidance_args():
5
+ parser = ArgumentParser()
6
+
7
+ parser.add_argument("--num_div", type=int, default=64)
8
+ parser.add_argument("--lambda_", type=float, default=1.0)
9
+ parser.add_argument("--beta", type=float, default=1.0)
10
+ parser.add_argument("--alpha_r", type=float, default=0.5)
11
+ parser.add_argument("--eta", type=float, default=1.0)
12
+ parser.add_argument("--Phi_init", type=float, default=math.radians(45.0))
13
+ parser.add_argument("--Phi_min", type=float, default=math.radians(15.0))
14
+ parser.add_argument("--Phi_max", type=float, default=math.radians(75.0))
15
+ parser.add_argument("--tau", type=float, default=0.3)
16
+ parser.add_argument("--T", type=int, default=100)
17
+ parser.add_argument("--length", type=int, default=12)
18
+ parser.add_argument("--is_peptide", type=bool, default=True)
19
+ parser.add_argument("--n_samples", type=int, default=5)
20
+ parser.add_argument("--n_batches", type=int, default=2)
21
+ parser.add_argument("--target_protein", type=str, default="AAAAA")
22
+ parser.add_argument("--target_enhancer_class", type=int, default=0)
23
+ parser.add_argument("--target_DNA_shape", type=str, default='HelT')
24
+ parser.add_argument("--motifs", type=str, required=False)
25
+ parser.add_argument("--weights", type=float, nargs='+', required=False)
26
+ parser.add_argument("--output_file", type=str, default='moo_outputs.txt')
27
+ parser.add_argument("--motif_penalty", action='store_true')
28
+
29
+ args = parser.parse_args()
30
+ return args