DmitrMakeev commited on
Commit
39946a9
1 Parent(s): a3e6b3a

Upload 9 files

Browse files
Files changed (9) hide show
  1. .gitattributes +33 -0
  2. LICENSE +21 -0
  3. README.md +13 -0
  4. app.py +91 -0
  5. cog.yaml +41 -0
  6. environment.yml +121 -0
  7. inference.py +105 -0
  8. predict.py +104 -0
  9. requirements.txt +17 -0
.gitattributes ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
11
+ *.model filter=lfs diff=lfs merge=lfs -text
12
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
13
+ *.npy filter=lfs diff=lfs merge=lfs -text
14
+ *.npz filter=lfs diff=lfs merge=lfs -text
15
+ *.onnx filter=lfs diff=lfs merge=lfs -text
16
+ *.ot filter=lfs diff=lfs merge=lfs -text
17
+ *.parquet filter=lfs diff=lfs merge=lfs -text
18
+ *.pb filter=lfs diff=lfs merge=lfs -text
19
+ *.pickle filter=lfs diff=lfs merge=lfs -text
20
+ *.pkl filter=lfs diff=lfs merge=lfs -text
21
+ *.pt filter=lfs diff=lfs merge=lfs -text
22
+ *.pth filter=lfs diff=lfs merge=lfs -text
23
+ *.rar filter=lfs diff=lfs merge=lfs -text
24
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
25
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
26
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
27
+ *.tflite filter=lfs diff=lfs merge=lfs -text
28
+ *.tgz filter=lfs diff=lfs merge=lfs -text
29
+ *.wasm filter=lfs diff=lfs merge=lfs -text
30
+ *.xz filter=lfs diff=lfs merge=lfs -text
31
+ *.zip filter=lfs diff=lfs merge=lfs -text
32
+ *.zst filter=lfs diff=lfs merge=lfs -text
33
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 Menghan Xia
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: colorizator
3
+ emoji: 🐢
4
+ colorFrom: pink
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 3.9
8
+ app_file: app.py
9
+ pinned: false
10
+ license: openrail
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os, requests
3
+ import numpy as np
4
+ from inference import setup_model, colorize_grayscale, predict_anchors
5
+
6
+ ## local | remote
7
+ RUN_MODE = "remote"
8
+ if RUN_MODE != "local":
9
+ os.system("wget https://huggingface.co/menghanxia/disco/resolve/main/disco-beta.pth.rar")
10
+ os.rename("disco-beta.pth.rar", "./checkpoints/disco-beta.pth.rar")
11
+ ## examples
12
+ os.system("wget https://huggingface.co/menghanxia/disco/resolve/main/01.jpg")
13
+ os.system("wget https://huggingface.co/menghanxia/disco/resolve/main/02.jpg")
14
+ os.system("wget https://huggingface.co/menghanxia/disco/resolve/main/03.jpg")
15
+ os.system("wget https://huggingface.co/menghanxia/disco/resolve/main/04.jpg")
16
+
17
+ ## step 1: set up model
18
+ device = "cpu"
19
+ checkpt_path = "checkpoints/disco-beta.pth.rar"
20
+ colorizer, colorLabeler = setup_model(checkpt_path, device=device)
21
+
22
+ def click_colorize(rgb_img, hint_img, n_anchors, is_high_res, is_editable):
23
+ if hint_img is None:
24
+ hint_img = rgb_img
25
+ output = colorize_grayscale(colorizer, colorLabeler, rgb_img, hint_img, n_anchors, True, is_editable, device)
26
+ output1 = colorize_grayscale(colorizer, colorLabeler, rgb_img, hint_img, n_anchors, False, is_editable, device)
27
+ return output, output1
28
+
29
+ def click_predanchors(rgb_img, n_anchors, is_high_res, is_editable):
30
+ output = predict_anchors(colorizer, colorLabeler, rgb_img, n_anchors, is_high_res, is_editable, device)
31
+ return output
32
+
33
+ ## step 2: configure interface
34
+ def switch_states(is_checked):
35
+ if is_checked:
36
+ return gr.Image.update(visible=True), gr.Button.update(visible=True)
37
+ else:
38
+ return gr.Image.update(visible=False), gr.Button.update(visible=False)
39
+
40
+ demo = gr.Blocks(title="DISCO")
41
+ with demo:
42
+ gr.Markdown(value="""
43
+ **Gradio demo for DISCO: Disentangled Image Colorization via Global Anchors**. Check our [project page](https://menghanxia.github.io/projects/disco.html) 😛.
44
+ """)
45
+ with gr.Row():
46
+ with gr.Column():
47
+ with gr.Row():
48
+ Image_input = gr.Image(type="numpy", label="Input", interactive=True)
49
+ Image_anchor = gr.Image(type="numpy", label="Anchor", tool="color-sketch", interactive=True, visible=False)
50
+ with gr.Row():
51
+ Num_anchor = gr.Number(type="int", value=8, label="Num. of anchors (3~14)")
52
+ Radio_resolution = gr.Radio(type="index", choices=["Low (256x256)", "High (512x512)"], \
53
+ label="Colorization resolution (Low is more stable)", value="Low (256x256)")
54
+ with gr.Row():
55
+ Ckeckbox_editable = gr.Checkbox(default=False, label='Show editable anchors')
56
+ Button_show_anchor = gr.Button(value="Predict anchors", visible=False)
57
+ Button_run = gr.Button(value="Colorize")
58
+ with gr.Column():
59
+ Image_output = [gr.Image(type="numpy", label="Output").style(height=480), gr.Image(type="numpy", label="Output").style(height=480)]
60
+
61
+ Ckeckbox_editable.change(fn=switch_states, inputs=Ckeckbox_editable, outputs=[Image_anchor, Button_show_anchor])
62
+ Button_show_anchor.click(fn=click_predanchors, inputs=[Image_input, Num_anchor, Radio_resolution, Ckeckbox_editable], outputs=Image_anchor)
63
+ Button_run.click(fn=click_colorize, inputs=[Image_input, Image_anchor, Num_anchor, Radio_resolution, Ckeckbox_editable], \
64
+ outputs=Image_output)
65
+
66
+ ## guiline
67
+ gr.Markdown(value="""
68
+ 🔔**Guideline**
69
+ 1. Upload your image or select one from the examples.
70
+ 2. Set up the arguments: "Num. of anchors" and "Colorization resolution".
71
+ 3. Run the colorization (two modes supported):
72
+ - 📀Automatic mode: **Click** "Colorize" to get the automatically colorized output.
73
+ - ✏️Editable mode: **Check** ""Show editable anchors"; **Click** "Predict anchors"; **Redraw** the anchor colors (only anchor region will be used); **Click** "Colorize" to get the result.
74
+ """)
75
+ if RUN_MODE != "local":
76
+ gr.Examples(examples=[
77
+ ['01.jpg', 8, "Low (256x256)"],
78
+ ['02.jpg', 8, "Low (256x256)"],
79
+ ['03.jpg', 8, "Low (256x256)"],
80
+ ['04.jpg', 8, "Low (256x256)"],
81
+ ],
82
+ inputs=[Image_input,Num_anchor,Radio_resolution], outputs=[Image_output], label="Examples", cache_examples=False)
83
+ gr.HTML(value="""
84
+ <p style="text-align:center; color:orange"><a href='https://menghanxia.github.io/projects/disco.html' target='_blank'>DISCO Project Page</a> | <a href='https://github.com/MenghanXia/DisentangledColorization' target='_blank'>Github Repo</a></p>
85
+ """)
86
+
87
+ if RUN_MODE == "local":
88
+ demo.launch(server_name='9.134.253.83',server_port=7788)
89
+ else:
90
+ demo.queue(default_enabled=True, status_update_rate=5)
91
+ demo.launch()
cog.yaml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration for Cog ⚙️
2
+ # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
3
+
4
+ build:
5
+ # set to true if your model requires a GPU
6
+ cuda: "10.2"
7
+ gpu: true
8
+
9
+ # a list of ubuntu apt packages to install
10
+ system_packages:
11
+ # - "libgl1-mesa-glx"
12
+ # - "libglib2.0-0"
13
+ - "libgl1-mesa-dev"
14
+
15
+ # python version in the form '3.8' or '3.8.12'
16
+ python_version: "3.8"
17
+
18
+ # a list of packages in the format <package-name>==<version>
19
+ python_packages:
20
+ # - "numpy==1.19.4"
21
+ # - "torch==1.8.0"
22
+ # - "torchvision==0.9.0"
23
+ - "numpy==1.23.1"
24
+ - "torch==1.8.0"
25
+ - "torchvision==0.9.0"
26
+ - "opencv-python==4.6.0.66"
27
+ - "pandas==1.4.3"
28
+ - "pillow==9.2.0"
29
+ - "tqdm==4.64.0"
30
+ - "scikit-image==0.19.3"
31
+ - "scikit-learn==1.1.2"
32
+ - "scipy==1.9.1"
33
+
34
+ # commands run after the environment is setup
35
+ # run:
36
+ # - "echo env is ready!"
37
+ # - "echo another command if needed"
38
+
39
+ # predict.py defines how predictions are run on your model
40
+ predict: "predict.py:Predictor"
41
+ #image: "r8.im/menghanxia/disco"
environment.yml ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: DISCO
2
+ channels:
3
+ - pytorch
4
+ - defaults
5
+ - conda-forge
6
+ dependencies:
7
+ - blas=1.0=mkl
8
+ - bzip2=1.0.8=h7b6447c_0
9
+ - ca-certificates=2022.07.19=h06a4308_0
10
+ - certifi=2022.6.15=py38h06a4308_0
11
+ - cudatoolkit=10.2.89=hfd86e86_1
12
+ - freetype=2.11.0=h70c0345_0
13
+ - giflib=5.2.1=h7b6447c_0
14
+ - gmp=6.2.1=h295c915_3
15
+ - gnutls=3.6.15=he1e5248_0
16
+ - intel-openmp=2021.4.0=h06a4308_3561
17
+ - jpeg=9b=h024ee3a_2
18
+ - lame=3.100=h7b6447c_0
19
+ - lcms2=2.12=h3be6417_0
20
+ - ld_impl_linux-64=2.38=h1181459_1
21
+ - libffi=3.3=he6710b0_2
22
+ - libgcc-ng=11.2.0=h1234567_1
23
+ - libiconv=1.16=h7f8727e_2
24
+ - libidn2=2.3.2=h7f8727e_0
25
+ - libpng=1.6.37=hbc83047_0
26
+ - libstdcxx-ng=11.2.0=h1234567_1
27
+ - libtasn1=4.16.0=h27cfd23_0
28
+ - libtiff=4.1.0=h2733197_1
29
+ - libunistring=0.9.10=h27cfd23_0
30
+ - libuv=1.40.0=h7b6447c_0
31
+ - libwebp=1.2.0=h89dd481_0
32
+ - lz4-c=1.9.3=h295c915_1
33
+ - mkl=2021.4.0=h06a4308_640
34
+ - mkl-service=2.4.0=py38h7f8727e_0
35
+ - mkl_fft=1.3.1=py38hd3c417c_0
36
+ - mkl_random=1.2.2=py38h51133e4_0
37
+ - ncurses=6.3=h5eee18b_3
38
+ - nettle=3.7.3=hbbd107a_1
39
+ - ninja=1.10.2=h06a4308_5
40
+ - ninja-base=1.10.2=hd09550d_5
41
+ - numpy=1.23.1=py38h6c91a56_0
42
+ - numpy-base=1.23.1=py38ha15fc14_0
43
+ - openh264=2.1.1=h4ff587b_0
44
+ - openssl=1.1.1q=h7f8727e_0
45
+ - pillow=9.2.0=py38hace64e9_1
46
+ - pip=22.1.2=py38h06a4308_0
47
+ - python=3.8.13=h12debd9_0
48
+ - readline=8.1.2=h7f8727e_1
49
+ - setuptools=63.4.1=py38h06a4308_0
50
+ - six=1.16.0=pyhd3eb1b0_1
51
+ - sqlite=3.39.2=h5082296_0
52
+ - tk=8.6.12=h1ccaba5_0
53
+ - typing_extensions=4.3.0=py38h06a4308_0
54
+ - wheel=0.37.1=pyhd3eb1b0_0
55
+ - xz=5.2.5=h7f8727e_1
56
+ - zlib=1.2.12=h7f8727e_2
57
+ - zstd=1.4.9=haebb681_0
58
+ - ffmpeg=4.3=hf484d3e_0
59
+ - pytorch=1.8.0=py3.8_cuda10.2_cudnn7.6.5_0
60
+ - torchaudio=0.8.0=py38
61
+ - torchvision=0.9.0=py38_cu102
62
+ - pip:
63
+ - addict==2.4.0
64
+ - astunparse==1.6.3
65
+ - cachetools==4.2.4
66
+ - charset-normalizer==2.0.7
67
+ - clang==5.0
68
+ - cycler==0.11.0
69
+ - flatbuffers==1.12
70
+ - fonttools==4.37.1
71
+ - future==0.18.2
72
+ - gast==0.4.0
73
+ - google-auth==2.3.2
74
+ - google-auth-oauthlib==0.4.6
75
+ - google-pasta==0.2.0
76
+ - grpcio==1.41.1
77
+ - h5py==3.1.0
78
+ - idna==3.3
79
+ - imageio==2.21.1
80
+ - joblib==1.1.0
81
+ - keras==2.6.0
82
+ - keras-preprocessing==1.1.2
83
+ - kiwisolver==1.4.4
84
+ - lpips==0.1.4
85
+ - markdown==3.3.4
86
+ - matplotlib==3.5.3
87
+ - networkx==2.8.6
88
+ - oauthlib==3.1.1
89
+ - opencv-python==4.6.0.66
90
+ - opt-einsum==3.3.0
91
+ - packaging==21.3
92
+ - pandas==1.4.3
93
+ - protobuf==3.19.0
94
+ - pyasn1==0.4.8
95
+ - pyasn1-modules==0.2.8
96
+ - pyparsing==3.0.9
97
+ - python-dateutil==2.8.2
98
+ - pytz==2022.2.1
99
+ - pywavelets==1.3.0
100
+ - pyyaml==6.0
101
+ - requests==2.26.0
102
+ - requests-oauthlib==1.3.0
103
+ - rsa==4.7.2
104
+ - scikit-image==0.19.3
105
+ - scikit-learn==1.1.2
106
+ - scipy==1.9.1
107
+ - tensorboard-data-server==0.6.1
108
+ - tensorboard-plugin-wit==1.8.0
109
+ - tensorflow-estimator==2.6.0
110
+ - tensorflow-gpu==2.6.0
111
+ - termcolor==1.1.0
112
+ - threadpoolctl==3.1.0
113
+ - tifffile==2022.8.12
114
+ - torch==1.8.0
115
+ - tqdm==4.64.0
116
+ - urllib3==1.26.7
117
+ - werkzeug==2.0.2
118
+ - wrapt==1.12.1
119
+ - yapf==0.32.0
120
+ prefix: /root/data/programs/anaconda3/envs/DISCO
121
+
inference.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, glob, sys, logging
2
+ import argparse, datetime, time
3
+ import numpy as np
4
+ import cv2
5
+ from PIL import Image
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from models import model, basic
10
+ from utils import util
11
+
12
+
13
+ def setup_model(checkpt_path, device="cuda"):
14
+ #print('--------------', torch.cuda.is_available())
15
+ """Load the model into memory to make running multiple predictions efficient"""
16
+ colorLabeler = basic.ColorLabel(device=device)
17
+ colorizer = model.AnchorColorProb(inChannel=1, outChannel=313, enhanced=True, colorLabeler=colorLabeler)
18
+ colorizer = colorizer.to(device)
19
+ #checkpt_path = "./checkpoints/disco-beta.pth.rar"
20
+ assert os.path.exists(checkpt_path), "No checkpoint found!"
21
+ data_dict = torch.load(checkpt_path, map_location=torch.device('cpu'))
22
+ colorizer.load_state_dict(data_dict['state_dict'])
23
+ colorizer.eval()
24
+ return colorizer, colorLabeler
25
+
26
+
27
+ def resize_ab2l(gray_img, lab_imgs, vis=False):
28
+ H, W = gray_img.shape[:2]
29
+ reszied_ab = cv2.resize(lab_imgs[:,:,1:], (W,H), interpolation=cv2.INTER_LINEAR)
30
+ if vis:
31
+ gray_img = cv2.resize(lab_imgs[:,:,:1], (W,H), interpolation=cv2.INTER_LINEAR)
32
+ return np.concatenate((gray_img[:,:,np.newaxis], reszied_ab), axis=2)
33
+ else:
34
+ return np.concatenate((gray_img, reszied_ab), axis=2)
35
+
36
+ def prepare_data(rgb_img, target_res):
37
+ rgb_img = np.array(rgb_img / 255., np.float32)
38
+ lab_img = cv2.cvtColor(rgb_img, cv2.COLOR_RGB2LAB)
39
+ org_grays = (lab_img[:,:,[0]]-50.) / 50.
40
+ lab_img = cv2.resize(lab_img, target_res, interpolation=cv2.INTER_LINEAR)
41
+
42
+ lab_img = torch.from_numpy(lab_img.transpose((2, 0, 1)))
43
+ gray_img = (lab_img[0:1,:,:]-50.) / 50.
44
+ ab_chans = lab_img[1:3,:,:] / 110.
45
+ input_grays = gray_img.unsqueeze(0)
46
+ input_colors = ab_chans.unsqueeze(0)
47
+ return input_grays, input_colors, org_grays
48
+
49
+
50
+ def colorize_grayscale(colorizer, color_class, rgb_img, hint_img, n_anchors, is_high_res, is_editable, device="cuda"):
51
+ n_anchors = int(n_anchors)
52
+ n_anchors = max(n_anchors, 3)
53
+ n_anchors = min(n_anchors, 14)
54
+ target_res = (512,512) if is_high_res else (256,256)
55
+ input_grays, input_colors, org_grays = prepare_data(rgb_img, target_res)
56
+ input_grays = input_grays.to(device)
57
+ input_colors = input_colors.to(device)
58
+
59
+ if is_editable:
60
+ print('>>>:editable mode')
61
+ sampled_T = -1
62
+ _, input_colors, _ = prepare_data(hint_img, target_res)
63
+ input_colors = input_colors.to(device)
64
+ pal_logit, ref_logit, enhanced_ab, affinity_map, spix_colors, hint_mask = colorizer(input_grays, \
65
+ input_colors, n_anchors, sampled_T)
66
+ else:
67
+ print('>>>:automatic mode')
68
+ sampled_T = 0
69
+ pal_logit, ref_logit, enhanced_ab, affinity_map, spix_colors, hint_mask = colorizer(input_grays, \
70
+ input_colors, n_anchors, sampled_T)
71
+
72
+ pred_labs = torch.cat((input_grays,enhanced_ab), dim=1)
73
+ lab_imgs = basic.tensor2array(pred_labs).squeeze(axis=0)
74
+ lab_imgs = resize_ab2l(org_grays, lab_imgs)
75
+
76
+ lab_imgs[:,:,0] = lab_imgs[:,:,0] * 50.0 + 50.0
77
+ lab_imgs[:,:,1:3] = lab_imgs[:,:,1:3] * 110.0
78
+ rgb_output = cv2.cvtColor(lab_imgs[:,:,:], cv2.COLOR_LAB2RGB)
79
+ return (rgb_output*255.0).astype(np.uint8)
80
+
81
+
82
+ def predict_anchors(colorizer, color_class, rgb_img, n_anchors, is_high_res, is_editable, device="cuda"):
83
+ n_anchors = int(n_anchors)
84
+ n_anchors = max(n_anchors, 3)
85
+ n_anchors = min(n_anchors, 14)
86
+ target_res = (512,512) if is_high_res else (256,256)
87
+ input_grays, input_colors, org_grays = prepare_data(rgb_img, target_res)
88
+ input_grays = input_grays.to(device)
89
+ input_colors = input_colors.to(device)
90
+
91
+ sampled_T, sp_size = 0, 16
92
+ pal_logit, ref_logit, enhanced_ab, affinity_map, spix_colors, hint_mask = colorizer(input_grays, \
93
+ input_colors, n_anchors, sampled_T)
94
+ pred_probs = pal_logit
95
+ guided_colors = color_class.decode_ind2ab(ref_logit, T=0)
96
+ guided_colors = basic.upfeat(guided_colors, affinity_map, sp_size, sp_size)
97
+ anchor_masks = basic.upfeat(hint_mask, affinity_map, sp_size, sp_size)
98
+ marked_labs = basic.mark_color_hints(input_grays, guided_colors, anchor_masks, base_ABs=None)
99
+ lab_imgs = basic.tensor2array(marked_labs).squeeze(axis=0)
100
+ lab_imgs = resize_ab2l(org_grays, lab_imgs, vis=True)
101
+
102
+ lab_imgs[:,:,0] = lab_imgs[:,:,0] * 50.0 + 50.0
103
+ lab_imgs[:,:,1:3] = lab_imgs[:,:,1:3] * 110.0
104
+ rgb_output = cv2.cvtColor(lab_imgs[:,:,:], cv2.COLOR_LAB2RGB)
105
+ return (rgb_output*255.0).astype(np.uint8)
predict.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Prediction interface for Cog ⚙️
2
+ # https://github.com/replicate/cog/blob/main/docs/python.md
3
+
4
+ from cog import BasePredictor, Input, Path
5
+ import tempfile
6
+ import os, glob
7
+ import numpy as np
8
+ import cv2
9
+ from PIL import Image
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from models import model, basic
14
+ from utils import util
15
+
16
+ class Predictor(BasePredictor):
17
+ def setup(self):
18
+ seed = 130
19
+ np.random.seed(seed)
20
+ torch.manual_seed(seed)
21
+ torch.cuda.manual_seed(seed)
22
+ #print('--------------', torch.cuda.is_available())
23
+ """Load the model into memory to make running multiple predictions efficient"""
24
+ self.colorizer = model.AnchorColorProb(inChannel=1, outChannel=313, enhanced=True)
25
+ self.colorizer = self.colorizer.cuda()
26
+ checkpt_path = "./checkpoints/disco-beta.pth.rar"
27
+ assert os.path.exists(checkpt_path)
28
+ data_dict = torch.load(checkpt_path, map_location=torch.device('cpu'))
29
+ self.colorizer.load_state_dict(data_dict['state_dict'])
30
+ self.colorizer.eval()
31
+ self.color_class = basic.ColorLabel(lambda_=0.5, device='cuda')
32
+
33
+ def resize_ab2l(self, gray_img, lab_imgs):
34
+ H, W = gray_img.shape[:2]
35
+ reszied_ab = cv2.resize(lab_imgs[:,:,1:], (W,H), interpolation=cv2.INTER_LINEAR)
36
+ return np.concatenate((gray_img, reszied_ab), axis=2)
37
+
38
+ def predict(
39
+ self,
40
+ image: Path = Input(description="input image. Output will be one or multiple colorized images."),
41
+ n_anchors: int = Input(
42
+ description="number of color anchors", ge=3, le=14, default=8
43
+ ),
44
+ multi_result: bool = Input(
45
+ description="to generate diverse results", default=False
46
+ ),
47
+ vis_anchors: bool = Input(
48
+ description="to visualize the anchor locations", default=False
49
+ )
50
+ ) -> Path:
51
+ """Run a single prediction on the model"""
52
+ bgr_img = cv2.imread(str(image), cv2.IMREAD_COLOR)
53
+ rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
54
+ rgb_img = np.array(rgb_img / 255., np.float32)
55
+ lab_img = cv2.cvtColor(rgb_img, cv2.COLOR_RGB2LAB)
56
+ org_grays = (lab_img[:,:,[0]]-50.) / 50.
57
+ lab_img = cv2.resize(lab_img, (256,256), interpolation=cv2.INTER_LINEAR)
58
+
59
+ lab_img = torch.from_numpy(lab_img.transpose((2, 0, 1)))
60
+ gray_img = (lab_img[0:1,:,:]-50.) / 50.
61
+ ab_chans = lab_img[1:3,:,:] / 110.
62
+ input_grays = gray_img.unsqueeze(0)
63
+ input_colors = ab_chans.unsqueeze(0)
64
+ input_grays = input_grays.cuda(non_blocking=True)
65
+ input_colors = input_colors.cuda(non_blocking=True)
66
+
67
+ sampled_T = 2 if multi_result else 0
68
+ pal_logit, ref_logit, enhanced_ab, affinity_map, spix_colors, hint_mask = self.colorizer(input_grays, \
69
+ input_colors, n_anchors, True, sampled_T)
70
+ pred_probs = pal_logit
71
+ guided_colors = self.color_class.decode_ind2ab(ref_logit, T=0)
72
+ sp_size = 16
73
+ guided_colors = basic.upfeat(guided_colors, affinity_map, sp_size, sp_size)
74
+ res_list = []
75
+ if multi_result:
76
+ for no in range(3):
77
+ pred_labs = torch.cat((input_grays,enhanced_ab[no:no+1,:,:,:]), dim=1)
78
+ lab_imgs = basic.tensor2array(pred_labs).squeeze(axis=0)
79
+ lab_imgs = self.resize_ab2l(org_grays, lab_imgs)
80
+ #util.save_normLabs_from_batch(lab_imgs, save_dir, [file_name], -1, suffix='c%d'%no)
81
+ res_list.append(lab_imgs)
82
+ else:
83
+ pred_labs = torch.cat((input_grays,enhanced_ab), dim=1)
84
+ lab_imgs = basic.tensor2array(pred_labs).squeeze(axis=0)
85
+ lab_imgs = self.resize_ab2l(org_grays, lab_imgs)
86
+ #util.save_normLabs_from_batch(lab_imgs, save_dir, [file_name], -1)#, suffix='enhanced')
87
+ res_list.append(lab_imgs)
88
+
89
+ if vis_anchors:
90
+ ## visualize anchor locations
91
+ anchor_masks = basic.upfeat(hint_mask, affinity_map, sp_size, sp_size)
92
+ marked_labs = basic.mark_color_hints(input_grays, enhanced_ab, anchor_masks, base_ABs=enhanced_ab)
93
+ hint_imgs = basic.tensor2array(marked_labs).squeeze(axis=0)
94
+ hint_imgs = self.resize_ab2l(org_grays, hint_imgs)
95
+ #util.save_normLabs_from_batch(hint_imgs, save_dir, [file_name], -1, suffix='anchors')
96
+ res_list.append(hint_imgs)
97
+
98
+ output = cv2.vconcat(res_list)
99
+ output[:,:,0] = output[:,:,0] * 50.0 + 50.0
100
+ output[:,:,1:3] = output[:,:,1:3] * 110.0
101
+ rgb_output = cv2.cvtColor(output[:,:,:], cv2.COLOR_LAB2BGR)
102
+ out_path = Path(tempfile.mkdtemp()) / "out.png"
103
+ cv2.imwrite(str(out_path), (rgb_output*255.0).astype(np.uint8))
104
+ return out_path
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ addict
2
+ future
3
+ numpy
4
+ opencv-python
5
+ pandas
6
+ Pillow
7
+ pyyaml
8
+ requests
9
+ scikit-image
10
+ scikit-learn
11
+ scipy
12
+ torch>=1.8.0
13
+ torchvision
14
+ tensorboardx>=2.4
15
+ tqdm
16
+ yapf
17
+ lpips