inference_hw_torch / biggan_torch_space.py
vladmir077's picture
some changes to remove warning issue
7cebbeb
import json
import time
import random
import streamlit as st
import torch
from pytorch_pretrained_biggan import (BigGAN, one_hot_from_int, truncated_noise_sample, convert_to_images)
SAVE_NAME = 'result'
SAVE_EXT = 'png'
SAVE_PATH = SAVE_NAME + '.' + SAVE_EXT
@st.cache
def load_model():
# Load pre-trained model tokenizer (vocabulary)
model = BigGAN.from_pretrained('biggan-deep-256')
model.eval()
return model
@st.cache
def load_imagenet_classes():
# wget https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json
with open('imagenet-simple-labels.json', 'r') as file:
class_names = json.load(file)
class_to_num = dict({class_names[i]: i for i in range(len(class_names))})
return class_names, class_to_num
def run_model(class_num: int = 10, seed: int = 42):
# from README
# Prepare a input
truncation = 0.4
class_vector = one_hot_from_int(class_num, batch_size=1)
noise_vector = truncated_noise_sample(truncation=truncation, batch_size=1, seed=seed)
# All in tensors
noise_vector = torch.from_numpy(noise_vector)
class_vector = torch.from_numpy(class_vector)
# Generate an image
with torch.no_grad():
output = model(noise_vector, class_vector, truncation)
return output
def update_seed():
try:
new_seed = int(st.session_state.seed)
except ValueError:
st.error('Something wrong with random seed. It will be chosen randomly')
new_seed = random.randint(0, 200)
st.session_state.seed = new_seed
def update_class_num(class_to_num: dict):
new_class_name = st.session_state.class_name
st.session_state.class_num = class_to_num[new_class_name]
if __name__ == '__main__':
title = \
"""
### Inference Homework! (HSE Practical_DL course)
### BigGan demo. Choose an imagenet class and random seed to generate some images \U0001F643
"""
st.markdown(title)
# loading BigGan model
model = load_model()
# loading ImageNet classes
classes, class_to_num = load_imagenet_classes()
# initializing state vars
if 'seed' not in st.session_state:
st.session_state.seed = 42
if 'class_name' not in st.session_state:
st.session_state.class_name = classes[0]
if 'class_num' not in st.session_state:
st.session_state.class_num = 0
st.number_input(
'Select random seed (positive integer)',
# value=42,
min_value=1,
max_value=500,
key='seed',
step=1,
on_change=update_seed
)
st.selectbox(
'Select ImageNet category',
options=classes,
key='class_name',
on_change=update_class_num,
args=(class_to_num, )
)
# inference
if st.button('Generate'):
start_time = time.time()
generated = run_model(st.session_state.class_num, st.session_state.seed)
latency = round(time.time() - start_time, 4)
st.info(f'Generation latency - {latency} s.', icon='ℹ️')
# Save results as png images
img = convert_to_images(generated)[0]
img.save(SAVE_PATH, SAVE_EXT)
# display image
st.image(SAVE_PATH)