{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Copyright 2020 Erik Härkönen. All rights reserved.\n", "# This file is licensed to you under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License. You may obtain a copy\n", "# of the License at http://www.apache.org/licenses/LICENSE-2.0\n", "\n", "# Unless required by applicable law or agreed to in writing, software distributed under\n", "# the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS\n", "# OF ANY KIND, either express or implied. See the License for the specific language\n", "# governing permissions and limitations under the License.\n", "\n", "# Comparison to GAN steerability and InterfaceGAN\n", "%matplotlib inline\n", "from notebook_init import *\n", "import pickle\n", "\n", "out_root = Path('out/figures/steerability_comp')\n", "makedirs(out_root, exist_ok=True)\n", "rand = lambda : np.random.randint(np.iinfo(np.int32).max)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def show_strip(frames):\n", " plt.figure(figsize=(20,20))\n", " plt.axis('off')\n", " plt.imshow(np.hstack(pad_frames(frames, 64)))\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "normalize = lambda t : t / np.sqrt(np.sum(t.reshape(-1)**2))\n", "\n", "def compute(\n", " model,\n", " lat_mean,\n", " prefix,\n", " imgclass,\n", " seeds,\n", " d_ours,\n", " l_start,\n", " l_end,\n", " scale_ours,\n", " d_sup, # single or one per layer\n", " scale_sup,\n", " center=True\n", "):\n", " model.set_output_class(imgclass)\n", " makedirs(out_root / imgclass, exist_ok=True)\n", " \n", " for seed in seeds:\n", " print(seed)\n", " deltas = [d_ours, d_sup]\n", " scales = [scale_ours, scale_sup]\n", " ranges = [(l_start, l_end), (0, model.get_max_latents())]\n", " names = ['ours', 'supervised']\n", "\n", " for delta, name, scale, l_range in zip(deltas, names, scales, ranges):\n", " lat_base = model.sample_latent(1, seed=seed).cpu().numpy()\n", "\n", " # Shift latent to lie on mean along given direction\n", " if center:\n", " y = normalize(d_sup) # assume ground truth\n", " dotp = np.sum((lat_base - lat_mean) * y, axis=-1, keepdims=True)\n", " lat_base = lat_base - dotp * y\n", " \n", " # Convert single delta to per-layer delta (to support Steerability StyleGAN)\n", " if delta.shape[0] > 1:\n", " #print('Unstacking delta')\n", " *d_per_layer, = delta # might have per-layer scales, don't normalize\n", " else:\n", " d_per_layer = [normalize(delta)]*model.get_max_latents()\n", " \n", " frames = []\n", " n_frames = 5\n", " for a in np.linspace(-1.0, 1.0, n_frames):\n", " w = [lat_base]*model.get_max_latents()\n", " for l in range(l_range[0], l_range[1]):\n", " w[l] = w[l] + a*d_per_layer[l]*scale\n", " frames.append(model.sample_np(w))\n", "\n", " for i, frame in enumerate(frames):\n", " Image.fromarray(np.uint8(frame*255)).save(\n", " out_root / imgclass / f'{prefix}_{name}_{seed}_{i}.png')\n", " \n", " strip = np.hstack(pad_frames(frames, 64))\n", " plt.figure(figsize=(12,12))\n", " plt.imshow(strip)\n", " plt.axis('off')\n", " plt.tight_layout()\n", " plt.title(f'{prefix} - {name}, scale={scale}')\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true, "tags": [] }, "outputs": [], "source": [ "# BigGAN-512\n", "\n", "inst = get_instrumented_model('BigGAN-512', 'husky', 'generator.gen_z', device, inst=inst)\n", "model = inst.model\n", "\n", "K = model.get_max_latents()\n", "pc_config = Config(components=128, n=1_000_000,\n", " layer='generator.gen_z', model='BigGAN-512', output_class='husky')\n", "dump_name = get_or_compute(pc_config, inst)\n", "\n", "with np.load(dump_name) as data:\n", " lat_comp = data['lat_comp']\n", " lat_mean = data['lat_mean']\n", "\n", "with open('data/steerability/biggan_deep_512/gan_steer-linear_zoom_512.pkl', 'rb') as f:\n", " delta_steerability_zoom = pickle.load(f)['w_zoom'].reshape(1, 128)\n", "with open('data/steerability/biggan_deep_512/gan_steer-linear_shiftx_512.pkl', 'rb') as f:\n", " delta_steerability_transl = pickle.load(f)['w_shiftx'].reshape(1, 128)\n", "\n", "# Indices determined by visual inspection\n", "delta_ours_transl = lat_comp[0]\n", "delta_ours_zoom = lat_comp[6]\n", "\n", "model.truncation = 0.6\n", "compute(model, lat_mean, 'zoom', 'robin', [560157313], delta_ours_zoom, 0, K, -3.0, delta_steerability_zoom, 5.5)\n", "compute(model, lat_mean, 'zoom', 'ship', [107715983], delta_ours_zoom, 0, K, -3.0, delta_steerability_zoom, 5.0)\n", "\n", "compute(model, lat_mean, 'translate', 'golden_retriever', [552411435], delta_ours_transl, 0, K, -2.0, delta_steerability_transl, 4.5)\n", "compute(model, lat_mean, 'translate', 'lemon', [331582800], delta_ours_transl, 0, K, -3.0, delta_steerability_transl, 6.0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true, "tags": [] }, "outputs": [], "source": [ "# StyleGAN1-ffhq (InterfaceGAN)\n", "\n", "inst = get_instrumented_model('StyleGAN', 'ffhq', 'g_mapping', device, use_w=True, inst=inst)\n", "model = inst.model\n", "\n", "K = model.get_max_latents()\n", "pc_config = Config(components=128, n=1_000_000, use_w=True,\n", " layer='g_mapping', model='StyleGAN', output_class='ffhq')\n", "dump_name = get_or_compute(pc_config, inst)\n", "\n", "with np.load(dump_name) as data:\n", " lat_comp = data['lat_comp']\n", " lat_mean = data['lat_mean']\n", "\n", "# SG-ffhq-w, non-conditional\n", "d_ffhq_pose = np.load('data/interfacegan/stylegan_ffhq_pose_w_boundary.npy').astype(np.float32)\n", "d_ffhq_smile = np.load('data/interfacegan/stylegan_ffhq_smile_w_boundary.npy').astype(np.float32)\n", "d_ffhq_gender = np.load('data/interfacegan/stylegan_ffhq_gender_w_boundary.npy').astype(np.float32)\n", "d_ffhq_glasses = np.load('data/interfacegan/stylegan_ffhq_eyeglasses_w_boundary.npy').astype(np.float32)\n", "\n", "# Indices determined by visual inspection\n", "d_ours_pose = lat_comp[9]\n", "d_ours_smile = lat_comp[44]\n", "d_ours_gender = lat_comp[0]\n", "d_ours_glasses = lat_comp[12]\n", "\n", "model.truncation = 1.0 # NOT IMPLEMENTED\n", "compute(model, lat_mean, 'pose', 'ffhq', [440608316, 1811098088, 129888612], d_ours_pose, 0, 7, -1.0, d_ffhq_pose, 1.0)\n", "compute(model, lat_mean, 'smile', 'ffhq', [1759734403, 1647189561, 70163682], d_ours_smile, 3, 4, -8.5, d_ffhq_smile, 1.0)\n", "compute(model, lat_mean, 'gender', 'ffhq', [1302836080, 1746672325], d_ours_gender, 2, 6, -4.5, d_ffhq_gender, 1.5)\n", "compute(model, lat_mean, 'glasses', 'ffhq', [1565213752, 1005764659, 1110182583], d_ours_glasses, 0, 2, 4.0, d_ffhq_glasses, 1.0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "# StyleGAN1-ffhq (Steerability)\n", "\n", "inst = get_instrumented_model('StyleGAN', 'ffhq', 'g_mapping', device, use_w=True, inst=inst)\n", "model = inst.model\n", "\n", "K = model.get_max_latents()\n", "pc_config = Config(components=128, n=1_000_000, use_w=True,\n", " layer='g_mapping', model='StyleGAN', output_class='ffhq')\n", "dump_name = get_or_compute(pc_config, inst)\n", "\n", "with np.load(dump_name) as data:\n", " lat_comp = data['lat_comp']\n", " lat_mean = data['lat_mean']\n", "\n", "# SG-ffhq-w, non-conditional\n", "# Shapes: [18, 512]\n", "d_ffhq_R = np.load('data/steerability/stylegan_ffhq/ffhq_rgb_0.npy').astype(np.float32)\n", "d_ffhq_G = np.load('data/steerability/stylegan_ffhq/ffhq_rgb_1.npy').astype(np.float32)\n", "d_ffhq_B = np.load('data/steerability/stylegan_ffhq/ffhq_rgb_2.npy').astype(np.float32)\n", "\n", "# Indices determined by visual inspection\n", "d_ours_R = lat_comp[0]\n", "d_ours_G = -lat_comp[1]\n", "d_ours_B = -lat_comp[2]\n", "\n", "model.truncation = 1.0 # NOT IMPLEMENTED\n", "compute(model, lat_mean, 'red', 'ffhq', [5], d_ours_R, 17, 18, 8.0, d_ffhq_R, 1.0, center=False)\n", "compute(model, lat_mean, 'green', 'ffhq', [5], d_ours_G, 17, 18, 15.0, d_ffhq_G, 1.0, center=False)\n", "compute(model, lat_mean, 'blue', 'ffhq', [5], d_ours_B, 17, 18, 10.0, d_ffhq_B, 1.0, center=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true, "tags": [] }, "outputs": [], "source": [ "# StyleGAN1-celebahq (InterfaceGAN)\n", "\n", "inst = get_instrumented_model('StyleGAN', 'celebahq', 'g_mapping', device, use_w=True, inst=inst)\n", "model = inst.model\n", "\n", "K = model.get_max_latents()\n", "pc_config = Config(components=128, n=1_000_000, use_w=True,\n", " layer='g_mapping', model='StyleGAN', output_class='celebahq')\n", "dump_name = get_or_compute(pc_config, inst)\n", "\n", "with np.load(dump_name) as data:\n", " lat_comp = data['lat_comp']\n", " lat_mean = data['lat_mean']\n", "\n", "# SG-ffhq-w, non-conditional\n", "d_celebahq_pose = np.load('data/interfacegan/stylegan_celebahq_pose_w_boundary.npy').astype(np.float32)\n", "d_celebahq_smile = np.load('data/interfacegan/stylegan_celebahq_smile_w_boundary.npy').astype(np.float32)\n", "d_celebahq_gender = np.load('data/interfacegan/stylegan_celebahq_gender_w_boundary.npy').astype(np.float32)\n", "d_celebahq_glasses = np.load('data/interfacegan/stylegan_celebahq_eyeglasses_w_boundary.npy').astype(np.float32)\n", "\n", "# Indices determined by visual inspection\n", "d_ours_pose = lat_comp[7]\n", "d_ours_smile = lat_comp[14]\n", "d_ours_gender = lat_comp[1]\n", "d_ours_glasses = lat_comp[5]\n", "\n", "model.truncation = 1.0 # NOT IMPLEMENTED\n", "compute(model, lat_mean, 'pose', 'celebahq', [1939067252, 1460055449, 329555154], d_ours_pose, 0, 7, -1.0, d_celebahq_pose, 1.0)\n", "compute(model, lat_mean, 'smile', 'celebahq', [329187806, 424805522, 1777796971], d_ours_smile, 3, 4, -7.0, d_celebahq_smile, 1.3)\n", "compute(model, lat_mean, 'gender', 'celebahq', [1144615644, 967075839, 264878205], d_ours_gender, 0, 2, -3.2, d_celebahq_gender, 1.2)\n", "compute(model, lat_mean, 'glasses', 'celebahq', [991993380, 594344173, 2119328990, 1919124025], d_ours_glasses, 0, 1, -10.0, d_celebahq_glasses, 2.0) # hard for both" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false, "tags": [] }, "outputs": [], "source": [ "# StyleGAN1-cars (Steerability)\n", "\n", "inst = get_instrumented_model('StyleGAN', 'cars', 'g_mapping', device, use_w=True, inst=inst)\n", "model = inst.model\n", "\n", "K = model.get_max_latents()\n", "pc_config = Config(components=128, n=1_000_000, use_w=True,\n", " layer='g_mapping', model='StyleGAN', output_class='cars')\n", "dump_name = get_or_compute(pc_config, inst)\n", "\n", "with np.load(dump_name) as data:\n", " lat_comp = data['lat_comp']\n", " lat_mean = data['lat_mean']\n", "\n", "# Shapes: [16, 512]\n", "d_cars_rot = np.load('data/steerability/stylegan_cars/rotate2d.npy').astype(np.float32)\n", "d_cars_shift = np.load('data/steerability/stylegan_cars/shifty.npy').astype(np.float32)\n", "\n", "# Add two final layers\n", "d_cars_rot = np.append(d_cars_rot, np.zeros((2,512), dtype=np.float32), axis=0)\n", "d_cars_shift = np.append(d_cars_shift, np.zeros((2,512), dtype=np.float32), axis=0)\n", "\n", "print(d_cars_rot.shape)\n", "\n", "# Indices determined by visual inspection\n", "d_ours_rot = lat_comp[0]\n", "d_ours_shift = lat_comp[7]\n", "\n", "model.truncation = 1.0 # NOT IMPLEMENTED\n", "compute(model, lat_mean, 'rotate2d', 'cars', [46, 28], d_ours_rot, 0, 1, 1.0, d_cars_rot, 1.0, center=False)\n", "compute(model, lat_mean, 'shifty', 'cars', [0, 13], d_ours_shift, 1, 2, 4.0, d_cars_shift, 1.0, center=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "file_extension": ".py", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.7-final" }, "mimetype": "text/x-python", "name": "python", "npconvert_exporter": "python", "pygments_lexer": "ipython3", "version": 3 }, "nbformat": 4, "nbformat_minor": 2 }