File size: 5,062 Bytes
8769c8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e23189
 
 
 
 
b924eac
8769c8a
 
5e23189
8769c8a
 
 
 
9d87f5f
8769c8a
 
 
 
9d87f5f
8769c8a
 
a1d9395
 
 
 
8769c8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e23189
8769c8a
 
 
9d87f5f
 
 
8769c8a
9d87f5f
8769c8a
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
import tensorflow as tf
os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'
import numpy as np
import PIL.Image
import gradio as gr
import tensorflow_hub as hub
import matplotlib.pyplot as plt
from real_esrgan_app import *

'''
inference(img,mode)
'''

hub_module = hub.load('https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2')

def tensor_to_image(tensor):
    tensor = tensor*255
    tensor = np.array(tensor, dtype=np.uint8)
    if np.ndim(tensor)>3:
      assert tensor.shape[0] == 1
      tensor = tensor[0]
    return PIL.Image.fromarray(tensor)


style_urls = {
    'Kanagawa great wave': 'The_Great_Wave_off_Kanagawa.jpg',
    'Kandinsky composition 7': 'Kandinsky_Composition_7.jpg',
    'Hubble pillars of creation': 'Pillars_of_creation_2014_HST_WFC3-UVIS_full-res_denoised.jpg',
    'Van gogh starry night': 'Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg',
    'Turner nantes': 'JMW_Turner_-_Nantes_from_the_Ile_Feydeau.jpg',
    'Munch scream': 'Edvard_Munch.jpg',
    'Picasso demoiselles avignon': 'Les_Demoiselles.jpg',
    'Picasso violin': 'picaso_violin.jpg',
    'Picasso bottle of rum': 'picaso_rum.jpg',
    'Fire': 'Large_bonfire.jpg',
    'Derkovits woman head': 'Derkovits_Gyula_Woman_head_1922.jpg',
    'Amadeo style life': 'Amadeo_Souza_Cardoso.jpg',
    'Derkovtis talig': 'Derkovits_Gyula_Talig.jpg',
    'Kadishman': 'kadishman.jpeg'
}


style_images = [k for k, v in style_urls.items()]

def image_click(images, evt: gr.SelectData,
    ):
    img_selected = images[evt.index]["name"]
    #print(img_selected)
    return img_selected


#radio_style = gr.Radio(style_images, label="Choose Style")

def perform_neural_transfer(content_image_input, style_image_input, super_resolution_type, hub_module = hub_module):
    content_image = content_image_input.astype(np.float32)[np.newaxis, ...] / 255.
    content_image = tf.image.resize(content_image, (400, 600))

    #style_image_input = style_urls[style_image_input]
    #style_image_input = plt.imread(style_image_input)
    style_image = style_image_input.astype(np.float32)[np.newaxis, ...] / 255.

    style_image = tf.image.resize(style_image, (256, 256))

    outputs = hub_module(tf.constant(content_image), tf.constant(style_image))
    stylized_image = outputs[0]

    stylized_image = tensor_to_image(stylized_image)
    content_image_input = tensor_to_image(content_image_input)
    stylized_image = stylized_image.resize(content_image_input.size)

    print("super_resolution_type :")
    print(super_resolution_type)
    #print(super_resolution_type.value)
    
    if super_resolution_type not in ["base", "anime"]:
        return stylized_image
    else:
        print("call else :")
        stylized_image = inference(stylized_image, super_resolution_type)
        return stylized_image

with gr.Blocks() as demo:
    gr.HTML("<h1><center> 🐑 Art Generation with Neural Style Transfer Fixed by Real-ESRGAN </center></h1>")

    with gr.Row():
        style_reference_input_gallery = gr.Gallery(list(style_urls.values()),
                        #width = 512,
                        height = 768 + 128,
                        label = "Style Image gallery (click to use)")
        with gr.Column():
            #super_resolution_type = gr.Radio(["base", "anime", "none"], type="value", default="base", label="choose Real-ESRGAN model type used to super resolution the Image Transformed")
            super_resolution_type = gr.Radio(choices = ["base", "anime", "none"],
                        value="base", label="choose Real-ESRGAN model type used to super resolution the Image Transformed", 
                                                 interactive = True)
            style_reference_input_image = gr.Image(
                        label = "Style Image (you can upload yourself or click from left gallery)",
                        #width = 512,
                        interactive = True, value = style_urls["Kanagawa great wave"]
                        )
            content_image_input = gr.Image(label="Content Image", interactive = True,
                        #width = 512
                    )
            trans_image_output = gr.Image(label="Image Transformed", interactive = True,
                    #width = 512
            )
            trans_button = gr.Button(label = "transform Content image style from Style Image")


    style_reference_input_gallery.select(
        image_click, style_reference_input_gallery, style_reference_input_image
    )

    trans_button.click(perform_neural_transfer, [content_image_input, style_reference_input_image, super_resolution_type], trans_image_output)

    gr.Examples(
        [
        [style_urls["Kanagawa great wave"], style_urls["Kadishman"], "none"],
        [style_urls["Derkovits woman head"], style_urls["Kadishman"], "base"],
        [style_urls["Kadishman"], style_urls["Kadishman"], "anime"],
        ],
        inputs = [style_reference_input_image, content_image_input, super_resolution_type],
        label = "Transform Examples"
    )

demo.launch()