vibrantGAN / app.py
datafreak's picture
Update app.py
b00682a verified
import gradio as gr
import tensorflow as tf
import numpy as np
from PIL import Image, ImageEnhance
from tensorflow_examples.models.pix2pix import pix2pix
OUTPUT_CHANNELS = 3
generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_g.load_weights("vibrantGAN-generator-g-final.weights.h5")
def preprocess_single_image(image, target_height=256, target_width=256):
# Convert PIL image to tensorflow tensor
image = tf.convert_to_tensor(np.array(image))
# Ensure image has 3 channels (RGB)
if len(image.shape) == 2: # If grayscale
image = tf.stack([image, image, image], axis=-1)
elif image.shape[-1] == 4: # If RGBA
image = image[:, :, :3]
# Resize the image
image = tf.image.resize(image, [target_height, target_width])
# Normalize to [-1, 1]
image = tf.cast(image, tf.float32)
image = (image / 127.5) - 1
return image
def process_image(input_image, output_resolution=(512, 512)):
if input_image is None:
return None
# Get original input image size
original_size = input_image.size
# Preprocess the image
processed_input = preprocess_single_image(input_image)
# Add batch dimension
processed_input = tf.expand_dims(processed_input, 0)
# Generate prediction
prediction = generator_g(processed_input)
# Convert the prediction to displayable format
output_image = prediction[0] * 0.5 + 0.5 # Denormalize to [0, 1]
output_image = tf.clip_by_value(output_image, 0, 1)
# Resize the output image to the desired resolution
output_image = tf.image.resize(output_image, output_resolution, method='bilinear')
# Convert to numpy array and then to PIL Image
output_array = (output_image.numpy() * 255).astype(np.uint8)
output_pil = Image.fromarray(output_array)
enhancer = ImageEnhance.Sharpness(output_pil)
output_pil = enhancer.enhance(4) # Adjust the factor as needed
return output_pil
# Create Gradio interface
demo = gr.Interface(
fn=process_image,
inputs=gr.Image(type="pil", label="Input Image"),
outputs=gr.Image(type="pil", label="Generated Output"),
title="Image Processing Model",
description="Upload an image to see the model's output.",
)
# Launch the interface
demo.launch(debug=True)