SkinDemarker / app.py
AdarshRavis's picture
Update app.py
db93a9c verified
raw
history blame
No virus
3.67 kB
import streamlit as st
import numpy as np
import cv2
from PIL import Image
import matplotlib.pyplot as plt
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, LeakyReLU, Concatenate, Conv2DTranspose, Activation, BatchNormalization, Dropout
from tensorflow.keras.initializers import RandomNormal
# Define the generator model architecture
def define_generator(image_shape=(256, 256, 3)):
init = RandomNormal(stddev=0.02)
in_image = Input(shape=image_shape)
# Encoder model
e1 = define_encoder_block(in_image, 64, batchnorm=False)
e2 = define_encoder_block(e1, 128)
e3 = define_encoder_block(e2, 256)
e4 = define_encoder_block(e3, 512)
e5 = define_encoder_block(e4, 512)
e6 = define_encoder_block(e5, 512)
e7 = define_encoder_block(e6, 512)
# Bottleneck
b = Conv2D(512, (4, 4), strides=(2, 2), padding='same', kernel_initializer=init)(e7)
b = Activation('relu')(b)
# Decoder model
d1 = decoder_block(b, e7, 512)
d2 = decoder_block(d1, e6, 512)
d3 = decoder_block(d2, e5, 512)
d4 = decoder_block(d3, e4, 512, dropout=False)
d5 = decoder_block(d4, e3, 256, dropout=False)
d6 = decoder_block(d5, e2, 128, dropout=False)
d7 = decoder_block(d6, e1, 64, dropout=False)
# Output
g = Conv2DTranspose(image_shape[2], (4, 4), strides=(2, 2), padding='same', kernel_initializer=init)(d7)
out_image = Activation('tanh')(g)
# Define model
model = Model(in_image, out_image)
return model
def define_encoder_block(layer_in, n_filters, batchnorm=True):
init = RandomNormal(stddev=0.02)
g = Conv2D(n_filters, (4, 4), strides=(2, 2), padding='same', kernel_initializer=init)(layer_in)
if batchnorm:
g = BatchNormalization()(g, training=True)
g = LeakyReLU(alpha=0.2)(g)
return g
def decoder_block(layer_in, skip_in, n_filters, dropout=True):
init = RandomNormal(stddev=0.02)
g = Conv2DTranspose(n_filters, (4, 4), strides=(2, 2), padding='same', kernel_initializer=init)(layer_in)
g = BatchNormalization()(g, training=True)
if dropout:
g = Dropout(0.5)(g, training=True)
g = Concatenate()([g, skip_in])
g = Activation('relu')(g)
return g
def load_and_preprocess_image(image):
custom_image = np.array(image)
custom_image = cv2.resize(custom_image, (256, 256))
#custom_image = cv2.cvtColor(custom_image, cv2.COLOR_BGR2RGB)
custom_image = (custom_image.astype(np.float32) - 127.5) / 127.5
custom_image = np.expand_dims(custom_image, axis=0)
return custom_image
# Streamlit UI
st.title("Image Processing with Keras Model")
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
# Load and preprocess image
image = Image.open(uploaded_file).convert('RGB')
st.image(image, caption='Original Image', use_column_width=True)
custom_image = load_and_preprocess_image(image)
# Define and load model
model = define_generator()
model.load_weights('model_weights.h5')
# Generate image
gen_image = model.predict(custom_image)
gen_image = np.squeeze(gen_image, axis=0)
gen_image = ((gen_image + 1) * 127.5).astype(np.uint8)
gen_image_pil = Image.fromarray(gen_image)
st.image(gen_image_pil, caption='Generated Image', use_column_width=True)
# Plotting results
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(image)
ax[0].set_title("Original Image")
ax[0].axis('off')
ax[1].imshow(gen_image_pil)
ax[1].set_title("Generated Image")
ax[1].axis('off')
st.pyplot(fig)