In [None]:
# Copyright 2020 Erik Härkönen. All rights reserved.
# This file is licensed to you under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. You may obtain a copy
# of the License at http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software distributed under
# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS
# OF ANY KIND, either express or implied. See the License for the specific language
# governing permissions and limitations under the License.

%matplotlib inline
from notebook_init import *
import scipy

out_root = Path('out/figures/first_20_pcs')
makedirs(out_root, exist_ok=True)
rand = lambda : np.random.randint(np.iinfo(np.int32).max)

In [None]:
def gram_schmidt_cols(V):
 Q, R = np.linalg.qr(V)
 return Q.T

def gram_schmidt_rows(V):
 Q, R = np.linalg.qr(V.T)
 return Q.T

def make_ortho_normal(N=512):
 return scipy.stats.special_ortho_group.rvs(N)

# Components on rows
def assert_normalized(V):
 assert np.allclose(np.diag(np.dot(V1, V1.T)), np.ones(V1.shape[0])), 'Basis not normalized'

# V = [n_comp, n_dim]
def assert_orthonormal(V):
 M = np.dot(V, V.T) # [n_comp, n_comp]
 det = np.linalg.det(M)
 assert np.allclose(M, np.identity(M.shape[0]), atol=1e-5), f'Basis is not orthonormal (det={det})'

In [None]:
n_pcs = 20
inst = None

