Spaces:
Runtime error
Runtime error
import sys | |
import pickle | |
import os | |
import numpy as np | |
import PIL.Image | |
import IPython.display | |
from IPython.display import Image | |
import matplotlib.pyplot as plt | |
import gradio as gr | |
sys.path.insert(0, "/StyleGAN2-GANbanales") | |
import dnnlib | |
import dnnlib.tflib as tflib | |
############################################################################## | |
# Generation functions | |
def seed2vec(Gs, seed): | |
rnd = np.random.RandomState(seed) | |
return rnd.randn(1, *Gs.input_shape[1:]) | |
def init_random_state(Gs, seed): | |
rnd = np.random.RandomState(seed) | |
noise_vars = [var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')] | |
tflib.set_vars({var: rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width] | |
def generate_image(Gs, z, truncation_psi, prefix="image", save=False, show=False): | |
# Render images for dlatents initialized from random seeds. | |
Gs_kwargs = { | |
'output_transform': dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True), | |
'randomize_noise': False | |
} | |
if truncation_psi is not None: | |
Gs_kwargs['truncation_psi'] = truncation_psi | |
label = np.zeros([1] + Gs.input_shapes[1][1:]) | |
images = Gs.run(z, label, **Gs_kwargs) # [minibatch, height, width, channel] | |
if save == True: | |
path = f"{prefix}.png" | |
PIL.Image.fromarray(images[0], 'RGB').save(path) | |
if show == True: | |
return images[0] | |
############################################################################## | |
# Function concatenate | |
def concatenate(img_array): | |
zeros = np.zeros([256,256,3], dtype=np.uint8) | |
zeros.fill(255) | |
white_img = zeros | |
# 1 - 2 images | |
if len(img_array) <= 2: | |
row_img = img_array[0] | |
for i in img_array[1:]: | |
row_img = np.hstack((row_img, i)) | |
final_img = row_img | |
# 3 - 4 images | |
elif len(img_array) >= 3 and len(img_array) <= 4: | |
row1_img = img_array[0] | |
for i in img_array[1:2]: | |
row1_img = np.hstack((row1_img, i)) | |
row2_img = img_array[2] | |
for i in img_array[3:]: | |
row2_img = np.hstack((row2_img, i)) | |
for i in range(4-len(img_array)): | |
row2_img = np.hstack((row2_img, white_img)) | |
final_img = np.vstack((row1_img, row2_img)) | |
# 5 - 6 images | |
elif len(img_array) >= 4 and len(img_array) <= 6: | |
row1_img = img_array[0] | |
for i in img_array[1:3]: | |
row1_img = cv2.hconcat([row1_img, i]) | |
row2_img = img_array[3] | |
for i in img_array[4:]: | |
row2_img = cv2.hconcat([row2_img, i]) | |
for i in range(6-len(img_array)): | |
row2_img = cv2.hconcat([row2_img, white_img]) | |
final_img = cv2.vconcat([row1_img, row2_img]) | |
# 7 - 9 images | |
elif len(img_array) >= 7: | |
row1_img = img_array[0] | |
for i in img_array[1:3]: | |
row1_img = cv2.hconcat([row1_img, i]) | |
row2_img = img_array[3] | |
for i in img_array[4:6]: | |
row2_img = cv2.hconcat([row2_img, i]) | |
row3_img = img_array[6] | |
for i in img_array[7:9]: | |
row3_img = cv2.hconcat([row3_img, i]) | |
for i in range(9-len(img_array)): | |
row3_img = cv2.hconcat([row3_img, white_img]) | |
final_img = cv2.vconcat([row1_img, row2_img]) | |
final_img = cv2.vconcat([final_img, row3_img]) | |
return final_img | |
############################################################################## | |
# Function initiate | |
def initiate(seed, n_imgs, text): | |
pkl_file = "networks/experimento_2.pkl" | |
tflib.init_tf() | |
with open(pkl_file, 'rb') as pickle_file: | |
_G, _D, Gs = pickle.load(pickle_file) | |
img_array = [] | |
first_seed = seed | |
for i in range(seed, seed+n_imgs): | |
init_random_state(Gs, 10) | |
z = seed2vec(Gs, seed) | |
img = generate_image(Gs, z, 1.0, show=True) | |
img_array.append(img) | |
seed+=1 | |
final_img = concatenate(img_array) | |
return final_img, "Im谩genes generadas" | |
############################################################################## | |
# Gradio code | |
iface = gr.Interface( | |
fn=initiate, | |
inputs=[gr.inputs.Slider(0, 99999999, "image"), gr.inputs.Slider(1, 9, "images"), "text"], | |
outputs=["image", "text"], | |
examples=[ | |
[40, 1, "Edificios al anochecer"], | |
[37, 1, "Fuente de d铆a"], | |
[426, 1, "Edificios con cielo oscuro"], | |
[397, 1, "Edificios de d铆a"], | |
[395, 1, "Edificios desde anfiteatro"], | |
[281, 1, "Edificios con luces encendidas"], | |
[230, 1, "Edificios con luces encendidas y vegetaci贸n"], | |
[221, 1, "Edificios con vegetaci贸n"], | |
[214, 1, "Edificios al atardecer con luces encendidas"], | |
[198, 1, "Edificio al anochecer con luces en el pasillo"] | |
], | |
title="GANbanales", | |
description="Una GAN para generar im谩genes del campus universitario de Rabanales, C贸rdoba." | |
) | |
if __name__ == "__main__": | |
app, local_url, share_url = iface.launch(debug=True) |