File size: 3,668 Bytes
1848a98
ae66579
806da74
 
 
5e21889
 
 
1848a98
5e21889
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1848a98
5e21889
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1848a98
e0cdb19
 
 
db93a9c
e0cdb19
 
 
 
806da74
 
1848a98
f3bb9da
2f94c4d
ae66579
806da74
 
 
 
 
5e21889
 
 
 
 
 
806da74
 
 
 
 
2f94c4d
806da74
 
edb7fc7
806da74
 
 
2f94c4d
806da74
 
 
1848a98
806da74
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import streamlit as st
import numpy as np
import cv2
from PIL import Image
import matplotlib.pyplot as plt
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, LeakyReLU, Concatenate, Conv2DTranspose, Activation, BatchNormalization, Dropout
from tensorflow.keras.initializers import RandomNormal

# Define the generator model architecture
def define_generator(image_shape=(256, 256, 3)):
    init = RandomNormal(stddev=0.02)
    in_image = Input(shape=image_shape)
    
    # Encoder model
    e1 = define_encoder_block(in_image, 64, batchnorm=False)
    e2 = define_encoder_block(e1, 128)
    e3 = define_encoder_block(e2, 256)
    e4 = define_encoder_block(e3, 512)
    e5 = define_encoder_block(e4, 512)
    e6 = define_encoder_block(e5, 512)
    e7 = define_encoder_block(e6, 512)
    
    # Bottleneck
    b = Conv2D(512, (4, 4), strides=(2, 2), padding='same', kernel_initializer=init)(e7)
    b = Activation('relu')(b)
    
    # Decoder model
    d1 = decoder_block(b, e7, 512)
    d2 = decoder_block(d1, e6, 512)
    d3 = decoder_block(d2, e5, 512)
    d4 = decoder_block(d3, e4, 512, dropout=False)
    d5 = decoder_block(d4, e3, 256, dropout=False)
    d6 = decoder_block(d5, e2, 128, dropout=False)
    d7 = decoder_block(d6, e1, 64, dropout=False)
    
    # Output
    g = Conv2DTranspose(image_shape[2], (4, 4), strides=(2, 2), padding='same', kernel_initializer=init)(d7)
    out_image = Activation('tanh')(g)
    
    # Define model
    model = Model(in_image, out_image)
    return model

def define_encoder_block(layer_in, n_filters, batchnorm=True):
    init = RandomNormal(stddev=0.02)
    g = Conv2D(n_filters, (4, 4), strides=(2, 2), padding='same', kernel_initializer=init)(layer_in)
    if batchnorm:
        g = BatchNormalization()(g, training=True)
    g = LeakyReLU(alpha=0.2)(g)
    return g

def decoder_block(layer_in, skip_in, n_filters, dropout=True):
    init = RandomNormal(stddev=0.02)
    g = Conv2DTranspose(n_filters, (4, 4), strides=(2, 2), padding='same', kernel_initializer=init)(layer_in)
    g = BatchNormalization()(g, training=True)
    if dropout:
        g = Dropout(0.5)(g, training=True)
    g = Concatenate()([g, skip_in])
    g = Activation('relu')(g)
    return g

def load_and_preprocess_image(image):
    custom_image = np.array(image)
    custom_image = cv2.resize(custom_image, (256, 256))
    #custom_image = cv2.cvtColor(custom_image, cv2.COLOR_BGR2RGB)
    custom_image = (custom_image.astype(np.float32) - 127.5) / 127.5
    custom_image = np.expand_dims(custom_image, axis=0)
    return custom_image

# Streamlit UI
st.title("Image Processing with Keras Model")

uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])

if uploaded_file is not None:
    # Load and preprocess image
    image = Image.open(uploaded_file).convert('RGB')
    st.image(image, caption='Original Image', use_column_width=True)

    custom_image = load_and_preprocess_image(image)
    
    # Define and load model
    model = define_generator()
    model.load_weights('model_weights.h5')
    
    # Generate image
    gen_image = model.predict(custom_image)
    gen_image = np.squeeze(gen_image, axis=0)
    gen_image = ((gen_image + 1) * 127.5).astype(np.uint8)
    gen_image_pil = Image.fromarray(gen_image)
    st.image(gen_image_pil, caption='Generated Image', use_column_width=True)

    # Plotting results
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))

    ax[0].imshow(image)
    ax[0].set_title("Original Image")
    ax[0].axis('off')

    ax[1].imshow(gen_image_pil)
    ax[1].set_title("Generated Image")
    ax[1].axis('off')

    st.pyplot(fig)