breynolds1247's picture
Update app.py
4fd82b8
raw
history blame
No virus
6.29 kB
import os
#os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
#Imports
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import tensorflow_hub as hub
from PIL import Image
import gradio as gr
from helper_functions import *
def style_transfer(input_image, artist):
style_path_van_gogh = keras.utils.get_file('Starry-Night-canvas-Vincent-van-Gogh-New-1889.jpg',
'https://cdn.britannica.com/78/43678-050-F4DC8D93/Starry-Night-canvas-Vincent-van-Gogh-New-1889.jpg')
style_path_davinci = keras.utils.get_file('Leonardo_da_Vinci_-_Mona_Lisa_%28La_Gioconda%29_-_WGA12711.jpg',
'https://upload.wikimedia.org/wikipedia/commons/f/f2/Leonardo_da_Vinci_-_Mona_Lisa_%28La_Gioconda%29_-_WGA12711.jpg')
style_path_dali = keras.utils.get_file('The_Persistence_of_Memory.jpg',
'https://upload.wikimedia.org/wikipedia/en/d/dd/The_Persistence_of_Memory.jpg')
style_path_monet = keras.utils.get_file('Claude_Monet_-_Water_Lilies_-_Google_Art_Project_%28462013%29.jpg',
'https://upload.wikimedia.org/wikipedia/commons/a/af/Claude_Monet_-_Water_Lilies_-_Google_Art_Project_%28462013%29.jpg')
style_path_picasso = keras.utils.get_file('Picasso_The_Weeping_Woman_Tate_identifier_T05010_10.jpg',
'https://upload.wikimedia.org/wikipedia/en/1/14/Picasso_The_Weeping_Woman_Tate_identifier_T05010_10.jpg')
style_path_rembrandt = keras.utils.get_file('1259px-The_Nightwatch_by_Rembrandt_-_Rijksmuseum.jpg',
'https://upload.wikimedia.org/wikipedia/commons/thumb/9/94/The_Nightwatch_by_Rembrandt_-_Rijksmuseum.jpg/1259px-The_Nightwatch_by_Rembrandt_-_Rijksmuseum.jpg')
#set dimensions of input image
oc_max_dim = 1080
#set parameters for each choice of artist
if artist == "Vincent van Gogh":
style_max_dim = 442
style_path = style_path_van_gogh
elif artist == "Claude Monet":
style_max_dim = 256
style_path = style_path_monet
elif artist == "Leonardo da Vinci":
style_max_dim = 442
style_path = style_path_davinci
elif artist == "Rembrandt":
style_max_dim = 256
style_path = style_path_rembrandt
elif artist == "Pablo Picasso":
style_max_dim = 256
style_path = style_path_picasso
elif artist == "Salvador Dali":
style_max_dim = 512
style_path = style_path_dali
#load content and style images
content_image = load_img(input_image, content=True, max_dim=oc_max_dim)
style_image = load_img(style_path, content=False, max_dim=style_max_dim)
#Load Magenta Arbitrary Image Stylization network
hub_module = hub.load('https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/1')
#Pass content and style images as arguments in TensorFlow Constant object format
stylized_image = hub_module(tf.constant(content_image), tf.constant(style_image))[0]
print("stylized_image:")
print(stylized_image[0])
print(stylized_image)
#return tf.io.encode_jpeg(stylized_image[0])
return tf.keras.preprocessing.image.img_to_array(stylized_image[0])
app = gr.Interface(
style_transfer,
[gr.Image(type='pil'), gr.Radio(["Vincent van Gogh", "Claude Monet", "Leonardo da Vinci", "Rembrandt", "Pablo Picasso", "Salvador Dali"])],
gr.Image(type='pil'),
title="Artist Style Transfer Tool",
description="Make your own art in the style of six famous artists using pretrained neural networks and deep learning!"
#article="https://arxiv.org/abs/1705.06830"
)
app.launch()
"""
def inference(input_image):
preprocess = transforms.Compose([
transforms.Resize(260),
transforms.CenterCrop(224),
transforms.ToTensor(),
#transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
# move the input and model to GPU for speed if available
if torch.cuda.is_available():
input_batch = input_batch.to('cuda')
model.to('cuda')
else:
model.to('cpu')
with torch.no_grad():
output = model(input_batch)
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
probabilities = torch.nn.functional.softmax(output[0], dim=0)
# Read the categories
with open("artist_classes.txt", "r") as f:
categories = [s.strip() for s in f.readlines()]
categories = {
0:"vanGogh",
1:"Monet",
2:"Leonardo da Vinci",
3:"Rembrandt",
4:"Pablo Picasso",
5:"Salvador Dali"
}
# Show top categories per image
top5_prob, top5_catid = torch.topk(probabilities, 6)
result = {}
for i in range(top5_prob.size(0)):
result[categories[top5_catid[i].item()]] = top5_prob[i].item()
return result"""
##inputs = gr.Image(type='pil')
##outputs = gr.Label(type="confidences",num_top_classes=5)
##title = "Artist Classifier"
##description = "Gradio demo for MOBILENET V2, Efficient networks optimized for speed and memory, with residual blocks. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
##article = "<p style='text-align: center'><a href='https://arxiv.org/abs/1801.04381'>MobileNetV2: Inverted Residuals and Linear Bottlenecks</a> | <a href='https://github.com/pytorch/vision/blob/master/torchvision/models/mobilenet.py'>Github Repo</a></p>"
"""
def greet(name):
return "Hello " + name + "!!"
demo = gr.Interface(fn=greet, inputs="text", outputs="text")
demo.launch()
"""
#examples = [
# ['dog.jpg']
#]
#gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=examples, analytics_enabled=False).launch()
#gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, analytics_enabled=False).launch()