Keiser41 commited on
Commit
fd5f9f3
1 Parent(s): 84022c3

Update pintar.py

Browse files
Files changed (1) hide show
  1. pintar.py +43 -33
pintar.py CHANGED
@@ -1,8 +1,10 @@
1
  import os
2
  import numpy as np
3
  from skimage import color, io
 
4
  import torch
5
  import torch.nn.functional as F
 
6
  from PIL import Image
7
  from models import ColorEncoder, ColorUNet
8
  from extractor.manga_panel_extractor import PanelExtractor
@@ -20,7 +22,7 @@ def Lab2RGB_out(img_lab):
20
  img_ab = img_lab[:,1:,:,:]
21
  img_l = img_l + 50
22
  pred_lab = torch.cat((img_l, img_ab), 1)[0,...].numpy()
23
- out = (np.clip(color.lab2rgb(pred_lab.transpose(1, 2, 0)), 0, 1) * 255).astype("uint8")
24
  return out
25
 
26
  def RGB2Lab(inputs):
@@ -49,19 +51,20 @@ def preprocessing(inputs):
49
  return img.unsqueeze(0), img_lab.unsqueeze(0)
50
 
51
  if __name__ == "__main__":
52
- parser = argparse.ArgumentParser(description="Colorize manga images.")
53
- parser.add_argument("-i", "--input", type=str, required=True, help="Path to input image directory")
54
- parser.add_argument("-r", "--reference", type=str, required=True, help="Path to reference image")
55
- parser.add_argument("-o", "--output", type=str, required=True, help="Output directory")
56
- parser.add_argument("-ckpt", "--checkpoint", type=str, required=True, help="Path to model checkpoint")
57
-
58
  args = parser.parse_args()
59
 
60
  device = "cuda"
61
- input_image_dir = args.input
62
- output_directory = args.output
63
- ckpt_path = args.checkpoint
64
- reference_image_path = args.reference
 
 
65
  imgsize = 256
66
 
67
  ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
@@ -74,32 +77,39 @@ if __name__ == "__main__":
74
  colorUNet.load_state_dict(ckpt["colorUNet"])
75
  colorUNet.eval()
76
 
77
- img_name = os.path.splitext(os.path.basename(img_path))[0]
78
- img1 = Image.open(img_path).convert("RGB")
79
- width, height = img1.size
80
- img1, img1_lab = preprocessing(img1)
81
- img2, img2_lab = preprocessing(Image.open(reference_image_path).convert("RGB"))
 
 
 
 
 
 
82
 
83
- img1 = img1.to(device)
84
- img1_lab = img1_lab.to(device)
85
- img2 = img2.to(device)
86
- img2_lab = img2_lab.to(device)
87
 
88
- with torch.no_grad():
89
- img2_resize = F.interpolate(img2 / 255., size=(imgsize, imgsize), mode='bilinear', recompute_scale_factor=False, align_corners=False)
90
- img1_L_resize = F.interpolate(img1_lab[:,:1,:,:] / 50., size=(imgsize, imgsize), mode='bilinear', recompute_scale_factor=False, align_corners=False)
91
 
92
- color_vector = colorEncoder(img2_resize)
93
 
94
- fake_ab = colorUNet((img1_L_resize, color_vector))
95
- fake_ab = F.interpolate(fake_ab*110, size=(height, width), mode='bilinear', recompute_scale_factor=False, align_corners=False)
96
 
97
- fake_img = torch.cat((img1_lab[:,:1,:,:], fake_ab), 1)
98
- fake_img = Lab2RGB_out(fake_img)
99
 
100
- out_folder = os.path.dirname(img_path)
101
- mkdirs(out_folder)
102
- out_img_path = os.path.join(out_folder, f'{img_name}_color.png')
103
- io.imsave(out_img_path, fake_img)
 
104
 
105
- print(f'Colored image has been saved to {out_img_path}.')
 
1
  import os
2
  import numpy as np
