Spaces:
Runtime error
Runtime error
ArchitSharma
commited on
Commit
·
a408126
1
Parent(s):
c716076
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Based on: https://github.com/jantic/DeOldify
|
2 |
+
import os, re, time
|
3 |
+
|
4 |
+
os.environ["TORCH_HOME"] = os.path.join(os.getcwd(), ".cache")
|
5 |
+
os.environ["XDG_CACHE_HOME"] = os.path.join(os.getcwd(), ".cache")
|
6 |
+
|
7 |
+
import streamlit as st
|
8 |
+
import PIL
|
9 |
+
import cv2
|
10 |
+
import numpy as np
|
11 |
+
import uuid
|
12 |
+
from zipfile import ZipFile, ZIP_DEFLATED
|
13 |
+
from io import BytesIO
|
14 |
+
from random import randint
|
15 |
+
from datetime import datetime
|
16 |
+
|
17 |
+
from src.deoldify import device
|
18 |
+
from src.deoldify.device_id import DeviceId
|
19 |
+
from src.deoldify.visualize import *
|
20 |
+
from src.app_utils import get_model_bin
|
21 |
+
|
22 |
+
|
23 |
+
device.set(device=DeviceId.CPU)
|
24 |
+
|
25 |
+
|
26 |
+
@st.cache(allow_output_mutation=True, show_spinner=False)
|
27 |
+
def load_model(model_dir, option):
|
28 |
+
if option.lower() == 'artistic':
|
29 |
+
model_url = 'https://data.deepai.org/deoldify/ColorizeArtistic_gen.pth'
|
30 |
+
get_model_bin(model_url, os.path.join(model_dir, "ColorizeArtistic_gen.pth"))
|
31 |
+
colorizer = get_image_colorizer(artistic=True)
|
32 |
+
elif option.lower() == 'stable':
|
33 |
+
model_url = "https://www.dropbox.com/s/usf7uifrctqw9rl/ColorizeStable_gen.pth?dl=0"
|
34 |
+
get_model_bin(model_url, os.path.join(model_dir, "ColorizeStable_gen.pth"))
|
35 |
+
colorizer = get_image_colorizer(artistic=False)
|
36 |
+
|
37 |
+
return colorizer
|
38 |
+
|
39 |
+
|
40 |
+
def resize_img(input_img, max_size):
|
41 |
+
img = input_img.copy()
|
42 |
+
img_height, img_width = img.shape[0],img.shape[1]
|
43 |
+
|
44 |
+
if max(img_height, img_width) > max_size:
|
45 |
+
if img_height > img_width:
|
46 |
+
new_width = img_width*(max_size/img_height)
|
47 |
+
new_height = max_size
|
48 |
+
resized_img = cv2.resize(img,(int(new_width), int(new_height)))
|
49 |
+
return resized_img
|
50 |
+
|
51 |
+
elif img_height <= img_width:
|
52 |
+
new_width = img_height*(max_size/img_width)
|
53 |
+
new_height = max_size
|
54 |
+
resized_img = cv2.resize(img,(int(new_width), int(new_height)))
|
55 |
+
return resized_img
|
56 |
+
|
57 |
+
return img
|
58 |
+
|
59 |
+
|
60 |
+
def colorize_image(pil_image, img_size=800) -> "PIL.Image":
|
61 |
+
# Open the image
|
62 |
+
pil_img = pil_image.convert("RGB")
|
63 |
+
img_rgb = np.array(pil_img)
|
64 |
+
resized_img_rgb = resize_img(img_rgb, img_size)
|
65 |
+
resized_pil_img = PIL.Image.fromarray(resized_img_rgb)
|
66 |
+
|
67 |
+
# Send the image to the model
|
68 |
+
output_pil_img = colorizer.plot_transformed_pil_image(resized_pil_img, render_factor=35, compare=False)
|
69 |
+
|
70 |
+
return output_pil_img
|
71 |
+
|
72 |
+
|
73 |
+
def image_download_button(pil_image, filename: str, fmt: str, label="Download"):
|
74 |
+
if fmt not in ["jpg", "png"]:
|
75 |
+
raise Exception(f"Unknown image format (Available: {fmt} - case sensitive)")
|
76 |
+
|
77 |
+
pil_format = "JPEG" if fmt == "jpg" else "PNG"
|
78 |
+
file_format = "jpg" if fmt == "jpg" else "png"
|
79 |
+
mime = "image/jpeg" if fmt == "jpg" else "image/png"
|
80 |
+
|
81 |
+
buf = BytesIO()
|
82 |
+
pil_image.save(buf, format=pil_format)
|
83 |
+
|
84 |
+
return st.download_button(
|
85 |
+
label=label,
|
86 |
+
data=buf.getvalue(),
|
87 |
+
file_name=f'{filename}.{file_format}',
|
88 |
+
mime=mime,
|
89 |
+
)
|
90 |
+
|
91 |
+
|
92 |
+
###########################
|
93 |
+
###### STREAMLIT CODE #####
|
94 |
+
###########################
|
95 |
+
|
96 |
+
|
97 |
+
st_color_option = "Artistic"
|
98 |
+
|
99 |
+
# Load models
|
100 |
+
try:
|
101 |
+
with st.spinner("Loading..."):
|
102 |
+
print('before loading the model')
|
103 |
+
colorizer = load_model('models/', st_color_option)
|
104 |
+
print('after loading the model')
|
105 |
+
|
106 |
+
except Exception as e:
|
107 |
+
colorizer = None
|
108 |
+
print('Error while loading the model. Please refresh the page')
|
109 |
+
print(e)
|
110 |
+
st.write("**App loading error. Please try again later.**")
|
111 |
+
|
112 |
+
|
113 |
+
|
114 |
+
if colorizer is not None:
|
115 |
+
st.title("Digital Photo Color Restoration")
|
116 |
+
|
117 |
+
uploaded_file = st.file_uploader("Upload photo", accept_multiple_files=False, type=["png", "jpg", "jpeg"])
|
118 |
+
|
119 |
+
if uploaded_file is not None:
|
120 |
+
bytes_data = uploaded_file.getvalue()
|
121 |
+
img_input = PIL.Image.open(BytesIO(bytes_data)).convert("RGB")
|
122 |
+
|
123 |
+
with st.expander("Original photo", True):
|
124 |
+
st.image(img_input)
|
125 |
+
|
126 |
+
if st.button("Colorize!") and uploaded_file is not None:
|
127 |
+
|
128 |
+
with st.spinner("AI is doing the magic!"):
|
129 |
+
img_output = colorize_image(img_input)
|
130 |
+
img_output = img_output.resize(img_input.size)
|
131 |
+
|
132 |
+
# NOTE: Calm! I'm not logging the input and outputs.
|
133 |
+
# It is impossible to access the filesystem in spaces environment.
|
134 |
+
now = datetime.now().strftime("%Y%m%d-%H%M%S-%f")
|
135 |
+
img_input.convert("RGB").save(f"./output/{now}-input.jpg")
|
136 |
+
img_output.convert("RGB").save(f"./output/{now}-output.jpg")
|
137 |
+
|
138 |
+
st.write("AI has finished the job!")
|
139 |
+
st.image(img_output)
|
140 |
+
# reuse = st.button('Edit again (Re-use this image)', on_click=set_image, args=(inpainted_img, ))
|
141 |
+
|
142 |
+
uploaded_name = os.path.splitext(uploaded_file.name)[0]
|
143 |
+
image_download_button(
|
144 |
+
pil_image=img_output,
|
145 |
+
filename=uploaded_name,
|
146 |
+
fmt="jpg",
|
147 |
+
label="Download Image"
|
148 |
+
)
|
149 |
+
|