#importing the libraries
import os, sys, re
import streamlit as st
import PIL
from PIL import Image
import cv2
import numpy as np
import uuid
# Import torch libraries
import fastai
import torch
# Import util functions from app_utils
from app_utils import download
from app_utils import generate_random_filename
from app_utils import clean_me
from app_utils import clean_all
from app_utils import create_directory
from app_utils import get_model_bin
from app_utils import convertToJPG
# Import util functions from deoldify
# NOTE: This must be the first call in order to work properly!
from deoldify import device
from deoldify.device_id import DeviceId
#choices: CPU, GPU0...GPU7
device.set(device=DeviceId.CPU)
from deoldify.visualize import *
####### INPUT PARAMS ###########
model_folder = 'models/'
max_img_size = 800
################################
@st.cache(allow_output_mutation=True)
def load_model(model_dir, option):
if option.lower() == 'artistic':
model_url = 'https://data.deepai.org/deoldify/ColorizeArtistic_gen.pth'
get_model_bin(model_url, os.path.join(model_dir, "ColorizeArtistic_gen.pth"))
colorizer = get_image_colorizer(artistic=True)
elif option.lower() == 'stable':
model_url = "https://www.dropbox.com/s/usf7uifrctqw9rl/ColorizeStable_gen.pth?dl=0"
get_model_bin(model_url, os.path.join(model_dir, "ColorizeStable_gen.pth"))
colorizer = get_image_colorizer(artistic=False)
return colorizer
def resize_img(input_img, max_size):
img = input_img.copy()
img_height, img_width = img.shape[0],img.shape[1]
if max(img_height, img_width) > max_size:
if img_height > img_width:
new_width = img_width*(max_size/img_height)
new_height = max_size
resized_img = cv2.resize(img,(int(new_width), int(new_height)))
return resized_img
elif img_height <= img_width:
new_width = img_height*(max_size/img_width)
new_height = max_size
resized_img = cv2.resize(img,(int(new_width), int(new_height)))
return resized_img
return img
def get_image_download_link(img,filename,text):
button_uuid = str(uuid.uuid4()).replace('-', '')
button_id = re.sub('\d+', '', button_uuid)
custom_css = f"""
"""
buffered = BytesIO()
img.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode()
href = custom_css + f'{text}'
return href
# General configuration
# st.set_page_config(layout="centered")
st.set_page_config(layout="wide")
st.set_option('deprecation.showfileUploaderEncoding', False)
st.markdown('''