File size: 1,251 Bytes
a7f8f41
 
 
ec30580
 
a7f8f41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import gradio as gr
import tensorflow as tf

encoder = tf.keras.models.load_model("nst-encoder.h5", compile=False)
decoder = tf.keras.models.load_model("nst-decoder.h5", compile=False)

def get_mean_std(tensor, epsilon=1e-5):
    axes = [1, 2]
    tensor_mean, tensor_var = tf.nn.moments(tensor, axes=axes, keepdims=True)
    tensor_std = tf.sqrt(tensor_var + epsilon)
    return tensor_mean, tensor_std

def ada_in(style, content, epsilon=1e-5):
    c_mean, c_std = get_mean_std(content)
    s_mean, s_std = get_mean_std(style)
    t = s_std * (content - c_mean) / c_std + s_mean
    return t

def load_resize(image):
    image = tf.image.convert_image_dtype(image, dtype="float32")
    image = tf.image.resize(image, (224, 224))
    return image

def infer(style, content):
    style = load_resize(style)
    style = style[tf.newaxis, ...]
    content = load_resize(content)
    content = content[tf.newaxis, ...]

    style_enc = encoder(style)
    content_enc = encoder(content)
    
    t = ada_in(style=style_enc, content=content_enc)
    
    recons_image = decoder(t)
    return recons_image[0].numpy()

iface = gr.Interface(
	fn=infer, 
	inputs=[gr.inputs.Image(label="style"),
		gr.inputs.Image(label="content")],
	outputs="image").launch()