Keiser41 commited on
Commit
063c371
1 Parent(s): 5502c9b

Update pintar.py

Browse files
Files changed (1) hide show
  1. pintar.py +86 -36
pintar.py CHANGED
@@ -1,11 +1,12 @@
1
  import os
2
- import argparse
3
  import numpy as np
4
  from skimage import color, io
5
  import torch
6
  import torch.nn.functional as F
7
  from PIL import Image
8
  from models import ColorEncoder, ColorUNet
 
 
9
 
10
  os.environ["CUDA_VISIBLE_DEVICES"] = '0'
11
 
@@ -29,8 +30,6 @@ def Normalize(inputs):
29
  l = inputs[:, :, 0:1]
30
  ab = inputs[:, :, 1:3]
31
  l = l - 50
32
- l = l / 50 # Normalizar L al rango [-1, 1]
33
- ab = ab / 110 # Normalizar ab al rango [-1, 1]
34
  lab = np.concatenate((l, ab), 2)
35
  return lab.astype('float32')
36
 
@@ -38,10 +37,20 @@ def numpy2tensor(inputs):
38
  out = torch.from_numpy(inputs.transpose(2,0,1))
39
  return out
40
 
 
 
 
 
 
 
 
 
 
 
 
41
  if __name__ == "__main__":
42
  parser = argparse.ArgumentParser(description="Colorize manga images.")
43
  parser.add_argument("-i", "--input_folder", type=str, required=True, help="Path to the input folder containing manga images.")
44
- parser.add_argument("-r", "--reference_image", type=str, required=True, help="Path to the reference image for colorization.")
45
  parser.add_argument("-ckpt", "--model_checkpoint", type=str, required=True, help="Path to the model checkpoint file.")
46
  parser.add_argument("-o", "--output_folder", type=str, required=True, help="Path to the output folder where colorized images will be saved.")
47
  parser.add_argument("-ne", "--no_extractor", action="store_true", help="Do not segment the manga panels.")
@@ -59,41 +68,82 @@ if __name__ == "__main__":
59
  colorUNet.load_state_dict(ckpt["colorUNet"])
60
  colorUNet.eval()
61
 
62
- reference_img = Image.open(args.reference_image).convert("RGB")
63
- reference_img = np.array(reference_img).astype(np.float32) / 255.0 # Asegúrate de que la referencia esté en el rango [0, 1]
64
- reference_img_lab = RGB2Lab(reference_img)
65
- reference_img_lab = Normalize(reference_img_lab)
66
- reference_img_lab = numpy2tensor(reference_img_lab)
67
- reference_img_lab = reference_img_lab.to(device).unsqueeze(0)
68
-
69
- for root, dirs, files in os.walk(args.input_folder):
70
- for file in files:
71
- if file.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')):
72
- input_image_path = os.path.join(root, file)
73
-
74
- img = Image.open(input_image_path).convert("RGB")
75
- img = np.array(img).astype(np.float32) / 255.0 # Asegúrate de que la imagen de entrada esté en el rango [0, 1]
76
- img_lab = RGB2Lab(img)
77
- img_lab = Normalize(img_lab)
78
- img_lab = numpy2tensor(img_lab)
79
- img_lab = img_lab.to(device).unsqueeze(0)
 
 
80
 
81
  with torch.no_grad():
82
- img_resize = F.interpolate(img_lab, size=(256, 256), mode='bilinear', align_corners=False)
83
- img_L_resize = F.interpolate(img_resize[:, :1, :, :], size=(256, 256), mode='bilinear', align_corners=False)
84
 
85
- color_vector = colorEncoder(img_resize)
86
- fake_ab = colorUNet((img_L_resize, color_vector))
87
- fake_ab = F.interpolate(fake_ab, size=(img.shape[0], img.shape[1]), mode='bilinear', align_corners=False)
88
 
89
- fake_img = torch.cat((img_lab[:, :1, :, :], fake_ab), 1)
90
  fake_img = Lab2RGB_out(fake_img)
91
- fake_img = (fake_img * 255).astype(np.uint8) # Convierte de nuevo a [0, 255]
92
-
93
- relative_path = os.path.relpath(input_image_path, args.input_folder)
94
- output_subfolder = os.path.join(args.output_folder, os.path.dirname(relative_path), 'color')
95
- mkdirs(output_subfolder)
96
- output_image_path = os.path.join(output_subfolder, f'{os.path.splitext(os.path.basename(input_image_path))[0]}_colorized.png')
97
- io.imsave(output_image_path, fake_img)
98
 
