#!/usr/bin/env python # coding: utf-8 # # Library # In[1]: import torch import torch.nn.functional as F from torch import nn from torch.cuda.amp import autocast import torchvision from torchvision.transforms import transforms from torch.utils.data import DataLoader from torch.optim import Adam from einops import rearrange, reduce, repeat import math from random import random from collections import namedtuple from functools import partial from tqdm.auto import tqdm import logging import os from PIL import Image from torchvision import utils # # Helper # ### Constant # In[2]: ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start']) # ### Functions # In[3]: def exists(x): return x is not None def default(val, d): if exists(val): return val return d() if callable(d) else d # In[4]: def cast_tuple(t, length = 1): if isinstance(t, tuple): return t return ((t,) * length) # In[5]: def divisible_by(numer, denom): return (numer % denom) == 0 # In[6]: def identity(t, *args, **kwargs): return t # In[7]: def cycle(dl): while True: for data in dl: yield data # In[8]: def has_int_squareroot(num): return (math.sqrt(num) ** 2) == num # In[9]: def num_to_groups(num, divisor): groups = num // divisor remainder = num % divisor arr = [divisor] * groups if remainder > 0: arr.append(remainder) return arr # In[10]: def convert_image_to_fn(img_type, image): if image.mode != img_type: return image.convert(img_type) return image # In[11]: def extract(a, t, x_shape): b, *_ = t.shape out = a.gather(-1, t) return out.reshape(b, *((1,) * (len(x_shape) - 1))) # ### Normalization Functions # In[12]: def normalize_to_neg_one_to_one(img): return img * 2 - 1 def unnormalize_to_zero_to_one(t): return (t + 1) * 0.5 # ### Sinusoidal positional embeds # In[13]: class SinusoidalPosEmb(nn.Module): def __init__(self, dim, theta = 10000): super().__init__() self.dim = dim self.theta = theta def forward(self, x): device = x.device half_dim = self.dim // 2 emb = math.log(self.theta) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device=device) * -emb) emb = x[:, None] * emb[None, :] emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb # In[14]: class RandomOrLearnedSinusoidalPosEmb(nn.Module): """ following @crowsonkb 's lead with random (learned optional) sinusoidal pos emb """ """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ def __init__(self, dim, is_random = False): super().__init__() assert divisible_by(dim, 2) half_dim = dim // 2 self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = not is_random) def forward(self, x): x = rearrange(x, 'b -> b 1') freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1) fouriered = torch.cat((x, fouriered), dim = -1) return fouriered # ### Schedule # In[15]: def linear_beta_schedule(timesteps): """ linear schedule, proposed in original ddpm paper """ scale = 1000 / timesteps beta_start = scale * 0.0001 beta_end = scale * 0.02 return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64) # In[16]: def cosine_beta_schedule(timesteps, s = 0.008): """ cosine schedule as proposed in https://openreview.net/forum?id=-NEXDKk8gZ """ steps = timesteps + 1 t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps alphas_cumprod = torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** 2 alphas_cumprod = alphas_cumprod / alphas_cumprod[0] betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) return torch.clip(betas, 0, 0.999) # In[17]: def sigmoid_beta_schedule(timesteps, start = -3, end = 3, tau = 1, clamp_min = 1e-5): """ sigmoid schedule proposed in https://arxiv.org/abs/2212.11972 - Figure 8 better for images > 64x64, when used during training """ steps = timesteps + 1 t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps v_start = torch.tensor(start / tau).sigmoid() v_end = torch.tensor(end / tau).sigmoid() alphas_cumprod = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start) alphas_cumprod = alphas_cumprod / alphas_cumprod[0] betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) return torch.clip(betas, 0, 0.999) # # Diffusion model # In[18]: class GaussianDiffusion(nn.Module): # Copy from https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L163 def __init__( self, model, *, image_size, timesteps = 1000, sampling_timesteps = None, objective = 'pred_noise', beta_schedule = 'linear', schedule_fn_kwargs = dict(), ddim_sampling_eta = 0., auto_normalize = True, offset_noise_strength = 0., # https://www.crosslabs.org/blog/diffusion-with-offset-noise min_snr_loss_weight = False, # https://arxiv.org/abs/2303.09556 min_snr_gamma = 5 ): super().__init__() assert not (type(self) == GaussianDiffusion and model.channels != model.out_dim) assert not hasattr(model, 'random_or_learned_sinusoidal_cond') or not model.random_or_learned_sinusoidal_cond self.model = model self.channels = self.model.channels self.self_condition = self.model.self_condition self.image_size = image_size self.objective = objective assert objective in {'pred_noise', 'pred_x0', 'pred_v'}, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start) or pred_v (predict v [v-parameterization as defined in appendix D of progressive distillation paper, used in imagen-video successfully])' if beta_schedule == 'linear': beta_schedule_fn = linear_beta_schedule elif beta_schedule == 'cosine': beta_schedule_fn = cosine_beta_schedule elif beta_schedule == 'sigmoid': beta_schedule_fn = sigmoid_beta_schedule else: raise ValueError(f'unknown beta schedule {beta_schedule}') betas = beta_schedule_fn(timesteps, **schedule_fn_kwargs) alphas = 1. - betas alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.) timesteps, = betas.shape self.num_timesteps = int(timesteps) # sampling related parameters self.sampling_timesteps = default(sampling_timesteps, timesteps) # default num sampling timesteps to number of timesteps at training assert self.sampling_timesteps <= timesteps self.is_ddim_sampling = self.sampling_timesteps < timesteps self.ddim_sampling_eta = ddim_sampling_eta # helper function to register buffer from float64 to float32 register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32)) register_buffer('betas', betas) register_buffer('alphas_cumprod', alphas_cumprod) register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) # calculations for diffusion q(x_t | x_{t-1}) and others register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod)) register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod)) register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod)) register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1)) # calculations for posterior q(x_{t-1} | x_t, x_0) posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) register_buffer('posterior_variance', posterior_variance) # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20))) register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) # offset noise strength - in blogpost, they claimed 0.1 was ideal self.offset_noise_strength = offset_noise_strength # derive loss weight # snr - signal noise ratio snr = alphas_cumprod / (1 - alphas_cumprod) # https://arxiv.org/abs/2303.09556 maybe_clipped_snr = snr.clone() if min_snr_loss_weight: maybe_clipped_snr.clamp_(max = min_snr_gamma) if objective == 'pred_noise': register_buffer('loss_weight', maybe_clipped_snr / snr) elif objective == 'pred_x0': register_buffer('loss_weight', maybe_clipped_snr) elif objective == 'pred_v': register_buffer('loss_weight', maybe_clipped_snr / (snr + 1)) # auto-normalization of data [0, 1] -> [-1, 1] - can turn off by setting it to be False self.normalize = normalize_to_neg_one_to_one if auto_normalize else identity self.unnormalize = unnormalize_to_zero_to_one if auto_normalize else identity @property def device(self): return self.betas.device def predict_start_from_noise(self, x_t, t, noise): return ( extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise ) def predict_noise_from_start(self, x_t, t, x0): return ( (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \ extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) ) def predict_v(self, x_start, t, noise): return ( extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise - extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start ) def predict_start_from_v(self, x_t, t, v): return ( extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v ) def q_posterior(self, x_start, x_t, t): posterior_mean = ( extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t ) posterior_variance = extract(self.posterior_variance, t, x_t.shape) posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) return posterior_mean, posterior_variance, posterior_log_variance_clipped def model_predictions(self, x, t, x_self_cond = None, clip_x_start = False, rederive_pred_noise = False): model_output = self.model(x, t, x_self_cond) maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity if self.objective == 'pred_noise': pred_noise = model_output x_start = self.predict_start_from_noise(x, t, pred_noise) x_start = maybe_clip(x_start) if clip_x_start and rederive_pred_noise: pred_noise = self.predict_noise_from_start(x, t, x_start) elif self.objective == 'pred_x0': x_start = model_output x_start = maybe_clip(x_start) pred_noise = self.predict_noise_from_start(x, t, x_start) elif self.objective == 'pred_v': v = model_output x_start = self.predict_start_from_v(x, t, v) x_start = maybe_clip(x_start) pred_noise = self.predict_noise_from_start(x, t, x_start) return ModelPrediction(pred_noise, x_start) def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True): preds = self.model_predictions(x, t, x_self_cond) x_start = preds.pred_x_start if clip_denoised: x_start.clamp_(-1., 1.) model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t) return model_mean, posterior_variance, posterior_log_variance, x_start @torch.inference_mode() def p_sample(self, x, t: int, x_self_cond = None): b, *_, device = *x.shape, self.device batched_times = torch.full((b,), t, device = device, dtype = torch.long) model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, x_self_cond = x_self_cond, clip_denoised = True) noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0 pred_img = model_mean + (0.5 * model_log_variance).exp() * noise return pred_img, x_start @torch.inference_mode() def p_sample_loop(self, shape, return_all_timesteps = False): batch, device = shape[0], self.device img = torch.randn(shape, device = device) imgs = [img] x_start = None for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps): self_cond = x_start if self.self_condition else None img, x_start = self.p_sample(img, t, self_cond) imgs.append(img) ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1) ret = self.unnormalize(ret) return ret @torch.inference_mode() def ddim_sample(self, shape, return_all_timesteps = False): batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective times = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps times = list(reversed(times.int().tolist())) time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)] img = torch.randn(shape, device = device) imgs = [img] x_start = None for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'): time_cond = torch.full((batch,), time, device = device, dtype = torch.long) self_cond = x_start if self.self_condition else None pred_noise, x_start, *_ = self.model_predictions(img, time_cond, self_cond, clip_x_start = True, rederive_pred_noise = True) if time_next < 0: img = x_start imgs.append(img) continue alpha = self.alphas_cumprod[time] alpha_next = self.alphas_cumprod[time_next] sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() c = (1 - alpha_next - sigma ** 2).sqrt() noise = torch.randn_like(img) img = x_start * alpha_next.sqrt() + \ c * pred_noise + \ sigma * noise imgs.append(img) ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1) ret = self.unnormalize(ret) return ret @torch.inference_mode() def sample(self, batch_size = 16, return_all_timesteps = False): image_size, channels = self.image_size, self.channels sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample return sample_fn((batch_size, channels, image_size, image_size), return_all_timesteps = return_all_timesteps) @torch.inference_mode() def interpolate(self, x1, x2, t = None, lam = 0.5): b, *_, device = *x1.shape, x1.device t = default(t, self.num_timesteps - 1) assert x1.shape == x2.shape t_batched = torch.full((b,), t, device = device) xt1, xt2 = map(lambda x: self.q_sample(x, t = t_batched), (x1, x2)) img = (1 - lam) * xt1 + lam * xt2 x_start = None for i in tqdm(reversed(range(0, t)), desc = 'interpolation sample time step', total = t): self_cond = x_start if self.self_condition else None img, x_start = self.p_sample(img, i, self_cond) return img @autocast(enabled = False) def q_sample(self, x_start, t, noise = None): noise = default(noise, lambda: torch.randn_like(x_start)) return ( extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise ) def p_losses(self, x_start, t, noise = None, offset_noise_strength = None): b, c, h, w = x_start.shape noise = default(noise, lambda: torch.randn_like(x_start)) # offset noise - https://www.crosslabs.org/blog/diffusion-with-offset-noise offset_noise_strength = default(offset_noise_strength, self.offset_noise_strength) if offset_noise_strength > 0.: offset_noise = torch.randn(x_start.shape[:2], device = self.device) noise += offset_noise_strength * rearrange(offset_noise, 'b c -> b c 1 1') # noise sample x = self.q_sample(x_start = x_start, t = t, noise = noise) # if doing self-conditioning, 50% of the time, predict x_start from current set of times # and condition with unet with that # this technique will slow down training by 25%, but seems to lower FID significantly x_self_cond = None if self.self_condition and random() < 0.5: with torch.no_grad(): x_self_cond = self.model_predictions(x, t).pred_x_start x_self_cond.detach_() # predict and take gradient step model_out = self.model(x, t, x_self_cond) if self.objective == 'pred_noise': target = noise elif self.objective == 'pred_x0': target = x_start elif self.objective == 'pred_v': v = self.predict_v(x_start, t, noise) target = v else: raise ValueError(f'unknown objective {self.objective}') loss = F.mse_loss(model_out, target, reduction = 'none') loss = reduce(loss, 'b ... -> b', 'mean') loss = loss * extract(self.loss_weight, t, loss.shape) return loss.mean() def forward(self, img, *args, **kwargs): b, c, h, w, device, img_size, = *img.shape, img.device, self.image_size assert h == img_size and w == img_size, f'height and width of image must be {img_size}' t = torch.randint(0, self.num_timesteps, (b,), device=device).long() img = self.normalize(img) return self.p_losses(img, t, *args, **kwargs) # # Resnet Model # In[19]: def default_conv(in_channels, out_channels, kernel_size, bias=True): return nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size//2), bias=bias) # In[20]: class Swish(nn.Module): def forward(self, x): return x * torch.sigmoid(x) # In[21]: class AttnBlock(nn.Module): def __init__(self, in_ch): super().__init__() self.group_norm = nn.GroupNorm(32, in_ch) self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) def forward(self, x): B, C, H, W = x.shape h = self.group_norm(x) q = self.proj_q(h) k = self.proj_k(h) v = self.proj_v(h) q = q.permute(0, 2, 3, 1).view(B, H * W, C) k = k.view(B, C, H * W) w = torch.bmm(q, k) * (int(C) ** (-0.5)) assert list(w.shape) == [B, H * W, H * W] w = F.softmax(w, dim=-1) v = v.permute(0, 2, 3, 1).view(B, H * W, C) h = torch.bmm(w, v) assert list(h.shape) == [B, H * W, C] h = h.view(B, H, W, C).permute(0, 3, 1, 2) h = self.proj(h) return x + h # In[22]: class ResBlock(nn.Module): def __init__(self, in_ch, out_ch, tdim, dropout, attn=False): super().__init__() self.block1 = nn.Sequential( nn.GroupNorm(32, in_ch), Swish(), nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1), ) self.temb_proj = nn.Sequential( Swish(), nn.Linear(tdim, out_ch), ) self.block2 = nn.Sequential( nn.GroupNorm(32, out_ch), Swish(), nn.Dropout(dropout), nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1), ) if in_ch != out_ch: self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0) else: self.shortcut = nn.Identity() if attn: self.attn = AttnBlock(out_ch) else: self.attn = nn.Identity() def forward(self, x, temb): h = self.block1(x) # h += self.temb_proj(temb)[:, :, None, None] h = self.block2(h) h = h + self.shortcut(x) h = self.attn(h) return h # In[23]: class DownSample(nn.Module): def __init__(self, in_ch, out_ch): super().__init__() self.main = nn.Conv2d(in_ch, out_ch, 3, stride=2, padding=1) def forward(self, x, temb): x = self.main(x) return x # In[24]: class UpSample(nn.Module): def __init__(self, in_ch, out_ch): super().__init__() self.main = nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1) def forward(self, x, temb): _, _, H, W = x.shape x = F.interpolate( x, scale_factor=2, mode='nearest') x = self.main(x) return x # In[26]: class Unet(nn.Module): # Modified from https://github.com/sanghyun-son/EDSR-PyTorch/blob/master/src/model/edsr.py#L31 def __init__(self, n_feats=128, ch_mul=[1, 2, 4], attention_mul=[1, 2], t_dim=256, dropout=0.1, channels=1, out_dim=1, num_res_blocks=2, self_condition = False, learned_sinusoidal_cond=False, random_fourier_features=False, learned_sinusoidal_dim=16, sinusoidal_pos_emb_theta=10000, conv=default_conv): super(Unet, self).__init__() self.n_feats = n_feats self.t_dim = t_dim self.dropout = dropout self.channels = channels self.out_dim = out_dim self.self_condition = self_condition self.kernel_size = 3 # define time embedding if learned_sinusoidal_cond: sinu_pos_emb = RandomOrLearnedSinusoidalPosEmb(learned_sinusoidal_dim, random_fourier_features) fourier_dim = learned_sinusoidal_dim + 1 else: sinu_pos_emb = SinusoidalPosEmb(dim=self.n_feats, theta=sinusoidal_pos_emb_theta) fourier_dim = self.n_feats self.time_mlp = nn.Sequential( sinu_pos_emb, nn.Linear(fourier_dim, self.t_dim), nn.GELU(), nn.Linear(self.t_dim, self.t_dim) ) # define head module self.head = conv(self.channels, self.n_feats, self.kernel_size) # define downsample module channel_list = [] current_channel = n_feats self.downblocks = nn.ModuleList() for i, mult in enumerate(ch_mul): out_channels = n_feats * mult for _ in range(num_res_blocks): self.downblocks.append( ResBlock(in_ch=current_channel, out_ch=out_channels, tdim=self.t_dim, dropout=self.dropout, attn=(mult in attention_mul))) current_channel = out_channels channel_list.append(current_channel) if i != len(ch_mul) - 1: out_channels = n_feats * ch_mul[i + 1] self.downblocks.append(DownSample(current_channel, out_channels)) channel_list.append((current_channel, out_channels)) current_channel = out_channels # define middle module self.middleblocks = nn.ModuleList([ ResBlock(in_ch=current_channel, out_ch=current_channel, tdim=self.t_dim, dropout=self.dropout, attn=True), ResBlock(in_ch=current_channel, out_ch=current_channel, tdim=self.t_dim, dropout=self.dropout, attn=True), ]) # define upsample module self.upblocks = nn.ModuleList() for i, mult in reversed(list(enumerate(ch_mul))): out_channels = n_feats * mult for _ in range(num_res_blocks): self.upblocks.append( ResBlock(in_ch=channel_list.pop(), out_ch=out_channels, tdim=self.t_dim, dropout=self.dropout, attn=(mult in attention_mul))) if i != 0: curr_ch, out_ch = channel_list.pop() self.upblocks.append(UpSample(out_ch, curr_ch)) self.upblocks.append(ResBlock(in_ch=curr_ch*2, out_ch=curr_ch, tdim=self.t_dim, dropout=self.dropout, attn=(mult in attention_mul))) current_channel = out_channels assert len(channel_list) == 0 # define tail module self.tail = nn.Sequential( nn.GroupNorm(32, current_channel), Swish(), nn.Conv2d(current_channel, self.out_dim, 3, stride=1, padding=1) ) def forward(self, x, t, cond=None): t = self.time_mlp(t) # Downsample x = self.head(x) x_list = [] for block in self.downblocks: if isinstance(block, DownSample): x_list.append(x) x = block(x, t) # Middle for block in self.middleblocks: x = block(x, t) # Upsample up = False for block in self.upblocks: if up: x = torch.concat([x_list.pop(), x], dim=1) up = False if isinstance(block, UpSample): up = True x = block(x, t) x = self.tail(x) return x # # Train # In[24]: # In[25]: # In[26]: # define model model = Unet( n_feats=128, ch_mul=[1,2,4], attention_mul=[1,2], t_dim=512, dropout=0.1, channels=1, # MNIST out_dim=1, # MNIST num_res_blocks=2, self_condition = False, learned_sinusoidal_cond=False, random_fourier_features=False, learned_sinusoidal_dim=16, sinusoidal_pos_emb_theta=10000, ) diffusion_model = GaussianDiffusion( model, image_size=28, # MNIST timesteps=1000, sampling_timesteps=None, objective ='pred_noise', beta_schedule ='linear', schedule_fn_kwargs=dict(), ddim_sampling_eta= 0., auto_normalize = True, offset_noise_strength = 0., # https://www.crosslabs.org/blog/diffusion-with-offset-noise min_snr_loss_weight = False, # https://arxiv.org/abs/2303.09556 min_snr_gamma = 5) # In[27]: # In[28]: # In[29]: # device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # In[30]: # trainer max_epoches = 50 iter_print = 100 iter_sample = 1000 save_each = 1 diffusion_model = diffusion_model.to(device) last_trained_path = '/content/DDPM_ResNet_Unet/unet_wo_t/model/epoch_30.pth' diffusion_model.load_state_dict(torch.load(os.path.join(last_trained_path), map_location=device)['model']) sample_path = '/content/DDPM_ResNet_Unet/unet_wo_t/sample' if not os.path.exists(sample_path): os.mkdir(sample_path) num_sample = 10000 sample_batch = 16 count = 0 if num_sample % sample_batch != 0: num_sample = num_sample + (sample_batch - (num_sample % sample_batch)) for batch in range(num_sample//sample_batch): imgs = diffusion_model.sample(batch_size=sample_batch, return_all_timesteps=False) for i in range(imgs.size(0)): torchvision.utils.save_image(imgs[i, :, :, :], os.path.join(sample_path ,f'{count}.png')) count += 1