ArchitSharma's picture
Update app.py
f1705e6
raw
history blame
4.93 kB
# Based on: https://github.com/jantic/DeOldify
import os, re, time
os.environ["TORCH_HOME"] = os.path.join(os.getcwd(), ".cache")
os.environ["XDG_CACHE_HOME"] = os.path.join(os.getcwd(), ".cache")
import streamlit as st
import PIL
import cv2
import numpy as np
import uuid
from zipfile import ZipFile, ZIP_DEFLATED
from io import BytesIO
from random import randint
from datetime import datetime
from src.deoldify import device
from src.deoldify.device_id import DeviceId
from src.deoldify.visualize import *
from src.app_utils import get_model_bin
device.set(device=DeviceId.CPU)
@st.cache_resource
def load_model(model_dir, option):
if option.lower() == 'artistic':
model_url = 'https://data.deepai.org/deoldify/ColorizeArtistic_gen.pth'
get_model_bin(model_url, os.path.join(model_dir, "ColorizeArtistic_gen.pth"))
colorizer = get_image_colorizer(artistic=True)
elif option.lower() == 'stable':
model_url = "https://www.dropbox.com/s/usf7uifrctqw9rl/ColorizeStable_gen.pth?dl=0"
get_model_bin(model_url, os.path.join(model_dir, "ColorizeStable_gen.pth"))
colorizer = get_image_colorizer(artistic=False)
return colorizer
def resize_img(input_img, max_size):
img = input_img.copy()
img_height, img_width = img.shape[0],img.shape[1]
if max(img_height, img_width) > max_size:
if img_height > img_width:
new_width = img_width*(max_size/img_height)
new_height = max_size
resized_img = cv2.resize(img,(int(new_width), int(new_height)))
return resized_img
elif img_height <= img_width:
new_width = img_height*(max_size/img_width)
new_height = max_size
resized_img = cv2.resize(img,(int(new_width), int(new_height)))
return resized_img
return img
def colorize_image(pil_image, img_size=800) -> "PIL.Image":
# Open the image
pil_img = pil_image.convert("RGB")
img_rgb = np.array(pil_img)
resized_img_rgb = resize_img(img_rgb, img_size)
resized_pil_img = PIL.Image.fromarray(resized_img_rgb)
# Send the image to the model
output_pil_img = colorizer.plot_transformed_pil_image(resized_pil_img, render_factor=35, compare=False)
return output_pil_img
def image_download_button(pil_image, filename: str, fmt: str, label="Download"):
if fmt not in ["jpg", "png"]:
raise Exception(f"Unknown image format (Available: {fmt} - case sensitive)")
pil_format = "JPEG" if fmt == "jpg" else "PNG"
file_format = "jpg" if fmt == "jpg" else "png"
mime = "image/jpeg" if fmt == "jpg" else "image/png"
buf = BytesIO()
pil_image.save(buf, format=pil_format)
return st.download_button(
label=label,
data=buf.getvalue(),
file_name=f'{filename}.{file_format}',
mime=mime,
)
###########################
###### STREAMLIT CODE #####
###########################
st_color_option = "Artistic"
# Load models
try:
with st.spinner("Loading..."):
print('before loading the model')
colorizer = load_model('models/', st_color_option)
print('after loading the model')
except Exception as e:
colorizer = None
print('Error while loading the model. Please refresh the page')
print(e)
st.write("**App loading error. Please try again later.**")
if colorizer is not None:
st.title("Digital Photo Color Restoration")
uploaded_file = st.file_uploader("Upload photo", accept_multiple_files=False, type=["png", "jpg", "jpeg"])
if uploaded_file is not None:
bytes_data = uploaded_file.getvalue()
img_input = PIL.Image.open(BytesIO(bytes_data)).convert("RGB")
with st.expander("Original photo", True):
st.image(img_input)
if st.button("Restore Color!") and uploaded_file is not None:
with st.spinner("AI is doing the magic!"):
img_output = colorize_image(img_input)
img_output = img_output.resize(img_input.size)
# NOTE: Calm! I'm not logging the input and outputs.
# It is impossible to access the filesystem in spaces environment.
now = datetime.now().strftime("%Y%m%d-%H%M%S-%f")
img_input.convert("RGB").save(f"./output/{now}-input.jpg")
img_output.convert("RGB").save(f"./output/{now}-output.jpg")
st.write("AI has finished the job!")
st.image(img_output)
# reuse = st.button('Edit again (Re-use this image)', on_click=set_image, args=(inpainted_img, ))
uploaded_name = os.path.splitext(uploaded_file.name)[0]
image_download_button(
pil_image=img_output,
filename=uploaded_name,
fmt="jpg",
label="Download Image"
)