In [None]:
# Copyright 2020 Erik Härkönen. All rights reserved.
# This file is licensed to you under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. You may obtain a copy
# of the License at http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software distributed under
# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS
# OF ANY KIND, either express or implied. See the License for the specific language
# governing permissions and limitations under the License.

# Comparison to GAN steerability and InterfaceGAN
%matplotlib inline
from notebook_init import *
import pickle

out_root = Path('out/figures/steerability_comp')
makedirs(out_root, exist_ok=True)
rand = lambda : np.random.randint(np.iinfo(np.int32).max)

In [None]:
def show_strip(frames):
 plt.figure(figsize=(20,20))
 plt.axis('off')
 plt.imshow(np.hstack(pad_frames(frames, 64)))
 plt.show()

In [None]:
normalize = lambda t : t / np.sqrt(np.sum(t.reshape(-1)**2))

def compute(
 model,
 lat_mean,
 prefix,
 imgclass,
 seeds,
 d_ours,
 l_start,
 l_end,
 scale_ours,
 d_sup, # single or one per layer
 scale_sup,
 center=True
):
 model.set_output_class(imgclass)
 makedirs(out_root / imgclass, exist_ok=True)
 
 for seed in seeds:
 print(seed)
 deltas = [d_ours, d_sup]
 scales = [scale_ours, scale_sup]
 ranges = [(l_start, l_end), (0, model.get_max_latents())]
 names = ['ours', 'supervised']

 for delta, name, scale, l_range in zip(deltas, names, scales, ranges):
 lat_base = model.sample_latent(1, seed=seed).cpu().numpy()

 # Shift latent to lie on mean along given direction
 if center:
 y = normalize(d_sup) # assume ground truth
 dotp = np.sum((lat_base - lat_mean) * y, axis=-1, keepdims=True)
 lat_base = lat_base - dotp * y
 
 # Convert single delta to per-layer delta (to support Steerability StyleGAN)
 if delta.shape[0] > 1:
 #print('Unstacking delta')
 *d_per_layer, = delta # might have per-layer scales, don't normalize
 else:
 d_per_layer = [normalize(delta)]*model.get_max_latents()
 
 frames = []
 n_frames = 5
 for a in np.linspace(-1.0, 1.0, n_frames):
 w = [lat_base]*model.get_max_latents()
 for l in range(l_range[0], l_range[1]):
 w[l] = w[l] + a*d_per_layer[l]*scale
 frames.append(model.sample_np(w))

 for i, frame in enumerate(frames):
 Image.fromarray(np.uint8(frame*255)).save(
 out_root / imgclass / f'{prefix}_{name}_{seed}_{i}.png')
 
 strip = np.hstack(pad_frames(frames, 64))
 plt.figure(figsize=(12,12))
 plt.imshow(strip)
 plt.axis('off')
 plt.tight_layout()
 plt.title(f'{prefix} - {name}, scale={scale}')
 plt.show()

In [None]:
# BigGAN-512

inst = get_instrumented_model('BigGAN-512', 'husky', 'generator.gen_z', device, inst=inst)
model = inst.model

K = model.get_max_latents()
pc_config = Config(components=128, n=1_000_000,
 layer='generator.gen_z', model='BigGAN-512', output_class='husky')
dump_name = get_or_compute(pc_config, inst)

with np.load(dump_name) as data:
 lat_comp = data['lat_comp']
 lat_mean = data['lat_mean']

with open('data/steerability/biggan_deep_512/gan_steer-linear_zoom_512.pkl', 'rb') as f:
 delta_steerability_zoom = pickle.load(f)['w_zoom'].reshape(1, 128)
with open('data/steerability/biggan_deep_512/gan_steer-linear_shiftx_512.pkl', 'rb') as f:
 delta_steerability_transl = pickle.load(f)['w_shiftx'].reshape(1, 128)

# Indices determined by visual inspection
delta_ours_transl = lat_comp[0]
delta_ours_zoom = lat_comp[6]

