|
|
|
""" |
|
Created on Tue Apr 25 14:45:59 2023 |
|
|
|
@author: pio-r |
|
""" |
|
|
|
import torch |
|
from tqdm import tqdm |
|
import torch.nn as nn |
|
import logging |
|
import numpy as np |
|
|
|
logging.basicConfig(format="%(asctime)s - %(levelname)s: %(message)s", level=logging.INFO, datefmt="%I:%M:%S") |
|
|
|
|
|
class Diffusion_cond: |
|
def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02, img_size=256, img_channel=1, device="cuda"): |
|
self.noise_steps = noise_steps |
|
self.beta_start = beta_start |
|
self.beta_end = beta_end |
|
self.img_channel = img_channel |
|
self.img_size = img_size |
|
self.device = device |
|
|
|
self.beta = self.prepare_noise_schedule().to(device) |
|
self.alpha = 1. - self.beta |
|
self.alphas_prev = torch.cat([torch.tensor([1.0]).to(device), self.alpha[:-1]], dim=0) |
|
self.alpha_hat = torch.cumprod(self.alpha, dim=0) |
|
self.alphas_cumprod_prev = torch.cat([torch.tensor([1.0]).to(device), self.alpha_hat[:-1]], dim=0) |
|
|
|
def prepare_noise_schedule(self): |
|
return torch.linspace(self.beta_start, self.beta_end, self.noise_steps) |
|
|
|
def noise_images(self, x, t): |
|
sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None] |
|
sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None] |
|
Ɛ = torch.randn_like(x) |
|
return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * Ɛ, Ɛ |
|
|
|
def sample_timesteps(self, n): |
|
return torch.randint(low=1, high=self.noise_steps, size=(n,)) |
|
|
|
def sample(self, model, n, y, labels, cfg_scale=3, eta=1, sampling_mode='ddpm'): |
|
logging.info(f"Sampling {n} new images....") |
|
model.eval() |
|
with torch.no_grad(): |
|
x = torch.randn((n, self.img_channel, self.img_size, self.img_size)).to(self.device) |
|
for i in tqdm(reversed(range(1, self.noise_steps)), position=0): |
|
t = (torch.ones(n) * i).long().to(self.device) |
|
predicted_noise = model(x, y, labels, t) |
|
if cfg_scale > 0: |
|
uncond_predicted_noise = model(x, y, None, t) |
|
predicted_noise = torch.lerp(uncond_predicted_noise, predicted_noise, cfg_scale) |
|
|
|
|
|
alpha = self.alpha[t][:, None, None, None] |
|
alpha_hat = self.alpha_hat[t][:, None, None, None] |
|
alpha_prev = self.alphas_cumprod_prev[t][:, None, None, None] |
|
beta = self.beta[t][:, None, None, None] |
|
|
|
sigma = ( |
|
eta |
|
* torch.sqrt((1 - alpha_prev) / (1 - alpha_hat) |
|
* (1 - alpha_hat / alpha_prev)) |
|
) |
|
if i > 1: |
|
noise = torch.randn_like(x) |
|
else: |
|
noise = torch.zeros_like(x) |
|
|
|
pred_x0 = (x - torch.sqrt(1 - alpha_hat) * predicted_noise) / torch.sqrt(alpha_hat) |
|
if sampling_mode == 'ddpm': |
|
x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise |
|
elif sampling_mode == 'ddim': |
|
noise = torch.randn_like(x) |
|
nonzero_mask = ( |
|
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) |
|
) |
|
x = ( |
|
torch.sqrt(alpha_prev) * pred_x0 + |
|
torch.sqrt(1 - alpha_prev - sigma ** 2) * predicted_noise + |
|
nonzero_mask * sigma * noise |
|
) |
|
else: |
|
print('The sampler {} is not implemented'.format(sampling_mode)) |
|
break |
|
model.train() |
|
x = (x.clamp(-1, 1) + 1) / 2 |
|
x = (x * 255).type(torch.uint8) |
|
return x |
|
|
|
mse = nn.MSELoss() |
|
|
|
def psnr(input: torch.Tensor, target: torch.Tensor, max_val: float) -> torch.Tensor: |
|
r"""Create a function that calculates the PSNR between 2 images. |
|
|
|
PSNR is Peek Signal to Noise Ratio, which is similar to mean squared error. |
|
Given an m x n image, the PSNR is: |
|
|
|
.. math:: |
|
|
|
\text{PSNR} = 10 \log_{10} \bigg(\frac{\text{MAX}_I^2}{MSE(I,T)}\bigg) |
|
|
|
where |
|
|
|
.. math:: |
|
|
|
\text{MSE}(I,T) = \frac{1}{mn}\sum_{i=0}^{m-1}\sum_{j=0}^{n-1} [I(i,j) - T(i,j)]^2 |
|
|
|
and :math:`\text{MAX}_I` is the maximum possible input value |
|
(e.g for floating point images :math:`\text{MAX}_I=1`). |
|
|
|
Args: |
|
input: the input image with arbitrary shape :math:`(*)`. |
|
labels: the labels image with arbitrary shape :math:`(*)`. |
|
max_val: The maximum value in the input tensor. |
|
|
|
Return: |
|
the computed loss as a scalar. |
|
|
|
Examples: |
|
>>> ones = torch.ones(1) |
|
>>> psnr(ones, 1.2 * ones, 2.) # 10 * log(4/((1.2-1)**2)) / log(10) |
|
tensor(20.0000) |
|
|
|
Reference: |
|
https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio#Definition |
|
""" |
|
if not isinstance(input, torch.Tensor): |
|
raise TypeError(f"Expected torch.Tensor but got {type(target)}.") |
|
|
|
if not isinstance(target, torch.Tensor): |
|
raise TypeError(f"Expected torch.Tensor but got {type(input)}.") |
|
|
|
if input.shape != target.shape: |
|
raise TypeError(f"Expected tensors of equal shapes, but got {input.shape} and {target.shape}") |
|
|
|
return 10.0 * torch.log10(max_val**2 / mse(input, target)) |