Spaces:
Sleeping
Sleeping
import gradio as gr | |
import tensorflow as tf | |
import numpy as np | |
from PIL import Image, ImageEnhance | |
from tensorflow_examples.models.pix2pix import pix2pix | |
OUTPUT_CHANNELS = 3 | |
generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm') | |
generator_g.load_weights("vibrantGAN-generator-g-final.weights.h5") | |
def preprocess_single_image(image, target_height=256, target_width=256): | |
# Convert PIL image to tensorflow tensor | |
image = tf.convert_to_tensor(np.array(image)) | |
# Ensure image has 3 channels (RGB) | |
if len(image.shape) == 2: # If grayscale | |
image = tf.stack([image, image, image], axis=-1) | |
elif image.shape[-1] == 4: # If RGBA | |
image = image[:, :, :3] | |
# Resize the image | |
image = tf.image.resize(image, [target_height, target_width]) | |
# Normalize to [-1, 1] | |
image = tf.cast(image, tf.float32) | |
image = (image / 127.5) - 1 | |
return image | |
def process_image(input_image, output_resolution=(512, 512)): | |
if input_image is None: | |
return None | |
# Get original input image size | |
original_size = input_image.size | |
# Preprocess the image | |
processed_input = preprocess_single_image(input_image) | |
# Add batch dimension | |
processed_input = tf.expand_dims(processed_input, 0) | |
# Generate prediction | |
prediction = generator_g(processed_input) | |
# Convert the prediction to displayable format | |
output_image = prediction[0] * 0.5 + 0.5 # Denormalize to [0, 1] | |
output_image = tf.clip_by_value(output_image, 0, 1) | |
# Resize the output image to the desired resolution | |
output_image = tf.image.resize(output_image, output_resolution, method='bilinear') | |
# Convert to numpy array and then to PIL Image | |
output_array = (output_image.numpy() * 255).astype(np.uint8) | |
output_pil = Image.fromarray(output_array) | |
enhancer = ImageEnhance.Sharpness(output_pil) | |
output_pil = enhancer.enhance(4) # Adjust the factor as needed | |
return output_pil | |
# Create Gradio interface | |
demo = gr.Interface( | |
fn=process_image, | |
inputs=gr.Image(type="pil", label="Input Image"), | |
outputs=gr.Image(type="pil", label="Generated Output"), | |
title="Image Processing Model", | |
description="Upload an image to see the model's output.", | |
) | |
# Launch the interface | |
demo.launch(debug=True) | |