photo2monet / app.py
ayaderaghul's picture
Update app.py
ccda766
raw
history blame
2.53 kB
import gradio as gr
import keras
from keras.models import load_model
# from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
from tensorflow_addons.layers import InstanceNormalization
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
cust = {'InstanceNormalization': InstanceNormalization}
model=load_model('g-cycleGAN-photo2monet-500images-epoch10_30_30_30_30.h5',cust)
path = [['ex1.jpg'], ['ex2.jpg'], ['ex3.jpg'],['ex4.jpg'],['ex5.jpg']]
# preprocess
AUTOTUNE = tf.data.AUTOTUNE
BUFFER_SIZE = 400
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256
def resize(image,height,width):
'''
Resizing the image
'''
resized_image = tf.image.resize(image,[height,width],method = tf.image.ResizeMethod.NEAREST_NEIGHBOR)
return resized_image
def normalize(input_image):
# def normalize(real_image, input_image)
input_image = (input_image/127.5) - 1
return input_image
# real_image = (real_image/127.5) - 1
# return real_image,input_image
def load(img_file):
'''
load the image. Since we need only the target image and a
gray scale version of the same, we are going to load one
and create the other from it
'''
img = tf.io.read_file(img_file)
img = tf.io.decode_jpeg(img)
# w = tf.shape(img)[1]
# w = w//2
# real_image = img[:,:w,:]
real_image = tf.cast(img,tf.float32)
return real_image
def load_image_test(image_file):
'''
We are not using random jitter here and thus creating
a gray scale image after resizing.
'''
re = load(image_file)
re = resize(re,IMG_HEIGHT,IMG_WIDTH)
# inp = tf.image.rgb_to_grayscale(re)
# re,inp = normalize(re,inp)
# inp = re
# re, inp = normalize(re,inp)
re = normalize(re)
# return inp,re
return re
def show_preds_image(image_path):
A = load_image_test(image_path)
# A = (A - 127.5) / 127.5
A = np.expand_dims(A,axis=0)
B = model(A)
# B = np.squeeze(B,axis=0)
B = B[0]
B = B * 0.5 + 0.5
B = B.numpy()
return B
inputs_image = [
gr.components.Image(shape=(256,256),type="filepath", label="Input Image"),
]
outputs_image = [
gr.components.Image(shape=(256,256),type="numpy", label="Output Image").style(width=256, height=256),
]
interface_image = gr.Interface(
fn=show_preds_image,
inputs=inputs_image,
outputs=outputs_image,
title="photo2monet",
examples=path,
cache_examples=False,
)
gr.TabbedInterface(
[interface_image],
tab_names=['Image inference']
).queue().launch()