File size: 8,647 Bytes
b31d3c9
0cb9530
 
 
 
 
 
 
b31d3c9
 
eec16fc
0cb9530
 
 
 
 
 
 
 
 
b31d3c9
 
 
 
0cb9530
 
 
 
 
 
b95247b
0cb9530
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b31d3c9
0cb9530
 
 
 
 
 
b31d3c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0cb9530
 
b31d3c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eec16fc
b31d3c9
 
 
 
 
eec16fc
2f2e4c6
 
4c9d269
b31d3c9
 
 
 
 
 
 
 
 
 
 
 
 
5695e02
b31d3c9
 
 
 
2f2e4c6
 
 
b31d3c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d984001
b31d3c9
 
 
d984001
b31d3c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0cb9530
 
 
 
 
 
 
 
 
 
 
 
 
 
067974b
4c9d269
067974b
 
 
 
0cb9530
067974b
0cb9530
 
067974b
0cb9530
 
 
 
 
 
 
 
5695e02
067974b
5695e02
7eda7b6
0cb9530
 
 
878ecf2
b95247b
0cb9530
 
067974b
0cb9530
 
8b64f9e
 
 
1250311
8b64f9e
1250311
 
 
 
8b64f9e
 
 
0cb9530
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
# Import general purpose libraries
import os, sys, re
import streamlit as st
import PIL
from PIL import Image
import cv2
import numpy as np
import uuid
from zipfile import ZipFile, ZIP_DEFLATED
from io import BytesIO
from stqdm import stqdm

# Import util functions from deoldify
# NOTE:  This must be the first call in order to work properly!
from deoldify import device
from deoldify.device_id import DeviceId
#choices:  CPU, GPU0...GPU7
device.set(device=DeviceId.CPU)
from deoldify.visualize import *

# Import util functions from app_utils
from app_utils import get_model_bin



####### INPUT PARAMS ###########
model_folder = 'models/'
max_img_size = 800
################################

@st.cache(allow_output_mutation=True, show_spinner=False)
def load_model(model_dir, option):
    if option.lower() == 'artistic':
        model_url = 'https://data.deepai.org/deoldify/ColorizeArtistic_gen.pth'
        get_model_bin(model_url, os.path.join(model_dir, "ColorizeArtistic_gen.pth"))
        colorizer = get_image_colorizer(artistic=True)
    elif option.lower() == 'stable':
        model_url = "https://www.dropbox.com/s/usf7uifrctqw9rl/ColorizeStable_gen.pth?dl=0"
        get_model_bin(model_url, os.path.join(model_dir, "ColorizeStable_gen.pth"))
        colorizer = get_image_colorizer(artistic=False)

    return colorizer

def resize_img(input_img, max_size):
    img = input_img.copy()
    img_height, img_width = img.shape[0],img.shape[1]

    if max(img_height, img_width) > max_size:
        if img_height > img_width:
            new_width = img_width*(max_size/img_height)
            new_height = max_size
            resized_img = cv2.resize(img,(int(new_width), int(new_height)))
            return resized_img

        elif img_height <= img_width:
            new_width = img_height*(max_size/img_width)
            new_height = max_size
            resized_img = cv2.resize(img,(int(new_width), int(new_height)))
            return resized_img

    return img

def get_image_download_link(img, filename, button_text):
    button_uuid = str(uuid.uuid4()).replace('-', '')
    button_id = re.sub('\d+', '', button_uuid)

    buffered = BytesIO()
    img.save(buffered, format="JPEG")
    img_str = base64.b64encode(buffered.getvalue()).decode()

    return get_button_html_code(img_str, filename, 'txt', button_id, button_text)

def get_button_html_code(data_str, filename, filetype, button_id, button_txt='Download file'):
    custom_css = f""" 
    <style>
        #{button_id} {{
            background-color: rgb(255, 255, 255);
            color: rgb(38, 39, 48);
            padding: 0.25em 0.38em;
            position: relative;
            text-decoration: none;
            border-radius: 4px;
            border-width: 1px;
            border-style: solid;
            border-color: rgb(230, 234, 241);
            border-image: initial;

        }} 
        #{button_id}:hover {{
            border-color: rgb(246, 51, 102);
            color: rgb(246, 51, 102);
        }}
        #{button_id}:active {{
            box-shadow: none;
            background-color: rgb(246, 51, 102);
            color: white;
            }}
    </style> """
    
    href =  custom_css + f'<a href="data:file/{filetype};base64,{data_str}" id="{button_id}" download="{filename}">{button_txt}</a>'
    return href

