Keiser41 commited on
Commit
df23063
1 Parent(s): 0004858

Update pintar.py

Browse files
Files changed (1) hide show
  1. pintar.py +135 -37
pintar.py CHANGED
@@ -20,7 +20,7 @@ def Lab2RGB_out(img_lab):
20
  img_ab = img_lab[:,1:,:,:]
21
  img_l = img_l + 50
22
  pred_lab = torch.cat((img_l, img_ab), 1)[0,...].numpy()
23
- out = (np.clip(color.lab2rgb(pred_lab.transpose(1, 2, 0)), 0, 1) * 255).astype("uint8")
24
  return out
25
 
26
  def RGB2Lab(inputs):
@@ -34,11 +34,11 @@ def Normalize(inputs):
34
  return lab.astype('float32')
35
 
36
  def numpy2tensor(inputs):
37
- out = torch.from_numpy(inputs.transpose(2, 0, 1))
38
  return out
39
 
40
  def tensor2numpy(inputs):
41
- out = inputs[0, ...].detach().cpu().numpy().transpose(1, 2, 0)
42
  return out
43
 
44
  def preprocessing(inputs):
@@ -49,41 +49,139 @@ def preprocessing(inputs):
49
  return img.unsqueeze(0), img_lab.unsqueeze(0)
50
 
51
  if __name__ == "__main__":
52
- parser = argparse.ArgumentParser()
53
- parser.add_argument("-r", "--reference", type=str, help="ruta de la imagen de referencia")
54
- parser.add_argument("-o", "--output", type=str, help="carpeta de salida para las imágenes coloreadas")
55
- parser.add_argument("-ckpt", "--model_checkpoint", type=str, help="ruta del modelo de checkpoint")
56
- args = parser.parse_args()
57
-
58
  device = "cuda"
59
 
60
- ckpt_path = args.model_checkpoint or 'experiments/Color2Manga_gray/074000_gray.pt'
61
- test_dir_path = 'test_datasets/gray_test'
62
- no_extractor = False
63
-
64
- # ... (resto del código)
65
-
66
- while True:
67
- # ... (resto del código)
68
-
69
- with torch.no_grad():
70
- img2_resize = F.interpolate(img2 / 255., size=(256, 256), mode='bilinear', recompute_scale_factor=False, align_corners=False)
71
- img1_L_resize = F.interpolate(img1_lab[:,:1,:,:] / 50., size=(256, 256), mode='bilinear', recompute_scale_factor=False, align_corners=False)
72
-
73
- color_vector = colorEncoder(img2_resize)
74
-
75
- fake_ab = colorUNet((img1_L_resize, color_vector))
76
- fake_ab = F.interpolate(fake_ab * 110, size=(height, width), mode='bilinear', recompute_scale_factor=False, align_corners=False)
77
-
78
- fake_img = torch.cat((img1_lab[:,:1,:,:], fake_ab), 1)
79
- fake_img = Lab2RGB_out(fake_img)
80
 
81
- out_folder = os.path.join(output_folder, 'color')
82
- if not os.path.exists(out_folder):
83
- os.makedirs(out_folder)
84
- out_img_path = os.path.join(out_folder, f'{img_name}_color.png')
85
 
86
- # show image
87
- Image.fromarray(fake_img).show()
88
- # save image
89
- io.imsave(out_img_path, fake_img)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  img_ab = img_lab[:,1:,:,:]
21
  img_l = img_l + 50
22
  pred_lab = torch.cat((img_l, img_ab), 1)[0,...].numpy()
23
+ out = (np.clip(color.lab2rgb(pred_lab.transpose(1, 2, 0)), 0, 1)* 255).astype("uint8")
24
  return out
25
 
26
  def RGB2Lab(inputs):
 
34
  return lab.astype('float32')
35
 
36
  def numpy2tensor(inputs):
37
+ out = torch.from_numpy(inputs.transpose(2,0,1))
38
  return out
39
 
40
  def tensor2numpy(inputs):
41
+ out = inputs[0,...].detach().cpu().numpy().transpose(1,2,0)
42
  return out
43
 
44
  def preprocessing(inputs):
 
49
  return img.unsqueeze(0), img_lab.unsqueeze(0)
50
 
51
  if __name__ == "__main__":
 
 
 
 
 
 
52
  device = "cuda"
53
 
