Daniel Verdu
first commit in hf_spaces
9e08039
raw
history blame
5.37 kB
#importing the libraries
import os, sys, re
import streamlit as st
from PIL import Image
import cv2
import numpy as np
import uuid
# Import torch libraries
import fastai
import torch
# Import util functions from app_utils
from app_utils import download
from app_utils import generate_random_filename
from app_utils import clean_me
from app_utils import clean_all
from app_utils import create_directory
from app_utils import get_model_bin
from app_utils import convertToJPG
# Import util functions from deoldify
# NOTE: This must be the first call in order to work properly!
from deoldify import device
from deoldify.device_id import DeviceId
#choices: CPU, GPU0...GPU7
device.set(device=DeviceId.CPU)
from deoldify.visualize import *
####### INPUT PARAMS ###########
model_folder = 'models/'
max_img_size = 800
################################
@st.cache(allow_output_mutation=True)
def load_model(model_dir, option):
if option.lower() == 'artistic':
model_url = 'https://data.deepai.org/deoldify/ColorizeArtistic_gen.pth'
get_model_bin(model_url, os.path.join(model_dir, "ColorizeArtistic_gen.pth"))
colorizer = get_image_colorizer(artistic=True)
elif option.lower() == 'stable':
model_url = "https://www.dropbox.com/s/usf7uifrctqw9rl/ColorizeStable_gen.pth?dl=0"
get_model_bin(model_url, os.path.join(model_dir, "ColorizeStable_gen.pth"))
colorizer = get_image_colorizer(artistic=False)
return colorizer
def resize_img(input_img, max_size):
img = input_img.copy()
img_height, img_width = img.shape[0],img.shape[1]
if max(img_height, img_width) > max_size:
if img_height > img_width:
new_width = img_width*(max_size/img_height)
new_height = max_size
resized_img = cv2.resize(img,(int(new_width), int(new_height)))
return resized_img
elif img_height <= img_width:
new_width = img_height*(max_size/img_width)
new_height = max_size
resized_img = cv2.resize(img,(int(new_width), int(new_height)))
return resized_img
return img
def get_image_download_link(img,filename,text):
button_uuid = str(uuid.uuid4()).replace('-', '')
button_id = re.sub('\d+', '', button_uuid)
custom_css = f"""
<style>
#{button_id} {{
background-color: rgb(255, 255, 255);
color: rgb(38, 39, 48);
padding: 0.25em 0.38em;
position: relative;
text-decoration: none;
border-radius: 4px;
border-width: 1px;
border-style: solid;
border-color: rgb(230, 234, 241);
border-image: initial;
}}
#{button_id}:hover {{
border-color: rgb(246, 51, 102);
color: rgb(246, 51, 102);
}}
#{button_id}:active {{
box-shadow: none;
background-color: rgb(246, 51, 102);
color: white;
}}
</style> """
buffered = BytesIO()
img.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode()
href = custom_css + f'<a href="data:file/txt;base64,{img_str}" id="{button_id}" download="{filename}">{text}</a>'
return href
# General configuration
st.set_page_config(layout="centered")
st.set_option('deprecation.showfileUploaderEncoding', False)
st.markdown('''
<style>
.uploadedFile {display: none}
<style>''',
unsafe_allow_html=True)
# Main window configuration
st.title("Black and white colorizer")
st.markdown("This app puts color into your black and white pictures")
title_message = st.empty()
title_message.markdown("**Model loading, please wait** βŒ›")
# # Sidebar
color_option = st.sidebar.selectbox('Select colorizer mode',
('Artistic', 'Stable'))
# st.sidebar.title('Model parameters')
# det_conf_thres = st.sidebar.slider("Detector confidence threshold", 0.1, 0.9, value=0.5, step=0.1)
# det_nms_thres = st.sidebar.slider("Non-maximum supression IoU", 0.1, 0.9, value=0.4, step=0.1)
# Load models
colorizer = load_model(model_folder, color_option)
title_message.markdown("**To begin, please upload an image** πŸ‘‡")
#Choose your own image
uploaded_file = st.file_uploader("Upload a black and white photo", type=['png', 'jpg', 'jpeg'])
# show = st.image(use_column_width='auto')
input_img_pos = st.empty()
output_img_pos = st.empty()
if uploaded_file is not None:
img_name = uploaded_file.name
pil_img = Image.open(uploaded_file)
img_rgb = np.array(pil_img)
resized_img_rgb = resize_img(img_rgb, max_img_size)
resized_pil_img = Image.fromarray(resized_img_rgb)
title_message.markdown("**Processing your image, please wait** βŒ›")
output_pil_img = colorizer.plot_transformed_pil_image(resized_pil_img, render_factor=35, compare=False)
title_message.markdown("**To begin, please upload an image** πŸ‘‡")
# Plot images
input_img_pos.image(resized_pil_img, 'Input image', use_column_width=True)
output_img_pos.image(output_pil_img, 'Output image', use_column_width=True)
st.markdown(get_image_download_link(output_pil_img, img_name, 'Download '+img_name), unsafe_allow_html=True)