sefa / app.py
Johannes Kolbe
added some more info
d5205ed
# python 3.7
"""Demo."""
import numpy as np
import torch
import streamlit as st
import SessionState
from models import parse_gan_type
from utils import to_tensor
from utils import postprocess
from utils import load_generator
from utils import factorize_weight
@st.cache(allow_output_mutation=True, show_spinner=False)
def get_model(model_name):
"""Gets model by name."""
return load_generator(model_name, from_hf_hub=True)
@st.cache(allow_output_mutation=True, show_spinner=False)
def factorize_model(model, layer_idx):
"""Factorizes semantics from target layers of the given model."""
return factorize_weight(model, layer_idx)
def sample(model, gan_type, num=1):
"""Samples latent codes."""
codes = torch.randn(num, model.z_space_dim)
if gan_type == 'pggan':
codes = model.layer0.pixel_norm(codes)
elif gan_type == 'stylegan':
codes = model.mapping(codes)['w']
codes = model.truncation(codes,
trunc_psi=0.7,
trunc_layers=8)
elif gan_type == 'stylegan2':
codes = model.mapping(codes)['w']
codes = model.truncation(codes,
trunc_psi=0.5,
trunc_layers=18)
codes = codes.detach().cpu().numpy()
return codes
@st.cache(allow_output_mutation=True, show_spinner=False)
def synthesize(model, gan_type, code):
"""Synthesizes an image with the give code."""
if gan_type == 'pggan':
image = model(to_tensor(code))['image']
elif gan_type in ['stylegan', 'stylegan2']:
image = model.synthesis(to_tensor(code))['image']
image = postprocess(image)[0]
return image
def _update_slider():
num_semantics = st.session_state["num_semantics"]
for sem_idx in range(num_semantics):
st.session_state[f"semantic_slider_{sem_idx}"] = 0
"""Main function (loop for StreamLit)."""
st.title('Closed-Form Factorization of Latent Semantics in GANs')
st.markdown("This space is the ported version of [Closed-Form Factorization of Latent Semantics in GANs](https://github.com/genforce/sefa). It reads all sample models from the Hugging Face Hub")
st.markdown("---")
st.sidebar.title('Options')
st.sidebar.button('Reset', on_click=_update_slider, kwargs={})
model_name = st.sidebar.selectbox(
'Model to Interpret',
['pggan_celebahq1024', 'stylegan_animeface512', 'stylegan_car512', 'stylegan_cat256'])
model = get_model(model_name)
gan_type = parse_gan_type(model)
layer_idx = st.sidebar.selectbox(
'Layers to Interpret',
['all', '0-1', '2-5', '6-13'])
layers, boundaries, eigen_values = factorize_model(model, layer_idx)
num_semantics = st.sidebar.number_input(
'Number of semantics', value=5, min_value=0, max_value=None, step=1, key="num_semantics")
steps = {sem_idx: 0 for sem_idx in range(num_semantics)}
if gan_type == 'pggan':
max_step = 5.0
elif gan_type == 'stylegan':
max_step = 2.0
elif gan_type == 'stylegan2':
max_step = 15.0
for sem_idx in steps:
eigen_value = eigen_values[sem_idx]
steps[sem_idx] = st.sidebar.slider(
f'Semantic {sem_idx:03d} (eigen value: {eigen_value:.3f})',
value=0.0,
min_value=-max_step,
max_value=max_step,
step=0.04 * max_step,
key=f"semantic_slider_{sem_idx}")
image_placeholder = st.empty()
button_placeholder = st.empty()
button_totally_random = st.empty()
try:
base_codes = np.load(f'latent_codes/{model_name}_latents.npy')
except FileNotFoundError:
base_codes = sample(model, gan_type)
state = SessionState.get(model_name=model_name,
code_idx=0,
codes=base_codes[0:1])
if state.model_name != model_name:
state.model_name = model_name
state.code_idx = 0
state.codes = base_codes[0:1]
if button_placeholder.button('Next Sample'):
state.code_idx += 1
if state.code_idx < base_codes.shape[0]:
state.codes = base_codes[state.code_idx][np.newaxis]
else:
state.codes = sample(model, gan_type)
if button_totally_random.button('Totally Random'):
state.codes = sample(model, gan_type)
code = state.codes.copy()
for sem_idx, step in steps.items():
if gan_type == 'pggan':
code += boundaries[sem_idx:sem_idx + 1] * step
elif gan_type in ['stylegan', 'stylegan2']:
code[:, layers, :] += boundaries[sem_idx:sem_idx + 1] * step
image = synthesize(model, gan_type, code)
image_placeholder.image(image / 255.0)
st.markdown("---")
st.markdown("""This space was created by [johko](https://twitter.com/johko990). Main credits go to the original authors Yujun Shen and Bolei Zhou, who created a great code base to work on.
This version loads all models from the Hugging Face Hub.""")