File size: 3,253 Bytes
2aa01a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7cebbeb
2aa01a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import json
import time
import random

import streamlit as st
import torch

from pytorch_pretrained_biggan import (BigGAN, one_hot_from_int, truncated_noise_sample, convert_to_images)

SAVE_NAME = 'result'
SAVE_EXT = 'png'
SAVE_PATH = SAVE_NAME + '.' + SAVE_EXT


@st.cache
def load_model():
    # Load pre-trained model tokenizer (vocabulary)
    model = BigGAN.from_pretrained('biggan-deep-256')
    model.eval()
    return model


@st.cache
def load_imagenet_classes():
    # wget https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json
    with open('imagenet-simple-labels.json', 'r') as file:
        class_names = json.load(file)

    class_to_num = dict({class_names[i]: i for i in range(len(class_names))})
    return class_names, class_to_num


def run_model(class_num: int = 10, seed: int = 42):
    # from README

    # Prepare a input
    truncation = 0.4
    class_vector = one_hot_from_int(class_num, batch_size=1)
    noise_vector = truncated_noise_sample(truncation=truncation, batch_size=1, seed=seed)

    # All in tensors
    noise_vector = torch.from_numpy(noise_vector)
    class_vector = torch.from_numpy(class_vector)

    # Generate an image
    with torch.no_grad():
        output = model(noise_vector, class_vector, truncation)

    return output


def update_seed():
    try:
        new_seed = int(st.session_state.seed)
    except ValueError:
        st.error('Something wrong with random seed. It will be chosen randomly')
        new_seed = random.randint(0, 200)

    st.session_state.seed = new_seed


def update_class_num(class_to_num: dict):
    new_class_name = st.session_state.class_name
    st.session_state.class_num = class_to_num[new_class_name]


if __name__ == '__main__':
    title = \
        """
        ### Inference Homework! (HSE Practical_DL course)
        ### BigGan demo. Choose an imagenet class and random seed to generate some images \U0001F643
        """
    st.markdown(title)

    # loading BigGan model
    model = load_model()

    # loading ImageNet classes
    classes, class_to_num = load_imagenet_classes()

    # initializing state vars

    if 'seed' not in st.session_state:
        st.session_state.seed = 42

    if 'class_name' not in st.session_state:
        st.session_state.class_name = classes[0]

    if 'class_num' not in st.session_state:
        st.session_state.class_num = 0

    st.number_input(
        'Select random seed (positive integer)',
        # value=42,
        min_value=1,
        max_value=500,
        key='seed',
        step=1,
        on_change=update_seed
    )

    st.selectbox(
        'Select ImageNet category',
        options=classes,
        key='class_name',
        on_change=update_class_num,
        args=(class_to_num, )
    )

    # inference
    if st.button('Generate'):
        start_time = time.time()
        generated = run_model(st.session_state.class_num, st.session_state.seed)
        latency = round(time.time() - start_time, 4)
        st.info(f'Generation latency - {latency} s.', icon='ℹ️')

        # Save results as png images
        img = convert_to_images(generated)[0]
        img.save(SAVE_PATH, SAVE_EXT)

        # display image
        st.image(SAVE_PATH)