NST-Demo / app.py
WarpWingHF's picture
Update app.py
234c728 verified
raw
history blame
6.45 kB
import gradio as gr
import spaces
import tensorflow as tf
import numpy as np
from PIL import Image
import logging
import time
from tqdm import tqdm
# Initialize logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("style_transfer_app")
# Set TensorFlow threading options
tf.config.threading.set_inter_op_parallelism_threads(8)
tf.config.threading.set_intra_op_parallelism_threads(8)
def load_img(image):
"""Load and preprocess image for style transfer"""
max_dim = 512
# Convert PIL Image to tensor
img = tf.convert_to_tensor(np.array(image))
img = tf.image.convert_image_dtype(img, tf.float32)
shape = tf.cast(tf.shape(img)[:-1], tf.float32)
long_dim = max(shape)
scale = max_dim / long_dim
new_shape = tf.cast(shape * scale, tf.int32)
img = tf.image.resize(img, new_shape)
img = img[tf.newaxis, :]
return img
def vgg_layers(layer_names):
"""Create VGG model with specified layers"""
vgg = tf.keras.applications.VGG19(include_top=False, weights='imagenet')
vgg.trainable = False
outputs = [vgg.get_layer(name).output for name in layer_names]
model = tf.keras.Model([vgg.input], outputs)
return model
def gram_matrix(input_tensor):
"""Calculate Gram matrix"""
result = tf.linalg.einsum('bijc,bijd->bcd', input_tensor, input_tensor)
input_shape = tf.shape(input_tensor)
num_locations = tf.cast(input_shape[1]*input_shape[2], tf.float32)
return result / num_locations
class StyleContentModel(tf.keras.models.Model):
def __init__(self, style_layers, content_layers):
super(StyleContentModel, self).__init__()
self.vgg = vgg_layers(style_layers + content_layers)
self.style_layers = style_layers
self.content_layers = content_layers
self.num_style_layers = len(style_layers)
self.vgg.trainable = False
def call(self, inputs):
inputs = inputs * 255.0
preprocessed_input = tf.keras.applications.vgg19.preprocess_input(inputs)
outputs = self.vgg(preprocessed_input)
style_outputs, content_outputs = (outputs[:self.num_style_layers],
outputs[self.num_style_layers:])
style_outputs = [gram_matrix(style_output)
for style_output in style_outputs]
content_dict = {content_name: value
for content_name, value
in zip(self.content_layers, content_outputs)}
style_dict = {style_name: value
for style_name, value
in zip(self.style_layers, style_outputs)}
return {'content': content_dict, 'style': style_dict}
def clip_0_1(image):
"""Clip tensor values between 0 and 1"""
return tf.clip_by_value(image, clip_value_min=0.0, clip_value_max=1.0)
def style_content_loss(outputs, style_targets, content_targets, style_weight, content_weight):
"""Calculate style and content loss"""
style_outputs = outputs['style']
content_outputs = outputs['content']
style_loss = tf.add_n([tf.reduce_mean((style_outputs[name]-style_targets[name])**2)
for name in style_outputs.keys()])
style_loss *= style_weight / len(style_outputs)
content_loss = tf.add_n([tf.reduce_mean((content_outputs[name]-content_targets[name])**2)
for name in content_outputs.keys()])
content_loss *= content_weight / len(content_outputs)
loss = style_loss + content_loss
return loss
@tf.function
def train_step(image, extractor, style_targets, content_targets, opt, style_weight, content_weight, total_variation_weight):
"""Perform one training step"""
with tf.GradientTape() as tape:
outputs = extractor(image)
loss = style_content_loss(outputs, style_targets, content_targets, style_weight, content_weight)
loss += total_variation_weight * tf.image.total_variation(image)
grad = tape.gradient(loss, image)
opt.apply_gradients([(grad, image)])
image.assign(clip_0_1(image))
return loss
def tensor_to_image(tensor):
"""Convert tensor to PIL Image"""
tensor = tensor * 255
tensor = np.array(tensor, dtype=np.uint8)
if np.ndim(tensor) > 3:
tensor = tensor[0]
return Image.fromarray(tensor)
# Initialize the style-content model
style_layers = ['block1_conv1', 'block2_conv1', 'block3_conv1', 'block4_conv1', 'block5_conv1']
content_layers = ['block5_conv2']
extractor = StyleContentModel(style_layers, content_layers)
@spaces.GPU(duration=120) # Style transfer typically needs more than 60s
def style_transfer_fn(content_image, style_image, progress=gr.Progress(track_tqdm=True)):
"""Main style transfer function for Gradio interface"""
try:
# Preprocess images
content_img = load_img(content_image)
style_img = load_img(style_image)
# Extract style and content features
style_targets = extractor(style_img)['style']
content_targets = extractor(content_img)['content']
image = tf.Variable(content_img)
# Set optimization parameters
opt = tf.keras.optimizers.Adam(learning_rate=0.02, beta_1=0.99, epsilon=1e-1)
style_weight = 1e-2
content_weight = 1e4
total_variation_weight = 30
epochs = 10
steps_per_epoch = 100
start_time = time.time()
# Training loop
for n in tqdm(range(epochs), desc="Epochs"):
for m in tqdm(range(steps_per_epoch), desc="Steps", leave=False):
loss = train_step(image, extractor, style_targets, content_targets,
opt, style_weight, content_weight, total_variation_weight)
# Convert result to image
result_image = tensor_to_image(image)
return result_image
except Exception as e:
logger.error(f"Error during style transfer: {e}")
raise gr.Error("An error occurred during style transfer.")
# Create Gradio interface
iface = gr.Interface(
fn=style_transfer_fn,
inputs=[
gr.Image(label="Content Image", type="pil"),
gr.Image(label="Style Image", type="pil")
],
outputs=gr.Image(label="Stylized Image"),
title="Neural Style Transfer - Ty Chermsirivatana",
description="Upload a content image and a style image to create a stylized image in context.",
)
# Launch the interface
if __name__ == "__main__":
iface.launch()