{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from PIL import Image\n", "from IPython.display import display\n", "import torch as th\n", "\n", "from glide_text2im.download import load_checkpoint\n", "from glide_text2im.model_creation import (\n", " create_model_and_diffusion,\n", " model_and_diffusion_defaults,\n", " model_and_diffusion_defaults_upsampler\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# This notebook supports both CPU and GPU.\n", "# On CPU, generating one sample may take on the order of 20 minutes.\n", "# On a GPU, it should be under a minute.\n", "\n", "has_cuda = th.cuda.is_available()\n", "device = th.device('cpu' if not has_cuda else 'cuda')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Create base model.\n", "options = model_and_diffusion_defaults()\n", "options['use_fp16'] = has_cuda\n", "options['timestep_respacing'] = '100' # use 100 diffusion steps for fast sampling\n", "model, diffusion = create_model_and_diffusion(**options)\n", "model.eval()\n", "if has_cuda:\n", " model.convert_to_fp16()\n", "model.to(device)\n", "model.load_state_dict(load_checkpoint('base', device))\n", "print('total base parameters', sum(x.numel() for x in model.parameters()))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Create upsampler model.\n", "options_up = model_and_diffusion_defaults_upsampler()\n", "options_up['use_fp16'] = has_cuda\n", "options_up['timestep_respacing'] = 'fast27' # use 27 diffusion steps for very fast sampling\n", "model_up, diffusion_up = create_model_and_diffusion(**options_up)\n", "model_up.eval()\n", "if has_cuda:\n", " model_up.convert_to_fp16()\n", "model_up.to(device)\n", "model_up.load_state_dict(load_checkpoint('upsample', device))\n", "print('total upsampler parameters', sum(x.numel() for x in model_up.parameters()))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def show_images(batch: th.Tensor):\n", " \"\"\" Display a batch of images inline. \"\"\"\n", " scaled = ((batch + 1)*127.5).round().clamp(0,255).to(th.uint8).cpu()\n", " reshaped = scaled.permute(2, 0, 3, 1).reshape([batch.shape[2], -1, 3])\n", " display(Image.fromarray(reshaped.numpy()))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Sampling parameters\n", "prompt = \"an oil painting of a corgi\"\n", "batch_size = 1\n", "guidance_scale = 3.0\n", "\n", "# Tune this parameter to control the sharpness of 256x256 images.\n", "# A value of 1.0 is sharper, but sometimes results in grainy artifacts.\n", "upsample_temp = 0.997" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "##############################\n", "# Sample from the base model #\n", "##############################\n", "\n", "# Create the text tokens to feed to the model.\n", "tokens = model.tokenizer.encode(prompt)\n", "tokens, mask = model.tokenizer.padded_tokens_and_mask(\n", " tokens, options['text_ctx']\n", ")\n", "\n", "# Create the classifier-free guidance tokens (empty)\n", "full_batch_size = batch_size * 2\n", "uncond_tokens, uncond_mask = model.tokenizer.padded_tokens_and_mask(\n", " [], options['text_ctx']\n", ")\n", "\n", "# Pack the tokens together into model kwargs.\n", "model_kwargs = dict(\n", " tokens=th.tensor(\n", " [tokens] * batch_size + [uncond_tokens] * batch_size, device=device\n", " ),\n", " mask=th.tensor(\n", " [mask] * batch_size + [uncond_mask] * batch_size,\n", " dtype=th.bool,\n", " device=device,\n", " ),\n", ")\n", "\n", "# Create a classifier-free guidance sampling function\n", "def model_fn(x_t, ts, **kwargs):\n", " half = x_t[: len(x_t) // 2]\n", " combined = th.cat([half, half], dim=0)\n", " model_out = model(combined, ts, **kwargs)\n", " eps, rest = model_out[:, :3], model_out[:, 3:]\n", " cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)\n", " half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)\n", " eps = th.cat([half_eps, half_eps], dim=0)\n", " return th.cat([eps, rest], dim=1)\n", "\n", "# Sample from the base model.\n", "model.del_cache()\n", "samples = diffusion.p_sample_loop(\n", " model_fn,\n", " (full_batch_size, 3, options[\"image_size\"], options[\"image_size\"]),\n", " device=device,\n", " clip_denoised=True,\n", " progress=True,\n", " model_kwargs=model_kwargs,\n", " cond_fn=None,\n", ")[:batch_size]\n", "model.del_cache()\n", "\n", "# Show the output\n", "show_images(samples)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "##############################\n", "# Upsample the 64x64 samples #\n", "##############################\n", "\n", "tokens = model_up.tokenizer.encode(prompt)\n", "tokens, mask = model_up.tokenizer.padded_tokens_and_mask(\n", " tokens, options_up['text_ctx']\n", ")\n", "\n", "# Create the model conditioning dict.\n", "model_kwargs = dict(\n", " # Low-res image to upsample.\n", " low_res=((samples+1)*127.5).round()/127.5 - 1,\n", "\n", " # Text tokens\n", " tokens=th.tensor(\n", " [tokens] * batch_size, device=device\n", " ),\n", " mask=th.tensor(\n", " [mask] * batch_size,\n", " dtype=th.bool,\n", " device=device,\n", " ),\n", ")\n", "\n", "# Sample from the base model.\n", "model_up.del_cache()\n", "up_shape = (batch_size, 3, options_up[\"image_size\"], options_up[\"image_size\"])\n", "up_samples = diffusion_up.ddim_sample_loop(\n", " model_up,\n", " up_shape,\n", " noise=th.randn(up_shape, device=device) * upsample_temp,\n", " device=device,\n", " clip_denoised=True,\n", " progress=True,\n", " model_kwargs=model_kwargs,\n", " cond_fn=None,\n", ")[:batch_size]\n", "model_up.del_cache()\n", "\n", "# Show the output\n", "show_images(up_samples)" ] } ], "metadata": { "interpreter": { "hash": "e7d6e62d90e7e85f9a0faa7f0b1d576302d7ae6108e9fe361594f8e1c8b05781" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 2 }