99
- print(f'Colored images have been saved to: {args.output_folder}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
 
2
  import numpy as np
3
  from skimage import color, io
4
  import torch
5
  import torch.nn.functional as F
6
  from PIL import Image
7
  from models import ColorEncoder, ColorUNet
8
+ from extractor.manga_panel_extractor import PanelExtractor
9
+ import argparse
10
 
11
  os.environ["CUDA_VISIBLE_DEVICES"] = '0'
12
 
 
30
  l = inputs[:, :, 0:1]
31
  ab = inputs[:, :, 1:3]
32
  l = l - 50
 
 
33
  lab = np.concatenate((l, ab), 2)
34
  return lab.astype('float32')
35
 
 
37
  out = torch.from_numpy(inputs.transpose(2,0,1))
38
  return out
39
 
40
+ def tensor2numpy(inputs):
41
+ out = inputs[0,...].detach().cpu().numpy().transpose(1,2,0)
42
+ return out
43
+
44
+ def preprocessing(inputs):
45
+ img_lab = Normalize(RGB2Lab(inputs))
46
+ img = np.array(inputs, 'float32')
47
+ img = numpy2tensor(img)
48
+ img_lab = numpy2tensor(img_lab)
49
+ return img.unsqueeze(0), img_lab.unsqueeze(0)
50
+
51
  if __name__ == "__main__":
52
  parser = argparse.ArgumentParser(description="Colorize manga images.")
53
  parser.add_argument("-i", "--input_folder", type=str, required=True, help="Path to the input folder containing manga images.")
 
54
  parser.add_argument("-ckpt", "--model_checkpoint", type=str, required=True, help="Path to the model checkpoint file.")
55
  parser.add_argument("-o", "--output_folder", type=str, required=True, help="Path to the output folder where colorized images will be saved.")
56
  parser.add_argument("-ne", "--no_extractor", action="store_true", help="Do not segment the manga panels.")
 
68
  colorUNet.load_state_dict(ckpt["colorUNet"])
69
  colorUNet.eval()
70
 
71
+ input_files = os.listdir(args.input_folder)
72
+
73
+ for input_file in input_files:
74
+ input_path = os.path.join(args.input_folder, input_file)
75
+
76
+ if os.path.isfile(input_path):
77
+ if args.no_extractor:
78
+ ref_img_path = input("Please enter the path of the reference image: ")
79
+
80
+ img1 = Image.open(ref_img_path).convert("RGB")
81
+ width, height = img1.size
82
+ img2 = Image.open(input_path).convert("RGB")
83
+
84
+ img1, img1_lab = preprocessing(img1)
85
+ img2, img2_lab = preprocessing(img2)
86
+
87
+ img1 = img1.to(device)
88
+ img1_lab = img1_lab.to(device)
89
+ img2 = img2.to(device)
90
+ img2_lab = img2_lab.to(device)
91
 
92
  with torch.no_grad():
93
+ img2_resize = F.interpolate(img2 / 255., size=(256, 256), mode='bilinear', recompute_scale_factor=False, align_corners=False)
94
+ img1_L_resize = F.interpolate(img1_lab[:, :1, :, :] / 50., size=(256, 256), mode='bilinear', recompute_scale_factor=False, align_corners=False)
95
 
96
+ color_vector = colorEncoder(img2_resize)
97
+ fake_ab = colorUNet((img1_L_resize, color_vector))
98
+ fake_ab = F.interpolate(fake_ab * 110, size=(height, width), mode='bilinear', recompute_scale_factor=False, align_corners=False)
99
 
100
+ fake_img = torch.cat((img1_lab[:, :1, :, :], fake_ab), 1)
101
  fake_img = Lab2RGB_out(fake_img)
 
 
 
 
 
 
 
102
 
103
+ out_folder = os.path.join(args.output_folder, 'color')
104
+ mkdirs(out_folder)
105
+ out_img_path = os.path.join(out_folder, f'{os.path.splitext(input_file)[0]}_color.png')
106
+ io.imsave(out_img_path, fake_img)
107
+
108
+ else:
109
+ panel_extractor = PanelExtractor(min_pct_panel=5, max_pct_panel=90) # You might need to adjust these parameters
110
+ panels, masks, panel_masks = panel_extractor.extract(input_path)
111
+
112
+ ref_img_paths = []
113
+ print("Please enter the name of the reference image in order according to the number prompts on the picture")
114
+ for i in range(len(panels)):
115
+ ref_img_path = input(f"{i+1}/{len(panels)} reference image:")
116
+ ref_img_paths.append(ref_img_path)
117
+
118
+ fake_imgs = []
119
+ for i in range(len(panels)):
120
+ img1 = Image.fromarray(panels[i]).convert("RGB")
121
+ width, height = img1.size
122
+ img2 = Image.open(ref_img_paths[i]).convert("RGB")
123
+
124
+ img1, img1_lab = preprocessing(img1)
125
+ img2, img2_lab = preprocessing(img2)
126
+
127
+ img1 = img1.to(device)
128
+ img1_lab = img1_lab.to(device)
129
+ img2 = img2.to(device)
130
+ img2_lab = img2_lab.to(device)
131
+
132
+ with torch.no_grad():
133
+ img2_resize = F.interpolate(img2 / 255., size=(256, 256), mode='bilinear', recompute_scale_factor=False, align_corners=False)
134
+ img1_L_resize = F.interpolate(img1_lab[:,:1,:,:] / 50., size=(256, 256), mode='bilinear', recompute_scale_factor=False, align_corners=False)
135
+
136
+ color_vector = colorEncoder(img2_resize)
137
+
138
+ fake_ab = colorUNet((img1_L_resize, color_vector))
139
+ fake_ab = F.interpolate(fake_ab*110, size=(height, width), mode='bilinear', recompute_scale_factor=False, align_corners=False)
140
+
141
+ fake_img = torch.cat((img1_lab[:,:1,:,:], fake_ab), 1)
142
+ fake_img = Lab2RGB_out(fake_img)
143
+
144
+ out_folder = os.path.join(args.output_folder, 'color')
145
+ mkdirs(out_folder)
146
+ out_img_path = os.path.join(out_folder, f'{os.path.splitext(input_file)[0]}_panel_{i}_color.png')
147
+ io.imsave(out_img_path, fake_img)
148
+
149
+ print(f'Colored images have been saved to: {os.path.join(args.output_folder, "color")}')