Sketch2Shoes / app.py
bhadresh-savani's picture
Update app.py
7c82ec5
raw
history blame
No virus
1.47 kB
import tensorflow as tf
import pathlib
import gradio as gr
import matplotlib.pyplot as plt
from huggingface_hub import from_pretrained_keras
import numpy as np
# Normalizing the images to [-1, 1]
def normalize_test(input_image):
input_image = tf.cast(input_image, tf.float32)
input_image = (input_image / 127.5) - 1
return input_image
def resize(input_image, height, width):
input_image = tf.image.resize(input_image, [height, width],
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
return input_image
def load_image_infer(image_file):
input_image = resize(image_file, 256, 256)
input_image = normalize_test(input_image)
return input_image
def generate_images(test_input):
test_input = load_image_infer(test_input)
prediction = generator(np.expand_dims(test_input, axis=0), training=True)
fig = plt.figure(figsize=(128, 128))
title = ['Predicted Image']
plt.title('Predicted Image')
# Getting the pixel values in the [0, 1] range to plot.
plt.imshow(prediction[0,:,:,:] * 0.5 + 0.5)
plt.axis('off')
return fig
generator = from_pretrained_keras("keras-io/pix2pix-generator")
img = gr.inputs.Image(shape=(256,256))
plot = gr.outputs.Image(type="plot")
description = "Conditional GAN model that translates image-to-image."
gr.Interface(generate_images, inputs = img, outputs = plot,
title = "Pix2Pix Shoes Reconstructor", description = description, examples = [["./img.png"]]).launch()