{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Copyright 2020 Erik Härkönen. All rights reserved.\n", "# This file is licensed to you under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License. You may obtain a copy\n", "# of the License at http://www.apache.org/licenses/LICENSE-2.0\n", "\n", "# Unless required by applicable law or agreed to in writing, software distributed under\n", "# the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS\n", "# OF ANY KIND, either express or implied. See the License for the specific language\n", "# governing permissions and limitations under the License.\n", "\n", "# Recreate StyleGAN1 style mixing image grid\n", "from IPython.display import Image as IPyImage\n", "from IPython.core.display import HTML \n", "#IPyImage('style_mixing.png')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "from notebook_init import *\n", "\n", "layer_names = [f'generator.layers.{i}' for i in range(14)] # annotate all shapes\n", "inst = get_instrumented_model('BigGAN-512', 'promontory', layer_names, device)\n", "model = inst.model\n", "inst.close()\n", "\n", "torch.manual_seed(0)\n", "np.random.seed(0)\n", "\n", "makedirs('out', exist_ok=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def generate(trunc, cls, custom_seeds=[], layers=[0, 2, 4], N=5):\n", " inst.remove_edits()\n", " model.set_output_class(cls)\n", " \n", " custom_seeds = custom_seeds[:N] # limit to N images\n", " seeds = np.random.randint(np.iinfo(np.int32).max, size=N)\n", " seeds[:len(custom_seeds)] = custom_seeds\n", " print(seeds, trunc, cls)\n", " \n", " latents = [model.sample_latent(1, truncation=trunc, seed=s) for s in seeds]\n", " latent_a = latents[0]\n", " out_a = model.sample_np(latent_a)\n", "\n", " outputs = [model.sample_np(z) for z in latents]\n", " empty = np.ones_like(outputs[0])\n", "\n", " # Inputs B\n", " row0 = np.hstack([empty] + outputs[1:])\n", " rows = [row0]\n", "\n", " # Mix style starting from layer l\n", " for layer_num in layers:\n", " inst.close()\n", " layer_name = f'generator.layers.{layer_num}'\n", " inst.retain_layer(layer_name)\n", "\n", " imgs = []\n", "\n", " imgs.append(out_a)\n", " model.partial_forward(latent_a, layer_name)\n", " feat_a = inst.retained_features()[layer_name].detach()\n", "\n", " # Generate hybrids\n", " for i in range(1, len(latents)):\n", " # Use latent of B, early activations of A\n", " inst.edit_layer(layer_name, ablation=1.0, replacement=feat_a)\n", " out_b = model.sample_np(latents[i])\n", " imgs.append(out_b)\n", "\n", " rows.append(np.hstack(imgs))\n", "\n", " grid = np.vstack(rows)\n", " im = Image.fromarray((grid*255).astype(np.uint8))\n", " im.save(f'out/grid_{cls}.png')\n", "\n", " plt.figure(figsize=(15,15))\n", " plt.imshow(grid)\n", " plt.axis('off')\n", " plt.show()\n", "\n", " from IPython.display import Javascript, display\n", " \n", " if 0:\n", " display(Javascript(\"\"\"\n", " require(\n", " [\"base/js/dialog\"], \n", " function(dialog) {\n", " dialog.modal({\n", " title: 'Debug',\n", " body: 'Please close viewer window before continuing',\n", " buttons: {\n", " 'Close': {}\n", " }\n", " });\n", " }\n", " );\n", " \"\"\"))\n", " im.show()\n", " \n", "\n", "#generate(0.95, 'irish_setter', [716257571, 216337755, 602801999, 1027629257])\n", "generate(0.95, 'barn', [237774802, 1498010115, 105741908, 857168362, 639216961])\n", "#generate(0.95, 'coral_reef')\n", "#generate(0.95, 'lighthouse', [1573600108])\n", "#generate(0.95, 'seashore', [1891640828, 130794492, 1321047179, 750963629])\n", "generate(0.95, 'castle', [995150904, 530702035])\n", "#generate(0.95, 'golden_retriever', [])\n", "#generate(0.95, 'goldfinch', [])\n", "#generate(0.95, 'indigo_bunting', [1624898412])\n", "#generate(0.95, 'red_wine', [])\n", "#generate(0.95, 'anemone_fish', [11610217])\n", "#generate(0.95, 'earthstar', [])\n", "#generate(0.95, 'beer_bottle', [485603871, 527619953])\n", "#generate(0.8, 'beer_glass', [])\n", "#generate(0.95, 'church', [628962584, 1700971930]) # , 371570218, 1137007398, 1412786664\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Show every layer for given content and style pair\n", "def blend(cls, seed1, seed2):\n", " inst.remove_edits()\n", " model.set_output_class(cls)\n", " z1 = model.sample_latent(seed=seed1)\n", " z2 = model.sample_latent(seed=seed2)\n", "\n", " out1 = model.sample_np(z1)\n", " out2 = model.sample_np(z2)\n", "\n", " intermed = []\n", " for layer in range(0, 6, 1):\n", " inst.close()\n", " inst.remove_edits()\n", " layer_name = f'generator.layers.{layer}'\n", " inst.retain_layer(layer_name)\n", "\n", " # Content features up to layer\n", " model.partial_forward(z1, layer_name)\n", " feat = inst.retained_features()[layer_name].detach()\n", "\n", " # New style\n", " inst.edit_layer(layer_name, ablation=1.0, replacement=feat)\n", " intermed.append(model.sample_np(z2))\n", "\n", " imgs = np.hstack([out1] + intermed[::-1] + [out2])\n", " im = Image.fromarray((imgs*255).astype(np.uint8))\n", " im.save(f'out/{cls}_style_layer_comp.png')\n", "\n", " # Style blending by latent interpolation (does not keep geometry consistent)\n", " inst.remove_edits()\n", " lerp = lambda x,y,a : a*x+(1-a)*y\n", " imgs_latent_interp = []\n", " for a in np.linspace(0.0, 1.0, 8):\n", " z = lerp(z2, z1, a)\n", " imgs_latent_interp.append(model.sample_np(z))\n", "\n", " imgs_latent_interp = np.hstack(imgs_latent_interp)\n", " im = Image.fromarray((imgs_latent_interp*255).astype(np.uint8))\n", " im.save(f'out/{cls}_style_latent_comp.png')\n", "\n", "\n", "blend('castle', 995150904, 1171165061)\n", "blend('church', 628962584, 1700971930)\n", " " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" } }, "nbformat": 4, "nbformat_minor": 2 }