Keiser41 commited on
Commit
f6df16f
1 Parent(s): 063c371

Update pintar.py

Browse files
Files changed (1) hide show
  1. pintar.py +33 -82
pintar.py CHANGED
@@ -1,12 +1,13 @@
1
  import os
2
  import numpy as np
3
  from skimage import color, io
 
4
  import torch
5
  import torch.nn.functional as F
 
6
  from PIL import Image
7
  from models import ColorEncoder, ColorUNet
8
  from extractor.manga_panel_extractor import PanelExtractor
9
- import argparse
10
 
11
  os.environ["CUDA_VISIBLE_DEVICES"] = '0'
12
 
@@ -20,7 +21,7 @@ def Lab2RGB_out(img_lab):
20
  img_ab = img_lab[:,1:,:,:]
21
  img_l = img_l + 50
22
  pred_lab = torch.cat((img_l, img_ab), 1)[0,...].numpy()
23
- out = (np.clip(color.lab2rgb(pred_lab.transpose(1, 2, 0)), 0, 1) * 255).astype("uint8")
24
  return out
25
 
26
  def RGB2Lab(inputs):
@@ -49,16 +50,16 @@ def preprocessing(inputs):
49
  return img.unsqueeze(0), img_lab.unsqueeze(0)
50
 
51
  if __name__ == "__main__":
52
- parser = argparse.ArgumentParser(description="Colorize manga images.")
53
- parser.add_argument("-i", "--input_folder", type=str, required=True, help="Path to the input folder containing manga images.")
54
- parser.add_argument("-ckpt", "--model_checkpoint", type=str, required=True, help="Path to the model checkpoint file.")
55
- parser.add_argument("-o", "--output_folder", type=str, required=True, help="Path to the output folder where colorized images will be saved.")
56
- parser.add_argument("-ne", "--no_extractor", action="store_true", help="Do not segment the manga panels.")
57
- args = parser.parse_args()
58
-
59
  device = "cuda"
60
 
61
- ckpt = torch.load(args.model_checkpoint, map_location=lambda storage, loc: storage)
 
 
 
 
 
 
 
62
 
63
  colorEncoder = ColorEncoder().to(device)
64
  colorEncoder.load_state_dict(ckpt["colorEncoder"])
@@ -68,82 +69,32 @@ if __name__ == "__main__":
68
  colorUNet.load_state_dict(ckpt["colorUNet"])
69
  colorUNet.eval()
70
 
71
- input_files = os.listdir(args.input_folder)
72
-
73
- for input_file in input_files:
74
- input_path = os.path.join(args.input_folder, input_file)
75
-
76
- if os.path.isfile(input_path):
77
- if args.no_extractor:
78
- ref_img_path = input("Please enter the path of the reference image: ")
79
-
80
- img1 = Image.open(ref_img_path).convert("RGB")
81
- width, height = img1.size
82
- img2 = Image.open(input_path).convert("RGB")
83
-
84
- img1, img1_lab = preprocessing(img1)
85
- img2, img2_lab = preprocessing(img2)
86
-
87
- img1 = img1.to(device)
88
- img1_lab = img1_lab.to(device)
89
- img2 = img2.to(device)
90
- img2_lab = img2_lab.to(device)
91
-
92
- with torch.no_grad():
93
- img2_resize = F.interpolate(img2 / 255., size=(256, 256), mode='bilinear', recompute_scale_factor=False, align_corners=False)
94
- img1_L_resize = F.interpolate(img1_lab[:, :1, :, :] / 50., size=(256, 256), mode='bilinear', recompute_scale_factor=False, align_corners=False)
95
-
96
- color_vector = colorEncoder(img2_resize)
97
- fake_ab = colorUNet((img1_L_resize, color_vector))
98
- fake_ab = F.interpolate(fake_ab * 110, size=(height, width), mode='bilinear', recompute_scale_factor=False, align_corners=False)
99
-
100
- fake_img = torch.cat((img1_lab[:, :1, :, :], fake_ab), 1)
101
- fake_img = Lab2RGB_out(fake_img)
102
-
103
- out_folder = os.path.join(args.output_folder, 'color')
104
- mkdirs(out_folder)
105
- out_img_path = os.path.join(out_folder, f'{os.path.splitext(input_file)[0]}_color.png')
106
- io.imsave(out_img_path, fake_img)
107
-
108
- else:
109
- panel_extractor = PanelExtractor(min_pct_panel=5, max_pct_panel=90) # You might need to adjust these parameters
110
- panels, masks, panel_masks = panel_extractor.extract(input_path)
111
-
112
- ref_img_paths = []
113
- print("Please enter the name of the reference image in order according to the number prompts on the picture")
114
- for i in range(len(panels)):
115
- ref_img_path = input(f"{i+1}/{len(panels)} reference image:")
116
- ref_img_paths.append(ref_img_path)
117
-
118
- fake_imgs = []
119
- for i in range(len(panels)):
120
- img1 = Image.fromarray(panels[i]).convert("RGB")
121
- width, height = img1.size
122
- img2 = Image.open(ref_img_paths[i]).convert("RGB")
123
-
124
- img1, img1_lab = preprocessing(img1)
125
- img2, img2_lab = preprocessing(img2)
126
 
