File size: 9,618 Bytes
b31d3c9
3c47651
0cb9530
 
 
 
 
b31d3c9
 
3c47651
0cb9530
 
 
 
 
 
 
 
 
b31d3c9
 
 
 
0cb9530
3c47651
13d150c
3c47651
 
 
 
2ec64cc
3c47651
 
 
0cb9530
 
2ec64cc
 
b95247b
0cb9530
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b31d3c9
0cb9530
 
 
 
 
 
b31d3c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0cb9530
 
b31d3c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eec16fc
b31d3c9
 
 
 
 
eec16fc
2f2e4c6
 
3c47651
b31d3c9
 
 
 
 
 
 
 
 
 
 
 
 
5695e02
b31d3c9
 
 
 
2f2e4c6
 
 
b31d3c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d984001
b31d3c9
 
 
d984001
b31d3c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0cb9530
 
 
 
 
 
 
 
 
 
 
 
 
 
067974b
4c9d269
067974b
 
 
 
0cb9530
067974b
0cb9530
 
067974b
0cb9530
 
 
 
 
 
 
 
5695e02
3c47651
5695e02
7eda7b6
0cb9530
 
 
878ecf2
b95247b
0cb9530
 
067974b
2ec64cc
0cb9530
8b64f9e
 
3c47651
 
 
2ec64cc
3c47651
13d150c
 
 
 
 
 
 
cc9c1ce
5cc54b6
 
 
 
13d150c
 
 
0cb9530
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
# Import general purpose libraries
import os, re, time
import streamlit as st
import PIL
import cv2
import numpy as np
import uuid
from zipfile import ZipFile, ZIP_DEFLATED
from io import BytesIO
from random import randint

# Import util functions from deoldify
# NOTE:  This must be the first call in order to work properly!
from deoldify import device
from deoldify.device_id import DeviceId
#choices:  CPU, GPU0...GPU7
device.set(device=DeviceId.CPU)
from deoldify.visualize import *

# Import util functions from app_utils
from app_utils import get_model_bin



SESSION_STATE_VARIABLES = [
    'model_folder','max_img_size','uploaded_file_key','uploaded_files'
]
for i in SESSION_STATE_VARIABLES:
    if i not in st.session_state:
        st.session_state[i] = None
                
#### SET INPUT PARAMS ###########
if not st.session_state.model_folder: st.session_state.model_folder = 'models/'
if not st.session_state.max_img_size: st.session_state.max_img_size = 800
################################



@st.cache(allow_output_mutation=True, show_spinner=False)
def load_model(model_dir, option):
    if option.lower() == 'artistic':
        model_url = 'https://data.deepai.org/deoldify/ColorizeArtistic_gen.pth'
        get_model_bin(model_url, os.path.join(model_dir, "ColorizeArtistic_gen.pth"))
        colorizer = get_image_colorizer(artistic=True)
    elif option.lower() == 'stable':
        model_url = "https://www.dropbox.com/s/usf7uifrctqw9rl/ColorizeStable_gen.pth?dl=0"
        get_model_bin(model_url, os.path.join(model_dir, "ColorizeStable_gen.pth"))
        colorizer = get_image_colorizer(artistic=False)

    return colorizer

def resize_img(input_img, max_size):
    img = input_img.copy()
    img_height, img_width = img.shape[0],img.shape[1]

    if max(img_height, img_width) > max_size:
        if img_height > img_width:
            new_width = img_width*(max_size/img_height)
            new_height = max_size
            resized_img = cv2.resize(img,(int(new_width), int(new_height)))
            return resized_img

        elif img_height <= img_width:
            new_width = img_height*(max_size/img_width)
            new_height = max_size
            resized_img = cv2.resize(img,(int(new_width), int(new_height)))
            return resized_img

    return img

def get_image_download_link(img, filename, button_text):
    button_uuid = str(uuid.uuid4()).replace('-', '')
    button_id = re.sub('\d+', '', button_uuid)

    buffered = BytesIO()
    img.save(buffered, format="JPEG")
    img_str = base64.b64encode(buffered.getvalue()).decode()

    return get_button_html_code(img_str, filename, 'txt', button_id, button_text)

def get_button_html_code(data_str, filename, filetype, button_id, button_txt='Download file'):
    custom_css = f""" 
    <style>
        #{button_id} {{
            background-color: rgb(255, 255, 255);
            color: rgb(38, 39, 48);
            padding: 0.25em 0.38em;
            position: relative;
            text-decoration: none;
            border-radius: 4px;
            border-width: 1px;
            border-style: solid;
            border-color: rgb(230, 234, 241);
            border-image: initial;

        }} 
        #{button_id}:hover {{
            border-color: rgb(246, 51, 102);
            color: rgb(246, 51, 102);
        }}
        #{button_id}:active {{
            box-shadow: none;
            background-color: rgb(246, 51, 102);
            color: white;
            }}
    </style> """
    
    href =  custom_css + f'<a href="data:file/{filetype};base64,{data_str}" id="{button_id}" download="{filename}">{button_txt}</a>'
    return href

