Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import tensorflow as tf | |
| import numpy as np | |
| from PIL import Image, ImageEnhance | |
| from tensorflow_examples.models.pix2pix import pix2pix | |
| OUTPUT_CHANNELS = 3 | |
| generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm') | |
| generator_g.load_weights("vibrantGAN-generator-g-final.weights.h5") | |
| def preprocess_single_image(image, target_height=256, target_width=256): | |
| # Convert PIL image to tensorflow tensor | |
| image = tf.convert_to_tensor(np.array(image)) | |
| # Ensure image has 3 channels (RGB) | |
| if len(image.shape) == 2: # If grayscale | |
| image = tf.stack([image, image, image], axis=-1) | |
| elif image.shape[-1] == 4: # If RGBA | |
| image = image[:, :, :3] | |
| # Resize the image | |
| image = tf.image.resize(image, [target_height, target_width]) | |
| # Normalize to [-1, 1] | |
| image = tf.cast(image, tf.float32) | |
| image = (image / 127.5) - 1 | |
| return image | |
| def process_image(input_image, output_resolution=(512, 512)): | |
| if input_image is None: | |
| return None | |
| # Get original input image size | |
| original_size = input_image.size | |
| # Preprocess the image | |
| processed_input = preprocess_single_image(input_image) | |
| # Add batch dimension | |
| processed_input = tf.expand_dims(processed_input, 0) | |
| # Generate prediction | |
| prediction = generator_g(processed_input) | |
| # Convert the prediction to displayable format | |
| output_image = prediction[0] * 0.5 + 0.5 # Denormalize to [0, 1] | |
| output_image = tf.clip_by_value(output_image, 0, 1) | |
| # Resize the output image to the desired resolution | |
| output_image = tf.image.resize(output_image, output_resolution, method='bilinear') | |
| # Convert to numpy array and then to PIL Image | |
| output_array = (output_image.numpy() * 255).astype(np.uint8) | |
| output_pil = Image.fromarray(output_array) | |
| enhancer = ImageEnhance.Sharpness(output_pil) | |
| output_pil = enhancer.enhance(4) # Adjust the factor as needed | |
| return output_pil | |
| # Create Gradio interface | |
| demo = gr.Interface( | |
| fn=process_image, | |
| inputs=gr.Image(type="pil", label="Input Image"), | |
| outputs=gr.Image(type="pil", label="Generated Output"), | |
| title="Image Processing Model", | |
| description="Upload an image to see the model's output.", | |
| ) | |
| # Launch the interface | |
| demo.launch(debug=True) | |