|
import itertools |
|
import math |
|
import torch |
|
import torch.nn.functional as F |
|
import pytorch_lightning as L |
|
import torchmetrics |
|
from dataclasses import dataclass |
|
from models import dit, ema |
|
import noise_schedule |
|
|
|
LOG2 = math.log(2) |
|
|
|
@dataclass |
|
class Loss: |
|
loss: torch.FloatTensor |
|
nlls: torch.FloatTensor |
|
token_mask: torch.FloatTensor |
|
|
|
class NLL(torchmetrics.MeanMetric): |
|
pass |
|
|
|
class BPD(NLL): |
|
def compute(self) -> torch.Tensor: |
|
"""Computes the bits per dimension. |
|
|
|
Returns: |
|
bpd |
|
""" |
|
return self.mean_value / self.weight / LOG2 |
|
|
|
class Perplexity(NLL): |
|
def compute(self) -> torch.Tensor: |
|
"""Computes the Perplexity. |
|
|
|
Returns: |
|
Perplexity |
|
""" |
|
return torch.exp(self.mean_value / self.weight) |
|
|
|
class Diffusion(L.LightningModule): |
|
def __init__(self, config, latent_dim): |
|
super().__init__() |
|
self.config = config |
|
self.latent_dim = latent_dim |
|
|
|
self.backbone = dit.DIT(config, vocab_size=self.latent_dim) |
|
self.T = self.config.T |
|
self.subs_masking = self.config.subs_masking |
|
|
|
self.softplus = torch.nn.Softplus() |
|
metrics = torchmetrics.MetricCollection({ |
|
'nll': NLL(), |
|
'bpd': BPD(), |
|
'ppl': Perplexity(), |
|
}) |
|
metrics.set_dtype(torch.float64) |
|
self.train_metrics = metrics.clone(prefix='train/') |
|
self.valid_metrics = metrics.clone(prefix='val/') |
|
self.test_metrics = metrics.clone(prefix='test/') |
|
|
|
self.noise = noise_schedule.get_noise(self.config, dtype=self.dtype) |
|
self.lr = self.config.optim["lr"] |
|
self.sampling_eps = self.config.training.get("sampling_eps", 1e-5) |
|
self.time_conditioning = self.config.get("time_conditioning", True) |
|
self.neg_infinity = -1000000.0 |
|
|
|
def forward(self, latents, sigma): |
|
"""Forward diffusion process, adds noise to the latents.""" |
|
noise = sigma * torch.randn_like(latents) |
|
noisy_latents = latents + noise |
|
return noisy_latents |
|
|
|
def reverse_diffusion(self, noisy_latents, sigma): |
|
"""Reverse diffusion process, denoises the latents.""" |
|
denoised_latents = self.backbone(noisy_latents, sigma) |
|
return denoised_latents |
|
|
|
def training_step(self, batch, batch_idx): |
|
sigma = torch.rand(batch.size(0), device=self.device) |
|
noisy_latents = self.forward(batch, sigma) |
|
denoised_latents = self.reverse_diffusion(noisy_latents, sigma) |
|
loss = F.mse_loss(denoised_latents, batch) |
|
self.log("train_loss", loss) |
|
return loss |
|
|
|
def configure_optimizers(self): |
|
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) |
|
return optimizer |
|
|