127
- img1 = img1.to(device)
128
- img1_lab = img1_lab.to(device)
129
- img2 = img2.to(device)
130
- img2_lab = img2_lab.to(device)
131
 
132
- with torch.no_grad():
133
- img2_resize = F.interpolate(img2 / 255., size=(256, 256), mode='bilinear', recompute_scale_factor=False, align_corners=False)
134
- img1_L_resize = F.interpolate(img1_lab[:,:1,:,:] / 50., size=(256, 256), mode='bilinear', recompute_scale_factor=False, align_corners=False)
135
 
136
- color_vector = colorEncoder(img2_resize)
137
 
138
- fake_ab = colorUNet((img1_L_resize, color_vector))
139
- fake_ab = F.interpolate(fake_ab*110, size=(height, width), mode='bilinear', recompute_scale_factor=False, align_corners=False)
140
 
141
- fake_img = torch.cat((img1_lab[:,:1,:,:], fake_ab), 1)
142
- fake_img = Lab2RGB_out(fake_img)
143
 
144
- out_folder = os.path.join(args.output_folder, 'color')
145
- mkdirs(out_folder)
146
- out_img_path = os.path.join(out_folder, f'{os.path.splitext(input_file)[0]}_panel_{i}_color.png')
147
- io.imsave(out_img_path, fake_img)
148
 
149
- print(f'Colored images have been saved to: {os.path.join(args.output_folder, "color")}')
 
1
  import os
2
  import numpy as np
3
  from skimage import color, io
4
+
5
  import torch
6
  import torch.nn.functional as F
7
+
8
  from PIL import Image
9
  from models import ColorEncoder, ColorUNet
10
  from extractor.manga_panel_extractor import PanelExtractor
 
11
 
12
  os.environ["CUDA_VISIBLE_DEVICES"] = '0'
13
 
 
21
  img_ab = img_lab[:,1:,:,:]
22
  img_l = img_l + 50
23
  pred_lab = torch.cat((img_l, img_ab), 1)[0,...].numpy()
24
+ out = (np.clip(color.lab2rgb(pred_lab.transpose(1, 2, 0)), 0, 1)* 255).astype("uint8")
25
  return out
26
 
27
  def RGB2Lab(inputs):
 
50
  return img.unsqueeze(0), img_lab.unsqueeze(0)
51
 
52
  if __name__ == "__main__":
 
 
 
 
 
 
 
53
  device = "cuda"
54
 
55
+ # Specify the paths here
56
+ img_path = 'path/to/your/input/image.jpg'
57
+ ckpt_path = 'path/to/your/model_checkpoint.pt'
58
+ reference_image_path = 'path/to/your/reference/image.jpg'
59
+
60
+ imgsize = 256
61
+
62
+ ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
63
 
64
  colorEncoder = ColorEncoder().to(device)
65
  colorEncoder.load_state_dict(ckpt["colorEncoder"])
 
69
  colorUNet.load_state_dict(ckpt["colorUNet"])
70
  colorUNet.eval()
71
 
72
+ img_name = os.path.splitext(os.path.basename(img_path))[0]
73
+ img1 = Image.open(img_path).convert("RGB")
74
+ width, height = img1.size
75
+ img1, img1_lab = preprocessing(img1)
76
+ img2, img2_lab = preprocessing(Image.open(reference_image_path).convert("RGB"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
+ img1 = img1.to(device)
79
+ img1_lab = img1_lab.to(device)
80
+ img2 = img2.to(device)
81
+ img2_lab = img2_lab.to(device)
82
 
83
+ with torch.no_grad():
84
+ img2_resize = F.interpolate(img2 / 255., size=(imgsize, imgsize), mode='bilinear', recompute_scale_factor=False, align_corners=False)
85
+ img1_L_resize = F.interpolate(img1_lab[:,:1,:,:] / 50., size=(imgsize, imgsize), mode='bilinear', recompute_scale_factor=False, align_corners=False)
86
 
87
+ color_vector = colorEncoder(img2_resize)
88
 
89
+ fake_ab = colorUNet((img1_L_resize, color_vector))
90
+ fake_ab = F.interpolate(fake_ab*110, size=(height, width), mode='bilinear', recompute_scale_factor=False, align_corners=False)
91
 
92
+ fake_img = torch.cat((img1_lab[:,:1,:,:], fake_ab), 1)
93
+ fake_img = Lab2RGB_out(fake_img)
94
 
95
+ out_folder = os.path.dirname(img_path)
96
+ mkdirs(out_folder)
97
+ out_img_path = os.path.join(out_folder, f'{img_name}_color.png')
98
+ io.imsave(out_img_path, fake_img)
99
 
100
+ print(f'Colored image has been saved to {out_img_path}.')