File size: 5,257 Bytes
47c60f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# Based on: https://github.com/jantic/DeOldify
import os, re, time

os.environ["TORCH_HOME"] = os.path.join(os.getcwd(), ".cache")
os.environ["XDG_CACHE_HOME"] = os.path.join(os.getcwd(), ".cache")

import streamlit as st
import PIL
import cv2
import numpy as np
import uuid
from zipfile import ZipFile, ZIP_DEFLATED
from io import BytesIO
from random import randint
from datetime import datetime

from src.deoldify import device
from src.deoldify.device_id import DeviceId
from src.deoldify.visualize import *
from src.app_utils import get_model_bin


device.set(device=DeviceId.CPU)


@st.cache(allow_output_mutation=True, show_spinner=False)
def load_model(model_dir, option):
    if option.lower() == 'artistic':
        model_url = 'https://data.deepai.org/deoldify/ColorizeArtistic_gen.pth'
        get_model_bin(model_url, os.path.join(model_dir, "ColorizeArtistic_gen.pth"))
        colorizer = get_image_colorizer(artistic=True)
    elif option.lower() == 'stable':
        model_url = "https://www.dropbox.com/s/usf7uifrctqw9rl/ColorizeStable_gen.pth?dl=0"
        get_model_bin(model_url, os.path.join(model_dir, "ColorizeStable_gen.pth"))
        colorizer = get_image_colorizer(artistic=False)

    return colorizer


def resize_img(input_img, max_size):
    img = input_img.copy()
    img_height, img_width = img.shape[0],img.shape[1]

    if max(img_height, img_width) > max_size:
        if img_height > img_width:
            new_width = img_width*(max_size/img_height)
            new_height = max_size
            resized_img = cv2.resize(img,(int(new_width), int(new_height)))
            return resized_img

        elif img_height <= img_width:
            new_width = img_height*(max_size/img_width)
            new_height = max_size
            resized_img = cv2.resize(img,(int(new_width), int(new_height)))
            return resized_img

    return img


def colorize_image(pil_image, img_size=800) -> "PIL.Image":
    # Open the image
    pil_img = pil_image.convert("RGB")
    img_rgb = np.array(pil_img)
    resized_img_rgb = resize_img(img_rgb, img_size)
    resized_pil_img = PIL.Image.fromarray(resized_img_rgb)

    # Send the image to the model
    output_pil_img = colorizer.plot_transformed_pil_image(resized_pil_img, render_factor=35, compare=False)
    
    return output_pil_img


def image_download_button(pil_image, filename: str, fmt: str, label="Download"):
    if fmt not in ["jpg", "png"]:
        raise Exception(f"Unknown image format (Available: {fmt} - case sensitive)")
    
    pil_format = "JPEG" if fmt == "jpg" else "PNG"
    file_format = "jpg" if fmt == "jpg" else "png"
    mime = "image/jpeg" if fmt == "jpg" else "image/png"
    
    buf = BytesIO()
    pil_image.save(buf, format=pil_format)
    
    return st.download_button(
        label=label,
        data=buf.getvalue(),
        file_name=f'{filename}.{file_format}',
        mime=mime,
    )


###########################
###### STREAMLIT CODE #####
###########################


st_color_option = "Artistic"

# Load models
try:
    with st.spinner("Loading..."):
        print('before loading the model')
        colorizer = load_model('models/', st_color_option)
        print('after loading the model')

except Exception as e: 
    colorizer = None
    print('Error while loading the model. Please refresh the page')
    print(e)
    st.write("**App loading error. Please try again later.**")



if colorizer is not None:
    st.title("AI Photo Colorization")

    st.image(open("assets/demo.jpg", "rb").read())

    st.markdown(
        """
        Colorizing black & white photo can be expensive and time consuming. We introduce AI that can colorize
        grayscale photo in seconds. **Just upload your grayscale image, then click colorize.**
        """
    )
    
    uploaded_file = st.file_uploader("Upload photo", accept_multiple_files=False, type=["png", "jpg", "jpeg"])

    if uploaded_file is not None:
        bytes_data = uploaded_file.getvalue()
        img_input = PIL.Image.open(BytesIO(bytes_data)).convert("RGB")
        
        with st.expander("Original photo", True):
            st.image(img_input)

        if st.button("Colorize!") and uploaded_file is not None:
            
            with st.spinner("AI is doing the magic!"):
                img_output = colorize_image(img_input)
                img_output = img_output.resize(img_input.size)
            
            # NOTE: Calm! I'm not logging the input and outputs.
            # It is impossible to access the filesystem in spaces environment.
            now = datetime.now().strftime("%Y%m%d-%H%M%S-%f")
            img_input.convert("RGB").save(f"./output/{now}-input.jpg")
            img_output.convert("RGB").save(f"./output/{now}-output.jpg")
            
            st.write("AI has finished the job!")
            st.image(img_output)
            # reuse = st.button('Edit again (Re-use this image)', on_click=set_image, args=(inpainted_img, ))
            
            uploaded_name = os.path.splitext(uploaded_file.name)[0]
            image_download_button(
                pil_image=img_output,
                filename=uploaded_name,
                fmt="jpg",
                label="Download Image"
            )