{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "ClothingGAN-Demo.ipynb",
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "Bm8iDDKC1LZo"
},
"source": [
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mfrashad/ClothingGAN/blob/master/ClothingGAN_Demo.ipynb)\n",
"# Clothing GAN demo\n",
"Notebook by [@mfrashad](https://mfrashad.com)\n",
"\n",
"\n",
"
\n",
"Make sure runtime type is GPU"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 200
},
"cellView": "form",
"id": "Kj8mGkmH0xgA",
"outputId": "6a793110-884d-4f59-89ec-9c5eced9b98a"
},
"source": [
"#@title Install dependencies (restart runtime after installing)\n",
"from IPython.display import Javascript\n",
"display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 200})'''))\n",
"!pip install ninja gradio fbpca boto3 requests==2.23.0 urllib3==1.25.11"
],
"execution_count": 1,
"outputs": [
{
"output_type": "display_data",
"data": {
"application/javascript": [
"google.colab.output.setIframeHeight(0, true, {maxHeight: 200})"
],
"text/plain": [
""
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"Collecting ninja\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/1d/de/393468f2a37fc2c1dc3a06afc37775e27fde2d16845424141d4da62c686d/ninja-1.10.0.post2-py3-none-manylinux1_x86_64.whl (107kB)\n",
"\r\u001b[K |███ | 10kB 19.1MB/s eta 0:00:01\r\u001b[K |██████ | 20kB 17.8MB/s eta 0:00:01\r\u001b[K |█████████▏ | 30kB 10.3MB/s eta 0:00:01\r\u001b[K |████████████▏ | 40kB 8.4MB/s eta 0:00:01\r\u001b[K |███████████████▎ | 51kB 5.3MB/s eta 0:00:01\r\u001b[K |██████████████████▎ | 61kB 5.4MB/s eta 0:00:01\r\u001b[K |█████████████████████▍ | 71kB 6.0MB/s eta 0:00:01\r\u001b[K |████████████████████████▍ | 81kB 6.4MB/s eta 0:00:01\r\u001b[K |███████████████████████████▍ | 92kB 6.5MB/s eta 0:00:01\r\u001b[K |██████████████████████████████▌ | 102kB 6.8MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 112kB 6.8MB/s \n",
"\u001b[?25hCollecting gradio\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/59/97/b7210489b201409175e63afa3307f8e067fe1289cc19a68003dfeef03f06/gradio-2.1.0-py3-none-any.whl (2.5MB)\n",
"\u001b[K |████████████████████████████████| 2.5MB 8.5MB/s \n",
"\u001b[?25hCollecting fbpca\n",
" Downloading https://files.pythonhosted.org/packages/a7/a5/2085d0645a4bb4f0b606251b0b7466c61326e4a471d445c1c3761a2d07bc/fbpca-1.0.tar.gz\n",
"Collecting boto3\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/2c/e1/2c6c374f043c3f22829563b7fb2bf28fe3dca7ce5994bc5ceeff0959d6c9/boto3-1.17.105-py2.py3-none-any.whl (131kB)\n",
"\u001b[K |████████████████████████████████| 133kB 26.6MB/s \n",
"\u001b[?25hRequirement already satisfied: requests==2.23.0 in /usr/local/lib/python3.7/dist-packages (2.23.0)\n",
"Collecting urllib3==1.25.11\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/56/aa/4ef5aa67a9a62505db124a5cb5262332d1d4153462eb8fd89c9fa41e5d92/urllib3-1.25.11-py2.py3-none-any.whl (127kB)\n",
"\u001b[K |████████████████████████████████| 133kB 22.3MB/s \n",
"\u001b[?25hRequirement already satisfied: pillow in /usr/local/lib/python3.7/dist-packages (from gradio) (7.1.2)\n",
"Collecting ffmpy\n",
" Downloading https://files.pythonhosted.org/packages/bf/e2/947df4b3d666bfdd2b0c6355d215c45d2d40f929451cb29a8a2995b29788/ffmpy-0.3.0.tar.gz\n",
"Requirement already satisfied: Flask>=1.1.1 in /usr/local/lib/python3.7/dist-packages (from gradio) (1.1.4)\n",
"Collecting flask-cachebuster\n",
" Downloading https://files.pythonhosted.org/packages/74/47/f3e1fedfaad965c81c2f17234636d72f71450f1b4522ca26d2b7eb4a0a74/Flask-CacheBuster-1.0.0.tar.gz\n",
"Collecting pycryptodome\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/ad/16/9627ab0493894a11c68e46000dbcc82f578c8ff06bc2980dcd016aea9bd3/pycryptodome-3.10.1-cp35-abi3-manylinux2010_x86_64.whl (1.9MB)\n",
"\u001b[K |████████████████████████████████| 1.9MB 26.5MB/s \n",
"\u001b[?25hRequirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from gradio) (1.1.5)\n",
"Collecting Flask-Cors>=3.0.8\n",
" Downloading https://files.pythonhosted.org/packages/db/84/901e700de86604b1c4ef4b57110d4e947c218b9997adf5d38fa7da493bce/Flask_Cors-3.0.10-py2.py3-none-any.whl\n",
"Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from gradio) (1.4.1)\n",
"Collecting paramiko\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/95/19/124e9287b43e6ff3ebb9cdea3e5e8e88475a873c05ccdf8b7e20d2c4201e/paramiko-2.7.2-py2.py3-none-any.whl (206kB)\n",
"\u001b[K |████████████████████████████████| 215kB 49.0MB/s \n",
"\u001b[?25hCollecting analytics-python\n",
" Downloading https://files.pythonhosted.org/packages/30/81/2f447982f8d5dec5b56c10ca9ac53e5de2b2e9e2bdf7e091a05731f21379/analytics_python-1.3.1-py2.py3-none-any.whl\n",
"Collecting markdown2\n",
" Downloading https://files.pythonhosted.org/packages/5d/be/3924cc1c0e12030b5225de2b4521f1dc729730773861475de26be64a0d2b/markdown2-2.4.0-py2.py3-none-any.whl\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from gradio) (1.19.5)\n",
"Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from gradio) (3.2.2)\n",
"Collecting Flask-Login\n",
" Downloading https://files.pythonhosted.org/packages/2b/83/ac5bf3279f969704fc1e63f050c50e10985e50fd340e6069ec7e09df5442/Flask_Login-0.5.0-py2.py3-none-any.whl\n",
"Collecting jmespath<1.0.0,>=0.7.1\n",
" Downloading https://files.pythonhosted.org/packages/07/cb/5f001272b6faeb23c1c9e0acc04d48eaaf5c862c17709d20e3469c6e0139/jmespath-0.10.0-py2.py3-none-any.whl\n",
"Collecting s3transfer<0.5.0,>=0.4.0\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/63/d0/693477c688348654ddc21dcdce0817653a294aa43f41771084c25e7ff9c7/s3transfer-0.4.2-py2.py3-none-any.whl (79kB)\n",
"\u001b[K |████████████████████████████████| 81kB 12.5MB/s \n",
"\u001b[?25hCollecting botocore<1.21.0,>=1.20.105\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/95/da/3417300f85ba5173e8dba9248b9ae8bcb74a8aac1c92fa3d257f99073b9e/botocore-1.20.105-py2.py3-none-any.whl (7.7MB)\n",
"\u001b[K |████████████████████████████████| 7.7MB 45.1MB/s \n",
"\u001b[?25hRequirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests==2.23.0) (2021.5.30)\n",
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests==2.23.0) (3.0.4)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests==2.23.0) (2.10)\n",
"Requirement already satisfied: click<8.0,>=5.1 in /usr/local/lib/python3.7/dist-packages (from Flask>=1.1.1->gradio) (7.1.2)\n",
"Requirement already satisfied: Jinja2<3.0,>=2.10.1 in /usr/local/lib/python3.7/dist-packages (from Flask>=1.1.1->gradio) (2.11.3)\n",
"Requirement already satisfied: Werkzeug<2.0,>=0.15 in /usr/local/lib/python3.7/dist-packages (from Flask>=1.1.1->gradio) (1.0.1)\n",
"Requirement already satisfied: itsdangerous<2.0,>=0.24 in /usr/local/lib/python3.7/dist-packages (from Flask>=1.1.1->gradio) (1.1.0)\n",
"Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas->gradio) (2.8.1)\n",
"Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas->gradio) (2018.9)\n",
"Requirement already satisfied: Six in /usr/local/lib/python3.7/dist-packages (from Flask-Cors>=3.0.8->gradio) (1.15.0)\n",
"Collecting pynacl>=1.0.1\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/9d/57/2f5e6226a674b2bcb6db531e8b383079b678df5b10cdaa610d6cf20d77ba/PyNaCl-1.4.0-cp35-abi3-manylinux1_x86_64.whl (961kB)\n",
"\u001b[K |████████████████████████████████| 962kB 53.6MB/s \n",
"\u001b[?25hCollecting cryptography>=2.5\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/b2/26/7af637e6a7e87258b963f1731c5982fb31cd507f0d90d91836e446955d02/cryptography-3.4.7-cp36-abi3-manylinux2014_x86_64.whl (3.2MB)\n",
"\u001b[K |████████████████████████████████| 3.2MB 49.0MB/s \n",
"\u001b[?25hCollecting bcrypt>=3.1.3\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/26/70/6d218afbe4c73538053c1016dd631e8f25fffc10cd01f5c272d7acf3c03d/bcrypt-3.2.0-cp36-abi3-manylinux2010_x86_64.whl (63kB)\n",
"\u001b[K |████████████████████████████████| 71kB 11.2MB/s \n",
"\u001b[?25hCollecting backoff==1.10.0\n",
" Downloading https://files.pythonhosted.org/packages/f0/32/c5dd4f4b0746e9ec05ace2a5045c1fc375ae67ee94355344ad6c7005fd87/backoff-1.10.0-py2.py3-none-any.whl\n",
"Collecting monotonic>=1.5\n",
" Downloading https://files.pythonhosted.org/packages/9a/67/7e8406a29b6c45be7af7740456f7f37025f0506ae2e05fb9009a53946860/monotonic-1.6-py2.py3-none-any.whl\n",
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->gradio) (0.10.0)\n",
"Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->gradio) (1.3.1)\n",
"Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->gradio) (2.4.7)\n",
"Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.7/dist-packages (from Jinja2<3.0,>=2.10.1->Flask>=1.1.1->gradio) (2.0.1)\n",
"Requirement already satisfied: cffi>=1.4.1 in /usr/local/lib/python3.7/dist-packages (from pynacl>=1.0.1->paramiko->gradio) (1.14.5)\n",
"Requirement already satisfied: pycparser in /usr/local/lib/python3.7/dist-packages (from cffi>=1.4.1->pynacl>=1.0.1->paramiko->gradio) (2.20)\n",
"Building wheels for collected packages: fbpca, ffmpy, flask-cachebuster\n",
" Building wheel for fbpca (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for fbpca: filename=fbpca-1.0-cp37-none-any.whl size=11376 sha256=4b14cfd952b104d56a985021caf774f906fed7ca5fae1d1f41c570c6c0ea121c\n",
" Stored in directory: /root/.cache/pip/wheels/53/a2/dd/9b66cf53dbc58cec1e613d216689e5fa946d3e7805c30f60dc\n",
" Building wheel for ffmpy (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for ffmpy: filename=ffmpy-0.3.0-cp37-none-any.whl size=4710 sha256=571d113a8f5d748045ade970eca5e1a4bab4ed32cfb262851649d238839f682d\n",
" Stored in directory: /root/.cache/pip/wheels/cc/ac/c4/bef572cb7e52bfca170046f567e64858632daf77e0f34e5a74\n",
" Building wheel for flask-cachebuster (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for flask-cachebuster: filename=Flask_CacheBuster-1.0.0-cp37-none-any.whl size=3372 sha256=37a3576c476072ff54679fb45e5e3c1150e6a1790d23da3873cc2394f3be7741\n",
" Stored in directory: /root/.cache/pip/wheels/9f/fc/a7/ab5712c3ace9a8f97276465cc2937316ab8063c1fea488ea77\n",
"Successfully built fbpca ffmpy flask-cachebuster\n",
"\u001b[31mERROR: datascience 0.10.6 has requirement folium==0.2.1, but you'll have folium 0.8.3 which is incompatible.\u001b[0m\n",
"Installing collected packages: ninja, ffmpy, flask-cachebuster, pycryptodome, Flask-Cors, pynacl, cryptography, bcrypt, paramiko, backoff, monotonic, analytics-python, markdown2, Flask-Login, gradio, fbpca, jmespath, urllib3, botocore, s3transfer, boto3\n",
" Found existing installation: urllib3 1.24.3\n",
" Uninstalling urllib3-1.24.3:\n",
" Successfully uninstalled urllib3-1.24.3\n",
"Successfully installed Flask-Cors-3.0.10 Flask-Login-0.5.0 analytics-python-1.3.1 backoff-1.10.0 bcrypt-3.2.0 boto3-1.17.105 botocore-1.20.105 cryptography-3.4.7 fbpca-1.0 ffmpy-0.3.0 flask-cachebuster-1.0.0 gradio-2.1.0 jmespath-0.10.0 markdown2-2.4.0 monotonic-1.6 ninja-1.10.0.post2 paramiko-2.7.2 pycryptodome-3.10.1 pynacl-1.4.0 s3transfer-0.4.2 urllib3-1.25.11\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "qwf-gggHtA_t",
"outputId": "68f7afee-6889-4f9c-97d0-de9ba931c206"
},
"source": [
"!git clone https://github.com/mfrashad/ClothingGAN.git\n",
"%cd ClothingGAN/"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
"Cloning into 'ClothingGAN'...\n",
"remote: Enumerating objects: 333, done.\u001b[K\n",
"remote: Counting objects: 100% (60/60), done.\u001b[K\n",
"remote: Compressing objects: 100% (37/37), done.\u001b[K\n",
"remote: Total 333 (delta 38), reused 22 (delta 22), pack-reused 273\u001b[K\n",
"Receiving objects: 100% (333/333), 47.08 MiB | 51.89 MiB/s, done.\n",
"Resolving deltas: 100% (108/108), done.\n",
"/content/ClothingGAN\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 153
},
"cellView": "form",
"id": "2BYsIETVtGnF",
"outputId": "e3b5c2ac-9c18-4917-de0f-1f2c126aa696"
},
"source": [
"#@title Install other dependencies\n",
"from IPython.display import Javascript\n",
"display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 200})'''))\n",
"!git submodule update --init --recursive\n",
"!python -c \"import nltk; nltk.download('wordnet')\""
],
"execution_count": 2,
"outputs": [
{
"output_type": "display_data",
"data": {
"application/javascript": [
"google.colab.output.setIframeHeight(0, true, {maxHeight: 200})"
],
"text/plain": [
""
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"Submodule 'stylegan/stylegan_tf' (https://github.com/NVlabs/stylegan.git) registered for path 'models/stylegan/stylegan_tf'\n",
"Submodule 'stylegan2/stylegan2-pytorch' (https://github.com/harskish/stylegan2-pytorch.git) registered for path 'models/stylegan2/stylegan2-pytorch'\n",
"Cloning into '/content/ClothingGAN/models/stylegan/stylegan_tf'...\n",
"Cloning into '/content/ClothingGAN/models/stylegan2/stylegan2-pytorch'...\n",
"Submodule path 'models/stylegan/stylegan_tf': checked out '66813a32aac5045fcde72751522a0c0ba963f6f2'\n",
"Submodule path 'models/stylegan2/stylegan2-pytorch': checked out '91ea2a7a4320701535466cce942c9e099d65670e'\n",
"[nltk_data] Downloading package wordnet to /root/nltk_data...\n",
"[nltk_data] Unzipping corpora/wordnet.zip.\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"cellView": "form",
"id": "dJg91yvSwKi3",
"outputId": "15c2285c-e8da-45b9-bdc2-9bab3709b270"
},
"source": [
"#@title Load Model\n",
"selected_model = 'lookbook'\n",
"\n",
"# Load model\n",
"from IPython.utils import io\n",
"import torch\n",
"import PIL\n",
"import numpy as np\n",
"import ipywidgets as widgets\n",
"from PIL import Image\n",
"import imageio\n",
"from models import get_instrumented_model\n",
"from decomposition import get_or_compute\n",
"from config import Config\n",
"from skimage import img_as_ubyte\n",
"\n",
"# Speed up computation\n",
"torch.autograd.set_grad_enabled(False)\n",
"torch.backends.cudnn.benchmark = True\n",
"\n",
"# Specify model to use\n",
"config = Config(\n",
" model='StyleGAN2',\n",
" layer='style',\n",
" output_class=selected_model,\n",
" components=80,\n",
" use_w=True,\n",
" batch_size=5_000, # style layer quite small\n",
")\n",
"\n",
"inst = get_instrumented_model(config.model, config.output_class,\n",
" config.layer, torch.device('cuda'), use_w=config.use_w)\n",
"\n",
"path_to_components = get_or_compute(config, inst)\n",
"\n",
"model = inst.model\n",
"\n",
"comps = np.load(path_to_components)\n",
"lst = comps.files\n",
"latent_dirs = []\n",
"latent_stdevs = []\n",
"\n",
"load_activations = False\n",
"\n",
"for item in lst:\n",
" if load_activations:\n",
" if item == 'act_comp':\n",
" for i in range(comps[item].shape[0]):\n",
" latent_dirs.append(comps[item][i])\n",
" if item == 'act_stdev':\n",
" for i in range(comps[item].shape[0]):\n",
" latent_stdevs.append(comps[item][i])\n",
" else:\n",
" if item == 'lat_comp':\n",
" for i in range(comps[item].shape[0]):\n",
" latent_dirs.append(comps[item][i])\n",
" if item == 'lat_stdev':\n",
" for i in range(comps[item].shape[0]):\n",
" latent_stdevs.append(comps[item][i])"
],
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"text": [
"StyleGAN2: Optimized CUDA op FusedLeakyReLU not available, using native PyTorch fallback.\n",
"StyleGAN2: Optimized CUDA op UpFirDn2d not available, using native PyTorch fallback.\n",
"Downloading https://drive.google.com/uc?export=download&id=1-F-RMkbHUv_S_k-_olh43mu5rDUMGYKe\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"cellView": "form",
"id": "uCR_3Ghos2kK"
},
"source": [
"#@title Define functions\n",
"from ipywidgets import fixed\n",
"\n",
"# Taken from https://github.com/alexanderkuk/log-progress\n",
"def log_progress(sequence, every=1, size=None, name='Items'):\n",
" from ipywidgets import IntProgress, HTML, VBox\n",
" from IPython.display import display\n",
"\n",
" is_iterator = False\n",
" if size is None:\n",
" try:\n",
" size = len(sequence)\n",
" except TypeError:\n",
" is_iterator = True\n",
" if size is not None:\n",
" if every is None:\n",
" if size <= 200:\n",
" every = 1\n",
" else:\n",
" every = int(size / 200) # every 0.5%\n",
" else:\n",
" assert every is not None, 'sequence is iterator, set every'\n",
"\n",
" if is_iterator:\n",
" progress = IntProgress(min=0, max=1, value=1)\n",
" progress.bar_style = 'info'\n",
" else:\n",
" progress = IntProgress(min=0, max=size, value=0)\n",
" label = HTML()\n",
" box = VBox(children=[label, progress])\n",
" display(box)\n",
"\n",
" index = 0\n",
" try:\n",
" for index, record in enumerate(sequence, 1):\n",
" if index == 1 or index % every == 0:\n",
" if is_iterator:\n",
" label.value = '{name}: {index} / ?'.format(\n",
" name=name,\n",
" index=index\n",
" )\n",
" else:\n",
" progress.value = index\n",
" label.value = u'{name}: {index} / {size}'.format(\n",
" name=name,\n",
" index=index,\n",
" size=size\n",
" )\n",
" yield record\n",
" except:\n",
" progress.bar_style = 'danger'\n",
" raise\n",
" else:\n",
" progress.bar_style = 'success'\n",
" progress.value = index\n",
" label.value = \"{name}: {index}\".format(\n",
" name=name,\n",
" index=str(index or '?')\n",
" )\n",
"\n",
"def name_direction(sender):\n",
" if not text.value:\n",
" print('Please name the direction before saving')\n",
" return\n",
" \n",
" if num in named_directions.values():\n",
" target_key = list(named_directions.keys())[list(named_directions.values()).index(num)]\n",
" print(f'Direction already named: {target_key}')\n",
" print(f'Overwriting... ')\n",
" del(named_directions[target_key])\n",
" named_directions[text.value] = [num, start_layer.value, end_layer.value]\n",
" save_direction(random_dir, text.value)\n",
" for item in named_directions:\n",
" print(item, named_directions[item])\n",
"\n",
"def save_direction(direction, filename):\n",
" filename += \".npy\"\n",
" np.save(filename, direction, allow_pickle=True, fix_imports=True)\n",
" print(f'Latent direction saved as {filename}')\n",
"\n",
"def mix_w(w1, w2, content, style):\n",
" for i in range(0,5):\n",
" w2[i] = w1[i] * (1 - content) + w2[i] * content\n",
"\n",
" for i in range(5, 16):\n",
" w2[i] = w1[i] * (1 - style) + w2[i] * style\n",
" \n",
" return w2\n",
"\n",
"def display_sample_pytorch(seed, truncation, directions, distances, scale, start, end, w=None, disp=True, save=None, noise_spec=None):\n",
" # blockPrint()\n",
" model.truncation = truncation\n",
" if w is None:\n",
" w = model.sample_latent(1, seed=seed).detach().cpu().numpy()\n",
" w = [w]*model.get_max_latents() # one per layer\n",
" else:\n",
" w = [np.expand_dims(x, 0) for x in w]\n",
" \n",
" for l in range(start, end):\n",
" for i in range(len(directions)):\n",
" w[l] = w[l] + directions[i] * distances[i] * scale\n",
" \n",
" torch.cuda.empty_cache()\n",
" #save image and display\n",
" out = model.sample_np(w)\n",
" final_im = Image.fromarray((out * 255).astype(np.uint8)).resize((500,500),Image.LANCZOS)\n",
" \n",
" \n",
" if save is not None:\n",
" if disp == False:\n",
" print(save)\n",
" final_im.save(f'out/{seed}_{save:05}.png')\n",
" if disp:\n",
" display(final_im)\n",
" \n",
" return final_im\n",
"\n",
"def generate_mov(seed, truncation, direction_vec, scale, layers, n_frames, out_name = 'out', noise_spec = None, loop=True):\n",
" \"\"\"Generates a mov moving back and forth along the chosen direction vector\"\"\"\n",
" # Example of reading a generated set of images, and storing as MP4.\n",
" %mkdir out\n",
" movieName = f'out/{out_name}.mp4'\n",
" offset = -10\n",
" step = 20 / n_frames\n",
" imgs = []\n",
" for i in log_progress(range(n_frames), name = \"Generating frames\"):\n",
" print(f'\\r{i} / {n_frames}', end='')\n",
" w = model.sample_latent(1, seed=seed).cpu().numpy()\n",
"\n",
" model.truncation = truncation\n",
" w = [w]*model.get_max_latents() # one per layer\n",
" for l in layers:\n",
" if l <= model.get_max_latents():\n",
" w[l] = w[l] + direction_vec * offset * scale\n",
"\n",
" #save image and display\n",
" out = model.sample_np(w)\n",
" final_im = Image.fromarray((out * 255).astype(np.uint8))\n",
" imgs.append(out)\n",
" #increase offset\n",
" offset += step\n",
" if loop:\n",
" imgs += imgs[::-1]\n",
" with imageio.get_writer(movieName, mode='I') as writer:\n",
" for image in log_progress(list(imgs), name = \"Creating animation\"):\n",
" writer.append_data(img_as_ubyte(image))"
],
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 640
},
"cellView": "form",
"id": "jneXxZnNwHo5",
"outputId": "c8e2b76e-3a00-47f5-ba2d-51606f09ee93"
},
"source": [
"#@title Demo UI\n",
"import gradio as gr\n",
"import numpy as np\n",
"\n",
"def generate_image(seed1, seed2, content, style, truncation, c0, c1, c2, c3, c4, c5, c6, start_layer, end_layer):\n",
" seed1 = int(seed1)\n",
" seed2 = int(seed2)\n",
"\n",
" scale = 1\n",
" params = {'c0': c0,\n",
" 'c1': c1,\n",
" 'c2': c2,\n",
" 'c3': c3,\n",
" 'c4': c4,\n",
" 'c5': c5,\n",
" 'c6': c6}\n",
"\n",
" param_indexes = {'c0': 0,\n",
" 'c1': 1,\n",
" 'c2': 2,\n",
" 'c3': 3,\n",
" 'c4': 4,\n",
" 'c5': 5,\n",
" 'c6': 6}\n",
"\n",
" directions = []\n",
" distances = []\n",
" for k, v in params.items():\n",
" directions.append(latent_dirs[param_indexes[k]])\n",
" distances.append(v)\n",
"\n",
" w1 = model.sample_latent(1, seed=seed1).detach().cpu().numpy()\n",
" w1 = [w1]*model.get_max_latents() # one per layer\n",
" im1 = model.sample_np(w1)\n",
"\n",
" w2 = model.sample_latent(1, seed=seed2).detach().cpu().numpy()\n",
" w2 = [w2]*model.get_max_latents() # one per layer\n",
" im2 = model.sample_np(w2)\n",
" combined_im = np.concatenate([im1, im2], axis=1)\n",
" input_im = Image.fromarray((combined_im * 255).astype(np.uint8))\n",
" \n",
"\n",
" mixed_w = mix_w(w1, w2, content, style)\n",
" return input_im, display_sample_pytorch(seed1, truncation, directions, distances, scale, int(start_layer), int(end_layer), w=mixed_w, disp=False)\n",
"\n",
"truncation = gr.inputs.Slider(minimum=0, maximum=1, default=0.5, label=\"Truncation\")\n",
"start_layer = gr.inputs.Number(default=0, label=\"Start Layer\")\n",
"end_layer = gr.inputs.Number(default=14, label=\"End Layer\")\n",
"seed1 = gr.inputs.Number(default=0, label=\"Seed 1\")\n",
"seed2 = gr.inputs.Number(default=0, label=\"Seed 2\")\n",
"content = gr.inputs.Slider(label=\"Structure\", minimum=0, maximum=1, default=0.5)\n",
"style = gr.inputs.Slider(label=\"Style\", minimum=0, maximum=1, default=0.5)\n",
"\n",
"slider_max_val = 20\n",
"slider_min_val = -20\n",
"slider_step = 1\n",
"\n",
"c0 = gr.inputs.Slider(label=\"Sleeve & Size\", minimum=slider_min_val, maximum=slider_max_val, default=0)\n",
"c1 = gr.inputs.Slider(label=\"Dress - Jacket\", minimum=slider_min_val, maximum=slider_max_val, default=0)\n",
"c2 = gr.inputs.Slider(label=\"Female Coat\", minimum=slider_min_val, maximum=slider_max_val, default=0)\n",
"c3 = gr.inputs.Slider(label=\"Coat\", minimum=slider_min_val, maximum=slider_max_val, default=0)\n",
"c4 = gr.inputs.Slider(label=\"Graphics\", minimum=slider_min_val, maximum=slider_max_val, default=0)\n",
"c5 = gr.inputs.Slider(label=\"Dark\", minimum=slider_min_val, maximum=slider_max_val, default=0)\n",
"c6 = gr.inputs.Slider(label=\"Less Cleavage\", minimum=slider_min_val, maximum=slider_max_val, default=0)\n",
"\n",
"\n",
"scale = 1\n",
"\n",
"inputs = [seed1, seed2, content, style, truncation, c0, c1, c2, c3, c4, c5, c6, start_layer, end_layer]\n",
"\n",
"gr.Interface(generate_image, inputs, [\"image\", \"image\"], live=True, title=\"ClothingGAN\").launch()"
],
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"text": [
"Colab notebook detected. To show errors in colab notebook, set `debug=True` in `launch()`\n",
"This share link will expire in 24 hours. If you need a permanent link, visit: https://gradio.app/introducing-hosted (NEW!)\n",
"Running on External URL: https://10342.gradio.app\n",
"Interface loading below...\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
" \n",
" "
],
"text/plain": [
""
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(,\n",
" 'http://127.0.0.1:7860/',\n",
" 'https://10342.gradio.app')"
]
},
"metadata": {
"tags": []
},
"execution_count": 5
}
]
}
]
}