{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "NDyknNLUIPKP" }, "source": [ "# Library" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "xtzpYsKzfUqx" }, "outputs": [], "source": [ "import torch\n", "import torch.nn.functional as F\n", "from torch import nn\n", "from torch.cuda.amp import autocast\n", "\n", "import torchvision\n", "from torchvision.transforms import transforms\n", "from torch.utils.data import DataLoader\n", "\n", "from torch.optim import Adam\n", "\n", "from einops import rearrange, reduce, repeat\n", "import math\n", "from random import random\n", "\n", "from collections import namedtuple\n", "from functools import partial\n", "from tqdm.auto import tqdm\n", "import logging\n", "import os\n", "\n", "from PIL import Image\n", "from torchvision import utils" ] }, { "cell_type": "markdown", "metadata": { "id": "9Di63CM_dLK0" }, "source": [ "# Helper" ] }, { "cell_type": "markdown", "metadata": { "id": "zxVxXvMReAx7" }, "source": [ "### Constant" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "5E-cF3DLdMH6" }, "outputs": [], "source": [ "ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])" ] }, { "cell_type": "markdown", "metadata": { "id": "hdXwRoR7eC1V" }, "source": [ "### Functions" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "y2X3CKF0dQTw" }, "outputs": [], "source": [ "def exists(x):\n", " return x is not None\n", "\n", "def default(val, d):\n", " if exists(val):\n", " return val\n", " return d() if callable(d) else d" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "4NHXsFv6dU46" }, "outputs": [], "source": [ "def cast_tuple(t, length = 1):\n", " if isinstance(t, tuple):\n", " return t\n", " return ((t,) * length)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "r0P0wIhlddrp" }, "outputs": [], "source": [ "def divisible_by(numer, denom):\n", " return (numer % denom) == 0" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "3bAT9MpEdft4" }, "outputs": [], "source": [ "def identity(t, *args, **kwargs):\n", " return t" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "SEjbKEwzdmMS" }, "outputs": [], "source": [ "def cycle(dl):\n", " while True:\n", " for data in dl:\n", " yield data" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "hy9C_wBJdnjd" }, "outputs": [], "source": [ "def has_int_squareroot(num):\n", " return (math.sqrt(num) ** 2) == num" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "r_tLcSg_dvSg" }, "outputs": [], "source": [ "def num_to_groups(num, divisor):\n", " groups = num // divisor\n", " remainder = num % divisor\n", " arr = [divisor] * groups\n", " if remainder > 0:\n", " arr.append(remainder)\n", " return arr" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "id": "PyqfVkPDd9AP" }, "outputs": [], "source": [ "def convert_image_to_fn(img_type, image):\n", " if image.mode != img_type:\n", " return image.convert(img_type)\n", " return image" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "id": "c4GkI1-Jr26q" }, "outputs": [], "source": [ "def extract(a, t, x_shape):\n", " b, *_ = t.shape\n", " out = a.gather(-1, t)\n", " return out.reshape(b, *((1,) * (len(x_shape) - 1)))" ] }, { "cell_type": "markdown", "metadata": { "id": "kiTIA-VqeG7_" }, "source": [ "### Normalization Functions" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "id": "fBnwqtaKeNAA" }, "outputs": [], "source": [ "def normalize_to_neg_one_to_one(img):\n", " return img * 2 - 1\n", "\n", "def unnormalize_to_zero_to_one(t):\n", " return (t + 1) * 0.5" ] }, { "cell_type": "markdown", "metadata": { "id": "xdr3M-Kbe_fR" }, "source": [ "### Sinusoidal positional embeds" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "id": "_2aXPqWgfCQ-" }, "outputs": [], "source": [ "class SinusoidalPosEmb(nn.Module):\n", " def __init__(self, dim, theta = 10000):\n", " super().__init__()\n", " self.dim = dim\n", " self.theta = theta\n", "\n", " def forward(self, x):\n", " device = x.device\n", " half_dim = self.dim // 2\n", " emb = math.log(self.theta) / (half_dim - 1)\n", " emb = torch.exp(torch.arange(half_dim, device=device) * -emb)\n", " emb = x[:, None] * emb[None, :]\n", " emb = torch.cat((emb.sin(), emb.cos()), dim=-1)\n", " return emb" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "id": "uOz-MCVlfEVs" }, "outputs": [], "source": [ "class RandomOrLearnedSinusoidalPosEmb(nn.Module):\n", " \"\"\" following @crowsonkb 's lead with random (learned optional) sinusoidal pos emb \"\"\"\n", " \"\"\" https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 \"\"\"\n", "\n", " def __init__(self, dim, is_random = False):\n", " super().__init__()\n", " assert divisible_by(dim, 2)\n", " half_dim = dim // 2\n", " self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = not is_random)\n", "\n", " def forward(self, x):\n", " x = rearrange(x, 'b -> b 1')\n", " freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi\n", " fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)\n", " fouriered = torch.cat((x, fouriered), dim = -1)\n", " return fouriered" ] }, { "cell_type": "markdown", "metadata": { "id": "aCIo4WI8rqpj" }, "source": [ "### Schedule" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "id": "SZvqJKr9rsH-" }, "outputs": [], "source": [ "def linear_beta_schedule(timesteps):\n", " \"\"\"\n", " linear schedule, proposed in original ddpm paper\n", " \"\"\"\n", " scale = 1000 / timesteps\n", " beta_start = scale * 0.0001\n", " beta_end = scale * 0.02\n", " return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "id": "HtldF7joruUZ" }, "outputs": [], "source": [ "def cosine_beta_schedule(timesteps, s = 0.008):\n", " \"\"\"\n", " cosine schedule\n", " as proposed in https://openreview.net/forum?id=-NEXDKk8gZ\n", " \"\"\"\n", " steps = timesteps + 1\n", " t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps\n", " alphas_cumprod = torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** 2\n", " alphas_cumprod = alphas_cumprod / alphas_cumprod[0]\n", " betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])\n", " return torch.clip(betas, 0, 0.999)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "id": "56UEThPwrvrj" }, "outputs": [], "source": [ "def sigmoid_beta_schedule(timesteps, start = -3, end = 3, tau = 1, clamp_min = 1e-5):\n", " \"\"\"\n", " sigmoid schedule\n", " proposed in https://arxiv.org/abs/2212.11972 - Figure 8\n", " better for images > 64x64, when used during training\n", " \"\"\"\n", " steps = timesteps + 1\n", " t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps\n", " v_start = torch.tensor(start / tau).sigmoid()\n", " v_end = torch.tensor(end / tau).sigmoid()\n", " alphas_cumprod = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start)\n", " alphas_cumprod = alphas_cumprod / alphas_cumprod[0]\n", " betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])\n", " return torch.clip(betas, 0, 0.999)" ] }, { "cell_type": "markdown", "metadata": { "id": "JuH_4NvddIt_" }, "source": [ "# Diffusion model" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "id": "HUpKRpzpiQbU" }, "outputs": [], "source": [ "class GaussianDiffusion(nn.Module):\n", " # Copy from https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L163\n", "\n", " def __init__(\n", " self,\n", " model,\n", " *,\n", " image_size,\n", " timesteps = 1000,\n", " sampling_timesteps = None,\n", " objective = 'pred_noise',\n", " beta_schedule = 'linear',\n", " schedule_fn_kwargs = dict(),\n", " ddim_sampling_eta = 0.,\n", " auto_normalize = True,\n", " offset_noise_strength = 0., # https://www.crosslabs.org/blog/diffusion-with-offset-noise\n", " min_snr_loss_weight = False, # https://arxiv.org/abs/2303.09556\n", " min_snr_gamma = 5\n", " ):\n", " super().__init__()\n", " assert not (type(self) == GaussianDiffusion and model.channels != model.out_dim)\n", " assert not hasattr(model, 'random_or_learned_sinusoidal_cond') or not model.random_or_learned_sinusoidal_cond\n", "\n", " self.model = model\n", "\n", " self.channels = self.model.channels\n", " self.self_condition = self.model.self_condition\n", "\n", " self.image_size = image_size\n", "\n", " self.objective = objective\n", "\n", " assert objective in {'pred_noise', 'pred_x0', 'pred_v'}, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start) or pred_v (predict v [v-parameterization as defined in appendix D of progressive distillation paper, used in imagen-video successfully])'\n", "\n", " if beta_schedule == 'linear':\n", " beta_schedule_fn = linear_beta_schedule\n", " elif beta_schedule == 'cosine':\n", " beta_schedule_fn = cosine_beta_schedule\n", " elif beta_schedule == 'sigmoid':\n", " beta_schedule_fn = sigmoid_beta_schedule\n", " else:\n", " raise ValueError(f'unknown beta schedule {beta_schedule}')\n", "\n", " betas = beta_schedule_fn(timesteps, **schedule_fn_kwargs)\n", "\n", " alphas = 1. - betas\n", " alphas_cumprod = torch.cumprod(alphas, dim=0)\n", " alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)\n", "\n", " timesteps, = betas.shape\n", " self.num_timesteps = int(timesteps)\n", "\n", " # sampling related parameters\n", "\n", " self.sampling_timesteps = default(sampling_timesteps, timesteps) # default num sampling timesteps to number of timesteps at training\n", "\n", " assert self.sampling_timesteps <= timesteps\n", " self.is_ddim_sampling = self.sampling_timesteps < timesteps\n", " self.ddim_sampling_eta = ddim_sampling_eta\n", "\n", " # helper function to register buffer from float64 to float32\n", "\n", " register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))\n", "\n", " register_buffer('betas', betas)\n", " register_buffer('alphas_cumprod', alphas_cumprod)\n", " register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)\n", "\n", " # calculations for diffusion q(x_t | x_{t-1}) and others\n", "\n", " register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))\n", " register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))\n", " register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))\n", " register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))\n", " register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))\n", "\n", " # calculations for posterior q(x_{t-1} | x_t, x_0)\n", "\n", " posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)\n", "\n", " # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)\n", "\n", " register_buffer('posterior_variance', posterior_variance)\n", "\n", " # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain\n", "\n", " register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))\n", " register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))\n", " register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))\n", "\n", " # offset noise strength - in blogpost, they claimed 0.1 was ideal\n", "\n", " self.offset_noise_strength = offset_noise_strength\n", "\n", " # derive loss weight\n", " # snr - signal noise ratio\n", "\n", " snr = alphas_cumprod / (1 - alphas_cumprod)\n", "\n", " # https://arxiv.org/abs/2303.09556\n", "\n", " maybe_clipped_snr = snr.clone()\n", " if min_snr_loss_weight:\n", " maybe_clipped_snr.clamp_(max = min_snr_gamma)\n", "\n", " if objective == 'pred_noise':\n", " register_buffer('loss_weight', maybe_clipped_snr / snr)\n", " elif objective == 'pred_x0':\n", " register_buffer('loss_weight', maybe_clipped_snr)\n", " elif objective == 'pred_v':\n", " register_buffer('loss_weight', maybe_clipped_snr / (snr + 1))\n", "\n", " # auto-normalization of data [0, 1] -> [-1, 1] - can turn off by setting it to be False\n", "\n", " self.normalize = normalize_to_neg_one_to_one if auto_normalize else identity\n", " self.unnormalize = unnormalize_to_zero_to_one if auto_normalize else identity\n", "\n", " @property\n", " def device(self):\n", " return self.betas.device\n", "\n", " def predict_start_from_noise(self, x_t, t, noise):\n", " return (\n", " extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -\n", " extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise\n", " )\n", "\n", " def predict_noise_from_start(self, x_t, t, x0):\n", " return (\n", " (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \\\n", " extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)\n", " )\n", "\n", " def predict_v(self, x_start, t, noise):\n", " return (\n", " extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -\n", " extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start\n", " )\n", "\n", " def predict_start_from_v(self, x_t, t, v):\n", " return (\n", " extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -\n", " extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v\n", " )\n", "\n", " def q_posterior(self, x_start, x_t, t):\n", " posterior_mean = (\n", " extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +\n", " extract(self.posterior_mean_coef2, t, x_t.shape) * x_t\n", " )\n", " posterior_variance = extract(self.posterior_variance, t, x_t.shape)\n", " posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)\n", " return posterior_mean, posterior_variance, posterior_log_variance_clipped\n", "\n", " def model_predictions(self, x, t, x_self_cond = None, clip_x_start = False, rederive_pred_noise = False):\n", " model_output = self.model(x, t, x_self_cond)\n", " maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity\n", "\n", " if self.objective == 'pred_noise':\n", " pred_noise = model_output\n", " x_start = self.predict_start_from_noise(x, t, pred_noise)\n", " x_start = maybe_clip(x_start)\n", "\n", " if clip_x_start and rederive_pred_noise:\n", " pred_noise = self.predict_noise_from_start(x, t, x_start)\n", "\n", " elif self.objective == 'pred_x0':\n", " x_start = model_output\n", " x_start = maybe_clip(x_start)\n", " pred_noise = self.predict_noise_from_start(x, t, x_start)\n", "\n", " elif self.objective == 'pred_v':\n", " v = model_output\n", " x_start = self.predict_start_from_v(x, t, v)\n", " x_start = maybe_clip(x_start)\n", " pred_noise = self.predict_noise_from_start(x, t, x_start)\n", "\n", " return ModelPrediction(pred_noise, x_start)\n", "\n", " def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True):\n", " preds = self.model_predictions(x, t, x_self_cond)\n", " x_start = preds.pred_x_start\n", "\n", " if clip_denoised:\n", " x_start.clamp_(-1., 1.)\n", "\n", " model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t)\n", " return model_mean, posterior_variance, posterior_log_variance, x_start\n", "\n", " @torch.inference_mode()\n", " def p_sample(self, x, t: int, x_self_cond = None):\n", " b, *_, device = *x.shape, self.device\n", " batched_times = torch.full((b,), t, device = device, dtype = torch.long)\n", " model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, x_self_cond = x_self_cond, clip_denoised = True)\n", " noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0\n", " pred_img = model_mean + (0.5 * model_log_variance).exp() * noise\n", " return pred_img, x_start\n", "\n", " @torch.inference_mode()\n", " def p_sample_loop(self, shape, return_all_timesteps = False):\n", " batch, device = shape[0], self.device\n", "\n", " img = torch.randn(shape, device = device)\n", " imgs = [img]\n", "\n", " x_start = None\n", "\n", " for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):\n", " self_cond = x_start if self.self_condition else None\n", " img, x_start = self.p_sample(img, t, self_cond)\n", " imgs.append(img)\n", "\n", " ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1)\n", "\n", " ret = self.unnormalize(ret)\n", " return ret\n", "\n", " @torch.inference_mode()\n", " def ddim_sample(self, shape, return_all_timesteps = False):\n", " batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective\n", "\n", " times = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps\n", " times = list(reversed(times.int().tolist()))\n", " time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]\n", "\n", " img = torch.randn(shape, device = device)\n", " imgs = [img]\n", "\n", " x_start = None\n", "\n", " for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):\n", " time_cond = torch.full((batch,), time, device = device, dtype = torch.long)\n", " self_cond = x_start if self.self_condition else None\n", " pred_noise, x_start, *_ = self.model_predictions(img, time_cond, self_cond, clip_x_start = True, rederive_pred_noise = True)\n", "\n", " if time_next < 0:\n", " img = x_start\n", " imgs.append(img)\n", " continue\n", "\n", " alpha = self.alphas_cumprod[time]\n", " alpha_next = self.alphas_cumprod[time_next]\n", "\n", " sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()\n", " c = (1 - alpha_next - sigma ** 2).sqrt()\n", "\n", " noise = torch.randn_like(img)\n", "\n", " img = x_start * alpha_next.sqrt() + \\\n", " c * pred_noise + \\\n", " sigma * noise\n", "\n", " imgs.append(img)\n", "\n", " ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1)\n", "\n", " ret = self.unnormalize(ret)\n", " return ret\n", "\n", " @torch.inference_mode()\n", " def sample(self, batch_size = 16, return_all_timesteps = False):\n", " image_size, channels = self.image_size, self.channels\n", " sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample\n", " return sample_fn((batch_size, channels, image_size, image_size), return_all_timesteps = return_all_timesteps)\n", "\n", " @torch.inference_mode()\n", " def interpolate(self, x1, x2, t = None, lam = 0.5):\n", " b, *_, device = *x1.shape, x1.device\n", " t = default(t, self.num_timesteps - 1)\n", "\n", " assert x1.shape == x2.shape\n", "\n", " t_batched = torch.full((b,), t, device = device)\n", " xt1, xt2 = map(lambda x: self.q_sample(x, t = t_batched), (x1, x2))\n", "\n", " img = (1 - lam) * xt1 + lam * xt2\n", "\n", " x_start = None\n", "\n", " for i in tqdm(reversed(range(0, t)), desc = 'interpolation sample time step', total = t):\n", " self_cond = x_start if self.self_condition else None\n", " img, x_start = self.p_sample(img, i, self_cond)\n", "\n", " return img\n", "\n", " @autocast(enabled = False)\n", " def q_sample(self, x_start, t, noise = None):\n", " noise = default(noise, lambda: torch.randn_like(x_start))\n", "\n", " return (\n", " extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +\n", " extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise\n", " )\n", "\n", " def p_losses(self, x_start, t, noise = None, offset_noise_strength = None):\n", " b, c, h, w = x_start.shape\n", "\n", " noise = default(noise, lambda: torch.randn_like(x_start))\n", "\n", " # offset noise - https://www.crosslabs.org/blog/diffusion-with-offset-noise\n", "\n", " offset_noise_strength = default(offset_noise_strength, self.offset_noise_strength)\n", "\n", " if offset_noise_strength > 0.:\n", " offset_noise = torch.randn(x_start.shape[:2], device = self.device)\n", " noise += offset_noise_strength * rearrange(offset_noise, 'b c -> b c 1 1')\n", "\n", " # noise sample\n", "\n", " x = self.q_sample(x_start = x_start, t = t, noise = noise)\n", "\n", " # if doing self-conditioning, 50% of the time, predict x_start from current set of times\n", " # and condition with unet with that\n", " # this technique will slow down training by 25%, but seems to lower FID significantly\n", "\n", " x_self_cond = None\n", " if self.self_condition and random() < 0.5:\n", " with torch.no_grad():\n", " x_self_cond = self.model_predictions(x, t).pred_x_start\n", " x_self_cond.detach_()\n", "\n", " # predict and take gradient step\n", "\n", " model_out = self.model(x, t, x_self_cond)\n", "\n", " if self.objective == 'pred_noise':\n", " target = noise\n", " elif self.objective == 'pred_x0':\n", " target = x_start\n", " elif self.objective == 'pred_v':\n", " v = self.predict_v(x_start, t, noise)\n", " target = v\n", " else:\n", " raise ValueError(f'unknown objective {self.objective}')\n", "\n", " loss = F.mse_loss(model_out, target, reduction = 'none')\n", " loss = reduce(loss, 'b ... -> b', 'mean')\n", "\n", " loss = loss * extract(self.loss_weight, t, loss.shape)\n", " return loss.mean()\n", "\n", " def forward(self, img, *args, **kwargs):\n", " b, c, h, w, device, img_size, = *img.shape, img.device, self.image_size\n", " assert h == img_size and w == img_size, f'height and width of image must be {img_size}'\n", " t = torch.randint(0, self.num_timesteps, (b,), device=device).long()\n", "\n", " img = self.normalize(img)\n", " return self.p_losses(img, t, *args, **kwargs)" ] }, { "cell_type": "markdown", "metadata": { "id": "G0EUK2YoOzNh" }, "source": [ "# Resnet Model" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "id": "OzRlUwn_uIZx" }, "outputs": [], "source": [ "def default_conv(in_channels, out_channels, kernel_size, bias=True):\n", " return nn.Conv2d(\n", " in_channels, out_channels, kernel_size,\n", " padding=(kernel_size//2), bias=bias)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "id": "GUE3rp5BxFpe" }, "outputs": [], "source": [ "class Swish(nn.Module):\n", " def forward(self, x):\n", " return x * torch.sigmoid(x)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "id": "arVWRQkzxLt5" }, "outputs": [], "source": [ "class AttnBlock(nn.Module):\n", " def __init__(self, in_ch):\n", " super().__init__()\n", " self.group_norm = nn.GroupNorm(32, in_ch)\n", " self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)\n", " self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)\n", " self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)\n", " self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)\n", "\n", " def forward(self, x):\n", " B, C, H, W = x.shape\n", " h = self.group_norm(x)\n", " q = self.proj_q(h)\n", " k = self.proj_k(h)\n", " v = self.proj_v(h)\n", "\n", " q = q.permute(0, 2, 3, 1).view(B, H * W, C)\n", " k = k.view(B, C, H * W)\n", " w = torch.bmm(q, k) * (int(C) ** (-0.5))\n", " assert list(w.shape) == [B, H * W, H * W]\n", " w = F.softmax(w, dim=-1)\n", "\n", " v = v.permute(0, 2, 3, 1).view(B, H * W, C)\n", " h = torch.bmm(w, v)\n", " assert list(h.shape) == [B, H * W, C]\n", " h = h.view(B, H, W, C).permute(0, 3, 1, 2)\n", " h = self.proj(h)\n", "\n", " return x + h" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "id": "DI5MM527m9CT" }, "outputs": [], "source": [ "class ResBlock(nn.Module):\n", " def __init__(self, in_ch, out_ch, tdim, dropout, attn=False):\n", " super().__init__()\n", " self.block1 = nn.Sequential(\n", " nn.GroupNorm(32, in_ch),\n", " Swish(),\n", " nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),\n", " )\n", " self.temb_proj = nn.Sequential(\n", " Swish(),\n", " nn.Linear(tdim, out_ch),\n", " )\n", " self.block2 = nn.Sequential(\n", " nn.GroupNorm(32, out_ch),\n", " Swish(),\n", " nn.Dropout(dropout),\n", " nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),\n", " )\n", " if in_ch != out_ch:\n", " self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)\n", " else:\n", " self.shortcut = nn.Identity()\n", " if attn:\n", " self.attn = AttnBlock(out_ch)\n", " else:\n", " self.attn = nn.Identity()\n", "\n", " def forward(self, x, temb):\n", " h = self.block1(x)\n", " # h += self.temb_proj(temb)[:, :, None, None]\n", " h = self.block2(h)\n", "\n", " h = h + self.shortcut(x)\n", " h = self.attn(h)\n", " return h" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "class DownSample(nn.Module):\n", " def __init__(self, in_ch, out_ch):\n", " super().__init__()\n", " self.main = nn.Conv2d(in_ch, out_ch, 3, stride=2, padding=1)\n", "\n", " def forward(self, x, temb):\n", " x = self.main(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "class UpSample(nn.Module):\n", " def __init__(self, in_ch, out_ch):\n", " super().__init__()\n", " self.main = nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1)\n", "\n", " def forward(self, x, temb):\n", " _, _, H, W = x.shape\n", " x = F.interpolate(\n", " x, scale_factor=2, mode='nearest')\n", " x = self.main(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "class Unet(nn.Module):\n", " # Modified from https://github.com/sanghyun-son/EDSR-PyTorch/blob/master/src/model/edsr.py#L31\n", "\n", " def __init__(self,\n", " n_feats=128,\n", " ch_mul=[1, 2, 4],\n", " attention_mul=[1, 2],\n", " t_dim=256,\n", " dropout=0.1,\n", " channels=1,\n", " out_dim=1,\n", " num_res_blocks=2,\n", " self_condition = False,\n", " learned_sinusoidal_cond=False,\n", " random_fourier_features=False,\n", " learned_sinusoidal_dim=16,\n", " sinusoidal_pos_emb_theta=10000,\n", " conv=default_conv):\n", " super(Unet, self).__init__()\n", " \n", " self.n_feats = n_feats\n", " self.t_dim = t_dim\n", " self.dropout = dropout\n", " self.channels = channels\n", " self.out_dim = out_dim\n", " self.self_condition = self_condition\n", " self.kernel_size = 3\n", "\n", " # define time embedding\n", " if learned_sinusoidal_cond:\n", " sinu_pos_emb = RandomOrLearnedSinusoidalPosEmb(learned_sinusoidal_dim, random_fourier_features)\n", " fourier_dim = learned_sinusoidal_dim + 1\n", " else:\n", " sinu_pos_emb = SinusoidalPosEmb(dim=self.n_feats, theta=sinusoidal_pos_emb_theta)\n", " fourier_dim = self.n_feats\n", "\n", " self.time_mlp = nn.Sequential(\n", " sinu_pos_emb,\n", " nn.Linear(fourier_dim, self.t_dim),\n", " nn.GELU(),\n", " nn.Linear(self.t_dim, self.t_dim)\n", " )\n", "\n", " # define head module\n", " self.head = conv(self.channels, self.n_feats, self.kernel_size)\n", "\n", " # define downsample module\n", " channel_list = []\n", " current_channel = n_feats\n", " self.downblocks = nn.ModuleList()\n", " for i, mult in enumerate(ch_mul):\n", " out_channels = n_feats * mult\n", " for _ in range(num_res_blocks):\n", " self.downblocks.append(\n", " ResBlock(in_ch=current_channel,\n", " out_ch=out_channels,\n", " tdim=self.t_dim,\n", " dropout=self.dropout,\n", " attn=(mult in attention_mul)))\n", " current_channel = out_channels\n", " channel_list.append(current_channel)\n", " if i != len(ch_mul) - 1:\n", " out_channels = n_feats * ch_mul[i + 1]\n", " self.downblocks.append(DownSample(current_channel, out_channels))\n", " channel_list.append((current_channel, out_channels))\n", " current_channel = out_channels\n", " \n", " # define middle module\n", " self.middleblocks = nn.ModuleList([\n", " ResBlock(in_ch=current_channel, out_ch=current_channel, tdim=self.t_dim, dropout=self.dropout, attn=True),\n", " ResBlock(in_ch=current_channel, out_ch=current_channel, tdim=self.t_dim, dropout=self.dropout, attn=True),\n", " ])\n", " \n", " # define upsample module\n", " self.upblocks = nn.ModuleList()\n", " for i, mult in reversed(list(enumerate(ch_mul))):\n", " out_channels = n_feats * mult\n", " for _ in range(num_res_blocks): \n", " self.upblocks.append(\n", " ResBlock(in_ch=channel_list.pop(),\n", " out_ch=out_channels,\n", " tdim=self.t_dim,\n", " dropout=self.dropout,\n", " attn=(mult in attention_mul)))\n", " if i != 0:\n", " curr_ch, out_ch = channel_list.pop()\n", " self.upblocks.append(UpSample(out_ch, curr_ch))\n", " self.upblocks.append(ResBlock(in_ch=curr_ch*2,\n", " out_ch=curr_ch,\n", " tdim=self.t_dim,\n", " dropout=self.dropout,\n", " attn=(mult in attention_mul)))\n", " \n", " current_channel = out_channels\n", " assert len(channel_list) == 0\n", " \n", "\n", " # define tail module\n", " self.tail = nn.Sequential(\n", " nn.GroupNorm(32, current_channel),\n", " Swish(),\n", " nn.Conv2d(current_channel, self.out_dim, 3, stride=1, padding=1)\n", " )\n", "\n", "\n", " def forward(self, x, t, cond=None):\n", " t = self.time_mlp(t)\n", "\n", " # Downsample\n", " x = self.head(x)\n", " x_list = []\n", "\n", " for block in self.downblocks:\n", " if isinstance(block, DownSample):\n", " x_list.append(x)\n", " x = block(x, t)\n", " \n", " # Middle\n", " for block in self.middleblocks:\n", " x = block(x, t)\n", " \n", " # Upsample\n", " up = False\n", " for block in self.upblocks:\n", " if up:\n", " x = torch.concat([x_list.pop(), x], dim=1)\n", " up = False\n", " if isinstance(block, UpSample):\n", " up = True\n", " x = block(x, t)\n", " \n", "\n", " x = self.tail(x)\n", "\n", " return x" ] }, { "cell_type": "markdown", "metadata": { "id": "sKYnOD4tgoVm" }, "source": [ "# Train" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "id": "IX17Bgq4ponl" }, "outputs": [], "source": [ "# output dir\n", "save_path = 'resnet/model'\n", "log_path = 'resnet/log'\n", "\n", "if not os.path.exists(log_path):\n", " os.mkdir(log_path)\n", "if not os.path.exists(save_path):\n", " os.mkdir(save_path)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "id": "M0xgP-lfqNyA" }, "outputs": [], "source": [ "# setup logging\n", "\n", "# Setup logging to file\n", "logging.basicConfig(\n", " filename=os.path.join(log_path, 'info.log'),\n", " filemode=\"w\",\n", " level=logging.DEBUG,\n", " format= '[%(asctime)s] %(levelname)s - %(message)s',\n", " datefmt='%H:%M:%S',\n", " force=True\n", " )\n", "\n", "\n", "# Stop PIL from printing to file\n", "pil_logger = logging.getLogger('PIL')\n", "pil_logger.setLevel(logging.INFO)\n", "\n", "# write and print at the same time\n", "console = logging.StreamHandler()\n", "console.setLevel(logging.INFO)\n", "logging.getLogger().addHandler(console)\n", "\n", "logger = logging.getLogger('Diffusion_Resnet')" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "id": "7D_66Fxzgpsj" }, "outputs": [], "source": [ "# define model\n", "model = Unet(\n", " n_feats=128,\n", " ch_mul=[1,2,4],\n", " attention_mul=[0],\n", " t_dim=512,\n", " dropout=0.1,\n", " channels=1, # MNIST\n", " out_dim=1, # MNIST\n", " num_res_blocks=2,\n", " self_condition = False,\n", " learned_sinusoidal_cond=False,\n", " random_fourier_features=False,\n", " learned_sinusoidal_dim=16,\n", " sinusoidal_pos_emb_theta=10000,\n", ")\n", "\n", "diffusion_model = GaussianDiffusion(\n", " model,\n", " image_size=28, # MNIST\n", " timesteps=1000,\n", " sampling_timesteps=None,\n", " objective ='pred_noise',\n", " beta_schedule ='linear',\n", " schedule_fn_kwargs=dict(),\n", " ddim_sampling_eta= 0.,\n", " auto_normalize = True,\n", " offset_noise_strength = 0., # https://www.crosslabs.org/blog/diffusion-with-offset-noise\n", " min_snr_loss_weight = False, # https://arxiv.org/abs/2303.09556\n", " min_snr_gamma = 5)" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "3mhhRbVYioA4", "outputId": "641a6961-9b1b-49f2-f938-19ed852cd6e6" }, "outputs": [], "source": [ "# define dataset\n", "transform = transforms.Compose([\n", " transforms.ToTensor(),\n", " # v2.Normalize((0.1307,), (0.3081,)), # https://stackoverflow.com/questions/70892017/normalize-mnist-in-pytorch\n", "])\n", "\n", "train_dataset = torchvision.datasets.MNIST(root='.', train=True,\n", " download=True, transform=transform)\n", "# test_dataset = torchvision.datasets.MNIST(root='.', train=True,\n", "# download=True, transform=transform)\n", "\n", "train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)\n", "# test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "id": "y8TRMUr7lNh0" }, "outputs": [], "source": [ "# define optimizer\n", "train_lr = 1e-4\n", "adam_betas = (0.9, 0.99)\n", "\n", "optimizer = Adam(diffusion_model.parameters(),\n", " lr=train_lr,\n", " betas=adam_betas)" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "id": "UmBTWmMxn7W-" }, "outputs": [], "source": [ "# device\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000, "referenced_widgets": [ "857c3f15450d4d529c0bf11d41cb4253", "d8f0eb19fdb64b0092b3fecaf236ed5d", "ff5ac0305129457da54f3b404671e794", "8736253b92c8489ba57b953bc2d6627b", "4919f968866f459bb98dfdab3322cf0b", "af8e238099d34e3d9b0cc05abe39d6df", "a1338d5096c44c2c88a849046f56ec52", "a9ff57c43d214f38ba3229aa5f590099", "b7a09364b1004e44840e51ce713d52be", "23ff0ecca3934992bb3469a018799350", "42760c1e73e64231aec761748f35611e", "09780340d728481c9405c24129822397", "9e37e76eba5b4db19bae5c631fab55e5", "911eb17aba3146298c6c6d1553289d7e", "aaee07d9bd6b4e35a213dc05f2fcbf18", "eb7f875ad14942ba81f20a9a97de755f", "a35fa34039aa40daba66db1cb9a981fa", "7d87e432d5bb4e5ebf4d2d07ddf44ffe", "ccd35f64801b4c83ad0343e6d5abb375", "fbec2ad31f1e4dc2a3123bd0085c3916", "f207d57428fe49bf8077164627a361f6", "27ae0d5085654196a56b06535389b7a7", "565592bf04bf4cefbbf896766ffc2988", "c973c00dc83a4d55a3e38eadf471ccba", "6e92910bdb4945fd82c0280539fcdb00", "8ab8438ac7ca418eae5a283ba701850f", "9851ba32d7b34c109e8c6641df754a25", "3ed540f9a929408d9f59ea371b0b3433", "23f8195c3f774538aea3c4d1a6db1e23", "7178b563fb4e473c8a6933d8fca8cf2a", "383764f160b6475aba892a6cce70786e", "5934379970f34b00a6f8f8d2675694b0", "75472ac55bbc445da8773d5ef5600adc", "c22ac378bc0c47dcac01fcef6c2b74b0", "1f050054fad14efcbbea80bfc81744c5", "ca177467d2484580a3fbbc94c8fe73fb", "583380888879419488ce9768080ccdeb", "d82fb696ffdd4d5e801291c09874f7bc", "7880e017257c4ae99fc1185d0a042679", "a44430de99254f4e9ec65dce41a08213", "55b391fa2eee4c438e764f37b8a72886", "8c13e2206c214ff996c4e872f4ca6f90", "8e32633a653c499bbde29054586b3082", "3213360337034c31ba6844bffd10d8b0", "374d57e5cf4a45b5980ea2a2facabcd9", "885afe867fe34968b2400f9c427e45bb", "c6f5fcf8391c4e7e9cdfc9802271d5ab", "75d8e8068d404c25889b172fbd6709a6", "852d204c326345f0b497865db9e4208a", "c1e1ebeca32b47d5a4b1a4f8e480becd", "21816c2e65de4949b2a07b688f305ca1", "f8694040d74b4ab296538d98963e71f8", "f484370cd40244c6b28c71a0b0db65c4", "8abebe68b98a43d3a4fee381338ed642", "15b63fc5268749b29fd1970e8f57c0cf", "85e7b53cc6634bf2a9344029b65b811e", "1b032ff718974bf98b155909041e34a4", "e3b844b79f8a458ca56b582a97008c23", "d79fda82790a41d69c38e4a50c087301", "d036c52a50894d248435f0e2d459cd29", "7260aee3d5ef4168a46d941fcda0d383", "fbf62338858a405e99d45534fbf75f86", "cde40d8669ab4835b81bd06fecead120", "540dc27e86094ac4863c35dcbc4d3f9b", "f39c7ff2b0c844f3ad049d3e812d79b9", "4682b3abe3134554b0b3b8333ceb0fc6", "4142a38fed0b4f648e3afe81385595fa", "f5608785587048e49533d91cf8356732", "a20aebe2d8474b979e0eb73141de8704", "1352a16733d0433c857b1215bcd730fd", "2c1eca9bf78a4046acc0cc955492e8b5", "817890da63304994b579481621fc9d4f", "0c3d2867c7194f58ae83fef73505471f", "879ee37ff7b749d49d05573ead370739", "41cd16b082114bae9fb9ed3c7df5d71d", "fb162bc227774ca194ba7f944b7522f8", "0f4d4cd6aba04b188cd992b2a7f5f3dc", "302228f02177406c97db62041dfa2c59", "8f23a14df0fe47b6867ade436cae8e89", "8d3c5741c8a740efb96c4c9dd3f4bf70", "a664c8d5231f45c2ae94e7201037a12c", "a1e1869682fc457a99ac5a18cf7bcb25", "6ea0969878e548c298057391f7bcc112", "9ccddafc65634b2a898d803d94da3be8", "cd7f63747a4a49069595951f44dc7b20", "07735c43098b4e49b9db4a9321c5d446", "f69d418d6916405486f4c4916ae0b991", "d52330c4d01d4fdc93fc2efd06b21887", "d8f1e7d1aacd4c88a7e238d446bb74bb", "474169cf04b64c609780548a17e5c6cf", "6a5fe9a9aa204b9286808de48e59303f", "6f9e7f52f74a4646b2f99ddb41c21b69", "6810d8d15ee149c09c6193205a17ef3a", "ca8350a7fff748f38c12a7e4a65aa264", "72eb86918e4846c2958aa5f286508114", "90365a41563b46e4b46981d6734c3349", "545ec778cd9548508e48f9ad874858b3", "4f57c41ca1a54feeb46c0a00f43dfe8c", "9ca3b701041a4274aaeee1c5afac0f31", "0834e883a51a4403a60bcfc49cf0f193", "2e79117731da4ca19cc7827056dd343b", "6ccc897173ef4ba3882b7048527e5e65", "7b39bfb578b747f1a2adbb7d4ab39b6f", "261875b26e6d431691d62908e44ec8ec", "1a7ce70b672549508107f3e1aff33a7f", "23e58eb20453421c98f34b68cad1ea0c", "f5145344c18e428984541ba4da095d4c", "5168777e2c14422fb83401ba8a645426", "dd1120f5759644259e62d581bb907595", "fc6de7e7b7b84772986135bc1a2a0755", "9c5f76a4712e48fc9245a40f4420eb3d", "e213af73f15945f998648e52ebcdaab1", "24fe4edd5bb946de941597d393db613a", "5111966d34474761837cd98026c82e5b", "558e21a38c7747128943d1f80a872695", "2257c6054068440ca9cf673326b3661c", "c8bf57439669431f8abb19ca73a97d4a", "b02a2ec8411d433aaa9a1391bc274d10", "7be58c89172b42efba026a5bf986b392", "9e1ec32d3e194e8981671b02b3d6bb4f", "17e0e9ec01d2456d869c28cffab85f0c", "6beaba46a9e34d048da0e9448781aa72", "9abcec203b26409087d96f3e3a73db96", "c7ee49cea60d4857ae1e44376dc1fc75", "5eae9fb2a5294281abb3fa4bb0200cb7", "c72222f736d04b84a5a1a8f0db331577", "92f7c981deaf4fd897c4a88b327f8b3e", "0d1a536bd55e490cb76fe0d840908fe6", "e95075ea297c4110a51822226f3f9370", "9aaef5a48ecc488c8af227469dec1e04", "8a9b7826399e4051b03451616c610274", "131d964e534548eb9f17f9fe7a4125a3", "9c90572096c4428fb6c334fc859a7580", "e22d51e41f7b4743bca1455b24dc4f66", "6e5ac5e267bc4cb48d3d1373c654ad98", "84ff76db262449dab456ba1d289a290f", "0d5d18897d104e0b8d161a4c8ebc871e", "51837fa329d7442e906017cdfabc660a", "eae0e8b77f654b7389a629d21951e760", "96d06c1e31004fd7b11abec0d6f2aa65", "2a3f63d62e3846d197346f2d8168b729", "49972abac2274a9d8ec37779cc65b337", "65584f2726b24b0cb606c4d32a87cada", "263e281b773e4c31b68da11a9c5aea03", "0c3ab5dd1e654e2595872e61fab4776a", "5a1d55f7197a45c7ade8680775ba8aec", "807b5f6c67b84133af43af8cd729137c", "63775c1cbdd04d80bccb4c40a1cecf1d", "351a938d65fc446885341a0677b02441", "dbf68c06f10c4305a107d8a44f89c437", "9febdeea2bf44ed499be28a975fdd637", "318938cb1f2f4ebea65a1cd59f420c28", "1278bc75f6cf4964b941bcd7fc1c72df", "75fb3fd270de45b8b936b48a4f761639", "1d5a425c04044734b1fe49d111b8e814", "b72b5e10e21b42ec80715a33414f11d8", "98cafc16043d42dd8d02d9f858099016", "cbd36a3cbd3f4e92929c6c593d6811be", "90eaf5412a2447c69154bc553841938d", "3891957a20a4454bb0cde2b4a793d3a4", "a6983213805745d6b690555f6f50d0e0", "f495962491594efeb10dbd14d6dee12d", "98fa690742f948d783eae2c9aa05e2d1", "e47a0ae960484281885416ff79b13227", "d47545e83a4c4c2d9f0f33045b41a65e", "71df511caf154e6b82cddbcebe49ecb9", "db5e579295044022ba5a9c12ea1edebc", "eafc72897f4e4120897c4618a270d48c", "19a38bf332294ab1827ceadd52a523e6", "88f872d67981464ebed43cebeddcd1ef", "db3bc67efcd046ba8268f1b10104dfea", "c954cedd654d42348f20a1b151822ada", "abb1e721a4d6486abcc89360365f4af8", "033eec76cad84bed9977b88ad2d562de", "a91ff9e6e19b4736aba2ec95c71478e1", "8cd2ea272add4a14b7ba05023993a195", "5be58639cfb344798434adc748a03ae8", "6e59fbf41be64d6a8b73cb61c9fdb9b6", "c6354e3c0339480e98a70ae380ada473", "14142bb9177c42e1beaf0387ec848860", "26faabb69fc548a5b4847154706c8e95", "e48bcec2540347ddbd8cfbda6e152063", "1e892f2fc7af4e5092b528bcae108056", "d1068325215347588f984bb89d1174ff", "22a23bb6948e4012a21ad0c305526ada", "3fc0bf560cfb4653a6ad9bfcb4e60e0f", "e942deb7b82349c1842aa97da94cbabd", "9e0334712c5645c58bd98320386385cc", "d0a1130ae8b54cb48af375f58afeb9db", "26b5b77cc55e4b9dbf6e8d2dcbc75812", "304b72600fc94af6b04938eaced28aad", "f46737f123114b97876473045528fe5a", "82eaf05999314425a134cc73245e7850", "1c0a7a5a5c094464a937e37e9d6c754b", "3901eb01537c48449bbc3b020392b844", "3aaa5171e84040e6b61364ff4cb8c84c", "382954f171dd4e318a85bd583eadca73", "0be16d4a86ff4889b914a0e3af7b47a6", "a80f58405dc54ebc819978144b479781", "2a622b52036a4f3182afce07dca3c182", "f925d984ea184d39bf38f73bfd027d9a", "141c321476954d67b5202d5ae8535a9b", "4f3b3b384bdd44d78d3d1bae5e749dc1", "742764cce8ad4ca2a3efeb2ef1c60faa", "c4036c19fdaf44479b4783b51d082e18", "66e76848072e402a85152e8f014046d1", "793b581d82ae4ab69d56279e3f4d3812", "123f11205ee24ab5aeb77048d1f2d8e5", "fc469c1965c64c0cb0a6399a3dcca442", "536582c2dd4f4a77bf27342be14f05bb", "72384a47b1c94da3bfee80a39474c338", "d691b615e0844eeeb02224fb3e6f2b08", "0d7b477d713644a39340c9882ebfaf69", "aeb68184b87646c196becd6eaac6fcbf", "23c228dbd62942708f2060742aafbe6a", "12c181ba36504f15837d81fcee8fe0a4", "1b530eb483c3437c811439f94bab7d0f", "5addb197ecf7451c82b0a90f23f12842", "979312569b624f1d9ee235ee6e114de0", "9154c46e418f4b41b259bd5ff4ee34c6", "09b98afdcafc4d0c84ea208f13e21d44", "cc13db89822b4e27a61ee1c065927cba", "0b5949b83bac403a90772b846858f9ab", "5dee0e7d35d44e1596c932fef0c1d6b4", "534248acf3b04bd5b0aaaa85f5b841e3", "44f44d9d76e1495193a98154b8197594", "ac2bda45d75b4faabc5662c00c0fa094", "303db6df6d494944bdc4c823de508c1a", "99c5772fafbd477a932fff478d02ee22", "2001e1a36b9446ca9d8c056b468ac902", "317dff1ffc75454d9a19a01f1da6e7c1", "fa9932a3c7784e74877e24054dc812c1", "e6d311c2426d46adba18c7ce76baa2e3", "9527dc2106354a59b19d6534b7563b5e", "83313d45f3df443483ec20d4c2655603", "f7156fa67b9148538eb22958558b5c09", "df8b5231ef264851bc887dbfa32f1e90", "c26c3a58ba014dd3abe192b43288d40d", "62ee44fa5568426982a4e117b186e2a9", "cae06780e8064d7097f3c5b98b7db552", "25796d342fea47b8a95061c5bd0fe9c9", "5f442b316ab74df29a055dfd7a6a8339", "503951fef852409d9a6df4cdaaa44520", "1a723fa277884c019c1b19ef3b97e1da", "9457a23761a046588570b95c69f28e51", "ec71e594c76d46f2816f06743f93d73b", "39dcb083faf24eafabae7f598505d69f", "cc4813f5ef984613a44b731ff74c9c42", "85064c9751994315a3c85538131a1545", "b71a71aadd4049c1a8281ea8487d3bae", "65855f3cad194dcf99722ca0fc4a2439", "864e6e43c81d4a17bebdf89a4f3d6749", "8701f498e8944e8aa7a393ebde866b4a", "35b697118dea48bb82d84503eed11362", "add5fb60fdb74b4c85f03cc4abb367b0", "f25ff7e82a9f41e09279d37d8d47d73e", "54076fc088964b169473c21c2172c6b4", "205be1b44b0a44068dc259f9d91ef637", "2416512323c243c4abd81d1bd48282e3", "518cd5b295eb4d739cee4c1fb0f96dc8", "5b9d2eaac57244ddad27e4c9cc66cd5e", "6cb28be1780d4f75a925da5404cc89a3", "c27811d082954ea6bfec493a16da60eb", "6c3062bf869548e08855c8966e848545", "ecac7e6bb9784dc793f683ce7f1cc4a6", "5329158506e1499fa2c151e7ce19147f", "ecc74d7a85674a3d81d7cf1490e01ba6", "39058df4917042a0a6cf88f1886591a4", "06d16d5568964531b881f66d848740a1", "6f012108aeef42afa9d7f4000e643df8", "ad745f3ee22043febd0ef3554d573771", "588cabf315cf4577abe6c861b5a81575", "a5dacd4fa6f14c39a15224a6934f316b", "cc2f093be2ab4af5ab646b8ca7806aa9", "fe2b9a4922704d9c968854cdc2914334", "caf35dc15cab4143808b7361018cc776", "5280fbbd1d704cdca9082928da7948fc", "bf5434d60960498ebc0ed007b8c99060", "73d1ff6cba654a499b6ea0523a6b534a", "de0f74ef095c457888ad40821d4353d2", "ed4c5638a98848b89be93c15dbaaa8b4", "8104c493dcd84420b318a1a97e5e1976", "25687a365dd0456db1662f6179c794e1", "96d95f61c06d48389e60b869ff8b91ae", "8d16b44709ee403292a4f028880574d9", "e175255c01104398ad42390384a0da82", "a57c705d7fc246a0a745d3f180865e01", "6b692966fcfe49a9a0937efdded3eb83", "3af2432c95f94eac85b34bc1c05c0bb5", "666ed6880e704532a79d6eb0dff47264", "43eb93cf2d4744b5954be580d94632ae", "947dcd807b6f42abb56400e8f1accd26", "a55e935e5afa459db325598e527cb5d0", "ed0f8b2b16ba4496812d71dfa95406ea", "b8cfed6f951e4cf694303c3170095a70", "a75e44f024034a00872c06df20b7f83e", "2c5274f677b4431086cc499d90a51611", "1b2e76e3eed142c2b39ae18ce8bae669", "ef8b762e82194a9180a7aa1b71e16e00", "91054e4a2f1c4e619b6d36683b04f9b7", "45dac5c21e9348e889cb284b1db4236c", "8ca737628a5844329c2a2c58ddf58d57", "b7bedf67b21644be9da9f4266b46e368", "08dfc324e6a841ab8b61d3729058ca64", "459df98a78c845fa954c74af3c5c644b", "f6e649776bbe48eab6a0022e7d0d6b28", "373e2e3dc869409aa52f1dc7319bd409", "2c01626ca5ca4d4f8f617c5524fb91ca", "cddb10f49ee84fa7bbecea126386b74a", "d506c5295c9d471cb4f2c01b49930942", "6452749859c245d98c826f5903b4c44b", "b021e48c1c294b489bc683ac27a63338", "56c23502f16443859eb6b180a31e7488", "015c83218c5f429fa026885c30db4ff3", "4a459a42c7bc4a4298e302c587f0e6a9", "571320e39cc54afe941f2a2dcc4dbfd8", "e5e3b4bb465643478a43aeed601e1794", "438819c10f974244aa4303f95431fc1d", "08b2d7bbe1254673831fd569eb6a6d49", "8d4cffef45024c1db0451b61cbfee77f", "1aeaed6b40e746169c5d8be1dc85dfa4", "47017cc8cd9b49d7bbd8d63da7a5d593", "5c58acac4f0a482b92624560a02bf969", "9be3c20cc99d4df58ed947b24034651d", "67d65a64c00e4d5e8cbfaabcf958a13f", "02a7bf81fa0c425cb08fff4782ee5cf6", "b2fdbafafac442e6bb13df5a3b39f7db", "3ed6f145d6c549d8a90fefda283c0640", "478bd283a36041b8b1c18807ff56f68f", "28b6f7d2eb7541cea287f2cdeaf89135", "2f5066a1e44847c998ce7e190182605a", "92076e5186af4b339cf1e3c5e86d0ad9", "0b3390c0ec0044218f9590324a3e7e74", "e44180b3b4cd494d800e496b0fb1fc16", "bf2e91a0a5c5443dac1d1233dcb5d9f4", "0886a59ad02a45349c71f169ed936f68", "6a211d4d8e104485a1f1e0dde6717b3d", "aa9c0be7601f4e99854f0a1cc0b2c5dd", "b5a38be3c6f744c78395e69f3cf177ec", "d56ae5dbabe64edaaf0cf58ff140d874", "782cdb95f2ff49b8aaf9418f98a82ed5", "9ebfde3148ca420babb6b0fba3444029", "fe73ebe19d8f4f9badb520435ac39037", "94e04964186443eea8665205f20b0a48", "2175967d0c0749d38ce7b051a7cac82b", "68491e3bc29846db9bc52a0f2d1af39d", "18aceecba8814c3fa900432272d2bae6", "fa3f05b4fd9548348ee6614e0c756ea6", "c854272b9d5641fba81d7fd0a273e748", "e890ddd289c54a06bf0c89ce1ed64c60", "47a35f325de447ce89a2b3308fb74762", "01ab680b94574e04a683f27ed19d1125" ] }, "id": "E-tdEl4sn9N1", "outputId": "d5c9d0a2-f938-4d79-ed09-531465e7ca22" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Epoch 1/50, Iter 0: Loss = 0.9049434065818787, lr = 0.0001\n", "Epoch 1/50, Iter 100: Loss = 0.0803382471203804, lr = 0.0001\n", "Epoch 1/50, Iter 200: Loss = 0.09960899502038956, lr = 0.0001\n", "Epoch 1/50, Iter 300: Loss = 0.06725724041461945, lr = 0.0001\n", "Epoch 1/50, Iter 400: Loss = 0.033031366765499115, lr = 0.0001\n", "Epoch 1/50, Iter 500: Loss = 0.04144970700144768, lr = 0.0001\n" ] }, { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "Cell \u001b[1;32mIn[30], line 42\u001b[0m\n\u001b[0;32m 33\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m count \u001b[38;5;241m%\u001b[39m iter_print \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m count \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m 34\u001b[0m logger\u001b[38;5;241m.\u001b[39minfo(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mEpoch \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m/\u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m, Iter \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m: Loss = \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m, lr = \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\n\u001b[0;32m 35\u001b[0m epoch,\n\u001b[0;32m 36\u001b[0m max_epoches,\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 39\u001b[0m train_lr,\n\u001b[0;32m 40\u001b[0m ))\n\u001b[1;32m---> 42\u001b[0m log_loss\u001b[38;5;241m.\u001b[39mappend(\u001b[43mloss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmean\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mitem\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[0;32m 44\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m 46\u001b[0m count \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n", "\u001b[1;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "# trainer\n", "max_epoches = 50\n", "iter_print = 100\n", "iter_sample = 1000\n", "save_each = 1\n", "\n", "diffusion_model = diffusion_model.to(device)\n", "\n", "last_trained_path = None\n", "if last_trained_path:\n", " data = torch.load(os.path.join(last_trained_path))\n", " diffusion_model.load_state_dict(data['model'])\n", " optimizer.load_state_dict(data['opt'])\n", " count = data['step']\n", " start_epoch = data['epoch']\n", " log_loss = data['loss']\n", "else:\n", " count = 0\n", " start_epoch = 1\n", " log_loss = []\n", "\n", "for epoch in range(start_epoch, max_epoches+1):\n", " diffusion_model.train()\n", " for img, _ in train_dataloader:\n", " img = img.to(device)\n", "\n", " loss = diffusion_model(img)\n", "\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " if count % iter_print == 0 or count == 0:\n", " logger.info('Epoch {}/{}, Iter {}: Loss = {}, lr = {}'.format(\n", " epoch,\n", " max_epoches,\n", " count,\n", " loss.mean().item(),\n", " train_lr,\n", " ))\n", "\n", " log_loss.append(loss.mean().item())\n", "\n", " loss = None\n", "\n", " count += 1\n", "\n", " if count % iter_sample == 0:\n", " diffusion_model.eval()\n", "\n", " sample_imgs = diffusion_model.sample(batch_size=16)\n", "\n", " utils.save_image(sample_imgs,\n", " os.path.join(log_path, f\"iter_{count}.png\"),\n", " nrow = int(math.sqrt(16)))\n", "\n", "\n", " if epoch % save_each == 0:\n", " data = {\n", " 'model': diffusion_model.state_dict(),\n", " 'opt': optimizer.state_dict(),\n", " 'step': count,\n", " 'epoch': epoch,\n", " 'loss': log_loss,\n", " }\n", "\n", " torch.save(data, os.path.join(save_path, f\"epoch_{epoch}.pth\"))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Sample" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "model = Unet(\n", " n_feats=128,\n", " ch_mul=[1,2,4],\n", " attention_mul=[1,2],\n", " t_dim=512,\n", " dropout=0.1,\n", " channels=1, # MNIST\n", " out_dim=1, # MNIST\n", " num_res_blocks=2,\n", " self_condition = False,\n", " learned_sinusoidal_cond=False,\n", " random_fourier_features=False,\n", " learned_sinusoidal_dim=16,\n", " sinusoidal_pos_emb_theta=10000,\n", ")\n", "\n", "diffusion_model = GaussianDiffusion(\n", " model,\n", " image_size=28, # MNIST\n", " timesteps=1000,\n", " sampling_timesteps=None,\n", " objective ='pred_noise',\n", " beta_schedule ='linear',\n", " schedule_fn_kwargs=dict(),\n", " ddim_sampling_eta= 0.,\n", " auto_normalize = True,\n", " offset_noise_strength = 0., # https://www.crosslabs.org/blog/diffusion-with-offset-noise\n", " min_snr_loss_weight = False, # https://arxiv.org/abs/2303.09556\n", " min_snr_gamma = 5)" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "last_trained_path = 'unet_wo_t\\model\\epoch_30.pth'\n", "diffusion_model.load_state_dict(torch.load(os.path.join(last_trained_path))['model'])" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "sample_path = 'unet_wo_t'\n", "sample_batch = 32" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "diffusion_model = diffusion_model.to(device)" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2e2060e2817449b384cbafe1bba9bc66", "version_major": 2, "version_minor": 0 }, "text/plain": [ "sampling loop time step: 0%| | 0/1000 [00:00