Spaces:
Runtime error
Runtime error
import json | |
import time | |
import random | |
import streamlit as st | |
import torch | |
from pytorch_pretrained_biggan import (BigGAN, one_hot_from_int, truncated_noise_sample, convert_to_images) | |
SAVE_NAME = 'result' | |
SAVE_EXT = 'png' | |
SAVE_PATH = SAVE_NAME + '.' + SAVE_EXT | |
def load_model(): | |
# Load pre-trained model tokenizer (vocabulary) | |
model = BigGAN.from_pretrained('biggan-deep-256') | |
model.eval() | |
return model | |
def load_imagenet_classes(): | |
# wget https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json | |
with open('imagenet-simple-labels.json', 'r') as file: | |
class_names = json.load(file) | |
class_to_num = dict({class_names[i]: i for i in range(len(class_names))}) | |
return class_names, class_to_num | |
def run_model(class_num: int = 10, seed: int = 42): | |
# from README | |
# Prepare a input | |
truncation = 0.4 | |
class_vector = one_hot_from_int(class_num, batch_size=1) | |
noise_vector = truncated_noise_sample(truncation=truncation, batch_size=1, seed=seed) | |
# All in tensors | |
noise_vector = torch.from_numpy(noise_vector) | |
class_vector = torch.from_numpy(class_vector) | |
# Generate an image | |
with torch.no_grad(): | |
output = model(noise_vector, class_vector, truncation) | |
return output | |
def update_seed(): | |
try: | |
new_seed = int(st.session_state.seed) | |
except ValueError: | |
st.error('Something wrong with random seed. It will be chosen randomly') | |
new_seed = random.randint(0, 200) | |
st.session_state.seed = new_seed | |
def update_class_num(class_to_num: dict): | |
new_class_name = st.session_state.class_name | |
st.session_state.class_num = class_to_num[new_class_name] | |
if __name__ == '__main__': | |
title = \ | |
""" | |
### Inference Homework! (HSE Practical_DL course) | |
### BigGan demo. Choose an imagenet class and random seed to generate some images \U0001F643 | |
""" | |
st.markdown(title) | |
# loading BigGan model | |
model = load_model() | |
# loading ImageNet classes | |
classes, class_to_num = load_imagenet_classes() | |
# initializing state vars | |
if 'seed' not in st.session_state: | |
st.session_state.seed = 42 | |
if 'class_name' not in st.session_state: | |
st.session_state.class_name = classes[0] | |
if 'class_num' not in st.session_state: | |
st.session_state.class_num = 0 | |
st.number_input( | |
'Select random seed (positive integer)', | |
# value=42, | |
min_value=1, | |
max_value=500, | |
key='seed', | |
step=1, | |
on_change=update_seed | |
) | |
st.selectbox( | |
'Select ImageNet category', | |
options=classes, | |
key='class_name', | |
on_change=update_class_num, | |
args=(class_to_num, ) | |
) | |
# inference | |
if st.button('Generate'): | |
start_time = time.time() | |
generated = run_model(st.session_state.class_num, st.session_state.seed) | |
latency = round(time.time() - start_time, 4) | |
st.info(f'Generation latency - {latency} s.', icon='ℹ️') | |
# Save results as png images | |
img = convert_to_images(generated)[0] | |
img.save(SAVE_PATH, SAVE_EXT) | |
# display image | |
st.image(SAVE_PATH) | |