Spaces:
Runtime error
Runtime error
#importing the libraries | |
import os, sys, re | |
import streamlit as st | |
from PIL import Image | |
import cv2 | |
import numpy as np | |
import uuid | |
# Import torch libraries | |
import fastai | |
import torch | |
# Import util functions from app_utils | |
from app_utils import download | |
from app_utils import generate_random_filename | |
from app_utils import clean_me | |
from app_utils import clean_all | |
from app_utils import create_directory | |
from app_utils import get_model_bin | |
from app_utils import convertToJPG | |
# Import util functions from deoldify | |
# NOTE: This must be the first call in order to work properly! | |
from deoldify import device | |
from deoldify.device_id import DeviceId | |
#choices: CPU, GPU0...GPU7 | |
device.set(device=DeviceId.CPU) | |
from deoldify.visualize import * | |
####### INPUT PARAMS ########### | |
model_folder = 'models/' | |
max_img_size = 800 | |
################################ | |
def load_model(model_dir, option): | |
if option.lower() == 'artistic': | |
model_url = 'https://data.deepai.org/deoldify/ColorizeArtistic_gen.pth' | |
get_model_bin(model_url, os.path.join(model_dir, "ColorizeArtistic_gen.pth")) | |
colorizer = get_image_colorizer(artistic=True) | |
elif option.lower() == 'stable': | |
model_url = "https://www.dropbox.com/s/usf7uifrctqw9rl/ColorizeStable_gen.pth?dl=0" | |
get_model_bin(model_url, os.path.join(model_dir, "ColorizeStable_gen.pth")) | |
colorizer = get_image_colorizer(artistic=False) | |
return colorizer | |
def resize_img(input_img, max_size): | |
img = input_img.copy() | |
img_height, img_width = img.shape[0],img.shape[1] | |
if max(img_height, img_width) > max_size: | |
if img_height > img_width: | |
new_width = img_width*(max_size/img_height) | |
new_height = max_size | |
resized_img = cv2.resize(img,(int(new_width), int(new_height))) | |
return resized_img | |
elif img_height <= img_width: | |
new_width = img_height*(max_size/img_width) | |
new_height = max_size | |
resized_img = cv2.resize(img,(int(new_width), int(new_height))) | |
return resized_img | |
return img | |
def get_image_download_link(img,filename,text): | |
button_uuid = str(uuid.uuid4()).replace('-', '') | |
button_id = re.sub('\d+', '', button_uuid) | |
custom_css = f""" | |
<style> | |
#{button_id} {{ | |
background-color: rgb(255, 255, 255); | |
color: rgb(38, 39, 48); | |
padding: 0.25em 0.38em; | |
position: relative; | |
text-decoration: none; | |
border-radius: 4px; | |
border-width: 1px; | |
border-style: solid; | |
border-color: rgb(230, 234, 241); | |
border-image: initial; | |
}} | |
#{button_id}:hover {{ | |
border-color: rgb(246, 51, 102); | |
color: rgb(246, 51, 102); | |
}} | |
#{button_id}:active {{ | |
box-shadow: none; | |
background-color: rgb(246, 51, 102); | |
color: white; | |
}} | |
</style> """ | |
buffered = BytesIO() | |
img.save(buffered, format="JPEG") | |
img_str = base64.b64encode(buffered.getvalue()).decode() | |
href = custom_css + f'<a href="data:file/txt;base64,{img_str}" id="{button_id}" download="{filename}">{text}</a>' | |
return href | |
# General configuration | |
st.set_page_config(layout="centered") | |
st.set_option('deprecation.showfileUploaderEncoding', False) | |
st.markdown(''' | |
<style> | |
.uploadedFile {display: none} | |
<style>''', | |
unsafe_allow_html=True) | |
# Main window configuration | |
st.title("Black and white colorizer") | |
st.markdown("This app puts color into your black and white pictures") | |
title_message = st.empty() | |
title_message.markdown("**Model loading, please wait** β") | |
# # Sidebar | |
color_option = st.sidebar.selectbox('Select colorizer mode', | |
('Artistic', 'Stable')) | |
# st.sidebar.title('Model parameters') | |
# det_conf_thres = st.sidebar.slider("Detector confidence threshold", 0.1, 0.9, value=0.5, step=0.1) | |
# det_nms_thres = st.sidebar.slider("Non-maximum supression IoU", 0.1, 0.9, value=0.4, step=0.1) | |
# Load models | |
colorizer = load_model(model_folder, color_option) | |
title_message.markdown("**To begin, please upload an image** π") | |
#Choose your own image | |
uploaded_file = st.file_uploader("Upload a black and white photo", type=['png', 'jpg', 'jpeg']) | |
# show = st.image(use_column_width='auto') | |
input_img_pos = st.empty() | |
output_img_pos = st.empty() | |
if uploaded_file is not None: | |
img_name = uploaded_file.name | |
pil_img = Image.open(uploaded_file) | |
img_rgb = np.array(pil_img) | |
resized_img_rgb = resize_img(img_rgb, max_img_size) | |
resized_pil_img = Image.fromarray(resized_img_rgb) | |
title_message.markdown("**Processing your image, please wait** β") | |
output_pil_img = colorizer.plot_transformed_pil_image(resized_pil_img, render_factor=35, compare=False) | |
title_message.markdown("**To begin, please upload an image** π") | |
# Plot images | |
input_img_pos.image(resized_pil_img, 'Input image', use_column_width=True) | |
output_img_pos.image(output_pil_img, 'Output image', use_column_width=True) | |
st.markdown(get_image_download_link(output_pil_img, img_name, 'Download '+img_name), unsafe_allow_html=True) | |