|
|
|
"""Demo.""" |
|
|
|
import numpy as np |
|
import torch |
|
import streamlit as st |
|
import SessionState |
|
|
|
from models import parse_gan_type |
|
from utils import to_tensor |
|
from utils import postprocess |
|
from utils import load_generator |
|
from utils import factorize_weight |
|
|
|
|
|
@st.cache(allow_output_mutation=True, show_spinner=False) |
|
def get_model(model_name): |
|
"""Gets model by name.""" |
|
return load_generator(model_name, from_hf_hub=True) |
|
|
|
|
|
@st.cache(allow_output_mutation=True, show_spinner=False) |
|
def factorize_model(model, layer_idx): |
|
"""Factorizes semantics from target layers of the given model.""" |
|
return factorize_weight(model, layer_idx) |
|
|
|
|
|
def sample(model, gan_type, num=1): |
|
"""Samples latent codes.""" |
|
codes = torch.randn(num, model.z_space_dim) |
|
if gan_type == 'pggan': |
|
codes = model.layer0.pixel_norm(codes) |
|
elif gan_type == 'stylegan': |
|
codes = model.mapping(codes)['w'] |
|
codes = model.truncation(codes, |
|
trunc_psi=0.7, |
|
trunc_layers=8) |
|
elif gan_type == 'stylegan2': |
|
codes = model.mapping(codes)['w'] |
|
codes = model.truncation(codes, |
|
trunc_psi=0.5, |
|
trunc_layers=18) |
|
codes = codes.detach().cpu().numpy() |
|
return codes |
|
|
|
|
|
@st.cache(allow_output_mutation=True, show_spinner=False) |
|
def synthesize(model, gan_type, code): |
|
"""Synthesizes an image with the give code.""" |
|
if gan_type == 'pggan': |
|
image = model(to_tensor(code))['image'] |
|
elif gan_type in ['stylegan', 'stylegan2']: |
|
image = model.synthesis(to_tensor(code))['image'] |
|
image = postprocess(image)[0] |
|
return image |
|
|
|
|
|
"""Main function (loop for StreamLit).""" |
|
st.title('Closed-Form Factorization of Latent Semantics in GANs') |
|
st.sidebar.title('Options') |
|
reset = st.sidebar.button('Reset') |
|
|
|
model_name = st.sidebar.selectbox( |
|
'Model to Interpret', |
|
['pggan_celebahq1024', 'stylegan_animeface512', 'stylegan_car512', 'stylegan_cat256']) |
|
|
|
model = get_model(model_name) |
|
gan_type = parse_gan_type(model) |
|
layer_idx = st.sidebar.selectbox( |
|
'Layers to Interpret', |
|
['all', '0-1', '2-5', '6-13']) |
|
layers, boundaries, eigen_values = factorize_model(model, layer_idx) |
|
|
|
num_semantics = st.sidebar.number_input( |
|
'Number of semantics', value=5, min_value=0, max_value=None, step=1) |
|
steps = {sem_idx: 0 for sem_idx in range(num_semantics)} |
|
if gan_type == 'pggan': |
|
max_step = 5.0 |
|
elif gan_type == 'stylegan': |
|
max_step = 2.0 |
|
elif gan_type == 'stylegan2': |
|
max_step = 15.0 |
|
for sem_idx in steps: |
|
eigen_value = eigen_values[sem_idx] |
|
steps[sem_idx] = st.sidebar.slider( |
|
f'Semantic {sem_idx:03d} (eigen value: {eigen_value:.3f})', |
|
value=0.0, |
|
min_value=-max_step, |
|
max_value=max_step, |
|
step=0.04 * max_step if not reset else 0.0) |
|
|
|
image_placeholder = st.empty() |
|
button_placeholder = st.empty() |
|
|
|
try: |
|
base_codes = np.load(f'latent_codes/{model_name}_latents.npy') |
|
except FileNotFoundError: |
|
base_codes = sample(model, gan_type) |
|
|
|
state = SessionState.get(model_name=model_name, |
|
code_idx=0, |
|
codes=base_codes[0:1]) |
|
if state.model_name != model_name: |
|
state.model_name = model_name |
|
state.code_idx = 0 |
|
state.codes = base_codes[0:1] |
|
|
|
if button_placeholder.button('Random', key=0): |
|
state.code_idx += 1 |
|
if state.code_idx < base_codes.shape[0]: |
|
state.codes = base_codes[state.code_idx][np.newaxis] |
|
else: |
|
state.codes = sample(model, gan_type) |
|
|
|
code = state.codes.copy() |
|
for sem_idx, step in steps.items(): |
|
if gan_type == 'pggan': |
|
code += boundaries[sem_idx:sem_idx + 1] * step |
|
elif gan_type in ['stylegan', 'stylegan2']: |
|
code[:, layers, :] += boundaries[sem_idx:sem_idx + 1] * step |
|
image = synthesize(model, gan_type, code) |
|
image_placeholder.image(image / 255.0) |
|
|