Keiser41 commited on
Commit
8ed6625
1 Parent(s): f6df16f

Update pintar.py

Browse files
Files changed (1) hide show
  1. pintar.py +13 -8
pintar.py CHANGED
@@ -1,13 +1,12 @@
1
  import os
2
  import numpy as np
3
  from skimage import color, io
4
-
5
  import torch
6
  import torch.nn.functional as F
7
-
8
  from PIL import Image
9
  from models import ColorEncoder, ColorUNet
10
  from extractor.manga_panel_extractor import PanelExtractor
 
11
 
12
  os.environ["CUDA_VISIBLE_DEVICES"] = '0'
13
 
@@ -21,7 +20,7 @@ def Lab2RGB_out(img_lab):
21
  img_ab = img_lab[:,1:,:,:]
22
  img_l = img_l + 50
23
  pred_lab = torch.cat((img_l, img_ab), 1)[0,...].numpy()
24
- out = (np.clip(color.lab2rgb(pred_lab.transpose(1, 2, 0)), 0, 1)* 255).astype("uint8")
25
  return out
26
 
27
  def RGB2Lab(inputs):
@@ -50,13 +49,19 @@ def preprocessing(inputs):
50
  return img.unsqueeze(0), img_lab.unsqueeze(0)
51
 
52
  if __name__ == "__main__":
53
- device = "cuda"
 
 
 
 
54
 
55
- # Specify the paths here
56
- img_path = 'path/to/your/input/image.jpg'
57
- ckpt_path = 'path/to/your/model_checkpoint.pt'
58
- reference_image_path = 'path/to/your/reference/image.jpg'
59
 
 
 
 
 
 
60
  imgsize = 256
61
 
62
  ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
 
1
  import os
2
  import numpy as np
3
  from skimage import color, io
 
4
  import torch
5
  import torch.nn.functional as F
 
6
  from PIL import Image
7
  from models import ColorEncoder, ColorUNet
8
  from extractor.manga_panel_extractor import PanelExtractor
9
+ import argparse
10
 
11
  os.environ["CUDA_VISIBLE_DEVICES"] = '0'
12
 
 
20
  img_ab = img_lab[:,1:,:,:]
21
  img_l = img_l + 50
22
  pred_lab = torch.cat((img_l, img_ab), 1)[0,...].numpy()
23
+ out = (np.clip(color.lab2rgb(pred_lab.transpose(1, 2, 0)), 0, 1) * 255).astype("uint8")
24
  return out
25
 
26
  def RGB2Lab(inputs):
 
49
  return img.unsqueeze(0), img_lab.unsqueeze(0)
50
 
51
  if __name__ == "__main__":
52
+ parser = argparse.ArgumentParser(description="Colorize manga images.")
53
+ parser.add_argument("-i", "--input", type=str, required=True, help="Path to input image directory")
54
+ parser.add_argument("-r", "--reference", type=str, required=True, help="Path to reference image")
55
+ parser.add_argument("-o", "--output", type=str, required=True, help="Output directory")
56
+ parser.add_argument("-ckpt", "--checkpoint", type=str, required=True, help="Path to model checkpoint")
57
 
58
+ args = parser.parse_args()
 
 
 
59
 
60
+ device = "cuda"
61
+ input_image_dir = args.input
62
+ output_directory = args.output
63
+ model_checkpoint_path = args.checkpoint
64
+ reference_image_path = args.reference
65
  imgsize = 256
66
 
67
  ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage)