File size: 6,223 Bytes
0f1af34 f41fc94 0f1af34 5d926b5 0f1af34 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
# -*- coding: utf-8 -*-
"""
Created on Tue Apr 25 14:45:59 2023
@author: pio-r
"""
import torch
from tqdm import tqdm
import torch.nn as nn
import logging
import numpy as np
logging.basicConfig(format="%(asctime)s - %(levelname)s: %(message)s", level=logging.INFO, datefmt="%I:%M:%S")
class Diffusion_cond:
def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02, img_size=256, img_channel=1, device="cuda"):
self.noise_steps = noise_steps # timestesps
self.beta_start = beta_start
self.beta_end = beta_end
self.img_channel = img_channel
self.img_size = img_size
self.device = device
self.beta = self.prepare_noise_schedule().to(device)
self.alpha = 1. - self.beta
self.alphas_prev = torch.cat([torch.tensor([1.0]).to(device), self.alpha[:-1]], dim=0)
self.alpha_hat = torch.cumprod(self.alpha, dim=0)
self.alphas_cumprod_prev = torch.cat([torch.tensor([1.0]).to(device), self.alpha_hat[:-1]], dim=0)
# self.alphas_cumprod_prev = torch.from_numpy(np.append(1, self.alpha_hat[:-1].cpu().numpy())).to(device)
def prepare_noise_schedule(self):
return torch.linspace(self.beta_start, self.beta_end, self.noise_steps) # linear variance schedule as proposed by Ho et al 2020
def noise_images(self, x, t):
sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
Ɛ = torch.randn_like(x)
return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * Ɛ, Ɛ # equation in the paper from Ho et al that describes the noise processs
def sample_timesteps(self, n):
return torch.randint(low=1, high=self.noise_steps, size=(n,))
def sample(self, model, n, y, labels, cfg_scale=3, eta=1, sampling_mode='ddpm'):
logging.info(f"Sampling {n} new images....")
model.eval() # evaluation mode
with torch.no_grad(): # algorithm 2 from DDPM
x = torch.randn((n, self.img_channel, self.img_size, self.img_size)).to(self.device)
for i in tqdm(reversed(range(1, self.noise_steps)), position=0): # reverse loop from T to 1
t = (torch.ones(n) * i).long().to(self.device) # create timesteps tensor of length n
predicted_noise = model(x, y, labels, t)
# if cfg_scale > 0:
# uncond_predicted_noise = model(x, y, None, t)
# predicted_noise = torch.lerp(uncond_predicted_noise, predicted_noise, cfg_scale)
alpha = self.alpha[t][:, None, None, None]
alpha_hat = self.alpha_hat[t][:, None, None, None] # this is noise, created in one
alpha_prev = self.alphas_cumprod_prev[t][:, None, None, None]
beta = self.beta[t][:, None, None, None]
# SAMPLING adjusted from Stable diffusion
sigma = (
eta
* torch.sqrt((1 - alpha_prev) / (1 - alpha_hat)
* (1 - alpha_hat / alpha_prev))
)
if i > 1:
noise = torch.randn_like(x)
else:
noise = torch.zeros_like(x)
# pred_x0 = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise)
pred_x0 = (x - torch.sqrt(1 - alpha_hat) * predicted_noise) / torch.sqrt(alpha_hat)
if sampling_mode == 'ddpm':
x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
elif sampling_mode == 'ddim':
noise = torch.randn_like(x)
nonzero_mask = (
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
)
x = (
torch.sqrt(alpha_prev) * pred_x0 +
torch.sqrt(1 - alpha_prev - sigma ** 2) * predicted_noise +
nonzero_mask * sigma * noise
)
else:
print('The sampler {} is not implemented'.format(sampling_mode))
break
model.train() # it goes back to training mode
x = (x.clamp(-1, 1) + 1) / 2 # to be in [-1, 1], the plus 1 and the division by 2 is to bring back values to [0, 1]
x = (x * 255).type(torch.uint8) # to bring in valid pixel range
return x
mse = nn.MSELoss()
def psnr(input: torch.Tensor, target: torch.Tensor, max_val: float) -> torch.Tensor:
r"""Create a function that calculates the PSNR between 2 images.
PSNR is Peek Signal to Noise Ratio, which is similar to mean squared error.
Given an m x n image, the PSNR is:
.. math::
\text{PSNR} = 10 \log_{10} \bigg(\frac{\text{MAX}_I^2}{MSE(I,T)}\bigg)
where
.. math::
\text{MSE}(I,T) = \frac{1}{mn}\sum_{i=0}^{m-1}\sum_{j=0}^{n-1} [I(i,j) - T(i,j)]^2
and :math:`\text{MAX}_I` is the maximum possible input value
(e.g for floating point images :math:`\text{MAX}_I=1`).
Args:
input: the input image with arbitrary shape :math:`(*)`.
labels: the labels image with arbitrary shape :math:`(*)`.
max_val: The maximum value in the input tensor.
Return:
the computed loss as a scalar.
Examples:
>>> ones = torch.ones(1)
>>> psnr(ones, 1.2 * ones, 2.) # 10 * log(4/((1.2-1)**2)) / log(10)
tensor(20.0000)
Reference:
https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio#Definition
"""
if not isinstance(input, torch.Tensor):
raise TypeError(f"Expected torch.Tensor but got {type(target)}.")
if not isinstance(target, torch.Tensor):
raise TypeError(f"Expected torch.Tensor but got {type(input)}.")
if input.shape != target.shape:
raise TypeError(f"Expected tensors of equal shapes, but got {input.shape} and {target.shape}")
return 10.0 * torch.log10(max_val**2 / mse(input, target)) |