#importing the libraries import os, sys, re import streamlit as st import PIL from PIL import Image import cv2 import numpy as np import uuid import ssl ssl._create_default_https_context = ssl._create_unverified_context # Import torch libraries import fastai import torch # Import util functions from app_utils from app_utils import download from app_utils import generate_random_filename from app_utils import clean_me from app_utils import clean_all from app_utils import get_model_bin from app_utils import convertToJPG # Import util functions from deoldify # NOTE: This must be the first call in order to work properly! from deoldify import device from deoldify.device_id import DeviceId #choices: CPU, GPU0...GPU7 device.set(device=DeviceId.CPU) from deoldify.visualize import * ####### INPUT PARAMS ########### model_folder = 'models/' max_img_size = 800 ################################ @st.cache(allow_output_mutation=True) def load_model(model_dir, option): if option.lower() == 'artistic': model_url = 'https://data.deepai.org/deoldify/ColorizeArtistic_gen.pth' get_model_bin(model_url, os.path.join(model_dir, "ColorizeArtistic_gen.pth")) colorizer = get_image_colorizer(artistic=True) elif option.lower() == 'stable': model_url = "https://www.dropbox.com/s/usf7uifrctqw9rl/ColorizeStable_gen.pth?dl=0" get_model_bin(model_url, os.path.join(model_dir, "ColorizeStable_gen.pth")) colorizer = get_image_colorizer(artistic=False) return colorizer def resize_img(input_img, max_size): img = input_img.copy() img_height, img_width = img.shape[0],img.shape[1] if max(img_height, img_width) > max_size: if img_height > img_width: new_width = img_width*(max_size/img_height) new_height = max_size resized_img = cv2.resize(img,(int(new_width), int(new_height))) return resized_img elif img_height <= img_width: new_width = img_height*(max_size/img_width) new_height = max_size resized_img = cv2.resize(img,(int(new_width), int(new_height))) return resized_img return img def get_image_download_link(img,filename,text): button_uuid = str(uuid.uuid4()).replace('-', '') button_id = re.sub('\d+', '', button_uuid) custom_css = f""" """ buffered = BytesIO() img.save(buffered, format="JPEG") img_str = base64.b64encode(buffered.getvalue()).decode() href = custom_css + f'{text}' return href # General configuration # st.set_page_config(layout="centered") st.set_page_config(layout="wide") st.set_option('deprecation.showfileUploaderEncoding', False) st.markdown('''