{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19", "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5", "execution": { "iopub.execute_input": "2021-12-22T15:46:09.354280Z", "iopub.status.busy": "2021-12-22T15:46:09.353419Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading model...\n" ] }, { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "\u001b[0;32m/tmp/ipykernel_4175/2765481374.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0mdevice\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'cuda:0'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 23\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mdalle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mset_seed\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 24\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mdalle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodels\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mDalle\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mclip\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/Downloads/minDALL-E/dalle/utils/__init__.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mconfig\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0msampling\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/Downloads/minDALL-E/dalle/utils/utils.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mhashlib\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtarfile\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 13\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mclip\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnumpy\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/.local/lib/python3.8/site-packages/torch/__init__.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 611\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 612\u001b[0m \u001b[0;31m# Shared memory manager needs to know the exact location of manager executable\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 613\u001b[0;31m \u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_initExtension\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmanager_path\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 614\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0mmanager_path\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 615\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/.local/lib/python3.8/site-packages/torch/cuda/__init__.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 683\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mprofiler\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 684\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnvtx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 685\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mamp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m~/.local/lib/python3.8/site-packages/torch/cuda/amp/__init__.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mautocast_mode\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mautocast\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcustom_fwd\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcustom_bwd\u001b[0m \u001b[0;31m# noqa: F401\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mgrad_scaler\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mGradScaler\u001b[0m \u001b[0;31m# noqa: F401\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/.local/lib/python3.8/site-packages/torch/cuda/amp/autocast_mode.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mcollections\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mnumpy\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0mHAS_NUMPY\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mModuleNotFoundError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/numpy/__init__.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 151\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mcore\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 152\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mcompat\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 153\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mlib\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 154\u001b[0m \u001b[0;31m# NOTE: to be revisited following future namespace cleanup.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 155\u001b[0m \u001b[0;31m# See gh-14454 and gh-15672 for discussion.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/numpy/lib/__init__.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0;31m# Private submodules\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mtype_check\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mindex_tricks\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 26\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mfunction_base\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mnanfunctions\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/numpy/lib/index_tricks.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnumpy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatrixlib\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mmatrixlib\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 13\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mfunction_base\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mdiff\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 14\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mnumpy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmultiarray\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mravel_multi_index\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0munravel_index\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mnumpy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moverrides\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mset_module\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/lib/python3.8/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_find_and_load\u001b[0;34m(name, import_)\u001b[0m\n", "\u001b[0;32m/usr/lib/python3.8/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_find_and_load_unlocked\u001b[0;34m(name, import_)\u001b[0m\n", "\u001b[0;32m/usr/lib/python3.8/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_load_unlocked\u001b[0;34m(spec)\u001b[0m\n", "\u001b[0;32m/usr/lib/python3.8/importlib/_bootstrap_external.py\u001b[0m in \u001b[0;36mexec_module\u001b[0;34m(self, module)\u001b[0m\n", "\u001b[0;32m/usr/lib/python3.8/importlib/_bootstrap_external.py\u001b[0m in \u001b[0;36mget_code\u001b[0;34m(self, fullname)\u001b[0m\n", "\u001b[0;32m/usr/lib/python3.8/importlib/_bootstrap_external.py\u001b[0m in \u001b[0;36mget_data\u001b[0;34m(self, path)\u001b[0m\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "#@title # [minDALL-E](https://github.com/kakaobrain/minDALL-E) Inferencing\n", "#modified version of mega b notebook https://colab.research.google.com/drive/1Gg7-c7LrUTNfQ-Fk-BVNCe9kvedZZsAh?usp=sharing\n", "\n", "import os\n", "from IPython.display import clear_output \n", "\n", "!git clone -q https://github.com/kakaobrain/minDALL-E.git\n", "\n", "%cd minDALL-E/\n", "\n", "!pip install -q tokenizers>=0.10.2\n", "!pip install -q pyflakes>=2.2.0\n", "!pip install -q tqdm>=4.46.0\n", "!pip install -q pytorch-lightning>=1.5\n", "!pip install -q einops\n", "!pip install -q omegaconf\n", "!pip install -q git+https://github.com/openai/CLIP.git\n", "\n", "clear_output()\n", "print(\"Downloading model...\")\n", "device = 'cuda:0'\n", "\n", "from dalle.utils.utils import set_seed\n", "from dalle.models import Dalle\n", "import clip\n", "import math\n", "\n", "%matplotlib inline\n", "\n", "model = Dalle.from_pretrained('minDALL-E/1.3B') # This will automatically download the pretrained model.\n", "model.to(device=device)\n", "model_clip, preprocess_clip = clip.load(\"ViT-B/32\", device=device)\n", "model_clip.to(device=device)\n", "clear_output()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": { "iopub.execute_input": "2021-12-22T03:40:27.709632Z", "iopub.status.busy": "2021-12-22T03:40:27.709269Z", "iopub.status.idle": "2021-12-22T03:40:27.800815Z", "shell.execute_reply": "2021-12-22T03:40:27.799634Z", "shell.execute_reply.started": "2021-12-22T03:40:27.709509Z" } }, "outputs": [], "source": [ "from matplotlib import pyplot as plt\n", "import clip\n", "import torch\n", "import numpy as np\n", "from dalle.models import Dalle\n", "from dalle.utils.utils import clip_score\n", "\n", "\n", "prompt = \"A large bedroom with three armchairs next to a painting of the sun\"\n", "print(prompt)\n", "\n", "num_candidates = 256\n", "images = []\n", "\n", "\n", "torch.cuda.empty_cache()\n", "\n", "for i in range(int(num_candidates / 32)):\n", " with torch.no_grad():\n", " images.append(model.sampling(prompt=prompt,\n", " top_k=128,\n", " top_p=None,\n", " softmax_temperature=0.7,\n", " num_candidates=32,\n", " device=device).cpu().numpy())\n", " \n", " torch.cuda.empty_cache()\n", "\n", "images = np.concatenate(images)\n", "images = np.transpose(images, (0, 2, 3, 1))\n", "\n", "with torch.no_grad():\n", " rank = clip_score(prompt=prompt,\n", " images=images,\n", " model_clip=model_clip,\n", " preprocess_clip=preprocess_clip,\n", " device=device)\n", "\n", "torch.cuda.empty_cache()\n", "\n", "images = images[rank]\n", "\n", "n = num_candidates\n", "\n", "fig = plt.figure(figsize=(6*int(math.sqrt(n)), 6*int(math.sqrt(n))))\n", "for i in range(n):\n", " ax = fig.add_subplot(int(math.sqrt(n)), int(math.sqrt(n)), i+1)\n", " ax.imshow(images[i])\n", " ax.set_axis_off()\n", "\n", "plt.tight_layout()\n", "plt.show() " ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 4 }