File size: 2,308 Bytes
a41434f
 
 
b00682a
a41434f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97cd984
a41434f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f72cbc0
 
 
a41434f
 
 
 
1c850d7
 
 
 
a41434f
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import gradio as gr
import tensorflow as tf
import numpy as np
from PIL import Image, ImageEnhance
from tensorflow_examples.models.pix2pix import pix2pix

OUTPUT_CHANNELS = 3

generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_g.load_weights("vibrantGAN-generator-g-final.weights.h5")

def preprocess_single_image(image, target_height=256, target_width=256):
    # Convert PIL image to tensorflow tensor
    image = tf.convert_to_tensor(np.array(image))

    # Ensure image has 3 channels (RGB)
    if len(image.shape) == 2:  # If grayscale
        image = tf.stack([image, image, image], axis=-1)
    elif image.shape[-1] == 4:  # If RGBA
        image = image[:, :, :3]

    # Resize the image
    image = tf.image.resize(image, [target_height, target_width])

    # Normalize to [-1, 1]
    image = tf.cast(image, tf.float32)
    image = (image / 127.5) - 1

    return image

def process_image(input_image, output_resolution=(512, 512)):
    if input_image is None:
        return None

    # Get original input image size
    original_size = input_image.size

    # Preprocess the image
    processed_input = preprocess_single_image(input_image)

    # Add batch dimension
    processed_input = tf.expand_dims(processed_input, 0)

    # Generate prediction
    prediction = generator_g(processed_input)

    # Convert the prediction to displayable format
    output_image = prediction[0] * 0.5 + 0.5  # Denormalize to [0, 1]
    output_image = tf.clip_by_value(output_image, 0, 1)

    # Resize the output image to the desired resolution
    output_image = tf.image.resize(output_image, output_resolution, method='bilinear')

    # Convert to numpy array and then to PIL Image
    output_array = (output_image.numpy() * 255).astype(np.uint8)
    output_pil = Image.fromarray(output_array)


    enhancer = ImageEnhance.Sharpness(output_pil)
    output_pil = enhancer.enhance(4)  # Adjust the factor as needed

    return output_pil

# Create Gradio interface
demo = gr.Interface(
    fn=process_image,
    inputs=gr.Image(type="pil", label="Input Image"),
    outputs=gr.Image(type="pil", label="Generated Output"),
    title="Image Processing Model",
    description="Upload an image to see the model's output.",
)

# Launch the interface
demo.launch(debug=True)