{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "ClothingGAN-Demo.ipynb", "provenance": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "Bm8iDDKC1LZo" }, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mfrashad/ClothingGAN/blob/master/ClothingGAN_Demo.ipynb)\n", "# Clothing GAN demo\n", "Notebook by [@mfrashad](https://mfrashad.com)\n", "\n", "\n", "
\n", "Make sure runtime type is GPU" ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 200 }, "cellView": "form", "id": "Kj8mGkmH0xgA", "outputId": "6a793110-884d-4f59-89ec-9c5eced9b98a" }, "source": [ "#@title Install dependencies (restart runtime after installing)\n", "from IPython.display import Javascript\n", "display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 200})'''))\n", "!pip install ninja gradio fbpca boto3 requests==2.23.0 urllib3==1.25.11" ], "execution_count": 1, "outputs": [ { "output_type": "display_data", "data": { "application/javascript": [ "google.colab.output.setIframeHeight(0, true, {maxHeight: 200})" ], "text/plain": [ "" ] }, "metadata": { "tags": [] } }, { "output_type": "stream", "text": [ "Collecting ninja\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/1d/de/393468f2a37fc2c1dc3a06afc37775e27fde2d16845424141d4da62c686d/ninja-1.10.0.post2-py3-none-manylinux1_x86_64.whl (107kB)\n", "\r\u001b[K |███ | 10kB 19.1MB/s eta 0:00:01\r\u001b[K |██████ | 20kB 17.8MB/s eta 0:00:01\r\u001b[K |█████████▏ | 30kB 10.3MB/s eta 0:00:01\r\u001b[K |████████████▏ | 40kB 8.4MB/s eta 0:00:01\r\u001b[K |███████████████▎ | 51kB 5.3MB/s eta 0:00:01\r\u001b[K |██████████████████▎ | 61kB 5.4MB/s eta 0:00:01\r\u001b[K |█████████████████████▍ | 71kB 6.0MB/s eta 0:00:01\r\u001b[K |████████████████████████▍ | 81kB 6.4MB/s eta 0:00:01\r\u001b[K |███████████████████████████▍ | 92kB 6.5MB/s eta 0:00:01\r\u001b[K |██████████████████████████████▌ | 102kB 6.8MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 112kB 6.8MB/s \n", "\u001b[?25hCollecting gradio\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/59/97/b7210489b201409175e63afa3307f8e067fe1289cc19a68003dfeef03f06/gradio-2.1.0-py3-none-any.whl (2.5MB)\n", "\u001b[K |████████████████████████████████| 2.5MB 8.5MB/s \n", "\u001b[?25hCollecting fbpca\n", " Downloading https://files.pythonhosted.org/packages/a7/a5/2085d0645a4bb4f0b606251b0b7466c61326e4a471d445c1c3761a2d07bc/fbpca-1.0.tar.gz\n", "Collecting boto3\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/2c/e1/2c6c374f043c3f22829563b7fb2bf28fe3dca7ce5994bc5ceeff0959d6c9/boto3-1.17.105-py2.py3-none-any.whl (131kB)\n", "\u001b[K |████████████████████████████████| 133kB 26.6MB/s \n", "\u001b[?25hRequirement already satisfied: requests==2.23.0 in /usr/local/lib/python3.7/dist-packages (2.23.0)\n", "Collecting urllib3==1.25.11\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/56/aa/4ef5aa67a9a62505db124a5cb5262332d1d4153462eb8fd89c9fa41e5d92/urllib3-1.25.11-py2.py3-none-any.whl (127kB)\n", "\u001b[K |████████████████████████████████| 133kB 22.3MB/s \n", "\u001b[?25hRequirement already satisfied: pillow in /usr/local/lib/python3.7/dist-packages (from gradio) (7.1.2)\n", "Collecting ffmpy\n", " Downloading https://files.pythonhosted.org/packages/bf/e2/947df4b3d666bfdd2b0c6355d215c45d2d40f929451cb29a8a2995b29788/ffmpy-0.3.0.tar.gz\n", "Requirement already satisfied: Flask>=1.1.1 in /usr/local/lib/python3.7/dist-packages (from gradio) (1.1.4)\n", "Collecting flask-cachebuster\n", " Downloading https://files.pythonhosted.org/packages/74/47/f3e1fedfaad965c81c2f17234636d72f71450f1b4522ca26d2b7eb4a0a74/Flask-CacheBuster-1.0.0.tar.gz\n", "Collecting pycryptodome\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/ad/16/9627ab0493894a11c68e46000dbcc82f578c8ff06bc2980dcd016aea9bd3/pycryptodome-3.10.1-cp35-abi3-manylinux2010_x86_64.whl (1.9MB)\n", "\u001b[K |████████████████████████████████| 1.9MB 26.5MB/s \n", "\u001b[?25hRequirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from gradio) (1.1.5)\n", "Collecting Flask-Cors>=3.0.8\n", " Downloading https://files.pythonhosted.org/packages/db/84/901e700de86604b1c4ef4b57110d4e947c218b9997adf5d38fa7da493bce/Flask_Cors-3.0.10-py2.py3-none-any.whl\n", "Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from gradio) (1.4.1)\n", "Collecting paramiko\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/95/19/124e9287b43e6ff3ebb9cdea3e5e8e88475a873c05ccdf8b7e20d2c4201e/paramiko-2.7.2-py2.py3-none-any.whl (206kB)\n", "\u001b[K |████████████████████████████████| 215kB 49.0MB/s \n", "\u001b[?25hCollecting analytics-python\n", " Downloading https://files.pythonhosted.org/packages/30/81/2f447982f8d5dec5b56c10ca9ac53e5de2b2e9e2bdf7e091a05731f21379/analytics_python-1.3.1-py2.py3-none-any.whl\n", "Collecting markdown2\n", " Downloading https://files.pythonhosted.org/packages/5d/be/3924cc1c0e12030b5225de2b4521f1dc729730773861475de26be64a0d2b/markdown2-2.4.0-py2.py3-none-any.whl\n", "Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from gradio) (1.19.5)\n", "Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from gradio) (3.2.2)\n", "Collecting Flask-Login\n", " Downloading https://files.pythonhosted.org/packages/2b/83/ac5bf3279f969704fc1e63f050c50e10985e50fd340e6069ec7e09df5442/Flask_Login-0.5.0-py2.py3-none-any.whl\n", "Collecting jmespath<1.0.0,>=0.7.1\n", " Downloading https://files.pythonhosted.org/packages/07/cb/5f001272b6faeb23c1c9e0acc04d48eaaf5c862c17709d20e3469c6e0139/jmespath-0.10.0-py2.py3-none-any.whl\n", "Collecting s3transfer<0.5.0,>=0.4.0\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/63/d0/693477c688348654ddc21dcdce0817653a294aa43f41771084c25e7ff9c7/s3transfer-0.4.2-py2.py3-none-any.whl (79kB)\n", "\u001b[K |████████████████████████████████| 81kB 12.5MB/s \n", "\u001b[?25hCollecting botocore<1.21.0,>=1.20.105\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/95/da/3417300f85ba5173e8dba9248b9ae8bcb74a8aac1c92fa3d257f99073b9e/botocore-1.20.105-py2.py3-none-any.whl (7.7MB)\n", "\u001b[K |████████████████████████████████| 7.7MB 45.1MB/s \n", "\u001b[?25hRequirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests==2.23.0) (2021.5.30)\n", "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests==2.23.0) (3.0.4)\n", "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests==2.23.0) (2.10)\n", "Requirement already satisfied: click<8.0,>=5.1 in /usr/local/lib/python3.7/dist-packages (from Flask>=1.1.1->gradio) (7.1.2)\n", "Requirement already satisfied: Jinja2<3.0,>=2.10.1 in /usr/local/lib/python3.7/dist-packages (from Flask>=1.1.1->gradio) (2.11.3)\n", "Requirement already satisfied: Werkzeug<2.0,>=0.15 in /usr/local/lib/python3.7/dist-packages (from Flask>=1.1.1->gradio) (1.0.1)\n", "Requirement already satisfied: itsdangerous<2.0,>=0.24 in /usr/local/lib/python3.7/dist-packages (from Flask>=1.1.1->gradio) (1.1.0)\n", "Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas->gradio) (2.8.1)\n", "Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas->gradio) (2018.9)\n", "Requirement already satisfied: Six in /usr/local/lib/python3.7/dist-packages (from Flask-Cors>=3.0.8->gradio) (1.15.0)\n", "Collecting pynacl>=1.0.1\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/9d/57/2f5e6226a674b2bcb6db531e8b383079b678df5b10cdaa610d6cf20d77ba/PyNaCl-1.4.0-cp35-abi3-manylinux1_x86_64.whl (961kB)\n", "\u001b[K |████████████████████████████████| 962kB 53.6MB/s \n", "\u001b[?25hCollecting cryptography>=2.5\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/b2/26/7af637e6a7e87258b963f1731c5982fb31cd507f0d90d91836e446955d02/cryptography-3.4.7-cp36-abi3-manylinux2014_x86_64.whl (3.2MB)\n", "\u001b[K |████████████████████████████████| 3.2MB 49.0MB/s \n", "\u001b[?25hCollecting bcrypt>=3.1.3\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/26/70/6d218afbe4c73538053c1016dd631e8f25fffc10cd01f5c272d7acf3c03d/bcrypt-3.2.0-cp36-abi3-manylinux2010_x86_64.whl (63kB)\n", "\u001b[K |████████████████████████████████| 71kB 11.2MB/s \n", "\u001b[?25hCollecting backoff==1.10.0\n", " Downloading https://files.pythonhosted.org/packages/f0/32/c5dd4f4b0746e9ec05ace2a5045c1fc375ae67ee94355344ad6c7005fd87/backoff-1.10.0-py2.py3-none-any.whl\n", "Collecting monotonic>=1.5\n", " Downloading https://files.pythonhosted.org/packages/9a/67/7e8406a29b6c45be7af7740456f7f37025f0506ae2e05fb9009a53946860/monotonic-1.6-py2.py3-none-any.whl\n", "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->gradio) (0.10.0)\n", "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->gradio) (1.3.1)\n", "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->gradio) (2.4.7)\n", "Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.7/dist-packages (from Jinja2<3.0,>=2.10.1->Flask>=1.1.1->gradio) (2.0.1)\n", "Requirement already satisfied: cffi>=1.4.1 in /usr/local/lib/python3.7/dist-packages (from pynacl>=1.0.1->paramiko->gradio) (1.14.5)\n", "Requirement already satisfied: pycparser in /usr/local/lib/python3.7/dist-packages (from cffi>=1.4.1->pynacl>=1.0.1->paramiko->gradio) (2.20)\n", "Building wheels for collected packages: fbpca, ffmpy, flask-cachebuster\n", " Building wheel for fbpca (setup.py) ... \u001b[?25l\u001b[?25hdone\n", " Created wheel for fbpca: filename=fbpca-1.0-cp37-none-any.whl size=11376 sha256=4b14cfd952b104d56a985021caf774f906fed7ca5fae1d1f41c570c6c0ea121c\n", " Stored in directory: /root/.cache/pip/wheels/53/a2/dd/9b66cf53dbc58cec1e613d216689e5fa946d3e7805c30f60dc\n", " Building wheel for ffmpy (setup.py) ... \u001b[?25l\u001b[?25hdone\n", " Created wheel for ffmpy: filename=ffmpy-0.3.0-cp37-none-any.whl size=4710 sha256=571d113a8f5d748045ade970eca5e1a4bab4ed32cfb262851649d238839f682d\n", " Stored in directory: /root/.cache/pip/wheels/cc/ac/c4/bef572cb7e52bfca170046f567e64858632daf77e0f34e5a74\n", " Building wheel for flask-cachebuster (setup.py) ... \u001b[?25l\u001b[?25hdone\n", " Created wheel for flask-cachebuster: filename=Flask_CacheBuster-1.0.0-cp37-none-any.whl size=3372 sha256=37a3576c476072ff54679fb45e5e3c1150e6a1790d23da3873cc2394f3be7741\n", " Stored in directory: /root/.cache/pip/wheels/9f/fc/a7/ab5712c3ace9a8f97276465cc2937316ab8063c1fea488ea77\n", "Successfully built fbpca ffmpy flask-cachebuster\n", "\u001b[31mERROR: datascience 0.10.6 has requirement folium==0.2.1, but you'll have folium 0.8.3 which is incompatible.\u001b[0m\n", "Installing collected packages: ninja, ffmpy, flask-cachebuster, pycryptodome, Flask-Cors, pynacl, cryptography, bcrypt, paramiko, backoff, monotonic, analytics-python, markdown2, Flask-Login, gradio, fbpca, jmespath, urllib3, botocore, s3transfer, boto3\n", " Found existing installation: urllib3 1.24.3\n", " Uninstalling urllib3-1.24.3:\n", " Successfully uninstalled urllib3-1.24.3\n", "Successfully installed Flask-Cors-3.0.10 Flask-Login-0.5.0 analytics-python-1.3.1 backoff-1.10.0 bcrypt-3.2.0 boto3-1.17.105 botocore-1.20.105 cryptography-3.4.7 fbpca-1.0 ffmpy-0.3.0 flask-cachebuster-1.0.0 gradio-2.1.0 jmespath-0.10.0 markdown2-2.4.0 monotonic-1.6 ninja-1.10.0.post2 paramiko-2.7.2 pycryptodome-3.10.1 pynacl-1.4.0 s3transfer-0.4.2 urllib3-1.25.11\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "qwf-gggHtA_t", "outputId": "68f7afee-6889-4f9c-97d0-de9ba931c206" }, "source": [ "!git clone https://github.com/mfrashad/ClothingGAN.git\n", "%cd ClothingGAN/" ], "execution_count": 1, "outputs": [ { "output_type": "stream", "text": [ "Cloning into 'ClothingGAN'...\n", "remote: Enumerating objects: 333, done.\u001b[K\n", "remote: Counting objects: 100% (60/60), done.\u001b[K\n", "remote: Compressing objects: 100% (37/37), done.\u001b[K\n", "remote: Total 333 (delta 38), reused 22 (delta 22), pack-reused 273\u001b[K\n", "Receiving objects: 100% (333/333), 47.08 MiB | 51.89 MiB/s, done.\n", "Resolving deltas: 100% (108/108), done.\n", "/content/ClothingGAN\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 153 }, "cellView": "form", "id": "2BYsIETVtGnF", "outputId": "e3b5c2ac-9c18-4917-de0f-1f2c126aa696" }, "source": [ "#@title Install other dependencies\n", "from IPython.display import Javascript\n", "display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 200})'''))\n", "!git submodule update --init --recursive\n", "!python -c \"import nltk; nltk.download('wordnet')\"" ], "execution_count": 2, "outputs": [ { "output_type": "display_data", "data": { "application/javascript": [ "google.colab.output.setIframeHeight(0, true, {maxHeight: 200})" ], "text/plain": [ "" ] }, "metadata": { "tags": [] } }, { "output_type": "stream", "text": [ "Submodule 'stylegan/stylegan_tf' (https://github.com/NVlabs/stylegan.git) registered for path 'models/stylegan/stylegan_tf'\n", "Submodule 'stylegan2/stylegan2-pytorch' (https://github.com/harskish/stylegan2-pytorch.git) registered for path 'models/stylegan2/stylegan2-pytorch'\n", "Cloning into '/content/ClothingGAN/models/stylegan/stylegan_tf'...\n", "Cloning into '/content/ClothingGAN/models/stylegan2/stylegan2-pytorch'...\n", "Submodule path 'models/stylegan/stylegan_tf': checked out '66813a32aac5045fcde72751522a0c0ba963f6f2'\n", "Submodule path 'models/stylegan2/stylegan2-pytorch': checked out '91ea2a7a4320701535466cce942c9e099d65670e'\n", "[nltk_data] Downloading package wordnet to /root/nltk_data...\n", "[nltk_data] Unzipping corpora/wordnet.zip.\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "cellView": "form", "id": "dJg91yvSwKi3", "outputId": "15c2285c-e8da-45b9-bdc2-9bab3709b270" }, "source": [ "#@title Load Model\n", "selected_model = 'lookbook'\n", "\n", "# Load model\n", "from IPython.utils import io\n", "import torch\n", "import PIL\n", "import numpy as np\n", "import ipywidgets as widgets\n", "from PIL import Image\n", "import imageio\n", "from models import get_instrumented_model\n", "from decomposition import get_or_compute\n", "from config import Config\n", "from skimage import img_as_ubyte\n", "\n", "# Speed up computation\n", "torch.autograd.set_grad_enabled(False)\n", "torch.backends.cudnn.benchmark = True\n", "\n", "# Specify model to use\n", "config = Config(\n", " model='StyleGAN2',\n", " layer='style',\n", " output_class=selected_model,\n", " components=80,\n", " use_w=True,\n", " batch_size=5_000, # style layer quite small\n", ")\n", "\n", "inst = get_instrumented_model(config.model, config.output_class,\n", " config.layer, torch.device('cuda'), use_w=config.use_w)\n", "\n", "path_to_components = get_or_compute(config, inst)\n", "\n", "model = inst.model\n", "\n", "comps = np.load(path_to_components)\n", "lst = comps.files\n", "latent_dirs = []\n", "latent_stdevs = []\n", "\n", "load_activations = False\n", "\n", "for item in lst:\n", " if load_activations:\n", " if item == 'act_comp':\n", " for i in range(comps[item].shape[0]):\n", " latent_dirs.append(comps[item][i])\n", " if item == 'act_stdev':\n", " for i in range(comps[item].shape[0]):\n", " latent_stdevs.append(comps[item][i])\n", " else:\n", " if item == 'lat_comp':\n", " for i in range(comps[item].shape[0]):\n", " latent_dirs.append(comps[item][i])\n", " if item == 'lat_stdev':\n", " for i in range(comps[item].shape[0]):\n", " latent_stdevs.append(comps[item][i])" ], "execution_count": 3, "outputs": [ { "output_type": "stream", "text": [ "StyleGAN2: Optimized CUDA op FusedLeakyReLU not available, using native PyTorch fallback.\n", "StyleGAN2: Optimized CUDA op UpFirDn2d not available, using native PyTorch fallback.\n", "Downloading https://drive.google.com/uc?export=download&id=1-F-RMkbHUv_S_k-_olh43mu5rDUMGYKe\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "cellView": "form", "id": "uCR_3Ghos2kK" }, "source": [ "#@title Define functions\n", "from ipywidgets import fixed\n", "\n", "# Taken from https://github.com/alexanderkuk/log-progress\n", "def log_progress(sequence, every=1, size=None, name='Items'):\n", " from ipywidgets import IntProgress, HTML, VBox\n", " from IPython.display import display\n", "\n", " is_iterator = False\n", " if size is None:\n", " try:\n", " size = len(sequence)\n", " except TypeError:\n", " is_iterator = True\n", " if size is not None:\n", " if every is None:\n", " if size <= 200:\n", " every = 1\n", " else:\n", " every = int(size / 200) # every 0.5%\n", " else:\n", " assert every is not None, 'sequence is iterator, set every'\n", "\n", " if is_iterator:\n", " progress = IntProgress(min=0, max=1, value=1)\n", " progress.bar_style = 'info'\n", " else:\n", " progress = IntProgress(min=0, max=size, value=0)\n", " label = HTML()\n", " box = VBox(children=[label, progress])\n", " display(box)\n", "\n", " index = 0\n", " try:\n", " for index, record in enumerate(sequence, 1):\n", " if index == 1 or index % every == 0:\n", " if is_iterator:\n", " label.value = '{name}: {index} / ?'.format(\n", " name=name,\n", " index=index\n", " )\n", " else:\n", " progress.value = index\n", " label.value = u'{name}: {index} / {size}'.format(\n", " name=name,\n", " index=index,\n", " size=size\n", " )\n", " yield record\n", " except:\n", " progress.bar_style = 'danger'\n", " raise\n", " else:\n", " progress.bar_style = 'success'\n", " progress.value = index\n", " label.value = \"{name}: {index}\".format(\n", " name=name,\n", " index=str(index or '?')\n", " )\n", "\n", "def name_direction(sender):\n", " if not text.value:\n", " print('Please name the direction before saving')\n", " return\n", " \n", " if num in named_directions.values():\n", " target_key = list(named_directions.keys())[list(named_directions.values()).index(num)]\n", " print(f'Direction already named: {target_key}')\n", " print(f'Overwriting... ')\n", " del(named_directions[target_key])\n", " named_directions[text.value] = [num, start_layer.value, end_layer.value]\n", " save_direction(random_dir, text.value)\n", " for item in named_directions:\n", " print(item, named_directions[item])\n", "\n", "def save_direction(direction, filename):\n", " filename += \".npy\"\n", " np.save(filename, direction, allow_pickle=True, fix_imports=True)\n", " print(f'Latent direction saved as {filename}')\n", "\n", "def mix_w(w1, w2, content, style):\n", " for i in range(0,5):\n", " w2[i] = w1[i] * (1 - content) + w2[i] * content\n", "\n", " for i in range(5, 16):\n", " w2[i] = w1[i] * (1 - style) + w2[i] * style\n", " \n", " return w2\n", "\n", "def display_sample_pytorch(seed, truncation, directions, distances, scale, start, end, w=None, disp=True, save=None, noise_spec=None):\n", " # blockPrint()\n", " model.truncation = truncation\n", " if w is None:\n", " w = model.sample_latent(1, seed=seed).detach().cpu().numpy()\n", " w = [w]*model.get_max_latents() # one per layer\n", " else:\n", " w = [np.expand_dims(x, 0) for x in w]\n", " \n", " for l in range(start, end):\n", " for i in range(len(directions)):\n", " w[l] = w[l] + directions[i] * distances[i] * scale\n", " \n", " torch.cuda.empty_cache()\n", " #save image and display\n", " out = model.sample_np(w)\n", " final_im = Image.fromarray((out * 255).astype(np.uint8)).resize((500,500),Image.LANCZOS)\n", " \n", " \n", " if save is not None:\n", " if disp == False:\n", " print(save)\n", " final_im.save(f'out/{seed}_{save:05}.png')\n", " if disp:\n", " display(final_im)\n", " \n", " return final_im\n", "\n", "def generate_mov(seed, truncation, direction_vec, scale, layers, n_frames, out_name = 'out', noise_spec = None, loop=True):\n", " \"\"\"Generates a mov moving back and forth along the chosen direction vector\"\"\"\n", " # Example of reading a generated set of images, and storing as MP4.\n", " %mkdir out\n", " movieName = f'out/{out_name}.mp4'\n", " offset = -10\n", " step = 20 / n_frames\n", " imgs = []\n", " for i in log_progress(range(n_frames), name = \"Generating frames\"):\n", " print(f'\\r{i} / {n_frames}', end='')\n", " w = model.sample_latent(1, seed=seed).cpu().numpy()\n", "\n", " model.truncation = truncation\n", " w = [w]*model.get_max_latents() # one per layer\n", " for l in layers:\n", " if l <= model.get_max_latents():\n", " w[l] = w[l] + direction_vec * offset * scale\n", "\n", " #save image and display\n", " out = model.sample_np(w)\n", " final_im = Image.fromarray((out * 255).astype(np.uint8))\n", " imgs.append(out)\n", " #increase offset\n", " offset += step\n", " if loop:\n", " imgs += imgs[::-1]\n", " with imageio.get_writer(movieName, mode='I') as writer:\n", " for image in log_progress(list(imgs), name = \"Creating animation\"):\n", " writer.append_data(img_as_ubyte(image))" ], "execution_count": 4, "outputs": [] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 640 }, "cellView": "form", "id": "jneXxZnNwHo5", "outputId": "c8e2b76e-3a00-47f5-ba2d-51606f09ee93" }, "source": [ "#@title Demo UI\n", "import gradio as gr\n", "import numpy as np\n", "\n", "def generate_image(seed1, seed2, content, style, truncation, c0, c1, c2, c3, c4, c5, c6, start_layer, end_layer):\n", " seed1 = int(seed1)\n", " seed2 = int(seed2)\n", "\n", " scale = 1\n", " params = {'c0': c0,\n", " 'c1': c1,\n", " 'c2': c2,\n", " 'c3': c3,\n", " 'c4': c4,\n", " 'c5': c5,\n", " 'c6': c6}\n", "\n", " param_indexes = {'c0': 0,\n", " 'c1': 1,\n", " 'c2': 2,\n", " 'c3': 3,\n", " 'c4': 4,\n", " 'c5': 5,\n", " 'c6': 6}\n", "\n", " directions = []\n", " distances = []\n", " for k, v in params.items():\n", " directions.append(latent_dirs[param_indexes[k]])\n", " distances.append(v)\n", "\n", " w1 = model.sample_latent(1, seed=seed1).detach().cpu().numpy()\n", " w1 = [w1]*model.get_max_latents() # one per layer\n", " im1 = model.sample_np(w1)\n", "\n", " w2 = model.sample_latent(1, seed=seed2).detach().cpu().numpy()\n", " w2 = [w2]*model.get_max_latents() # one per layer\n", " im2 = model.sample_np(w2)\n", " combined_im = np.concatenate([im1, im2], axis=1)\n", " input_im = Image.fromarray((combined_im * 255).astype(np.uint8))\n", " \n", "\n", " mixed_w = mix_w(w1, w2, content, style)\n", " return input_im, display_sample_pytorch(seed1, truncation, directions, distances, scale, int(start_layer), int(end_layer), w=mixed_w, disp=False)\n", "\n", "truncation = gr.inputs.Slider(minimum=0, maximum=1, default=0.5, label=\"Truncation\")\n", "start_layer = gr.inputs.Number(default=0, label=\"Start Layer\")\n", "end_layer = gr.inputs.Number(default=14, label=\"End Layer\")\n", "seed1 = gr.inputs.Number(default=0, label=\"Seed 1\")\n", "seed2 = gr.inputs.Number(default=0, label=\"Seed 2\")\n", "content = gr.inputs.Slider(label=\"Structure\", minimum=0, maximum=1, default=0.5)\n", "style = gr.inputs.Slider(label=\"Style\", minimum=0, maximum=1, default=0.5)\n", "\n", "slider_max_val = 20\n", "slider_min_val = -20\n", "slider_step = 1\n", "\n", "c0 = gr.inputs.Slider(label=\"Sleeve & Size\", minimum=slider_min_val, maximum=slider_max_val, default=0)\n", "c1 = gr.inputs.Slider(label=\"Dress - Jacket\", minimum=slider_min_val, maximum=slider_max_val, default=0)\n", "c2 = gr.inputs.Slider(label=\"Female Coat\", minimum=slider_min_val, maximum=slider_max_val, default=0)\n", "c3 = gr.inputs.Slider(label=\"Coat\", minimum=slider_min_val, maximum=slider_max_val, default=0)\n", "c4 = gr.inputs.Slider(label=\"Graphics\", minimum=slider_min_val, maximum=slider_max_val, default=0)\n", "c5 = gr.inputs.Slider(label=\"Dark\", minimum=slider_min_val, maximum=slider_max_val, default=0)\n", "c6 = gr.inputs.Slider(label=\"Less Cleavage\", minimum=slider_min_val, maximum=slider_max_val, default=0)\n", "\n", "\n", "scale = 1\n", "\n", "inputs = [seed1, seed2, content, style, truncation, c0, c1, c2, c3, c4, c5, c6, start_layer, end_layer]\n", "\n", "gr.Interface(generate_image, inputs, [\"image\", \"image\"], live=True, title=\"ClothingGAN\").launch()" ], "execution_count": 5, "outputs": [ { "output_type": "stream", "text": [ "Colab notebook detected. To show errors in colab notebook, set `debug=True` in `launch()`\n", "This share link will expire in 24 hours. If you need a permanent link, visit: https://gradio.app/introducing-hosted (NEW!)\n", "Running on External URL: https://10342.gradio.app\n", "Interface loading below...\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": { "tags": [] } }, { "output_type": "execute_result", "data": { "text/plain": [ "(,\n", " 'http://127.0.0.1:7860/',\n", " 'https://10342.gradio.app')" ] }, "metadata": { "tags": [] }, "execution_count": 5 } ] } ] }