vladmir077's picture
first commit
6e231e7
import json
import time
import random
import streamlit as st
import torch
import onnxruntime
from pytorch_pretrained_biggan import (BigGAN, one_hot_from_int, truncated_noise_sample, convert_to_images)
SAVE_NAME = 'result'
SAVE_EXT = 'png'
SAVE_PATH = SAVE_NAME + '.' + SAVE_EXT
@st.cache
def load_session():
# Load session
ort_session = onnxruntime.InferenceSession('biggan.onnx')
return ort_session
@st.cache
def load_imagenet_classes():
# wget https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json
with open('imagenet-simple-labels.json', 'r') as file:
class_names = json.load(file)
class_to_num = dict({class_names[i]: i for i in range(len(class_names))})
return class_names, class_to_num
def run_model(ort_sess, class_num: int = 10, seed: int = 42):
# from README
# Prepare a input
truncation = 0.4
class_vector = one_hot_from_int(class_num, batch_size=1)
noise_vector = truncated_noise_sample(truncation=truncation, batch_size=1, seed=seed)
ort_inputs = {
ort_sess.get_inputs()[0].name: noise_vector,
ort_sess.get_inputs()[1].name: class_vector,
}
output = ort_sess.run(None, ort_inputs, )
return output
def update_seed():
try:
new_seed = int(st.session_state.seed)
except ValueError:
st.error('Something wrong with random seed. It will be chosen randomly')
new_seed = random.randint(0, 200)
st.session_state.seed = new_seed
def update_class_num(class_to_num: dict):
new_class_name = st.session_state.class_name
st.session_state.class_num = class_to_num[new_class_name]
if __name__ == '__main__':
title = \
"""
### Inference Homework! (HSE Practical_DL course)
### BigGan demo. Choose an imagenet class and random seed to generate some images \U0001F643
"""
st.markdown(title)
# loading BigGan model
ort_sess = load_session()
# loading ImageNet classes
classes, class_to_num = load_imagenet_classes()
# initializing state vars
if 'seed' not in st.session_state:
st.session_state.seed = 42
if 'class_name' not in st.session_state:
st.session_state.class_name = classes[0]
if 'class_num' not in st.session_state:
st.session_state.class_num = 0
st.number_input(
'Select random seed (positive integer)',
# value=42,
min_value=1,
max_value=500,
key='seed',
step=1,
on_change=update_seed
)
st.selectbox(
'Select ImageNet category',
options=classes,
key='class_name',
on_change=update_class_num,
args=(class_to_num, )
)
# inference
if st.button('Generate'):
start_time = time.time()
generated = run_model(ort_sess, st.session_state.class_num, st.session_state.seed)
latency = round(time.time() - start_time, 4)
st.info(f'Generation latency - {latency} s.', icon='ℹ️')
# Save results as png images
img = convert_to_images(generated[0])[0]
img.save(SAVE_PATH, SAVE_EXT)
# display image
st.image(SAVE_PATH)