File size: 3,922 Bytes
3b72cdb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
# python 3.7
"""Demo."""
import numpy as np
import torch
import streamlit as st
import SessionState
from models import parse_gan_type
from utils import to_tensor
from utils import postprocess
from utils import load_generator
from utils import factorize_weight
@st.cache(allow_output_mutation=True, show_spinner=False)
def get_model(model_name):
"""Gets model by name."""
return load_generator(model_name)
@st.cache(allow_output_mutation=True, show_spinner=False)
def factorize_model(model, layer_idx):
"""Factorizes semantics from target layers of the given model."""
return factorize_weight(model, layer_idx)
def sample(model, gan_type, num=1):
"""Samples latent codes."""
codes = torch.randn(num, model.z_space_dim).cuda()
if gan_type == 'pggan':
codes = model.layer0.pixel_norm(codes)
elif gan_type == 'stylegan':
codes = model.mapping(codes)['w']
codes = model.truncation(codes,
trunc_psi=0.7,
trunc_layers=8)
elif gan_type == 'stylegan2':
codes = model.mapping(codes)['w']
codes = model.truncation(codes,
trunc_psi=0.5,
trunc_layers=18)
codes = codes.detach().cpu().numpy()
return codes
@st.cache(allow_output_mutation=True, show_spinner=False)
def synthesize(model, gan_type, code):
"""Synthesizes an image with the give code."""
if gan_type == 'pggan':
image = model(to_tensor(code))['image']
elif gan_type in ['stylegan', 'stylegan2']:
image = model.synthesis(to_tensor(code))['image']
image = postprocess(image)[0]
return image
"""Main function (loop for StreamLit)."""
st.title('Closed-Form Factorization of Latent Semantics in GANs')
st.sidebar.title('Options')
reset = st.sidebar.button('Reset')
model_name = st.sidebar.selectbox(
'Model to Interpret',
['stylegan_animeface512', 'stylegan_car512', 'stylegan_cat256',
'pggan_celebahq1024'])
model = get_model(model_name)
gan_type = parse_gan_type(model)
layer_idx = st.sidebar.selectbox(
'Layers to Interpret',
['all', '0-1', '2-5', '6-13'])
layers, boundaries, eigen_values = factorize_model(model, layer_idx)
num_semantics = st.sidebar.number_input(
'Number of semantics', value=10, min_value=0, max_value=None, step=1)
steps = {sem_idx: 0 for sem_idx in range(num_semantics)}
if gan_type == 'pggan':
max_step = 5.0
elif gan_type == 'stylegan':
max_step = 2.0
elif gan_type == 'stylegan2':
max_step = 15.0
for sem_idx in steps:
eigen_value = eigen_values[sem_idx]
steps[sem_idx] = st.sidebar.slider(
f'Semantic {sem_idx:03d} (eigen value: {eigen_value:.3f})',
value=0.0,
min_value=-max_step,
max_value=max_step,
step=0.04 * max_step if not reset else 0.0)
image_placeholder = st.empty()
button_placeholder = st.empty()
try:
base_codes = np.load(f'latent_codes/{model_name}_latents.npy')
except FileNotFoundError:
base_codes = sample(model, gan_type)
state = SessionState.get(model_name=model_name,
code_idx=0,
codes=base_codes[0:1])
if state.model_name != model_name:
state.model_name = model_name
state.code_idx = 0
state.codes = base_codes[0:1]
if button_placeholder.button('Random', key=0):
state.code_idx += 1
if state.code_idx < base_codes.shape[0]:
state.codes = base_codes[state.code_idx][np.newaxis]
else:
state.codes = sample(model, gan_type)
code = state.codes.copy()
for sem_idx, step in steps.items():
if gan_type == 'pggan':
code += boundaries[sem_idx:sem_idx + 1] * step
elif gan_type in ['stylegan', 'stylegan2']:
code[:, layers, :] += boundaries[sem_idx:sem_idx + 1] * step
image = synthesize(model, gan_type, code)
image_placeholder.image(image / 255.0)
|