|
import torch |
|
from torchvision import transforms |
|
|
|
import numpy as np |
|
from skimage.color import rgb2lab, lab2rgb |
|
import skimage.transform |
|
from PIL import Image |
|
|
|
import os |
|
from tqdm import tqdm |
|
from moviepy.editor import VideoFileClip, AudioFileClip |
|
from moviepy.tools import cvsecs |
|
import cv2 |
|
|
|
from pdb import set_trace |
|
|
|
|
|
def lab_to_rgb(L, ab): |
|
""" |
|
Takes a batch of images |
|
""" |
|
L = (L + 1.) * 50. |
|
ab = ab * 110. |
|
Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy() |
|
rgb_imgs = [] |
|
for img in Lab: |
|
img_rgb = lab2rgb(img) |
|
rgb_imgs.append(img_rgb) |
|
return np.stack(rgb_imgs, axis=0) |
|
|
|
|
|
SIZE = 256 |
|
|
|
|
|
def get_L(img): |
|
img = transforms.Resize( |
|
(SIZE, SIZE), transforms.InterpolationMode.BICUBIC)(img) |
|
img = np.array(img) |
|
img_lab = rgb2lab(img).astype("float32") |
|
img_lab = transforms.ToTensor()(img_lab) |
|
L = img_lab[[0], ...] / 50. - 1. |
|
|
|
return L |
|
|
|
|
|
def get_predictions(model, L): |
|
|
|
model.eval() |
|
with torch.no_grad(): |
|
model.L = L.to(torch.device('cpu')) |
|
model.forward() |
|
fake_color = model.fake_color.detach() |
|
fake_imgs = lab_to_rgb(L, fake_color) |
|
|
|
return fake_imgs |
|
|
|
|
|
def colorize_img(model, img): |
|
L = get_L(img) |
|
L = L[None] |
|
fake_imgs = get_predictions(model, L) |
|
fake_img = fake_imgs[0] |
|
resized_fake_img = skimage.transform.resize( |
|
fake_img, img.size[::-1]) |
|
|
|
return resized_fake_img |
|
|
|
|
|
def valid_start_end(duration, start_input, end_input): |
|
start = start_input |
|
end = end_input |
|
if start == '': |
|
start = 0 |
|
if end == '': |
|
end = duration |
|
|
|
try: |
|
start = cvsecs(start) |
|
end = cvsecs(end) |
|
except BaseException: |
|
|
|
raise Exception("Invalid start, end values") |
|
|
|
|
|
start = max(start, 0) |
|
end = min(duration, end) |
|
|
|
|
|
if start >= end: |
|
raise Exception("Start must be before end.") |
|
|
|
return start, end |
|
|
|
|
|
def colorize_vid(path_input, model, fps, start_input, end_input): |
|
|
|
original_video = VideoFileClip(path_input) |
|
|
|
|
|
start, end = valid_start_end( |
|
original_video.duration, start_input, end_input) |
|
|
|
input_video = original_video.subclip(start, end) |
|
|
|
if isinstance(fps, int): |
|
used_fps = fps |
|
nframes = np.round(fps * input_video.duration) |
|
else: |
|
used_fps = input_video.fps |
|
nframes = input_video.reader.nframes |
|
print( |
|
f"Colorizing output with FPS: {fps}, nframes: {nframes}, resolution: {input_video.size}.") |
|
|
|
frames = input_video.iter_frames(fps=used_fps) |
|
|
|
|
|
base_path, suffix = os.path.splitext(path_input) |
|
path_video_tmp = base_path + "_tmp" + suffix |
|
|
|
|
|
size = input_video.size |
|
out = cv2.VideoWriter( |
|
path_video_tmp, |
|
cv2.VideoWriter_fourcc( |
|
*'mp4v'), |
|
used_fps, |
|
size) |
|
|
|
|
|
for frame in tqdm(frames, total=nframes): |
|
|
|
color_frame = colorize_img(model, Image.fromarray(frame)) |
|
|
|
if color_frame.max() <= 1: |
|
color_frame = (color_frame * 255).astype(np.uint8) |
|
|
|
color_frame = cv2.cvtColor(color_frame, cv2.COLOR_BGR2RGB) |
|
out.write(color_frame) |
|
out.release() |
|
|
|
|
|
path_output = base_path + "_out" + suffix |
|
|
|
|
|
path_audio_tmp = base_path + "audio_tmp.mp3" |
|
input_video.audio.write_audiofile(path_audio_tmp, logger=None) |
|
input_audio = AudioFileClip(path_audio_tmp) |
|
|
|
output_video = VideoFileClip(path_video_tmp) |
|
output_video = output_video.set_audio(input_audio) |
|
output_video.write_videofile(path_output, logger=None) |
|
|
|
os.remove(path_video_tmp) |
|
os.remove(path_audio_tmp) |
|
|
|
print("Done.") |
|
return path_output |
|
|