{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Copyright 2020 Erik Härkönen. All rights reserved.\n", "# This file is licensed to you under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License. You may obtain a copy\n", "# of the License at http://www.apache.org/licenses/LICENSE-2.0\n", "\n", "# Unless required by applicable law or agreed to in writing, software distributed under\n", "# the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS\n", "# OF ANY KIND, either express or implied. See the License for the specific language\n", "# governing permissions and limitations under the License.\n", "\n", "%matplotlib inline\n", "from notebook_init import *\n", "import scipy\n", "\n", "outdir = Path('out/figures/random_baseline')\n", "makedirs(outdir, exist_ok=True)\n", "\n", "# Project tensor 'X' onto orthonormal basis 'comp', return coordinates\n", "def project_ortho(X, comp):\n", " N = comp.shape[0]\n", " coords = (comp.reshape(N, -1) * X.reshape(-1)).sum(dim=1)\n", " return coords.reshape([N]+[1]*X.ndim)\n", "\n", "def show_img(img_np, W=6, H=6):\n", " #plt.figure(figsize=(W,H))\n", " plt.axis('off')\n", " plt.tight_layout()\n", " plt.imshow(img_np, interpolation='bilinear')\n", " \n", "inst = None # reused when possible" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false, "tags": [] }, "outputs": [], "source": [ "from torchvision.utils import make_grid\n", "\n", "def generate(model_name, class_name, seed=None, trunc=0.6, N=5, use_random_basis=True):\n", " global inst\n", " \n", " config = Config(n=1_000_000, batch_size=500, model=model_name,\n", " output_class=class_name, use_w=('StyleGAN' in model_name))\n", " \n", " if model_name == 'StyleGAN2':\n", " config.layer = 'style'\n", " elif model_name == 'StyleGAN':\n", " config.layer = 'g_mapping'\n", " else:\n", " config.layer = 'generator.gen_z'\n", " config.n = 1_000_000\n", " config.output_class = 'husky'\n", " \n", " inst = get_instrumented_model(config, torch.device('cuda'), inst=inst)\n", " model = inst.model\n", "\n", " K = model.get_latent_dims()\n", " config.components = K\n", " \n", " dump_name = get_or_compute(config, inst)\n", "\n", " with np.load(dump_name) as data:\n", " lat_comp = torch.from_numpy(data['lat_comp']).cuda()\n", " lat_mean = torch.from_numpy(data['lat_mean']).cuda()\n", " lat_std = torch.from_numpy(data['lat_stdev']).cuda()\n", " \n", " B = 6\n", " if seed is None:\n", " seed = np.random.randint(np.iinfo(np.int32).max - B)\n", " model.truncation = trunc\n", " \n", " if 'BigGAN' in model_name:\n", " model.set_output_class(class_name)\n", "\n", " print(f'Seeds: {seed} - {seed+B}')\n", "\n", " # Resampling test\n", " w_base = model.sample_latent(1, seed=seed + B)\n", " plt.imshow(model.sample_np(w_base))\n", " plt.axis('off')\n", " plt.show()\n", "\n", " # Resample some components\n", " def get_batch(indices, basis):\n", " w_batch = torch.zeros(B, K).cuda()\n", " coord_base = project_ortho(w_base - lat_mean, basis)\n", "\n", " for i in range(B):\n", " w = model.sample_latent(1, seed=seed + i)\n", " coords = coord_base.clone()\n", " coords_resampled = project_ortho(w - lat_mean, basis)\n", " coords[indices, :, :] = coords_resampled[indices, :, :]\n", " w_batch[i, :] = lat_mean + torch.sum(coords * basis, dim=0)\n", "\n", " return w_batch\n", "\n", " def show_grid(w, title):\n", " out = model.forward(w)\n", " if class_name == 'car':\n", " out = out[:, :, 64:-64, :]\n", " elif class_name == 'cat':\n", " out = out[:, :, 18:-8, :]\n", " grid = make_grid(out, nrow=3)\n", " grid_np = grid.clamp(0, 1).permute(1, 2, 0).cpu().numpy()\n", " show_img(grid_np)\n", " plt.title(title)\n", "\n", " def save_imgs(w, prefix):\n", " for i, img in enumerate(model.sample_np(w)):\n", " if class_name == 'car':\n", " img = img[64:-64, :, :]\n", " elif class_name == 'cat':\n", " img = img[18:-8, :, :]\n", " outpath = outdir / f'{model_name}_{class_name}' / f'{prefix}_{i}.png'\n", " makedirs(outpath.parent, exist_ok=True)\n", " Image.fromarray(np.uint8(img * 255)).save(outpath)\n", " #print('Saving', outpath)\n", "\n", " def orthogonalize_rows(V):\n", " Q, R = np.linalg.qr(V.T)\n", " return Q.T\n", " \n", " # V = [n_comp, n_dim]\n", " def assert_orthonormal(V):\n", " M = np.dot(V, V.T) # [n_comp, n_comp]\n", " det = np.linalg.det(M)\n", " assert np.allclose(M, np.identity(M.shape[0]), atol=1e-5), f'Basis is not orthonormal (det={det})'\n", "\n", " plt.figure(figsize=((12,6.5) if class_name in ['car', 'cat'] else (12,8)))\n", " \n", " # First N fixed\n", " ind_rand = np.array(range(N, K)) # N -> K rerandomized\n", " b1 = get_batch(ind_rand, lat_comp)\n", " plt.subplot(2, 2, 1)\n", " show_grid(b1, f'Keep {N} first pca -> Consistent pose')\n", " save_imgs(b1, f'keep_{N}_first_{seed}')\n", "\n", " # First N randomized\n", " ind_rand = np.array(range(0, N)) # 0 -> N rerandomized\n", " b2 = get_batch(ind_rand, lat_comp)\n", " plt.subplot(2, 2, 2)\n", " show_grid(b2, f'Randomize {N} first pca -> Consistent style')\n", " save_imgs(b2, f'randomize_{N}_first_{seed}')\n", "\n", " if use_random_basis:\n", " # Random orthonormal basis drawn from p(w)\n", " # Highly shaped by W, sort of a noisy pseudo-PCA\n", " #V = (model.sample_latent(K, seed=seed + B + 1) - lat_mean).cpu().numpy()\n", " #V = V / np.sqrt(np.sum(V*V, axis=-1, keepdims=True)) # normalize rows\n", " #V = orthogonalize_rows(V)\n", " \n", " # Isotropic random basis\n", " V = scipy.stats.special_ortho_group.rvs(K)\n", " assert_orthonormal(V)\n", "\n", " rand_basis = torch.from_numpy(V).float().view(lat_comp.shape).to(device)\n", " assert rand_basis.shape == lat_comp.shape, f'Shape mismatch: {rand_basis.shape} != {lat_comp.shape}'\n", "\n", " ind_perm = range(K)\n", " else:\n", " # Just use shuffled PCA basis\n", " rng = np.random.RandomState(seed=seed)\n", " perm = rng.permutation(range(K))\n", " rand_basis = lat_comp[perm, :]\n", "\n", " basis_type_str = 'random' if use_random_basis else 'pca_shfl'\n", "\n", " # First N random fixed\n", " ind_rand = np.array(range(N, K)) # N -> K rerandomized\n", " b3 = get_batch(ind_rand, rand_basis)\n", " plt.subplot(2, 2, 3)\n", " show_grid(b3, f'Keep {N} first {basis_type_str} -> Little consistency')\n", " save_imgs(b3, f'keep_{N}_first_{basis_type_str}_{seed}')\n", " \n", " # First N random rerandomized\n", " ind_rand = np.array(range(0, N)) # 0 -> N rerandomized\n", " b4 = get_batch(ind_rand, rand_basis)\n", " plt.subplot(2, 2, 4)\n", " show_grid(b4, f'Randomize {N} first {basis_type_str} -> Little variation')\n", " save_imgs(b4, f'randomize_{N}_first_{basis_type_str}_{seed}')\n", " \n", " plt.show()\n", "\n", "\n", "# In paper\n", "generate('StyleGAN2', 'cat', seed=1866827965, trunc=0.55, N=8)\n", " \n", "# In supplemental\n", "generate('StyleGAN', 'bedrooms', seed=1382244162, trunc=1.0, N=10)\n", "generate('StyleGAN', 'ffhq', seed=598174413, trunc=1.0, N=10)\n", "generate('BigGAN-256', 'duck', seed=1134462557, trunc=1.0, N=10)\n", "generate('StyleGAN2', 'car', seed=1257084100, trunc=0.7, N=5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.7-final" } }, "nbformat": 4, "nbformat_minor": 2 }