menghanxia commited on
Commit
6e70c4a
1 Parent(s): 7f58cb0

created the space

Browse files
LICENSE ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Deep Halftoning with Reversible Binary Pattern
2
+
3
+ Copyright (c) 2021 The Chinese University of Hong Kong
4
+
5
+ Copyright and License Information: The source code, the binary executable, and all data files (hereafter, Software) are copyrighted by The Chinese University of Hong Kong and Tien-Tsin Wong (hereafter, Author), Copyright (c) 2021 The Chinese University of Hong Kong. All Rights Reserved.
6
+
7
+ The Author grants to you ("Licensee") a non-exclusive license to use the Software for academic, research and commercial purposes, without fee. For commercial use, Licensee should submit a WRITTEN NOTICE to the Author. The notice should clearly identify the software package/system/hardware (name, version, and/or model number) using the Software. Licensee may distribute the Software to third parties provided that the copyright notice and this statement appears on all copies. Licensee agrees that the copyright notice and this statement will appear on all copies of the Software, or portions thereof. The Author retains exclusive ownership of the Software.
8
+
9
+ Licensee may make derivatives of the Software, provided that such derivatives can only be used for the purposes specified in the license grant above.
10
+
11
+ THE AUTHOR MAKES NO REPRESENTATIONS OR WARRANTIES ABOUT THE SUITABILITY OF THE SOFTWARE, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, OR NON-INFRINGEMENT. THE AUTHOR SHALL NOT BE LIABLE FOR ANY DAMAGES SUFFERED BY LICENSEE AS A RESULT OF USING, MODIFYING OR DISTRIBUTING THE SOFTWARE OR ITS DERIVATIVES.
12
+
13
+ By using the source code, Licensee agrees to cite the following papers in
14
+ Licensee's publication/work:
15
+
16
+ Menghan Xia, Wenbo Hu, Xueting Liu and Tien-Tsin Wong
17
+ "Deep Halftoning with Reversible Binary Pattern"
18
+ IEEE International Conference on Computer Vision (ICCV), 2021.
19
+
20
+
21
+ By using or copying the Software, Licensee agrees to abide by the intellectual property laws, and all other applicable laws of the U.S., and the terms of this license.
22
+
23
+ Author shall have the right to terminate this license immediately by written notice upon Licensee's breach of, or non-compliance with, any of its terms.
24
+ Licensee may be held legally responsible for any infringement that is caused or encouraged by Licensee's failure to abide by the terms of this license.
25
+
26
+ For more information or comments, send mail to: ttwong@acm.org
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os, requests
3
+ import numpy as np
4
+ import torch.nn.functional as F
5
+ from model.model import ResHalf
6
+ from inference import Inferencer
7
+ from utils import util
8
+
9
+ ## local | remote
10
+ RUN_MODE = "remote"
11
+ if RUN_MODE != "local":
12
+ os.system("wget https://huggingface.co/menghanxia/ReversibleHalftoning/resolve/main/model_best.pth.tar")
13
+ os.rename("model_best.pth.tar", "./checkpoints/model_best.pth.tar")
14
+ ## examples
15
+ os.system("wget https://huggingface.co/menghanxia/disco/resolve/main/girl.png")
16
+ os.system("wget https://huggingface.co/menghanxia/disco/resolve/main/wave.png")
17
+ os.system("wget https://huggingface.co/menghanxia/disco/resolve/main/painting.png")
18
+
19
+ ## step 1: set up model
20
+ device = "cpu"
21
+ checkpt_path = "checkpoints/model_best.pth.tar"
22
+ invhalfer = Inferencer(checkpoint_path=checkpt_path, model=ResHalf(train=False), use_cuda=False)
23
+
24
+
25
+ def prepare_data(input_img, decoding_only=False):
26
+ input_img = np.array(input_img / 255., np.float32)
27
+ if decoding_only:
28
+ input_img = input_img[:,:,:1]
29
+ input_img = util.img2tensor(input_img * 2. - 1.)
30
+ return input_img
31
+
32
+
33
+ def run_invhalf(invhalfer, input_img, decoding_only, device="cuda"):
34
+ input_img = prepare_data(input_img, decoding_only)
35
+ input_img = input_img.to(device)
36
+ if decoding_only:
37
+ print('>>>:restoration mode')
38
+ resColor = invhalfer(input_img, decoding_only=decoding_only)
39
+ output = util.tensor2img(resColor / 2. + 0.5) * 255.
40
+ else:
41
+ print('>>>:halftoning mode')
42
+ resHalftone, resColor = invhalfer(input_img, decoding_only=decoding_only)
43
+ output = util.tensor2img(resHalftone / 2. + 0.5) * 255.
44
+ return (output+0.5).astype(np.uint8)
45
+
46
+
47
+ def click_run(input_img, decoding_only):
48
+ output = run_invhalf(invhalfer, input_img, decoding_only, device)
49
+ return output
50
+
51
+ ## step 2: configure interface
52
+ demo = gr.Blocks(title="ReversibleHalftoning")
53
+ with demo:
54
+ gr.Markdown(value="""
55
+ **Gradio demo for ReversibleHalftoning: Deep Halftoning with Reversible Binary Pattern**. Check our [github page](https://github.com/MenghanXia/ReversibleHalftoning) 😛.
56
+ """)
57
+ with gr.Row():
58
+ with gr.Column():
59
+ Image_input = gr.Image(type="numpy", label="Input", interactive=True)
60
+ with gr.Row():
61
+ Radio_mode = gr.Radio(type="index", choices=["Halftoning (Photo2Halftone)", "Restoration (Halftone2Photo)"], \
62
+ label="Choose a running mode", value="Halftoning (Photo2Halftone)")
63
+ Button_run = gr.Button(value="Run")
64
+ with gr.Column():
65
+ Image_output = gr.Image(type="numpy", label="Output").style(height=480)
66
+
67
+ Button_run.click(fn=click_run, inputs=[Image_input, Radio_mode], outputs=Image_output)
68
+
69
+ if RUN_MODE == "local":
70
+ gr.Examples(examples=[
71
+ ['girl.png', "Halftoning (Photo2Halftone)"],
72
+ ['wave.png', "Halftoning (Photo2Halftone)"],
73
+ ['painting.png', "Restoration (Halftone2Photo)"],
74
+ ],
75
+ inputs=[Image_input,Radio_mode], outputs=[Image_output], label="Examples")
76
+
77
+ if RUN_MODE != "local":
78
+ demo.launch(server_name='9.134.253.83',server_port=7788)
79
+ else:
80
+ demo.launch()
inference.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import os, argparse, json
4
+ from os.path import join
5
+ from glob import glob
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ from model.model import ResHalf
11
+ from model.model import Quantize
12
+ from model.loss import l1_loss
13
+ from utils import util
14
+ from utils.dct import DCT_Lowfrequency
15
+ from utils.filters_tensor import bgr2gray
16
+
17
+
18
+ class Inferencer:
19
+ def __init__(self, checkpoint_path, model, use_cuda=True, multi_gpu=True):
20
+ self.checkpoint = torch.load(checkpoint_path)
21
+ self.use_cuda = use_cuda
22
+ self.model = model.eval()
23
+ if multi_gpu:
24
+ self.model = torch.nn.DataParallel(self.model)
25
+ if self.use_cuda:
26
+ self.model = self.model.cuda()
27
+ self.model.load_state_dict(self.checkpoint['state_dict'])
28
+
29
+ def __call__(self, input_img, decoding_only=False):
30
+ with torch.no_grad():
31
+ scale = 8
32
+ _, _, H, W = input_img.shape
33
+ if H % scale != 0 or W % scale != 0:
34
+ input_img = F.pad(input_img, [0, scale - W % scale, 0, scale - H % scale], mode='reflect')
35
+ if self.use_cuda:
36
+ input_img = input_img.cuda()
37
+ if decoding_only:
38
+ resColor = self.model(input_img, decoding_only)
39
+ if H % scale != 0 or W % scale != 0:
40
+ resColor = resColor[:, :, :H, :W]
41
+ return resColor
42
+ else:
43
+ resHalftone, resColor = self.model(input_img, decoding_only)
44
+ resHalftone = Quantize.apply((resHalftone + 1.0) * 0.5) * 2.0 - 1.
45
+ if H % scale != 0 or W % scale != 0:
46
+ resHalftone = resHalftone[:, :, :H, :W]
47
+ resColor = resColor[:, :, :H, :W]
48
+ return resHalftone, resColor
49
+
50
+
51
+ if __name__ == '__main__':
52
+ parser = argparse.ArgumentParser(description='invHalf')
53
+ parser.add_argument('--model', default=None, type=str,
54
+ help='model weight file path')
55
+ parser.add_argument('--decoding', action='store_true', default=False, help='restoration from halftone input')
56
+ parser.add_argument('--data_dir', default=None, type=str,
57
+ help='where to load input data (RGB images)')
58
+ parser.add_argument('--save_dir', default=None, type=str,
59
+ help='where to save the result')
60
+ args = parser.parse_args()
61
+
62
+ invhalfer = Inferencer(
63
+ checkpoint_path=args.model,
64
+ model=ResHalf(train=False)
65
+ )
66
+ save_dir = os.path.join(args.save_dir)
67
+ util.ensure_dir(save_dir)
68
+ test_imgs = glob(join(args.data_dir, '*.*g'))
69
+ print('------loaded %d images.' % len(test_imgs) )
70
+ for img in test_imgs:
71
+ print('[*] processing %s ...' % img)
72
+ if args.decoding:
73
+ input_img = cv2.imread(img, flags=cv2.IMREAD_GRAYSCALE) / 127.5 - 1.
74
+ c = invhalfer(util.img2tensor(input_img), decoding_only=True)
75
+ c = util.tensor2img(c / 2. + 0.5) * 255.
76
+ cv2.imwrite(join(save_dir, 'restored_' + img.split('/')[-1].split('.')[0] + '.png'), c)
77
+ else:
78
+ input_img = cv2.imread(img, flags=cv2.IMREAD_COLOR) / 127.5 - 1.
79
+ h, c = invhalfer(util.img2tensor(input_img), decoding_only=False)
80
+ h = util.tensor2img(h / 2. + 0.5) * 255.
81
+ c = util.tensor2img(c / 2. + 0.5) * 255.
82
+ cv2.imwrite(join(save_dir, 'halftone_' + img.split('/')[-1].split('.')[0] + '.png'), h)
83
+ cv2.imwrite(join(save_dir, 'restored_' + img.split('/')[-1].split('.')[0] + '.png'), c)
model/__init__.py ADDED
File without changes
model/base_module.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from torch.nn import functional as F
3
+ import torch
4
+ import numpy as np
5
+
6
+ def tensor2array(tensors):
7
+ arrays = tensors.detach().to("cpu").numpy()
8
+ return np.transpose(arrays, (0, 2, 3, 1))
9
+
10
+
11
+ class ResidualBlock(nn.Module):
12
+ def __init__(self, channels):
13
+ super(ResidualBlock, self).__init__()
14
+ self.conv = nn.Sequential(
15
+ nn.Conv2d(channels, channels, kernel_size=3, padding=1),
16
+ nn.ReLU(inplace=True),
17
+ nn.Conv2d(channels, channels, kernel_size=3, padding=1)
18
+ )
19
+
20
+ def forward(self, x):
21
+ residual = self.conv(x)
22
+ return x + residual
23
+
24
+
25
+ class DownsampleBlock(nn.Module):
26
+ def __init__(self, in_channels, out_channels, withConvRelu=True):
27
+ super(DownsampleBlock, self).__init__()
28
+ if withConvRelu:
29
+ self.conv = nn.Sequential(
30
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=2),
31
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
32
+ nn.ReLU(inplace=True)
33
+ )
34
+ else:
35
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=2)
36
+
37
+ def forward(self, x):
38
+ return self.conv(x)
39
+
40
+
41
+ class ConvBlock(nn.Module):
42
+ def __init__(self, inChannels, outChannels, convNum):
43
+ super(ConvBlock, self).__init__()
44
+ self.inConv = nn.Sequential(
45
+ nn.Conv2d(inChannels, outChannels, kernel_size=3, padding=1),
46
+ nn.ReLU(inplace=True)
47
+ )
48
+ layers = []
49
+ for _ in range(convNum - 1):
50
+ layers.append(nn.Conv2d(outChannels, outChannels, kernel_size=3, padding=1))
51
+ layers.append(nn.ReLU(inplace=True))
52
+ self.conv = nn.Sequential(*layers)
53
+
54
+ def forward(self, x):
55
+ x = self.inConv(x)
56
+ x = self.conv(x)
57
+ return x
58
+
59
+
60
+ class UpsampleBlock(nn.Module):
61
+ def __init__(self, in_channels, out_channels):
62
+ super(UpsampleBlock, self).__init__()
63
+ self.conv = nn.Sequential(
64
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=1),
65
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
66
+ nn.ReLU(inplace=True)
67
+ )
68
+
69
+ def forward(self, x):
70
+ x = F.interpolate(x, scale_factor=2, mode='nearest')
71
+ return self.conv(x)
72
+
73
+
74
+ class SkipConnection(nn.Module):
75
+ def __init__(self, channels):
76
+ super(SkipConnection, self).__init__()
77
+ self.conv = nn.Conv2d(2 * channels, channels, 1, bias=False)
78
+
79
+ def forward(self, x, y):
80
+ x = torch.cat((x, y), 1)
81
+ return self.conv(x)
model/hourglass.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from .base_module import ConvBlock, DownsampleBlock, ResidualBlock, SkipConnection, UpsampleBlock
3
+
4
+
5
+ class HourGlass(nn.Module):
6
+ def __init__(self, convNum=4, resNum=4, inChannel=6, outChannel=3):
7
+ super(HourGlass, self).__init__()
8
+ self.inConv = ConvBlock(inChannel, 64, convNum=2)
9
+ self.down1 = nn.Sequential(*[DownsampleBlock(64, 128, withConvRelu=False), ConvBlock(128, 128, convNum=2)])
10
+ self.down2 = nn.Sequential(
11
+ *[DownsampleBlock(128, 256, withConvRelu=False), ConvBlock(256, 256, convNum=convNum)])
12
+ self.down3 = nn.Sequential(
13
+ *[DownsampleBlock(256, 512, withConvRelu=False), ConvBlock(512, 512, convNum=convNum)])
14
+ self.residual = nn.Sequential(*[ResidualBlock(512) for _ in range(resNum)])
15
+ self.up3 = nn.Sequential(*[UpsampleBlock(512, 256), ConvBlock(256, 256, convNum=convNum)])
16
+ self.skip3 = SkipConnection(256)
17
+ self.up2 = nn.Sequential(*[UpsampleBlock(256, 128), ConvBlock(128, 128, convNum=2)])
18
+ self.skip2 = SkipConnection(128)
19
+ self.up1 = nn.Sequential(*[UpsampleBlock(128, 64), ConvBlock(64, 64, convNum=2)])
20
+ self.skip1 = SkipConnection(64)
21
+ self.outConv = nn.Sequential(
22
+ nn.Conv2d(64, 64, kernel_size=3, padding=1),
23
+ nn.ReLU(inplace=True),
24
+ nn.Conv2d(64, outChannel, kernel_size=1, padding=0)
25
+ )
26
+
27
+ def forward(self, x):
28
+ f1 = self.inConv(x)
29
+ f2 = self.down1(f1)
30
+ f3 = self.down2(f2)
31
+ f4 = self.down3(f3)
32
+ r4 = self.residual(f4)
33
+ r3 = self.skip3(self.up3(r4), f3)
34
+ r2 = self.skip2(self.up2(r3), f2)
35
+ r1 = self.skip1(self.up1(r2), f1)
36
+ y = self.outConv(r1)
37
+ return y
38
+
39
+
40
+ class ResidualHourGlass(nn.Module):
41
+ def __init__(self, resNum=4, inChannel=6, outChannel=3):
42
+ super(ResidualHourGlass, self).__init__()
43
+ self.inConv = nn.Conv2d(inChannel, 64, kernel_size=3, padding=1)
44
+ self.residualBefore = nn.Sequential(*[ResidualBlock(64) for _ in range(2)])
45
+ self.down1 = nn.Sequential(
46
+ *[DownsampleBlock(64, 128, withConvRelu=False), ConvBlock(128, 128, convNum=2)])
47
+ self.down2 = nn.Sequential(
48
+ *[DownsampleBlock(128, 256, withConvRelu=False), ConvBlock(256, 256, convNum=2)])
49
+ self.residual = nn.Sequential(*[ResidualBlock(256) for _ in range(resNum)])
50
+ self.up2 = nn.Sequential(*[UpsampleBlock(256, 128), ConvBlock(128, 128, convNum=2)])
51
+ self.skip2 = SkipConnection(128)
52
+ self.up1 = nn.Sequential(*[UpsampleBlock(128, 64), ConvBlock(64, 64, convNum=2)])
53
+ self.skip1 = SkipConnection(64)
54
+ self.residualAfter = nn.Sequential(*[ResidualBlock(64) for _ in range(2)])
55
+ self.outConv = nn.Sequential(
56
+ nn.Conv2d(64, outChannel, kernel_size=3, padding=1),
57
+ nn.Tanh()
58
+ )
59
+
60
+ def forward(self, x):
61
+ f1 = self.inConv(x)
62
+ f1 = self.residualBefore(f1)
63
+ f2 = self.down1(f1)
64
+ f3 = self.down2(f2)
65
+ r3 = self.residual(f3)
66
+ r2 = self.skip2(self.up2(r3), f2)
67
+ r1 = self.skip1(self.up1(r2), f1)
68
+ y = self.residualAfter(r1)
69
+ y = self.outConv(y)
70
+ return y
model/loss.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch.nn.functional as F
3
+ import torch
4
+ from utils.filters_tensor import GaussianSmoothing, bgr2gray
5
+ from utils import pytorch_ssim
6
+ from torch import nn
7
+ from .hourglass import HourGlass
8
+ from torchvision.models.vgg import vgg19
9
+
10
+
11
+ def l2_loss(y_input, y_target):
12
+ return F.mse_loss(y_input, y_target)
13
+
14
+
15
+ def l1_loss(y_input, y_target):
16
+ return F.l1_loss(y_input, y_target)
17
+
18
+
19
+ def gaussianL2(yInput, yTarget):
20
+ # data range [-1,1]
21
+ smoother = GaussianSmoothing(channels=1, kernel_size=11, sigma=2.0)
22
+ gaussianInput = smoother(yInput)
23
+ gaussianTarget = smoother(bgr2gray(yTarget))
24
+ return F.mse_loss(gaussianInput, gaussianTarget)
25
+
26
+
27
+ def binL1(yInput):
28
+ # data range is [-1,1]
29
+ return (yInput.abs() - 1.0).abs().mean()
30
+
31
+
32
+ def ssimLoss(yInput, yTarget):
33
+ # data range is [-1,1]
34
+ ssim = pytorch_ssim.ssim(yInput / 2. + 0.5, bgr2gray(yTarget / 2. + 0.5), window_size=11)
35
+ return 1. - ssim
36
+
37
+
38
+ class InverseHalf(nn.Module):
39
+ def __init__(self):
40
+ super(InverseHalf, self).__init__()
41
+ self.net = HourGlass(inChannel=1, outChannel=1)
42
+
43
+ def forward(self, x):
44
+ grayscale = self.net(x)
45
+ return grayscale
46
+
47
+
48
+ class FeatureLoss:
49
+ def __init__(self, pretrainedPath, requireGrad=False, multiGpu=True):
50
+ self.featureExactor = InverseHalf()
51
+ if multiGpu:
52
+ self.featureExactor = torch.nn.DataParallel(self.featureExactor).cuda()
53
+ print("-loading feature extractor: {} ...".format(pretrainedPath))
54
+ checkpoint = torch.load(pretrainedPath)
55
+ self.featureExactor.load_state_dict(checkpoint['state_dict'])
56
+ print("-feature network loaded")
57
+ if not requireGrad:
58
+ for param in self.featureExactor.parameters():
59
+ param.requires_grad = False
60
+
61
+ def __call__(self, yInput, yTarget):
62
+ inFeature = self.featureExactor(yInput)
63
+ return l2_loss(inFeature, yTarget)
64
+
65
+
66
+ class Vgg19Loss:
67
+ def __init__(self, multiGpu=True):
68
+ os.environ['TORCH_HOME']='~/bigdata/0ProgramS/checkpoints'
69
+ # data in BGR format, [0,1] range
70
+ self.mean = [0.485, 0.456, 0.406]
71
+ self.mean.reverse()
72
+ self.std = [0.229, 0.224, 0.225]
73
+ self.std.reverse()
74
+ vgg = vgg19(pretrained=True)
75
+ # maxpoll after conv4_4
76
+ self.featureExactor = nn.Sequential(*list(vgg.features)[:28]).eval()
77
+ for param in self.featureExactor.parameters():
78
+ param.requires_grad = False
79
+ if multiGpu:
80
+ self.featureExactor = torch.nn.DataParallel(self.featureExactor).cuda()
81
+ print('[*] Vgg19Loss init!')
82
+
83
+ def normalize(self, tensor):
84
+ tensor = tensor.clone()
85
+ mean = torch.as_tensor(self.mean, dtype=torch.float32, device=tensor.device)
86
+ std = torch.as_tensor(self.std, dtype=torch.float32, device=tensor.device)
87
+ tensor.sub_(mean[None, :, None, None]).div_(std[None, :, None, None])
88
+ return tensor
89
+
90
+ def __call__(self, yInput, yTarget):
91
+ inFeature = self.featureExactor(self.normalize(yInput).flip(1))
92
+ targetFeature = self.featureExactor(self.normalize(yTarget).flip(1))
93
+ return l2_loss(inFeature, targetFeature)
model/model.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.autograd import Function
5
+
6
+ from .hourglass import HourGlass
7
+ from utils.dct import DCT_Lowfrequency
8
+ from utils.filters_tensor import bgr2gray
9
+
10
+ from collections import OrderedDict
11
+ import numpy as np
12
+
13
+
14
+ class Quantize(Function):
15
+ @staticmethod
16
+ def forward(ctx, x):
17
+ ctx.save_for_backward(x)
18
+ y = x.round()
19
+ return y
20
+
21
+ @staticmethod
22
+ def backward(ctx, grad_output):
23
+ inputX = ctx.saved_tensors
24
+ return grad_output
25
+
26
+
27
+ class ResHalf(nn.Module):
28
+ def __init__(self, train=True, warm_stage=False):
29
+ super(ResHalf, self).__init__()
30
+ self.encoder = HourGlass(inChannel=4, outChannel=1, resNum=4, convNum=4)
31
+ self.decoder = HourGlass(inChannel=1, outChannel=3, resNum=4, convNum=4)
32
+ self.dcter = DCT_Lowfrequency(size=256, fLimit=50)
33
+ # quantize [-1,1] data to be {-1,1}
34
+ self.quantizer = lambda x: Quantize.apply(0.5 * (x + 1.)) * 2. - 1.
35
+ self.isTrain = train
36
+ if warm_stage:
37
+ for name, param in self.decoder.named_parameters():
38
+ param.requires_grad = False
39
+
40
+ def add_impluse_noise(self, input_halfs, p=0.0):
41
+ N,C,H,W = input_halfs.shape
42
+ SNR = 1-p
43
+ np_input_halfs = input_halfs.detach().to("cpu").numpy()
44
+ np_input_halfs = np.transpose(np_input_halfs, (0, 2, 3, 1))
45
+ for i in range(N):
46
+ mask = np.random.choice((0, 1, 2), size=(H, W, 1), p=[SNR, (1 - SNR) / 2., (1 - SNR) / 2.])
47
+ np_input_halfs[i, mask==1] = 1.0
48
+ np_input_halfs[i, mask==2] = -1.0
49
+ return torch.from_numpy(np_input_halfs.transpose((0, 3, 1, 2))).to(input_halfs.device)
50
+
51
+ def forward(self, input_img, decoding_only=False):
52
+ if decoding_only:
53
+ halfResQ = self.quantizer(input_img)
54
+ restored = self.decoder(halfResQ)
55
+ return restored
56
+
57
+ noise = torch.randn_like(input_img) * 0.3
58
+ halfRes = self.encoder(torch.cat((input_img, noise[:,:1,:,:]), dim=1))
59
+ halfResQ = self.quantizer(halfRes)
60
+ restored = self.decoder(halfResQ)
61
+ if self.isTrain:
62
+ halfDCT = self.dcter(halfRes / 2. + 0.5)
63
+ refDCT = self.dcter(bgr2gray(input_img / 2. + 0.5))
64
+ return halfRes, halfDCT, refDCT, restored
65
+ else:
66
+ return halfRes, restored
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ addict
2
+ future
3
+ numpy
4
+ opencv-python
5
+ pandas
6
+ Pillow
7
+ pyyaml
8
+ requests
9
+ scikit-image
10
+ scikit-learn
11
+ scipy
12
+ torch>=1.8.0
13
+ torchvision
14
+ tensorboardx>=2.4
15
+ tqdm
16
+ yapf
17
+ lpips
scripts/invhalf_full.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "invhalf_full",
3
+ "initial_ckpt": "checkpoints/model_warm.pth.tar",
4
+ "model": "ResHalf",
5
+ "data_dir": "dataset/",
6
+ "save_dir": "./",
7
+ "trainer": {
8
+ "epochs": 1000,
9
+ "save_epochs": 5
10
+ },
11
+ "data_loader": {
12
+ "dataset": "HalftoneVOC2012.json",
13
+ "special_set": "special_color.json",
14
+ "batch_size": 1,
15
+ "shuffle": true,
16
+ "num_workers": 32
17
+ },
18
+ "quantizeLoss": "binL1",
19
+ "quantizeLossWeight": 0.1,
20
+ "toneLoss": "gaussianL2",
21
+ "toneLossWeight": 0.6,
22
+ "structureLoss": "ssimLoss",
23
+ "structureLossWeight": 0.0,
24
+ "restoreLoss": "l2_loss",
25
+ "restoreLossWeight": 1.0,
26
+ "blueNoiseLossWeight": 0.3,
27
+ "vggLossWeight": 0.0002,
28
+ "cuda": true,
29
+ "multi-gpus": true,
30
+ "optimizer_type": "Adam",
31
+ "optimizer": {
32
+ "lr": 0.0001,
33
+ "weight_decay": 0
34
+ },
35
+ "lr_sheduler": {
36
+ "factor": 0.5,
37
+ "patience": 3,
38
+ "threshold": 1e-05,
39
+ "cooldown": 0
40
+ },
41
+ "seed": 131
42
+ }
scripts/invhalf_warm.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "invhalf_warmup",
3
+ "model": "ResHalf",
4
+ "data_dir": "dataset/",
5
+ "save_dir": "./",
6
+ "trainer": {
7
+ "epochs": 1000,
8
+ "save_epochs": 5
9
+ },
10
+ "data_loader": {
11
+ "dataset": "HalftoneVOC2012.json",
12
+ "special_set": "special_color.json",
13
+ "batch_size": 8,
14
+ "shuffle": true,
15
+ "num_workers": 32
16
+ },
17
+ "quantizeLoss": "binL1",
18
+ "quantizeLossWeight": 0.2,
19
+ "toneLoss": "gaussianL2",
20
+ "toneLossWeight": 0.6,
21
+ "structureLoss": "ssimLoss",
22
+ "structureLossWeight": 0.0,
23
+ "blueNoiseLossWeight": 0.3,
24
+ "featureLossWeight": 1.0,
25
+ "cuda": true,
26
+ "multi-gpus": true,
27
+ "optimizer_type": "Adam",
28
+ "optimizer": {
29
+ "lr": 0.0001,
30
+ "weight_decay": 0
31
+ },
32
+ "lr_sheduler": {
33
+ "factor": 0.5,
34
+ "patience": 3,
35
+ "threshold": 1e-05,
36
+ "cooldown": 0
37
+ },
38
+ "seed": 131
39
+ }
train.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, glob, datetime, time
2
+ import argparse, json
3
+
4
+ import torch
5
+ import torch.optim as optim
6
+ from torch.autograd import Variable
7
+ import torchvision
8
+ from torch.utils.data import DataLoader
9
+ from torch.backends import cudnn
10
+
11
+ from model.base_module import tensor2array
12
+ from model.model import ResHalf
13
+ from model.loss import *
14
+ from utils.dataset import HalftoneVOC2012 as Dataset
15
+ from utils.util import ensure_dir, save_list, save_images_from_batch
16
+
17
+
18
+ class Trainer():
19
+ def __init__(self, config, resume):
20
+ self.config = config
21
+ self.name = config['name']
22
+ self.resume_path = resume
23
+ self.n_epochs = config['trainer']['epochs']
24
+ self.with_cuda = config['cuda'] and torch.cuda.is_available()
25
+ self.seed = config['seed']
26
+ self.start_epoch = 0
27
+ self.save_freq = config['trainer']['save_epochs']
28
+ self.checkpoint_dir = os.path.join(config['save_dir'], self.name)
29
+ ensure_dir(self.checkpoint_dir)
30
+ json.dump(config, open(os.path.join(self.checkpoint_dir, 'config.json'), 'w'),
31
+ indent=4, sort_keys=False)
32
+ print("@Workspace: %s *************"%self.checkpoint_dir)
33
+ self.cache = os.path.join(self.checkpoint_dir, 'train_cache')
34
+ self.val_halftone = os.path.join(self.cache, 'halftone')
35
+ self.val_restored = os.path.join(self.cache, 'restored')
36
+ ensure_dir(self.val_halftone)
37
+ ensure_dir(self.val_restored)
38
+
39
+ ## model
40
+ self.model = eval(config['model'])()
41
+ if self.config['multi-gpus']:
42
+ self.model = torch.nn.DataParallel(self.model).cuda()
43
+ elif self.with_cuda:
44
+ self.model = self.model.cuda()
45
+
46
+ ## optimizer
47
+ self.optimizer = getattr(optim, config['optimizer_type'])(self.model.parameters(), **config['optimizer'])
48
+ self.lr_sheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, **config['lr_sheduler'])
49
+
50
+ ## dataset loader
51
+ with open(os.path.join(config['data_dir'], config['data_loader']['dataset'])) as f:
52
+ dataset = json.load(f)
53
+ train_set = Dataset(dataset['train'])
54
+ self.train_data_loader = DataLoader(train_set, batch_size=config['data_loader']['batch_size'],
55
+ shuffle=config['data_loader']['shuffle'],
56
+ num_workers=config['data_loader']['num_workers'])
57
+ val_set = Dataset(dataset['val'])
58
+ self.valid_data_loader = DataLoader(val_set, batch_size=config['data_loader']['batch_size'],
59
+ shuffle=False,
60
+ num_workers=config['data_loader']['num_workers'])
61
+ # special dataloader: constant color images
62
+ with open(os.path.join(config['data_dir'], config['data_loader']['special_set'])) as f:
63
+ dataset = json.load(f)
64
+ specialSet = Dataset(dataset['train'])
65
+ self.specialDataloader = DataLoader(specialSet, batch_size=config['data_loader']['batch_size'],
66
+ shuffle=config['data_loader']['shuffle'],
67
+ num_workers=config['data_loader']['num_workers'])
68
+
69
+ ## loss function
70
+ self.quantizeLoss = eval(config['quantizeLoss'])
71
+ self.quantizeLossWeight = config['quantizeLossWeight']
72
+ self.toneLoss = eval(config['toneLoss'])
73
+ self.toneLossWeight = config['toneLossWeight']
74
+ self.structureLoss = eval(config['structureLoss'])
75
+ self.structureLossWeight = config['structureLossWeight']
76
+ self.restoreLoss = eval(config['restoreLoss'])
77
+ self.restoreLossWeight = config['restoreLossWeight']
78
+ # quantize [-1,1] data to be {-1,1}
79
+ self.quantizer = lambda x: Quantize.apply(0.5 * (x + 1.)) * 2. - 1.
80
+ self.blueNoiseLossWeight = config['blueNoiseLossWeight']
81
+ self.vggloss = Vgg19Loss()
82
+ self.vggLossWeight = config['vggLossWeight']
83
+
84
+ # resume checkpoint or load warm-up checkpoint
85
+ checkpt_path = self.config['initial_ckpt']
86
+ if self.resume_path:
87
+ checkpt_path = self.resume_path
88
+ assert os.path.exists(checkpt_path), 'Invalid checkpoint Path: %s' % checkpt_path
89
+ self.load_checkpoint(checkpt_path)
90
+
91
+
92
+ def _train(self):
93
+ torch.manual_seed(self.config['seed'])
94
+ torch.cuda.manual_seed(self.config['seed'])
95
+ cudnn.benchmark = True
96
+
97
+ start_time = time.time()
98
+ self.monitor_best = 999.
99
+ for epoch in range(self.start_epoch, self.n_epochs + 1):
100
+ ep_st = time.time()
101
+ epoch_loss = self._train_epoch(epoch)
102
+ # perform lr_sheduler
103
+ self.lr_sheduler.step(epoch_loss['total_loss'])
104
+ epoch_lr = self.optimizer.state_dict()['param_groups'][0]['lr']
105
+ epoch_metric = self._valid_epoch(epoch)
106
+ print("[*] --- epoch: %d/%d | loss: %4.4f | metric: %4.4f | time-consumed: %4.2f ---" % \
107
+ (epoch+1, self.n_epochs, epoch_loss['total_loss'], epoch_metric, (time.time()-ep_st)))
108
+
109
+ # save losses and learning rate
110
+ epoch_loss['metric'] = epoch_metric
111
+ epoch_loss['lr'] = epoch_lr
112
+ self.save_loss(epoch_loss, epoch)
113
+ if ((epoch+1) % self.save_freq == 0 or epoch == (self.n_epochs-1)):
114
+ print('---------- saving model ...')
115
+ self.save_checkpoint(epoch)
116
+ if self.monitor_best > epoch_metric:
117
+ self.monitor_best = epoch_metric
118
+ self.save_checkpoint(epoch, save_best=True)
119
+
120
+ print("Training finished! consumed %f sec" % (time.time() - start_time))
121
+
122
+
123
+ def _to_variable(self, data, target):
124
+ data, target = Variable(data), Variable(target)
125
+ if self.with_cuda:
126
+ data, target = data.cuda(), target.cuda()
127
+ return data, target
128
+
129
+
130
+ def _train_epoch(self, epoch):
131
+ self.model.train()
132
+ total_loss, quantize_loss, restore_loss = 0, 0, 0
133
+ tone_loss, structure_loss, blue_noise_loss = 0, 0, 0
134
+
135
+ specialIter = iter(self.specialDataloader)
136
+ time_stamp = time.time()
137
+ for batch_idx, (color, halftone) in enumerate(self.train_data_loader):
138
+ color, halftone = self._to_variable(color, halftone)
139
+ # special data
140
+ try:
141
+ specialColor, specialHalftone = next(specialIter)
142
+ except StopIteration:
143
+ # reinitialize data loader
144
+ specialIter = iter(self.specialDataloader)
145
+ specialColor, specialHalftone = next(specialIter)
146
+ specialColor, specialHalftone = self._to_variable(specialColor, specialHalftone)
147
+ self.optimizer.zero_grad()
148
+ output = self.model(color, halftone)
149
+ quantizeLoss = self.quantizeLoss(output[0])
150
+ toneLoss = self.toneLoss(output[0], color)
151
+ structureLoss = self.structureLoss(output[0], color)
152
+
153
+ # restore
154
+ restoredColor = output[-1]
155
+ restoreLoss = self.restoreLoss(restoredColor, color)
156
+ vggLoss = self.vggloss(restoredColor / 2. + 0.5, color / 2. + 0.5)
157
+
158
+ # special data
159
+ output = self.model(specialColor, specialHalftone)
160
+ toneLossSpecial = self.toneLoss(output[0], specialColor)
161
+ blueNoiseLoss = l1_loss(output[1], output[2])
162
+ quantizeLossSpecial = self.quantizeLoss(output[0])
163
+ loss = (self.toneLossWeight * toneLoss + self.blueNoiseLossWeight*toneLossSpecial) \
164
+ + self.quantizeLossWeight * (0.5*quantizeLoss + 0.5*quantizeLossSpecial) \
165
+ + self.structureLossWeight * structureLoss \
166
+ + self.blueNoiseLossWeight * blueNoiseLoss \
167
+ + self.vggLossWeight * vggLoss \
168
+ + self.restoreLossWeight * restoreLoss
169
+
170
+ loss.backward()
171
+ # apply grad clip to make training roboust
172
+ # torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.0001)
173
+ self.optimizer.step()
174
+
175
+ total_loss += loss.item()
176
+ quantize_loss += quantizeLoss.item()
177
+ restore_loss += (self.restoreLossWeight*restoreLoss + self.vggLossWeight*vggLoss).item()
178
+ tone_loss += toneLoss.item()
179
+ structure_loss += structureLoss.item()
180
+ blue_noise_loss += blueNoiseLoss.item()
181
+ if batch_idx % 100 == 0:
182
+ tm = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
183
+ print("%s >> [%d/%d] iter:%d loss:%4.4f "%(tm, epoch+1, self.n_epochs, batch_idx+1, loss.item()))
184
+
185
+ epoch_loss = dict()
186
+ epoch_loss['total_loss'] = total_loss / (batch_idx+1)
187
+ epoch_loss['quantize_loss'] = quantize_loss / (batch_idx+1)
188
+ epoch_loss['tone_loss'] = tone_loss / (batch_idx+1)
189
+ epoch_loss['structure_loss'] = structure_loss / (batch_idx+1)
190
+ epoch_loss['bluenoise_loss'] = blue_noise_loss / (batch_idx+1)
191
+ epoch_loss['restore_loss'] = restore_loss / (batch_idx+1)
192
+
193
+ return epoch_loss
194
+
195
+
196
+ def _valid_epoch(self, epoch):
197
+ self.model.eval()
198
+ total_loss = 0
199
+ with torch.no_grad():
200
+ for batch_idx, (color, halftone) in enumerate(self.valid_data_loader):
201
+ color, halftone = self._to_variable(color, halftone)
202
+ output = self.model(color, halftone)
203
+ quantizeLoss = self.quantizeLoss(output[0])
204
+ toneLoss = self.toneLoss(output[0], color)
205
+ structureLoss = self.structureLoss(output[0], color)
206
+ # restore
207
+ restoredColor = output[-1]
208
+ restoreLoss = self.restoreLoss(restoredColor, color)
209
+ vggLoss = self.vggloss(restoredColor / 2. + 0.5, color / 2. + 0.5)
210
+
211
+ loss = self.toneLossWeight * toneLoss \
212
+ + self.quantizeLossWeight * quantizeLoss \
213
+ + self.structureLossWeight * structureLoss \
214
+ + self.vggLossWeight * vggLoss \
215
+ + self.restoreLossWeight * restoreLoss
216
+
217
+ total_loss += loss.item()
218
+ #! save intermediate images
219
+ gray_imgs = tensor2array(output[0])
220
+ color_imgs = tensor2array(output[-1])
221
+ save_images_from_batch(gray_imgs, self.val_halftone, None, batch_idx)
222
+ save_images_from_batch(color_imgs, self.val_restored, None, batch_idx)
223
+
224
+ return total_loss
225
+
226
+
227
+ def save_loss(self, epoch_loss, epoch):
228
+ if epoch == 0:
229
+ for key in epoch_loss:
230
+ save_list(os.path.join(self.cache, key), [epoch_loss[key]], append_mode=False)
231
+ else:
232
+ for key in epoch_loss:
233
+ save_list(os.path.join(self.cache, key), [epoch_loss[key]], append_mode=True)
234
+
235
+
236
+ def load_checkpoint(self, checkpt_path):
237
+ print("-loading checkpoint from: {} ...".format(checkpt_path))
238
+ if self.resume_path:
239
+ checkpoint = torch.load(checkpt_path)
240
+ self.start_epoch = checkpoint['epoch'] + 1
241
+ self.monitor_best = checkpoint['monitor_best']
242
+ self.model.load_state_dict(checkpoint['state_dict'])
243
+ self.optimizer.load_state_dict(checkpoint['optimizer'])
244
+ else:
245
+ checkpoint = torch.load(checkpt_path)
246
+ self.model.load_state_dict(checkpoint['state_dict'], strict=False)
247
+ print("-pretrained checkpoint loaded.")
248
+
249
+
250
+ def save_checkpoint(self, epoch, save_best=False):
251
+ state = {
252
+ 'epoch': epoch,
253
+ 'state_dict': self.model.state_dict(),
254
+ 'optimizer': self.optimizer.state_dict(),
255
+ 'monitor_best': self.monitor_best
256
+ }
257
+ save_path = os.path.join(self.checkpoint_dir, 'model_last.pth.tar')
258
+ if save_best:
259
+ save_path = os.path.join(self.checkpoint_dir, 'model_best.pth.tar')
260
+ torch.save(state, save_path)
261
+
262
+
263
+ if __name__ == '__main__':
264
+ parser = argparse.ArgumentParser(description='InvHalf')
265
+ parser.add_argument('-c', '--config', default=None, type=str,
266
+ help='config file path (default: None)')
267
+ parser.add_argument('-r', '--resume', default=None, type=str,
268
+ help='path to latest checkpoint (default: None)')
269
+ args = parser.parse_args()
270
+ config_dict = json.load(open(args.config))
271
+ node = Trainer(config_dict, resume=args.resume)
272
+ node._train()
train_warm.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, glob, datetime, time
2
+ import argparse, json
3
+
4
+ import torch
5
+ import torch.optim as optim
6
+ from torch.autograd import Variable
7
+ import torchvision
8
+ from torch.utils.data import DataLoader
9
+ from torch.backends import cudnn
10
+
11
+ from model.base_module import tensor2array
12
+ from model.model import ResHalf
13
+ from model.loss import *
14
+ from utils.dataset import HalftoneVOC2012 as Dataset
15
+ from utils.util import ensure_dir, save_list, save_images_from_batch
16
+
17
+
18
+ class Trainer():
19
+ def __init__(self, config, resume):
20
+ self.config = config
21
+ self.name = config['name']
22
+ self.resume_path = resume
23
+ self.n_epochs = config['trainer']['epochs']
24
+ self.with_cuda = config['cuda'] and torch.cuda.is_available()
25
+ self.seed = config['seed']
26
+ self.start_epoch = 0
27
+ self.save_freq = config['trainer']['save_epochs']
28
+ self.checkpoint_dir = os.path.join(config['save_dir'], self.name)
29
+ ensure_dir(self.checkpoint_dir)
30
+ json.dump(config, open(os.path.join(self.checkpoint_dir, 'config.json'), 'w'),
31
+ indent=4, sort_keys=False)
32
+ print("@Workspace: %s *************"%self.checkpoint_dir)
33
+ self.cache = os.path.join(self.checkpoint_dir, 'train_cache')
34
+ self.val_halftone = os.path.join(self.cache, 'halftone')
35
+ self.val_restored = os.path.join(self.cache, 'restored')
36
+ ensure_dir(self.val_halftone)
37
+ ensure_dir(self.val_restored)
38
+
39
+ ## model
40
+ self.model = eval(config['model'])(train=True, warm_stage=True)
41
+ if self.config['multi-gpus']:
42
+ self.model = torch.nn.DataParallel(self.model).cuda()
43
+ elif self.with_cuda:
44
+ self.model = self.model.cuda()
45
+
46
+ ## optimizer
47
+ self.optimizer = getattr(optim, config['optimizer_type'])(self.model.parameters(), **config['optimizer'])
48
+ self.lr_sheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, **config['lr_sheduler'])
49
+
50
+ ## dataset loader
51
+ with open(os.path.join(config['data_dir'], config['data_loader']['dataset'])) as f:
52
+ dataset = json.load(f)
53
+ train_set = Dataset(dataset['train'])
54
+ self.train_data_loader = DataLoader(train_set, batch_size=config['data_loader']['batch_size'],
55
+ shuffle=config['data_loader']['shuffle'],
56
+ num_workers=config['data_loader']['num_workers'])
57
+ val_set = Dataset(dataset['val'])
58
+ self.valid_data_loader = DataLoader(val_set, batch_size=config['data_loader']['batch_size'],
59
+ shuffle=False,
60
+ num_workers=config['data_loader']['num_workers'])
61
+ # special dataloader: constant color images
62
+ with open(os.path.join(config['data_dir'], config['data_loader']['special_set'])) as f:
63
+ dataset = json.load(f)
64
+ specialSet = Dataset(dataset['train'])
65
+ self.specialDataloader = DataLoader(specialSet, batch_size=config['data_loader']['batch_size'],
66
+ shuffle=config['data_loader']['shuffle'],
67
+ num_workers=config['data_loader']['num_workers'])
68
+
69
+ ## loss function
70
+ self.quantizeLoss = eval(config['quantizeLoss'])
71
+ self.quantizeLossWeight = config['quantizeLossWeight']
72
+ self.toneLoss = eval(config['toneLoss'])
73
+ self.toneLossWeight = config['toneLossWeight']
74
+ self.structureLoss = eval(config['structureLoss'])
75
+ self.structureLossWeight = config['structureLossWeight']
76
+ # quantize [-1,1] data to be {-1,1}
77
+ self.quantizer = lambda x: Quantize.apply(0.5 * (x + 1.)) * 2. - 1.
78
+ self.blueNoiseLossWeight = config['blueNoiseLossWeight']
79
+ self.featureLoss = FeatureLoss(
80
+ requireGrad=False, pretrainedPath='checkpoints/invhalftone_checkpt/model_best.pth.tar')
81
+ self.featureLossWeight = config['featureLossWeight']
82
+
83
+ # resume checkpoint
84
+ if self.resume_path:
85
+ assert os.path.exists(resume_path), 'Invalid checkpoint Path: %s' % resume_path
86
+ self.load_checkpoint(self.resume_path)
87
+
88
+
89
+ def _train(self):
90
+ torch.manual_seed(self.config['seed'])
91
+ torch.cuda.manual_seed(self.config['seed'])
92
+ cudnn.benchmark = True
93
+
94
+ start_time = time.time()
95
+ self.monitor_best = 999.
96
+ for epoch in range(self.start_epoch, self.n_epochs + 1):
97
+ ep_st = time.time()
98
+ epoch_loss = self._train_epoch(epoch)
99
+ # perform lr_sheduler
100
+ self.lr_sheduler.step(epoch_loss['total_loss'])
101
+ epoch_lr = self.optimizer.state_dict()['param_groups'][0]['lr']
102
+ epoch_metric = self._valid_epoch(epoch)
103
+ print("[*] --- epoch: %d/%d | loss: %4.4f | metric: %4.4f | time-consumed: %4.2f ---" % \
104
+ (epoch+1, self.n_epochs, epoch_loss['total_loss'], epoch_metric, (time.time()-ep_st)))
105
+
106
+ # save losses and learning rate
107
+ epoch_loss['metric'] = epoch_metric
108
+ epoch_loss['lr'] = epoch_lr
109
+ self.save_loss(epoch_loss, epoch)
110
+ if ((epoch+1) % self.save_freq == 0 or epoch == (self.n_epochs-1)):
111
+ print('---------- saving model ...')
112
+ self.save_checkpoint(epoch)
113
+ if self.monitor_best > epoch_metric:
114
+ self.monitor_best = epoch_metric
115
+ self.save_checkpoint(epoch, save_best=True)
116
+
117
+ print("Training finished! consumed %f sec" % (time.time() - start_time))
118
+
119
+
120
+ def _to_variable(self, data, target):
121
+ data, target = Variable(data), Variable(target)
122
+ if self.with_cuda:
123
+ data, target = data.cuda(), target.cuda()
124
+ return data, target
125
+
126
+
127
+ def _train_epoch(self, epoch):
128
+ self.model.train()
129
+ total_loss, quantize_loss, feature_loss = 0, 0, 0
130
+ tone_loss, structure_loss, blue_noise_loss = 0, 0, 0
131
+
132
+ specialIter = iter(self.specialDataloader)
133
+ time_stamp = time.time()
134
+ for batch_idx, (color, halftone) in enumerate(self.train_data_loader):
135
+ color, halftone = self._to_variable(color, halftone)
136
+ # special data
137
+ try:
138
+ specialColor, specialHalftone = next(specialIter)
139
+ except StopIteration:
140
+ # reinitialize data loader
141
+ specialIter = iter(self.specialDataloader)
142
+ specialColor, specialHalftone = next(specialIter)
143
+ specialColor, specialHalftone = self._to_variable(specialColor, specialHalftone)
144
+ self.optimizer.zero_grad()
145
+ output = self.model(color, halftone)
146
+ quantizeLoss = self.quantizeLoss(output[0])
147
+ toneLoss = self.toneLoss(output[0], color)
148
+ structureLoss = self.structureLoss(output[0], color)
149
+ featureLoss = self.featureLoss(output[0], bgr2gray(color))
150
+
151
+ # special data
152
+ output = self.model(specialColor, specialHalftone)
153
+ toneLossSpecial = self.toneLoss(output[0], specialColor)
154
+ blueNoiseLoss = l1_loss(output[1], output[2])
155
+ quantizeLossSpecial = self.quantizeLoss(output[0])
156
+ loss = (self.toneLossWeight * toneLoss + self.blueNoiseLossWeight*toneLossSpecial) \
157
+ + self.quantizeLossWeight * (0.5*quantizeLoss + 0.5*quantizeLossSpecial) \
158
+ + self.structureLossWeight * structureLoss \
159
+ + self.blueNoiseLossWeight * blueNoiseLoss \
160
+ + self.featureLossWeight * featureLoss
161
+
162
+ loss.backward()
163
+ # apply grad clip to make training roboust
164
+ # torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.0001)
165
+ self.optimizer.step()
166
+
167
+ total_loss += loss.item()
168
+ quantize_loss += quantizeLoss.item()
169
+ feature_loss += featureLoss.item()
170
+ tone_loss += toneLoss.item()
171
+ structure_loss += structureLoss.item()
172
+ blue_noise_loss += blueNoiseLoss.item()
173
+ if batch_idx % 100 == 0:
174
+ tm = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
175
+ print("%s >> [%d/%d] iter:%d loss:%4.4f "%(tm, epoch+1, self.n_epochs, batch_idx+1, loss.item()))
176
+
177
+ epoch_loss = dict()
178
+ epoch_loss['total_loss'] = total_loss / (batch_idx+1)
179
+ epoch_loss['quantize_loss'] = quantize_loss / (batch_idx+1)
180
+ epoch_loss['tone_loss'] = tone_loss / (batch_idx+1)
181
+ epoch_loss['structure_loss'] = structure_loss / (batch_idx+1)
182
+ epoch_loss['bluenoise_loss'] = blue_noise_loss / (batch_idx+1)
183
+ epoch_loss['feature_loss'] = feature_loss / (batch_idx+1)
184
+
185
+ return epoch_loss
186
+
187
+
188
+ def _valid_epoch(self, epoch):
189
+ self.model.eval()
190
+ total_loss = 0
191
+ with torch.no_grad():
192
+ for batch_idx, (color, halftone) in enumerate(self.valid_data_loader):
193
+ color, halftone = self._to_variable(color, halftone)
194
+ output = self.model(color, halftone)
195
+ quantizeLoss = self.quantizeLoss(output[0])
196
+ toneLoss = self.toneLoss(output[0], color)
197
+ structureLoss = self.structureLoss(output[0], color)
198
+ featureLoss = self.featureLoss(output[0], bgr2gray(color))
199
+
200
+ loss = self.toneLossWeight * toneLoss \
201
+ + self.quantizeLossWeight * quantizeLoss \
202
+ + self.structureLossWeight * structureLoss \
203
+ + self.featureLossWeight * featureLoss
204
+
205
+ total_loss += loss.item()
206
+ #! save intermediate images
207
+ gray_imgs = tensor2array(output[0])
208
+ color_imgs = tensor2array(output[-1])
209
+ save_images_from_batch(gray_imgs, self.val_halftone, None, batch_idx)
210
+ save_images_from_batch(color_imgs, self.val_restored, None, batch_idx)
211
+
212
+ return total_loss
213
+
214
+
215
+ def save_loss(self, epoch_loss, epoch):
216
+ if epoch == 0:
217
+ for key in epoch_loss:
218
+ save_list(os.path.join(self.cache, key), [epoch_loss[key]], append_mode=False)
219
+ else:
220
+ for key in epoch_loss:
221
+ save_list(os.path.join(self.cache, key), [epoch_loss[key]], append_mode=True)
222
+
223
+
224
+ def load_checkpoint(self, checkpt_path):
225
+ print("-loading checkpoint from: {} ...".format(checkpt_path))
226
+ checkpoint = torch.load(checkpt_path)
227
+ self.start_epoch = checkpoint['epoch'] + 1
228
+ self.monitor_best = checkpoint['monitor_best']
229
+ self.model.load_state_dict(checkpoint['state_dict'])
230
+ self.optimizer.load_state_dict(checkpoint['optimizer'])
231
+ print("-pretrained checkpoint loaded.")
232
+
233
+
234
+ def save_checkpoint(self, epoch, save_best=False):
235
+ state = {
236
+ 'epoch': epoch,
237
+ 'state_dict': self.model.state_dict(),
238
+ 'optimizer': self.optimizer.state_dict(),
239
+ 'monitor_best': self.monitor_best
240
+ }
241
+ save_path = os.path.join(self.checkpoint_dir, 'model_last.pth.tar')
242
+ if save_best:
243
+ save_path = os.path.join(self.checkpoint_dir, 'model_best.pth.tar')
244
+ torch.save(state, save_path)
245
+
246
+
247
+ if __name__ == '__main__':
248
+ parser = argparse.ArgumentParser(description='InvHalf')
249
+ parser.add_argument('-c', '--config', default=None, type=str,
250
+ help='config file path (default: None)')
251
+ parser.add_argument('-r', '--resume', default=None, type=str,
252
+ help='path to latest checkpoint (default: None)')
253
+ args = parser.parse_args()
254
+ config_dict = json.load(open(args.config))
255
+ node = Trainer(config_dict, resume=args.resume)
256
+ node._train()
utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .util import *
utils/_dct.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+
6
+
7
+ def dct1(x):
8
+ """
9
+ Discrete Cosine Transform, Type I
10
+
11
+ :param x: the input signal
12
+ :return: the DCT-I of the signal over the last dimension
13
+ """
14
+ x_shape = x.shape
15
+ x = x.view(-1, x_shape[-1])
16
+
17
+ #return torch.rfft(torch.cat([x, x.flip([1])[:, 1:-1]], dim=1), 1)[:, :, 0].view(*x_shape)
18
+ return torch.fft.fft(torch.cat([x, x.flip([1])[:, 1:-1]], dim=1), 1)[:, :, 0].view(*x_shape)
19
+
20
+
21
+ def idct1(X):
22
+ """
23
+ The inverse of DCT-I, which is just a scaled DCT-I
24
+
25
+ Our definition if idct1 is such that idct1(dct1(x)) == x
26
+
27
+ :param X: the input signal
28
+ :return: the inverse DCT-I of the signal over the last dimension
29
+ """
30
+ n = X.shape[-1]
31
+ return dct1(X) / (2 * (n - 1))
32
+
33
+
34
+ def dct(x, norm=None):
35
+ """
36
+ Discrete Cosine Transform, Type II (a.k.a. the DCT)
37
+
38
+ For the meaning of the parameter `norm`, see:
39
+ https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
40
+
41
+ :param x: the input signal
42
+ :param norm: the normalization, None or 'ortho'
43
+ :return: the DCT-II of the signal over the last dimension
44
+ """
45
+ x_shape = x.shape
46
+ N = x_shape[-1]
47
+ x = x.contiguous().view(-1, N)
48
+
49
+ v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)
50
+
51
+ #Vc = torch.fft.rfft(v, 1, onesided=False)
52
+ Vc = torch.view_as_real(torch.fft.fft(v, dim=1))
53
+
54
+ k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N)
55
+ W_r = torch.cos(k)
56
+ W_i = torch.sin(k)
57
+
58
+ V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i
59
+
60
+ if norm == 'ortho':
61
+ V[:, 0] /= np.sqrt(N) * 2
62
+ V[:, 1:] /= np.sqrt(N / 2) * 2
63
+
64
+ V = 2 * V.view(*x_shape)
65
+
66
+ return V
67
+
68
+
69
+ def idct(X, norm=None):
70
+ """
71
+ The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III
72
+
73
+ Our definition of idct is that idct(dct(x)) == x
74
+
75
+ For the meaning of the parameter `norm`, see:
76
+ https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
77
+
78
+ :param X: the input signal
79
+ :param norm: the normalization, None or 'ortho'
80
+ :return: the inverse DCT-II of the signal over the last dimension
81
+ """
82
+
83
+ x_shape = X.shape
84
+ N = x_shape[-1]
85
+
86
+ X_v = X.contiguous().view(-1, x_shape[-1]) / 2
87
+
88
+ if norm == 'ortho':
89
+ X_v[:, 0] *= np.sqrt(N) * 2
90
+ X_v[:, 1:] *= np.sqrt(N / 2) * 2
91
+
92
+ k = torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] * np.pi / (2 * N)
93
+ W_r = torch.cos(k)
94
+ W_i = torch.sin(k)
95
+
96
+ V_t_r = X_v
97
+ V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1)
98
+
99
+ V_r = V_t_r * W_r - V_t_i * W_i
100
+ V_i = V_t_r * W_i + V_t_i * W_r
101
+
102
+ V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2)
103
+
104
+ #v = torch.irfft(V, 1, onesided=False)
105
+ v = torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1)
106
+ x = v.new_zeros(v.shape)
107
+ x[:, ::2] += v[:, :N - (N // 2)]
108
+ x[:, 1::2] += v.flip([1])[:, :N // 2]
109
+
110
+ return x.view(*x_shape)
111
+
112
+
113
+ def dct_2d(x, norm=None):
114
+ """
115
+ 2-dimentional Discrete Cosine Transform, Type II (a.k.a. the DCT)
116
+
117
+ For the meaning of the parameter `norm`, see:
118
+ https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
119
+
120
+ :param x: the input signal
121
+ :param norm: the normalization, None or 'ortho'
122
+ :return: the DCT-II of the signal over the last 2 dimensions
123
+ """
124
+ X1 = dct(x, norm=norm)
125
+ X2 = dct(X1.transpose(-1, -2), norm=norm)
126
+ return X2.transpose(-1, -2)
127
+
128
+
129
+ def idct_2d(X, norm=None):
130
+ """
131
+ The inverse to 2D DCT-II, which is a scaled Discrete Cosine Transform, Type III
132
+
133
+ Our definition of idct is that idct_2d(dct_2d(x)) == x
134
+
135
+ For the meaning of the parameter `norm`, see:
136
+ https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
137
+
138
+ :param X: the input signal
139
+ :param norm: the normalization, None or 'ortho'
140
+ :return: the DCT-II of the signal over the last 2 dimensions
141
+ """
142
+ x1 = idct(X, norm=norm)
143
+ x2 = idct(x1.transpose(-1, -2), norm=norm)
144
+ return x2.transpose(-1, -2)
145
+
146
+
147
+ def dct_3d(x, norm=None):
148
+ """
149
+ 3-dimentional Discrete Cosine Transform, Type II (a.k.a. the DCT)
150
+
151
+ For the meaning of the parameter `norm`, see:
152
+ https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
153
+
154
+ :param x: the input signal
155
+ :param norm: the normalization, None or 'ortho'
156
+ :return: the DCT-II of the signal over the last 3 dimensions
157
+ """
158
+ X1 = dct(x, norm=norm)
159
+ X2 = dct(X1.transpose(-1, -2), norm=norm)
160
+ X3 = dct(X2.transpose(-1, -3), norm=norm)
161
+ return X3.transpose(-1, -3).transpose(-1, -2)
162
+
163
+
164
+ def idct_3d(X, norm=None):
165
+ """
166
+ The inverse to 3D DCT-II, which is a scaled Discrete Cosine Transform, Type III
167
+
168
+ Our definition of idct is that idct_3d(dct_3d(x)) == x
169
+
170
+ For the meaning of the parameter `norm`, see:
171
+ https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
172
+
173
+ :param X: the input signal
174
+ :param norm: the normalization, None or 'ortho'
175
+ :return: the DCT-II of the signal over the last 3 dimensions
176
+ """
177
+ x1 = idct(X, norm=norm)
178
+ x2 = idct(x1.transpose(-1, -2), norm=norm)
179
+ x3 = idct(x2.transpose(-1, -3), norm=norm)
180
+ return x3.transpose(-1, -3).transpose(-1, -2)
181
+
182
+
183
+ # class LinearDCT(nn.Linear):
184
+ # """Implement any DCT as a linear layer; in practice this executes around
185
+ # 50x faster on GPU. Unfortunately, the DCT matrix is stored, which will
186
+ # increase memory usage.
187
+ # :param in_features: size of expected input
188
+ # :param type: which dct function in this file to use"""
189
+ #
190
+ # def __init__(self, in_features, type, norm=None, bias=False):
191
+ # self.type = type
192
+ # self.N = in_features
193
+ # self.norm = norm
194
+ # super(LinearDCT, self).__init__(in_features, in_features, bias=bias)
195
+ #
196
+ # def reset_parameters(self):
197
+ # # initialise using dct function
198
+ # I = torch.eye(self.N)
199
+ # if self.type == 'dct1':
200
+ # self.weight.data = dct1(I).data.t()
201
+ # elif self.type == 'idct1':
202
+ # self.weight.data = idct1(I).data.t()
203
+ # elif self.type == 'dct':
204
+ # self.weight.data = dct(I, norm=self.norm).data.t()
205
+ # elif self.type == 'idct':
206
+ # self.weight.data = idct(I, norm=self.norm).data.t()
207
+ # self.weight.require_grad = False # don't learn this!
208
+
209
+ class LinearDCT(nn.Module):
210
+ """Implement any DCT as a linear layer; in practice this executes around
211
+ 50x faster on GPU. Unfortunately, the DCT matrix is stored, which will
212
+ increase memory usage.
213
+ :param in_features: size of expected input
214
+ :param type: which dct function in this file to use"""
215
+
216
+ def __init__(self, in_features, type, norm=None):
217
+ super(LinearDCT, self).__init__()
218
+ self.type = type
219
+ self.N = in_features
220
+ self.norm = norm
221
+ I = torch.eye(self.N)
222
+ if self.type == 'dct1':
223
+ self.weight = dct1(I).data.t()
224
+ elif self.type == 'idct1':
225
+ self.weight = idct1(I).data.t()
226
+ elif self.type == 'dct':
227
+ self.weight = dct(I, norm=self.norm).data.t()
228
+ elif self.type == 'idct':
229
+ self.weight = idct(I, norm=self.norm).data.t()
230
+ # self.register_buffer('weight', kernel)
231
+ # self.weight = kernel
232
+
233
+ def forward(self, x):
234
+ return F.linear(x, weight=self.weight.cuda(x.get_device()))
235
+
236
+
237
+ def apply_linear_2d(x, linear_layer):
238
+ """Can be used with a LinearDCT layer to do a 2D DCT.
239
+ :param x: the input signal
240
+ :param linear_layer: any PyTorch Linear layer
241
+ :return: result of linear layer applied to last 2 dimensions
242
+ """
243
+ X1 = linear_layer(x)
244
+ X2 = linear_layer(X1.transpose(-1, -2))
245
+ return X2.transpose(-1, -2)
246
+
247
+
248
+ def apply_linear_3d(x, linear_layer):
249
+ """Can be used with a LinearDCT layer to do a 3D DCT.
250
+ :param x: the input signal
251
+ :param linear_layer: any PyTorch Linear layer
252
+ :return: result of linear layer applied to last 3 dimensions
253
+ """
254
+ X1 = linear_layer(x)
255
+ X2 = linear_layer(X1.transpose(-1, -2))
256
+ X3 = linear_layer(X2.transpose(-1, -3))
257
+ return X3.transpose(-1, -3).transpose(-1, -2)
258
+
259
+
260
+ if __name__ == '__main__':
261
+ x = torch.Tensor(1000, 4096)
262
+ x.normal_(0, 1)
263
+ linear_dct = LinearDCT(4096, 'dct')
264
+ error = torch.abs(dct(x) - linear_dct(x))
265
+ assert error.max() < 1e-3, (error, error.max())
266
+ linear_idct = LinearDCT(4096, 'idct')
267
+ error = torch.abs(idct(x) - linear_idct(x))
268
+ assert error.max() < 1e-3, (error, error.max())
utils/dataset.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.utils.data as data
3
+ import cv2
4
+ import numpy as np
5
+ from os.path import join
6
+
7
+
8
+ class HalftoneVOC2012(data.Dataset):
9
+ # data range is [-1,1], color image is in BGR format
10
+ def __init__(self, data_list):
11
+ super(HalftoneVOC2012, self).__init__()
12
+ self.inputs = [join('Data', x) for x in data_list['inputs']]
13
+ self.labels = [join('Data', x) for x in data_list['labels']]
14
+
15
+ @staticmethod
16
+ def load_input(name):
17
+ img = cv2.imread(name, flags=cv2.IMREAD_COLOR)
18
+ # transpose data
19
+ img = img.transpose((2, 0, 1))
20
+ # to Tensor
21
+ img = torch.from_numpy(img.astype(np.float32) / 127.5 - 1.0)
22
+ return img
23
+
24
+ @staticmethod
25
+ def load_label(name):
26
+ img = cv2.imread(name, flags=cv2.IMREAD_GRAYSCALE)
27
+ # transpose data
28
+ img = img[np.newaxis, :, :]
29
+ # to Tensor
30
+ img = torch.from_numpy(img.astype(np.float32) / 127.5 - 1.0)
31
+ return img
32
+
33
+ def __getitem__(self, index):
34
+ input_data = self.load_input(self.inputs[index])
35
+ label_data = self.load_label(self.labels[index])
36
+ return input_data, label_data
37
+
38
+ def __len__(self):
39
+ return len(self.inputs)
utils/dct.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ -------------------------------------------------
4
+ File Name: dct
5
+ Author : wenbo
6
+ date: 12/4/2019
7
+ Description :
8
+ -------------------------------------------------
9
+ Change Activity:
10
+ 12/4/2019:
11
+ -------------------------------------------------
12
+ """
13
+ __author__ = 'wenbo'
14
+
15
+ from torch import nn
16
+ from ._dct import LinearDCT, apply_linear_2d
17
+
18
+
19
+ class DCT_Lowfrequency(nn.Module):
20
+ def __init__(self, size=256, fLimit=50):
21
+ super(DCT_Lowfrequency, self).__init__()
22
+ self.fLimit = fLimit
23
+ self.dct = LinearDCT(size, type='dct', norm='ortho')
24
+ self.dctTransformer = lambda x: apply_linear_2d(x, self.dct)
25
+
26
+ def forward(self, x):
27
+ x = self.dctTransformer(x)
28
+ x = x[:, :, :self.fLimit, :self.fLimit]
29
+ return x
utils/filters_tensor.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numbers
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+
8
+ class GaussianSmoothing(nn.Module):
9
+ """
10
+ Apply gaussian smoothing on a
11
+ 1d, 2d or 3d tensor. Filtering is performed seperately for each channel
12
+ in the input using a depthwise convolution.
13
+ Arguments:
14
+ channels (int, sequence): Number of channels of the input tensors. Output will
15
+ have this number of channels as well.
16
+ kernel_size (int, sequence): Size of the gaussian kernel.
17
+ sigma (float, sequence): Standard deviation of the gaussian kernel.
18
+ dim (int, optional): The number of dimensions of the data.
19
+ Default value is 2 (spatial).
20
+ """
21
+
22
+ def __init__(self, channels, kernel_size, sigma, dim=2, cuda=True):
23
+ super(GaussianSmoothing, self).__init__()
24
+ if isinstance(kernel_size, numbers.Number):
25
+ kernel_size = [kernel_size] * dim
26
+ if isinstance(sigma, numbers.Number):
27
+ sigma = [sigma] * dim
28
+
29
+ # The gaussian kernel is the product of the
30
+ # gaussian function of each dimension.
31
+ kernel = 1
32
+ meshgrids = torch.meshgrid(
33
+ [
34
+ torch.arange(size, dtype=torch.float32)
35
+ for size in kernel_size
36
+ ]
37
+ )
38
+ for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
39
+ mean = (size - 1) / 2
40
+ kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-((mgrid - mean) / std) ** 2 / 2)
41
+
42
+ # Make sure sum of values in gaussian kernel equals 1.
43
+ kernel = kernel / torch.sum(kernel)
44
+
45
+ # Reshape to depthwise convolutional weight
46
+ kernel = kernel.view(1, 1, *kernel.size())
47
+ kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
48
+
49
+ # if cuda:
50
+ # kernel = kernel.cuda()
51
+ # self.register_buffer('weight', kernel)
52
+ self.weight = kernel
53
+ self.groups = channels
54
+
55
+ if dim == 1:
56
+ self.conv = F.conv1d
57
+ elif dim == 2:
58
+ self.conv = F.conv2d
59
+ elif dim == 3:
60
+ self.conv = F.conv3d
61
+ else:
62
+ raise RuntimeError(
63
+ 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim)
64
+ )
65
+
66
+ def forward(self, input):
67
+ """
68
+ Apply gaussian filter to input.
69
+ Arguments:
70
+ input (torch.Tensor): Input to apply gaussian filter on.
71
+ Returns:
72
+ filtered (torch.Tensor): Filtered output.
73
+ """
74
+ return self.conv(input, weight=self.weight.cuda(input.get_device()), groups=self.groups)
75
+
76
+
77
+ def bgr2gray(color):
78
+ # gray = 0.299⋅R+0.587⋅G+0.114⋅B
79
+ gray = color[:, 0, ...] * 0.114 + color[:, 1, ...] * 0.587 + color[:, 2, ...] * 0.299
80
+ gray = gray.unsqueeze_(1)
81
+ return gray
utils/pytorch_ssim.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch.autograd import Variable
4
+ import numpy as np
5
+ from math import exp
6
+
7
+
8
+ def gaussian(window_size, sigma):
9
+ gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
10
+ return gauss / gauss.sum()
11
+
12
+
13
+ def create_window(window_size, channel):
14
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
15
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
16
+ window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
17
+ return window
18
+
19
+
20
+ def _ssim(img1, img2, window, window_size, channel, size_average=True):
21
+ mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
22
+ mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
23
+
24
+ mu1_sq = mu1.pow(2)
25
+ mu2_sq = mu2.pow(2)
26
+ mu1_mu2 = mu1 * mu2
27
+
28
+ sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
29
+ sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
30
+ sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
31
+
32
+ C1 = 0.01 ** 2
33
+ C2 = 0.03 ** 2
34
+
35
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
36
+
37
+ if size_average:
38
+ return ssim_map.mean()
39
+ else:
40
+ return ssim_map.mean(1).mean(1).mean(1)
41
+
42
+
43
+ class SSIM(torch.nn.Module):
44
+ def __init__(self, window_size=11, size_average=True):
45
+ super(SSIM, self).__init__()
46
+ self.window_size = window_size
47
+ self.size_average = size_average
48
+ self.channel = 1
49
+ self.window = create_window(window_size, self.channel)
50
+
51
+ def forward(self, img1, img2):
52
+ (_, channel, _, _) = img1.size()
53
+
54
+ if channel == self.channel and self.window.data.type() == img1.data.type():
55
+ window = self.window
56
+ else:
57
+ window = create_window(self.window_size, channel)
58
+
59
+ if img1.is_cuda:
60
+ window = window.cuda(img1.get_device())
61
+ window = window.type_as(img1)
62
+
63
+ self.window = window
64
+ self.channel = channel
65
+
66
+ return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
67
+
68
+
69
+ def ssim(img1, img2, window_size=11, size_average=True):
70
+ (_, channel, _, _) = img1.size()
71
+ window = create_window(window_size, channel)
72
+
73
+ if img1.is_cuda:
74
+ window = window.cuda(img1.get_device())
75
+ window = window.type_as(img1)
76
+
77
+ return _ssim(img1, img2, window, window_size, channel, size_average)
utils/util.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import cv2
4
+ import torch
5
+
6
+ def ensure_dir(path):
7
+ if not os.path.exists(path):
8
+ os.makedirs(path)
9
+
10
+
11
+ def get_filelist(data_dir):
12
+ file_list = glob.glob(os.path.join(data_dir, '*.*'))
13
+ file_list.sort()
14
+ return file_list
15
+
16
+
17
+ def collect_filenames(data_dir):
18
+ file_list = get_filelist(data_dir)
19
+ name_list = []
20
+ for file_path in file_list:
21
+ _, file_name = os.path.split(file_path)
22
+ name_list.append(file_name)
23
+ name_list.sort()
24
+ return name_list
25
+
26
+
27
+ def save_list(save_path, data_list, append_mode=False):
28
+ n = len(data_list)
29
+ if append_mode:
30
+ with open(save_path, 'a') as f:
31
+ f.writelines([str(data_list[i]) + '\n' for i in range(n-1,n)])
32
+ else:
33
+ with open(save_path, 'w') as f:
34
+ f.writelines([str(data_list[i]) + '\n' for i in range(n)])
35
+ return None
36
+
37
+
38
+ def save_images_from_batch(img_batch, save_dir, filename_list, batch_no=-1):
39
+ N,H,W,C = img_batch.shape
40
+ if C == 3:
41
+ #! rgb color image
42
+ for i in range(N):
43
+ # [-1,1] >>> [0,255]
44
+ img_batch_i = np.clip(img_batch[i,:,:,:]*0.5+0.5, 0, 1)
45
+ image = (255.0*img_batch_i).astype(np.uint8)
46
+ save_name = filename_list[i] if batch_no==-1 else '%05d.png' % (batch_no*N+i)
47
+ cv2.imwrite(os.path.join(save_dir, save_name), image)
48
+ elif C == 1:
49
+ #! single-channel gray image
50
+ for i in range(N):
51
+ # [-1,1] >>> [0,255]
52
+ img_batch_i = np.clip(img_batch[i,:,:,0]*0.5+0.5, 0, 1)
53
+ image = (255.0*img_batch_i).astype(np.uint8)
54
+ save_name = filename_list[i] if batch_no==-1 else '%05d.png' % (batch_no*img_batch.shape[0]+i)
55
+ cv2.imwrite(os.path.join(save_dir, save_name), image)
56
+ return None
57
+
58
+
59
+ def imagesc(nd_array):
60
+ plt.imshow(nd_array)
61
+ plt.colorbar()
62
+ plt.show()
63
+
64
+
65
+ def img2tensor(img):
66
+ if len(img.shape) == 2:
67
+ img = img[..., np.newaxis]
68
+ img_t = np.expand_dims(img.transpose(2, 0, 1), axis=0)
69
+ img_t = torch.from_numpy(img_t.astype(np.float32))
70
+ return img_t
71
+
72
+
73
+ def tensor2img(img_t):
74
+ img = img_t[0].detach().to("cpu").numpy()
75
+ img = np.transpose(img, (1, 2, 0))
76
+ if img.shape[-1] == 1:
77
+ img = img[..., 0]
78
+ return img