File size: 4,255 Bytes
fa26127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import torch
from torchvision import transforms

import numpy as np
from skimage.color import rgb2lab, lab2rgb
import skimage.transform
from PIL import Image

import os
from tqdm import tqdm
from moviepy.editor import VideoFileClip, AudioFileClip
from moviepy.tools import cvsecs
import cv2

from pdb import set_trace


def lab_to_rgb(L, ab):
    """
    Takes a batch of images
    """
    L = (L + 1.) * 50.
    ab = ab * 110.
    Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
    rgb_imgs = []
    for img in Lab:
        img_rgb = lab2rgb(img)
        rgb_imgs.append(img_rgb)
    return np.stack(rgb_imgs, axis=0)


SIZE = 256


def get_L(img):
    img = transforms.Resize(
        (SIZE, SIZE), transforms.InterpolationMode.BICUBIC)(img)
    img = np.array(img)
    img_lab = rgb2lab(img).astype("float32")
    img_lab = transforms.ToTensor()(img_lab)
    L = img_lab[[0], ...] / 50. - 1.  # Between -1 and 1

    return L


def get_predictions(model, L):
    # model.L = L.to(model.device)
    model.eval()
    with torch.no_grad():
        model.L = L.to(torch.device('cpu'))
        model.forward()
    fake_color = model.fake_color.detach()
    fake_imgs = lab_to_rgb(L, fake_color)

    return fake_imgs


def colorize_img(model, img):
    L = get_L(img)
    L = L[None]  # put in list
    fake_imgs = get_predictions(model, L)
    fake_img = fake_imgs[0]  # get out of list
    resized_fake_img = skimage.transform.resize(
        fake_img, img.size[::-1])  # reshape to original size

    return resized_fake_img


def valid_start_end(duration, start_input, end_input):
    start = start_input
    end = end_input
    if start == '':
        start = 0
    if end == '':
        end = duration

    try:
        start = cvsecs(start)
        end = cvsecs(end)
    except BaseException:
        # start, end aren't actual time values.
        raise Exception("Invalid start, end values")

    # make it minimal maximum length
    start = max(start, 0)
    end = min(duration, end)

    # start must be less than end
    if start >= end:
        raise Exception("Start must be before end.")

    return start, end


def colorize_vid(path_input, model, fps, start_input, end_input):

    original_video = VideoFileClip(path_input)

    # validate start, end
    start, end = valid_start_end(
        original_video.duration, start_input, end_input)

    input_video = original_video.subclip(start, end)

    if isinstance(fps, int):
        used_fps = fps
        nframes = np.round(fps * input_video.duration)
    else:
        used_fps = input_video.fps
        nframes = input_video.reader.nframes
    print(
        f"Colorizing output with FPS: {fps}, nframes: {nframes}, resolution: {input_video.size}.")

    frames = input_video.iter_frames(fps=used_fps)

    # create tmp path that is same as input path but with '_tmp.[suffix]'
    base_path, suffix = os.path.splitext(path_input)
    path_video_tmp = base_path + "_tmp" + suffix

    # create video writer for output
    size = input_video.size
    out = cv2.VideoWriter(
        path_video_tmp,
        cv2.VideoWriter_fourcc(
            *'mp4v'),
        used_fps,
        size)
    # out = cv2.VideoWriter(path_video_tmp, cv2.VideoWriter_fourcc(*'DIVX'), used_fps, size)

    for frame in tqdm(frames, total=nframes):
        # get colorized frame
        color_frame = colorize_img(model, Image.fromarray(frame))

        if color_frame.max() <= 1:
            color_frame = (color_frame * 255).astype(np.uint8)

        color_frame = cv2.cvtColor(color_frame, cv2.COLOR_BGR2RGB)
        out.write(color_frame)
    out.release()

    # create output path that is same as input path but with '_out.[suffix]'
    path_output = base_path + "_out" + suffix

    # for some reason, subclip doesn't save audio. so make tmp audio file
    path_audio_tmp = base_path + "audio_tmp.mp3"
    input_video.audio.write_audiofile(path_audio_tmp, logger=None)
    input_audio = AudioFileClip(path_audio_tmp)

    output_video = VideoFileClip(path_video_tmp)
    output_video = output_video.set_audio(input_audio)
    output_video.write_videofile(path_output, logger=None)

    os.remove(path_video_tmp)
    os.remove(path_audio_tmp)

    print("Done.")
    return path_output