File size: 5,120 Bytes
ff2b8e3
 
7db1e87
ff2b8e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed6b6d6
ff2b8e3
 
 
 
 
 
 
 
 
 
ed6b6d6
ff2b8e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7db1e87
 
 
 
 
 
ff2b8e3
 
7db1e87
ff2b8e3
d5205ed
 
ff2b8e3
7db1e87
ff2b8e3
 
 
ed6b6d6
ff2b8e3
 
 
 
 
 
 
 
 
7db1e87
ff2b8e3
 
 
 
 
 
 
 
 
 
 
 
 
 
7db1e87
 
ff2b8e3
 
 
7db1e87
ff2b8e3
 
 
 
 
 
 
 
 
 
 
 
 
 
7db1e87
ff2b8e3
 
 
 
 
 
7db1e87
 
 
ff2b8e3
 
 
 
 
 
 
 
 
d5205ed
 
 
ff2b8e3
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# python 3.7
"""Demo."""
import random

import numpy as np
import torch
import streamlit as st
import SessionState

from models import parse_gan_type
from utils import to_tensor
from utils import postprocess
from utils import load_generator
from utils import factorize_weight


@st.cache(allow_output_mutation=True, show_spinner=False)
def get_model(model_name):
    """Gets model by name."""
    return load_generator(model_name, from_hf_hub=True)


@st.cache(allow_output_mutation=True, show_spinner=False)
def factorize_model(model, layer_idx):
    """Factorizes semantics from target layers of the given model."""
    return factorize_weight(model, layer_idx)


def sample(model, gan_type, num=1):
    """Samples latent codes."""
    codes = torch.randn(num, model.z_space_dim)
    if gan_type == 'pggan':
        codes = model.layer0.pixel_norm(codes)
    elif gan_type == 'stylegan':
        codes = model.mapping(codes)['w']
        codes = model.truncation(codes,
                                 trunc_psi=0.7,
                                 trunc_layers=8)
    elif gan_type == 'stylegan2':
        codes = model.mapping(codes)['w']
        codes = model.truncation(codes,
                                 trunc_psi=0.5,
                                 trunc_layers=18)
    codes = codes.detach().cpu().numpy()
    return codes


@st.cache(allow_output_mutation=True, show_spinner=False)
def synthesize(model, gan_type, code):
    """Synthesizes an image with the give code."""
    if gan_type == 'pggan':
        image = model(to_tensor(code))['image']
    elif gan_type in ['stylegan', 'stylegan2']:
        image = model.synthesis(to_tensor(code))['image']
    image = postprocess(image)[0]
    return image


def _update_slider():
    num_semantics = st.session_state["num_semantics"]
    for sem_idx in range(num_semantics):
        st.session_state[f"semantic_slider_{sem_idx}"] = 0


def main():
    """Main function (loop for StreamLit)."""

    st.title('Closed-Form Factorization of Latent Semantics in GANs')
    st.markdown("This space is the ported version of [Closed-Form Factorization of Latent Semantics in GANs](https://github.com/genforce/sefa). It reads all sample models from the Hugging Face Hub")
    st.markdown("---")
    st.sidebar.title('Options')
    st.sidebar.button('Reset', on_click=_update_slider, kwargs={})

    model_name = st.sidebar.selectbox(
        'Model to Interpret',
        ['pggan_celebahq1024', 'stylegan_animeface512', 'stylegan_car512', 'stylegan_cat256',])

    model = get_model(model_name)
    gan_type = parse_gan_type(model)
    layer_idx = st.sidebar.selectbox(
        'Layers to Interpret',
        ['all', '0-1', '2-5', '6-13'])
    layers, boundaries, eigen_values = factorize_model(model, layer_idx)

    num_semantics = st.sidebar.number_input(
        'Number of semantics', value=5, min_value=0, max_value=None, step=1, key="num_semantics")
    steps = {sem_idx: 0 for sem_idx in range(num_semantics)}
    if gan_type == 'pggan':
        max_step = 5.0
    elif gan_type == 'stylegan':
        max_step = 2.0
    elif gan_type == 'stylegan2':
        max_step = 15.0
    for sem_idx in steps:
        eigen_value = eigen_values[sem_idx]
        steps[sem_idx] = st.sidebar.slider(
            f'Semantic {sem_idx:03d} (eigen value: {eigen_value:.3f})',
            value=0.0,
            min_value=-max_step,
            max_value=max_step,
            step=0.04 * max_step,
            key=f"semantic_slider_{sem_idx}")

    image_placeholder = st.empty()
    button_placeholder = st.empty()
    button_totally_random = st.empty()

    try:
        base_codes = np.load(f'latent_codes/{model_name}_latents.npy')
    except FileNotFoundError:
        base_codes = sample(model, gan_type)

    state = SessionState.get(model_name=model_name,
                             code_idx=0,
                             codes=base_codes[0:1])
    if state.model_name != model_name:
        state.model_name = model_name
        state.code_idx = 0
        state.codes = base_codes[0:1]

    if button_placeholder.button('Next Sample'):
        state.code_idx += 1
        if state.code_idx < base_codes.shape[0]:
            state.codes = base_codes[state.code_idx][np.newaxis]
        else:
            state.codes = sample(model, gan_type)

    if button_totally_random.button('Totally Random'):
        state.codes = sample(model, gan_type)

    code = state.codes.copy()
    for sem_idx, step in steps.items():
        if gan_type == 'pggan':
            code += boundaries[sem_idx:sem_idx + 1] * step
        elif gan_type in ['stylegan', 'stylegan2']:
            code[:, layers, :] += boundaries[sem_idx:sem_idx + 1] * step
    image = synthesize(model, gan_type, code)
    image_placeholder.image(image / 255.0)

    st.markdown("---")
    st.markdown("""This space was created by [johko](https://twitter.com/johko990). Main credits go to the original authors Yujun Shen and Bolei Zhou, who created a great code base to work on. 
                This version loads all models from the Hugging Face Hub.""")

if __name__ == '__main__':
    main()