# Import general purpose libraries import os, re, time import streamlit as st import PIL import cv2 import numpy as np import uuid from zipfile import ZipFile, ZIP_DEFLATED from io import BytesIO from random import randint # Import util functions from deoldify # NOTE: This must be the first call in order to work properly! from deoldify import device from deoldify.device_id import DeviceId #choices: CPU, GPU0...GPU7 device.set(device=DeviceId.CPU) from deoldify.visualize import * # Import util functions from app_utils from app_utils import get_model_bin SESSION_STATE_VARIABLES = [ 'model_folder','max_img_size','uploaded_file_key','uploaded_files' ] for i in SESSION_STATE_VARIABLES: if i not in st.session_state: st.session_state[i] = None #### SET INPUT PARAMS ########### if not st.session_state.model_folder: st.session_state.model_folder = 'models/' if not st.session_state.max_img_size: st.session_state.max_img_size = 800 ################################ @st.cache(allow_output_mutation=True, show_spinner=False) def load_model(model_dir, option): if option.lower() == 'artistic': model_url = 'https://data.deepai.org/deoldify/ColorizeArtistic_gen.pth' get_model_bin(model_url, os.path.join(model_dir, "ColorizeArtistic_gen.pth")) colorizer = get_image_colorizer(artistic=True) elif option.lower() == 'stable': model_url = "https://www.dropbox.com/s/usf7uifrctqw9rl/ColorizeStable_gen.pth?dl=0" get_model_bin(model_url, os.path.join(model_dir, "ColorizeStable_gen.pth")) colorizer = get_image_colorizer(artistic=False) return colorizer def resize_img(input_img, max_size): img = input_img.copy() img_height, img_width = img.shape[0],img.shape[1] if max(img_height, img_width) > max_size: if img_height > img_width: new_width = img_width*(max_size/img_height) new_height = max_size resized_img = cv2.resize(img,(int(new_width), int(new_height))) return resized_img elif img_height <= img_width: new_width = img_height*(max_size/img_width) new_height = max_size resized_img = cv2.resize(img,(int(new_width), int(new_height))) return resized_img return img def get_image_download_link(img, filename, button_text): button_uuid = str(uuid.uuid4()).replace('-', '') button_id = re.sub('\d+', '', button_uuid) buffered = BytesIO() img.save(buffered, format="JPEG") img_str = base64.b64encode(buffered.getvalue()).decode() return get_button_html_code(img_str, filename, 'txt', button_id, button_text) def get_button_html_code(data_str, filename, filetype, button_id, button_txt='Download file'): custom_css = f""" """ href = custom_css + f'{button_txt}' return href def display_single_image(uploaded_file, img_size=800): st_title_message.markdown("**Processing your image, please wait** ⌛") img_name = uploaded_file.name # Open the image pil_img = PIL.Image.open(uploaded_file) img_rgb = np.array(pil_img) resized_img_rgb = resize_img(img_rgb, img_size) resized_pil_img = PIL.Image.fromarray(resized_img_rgb) # Send the image to the model output_pil_img = colorizer.plot_transformed_pil_image(resized_pil_img, render_factor=35, compare=False) # Plot images st_input_img.image(resized_pil_img, 'Input image', use_column_width=True) st_output_img.image(output_pil_img, 'Output image', use_column_width=True) # Show download button st_download_button.markdown(get_image_download_link(output_pil_img, img_name, 'Download Image'), unsafe_allow_html=True) # Reset the message st_title_message.markdown("**To begin, please upload an image** 👇") def process_multiple_images(uploaded_files, img_size=800): num_imgs = len(uploaded_files) output_images_list = [] img_names_list = [] idx = 1 st_progress_bar.progress(0) for idx, uploaded_file in enumerate(uploaded_files, start=1): st_title_message.markdown("**Processing image {}/{}. Please wait** ⌛".format(idx, num_imgs)) img_name = uploaded_file.name img_type = uploaded_file.type # Open the image pil_img = PIL.Image.open(uploaded_file) img_rgb = np.array(pil_img) resized_img_rgb = resize_img(img_rgb, img_size) resized_pil_img = PIL.Image.fromarray(resized_img_rgb) # Send the image to the model output_pil_img = colorizer.plot_transformed_pil_image(resized_pil_img, render_factor=35, compare=False) output_images_list.append(output_pil_img) img_names_list.append(img_name.split('.')[0]) percent = int((idx / num_imgs)*100) st_progress_bar.progress(percent) # Zip output files zip_path = 'processed_images.zip' zip_buf = zip_multiple_images(output_images_list, img_names_list, zip_path) st_download_button.download_button( label='Download ZIP file', data=zip_buf.read(), file_name=zip_path, mime="application/zip" ) # Show message st_title_message.markdown("**Images are ready for download** 💾") def zip_multiple_images(pil_images_list, img_names_list, dest_path): # Create zip file on memory zip_buf = BytesIO() with ZipFile(zip_buf, 'w', ZIP_DEFLATED) as zipObj: for pil_img, img_name in zip(pil_images_list, img_names_list): with BytesIO() as output: # Save image in memory pil_img.save(output, format="PNG") # Read data contents = output.getvalue() # Write it to zip file zipObj.writestr(img_name+".png", contents) zip_buf.seek(0) return zip_buf ########################### ###### STREAMLIT CODE ##### ########################### # General configuration # st.set_page_config(layout="centered") st.set_page_config(layout="wide") st.set_option('deprecation.showfileUploaderEncoding', False) st.markdown('''