File size: 3,915 Bytes
3b72cdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1da9de
3b72cdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# python 3.7
"""Demo."""

import numpy as np
import torch
import streamlit as st
import SessionState

from models import parse_gan_type
from utils import to_tensor
from utils import postprocess
from utils import load_generator
from utils import factorize_weight


@st.cache(allow_output_mutation=True, show_spinner=False)
def get_model(model_name):
    """Gets model by name."""
    return load_generator(model_name)


@st.cache(allow_output_mutation=True, show_spinner=False)
def factorize_model(model, layer_idx):
    """Factorizes semantics from target layers of the given model."""
    return factorize_weight(model, layer_idx)


def sample(model, gan_type, num=1):
    """Samples latent codes."""
    codes = torch.randn(num, model.z_space_dim)
    if gan_type == 'pggan':
        codes = model.layer0.pixel_norm(codes)
    elif gan_type == 'stylegan':
        codes = model.mapping(codes)['w']
        codes = model.truncation(codes,
                                 trunc_psi=0.7,
                                 trunc_layers=8)
    elif gan_type == 'stylegan2':
        codes = model.mapping(codes)['w']
        codes = model.truncation(codes,
                                 trunc_psi=0.5,
                                 trunc_layers=18)
    codes = codes.detach().cpu().numpy()
    return codes


@st.cache(allow_output_mutation=True, show_spinner=False)
def synthesize(model, gan_type, code):
    """Synthesizes an image with the give code."""
    if gan_type == 'pggan':
        image = model(to_tensor(code))['image']
    elif gan_type in ['stylegan', 'stylegan2']:
        image = model.synthesis(to_tensor(code))['image']
    image = postprocess(image)[0]
    return image


"""Main function (loop for StreamLit)."""
st.title('Closed-Form Factorization of Latent Semantics in GANs')
st.sidebar.title('Options')
reset = st.sidebar.button('Reset')

model_name = st.sidebar.selectbox(
    'Model to Interpret',
    ['stylegan_animeface512', 'stylegan_car512', 'stylegan_cat256',
        'pggan_celebahq1024'])

model = get_model(model_name)
gan_type = parse_gan_type(model)
layer_idx = st.sidebar.selectbox(
    'Layers to Interpret',
    ['all', '0-1', '2-5', '6-13'])
layers, boundaries, eigen_values = factorize_model(model, layer_idx)

num_semantics = st.sidebar.number_input(
    'Number of semantics', value=10, min_value=0, max_value=None, step=1)
steps = {sem_idx: 0 for sem_idx in range(num_semantics)}
if gan_type == 'pggan':
    max_step = 5.0
elif gan_type == 'stylegan':
    max_step = 2.0
elif gan_type == 'stylegan2':
    max_step = 15.0
for sem_idx in steps:
    eigen_value = eigen_values[sem_idx]
    steps[sem_idx] = st.sidebar.slider(
        f'Semantic {sem_idx:03d} (eigen value: {eigen_value:.3f})',
        value=0.0,
        min_value=-max_step,
        max_value=max_step,
        step=0.04 * max_step if not reset else 0.0)

image_placeholder = st.empty()
button_placeholder = st.empty()

try:
    base_codes = np.load(f'latent_codes/{model_name}_latents.npy')
except FileNotFoundError:
    base_codes = sample(model, gan_type)

state = SessionState.get(model_name=model_name,
                            code_idx=0,
                            codes=base_codes[0:1])
if state.model_name != model_name:
    state.model_name = model_name
    state.code_idx = 0
    state.codes = base_codes[0:1]

if button_placeholder.button('Random', key=0):
    state.code_idx += 1
    if state.code_idx < base_codes.shape[0]:
        state.codes = base_codes[state.code_idx][np.newaxis]
    else:
        state.codes = sample(model, gan_type)

code = state.codes.copy()
for sem_idx, step in steps.items():
    if gan_type == 'pggan':
        code += boundaries[sem_idx:sem_idx + 1] * step
    elif gan_type in ['stylegan', 'stylegan2']:
        code[:, layers, :] += boundaries[sem_idx:sem_idx + 1] * step
image = synthesize(model, gan_type, code)
image_placeholder.image(image / 255.0)