File size: 1,481 Bytes
a7f8f41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import gradio as gr
import tensorflow as tf
import wget

enc_url = 'https://huggingface.co/ariG23498/nst/blob/main/nst-encoder.h5'
enc_filename = wget.download(enc_url)

dec_url = 'https://huggingface.co/ariG23498/nst/blob/main/nst-decoder.h5'
dec_filename = wget.download(dec_url)

encoder = tf.keras.models.load_model(enc_filename, compile=False)
decoder = tf.keras.models.load_model(dec_filename, compile=False)

def get_mean_std(tensor, epsilon=1e-5):
    axes = [1, 2]
    tensor_mean, tensor_var = tf.nn.moments(tensor, axes=axes, keepdims=True)
    tensor_std = tf.sqrt(tensor_var + epsilon)
    return tensor_mean, tensor_std

def ada_in(style, content, epsilon=1e-5):
    c_mean, c_std = get_mean_std(content)
    s_mean, s_std = get_mean_std(style)
    t = s_std * (content - c_mean) / c_std + s_mean
    return t

def load_resize(image):
    image = tf.image.convert_image_dtype(image, dtype="float32")
    image = tf.image.resize(image, (224, 224))
    return image

def infer(style, content):
    style = load_resize(style)
    style = style[tf.newaxis, ...]
    content = load_resize(content)
    content = content[tf.newaxis, ...]

    style_enc = encoder(style)
    content_enc = encoder(content)
    
    t = ada_in(style=style_enc, content=content_enc)
    
    recons_image = decoder(t)
    return recons_image[0].numpy()

iface = gr.Interface(
	fn=infer, 
	inputs=[gr.inputs.Image(label="style"),
		gr.inputs.Image(label="content")],
	outputs="image").launch()