def display_single_image(uploaded_file, img_size=800):
    st_title_message.markdown("**Processing your image, please wait** βŒ›")
    img_name = uploaded_file.name

    # Open the image
    pil_img = PIL.Image.open(uploaded_file)
    img_rgb = np.array(pil_img)
    resized_img_rgb = resize_img(img_rgb, img_size)
    resized_pil_img = PIL.Image.fromarray(resized_img_rgb)

    # Send the image to the model
    output_pil_img = colorizer.plot_transformed_pil_image(resized_pil_img, render_factor=35, compare=False)

    # Plot images
    st_input_img.image(resized_pil_img, 'Input image', use_column_width=True)
    st_output_img.image(output_pil_img, 'Output image', use_column_width=True)

    # Show download button
    st_download_button.markdown(get_image_download_link(output_pil_img, img_name, 'Download Image'), unsafe_allow_html=True)

    # Reset the message
    st_title_message.markdown("**To begin, please upload an image** πŸ‘‡")

def process_multiple_images(uploaded_files, img_size=800):

    num_imgs = len(uploaded_files)

    output_images_list = []
    img_names_list = []
    idx = 1

    st_progress_bar.progress(0)

    for idx, uploaded_file in enumerate(uploaded_files, start=1):
        st_title_message.markdown("**Processing image {}/{}. Please wait** βŒ›".format(idx,
                                                                                    num_imgs))

        img_name = uploaded_file.name
        img_type = uploaded_file.type

        # Open the image
        pil_img = PIL.Image.open(uploaded_file)
        img_rgb = np.array(pil_img)
        resized_img_rgb = resize_img(img_rgb, img_size)
        resized_pil_img = PIL.Image.fromarray(resized_img_rgb)

        # Send the image to the model
        output_pil_img = colorizer.plot_transformed_pil_image(resized_pil_img, render_factor=35, compare=False)

        output_images_list.append(output_pil_img)
        img_names_list.append(img_name.split('.')[0])

        percent = int((idx / num_imgs)*100)
        st_progress_bar.progress(percent)

    # Zip output files
    zip_path = 'processed_images.zip'
    zip_buf = zip_multiple_images(output_images_list, img_names_list, zip_path)

    st_download_button.download_button(
        label='Download ZIP file',
        data=zip_buf.read(),
        file_name=zip_path,
        mime="application/zip"
    )

    # Show message
    st_title_message.markdown("**Images are ready for download** πŸ’Ύ")

def zip_multiple_images(pil_images_list, img_names_list, dest_path):
    # Create zip file on memory
    zip_buf = BytesIO()

    with ZipFile(zip_buf, 'w', ZIP_DEFLATED) as zipObj:
        for pil_img, img_name in zip(pil_images_list, img_names_list):
            with BytesIO() as output:
                # Save image in memory
                pil_img.save(output, format="PNG")
                
                # Read data
                contents = output.getvalue()

                # Write it to zip file
                zipObj.writestr(img_name+".png", contents)
    zip_buf.seek(0)
    return zip_buf



###########################
###### STREAMLIT CODE #####
###########################

# General configuration
# st.set_page_config(layout="centered")
st.set_page_config(layout="wide")
st.set_option('deprecation.showfileUploaderEncoding', False)
st.markdown('''
<style>
    .uploadedFile {display: none}
<style>''',
unsafe_allow_html=True)

# Main window configuration
st.title("Black and white colorizer")
st.markdown("This app puts color into your black and white pictures")
st_title_message = st.empty()
st_progress_bar = st.empty()
st_file_uploader = st.empty()
st_input_img = st.empty()
st_output_img = st.empty()
st_download_button = st.empty()

st_title_message.markdown("**Model loading, please wait** βŒ›")

# # Sidebar
st_color_option = st.sidebar.selectbox('Select colorizer mode',
                                    ('Artistic', 'Stable'))
                                    
# st.sidebar.title('Model parameters')
# det_conf_thres = st.sidebar.slider("Detector confidence threshold", 0.1, 0.9, value=0.5, step=0.1)
# det_nms_thres = st.sidebar.slider("Non-maximum supression IoU", 0.1, 0.9, value=0.4, step=0.1)

# Load models
try:
    print('before loading the model')
    colorizer = load_model(st.session_state.model_folder, st_color_option)
    print('after loading the model')

except Exception as e: 
    colorizer = None
    print('Error while loading the model. Please refresh the page')
    print(e)
    st_title_message.markdown("**Error while loading the model. Please refresh the page**")

if colorizer is not None:
    st_title_message.markdown("**To begin, please upload an image** πŸ‘‡")

    #Choose your own image
    uploaded_files = st_file_uploader.file_uploader("Upload a black and white photo", 
                                            type=['png', 'jpg', 'jpeg'],
                                            accept_multiple_files=True,
                                            key=f"{st.session_state['uploaded_file_key']}"
                                            )

    if uploaded_files:
        # Copy images to a session state
        st.session_state['uploaded_files'] = uploaded_files
                
        st.session_state['uploaded_file_key'] = str(randint(1000, 100000000))  # remove the uploaded file from the UI
        st.experimental_rerun()  # Force rerun to reload the file_uploader object with new key
    
    # If session state is not empty, we will process stored images
    if st.session_state['uploaded_files']:
        if len(st.session_state['uploaded_files']) == 1:
            display_single_image(st.session_state['uploaded_files'][0], st.session_state.max_img_size)
        elif len(st.session_state['uploaded_files']) > 1:
            process_multiple_images(st.session_state['uploaded_files'], st.session_state.max_img_size)
            
        # Reset session state variable
        st.session_state['uploaded_files'] = None