File size: 1,852 Bytes
6a8dfee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
import numpy as np
import tensorflow as tf
from module import encoder, decoder
from glob import glob
import runway


@runway.setup(options={"styleCheckpoint": runway.file(is_directory=True)})
def setup(opts):
    sess = tf.Session()
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    with tf.name_scope("placeholder"):
        input_photo = tf.placeholder(
            dtype=tf.float32, shape=[1, None, None, 3], name="photo"
        )
    input_photo_features = encoder(
        image=input_photo, options={"gf_dim": 32}, reuse=False
    )
    output_photo = decoder(
        features=input_photo_features, options={"gf_dim": 32}, reuse=False
    )
    saver = tf.train.Saver()
    path = opts["styleCheckpoint"]
    model_name = [p for p in os.listdir(path) if os.path.isdir(os.path.join(path, p))][
        0
    ]
    checkpoint_dir = os.path.join(path, model_name, "checkpoint_long")
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
    ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
    saver.restore(sess, os.path.join(checkpoint_dir, ckpt_name))
    return dict(sess=sess, input_photo=input_photo, output_photo=output_photo)


@runway.command(
    "stylize",
    inputs={"contentImage": runway.image},
    outputs={"stylizedImage": runway.image},
)
def stylize(model, inp):
    img = inp["contentImage"]
    img = np.array(img)
    img = img / 127.5 - 1.0
    img = np.expand_dims(img, axis=0)
    img = model["sess"].run(
        model["output_photo"], feed_dict={model["input_photo"]: img}
    )
    img = (img + 1.0) * 127.5
    img = img.astype("uint8")
    img = img[0]
    return dict(stylizedImage=img)


if __name__ == "__main__":
    #print("External Service port is:" + os.environ.get("SPORT",7860))
    #set env var: RW_PORT=7860
    os.environ["RW_PORT"] = "7860"
    runway.run()