3
  from skimage import color, io
4
+
5
  import torch
6
  import torch.nn.functional as F
7
+
8
  from PIL import Image
9
  from models import ColorEncoder, ColorUNet
10
  from extractor.manga_panel_extractor import PanelExtractor
 
22
  img_ab = img_lab[:,1:,:,:]
23
  img_l = img_l + 50
24
  pred_lab = torch.cat((img_l, img_ab), 1)[0,...].numpy()
25
+ out = (np.clip(color.lab2rgb(pred_lab.transpose(1, 2, 0)), 0, 1)* 255).astype("uint8")
26
  return out
27
 
28
  def RGB2Lab(inputs):
 
51
  return img.unsqueeze(0), img_lab.unsqueeze(0)
52
 
53
  if __name__ == "__main__":
54
+ parser = argparse.ArgumentParser(description="Colorize manga images based on a single reference image.")
55
+ parser.add_argument("-i", "--input_folder", type=str, required=True, help="Path to the input folder containing images to be colorized.")
56
+ parser.add_argument("-r", "--reference_image", type=str, required=True, help="Path to the reference image.")
57
+ parser.add_argument("-c", "--ckpt", type=str, required=True, help="Path to the model checkpoint file.")
58
+ parser.add_argument("-o", "--output_folder", type=str, required=True, help="Path to the output folder to save colorized images.")
 
59
  args = parser.parse_args()
60
 
61
  device = "cuda"
62
+
63
+ input_folder = args.input_folder
64
+ reference_image_path = args.reference_image
65
+ ckpt_path = args.ckpt
66
+ output_folder = args.output_folder
67
+
68
  imgsize = 256
69
 
70
  ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
 
77
  colorUNet.load_state_dict(ckpt["colorUNet"])
78
  colorUNet.eval()
79
 
80
+ # Recorre recursivamente el directorio de entrada y procesa cada imagen encontrada
81
+ for root, dirs, files in os.walk(input_folder):
82
+ for file in files:
83
+ if file.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')):
84
+ input_image_path = os.path.join(root, file)
85
+ img_name = os.path.splitext(os.path.basename(input_image_path))[0]
86
+
87
+ img1 = Image.open(input_image_path).convert("RGB")
88
+ width, height = img1.size
89
+ img1, img1_lab = preprocessing(img1)
90
+ img2, img2_lab = preprocessing(Image.open(reference_image_path).convert("RGB"))
91
 
92
+ img1 = img1.to(device)
93
+ img1_lab = img1_lab.to(device)
94
+ img2 = img2.to(device)
95
+ img2_lab = img2_lab.to(device)
96
 
97
+ with torch.no_grad():
98
+ img2_resize = F.interpolate(img2 / 255., size=(imgsize, imgsize), mode='bilinear', recompute_scale_factor=False, align_corners=False)
99
+ img1_L_resize = F.interpolate(img1_lab[:,:1,:,:] / 50., size=(imgsize, imgsize), mode='bilinear', recompute_scale_factor=False, align_corners=False)
100
 
101
+ color_vector = colorEncoder(img2_resize)
102
 
103
+ fake_ab = colorUNet((img1_L_resize, color_vector))
104
+ fake_ab = F.interpolate(fake_ab*110, size=(height, width), mode='bilinear', recompute_scale_factor=False, align_corners=False)
105
 
106
+ fake_img = torch.cat((img1_lab[:,:1,:,:], fake_ab), 1)
107
+ fake_img = Lab2RGB_out(fake_img)
108
 
109
+ out_subfolder = os.path.join(output_folder, os.path.relpath(root, input_folder))
110
+ out_folder = os.path.join(out_subfolder, 'color')
111
+ mkdirs(out_folder)
112
+ out_img_path = os.path.join(out_folder, f'{img_name}_color.png')
113
+ io.imsave(out_img_path, fake_img)
114
 
115
+ print(f'Colored images have been saved to {output_folder}.')