File size: 4,496 Bytes
0caed3c
 
 
fd5f9f3
0caed3c
 
fd5f9f3
0caed3c
 
063c371
8ed6625
0caed3c
 
 
 
 
 
 
 
 
 
 
 
 
fd5f9f3
0caed3c
 
 
 
 
 
 
 
 
 
 
 
 
df23063
0caed3c
 
063c371
 
 
 
 
 
 
 
 
 
 
0caed3c
fd5f9f3
 
 
 
 
8ed6625
f6df16f
8ed6625
fd5f9f3
 
 
 
 
 
f6df16f
 
 
df23063
 
 
 
 
 
 
 
 
fd5f9f3
 
 
 
 
 
 
 
 
 
 
063c371
fd5f9f3
 
 
 
063c371
fd5f9f3
 
 
063c371
fd5f9f3
063c371
fd5f9f3
 
063c371
fd5f9f3
 
063c371
fd5f9f3
 
 
 
 
063c371
fd5f9f3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import numpy as np
from skimage import color, io

import torch
import torch.nn.functional as F

from PIL import Image
from models import ColorEncoder, ColorUNet
from extractor.manga_panel_extractor import PanelExtractor
import argparse

os.environ["CUDA_VISIBLE_DEVICES"] = '0'

def mkdirs(path):
    if not os.path.exists(path):
        os.makedirs(path)

def Lab2RGB_out(img_lab):
    img_lab = img_lab.detach().cpu()
    img_l = img_lab[:,:1,:,:]
    img_ab = img_lab[:,1:,:,:]
    img_l = img_l + 50
    pred_lab = torch.cat((img_l, img_ab), 1)[0,...].numpy()
    out = (np.clip(color.lab2rgb(pred_lab.transpose(1, 2, 0)), 0, 1)* 255).astype("uint8")
    return out

def RGB2Lab(inputs):
    return color.rgb2lab(inputs)

def Normalize(inputs):
    l = inputs[:, :, 0:1]
    ab = inputs[:, :, 1:3]
    l = l - 50
    lab = np.concatenate((l, ab), 2)
    return lab.astype('float32')

def numpy2tensor(inputs):
    out = torch.from_numpy(inputs.transpose(2,0,1))
    return out

def tensor2numpy(inputs):
    out = inputs[0,...].detach().cpu().numpy().transpose(1,2,0)
    return out

def preprocessing(inputs):
    img_lab = Normalize(RGB2Lab(inputs))
    img = np.array(inputs, 'float32')
    img = numpy2tensor(img)
    img_lab = numpy2tensor(img_lab)
    return img.unsqueeze(0), img_lab.unsqueeze(0)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Colorize manga images based on a single reference image.")
    parser.add_argument("-i", "--input_folder", type=str, required=True, help="Path to the input folder containing images to be colorized.")
    parser.add_argument("-r", "--reference_image", type=str, required=True, help="Path to the reference image.")
    parser.add_argument("-c", "--ckpt", type=str, required=True, help="Path to the model checkpoint file.")
    parser.add_argument("-o", "--output_folder", type=str, required=True, help="Path to the output folder to save colorized images.")
    args = parser.parse_args()

    device = "cuda"

    input_folder = args.input_folder
    reference_image_path = args.reference_image
    ckpt_path = args.ckpt
    output_folder = args.output_folder

    imgsize = 256

    ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage)

    colorEncoder = ColorEncoder().to(device)
    colorEncoder.load_state_dict(ckpt["colorEncoder"])
    colorEncoder.eval()

    colorUNet = ColorUNet().to(device)
    colorUNet.load_state_dict(ckpt["colorUNet"])
    colorUNet.eval()

    # Recorre recursivamente el directorio de entrada y procesa cada imagen encontrada
    for root, dirs, files in os.walk(input_folder):
        for file in files:
            if file.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')):
                input_image_path = os.path.join(root, file)
                img_name = os.path.splitext(os.path.basename(input_image_path))[0]

                img1 = Image.open(input_image_path).convert("RGB")
                width, height = img1.size
                img1, img1_lab = preprocessing(img1)
                img2, img2_lab = preprocessing(Image.open(reference_image_path).convert("RGB"))

                img1 = img1.to(device)
                img1_lab = img1_lab.to(device)
                img2 = img2.to(device)
                img2_lab = img2_lab.to(device)

                with torch.no_grad():
                    img2_resize = F.interpolate(img2 / 255., size=(imgsize, imgsize), mode='bilinear', recompute_scale_factor=False, align_corners=False)
                    img1_L_resize = F.interpolate(img1_lab[:,:1,:,:] / 50., size=(imgsize, imgsize), mode='bilinear', recompute_scale_factor=False, align_corners=False)

                    color_vector = colorEncoder(img2_resize)

                    fake_ab = colorUNet((img1_L_resize, color_vector))
                    fake_ab = F.interpolate(fake_ab*110, size=(height, width), mode='bilinear', recompute_scale_factor=False, align_corners=False)

                    fake_img = torch.cat((img1_lab[:,:1,:,:], fake_ab), 1)
                    fake_img = Lab2RGB_out(fake_img)

                    out_subfolder = os.path.join(output_folder, os.path.relpath(root, input_folder))
                    out_folder = os.path.join(out_subfolder, 'color')
                    mkdirs(out_folder)
                    out_img_path = os.path.join(out_folder, f'{img_name}_color.png')
                    io.imsave(out_img_path, fake_img)

    print(f'Colored images have been saved to {output_folder}.')