In [None]:
# Copyright 2020 Erik Härkönen. All rights reserved.
# This file is licensed to you under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. You may obtain a copy
# of the License at http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software distributed under
# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS
# OF ANY KIND, either express or implied. See the License for the specific language
# governing permissions and limitations under the License.

%matplotlib inline
from notebook_init import *
import scipy

outdir = Path('out/figures/random_baseline')
makedirs(outdir, exist_ok=True)

# Project tensor 'X' onto orthonormal basis 'comp', return coordinates
def project_ortho(X, comp):
 N = comp.shape[0]
 coords = (comp.reshape(N, -1) * X.reshape(-1)).sum(dim=1)
 return coords.reshape([N]+[1]*X.ndim)

def show_img(img_np, W=6, H=6):
 #plt.figure(figsize=(W,H))
 plt.axis('off')
 plt.tight_layout()
 plt.imshow(img_np, interpolation='bilinear')
 
inst = None # reused when possible

In [None]:
from torchvision.utils import make_grid

def generate(model_name, class_name, seed=None, trunc=0.6, N=5, use_random_basis=True):
 global inst
 
 config = Config(n=1_000_000, batch_size=500, model=model_name,
 output_class=class_name, use_w=('StyleGAN' in model_name))
 
 if model_name == 'StyleGAN2':
 config.layer = 'style'
 elif model_name == 'StyleGAN':
 config.layer = 'g_mapping'
 else:
 config.layer = 'generator.gen_z'
 config.n = 1_000_000
 config.output_class = 'husky'
 
 inst = get_instrumented_model(config, torch.device('cuda'), inst=inst)
 model = inst.model

 K = model.get_latent_dims()
 config.components = K
 
 dump_name = get_or_compute(config, inst)

 with np.load(dump_name) as data:
 lat_comp = torch.from_numpy(data['lat_comp']).cuda()
 lat_mean = torch.from_numpy(data['lat_mean']).cuda()
 lat_std = torch.from_numpy(data['lat_stdev']).cuda()
 
 B = 6
 if seed is None:
 seed = np.random.randint(np.iinfo(np.int32).max - B)
 model.truncation = trunc
 
 if 'BigGAN' in model_name:
 model.set_output_class(class_name)

 print(f'Seeds: {seed} - {seed+B}')

 # Resampling test
 w_base = model.sample_latent(1, seed=seed + B)
 plt.imshow(model.sample_np(w_base))
 plt.axis('off')
 plt.show()

 # Resample some components
 def get_batch(indices, basis):
 w_batch = torch.zeros(B, K).cuda()
 coord_base = project_ortho(w_base - lat_mean, basis)

 for i in range(B):
 w = model.sample_latent(1, seed=seed + i)
 coords = coord_base.clone()
 coords_resampled = project_ortho(w - lat_mean, basis)
 coords[indices, :, :] = coords_resampled[indices, :, :]
 w_batch[i, :] = lat_mean + torch.sum(coords * basis, dim=0)

 return w_batch

 def show_grid(w, title):
 out = model.forward(w)
 if class_name == 'car':
 out = out[:, :, 64:-64, :]
 elif class_name == 'cat':
 out = out[:, :, 18:-8, :]
 grid = make_grid(out, nrow=3)
 grid_np = grid.clamp(0, 1).permute(1, 2, 0).cpu().numpy()
 show_img(grid_np)
 plt.title(title)

 def save_imgs(w, prefix):
 for i, img in enumerate(model.sample_np(w)):
 if class_name == 'car':
 img = img[64:-64, :, :]
 elif class_name == 'cat':
 img = img[18:-8, :, :]
 outpath = outdir / f'{model_name}_{class_name}' / f'{prefix}_{i}.png'
 makedirs(outpath.parent, exist_ok=True)
 Image.fromarray(np.uint8(img * 255)).save(outpath)
 #print('Saving', outpath)

 def orthogonalize_rows(V):
 Q, R = np.linalg.qr(V.T)
 return Q.T
 
 # V = [n_comp, n_dim]
 def assert_orthonormal(V):
 M = np.dot(V, V.T) # [n_comp, n_comp]
 det = np.linalg.det(M)
 assert np.allclose(M, np.identity(M.shape[0]), atol=1e-5), f'Basis is not orthonormal (det={det})'

 plt.figure(figsize=((12,6.5) if class_name in ['car', 'cat'] else (12,8)))
 
 # First N fixed
 ind_rand = np.array(range(N, K)) # N -> K rerandomized
 b1 = get_batch(ind_rand, lat_comp)
 plt.subplot(2, 2, 1)
 show_grid(b1, f'Keep {N} first pca -> Consistent pose')
 save_imgs(b1, f'keep_{N}_first_{seed}')

 # First N randomized
 ind_rand = np.array(range(0, N)) # 0 -> N rerandomized
 b2 = get_batch(ind_rand, lat_comp)
 plt.subplot(2, 2, 2)
 show_grid(b2, f'Randomize {N} first pca -> Consistent style')
 save_imgs(b2, f'randomize_{N}_first_{seed}')

 if use_random_basis:
 # Random orthonormal basis drawn from p(w)
 # Highly shaped by W, sort of a noisy pseudo-PCA
 #V = (model.sample_latent(K, seed=seed + B + 1) - lat_mean).cpu().numpy()
 #V = V / np.sqrt(np.sum(V*V, axis=-1, keepdims=True)) # normalize rows
 #V = orthogonalize_rows(V)
 
 # Isotropic random basis
 V = scipy.stats.special_ortho_group.rvs(K)
 assert_orthonormal(V)

 rand_basis = torch.from_numpy(V).float().view(lat_comp.shape).to(device)
 assert rand_basis.shape == lat_comp.shape, f'Shape mismatch: {rand_basis.shape} != {lat_comp.shape}'

 ind_perm = range(K)
 else:
 # Just use shuffled PCA basis
 rng = np.random.RandomState(seed=seed)
 perm = rng.permutation(range(K))
 rand_basis = lat_comp[perm, :]

 basis_type_str = 'random' if use_random_basis else 'pca_shfl'

 # First N random fixed
 ind_rand = np.array(range(N, K)) # N -> K rerandomized
 b3 = get_batch(ind_rand, rand_basis)
 plt.subplot(2, 2, 3)
 show_grid(b3, f'Keep {N} first {basis_type_str} -> Little consistency')
 save_imgs(b3, f'keep_{N}_first_{basis_type_str}_{seed}')
 
 # First N random rerandomized
 ind_rand = np.array(range(0, N)) # 0 -> N rerandomized
 b4 = get_batch(ind_rand, rand_basis)
 plt.subplot(2, 2, 4)
 show_grid(b4, f'Randomize {N} first {basis_type_str} -> Little variation')
 save_imgs(b4, f'randomize_{N}_first_{basis_type_str}_{seed}')
 
 plt.show()


# In paper
generate('StyleGAN2', 'cat', seed=1866827965, trunc=0.55, N=8)
 
# In supplemental
generate('StyleGAN', 'bedrooms', seed=1382244162, trunc=1.0, N=10)
generate('StyleGAN', 'ffhq', seed=598174413, trunc=1.0, N=10)
generate('BigGAN-256', 'duck', seed=1134462557, trunc=1.0, N=10)
generate('StyleGAN2', 'car', seed=1257084100, trunc=0.7, N=5)