In [None]:
# Copyright 2020 Erik Härkönen. All rights reserved.
# This file is licensed to you under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. You may obtain a copy
# of the License at http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software distributed under
# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS
# OF ANY KIND, either express or implied. See the License for the specific language
# governing permissions and limitations under the License.

# Figure: BigGAN edit transferability between classes
%matplotlib inline
from notebook_init import *

rand = lambda : np.random.randint(np.iinfo(np.int32).max)
outdir = Path('out/figures/edit_transferability')
makedirs(outdir, exist_ok=True)

In [None]:
inst = get_instrumented_model('BigGAN-512', 'husky', 'generator.gen_z', device, inst=inst)
model = inst.model
model.truncation = 0.7

pc_config = Config(components=80, n=1_000_000,
 layer='generator.gen_z', model='BigGAN-512', output_class='husky')
dump_name = get_or_compute(pc_config, inst)

with np.load(dump_name) as data:
 lat_comp = data['lat_comp']
 lat_mean = data['lat_mean']
 lat_std = data['lat_stdev']

# name: component_idx, layer_start, layer_end, strength
edits = {
 'translate_x': ( 0, 0, 15, -3.0),
 'zoom': ( 6, 0, 15, 2.0),
 'clouds': (54, 7, 10, 15.0),
 #'dark_fg': (51, 7, 10, 20.0),
 'sunlight': (33, 7, 10, 25.0),
 #'silouette': (13, 7, 10, -20.0),
 #'grass_bg': (69, 3, 7, -20.0),
}

def apply_offset(z, idx, start, end, sigma):
 lat = z if isinstance(z, list) else [z]*model.get_max_latents()
 for i in range(start, end):
 lat[i] = lat[i] + lat_comp[idx]*lat_std[idx]*sigma
 return lat

show = True

# good geom seeds: 2145371585
# good style seeds: 337336281, 2075156369, 311784160

for _ in range(1):
 
 # Type 1: geometric edit - transfers well
 
 seed1_geom = 2145371585
 seed2_geom = 2046317118
 print('Seeds geom:', [seed1_geom, seed2_geom])
 z1 = model.sample_latent(1, seed=seed1_geom).cpu().numpy()
 z2 = model.sample_latent(1, seed=seed2_geom).cpu().numpy()

 model.set_output_class('husky')
 base_husky = model.sample_np(z1)
 zoom_husky = model.sample_np(apply_offset(z1, *edits['zoom']))
 transl_husky = model.sample_np(apply_offset(z1, *edits['translate_x']))
 img_geom1 = np.hstack([base_husky, zoom_husky, transl_husky])

 model.set_output_class('castle')
 base_castle = model.sample_np(z2)
 zoom_castle = model.sample_np(apply_offset(z2, *edits['zoom']))
 transl_castle = model.sample_np(apply_offset(z2, *edits['translate_x']))
 img_geom2 = np.hstack([base_castle, zoom_castle, transl_castle])

 
 # Type 2: style edit - often transfers
 
 seed1_style = 417482011 #rand()
 seed2_style = 1026291813
 print('Seeds style:', [seed1_style, seed2_style])
 z1 = model.sample_latent(1, seed=seed1_style).cpu().numpy()
 z2 = model.sample_latent(1, seed=seed2_style).cpu().numpy()

 model.set_output_class('lighthouse')
 base_lighthouse = model.sample_np(z2)
 edit1_lighthouse = model.sample_np(apply_offset(z2, *edits['clouds']))
 edit2_lighthouse = model.sample_np(apply_offset(z2, *edits['sunlight']))
 img_style2 = np.hstack([base_lighthouse, edit1_lighthouse, edit2_lighthouse])
 
 model.set_output_class('barn')
 base_barn = model.sample_np(z1)
 edit1_barn = model.sample_np(apply_offset(z1, *edits['clouds']))
 edit2_barn = model.sample_np(apply_offset(z1, *edits['sunlight']))
 img_style1 = np.hstack([base_barn, edit1_barn, edit2_barn])
 
 
 grid = np.vstack([img_geom1, img_geom2, img_style1, img_style2])
 
 if show:
 plt.figure(figsize=(12,12))
 plt.imshow(grid)
 plt.axis('off')
 plt.show()
 else:
 Image.fromarray((255*grid).astype(np.uint8)).save(outdir / f'{seed1_geom}_{seed2_geom}_{seed1_style}_{seed2_style}_transf.jpg')
 
 # Save individual frames
 Image.fromarray((255*base_husky).astype(np.uint8)).save(outdir / 'geom_husky_1.png')
 Image.fromarray((255*zoom_husky).astype(np.uint8)).save(outdir / 'geom_husky_2.png')
 Image.fromarray((255*transl_husky).astype(np.uint8)).save(outdir / 'geom_husky_3.png')
 Image.fromarray((255*base_castle).astype(np.uint8)).save(outdir / 'geom_castle_1.png')
 Image.fromarray((255*zoom_castle).astype(np.uint8)).save(outdir / 'geom_castle_2.png')
 Image.fromarray((255*transl_castle).astype(np.uint8)).save(outdir / 'geom_castle_3.png')
 
 Image.fromarray((255*base_lighthouse).astype(np.uint8)).save(outdir / 'style_lighthouse_1.png')
 Image.fromarray((255*edit1_lighthouse).astype(np.uint8)).save(outdir / 'style_lighthouse_2.png')
 Image.fromarray((255*edit2_lighthouse).astype(np.uint8)).save(outdir / 'style_lighthouse_3.png')
 Image.fromarray((255*base_barn).astype(np.uint8)).save(outdir / 'style_barn_1.png')
 Image.fromarray((255*edit1_barn).astype(np.uint8)).save(outdir / 'style_barn_2.png')
 Image.fromarray((255*edit2_barn).astype(np.uint8)).save(outdir / 'style_barn_3.png')
 