model.truncation = 0.6
compute(model, lat_mean, 'zoom', 'robin', [560157313], delta_ours_zoom, 0, K, -3.0, delta_steerability_zoom, 5.5)
compute(model, lat_mean, 'zoom', 'ship', [107715983], delta_ours_zoom, 0, K, -3.0, delta_steerability_zoom, 5.0)

compute(model, lat_mean, 'translate', 'golden_retriever', [552411435], delta_ours_transl, 0, K, -2.0, delta_steerability_transl, 4.5)
compute(model, lat_mean, 'translate', 'lemon', [331582800], delta_ours_transl, 0, K, -3.0, delta_steerability_transl, 6.0)

In [None]:
# StyleGAN1-ffhq (InterfaceGAN)

inst = get_instrumented_model('StyleGAN', 'ffhq', 'g_mapping', device, use_w=True, inst=inst)
model = inst.model

K = model.get_max_latents()
pc_config = Config(components=128, n=1_000_000, use_w=True,
 layer='g_mapping', model='StyleGAN', output_class='ffhq')
dump_name = get_or_compute(pc_config, inst)

with np.load(dump_name) as data:
 lat_comp = data['lat_comp']
 lat_mean = data['lat_mean']

# SG-ffhq-w, non-conditional
d_ffhq_pose = np.load('data/interfacegan/stylegan_ffhq_pose_w_boundary.npy').astype(np.float32)
d_ffhq_smile = np.load('data/interfacegan/stylegan_ffhq_smile_w_boundary.npy').astype(np.float32)
d_ffhq_gender = np.load('data/interfacegan/stylegan_ffhq_gender_w_boundary.npy').astype(np.float32)
d_ffhq_glasses = np.load('data/interfacegan/stylegan_ffhq_eyeglasses_w_boundary.npy').astype(np.float32)

# Indices determined by visual inspection
d_ours_pose = lat_comp[9]
d_ours_smile = lat_comp[44]
d_ours_gender = lat_comp[0]
d_ours_glasses = lat_comp[12]

model.truncation = 1.0 # NOT IMPLEMENTED
compute(model, lat_mean, 'pose', 'ffhq', [440608316, 1811098088, 129888612], d_ours_pose, 0, 7, -1.0, d_ffhq_pose, 1.0)
compute(model, lat_mean, 'smile', 'ffhq', [1759734403, 1647189561, 70163682], d_ours_smile, 3, 4, -8.5, d_ffhq_smile, 1.0)
compute(model, lat_mean, 'gender', 'ffhq', [1302836080, 1746672325], d_ours_gender, 2, 6, -4.5, d_ffhq_gender, 1.5)
compute(model, lat_mean, 'glasses', 'ffhq', [1565213752, 1005764659, 1110182583], d_ours_glasses, 0, 2, 4.0, d_ffhq_glasses, 1.0)

In [None]:
# StyleGAN1-ffhq (Steerability)

inst = get_instrumented_model('StyleGAN', 'ffhq', 'g_mapping', device, use_w=True, inst=inst)
model = inst.model

K = model.get_max_latents()
pc_config = Config(components=128, n=1_000_000, use_w=True,
 layer='g_mapping', model='StyleGAN', output_class='ffhq')
dump_name = get_or_compute(pc_config, inst)

with np.load(dump_name) as data:
 lat_comp = data['lat_comp']
 lat_mean = data['lat_mean']

# SG-ffhq-w, non-conditional
# Shapes: [18, 512]
d_ffhq_R = np.load('data/steerability/stylegan_ffhq/ffhq_rgb_0.npy').astype(np.float32)
d_ffhq_G = np.load('data/steerability/stylegan_ffhq/ffhq_rgb_1.npy').astype(np.float32)
d_ffhq_B = np.load('data/steerability/stylegan_ffhq/ffhq_rgb_2.npy').astype(np.float32)

# Indices determined by visual inspection
d_ours_R = lat_comp[0]
d_ours_G = -lat_comp[1]
d_ours_B = -lat_comp[2]

