DDPM_ResNet_Unet / unet_wo_t /DDPM_Unet_wo_t.py
ndbao2002's picture
Update unet_wo_t/DDPM_Unet_wo_t.py
04469b8 verified
#!/usr/bin/env python
# coding: utf-8
# # Library
# In[1]:
import torch
import torch.nn.functional as F
from torch import nn
from torch.cuda.amp import autocast
import torchvision
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
from torch.optim import Adam
from einops import rearrange, reduce, repeat
import math
from random import random
from collections import namedtuple
from functools import partial
from tqdm.auto import tqdm
import logging
import os
from PIL import Image
from torchvision import utils
# # Helper
# ### Constant
# In[2]:
ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])
# ### Functions
# In[3]:
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
# In[4]:
def cast_tuple(t, length = 1):
if isinstance(t, tuple):
return t
return ((t,) * length)
# In[5]:
def divisible_by(numer, denom):
return (numer % denom) == 0
# In[6]:
def identity(t, *args, **kwargs):
return t
# In[7]:
def cycle(dl):
while True:
for data in dl:
yield data
# In[8]:
def has_int_squareroot(num):
return (math.sqrt(num) ** 2) == num
# In[9]:
def num_to_groups(num, divisor):
groups = num // divisor
remainder = num % divisor
arr = [divisor] * groups
if remainder > 0:
arr.append(remainder)
return arr
# In[10]:
def convert_image_to_fn(img_type, image):
if image.mode != img_type:
return image.convert(img_type)
return image
# In[11]:
def extract(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
# ### Normalization Functions
# In[12]:
def normalize_to_neg_one_to_one(img):
return img * 2 - 1
def unnormalize_to_zero_to_one(t):
return (t + 1) * 0.5
# ### Sinusoidal positional embeds
# In[13]:
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim, theta = 10000):
super().__init__()
self.dim = dim
self.theta = theta
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(self.theta) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
# In[14]:
class RandomOrLearnedSinusoidalPosEmb(nn.Module):
""" following @crowsonkb 's lead with random (learned optional) sinusoidal pos emb """
""" https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
def __init__(self, dim, is_random = False):
super().__init__()
assert divisible_by(dim, 2)
half_dim = dim // 2
self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = not is_random)
def forward(self, x):
x = rearrange(x, 'b -> b 1')
freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
fouriered = torch.cat((x, fouriered), dim = -1)
return fouriered
# ### Schedule
# In[15]:
def linear_beta_schedule(timesteps):
"""
linear schedule, proposed in original ddpm paper
"""
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)
# In[16]:
def cosine_beta_schedule(timesteps, s = 0.008):
"""
cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps = timesteps + 1
t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps
alphas_cumprod = torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.999)
# In[17]:
def sigmoid_beta_schedule(timesteps, start = -3, end = 3, tau = 1, clamp_min = 1e-5):
"""
sigmoid schedule
proposed in https://arxiv.org/abs/2212.11972 - Figure 8
better for images > 64x64, when used during training
"""
steps = timesteps + 1
t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps
v_start = torch.tensor(start / tau).sigmoid()
v_end = torch.tensor(end / tau).sigmoid()
alphas_cumprod = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start)
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.999)
# # Diffusion model
# In[18]:
class GaussianDiffusion(nn.Module):
# Copy from https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L163
def __init__(
self,
model,
*,
image_size,
timesteps = 1000,
sampling_timesteps = None,
objective = 'pred_noise',
beta_schedule = 'linear',
schedule_fn_kwargs = dict(),
ddim_sampling_eta = 0.,
auto_normalize = True,
offset_noise_strength = 0., # https://www.crosslabs.org/blog/diffusion-with-offset-noise
min_snr_loss_weight = False, # https://arxiv.org/abs/2303.09556
min_snr_gamma = 5
):
super().__init__()
assert not (type(self) == GaussianDiffusion and model.channels != model.out_dim)
assert not hasattr(model, 'random_or_learned_sinusoidal_cond') or not model.random_or_learned_sinusoidal_cond
self.model = model
self.channels = self.model.channels
self.self_condition = self.model.self_condition
self.image_size = image_size
self.objective = objective
assert objective in {'pred_noise', 'pred_x0', 'pred_v'}, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start) or pred_v (predict v [v-parameterization as defined in appendix D of progressive distillation paper, used in imagen-video successfully])'
if beta_schedule == 'linear':
beta_schedule_fn = linear_beta_schedule
elif beta_schedule == 'cosine':
beta_schedule_fn = cosine_beta_schedule
elif beta_schedule == 'sigmoid':
beta_schedule_fn = sigmoid_beta_schedule
else:
raise ValueError(f'unknown beta schedule {beta_schedule}')
betas = beta_schedule_fn(timesteps, **schedule_fn_kwargs)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
# sampling related parameters
self.sampling_timesteps = default(sampling_timesteps, timesteps) # default num sampling timesteps to number of timesteps at training
assert self.sampling_timesteps <= timesteps
self.is_ddim_sampling = self.sampling_timesteps < timesteps
self.ddim_sampling_eta = ddim_sampling_eta
# helper function to register buffer from float64 to float32
register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
register_buffer('betas', betas)
register_buffer('alphas_cumprod', alphas_cumprod)
register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
# calculations for diffusion q(x_t | x_{t-1}) and others
register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
register_buffer('posterior_variance', posterior_variance)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
# offset noise strength - in blogpost, they claimed 0.1 was ideal
self.offset_noise_strength = offset_noise_strength
# derive loss weight
# snr - signal noise ratio
snr = alphas_cumprod / (1 - alphas_cumprod)
# https://arxiv.org/abs/2303.09556
maybe_clipped_snr = snr.clone()
if min_snr_loss_weight:
maybe_clipped_snr.clamp_(max = min_snr_gamma)
if objective == 'pred_noise':
register_buffer('loss_weight', maybe_clipped_snr / snr)
elif objective == 'pred_x0':
register_buffer('loss_weight', maybe_clipped_snr)
elif objective == 'pred_v':
register_buffer('loss_weight', maybe_clipped_snr / (snr + 1))
# auto-normalization of data [0, 1] -> [-1, 1] - can turn off by setting it to be False
self.normalize = normalize_to_neg_one_to_one if auto_normalize else identity
self.unnormalize = unnormalize_to_zero_to_one if auto_normalize else identity
@property
def device(self):
return self.betas.device
def predict_start_from_noise(self, x_t, t, noise):
return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
def predict_noise_from_start(self, x_t, t, x0):
return (
(extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
)
def predict_v(self, x_start, t, noise):
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
)
def predict_start_from_v(self, x_t, t, v):
return (
extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
)
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def model_predictions(self, x, t, x_self_cond = None, clip_x_start = False, rederive_pred_noise = False):
model_output = self.model(x, t, x_self_cond)
maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity
if self.objective == 'pred_noise':
pred_noise = model_output
x_start = self.predict_start_from_noise(x, t, pred_noise)
x_start = maybe_clip(x_start)
if clip_x_start and rederive_pred_noise:
pred_noise = self.predict_noise_from_start(x, t, x_start)
elif self.objective == 'pred_x0':
x_start = model_output
x_start = maybe_clip(x_start)
pred_noise = self.predict_noise_from_start(x, t, x_start)
elif self.objective == 'pred_v':
v = model_output
x_start = self.predict_start_from_v(x, t, v)
x_start = maybe_clip(x_start)
pred_noise = self.predict_noise_from_start(x, t, x_start)
return ModelPrediction(pred_noise, x_start)
def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True):
preds = self.model_predictions(x, t, x_self_cond)
x_start = preds.pred_x_start
if clip_denoised:
x_start.clamp_(-1., 1.)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t)
return model_mean, posterior_variance, posterior_log_variance, x_start
@torch.inference_mode()
def p_sample(self, x, t: int, x_self_cond = None):
b, *_, device = *x.shape, self.device
batched_times = torch.full((b,), t, device = device, dtype = torch.long)
model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, x_self_cond = x_self_cond, clip_denoised = True)
noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0
pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
return pred_img, x_start
@torch.inference_mode()
def p_sample_loop(self, shape, return_all_timesteps = False):
batch, device = shape[0], self.device
img = torch.randn(shape, device = device)
imgs = [img]
x_start = None
for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
self_cond = x_start if self.self_condition else None
img, x_start = self.p_sample(img, t, self_cond)
imgs.append(img)
ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1)
ret = self.unnormalize(ret)
return ret
@torch.inference_mode()
def ddim_sample(self, shape, return_all_timesteps = False):
batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
times = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
times = list(reversed(times.int().tolist()))
time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
img = torch.randn(shape, device = device)
imgs = [img]
x_start = None
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
self_cond = x_start if self.self_condition else None
pred_noise, x_start, *_ = self.model_predictions(img, time_cond, self_cond, clip_x_start = True, rederive_pred_noise = True)
if time_next < 0:
img = x_start
imgs.append(img)
continue
alpha = self.alphas_cumprod[time]
alpha_next = self.alphas_cumprod[time_next]
sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c = (1 - alpha_next - sigma ** 2).sqrt()
noise = torch.randn_like(img)
img = x_start * alpha_next.sqrt() + \
c * pred_noise + \
sigma * noise
imgs.append(img)
ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1)
ret = self.unnormalize(ret)
return ret
@torch.inference_mode()
def sample(self, batch_size = 16, return_all_timesteps = False):
image_size, channels = self.image_size, self.channels
sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample
return sample_fn((batch_size, channels, image_size, image_size), return_all_timesteps = return_all_timesteps)
@torch.inference_mode()
def interpolate(self, x1, x2, t = None, lam = 0.5):
b, *_, device = *x1.shape, x1.device
t = default(t, self.num_timesteps - 1)
assert x1.shape == x2.shape
t_batched = torch.full((b,), t, device = device)
xt1, xt2 = map(lambda x: self.q_sample(x, t = t_batched), (x1, x2))
img = (1 - lam) * xt1 + lam * xt2
x_start = None
for i in tqdm(reversed(range(0, t)), desc = 'interpolation sample time step', total = t):
self_cond = x_start if self.self_condition else None
img, x_start = self.p_sample(img, i, self_cond)
return img
@autocast(enabled = False)
def q_sample(self, x_start, t, noise = None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
def p_losses(self, x_start, t, noise = None, offset_noise_strength = None):
b, c, h, w = x_start.shape
noise = default(noise, lambda: torch.randn_like(x_start))
# offset noise - https://www.crosslabs.org/blog/diffusion-with-offset-noise
offset_noise_strength = default(offset_noise_strength, self.offset_noise_strength)
if offset_noise_strength > 0.:
offset_noise = torch.randn(x_start.shape[:2], device = self.device)
noise += offset_noise_strength * rearrange(offset_noise, 'b c -> b c 1 1')
# noise sample
x = self.q_sample(x_start = x_start, t = t, noise = noise)
# if doing self-conditioning, 50% of the time, predict x_start from current set of times
# and condition with unet with that
# this technique will slow down training by 25%, but seems to lower FID significantly
x_self_cond = None
if self.self_condition and random() < 0.5:
with torch.no_grad():
x_self_cond = self.model_predictions(x, t).pred_x_start
x_self_cond.detach_()
# predict and take gradient step
model_out = self.model(x, t, x_self_cond)
if self.objective == 'pred_noise':
target = noise
elif self.objective == 'pred_x0':
target = x_start
elif self.objective == 'pred_v':
v = self.predict_v(x_start, t, noise)
target = v
else:
raise ValueError(f'unknown objective {self.objective}')
loss = F.mse_loss(model_out, target, reduction = 'none')
loss = reduce(loss, 'b ... -> b', 'mean')
loss = loss * extract(self.loss_weight, t, loss.shape)
return loss.mean()
def forward(self, img, *args, **kwargs):
b, c, h, w, device, img_size, = *img.shape, img.device, self.image_size
assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
img = self.normalize(img)
return self.p_losses(img, t, *args, **kwargs)
# # Resnet Model
# In[19]:
def default_conv(in_channels, out_channels, kernel_size, bias=True):
return nn.Conv2d(
in_channels, out_channels, kernel_size,
padding=(kernel_size//2), bias=bias)
# In[20]:
class Swish(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
# In[21]:
class AttnBlock(nn.Module):
def __init__(self, in_ch):
super().__init__()
self.group_norm = nn.GroupNorm(32, in_ch)
self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
def forward(self, x):
B, C, H, W = x.shape
h = self.group_norm(x)
q = self.proj_q(h)
k = self.proj_k(h)
v = self.proj_v(h)
q = q.permute(0, 2, 3, 1).view(B, H * W, C)
k = k.view(B, C, H * W)
w = torch.bmm(q, k) * (int(C) ** (-0.5))
assert list(w.shape) == [B, H * W, H * W]
w = F.softmax(w, dim=-1)
v = v.permute(0, 2, 3, 1).view(B, H * W, C)
h = torch.bmm(w, v)
assert list(h.shape) == [B, H * W, C]
h = h.view(B, H, W, C).permute(0, 3, 1, 2)
h = self.proj(h)
return x + h
# In[22]:
class ResBlock(nn.Module):
def __init__(self, in_ch, out_ch, tdim, dropout, attn=False):
super().__init__()
self.block1 = nn.Sequential(
nn.GroupNorm(32, in_ch),
Swish(),
nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),
)
self.temb_proj = nn.Sequential(
Swish(),
nn.Linear(tdim, out_ch),
)
self.block2 = nn.Sequential(
nn.GroupNorm(32, out_ch),
Swish(),
nn.Dropout(dropout),
nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
)
if in_ch != out_ch:
self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)
else:
self.shortcut = nn.Identity()
if attn:
self.attn = AttnBlock(out_ch)
else:
self.attn = nn.Identity()
def forward(self, x, temb):
h = self.block1(x)
# h += self.temb_proj(temb)[:, :, None, None]
h = self.block2(h)
h = h + self.shortcut(x)
h = self.attn(h)
return h
# In[23]:
class DownSample(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.main = nn.Conv2d(in_ch, out_ch, 3, stride=2, padding=1)
def forward(self, x, temb):
x = self.main(x)
return x
# In[24]:
class UpSample(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.main = nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1)
def forward(self, x, temb):
_, _, H, W = x.shape
x = F.interpolate(
x, scale_factor=2, mode='nearest')
x = self.main(x)
return x
# In[26]:
class Unet(nn.Module):
# Modified from https://github.com/sanghyun-son/EDSR-PyTorch/blob/master/src/model/edsr.py#L31
def __init__(self,
n_feats=128,
ch_mul=[1, 2, 4],
attention_mul=[1, 2],
t_dim=256,
dropout=0.1,
channels=1,
out_dim=1,
num_res_blocks=2,
self_condition = False,
learned_sinusoidal_cond=False,
random_fourier_features=False,
learned_sinusoidal_dim=16,
sinusoidal_pos_emb_theta=10000,
conv=default_conv):
super(Unet, self).__init__()
self.n_feats = n_feats
self.t_dim = t_dim
self.dropout = dropout
self.channels = channels
self.out_dim = out_dim
self.self_condition = self_condition
self.kernel_size = 3
# define time embedding
if learned_sinusoidal_cond:
sinu_pos_emb = RandomOrLearnedSinusoidalPosEmb(learned_sinusoidal_dim, random_fourier_features)
fourier_dim = learned_sinusoidal_dim + 1
else:
sinu_pos_emb = SinusoidalPosEmb(dim=self.n_feats, theta=sinusoidal_pos_emb_theta)
fourier_dim = self.n_feats
self.time_mlp = nn.Sequential(
sinu_pos_emb,
nn.Linear(fourier_dim, self.t_dim),
nn.GELU(),
nn.Linear(self.t_dim, self.t_dim)
)
# define head module
self.head = conv(self.channels, self.n_feats, self.kernel_size)
# define downsample module
channel_list = []
current_channel = n_feats
self.downblocks = nn.ModuleList()
for i, mult in enumerate(ch_mul):
out_channels = n_feats * mult
for _ in range(num_res_blocks):
self.downblocks.append(
ResBlock(in_ch=current_channel,
out_ch=out_channels,
tdim=self.t_dim,
dropout=self.dropout,
attn=(mult in attention_mul)))
current_channel = out_channels
channel_list.append(current_channel)
if i != len(ch_mul) - 1:
out_channels = n_feats * ch_mul[i + 1]
self.downblocks.append(DownSample(current_channel, out_channels))
channel_list.append((current_channel, out_channels))
current_channel = out_channels
# define middle module
self.middleblocks = nn.ModuleList([
ResBlock(in_ch=current_channel, out_ch=current_channel, tdim=self.t_dim, dropout=self.dropout, attn=True),
ResBlock(in_ch=current_channel, out_ch=current_channel, tdim=self.t_dim, dropout=self.dropout, attn=True),
])
# define upsample module
self.upblocks = nn.ModuleList()
for i, mult in reversed(list(enumerate(ch_mul))):
out_channels = n_feats * mult
for _ in range(num_res_blocks):
self.upblocks.append(
ResBlock(in_ch=channel_list.pop(),
out_ch=out_channels,
tdim=self.t_dim,
dropout=self.dropout,
attn=(mult in attention_mul)))
if i != 0:
curr_ch, out_ch = channel_list.pop()
self.upblocks.append(UpSample(out_ch, curr_ch))
self.upblocks.append(ResBlock(in_ch=curr_ch*2,
out_ch=curr_ch,
tdim=self.t_dim,
dropout=self.dropout,
attn=(mult in attention_mul)))
current_channel = out_channels
assert len(channel_list) == 0
# define tail module
self.tail = nn.Sequential(
nn.GroupNorm(32, current_channel),
Swish(),
nn.Conv2d(current_channel, self.out_dim, 3, stride=1, padding=1)
)
def forward(self, x, t, cond=None):
t = self.time_mlp(t)
# Downsample
x = self.head(x)
x_list = []
for block in self.downblocks:
if isinstance(block, DownSample):
x_list.append(x)
x = block(x, t)
# Middle
for block in self.middleblocks:
x = block(x, t)
# Upsample
up = False
for block in self.upblocks:
if up:
x = torch.concat([x_list.pop(), x], dim=1)
up = False
if isinstance(block, UpSample):
up = True
x = block(x, t)
x = self.tail(x)
return x
# # Train
# In[24]:
# output dir
save_path = '/content/DDPM_ResNet_Unet/unet_wo_t/model'
log_path = '/content/DDPM_ResNet_Unet/unet_wo_t/log'
if not os.path.exists(log_path):
os.mkdir(log_path)
if not os.path.exists(save_path):
os.mkdir(save_path)
# In[25]:
# setup logging
# Setup logging to file
logging.basicConfig(
filename=os.path.join(log_path, 'info.log'),
filemode="w",
level=logging.DEBUG,
format= '[%(asctime)s] %(levelname)s - %(message)s',
datefmt='%H:%M:%S',
force=True
)
# Stop PIL from printing to file
pil_logger = logging.getLogger('PIL')
pil_logger.setLevel(logging.INFO)
# write and print at the same time
console = logging.StreamHandler()
console.setLevel(logging.INFO)
logging.getLogger().addHandler(console)
logger = logging.getLogger('Diffusion_Unet')
# In[26]:
# define model
model = Unet(
n_feats=128,
ch_mul=[1,2,4],
attention_mul=[1,2],
t_dim=512,
dropout=0.1,
channels=1, # MNIST
out_dim=1, # MNIST
num_res_blocks=2,
self_condition = False,
learned_sinusoidal_cond=False,
random_fourier_features=False,
learned_sinusoidal_dim=16,
sinusoidal_pos_emb_theta=10000,
)
diffusion_model = GaussianDiffusion(
model,
image_size=28, # MNIST
timesteps=1000,
sampling_timesteps=None,
objective ='pred_noise',
beta_schedule ='linear',
schedule_fn_kwargs=dict(),
ddim_sampling_eta= 0.,
auto_normalize = True,
offset_noise_strength = 0., # https://www.crosslabs.org/blog/diffusion-with-offset-noise
min_snr_loss_weight = False, # https://arxiv.org/abs/2303.09556
min_snr_gamma = 5)
# In[27]:
# define dataset
transform = transforms.Compose([
transforms.ToTensor(),
# v2.Normalize((0.1307,), (0.3081,)), # https://stackoverflow.com/questions/70892017/normalize-mnist-in-pytorch
])
train_dataset = torchvision.datasets.MNIST(root='.', train=True,
download=True, transform=transform)
# test_dataset = torchvision.datasets.MNIST(root='.', train=True,
# download=True, transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
# test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)
# In[28]:
# define optimizer
train_lr = 1e-4
adam_betas = (0.9, 0.99)
optimizer = Adam(diffusion_model.parameters(),
lr=train_lr,
betas=adam_betas)
# In[29]:
# device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# In[30]:
# trainer
max_epoches = 50
iter_print = 100
iter_sample = 1000
save_each = 1
diffusion_model = diffusion_model.to(device)
last_trained_path = None
if last_trained_path:
data = torch.load(os.path.join(last_trained_path))
diffusion_model.load_state_dict(data['model'])
optimizer.load_state_dict(data['opt'])
count = data['step']
start_epoch = data['epoch']
log_loss = data['loss']
else:
count = 0
start_epoch = 1
log_loss = []
for epoch in range(start_epoch, max_epoches+1):
diffusion_model.train()
for img, _ in train_dataloader:
img = img.to(device)
loss = diffusion_model(img)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if count % iter_print == 0 or count == 0:
logger.info('Epoch {}/{}, Iter {}: Loss = {}, lr = {}'.format(
epoch,
max_epoches,
count,
loss.mean().item(),
train_lr,
))
log_loss.append(loss.mean().item())
loss = None
count += 1
if count % iter_sample == 0:
diffusion_model.eval()
sample_imgs = diffusion_model.sample(batch_size=16)
utils.save_image(sample_imgs,
os.path.join(log_path, f"iter_{count}.png"),
nrow = int(math.sqrt(16)))
if epoch % save_each == 0:
data = {
'model': diffusion_model.state_dict(),
'opt': optimizer.state_dict(),
'step': count,
'epoch': epoch,
'loss': log_loss,
}
torch.save(data, os.path.join(save_path, f"epoch_{epoch}.pth"))