{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from typing import Tuple\n", "\n", "from IPython.display import display\n", "from PIL import Image\n", "import numpy as np\n", "import torch as th\n", "import torch.nn.functional as F\n", "\n", "from glide_text2im.download import load_checkpoint\n", "from glide_text2im.model_creation import (\n", " create_model_and_diffusion,\n", " model_and_diffusion_defaults,\n", " model_and_diffusion_defaults_upsampler\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# This notebook supports both CPU and GPU.\n", "# On CPU, generating one sample may take on the order of 20 minutes.\n", "# On a GPU, it should be under a minute.\n", "\n", "has_cuda = th.cuda.is_available()\n", "device = th.device('cpu' if not has_cuda else 'cuda')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Create base model.\n", "options = model_and_diffusion_defaults()\n", "options['inpaint'] = True\n", "options['use_fp16'] = has_cuda\n", "options['timestep_respacing'] = '100' # use 100 diffusion steps for fast sampling\n", "model, diffusion = create_model_and_diffusion(**options)\n", "model.eval()\n", "if has_cuda:\n", " model.convert_to_fp16()\n", "model.to(device)\n", "model.load_state_dict(load_checkpoint('base-inpaint', device))\n", "print('total base parameters', sum(x.numel() for x in model.parameters()))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Create upsampler model.\n", "options_up = model_and_diffusion_defaults_upsampler()\n", "options_up['inpaint'] = True\n", "options_up['use_fp16'] = has_cuda\n", "options_up['timestep_respacing'] = 'fast27' # use 27 diffusion steps for very fast sampling\n", "model_up, diffusion_up = create_model_and_diffusion(**options_up)\n", "model_up.eval()\n", "if has_cuda:\n", " model_up.convert_to_fp16()\n", "model_up.to(device)\n", "model_up.load_state_dict(load_checkpoint('upsample-inpaint', device))\n", "print('total upsampler parameters', sum(x.numel() for x in model_up.parameters()))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def show_images(batch: th.Tensor):\n", " \"\"\" Display a batch of images inline. \"\"\"\n", " scaled = ((batch + 1)*127.5).round().clamp(0,255).to(th.uint8).cpu()\n", " reshaped = scaled.permute(2, 0, 3, 1).reshape([batch.shape[2], -1, 3])\n", " display(Image.fromarray(reshaped.numpy()))\n", "\n", "def read_image(path: str, size: int = 256) -> Tuple[th.Tensor, th.Tensor]:\n", " pil_img = Image.open(path).convert('RGB')\n", " pil_img = pil_img.resize((size, size), resample=Image.BICUBIC)\n", " img = np.array(pil_img)\n", " return th.from_numpy(img)[None].permute(0, 3, 1, 2).float() / 127.5 - 1" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Sampling parameters\n", "prompt = \"a corgi in a field\"\n", "batch_size = 1\n", "guidance_scale = 5.0\n", "\n", "# Tune this parameter to control the sharpness of 256x256 images.\n", "# A value of 1.0 is sharper, but sometimes results in grainy artifacts.\n", "upsample_temp = 0.997\n", "\n", "# Source image we are inpainting\n", "source_image_256 = read_image('grass.png', size=256)\n", "source_image_64 = read_image('grass.png', size=64)\n", "\n", "# The mask should always be a boolean 64x64 mask, and then we\n", "# can upsample it for the second stage.\n", "source_mask_64 = th.ones_like(source_image_64)[:, :1]\n", "source_mask_64[:, :, 20:] = 0\n", "source_mask_256 = F.interpolate(source_mask_64, (256, 256), mode='nearest')\n", "\n", "# Visualize the image we are inpainting\n", "show_images(source_image_256 * source_mask_256)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "##############################\n", "# Sample from the base model #\n", "##############################\n", "\n", "# Create the text tokens to feed to the model.\n", "tokens = model.tokenizer.encode(prompt)\n", "tokens, mask = model.tokenizer.padded_tokens_and_mask(\n", " tokens, options['text_ctx']\n", ")\n", "\n", "# Create the classifier-free guidance tokens (empty)\n", "full_batch_size = batch_size * 2\n", "uncond_tokens, uncond_mask = model.tokenizer.padded_tokens_and_mask(\n", " [], options['text_ctx']\n", ")\n", "\n", "# Pack the tokens together into model kwargs.\n", "model_kwargs = dict(\n", " tokens=th.tensor(\n", " [tokens] * batch_size + [uncond_tokens] * batch_size, device=device\n", " ),\n", " mask=th.tensor(\n", " [mask] * batch_size + [uncond_mask] * batch_size,\n", " dtype=th.bool,\n", " device=device,\n", " ),\n", "\n", " # Masked inpainting image\n", " inpaint_image=(source_image_64 * source_mask_64).repeat(full_batch_size, 1, 1, 1).to(device),\n", " inpaint_mask=source_mask_64.repeat(full_batch_size, 1, 1, 1).to(device),\n", ")\n", "\n", "# Create an classifier-free guidance sampling function\n", "def model_fn(x_t, ts, **kwargs):\n", " half = x_t[: len(x_t) // 2]\n", " combined = th.cat([half, half], dim=0)\n", " model_out = model(combined, ts, **kwargs)\n", " eps, rest = model_out[:, :3], model_out[:, 3:]\n", " cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)\n", " half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)\n", " eps = th.cat([half_eps, half_eps], dim=0)\n", " return th.cat([eps, rest], dim=1)\n", "\n", "def denoised_fn(x_start):\n", " # Force the model to have the exact right x_start predictions\n", " # for the part of the image which is known.\n", " return (\n", " x_start * (1 - model_kwargs['inpaint_mask'])\n", " + model_kwargs['inpaint_image'] * model_kwargs['inpaint_mask']\n", " )\n", "\n", "# Sample from the base model.\n", "model.del_cache()\n", "samples = diffusion.p_sample_loop(\n", " model_fn,\n", " (full_batch_size, 3, options[\"image_size\"], options[\"image_size\"]),\n", " device=device,\n", " clip_denoised=True,\n", " progress=True,\n", " model_kwargs=model_kwargs,\n", " cond_fn=None,\n", " denoised_fn=denoised_fn,\n", ")[:batch_size]\n", "model.del_cache()\n", "\n", "# Show the output\n", "show_images(samples)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "##############################\n", "# Upsample the 64x64 samples #\n", "##############################\n", "\n", "tokens = model_up.tokenizer.encode(prompt)\n", "tokens, mask = model_up.tokenizer.padded_tokens_and_mask(\n", " tokens, options_up['text_ctx']\n", ")\n", "\n", "# Create the model conditioning dict.\n", "model_kwargs = dict(\n", " # Low-res image to upsample.\n", " low_res=((samples+1)*127.5).round()/127.5 - 1,\n", "\n", " # Text tokens\n", " tokens=th.tensor(\n", " [tokens] * batch_size, device=device\n", " ),\n", " mask=th.tensor(\n", " [mask] * batch_size,\n", " dtype=th.bool,\n", " device=device,\n", " ),\n", "\n", " # Masked inpainting image.\n", " inpaint_image=(source_image_256 * source_mask_256).repeat(batch_size, 1, 1, 1).to(device),\n", " inpaint_mask=source_mask_256.repeat(batch_size, 1, 1, 1).to(device),\n", ")\n", "\n", "def denoised_fn(x_start):\n", " # Force the model to have the exact right x_start predictions\n", " # for the part of the image which is known.\n", " return (\n", " x_start * (1 - model_kwargs['inpaint_mask'])\n", " + model_kwargs['inpaint_image'] * model_kwargs['inpaint_mask']\n", " )\n", "\n", "# Sample from the base model.\n", "model_up.del_cache()\n", "up_shape = (batch_size, 3, options_up[\"image_size\"], options_up[\"image_size\"])\n", "up_samples = diffusion_up.p_sample_loop(\n", " model_up,\n", " up_shape,\n", " noise=th.randn(up_shape, device=device) * upsample_temp,\n", " device=device,\n", " clip_denoised=True,\n", " progress=True,\n", " model_kwargs=model_kwargs,\n", " cond_fn=None,\n", " denoised_fn=denoised_fn,\n", ")[:batch_size]\n", "model_up.del_cache()\n", "\n", "# Show the output\n", "show_images(up_samples)" ] } ], "metadata": { "interpreter": { "hash": "e7d6e62d90e7e85f9a0faa7f0b1d576302d7ae6108e9fe361594f8e1c8b05781" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 2 }