import json import time import random import streamlit as st import torch import onnxruntime from pytorch_pretrained_biggan import (BigGAN, one_hot_from_int, truncated_noise_sample, convert_to_images) SAVE_NAME = 'result' SAVE_EXT = 'png' SAVE_PATH = SAVE_NAME + '.' + SAVE_EXT @st.cache def load_session(): # Load session ort_session = onnxruntime.InferenceSession('biggan.onnx') return ort_session @st.cache def load_imagenet_classes(): # wget https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json with open('imagenet-simple-labels.json', 'r') as file: class_names = json.load(file) class_to_num = dict({class_names[i]: i for i in range(len(class_names))}) return class_names, class_to_num def run_model(ort_sess, class_num: int = 10, seed: int = 42): # from README # Prepare a input truncation = 0.4 class_vector = one_hot_from_int(class_num, batch_size=1) noise_vector = truncated_noise_sample(truncation=truncation, batch_size=1, seed=seed) ort_inputs = { ort_sess.get_inputs()[0].name: noise_vector, ort_sess.get_inputs()[1].name: class_vector, } output = ort_sess.run(None, ort_inputs, ) return output def update_seed(): try: new_seed = int(st.session_state.seed) except ValueError: st.error('Something wrong with random seed. It will be chosen randomly') new_seed = random.randint(0, 200) st.session_state.seed = new_seed def update_class_num(class_to_num: dict): new_class_name = st.session_state.class_name st.session_state.class_num = class_to_num[new_class_name] if __name__ == '__main__': title = \ """ ### Inference Homework! (HSE Practical_DL course) ### BigGan demo. Choose an imagenet class and random seed to generate some images \U0001F643 """ st.markdown(title) # loading BigGan model ort_sess = load_session() # loading ImageNet classes classes, class_to_num = load_imagenet_classes() # initializing state vars if 'seed' not in st.session_state: st.session_state.seed = 42 if 'class_name' not in st.session_state: st.session_state.class_name = classes[0] if 'class_num' not in st.session_state: st.session_state.class_num = 0 st.number_input( 'Select random seed (positive integer)', # value=42, min_value=1, max_value=500, key='seed', step=1, on_change=update_seed ) st.selectbox( 'Select ImageNet category', options=classes, key='class_name', on_change=update_class_num, args=(class_to_num, ) ) # inference if st.button('Generate'): start_time = time.time() generated = run_model(ort_sess, st.session_state.class_num, st.session_state.seed) latency = round(time.time() - start_time, 4) st.info(f'Generation latency - {latency} s.', icon='ℹ️') # Save results as png images img = convert_to_images(generated[0])[0] img.save(SAVE_PATH, SAVE_EXT) # display image st.image(SAVE_PATH)