{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Copyright 2020 Erik Härkönen. All rights reserved.\n", "# This file is licensed to you under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License. You may obtain a copy\n", "# of the License at http://www.apache.org/licenses/LICENSE-2.0\n", "\n", "# Unless required by applicable law or agreed to in writing, software distributed under\n", "# the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS\n", "# OF ANY KIND, either express or implied. See the License for the specific language\n", "# governing permissions and limitations under the License.\n", "\n", "# Show top 10 PCs for StyleGAN2 ffhq\n", "# Center along component before manipulation\n", "# Also show cleaned up PCs based on top10, a couple of cleaned up later style PCs\n", "%matplotlib inline\n", "from notebook_init import *\n", "\n", "out_root = Path('out/figures/pca_cleanup')\n", "makedirs(out_root / 'tuned', exist_ok=True)\n", "makedirs(out_root / 'global', exist_ok=True)\n", "rand = lambda : np.random.randint(np.iinfo(np.int32).max)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "use_w = True\n", "inst = get_instrumented_model('StyleGAN2', 'ffhq', 'style', device, inst=inst, use_w=use_w)\n", "model = inst.model\n", "model.truncation = 1.0\n", "\n", "pc_config = Config(components=80, n=1_000_000, use_w=use_w,\n", " layer='style', model='StyleGAN2', output_class='ffhq')\n", "dump_name = get_or_compute(pc_config, inst)\n", "\n", "with np.load(dump_name) as data:\n", " lat_comp = torch.from_numpy(data['lat_comp']).to(device)\n", " lat_mean = torch.from_numpy(data['lat_mean']).to(device)\n", " lat_std = data['lat_stdev']" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "seeds_ffhq = [366745668] #, 1502970553, 1235907362, 1302626592]\n", "#seeds_ffhq = [rand() for _ in range(50)]\n", "\n", "n_pcs = 14\n", "\n", "# Case 1: Normal centered PCs\n", "for seed in seeds_ffhq:\n", " print(seed)\n", " \n", " strips = []\n", " \n", " for i in range(n_pcs):\n", " z = model.sample_latent(1, seed=seed)\n", " batch_frames = create_strip_centered(inst, 'latent', 'style', [z],\n", " 0, lat_comp[i], 0, lat_std[i], 0, lat_mean, 2.0, 0, 18, num_frames=7)[0]\n", " strips.append(np.hstack(pad_frames(batch_frames)))\n", " for j, frame in enumerate(batch_frames):\n", " Image.fromarray(np.uint8(frame*255)).save(out_root / 'global' / f'{seed}_pc{i}_{j}.png')\n", " \n", " #col_left = np.vstack(pad_frames(strips[:n_pcs//2], 0, 64))\n", " #col_right = np.vstack(pad_frames(strips[n_pcs//2:], 0, 64))\n", " grid = np.vstack(strips)\n", " \n", " Image.fromarray(np.uint8(grid*255)).save(out_root / f'grid_{seed}.jpg')\n", " \n", " plt.figure(figsize=(20,40))\n", " plt.imshow(grid)\n", " plt.axis('off')\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "# Case 2: hand-tuned layer ranges for some directions\n", "hand_tuned = [\n", " ( 0, (1, 7), 2.0), # gender, keep age\n", " ( 1, (0, 3), 2.0), # rotate, keep gender\n", " ( 2, (3, 8), 2.0), # gender, keep geometry\n", " ( 3, (2, 8), 2.0), # age, keep lighting, no hat\n", " ( 4, (5, 18), 2.0), # background, keep geometry\n", " ( 5, (0, 4), 2.0), # hat, keep lighting and age\n", " ( 6, (7, 18), 2.0), # just lighting\n", " ( 7, (5, 9), 2.0), # just lighting\n", " ( 8, (1, 7), 2.0), # age, keep lighting\n", " ( 9, (0, 5), 2.0), # keep lighting\n", " (10, (7, 9), 2.0), # hair color, keep geom\n", " (11, (0, 5), 2.0), # hair length, keep color\n", " (12, (8, 9), 2.0), # light dir lr\n", "# (12, (4, 10), 2.0), # light position LR\n", " (13, (0, 6), 2.0), # about the same\n", "]\n", "\n", "for seed in seeds_ffhq:\n", " print(seed)\n", " \n", " strips = []\n", " \n", " for i, (s, e), sigma in hand_tuned:\n", " z = model.sample_latent(1, seed=seed)\n", " \n", " batch_frames = create_strip_centered(inst, 'latent', 'style', [z],\n", " 0, lat_comp[i], 0, lat_std[i], 0, lat_mean, sigma, s, e, num_frames=7)[0]\n", " strips.append(np.hstack(pad_frames(batch_frames)))\n", " for j, frame in enumerate(batch_frames):\n", " Image.fromarray(np.uint8(frame*255)).save(out_root / 'tuned' / f'{seed}_pc{i}_s{s}_e{e}_{j}.png')\n", " \n", " #col_left = np.vstack(pad_frames(strips[:len(strips)//2], 0, 64))\n", " #col_right = np.vstack(pad_frames(strips[len(strips)//2:], 0, 64))\n", " #grid = np.hstack(pad_frames(strips, 16))\n", " grid = np.vstack(strips)\n", " \n", " Image.fromarray(np.uint8(grid*255)).save(out_root / f'grid_{seed}_tuned.jpg')\n", " \n", " plt.figure(figsize=(20,40))\n", " plt.imshow(grid)\n", " plt.axis('off')\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" } }, "nbformat": 4, "nbformat_minor": 2 }