AlienChen commited on
Commit
e588998
·
verified ·
1 Parent(s): 4f6a4c3

Upload 10 files

Browse files
Files changed (10) hide show
  1. classifier.py +490 -0
  2. dataloader.py +692 -0
  3. diffusion.py +1629 -0
  4. eval_utils.py +90 -0
  5. noise_schedule.py +160 -0
  6. requirements.yaml +49 -0
  7. sample.py +124 -0
  8. tokenizer.py +279 -0
  9. uncond_sample.py +116 -0
  10. utils.py +86 -0
classifier.py ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import typing
3
+
4
+ import hydra.utils
5
+ import lightning as L
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import torchmetrics
9
+ import transformers
10
+
11
+ import dataloader
12
+ import models.dit
13
+ import noise_schedule
14
+
15
+
16
+ class MicroAveragingMetric(torchmetrics.Metric):
17
+ """Micro-averaging metric.
18
+
19
+ Adapted from https://github.com/HazyResearch/hyena-dna/blob/main/src/tasks/metrics.py#L12
20
+ """
21
+
22
+ def __init__(self, class_idx: typing.Optional[int] = 1,
23
+ dist_sync_on_step=False):
24
+ super().__init__(dist_sync_on_step=dist_sync_on_step)
25
+ self.class_idx = torch.tensor(class_idx) \
26
+ if class_idx is not None else None
27
+ self.add_state("numerator", default=torch.tensor(0.0),
28
+ dist_reduce_fx="sum")
29
+ self.add_state("denominator", default=torch.tensor(0.0),
30
+ dist_reduce_fx="sum")
31
+
32
+ def _update(
33
+ self, numerator, denominator, preds, y) -> tuple:
34
+ raise NotImplementedError
35
+
36
+ def update(self, logits: torch.Tensor, y: torch.Tensor):
37
+ # update metric states
38
+ preds = torch.argmax(logits, dim=-1)
39
+ y = y.view(-1)
40
+ assert preds.shape == y.shape, \
41
+ f"preds shape {preds.shape} != y shape {y.shape}"
42
+ self.numerator, self.denominator = self._update(
43
+ self.numerator, self.denominator, preds, y)
44
+
45
+ def compute(self):
46
+ # compute final result
47
+ value = self.numerator.float() / self.denominator \
48
+ if self.denominator.item() > 0. else torch.tensor(0.0)
49
+ return value
50
+
51
+ def reset(self):
52
+ self.numerator = torch.tensor(0.0).to(self.device)
53
+ self.denominator = torch.tensor(0.0).to(self.device)
54
+
55
+
56
+ class CrossEntropy(MicroAveragingMetric):
57
+ """Calculates cross-entropy loss."""
58
+ def _update(
59
+ self, numerator, denominator, logits, y) -> tuple:
60
+ with torch.no_grad():
61
+ numerator += F.cross_entropy(
62
+ logits.view(-1, logits.size(-1)),
63
+ y.view(-1),
64
+ ignore_index=-100,
65
+ reduction='sum')
66
+ denominator += y.numel()
67
+ return numerator, denominator
68
+
69
+ # Overrides parent class to use logits and not (argmax) preds
70
+ def update(self, logits: torch.Tensor, y: torch.Tensor):
71
+ y = y.view(-1)
72
+ self.numerator, self.denominator = self._update(
73
+ self.numerator, self.denominator, logits, y)
74
+
75
+
76
+ class Accuracy(MicroAveragingMetric):
77
+ """Calculates accuracy.
78
+
79
+ Can be used to calculate accuracy per class.
80
+ Copied from:
81
+ https://github.com/HazyResearch/hyena-dna/blob/main/src/tasks/metrics.py
82
+ """
83
+
84
+ def _update(
85
+ self, numerator, denominator, preds, y) -> tuple:
86
+ if self.class_idx is None:
87
+ numerator += (preds == y).sum()
88
+ denominator += y.numel()
89
+ else:
90
+ class_idx = self.class_idx
91
+ relevant_idxs = (y == class_idx)
92
+ numerator += (preds[relevant_idxs] == class_idx).sum()
93
+ denominator += relevant_idxs.sum()
94
+ relevant_idxs = (y != class_idx)
95
+ numerator += (preds[relevant_idxs] != class_idx).sum()
96
+ denominator += relevant_idxs.sum()
97
+ return numerator, denominator
98
+
99
+
100
+ class Precision(MicroAveragingMetric):
101
+ """Calculates precision.
102
+
103
+ Can be used to calculate precision per class.
104
+ Adapted from:
105
+ https://github.com/HazyResearch/hyena-dna/blob/main/src/tasks/metrics.py
106
+ """
107
+
108
+ def _update(self, numerator, denominator, preds, y) -> tuple:
109
+ class_idx = self.class_idx
110
+ relevant_idxs = (preds == class_idx)
111
+ numerator += (y[relevant_idxs] == class_idx).sum()
112
+ denominator += relevant_idxs.sum()
113
+ return numerator, denominator
114
+
115
+
116
+ class Recall(MicroAveragingMetric):
117
+ """Calculate recall.
118
+
119
+ Can be used to calculate recall per class.
120
+ Adapted from:
121
+ https://github.com/HazyResearch/hyena-dna/blob/main/src/tasks/metrics.py
122
+ """
123
+
124
+ def _update(self, numerator, denominator, preds, y) -> tuple:
125
+ class_idx = self.class_idx
126
+ relevant_idxs = (y == class_idx)
127
+ numerator += (preds[relevant_idxs] == class_idx).sum()
128
+ denominator += relevant_idxs.sum()
129
+ return numerator, denominator
130
+
131
+
132
+ class Classifier(L.LightningModule):
133
+ def __init__(
134
+ self,
135
+ config,
136
+ tokenizer: transformers.PreTrainedTokenizer,
137
+ pretrained_backbone: typing.Optional[torch.nn.Module] = None):
138
+ super().__init__()
139
+ self.save_hyperparameters(ignore=['pretrained_backbone'])
140
+ self.config = config
141
+
142
+ # This param indicates whether this model will be used
143
+ # for guidance (False) or only evaluation (True).
144
+ self.is_eval_classifier = getattr(
145
+ config, 'is_eval_classifier', False)
146
+
147
+ self.tokenizer = tokenizer
148
+ self.vocab_size = tokenizer.vocab_size
149
+ self.antithetic_sampling = config.training.antithetic_sampling
150
+ self.importance_sampling = config.training.importance_sampling
151
+ self.change_of_variables = config.training.change_of_variables
152
+ if (not hasattr(self.tokenizer, 'mask_token')
153
+ or self.tokenizer.mask_token is None):
154
+ self.mask_index = self.vocab_size
155
+ self.vocab_size += 1
156
+ else:
157
+ self.mask_index = self.tokenizer.mask_token_id
158
+
159
+ if config.classifier_backbone == 'dit':
160
+ self.classifier_model = models.dit.DITClassifier(
161
+ self.config, vocab_size=self.vocab_size)
162
+ elif self.config.classifier_backbone == 'dimamba':
163
+ self.classifier_model = models.dimamba.DiMambaClassifier(
164
+ self.config, vocab_size=self.vocab_size,
165
+ pad_token_id=self.tokenizer.pad_token_id)
166
+ elif config.classifier_backbone == 'hyenadna':
167
+ hyena_config = transformers.AutoConfig.from_pretrained(
168
+ config.classifier_model.hyena_model_name_or_path,
169
+ n_layer=config.classifier_model.n_layer,
170
+ trust_remote_code=True
171
+ )
172
+ self.classifier_model = transformers.AutoModelForSequenceClassification.from_config(
173
+ hyena_config,
174
+ pretrained=False,
175
+ num_labels=config.data.num_classes,
176
+ problem_type='single_label_classification',
177
+ trust_remote_code=True
178
+ )
179
+ else:
180
+ raise NotImplementedError(
181
+ f"Classifier backbone "
182
+ f"{self.config.classifier_backbone} not "
183
+ f"implemented.")
184
+ if pretrained_backbone is not None: # For PPLM / NOS
185
+ self.classifier_model.load_pretrained_encoder(
186
+ pretrained_backbone)
187
+ # Metrics are automatically reset at end of epoch
188
+ metrics = torchmetrics.MetricCollection({
189
+ 'cross_entropy': CrossEntropy(),
190
+ 'accuracy': Accuracy(class_idx=None),
191
+ })
192
+ if config.data.num_classes > 2:
193
+ for c in range(config.data.num_classes):
194
+ metrics.add_metrics(
195
+ {f"accuracy_class{c}": Accuracy(class_idx=c),
196
+ f"precision_class{c}": Precision(class_idx=c),
197
+ f"recall_class{c}": Recall(class_idx=c)})
198
+ else:
199
+ metrics.add_metrics(
200
+ {'precision': Precision(class_idx=1),
201
+ 'recall': Recall(class_idx=1)})
202
+ metrics.set_dtype(torch.float64)
203
+ self.train_metrics = metrics.clone(prefix='train/')
204
+ self.valid_metrics = metrics.clone(prefix='val/')
205
+
206
+ self.T = config.T
207
+ self.noise = noise_schedule.get_noise(config,
208
+ dtype=self.dtype)
209
+ self.sampling_eps = config.training.sampling_eps
210
+ self.lr = config.optim.lr
211
+ self.time_conditioning = config.time_conditioning
212
+ self.fast_forward_epochs = None
213
+ self.fast_forward_batches = None
214
+
215
+ def on_load_checkpoint(self, checkpoint):
216
+ # Copied from:
217
+ # https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py#L41
218
+ self.fast_forward_epochs = checkpoint['loops'][
219
+ 'fit_loop']['epoch_progress']['current']['completed']
220
+ self.fast_forward_batches = checkpoint['loops'][
221
+ 'fit_loop']['epoch_loop.batch_progress'][
222
+ 'current']['completed']
223
+
224
+ def on_save_checkpoint(self, checkpoint):
225
+ # Copied from:
226
+ # https://github.com/Dao-AILab/flash-attention/blob/main/training/src/tasks/seq.py
227
+ # ['epoch_loop.batch_progress']['total']['completed'] is
228
+ # 1 iteration behind, so we're using the optimizer's
229
+ # progress.
230
+ checkpoint['loops']['fit_loop'][
231
+ 'epoch_loop.batch_progress']['total'][
232
+ 'completed'] = checkpoint['loops']['fit_loop'][
233
+ 'epoch_loop.automatic_optimization.optim_progress'][
234
+ 'optimizer']['step']['total'][
235
+ 'completed'] * self.trainer.accumulate_grad_batches
236
+ checkpoint['loops']['fit_loop'][
237
+ 'epoch_loop.batch_progress']['current'][
238
+ 'completed'] = checkpoint['loops']['fit_loop'][
239
+ 'epoch_loop.automatic_optimization.optim_progress'][
240
+ 'optimizer']['step']['current'][
241
+ 'completed'] * self.trainer.accumulate_grad_batches
242
+ # _batches_that_stepped tracks the number of global
243
+ # steps, not the number of local steps, so we don't
244
+ # multiply with self.trainer.accumulate_grad_batches
245
+ # here.
246
+ checkpoint['loops']['fit_loop'][
247
+ 'epoch_loop.state_dict'][
248
+ '_batches_that_stepped'] = \
249
+ checkpoint['loops']['fit_loop'][
250
+ 'epoch_loop.automatic_optimization.optim_progress'][
251
+ 'optimizer']['step']['total']['completed']
252
+ if 'sampler' not in checkpoint.keys():
253
+ checkpoint['sampler'] = {}
254
+ if hasattr(self.trainer.train_dataloader.sampler,
255
+ 'state_dict'):
256
+ sampler_state_dict = self.trainer. \
257
+ train_dataloader.sampler.state_dict()
258
+ checkpoint['sampler'][
259
+ 'random_state'] = sampler_state_dict.get(
260
+ 'random_state', None)
261
+ else:
262
+ checkpoint['sampler']['random_state'] = None
263
+
264
+ def on_train_start(self):
265
+ # Adapted from:
266
+ # https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py
267
+ distributed = (
268
+ self.trainer._accelerator_connector.use_distributed_sampler
269
+ and self.trainer._accelerator_connector.is_distributed)
270
+ if distributed:
271
+ sampler_cls = dataloader.FaultTolerantDistributedSampler
272
+ else:
273
+ sampler_cls = dataloader.RandomFaultTolerantSampler
274
+ updated_dls = []
275
+ for dl in self.trainer.fit_loop._combined_loader.flattened:
276
+ if hasattr(dl.sampler, 'shuffle'):
277
+ dl_sampler = sampler_cls(
278
+ dl.dataset, shuffle=dl.sampler.shuffle)
279
+ else:
280
+ dl_sampler = sampler_cls(dl.dataset)
281
+ if (distributed
282
+ and self.fast_forward_epochs is not None
283
+ and self.fast_forward_batches is not None):
284
+ dl_sampler.load_state_dict({
285
+ 'epoch': self.fast_forward_epochs,
286
+ 'counter': (self.fast_forward_batches
287
+ * self.config.loader.batch_size)})
288
+ updated_dls.append(
289
+ torch.utils.data.DataLoader(
290
+ dl.dataset,
291
+ batch_size=self.config.loader.batch_size,
292
+ num_workers=self.config.loader.num_workers,
293
+ pin_memory=self.config.loader.pin_memory,
294
+ sampler=dl_sampler,
295
+ shuffle=False,
296
+ persistent_workers=self.config.loader.persistent_workers
297
+ ))
298
+ self.trainer.fit_loop._combined_loader.flattened = updated_dls
299
+
300
+ def forward(self, x, sigma=None, x_emb=None, attention_mask=None):
301
+ """Returns logits.
302
+
303
+ x_emb can be provided during PPLM / NoS-style guidance
304
+ (see: https://arxiv.org/abs/2305.20009).
305
+ """
306
+ if self.is_eval_classifier:
307
+ logits = self.classifier_model(x)
308
+ if hasattr(logits, 'logits'):
309
+ logits = logits.logits
310
+ else:
311
+ sigma = self._process_sigma(sigma) if sigma is not None else sigma
312
+ with torch.cuda.amp.autocast(dtype=torch.float32):
313
+ logits = self.classifier_model(x, sigma, x_emb=x_emb, attention_mask=attention_mask)
314
+ return logits
315
+
316
+ def get_log_probs(self, x, sigma, x_emb=None):
317
+ """Returns log probabilities.
318
+ Use for CBG-style guidance.
319
+ """
320
+ if self.is_eval_classifier:
321
+ raise NotImplementedError(
322
+ '`get_log_prob` not implemented for classifiers '
323
+ 'that are meant to be used for evaluation purposes '
324
+ 'only.')
325
+ with torch.cuda.amp.autocast(dtype=torch.float32):
326
+ return torch.nn.functional.log_softmax(
327
+ self.forward(x, sigma, x_emb=x_emb), dim=-1)
328
+
329
+ def training_step(self, batch, batch_idx):
330
+ loss = self._compute_loss(batch, prefix='train')
331
+ self.log(name='trainer/loss',
332
+ value=loss.item(),
333
+ on_step=True,
334
+ on_epoch=False,
335
+ sync_dist=True,
336
+ prog_bar=True)
337
+ self.log(name='lr',
338
+ value=
339
+ self.trainer.optimizers[0].param_groups[0][
340
+ 'lr'],
341
+ on_step=True,
342
+ on_epoch=False,
343
+ sync_dist=True,
344
+ prog_bar=True, logger=False)
345
+ return loss
346
+
347
+ def validation_step(self, batch, batch_idx):
348
+ return self._compute_loss(batch, prefix='val')
349
+
350
+ def configure_optimizers(self):
351
+ # TODO(yair): Lightning currently giving this warning when using `fp16`:
352
+ # "Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
353
+ # Not clear if this is a problem or not.
354
+ # See: https://github.com/Lightning-AI/pytorch-lightning/issues/5558
355
+ optimizer = torch.optim.AdamW(
356
+ itertools.chain(self.classifier_model.parameters(),
357
+ self.noise.parameters()),
358
+ lr=self.config.optim.lr,
359
+ betas=(self.config.optim.beta1,
360
+ self.config.optim.beta2),
361
+ eps=self.config.optim.eps,
362
+ weight_decay=self.config.optim.weight_decay)
363
+
364
+ scheduler = hydra.utils.instantiate(
365
+ self.config.lr_scheduler, optimizer=optimizer)
366
+ scheduler_dict = {
367
+ 'scheduler': scheduler,
368
+ 'interval': 'step',
369
+ 'monitor': 'val/loss',
370
+ 'name': 'trainer/lr',
371
+ }
372
+ return [optimizer], [scheduler_dict]
373
+
374
+ def _q_xt(self, x, move_chance):
375
+ """Computes the noisy sample xt.
376
+
377
+ Args:
378
+ x: int torch.Tensor with shape (batch_size,
379
+ diffusion_model_input_length), input.
380
+ move_chance: float torch.Tensor with shape
381
+ (batch_size, 1).
382
+ """
383
+ move_indices = torch.rand(
384
+ *x.shape, device=x.device) < move_chance
385
+ if self.config.diffusion == 'absorbing_state':
386
+ return torch.where(move_indices, self.mask_index, x)
387
+ if self.config.diffusion == 'uniform':
388
+ uniform_tensor = torch.randint(
389
+ 0, self.vocab_size, x.shape, device=x.device)
390
+ return torch.where(move_indices, uniform_tensor, x)
391
+ raise NotImplementedError(
392
+ f'Diffusion type {self.config.diffusion} not '
393
+ 'implemented.')
394
+
395
+ def _compute_loss(self, batch, prefix):
396
+ x0 = batch['input_ids']
397
+ attention_mask = batch['attention_mask']
398
+ t = None
399
+ if self.is_eval_classifier:
400
+ logits = self.forward(x0)
401
+ elif self.config.parameterization == 'ar':
402
+ # do not add noise for AR FUDGE and AR PPLM
403
+ logits = self.forward(
404
+ x0, attention_mask=attention_mask)
405
+ else:
406
+ t = self._sample_t(x0.shape[0])
407
+ if self.T > 0:
408
+ t = (t * self.T).to(torch.int)
409
+ t = t / self.T
410
+ # t \in {1/T, 2/T, ..., 1}
411
+ t += (1 / self.T)
412
+ if self.change_of_variables:
413
+ time_conditioning = t[:, None]
414
+ f_T = torch.log1p(- torch.exp(- self.noise.sigma_max))
415
+ f_0 = torch.log1p(- torch.exp(- self.noise.sigma_min))
416
+ move_chance = torch.exp(f_0 + t * (f_T - f_0))
417
+ move_chance = move_chance[:, None]
418
+ else:
419
+ sigma, _ = self.noise(t)
420
+ time_conditioning = sigma[:, None]
421
+ move_chance = 1 - torch.exp(-sigma[:, None])
422
+
423
+ xt = self._q_xt(x0, move_chance)
424
+ logits = self.forward(xt, time_conditioning, attention_mask=attention_mask)
425
+ if hasattr(self.config.data, 'label_col'):
426
+ if f"{self.config.data.label_col}_threshold" in batch:
427
+ y = batch[f"{self.config.data.label_col}_threshold"]
428
+ else:
429
+ y = batch[self.config.data.label_col]
430
+ else:
431
+ y = batch['label']
432
+ if (not self.is_eval_classifier
433
+ and getattr(self.config.training, 'use_label_smoothing', False)):
434
+ # Interpolate between one-hot and uniform distribution
435
+ labels = (torch.nn.functional.one_hot(y, self.config.data.num_classes) * (1 - t)[..., None] +
436
+ (1 / self.config.data.num_classes) * t[..., None])
437
+ else:
438
+ labels = y.view(-1)
439
+ if getattr(self.config, 'is_fudge_classifier', False):
440
+ expanded_y = y.unsqueeze(1).expand(-1, logits.shape[1]) # batch x seq
441
+ logits = logits.view(-1, self.config.data.num_classes)[attention_mask.flatten()==1, ...]
442
+ y = expanded_y.flatten().long()[attention_mask.flatten()==1]
443
+ loss = torch.nn.functional.cross_entropy(
444
+ logits,
445
+ y,
446
+ ignore_index=-100,
447
+ reduction='mean')
448
+ else:
449
+ loss = torch.nn.functional.cross_entropy(
450
+ logits.view(-1, logits.size(-1)),
451
+ labels,
452
+ ignore_index=-100,
453
+ reduction='mean')
454
+
455
+ if prefix == 'train':
456
+ self.train_metrics.update(logits, y)
457
+ metrics = self.train_metrics
458
+ elif prefix == 'val':
459
+ self.valid_metrics.update(logits, y)
460
+ metrics = self.valid_metrics
461
+ elif prefix == 'test':
462
+ self.test_metrics.update(logits, y)
463
+ metrics = self.test_metrics
464
+ else:
465
+ raise ValueError(f'Invalid prefix: {prefix}')
466
+
467
+ self.log_dict(metrics,
468
+ on_step=False,
469
+ on_epoch=True,
470
+ sync_dist=True)
471
+ return loss
472
+
473
+ def _sample_t(self, n):
474
+ _eps_t = torch.rand(n, device=self.device)
475
+ if self.antithetic_sampling:
476
+ offset = torch.arange(n, device=self.device) / n
477
+ _eps_t = (_eps_t / n + offset) % 1
478
+ t = (1 - self.sampling_eps) * _eps_t + self.sampling_eps
479
+ if self.importance_sampling:
480
+ return self.noise.importance_sampling_transformation(
481
+ t)
482
+ return t
483
+
484
+ def _process_sigma(self, sigma):
485
+ if sigma.ndim > 1:
486
+ sigma = sigma.squeeze(-1)
487
+ if not self.time_conditioning:
488
+ sigma = torch.zeros_like(sigma)
489
+ assert sigma.ndim == 1, sigma.shape
490
+ return sigma
dataloader.py ADDED
@@ -0,0 +1,692 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import itertools
3
+ import math
4
+ import os
5
+ import re
6
+ import shutil
7
+ import typing
8
+ import urllib
9
+ import zipfile
10
+
11
+ import datasets
12
+ import fsspec
13
+ import numpy as np
14
+ import tokenizers
15
+ import torch
16
+ import transformers
17
+ import lightning as L
18
+ from torch.utils.data import DataLoader, Subset
19
+ from functools import partial
20
+ import pdb
21
+
22
+ import custom_datasets.discretized_cifar10
23
+ import custom_datasets.ten_species_dataset
24
+ import utils
25
+
26
+ LOGGER = utils.get_logger(__name__)
27
+
28
+
29
+ # noinspection RegExpRedundantEscape
30
+ def lm1b_detokenizer(x):
31
+ x = x.replace('http : / / ', 'http://')
32
+ x = x.replace('https : / / ', 'https://')
33
+ x = re.sub(r' \'(\w+)', r"'\1", x)
34
+ x = re.sub(r' (\w+) \. ', r' \1. ', x)
35
+ x = re.sub(r' (\w+) \.$', r' \1.', x)
36
+ x = x.replace(' ? ', '? ')
37
+ x = re.sub(r' \?$', '?', x)
38
+ x = x.replace(' ! ', '! ')
39
+ x = re.sub(r' \!$', '!', x)
40
+ x = x.replace(' , ', ', ')
41
+ x = x.replace(' : ', ': ')
42
+ x = x.replace(' ; ', '; ')
43
+ x = x.replace(' / ', '/')
44
+ x = re.sub(r'\" ([^\"]+) \"', r'"\1"', x)
45
+ x = re.sub(r'\' ([^\']+) \'', r"'\1'", x)
46
+ x = re.sub(r'\( ([^\(\)]+) \)', r"(\1)", x)
47
+ x = re.sub(r'\[ ([^\[\]]+) \]', r"[\1]", x)
48
+ x = x.replace('$ ', '$')
49
+ x = x.replace('£ ', '£')
50
+ return x
51
+
52
+
53
+ class Text8Tokenizer(transformers.PreTrainedTokenizer):
54
+ def __init__(
55
+ self,
56
+ bos_token='[BOS]',
57
+ eos_token='[EOS]',
58
+ sep_token='[SEP]',
59
+ cls_token='[CLS]',
60
+ pad_token='[PAD]',
61
+ mask_token='[MASK]',
62
+ unk_token='[UNK]',
63
+ **kwargs):
64
+ self.characters = list('abcdefghijklmnopqrstuvwxyz ')
65
+ self._vocab_str_to_int = {
66
+ '[CLS]': 0,
67
+ '[SEP]': 1,
68
+ '[BOS]': 2,
69
+ '[EOS]': 3,
70
+ '[MASK]': 4,
71
+ '[PAD]': 5,
72
+ '[RESERVED]': 6,
73
+ '[UNK]': 7,
74
+ ** {ch: i + 8 for i, ch in enumerate(self.characters)}}
75
+ self._vocab_int_to_str = {
76
+ v: k for k, v in self._vocab_str_to_int.items()}
77
+ super().__init__(
78
+ bos_token=bos_token,
79
+ eos_token=eos_token,
80
+ sep_token=sep_token,
81
+ cls_token=cls_token,
82
+ pad_token=pad_token,
83
+ mask_token=mask_token,
84
+ unk_token=unk_token,
85
+ **kwargs)
86
+
87
+ @property
88
+ def vocab_size(self) -> int:
89
+ return len(self._vocab_str_to_int)
90
+
91
+ def _tokenize(self, text: str, **kwargs) -> typing.List[str]:
92
+ return list(text.lower())
93
+
94
+ def _convert_token_to_id(self, token: str) -> int:
95
+ return self._vocab_str_to_int.get(
96
+ token, self._vocab_str_to_int['[UNK]'])
97
+
98
+ def _convert_id_to_token(self, index: int) -> str:
99
+ return self._vocab_int_to_str[index]
100
+
101
+ def convert_tokens_to_string(self, tokens):
102
+ return ''.join(tokens)
103
+
104
+ def get_vocab(self) -> typing.Dict[str, int]:
105
+ return self._vocab_str_to_int
106
+
107
+
108
+ def get_text8_dataset(cache_dir, max_seq_length=256,
109
+ drop_last=True, crop_train=False):
110
+ """Adapted from:
111
+ https://github.com/google-research/google-research/blob/master/d3pm/text/datasets.py#L344
112
+
113
+ Args:
114
+ cache_dir: str, path to cache directory.
115
+ max_seq_length: int, maximum length of sequences.
116
+ (default: 256, as in D3PM codebase.)
117
+ drop_last: bool, whether to drop the last incomplete
118
+ batch. (default: True, as in D3PM codebase.)
119
+ crop_train: bool, whether to subsample contiguous
120
+ subsequences from training example. serves to
121
+ make sure transformer models with absolute position
122
+ embeddings do not have incorrect position-wise
123
+ marginals. (default: False, but necessary to match D3PM AR)
124
+
125
+ Returns:
126
+ dataset: dataset.DatasetDict, with keys 'train',
127
+ 'valid', 'test'.
128
+ """
129
+ url = 'http://mattmahoney.net/dc/text8.zip'
130
+ if not crop_train:
131
+ cache_dir = f'{cache_dir}/text8'
132
+ else:
133
+ cache_dir = f'{cache_dir}/text8-crop-train'
134
+ split_names = ['train', 'validation', 'test']
135
+ if not all([
136
+ utils.fsspec_exists(os.path.join(cache_dir, split))
137
+ for split in split_names
138
+ ]):
139
+ # Check if raw data exists
140
+ raw_cache_dir = os.path.join(cache_dir, 'raw_data')
141
+ if not all([
142
+ utils.fsspec_exists(
143
+ os.path.join(raw_cache_dir, f'text8.{split}.txt'))
144
+ for split in split_names
145
+ ]):
146
+ if not utils.fsspec_exists(
147
+ os.path.join(raw_cache_dir, 'text8.zip')):
148
+ utils.fsspec_mkdirs(raw_cache_dir, exist_ok=True)
149
+ LOGGER.info('Downloading text8 from URL {}.'.format(url))
150
+ with (urllib.request.urlopen(url) as in_stream,
151
+ open(os.path.join(raw_cache_dir, 'text8.zip'),
152
+ 'wb') as out_file):
153
+ shutil.copyfileobj(in_stream, out_file)
154
+
155
+ with fsspec.open(
156
+ os.path.join(raw_cache_dir, 'text8.zip'),
157
+ 'rb') as f:
158
+ rawdata = zipfile.ZipFile(f).read(
159
+ 'text8').decode('utf-8')
160
+
161
+ # Splits taken from D3PM codebase
162
+ splits = {
163
+ 'train': rawdata[:90_000_000],
164
+ 'validation': rawdata[90_000_000: 95_000_000],
165
+ 'test': rawdata[95_000_000:],
166
+ }
167
+
168
+ for split, data in splits.items():
169
+ _path = os.path.join(raw_cache_dir,
170
+ f'text8.{split}.txt')
171
+ with fsspec.open(_path, 'w') as f:
172
+ f.write(data)
173
+ else:
174
+ splits = {}
175
+ for split in split_names:
176
+ _path = os.path.join(raw_cache_dir,
177
+ f'text8.{split}.txt')
178
+ with fsspec.open(_path, 'r') as f:
179
+ splits[split] = f.read()
180
+
181
+ # Chunk and save as datasets.DatasetDict
182
+ def chunks(lst, n):
183
+ """Yield successive n-sized chunks from lst."""
184
+ for i in range(0, len(lst), n):
185
+ yield lst[i:i + n]
186
+
187
+ dataset_dict = {}
188
+ for k, v in splits.items():
189
+ if k == 'train' and crop_train == True:
190
+ chunk_size = 2 * max_seq_length
191
+ else:
192
+ chunk_size = max_seq_length
193
+ text = list(chunks(v, chunk_size))
194
+ if drop_last and len(text[-1]) < chunk_size:
195
+ text = text[:-1]
196
+ dataset_dict[k] = datasets.Dataset.from_dict({'text': text})
197
+ dataset = datasets.DatasetDict(dataset_dict)
198
+ dataset.save_to_disk(cache_dir)
199
+ else:
200
+ dataset = datasets.load_from_disk(cache_dir)
201
+
202
+ return dataset
203
+
204
+
205
+ def _group_texts(examples, block_size, bos, eos,
206
+ add_special_tokens=True):
207
+ # Concatenate all texts.
208
+ concatenated_examples = list(itertools.chain(* examples['input_ids']))
209
+ total_length = len(concatenated_examples)
210
+ # TODO(yair): look into not dropping the remainder but rather padding it.
211
+ # We drop the small remainder, and if the total_length < block_size - 2
212
+ # we exclude this batch and return an empty dict.
213
+ # We could add padding if the model supported it instead of
214
+ # this drop, you can customize this part to your needs.
215
+ # `-2` to account for [BOS] and [EOS] to be added below
216
+ new_block_size = block_size - (2 if add_special_tokens else 0)
217
+ total_length = (total_length // new_block_size) * new_block_size
218
+ # Split by chunks of max_len.
219
+ result = {}
220
+ _values = []
221
+ _attn_masks = []
222
+ for i in range(0, total_length, new_block_size):
223
+ if add_special_tokens:
224
+ _values.append(
225
+ [bos]
226
+ + concatenated_examples[i : i + new_block_size]
227
+ + [eos])
228
+ else:
229
+ _values.append(
230
+ concatenated_examples[i: i + new_block_size])
231
+ _attn_masks.append(torch.ones(block_size))
232
+ result['input_ids'] = _values
233
+ result['attention_mask'] = _attn_masks
234
+ return result
235
+
236
+
237
+ def get_dataset(
238
+ dataset_name, tokenizer, wrap, mode, cache_dir,
239
+ block_size=1024, num_proc=len(os.sched_getaffinity(0)),
240
+ streaming=False, override_cache=False,
241
+ add_special_tokens=True,
242
+ label_col=None, label_threshold=None):
243
+ if label_col is not None:
244
+ label_suffix = f'_label-{label_col}'
245
+ if label_threshold is not None:
246
+ label_suffix += f'_threshold-{label_threshold}'
247
+ else:
248
+ label_suffix = ''
249
+ if wrap:
250
+ filename = f'{dataset_name}_{mode}_bs{block_size}_wrapped{label_suffix}.dat'
251
+ else:
252
+ filename = f'{dataset_name}_{mode}_bs{block_size}_unwrapped{label_suffix}.dat'
253
+ _path = os.path.join(cache_dir, filename)
254
+ if utils.fsspec_exists(_path) and not override_cache:
255
+ LOGGER.info(f'Loading data from: {_path}')
256
+ return datasets.load_from_disk(_path).with_format('torch')
257
+ LOGGER.info(f'Generating new data at: {_path}')
258
+
259
+ crop_train = dataset_name == 'text8-crop'
260
+ if mode == 'train' and crop_train:
261
+ # double block size for subsampling
262
+ block_size *= 2
263
+
264
+ if dataset_name == 'text8':
265
+ assert wrap
266
+ dataset = get_text8_dataset(
267
+ cache_dir, max_seq_length=block_size)
268
+ elif dataset_name == 'amazon_polarity':
269
+ dataset = datasets.load_dataset(
270
+ 'amazon_polarity',
271
+ cache_dir=cache_dir,
272
+ streaming=streaming)
273
+ elif dataset_name == 'qm9':
274
+ dataset = datasets.load_dataset(
275
+ 'yairschiff/qm9',
276
+ cache_dir=cache_dir,
277
+ streaming=streaming,
278
+ split='train') # Dataset only has 'train' split
279
+ if label_threshold is not None:
280
+ pctiles = label_threshold if isinstance(label_threshold, list) \
281
+ else [label_threshold]
282
+ pctile_values = np.percentile(dataset[label_col],
283
+ q=pctiles)
284
+ threshold = np.ones(len(dataset[label_col])) * len(pctiles)
285
+ for i, p in reversed(list(enumerate(sorted(pctile_values)))):
286
+ threshold[dataset[label_col] <= p] = i
287
+ dataset = dataset.add_column(
288
+ f"{label_col}_threshold", threshold.astype(int))
289
+ label_col = f"{label_col}_threshold"
290
+ dataset = dataset.train_test_split(
291
+ test_size=0.05, seed=42) # hard-coded seed & size
292
+ dataset = dataset[mode]
293
+ elif dataset_name == 'ten_species':
294
+ return custom_datasets.ten_species_dataset.TenSpeciesDataset(
295
+ split=mode,
296
+ tokenizer=tokenizer,
297
+ max_length=block_size,
298
+ rc_aug=False, # TODO: find way to pass this
299
+ add_special_tokens=add_special_tokens)
300
+ else:
301
+ dataset = datasets.load_dataset(
302
+ dataset_name,
303
+ cache_dir=cache_dir,
304
+ streaming=streaming)
305
+
306
+ if dataset_name == 'qm9':
307
+ data = dataset
308
+ else:
309
+ data = dataset[mode]
310
+
311
+ if dataset_name == 'lm1b':
312
+ detokenizer = lm1b_detokenizer
313
+ else:
314
+ detokenizer = None
315
+
316
+ def _apply_detokenizer(detoker):
317
+ def detok(text):
318
+ for j, t in enumerate(text, 0):
319
+ text[j] = detoker(t)
320
+ return text
321
+ return detok
322
+
323
+ EOS = tokenizer.encode(tokenizer.eos_token)[0]
324
+ BOS = tokenizer.encode(tokenizer.bos_token)[0]
325
+
326
+ def preprocess_and_tokenize(example):
327
+ if 'amazon_polarity' in dataset_name:
328
+ text = example['content']
329
+ elif 'qm9' in dataset_name:
330
+ text = example['canonical_smiles']
331
+ elif dataset_name == 'ten_species':
332
+ text = example['sequence']
333
+ else:
334
+ text = example['text']
335
+
336
+ if detokenizer is not None:
337
+ text = _apply_detokenizer(detokenizer)(text)
338
+
339
+ tokenizer.padding_side = 'right'
340
+ tokenizer.truncation_side = 'right'
341
+
342
+ if wrap:
343
+ tokens = tokenizer(text,
344
+ add_special_tokens=False,
345
+ return_attention_mask=False,
346
+ return_token_type_ids=False)
347
+ if add_special_tokens:
348
+ tokens = {'input_ids':
349
+ [t + [EOS] for t in tokens['input_ids']]}
350
+ # Still missing BOS; will be added in group_texts
351
+ else:
352
+ tokens = {'input_ids': tokens['input_ids']}
353
+ else:
354
+ tokens = tokenizer(text,
355
+ max_length=block_size,
356
+ padding='max_length',
357
+ truncation=True,
358
+ add_special_tokens=add_special_tokens,
359
+ return_attention_mask=True,
360
+ return_token_type_ids=add_special_tokens)
361
+ return tokens
362
+
363
+ if streaming:
364
+ tokenized_dataset = data.map(
365
+ preprocess_and_tokenize,
366
+ batched=True,
367
+ desc='Tokenizing')
368
+ else:
369
+ tokenized_dataset = data.map(
370
+ preprocess_and_tokenize,
371
+ batched=True,
372
+ num_proc=num_proc,
373
+ load_from_cache_file=True,
374
+ desc='Tokenizing')
375
+ keep_cols = ['input_ids', 'token_type_ids',
376
+ 'attention_mask']
377
+ if label_col is not None:
378
+ keep_cols.append(label_col)
379
+ tokenized_dataset = tokenized_dataset.remove_columns(
380
+ [col for col in tokenized_dataset.column_names
381
+ if col not in keep_cols])
382
+
383
+ if not wrap:
384
+ tokenized_dataset.save_to_disk(_path)
385
+ return tokenized_dataset.with_format('torch')
386
+
387
+ group_texts = functools.partial(
388
+ _group_texts, block_size=block_size, bos=BOS, eos=EOS,
389
+ add_special_tokens=add_special_tokens)
390
+ if streaming:
391
+ chunked_dataset = tokenized_dataset.map(
392
+ group_texts,
393
+ batched=True,
394
+ desc='Grouping')
395
+ else:
396
+ chunked_dataset = tokenized_dataset.map(
397
+ group_texts,
398
+ batched=True,
399
+ num_proc=num_proc,
400
+ load_from_cache_file=True,
401
+ desc='Grouping')
402
+ chunked_dataset.save_to_disk(_path)
403
+ chunked_dataset = chunked_dataset.with_format('torch')
404
+ return chunked_dataset
405
+
406
+
407
+ def get_tokenizer(config):
408
+ if config.data.tokenizer_name_or_path == 'text8':
409
+ tokenizer = Text8Tokenizer()
410
+ elif config.data.tokenizer_name_or_path == 'bert-base-uncased':
411
+ tokenizer = transformers.BertTokenizer.\
412
+ from_pretrained('bert-base-uncased')
413
+ elif config.data.tokenizer_name_or_path == 'raw_pixels':
414
+ tokenizer = custom_datasets.discretized_cifar10.DummyVisionTokenizer(
415
+ 256, 32,
416
+ add_mask_token=config.data.add_mask_token,
417
+ add_special_tokens=config.data.add_special_tokens)
418
+ else:
419
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
420
+ config.data.tokenizer_name_or_path,
421
+ trust_remote_code=True)
422
+
423
+ if (isinstance(tokenizer, transformers.GPT2TokenizerFast)
424
+ or isinstance(tokenizer, transformers.GPT2Tokenizer)):
425
+ tokenizer._tokenizer.post_processor = tokenizers.processors.BertProcessing(
426
+ (tokenizer.bos_token, tokenizer.bos_token_id),
427
+ (tokenizer.eos_token, tokenizer.eos_token_id))
428
+
429
+ # For wrapped batches:
430
+ # [BOS] sent1 [EOS] sent2-fragment [EOS]
431
+ # [BOS] sent2-fragment [EOS] sent3 [EOS]
432
+ if tokenizer.bos_token is None:
433
+ if tokenizer.cls_token is None:
434
+ raise AttributeError(
435
+ 'Tokenizer must have a bos_token or '
436
+ f'cls_token: {tokenizer}')
437
+ tokenizer.bos_token = tokenizer.cls_token
438
+ if tokenizer.eos_token is None:
439
+ if tokenizer.sep_token is None:
440
+ raise AttributeError(
441
+ 'Tokenizer must have a eos_token '
442
+ f'or sep_token: {tokenizer}')
443
+ tokenizer.eos_token = tokenizer.sep_token
444
+ if tokenizer.pad_token is None and not config.is_vision:
445
+ tokenizer.add_special_tokens({'pad_token': '[PAD]'})
446
+
447
+ return tokenizer
448
+
449
+
450
+ def get_dataloaders(config, tokenizer, skip_train=False,
451
+ skip_valid=False, valid_seed=None):
452
+ num_gpus = torch.cuda.device_count()
453
+ assert (config.loader.global_batch_size
454
+ == (config.loader.batch_size
455
+ * config.trainer.num_nodes
456
+ * num_gpus
457
+ * config.trainer.accumulate_grad_batches))
458
+ if config.loader.global_batch_size % (
459
+ num_gpus * config.trainer.accumulate_grad_batches) != 0:
460
+ raise ValueError(
461
+ f'Train Batch Size {config.training.batch_size}'
462
+ f'not divisible by {num_gpus} gpus with accumulation '
463
+ f'{config.trainer.accumulate_grad_batches}.')
464
+ if config.loader.eval_global_batch_size % num_gpus != 0:
465
+ raise ValueError(
466
+ f'Eval Batch Size for {config.eval.batch_size} '
467
+ f'not divisible by {num_gpus}.')
468
+ label_col = getattr(config.data, 'label_col', None)
469
+ if skip_train:
470
+ train_set = None
471
+ else:
472
+ if 'cifar10' in config.data.train:
473
+ train_set = custom_datasets.discretized_cifar10.DiscreteCIFAR10(
474
+ config.data.train, train=True, download=True)
475
+ else:
476
+ train_set = get_dataset(
477
+ config.data.train,
478
+ tokenizer,
479
+ mode='train',
480
+ wrap=config.data.wrap,
481
+ cache_dir=config.data.cache_dir,
482
+ block_size=config.model.length,
483
+ override_cache=config.data.override_cache,
484
+ add_special_tokens=config.data.add_special_tokens,
485
+ label_col=label_col,
486
+ label_threshold=getattr(config.data,
487
+ 'label_col_pctile', None))
488
+ if config.data.valid in [
489
+ 'text8', 'lm1b', 'amazon_polarity', 'qm9',
490
+ 'ten_species']:
491
+ validation_split = 'test'
492
+ else:
493
+ validation_split = 'validation'
494
+ if skip_valid:
495
+ valid_set = None
496
+ else:
497
+ if 'cifar10' in config.data.train:
498
+ valid_set = custom_datasets.discretized_cifar10.DiscreteCIFAR10(
499
+ config.data.valid, train=False, download=True)
500
+ else:
501
+ valid_set = get_dataset(
502
+ config.data.valid,
503
+ tokenizer,
504
+ wrap=config.data.wrap,
505
+ mode=validation_split,
506
+ cache_dir=config.data.cache_dir,
507
+ block_size=config.model.length,
508
+ streaming=False,
509
+ override_cache=config.data.override_cache,
510
+ add_special_tokens=config.data.add_special_tokens,
511
+ label_col=label_col,
512
+ label_threshold=getattr(config.data,
513
+ 'label_col_pctile', None))
514
+
515
+ if skip_train:
516
+ train_loader = None
517
+ else:
518
+ train_loader = torch.utils.data.DataLoader(
519
+ train_set,
520
+ batch_size=config.loader.batch_size,
521
+ num_workers=config.loader.num_workers,
522
+ pin_memory=config.loader.pin_memory,
523
+ shuffle=not config.data.streaming,
524
+ persistent_workers=config.loader.persistent_workers
525
+ )
526
+ train_loader.tokenizer = tokenizer
527
+ if skip_valid:
528
+ valid_loader = None
529
+ else:
530
+ if valid_seed is None:
531
+ shuffle_valid = False
532
+ generator = None
533
+ else:
534
+ shuffle_valid = True
535
+ generator = torch.Generator().manual_seed(valid_seed)
536
+ valid_loader = torch.utils.data.DataLoader(
537
+ valid_set,
538
+ batch_size=config.loader.eval_batch_size,
539
+ num_workers=config.loader.num_workers,
540
+ pin_memory=config.loader.pin_memory,
541
+ shuffle=shuffle_valid,
542
+ generator=generator)
543
+ # Will be used in generative perplexity calculation
544
+ valid_loader.tokenizer = tokenizer
545
+
546
+ return train_loader, valid_loader
547
+
548
+
549
+ # Samplers adapted from: https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/fault_tolerant_sampler.py
550
+ class RandomFaultTolerantSampler(torch.utils.data.RandomSampler):
551
+
552
+ def __init__(self, *args, generator=None, **kwargs):
553
+ # TD [2022-07-17]: We don't force the seed to be zero. We generate random seed,
554
+ # which should be reproducible if pl.seed_everything was called beforehand.
555
+ # This means that changing the seed of the experiment will also change the
556
+ # sampling order.
557
+ if generator is None:
558
+ seed = int(torch.empty((), dtype=torch.int64).random_().item())
559
+ generator = torch.Generator().manual_seed(seed)
560
+ kwargs.pop('shuffle', None)
561
+ super().__init__(*args, generator=generator, **kwargs)
562
+ self.counter = 0
563
+ self.restarting = False
564
+
565
+ def state_dict(self):
566
+ return {'random_state': self.generator.get_state(),
567
+ 'counter': self.counter}
568
+
569
+ def load_state_dict(self, state_dict):
570
+ self.generator.set_state(state_dict.get('random_state'))
571
+ self.counter = state_dict['counter']
572
+ # self.start_counter = self.counter
573
+ self.restarting = True
574
+
575
+ # TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per
576
+ # epoch, and subsequent epoch will have very few batches.
577
+
578
+ def __iter__(self) -> typing.Iterator[int]:
579
+ n = len(self.data_source)
580
+
581
+ self.state = self.generator.get_state()
582
+ indices = torch.randperm(n, generator=self.generator).tolist()
583
+
584
+ if not self.restarting:
585
+ self.counter = 0
586
+ else:
587
+ indices = indices[self.counter:]
588
+ self.restarting = False
589
+
590
+ for index in indices:
591
+ self.counter += 1
592
+ yield index
593
+
594
+ self.counter = 0
595
+
596
+
597
+ class FaultTolerantDistributedSampler(torch.utils.data.DistributedSampler):
598
+
599
+ def __init__(self, *args, **kwargs):
600
+ super().__init__(*args, **kwargs)
601
+ self.counter = 0
602
+ self.restarting = False
603
+
604
+ def state_dict(self):
605
+ return {'epoch': self.epoch, 'counter': self.counter}
606
+
607
+ def load_state_dict(self, state_dict):
608
+ self.epoch = state_dict['epoch']
609
+ self.counter = state_dict['counter']
610
+ self.restarting = True
611
+
612
+ # TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per
613
+ # epoch, and subsequent epoch will have very few batches.
614
+ def __iter__(self):
615
+ if self.shuffle:
616
+ # deterministically shuffle based on epoch and seed
617
+ g = torch.Generator()
618
+ g.manual_seed(self.seed + self.epoch)
619
+ indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
620
+ else:
621
+ indices = list(range(len(self.dataset))) # type: ignore[arg-type]
622
+
623
+ if not self.drop_last:
624
+ # add extra samples to make it evenly divisible
625
+ padding_size = self.total_size - len(indices)
626
+ if padding_size <= len(indices):
627
+ indices += indices[:padding_size]
628
+ else:
629
+ indices += (indices * math.ceil(
630
+ padding_size / len(indices)))[:padding_size]
631
+ else:
632
+ # remove tail of data to make it evenly divisible.
633
+ indices = indices[:self.total_size]
634
+ assert len(indices) == self.total_size
635
+
636
+ # subsample
637
+ indices = indices[self.rank:self.total_size:self.num_replicas]
638
+ assert len(indices) == self.num_samples
639
+
640
+ if not self.restarting:
641
+ self.counter = 0
642
+ else:
643
+ indices = indices[self.counter:]
644
+ self.restarting = False
645
+
646
+ for index in indices:
647
+ self.counter += 1
648
+ yield index
649
+
650
+ self.counter = 0
651
+
652
+
653
+ def collate_fn(batch):
654
+ input_ids = torch.tensor(batch[0]['input_ids'])
655
+ attention_mask = torch.tensor(batch[0]['attention_mask'])
656
+ return {
657
+ 'input_ids': input_ids,
658
+ 'attention_mask': attention_mask
659
+ }
660
+
661
+ class CustomDataModule(L.LightningDataModule):
662
+ def __init__(self, train_dataset, val_dataset, test_dataset, tokenizer, config, batch_size: int=8, collate_fn=collate_fn):
663
+ super().__init__()
664
+ self.train_dataset = train_dataset
665
+ self.val_dataset = val_dataset
666
+ self.test_dataset = test_dataset
667
+ self.batch_size = batch_size
668
+ self.tokenizer = tokenizer
669
+ self.collate_fn = collate_fn
670
+ self.config = config
671
+
672
+ def train_dataloader(self):
673
+ return DataLoader(self.train_dataset,
674
+ collate_fn=partial(self.collate_fn),
675
+ num_workers=self.config.loader.num_workers,
676
+ pin_memory=self.config.loader.pin_memory,
677
+ shuffle=not self.config.data.streaming,
678
+ persistent_workers=self.config.loader.persistent_workers)
679
+
680
+ def val_dataloader(self):
681
+ return DataLoader(self.val_dataset,
682
+ collate_fn=partial(self.collate_fn),
683
+ num_workers=self.config.loader.num_workers,
684
+ pin_memory=self.config.loader.pin_memory,
685
+ shuffle=False)
686
+
687
+ def test_dataloader(self):
688
+ return DataLoader(self.test_dataset,
689
+ collate_fn=partial(self.collate_fn),
690
+ num_workers=self.config.loader.num_workers,
691
+ pin_memory=self.config.loader.pin_memory,
692
+ shuffle=not self.config.data.streaming)
diffusion.py ADDED
@@ -0,0 +1,1629 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module for modeling discrete diffusion
2
+ (absorbing state or uniform) and AR
3
+ (a special case of absorbing state).
4
+ """
5
+ import itertools
6
+ import math
7
+ import typing
8
+ from dataclasses import dataclass
9
+
10
+ import hydra.utils
11
+ import lightning as L
12
+ import numpy as np
13
+ import omegaconf
14
+ import torch
15
+ import torch.nn.functional as F
16
+ import torchmetrics
17
+ import transformers
18
+ from mamba_ssm.utils.generation import InferenceParams
19
+ from torch import Tensor
20
+ from tqdm.auto import tqdm
21
+ import pdb
22
+ import gc
23
+
24
+ import classifier
25
+ import dataloader
26
+ import models
27
+ import noise_schedule
28
+
29
+ LOG2 = math.log(2)
30
+
31
+
32
+ def _sample_categorical(categorical_probs):
33
+ gumbel_norm = (
34
+ 1e-10
35
+ - (torch.rand_like(categorical_probs) + 1e-10).log()).to(categorical_probs.dtype)
36
+ return (categorical_probs / gumbel_norm).argmax(dim=-1)
37
+
38
+
39
+ def _unsqueeze(x, reference):
40
+ return x.view(
41
+ * x.shape,
42
+ * ((1,) * (len(reference.shape) - len(x.shape))))
43
+
44
+
45
+ @dataclass
46
+ class Loss:
47
+ loss: torch.FloatTensor
48
+ nlls: torch.FloatTensor
49
+ token_mask: torch.FloatTensor
50
+ recon_loss: typing.Optional[torch.FloatTensor] = None
51
+ diffusion_loss: typing.Optional[torch.FloatTensor] = None
52
+
53
+
54
+ class NLL(torchmetrics.aggregation.MeanMetric):
55
+ pass
56
+
57
+
58
+ class BPD(NLL):
59
+ def compute(self) -> Tensor:
60
+ """Computes the bits per dimension.
61
+
62
+ Returns:
63
+ bpd
64
+ """
65
+ return self.mean_value / self.weight / LOG2
66
+
67
+
68
+ class Perplexity(NLL):
69
+ def compute(self) -> Tensor:
70
+ """Computes the Perplexity.
71
+
72
+ Returns:
73
+ Perplexity
74
+ """
75
+ return torch.exp(self.mean_value / self.weight)
76
+
77
+
78
+ class Diffusion(L.LightningModule):
79
+ def __init__(
80
+ self,
81
+ config,
82
+ tokenizer: transformers.PreTrainedTokenizer):
83
+ super().__init__()
84
+ self.save_hyperparameters()
85
+ self.config = config
86
+
87
+ self.tokenizer = tokenizer
88
+ self.vocab_size = tokenizer.vocab_size
89
+
90
+ self.antithetic_sampling = config.training.antithetic_sampling
91
+ self.importance_sampling = config.training.importance_sampling
92
+ self.change_of_variables = config.training.change_of_variables
93
+ self.noise = noise_schedule.get_noise(config, dtype=self.dtype)
94
+
95
+ if self.config.is_vision:
96
+ self.mask_index = getattr(tokenizer, 'mask_token_id', -1)
97
+ else:
98
+ if (not hasattr(self.tokenizer, 'mask_token')
99
+ or tokenizer.mask_token is None):
100
+ self.mask_index = self.vocab_size
101
+ self.vocab_size += 1
102
+ else:
103
+ self.mask_index = tokenizer.mask_token_id
104
+
105
+ # Note: creating limiting distribution with
106
+ # broadcast-able batch and sequence dimensions.
107
+ self.parameterization = config.parameterization
108
+ self.diffusion = config.diffusion
109
+ if config.parameterization == 'ar':
110
+ self.limiting_distribution = None
111
+ else:
112
+ if self.diffusion == 'absorbing_state':
113
+ # Not needed, posterior calculated explicitly.
114
+ limiting_distribution = None
115
+ elif self.diffusion == 'uniform':
116
+ limiting_distribution = torch.ones(
117
+ (1, 1, self.vocab_size), dtype=self.dtype) / self.vocab_size
118
+ else:
119
+ raise NotImplementedError(
120
+ f"Diffusion type {self.diffusion} not implemented.")
121
+ self.register_buffer('limiting_distribution',
122
+ limiting_distribution)
123
+
124
+ self.T = config.T
125
+ self.subs_masking = config.subs_masking
126
+ self.time_conditioning = config.time_conditioning
127
+
128
+ if self.config.backbone == 'dit':
129
+ self.backbone = models.dit.DIT(
130
+ self.config, vocab_size=self.vocab_size)
131
+ elif self.config.backbone == 'dimamba':
132
+ self.backbone = models.dimamba.DiMamba(
133
+ self.config, vocab_size=self.vocab_size,
134
+ pad_token_id=self.tokenizer.pad_token_id)
135
+ elif self.config.backbone == 'unet':
136
+ self.backbone = models.unet.UNet(
137
+ self.config, vocab_size=self.vocab_size)
138
+ elif self.config.backbone == 'hf_dit':
139
+ self.backbone = transformers.AutoModelForMaskedLM.from_pretrained(
140
+ config.model.pretrained_model_name_or_path, trust_remote_code=True)
141
+ else:
142
+ raise NotImplementedError(
143
+ f"Backbone {self.config.backbone} not implemented.")
144
+
145
+ self.lr = self.config.optim.lr
146
+ self.sampling_eps = config.training.sampling_eps
147
+
148
+ self.softplus = torch.nn.Softplus()
149
+ self.neg_infinity = -1_000_000.0
150
+
151
+ if config.training.ema > 0:
152
+ self.ema = models.ema.ExponentialMovingAverage(
153
+ itertools.chain(self.backbone.parameters(),
154
+ self.noise.parameters()),
155
+ decay=config.training.ema)
156
+ else:
157
+ self.ema = None
158
+
159
+ # metrics are automatically reset at end of epoch
160
+ metrics = torchmetrics.MetricCollection({
161
+ 'nll': NLL(),
162
+ 'bpd': BPD(),
163
+ 'ppl': Perplexity(),
164
+ })
165
+ metrics.set_dtype(torch.float64)
166
+ self.train_metrics = metrics.clone(prefix='train/')
167
+ self.valid_metrics = metrics.clone(prefix='val/')
168
+ self.test_metrics = metrics.clone(prefix='test/')
169
+
170
+ self.fast_forward_epochs = None
171
+ self.fast_forward_batches = None
172
+
173
+ self._validate_configuration()
174
+
175
+ def _validate_configuration(self):
176
+ assert not (self.change_of_variables
177
+ and self.importance_sampling)
178
+ if self.diffusion != 'absorbing_state':
179
+ assert self.parameterization not in {'ar', 'subs'}
180
+ if self.T > 0:
181
+ assert self.parameterization in {'d3pm', 'subs'}
182
+ if self.subs_masking:
183
+ assert self.parameterization == 'd3pm'
184
+
185
+ def on_load_checkpoint(self, checkpoint):
186
+ if self.limiting_distribution is not None:
187
+ checkpoint['state_dict']['limiting_distribution'] = self.limiting_distribution.to(
188
+ list(checkpoint['state_dict'].values())[0].device)
189
+ if self.ema:
190
+ self.ema.load_state_dict(checkpoint['ema'])
191
+ # Copied from:
192
+ # https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py#L41
193
+ self.fast_forward_epochs = checkpoint['loops'][
194
+ 'fit_loop']['epoch_progress']['current']['completed']
195
+ self.fast_forward_batches = checkpoint['loops'][
196
+ 'fit_loop']['epoch_loop.batch_progress'][
197
+ 'current']['completed']
198
+
199
+ def on_save_checkpoint(self, checkpoint):
200
+ # Do not save this buffer
201
+ checkpoint['state_dict'].pop('limiting_distribution',
202
+ None)
203
+ if self.ema:
204
+ checkpoint['ema'] = self.ema.state_dict()
205
+ # Copied from:
206
+ # https://github.com/Dao-AILab/flash-attention/blob/main/training/src/tasks/seq.py
207
+ # ['epoch_loop.batch_progress']['total']['completed'] is
208
+ # 1 iteration behind, so we're using the optimizer's
209
+ # progress.
210
+ checkpoint['loops']['fit_loop'][
211
+ 'epoch_loop.batch_progress']['total'][
212
+ 'completed'] = checkpoint['loops']['fit_loop'][
213
+ 'epoch_loop.automatic_optimization.optim_progress'][
214
+ 'optimizer']['step']['total'][
215
+ 'completed'] * self.trainer.accumulate_grad_batches
216
+ checkpoint['loops']['fit_loop'][
217
+ 'epoch_loop.batch_progress']['current'][
218
+ 'completed'] = checkpoint['loops']['fit_loop'][
219
+ 'epoch_loop.automatic_optimization.optim_progress'][
220
+ 'optimizer']['step']['current'][
221
+ 'completed'] * self.trainer.accumulate_grad_batches
222
+ # _batches_that_stepped tracks the number of global
223
+ # steps, not the number of local steps, so we don't
224
+ # multiply with self.trainer.accumulate_grad_batches
225
+ # here.
226
+ checkpoint['loops']['fit_loop'][
227
+ 'epoch_loop.state_dict'][
228
+ '_batches_that_stepped'] = checkpoint['loops']['fit_loop'][
229
+ 'epoch_loop.automatic_optimization.optim_progress'][
230
+ 'optimizer']['step']['total']['completed']
231
+ if 'sampler' not in checkpoint.keys():
232
+ checkpoint['sampler'] = {}
233
+ if hasattr(self.trainer.train_dataloader.sampler,
234
+ 'state_dict'):
235
+ sampler_state_dict = self.trainer.\
236
+ train_dataloader.sampler.state_dict()
237
+ checkpoint['sampler'][
238
+ 'random_state'] = sampler_state_dict.get(
239
+ 'random_state', None)
240
+ else:
241
+ checkpoint['sampler']['random_state'] = None
242
+
243
+ def on_train_start(self):
244
+ if self.ema:
245
+ self.ema.move_shadow_params_to_device(self.device)
246
+ # Adapted from:
247
+ # https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py
248
+ distributed = (
249
+ self.trainer._accelerator_connector.use_distributed_sampler
250
+ and self.trainer._accelerator_connector.is_distributed)
251
+ if distributed:
252
+ sampler_cls = dataloader.FaultTolerantDistributedSampler
253
+ else:
254
+ sampler_cls = dataloader.RandomFaultTolerantSampler
255
+ updated_dls = []
256
+ for dl in self.trainer.fit_loop._combined_loader.flattened:
257
+ if hasattr(dl.sampler, 'shuffle'):
258
+ dl_sampler = sampler_cls(
259
+ dl.dataset, shuffle=dl.sampler.shuffle)
260
+ else:
261
+ dl_sampler = sampler_cls(dl.dataset)
262
+ if (distributed
263
+ and self.fast_forward_epochs is not None
264
+ and self.fast_forward_batches is not None):
265
+ dl_sampler.load_state_dict({
266
+ 'epoch': self.fast_forward_epochs,
267
+ 'counter': (self.fast_forward_batches
268
+ * self.config.loader.batch_size)})
269
+
270
+ from functools import partial
271
+ from dataloader import collate_fn
272
+ collate_partial = partial(collate_fn)
273
+ torch.cuda.empty_cache()
274
+
275
+ updated_dls.append(
276
+ torch.utils.data.DataLoader(
277
+ dl.dataset,
278
+ # batch_size=self.config.loader.batch_size,
279
+ num_workers=self.config.loader.num_workers,
280
+ pin_memory=self.config.loader.pin_memory,
281
+ # sampler=dl_sampler,
282
+ shuffle=False,
283
+ persistent_workers=self.config.loader.persistent_workers,
284
+ collate_fn=collate_partial
285
+ ))
286
+ self.trainer.fit_loop._combined_loader.flattened = updated_dls
287
+
288
+ def configure_optimizers(self):
289
+ # TODO(yair): Lightning currently giving this warning when using `fp16`:
290
+ # "Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
291
+ # Not clear if this is a problem or not.
292
+ # See: https://github.com/Lightning-AI/pytorch-lightning/issues/5558
293
+ optimizer = torch.optim.AdamW(
294
+ itertools.chain(self.backbone.parameters(),
295
+ self.noise.parameters()),
296
+ lr=self.config.optim.lr,
297
+ betas=(self.config.optim.beta1,
298
+ self.config.optim.beta2),
299
+ eps=self.config.optim.eps,
300
+ weight_decay=self.config.optim.weight_decay)
301
+
302
+ scheduler = hydra.utils.instantiate(
303
+ self.config.lr_scheduler, optimizer=optimizer)
304
+ scheduler_dict = {
305
+ 'scheduler': scheduler,
306
+ 'interval': 'step',
307
+ 'monitor': 'val/loss',
308
+ 'name': 'trainer/lr',
309
+ }
310
+ return [optimizer], [scheduler_dict]
311
+
312
+ def optimizer_step(self, *args, **kwargs):
313
+ super().optimizer_step(*args, **kwargs)
314
+ if self.ema:
315
+ self.ema.update(itertools.chain(
316
+ self.backbone.parameters(),
317
+ self.noise.parameters()))
318
+
319
+ def _subs_parameterization(self, logits, xt):
320
+ # "Zero Masking Prob":
321
+ # log prob at the mask index = - infinity
322
+ logits[..., self.mask_index] += self.neg_infinity
323
+
324
+ # "Copy over":
325
+ # Apply updates directly in the logits matrix.
326
+ # For the logits of the unmasked tokens, set all values
327
+ # to -infinity except for the indices corresponding to
328
+ # the unmasked tokens.
329
+ unmasked_indices = (xt != self.mask_index)
330
+ logits[unmasked_indices] = self.neg_infinity
331
+ logits[unmasked_indices, xt[unmasked_indices]] = 0
332
+
333
+ # Normalize the logits such that x.exp() is
334
+ # a probability distribution over vocab_size.
335
+ return logits.log_softmax(dim=-1)
336
+
337
+ def _process_sigma(self, sigma):
338
+ if sigma is None:
339
+ assert self.parameterization == 'ar'
340
+ return sigma
341
+ if sigma.ndim > 1:
342
+ sigma = sigma.squeeze(-1)
343
+ if not self.time_conditioning:
344
+ sigma = torch.zeros_like(sigma)
345
+ assert sigma.ndim == 1, sigma.shape
346
+ return sigma
347
+
348
+ def forward(self, x, sigma, cond=None, x_emb=None, **kwargs):
349
+ """Returns log_probs / logits."""
350
+ sigma = self._process_sigma(sigma)
351
+ with torch.cuda.amp.autocast(dtype=torch.float32):
352
+ logits = self.backbone(x, sigma, cond, x_emb=x_emb, **kwargs)
353
+
354
+ if self.parameterization == 'subs':
355
+ # returns log_probs
356
+ return self._subs_parameterization(
357
+ logits=logits, xt=x)
358
+ if self.parameterization in {'ar', 'd3pm'}:
359
+ # returns log_probs
360
+ if self.subs_masking: # Can use "zero masking prob"
361
+ logits[:, :, self.mask_index] += self.neg_infinity
362
+ return logits.log_softmax(dim=-1)
363
+ return logits
364
+
365
+ def _compute_posterior(self, x, xt, alpha_s, alpha_t):
366
+ """Computes the posterior / approximate posterior.
367
+
368
+ Args:
369
+ x: Either clean input `x0` (one-hot),
370
+ or model's predicted `x_theta` of shape (B, L, V).
371
+ xt: The noisy latent (as indices) of shape (B, L).
372
+ alpha_s: Noise level at s of shape (B, [L | 1], 1).
373
+ alpha_t: Noise level at t of shape (B, [L | 1], 1).
374
+
375
+ Returns:
376
+ Posterior / approximate posterior of shape (B, L, V).
377
+ """
378
+ alpha_ts = alpha_t / alpha_s
379
+ d_alpha = alpha_s - alpha_t
380
+ xt_one_hot = F.one_hot(xt, self.vocab_size)
381
+ if self.diffusion == 'uniform':
382
+ return (
383
+ (alpha_t * self.vocab_size * x * xt_one_hot +
384
+ (alpha_ts - alpha_t) * xt_one_hot +
385
+ d_alpha * x +
386
+ (1 - alpha_ts) * (1 - alpha_s) * self.limiting_distribution)
387
+ /
388
+ (alpha_t * self.vocab_size * torch.gather(x, -1, xt[..., None]) +
389
+ (1 - alpha_t))
390
+ )
391
+ raise NotImplementedError(
392
+ f"Diffusion type {self.diffusion} not implemented.")
393
+
394
+ def _d3pm_loss(self, model_output, xt, x0, t):
395
+ assert self.config.noise.type == 'loglinear', (
396
+ 'D3PM loss only implemented for log-linear noise.')
397
+ dt = 1 / self.T
398
+
399
+ if torch.is_tensor(t):
400
+ t = t[:, None]
401
+ assert t.ndim == 2
402
+ t = t.clamp(0., 1. - 1e-4)
403
+ alpha_t = 1 - t + torch.zeros_like(xt)
404
+ alpha_s = 1 - (t - dt) + torch.zeros_like(xt)
405
+
406
+ if self.diffusion == 'absorbing_state':
407
+ log_x_theta_at_x0 = torch.gather(
408
+ model_output, -1, x0[:, :, None]).squeeze(-1)
409
+ log_x_theta_at_m = model_output[:, :, self.mask_index]
410
+ x_theta_at_m = log_x_theta_at_m.exp()
411
+
412
+ term_1_coef = dt / t
413
+ term_1_log_nr = torch.log(alpha_t * x_theta_at_m / t + 1)
414
+ term_1_log_dr = log_x_theta_at_x0
415
+
416
+ term_2_coef = 1 - dt / t
417
+ term_2_log_nr = term_1_log_nr
418
+ term_2_log_dr = torch.log(alpha_s * x_theta_at_m / (t - dt) + 1)
419
+
420
+ L_vb_masked = (
421
+ term_1_coef * (term_1_log_nr - term_1_log_dr)
422
+ + term_2_coef * (term_2_log_nr - term_2_log_dr))
423
+
424
+ L_vb = L_vb_masked * (xt == self.mask_index)
425
+ elif self.diffusion == 'uniform':
426
+ posterior = self._compute_posterior(
427
+ x=F.one_hot(x0, num_classes=self.vocab_size).to(self.dtype),
428
+ xt=xt,
429
+ alpha_s=alpha_s[..., None],
430
+ alpha_t=alpha_t[..., None])
431
+ posterior_pred = self._compute_posterior(
432
+ x=model_output.exp(),
433
+ xt=xt,
434
+ alpha_s=alpha_s[..., None],
435
+ alpha_t=alpha_t[..., None])
436
+ L_vb = (
437
+ posterior * (torch.log(posterior + 1e-12) - torch.log(posterior_pred))
438
+ ).sum(dim=-1)
439
+ else:
440
+ raise NotImplementedError(
441
+ f"Diffusion type {self.diffusion} not implemented for D3PM.")
442
+ return self.T * L_vb
443
+
444
+ def _reconstruction_loss(self, x0, cond=None):
445
+ # For D3PM parameterization
446
+ assert self.config.noise.type == 'loglinear', (
447
+ 'Reconstruction loss only implemented for log-linear '
448
+ 'noise.')
449
+ t0 = torch.zeros(x0.shape[0], dtype=self.dtype,
450
+ device=self.device)
451
+ time_conditioning = self.noise(t0)[0][:, None]
452
+ model_output_t0 = self.forward(x0, time_conditioning,
453
+ cond=cond)
454
+ return - torch.gather(input=model_output_t0,
455
+ dim=-1,
456
+ index=x0[:, :, None]).squeeze(-1)
457
+
458
+ def _sample_t(self, n):
459
+ _eps_t = torch.rand(n, device=self.device)
460
+ if self.antithetic_sampling:
461
+ offset = torch.arange(n, device=self.device) / n
462
+ _eps_t = (_eps_t / n + offset) % 1
463
+ t = (1 - self.sampling_eps) * _eps_t + self.sampling_eps
464
+ if self.importance_sampling:
465
+ return self.noise.importance_sampling_transformation(
466
+ t)
467
+ return t
468
+
469
+ def _q_xt(self, x, move_chance):
470
+ """Computes the noisy sample xt.
471
+
472
+ Args:
473
+ x: int torch.Tensor with shape (batch_size,
474
+ diffusion_model_input_length), input.
475
+ move_chance: float torch.Tensor with shape
476
+ (batch_size, 1).
477
+ """
478
+ move_indices = torch.rand(
479
+ *x.shape, device=x.device) < move_chance
480
+ if self.diffusion == 'absorbing_state':
481
+ return torch.where(move_indices, self.mask_index, x)
482
+ if self.diffusion == 'uniform':
483
+ uniform_tensor = torch.randint(
484
+ 0, self.vocab_size, x.shape, device=x.device)
485
+ return torch.where(move_indices, uniform_tensor, x)
486
+ elif self.diffusion == 'uniform_data_marginals':
487
+ return torch.where(
488
+ move_indices,
489
+ self._sample_prior(*x.shape),
490
+ x)
491
+ raise NotImplementedError(
492
+ f"Diffusion type {self.diffusion} not implemented.")
493
+
494
+ def _forward_pass_diffusion(self, x0, cond=None):
495
+ t = self._sample_t(x0.shape[0])
496
+ if self.T > 0:
497
+ t = (t * self.T).to(torch.int)
498
+ t = t / self.T
499
+ # t \in {1/T, 2/T, ..., 1}
500
+ t += (1 / self.T)
501
+
502
+ if self.change_of_variables:
503
+ time_conditioning = t[:, None]
504
+ f_T = torch.log1p(- torch.exp(- self.noise.sigma_max))
505
+ f_0 = torch.log1p(- torch.exp(- self.noise.sigma_min))
506
+ move_chance = torch.exp(f_0 + t * (f_T - f_0))
507
+ move_chance = move_chance[:, None]
508
+ sigma, dsigma = None, None
509
+ else:
510
+ sigma, dsigma = self.noise(t)
511
+ time_conditioning = sigma[:, None]
512
+ move_chance = 1 - torch.exp(-sigma[:, None])
513
+
514
+ xt = self._q_xt(x0, move_chance)
515
+ model_output = self.forward(xt, time_conditioning,
516
+ cond=cond)
517
+
518
+ # Discrete (finite T) time
519
+ if self.T > 0:
520
+ diffusion_loss = self._d3pm_loss(
521
+ model_output=model_output, xt=xt, x0=x0, t=t)
522
+ if self.parameterization == 'd3pm':
523
+ reconstruction_loss = self._reconstruction_loss(
524
+ x0, cond=cond)
525
+ if self.training and self.config.training.use_simple_ce_loss:
526
+ loss = -torch.gather(
527
+ input=model_output,
528
+ dim=-1,
529
+ index=x0[:, :, None]).squeeze(-1)
530
+ else:
531
+ loss = reconstruction_loss + diffusion_loss
532
+ return {
533
+ 'recon_loss': reconstruction_loss,
534
+ 'diffusion_loss': diffusion_loss,
535
+ 'loss': loss}
536
+ elif self.parameterization == 'subs':
537
+ if self.training and self.config.training.use_simple_ce_loss:
538
+ loss = -torch.gather(
539
+ input=model_output,
540
+ dim=-1,
541
+ index=x0[:, :, None]).squeeze(-1)
542
+ else:
543
+ loss = diffusion_loss
544
+ return {'diffusion_loss': diffusion_loss, 'loss': loss}
545
+ else:
546
+ raise ValueError(
547
+ f"Invalid parameterization: {self.parameterization} for T > 0.")
548
+
549
+ # Continuous (T --> infty) time
550
+ if self.diffusion == 'absorbing_state':
551
+ # SUBS parameterization, continuous time.
552
+ log_p_theta = torch.gather(
553
+ input=model_output,
554
+ dim=-1,
555
+ index=x0[:, :, None]).squeeze(-1)
556
+
557
+ if self.change_of_variables or self.importance_sampling:
558
+ if self.training and self.config.training.use_simple_ce_loss:
559
+ return {
560
+ 'diffusion_loss': log_p_theta * torch.log1p(-torch.exp(- self.noise.sigma_min)),
561
+ 'loss': -log_p_theta
562
+ }
563
+ return log_p_theta * torch.log1p(-torch.exp(- self.noise.sigma_min))
564
+
565
+ if self.training and self.config.training.use_simple_ce_loss:
566
+ return {
567
+ 'diffusion_loss': log_p_theta * (dsigma / torch.expm1(sigma))[:, None],
568
+ 'loss': log_p_theta
569
+ }
570
+ return - log_p_theta * (dsigma / torch.expm1(sigma))[:, None]
571
+
572
+ elif self.diffusion == 'uniform':
573
+ assert self.config.noise.type == 'loglinear', (
574
+ 'Continuous time uniform diffusion only implemented'
575
+ ' for log-linear noise.')
576
+ # TODO: Currently α_t' and α_t are hardcoded to a
577
+ # log-linear noise.
578
+ # Make generic (as above, for absorbing state):
579
+ # alpha_t_prime = -dsigma * (-sigma).exp()
580
+ # alpha_t = (-sigma).exp()
581
+ alpha_t_prime = -1.
582
+ alpha_t = 1. - t[..., None, None] # B, 1, 1
583
+
584
+ # x_bar = N * α_t * x + 1 - α_t ; B, L, V
585
+ x_bar = self.vocab_size * alpha_t * F.one_hot(x0, self.vocab_size).float() + 1 - alpha_t
586
+ x_bar_theta = self.vocab_size * alpha_t * model_output.exp() + 1 - alpha_t
587
+
588
+ # α_t' / (N*α_t)
589
+ coeff = alpha_t_prime / (self.vocab_size * alpha_t) # B, 1, 1
590
+
591
+ # Term 1: indices where z_t = 1
592
+ x_bar_zt = torch.gather(x_bar, -1, xt[..., None]) # B, L, 1
593
+ x_bar_theta_zt = torch.gather(x_bar_theta, -1, xt[..., None]) # B, L, 1
594
+ term1 = ((self.vocab_size / x_bar_zt) - (self.vocab_size / x_bar_theta_zt)) # B, L, 1
595
+
596
+ # Term 2: indices where z_t = 0
597
+ term2 = ( # B, L, V before summing --> B, L, 1 after
598
+ (x_bar / x_bar_zt) *
599
+ (
600
+ x_bar_theta_zt.log() - x_bar_theta.log() +
601
+ x_bar.log() - x_bar_zt.log()
602
+ )
603
+ )
604
+ term2 = term2.sum(dim=-1, keepdim=True) # B, L, 1
605
+
606
+ diffusion_loss = (coeff * (term1 - term2)).squeeze() # B, L
607
+ reconstruction_loss = self._reconstruction_loss(
608
+ x0, cond=cond)
609
+ if self.training and self.config.training.use_simple_ce_loss:
610
+ return {
611
+ 'recon_loss': reconstruction_loss,
612
+ 'diffusion_loss': diffusion_loss,
613
+ 'loss': -torch.gather(
614
+ input=model_output,
615
+ dim=-1,
616
+ index=x0[:, :, None]).squeeze(-1)
617
+ }
618
+ return {
619
+ 'recon_loss': reconstruction_loss,
620
+ 'diffusion_loss': diffusion_loss,
621
+ 'loss': diffusion_loss if getattr(self.config, 'zero_recon_loss', False)
622
+ else diffusion_loss + reconstruction_loss
623
+ }
624
+ else:
625
+ raise NotImplementedError(
626
+ f"Diffusion type {self.diffusion} not "
627
+ "implemented for continuous time case.")
628
+
629
+ def _maybe_sub_sample(self, x0, attention_mask):
630
+ seqlen = x0.shape[1]
631
+ # if seqlen > self.config.model.length:
632
+ # assert seqlen == 2 * self.config.model.length
633
+ # # cropping is necessary for the text8-crop dataset;
634
+ # # try the same starting point for now
635
+ # start = np.random.choice(self.config.model.length)
636
+ # end = start + self.config.model.length
637
+ # input_tokens = x0[:, start: end]
638
+ # output_tokens = x0[:, start + 1: end + 1]
639
+ # new_attention_mask = attention_mask[:, start: end]
640
+
641
+ # # Helps with validation PPL, since the val
642
+ # # examples will all start and end with BOS/EOS
643
+ # input_tokens[:, 0] = self.tokenizer.bos_token_id
644
+ # output_tokens[:, -1] = self.tokenizer.eos_token_id
645
+ # elif self.parameterization == 'ar':
646
+ # input_tokens = x0[:, :-1]
647
+ # output_tokens = x0[:, 1:]
648
+ # new_attention_mask = attention_mask[:, 1:]
649
+ # else:
650
+ # input_tokens = x0
651
+ # output_tokens = None
652
+ # new_attention_mask = attention_mask
653
+
654
+ input_tokens = x0
655
+ output_tokens = None
656
+ new_attention_mask = attention_mask
657
+ return input_tokens, output_tokens, new_attention_mask
658
+
659
+ def _loss(self, x0, attention_mask, cond=None):
660
+ (input_tokens, output_tokens,
661
+ attention_mask) = self._maybe_sub_sample(
662
+ x0, attention_mask)
663
+
664
+ recon_loss, diffusion_loss = None, None
665
+
666
+ if (cond is not None and self.training
667
+ and self.config.training.guidance is not None
668
+ and self.config.training.guidance.cond_dropout > 0):
669
+ # Randomly mask out conditioning for classifier-free
670
+ # guidance training.
671
+ p = torch.bernoulli(
672
+ torch.ones_like(cond) *
673
+ self.config.training.guidance.cond_dropout).to(torch.bool)
674
+ # Use num_classes index as conditioning mask_token_id
675
+ cond[p] = self.config.data.num_classes
676
+
677
+ if self.parameterization == 'ar':
678
+ logprobs = self.forward(
679
+ input_tokens, sigma=None, cond=cond)
680
+ loss = - logprobs.gather(
681
+ -1, output_tokens[:, :, None])[:, :, 0]
682
+ else:
683
+ loss = self._forward_pass_diffusion(input_tokens,
684
+ cond=cond)
685
+ if isinstance(loss, dict):
686
+ recon_loss = loss['recon_loss']
687
+ diffusion_loss = loss['diffusion_loss']
688
+ loss = loss['loss']
689
+
690
+ nlls = loss * attention_mask
691
+ count = attention_mask.sum()
692
+
693
+ if (self.config.training.compute_loss_on_pad_tokens
694
+ and self.training):
695
+ token_nll = loss.mean()
696
+ else:
697
+ batch_nll = nlls.sum()
698
+ token_nll = batch_nll / count
699
+
700
+ if recon_loss is not None and diffusion_loss is not None:
701
+ with torch.no_grad():
702
+ recon_loss_batch = (recon_loss * attention_mask).sum() / count
703
+ diffusion_loss_batch = (diffusion_loss * attention_mask).sum() / count
704
+ return Loss(loss=token_nll,
705
+ nlls=nlls,
706
+ token_mask=attention_mask,
707
+ recon_loss=recon_loss_batch,
708
+ diffusion_loss=diffusion_loss_batch)
709
+ return Loss(loss=token_nll,
710
+ nlls=nlls,
711
+ token_mask=attention_mask)
712
+
713
+ def _compute_loss(self, batch, prefix):
714
+ if 'attention_mask' in batch:
715
+ attention_mask = batch['attention_mask']
716
+ else:
717
+ attention_mask = None
718
+ cond = None
719
+ if (self.config.training.guidance is not None or # Training for / using CFG
720
+ (hasattr(self.config, 'guidance')
721
+ and self.config.guidance is not None
722
+ and self.config.guidance.method == 'cfg')):
723
+ if self.config.data.label_col in batch:
724
+ cond = batch[self.config.data.label_col]
725
+ elif f"{self.config.data.label_col}_threshold" in batch:
726
+ cond = batch[f"{self.config.data.label_col}_threshold"]
727
+ else:
728
+ raise RuntimeError(
729
+ f"Conditioning {self.config.data.label_col}"
730
+ f" not found in batch.")
731
+ losses = self._loss(batch['input_ids'], attention_mask,
732
+ cond=cond)
733
+
734
+ if prefix == 'train':
735
+ self.train_metrics.update(losses.nlls,
736
+ losses.token_mask)
737
+ metrics = self.train_metrics
738
+ elif prefix == 'val':
739
+ self.valid_metrics.update(losses.nlls,
740
+ losses.token_mask)
741
+ metrics = self.valid_metrics
742
+ elif prefix == 'test':
743
+ self.test_metrics.update(losses.nlls,
744
+ losses.token_mask)
745
+ metrics = self.test_metrics
746
+ else:
747
+ raise ValueError(f"Invalid prefix: {prefix}")
748
+
749
+ self.log_dict(metrics,
750
+ on_step=False,
751
+ on_epoch=True,
752
+ sync_dist=True)
753
+ return losses
754
+
755
+ def training_step(self, batch, batch_idx):
756
+ losses = self._compute_loss(batch, prefix='train')
757
+ self.log(name='trainer/loss',
758
+ value=losses.loss.item(),
759
+ on_step=True,
760
+ on_epoch=True,
761
+ sync_dist=True,
762
+ prog_bar=True)
763
+ if losses.recon_loss is not None:
764
+ self.log(name='trainer/recon_loss',
765
+ value=losses.recon_loss.item(),
766
+ on_step=True,
767
+ on_epoch=True,
768
+ sync_dist=True,
769
+ prog_bar=False)
770
+ self.log(name='trainer/diffusion_loss',
771
+ value=losses.diffusion_loss.item(),
772
+ on_step=True,
773
+ on_epoch=True,
774
+ sync_dist=True,
775
+ prog_bar=False)
776
+ self.log(name='lr',
777
+ value=self.trainer.optimizers[0].param_groups[0]['lr'],
778
+ on_step=True,
779
+ on_epoch=False,
780
+ sync_dist=True,
781
+ prog_bar=True, logger=False)
782
+ return losses.loss
783
+
784
+ def validation_step(self, batch, batch_idx):
785
+ losses = self._compute_loss(batch, prefix='val')
786
+ self.log(name='trainer/val_loss',
787
+ value=losses.loss.item(),
788
+ on_step=True,
789
+ on_epoch=True,
790
+ prog_bar=True,
791
+ sync_dist=True)
792
+ return losses.loss
793
+
794
+ def load_ema_params(self):
795
+ if self.ema:
796
+ self.ema.store(itertools.chain(
797
+ self.backbone.parameters(),
798
+ self.noise.parameters()))
799
+ self.ema.copy_to(itertools.chain(
800
+ self.backbone.parameters(),
801
+ self.noise.parameters()))
802
+
803
+ def _restore_non_ema_params(self):
804
+ if self.ema:
805
+ self.ema.restore(itertools.chain(
806
+ self.backbone.parameters(),
807
+ self.noise.parameters()))
808
+
809
+ def on_validation_epoch_start(self):
810
+ # pdb.set_trace()
811
+ gc.collect()
812
+ torch.cuda.empty_cache()
813
+ self.load_ema_params()
814
+ assert self.valid_metrics.nll.mean_value == 0
815
+ assert self.valid_metrics.nll.weight == 0
816
+
817
+ def on_validation_epoch_end(self):
818
+ # pdb.set_trace()
819
+ # self._restore_non_ema_params()
820
+ # if (not self.trainer.sanity_checking
821
+ # and self.config.eval.generate_samples
822
+ # and self.trainer.global_rank == 0):
823
+ # self.config.sampling.batch_size = 1
824
+ # if self.config.is_vision:
825
+ # samples = []
826
+ # if self.config.training.guidance is not None:
827
+ # # Generate one image per class (up to 10 images)
828
+
829
+ # guidance = {
830
+ # 'method': 'cfg', 'condition': 0, 'gamma': 1.0}
831
+ # omegaconf.OmegaConf.update(
832
+ # self.config, key='guidance', value=guidance,
833
+ # force_add=True)
834
+ # for i in range(max(self.config.data.num_classes, 10)):
835
+ # self.config.guidance.condition = i
836
+ # samples.append(self.sample())
837
+ # else:
838
+ # # Generate ten images
839
+ # for i in range(10):
840
+ # samples.append(self.sample())
841
+ # image_samples = self.tokenizer.batch_decode(
842
+ # torch.concat(samples, dim=0))
843
+ # if hasattr(self.trainer.logger, 'log_image'):
844
+ # self.trainer.logger.log_image(
845
+ # key=f"samples@global_step{self.global_step}",
846
+ # caption=[str(i) for i in range(len(samples))],
847
+ # images=[s for s in image_samples.float()])
848
+ # else:
849
+ # if self.config.training.guidance is not None:
850
+ # guidance = {
851
+ # 'method': 'cfg', 'condition': 0, 'gamma': 1.0}
852
+ # omegaconf.OmegaConf.update(
853
+ # self.config, key='guidance', value=guidance,
854
+ # force_add=True)
855
+ # for i in range(self.config.data.num_classes):
856
+ # self.config.guidance.condition = i
857
+ # samples = self.sample()
858
+ # decoded_samples = self.tokenizer.batch_decode(
859
+ # samples)
860
+ # if hasattr(self.trainer.logger, 'log_table'):
861
+ # # Log some generated samples
862
+ # self.trainer.logger.log_table(
863
+ # key=f"samples@global_step{self.global_step}_class-{i}",
864
+ # columns=['Generated Samples'],
865
+ # data=[decoded_samples])
866
+ # else:
867
+ # self.config.sampling.batch_size = 2
868
+ # samples = self.sample()
869
+ # decoded_samples = self.tokenizer.batch_decode(
870
+ # samples)
871
+ # if hasattr(self.trainer.logger, 'log_table'):
872
+ # # Log some generated samples
873
+ # self.trainer.logger.log_table(
874
+ # key=f"samples@global_step{self.global_step}",
875
+ # columns=['Generated Samples'],
876
+ # data=[[s] for s in decoded_samples])
877
+ gc.collect()
878
+ torch.cuda.empty_cache()
879
+ self._restore_non_ema_params()
880
+
881
+ def _sample_prior(self, *batch_dims):
882
+ if self.diffusion == 'absorbing_state':
883
+ return self.mask_index * torch.ones(
884
+ *batch_dims, dtype=torch.int64, device=self.device)
885
+ if self.diffusion == 'uniform':
886
+ return torch.randint(
887
+ 0, self.vocab_size, batch_dims, dtype=torch.int64,
888
+ device=self.device)
889
+ elif self.diffusion == 'uniform_data_marginals':
890
+ if self.limiting_distribution.squeeze().ndim == 2:
891
+ batch_dims = (batch_dims[0],)
892
+ return torch.distributions.Categorical(
893
+ self.limiting_distribution.squeeze()).sample(
894
+ sample_shape=torch.Size(batch_dims))
895
+ raise NotImplementedError(
896
+ f'Diffusion type {self.diffusion} not '
897
+ 'implemented.')
898
+
899
+ def sample(
900
+ self,
901
+ eps=1e-5,
902
+ target_sequence: torch.tensor = None,
903
+ target_motifs: torch.tensor = None,
904
+ classifier_model = None): # Note: differs from self.config.training.sampling_eps
905
+ """Generate samples from (ema) model.
906
+
907
+ Supports both AR and diffusion sampling.
908
+ Supports:
909
+ - standard decoding,
910
+ - classifier-free guidance,
911
+ - classifier-based guidance
912
+ - CBG / FUDGE,
913
+ - NOS / PPLM.
914
+ """
915
+ # WARNING: Lightning auto-casting is not working in this method.
916
+ if not self.config.eval.disable_ema:
917
+ self.load_ema_params()
918
+ if getattr(self.config, 'guidance', None) is not None:
919
+ if self.config.guidance.method == 'cfg':
920
+ cond = (torch.ones(self.config.sampling.batch_size, device=self.device) *
921
+ self.config.guidance.condition).to(torch.long)
922
+ else:
923
+ cond = None
924
+ if ((self.parameterization == 'ar' and self.config.guidance.method in {'fudge', 'pplm'})
925
+ or self.config.guidance.method in {'cbg', 'nos'}):
926
+ if classifier_model is None:
927
+ classifier_model = classifier.Classifier.load_from_checkpoint(
928
+ self.config.guidance.classifier_checkpoint_path,
929
+ tokenizer=self.tokenizer,
930
+ config=self.config, logger=False)
931
+ classifier_model = classifier_model.to(self.device)
932
+ classifier_model.eval()
933
+ else:
934
+ classifier_model = None
935
+ else:
936
+ classifier_model, cond = None, None
937
+
938
+ if self.parameterization == 'ar':
939
+ samples = self._ar_sample(
940
+ classifier_model=classifier_model, cond=cond)
941
+ else: # Diffusion sampling
942
+ samples = self._diffusion_sample(
943
+ classifier_model=classifier_model, cond=cond,
944
+ eps=eps,
945
+ target_sequence=target_sequence,
946
+ target_motifs=target_motifs)
947
+ if not self.config.eval.disable_ema:
948
+ self._restore_non_ema_params()
949
+ return samples
950
+
951
+ @torch.no_grad()
952
+ def _ar_sample(
953
+ self,
954
+ classifier_model: typing.Optional[classifier.Classifier] = None,
955
+ cond: typing.Optional[torch.tensor] = None,
956
+ ):
957
+ # precompute token buffer
958
+ num_pred_tokens = self.config.model.length - 1
959
+ x = torch.zeros(
960
+ (self.config.sampling.batch_size, num_pred_tokens + 1),
961
+ dtype=torch.long,
962
+ device=self.device)
963
+ x[:, 0] = self.tokenizer.bos_token_id
964
+ # precompute Gumbel sampling noise
965
+ if (getattr(self.config, 'guidance', None) is not None
966
+ and self.config.guidance.method == 'fudge'):
967
+ noise = torch.distributions.Gumbel(0, 1).sample(
968
+ (self.config.sampling.batch_size, # type: ignore
969
+ num_pred_tokens,
970
+ self.config.guidance.topk)).to(self.device)
971
+ else:
972
+ noise = torch.distributions.Gumbel(0, 1).sample(
973
+ (self.config.sampling.batch_size, # type: ignore
974
+ num_pred_tokens,
975
+ self.vocab_size)).to(self.device)
976
+ if self.config.sampling.use_float64:
977
+ noise = noise.to(torch.float64)
978
+ pbar = tqdm(range(num_pred_tokens), desc='AR Sampling',
979
+ leave=False)
980
+ inference_params = InferenceParams(
981
+ max_seqlen=num_pred_tokens,
982
+ max_batch_size=x.shape[0],
983
+ seqlen_offset=1)
984
+ # For cfg we do 2 forward passes, one for conditional
985
+ # model and one unconditional, so we need 2 copies of
986
+ # inference_params.
987
+ uncond_inference_params = InferenceParams(
988
+ max_seqlen=num_pred_tokens,
989
+ max_batch_size=x.shape[0],
990
+ seqlen_offset=1)
991
+ for i in pbar:
992
+ if getattr(self.config, 'guidance', None) is None:
993
+ if self.config.backbone == 'dimamba':
994
+ log_probs = self.forward(
995
+ x[:, i:i + 1], None, cond=None,
996
+ inference_params=inference_params)
997
+ else:
998
+ log_probs = self.forward(x[:, :i + 1],
999
+ None, cond=None)
1000
+ if self.config.sampling.use_float64:
1001
+ log_probs = log_probs.to(torch.float64)
1002
+ next_log_probs = log_probs[:, -1]
1003
+ y = (next_log_probs + noise[:, i]).argmax(-1)
1004
+ else:
1005
+ if self.config.guidance.method == 'cfg':
1006
+ if self.config.backbone == 'dimamba':
1007
+ next_log_probs = self._ar_cfg_denoise(
1008
+ cond=cond,
1009
+ gamma=self.config.guidance.gamma,
1010
+ x=x[:, i:i + 1],
1011
+ i=i,
1012
+ inference_params=(inference_params, uncond_inference_params))
1013
+ else:
1014
+ next_log_probs = self._ar_cfg_denoise(
1015
+ cond=cond,
1016
+ gamma=self.config.guidance.gamma,
1017
+ x=x,
1018
+ i=i)
1019
+ y = (next_log_probs + noise[:, i]).argmax(-1)
1020
+ elif self.config.guidance.method == 'fudge':
1021
+ if self.config.backbone == 'dimamba':
1022
+ next_log_probs, top_indices = self._ar_fudge_denoise(
1023
+ classifier_model=classifier_model,
1024
+ guidance_cond=self.config.guidance.condition,
1025
+ topk=self.config.guidance.topk,
1026
+ gamma=self.config.guidance.gamma,
1027
+ x=x[:, i:i + 1],
1028
+ i=i,
1029
+ inference_params=inference_params)
1030
+ else:
1031
+ next_log_probs, top_indices = self._ar_fudge_denoise(
1032
+ classifier_model=classifier_model,
1033
+ guidance_cond=self.config.guidance.condition,
1034
+ topk=self.config.guidance.topk,
1035
+ gamma=self.config.guidance.gamma,
1036
+ x=x,
1037
+ i=i)
1038
+ y = torch.gather(
1039
+ top_indices,
1040
+ 1,
1041
+ (next_log_probs + noise[:, i]).argmax(-1).unsqueeze(1)
1042
+ ).squeeze(1)
1043
+ elif self.config.guidance.method == 'pplm':
1044
+ raise NotImplementedError
1045
+ else:
1046
+ raise NotImplementedError(
1047
+ f"Guidance method {self.config.guidance.method} not implemented.")
1048
+ pbar.set_postfix(
1049
+ prob_check=(next_log_probs.exp().sum() / x.shape[0]).item(),
1050
+ nan_check=bool(next_log_probs.isnan().sum() > 0))
1051
+ x[:, i + 1] = y
1052
+ return x
1053
+
1054
+ def _ar_cfg_denoise(
1055
+ self,
1056
+ cond: torch.tensor,
1057
+ gamma: float,
1058
+ x: torch.tensor,
1059
+ i: int,
1060
+ **kwargs
1061
+ ) -> torch.tensor:
1062
+ if self.config.guidance.gamma == 0.0: # Sample unconditionally
1063
+ mask_cond = (torch.ones_like(cond) *
1064
+ self.config.data.num_classes)
1065
+ if self.config.backbone == 'dimamba':
1066
+ inference_params = kwargs.pop('inference_params')
1067
+ log_probs = self.forward(
1068
+ x[:, :i + 1],None, cond=mask_cond,
1069
+ inference_params=inference_params[1])
1070
+ else:
1071
+ log_probs = self.forward(
1072
+ x[:, :i + 1],None, cond=mask_cond, **kwargs)
1073
+ elif gamma == 1.0: # Sample conditionally
1074
+ if self.config.backbone == 'dimamba':
1075
+ inference_params = kwargs.pop('inference_params')
1076
+ log_probs = self.forward(
1077
+ x[:, :i + 1], None, cond=cond,
1078
+ inference_params=inference_params[0])
1079
+ else:
1080
+ log_probs = self.forward(
1081
+ x[:, :i + 1], None, cond=cond, **kwargs)
1082
+ else: # Sample from tempered distribution
1083
+ mask_cond = (torch.ones_like(cond) *
1084
+ self.config.data.num_classes)
1085
+ if self.config.backbone == 'dimamba':
1086
+ inference_params = kwargs.pop('inference_params')
1087
+ log_probs_cond = self.forward(
1088
+ x[:, :i + 1], None, cond=cond,
1089
+ inference_params=inference_params[0])
1090
+ log_probs_uncond = self.forward(
1091
+ x[:, :i + 1],None, cond=mask_cond,
1092
+ inference_params=inference_params[1])
1093
+ else:
1094
+ log_probs_cond = self.forward(
1095
+ x[:, :i + 1], None, cond=cond, **kwargs)
1096
+ log_probs_uncond = self.forward(
1097
+ x[:, :i + 1],None, cond=mask_cond, **kwargs)
1098
+
1099
+ log_probs = gamma * log_probs_cond + (1 - gamma) * log_probs_uncond
1100
+ # Gamma > 1.0 causes instability for Mamba, re-normalizing
1101
+ log_probs = log_probs.log_softmax(dim=-1)
1102
+ return log_probs[:, -1]
1103
+
1104
+ def _ar_fudge_denoise(
1105
+ self,
1106
+ classifier_model: classifier.Classifier,
1107
+ guidance_cond: int,
1108
+ topk: int,
1109
+ gamma: float,
1110
+ x: torch.tensor,
1111
+ i: int,
1112
+ **kwargs
1113
+ ) -> typing.Tuple[torch.tensor, torch.LongTensor]:
1114
+ log_probs = self.forward(
1115
+ x[:, :i + 1], None, cond=None, **kwargs)
1116
+ next_log_probs = log_probs[:, -1]
1117
+ top_logits, top_indices = next_log_probs.topk(topk, dim=-1)
1118
+ t_candidates = torch.cat(
1119
+ [x[:, :i + 1].unsqueeze(1).expand(-1, topk, -1),
1120
+ top_indices.unsqueeze(2)],
1121
+ dim=2).view(-1, i + 2) # (B * K), L
1122
+
1123
+ t = torch.zeros(t_candidates.shape[0],
1124
+ device=self.device)
1125
+ sigma, dsigma = self.noise(t)
1126
+ time_conditioning = sigma[:, None]
1127
+
1128
+ classifier_log_prob = classifier_model.get_log_probs(
1129
+ t_candidates, time_conditioning)
1130
+ classifier_log_prob = classifier_log_prob[:, i + 1, :].view(
1131
+ x.shape[0], topk, -1)[..., guidance_cond] # (batch, topk)
1132
+ next_log_probs = (top_logits + gamma * classifier_log_prob).log_softmax(dim=-1)
1133
+ return next_log_probs, top_indices
1134
+
1135
+ def _ar_pplm_denoise(
1136
+ self,
1137
+ classifier_model: classifier.Classifier,
1138
+ guidance_cond: int,
1139
+ num_ppl_steps: int,
1140
+ pplm_step_size: float,
1141
+ pplm_stability_coef: float,
1142
+ x: torch.tensor,
1143
+ i: int,
1144
+ ):
1145
+ raise NotImplementedError
1146
+
1147
+ @torch.no_grad()
1148
+ def _diffusion_sample(
1149
+ self,
1150
+ classifier_model: typing.Optional[classifier.Classifier] = None,
1151
+ cond: typing.Optional[torch.tensor] = None,
1152
+ eps: float = 1e-5, # Note: differs from self.config.training.sampling_eps
1153
+ target_sequence: torch.tensor = None,
1154
+ target_motifs: torch.tensor = None,
1155
+ ):
1156
+ xt = self._sample_prior(
1157
+ self.config.sampling.batch_size,
1158
+ self.config.model.length
1159
+ ).to(self.device)
1160
+
1161
+ timesteps = torch.linspace(
1162
+ 1, eps, self.config.sampling.steps + 1, device=self.device)
1163
+ dt = (1 - eps) / self.config.sampling.steps
1164
+ pbar = tqdm(range(self.config.sampling.steps),
1165
+ desc='Sampling',
1166
+ leave=False)
1167
+ NFEs = 0
1168
+ cache = None
1169
+
1170
+ for i in pbar:
1171
+ t = timesteps[i]
1172
+ if self.T > 0: # t in {1/T,..., 1}, to match training
1173
+ t = (t * self.T).to(torch.int)
1174
+ t = t / self.T
1175
+ t += (1 / self.T)
1176
+ t = t * torch.ones(xt.shape[0], 1, device=self.device)
1177
+ if cache is None:
1178
+ NFEs += 1
1179
+ sigma_t, _ = self.noise(t)
1180
+ sigma_s, _ = self.noise(t - dt)
1181
+ if sigma_t.ndim > 1:
1182
+ sigma_t = sigma_t.squeeze(-1)
1183
+ if sigma_s.ndim > 1:
1184
+ sigma_s = sigma_s.squeeze(-1)
1185
+ assert sigma_t.ndim == 1, sigma_t.shape
1186
+ assert sigma_s.ndim == 1, sigma_s.shape
1187
+ move_chance_t = 1 - torch.exp(-sigma_t)
1188
+ move_chance_s = 1 - torch.exp(-sigma_s)
1189
+ move_chance_t = move_chance_t[:, None, None]
1190
+ move_chance_s = move_chance_s[:, None, None]
1191
+ assert move_chance_t.ndim == 3, move_chance_t.shape
1192
+
1193
+ if getattr(self.config, 'guidance', None) is None:
1194
+ xs, q_xs, cache = self._ddpm_denoise(
1195
+ xt=xt,
1196
+ time_conditioning=sigma_t,
1197
+ move_chance_t=move_chance_t,
1198
+ move_chance_s=move_chance_s,
1199
+ cache=cache)
1200
+ else:
1201
+ if self.config.guidance.method == 'cfg':
1202
+ xs, q_xs, cache = self._cfg_denoise(
1203
+ cond=cond,
1204
+ gamma=self.config.guidance.gamma,
1205
+ xt=xt,
1206
+ time_conditioning=sigma_t,
1207
+ move_chance_t=move_chance_t,
1208
+ move_chance_s=move_chance_s,
1209
+ cache=cache)
1210
+ elif self.config.guidance.method == 'cbg':
1211
+ xs, q_xs, cache = self._cbg_denoise(
1212
+ classifier_model=classifier_model,
1213
+ conditioning_class=self.config.guidance.condition,
1214
+ gamma=self.config.guidance.gamma,
1215
+ use_approx=self.config.guidance.use_approx,
1216
+ xt=xt,
1217
+ time_conditioning=sigma_t,
1218
+ move_chance_t=move_chance_t,
1219
+ move_chance_s=move_chance_s,
1220
+ target_sequence=target_sequence,
1221
+ target_motifs=target_motifs,
1222
+ cache=cache)
1223
+ elif self.config.guidance.method == 'nos':
1224
+ xs, q_xs, cache = self._nos_denoise(
1225
+ classifier_model=classifier_model,
1226
+ conditioning_class=self.config.guidance.condition,
1227
+ num_nos_steps=self.config.guidance.num_nos_steps,
1228
+ nos_step_size=self.config.guidance.nos_step_size,
1229
+ nos_stability_coef=self.config.guidance.nos_stability_coef,
1230
+ xt=xt,
1231
+ time_conditioning=sigma_t,
1232
+ move_chance_t=move_chance_t,
1233
+ move_chance_s=move_chance_s)
1234
+ else:
1235
+ raise NotImplementedError(
1236
+ f"Guidance method {self.config.guidance.method} not implemented.")
1237
+ pbar.set_postfix(
1238
+ NFEs=NFEs,
1239
+ prob_check=(q_xs.sum() / xt.numel()).item(),
1240
+ nan_check=bool(q_xs.isnan().sum() > 0))
1241
+ if (not self.config.sampling.use_cache or
1242
+ not torch.allclose(xs, xt)):
1243
+ # Disable caching
1244
+ cache = None
1245
+ xt = xs
1246
+ return xt
1247
+
1248
+ def _ddpm_denoise(
1249
+ self,
1250
+ xt: torch.tensor,
1251
+ time_conditioning: torch.tensor,
1252
+ move_chance_t: torch.tensor,
1253
+ move_chance_s: torch.tensor,
1254
+ cache: typing.Optional[typing.Dict[str, torch.Tensor]] = None,
1255
+ ) -> typing.Tuple[torch.tensor, torch.tensor, typing.Dict[str, torch.tensor]]:
1256
+
1257
+ # Compute x_theta
1258
+ if cache is not None:
1259
+ log_x_theta = cache['log_x_theta']
1260
+ else:
1261
+ log_x_theta = self.forward(xt, time_conditioning,
1262
+ cond=None)
1263
+ if self.config.sampling.use_float64:
1264
+ log_x_theta = log_x_theta.to(torch.float64)
1265
+ x_theta = log_x_theta.exp()
1266
+
1267
+ # Compute posterior
1268
+ if self.diffusion == 'absorbing_state':
1269
+ q_xs = x_theta * (move_chance_t - move_chance_s)
1270
+ q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0]
1271
+ q_xs /= move_chance_t
1272
+ elif self.diffusion == 'uniform':
1273
+ q_xs = self._compute_posterior(
1274
+ x=x_theta,
1275
+ xt=xt,
1276
+ alpha_s=1 - move_chance_s,
1277
+ alpha_t=1 - move_chance_t)
1278
+ else:
1279
+ raise NotImplementedError(
1280
+ f"Diffusion type {self.diffusion} not implemented.")
1281
+
1282
+ # Sample from posterior
1283
+ xs = _sample_categorical(q_xs)
1284
+ if self.diffusion == 'absorbing_state':
1285
+ copy_flag = (xt != self.mask_index).to(torch.bool)
1286
+ q_xs[copy_flag] = 0.0
1287
+ q_xs[copy_flag, xt[copy_flag]] = 1.0
1288
+ xs = torch.where(copy_flag, xt, xs)
1289
+
1290
+ return xs, q_xs, {'log_x_theta': log_x_theta}
1291
+
1292
+ def _cfg_denoise(
1293
+ self,
1294
+ cond: torch.tensor,
1295
+ gamma: float,
1296
+ xt: torch.tensor,
1297
+ time_conditioning: torch.tensor,
1298
+ move_chance_t: torch.tensor,
1299
+ move_chance_s: torch.tensor,
1300
+ cache: typing.Optional[typing.Dict[str, torch.Tensor]] = None,
1301
+ ) -> typing.Tuple[torch.tensor, torch.tensor, typing.Dict[str, torch.tensor]]:
1302
+
1303
+ # Compute log_x_theta
1304
+ if cache is not None:
1305
+ log_x_theta_uncond = cache['log_x_theta_uncond']
1306
+ log_x_theta_cond = cache['log_x_theta_cond']
1307
+ else:
1308
+ if gamma == 0.0: # Sample unconditionally
1309
+ mask_cond = (torch.ones_like(cond) *
1310
+ self.config.data.num_classes)
1311
+ log_x_theta_uncond = self.forward(
1312
+ xt, time_conditioning, cond=mask_cond)
1313
+ log_x_theta_cond = None
1314
+ elif gamma == 1.0: # Sample conditionally
1315
+ log_x_theta_cond = self.forward(xt, time_conditioning,
1316
+ cond=cond)
1317
+ log_x_theta_uncond = None
1318
+ else: # Sample from tempered distribution
1319
+ log_x_theta_cond = self.forward(xt, time_conditioning,
1320
+ cond=cond)
1321
+ mask_cond = (torch.ones_like(cond) *
1322
+ self.config.data.num_classes)
1323
+ log_x_theta_uncond = self.forward(xt,
1324
+ time_conditioning,
1325
+ cond=mask_cond)
1326
+ # Compute (weighted) posterior
1327
+ if (log_x_theta_cond is None # gamma == 0
1328
+ or log_x_theta_uncond is None): # or gamma == 1
1329
+ log_x_theta = log_x_theta_uncond if log_x_theta_uncond is not None else log_x_theta_cond
1330
+ x_theta = log_x_theta.exp()
1331
+ if self.diffusion == 'absorbing_state':
1332
+ q_xs = x_theta * (move_chance_t - move_chance_s)
1333
+ q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0]
1334
+ q_xs /= move_chance_t
1335
+ elif self.diffusion == 'uniform':
1336
+ q_xs = self._compute_posterior(
1337
+ x=x_theta,
1338
+ xt=xt,
1339
+ alpha_s=1 - move_chance_s,
1340
+ alpha_t=1 - move_chance_t)
1341
+ else:
1342
+ raise NotImplementedError(
1343
+ f"Diffusion type {self.diffusion} not implemented.")
1344
+ else: # gamma != 0 and gamma != 1
1345
+ if self.diffusion == 'absorbing_state':
1346
+ log_x_theta = (gamma * log_x_theta_cond + (1 - gamma) * log_x_theta_uncond)
1347
+ x_theta = log_x_theta.softmax(dim=-1)
1348
+ q_xs = x_theta * (move_chance_t - move_chance_s)
1349
+ q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0]
1350
+ q_xs /= move_chance_t
1351
+ elif (self.diffusion == 'uniform'
1352
+ or self.diffusion == 'uniform_data_marginals'):
1353
+ log_q_xs_uncond = self._compute_posterior(
1354
+ x=log_x_theta_uncond.exp(),
1355
+ xt=xt,
1356
+ alpha_s=1 - move_chance_s,
1357
+ alpha_t=1 - move_chance_t).log()
1358
+ log_q_xs_cond = self._compute_posterior(
1359
+ x=log_x_theta_cond.exp(),
1360
+ xt=xt,
1361
+ alpha_s=1 - move_chance_s,
1362
+ alpha_t=1 - move_chance_t).log()
1363
+ log_q_xs = (gamma * log_q_xs_cond +
1364
+ (1 - gamma) * log_q_xs_uncond)
1365
+ q_xs = log_q_xs.softmax(dim=-1)
1366
+ else:
1367
+ raise NotImplementedError(
1368
+ f"Diffusion type {self.diffusion} not implemented.")
1369
+
1370
+ # Sample from posterior
1371
+ xs = _sample_categorical(q_xs)
1372
+ if self.diffusion == 'absorbing_state':
1373
+ copy_flag = (xt != self.mask_index).to(torch.bool)
1374
+ q_xs[copy_flag] = 0.0
1375
+ q_xs[copy_flag, xt[copy_flag]] = 1.0
1376
+ xs = torch.where(copy_flag, xt, xs)
1377
+
1378
+ return xs, q_xs, {'log_x_theta_uncond': log_x_theta_uncond,
1379
+ 'log_x_theta_cond': log_x_theta_cond}
1380
+
1381
+ def _cbg_denoise(
1382
+ self,
1383
+ conditioning_class: int,
1384
+ gamma: float,
1385
+ classifier_model: classifier.Classifier,
1386
+ xt: torch.tensor,
1387
+ time_conditioning: torch.tensor,
1388
+ move_chance_t: torch.tensor,
1389
+ move_chance_s: torch.tensor,
1390
+ target_sequence: torch.tensor = None,
1391
+ target_motifs: torch.tensor = None,
1392
+ use_approx: bool = False, # whether to use first-order approximation
1393
+ cache: typing.Optional[typing.Dict[str, torch.Tensor]] = None,
1394
+ ) -> typing.Tuple[torch.tensor, torch.tensor, typing.Dict[str, torch.tensor]]:
1395
+
1396
+ if cache is not None:
1397
+ log_x_theta = cache['log_x_theta']
1398
+ classifier_log_prob = cache['classifier_log_prob']
1399
+ else:
1400
+ # Diffusion model
1401
+ log_x_theta = self.forward(xt, time_conditioning,
1402
+ cond=None)
1403
+ # Classifier model
1404
+ if use_approx:
1405
+ xt_one_hot = torch.nn.functional.one_hot(
1406
+ xt, self.vocab_size).to(torch.float)
1407
+ with torch.enable_grad():
1408
+ xt_one_hot.requires_grad_(True)
1409
+ classifier_log_prob_xt = classifier_model.get_log_probs(
1410
+ xt_one_hot, time_conditioning)
1411
+ classifier_log_prob_xt[..., conditioning_class].sum().backward()
1412
+ grad_log_prob_xt = xt_one_hot.grad
1413
+
1414
+ classifier_log_prob_ratio = (
1415
+ grad_log_prob_xt - (xt_one_hot * grad_log_prob_xt).sum(dim=-1, keepdim=True)
1416
+ ).detach().requires_grad_(False)
1417
+ classifier_log_prob = (
1418
+ classifier_log_prob_ratio +
1419
+ classifier_log_prob_xt[..., conditioning_class][..., None, None]
1420
+ ).detach().requires_grad_(False)
1421
+ else:
1422
+ # Copied from https://github.com/hnisonoff/discrete_guidance/blob/main/src/fm_utils.py#L441
1423
+ bsz, seq_len = xt.shape
1424
+ # Create bsz*seq_len*N copies of input sequences
1425
+ # Shape: (bsz, 1, seq_len) -> (bsz, seq_len*N, seq_len)
1426
+ # (where N = vocab_size).
1427
+ xt_expand = xt.unsqueeze(1).repeat(1, seq_len * self.vocab_size, 1)
1428
+ # Flatten batch and transition dimensions
1429
+ # Shape: (bsz, seq_len*N, seq_len) -> (bsz*seq_len*N, seq_len)
1430
+ xt_expand = xt_expand.view(-1, seq_len)
1431
+
1432
+ # Create indices for all possible transitions
1433
+ # Shape: (seq_len*N,) -> (bsz, seq_len*N) -> (bsz*seq_len*N,)
1434
+ jump_idx = torch.arange(seq_len * self.vocab_size).to(xt.device)
1435
+ jump_idx = jump_idx.repeat(bsz, 1).flatten()
1436
+
1437
+ # Create tensor for states after one transition
1438
+ xt_jumps = xt_expand.clone()
1439
+
1440
+ # Calculate which dimension changes for each transition
1441
+ # Shape: (bsz*seq_len*N,)
1442
+ jump_dims = jump_idx // self.vocab_size
1443
+
1444
+ # Calculate new value for changed dimension
1445
+ # Shape: (bsz*seq_len*N,)
1446
+ jump_states = jump_idx % self.vocab_size
1447
+
1448
+ # Apply transitions by assigning new values at transition dimensions
1449
+ # Shape: (bsz*seq_len*N, seq_len)
1450
+ xt_jumps[
1451
+ torch.arange(jump_idx.size(0), device=xt.device),
1452
+ jump_dims, # Index the transitioned dimension
1453
+ ] = jump_states # Assign the new state
1454
+
1455
+ # classifier_log_prob = (classifier_model.get_log_probs(
1456
+ # xt_jumps, time_conditioning.repeat(seq_len * self.vocab_size)
1457
+ # ))[..., conditioning_class].reshape(bsz, seq_len, self.vocab_size)
1458
+
1459
+ target_sequence = target_sequence.to(self.device)
1460
+ mask_vec = torch.tensor([1 if i-1 in target_motifs else 0 for i in range(target_sequence.shape[1])]).to(self.device)
1461
+
1462
+ bindevaluator_probs = classifier_model.get_probs(
1463
+ xt_jumps, target_sequence.repeat(xt_jumps.shape[0], 1)
1464
+ )
1465
+
1466
+ # pdb.set_trace()
1467
+ bindevaluator_probs = torch.where(bindevaluator_probs == 0, torch.tensor(1e-8, dtype=bindevaluator_probs.dtype), bindevaluator_probs)
1468
+ classifier_log_prob = torch.log(bindevaluator_probs) * mask_vec
1469
+
1470
+ # pdb.set_trace()
1471
+ classifier_log_prob = classifier_log_prob.sum(dim=-1) / mask_vec.sum()
1472
+ classifier_log_prob = classifier_log_prob.reshape(bsz, seq_len, self.vocab_size)
1473
+
1474
+ # classifier_log_prob = (torch.exp(classifier_model.get_log_probs(
1475
+ # xt_jumps, target_sequence.repeat(xt_jumps.shape[0], 1)
1476
+ # )) * mask_vec).sum(dim=-1).log().reshape(bsz, seq_len, self.vocab_size)
1477
+
1478
+ # (bsz, seq_len, N) / (bsz, seq_len, N, tgt_len)
1479
+ # pdb.set_trace()
1480
+ # Compute unguided posterior
1481
+ if self.diffusion == 'absorbing_state':
1482
+ diffusion_log_probs = log_x_theta + torch.log(
1483
+ 1. - (move_chance_s / move_chance_t))
1484
+ diffusion_log_probs[..., self.mask_index] = torch.log(
1485
+ move_chance_s / move_chance_t)[:, :, 0]
1486
+ diffusion_log_probs.detach()
1487
+ elif self.diffusion == 'uniform':
1488
+ diffusion_log_probs = self._compute_posterior(
1489
+ x=log_x_theta.exp(),
1490
+ xt=xt,
1491
+ alpha_s=1 - move_chance_s,
1492
+ alpha_t=1 - move_chance_t).log()
1493
+ else:
1494
+ raise NotImplementedError(
1495
+ f"Diffusion type {self.diffusion} not implemented.")
1496
+
1497
+ # Apply guidance
1498
+ with torch.no_grad():
1499
+ if self.diffusion == 'absorbing_state':
1500
+ guided_log_probs = (gamma * classifier_log_prob) + diffusion_log_probs
1501
+ copy_flag = (xt != self.mask_index)
1502
+ guided_log_probs[copy_flag] = self.neg_infinity
1503
+ guided_log_probs[copy_flag, xt[copy_flag]] = 0.0
1504
+ elif self.diffusion == 'uniform':
1505
+ # pdb.set_trace()
1506
+ guided_log_probs = (gamma * classifier_log_prob) + diffusion_log_probs
1507
+ else:
1508
+ raise NotImplementedError(
1509
+ f"Diffusion type {self.diffusion} not implemented.")
1510
+
1511
+ guided_probs = guided_log_probs.softmax(dim=-1)
1512
+ # Sample from guided posterior
1513
+ xs = _sample_categorical(guided_probs)
1514
+ if self.diffusion == 'absorbing_state':
1515
+ xs = torch.where(copy_flag.to(bool), xt, xs)
1516
+ return xs, guided_probs, {'log_x_theta': log_x_theta,
1517
+ 'classifier_log_prob': classifier_log_prob}
1518
+
1519
+ def _nos_denoise(
1520
+ self,
1521
+ classifier_model: classifier.Classifier,
1522
+ num_nos_steps: int,
1523
+ nos_step_size: float,
1524
+ nos_stability_coef: float,
1525
+ conditioning_class: int,
1526
+ xt: torch.Tensor,
1527
+ time_conditioning: torch.tensor,
1528
+ move_chance_t: torch.tensor,
1529
+ move_chance_s: torch.tensor,
1530
+ ) -> typing.Tuple[torch.tensor, torch.tensor, None]:
1531
+ # Compute original diffusion_log_probs and hidden states
1532
+ copy_flag = (xt != self.mask_index).to(torch.bool)
1533
+ with torch.no_grad():
1534
+ time_conditioning = self._process_sigma(time_conditioning)
1535
+ with torch.cuda.amp.autocast(dtype=torch.float32):
1536
+ logits, hidden_states = self.backbone(
1537
+ xt, time_conditioning, cond=None,
1538
+ return_hidden_states=True)
1539
+ if self.parameterization == 'subs':
1540
+ log_x_theta = self._subs_parameterization(
1541
+ logits=logits, xt=xt)
1542
+ elif self.parameterization == 'd3pm':
1543
+ # returns log_probs
1544
+ if self.subs_masking: # Can use "zero masking prob"
1545
+ logits[:, :,
1546
+ self.mask_index] += self.neg_infinity
1547
+ log_x_theta = logits.log_softmax(dim=-1)
1548
+ else:
1549
+ raise NotImplementedError(
1550
+ f"Parameterization {self.parameterization} not implemented for NOS guidance.")
1551
+ if self.diffusion == 'absorbing_state':
1552
+ diffusion_log_probs = log_x_theta + torch.log(
1553
+ 1. - (move_chance_s / move_chance_t))
1554
+ diffusion_log_probs[..., self.mask_index] = torch.log(
1555
+ move_chance_s / move_chance_t)[:, :, 0]
1556
+ diffusion_log_probs[copy_flag] = self.neg_infinity
1557
+ diffusion_log_probs[copy_flag, xt[copy_flag]] = 0.0
1558
+ elif self.diffusion == 'uniform':
1559
+ diffusion_log_probs = self._compute_posterior(
1560
+ x=log_x_theta.exp(),
1561
+ xt=xt,
1562
+ alpha_s=1 - move_chance_s,
1563
+ alpha_t=1 - move_chance_t).log()
1564
+
1565
+ # Perform NOS steps
1566
+ kl_loss = torch.nn.KLDivLoss(reduction='batchmean',
1567
+ log_target=True)
1568
+ delta = torch.nn.Parameter(
1569
+ torch.zeros_like(hidden_states[-1]),
1570
+ requires_grad=True)
1571
+ optimizer = torch.optim.Adagrad([delta], lr=nos_step_size)
1572
+ with torch.enable_grad():
1573
+ for _ in tqdm(range(num_nos_steps),
1574
+ desc='NOS', leave=False):
1575
+ h_current = hidden_states[-1] + delta
1576
+ target_loss = classifier_model.get_log_probs(
1577
+ xt, time_conditioning, x_emb=h_current)[..., conditioning_class].sum()
1578
+ with torch.cuda.amp.autocast(dtype=torch.float32):
1579
+ new_logits = self.forward(xt, time_conditioning,
1580
+ cond=None,
1581
+ x_emb=h_current)
1582
+ if self.diffusion == 'absorbing_state':
1583
+ adjusted_log_probs = new_logits + torch.log(
1584
+ 1. - (move_chance_s / move_chance_t))
1585
+ adjusted_log_probs[
1586
+ ..., self.mask_index] = torch.log(
1587
+ move_chance_s / move_chance_t)[:, :, 0]
1588
+ adjusted_log_probs[
1589
+ copy_flag] = self.neg_infinity
1590
+ adjusted_log_probs[copy_flag, xt[copy_flag]] = 0.0
1591
+ elif self.diffusion == 'uniform':
1592
+ adjusted_log_probs = self._compute_posterior(
1593
+ x=new_logits.exp(),
1594
+ xt=xt,
1595
+ alpha_s=1 - move_chance_s,
1596
+ alpha_t=1 - move_chance_t).log()
1597
+ kl = kl_loss(adjusted_log_probs, diffusion_log_probs)
1598
+ loss = -target_loss + nos_stability_coef * kl
1599
+ optimizer.zero_grad()
1600
+ loss.backward()
1601
+ optimizer.step()
1602
+ with torch.cuda.amp.autocast(dtype=torch.float32):
1603
+ guided_logits = self.forward(
1604
+ xt, time_conditioning,
1605
+ cond=None,
1606
+ x_emb=hidden_states[-1] + delta.data)
1607
+ if self.diffusion == 'absorbing_state':
1608
+ diffusion_log_probs = guided_logits + torch.log(
1609
+ 1. - (move_chance_s / move_chance_t))
1610
+ diffusion_log_probs[
1611
+ ..., self.mask_index] = torch.log(
1612
+ move_chance_s / move_chance_t)[:, :, 0]
1613
+ diffusion_log_probs.detach()
1614
+ guided_probs = diffusion_log_probs.exp()
1615
+ elif self.diffusion == 'uniform':
1616
+ guided_probs = self._compute_posterior(
1617
+ x=guided_logits.exp(),
1618
+ xt=xt,
1619
+ alpha_s=1 - move_chance_s,
1620
+ alpha_t=1 - move_chance_t).detach()
1621
+ else:
1622
+ raise NotImplementedError(
1623
+ f"Diffusion type {self.diffusion} not implemented.")
1624
+
1625
+ xs = _sample_categorical(guided_probs)
1626
+ if self.diffusion == 'absorbing_state':
1627
+ xs = torch.where(copy_flag, xt, xs)
1628
+
1629
+ return xs, guided_probs, None
eval_utils.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import transformers
5
+ from tqdm import tqdm
6
+
7
+ import diffusion
8
+
9
+
10
+ def compute_ppl(
11
+ pretrained_model,
12
+ val_ds
13
+ ):
14
+ ppl_metrics = diffusion.Perplexity().to('cuda')
15
+ pbar = tqdm(val_ds, desc='PPL')
16
+ for batch in pbar:
17
+ input_ids = batch['input_ids'].to('cuda')
18
+ if 'attention_mask' in batch:
19
+ attention_mask = batch['attention_mask'].to('cuda')
20
+ else:
21
+ attention_mask = None
22
+ losses = pretrained_model._loss(input_ids, attention_mask)
23
+ ppl_metrics.update(losses.nlls, losses.token_mask)
24
+ pbar.set_postfix({'ppl': ppl_metrics.compute().item()})
25
+ return ppl_metrics.compute().item()
26
+
27
+
28
+ def compute_generative_ppl(
29
+ sentences,
30
+ eval_model_name_or_path,
31
+ gen_ppl_eval_batch_size=8,
32
+ max_length=128):
33
+ gen_ppl_metric = diffusion.Perplexity().to('cuda')
34
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
35
+ eval_model_tokenizer = transformers.AutoTokenizer.from_pretrained(
36
+ eval_model_name_or_path)
37
+ if eval_model_tokenizer.pad_token is None:
38
+ eval_model_tokenizer.pad_token = \
39
+ eval_model_tokenizer.eos_token
40
+ eval_model_tokenizer.pad_token_id = \
41
+ eval_model_tokenizer.eos_token_id
42
+ eval_model = transformers.AutoModelForCausalLM.from_pretrained(
43
+ eval_model_name_or_path).eval()
44
+ if max_length is None:
45
+ max_length = max_length
46
+ eval_model = eval_model.to('cuda')
47
+ # Re-tokenize using eval model's tokenizer
48
+ tokenizer_kwargs = {
49
+ 'return_tensors': 'pt',
50
+ 'return_token_type_ids': False,
51
+ 'return_attention_mask': True,
52
+ 'truncation': True,
53
+ 'padding': True,
54
+ 'max_length': max_length,
55
+ }
56
+ eval_context_size = 1024
57
+ samples = eval_model_tokenizer(
58
+ sentences, **tokenizer_kwargs)
59
+ attn_mask = samples['attention_mask']
60
+ samples = samples['input_ids']
61
+ attn_mask = attn_mask.to('cuda')
62
+ samples = samples.to('cuda')
63
+ num_batches = samples.shape[0] // gen_ppl_eval_batch_size
64
+ for i in tqdm(range(num_batches),
65
+ desc='Gen. PPL', leave=False):
66
+ _samples = torch.split(
67
+ samples[i * gen_ppl_eval_batch_size: (i + 1) * gen_ppl_eval_batch_size],
68
+ eval_context_size,
69
+ dim=-1)
70
+ _attn_mask = torch.split(
71
+ attn_mask[i * gen_ppl_eval_batch_size: (i + 1) * gen_ppl_eval_batch_size],
72
+ eval_context_size,
73
+ dim=-1)
74
+ for (sample_chunk, attn_mask_chunk) in zip(
75
+ _samples, _attn_mask):
76
+ logits = eval_model(
77
+ sample_chunk, attention_mask=attn_mask_chunk)[0]
78
+ logits = logits.transpose(-1, -2)
79
+
80
+ nlls = torch.nn.functional.cross_entropy(
81
+ logits[..., :-1],
82
+ sample_chunk[..., 1:],
83
+ reduction='none')
84
+ # first_eos = (sample_chunk == eval_model_tokenizer.eos_token_id).cumsum(-1) == 1
85
+ # token_mask = (sample_chunk != eval_model_tokenizer.eos_token_id)
86
+ # gen_ppl_metric.update(
87
+ # nlls, first_eos[..., 1:] + token_mask[..., 1:])
88
+ gen_ppl_metric.update(
89
+ nlls, attn_mask_chunk[..., 1:])
90
+ return gen_ppl_metric.compute().item()
noise_schedule.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ # Flags required to enable jit fusion kernels
7
+ torch._C._jit_set_profiling_mode(False)
8
+ torch._C._jit_set_profiling_executor(False)
9
+ torch._C._jit_override_can_fuse_on_cpu(True)
10
+ torch._C._jit_override_can_fuse_on_gpu(True)
11
+
12
+
13
+ def get_noise(config, dtype=torch.float32):
14
+ if config.noise.type == 'geometric':
15
+ return GeometricNoise(config.noise.sigma_min,
16
+ config.noise.sigma_max)
17
+ elif config.noise.type == 'loglinear':
18
+ return LogLinearNoise()
19
+ elif config.noise.type == 'cosine':
20
+ return CosineNoise()
21
+ elif config.noise.type == 'cosinesqr':
22
+ return CosineSqrNoise()
23
+ elif config.noise.type == 'linear':
24
+ return Linear(config.noise.sigma_min,
25
+ config.noise.sigma_max,
26
+ dtype)
27
+ else:
28
+ raise NotImplementedError(
29
+ f'{config.noise.type} noise schedule is not '
30
+ f'implemented.')
31
+
32
+
33
+ def binary_discretization(z):
34
+ z_hard = torch.sign(z)
35
+ z_soft = z / torch.norm(z, dim=-1, keepdim=True)
36
+ return z_soft + (z_hard - z_soft).detach()
37
+
38
+
39
+ class Noise(abc.ABC, nn.Module):
40
+ """
41
+ Base Noise class.
42
+
43
+ Defines forward signature, which returns:
44
+ total and rate of noise for a given timestep.
45
+ """
46
+ def forward(self, t):
47
+ # Assume time goes from 0 to 1
48
+ return self.total_noise(t), self.rate_noise(t)
49
+
50
+ @abc.abstractmethod
51
+ def rate_noise(self, t):
52
+ """
53
+ Rate of change of noise, i.e. g(t)
54
+ """
55
+ pass
56
+
57
+ @abc.abstractmethod
58
+ def total_noise(self, t):
59
+ """
60
+ Total noise ie \int_0^t g(t) dt + g(0)
61
+ """
62
+ pass
63
+
64
+
65
+ class CosineNoise(Noise):
66
+ def __init__(self, eps=1e-3):
67
+ super().__init__()
68
+ self.eps = eps
69
+
70
+ def rate_noise(self, t):
71
+ cos = (1 - self.eps) * torch.cos(t * torch.pi / 2)
72
+ sin = (1 - self.eps) * torch.sin(t * torch.pi / 2)
73
+ scale = torch.pi / 2
74
+ return scale * sin / (cos + self.eps)
75
+
76
+ def total_noise(self, t):
77
+ cos = torch.cos(t * torch.pi / 2)
78
+ return - torch.log(self.eps + (1 - self.eps) * cos)
79
+
80
+
81
+ class CosineSqrNoise(Noise):
82
+ def __init__(self, eps=1e-3):
83
+ super().__init__()
84
+ self.eps = eps
85
+
86
+ def rate_noise(self, t):
87
+ cos = (1 - self.eps) * (
88
+ torch.cos(t * torch.pi / 2) ** 2)
89
+ sin = (1 - self.eps) * torch.sin(t * torch.pi)
90
+ scale = torch.pi / 2
91
+ return scale * sin / (cos + self.eps)
92
+
93
+ def total_noise(self, t):
94
+ cos = torch.cos(t * torch.pi / 2) ** 2
95
+ return - torch.log(self.eps + (1 - self.eps) * cos)
96
+
97
+
98
+ class Linear(Noise):
99
+ def __init__(self, sigma_min=0, sigma_max=10,
100
+ dtype=torch.float32):
101
+ super().__init__()
102
+ self.sigma_min = torch.tensor(sigma_min, dtype=dtype)
103
+ self.sigma_max = torch.tensor(sigma_max, dtype=dtype)
104
+
105
+ def rate_noise(self, t):
106
+ return self.sigma_max - self.sigma_min
107
+
108
+ def total_noise(self, t):
109
+ return (self.sigma_min + t *
110
+ (self.sigma_max - self.sigma_min))
111
+
112
+ def importance_sampling_transformation(self, t):
113
+ f_T = torch.log1p(- torch.exp(- self.sigma_max))
114
+ f_0 = torch.log1p(- torch.exp(- self.sigma_min))
115
+ sigma_t = - torch.log1p(
116
+ -torch.exp(t * f_T + (1 - t) * f_0))
117
+ return (sigma_t - self.sigma_min) / (
118
+ self.sigma_max - self.sigma_min)
119
+
120
+
121
+ class GeometricNoise(Noise):
122
+ def __init__(self, sigma_min=1e-3, sigma_max=1):
123
+ super().__init__()
124
+ self.sigmas = 1.0 * torch.tensor([sigma_min, sigma_max])
125
+
126
+ def rate_noise(self, t):
127
+ return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t * (
128
+ self.sigmas[1].log() - self.sigmas[0].log())
129
+
130
+ def total_noise(self, t):
131
+ return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t
132
+
133
+
134
+ class LogLinearNoise(Noise):
135
+ """Log Linear noise schedule.
136
+
137
+ Built such that 1 - 1/e^(n(t)) interpolates between 0 and
138
+ ~1 when t varies from 0 to 1. Total noise is
139
+ -log(1 - (1 - eps) * t), so the sigma will be
140
+ (1 - eps) * t.
141
+ """
142
+ def __init__(self, eps=1e-3):
143
+ super().__init__()
144
+ self.eps = eps
145
+ self.sigma_max = self.total_noise(torch.tensor(1.0))
146
+ self.sigma_min = self.eps + self.total_noise(
147
+ torch.tensor(0.0))
148
+
149
+ def rate_noise(self, t):
150
+ return (1 - self.eps) / (1 - (1 - self.eps) * t)
151
+
152
+ def total_noise(self, t):
153
+ return -torch.log1p(-(1 - self.eps) * t)
154
+
155
+ def importance_sampling_transformation(self, t):
156
+ f_T = torch.log1p(- torch.exp(- self.sigma_max))
157
+ f_0 = torch.log1p(- torch.exp(- self.sigma_min))
158
+ sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0))
159
+ t = - torch.expm1(- sigma_t) / (1 - self.eps)
160
+ return t
requirements.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: ct_udlm
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ - anaconda
6
+ - defaults
7
+ dependencies:
8
+ - cuda-nvcc=12.4.99
9
+ - ipykernel=6.29.5
10
+ - ipython=8.15.0
11
+ - ipywidgets=8.1.2
12
+ - pip=23.3.1
13
+ - python=3.9.20
14
+ - pip:
15
+ - biopython==1.84
16
+ - causal-conv1d==1.4.0
17
+ - datasets==2.18.0
18
+ - einops==0.8.0
19
+ - flash-attn==2.7.2.post1
20
+ - fsspec==2024.2.0
21
+ - git-lfs==1.6
22
+ - h5py==3.10.0
23
+ - huggingface-hub==0.26.2
24
+ - hydra-core==1.3.2
25
+ - ipdb==0.13.13
26
+ - jupyter==1.1.1
27
+ - jupyterlab==4.1.8
28
+ - lightning==2.2.1
29
+ - lightning-utilities==0.11.9
30
+ - mamba-ssm==1.2.0.post1
31
+ - matplotlib==3.9.2
32
+ - notebook==7.1.1
33
+ - numpy==1.26.4
34
+ - omegaconf==2.3.0
35
+ - pandas==2.2.1
36
+ - pytorch-image-generation-metrics==0.6.1
37
+ - rdkit==2024.3.6
38
+ - regex==2024.11.6
39
+ - rich==13.7.1
40
+ - safetensors==0.4.5
41
+ - scikit-learn==1.4.0
42
+ - scipy==1.13.1
43
+ - seaborn==0.13.2
44
+ - timm==0.9.16
45
+ - tokenizers==0.15.2
46
+ - torchmetrics==1.6.0
47
+ - tqdm==4.67.0
48
+ - transformers==4.38.2
49
+ - wandb==0.13.5
sample.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import hydra
3
+ import lightning as L
4
+ import numpy as np
5
+ import omegaconf
6
+ import pandas as pd
7
+ import rdkit
8
+ import rich.syntax
9
+ import rich.tree
10
+ import torch
11
+ from tqdm.auto import tqdm
12
+ import pdb
13
+
14
+ import dataloader
15
+ import diffusion
16
+ from models.bindevaluator import BindEvaluator
17
+
18
+ rdkit.rdBase.DisableLog('rdApp.error')
19
+
20
+ omegaconf.OmegaConf.register_new_resolver(
21
+ 'cwd', os.getcwd)
22
+ omegaconf.OmegaConf.register_new_resolver(
23
+ 'device_count', torch.cuda.device_count)
24
+ omegaconf.OmegaConf.register_new_resolver(
25
+ 'eval', eval)
26
+ omegaconf.OmegaConf.register_new_resolver(
27
+ 'div_up', lambda x, y: (x + y - 1) // y)
28
+ omegaconf.OmegaConf.register_new_resolver(
29
+ 'if_then_else',
30
+ lambda condition, x, y: x if condition else y
31
+ )
32
+
33
+ def _print_config(
34
+ config: omegaconf.DictConfig,
35
+ resolve: bool = True) -> None:
36
+ """Prints content of DictConfig using Rich library and its tree structure.
37
+
38
+ Args:
39
+ config (DictConfig): Configuration composed by Hydra.
40
+ resolve (bool): Whether to resolve reference fields of DictConfig.
41
+ """
42
+
43
+ style = 'dim'
44
+ tree = rich.tree.Tree('CONFIG', style=style,
45
+ guide_style=style)
46
+
47
+ fields = config.keys()
48
+ for field in fields:
49
+ branch = tree.add(field, style=style, guide_style=style)
50
+
51
+ config_section = config.get(field)
52
+ branch_content = str(config_section)
53
+ if isinstance(config_section, omegaconf.DictConfig):
54
+ branch_content = omegaconf.OmegaConf.to_yaml(
55
+ config_section, resolve=resolve)
56
+
57
+ branch.add(rich.syntax.Syntax(branch_content, 'yaml'))
58
+ rich.print(tree)
59
+
60
+ def parse_motif(motif: str) -> list:
61
+ parts = motif.split(',')
62
+ result = []
63
+
64
+ for part in parts:
65
+ part = part.strip()
66
+ if '-' in part:
67
+ start, end = map(int, part.split('-'))
68
+ result.extend(range(start, end + 1))
69
+ else:
70
+ result.append(int(part))
71
+
72
+ return torch.tensor(result)
73
+
74
+ @hydra.main(version_base=None, config_path='./configs',
75
+ config_name='config')
76
+ def main(config: omegaconf.DictConfig) -> None:
77
+ # Reproducibility
78
+ L.seed_everything(config.seed)
79
+ os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
80
+ torch.use_deterministic_algorithms(True)
81
+ torch.backends.cudnn.benchmark = False
82
+
83
+ # _print_config(config, resolve=True)
84
+ print(f"Checkpoint: {config.eval.checkpoint_path}")
85
+
86
+ tokenizer = dataloader.get_tokenizer(config)
87
+ target_sequence = tokenizer(config.eval.target_sequence, return_tensors='pt')['input_ids']
88
+
89
+ pretrained = diffusion.Diffusion.load_from_checkpoint(
90
+ config.eval.checkpoint_path,
91
+ tokenizer=tokenizer,
92
+ config=config, logger=False)
93
+ pretrained.eval()
94
+
95
+ bindevaluator = BindEvaluator.load_from_checkpoint(
96
+ config.guidance.classifier_checkpoint_path,
97
+ n_layers=8,
98
+ d_model=128,
99
+ d_hidden=128,
100
+ n_head=8,
101
+ d_k=64,
102
+ d_v=128,
103
+ d_inner=64)
104
+
105
+ samples = []
106
+ for _ in tqdm(
107
+ range(config.sampling.num_sample_batches),
108
+ desc='Gen. batches', leave=False):
109
+ sample = pretrained.sample(
110
+ target_sequence = target_sequence,
111
+ target_motifs = parse_motif(config.eval.target_motifs),
112
+ classifier_model = bindevaluator
113
+ )
114
+ # print(f"Batch took {time.time() - start:.2f} seconds.")
115
+ samples.extend(
116
+ pretrained.tokenizer.batch_decode(sample))
117
+
118
+ print([sample.replace(' ', '')[5:-5] for sample in samples])
119
+
120
+ samples = [sample.replace(' ', '')[5:-5] for sample in samples]
121
+ print(samples)
122
+
123
+ if __name__ == '__main__':
124
+ main()
tokenizer.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Custom Tokenization classes."""
2
+
3
+ import collections
4
+ import json
5
+ import os
6
+ import re
7
+ from typing import List, Optional, Tuple, Union
8
+
9
+ from transformers.tokenization_utils import PreTrainedTokenizer
10
+ from transformers.utils import logging
11
+
12
+
13
+ logger = logging.get_logger(__name__)
14
+
15
+ VOCAB_FILES_NAMES = {'vocab_file': 'vocab.json'}
16
+ PRETRAINED_VOCAB_FILES_MAP = {
17
+ 'qm9': {
18
+ 'vocab_file': {
19
+ 'yairschiff/qm9-tokenizer': 'https://huggingface.co/yairschiff/qm9-tokenizer/resolve/main/vocab.json'
20
+ }
21
+ },
22
+ 'zinc250k': {
23
+ 'vocab_file': {
24
+ 'yairschiff/zinc250k-tokenizer': 'https://huggingface.co/yairschiff/zinc250k-tokenizer/resolve/main/vocab.json'
25
+ }
26
+ }
27
+ }
28
+
29
+
30
+ class SMILESTokenizer(PreTrainedTokenizer):
31
+ r"""
32
+ Construct a tokenizer for SMILES datasets.
33
+ Based on regex.
34
+
35
+ This tokenizer inherits from [`PreTrainedTokenizer`]
36
+ which contains most of the main methods. Users should
37
+ refer to this superclass for more information regarding
38
+ those methods.
39
+
40
+ Adapted from:
41
+ https://huggingface.co/ibm/MoLFormer-XL-both-10pct
42
+
43
+ Args:
44
+ vocab_file (`str`):
45
+ File containing the vocabulary.
46
+ unk_token (`str`, *optional*, defaults to `"<unk>"`):
47
+ The unknown token. A token not in the vocabulary
48
+ cannot be converted to an ID and is set to be
49
+ this token instead.
50
+ sep_token (`str`, *optional*, defaults to `"<eos>"`):
51
+ The separator token, which is used when building
52
+ a sequence from multiple sequences, e.g., two
53
+ sequences for sequence classification or for a
54
+ text and a question for question answering.
55
+ It is also used as the last token of a sequence
56
+ built with special tokens.
57
+ pad_token (`str`, *optional*, defaults to `"<pad>"`):
58
+ The token used for padding, for example, when
59
+ batching sequences of different lengths.
60
+ cls_token (`str`, *optional*, defaults to `"<bos>"`):
61
+ The classifier token which is used when doing
62
+ sequence classification (classification of the
63
+ whole sequence
64
+ instead of per-token classification). It is the
65
+ first token of the sequence when built with
66
+ special tokens.
67
+ mask_token (`str`, *optional*, defaults to `"<mask>"`):
68
+ The token used for masking values. This is the
69
+ token used when training this model with masked
70
+ language modeling. This is the token, which the
71
+ model will try to predict.
72
+ """
73
+
74
+ vocab_files_names = VOCAB_FILES_NAMES
75
+ model_input_names = ["input_ids", "attention_mask"]
76
+
77
+ def __init__(
78
+ self,
79
+ vocab_file,
80
+ unk_token='<unk>',
81
+ sep_token='<eos>',
82
+ pad_token='<pad>',
83
+ cls_token='<bos>',
84
+ mask_token='<mask>',
85
+ **kwargs,
86
+ ):
87
+ if not os.path.isfile(vocab_file):
88
+ raise ValueError(
89
+ "Can't find a vocabulary file at path"
90
+ f"'{vocab_file}'."
91
+ )
92
+ with open(vocab_file, encoding="utf-8") as vocab_handle:
93
+ vocab_from_file = json.load(vocab_handle)
94
+ # Re-index to account for special tokens
95
+ self.vocab = {
96
+ cls_token: 0,
97
+ sep_token: 1,
98
+ mask_token: 2,
99
+ pad_token: 3,
100
+ unk_token: 4,
101
+ **{k: v + 5 for k, v in vocab_from_file.items()}
102
+ }
103
+
104
+ self.ids_to_tokens = collections.OrderedDict(
105
+ [(ids, tok) for tok, ids in self.vocab.items()])
106
+ # Regex pattern taken from:
107
+ # https://github.com/pschwllr/MolecularTransformer
108
+ self.pattern = (
109
+ r"(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])"
110
+ )
111
+ self.regex_tokenizer = re.compile(self.pattern)
112
+
113
+ super().__init__(
114
+ unk_token=unk_token,
115
+ sep_token=sep_token,
116
+ pad_token=pad_token,
117
+ cls_token=cls_token,
118
+ mask_token=mask_token,
119
+ **kwargs,
120
+ )
121
+
122
+ @property
123
+ def vocab_size(self):
124
+ return len(self.vocab)
125
+
126
+ def get_vocab(self):
127
+ return dict(self.vocab, **self.added_tokens_encoder)
128
+
129
+ def _tokenize(self, text, **kwargs):
130
+ split_tokens = self.regex_tokenizer.findall(text)
131
+ return split_tokens
132
+
133
+ def _convert_token_to_id(self, token):
134
+ """Converts token (str) in an id using the vocab."""
135
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
136
+
137
+ def _convert_id_to_token(self, index):
138
+ """Converts index (integer) in a token (str) using the vocab."""
139
+ return self.ids_to_tokens.get(index, self.unk_token)
140
+
141
+ def convert_tokens_to_string(self, tokens):
142
+ """Converts sequence of tokens (string) in a single string."""
143
+ out_string = "".join(tokens).strip()
144
+ return out_string
145
+
146
+ # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.build_inputs_with_special_tokens
147
+ def build_inputs_with_special_tokens(
148
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
149
+ ) -> List[int]:
150
+ """
151
+ Build model inputs from a sequence or a pair of
152
+ sequences for sequence classification tasks by
153
+ concatenating and adding special tokens.
154
+ A BERT sequence has the following format:
155
+
156
+ - single sequence: `[CLS] X [SEP]`
157
+ - pair of sequences: `[CLS] A [SEP] B [SEP]`
158
+
159
+ Args:
160
+ token_ids_0 (`List[int]`):
161
+ List of IDs to which the special tokens will
162
+ be added.
163
+ token_ids_1 (`List[int]`, *optional*):
164
+ Optional second list of IDs for sequence
165
+ pairs.
166
+
167
+ Returns:
168
+ `List[int]`: List of [input IDs](../glossary#input-ids)
169
+ with the appropriate special tokens.
170
+ """
171
+ if token_ids_1 is None:
172
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
173
+ cls = [self.cls_token_id]
174
+ sep = [self.sep_token_id]
175
+ return cls + token_ids_0 + sep + token_ids_1 + sep
176
+
177
+ # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.get_special_tokens_mask
178
+ def get_special_tokens_mask(
179
+ self,
180
+ token_ids_0: List[int],
181
+ token_ids_1: Optional[List[int]] = None,
182
+ already_has_special_tokens: bool = False
183
+ ) -> List[int]:
184
+ """
185
+ Retrieve sequence ids from a token list that has no
186
+ special tokens added. This method is called when
187
+ adding special tokens using the tokenizer
188
+ `prepare_for_model` method.
189
+
190
+ Args:
191
+ token_ids_0 (`List[int]`):
192
+ List of IDs.
193
+ token_ids_1 (`List[int]`, *optional*):
194
+ Optional second list of IDs for sequence pairs.
195
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
196
+ Whether the token list is already formatted
197
+ with special tokens for the model.
198
+
199
+ Returns:
200
+ `List[int]`: A list of integers in the range
201
+ [0, 1]: 1 for a special token, 0 for a sequence
202
+ token.
203
+ """
204
+
205
+ if already_has_special_tokens:
206
+ return super().get_special_tokens_mask(
207
+ token_ids_0=token_ids_0,
208
+ token_ids_1=token_ids_1,
209
+ already_has_special_tokens=True
210
+ )
211
+
212
+ if token_ids_1 is not None:
213
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
214
+ return [1] + ([0] * len(token_ids_0)) + [1]
215
+
216
+ # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.create_token_type_ids_from_sequences
217
+ def create_token_type_ids_from_sequences(
218
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
219
+ ) -> List[int]:
220
+ """
221
+ Create a mask from the two sequences passed to be
222
+ used in a sequence-pair classification task.
223
+ A BERT sequence pair mask has the following format:
224
+
225
+ ```
226
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
227
+ | first sequence | second sequence |
228
+ ```
229
+
230
+ If `token_ids_1` is `None`, this method only returns
231
+ the first portion of the mask (0s).
232
+
233
+ Args:
234
+ token_ids_0 (`List[int]`):
235
+ List of IDs.
236
+ token_ids_1 (`List[int]`, *optional*):
237
+ Optional second list of IDs for sequence
238
+ pairs.
239
+
240
+ Returns:
241
+ `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
242
+ """
243
+ sep = [self.sep_token_id]
244
+ cls = [self.cls_token_id]
245
+ if token_ids_1 is None:
246
+ return len(cls + token_ids_0 + sep) * [0]
247
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
248
+
249
+ def save_vocabulary(
250
+ self, save_directory: str,
251
+ filename_prefix: Optional[str] = None
252
+ ) -> Union[Tuple[str], None]:
253
+ if not os.path.isdir(save_directory):
254
+ logger.error(
255
+ f"Vocabulary path ({save_directory}) should"
256
+ "be a directory.")
257
+ return None
258
+ vocab_file = os.path.join(
259
+ save_directory,
260
+ (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
261
+ )
262
+
263
+ with open(vocab_file, "w", encoding="utf-8") as f:
264
+ f.write(
265
+ json.dumps(
266
+ self.vocab,
267
+ indent=2,
268
+ sort_keys=True,
269
+ ensure_ascii=False
270
+ ) + "\n")
271
+ return (vocab_file,)
272
+
273
+
274
+ class QM9Tokenizer(SMILESTokenizer):
275
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP['qm9']
276
+
277
+
278
+ class Zinc250kTokenizer(SMILESTokenizer):
279
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP['zinc250k']
uncond_sample.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import hydra
3
+ import lightning as L
4
+ import numpy as np
5
+ import omegaconf
6
+ import pandas as pd
7
+ import rdkit
8
+ import rich.syntax
9
+ import rich.tree
10
+ import torch
11
+ from tqdm.auto import tqdm
12
+ import pdb
13
+ import csv
14
+
15
+ import dataloader
16
+ import diffusion
17
+
18
+ rdkit.rdBase.DisableLog('rdApp.error')
19
+
20
+ omegaconf.OmegaConf.register_new_resolver(
21
+ 'cwd', os.getcwd)
22
+ omegaconf.OmegaConf.register_new_resolver(
23
+ 'device_count', torch.cuda.device_count)
24
+ omegaconf.OmegaConf.register_new_resolver(
25
+ 'eval', eval)
26
+ omegaconf.OmegaConf.register_new_resolver(
27
+ 'div_up', lambda x, y: (x + y - 1) // y)
28
+ omegaconf.OmegaConf.register_new_resolver(
29
+ 'if_then_else',
30
+ lambda condition, x, y: x if condition else y
31
+ )
32
+
33
+ def _print_config(
34
+ config: omegaconf.DictConfig,
35
+ resolve: bool = True) -> None:
36
+ """Prints content of DictConfig using Rich library and its tree structure.
37
+
38
+ Args:
39
+ config (DictConfig): Configuration composed by Hydra.
40
+ resolve (bool): Whether to resolve reference fields of DictConfig.
41
+ """
42
+
43
+ style = 'dim'
44
+ tree = rich.tree.Tree('CONFIG', style=style,
45
+ guide_style=style)
46
+
47
+ fields = config.keys()
48
+ for field in fields:
49
+ branch = tree.add(field, style=style, guide_style=style)
50
+
51
+ config_section = config.get(field)
52
+ branch_content = str(config_section)
53
+ if isinstance(config_section, omegaconf.DictConfig):
54
+ branch_content = omegaconf.OmegaConf.to_yaml(
55
+ config_section, resolve=resolve)
56
+
57
+ branch.add(rich.syntax.Syntax(branch_content, 'yaml'))
58
+ rich.print(tree)
59
+
60
+ def parse_range(tgt_range: str) -> list:
61
+ parts = tgt_range.split(',')
62
+ result = []
63
+
64
+ for part in parts:
65
+ part = part.strip()
66
+ if '-' in part:
67
+ start, end = map(int, part.split('-'))
68
+ result.extend(range(start, end + 1))
69
+ else:
70
+ result.append(int(part))
71
+
72
+ return result
73
+
74
+ @hydra.main(version_base=None, config_path='./configs',
75
+ config_name='config')
76
+ def main(config: omegaconf.DictConfig) -> None:
77
+ # Reproducibility
78
+ L.seed_everything(config.seed)
79
+ os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
80
+ torch.use_deterministic_algorithms(True)
81
+ torch.backends.cudnn.benchmark = False
82
+
83
+ # _print_config(config, resolve=True)
84
+ print(f"Checkpoint: {config.eval.checkpoint_path}")
85
+
86
+ tokenizer = dataloader.get_tokenizer(config)
87
+
88
+ pretrained = diffusion.Diffusion.load_from_checkpoint(
89
+ config.eval.checkpoint_path,
90
+ tokenizer=tokenizer,
91
+ config=config, logger=False)
92
+ pretrained.eval()
93
+
94
+ target_lengths = parse_range(config.model.length_range)
95
+
96
+ for length in target_lengths:
97
+ config.model.length = length + 2
98
+ samples = []
99
+ for _ in tqdm(
100
+ range(config.sampling.num_sample_batches),
101
+ desc='Gen. batches', leave=False):
102
+ sample = pretrained.sample()
103
+ # print(f"Batch took {time.time() - start:.2f} seconds.")
104
+ samples.extend(
105
+ pretrained.tokenizer.batch_decode(sample))
106
+
107
+ # print([sample.replace(' ', '')[5:-5] for sample in samples])
108
+
109
+ samples = [sample.replace(' ', '')[5:-5] for sample in samples]
110
+ print(samples)
111
+
112
+ # df = pd.DataFrame(samples, columns=['sequence'])
113
+ # df.to_csv(f'/home/tc415/discrete-diffusion-guidance/samples/{length}.csv', index=False)
114
+
115
+ if __name__ == '__main__':
116
+ main()
utils.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Console logger utilities.
2
+
3
+ Copied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py
4
+ Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging
5
+ """
6
+
7
+ import logging
8
+
9
+ import fsspec
10
+ import lightning
11
+ import torch
12
+ from timm.scheduler import CosineLRScheduler
13
+
14
+
15
+ def fsspec_exists(filename):
16
+ """Check if a file exists using fsspec."""
17
+ fs, _ = fsspec.core.url_to_fs(filename)
18
+ return fs.exists(filename)
19
+
20
+
21
+ def fsspec_listdir(dirname):
22
+ """Listdir in manner compatible with fsspec."""
23
+ fs, _ = fsspec.core.url_to_fs(dirname)
24
+ return fs.ls(dirname)
25
+
26
+
27
+ def fsspec_mkdirs(dirname, exist_ok=True):
28
+ """Mkdirs in manner compatible with fsspec."""
29
+ fs, _ = fsspec.core.url_to_fs(dirname)
30
+ fs.makedirs(dirname, exist_ok=exist_ok)
31
+
32
+
33
+ def print_nans(tensor, name):
34
+ if torch.isnan(tensor).any():
35
+ print(name, tensor)
36
+
37
+
38
+ class CosineDecayWarmupLRScheduler(
39
+ CosineLRScheduler,
40
+ torch.optim.lr_scheduler._LRScheduler):
41
+ """Wrap timm.scheduler.CosineLRScheduler
42
+ Enables calling scheduler.step() without passing in epoch.
43
+ Supports resuming as well.
44
+ Adapted from:
45
+ https://github.com/HazyResearch/hyena-dna/blob/main/src/utils/optim/schedulers.py
46
+ """
47
+
48
+ def __init__(self, *args, **kwargs):
49
+ super().__init__(*args, **kwargs)
50
+ self._last_epoch = -1
51
+ self.step(epoch=0)
52
+
53
+ def step(self, epoch=None):
54
+ if epoch is None:
55
+ self._last_epoch += 1
56
+ else:
57
+ self._last_epoch = epoch
58
+ # We call either step or step_update, depending on
59
+ # whether we're using the scheduler every epoch or every
60
+ # step.
61
+ # Otherwise, lightning will always call step (i.e.,
62
+ # meant for each epoch), and if we set scheduler
63
+ # interval to "step", then the learning rate update will
64
+ # be wrong.
65
+ if self.t_in_epochs:
66
+ super().step(epoch=self._last_epoch)
67
+ else:
68
+ super().step_update(num_updates=self._last_epoch)
69
+
70
+
71
+ def get_logger(name=__name__, level=logging.INFO) -> logging.Logger:
72
+ """Initializes multi-GPU-friendly python logger."""
73
+
74
+ logger = logging.getLogger(name)
75
+ logger.setLevel(level)
76
+
77
+ # this ensures all logging levels get marked with the rank zero decorator
78
+ # otherwise logs would get multiplied for each GPU process in multi-GPU setup
79
+ for level in ('debug', 'info', 'warning', 'error',
80
+ 'exception', 'fatal', 'critical'):
81
+ setattr(logger,
82
+ level,
83
+ lightning.pytorch.utilities.rank_zero_only(
84
+ getattr(logger, level)))
85
+
86
+ return logger