Shue commited on
Commit
421360e
1 Parent(s): 232aa2a

added Animefy web app files.

Browse files
Files changed (6) hide show
  1. adjust_brightness.py +69 -0
  2. app.py +162 -0
  3. generator.py +161 -0
  4. requirements.txt +3 -0
  5. style.py +39 -0
  6. utils.py +47 -0
adjust_brightness.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ from PIL import Image
4
+
5
+
6
+ # def read_img(image_path):
7
+ # img = cv2.imread(image_path)
8
+ # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
9
+ # assert len(img.shape)==3
10
+ # return img
11
+
12
+ def read_img(image_file_buffer):
13
+ img = Image.open(image_file_buffer).convert('RGB')
14
+ img = np.array(img)
15
+ assert len(img.shape)==3
16
+ return img
17
+
18
+ # Calculates the average brightness in the specified irregular image
19
+ def calculate_average_brightness(img):
20
+ # Average value of three color channels
21
+ R = img[..., 0].mean()
22
+ G = img[..., 1].mean()
23
+ B = img[..., 2].mean()
24
+
25
+ brightness = 0.299 * R + 0.587 * G + 0.114 * B
26
+ return brightness, B, G, R
27
+
28
+ # Adjusting the average brightness of the target image to the average brightness of the source image
29
+ def adjust_brightness_from_src_to_dst(dst, src,path=None,if_show=None, if_info=None):
30
+ brightness1, B1, G1, R1 = calculate_average_brightness(src)
31
+ brightness2, B2, G2, R2 = calculate_average_brightness(dst)
32
+ brightness_difference = brightness1 / brightness2
33
+
34
+ if if_info:
35
+ print('Average brightness of original image', brightness1)
36
+ print('Average brightness of target', brightness2)
37
+ print('Brightness Difference between Original Image and Target', brightness_difference)
38
+
39
+ # According to the average display brightness
40
+ dstf = dst * brightness_difference
41
+
42
+ # According to the average value of the three-color channel
43
+ # dstf = dst.copy().astype(np.float32)
44
+ # dstf[..., 0] = dst[..., 0] * (B1 / B2)
45
+ # dstf[..., 1] = dst[..., 1] * (G1 / G2)
46
+ # dstf[..., 2] = dst[..., 2] * (R1 / R2)
47
+
48
+ # To limit the results and prevent crossing the border,
49
+ # it must be converted to uint8, otherwise the default result is float32, and errors will occur.
50
+ dstf = np.clip(dstf, 0, 255)
51
+ dstf = np.uint8(dstf)
52
+
53
+ ma,na,_ = src.shape
54
+ mb,nb,_ = dst.shape
55
+ result_show_img = np.zeros((max(ma, mb), 3 * max(na, nb), 3))
56
+ result_show_img[:mb, :nb, :] = dst
57
+ result_show_img[:ma, nb:nb + na, :] = src
58
+ result_show_img[:mb, nb + na:nb + na + nb, :] = dstf
59
+ result_show_img = result_show_img.astype(np.uint8)
60
+
61
+ if if_show:
62
+ cv2.imshow('-', cv2.cvtColor(result_show_img, cv2.COLOR_BGR2RGB))
63
+ cv2.waitKey(0)
64
+ cv2.destroyAllWindows()
65
+
66
+ if path != None:
67
+ cv2.imwrite(path, cv2.cvtColor(result_show_img, cv2.COLOR_BGR2RGB))
68
+
69
+ return dstf
app.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import cv2 as cv
4
+ import numpy as np
5
+ from random import randint
6
+ import threading
7
+
8
+ # style.py
9
+ import style
10
+
11
+ st.set_page_config(page_title="Animefy", page_icon="images/animefy_logo.png")
12
+
13
+ model_lock = threading.Lock()
14
+
15
+ # remove "Made with Streamlit" footer text
16
+ # uncomment "#MainMenu {visibility: hidden;}" to also remove the default Streamlit hamburger menu
17
+ hide_streamlit = """
18
+ <style>
19
+ #MainMenu {
20
+ visibility: hidden;
21
+ }
22
+ footer {
23
+ visibility: hidden;
24
+ }
25
+ header {
26
+ visibility: hidden;
27
+ }
28
+ </style>
29
+ """
30
+ st.markdown(hide_streamlit, unsafe_allow_html=True)
31
+
32
+ # randomizer. a workaround for clearing the contents of the file_uploader
33
+ if 'uploader_key' not in st.session_state:
34
+ st.session_state['uploader_key'] = str(randint(1000, 100000000))
35
+
36
+ # home page title and caption
37
+ st.markdown("""
38
+ # 📸 Animefy
39
+
40
+ Convert your photos into **anime** with _ease_ using **AnimeGAN**.
41
+ """)
42
+
43
+ st.markdown('---')
44
+
45
+ # main container of the page
46
+ page_container = st.empty()
47
+
48
+ # store home page contents inside page_container
49
+ home_page = page_container.container()
50
+
51
+ # step #1
52
+ home_page.markdown("""
53
+ ### Step #1: Upload the photo that you would like to process!
54
+ """)
55
+
56
+ # just some notes for the user
57
+ with home_page.expander("📣 Here are some things to take note of...", expanded=True):
58
+ st.write("""
59
+ * Do note that AnimeGAN works best with images containing **sceneries without people**.
60
+ * For best results, use images that **do not** contain human subjects.
61
+ * Due to server hardware limitations, only upload images with **at most** a resolution of **1980x1080**.
62
+ * Fore more information on AnimeGAN, click [here](https://github.com/TonyLianLong/AnimeGAN.js).
63
+ """)
64
+
65
+ # upload image functionality
66
+ uploaded_image = home_page.file_uploader(
67
+ "If you're ready, you can now upload your image here:", type=['png','jpg','jpeg'], key=st.session_state['uploader_key']
68
+ )
69
+
70
+ # if there is an uploaded image, show next steps
71
+ if uploaded_image is not None:
72
+ # just a preview of the uploaded image
73
+ home_page.markdown("""
74
+ #### Uploaded Image
75
+
76
+ Here's your photo! Just upload another one if you would like to change it 😉
77
+ """)
78
+ home_page.image(Image.open(uploaded_image))
79
+
80
+ home_page.write("---")
81
+
82
+ # step #2
83
+ home_page.markdown("""
84
+ ### Step #2: Now, select your preferred animation style!
85
+ """)
86
+
87
+ # drop down list for anime style to be applied to image
88
+ anime_style = home_page.selectbox (
89
+ 'Your preferred animation style:',
90
+ ('Paprika', 'Shinkai', 'Hayao')
91
+ )
92
+
93
+ # just some more notes for the user
94
+ with home_page.expander("🤔 What are these animation styles?", expanded=False):
95
+ st.write("""
96
+ These styles were derived from the works of various directors! Some of these might be familiar to you:
97
+ * Satoshi Kon: **Paprika**
98
+ * Makoto **Shinkai**: Your Name, 5 Centimeters per Second, Weathering with You
99
+ * **Hayao** Miyazaki: Spirited Away, My Neighbor Totoro, Princess Mononoke
100
+ """)
101
+
102
+ home_page.write("---")
103
+
104
+ # stylize image
105
+ home_page.markdown("If you're all set, then let's proceed! 😄")
106
+ stylize_btn = home_page.button("Stylize!")
107
+
108
+ # if "stylize" button is clicked,
109
+ if stylize_btn:
110
+ # remove processing page contents
111
+ page_container.empty()
112
+
113
+ with st.spinner('Hold on... Please do not close this tab....'):
114
+ model_lock.acquire()
115
+
116
+ # spinner (while processing image)
117
+ with st.spinner('Hold on... Processing your image...'):
118
+ # stylize input image and produce output
119
+ output_image = style.stylize(anime_style, uploaded_image)
120
+ model_lock.release()
121
+
122
+ # step #3
123
+ st.markdown("""
124
+ ### Step #3: Download your image!
125
+ """)
126
+
127
+ # display original and output images
128
+ st.markdown("""
129
+ Here's a before and after!
130
+ """)
131
+ before_col, after_col = st.columns(2)
132
+ with before_col:
133
+ # clamp and channels are used since OpenCV was used in processing the image
134
+ st.image(uploaded_image, clamp=True, channels='RGB')
135
+ with after_col:
136
+ # clamp and channels are used since OpenCV was used in processing the image
137
+ st.image(output_image, clamp=True, channels='RGB')
138
+
139
+ st.write("---")
140
+
141
+ # prepare output image for downloading
142
+ img_encode = cv.imencode('.jpg', output_image)[1]
143
+ data_encode = np.array(img_encode)
144
+ byte_encode = data_encode.tobytes()
145
+
146
+ # some instruction for downloading
147
+ st.write("Finally, just click this to download your _anime-fied_ image!")
148
+ # download button
149
+ st.download_button('Download Image', byte_encode, 'output.jpg', 'jpg')
150
+
151
+ st.write("---")
152
+
153
+ # retry message
154
+ st.markdown('Not satisfied? Click this to retry!')
155
+ # retry button
156
+ retry_btn = st.button("Retry!")
157
+
158
+ # randomizer. just another workaround.
159
+ st.session_state['uploader_key'] = str(randint(1000, 100000000))
160
+
161
+ if retry_btn:
162
+ page_container.empty()
generator.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow.contrib as tf_contrib
2
+ import tensorflow as tf
3
+
4
+ def layer_norm(x, scope='layer_norm') :
5
+ return tf_contrib.layers.layer_norm(x,
6
+ center=True, scale=True,
7
+ scope=scope)
8
+
9
+ def lrelu(x, alpha=0.2):
10
+ return tf.nn.leaky_relu(x, alpha)
11
+
12
+ def Conv2D(inputs, filters, kernel_size=3, strides=1, padding='VALID', Use_bias = None):
13
+ if kernel_size == 3 and strides == 1:
14
+ inputs = tf.pad(inputs, [[0, 0], [1, 1], [1, 1], [0, 0]], mode="REFLECT")
15
+ if kernel_size == 7 and strides == 1:
16
+ inputs = tf.pad(inputs, [[0, 0], [3, 3], [3, 3], [0, 0]], mode="REFLECT")
17
+ if strides == 2:
18
+ inputs = tf.pad(inputs, [[0, 0], [0, 1], [0, 1], [0, 0]], mode="REFLECT")
19
+ return tf.contrib.layers.conv2d(
20
+ inputs,
21
+ num_outputs=filters,
22
+ kernel_size=kernel_size,
23
+ stride=strides,
24
+ weights_initializer=tf.contrib.layers.variance_scaling_initializer(),
25
+ biases_initializer= Use_bias,
26
+ normalizer_fn=None,
27
+ activation_fn=None,
28
+ padding=padding)
29
+
30
+
31
+ def Conv2DNormLReLU(inputs, filters, kernel_size=3, strides=1, padding='VALID', Use_bias = None):
32
+ x = Conv2D(inputs, filters, kernel_size, strides,padding=padding, Use_bias = Use_bias)
33
+ x = layer_norm(x,scope=None)
34
+ return lrelu(x)
35
+
36
+ def dwise_conv(input, k_h=3, k_w=3, channel_multiplier=1, strides=[1, 1, 1, 1],
37
+ padding='VALID', name='dwise_conv', bias = True):
38
+ input = tf.pad(input, [[0, 0], [1, 1], [1, 1], [0, 0]], mode="REFLECT")
39
+ with tf.variable_scope(name):
40
+ in_channel = input.get_shape().as_list()[-1]
41
+ w = tf.get_variable('w', [k_h, k_w, in_channel, channel_multiplier],regularizer=None,initializer=tf.contrib.layers.variance_scaling_initializer())
42
+ conv = tf.nn.depthwise_conv2d(input, w, strides, padding, rate=None, name=name, data_format=None)
43
+ if bias:
44
+ biases = tf.get_variable('bias', [in_channel * channel_multiplier],initializer=tf.constant_initializer(0.0))
45
+ conv = tf.nn.bias_add(conv, biases)
46
+ return conv
47
+
48
+
49
+ def Unsample(inputs, filters, kernel_size=3):
50
+ '''
51
+ An alternative to transposed convolution where we first resize, then convolve.
52
+ See http://distill.pub/2016/deconv-checkerboard/
53
+ For some reason the shape needs to be statically known for gradient propagation
54
+ through tf.image.resize_images, but we only know that for fixed image size, so we
55
+ plumb through a "training" argument
56
+ '''
57
+ new_H, new_W = 2 * tf.shape(inputs)[1], 2 * tf.shape(inputs)[2]
58
+ inputs = tf.image.resize_images(inputs, [new_H, new_W])
59
+
60
+ return Conv2DNormLReLU(filters=filters, kernel_size=kernel_size, inputs=inputs)
61
+
62
+
63
+ class G_net(object):
64
+
65
+
66
+ def __init__(self, inputs):
67
+
68
+ with tf.variable_scope('G_MODEL'):
69
+
70
+ with tf.variable_scope('A'):
71
+ inputs = Conv2DNormLReLU(inputs, 32, 7)
72
+ inputs = Conv2DNormLReLU(inputs, 64, strides=2)
73
+ inputs = Conv2DNormLReLU(inputs, 64)
74
+
75
+ with tf.variable_scope('B'):
76
+ inputs = Conv2DNormLReLU(inputs, 128, strides=2)
77
+ inputs = Conv2DNormLReLU(inputs, 128)
78
+
79
+ with tf.variable_scope('C'):
80
+ inputs = Conv2DNormLReLU(inputs, 128)
81
+ inputs = self.InvertedRes_block(inputs, 2, 256, 1, 'r1')
82
+ inputs = self.InvertedRes_block(inputs, 2, 256, 1, 'r2')
83
+ inputs = self.InvertedRes_block(inputs, 2, 256, 1, 'r3')
84
+ inputs = self.InvertedRes_block(inputs, 2, 256, 1, 'r4')
85
+ inputs = Conv2DNormLReLU(inputs, 128)
86
+
87
+ with tf.variable_scope('D'):
88
+ inputs = Unsample(inputs, 128)
89
+ inputs = Conv2DNormLReLU(inputs, 128)
90
+
91
+ with tf.variable_scope('E'):
92
+ inputs = Unsample(inputs,64)
93
+ inputs = Conv2DNormLReLU(inputs, 64)
94
+ inputs = Conv2DNormLReLU(inputs, 32, 7)
95
+ with tf.variable_scope('out_layer'):
96
+ out = Conv2D(inputs, filters =3, kernel_size=1, strides=1)
97
+ self.fake = tf.tanh(out)
98
+
99
+
100
+ def InvertedRes_block(self, input, expansion_ratio, output_dim, stride, name, reuse=False, bias=None):
101
+ with tf.variable_scope(name, reuse=reuse):
102
+ # pw
103
+ bottleneck_dim = round(expansion_ratio * input.get_shape().as_list()[-1])
104
+ net = Conv2DNormLReLU(input, bottleneck_dim, kernel_size=1, Use_bias=bias)
105
+
106
+ # dw
107
+ net = dwise_conv(net, name=name)
108
+ net = layer_norm(net,scope='1')
109
+ net = lrelu(net)
110
+
111
+ # pw & linear
112
+ net = Conv2D(net, output_dim, kernel_size=1)
113
+ net = layer_norm(net,scope='2')
114
+
115
+ # element wise add, only for stride==1
116
+ if (int(input.get_shape().as_list()[-1]) == output_dim) and stride == 1:
117
+ net = input + net
118
+
119
+ return net
120
+
121
+ def Downsample(inputs, filters = 256, kernel_size=3):
122
+ '''
123
+ An alternative to transposed convolution where we first resize, then convolve.
124
+ See http://distill.pub/2016/deconv-checkerboard/
125
+ For some reason the shape needs to be statically known for gradient propagation
126
+ through tf.image.resize_images, but we only know that for fixed image size, so we
127
+ plumb through a "training" argument
128
+ '''
129
+
130
+ new_H, new_W = tf.shape(inputs)[1] // 2, tf.shape(inputs)[2] // 2
131
+ inputs = tf.image.resize_images(inputs, [new_H, new_W])
132
+
133
+ return Separable_conv2d(filters=filters, kernel_size=kernel_size, inputs=inputs)
134
+
135
+ def Conv2DTransposeLReLU(inputs, filters, kernel_size=2, strides=2, padding='SAME', Use_bias = None):
136
+
137
+ return tf.contrib.layers.conv2d_transpose(inputs,
138
+ num_outputs=filters,
139
+ kernel_size=kernel_size,
140
+ stride=strides,
141
+ biases_initializer=Use_bias,
142
+ normalizer_fn=tf.contrib.layers.instance_norm,
143
+ activation_fn=lrelu,
144
+ padding=padding)
145
+
146
+ def Separable_conv2d(inputs, filters, kernel_size=3, strides=1, padding='VALID', Use_bias = tf.zeros_initializer()):
147
+ if kernel_size==3 and strides==1:
148
+ inputs = tf.pad(inputs, [[0, 0], [1, 1], [1, 1], [0, 0]], mode="REFLECT")
149
+ if strides == 2:
150
+ inputs = tf.pad(inputs, [[0, 0], [0, 1], [0, 1], [0, 0]], mode="REFLECT")
151
+ return tf.contrib.layers.separable_conv2d(
152
+ inputs,
153
+ num_outputs=filters,
154
+ kernel_size=kernel_size,
155
+ depth_multiplier=1,
156
+ stride=strides,
157
+ weights_initializer=tf.contrib.layers.variance_scaling_initializer(),
158
+ biases_initializer=Use_bias,
159
+ normalizer_fn=tf.contrib.layers.layer_norm,
160
+ activation_fn=lrelu,
161
+ padding=padding)
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ tensorflow-gpu==1.15.0
2
+ opencv-python-headless
3
+ numpy
style.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tensorflow.compat.v1 as tf
3
+ import numpy as np
4
+ os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
5
+
6
+ # utils.py
7
+ from utils import *
8
+
9
+ # generator.py
10
+ import generator
11
+
12
+ # stylize input image with chosen anime style
13
+ def stylize(anime_style, input_image):
14
+ test_real = tf.placeholder(tf.float32, [1, None, None, 3], name='test')
15
+
16
+ with tf.variable_scope("generator", reuse=False):
17
+ test_generated = generator.G_net(test_real).fake
18
+ saver = tf.train.Saver()
19
+
20
+ # get model checkpoint folder according to chosen anime style
21
+ checkpoint_dir = 'models/' + anime_style
22
+
23
+ gpu_options = tf.GPUOptions(allow_growth=True)
24
+ with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, gpu_options=gpu_options)) as sess:
25
+ # load style model and its weights
26
+ ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
27
+ ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
28
+ saver.restore(sess, os.path.join(checkpoint_dir, ckpt_name))
29
+
30
+ # load and preprocess input image as a NumPy array
31
+ image = np.asarray(load_input_image(input_image))
32
+
33
+ # stylize image
34
+ output_image = sess.run(test_generated, feed_dict = {test_real : image})
35
+
36
+ # adjust brightness of output image
37
+ output_image = adjust_brightness_from_src_to_dst(inverse_transform(output_image.squeeze()), read_img(input_image))
38
+
39
+ return output_image
utils.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow.compat.v1 as tf
2
+ from adjust_brightness import adjust_brightness_from_src_to_dst, read_img
3
+ import cv2
4
+ import numpy as np
5
+ from PIL import Image
6
+
7
+ # def load_input_image(image_path, size=[256,256]):
8
+ # img = cv2.imread(image_path).astype(np.float32)
9
+ # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
10
+ # img = preprocessing(img,size)
11
+ # img = np.expand_dims(img, axis=0)
12
+ # return img
13
+
14
+ def load_input_image(image_file_buffer, size=[256, 256]):
15
+ img = Image.open(image_file_buffer).convert('RGB')
16
+ img = np.array(img).astype(np.float32)
17
+ img = preprocessing(img, size)
18
+ img = np.expand_dims(img, axis=0)
19
+ return img
20
+
21
+ def preprocessing(img, size):
22
+ h, w = img.shape[:2]
23
+ if h <= size[0]:
24
+ h = size[0]
25
+ else:
26
+ x = h % 32
27
+ h = h - x
28
+
29
+ if w < size[1]:
30
+ w = size[1]
31
+ else:
32
+ y = w % 32
33
+ w = w - y
34
+ # the cv2 resize func : dsize format is (W ,H)
35
+ img = cv2.resize(img, (w, h))
36
+ return img/127.5 - 1.0
37
+
38
+ def inverse_transform(images):
39
+ images = (images + 1.) / 2 * 255
40
+ # The calculation of floating-point numbers is inaccurate,
41
+ # and the range of pixel values must be limited to the boundary,
42
+ # otherwise, image distortion or artifacts will appear during display.
43
+ images = np.clip(images, 0, 255)
44
+ return images.astype(np.uint8)
45
+
46
+ # def imsave(images, path):
47
+ # return cv2.imwrite(path, cv2.cvtColor(images, cv2.COLOR_BGR2RGB))