|
import os.path |
|
|
|
import streamlit as st |
|
from models.deep_colorization.colorizers import * |
|
import cv2 |
|
from PIL import Image |
|
import tempfile |
|
import moviepy.editor as mp |
|
import time |
|
from tqdm import tqdm |
|
|
|
|
|
def format_time(seconds: float) -> str: |
|
"""Formats time in seconds to a human readable format""" |
|
if seconds < 60: |
|
return f"{int(seconds)} seconds" |
|
elif seconds < 3600: |
|
minutes = seconds // 60 |
|
seconds %= 60 |
|
return f"{minutes} minutes and {int(seconds)} seconds" |
|
elif seconds < 86400: |
|
hours = seconds // 3600 |
|
minutes = (seconds % 3600) // 60 |
|
seconds %= 60 |
|
return f"{hours} hours, {minutes} minutes, and {int(seconds)} seconds" |
|
else: |
|
days = seconds // 86400 |
|
hours = (seconds % 86400) // 3600 |
|
minutes = (seconds % 3600) // 60 |
|
seconds %= 60 |
|
return f"{days} days, {hours} hours, {minutes} minutes, and {int(seconds)} seconds" |
|
|
|
|
|
|
|
def colorize_frame(frame, colorizer) -> np.ndarray: |
|
tens_l_orig, tens_l_rs = preprocess_img(frame, HW=(256, 256)) |
|
return postprocess_tens(tens_l_orig, colorizer(tens_l_rs).cpu()) |
|
|
|
image = Image.open(r'img/streamlit.png') |
|
|
|
|
|
col1, col2 = st.columns([0.8, 0.2]) |
|
with col1: |
|
st.markdown(""" <style> .font { |
|
font-size:35px ; font-family: 'Cooper Black'; color: #FF4B4B;} |
|
</style> """, unsafe_allow_html=True) |
|
st.markdown('<p class="font">Upload your photo or video here...</p>', unsafe_allow_html=True) |
|
|
|
with col2: |
|
st.image(image, width=100) |
|
|
|
|
|
st.sidebar.markdown('<p class="font">Color Revive App</p>', unsafe_allow_html=True) |
|
with st.sidebar.expander("About the App"): |
|
st.write(""" |
|
Use this simple app to colorize your black and white images and videos with state of the art models. |
|
""") |
|
|
|
|
|
uploaded_file = st.file_uploader("", type=['jpg', 'png', 'jpeg', 'mp4']) |
|
|
|
|
|
if uploaded_file is not None: |
|
file_extension = os.path.splitext(uploaded_file.name)[1].lower() |
|
|
|
if file_extension in ['jpg', 'png', 'jpeg']: |
|
image = Image.open(uploaded_file) |
|
|
|
col1, col2 = st.columns([0.5, 0.5]) |
|
with col1: |
|
st.markdown('<p style="text-align: center;">Before</p>', unsafe_allow_html=True) |
|
st.image(image, width=300) |
|
|
|
|
|
with col2: |
|
st.markdown('<p style="text-align: center;">After</p>', unsafe_allow_html=True) |
|
filter = st.sidebar.radio('Colorize your image with:', |
|
['Original', 'ECCV 16', 'SIGGRAPH 17']) |
|
if filter == 'ECCV 16': |
|
colorizer_eccv16 = eccv16(pretrained=True).eval() |
|
img = load_img(uploaded_file) |
|
tens_l_orig, tens_l_rs = preprocess_img(img, HW=(256, 256)) |
|
out_img_eccv16 = postprocess_tens(tens_l_orig, colorizer_eccv16(tens_l_rs).cpu()) |
|
st.image(out_img_eccv16, width=300) |
|
elif filter == 'SIGGRAPH 17': |
|
colorizer_siggraph17 = siggraph17(pretrained=True).eval() |
|
img = load_img(uploaded_file) |
|
tens_l_orig, tens_l_rs = preprocess_img(img, HW=(256, 256)) |
|
out_img_siggraph17 = postprocess_tens(tens_l_orig, colorizer_siggraph17(tens_l_rs).cpu()) |
|
st.image(out_img_siggraph17, width=300) |
|
else: |
|
st.image(image, width=300) |
|
elif file_extension == 'mp4': |
|
|
|
temp_file = tempfile.NamedTemporaryFile(delete=False) |
|
temp_file.write(uploaded_file.read()) |
|
|
|
|
|
video = cv2.VideoCapture(temp_file.name) |
|
|
|
|
|
fps = video.get(cv2.CAP_PROP_FPS) |
|
|
|
|
|
col1, col2 = st.columns([0.5, 0.5]) |
|
with col1: |
|
st.markdown('<p style="text-align: center;">Before</p>', unsafe_allow_html=True) |
|
st.video(temp_file.name) |
|
|
|
with col2: |
|
st.markdown('<p style="text-align: center;">After</p>', unsafe_allow_html=True) |
|
filter = st.sidebar.radio('Colorize your video with:', |
|
['Original', 'ECCV 16', 'SIGGRAPH 17']) |
|
if filter == 'ECCV 16': |
|
colorizer = eccv16(pretrained=True).eval() |
|
elif filter == 'SIGGRAPH 17': |
|
colorizer = siggraph17(pretrained=True).eval() |
|
|
|
if filter != 'Original': |
|
with st.spinner("Colorizing frames..."): |
|
|
|
output_frames = [] |
|
total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
progress_bar = st.empty() |
|
|
|
start_time = time.time() |
|
for i in tqdm(range(total_frames), unit='frame', desc="Progress"): |
|
ret, frame = video.read() |
|
if not ret: |
|
break |
|
|
|
colorized_frame = colorize_frame(frame, colorizer) |
|
output_frames.append((colorized_frame * 255).astype(np.uint8)) |
|
|
|
elapsed_time = time.time() - start_time |
|
frames_completed = len(output_frames) |
|
frames_remaining = total_frames - frames_completed |
|
time_remaining = (frames_remaining / frames_completed) * elapsed_time |
|
|
|
progress_bar.progress(frames_completed / total_frames) |
|
|
|
if frames_completed < total_frames: |
|
progress_bar.text(f"Time Remaining: {format_time(time_remaining)}") |
|
else: |
|
progress_bar.empty() |
|
|
|
with st.spinner("Merging frames to video..."): |
|
frame_size = output_frames[0].shape[:2] |
|
output_filename = "output.mp4" |
|
fourcc = cv2.VideoWriter_fourcc(*"mp4v") |
|
out = cv2.VideoWriter(output_filename, fourcc, fps, (frame_size[1], frame_size[0])) |
|
|
|
|
|
for frame in output_frames: |
|
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) |
|
|
|
out.write(frame_bgr) |
|
|
|
out.release() |
|
|
|
|
|
converted_filename = "converted_output.mp4" |
|
clip = mp.VideoFileClip(output_filename) |
|
clip.write_videofile(converted_filename, codec="libx264") |
|
|
|
|
|
st.video(converted_filename) |
|
|
|
|
|
st.download_button( |
|
label="Download Colorized Video", |
|
data=open(converted_filename, "rb").read(), |
|
file_name="colorized_video.mp4" |
|
) |
|
|
|
|
|
video.release() |
|
temp_file.close() |
|
|
|
|
|
st.sidebar.title(' ') |
|
st.sidebar.markdown(' ') |
|
st.sidebar.subheader('Please help us improve!') |
|
with st.sidebar.form(key='columns_in_form', |
|
clear_on_submit=True): |
|
|
|
rating = st.slider("Please rate the app", min_value=1, max_value=5, value=3, |
|
help='Drag the slider to rate the app. This is a 1-5 rating scale where 5 is the highest rating') |
|
text = st.text_input(label='Please leave your feedback here') |
|
submitted = st.form_submit_button('Submit') |
|
if submitted: |
|
st.write('Thanks for your feedback!') |
|
st.markdown('Your Rating:') |
|
st.markdown(rating) |
|
st.markdown('Your Feedback:') |
|
st.markdown(text) |
|
|