54
+ parser = argparse.ArgumentParser()
55
+ parser.add_argument("--path", type=str, default=None, help="path of input image")
56
+ parser.add_argument("--size", type=int, default=None)
57
+ parser.add_argument("--ckpt", type=str, default=None, help="path of model weight")
58
+ parser.add_argument("-ne", "--no_extractor", action='store_true', help="Do not segment the manga panels.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
+ args = parser.parse_args()
 
 
 
61
 
62
+ if args.path:
63
+ test_dir_path = args.path
64
+ if args.size:
65
+ imgsize = args.size
66
+ if args.ckpt:
67
+ ckpt_path = args.ckpt
68
+ if args.no_extractor:
69
+ no_extractor = args.no_extractor
70
+
71
+ ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
72
+
73
+ colorEncoder = ColorEncoder().to(device)
74
+ colorEncoder.load_state_dict(ckpt["colorEncoder"])
75
+ colorEncoder.eval()
76
+
77
+ colorUNet = ColorUNet().to(device)
78
+ colorUNet.load_state_dict(ckpt["colorUNet"])
79
+ colorUNet.eval()
80
+
81
+ imgs = []
82
+ imgs_lab = []
83
+
84
+ while 1:
85
+ print(f'make sure both manga image and reference images are under this path {test_dir_path}')
86
+ img_path = input("please input the name of image needed to be colorized (with file extension): ")
87
+ img_path = os.path.join(test_dir_path, img_path)
88
+ img_name = os.path.basename(img_path)
89
+ img_name = os.path.splitext(img_name)[0]
90
+
91
+ if no_extractor:
92
+ ref_img_path = os.path.join(test_dir_path, input(f"Enter the reference image path: "))
93
+
94
+ img1 = Image.open(img_path).convert("RGB")
95
+ width, height = img1.size
96
+ img2 = Image.open(ref_img_path).convert("RGB")
97
+
98
+ img1, img1_lab = preprocessing(img1)
99
+ img2, img2_lab = preprocessing(img2)
100
+
101
+ img1 = img1.to(device)
102
+ img1_lab = img1_lab.to(device)
103
+ img2 = img2.to(device)
104
+ img2_lab = img2_lab.to(device)
105
+
106
+ with torch.no_grad():
107
+ img2_resize = F.interpolate(img2 / 255., size=(imgsize, imgsize), mode='bilinear',
108
+ recompute_scale_factor=False, align_corners=False)
109
+ img1_L_resize = F.interpolate(img1_lab[:, :1, :, :] / 50., size=(imgsize, imgsize), mode='bilinear',
110
+ recompute_scale_factor=False, align_corners=False)
111
+
112
+ color_vector = colorEncoder(img2_resize)
113
+
114
+ fake_ab = colorUNet((img1_L_resize, color_vector))
115
+ fake_ab = F.interpolate(fake_ab * 110, size=(height, width), mode='bilinear',
116
+ recompute_scale_factor=False, align_corners=False)
117
+
118
+ fake_img = torch.cat((img1_lab[:, :1, :, :], fake_ab), 1)
119
+ fake_img = Lab2RGB_out(fake_img)
120
+
121
+ out_folder = os.path.dirname(img_path)
122
+ out_name = os.path.basename(img_path)
123
+ out_name = os.path.splitext(out_name)[0]
124
+ out_img_path = os.path.join(out_folder, 'color', f'{out_name}_color.png')
125
+
126
+ # show image
127
+ Image.fromarray(fake_img).show()
128
+ # save image
129
+ folder_path = os.path.join(out_folder, 'color')
130
+ if not os.path.exists(folder_path):
131
+ os.makedirs(folder_path)
132
+ io.imsave(out_img_path, fake_img)
133
+
134
+ continue
135
+
136
+ panel_extractor = PanelExtractor(min_pct_panel=5, max_pct_panel=90)
137
+ panels, masks, panel_masks = panel_extractor.extract(img_path)
138
+ panel_num = len(panels)
139
+
140
+ ref_img_paths = []
141
+ print("Please enter the name of the reference image in order according to the number prompts on the picture")
142
+ for i in range(panel_num):
143
+ ref_img_path = os.path.join(test_dir_path, input(f"{i+1}/{panel_num} reference image:"))
144
+ ref_img_paths.append(ref_img_path)
145
+
146
+ fake_imgs = []
147
+ for i in range(panel_num):
148
+ img1 = Image.fromarray(panels[i]).convert("RGB")
149
+ width, height = img1.size
150
+ img2 = Image.open(ref_img_paths[i]).convert("RGB")
151
+
152
+ img1, img1_lab = preprocessing(img1)
153
+ img2, img2_lab = preprocessing(img2)
154
+
155
+ img1 = img1.to(device)
156
+ img1_lab = img1_lab.to(device)
157
+ img2 = img2.to(device)
158
+ img2_lab = img2_lab.to(device)
159
+
160
+ with torch.no_grad():
161
+ img2_resize = F.interpolate(img2 / 255., size=(imgsize, imgsize), mode='bilinear', recompute_scale_factor=False, align_corners=False)
162
+ img1_L_resize = F.interpolate(img1_lab[:,:1,:,:] / 50., size=(imgsize, imgsize), mode='bilinear', recompute_scale_factor=False, align_corners=False)
163
+
164
+ color_vector = colorEncoder(img2_resize)
165
+
166
+ fake_ab = colorUNet((img1_L_resize, color_vector))
167
+ fake_ab = F.interpolate(fake_ab*110, size=(height, width), mode='bilinear', recompute_scale_factor=False, align_corners=False)
168
+
169
+ fake_img = torch.cat((img1_lab[:,:1,:,:], fake_ab), 1)
170
+ fake_img = Lab2RGB_out(fake_img)
171
+ fake_imgs.append(fake_img)
172
+
173
+ if panel_num == 1:
174
+ out_folder = os.path.dirname(img_path)
175
+ out_name = os.path.basename(img_path)
176
+ out_name = os.path.splitext(out_name)[0]
177
+ out_img_path = os.path.join(out_folder,'color',f'{out_name}_color.png')
178
+
179
+ Image.fromarray(fake_imgs[0]).show()
180
+ folder_path = os.path.join(out_folder, 'color')
181
+ if not os.path.exists(folder_path):
182
+ os.makedirs(folder_path)
183
+ io.imsave(out_img_path, fake_imgs[0])
184
+ else:
185
+ panel_extractor.concatPanels(img_path, fake_imgs, masks, panel_masks)
186
+
187
+ print(f'Colored images have been saved to: {os.path.join(test_dir_path, "color")}')