{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "68fece49", "metadata": {}, "outputs": [], "source": [ "import gradio as gr\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import matplotlib.pyplot as plt\n", "\n", "class DoubleConv(nn.Module):\n", " def __init__(self, in_channels, out_channels, mid_channels=None, residual=False):\n", " super().__init__()\n", " self.residual = residual\n", " if not mid_channels:\n", " mid_channels = out_channels\n", " self.double_conv = nn.Sequential(\n", " nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),\n", " nn.GroupNorm(1, mid_channels),\n", " nn.GELU(),\n", " nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),\n", " nn.GroupNorm(1, out_channels),\n", " )\n", "\n", " def forward(self, x):\n", " if self.residual:\n", " return F.gelu(x + self.double_conv(x))\n", " else:\n", " return self.double_conv(x)\n", "\n", "class Down(nn.Module):\n", " def __init__(self, in_channels, out_channels, emb_dim=256):\n", " super().__init__()\n", " self.maxpool_conv = nn.Sequential(\n", " nn.MaxPool2d(2),\n", " DoubleConv(in_channels, in_channels, residual=True),\n", " DoubleConv(in_channels, out_channels),\n", " )\n", "\n", " self.emb_layer = nn.Sequential(\n", " nn.SiLU(),\n", " nn.Linear(\n", " emb_dim,\n", " out_channels\n", " ),\n", " )\n", "\n", " def forward(self, x, t):\n", " x = self.maxpool_conv(x)\n", " emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])\n", " return x + emb\n", "\n", "class Up(nn.Module):\n", " def __init__(self, in_channels, out_channels, emb_dim=256):\n", " super().__init__()\n", "\n", " self.up = nn.Upsample(scale_factor=2, mode=\"bilinear\", align_corners=True)\n", " self.conv = nn.Sequential(\n", " DoubleConv(in_channels, in_channels, residual=True),\n", " DoubleConv(in_channels, out_channels, in_channels // 2),\n", " )\n", "\n", " self.emb_layer = nn.Sequential(\n", " nn.SiLU(),\n", " nn.Linear(\n", " emb_dim,\n", " out_channels\n", " ),\n", " )\n", "\n", " def forward(self, x, skip_x, t):\n", " x = self.up(x)\n", " x = torch.cat([skip_x, x], dim=1)\n", " x = self.conv(x)\n", " emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])\n", " return x + emb\n", "\n", "class UNet(nn.Module):\n", " def __init__(self, c_in=3, c_out=3, time_dim=256, device=\"cuda\"):\n", " super().__init__()\n", " self.device = device\n", " self.time_dim = time_dim\n", "\n", " self.inc = DoubleConv(c_in, 64)\n", " self.down1 = Down(64, 128)\n", " self.down2 = Down(128, 256)\n", " self.down3 = Down(256, 256)\n", "\n", " self.bot1 = DoubleConv(256, 512)\n", " self.bot2 = DoubleConv(512, 512)\n", " self.bot3 = DoubleConv(512, 256)\n", "\n", " self.up1 = Up(512, 128)\n", " self.up2 = Up(256, 64)\n", " self.up3 = Up(128, 64)\n", " self.outc = nn.Conv2d(64, c_out, kernel_size=1)\n", "\n", " def positional_encoding(self, t, channels):\n", " inv_freq = 1.0 / (\n", " 10000\n", " ** (torch.arange(0, channels, 2, device=self.device).float() / channels)\n", " )\n", " pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)\n", " pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)\n", " pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)\n", " return pos_enc\n", "\n", " def forward(self, image, t):\n", " t = t.unsqueeze(-1).type(torch.float)\n", " t = self.positional_encoding(t, self.time_dim)\n", "\n", " x1 = self.inc(image)\n", " x2 = self.down1(x1, t)\n", " x3 = self.down2(x2, t)\n", " x4 = self.down3(x3, t)\n", "\n", " x4 = self.bot1(x4)\n", " # x4 = self.bot2(x4)\n", " x4 = self.bot3(x4)\n", "\n", " x = self.up1(x4, x3, t)\n", " x = self.up2(x, x2, t)\n", " x = self.up3(x, x1, t)\n", " output = self.outc(x)\n", " return output\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "model = UNet(device = device).to(device)\n", "model.load_state_dict(torch.load('Model_Saved_States/diffusion_64.pth'))\n", "img_size = 64\n", "class Diffusion():\n", " def __init__(self, time_steps = 500, beta_start = 0.0001, beta_stop = 0.02, image_size = 64, device = device):\n", " self.time_steps = time_steps\n", " self.beta_start = beta_start\n", " self.beta_stop = beta_stop\n", " self.img_size = image_size\n", " self.device = device\n", "\n", " self.beta = self.beta_schedule()\n", " self.beta = self.beta.to(device)\n", " self.alpha = 1 - self.beta\n", " self.alpha = self.alpha.to(device)\n", " self.alpha_hat = torch.cumprod(self.alpha, dim = 0).to(device)\n", "\n", "\n", " def beta_schedule(self):\n", " return torch.linspace(self.beta_start, self.beta_stop, self.time_steps)\n", "\n", " def noise_images(self, images, t):\n", " sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None,]\n", " sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None,]\n", " noises = torch.randn_like(images)\n", " noised_images = sqrt_alpha_hat * images + sqrt_one_minus_alpha_hat * noises\n", " return noised_images, noises\n", "\n", " def random_timesteps(self, n):\n", " return torch.randint(low=1, high=self.time_steps, size=(n,))\n", "\n", " def generate_samples(self, model, n):\n", " with torch.no_grad():\n", " x = torch.randn((n, 3, self.img_size, self.img_size)).to(self.device)\n", " for i in range(self.time_steps - 1, 1, -1):\n", " t = (torch.ones(n) * i).long().to(self.device)\n", " predicted_noise = model(x, t)\n", " alpha = self.alpha[t][:, None, None, None]\n", " alpha_hat = self.alpha_hat[t][:, None, None, None]\n", " beta = self.beta[t][:, None, None, None]\n", " if i > 1:\n", " noise = torch.randn_like(x)\n", " else:\n", " noise = torch.zeros_like(x)\n", " x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise\n", "\n", " return (x[0].cpu().numpy().transpose(1, 2, 0) / 255)\n", " #show_images\n", "\n", "diffusion = Diffusion()\n" ] }, { "cell_type": "code", "execution_count": 26, "id": "a80516cd", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7867\n", "Running on public URL: https://080248f8c7c14eec1e.gradio.live\n", "\n", "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np\n", "def greet(n):\n", " image = diffusion.generate_samples(model, n = 1)\n", " image = (np.clip(image * 255, -1, 1) + 1) / 2\n", " plt.imshow(image)\n", " return image\n", "\n", "iface = gr.Interface(fn=greet, inputs=\"number\", outputs=\"image\")\n", "iface.launch(share = True)" ] }, { "cell_type": "code", "execution_count": null, "id": "cc6f5064", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.4" } }, "nbformat": 4, "nbformat_minor": 5 }