Spaces:
Running
Running
File size: 1,929 Bytes
c4c7cee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import torch
from torch import nn
# ----------------------------------------------------------------------------
# Improved preconditioning proposed in the paper "Elucidating the Design
# Space of Diffusion-Based Generative networks" (EDM).
class EDMPrecond(torch.nn.Module):
def __init__(
self,
network,
label_dim=0, # Number of class labels, 0 = unconditional.
sigma_min=0, # Minimum supported noise level.
sigma_max=float("inf"), # Maximum supported noise level.
sigma_data=0.5, # Expected standard deviation of the training data.
):
super().__init__()
self.label_dim = label_dim
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.sigma_data = sigma_data
self.network = network
def forward(self, x, sigma, conditioning=None, **network_kwargs):
x = x.to(torch.float32)
sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
conditioning = (
None
if self.label_dim == 0
else torch.zeros([1, self.label_dim], device=x.device)
if conditioning is None
else conditioning.to(torch.float32)
)
c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt()
c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt()
c_noise = sigma.log() / 4
F_x = self.network(
(c_in * x),
c_noise.flatten(),
conditioning=conditioning,
**network_kwargs,
)
D_x = c_skip * x + c_out * F_x.to(torch.float32)
return D_x
def round_sigma(self, sigma):
return torch.as_tensor(sigma)
class DDPMPrecond(nn.Module):
def __init__(self):
super().__init__()
def forward(self, network, batch):
F_x = network(batch)
return F_x
|