Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
"""TF-Hub: Fast Style Transfer for Arbitrary Styles.ipynb | |
Automatically generated by Colaboratory. | |
Original file is located at | |
https://colab.research.google.com/github/tensorflow/hub/blob/master/examples/colab/tf2_arbitrary_image_stylization.ipynb | |
##### Copyright 2019 The TensorFlow Hub Authors. | |
Licensed under the Apache License, Version 2.0 (the "License"); | |
""" | |
# Copyright 2019 The TensorFlow Hub Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""# Fast Style Transfer for Arbitrary Styles | |
<table class="tfo-notebook-buttons" align="left"> | |
<td> | |
<a target="_blank" href="https://www.tensorflow.org/hub/tutorials/tf2_arbitrary_image_stylization"><img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a> | |
</td> | |
<td> | |
<a target="_blank" href="https://colab.research.google.com/github/tensorflow/hub/blob/master/examples/colab/tf2_arbitrary_image_stylization.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a> | |
</td> | |
<td> | |
<a target="_blank" href="https://github.com/tensorflow/hub/blob/master/examples/colab/tf2_arbitrary_image_stylization.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View on GitHub</a> | |
</td> | |
<td> | |
<a href="https://storage.googleapis.com/tensorflow_docs/hub/examples/colab/tf2_arbitrary_image_stylization.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png" />Download notebook</a> | |
</td> | |
<td> | |
<a href="https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2"><img src="https://www.tensorflow.org/images/hub_logo_32px.png" />See TF Hub model</a> | |
</td> | |
</table> | |
Based on the model code in [magenta](https://github.com/tensorflow/magenta/tree/master/magenta/models/arbitrary_image_stylization) and the publication: | |
[Exploring the structure of a real-time, arbitrary neural artistic stylization | |
network](https://arxiv.org/abs/1705.06830). | |
*Golnaz Ghiasi, Honglak Lee, | |
Manjunath Kudlur, Vincent Dumoulin, Jonathon Shlens*, | |
Proceedings of the British Machine Vision Conference (BMVC), 2017. | |
## Setup | |
Let's start with importing TF2 and all relevant dependencies. | |
""" | |
import functools | |
import os | |
from PIL import Image | |
from matplotlib import gridspec | |
import matplotlib.pylab as plt | |
import numpy as np | |
import tensorflow as tf | |
import tensorflow_hub as hub | |
import gradio as gr | |
# @title Define image loading and visualization functions { display-mode: "form" } | |
def crop_center(image): | |
"""Returns a cropped square image.""" | |
shape = image.shape | |
new_shape = min(shape[1], shape[2]) | |
offset_y = max(shape[1] - shape[2], 0) // 2 | |
offset_x = max(shape[2] - shape[1], 0) // 2 | |
image = tf.image.crop_to_bounding_box( | |
image, offset_y, offset_x, new_shape, new_shape) | |
return image | |
def load_image(image, image_size=(256, 256), preserve_aspect_ratio=True): | |
"""Loads and preprocesses images.""" | |
# Cache image file locally. | |
#image_path = tf.keras.utils.get_file(os.path.basename(image_url)[-128:], image_url) | |
# Load and convert to float32 numpy array, add batch dimension, and normalize to range [0, 1]. | |
#img = tf.io.decode_image( | |
# tf.io.read_file(image_path), | |
#channels=3, dtype=tf.float32)[tf.newaxis, ...] | |
#img = crop_center(image) | |
img = tf.image.resize(image, image_size, preserve_aspect_ratio=True) | |
return img | |
def show_n(images, titles=('',)): | |
n = len(images) | |
image_sizes = [image.shape[1] for image in images] | |
w = (image_sizes[0] * 6) // 320 | |
plt.figure(figsize=(w * n, w)) | |
gs = gridspec.GridSpec(1, n, width_ratios=image_sizes) | |
for i in range(n): | |
plt.subplot(gs[i]) | |
plt.imshow(images[i][0], aspect='equal') | |
plt.axis('off') | |
plt.title(titles[i] if len(titles) > i else '') | |
plt.show() | |
"""Let's get as well some images to play with.""" | |
# @title Load example images { display-mode: "form" } | |
#content_image_url = 'https://live.staticflickr.com/65535/52032998695_f57c61746c_c.jpg' # @param {type:"string"} | |
#style_image_url = 'https://live.staticflickr.com/65535/52032731604_a815a0b19f_c.jpg' # @param {type:"string"} | |
output_image_size = 384 # @param {type:"integer"} | |
# The content image size can be arbitrary. | |
content_img_size = (output_image_size, output_image_size) | |
# The style prediction model was trained with image size 256 and it's the | |
# recommended image size for the style image (though, other sizes work as | |
# well but will lead to different results). | |
style_img_size = (256, 256) # Recommended to keep it at 256. | |
# Load images from app | |
content_image_input = gr.inputs.Image(label="Content Image") | |
style_image_input = gr.inputs.Image(shape=(256, 256), label="Style Image") | |
#content_image = load_image(content_image_input, content_img_size) | |
#style_image = load_image(style_image_input, style_img_size) | |
#style_image = tf.nn.avg_pool(style_image, ksize=[3,3], strides=[1,1], padding='SAME') | |
#show_n([content_image, style_image], ['Content image', 'Style image']) | |
"""## Import TF Hub module""" | |
# Load TF Hub module. | |
hub_handle = 'https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2' | |
hub_module = hub.load(hub_handle) | |
"""The signature of this hub module for image stylization is: | |
``` | |
outputs = hub_module(content_image, style_image) | |
stylized_image = outputs[0] | |
``` | |
Where `content_image`, `style_image`, and `stylized_image` are expected to be 4-D Tensors with shapes `[batch_size, image_height, image_width, 3]`. | |
In the current example we provide only single images and therefore the batch dimension is 1, but one can use the same module to process more images at the same time. | |
The input and output values of the images should be in the range [0, 1]. | |
The shapes of content and style image don't have to match. Output image shape | |
is the same as the content image shape. | |
## Demonstrate image stylization | |
""" | |
# Stylize content image with given style image. | |
# This is pretty fast within a few milliseconds on a GPU. | |
''' | |
def modify(imageinput,style_input): | |
content_image = load_image(imageinput, content_img_size) | |
style_image = load_image(style_input, style_img_size) | |
style_image = tf.nn.avg_pool(style_image, ksize=[3,3], strides=[1,1], padding='SAME') | |
#show_n([content_image, style_image], ['Content image', 'Style image']) | |
outputs = hub_module(tf.constant(imageinput), tf.constant(style_input)) | |
return outputs[0] | |
''' | |
def perform_style_transfer(content_image, style_image): | |
content_image = tf.convert_to_tensor(content_image, np.float32)[tf.newaxis, ...] / 255. | |
style_image = tf.convert_to_tensor(style_image, np.float32)[tf.newaxis, ...] / 255. | |
output = hub_module(content_image, style_image) | |
stylized_image = output[0] | |
return Image.fromarray(np.uint8(stylized_image[0] * 255)) | |
#stylized_image = outputs[0] | |
# Visualize input images and the generated stylized image. | |
#show_n([content_image, style_image, stylized_image], titles=['Original content image', 'Style image', 'Stylized image']) | |
# Gradio app | |
#label = gr.outputs.Image(modify(content_image_input, style_image_input)) | |
app_interface = gr.Interface(perform_style_transfer, | |
inputs=[content_image_input, style_image_input], | |
outputs = gr.outputs.Image(), | |
title="Fast Neural Style Transfer", | |
description="Gradio demo for Fast Neural Style Transfer using a pretrained Image Stylization model from TensorFlow Hub. To use it, simply upload a content image and style image. To learn more about the project, please find the references listed below.", | |
) | |
app_interface.launch(debug= True) | |