jerichosy commited on
Commit
1cdece3
·
1 Parent(s): 4e7ebcb
.gitignore ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
161
+ .DS_Store
LICENSE ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2016, Richard Zhang, Phillip Isola, Alexei A. Efros
2
+ All rights reserved.
3
+
4
+ Redistribution and use in source and binary forms, with or without
5
+ modification, are permitted provided that the following conditions are met:
6
+
7
+ * Redistributions of source code must retain the above copyright notice, this
8
+ list of conditions and the following disclaimer.
9
+
10
+ * Redistributions in binary form must reproduce the above copyright notice,
11
+ this list of conditions and the following disclaimer in the documentation
12
+ and/or other materials provided with the distribution.
13
+
14
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
18
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
20
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
21
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
22
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
app.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from colorizers import *
4
+
5
+ # load colorizers
6
+ colorizer_eccv16 = eccv16(pretrained=True).eval()
7
+ colorizer_siggraph17 = siggraph17(pretrained=True).eval()
8
+
9
+ title = "Colorize Images"
10
+ description = """
11
+ Colorize black and white images using the ECCV 2016 and SIGGRAPH 2017 colorization papers by Zhang et al.:
12
+ - Colorful Image Colorization: https://arxiv.org/abs/1603.08511
13
+ - Real-Time User-Guided Image Colorization with Learned Deep Priors: https://arxiv.org/abs/1705.02999
14
+ <br>
15
+ Reference implementation: https://github.com/richzhang/colorization
16
+ <br>
17
+ Adapted to Gradio by DIGIMAP Group 12:
18
+ - GREGORIO, DALE PONS LEE
19
+ - SILLONA, JOHN EUGENE JUSTINIANO
20
+ - SY, MATTHEW JERICHO GO
21
+ """
22
+
23
+
24
+ def color(image, ver):
25
+ # default size to process images is 256x256
26
+ # grab L channel in both original ("orig") and resized ("rs") resolutions
27
+ (tens_l_orig, tens_l_rs) = preprocess_img(image, HW=(256, 256))
28
+
29
+ # colorizer outputs 256x256 ab map
30
+ # resize and concatenate to original L channel
31
+ if ver == "eccv16":
32
+ out_img = postprocess_tens(tens_l_orig, colorizer_eccv16(tens_l_rs).cpu())
33
+ else:
34
+ out_img = postprocess_tens(tens_l_orig, colorizer_siggraph17(tens_l_rs).cpu())
35
+
36
+ return out_img
37
+
38
+
39
+ gr.Interface(
40
+ fn=color,
41
+ inputs=[
42
+ "image",
43
+ gr.Radio(
44
+ ["eccv16", "siggraph17"], type="value", value="eccv16", label="version"
45
+ ),
46
+ ],
47
+ # outputs=[gr.Image(label="eccv16"), gr.Image(label="siggraph17")],
48
+ outputs="image",
49
+ allow_flagging="never",
50
+ title=title,
51
+ description=description,
52
+ examples=[
53
+ ["imgs/moon-captured-bw-lg.jpeg", "eccv16"],
54
+ ],
55
+ ).launch()
colorizers/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+
2
+ from .base_color import *
3
+ from .eccv16 import *
4
+ from .siggraph17 import *
5
+ from .util import *
6
+
colorizers/base_color.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from torch import nn
4
+
5
+ class BaseColor(nn.Module):
6
+ def __init__(self):
7
+ super(BaseColor, self).__init__()
8
+
9
+ self.l_cent = 50.
10
+ self.l_norm = 100.
11
+ self.ab_norm = 110.
12
+
13
+ def normalize_l(self, in_l):
14
+ return (in_l-self.l_cent)/self.l_norm
15
+
16
+ def unnormalize_l(self, in_l):
17
+ return in_l*self.l_norm + self.l_cent
18
+
19
+ def normalize_ab(self, in_ab):
20
+ return in_ab/self.ab_norm
21
+
22
+ def unnormalize_ab(self, in_ab):
23
+ return in_ab*self.ab_norm
24
+
colorizers/eccv16.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ from IPython import embed
6
+
7
+ from .base_color import *
8
+
9
+ class ECCVGenerator(BaseColor):
10
+ def __init__(self, norm_layer=nn.BatchNorm2d):
11
+ super(ECCVGenerator, self).__init__()
12
+
13
+ model1=[nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=True),]
14
+ model1+=[nn.ReLU(True),]
15
+ model1+=[nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=True),]
16
+ model1+=[nn.ReLU(True),]
17
+ model1+=[norm_layer(64),]
18
+
19
+ model2=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),]
20
+ model2+=[nn.ReLU(True),]
21
+ model2+=[nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1, bias=True),]
22
+ model2+=[nn.ReLU(True),]
23
+ model2+=[norm_layer(128),]
24
+
25
+ model3=[nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True),]
26
+ model3+=[nn.ReLU(True),]
27
+ model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
28
+ model3+=[nn.ReLU(True),]
29
+ model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1, bias=True),]
30
+ model3+=[nn.ReLU(True),]
31
+ model3+=[norm_layer(256),]
32
+
33
+ model4=[nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True),]
34
+ model4+=[nn.ReLU(True),]
35
+ model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
36
+ model4+=[nn.ReLU(True),]
37
+ model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
38
+ model4+=[nn.ReLU(True),]
39
+ model4+=[norm_layer(512),]
40
+
41
+ model5=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
42
+ model5+=[nn.ReLU(True),]
43
+ model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
44
+ model5+=[nn.ReLU(True),]
45
+ model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
46
+ model5+=[nn.ReLU(True),]
47
+ model5+=[norm_layer(512),]
48
+
49
+ model6=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
50
+ model6+=[nn.ReLU(True),]
51
+ model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
52
+ model6+=[nn.ReLU(True),]
53
+ model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
54
+ model6+=[nn.ReLU(True),]
55
+ model6+=[norm_layer(512),]
56
+
57
+ model7=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
58
+ model7+=[nn.ReLU(True),]
59
+ model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
60
+ model7+=[nn.ReLU(True),]
61
+ model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
62
+ model7+=[nn.ReLU(True),]
63
+ model7+=[norm_layer(512),]
64
+
65
+ model8=[nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=True),]
66
+ model8+=[nn.ReLU(True),]
67
+ model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
68
+ model8+=[nn.ReLU(True),]
69
+ model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
70
+ model8+=[nn.ReLU(True),]
71
+
72
+ model8+=[nn.Conv2d(256, 313, kernel_size=1, stride=1, padding=0, bias=True),]
73
+
74
+ self.model1 = nn.Sequential(*model1)
75
+ self.model2 = nn.Sequential(*model2)
76
+ self.model3 = nn.Sequential(*model3)
77
+ self.model4 = nn.Sequential(*model4)
78
+ self.model5 = nn.Sequential(*model5)
79
+ self.model6 = nn.Sequential(*model6)
80
+ self.model7 = nn.Sequential(*model7)
81
+ self.model8 = nn.Sequential(*model8)
82
+
83
+ self.softmax = nn.Softmax(dim=1)
84
+ self.model_out = nn.Conv2d(313, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=False)
85
+ self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear')
86
+
87
+ def forward(self, input_l):
88
+ conv1_2 = self.model1(self.normalize_l(input_l))
89
+ conv2_2 = self.model2(conv1_2)
90
+ conv3_3 = self.model3(conv2_2)
91
+ conv4_3 = self.model4(conv3_3)
92
+ conv5_3 = self.model5(conv4_3)
93
+ conv6_3 = self.model6(conv5_3)
94
+ conv7_3 = self.model7(conv6_3)
95
+ conv8_3 = self.model8(conv7_3)
96
+ out_reg = self.model_out(self.softmax(conv8_3))
97
+
98
+ return self.unnormalize_ab(self.upsample4(out_reg))
99
+
100
+ def eccv16(pretrained=True):
101
+ model = ECCVGenerator()
102
+ if(pretrained):
103
+ import torch.utils.model_zoo as model_zoo
104
+ model.load_state_dict(model_zoo.load_url('https://colorizers.s3.us-east-2.amazonaws.com/colorization_release_v2-9b330a0b.pth',map_location='cpu',check_hash=True))
105
+ return model
colorizers/siggraph17.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .base_color import *
5
+
6
+ class SIGGRAPHGenerator(BaseColor):
7
+ def __init__(self, norm_layer=nn.BatchNorm2d, classes=529):
8
+ super(SIGGRAPHGenerator, self).__init__()
9
+
10
+ # Conv1
11
+ model1=[nn.Conv2d(4, 64, kernel_size=3, stride=1, padding=1, bias=True),]
12
+ model1+=[nn.ReLU(True),]
13
+ model1+=[nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True),]
14
+ model1+=[nn.ReLU(True),]
15
+ model1+=[norm_layer(64),]
16
+ # add a subsampling operation
17
+
18
+ # Conv2
19
+ model2=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),]
20
+ model2+=[nn.ReLU(True),]
21
+ model2+=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),]
22
+ model2+=[nn.ReLU(True),]
23
+ model2+=[norm_layer(128),]
24
+ # add a subsampling layer operation
25
+
26
+ # Conv3
27
+ model3=[nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True),]
28
+ model3+=[nn.ReLU(True),]
29
+ model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
30
+ model3+=[nn.ReLU(True),]
31
+ model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
32
+ model3+=[nn.ReLU(True),]
33
+ model3+=[norm_layer(256),]
34
+ # add a subsampling layer operation
35
+
36
+ # Conv4
37
+ model4=[nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True),]
38
+ model4+=[nn.ReLU(True),]
39
+ model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
40
+ model4+=[nn.ReLU(True),]
41
+ model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
42
+ model4+=[nn.ReLU(True),]
43
+ model4+=[norm_layer(512),]
44
+
45
+ # Conv5
46
+ model5=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
47
+ model5+=[nn.ReLU(True),]
48
+ model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
49
+ model5+=[nn.ReLU(True),]
50
+ model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
51
+ model5+=[nn.ReLU(True),]
52
+ model5+=[norm_layer(512),]
53
+
54
+ # Conv6
55
+ model6=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
56
+ model6+=[nn.ReLU(True),]
57
+ model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
58
+ model6+=[nn.ReLU(True),]
59
+ model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
60
+ model6+=[nn.ReLU(True),]
61
+ model6+=[norm_layer(512),]
62
+
63
+ # Conv7
64
+ model7=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
65
+ model7+=[nn.ReLU(True),]
66
+ model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
67
+ model7+=[nn.ReLU(True),]
68
+ model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
69
+ model7+=[nn.ReLU(True),]
70
+ model7+=[norm_layer(512),]
71
+
72
+ # Conv7
73
+ model8up=[nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=True)]
74
+ model3short8=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
75
+
76
+ model8=[nn.ReLU(True),]
77
+ model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
78
+ model8+=[nn.ReLU(True),]
79
+ model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
80
+ model8+=[nn.ReLU(True),]
81
+ model8+=[norm_layer(256),]
82
+
83
+ # Conv9
84
+ model9up=[nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=True),]
85
+ model2short9=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),]
86
+ # add the two feature maps above
87
+
88
+ model9=[nn.ReLU(True),]
89
+ model9+=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),]
90
+ model9+=[nn.ReLU(True),]
91
+ model9+=[norm_layer(128),]
92
+
93
+ # Conv10
94
+ model10up=[nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1, bias=True),]
95
+ model1short10=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),]
96
+ # add the two feature maps above
97
+
98
+ model10=[nn.ReLU(True),]
99
+ model10+=[nn.Conv2d(128, 128, kernel_size=3, dilation=1, stride=1, padding=1, bias=True),]
100
+ model10+=[nn.LeakyReLU(negative_slope=.2),]
101
+
102
+ # classification output
103
+ model_class=[nn.Conv2d(256, classes, kernel_size=1, padding=0, dilation=1, stride=1, bias=True),]
104
+
105
+ # regression output
106
+ model_out=[nn.Conv2d(128, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=True),]
107
+ model_out+=[nn.Tanh()]
108
+
109
+ self.model1 = nn.Sequential(*model1)
110
+ self.model2 = nn.Sequential(*model2)
111
+ self.model3 = nn.Sequential(*model3)
112
+ self.model4 = nn.Sequential(*model4)
113
+ self.model5 = nn.Sequential(*model5)
114
+ self.model6 = nn.Sequential(*model6)
115
+ self.model7 = nn.Sequential(*model7)
116
+ self.model8up = nn.Sequential(*model8up)
117
+ self.model8 = nn.Sequential(*model8)
118
+ self.model9up = nn.Sequential(*model9up)
119
+ self.model9 = nn.Sequential(*model9)
120
+ self.model10up = nn.Sequential(*model10up)
121
+ self.model10 = nn.Sequential(*model10)
122
+ self.model3short8 = nn.Sequential(*model3short8)
123
+ self.model2short9 = nn.Sequential(*model2short9)
124
+ self.model1short10 = nn.Sequential(*model1short10)
125
+
126
+ self.model_class = nn.Sequential(*model_class)
127
+ self.model_out = nn.Sequential(*model_out)
128
+
129
+ self.upsample4 = nn.Sequential(*[nn.Upsample(scale_factor=4, mode='bilinear'),])
130
+ self.softmax = nn.Sequential(*[nn.Softmax(dim=1),])
131
+
132
+ def forward(self, input_A, input_B=None, mask_B=None):
133
+ if(input_B is None):
134
+ input_B = torch.cat((input_A*0, input_A*0), dim=1)
135
+ if(mask_B is None):
136
+ mask_B = input_A*0
137
+
138
+ conv1_2 = self.model1(torch.cat((self.normalize_l(input_A),self.normalize_ab(input_B),mask_B),dim=1))
139
+ conv2_2 = self.model2(conv1_2[:,:,::2,::2])
140
+ conv3_3 = self.model3(conv2_2[:,:,::2,::2])
141
+ conv4_3 = self.model4(conv3_3[:,:,::2,::2])
142
+ conv5_3 = self.model5(conv4_3)
143
+ conv6_3 = self.model6(conv5_3)
144
+ conv7_3 = self.model7(conv6_3)
145
+
146
+ conv8_up = self.model8up(conv7_3) + self.model3short8(conv3_3)
147
+ conv8_3 = self.model8(conv8_up)
148
+ conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2)
149
+ conv9_3 = self.model9(conv9_up)
150
+ conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2)
151
+ conv10_2 = self.model10(conv10_up)
152
+ out_reg = self.model_out(conv10_2)
153
+
154
+ conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2)
155
+ conv9_3 = self.model9(conv9_up)
156
+ conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2)
157
+ conv10_2 = self.model10(conv10_up)
158
+ out_reg = self.model_out(conv10_2)
159
+
160
+ return self.unnormalize_ab(out_reg)
161
+
162
+ def siggraph17(pretrained=True):
163
+ model = SIGGRAPHGenerator()
164
+ if(pretrained):
165
+ import torch.utils.model_zoo as model_zoo
166
+ model.load_state_dict(model_zoo.load_url('https://colorizers.s3.us-east-2.amazonaws.com/siggraph17-df00044c.pth',map_location='cpu',check_hash=True))
167
+ return model
168
+
colorizers/util.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from PIL import Image
3
+ import numpy as np
4
+ from skimage import color
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from IPython import embed
8
+
9
+ def load_img(img_path):
10
+ out_np = np.asarray(Image.open(img_path))
11
+ if(out_np.ndim==2):
12
+ out_np = np.tile(out_np[:,:,None],3)
13
+ return out_np
14
+
15
+ def resize_img(img, HW=(256,256), resample=3):
16
+ return np.asarray(Image.fromarray(img).resize((HW[1],HW[0]), resample=resample))
17
+
18
+ def preprocess_img(img_rgb_orig, HW=(256,256), resample=3):
19
+ # return original size L and resized L as torch Tensors
20
+ img_rgb_rs = resize_img(img_rgb_orig, HW=HW, resample=resample)
21
+
22
+ img_lab_orig = color.rgb2lab(img_rgb_orig)
23
+ img_lab_rs = color.rgb2lab(img_rgb_rs)
24
+
25
+ img_l_orig = img_lab_orig[:,:,0]
26
+ img_l_rs = img_lab_rs[:,:,0]
27
+
28
+ tens_orig_l = torch.Tensor(img_l_orig)[None,None,:,:]
29
+ tens_rs_l = torch.Tensor(img_l_rs)[None,None,:,:]
30
+
31
+ return (tens_orig_l, tens_rs_l)
32
+
33
+ def postprocess_tens(tens_orig_l, out_ab, mode='bilinear'):
34
+ # tens_orig_l 1 x 1 x H_orig x W_orig
35
+ # out_ab 1 x 2 x H x W
36
+
37
+ HW_orig = tens_orig_l.shape[2:]
38
+ HW = out_ab.shape[2:]
39
+
40
+ # call resize function if needed
41
+ if(HW_orig[0]!=HW[0] or HW_orig[1]!=HW[1]):
42
+ out_ab_orig = F.interpolate(out_ab, size=HW_orig, mode='bilinear')
43
+ else:
44
+ out_ab_orig = out_ab
45
+
46
+ out_lab_orig = torch.cat((tens_orig_l, out_ab_orig), dim=1)
47
+ return color.lab2rgb(out_lab_orig.data.cpu().numpy()[0,...].transpose((1,2,0)))
imgs/moon-captured-bw-lg.jpeg ADDED
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ scikit-image
3
+ numpy
4
+ matplotlib
5
+ argparse
6
+ Pillow
7
+ ipython
8
+ gradio
9
+ black