corrosivelogic's picture
Add application file
88902ef
import streamlit as st
from PIL import Image
import numpy as np
import tensorflow as tf
if 'clicked' not in st.session_state:
st.session_state.clicked = False
def click_button():
st.session_state.clicked = True
img_size = 400
vgg = tf.keras.applications.VGG19(include_top=False, input_shape=(img_size, img_size, 3), weights='vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5')
vgg.trainable = False
def get_layer_outputs(vgg, layer_names):
outputs = [vgg.get_layer(layer[0]).output for layer in layer_names]
model = tf.keras.Model([vgg.input], outputs)
return model
STYLE_LAYERS = [('block1_conv1', 0.2), ('block2_conv1', 0.2), ('block3_conv1', 0.2), ('block4_conv1', 0.2), ('block5_conv1', 0.2)]
content_layer = [('block5_conv4', 1)]
vgg_model_outputs = get_layer_outputs(vgg, STYLE_LAYERS + content_layer)
st.set_page_config(layout="wide")
st.markdown("<h1 style='text-align: center;'>Neural Style Transfer</h1>", unsafe_allow_html=True)
st.divider()
co1, co2, co3, co4 = st.columns(4)
with co2:
epochs = st.number_input("Input number of epochs", min_value=200, max_value=20000, step=50)
with co3:
st.write(" ")
st.write(" ")
st.button('Generate Art', on_click=click_button, type="primary", use_container_width=True)
col1, col2, col3 = st.columns(3)
with col1:
content_img = st.file_uploader("Input Content Image")
if content_img is not None:
content_image = np.array(Image.open(content_img).resize((img_size, img_size)))
content_image = np.expand_dims(content_image, axis=0) # Add batch dimension
generated_image = tf.Variable(tf.image.convert_image_dtype(content_image, tf.float32))
noise = tf.random.uniform(tf.shape(generated_image), 0, 0.5)
generated_image = tf.add(generated_image, noise)
generated_image = tf.clip_by_value(generated_image, clip_value_min=0.0, clip_value_max=1.0)
content_target = vgg_model_outputs(content_image)
preprocessed_content = tf.Variable(tf.image.convert_image_dtype(content_image, tf.float32))
a_C = vgg_model_outputs(preprocessed_content)
a_G = vgg_model_outputs(generated_image)
st.image(content_img, caption="CONTENT IMAGE", use_column_width=True)
with col2:
style_img = st.file_uploader("Input Style Image")
if style_img is not None:
style_image = np.array(Image.open(style_img).resize((img_size, img_size)))
style_image = np.expand_dims(style_image, axis=0) # Add batch dimension
style_targets = vgg_model_outputs(style_image)
preprocessed_style = tf.Variable(tf.image.convert_image_dtype(style_image, tf.float32))
a_S = vgg_model_outputs(preprocessed_style)
st.image(style_img, caption="STYLE IMAGE", use_column_width=True)
def compute_content_cost(content_output, generated_output):
a_C = content_output[-1]
a_G = generated_output[-1]
m, n_H, n_W, n_C = a_G.get_shape().as_list()
a_C_unrolled = tf.transpose(tf.reshape(a_C, shape=[m, -1, n_C]))
a_G_unrolled = tf.transpose(tf.reshape(a_G, shape=[m, -1, n_C]))
J_content = (1 / (4 * n_H * n_W * n_C)) * tf.reduce_sum(tf.square(tf.subtract(a_C_unrolled, a_G_unrolled)))
return J_content
def gram_matrix(A):
GA = tf.matmul(A, A, transpose_b=True)
return GA
def compute_layer_style_cost(a_S, a_G):
m, n_H, n_W, n_C = a_G.get_shape().as_list()
a_S = tf.transpose(tf.reshape(a_S, shape=[-1, n_C]))
a_G = tf.transpose(tf.reshape(a_G, shape=[-1, n_C]))
GS = gram_matrix(a_S)
GG = gram_matrix(a_G)
J_style_layer = (1 / (4 * n_C ** 2 * (n_H * n_W) ** 2)) * tf.reduce_sum(tf.square(tf.subtract(GS, GG)))
return J_style_layer
def compute_style_cost(style_image_output, generated_image_output, STYLE_LAYERS=STYLE_LAYERS):
J_style = 0
a_S = style_image_output[:-1]
a_G = generated_image_output[:-1]
for i, weight in zip(range(len(a_S)), STYLE_LAYERS):
J_style_layer = compute_layer_style_cost(a_S[i], a_G[i])
J_style += weight[1] * J_style_layer
return J_style
@tf.function()
def total_cost(J_content, J_style, alpha=10, beta=40):
J = alpha * J_content + beta * J_style
return J
def clip_0_1(image):
return tf.clip_by_value(image, clip_value_min=0.0, clip_value_max=1.0)
def tensor_to_image(tensor):
tensor = tensor * 255
tensor = np.array(tensor, dtype=np.uint8)
if np.ndim(tensor) > 3:
assert tensor.shape[0] == 1
tensor = tensor[0]
return Image.fromarray(tensor)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
@tf.function()
def train_step(generated_image):
with tf.GradientTape() as tape:
a_G = vgg_model_outputs(generated_image)
J_style = compute_style_cost(a_S, a_G)
J_content = compute_content_cost(a_C, a_G)
J = total_cost(J_content, J_style)
grad = tape.gradient(J, generated_image)
optimizer.apply_gradients([(grad, generated_image)])
generated_image.assign(clip_0_1(generated_image))
return J
with col3:
st.write("Generated Image")
if st.session_state.clicked:
generated_image = tf.Variable(tf.image.convert_image_dtype(content_image, tf.float32))
st.write(" ")
st.write(" ")
st.write(" ")
placeholder_1 = st.empty()
placeholder_2 = st.empty()
for I in range(epochs):
train_step(generated_image)
if I % 1 == 0:
image = tensor_to_image(generated_image)
placeholder_1.image(image)
placeholder_2.write(f"Epoch {I}")