mag2mag / diffusion.py
fpramunno's picture
Update diffusion.py
5d926b5 verified
raw
history blame
6.22 kB
# -*- coding: utf-8 -*-
"""
Created on Tue Apr 25 14:45:59 2023
@author: pio-r
"""
import torch
from tqdm import tqdm
import torch.nn as nn
import logging
import numpy as np
logging.basicConfig(format="%(asctime)s - %(levelname)s: %(message)s", level=logging.INFO, datefmt="%I:%M:%S")
class Diffusion_cond:
def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02, img_size=256, img_channel=1, device="cuda"):
self.noise_steps = noise_steps # timestesps
self.beta_start = beta_start
self.beta_end = beta_end
self.img_channel = img_channel
self.img_size = img_size
self.device = device
self.beta = self.prepare_noise_schedule().to(device)
self.alpha = 1. - self.beta
self.alphas_prev = torch.cat([torch.tensor([1.0]).to(device), self.alpha[:-1]], dim=0)
self.alpha_hat = torch.cumprod(self.alpha, dim=0)
self.alphas_cumprod_prev = torch.cat([torch.tensor([1.0]).to(device), self.alpha_hat[:-1]], dim=0)
# self.alphas_cumprod_prev = torch.from_numpy(np.append(1, self.alpha_hat[:-1].cpu().numpy())).to(device)
def prepare_noise_schedule(self):
return torch.linspace(self.beta_start, self.beta_end, self.noise_steps) # linear variance schedule as proposed by Ho et al 2020
def noise_images(self, x, t):
sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
Ɛ = torch.randn_like(x)
return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * Ɛ, Ɛ # equation in the paper from Ho et al that describes the noise processs
def sample_timesteps(self, n):
return torch.randint(low=1, high=self.noise_steps, size=(n,))
def sample(self, model, n, y, labels, cfg_scale=3, eta=1, sampling_mode='ddpm'):
logging.info(f"Sampling {n} new images....")
model.eval() # evaluation mode
with torch.no_grad(): # algorithm 2 from DDPM
x = torch.randn((n, self.img_channel, self.img_size, self.img_size)).to(self.device)
for i in tqdm(reversed(range(1, self.noise_steps)), position=0): # reverse loop from T to 1
t = (torch.ones(n) * i).long().to(self.device) # create timesteps tensor of length n
predicted_noise = model(x, y, labels, t)
if cfg_scale > 0:
uncond_predicted_noise = model(x, y, None, t)
predicted_noise = torch.lerp(uncond_predicted_noise, predicted_noise, cfg_scale)
alpha = self.alpha[t][:, None, None, None]
alpha_hat = self.alpha_hat[t][:, None, None, None] # this is noise, created in one
alpha_prev = self.alphas_cumprod_prev[t][:, None, None, None]
beta = self.beta[t][:, None, None, None]
# SAMPLING adjusted from Stable diffusion
sigma = (
eta
* torch.sqrt((1 - alpha_prev) / (1 - alpha_hat)
* (1 - alpha_hat / alpha_prev))
)
if i > 1:
noise = torch.randn_like(x)
else:
noise = torch.zeros_like(x)
# pred_x0 = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise)
pred_x0 = (x - torch.sqrt(1 - alpha_hat) * predicted_noise) / torch.sqrt(alpha_hat)
if sampling_mode == 'ddpm':
x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
elif sampling_mode == 'ddim':
noise = torch.randn_like(x)
nonzero_mask = (
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
)
x = (
torch.sqrt(alpha_prev) * pred_x0 +
torch.sqrt(1 - alpha_prev - sigma ** 2) * predicted_noise +
nonzero_mask * sigma * noise
)
else:
print('The sampler {} is not implemented'.format(sampling_mode))
break
model.train() # it goes back to training mode
x = (x.clamp(-1, 1) + 1) / 2 # to be in [-1, 1], the plus 1 and the division by 2 is to bring back values to [0, 1]
x = (x * 255).type(torch.uint8) # to bring in valid pixel range
return x
mse = nn.MSELoss()
def psnr(input: torch.Tensor, target: torch.Tensor, max_val: float) -> torch.Tensor:
r"""Create a function that calculates the PSNR between 2 images.
PSNR is Peek Signal to Noise Ratio, which is similar to mean squared error.
Given an m x n image, the PSNR is:
.. math::
\text{PSNR} = 10 \log_{10} \bigg(\frac{\text{MAX}_I^2}{MSE(I,T)}\bigg)
where
.. math::
\text{MSE}(I,T) = \frac{1}{mn}\sum_{i=0}^{m-1}\sum_{j=0}^{n-1} [I(i,j) - T(i,j)]^2
and :math:`\text{MAX}_I` is the maximum possible input value
(e.g for floating point images :math:`\text{MAX}_I=1`).
Args:
input: the input image with arbitrary shape :math:`(*)`.
labels: the labels image with arbitrary shape :math:`(*)`.
max_val: The maximum value in the input tensor.
Return:
the computed loss as a scalar.
Examples:
>>> ones = torch.ones(1)
>>> psnr(ones, 1.2 * ones, 2.) # 10 * log(4/((1.2-1)**2)) / log(10)
tensor(20.0000)
Reference:
https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio#Definition
"""
if not isinstance(input, torch.Tensor):
raise TypeError(f"Expected torch.Tensor but got {type(target)}.")
if not isinstance(target, torch.Tensor):
raise TypeError(f"Expected torch.Tensor but got {type(input)}.")
if input.shape != target.shape:
raise TypeError(f"Expected tensors of equal shapes, but got {input.shape} and {target.shape}")
return 10.0 * torch.log10(max_val**2 / mse(input, target))