File size: 4,496 Bytes
0caed3c fd5f9f3 0caed3c fd5f9f3 0caed3c 063c371 8ed6625 0caed3c fd5f9f3 0caed3c df23063 0caed3c 063c371 0caed3c fd5f9f3 8ed6625 f6df16f 8ed6625 fd5f9f3 f6df16f df23063 fd5f9f3 063c371 fd5f9f3 063c371 fd5f9f3 063c371 fd5f9f3 063c371 fd5f9f3 063c371 fd5f9f3 063c371 fd5f9f3 063c371 fd5f9f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import os
import numpy as np
from skimage import color, io
import torch
import torch.nn.functional as F
from PIL import Image
from models import ColorEncoder, ColorUNet
from extractor.manga_panel_extractor import PanelExtractor
import argparse
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
def mkdirs(path):
if not os.path.exists(path):
os.makedirs(path)
def Lab2RGB_out(img_lab):
img_lab = img_lab.detach().cpu()
img_l = img_lab[:,:1,:,:]
img_ab = img_lab[:,1:,:,:]
img_l = img_l + 50
pred_lab = torch.cat((img_l, img_ab), 1)[0,...].numpy()
out = (np.clip(color.lab2rgb(pred_lab.transpose(1, 2, 0)), 0, 1)* 255).astype("uint8")
return out
def RGB2Lab(inputs):
return color.rgb2lab(inputs)
def Normalize(inputs):
l = inputs[:, :, 0:1]
ab = inputs[:, :, 1:3]
l = l - 50
lab = np.concatenate((l, ab), 2)
return lab.astype('float32')
def numpy2tensor(inputs):
out = torch.from_numpy(inputs.transpose(2,0,1))
return out
def tensor2numpy(inputs):
out = inputs[0,...].detach().cpu().numpy().transpose(1,2,0)
return out
def preprocessing(inputs):
img_lab = Normalize(RGB2Lab(inputs))
img = np.array(inputs, 'float32')
img = numpy2tensor(img)
img_lab = numpy2tensor(img_lab)
return img.unsqueeze(0), img_lab.unsqueeze(0)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Colorize manga images based on a single reference image.")
parser.add_argument("-i", "--input_folder", type=str, required=True, help="Path to the input folder containing images to be colorized.")
parser.add_argument("-r", "--reference_image", type=str, required=True, help="Path to the reference image.")
parser.add_argument("-c", "--ckpt", type=str, required=True, help="Path to the model checkpoint file.")
parser.add_argument("-o", "--output_folder", type=str, required=True, help="Path to the output folder to save colorized images.")
args = parser.parse_args()
device = "cuda"
input_folder = args.input_folder
reference_image_path = args.reference_image
ckpt_path = args.ckpt
output_folder = args.output_folder
imgsize = 256
ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
colorEncoder = ColorEncoder().to(device)
colorEncoder.load_state_dict(ckpt["colorEncoder"])
colorEncoder.eval()
colorUNet = ColorUNet().to(device)
colorUNet.load_state_dict(ckpt["colorUNet"])
colorUNet.eval()
# Recorre recursivamente el directorio de entrada y procesa cada imagen encontrada
for root, dirs, files in os.walk(input_folder):
for file in files:
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')):
input_image_path = os.path.join(root, file)
img_name = os.path.splitext(os.path.basename(input_image_path))[0]
img1 = Image.open(input_image_path).convert("RGB")
width, height = img1.size
img1, img1_lab = preprocessing(img1)
img2, img2_lab = preprocessing(Image.open(reference_image_path).convert("RGB"))
img1 = img1.to(device)
img1_lab = img1_lab.to(device)
img2 = img2.to(device)
img2_lab = img2_lab.to(device)
with torch.no_grad():
img2_resize = F.interpolate(img2 / 255., size=(imgsize, imgsize), mode='bilinear', recompute_scale_factor=False, align_corners=False)
img1_L_resize = F.interpolate(img1_lab[:,:1,:,:] / 50., size=(imgsize, imgsize), mode='bilinear', recompute_scale_factor=False, align_corners=False)
color_vector = colorEncoder(img2_resize)
fake_ab = colorUNet((img1_L_resize, color_vector))
fake_ab = F.interpolate(fake_ab*110, size=(height, width), mode='bilinear', recompute_scale_factor=False, align_corners=False)
fake_img = torch.cat((img1_lab[:,:1,:,:], fake_ab), 1)
fake_img = Lab2RGB_out(fake_img)
out_subfolder = os.path.join(output_folder, os.path.relpath(root, input_folder))
out_folder = os.path.join(out_subfolder, 'color')
mkdirs(out_folder)
out_img_path = os.path.join(out_folder, f'{img_name}_color.png')
io.imsave(out_img_path, fake_img)
print(f'Colored images have been saved to {output_folder}.')
|