In [None]:
# Copyright 2020 Erik Härkönen. All rights reserved.
# This file is licensed to you under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. You may obtain a copy
# of the License at http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software distributed under
# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS
# OF ANY KIND, either express or implied. See the License for the specific language
# governing permissions and limitations under the License.

# Show top 10 PCs for StyleGAN2 ffhq
# Center along component before manipulation
# Also show cleaned up PCs based on top10, a couple of cleaned up later style PCs
%matplotlib inline
from notebook_init import *

out_root = Path('out/figures/pca_cleanup')
makedirs(out_root / 'tuned', exist_ok=True)
makedirs(out_root / 'global', exist_ok=True)
rand = lambda : np.random.randint(np.iinfo(np.int32).max)

In [None]:
use_w = True
inst = get_instrumented_model('StyleGAN2', 'ffhq', 'style', device, inst=inst, use_w=use_w)
model = inst.model
model.truncation = 1.0

pc_config = Config(components=80, n=1_000_000, use_w=use_w,
 layer='style', model='StyleGAN2', output_class='ffhq')
dump_name = get_or_compute(pc_config, inst)

with np.load(dump_name) as data:
 lat_comp = torch.from_numpy(data['lat_comp']).to(device)
 lat_mean = torch.from_numpy(data['lat_mean']).to(device)
 lat_std = data['lat_stdev']

In [None]:
seeds_ffhq = [366745668] #, 1502970553, 1235907362, 1302626592]
#seeds_ffhq = [rand() for _ in range(50)]

n_pcs = 14

# Case 1: Normal centered PCs
for seed in seeds_ffhq:
 print(seed)
 
 strips = []
 
 for i in range(n_pcs):
 z = model.sample_latent(1, seed=seed)
 batch_frames = create_strip_centered(inst, 'latent', 'style', [z],
 0, lat_comp[i], 0, lat_std[i], 0, lat_mean, 2.0, 0, 18, num_frames=7)[0]
 strips.append(np.hstack(pad_frames(batch_frames)))
 for j, frame in enumerate(batch_frames):
 Image.fromarray(np.uint8(frame*255)).save(out_root / 'global' / f'{seed}_pc{i}_{j}.png')
 
 #col_left = np.vstack(pad_frames(strips[:n_pcs//2], 0, 64))
 #col_right = np.vstack(pad_frames(strips[n_pcs//2:], 0, 64))
 grid = np.vstack(strips)
 
 Image.fromarray(np.uint8(grid*255)).save(out_root / f'grid_{seed}.jpg')
 
 plt.figure(figsize=(20,40))
 plt.imshow(grid)
 plt.axis('off')
 plt.show()

In [None]:
# Case 2: hand-tuned layer ranges for some directions
hand_tuned = [
 ( 0, (1, 7), 2.0), # gender, keep age
 ( 1, (0, 3), 2.0), # rotate, keep gender
 ( 2, (3, 8), 2.0), # gender, keep geometry
 ( 3, (2, 8), 2.0), # age, keep lighting, no hat
 ( 4, (5, 18), 2.0), # background, keep geometry
 ( 5, (0, 4), 2.0), # hat, keep lighting and age
 ( 6, (7, 18), 2.0), # just lighting
 ( 7, (5, 9), 2.0), # just lighting
 ( 8, (1, 7), 2.0), # age, keep lighting
 ( 9, (0, 5), 2.0), # keep lighting
 (10, (7, 9), 2.0), # hair color, keep geom
 (11, (0, 5), 2.0), # hair length, keep color
 (12, (8, 9), 2.0), # light dir lr
# (12, (4, 10), 2.0), # light position LR
 (13, (0, 6), 2.0), # about the same
]

for seed in seeds_ffhq:
 print(seed)
 
 strips = []
 
 for i, (s, e), sigma in hand_tuned:
 z = model.sample_latent(1, seed=seed)
 
 batch_frames = create_strip_centered(inst, 'latent', 'style', [z],
 0, lat_comp[i], 0, lat_std[i], 0, lat_mean, sigma, s, e, num_frames=7)[0]
 strips.append(np.hstack(pad_frames(batch_frames)))
 for j, frame in enumerate(batch_frames):
 Image.fromarray(np.uint8(frame*255)).save(out_root / 'tuned' / f'{seed}_pc{i}_s{s}_e{e}_{j}.png')
 
 #col_left = np.vstack(pad_frames(strips[:len(strips)//2], 0, 64))
 #col_right = np.vstack(pad_frames(strips[len(strips)//2:], 0, 64))
 #grid = np.hstack(pad_frames(strips, 16))
 grid = np.vstack(strips)
 
 Image.fromarray(np.uint8(grid*255)).save(out_root / f'grid_{seed}_tuned.jpg')
 
 plt.figure(figsize=(20,40))
 plt.imshow(grid)
 plt.axis('off')
 plt.show()