File size: 4,799 Bytes
3b72cdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed6b6d6
3b72cdb
 
 
 
 
 
 
 
 
 
b1da9de
3b72cdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7db1e87
 
 
 
 
3b72cdb
 
 
d5205ed
 
 
3b72cdb
7db1e87
3b72cdb
 
 
dd2f594
3b72cdb
 
 
 
 
 
 
 
 
7db1e87
3b72cdb
 
 
 
 
 
 
 
 
 
 
 
 
 
7db1e87
 
3b72cdb
 
 
7db1e87
3b72cdb
 
 
 
 
 
 
 
 
 
 
 
 
 
7db1e87
3b72cdb
 
 
 
 
 
7db1e87
 
 
3b72cdb
 
 
 
 
 
 
 
d5205ed
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# python 3.7
"""Demo."""

import numpy as np
import torch
import streamlit as st
import SessionState

from models import parse_gan_type
from utils import to_tensor
from utils import postprocess
from utils import load_generator
from utils import factorize_weight


@st.cache(allow_output_mutation=True, show_spinner=False)
def get_model(model_name):
    """Gets model by name."""
    return load_generator(model_name, from_hf_hub=True)


@st.cache(allow_output_mutation=True, show_spinner=False)
def factorize_model(model, layer_idx):
    """Factorizes semantics from target layers of the given model."""
    return factorize_weight(model, layer_idx)


def sample(model, gan_type, num=1):
    """Samples latent codes."""
    codes = torch.randn(num, model.z_space_dim)
    if gan_type == 'pggan':
        codes = model.layer0.pixel_norm(codes)
    elif gan_type == 'stylegan':
        codes = model.mapping(codes)['w']
        codes = model.truncation(codes,
                                 trunc_psi=0.7,
                                 trunc_layers=8)
    elif gan_type == 'stylegan2':
        codes = model.mapping(codes)['w']
        codes = model.truncation(codes,
                                 trunc_psi=0.5,
                                 trunc_layers=18)
    codes = codes.detach().cpu().numpy()
    return codes


@st.cache(allow_output_mutation=True, show_spinner=False)
def synthesize(model, gan_type, code):
    """Synthesizes an image with the give code."""
    if gan_type == 'pggan':
        image = model(to_tensor(code))['image']
    elif gan_type in ['stylegan', 'stylegan2']:
        image = model.synthesis(to_tensor(code))['image']
    image = postprocess(image)[0]
    return image

def _update_slider():
    num_semantics = st.session_state["num_semantics"]
    for sem_idx in range(num_semantics):
        st.session_state[f"semantic_slider_{sem_idx}"] = 0


"""Main function (loop for StreamLit)."""
st.title('Closed-Form Factorization of Latent Semantics in GANs')
st.markdown("This space is the ported version of [Closed-Form Factorization of Latent Semantics in GANs](https://github.com/genforce/sefa). It reads all sample models from the Hugging Face Hub")
st.markdown("---")
    
st.sidebar.title('Options')
st.sidebar.button('Reset', on_click=_update_slider, kwargs={})

model_name = st.sidebar.selectbox(
    'Model to Interpret',
    ['pggan_celebahq1024', 'stylegan_animeface512', 'stylegan_car512', 'stylegan_cat256'])

model = get_model(model_name)
gan_type = parse_gan_type(model)
layer_idx = st.sidebar.selectbox(
    'Layers to Interpret',
    ['all', '0-1', '2-5', '6-13'])
layers, boundaries, eigen_values = factorize_model(model, layer_idx)

num_semantics = st.sidebar.number_input(
    'Number of semantics', value=5, min_value=0, max_value=None, step=1, key="num_semantics")
steps = {sem_idx: 0 for sem_idx in range(num_semantics)}
if gan_type == 'pggan':
    max_step = 5.0
elif gan_type == 'stylegan':
    max_step = 2.0
elif gan_type == 'stylegan2':
    max_step = 15.0
for sem_idx in steps:
    eigen_value = eigen_values[sem_idx]
    steps[sem_idx] = st.sidebar.slider(
        f'Semantic {sem_idx:03d} (eigen value: {eigen_value:.3f})',
        value=0.0,
        min_value=-max_step,
        max_value=max_step,
        step=0.04 * max_step,
        key=f"semantic_slider_{sem_idx}")

image_placeholder = st.empty()
button_placeholder = st.empty()
button_totally_random = st.empty()

try:
    base_codes = np.load(f'latent_codes/{model_name}_latents.npy')
except FileNotFoundError:
    base_codes = sample(model, gan_type)

state = SessionState.get(model_name=model_name,
                            code_idx=0,
                            codes=base_codes[0:1])
if state.model_name != model_name:
    state.model_name = model_name
    state.code_idx = 0
    state.codes = base_codes[0:1]

if button_placeholder.button('Next Sample'):
    state.code_idx += 1
    if state.code_idx < base_codes.shape[0]:
        state.codes = base_codes[state.code_idx][np.newaxis]
    else:
        state.codes = sample(model, gan_type)

if button_totally_random.button('Totally Random'):
    state.codes = sample(model, gan_type)

code = state.codes.copy()
for sem_idx, step in steps.items():
    if gan_type == 'pggan':
        code += boundaries[sem_idx:sem_idx + 1] * step
    elif gan_type in ['stylegan', 'stylegan2']:
        code[:, layers, :] += boundaries[sem_idx:sem_idx + 1] * step
image = synthesize(model, gan_type, code)
image_placeholder.image(image / 255.0)

st.markdown("---")
st.markdown("""This space was created by [johko](https://twitter.com/johko990). Main credits go to the original authors Yujun Shen and Bolei Zhou, who created a great code base to work on. 
            This version loads all models from the Hugging Face Hub.""")