|
import streamlit as st |
|
import requests |
|
import os |
|
import sys |
|
from PIL import Image |
|
import io |
|
import time |
|
from pathlib import Path |
|
|
|
|
|
import collections |
|
import collections.abc |
|
for typ in ['Sized', 'Iterable', 'Mapping', 'MutableMapping', 'Sequence', 'MutableSequence']: |
|
if not hasattr(collections, typ): |
|
setattr(collections, typ, getattr(collections.abc, typ)) |
|
|
|
|
|
sys.path.append('./DeOldify') |
|
|
|
os.makedirs('models', exist_ok=True) |
|
|
|
if not os.path.exists('models/ColorizeArtistic_gen.pth'): |
|
os.symlink(os.path.abspath('./DeOldify/models/ColorizeArtistic_gen.pth'), |
|
'models/ColorizeArtistic_gen.pth') |
|
if not os.path.exists('models/ColorizeStable_gen.pth'): |
|
os.symlink(os.path.abspath( |
|
'./DeOldify/models/ColorizeStable_gen.pth'), 'models/ColorizeStable_gen.pth') |
|
if not os.path.exists('models/ColorizeVideo_gen.pth'): |
|
os.symlink(os.path.abspath( |
|
'./DeOldify/models/ColorizeVideo_gen.pth'), 'models/ColorizeVideo_gen.pth') |
|
|
|
|
|
|
|
API_URL = "http://localhost:8000" |
|
|
|
st.set_page_config( |
|
page_title="Image Colorization App", |
|
page_icon="🎨", |
|
layout="wide", |
|
) |
|
|
|
st.title("Black & White Image Colorization") |
|
st.markdown(""" |
|
Turn your black and white photos into colorized versions using DeOldify technology. |
|
Upload an image to get started! |
|
""") |
|
|
|
|
|
uploaded_file = st.file_uploader( |
|
"Choose a black and white image...", type=["jpg", "jpeg", "png"]) |
|
|
|
|
|
with st.sidebar: |
|
st.header("Colorization Options") |
|
|
|
|
|
model_type = st.radio( |
|
"Select Colorization Model", |
|
options=["Artistic", "Stable"], |
|
index=0, |
|
help="Artistic provides more vibrant colors, Stable provides more realistic colors" |
|
) |
|
|
|
|
|
render_factor = st.slider( |
|
"Render Factor", |
|
min_value=5, |
|
max_value=50, |
|
value=35, |
|
step=1, |
|
help="Higher values give better quality but take longer. Recommend 35 for artistic, 20 for stable." |
|
) |
|
|
|
|
|
st.subheader("Generate Multiple Renders") |
|
use_multiple_renders = st.checkbox( |
|
"Create multiple renders with different factors", value=False) |
|
|
|
if use_multiple_renders: |
|
min_factor = st.slider("Minimum Render Factor", 5, 40, 10, 5) |
|
max_factor = st.slider("Maximum Render Factor", |
|
min_factor + 5, 50, 40, 5) |
|
step_size = st.slider("Step Size", 1, 10, 5, 1) |
|
|
|
|
|
if uploaded_file is not None: |
|
|
|
col1, col2 = st.columns(2) |
|
|
|
with col1: |
|
st.subheader("Original Image") |
|
image = Image.open(uploaded_file) |
|
st.image(image, use_column_width=True) |
|
|
|
|
|
process_button = st.button("Colorize Image") |
|
|
|
if process_button: |
|
artistic_param = True if model_type == "Artistic" else False |
|
|
|
with st.spinner("Colorizing your image... Please wait."): |
|
try: |
|
if use_multiple_renders: |
|
|
|
files = { |
|
"file": ("image.jpg", uploaded_file.getvalue(), "image/jpeg")} |
|
params = { |
|
"min_render_factor": min_factor, |
|
"max_render_factor": max_factor, |
|
"step": step_size, |
|
"artistic": artistic_param |
|
} |
|
|
|
response = requests.post( |
|
f"{API_URL}/colorize_multiple", files=files, params=params) |
|
|
|
if response.status_code == 200: |
|
result = response.json() |
|
st.success("Multiple renders completed!") |
|
|
|
|
|
st.subheader("Select Render Factor") |
|
selected_index = st.select_slider( |
|
"Choose the render factor that looks best:", |
|
options=result["render_factors"] |
|
) |
|
|
|
|
|
index = result["render_factors"].index(selected_index) |
|
selected_image_path = result["output_paths"][index] |
|
|
|
|
|
with col2: |
|
st.subheader( |
|
f"Colorized (Render Factor: {selected_index})") |
|
colorized_img = requests.get( |
|
f"{API_URL}/image/{selected_image_path}").content |
|
st.image(Image.open(io.BytesIO( |
|
colorized_img)), use_column_width=True) |
|
|
|
|
|
st.download_button( |
|
label="Download Colorized Image", |
|
data=colorized_img, |
|
file_name=f"colorized_rf{selected_index}.jpg", |
|
mime="image/jpeg" |
|
) |
|
else: |
|
st.error(f"Error: {response.text}") |
|
|
|
else: |
|
|
|
files = { |
|
"file": ("image.jpg", uploaded_file.getvalue(), "image/jpeg")} |
|
params = { |
|
"render_factor": render_factor, |
|
"artistic": artistic_param |
|
} |
|
|
|
response = requests.post( |
|
f"{API_URL}/colorize", files=files, params=params) |
|
|
|
if response.status_code == 200: |
|
result = response.json() |
|
with col2: |
|
st.subheader( |
|
f"Colorized (Render Factor: {result['render_factor']})") |
|
colorized_img = requests.get( |
|
f"{API_URL}/image/{result['output_path']}").content |
|
st.image(Image.open(io.BytesIO( |
|
colorized_img)), use_column_width=True) |
|
|
|
|
|
st.download_button( |
|
label="Download Colorized Image", |
|
data=colorized_img, |
|
file_name="colorized.jpg", |
|
mime="image/jpeg" |
|
) |
|
else: |
|
st.error(f"Error: {response.text}") |
|
|
|
except Exception as e: |
|
st.error(f"An error occurred: {str(e)}") |
|
|
|
|
|
st.markdown("---") |
|
st.markdown("Powered by DeOldify - Image Colorization Project") |
|
|