|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
from torch.cuda.amp import autocast |
|
|
|
import torchvision |
|
from torchvision.transforms import transforms |
|
from torch.utils.data import DataLoader |
|
|
|
from torch.optim import Adam |
|
|
|
from einops import rearrange, reduce, repeat |
|
import math |
|
from random import random |
|
|
|
from collections import namedtuple |
|
from functools import partial |
|
from tqdm.auto import tqdm |
|
import logging |
|
import os |
|
|
|
from PIL import Image |
|
from torchvision import utils |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start']) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def exists(x): |
|
return x is not None |
|
|
|
def default(val, d): |
|
if exists(val): |
|
return val |
|
return d() if callable(d) else d |
|
|
|
|
|
|
|
|
|
|
|
def cast_tuple(t, length = 1): |
|
if isinstance(t, tuple): |
|
return t |
|
return ((t,) * length) |
|
|
|
|
|
|
|
|
|
|
|
def divisible_by(numer, denom): |
|
return (numer % denom) == 0 |
|
|
|
|
|
|
|
|
|
|
|
def identity(t, *args, **kwargs): |
|
return t |
|
|
|
|
|
|
|
|
|
|
|
def cycle(dl): |
|
while True: |
|
for data in dl: |
|
yield data |
|
|
|
|
|
|
|
|
|
|
|
def has_int_squareroot(num): |
|
return (math.sqrt(num) ** 2) == num |
|
|
|
|
|
|
|
|
|
|
|
def num_to_groups(num, divisor): |
|
groups = num // divisor |
|
remainder = num % divisor |
|
arr = [divisor] * groups |
|
if remainder > 0: |
|
arr.append(remainder) |
|
return arr |
|
|
|
|
|
|
|
|
|
|
|
def convert_image_to_fn(img_type, image): |
|
if image.mode != img_type: |
|
return image.convert(img_type) |
|
return image |
|
|
|
|
|
|
|
|
|
|
|
def extract(a, t, x_shape): |
|
b, *_ = t.shape |
|
out = a.gather(-1, t) |
|
return out.reshape(b, *((1,) * (len(x_shape) - 1))) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def normalize_to_neg_one_to_one(img): |
|
return img * 2 - 1 |
|
|
|
def unnormalize_to_zero_to_one(t): |
|
return (t + 1) * 0.5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SinusoidalPosEmb(nn.Module): |
|
def __init__(self, dim, theta = 10000): |
|
super().__init__() |
|
self.dim = dim |
|
self.theta = theta |
|
|
|
def forward(self, x): |
|
device = x.device |
|
half_dim = self.dim // 2 |
|
emb = math.log(self.theta) / (half_dim - 1) |
|
emb = torch.exp(torch.arange(half_dim, device=device) * -emb) |
|
emb = x[:, None] * emb[None, :] |
|
emb = torch.cat((emb.sin(), emb.cos()), dim=-1) |
|
return emb |
|
|
|
|
|
|
|
|
|
|
|
class RandomOrLearnedSinusoidalPosEmb(nn.Module): |
|
""" following @crowsonkb 's lead with random (learned optional) sinusoidal pos emb """ |
|
""" https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ |
|
|
|
def __init__(self, dim, is_random = False): |
|
super().__init__() |
|
assert divisible_by(dim, 2) |
|
half_dim = dim // 2 |
|
self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = not is_random) |
|
|
|
def forward(self, x): |
|
x = rearrange(x, 'b -> b 1') |
|
freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi |
|
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1) |
|
fouriered = torch.cat((x, fouriered), dim = -1) |
|
return fouriered |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def linear_beta_schedule(timesteps): |
|
""" |
|
linear schedule, proposed in original ddpm paper |
|
""" |
|
scale = 1000 / timesteps |
|
beta_start = scale * 0.0001 |
|
beta_end = scale * 0.02 |
|
return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64) |
|
|
|
|
|
|
|
|
|
|
|
def cosine_beta_schedule(timesteps, s = 0.008): |
|
""" |
|
cosine schedule |
|
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ |
|
""" |
|
steps = timesteps + 1 |
|
t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps |
|
alphas_cumprod = torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** 2 |
|
alphas_cumprod = alphas_cumprod / alphas_cumprod[0] |
|
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) |
|
return torch.clip(betas, 0, 0.999) |
|
|
|
|
|
|
|
|
|
|
|
def sigmoid_beta_schedule(timesteps, start = -3, end = 3, tau = 1, clamp_min = 1e-5): |
|
""" |
|
sigmoid schedule |
|
proposed in https://arxiv.org/abs/2212.11972 - Figure 8 |
|
better for images > 64x64, when used during training |
|
""" |
|
steps = timesteps + 1 |
|
t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps |
|
v_start = torch.tensor(start / tau).sigmoid() |
|
v_end = torch.tensor(end / tau).sigmoid() |
|
alphas_cumprod = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start) |
|
alphas_cumprod = alphas_cumprod / alphas_cumprod[0] |
|
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) |
|
return torch.clip(betas, 0, 0.999) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GaussianDiffusion(nn.Module): |
|
|
|
|
|
def __init__( |
|
self, |
|
model, |
|
*, |
|
image_size, |
|
timesteps = 1000, |
|
sampling_timesteps = None, |
|
objective = 'pred_noise', |
|
beta_schedule = 'linear', |
|
schedule_fn_kwargs = dict(), |
|
ddim_sampling_eta = 0., |
|
auto_normalize = True, |
|
offset_noise_strength = 0., |
|
min_snr_loss_weight = False, |
|
min_snr_gamma = 5 |
|
): |
|
super().__init__() |
|
assert not (type(self) == GaussianDiffusion and model.channels != model.out_dim) |
|
assert not hasattr(model, 'random_or_learned_sinusoidal_cond') or not model.random_or_learned_sinusoidal_cond |
|
|
|
self.model = model |
|
|
|
self.channels = self.model.channels |
|
self.self_condition = self.model.self_condition |
|
|
|
self.image_size = image_size |
|
|
|
self.objective = objective |
|
|
|
assert objective in {'pred_noise', 'pred_x0', 'pred_v'}, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start) or pred_v (predict v [v-parameterization as defined in appendix D of progressive distillation paper, used in imagen-video successfully])' |
|
|
|
if beta_schedule == 'linear': |
|
beta_schedule_fn = linear_beta_schedule |
|
elif beta_schedule == 'cosine': |
|
beta_schedule_fn = cosine_beta_schedule |
|
elif beta_schedule == 'sigmoid': |
|
beta_schedule_fn = sigmoid_beta_schedule |
|
else: |
|
raise ValueError(f'unknown beta schedule {beta_schedule}') |
|
|
|
betas = beta_schedule_fn(timesteps, **schedule_fn_kwargs) |
|
|
|
alphas = 1. - betas |
|
alphas_cumprod = torch.cumprod(alphas, dim=0) |
|
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.) |
|
|
|
timesteps, = betas.shape |
|
self.num_timesteps = int(timesteps) |
|
|
|
|
|
|
|
self.sampling_timesteps = default(sampling_timesteps, timesteps) |
|
|
|
assert self.sampling_timesteps <= timesteps |
|
self.is_ddim_sampling = self.sampling_timesteps < timesteps |
|
self.ddim_sampling_eta = ddim_sampling_eta |
|
|
|
|
|
|
|
register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32)) |
|
|
|
register_buffer('betas', betas) |
|
register_buffer('alphas_cumprod', alphas_cumprod) |
|
register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) |
|
|
|
|
|
|
|
register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) |
|
register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod)) |
|
register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod)) |
|
register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod)) |
|
register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1)) |
|
|
|
|
|
|
|
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) |
|
|
|
|
|
|
|
register_buffer('posterior_variance', posterior_variance) |
|
|
|
|
|
|
|
register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20))) |
|
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) |
|
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) |
|
|
|
|
|
|
|
self.offset_noise_strength = offset_noise_strength |
|
|
|
|
|
|
|
|
|
snr = alphas_cumprod / (1 - alphas_cumprod) |
|
|
|
|
|
|
|
maybe_clipped_snr = snr.clone() |
|
if min_snr_loss_weight: |
|
maybe_clipped_snr.clamp_(max = min_snr_gamma) |
|
|
|
if objective == 'pred_noise': |
|
register_buffer('loss_weight', maybe_clipped_snr / snr) |
|
elif objective == 'pred_x0': |
|
register_buffer('loss_weight', maybe_clipped_snr) |
|
elif objective == 'pred_v': |
|
register_buffer('loss_weight', maybe_clipped_snr / (snr + 1)) |
|
|
|
|
|
|
|
self.normalize = normalize_to_neg_one_to_one if auto_normalize else identity |
|
self.unnormalize = unnormalize_to_zero_to_one if auto_normalize else identity |
|
|
|
@property |
|
def device(self): |
|
return self.betas.device |
|
|
|
def predict_start_from_noise(self, x_t, t, noise): |
|
return ( |
|
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - |
|
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise |
|
) |
|
|
|
def predict_noise_from_start(self, x_t, t, x0): |
|
return ( |
|
(extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \ |
|
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) |
|
) |
|
|
|
def predict_v(self, x_start, t, noise): |
|
return ( |
|
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise - |
|
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start |
|
) |
|
|
|
def predict_start_from_v(self, x_t, t, v): |
|
return ( |
|
extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - |
|
extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v |
|
) |
|
|
|
def q_posterior(self, x_start, x_t, t): |
|
posterior_mean = ( |
|
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + |
|
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t |
|
) |
|
posterior_variance = extract(self.posterior_variance, t, x_t.shape) |
|
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) |
|
return posterior_mean, posterior_variance, posterior_log_variance_clipped |
|
|
|
def model_predictions(self, x, t, x_self_cond = None, clip_x_start = False, rederive_pred_noise = False): |
|
model_output = self.model(x, t, x_self_cond) |
|
maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity |
|
|
|
if self.objective == 'pred_noise': |
|
pred_noise = model_output |
|
x_start = self.predict_start_from_noise(x, t, pred_noise) |
|
x_start = maybe_clip(x_start) |
|
|
|
if clip_x_start and rederive_pred_noise: |
|
pred_noise = self.predict_noise_from_start(x, t, x_start) |
|
|
|
elif self.objective == 'pred_x0': |
|
x_start = model_output |
|
x_start = maybe_clip(x_start) |
|
pred_noise = self.predict_noise_from_start(x, t, x_start) |
|
|
|
elif self.objective == 'pred_v': |
|
v = model_output |
|
x_start = self.predict_start_from_v(x, t, v) |
|
x_start = maybe_clip(x_start) |
|
pred_noise = self.predict_noise_from_start(x, t, x_start) |
|
|
|
return ModelPrediction(pred_noise, x_start) |
|
|
|
def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True): |
|
preds = self.model_predictions(x, t, x_self_cond) |
|
x_start = preds.pred_x_start |
|
|
|
if clip_denoised: |
|
x_start.clamp_(-1., 1.) |
|
|
|
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t) |
|
return model_mean, posterior_variance, posterior_log_variance, x_start |
|
|
|
@torch.inference_mode() |
|
def p_sample(self, x, t: int, x_self_cond = None): |
|
b, *_, device = *x.shape, self.device |
|
batched_times = torch.full((b,), t, device = device, dtype = torch.long) |
|
model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, x_self_cond = x_self_cond, clip_denoised = True) |
|
noise = torch.randn_like(x) if t > 0 else 0. |
|
pred_img = model_mean + (0.5 * model_log_variance).exp() * noise |
|
return pred_img, x_start |
|
|
|
@torch.inference_mode() |
|
def p_sample_loop(self, shape, return_all_timesteps = False): |
|
batch, device = shape[0], self.device |
|
|
|
img = torch.randn(shape, device = device) |
|
imgs = [img] |
|
|
|
x_start = None |
|
|
|
for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps): |
|
self_cond = x_start if self.self_condition else None |
|
img, x_start = self.p_sample(img, t, self_cond) |
|
imgs.append(img) |
|
|
|
ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1) |
|
|
|
ret = self.unnormalize(ret) |
|
return ret |
|
|
|
@torch.inference_mode() |
|
def ddim_sample(self, shape, return_all_timesteps = False): |
|
batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective |
|
|
|
times = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1) |
|
times = list(reversed(times.int().tolist())) |
|
time_pairs = list(zip(times[:-1], times[1:])) |
|
|
|
img = torch.randn(shape, device = device) |
|
imgs = [img] |
|
|
|
x_start = None |
|
|
|
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'): |
|
time_cond = torch.full((batch,), time, device = device, dtype = torch.long) |
|
self_cond = x_start if self.self_condition else None |
|
pred_noise, x_start, *_ = self.model_predictions(img, time_cond, self_cond, clip_x_start = True, rederive_pred_noise = True) |
|
|
|
if time_next < 0: |
|
img = x_start |
|
imgs.append(img) |
|
continue |
|
|
|
alpha = self.alphas_cumprod[time] |
|
alpha_next = self.alphas_cumprod[time_next] |
|
|
|
sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() |
|
c = (1 - alpha_next - sigma ** 2).sqrt() |
|
|
|
noise = torch.randn_like(img) |
|
|
|
img = x_start * alpha_next.sqrt() + \ |
|
c * pred_noise + \ |
|
sigma * noise |
|
|
|
imgs.append(img) |
|
|
|
ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1) |
|
|
|
ret = self.unnormalize(ret) |
|
return ret |
|
|
|
@torch.inference_mode() |
|
def sample(self, batch_size = 16, return_all_timesteps = False): |
|
image_size, channels = self.image_size, self.channels |
|
sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample |
|
return sample_fn((batch_size, channels, image_size, image_size), return_all_timesteps = return_all_timesteps) |
|
|
|
@torch.inference_mode() |
|
def interpolate(self, x1, x2, t = None, lam = 0.5): |
|
b, *_, device = *x1.shape, x1.device |
|
t = default(t, self.num_timesteps - 1) |
|
|
|
assert x1.shape == x2.shape |
|
|
|
t_batched = torch.full((b,), t, device = device) |
|
xt1, xt2 = map(lambda x: self.q_sample(x, t = t_batched), (x1, x2)) |
|
|
|
img = (1 - lam) * xt1 + lam * xt2 |
|
|
|
x_start = None |
|
|
|
for i in tqdm(reversed(range(0, t)), desc = 'interpolation sample time step', total = t): |
|
self_cond = x_start if self.self_condition else None |
|
img, x_start = self.p_sample(img, i, self_cond) |
|
|
|
return img |
|
|
|
@autocast(enabled = False) |
|
def q_sample(self, x_start, t, noise = None): |
|
noise = default(noise, lambda: torch.randn_like(x_start)) |
|
|
|
return ( |
|
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + |
|
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise |
|
) |
|
|
|
def p_losses(self, x_start, t, noise = None, offset_noise_strength = None): |
|
b, c, h, w = x_start.shape |
|
|
|
noise = default(noise, lambda: torch.randn_like(x_start)) |
|
|
|
|
|
|
|
offset_noise_strength = default(offset_noise_strength, self.offset_noise_strength) |
|
|
|
if offset_noise_strength > 0.: |
|
offset_noise = torch.randn(x_start.shape[:2], device = self.device) |
|
noise += offset_noise_strength * rearrange(offset_noise, 'b c -> b c 1 1') |
|
|
|
|
|
|
|
x = self.q_sample(x_start = x_start, t = t, noise = noise) |
|
|
|
|
|
|
|
|
|
|
|
x_self_cond = None |
|
if self.self_condition and random() < 0.5: |
|
with torch.no_grad(): |
|
x_self_cond = self.model_predictions(x, t).pred_x_start |
|
x_self_cond.detach_() |
|
|
|
|
|
|
|
model_out = self.model(x, t, x_self_cond) |
|
|
|
if self.objective == 'pred_noise': |
|
target = noise |
|
elif self.objective == 'pred_x0': |
|
target = x_start |
|
elif self.objective == 'pred_v': |
|
v = self.predict_v(x_start, t, noise) |
|
target = v |
|
else: |
|
raise ValueError(f'unknown objective {self.objective}') |
|
|
|
loss = F.mse_loss(model_out, target, reduction = 'none') |
|
loss = reduce(loss, 'b ... -> b', 'mean') |
|
|
|
loss = loss * extract(self.loss_weight, t, loss.shape) |
|
return loss.mean() |
|
|
|
def forward(self, img, *args, **kwargs): |
|
b, c, h, w, device, img_size, = *img.shape, img.device, self.image_size |
|
assert h == img_size and w == img_size, f'height and width of image must be {img_size}' |
|
t = torch.randint(0, self.num_timesteps, (b,), device=device).long() |
|
|
|
img = self.normalize(img) |
|
return self.p_losses(img, t, *args, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def default_conv(in_channels, out_channels, kernel_size, bias=True): |
|
return nn.Conv2d( |
|
in_channels, out_channels, kernel_size, |
|
padding=(kernel_size//2), bias=bias) |
|
|
|
|
|
|
|
|
|
|
|
class Swish(nn.Module): |
|
def forward(self, x): |
|
return x * torch.sigmoid(x) |
|
|
|
|
|
|
|
|
|
|
|
class AttnBlock(nn.Module): |
|
def __init__(self, in_ch): |
|
super().__init__() |
|
self.group_norm = nn.GroupNorm(32, in_ch) |
|
self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) |
|
self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) |
|
self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) |
|
self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) |
|
|
|
def forward(self, x): |
|
B, C, H, W = x.shape |
|
h = self.group_norm(x) |
|
q = self.proj_q(h) |
|
k = self.proj_k(h) |
|
v = self.proj_v(h) |
|
|
|
q = q.permute(0, 2, 3, 1).view(B, H * W, C) |
|
k = k.view(B, C, H * W) |
|
w = torch.bmm(q, k) * (int(C) ** (-0.5)) |
|
assert list(w.shape) == [B, H * W, H * W] |
|
w = F.softmax(w, dim=-1) |
|
|
|
v = v.permute(0, 2, 3, 1).view(B, H * W, C) |
|
h = torch.bmm(w, v) |
|
assert list(h.shape) == [B, H * W, C] |
|
h = h.view(B, H, W, C).permute(0, 3, 1, 2) |
|
h = self.proj(h) |
|
|
|
return x + h |
|
|
|
|
|
|
|
|
|
|
|
class ResBlock(nn.Module): |
|
def __init__(self, in_ch, out_ch, tdim, dropout, attn=False): |
|
super().__init__() |
|
self.block1 = nn.Sequential( |
|
nn.GroupNorm(32, in_ch), |
|
Swish(), |
|
nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1), |
|
) |
|
self.temb_proj = nn.Sequential( |
|
Swish(), |
|
nn.Linear(tdim, out_ch), |
|
) |
|
self.block2 = nn.Sequential( |
|
nn.GroupNorm(32, out_ch), |
|
Swish(), |
|
nn.Dropout(dropout), |
|
nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1), |
|
) |
|
if in_ch != out_ch: |
|
self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0) |
|
else: |
|
self.shortcut = nn.Identity() |
|
if attn: |
|
self.attn = AttnBlock(out_ch) |
|
else: |
|
self.attn = nn.Identity() |
|
|
|
def forward(self, x, temb): |
|
h = self.block1(x) |
|
|
|
h = self.block2(h) |
|
|
|
h = h + self.shortcut(x) |
|
h = self.attn(h) |
|
return h |
|
|
|
|
|
|
|
|
|
|
|
class DownSample(nn.Module): |
|
def __init__(self, in_ch, out_ch): |
|
super().__init__() |
|
self.main = nn.Conv2d(in_ch, out_ch, 3, stride=2, padding=1) |
|
|
|
def forward(self, x, temb): |
|
x = self.main(x) |
|
return x |
|
|
|
|
|
|
|
|
|
|
|
class UpSample(nn.Module): |
|
def __init__(self, in_ch, out_ch): |
|
super().__init__() |
|
self.main = nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1) |
|
|
|
def forward(self, x, temb): |
|
_, _, H, W = x.shape |
|
x = F.interpolate( |
|
x, scale_factor=2, mode='nearest') |
|
x = self.main(x) |
|
return x |
|
|
|
|
|
|
|
|
|
|
|
class Unet(nn.Module): |
|
|
|
|
|
def __init__(self, |
|
n_feats=128, |
|
ch_mul=[1, 2, 4], |
|
attention_mul=[1, 2], |
|
t_dim=256, |
|
dropout=0.1, |
|
channels=1, |
|
out_dim=1, |
|
num_res_blocks=2, |
|
self_condition = False, |
|
learned_sinusoidal_cond=False, |
|
random_fourier_features=False, |
|
learned_sinusoidal_dim=16, |
|
sinusoidal_pos_emb_theta=10000, |
|
conv=default_conv): |
|
super(Unet, self).__init__() |
|
|
|
self.n_feats = n_feats |
|
self.t_dim = t_dim |
|
self.dropout = dropout |
|
self.channels = channels |
|
self.out_dim = out_dim |
|
self.self_condition = self_condition |
|
self.kernel_size = 3 |
|
|
|
|
|
if learned_sinusoidal_cond: |
|
sinu_pos_emb = RandomOrLearnedSinusoidalPosEmb(learned_sinusoidal_dim, random_fourier_features) |
|
fourier_dim = learned_sinusoidal_dim + 1 |
|
else: |
|
sinu_pos_emb = SinusoidalPosEmb(dim=self.n_feats, theta=sinusoidal_pos_emb_theta) |
|
fourier_dim = self.n_feats |
|
|
|
self.time_mlp = nn.Sequential( |
|
sinu_pos_emb, |
|
nn.Linear(fourier_dim, self.t_dim), |
|
nn.GELU(), |
|
nn.Linear(self.t_dim, self.t_dim) |
|
) |
|
|
|
|
|
self.head = conv(self.channels, self.n_feats, self.kernel_size) |
|
|
|
|
|
channel_list = [] |
|
current_channel = n_feats |
|
self.downblocks = nn.ModuleList() |
|
for i, mult in enumerate(ch_mul): |
|
out_channels = n_feats * mult |
|
for _ in range(num_res_blocks): |
|
self.downblocks.append( |
|
ResBlock(in_ch=current_channel, |
|
out_ch=out_channels, |
|
tdim=self.t_dim, |
|
dropout=self.dropout, |
|
attn=(mult in attention_mul))) |
|
current_channel = out_channels |
|
channel_list.append(current_channel) |
|
if i != len(ch_mul) - 1: |
|
out_channels = n_feats * ch_mul[i + 1] |
|
self.downblocks.append(DownSample(current_channel, out_channels)) |
|
channel_list.append((current_channel, out_channels)) |
|
current_channel = out_channels |
|
|
|
|
|
self.middleblocks = nn.ModuleList([ |
|
ResBlock(in_ch=current_channel, out_ch=current_channel, tdim=self.t_dim, dropout=self.dropout, attn=True), |
|
ResBlock(in_ch=current_channel, out_ch=current_channel, tdim=self.t_dim, dropout=self.dropout, attn=True), |
|
]) |
|
|
|
|
|
self.upblocks = nn.ModuleList() |
|
for i, mult in reversed(list(enumerate(ch_mul))): |
|
out_channels = n_feats * mult |
|
for _ in range(num_res_blocks): |
|
self.upblocks.append( |
|
ResBlock(in_ch=channel_list.pop(), |
|
out_ch=out_channels, |
|
tdim=self.t_dim, |
|
dropout=self.dropout, |
|
attn=(mult in attention_mul))) |
|
if i != 0: |
|
curr_ch, out_ch = channel_list.pop() |
|
self.upblocks.append(UpSample(out_ch, curr_ch)) |
|
self.upblocks.append(ResBlock(in_ch=curr_ch*2, |
|
out_ch=curr_ch, |
|
tdim=self.t_dim, |
|
dropout=self.dropout, |
|
attn=(mult in attention_mul))) |
|
|
|
current_channel = out_channels |
|
assert len(channel_list) == 0 |
|
|
|
|
|
|
|
self.tail = nn.Sequential( |
|
nn.GroupNorm(32, current_channel), |
|
Swish(), |
|
nn.Conv2d(current_channel, self.out_dim, 3, stride=1, padding=1) |
|
) |
|
|
|
|
|
def forward(self, x, t, cond=None): |
|
t = self.time_mlp(t) |
|
|
|
|
|
x = self.head(x) |
|
x_list = [] |
|
|
|
for block in self.downblocks: |
|
if isinstance(block, DownSample): |
|
x_list.append(x) |
|
x = block(x, t) |
|
|
|
|
|
for block in self.middleblocks: |
|
x = block(x, t) |
|
|
|
|
|
up = False |
|
for block in self.upblocks: |
|
if up: |
|
x = torch.concat([x_list.pop(), x], dim=1) |
|
up = False |
|
if isinstance(block, UpSample): |
|
up = True |
|
x = block(x, t) |
|
|
|
|
|
x = self.tail(x) |
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
save_path = '/content/DDPM_ResNet_Unet/unet_wo_t/model' |
|
log_path = '/content/DDPM_ResNet_Unet/unet_wo_t/log' |
|
|
|
if not os.path.exists(log_path): |
|
os.mkdir(log_path) |
|
if not os.path.exists(save_path): |
|
os.mkdir(save_path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
filename=os.path.join(log_path, 'info.log'), |
|
filemode="w", |
|
level=logging.DEBUG, |
|
format= '[%(asctime)s] %(levelname)s - %(message)s', |
|
datefmt='%H:%M:%S', |
|
force=True |
|
) |
|
|
|
|
|
|
|
pil_logger = logging.getLogger('PIL') |
|
pil_logger.setLevel(logging.INFO) |
|
|
|
|
|
console = logging.StreamHandler() |
|
console.setLevel(logging.INFO) |
|
logging.getLogger().addHandler(console) |
|
|
|
logger = logging.getLogger('Diffusion_Unet') |
|
|
|
|
|
|
|
|
|
|
|
|
|
model = Unet( |
|
n_feats=128, |
|
ch_mul=[1,2,4], |
|
attention_mul=[1,2], |
|
t_dim=512, |
|
dropout=0.1, |
|
channels=1, |
|
out_dim=1, |
|
num_res_blocks=2, |
|
self_condition = False, |
|
learned_sinusoidal_cond=False, |
|
random_fourier_features=False, |
|
learned_sinusoidal_dim=16, |
|
sinusoidal_pos_emb_theta=10000, |
|
) |
|
|
|
diffusion_model = GaussianDiffusion( |
|
model, |
|
image_size=28, |
|
timesteps=1000, |
|
sampling_timesteps=None, |
|
objective ='pred_noise', |
|
beta_schedule ='linear', |
|
schedule_fn_kwargs=dict(), |
|
ddim_sampling_eta= 0., |
|
auto_normalize = True, |
|
offset_noise_strength = 0., |
|
min_snr_loss_weight = False, |
|
min_snr_gamma = 5) |
|
|
|
|
|
|
|
|
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.ToTensor(), |
|
|
|
]) |
|
|
|
train_dataset = torchvision.datasets.MNIST(root='.', train=True, |
|
download=True, transform=transform) |
|
|
|
|
|
|
|
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_lr = 1e-4 |
|
adam_betas = (0.9, 0.99) |
|
|
|
optimizer = Adam(diffusion_model.parameters(), |
|
lr=train_lr, |
|
betas=adam_betas) |
|
|
|
|
|
|
|
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
|
|
|
|
|
|
|
|
max_epoches = 50 |
|
iter_print = 100 |
|
iter_sample = 1000 |
|
save_each = 1 |
|
|
|
diffusion_model = diffusion_model.to(device) |
|
|
|
last_trained_path = None |
|
if last_trained_path: |
|
data = torch.load(os.path.join(last_trained_path)) |
|
diffusion_model.load_state_dict(data['model']) |
|
optimizer.load_state_dict(data['opt']) |
|
count = data['step'] |
|
start_epoch = data['epoch'] |
|
log_loss = data['loss'] |
|
else: |
|
count = 0 |
|
start_epoch = 1 |
|
log_loss = [] |
|
|
|
for epoch in range(start_epoch, max_epoches+1): |
|
diffusion_model.train() |
|
for img, _ in train_dataloader: |
|
img = img.to(device) |
|
|
|
loss = diffusion_model(img) |
|
|
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
|
|
if count % iter_print == 0 or count == 0: |
|
logger.info('Epoch {}/{}, Iter {}: Loss = {}, lr = {}'.format( |
|
epoch, |
|
max_epoches, |
|
count, |
|
loss.mean().item(), |
|
train_lr, |
|
)) |
|
|
|
log_loss.append(loss.mean().item()) |
|
|
|
loss = None |
|
|
|
count += 1 |
|
|
|
if count % iter_sample == 0: |
|
diffusion_model.eval() |
|
|
|
sample_imgs = diffusion_model.sample(batch_size=16) |
|
|
|
utils.save_image(sample_imgs, |
|
os.path.join(log_path, f"iter_{count}.png"), |
|
nrow = int(math.sqrt(16))) |
|
|
|
|
|
if epoch % save_each == 0: |
|
data = { |
|
'model': diffusion_model.state_dict(), |
|
'opt': optimizer.state_dict(), |
|
'step': count, |
|
'epoch': epoch, |
|
'loss': log_loss, |
|
} |
|
|
|
torch.save(data, os.path.join(save_path, f"epoch_{epoch}.pth")) |
|
|
|
|