def generate(model, basis, stds, mean, seeds, name, scale=2.0):
 makedirs(out_root / name, exist_ok=True)
 for seed in seeds:
 print(seed)
 
 strips = []
 
 for i in range(n_pcs):
 z = model.sample_latent(1, seed=seed)
 batch_frames = create_strip_centered(inst, 'latent', 'style', [z],
 0, basis[i], 0, stds[i], 0, mean, scale, 0, 18, num_frames=5)[0]
 strips.append(np.hstack(pad_frames(batch_frames, pad_fract_horiz=32)))
 #for j, frame in enumerate(batch_frames):
 # Image.fromarray(np.uint8(frame*255)).save(out_root / name / f'{seed}_comp{i}_{j}.png')
 
 for i, strip in enumerate(strips):
 Image.fromarray(np.uint8(strip*255)).save(out_root / name / f'{seed}_comp{i}.png', compress_level=1) # converted to jpg for paper

 grid = np.vstack(strips)
 
 im = Image.fromarray(np.uint8(grid*255))
 im.resize((im.width // 2, im.height // 2)).save(out_root / name / f'grid_{seed}.jpg', compress_level=1)
 
 plt.figure(figsize=(20,40))
 plt.title(name)
 plt.imshow(grid)
 plt.axis('off')
 plt.show()

def get_comp(config, inst):
 classname = config.output_class
 
 # BigGAN components are class agnostic
 # => use precomputed husky components
 if 'BigGAN' in inst.model.model_name:
 config.output_class = 'husky'
 
 dump = get_or_compute(config, inst)
 config.output_class = classname # restore

 return dump

def gen_principal_components(config, seeds, name, scale=2):
 global inst
 inst = get_instrumented_model(config, device, inst=inst)
 dump_name = get_comp(config, inst)

 model = inst.model
 model.truncation = 1.0

 with np.load(dump_name) as data:
 lat_comp = torch.from_numpy(data['lat_comp']).to(device)
 lat_mean = torch.from_numpy(data['lat_mean']).to(device)
 lat_std = data['lat_stdev']

 generate(model, lat_comp, lat_std, lat_mean, seeds, f'{name}_{int(scale)}sigma', scale)

def gen_normal_w_ortho(config, seeds, name, scale=2):
 global inst
 inst = get_instrumented_model(config, device, inst=inst)
 dump_name = get_comp(config, inst)
 
 model = inst.model
 model.truncation = 1.0

 with np.load(dump_name) as data:
 mean = torch.from_numpy(data['lat_mean']).to(device)

 n_comp = model.get_latent_dims() # full rank basis
 V = make_ortho_normal(n_comp)
 assert_orthonormal(V)

 comp = torch.from_numpy(V).float().unsqueeze(dim=1).to(device)
 stdev = torch.ones((n_comp,)).float().to(device)

 generate(model, comp, stdev, mean, seeds, f'{name}_{int(scale)}x', scale)

# Pseudo-PCA ablation: basis highly shaped by W
def gen_w_ortho_ablation(config, seeds, name, scale=2):
 global inst
 inst = get_instrumented_model(config, device, inst=inst)
 dump_name = get_comp(config, inst)
 
 model = inst.model
 model.truncation = 1.0

 with np.load(dump_name) as data:
 mean = torch.from_numpy(data['lat_mean']).to(device)
 stdev = torch.from_numpy(data['lat_stdev']).to(device) # use PCA stdevs with random W dirs

 n_comp = model.get_latent_dims() # full rank
 V = (model.sample_latent(n_comp, seed=0) - mean).cpu().numpy() # [n_comp, n_dim]
 V = V / np.sqrt(np.sum(V*V, axis=-1, keepdims=True)) # normalize rows
 V = gram_schmidt_rows(V)
 assert_orthonormal(V)

 comp = torch.from_numpy(V).float().unsqueeze(dim=1).to(device)
 generate(model, comp, stdev, mean, seeds, f'{name}_{int(scale)}_pc_sigmas', scale)

In [None]:
seeds = [366745668] # 1502970553, 1235907362, 1302626592]

# StyleGAN2 ffhq
cfg = Config(components=512, n=1_000_000, batch_size=10_000, use_w=True,
 layer='style', model='StyleGAN2', output_class='ffhq')

#gen_principal_components(cfg, seeds, 'stylegan2_ffhq_pca')
#gen_normal_w_ortho(cfg, seeds, 'stylegan2_ffhq_random', scale=6)
#gen_w_ortho_ablation(cfg, seeds, 'stylegan2_ffhq_ablation')

# Switch to Z latent space
cfg.use_w = False
gen_normal_w_ortho(cfg, seeds, 'stylegan2_ffhq_random_z', scale=10)

In [None]:
seeds = [697477267] # 901810270, 101052884, 794859404, 1459915324

# StyleGAN2 car
cfg = Config(components=512, n=1_000_000, batch_size=10_000, use_w=True,
 layer='style', model='StyleGAN2', output_class='car')
#gen_principal_components(cfg, seeds, 'stylegan2_car_pca')
gen_normal_w_ortho(cfg, seeds, 'stylegan2_car_random', scale=8)
#gen_w_ortho_ablation(cfg, seeds, 'stylegan2_car_ablation')

# Switch to Z latent space
cfg.use_w = False
gen_normal_w_ortho(cfg, seeds, 'stylegan2_car_random_z', scale=10)

In [None]:
seeds = [1285057649] #1526046390, 1762862368

# StyleGAN2 cat
cfg = Config(components=512, n=1_000_000, batch_size=10_000, use_w=True,
 layer='style', model='StyleGAN2', output_class='cat')
gen_principal_components(cfg, seeds, 'stylegan2_cat_pca')
gen_normal_w_ortho(cfg, seeds, 'stylegan2_cat_random', scale=10)
#gen_w_ortho_ablation(cfg, seeds, 'stylegan2_cat_ablation')

# Switch to Z latent space
cfg.use_w = False
gen_normal_w_ortho(cfg, seeds, 'stylegan2_cat_random_z', scale=10)

In [None]:
seeds = [2129808859] #1903883295

# BigGAN-512 husky
cfg = Config(components=128, n=1_000_000,
 layer='generator.gen_z', model='BigGAN-512', output_class='husky')
gen_principal_components(cfg, seeds, 'biggan512_husky_pca', scale=2)
gen_normal_w_ortho(cfg, seeds, 'biggan512_husky_random', scale=6)
#gen_w_ortho_ablation(cfg, seeds, 'biggan512_husky_ablation', scale=2)

In [None]:
seeds = [844616023]

# BigGAN-512 church
cfg = Config(components=128, n=1_000_000,
 layer='generator.gen_z', model='BigGAN-512', output_class='church')
gen_principal_components(cfg, seeds, 'biggan512_church_pca', scale=3)
gen_normal_w_ortho(cfg, seeds, 'biggan512_church_random', scale=8)
#gen_w_ortho_ablation(cfg, seeds, 'biggan512_church_ablation', scale=3)