{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Copyright 2020 Erik Härkönen. All rights reserved.\n", "# This file is licensed to you under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License. You may obtain a copy\n", "# of the License at http://www.apache.org/licenses/LICENSE-2.0\n", "\n", "# Unless required by applicable law or agreed to in writing, software distributed under\n", "# the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS\n", "# OF ANY KIND, either express or implied. See the License for the specific language\n", "# governing permissions and limitations under the License.\n", "\n", "# Figure: BigGAN edit transferability between classes\n", "%matplotlib inline\n", "from notebook_init import *\n", "\n", "rand = lambda : np.random.randint(np.iinfo(np.int32).max)\n", "outdir = Path('out/figures/edit_transferability')\n", "makedirs(outdir, exist_ok=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "inst = get_instrumented_model('BigGAN-512', 'husky', 'generator.gen_z', device, inst=inst)\n", "model = inst.model\n", "model.truncation = 0.7\n", "\n", "pc_config = Config(components=80, n=1_000_000,\n", " layer='generator.gen_z', model='BigGAN-512', output_class='husky')\n", "dump_name = get_or_compute(pc_config, inst)\n", "\n", "with np.load(dump_name) as data:\n", " lat_comp = data['lat_comp']\n", " lat_mean = data['lat_mean']\n", " lat_std = data['lat_stdev']\n", "\n", "# name: component_idx, layer_start, layer_end, strength\n", "edits = {\n", " 'translate_x': ( 0, 0, 15, -3.0),\n", " 'zoom': ( 6, 0, 15, 2.0),\n", " 'clouds': (54, 7, 10, 15.0),\n", " #'dark_fg': (51, 7, 10, 20.0),\n", " 'sunlight': (33, 7, 10, 25.0),\n", " #'silouette': (13, 7, 10, -20.0),\n", " #'grass_bg': (69, 3, 7, -20.0),\n", "}\n", "\n", "def apply_offset(z, idx, start, end, sigma):\n", " lat = z if isinstance(z, list) else [z]*model.get_max_latents()\n", " for i in range(start, end):\n", " lat[i] = lat[i] + lat_comp[idx]*lat_std[idx]*sigma\n", " return lat\n", "\n", "show = True\n", "\n", "# good geom seeds: 2145371585\n", "# good style seeds: 337336281, 2075156369, 311784160\n", "\n", "for _ in range(1):\n", " \n", " # Type 1: geometric edit - transfers well\n", " \n", " seed1_geom = 2145371585\n", " seed2_geom = 2046317118\n", " print('Seeds geom:', [seed1_geom, seed2_geom])\n", " z1 = model.sample_latent(1, seed=seed1_geom).cpu().numpy()\n", " z2 = model.sample_latent(1, seed=seed2_geom).cpu().numpy()\n", "\n", " model.set_output_class('husky')\n", " base_husky = model.sample_np(z1)\n", " zoom_husky = model.sample_np(apply_offset(z1, *edits['zoom']))\n", " transl_husky = model.sample_np(apply_offset(z1, *edits['translate_x']))\n", " img_geom1 = np.hstack([base_husky, zoom_husky, transl_husky])\n", "\n", " model.set_output_class('castle')\n", " base_castle = model.sample_np(z2)\n", " zoom_castle = model.sample_np(apply_offset(z2, *edits['zoom']))\n", " transl_castle = model.sample_np(apply_offset(z2, *edits['translate_x']))\n", " img_geom2 = np.hstack([base_castle, zoom_castle, transl_castle])\n", "\n", " \n", " # Type 2: style edit - often transfers\n", " \n", " seed1_style = 417482011 #rand()\n", " seed2_style = 1026291813\n", " print('Seeds style:', [seed1_style, seed2_style])\n", " z1 = model.sample_latent(1, seed=seed1_style).cpu().numpy()\n", " z2 = model.sample_latent(1, seed=seed2_style).cpu().numpy()\n", "\n", " model.set_output_class('lighthouse')\n", " base_lighthouse = model.sample_np(z2)\n", " edit1_lighthouse = model.sample_np(apply_offset(z2, *edits['clouds']))\n", " edit2_lighthouse = model.sample_np(apply_offset(z2, *edits['sunlight']))\n", " img_style2 = np.hstack([base_lighthouse, edit1_lighthouse, edit2_lighthouse])\n", " \n", " model.set_output_class('barn')\n", " base_barn = model.sample_np(z1)\n", " edit1_barn = model.sample_np(apply_offset(z1, *edits['clouds']))\n", " edit2_barn = model.sample_np(apply_offset(z1, *edits['sunlight']))\n", " img_style1 = np.hstack([base_barn, edit1_barn, edit2_barn])\n", " \n", " \n", " grid = np.vstack([img_geom1, img_geom2, img_style1, img_style2])\n", " \n", " if show:\n", " plt.figure(figsize=(12,12))\n", " plt.imshow(grid)\n", " plt.axis('off')\n", " plt.show()\n", " else:\n", " Image.fromarray((255*grid).astype(np.uint8)).save(outdir / f'{seed1_geom}_{seed2_geom}_{seed1_style}_{seed2_style}_transf.jpg')\n", " \n", " # Save individual frames\n", " Image.fromarray((255*base_husky).astype(np.uint8)).save(outdir / 'geom_husky_1.png')\n", " Image.fromarray((255*zoom_husky).astype(np.uint8)).save(outdir / 'geom_husky_2.png')\n", " Image.fromarray((255*transl_husky).astype(np.uint8)).save(outdir / 'geom_husky_3.png')\n", " Image.fromarray((255*base_castle).astype(np.uint8)).save(outdir / 'geom_castle_1.png')\n", " Image.fromarray((255*zoom_castle).astype(np.uint8)).save(outdir / 'geom_castle_2.png')\n", " Image.fromarray((255*transl_castle).astype(np.uint8)).save(outdir / 'geom_castle_3.png')\n", " \n", " Image.fromarray((255*base_lighthouse).astype(np.uint8)).save(outdir / 'style_lighthouse_1.png')\n", " Image.fromarray((255*edit1_lighthouse).astype(np.uint8)).save(outdir / 'style_lighthouse_2.png')\n", " Image.fromarray((255*edit2_lighthouse).astype(np.uint8)).save(outdir / 'style_lighthouse_3.png')\n", " Image.fromarray((255*base_barn).astype(np.uint8)).save(outdir / 'style_barn_1.png')\n", " Image.fromarray((255*edit1_barn).astype(np.uint8)).save(outdir / 'style_barn_2.png')\n", " Image.fromarray((255*edit2_barn).astype(np.uint8)).save(outdir / 'style_barn_3.png')\n", " \n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "file_extension": ".py", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" }, "mimetype": "text/x-python", "name": "python", "npconvert_exporter": "python", "pygments_lexer": "ipython3", "version": 3 }, "nbformat": 4, "nbformat_minor": 2 }