ArchitSharma commited on
Commit
a408126
·
1 Parent(s): c716076

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +149 -0
app.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on: https://github.com/jantic/DeOldify
2
+ import os, re, time
3
+
4
+ os.environ["TORCH_HOME"] = os.path.join(os.getcwd(), ".cache")
5
+ os.environ["XDG_CACHE_HOME"] = os.path.join(os.getcwd(), ".cache")
6
+
7
+ import streamlit as st
8
+ import PIL
9
+ import cv2
10
+ import numpy as np
11
+ import uuid
12
+ from zipfile import ZipFile, ZIP_DEFLATED
13
+ from io import BytesIO
14
+ from random import randint
15
+ from datetime import datetime
16
+
17
+ from src.deoldify import device
18
+ from src.deoldify.device_id import DeviceId
19
+ from src.deoldify.visualize import *
20
+ from src.app_utils import get_model_bin
21
+
22
+
23
+ device.set(device=DeviceId.CPU)
24
+
25
+
26
+ @st.cache(allow_output_mutation=True, show_spinner=False)
27
+ def load_model(model_dir, option):
28
+ if option.lower() == 'artistic':
29
+ model_url = 'https://data.deepai.org/deoldify/ColorizeArtistic_gen.pth'
30
+ get_model_bin(model_url, os.path.join(model_dir, "ColorizeArtistic_gen.pth"))
31
+ colorizer = get_image_colorizer(artistic=True)
32
+ elif option.lower() == 'stable':
33
+ model_url = "https://www.dropbox.com/s/usf7uifrctqw9rl/ColorizeStable_gen.pth?dl=0"
34
+ get_model_bin(model_url, os.path.join(model_dir, "ColorizeStable_gen.pth"))
35
+ colorizer = get_image_colorizer(artistic=False)
36
+
37
+ return colorizer
38
+
39
+
40
+ def resize_img(input_img, max_size):
41
+ img = input_img.copy()
42
+ img_height, img_width = img.shape[0],img.shape[1]
43
+
44
+ if max(img_height, img_width) > max_size:
45
+ if img_height > img_width:
46
+ new_width = img_width*(max_size/img_height)
47
+ new_height = max_size
48
+ resized_img = cv2.resize(img,(int(new_width), int(new_height)))
49
+ return resized_img
50
+
51
+ elif img_height <= img_width:
52
+ new_width = img_height*(max_size/img_width)
53
+ new_height = max_size
54
+ resized_img = cv2.resize(img,(int(new_width), int(new_height)))
55
+ return resized_img
56
+
57
+ return img
58
+
59
+
60
+ def colorize_image(pil_image, img_size=800) -> "PIL.Image":
61
+ # Open the image
62
+ pil_img = pil_image.convert("RGB")
63
+ img_rgb = np.array(pil_img)
64
+ resized_img_rgb = resize_img(img_rgb, img_size)
65
+ resized_pil_img = PIL.Image.fromarray(resized_img_rgb)
66
+
67
+ # Send the image to the model
68
+ output_pil_img = colorizer.plot_transformed_pil_image(resized_pil_img, render_factor=35, compare=False)
69
+
70
+ return output_pil_img
71
+
72
+
73
+ def image_download_button(pil_image, filename: str, fmt: str, label="Download"):
74
+ if fmt not in ["jpg", "png"]:
75
+ raise Exception(f"Unknown image format (Available: {fmt} - case sensitive)")
76
+
77
+ pil_format = "JPEG" if fmt == "jpg" else "PNG"
78
+ file_format = "jpg" if fmt == "jpg" else "png"
79
+ mime = "image/jpeg" if fmt == "jpg" else "image/png"
80
+
81
+ buf = BytesIO()
82
+ pil_image.save(buf, format=pil_format)
83
+
84
+ return st.download_button(
85
+ label=label,
86
+ data=buf.getvalue(),
87
+ file_name=f'{filename}.{file_format}',
88
+ mime=mime,
89
+ )
90
+
91
+
92
+ ###########################
93
+ ###### STREAMLIT CODE #####
94
+ ###########################
95
+
96
+
97
+ st_color_option = "Artistic"
98
+
99
+ # Load models
100
+ try:
101
+ with st.spinner("Loading..."):
102
+ print('before loading the model')
103
+ colorizer = load_model('models/', st_color_option)
104
+ print('after loading the model')
105
+
106
+ except Exception as e:
107
+ colorizer = None
108
+ print('Error while loading the model. Please refresh the page')
109
+ print(e)
110
+ st.write("**App loading error. Please try again later.**")
111
+
112
+
113
+
114
+ if colorizer is not None:
115
+ st.title("Digital Photo Color Restoration")
116
+
117
+ uploaded_file = st.file_uploader("Upload photo", accept_multiple_files=False, type=["png", "jpg", "jpeg"])
118
+
119
+ if uploaded_file is not None:
120
+ bytes_data = uploaded_file.getvalue()
121
+ img_input = PIL.Image.open(BytesIO(bytes_data)).convert("RGB")
122
+
123
+ with st.expander("Original photo", True):
124
+ st.image(img_input)
125
+
126
+ if st.button("Colorize!") and uploaded_file is not None:
127
+
128
+ with st.spinner("AI is doing the magic!"):
129
+ img_output = colorize_image(img_input)
130
+ img_output = img_output.resize(img_input.size)
131
+
132
+ # NOTE: Calm! I'm not logging the input and outputs.
133
+ # It is impossible to access the filesystem in spaces environment.
134
+ now = datetime.now().strftime("%Y%m%d-%H%M%S-%f")
135
+ img_input.convert("RGB").save(f"./output/{now}-input.jpg")
136
+ img_output.convert("RGB").save(f"./output/{now}-output.jpg")
137
+
138
+ st.write("AI has finished the job!")
139
+ st.image(img_output)
140
+ # reuse = st.button('Edit again (Re-use this image)', on_click=set_image, args=(inpainted_img, ))
141
+
142
+ uploaded_name = os.path.splitext(uploaded_file.name)[0]
143
+ image_download_button(
144
+ pil_image=img_output,
145
+ filename=uploaded_name,
146
+ fmt="jpg",
147
+ label="Download Image"
148
+ )
149
+