def display_single_image(uploaded_file, img_size=800):
    print('Type: ', type(uploaded_file))
    st_title_message.markdown("**Processing your image, please wait** βŒ›")
    img_name = uploaded_file.name

    # Open the image
    pil_img = PIL.Image.open(uploaded_file)
    img_rgb = np.array(pil_img)
    resized_img_rgb = resize_img(img_rgb, img_size)
    resized_pil_img = PIL.Image.fromarray(resized_img_rgb)

    # Send the image to the model
    output_pil_img = colorizer.plot_transformed_pil_image(resized_pil_img, render_factor=35, compare=False)

    # Plot images
    st_input_img.image(resized_pil_img, 'Input image', use_column_width=True)
    st_output_img.image(output_pil_img, 'Output image', use_column_width=True)

    # Show download button
    st_download_button.markdown(get_image_download_link(output_pil_img, img_name, 'Download Image'), unsafe_allow_html=True)

    # Reset the message
    st_title_message.markdown("**To begin, please upload an image** πŸ‘‡")

def process_multiple_images(uploaded_files, img_size=800):

    num_imgs = len(uploaded_files)

    output_images_list = []
    img_names_list = []
    idx = 1

    st_progress_bar.progress(0)

    for idx, uploaded_file in stqdm(enumerate(uploaded_files, start=1), st_container=st_progress_bar):
        st_title_message.markdown("**Processing image {}/{}. Please wait** βŒ›".format(idx,
                                                                                    num_imgs))

        img_name = uploaded_file.name
        img_type = uploaded_file.type

        # Open the image
        pil_img = PIL.Image.open(uploaded_file)
        img_rgb = np.array(pil_img)
        resized_img_rgb = resize_img(img_rgb, img_size)
        resized_pil_img = PIL.Image.fromarray(resized_img_rgb)

        # Send the image to the model
        output_pil_img = colorizer.plot_transformed_pil_image(resized_pil_img, render_factor=35, compare=False)

        output_images_list.append(output_pil_img)
        img_names_list.append(img_name.split('.')[0])

        percent = int((idx / num_imgs)*100)
        st_progress_bar.progress(percent)

    # Zip output files
    zip_path = 'processed_images.zip'
    zip_buf = zip_multiple_images(output_images_list, img_names_list, zip_path)

    st_download_button.download_button(
        label='Download ZIP file',
        data=zip_buf.read(),
        file_name=zip_path,
        mime="application/zip"
    )

    # Show message
    st_title_message.markdown("**Images are ready for download** πŸ’Ύ")

def zip_multiple_images(pil_images_list, img_names_list, dest_path):
    # Create zip file on memory
    zip_buf = BytesIO()

    with ZipFile(zip_buf, 'w', ZIP_DEFLATED) as zipObj:
        for pil_img, img_name in zip(pil_images_list, img_names_list):
            with BytesIO() as output:
                # Save image in memory
                pil_img.save(output, format="PNG")
                
                # Read data
                contents = output.getvalue()

                # Write it to zip file
                zipObj.writestr(img_name+".png", contents)
    zip_buf.seek(0)
    return zip_buf



###########################
###### STREAMLIT CODE #####
###########################

# General configuration
# st.set_page_config(layout="centered")
st.set_page_config(layout="wide")
st.set_option('deprecation.showfileUploaderEncoding', False)
st.markdown('''
<style>
    .uploadedFile {display: none}
<style>''',
unsafe_allow_html=True)

# Main window configuration
st.title("Black and white colorizer")
st.markdown("This app puts color into your black and white pictures")
st_title_message = st.empty()
st_progress_bar = st.empty()
st_file_uploader = st.empty()
st_input_img = st.empty()
st_output_img = st.empty()
st_download_button = st.empty()

st_title_message.markdown("**Model loading, please wait** βŒ›")

# # Sidebar
st_color_option = st.sidebar.selectbox('Select colorizer mode',
                                    ('Artistic', 'Stable'))
                                    
# st.sidebar.title('Model parameters')
# det_conf_thres = st.sidebar.slider("Detector confidence threshold", 0.1, 0.9, value=0.5, step=0.1)
# det_nms_thres = st.sidebar.slider("Non-maximum supression IoU", 0.1, 0.9, value=0.4, step=0.1)

# Load models
try:
    print('before loading the model')
    colorizer = load_model(model_folder, st_color_option)
    print('after loading the model')

except Exception as e: 
    colorizer = None
    print('Error while loading the model. Please refresh the page')
    print(e)
    st_title_message.markdown("**Error while loading the model. Please refresh the page**")

if colorizer is not None:
    st_title_message.markdown("**To begin, please upload an image** πŸ‘‡")

    #Choose your own image
    uploaded_files = st_file_uploader.file_uploader("Upload a black and white photo", 
                                            type=['png', 'jpg', 'jpeg'],
                                            accept_multiple_files=True)

    if uploaded_files is not None:
        if len(uploaded_files) == 1:
            display_single_image(uploaded_files[0], max_img_size)
        elif len(uploaded_files) > 1:
            process_multiple_images(uploaded_files, max_img_size)
        
        # Clear uploaded files
        uploaded_files.seek(0)