{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Copyright 2020 Erik Härkönen. All rights reserved.\n", "# This file is licensed to you under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License. You may obtain a copy\n", "# of the License at http://www.apache.org/licenses/LICENSE-2.0\n", "\n", "# Unless required by applicable law or agreed to in writing, software distributed under\n", "# the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS\n", "# OF ANY KIND, either express or implied. See the License for the specific language\n", "# governing permissions and limitations under the License.\n", "\n", "%matplotlib inline\n", "from notebook_init import *\n", "import scipy\n", "\n", "out_root = Path('out/figures/first_20_pcs')\n", "makedirs(out_root, exist_ok=True)\n", "rand = lambda : np.random.randint(np.iinfo(np.int32).max)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def gram_schmidt_cols(V):\n", " Q, R = np.linalg.qr(V)\n", " return Q.T\n", "\n", "def gram_schmidt_rows(V):\n", " Q, R = np.linalg.qr(V.T)\n", " return Q.T\n", "\n", "def make_ortho_normal(N=512):\n", " return scipy.stats.special_ortho_group.rvs(N)\n", "\n", "# Components on rows\n", "def assert_normalized(V):\n", " assert np.allclose(np.diag(np.dot(V1, V1.T)), np.ones(V1.shape[0])), 'Basis not normalized'\n", "\n", "# V = [n_comp, n_dim]\n", "def assert_orthonormal(V):\n", " M = np.dot(V, V.T) # [n_comp, n_comp]\n", " det = np.linalg.det(M)\n", " assert np.allclose(M, np.identity(M.shape[0]), atol=1e-5), f'Basis is not orthonormal (det={det})'" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "n_pcs = 20\n", "inst = None\n", "\n", "def generate(model, basis, stds, mean, seeds, name, scale=2.0):\n", " makedirs(out_root / name, exist_ok=True)\n", " for seed in seeds:\n", " print(seed)\n", " \n", " strips = []\n", " \n", " for i in range(n_pcs):\n", " z = model.sample_latent(1, seed=seed)\n", " batch_frames = create_strip_centered(inst, 'latent', 'style', [z],\n", " 0, basis[i], 0, stds[i], 0, mean, scale, 0, 18, num_frames=5)[0]\n", " strips.append(np.hstack(pad_frames(batch_frames, pad_fract_horiz=32)))\n", " #for j, frame in enumerate(batch_frames):\n", " # Image.fromarray(np.uint8(frame*255)).save(out_root / name / f'{seed}_comp{i}_{j}.png')\n", " \n", " for i, strip in enumerate(strips):\n", " Image.fromarray(np.uint8(strip*255)).save(out_root / name / f'{seed}_comp{i}.png', compress_level=1) # converted to jpg for paper\n", "\n", " grid = np.vstack(strips)\n", " \n", " im = Image.fromarray(np.uint8(grid*255))\n", " im.resize((im.width // 2, im.height // 2)).save(out_root / name / f'grid_{seed}.jpg', compress_level=1)\n", " \n", " plt.figure(figsize=(20,40))\n", " plt.title(name)\n", " plt.imshow(grid)\n", " plt.axis('off')\n", " plt.show()\n", "\n", "def get_comp(config, inst):\n", " classname = config.output_class\n", " \n", " # BigGAN components are class agnostic\n", " # => use precomputed husky components\n", " if 'BigGAN' in inst.model.model_name:\n", " config.output_class = 'husky'\n", " \n", " dump = get_or_compute(config, inst)\n", " config.output_class = classname # restore\n", "\n", " return dump\n", "\n", "def gen_principal_components(config, seeds, name, scale=2):\n", " global inst\n", " inst = get_instrumented_model(config, device, inst=inst)\n", " dump_name = get_comp(config, inst)\n", "\n", " model = inst.model\n", " model.truncation = 1.0\n", "\n", " with np.load(dump_name) as data:\n", " lat_comp = torch.from_numpy(data['lat_comp']).to(device)\n", " lat_mean = torch.from_numpy(data['lat_mean']).to(device)\n", " lat_std = data['lat_stdev']\n", "\n", " generate(model, lat_comp, lat_std, lat_mean, seeds, f'{name}_{int(scale)}sigma', scale)\n", "\n", "def gen_normal_w_ortho(config, seeds, name, scale=2):\n", " global inst\n", " inst = get_instrumented_model(config, device, inst=inst)\n", " dump_name = get_comp(config, inst)\n", " \n", " model = inst.model\n", " model.truncation = 1.0\n", "\n", " with np.load(dump_name) as data:\n", " mean = torch.from_numpy(data['lat_mean']).to(device)\n", "\n", " n_comp = model.get_latent_dims() # full rank basis\n", " V = make_ortho_normal(n_comp)\n", " assert_orthonormal(V)\n", "\n", " comp = torch.from_numpy(V).float().unsqueeze(dim=1).to(device)\n", " stdev = torch.ones((n_comp,)).float().to(device)\n", "\n", " generate(model, comp, stdev, mean, seeds, f'{name}_{int(scale)}x', scale)\n", "\n", "# Pseudo-PCA ablation: basis highly shaped by W\n", "def gen_w_ortho_ablation(config, seeds, name, scale=2):\n", " global inst\n", " inst = get_instrumented_model(config, device, inst=inst)\n", " dump_name = get_comp(config, inst)\n", " \n", " model = inst.model\n", " model.truncation = 1.0\n", "\n", " with np.load(dump_name) as data:\n", " mean = torch.from_numpy(data['lat_mean']).to(device)\n", " stdev = torch.from_numpy(data['lat_stdev']).to(device) # use PCA stdevs with random W dirs\n", "\n", " n_comp = model.get_latent_dims() # full rank\n", " V = (model.sample_latent(n_comp, seed=0) - mean).cpu().numpy() # [n_comp, n_dim]\n", " V = V / np.sqrt(np.sum(V*V, axis=-1, keepdims=True)) # normalize rows\n", " V = gram_schmidt_rows(V)\n", " assert_orthonormal(V)\n", "\n", " comp = torch.from_numpy(V).float().unsqueeze(dim=1).to(device)\n", " generate(model, comp, stdev, mean, seeds, f'{name}_{int(scale)}_pc_sigmas', scale)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true, "tags": [] }, "outputs": [], "source": [ "seeds = [366745668] # 1502970553, 1235907362, 1302626592]\n", "\n", "# StyleGAN2 ffhq\n", "cfg = Config(components=512, n=1_000_000, batch_size=10_000, use_w=True,\n", " layer='style', model='StyleGAN2', output_class='ffhq')\n", "\n", "#gen_principal_components(cfg, seeds, 'stylegan2_ffhq_pca')\n", "#gen_normal_w_ortho(cfg, seeds, 'stylegan2_ffhq_random', scale=6)\n", "#gen_w_ortho_ablation(cfg, seeds, 'stylegan2_ffhq_ablation')\n", "\n", "# Switch to Z latent space\n", "cfg.use_w = False\n", "gen_normal_w_ortho(cfg, seeds, 'stylegan2_ffhq_random_z', scale=10)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false, "tags": [] }, "outputs": [], "source": [ "seeds = [697477267] # 901810270, 101052884, 794859404, 1459915324\n", "\n", "# StyleGAN2 car\n", "cfg = Config(components=512, n=1_000_000, batch_size=10_000, use_w=True,\n", " layer='style', model='StyleGAN2', output_class='car')\n", "#gen_principal_components(cfg, seeds, 'stylegan2_car_pca')\n", "gen_normal_w_ortho(cfg, seeds, 'stylegan2_car_random', scale=8)\n", "#gen_w_ortho_ablation(cfg, seeds, 'stylegan2_car_ablation')\n", "\n", "# Switch to Z latent space\n", "cfg.use_w = False\n", "gen_normal_w_ortho(cfg, seeds, 'stylegan2_car_random_z', scale=10)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true, "tags": [] }, "outputs": [], "source": [ "seeds = [1285057649] #1526046390, 1762862368\n", "\n", "# StyleGAN2 cat\n", "cfg = Config(components=512, n=1_000_000, batch_size=10_000, use_w=True,\n", " layer='style', model='StyleGAN2', output_class='cat')\n", "gen_principal_components(cfg, seeds, 'stylegan2_cat_pca')\n", "gen_normal_w_ortho(cfg, seeds, 'stylegan2_cat_random', scale=10)\n", "#gen_w_ortho_ablation(cfg, seeds, 'stylegan2_cat_ablation')\n", "\n", "# Switch to Z latent space\n", "cfg.use_w = False\n", "gen_normal_w_ortho(cfg, seeds, 'stylegan2_cat_random_z', scale=10)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true, "tags": [] }, "outputs": [], "source": [ "seeds = [2129808859] #1903883295\n", "\n", "# BigGAN-512 husky\n", "cfg = Config(components=128, n=1_000_000,\n", " layer='generator.gen_z', model='BigGAN-512', output_class='husky')\n", "gen_principal_components(cfg, seeds, 'biggan512_husky_pca', scale=2)\n", "gen_normal_w_ortho(cfg, seeds, 'biggan512_husky_random', scale=6)\n", "#gen_w_ortho_ablation(cfg, seeds, 'biggan512_husky_ablation', scale=2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "seeds = [844616023]\n", "\n", "# BigGAN-512 church\n", "cfg = Config(components=128, n=1_000_000,\n", " layer='generator.gen_z', model='BigGAN-512', output_class='church')\n", "gen_principal_components(cfg, seeds, 'biggan512_church_pca', scale=3)\n", "gen_normal_w_ortho(cfg, seeds, 'biggan512_church_random', scale=8)\n", "#gen_w_ortho_ablation(cfg, seeds, 'biggan512_church_ablation', scale=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.7-final" } }, "nbformat": 4, "nbformat_minor": 2 }