model.truncation = 1.0 # NOT IMPLEMENTED
compute(model, lat_mean, 'red', 'ffhq', [5], d_ours_R, 17, 18, 8.0, d_ffhq_R, 1.0, center=False)
compute(model, lat_mean, 'green', 'ffhq', [5], d_ours_G, 17, 18, 15.0, d_ffhq_G, 1.0, center=False)
compute(model, lat_mean, 'blue', 'ffhq', [5], d_ours_B, 17, 18, 10.0, d_ffhq_B, 1.0, center=False)

In [None]:
# StyleGAN1-celebahq (InterfaceGAN)

inst = get_instrumented_model('StyleGAN', 'celebahq', 'g_mapping', device, use_w=True, inst=inst)
model = inst.model

K = model.get_max_latents()
pc_config = Config(components=128, n=1_000_000, use_w=True,
 layer='g_mapping', model='StyleGAN', output_class='celebahq')
dump_name = get_or_compute(pc_config, inst)

with np.load(dump_name) as data:
 lat_comp = data['lat_comp']
 lat_mean = data['lat_mean']

# SG-ffhq-w, non-conditional
d_celebahq_pose = np.load('data/interfacegan/stylegan_celebahq_pose_w_boundary.npy').astype(np.float32)
d_celebahq_smile = np.load('data/interfacegan/stylegan_celebahq_smile_w_boundary.npy').astype(np.float32)
d_celebahq_gender = np.load('data/interfacegan/stylegan_celebahq_gender_w_boundary.npy').astype(np.float32)
d_celebahq_glasses = np.load('data/interfacegan/stylegan_celebahq_eyeglasses_w_boundary.npy').astype(np.float32)

# Indices determined by visual inspection
d_ours_pose = lat_comp[7]
d_ours_smile = lat_comp[14]
d_ours_gender = lat_comp[1]
d_ours_glasses = lat_comp[5]

model.truncation = 1.0 # NOT IMPLEMENTED
compute(model, lat_mean, 'pose', 'celebahq', [1939067252, 1460055449, 329555154], d_ours_pose, 0, 7, -1.0, d_celebahq_pose, 1.0)
compute(model, lat_mean, 'smile', 'celebahq', [329187806, 424805522, 1777796971], d_ours_smile, 3, 4, -7.0, d_celebahq_smile, 1.3)
compute(model, lat_mean, 'gender', 'celebahq', [1144615644, 967075839, 264878205], d_ours_gender, 0, 2, -3.2, d_celebahq_gender, 1.2)
compute(model, lat_mean, 'glasses', 'celebahq', [991993380, 594344173, 2119328990, 1919124025], d_ours_glasses, 0, 1, -10.0, d_celebahq_glasses, 2.0) # hard for both

In [None]:
# StyleGAN1-cars (Steerability)

inst = get_instrumented_model('StyleGAN', 'cars', 'g_mapping', device, use_w=True, inst=inst)
model = inst.model

K = model.get_max_latents()
pc_config = Config(components=128, n=1_000_000, use_w=True,
 layer='g_mapping', model='StyleGAN', output_class='cars')
dump_name = get_or_compute(pc_config, inst)

with np.load(dump_name) as data:
 lat_comp = data['lat_comp']
 lat_mean = data['lat_mean']

# Shapes: [16, 512]
d_cars_rot = np.load('data/steerability/stylegan_cars/rotate2d.npy').astype(np.float32)
d_cars_shift = np.load('data/steerability/stylegan_cars/shifty.npy').astype(np.float32)

# Add two final layers
d_cars_rot = np.append(d_cars_rot, np.zeros((2,512), dtype=np.float32), axis=0)
d_cars_shift = np.append(d_cars_shift, np.zeros((2,512), dtype=np.float32), axis=0)

print(d_cars_rot.shape)

# Indices determined by visual inspection
d_ours_rot = lat_comp[0]
d_ours_shift = lat_comp[7]

model.truncation = 1.0 # NOT IMPLEMENTED
compute(model, lat_mean, 'rotate2d', 'cars', [46, 28], d_ours_rot, 0, 1, 1.0, d_cars_rot, 1.0, center=False)
compute(model, lat_mean, 'shifty', 'cars', [0, 13], d_ours_shift, 1, 2, 4.0, d_cars_shift, 1.0, center=False)