talha commited on
Commit
ad250d1
1 Parent(s): 42b38a5

models added

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +4 -1
  2. README.md +1 -1
  3. app.py +4 -9
  4. image_enhancement.egg-info/PKG-INFO +9 -0
  5. image_enhancement.egg-info/SOURCES.txt +38 -0
  6. image_enhancement.egg-info/dependency_links.txt +1 -0
  7. image_enhancement.egg-info/top_level.txt +2 -0
  8. models/README.md +16 -0
  9. models/__init__.py +0 -0
  10. models/__pycache__/__init__.cpython-310.pyc +0 -0
  11. models/demoire/__init__.py +0 -0
  12. models/demoire/__pycache__/__init__.cpython-310.pyc +0 -0
  13. models/demoire/__pycache__/nets.cpython-310.pyc +0 -0
  14. {model → models/demoire}/nets.py +0 -0
  15. models/llflow/LOL_smallNet.pth +3 -0
  16. models/llflow/LOL_smallNet.yml +125 -0
  17. models/llflow/Measure.py +127 -0
  18. models/llflow/__init__.py +4 -0
  19. models/llflow/__pycache__/Measure.cpython-310.pyc +0 -0
  20. models/llflow/__pycache__/__init__.cpython-310.pyc +0 -0
  21. models/llflow/__pycache__/imresize.cpython-310.pyc +0 -0
  22. models/llflow/__pycache__/inference.cpython-310.pyc +0 -0
  23. models/llflow/__pycache__/option_.cpython-310.pyc +0 -0
  24. models/llflow/__pycache__/util.cpython-310.pyc +0 -0
  25. models/llflow/imresize.py +180 -0
  26. models/llflow/inference.py +157 -0
  27. models/llflow/models/LLFlow_model.py +400 -0
  28. models/llflow/models/__init__.py +52 -0
  29. models/llflow/models/__pycache__/LLFlow_model.cpython-310.pyc +0 -0
  30. models/llflow/models/__pycache__/__init__.cpython-310.pyc +0 -0
  31. models/llflow/models/__pycache__/base_model.cpython-310.pyc +0 -0
  32. models/llflow/models/__pycache__/lr_scheduler.cpython-310.pyc +0 -0
  33. models/llflow/models/__pycache__/networks.cpython-310.pyc +0 -0
  34. models/llflow/models/base_model.py +145 -0
  35. models/llflow/models/lr_scheduler.py +147 -0
  36. models/llflow/models/modules/ConditionEncoder.py +287 -0
  37. models/llflow/models/modules/FlowActNorms.py +128 -0
  38. models/llflow/models/modules/FlowAffineCouplingsAblation.py +169 -0
  39. models/llflow/models/modules/FlowStep.py +136 -0
  40. models/llflow/models/modules/FlowUpsamplerNet.py +328 -0
  41. models/llflow/models/modules/LLFlow_arch.py +248 -0
  42. models/llflow/models/modules/Permutations.py +59 -0
  43. models/llflow/models/modules/RRDBNet_arch.py +147 -0
  44. models/llflow/models/modules/Split.py +88 -0
  45. models/llflow/models/modules/__init__.py +0 -0
  46. models/llflow/models/modules/__pycache__/ConditionEncoder.cpython-310.pyc +0 -0
  47. models/llflow/models/modules/__pycache__/FlowActNorms.cpython-310.pyc +0 -0
  48. models/llflow/models/modules/__pycache__/FlowAffineCouplingsAblation.cpython-310.pyc +0 -0
  49. models/llflow/models/modules/__pycache__/FlowStep.cpython-310.pyc +0 -0
  50. models/llflow/models/modules/__pycache__/FlowUpsamplerNet.cpython-310.pyc +0 -0
.gitignore CHANGED
@@ -1 +1,4 @@
1
- model/__pycache__/
 
 
 
 
1
+ model/__pycache__/
2
+ models/llflow/LOLdataset.zip
3
+ models/llflow/dataset_samples
4
+ models/results
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Screen Image Demoireing
3
  emoji: 🔥
4
  sdk: gradio
5
  sdk_version: 3.10.1
 
1
  ---
2
+ title: Image Enhancement
3
  emoji: 🔥
4
  sdk: gradio
5
  sdk_version: 3.10.1
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- from model.nets import my_model
3
  import torch
4
  import cv2
5
  import torch.utils.data as data
@@ -15,7 +15,9 @@ import torch.nn.functional as F
15
  from rich.panel import Panel
16
  from rich.columns import Columns
17
  from rich.console import Console
18
- from models.gfpgan import gfpgan_predict
 
 
19
 
20
  os.environ["CUDA_VISIBLE_DEVICES"] = "1"
21
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -89,13 +91,6 @@ def predict_gfpgan(img):
89
 
90
  with Console().status("[red] using [green] GFP-GAN v1.4", spinner="aesthetic"):
91
  # if image already exists with this name then delete it
92
- if Path("input_image_gfpgan.jpg").exists():
93
- os.remove("input_image_gfpgan.jpg")
94
- # save incoming PIL image to disk
95
- img.save("input_image_gfpgan.jpg")
96
-
97
- out = gfpgan_predict(img)
98
- Console().print(out)
99
 
100
  return img
101
 
 
1
  import gradio as gr
2
+ from models.demoire.nets import my_model
3
  import torch
4
  import cv2
5
  import torch.utils.data as data
 
15
  from rich.panel import Panel
16
  from rich.columns import Columns
17
  from rich.console import Console
18
+ # from models.gfpgan import gfpgan_predict
19
+ from models.llflow.inference import main
20
+
21
 
22
  os.environ["CUDA_VISIBLE_DEVICES"] = "1"
23
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
91
 
92
  with Console().status("[red] using [green] GFP-GAN v1.4", spinner="aesthetic"):
93
  # if image already exists with this name then delete it
 
 
 
 
 
 
 
94
 
95
  return img
96
 
image_enhancement.egg-info/PKG-INFO ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: image-enhancement
3
+ Version: 1.0
4
+ Summary: UNKNOWN
5
+ License: UNKNOWN
6
+ Platform: UNKNOWN
7
+
8
+ UNKNOWN
9
+
image_enhancement.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ README.md
2
+ setup.py
3
+ image_enhancement.egg-info/PKG-INFO
4
+ image_enhancement.egg-info/SOURCES.txt
5
+ image_enhancement.egg-info/dependency_links.txt
6
+ image_enhancement.egg-info/top_level.txt
7
+ model/__init__.py
8
+ model/nets.py
9
+ models/__init__.py
10
+ models/gfpgan.py
11
+ models/llflow/Measure.py
12
+ models/llflow/__init__.py
13
+ models/llflow/imresize.py
14
+ models/llflow/inference.py
15
+ models/llflow/option_.py
16
+ models/llflow/util.py
17
+ models/llflow/models/LLFlow_model.py
18
+ models/llflow/models/__init__.py
19
+ models/llflow/models/base_model.py
20
+ models/llflow/models/lr_scheduler.py
21
+ models/llflow/models/networks.py
22
+ models/llflow/models/modules/ConditionEncoder.py
23
+ models/llflow/models/modules/FlowActNorms.py
24
+ models/llflow/models/modules/FlowAffineCouplingsAblation.py
25
+ models/llflow/models/modules/FlowStep.py
26
+ models/llflow/models/modules/FlowUpsamplerNet.py
27
+ models/llflow/models/modules/LLFlow_arch.py
28
+ models/llflow/models/modules/Permutations.py
29
+ models/llflow/models/modules/RRDBNet_arch.py
30
+ models/llflow/models/modules/Split.py
31
+ models/llflow/models/modules/__init__.py
32
+ models/llflow/models/modules/base_layers.py
33
+ models/llflow/models/modules/color_encoder.py
34
+ models/llflow/models/modules/flow.py
35
+ models/llflow/models/modules/glow_arch.py
36
+ models/llflow/models/modules/loss.py
37
+ models/llflow/models/modules/module_util.py
38
+ models/llflow/models/modules/thops.py
image_enhancement.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
image_enhancement.egg-info/top_level.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ model
2
+ models
models/README.md ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # References 📚
2
+
3
+
4
+
5
+ [1] Y. Wang, “[AAAI 2022 Oral] Low-Light Image Enhancement with Normalizing Flow.” Nov. 23, 2022. Accessed: Nov. 24, 2022. [Online]. Available: https://github.com/wyf0912/LLFlow
6
+ [2] L. Wang and K.-J. Yoon, “Deep Learning for HDR Imaging: State-of-the-Art and Future Trends.” arXiv, Nov. 07, 2021. Accessed: Nov. 24, 2022. [Online]. Available: http://arxiv.org/abs/2110.10394
7
+ [3] Y. WANG, “Neural Color Operators for Sequential Image Retouching (ECCV2022).” Nov. 10, 2022. Accessed: Nov. 24, 2022. [Online]. Available: https://github.com/amberwangyili/neurop
8
+ [4] jwhe, “Conditional Sequential Modulation for Efficient Global Image Retouching Paper Link.” Nov. 23, 2022. Accessed: Nov. 24, 2022. [Online]. Available: https://github.com/hejingwenhejingwen/CSRNet
9
+ [5] Why, “Local Color Distributions Prior for Image Enhancement [ECCV2022].” Nov. 21, 2022. Accessed: Nov. 23, 2022. [Online]. Available: https://github.com/onpix/LCDPNet
10
+ [6] “Towards Efficient and Scale-Robust Ultra-High-Definition Image Demoiréing.” CVMI Lab, Nov. 21, 2022. Accessed: Nov. 21, 2022. [Online]. Available: https://github.com/CVMI-Lab/UHDM
11
+ [7] Z. Wang, “Uformer: A General U-Shaped Transformer for Image Restoration (CVPR 2022).” Nov. 20, 2022. Accessed: Nov. 21, 2022. [Online]. Available: https://github.com/ZhendongWang6/Uformer
12
+ [8] B. Zheng, “Learnbale_Bandpass_Filter.” Nov. 21, 2022. Accessed: Nov. 21, 2022. [Online]. Available: https://github.com/zhenngbolun/Learnbale_Bandpass_Filter
13
+ [9] K. Team, “Keras documentation: Enhanced Deep Residual Networks for single-image super-resolution.” https://keras.io/examples/vision/edsr/ (accessed Nov. 21, 2022).
14
+ [10] B. Lim, S. Son, H. Kim, S. Nah, and K. M. Lee, “Enhanced Deep Residual Networks for Single Image Super-Resolution.” arXiv, Jul. 10, 2017. doi: 10.48550/arXiv.1707.02921.
15
+ [11] C. Dong, C. C. Loy, K. He, and X. Tang, “Image Super-Resolution Using Deep Convolutional Networks.” arXiv, Jul. 31, 2015. doi: 10.48550/arXiv.1501.00092.
16
+ [12] Z. Anvari and V. Athitsos, “A Survey on Deep learning based Document Image Enhancement.” arXiv, Jan. 03, 2022. doi: 10.48550/arXiv.2112.02719.
models/__init__.py ADDED
File without changes
models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (147 Bytes). View file
 
models/demoire/__init__.py ADDED
File without changes
models/demoire/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (155 Bytes). View file
 
models/demoire/__pycache__/nets.cpython-310.pyc ADDED
Binary file (8.67 kB). View file
 
{model → models/demoire}/nets.py RENAMED
File without changes
models/llflow/LOL_smallNet.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2bf4c9192b401bf7155b2aa0781d9d8eed2e0bcc148286a9e2b224e12777bb38
3
+ size 21874185
models/llflow/LOL_smallNet.yml ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #### general settings
2
+ name: train_rebuttal_smallNet_ch32_blocks1
3
+ use_tb_logger: true
4
+ model: LLFlow
5
+ distortion: sr
6
+ scale: 1
7
+ gpu_ids: [0]
8
+ dataset: LoL
9
+ optimize_all_z: false
10
+ cond_encoder: ConEncoder1
11
+ train_gt_ratio: 0.2
12
+ avg_color_map: false
13
+
14
+ concat_histeq: true
15
+ histeq_as_input: false
16
+ concat_color_map: false
17
+ gray_map: false # concat 1-input.mean(dim=1) to the input
18
+
19
+ align_condition_feature: false
20
+ align_weight: 0.001
21
+ align_maxpool: true
22
+
23
+ to_yuv: false
24
+
25
+ encode_color_map: false
26
+ le_curve: false
27
+ # sigmoid_output: true
28
+
29
+ #### datasets
30
+ datasets:
31
+ train:
32
+ root: D:\LOLdataset
33
+ quant: 32
34
+ use_shuffle: true
35
+ n_workers: 1 # per GPU
36
+ batch_size: 16
37
+ use_flip: true
38
+ color: RGB
39
+ use_crop: true
40
+ GT_size: 160 # 192
41
+ noise_prob: 0
42
+ noise_level: 5
43
+ log_low: true
44
+ gamma_aug: false
45
+
46
+ val:
47
+ root: D:\LOLdataset
48
+ n_workers: 1
49
+ quant: 32
50
+ n_max: 20
51
+ batch_size: 1 # must be 1
52
+ log_low: true
53
+
54
+ #### Test Settings
55
+ # dataroot_GT: D:\LOLdataset\eval15\high
56
+ # dataroot_LR: D:\LOLdataset\eval15\low
57
+ dataroot_unpaired: models/llflow/dataset_samples/our485/low
58
+ # dataroot_unpaired: /home/data/Dataset/LOL_test/Fusion
59
+ dataroot_GT: D:\Dataset\LOL-v2\LOL-v2\IntegratedTest\Test\high
60
+ dataroot_LR: D:\Dataset\LOL-v2\LOL-v2\IntegratedTest\Test\low
61
+ model_path: models/llflow/LOL_smallNet.pth
62
+ heat: 0 # This is the standard deviation of the latent vectors
63
+
64
+ #### network structures
65
+ network_G:
66
+ which_model_G: LLFlow
67
+ in_nc: 3
68
+ out_nc: 3
69
+ nf: 32
70
+ nb: 4 # 12 for our low light encoder, 23 for LLFlow
71
+ train_RRDB: false
72
+ train_RRDB_delay: 0.5
73
+
74
+ flow:
75
+ K: 4 # 24.49 psnr用的12 # 16
76
+ L: 3 # 4
77
+ noInitialInj: true
78
+ coupling: CondAffineSeparatedAndCond
79
+ additionalFlowNoAffine: 2
80
+ conditionInFeaDim: 64
81
+ split:
82
+ enable: false
83
+ fea_up0: true
84
+ stackRRDB:
85
+ blocks: [1]
86
+ concat: true
87
+
88
+ #### path
89
+ path:
90
+ # pretrain_model_G: ../pretrained_models/RRDB_DF2K_8X.pth
91
+ strict_load: true
92
+ resume_state: auto
93
+
94
+ #### training settings: learning rate scheme, loss
95
+ train:
96
+ manual_seed: 10
97
+ lr_G: !!float 5e-4 # normalizing flow 5e-4; l1 loss train 5e-5
98
+ weight_decay_G: 0 # 1e-5 # 5e-5 # 1e-5
99
+ beta1: 0.9
100
+ beta2: 0.99
101
+ lr_scheme: MultiStepLR
102
+ warmup_iter: -1 # no warm up
103
+ lr_steps_rel: [ 0.5, 0.75, 0.9, 0.95 ] # [0.2, 0.35, 0.5, 0.65, 0.8, 0.95] # [ 0.5, 0.75, 0.9, 0.95 ]
104
+ lr_gamma: 0.5
105
+
106
+ weight_l1: 0
107
+ # flow_warm_up_iter: -1
108
+ weight_fl: 1
109
+
110
+ niter: 45000 #200000
111
+ val_freq: 200 # 200
112
+
113
+ #### validation settings
114
+ val:
115
+ # heats: [ 0.0, 0.5, 0.75, 1.0 ]
116
+ n_sample: 4
117
+
118
+ test:
119
+ heats: [ 0.0, 0.7, 0.8, 0.9 ]
120
+
121
+ #### logger
122
+ logger:
123
+ # Debug print_freq: 100
124
+ print_freq: 100
125
+ save_checkpoint_freq: !!float 1e3
models/llflow/Measure.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import time
4
+ from collections import OrderedDict
5
+
6
+ import numpy as np
7
+ import torch
8
+ import cv2
9
+ import argparse
10
+
11
+ from natsort import natsort
12
+ from skimage.metrics import structural_similarity as ssim
13
+ from skimage.metrics import peak_signal_noise_ratio as psnr
14
+ import lpips
15
+
16
+
17
+ class Measure():
18
+ def __init__(self, net='alex', use_gpu=False):
19
+ self.device = 'cuda' if use_gpu else 'cpu'
20
+ self.model = lpips.LPIPS(net=net)
21
+ self.model.to(self.device)
22
+
23
+ def measure(self, imgA, imgB):
24
+ return [float(f(imgA, imgB)) for f in [self.psnr, self.ssim, self.lpips]]
25
+
26
+ def lpips(self, imgA, imgB, model=None):
27
+ tA = t(imgA).to(self.device)
28
+ tB = t(imgB).to(self.device)
29
+ dist01 = self.model.forward(tA, tB).item()
30
+ return dist01
31
+
32
+ def ssim(self, imgA, imgB, gray_scale=True):
33
+ if gray_scale:
34
+ score, diff = ssim(cv2.cvtColor(imgA, cv2.COLOR_RGB2GRAY), cv2.cvtColor(
35
+ imgB, cv2.COLOR_RGB2GRAY), full=True, multichannel=True)
36
+ # multichannel: If True, treat the last dimension of the array as channels. Similarity calculations are done independently for each channel then averaged.
37
+ else:
38
+ score, diff = ssim(imgA, imgB, full=True, multichannel=True)
39
+ return score
40
+
41
+ def psnr(self, imgA, imgB):
42
+ psnr_val = psnr(imgA, imgB)
43
+ return psnr_val
44
+
45
+
46
+ def t(img):
47
+ def to_4d(img):
48
+ assert len(img.shape) == 3
49
+ assert img.dtype == np.uint8
50
+ img_new = np.expand_dims(img, axis=0)
51
+ assert len(img_new.shape) == 4
52
+ return img_new
53
+
54
+ def to_CHW(img):
55
+ return np.transpose(img, [2, 0, 1])
56
+
57
+ def to_tensor(img):
58
+ return torch.Tensor(img)
59
+
60
+ return to_tensor(to_4d(to_CHW(img))) / 127.5 - 1
61
+
62
+
63
+ def fiFindByWildcard(wildcard):
64
+ return natsort.natsorted(glob.glob(wildcard, recursive=True))
65
+
66
+
67
+ def imread(path):
68
+ return cv2.imread(path)[:, :, [2, 1, 0]]
69
+
70
+
71
+ def format_result(psnr, ssim, lpips):
72
+ return f'{psnr:0.2f}, {ssim:0.3f}, {lpips:0.3f}'
73
+
74
+
75
+ def measure_dirs(dirA, dirB, use_gpu, verbose=False):
76
+ if verbose:
77
+ def vprint(x): return print(x)
78
+ else:
79
+ def vprint(x): return None
80
+
81
+ t_init = time.time()
82
+
83
+ paths_A = fiFindByWildcard(os.path.join(dirA, f'*.{type}'))
84
+ paths_B = fiFindByWildcard(os.path.join(dirB, f'*.{type}'))
85
+
86
+ vprint("Comparing: ")
87
+ vprint(dirA)
88
+ vprint(dirB)
89
+
90
+ measure = Measure(use_gpu=use_gpu)
91
+
92
+ results = []
93
+ for pathA, pathB in zip(paths_A, paths_B):
94
+ result = OrderedDict()
95
+
96
+ t = time.time()
97
+ result['psnr'], result['ssim'], result['lpips'] = measure.measure(
98
+ imread(pathA), imread(pathB))
99
+ d = time.time() - t
100
+ vprint(
101
+ f"{pathA.split('/')[-1]}, {pathB.split('/')[-1]}, {format_result(**result)}, {d:0.1f}")
102
+
103
+ results.append(result)
104
+
105
+ psnr = np.mean([result['psnr'] for result in results])
106
+ ssim = np.mean([result['ssim'] for result in results])
107
+ lpips = np.mean([result['lpips'] for result in results])
108
+
109
+ vprint(
110
+ f"Final Result: {format_result(psnr, ssim, lpips)}, {time.time() - t_init:0.1f}s")
111
+
112
+
113
+ if __name__ == "__main__":
114
+ parser = argparse.ArgumentParser()
115
+ parser.add_argument('-dirA', default='', type=str)
116
+ parser.add_argument('-dirB', default='', type=str)
117
+ parser.add_argument('-type', default='png')
118
+ parser.add_argument('--use_gpu', action='store_true', default=False)
119
+ args = parser.parse_args()
120
+
121
+ dirA = args.dirA
122
+ dirB = args.dirB
123
+ type = args.type
124
+ use_gpu = args.use_gpu
125
+
126
+ if len(dirA) > 0 and len(dirB) > 0:
127
+ measure_dirs(dirA, dirB, use_gpu=use_gpu, verbose=True)
models/llflow/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # from util import get_resume_paths, opt_get
2
+ from .Measure import Measure, psnr
3
+ from .imresize import imresize
4
+ from models import *
models/llflow/__pycache__/Measure.cpython-310.pyc ADDED
Binary file (4.74 kB). View file
 
models/llflow/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (258 Bytes). View file
 
models/llflow/__pycache__/imresize.cpython-310.pyc ADDED
Binary file (4.69 kB). View file
 
models/llflow/__pycache__/inference.cpython-310.pyc ADDED
Binary file (5.09 kB). View file
 
models/llflow/__pycache__/option_.cpython-310.pyc ADDED
Binary file (4.4 kB). View file
 
models/llflow/__pycache__/util.cpython-310.pyc ADDED
Binary file (7.68 kB). View file
 
models/llflow/imresize.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/fatheral/matlab_imresize
2
+ #
3
+ # MIT License
4
+ #
5
+ # Copyright (c) 2020 Alex
6
+ #
7
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ # of this software and associated documentation files (the "Software"), to deal
9
+ # in the Software without restriction, including without limitation the rights
10
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ # copies of the Software, and to permit persons to whom the Software is
12
+ # furnished to do so, subject to the following conditions:
13
+ #
14
+ # The above copyright notice and this permission notice shall be included in all
15
+ # copies or substantial portions of the Software.
16
+ #
17
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ # SOFTWARE.
24
+
25
+
26
+ from __future__ import print_function
27
+ import numpy as np
28
+ from math import ceil, floor
29
+
30
+
31
+ def deriveSizeFromScale(img_shape, scale):
32
+ output_shape = []
33
+ for k in range(2):
34
+ output_shape.append(int(ceil(scale[k] * img_shape[k])))
35
+ return output_shape
36
+
37
+
38
+ def deriveScaleFromSize(img_shape_in, img_shape_out):
39
+ scale = []
40
+ for k in range(2):
41
+ scale.append(1.0 * img_shape_out[k] / img_shape_in[k])
42
+ return scale
43
+
44
+
45
+ def triangle(x):
46
+ x = np.array(x).astype(np.float64)
47
+ lessthanzero = np.logical_and((x >= -1), x < 0)
48
+ greaterthanzero = np.logical_and((x <= 1), x >= 0)
49
+ f = np.multiply((x + 1), lessthanzero) + np.multiply((1 - x), greaterthanzero)
50
+ return f
51
+
52
+
53
+ def cubic(x):
54
+ x = np.array(x).astype(np.float64)
55
+ absx = np.absolute(x)
56
+ absx2 = np.multiply(absx, absx)
57
+ absx3 = np.multiply(absx2, absx)
58
+ f = np.multiply(1.5 * absx3 - 2.5 * absx2 + 1, absx <= 1) + np.multiply(-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2,
59
+ (1 < absx) & (absx <= 2))
60
+ return f
61
+
62
+
63
+ def contributions(in_length, out_length, scale, kernel, k_width):
64
+ if scale < 1:
65
+ h = lambda x: scale * kernel(scale * x)
66
+ kernel_width = 1.0 * k_width / scale
67
+ else:
68
+ h = kernel
69
+ kernel_width = k_width
70
+ x = np.arange(1, out_length + 1).astype(np.float64)
71
+ u = x / scale + 0.5 * (1 - 1 / scale)
72
+ left = np.floor(u - kernel_width / 2)
73
+ P = int(ceil(kernel_width)) + 2
74
+ ind = np.expand_dims(left, axis=1) + np.arange(P) - 1 # -1 because indexing from 0
75
+ indices = ind.astype(np.int32)
76
+ weights = h(np.expand_dims(u, axis=1) - indices - 1) # -1 because indexing from 0
77
+ weights = np.divide(weights, np.expand_dims(np.sum(weights, axis=1), axis=1))
78
+ aux = np.concatenate((np.arange(in_length), np.arange(in_length - 1, -1, step=-1))).astype(np.int32)
79
+ indices = aux[np.mod(indices, aux.size)]
80
+ ind2store = np.nonzero(np.any(weights, axis=0))
81
+ weights = weights[:, ind2store]
82
+ indices = indices[:, ind2store]
83
+ return weights, indices
84
+
85
+
86
+ def imresizemex(inimg, weights, indices, dim):
87
+ in_shape = inimg.shape
88
+ w_shape = weights.shape
89
+ out_shape = list(in_shape)
90
+ out_shape[dim] = w_shape[0]
91
+ outimg = np.zeros(out_shape)
92
+ if dim == 0:
93
+ for i_img in range(in_shape[1]):
94
+ for i_w in range(w_shape[0]):
95
+ w = weights[i_w, :]
96
+ ind = indices[i_w, :]
97
+ im_slice = inimg[ind, i_img].astype(np.float64)
98
+ outimg[i_w, i_img] = np.sum(np.multiply(np.squeeze(im_slice, axis=0), w.T), axis=0)
99
+ elif dim == 1:
100
+ for i_img in range(in_shape[0]):
101
+ for i_w in range(w_shape[0]):
102
+ w = weights[i_w, :]
103
+ ind = indices[i_w, :]
104
+ im_slice = inimg[i_img, ind].astype(np.float64)
105
+ outimg[i_img, i_w] = np.sum(np.multiply(np.squeeze(im_slice, axis=0), w.T), axis=0)
106
+ if inimg.dtype == np.uint8:
107
+ outimg = np.clip(outimg, 0, 255)
108
+ return np.around(outimg).astype(np.uint8)
109
+ else:
110
+ return outimg
111
+
112
+
113
+ def imresizevec(inimg, weights, indices, dim):
114
+ wshape = weights.shape
115
+ if dim == 0:
116
+ weights = weights.reshape((wshape[0], wshape[2], 1, 1))
117
+ outimg = np.sum(weights * ((inimg[indices].squeeze(axis=1)).astype(np.float64)), axis=1)
118
+ elif dim == 1:
119
+ weights = weights.reshape((1, wshape[0], wshape[2], 1))
120
+ outimg = np.sum(weights * ((inimg[:, indices].squeeze(axis=2)).astype(np.float64)), axis=2)
121
+ if inimg.dtype == np.uint8:
122
+ outimg = np.clip(outimg, 0, 255)
123
+ return np.around(outimg).astype(np.uint8)
124
+ else:
125
+ return outimg
126
+
127
+
128
+ def resizeAlongDim(A, dim, weights, indices, mode="vec"):
129
+ if mode == "org":
130
+ out = imresizemex(A, weights, indices, dim)
131
+ else:
132
+ out = imresizevec(A, weights, indices, dim)
133
+ return out
134
+
135
+
136
+ def imresize(I, scalar_scale=None, method='bicubic', output_shape=None, mode="vec"):
137
+ if method is 'bicubic':
138
+ kernel = cubic
139
+ elif method is 'bilinear':
140
+ kernel = triangle
141
+ else:
142
+ print('Error: Unidentified method supplied')
143
+
144
+ kernel_width = 4.0
145
+ # Fill scale and output_size
146
+ if scalar_scale is not None:
147
+ scalar_scale = float(scalar_scale)
148
+ scale = [scalar_scale, scalar_scale]
149
+ output_size = deriveSizeFromScale(I.shape, scale)
150
+ elif output_shape is not None:
151
+ scale = deriveScaleFromSize(I.shape, output_shape)
152
+ output_size = list(output_shape)
153
+ else:
154
+ print('Error: scalar_scale OR output_shape should be defined!')
155
+ return
156
+ scale_np = np.array(scale)
157
+ order = np.argsort(scale_np)
158
+ weights = []
159
+ indices = []
160
+ for k in range(2):
161
+ w, ind = contributions(I.shape[k], output_size[k], scale[k], kernel, kernel_width)
162
+ weights.append(w)
163
+ indices.append(ind)
164
+ B = np.copy(I)
165
+ flag2D = False
166
+ if B.ndim == 2:
167
+ B = np.expand_dims(B, axis=2)
168
+ flag2D = True
169
+ for k in range(2):
170
+ dim = order[k]
171
+ B = resizeAlongDim(B, dim, weights[dim], indices[dim], mode)
172
+ if flag2D:
173
+ B = np.squeeze(B, axis=2)
174
+ return B
175
+
176
+
177
+ def convertDouble2Byte(I):
178
+ B = np.clip(I, 0.0, 1.0)
179
+ B = 255 * B
180
+ return np.around(B).astype(np.uint8)
models/llflow/inference.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import sys
3
+ from collections import OrderedDict
4
+ import tqdm
5
+ from natsort import natsort
6
+ import argparse
7
+ import models.llflow.option_ as option
8
+ from models.llflow import Measure, psnr
9
+ from models.llflow import imresize
10
+ from models import create_model
11
+ import torch
12
+ from util import opt_get
13
+ import numpy as np
14
+ import pandas as pd
15
+ import os
16
+ import cv2
17
+ from rich.console import Console
18
+
19
+ def fiFindByWildcard(wildcard):
20
+ return natsort.natsorted(glob.glob(wildcard, recursive=True))
21
+
22
+
23
+ def load_model(conf_path):
24
+ opt = option.parse(conf_path, is_train=False)
25
+ opt['gpu_ids'] = None
26
+ opt = option.dict_to_nonedict(opt)
27
+ model = create_model(opt)
28
+
29
+ model_path = opt_get(opt, ['model_path'], None)
30
+ model.load_network(load_path=model_path, network=model.netG)
31
+ return model, opt
32
+
33
+
34
+ def predict(model, lr):
35
+ model.feed_data({"LQ": t(lr)}, need_GT=False)
36
+ model.test()
37
+ visuals = model.get_current_visuals(need_GT=False)
38
+ return visuals.get('rlt', visuals.get('NORMAL'))
39
+
40
+
41
+ def t(array): return torch.Tensor(np.expand_dims(array.transpose([2, 0, 1]), axis=0).astype(np.float32)) / 255
42
+
43
+
44
+ def rgb(t): return (
45
+ np.clip((t[0] if len(t.shape) == 4 else t).detach().cpu().numpy().transpose([1, 2, 0]), 0, 1) * 255).astype(
46
+ np.uint8)
47
+
48
+
49
+ def imread(path):
50
+ return cv2.imread(path)[:, :, [2, 1, 0]]
51
+
52
+
53
+ def imwrite(path, img):
54
+ os.makedirs(os.path.dirname(path), exist_ok=True)
55
+ cv2.imwrite(path, img[:, :, [2, 1, 0]])
56
+
57
+
58
+ def imCropCenter(img, size):
59
+ h, w, c = img.shape
60
+
61
+ h_start = max(h // 2 - size // 2, 0)
62
+ h_end = min(h_start + size, h)
63
+
64
+ w_start = max(w // 2 - size // 2, 0)
65
+ w_end = min(w_start + size, w)
66
+
67
+ return img[h_start:h_end, w_start:w_end]
68
+
69
+
70
+ def impad(img, top=0, bottom=0, left=0, right=0, color=255):
71
+ return np.pad(img, [(top, bottom), (left, right), (0, 0)], 'reflect')
72
+
73
+
74
+ def hiseq_color_cv2_img(img):
75
+ (b, g, r) = cv2.split(img)
76
+ bH = cv2.equalizeHist(b)
77
+ gH = cv2.equalizeHist(g)
78
+ rH = cv2.equalizeHist(r)
79
+ result = cv2.merge((bH, gH, rH))
80
+ return result
81
+
82
+
83
+ def auto_padding(img, times=16):
84
+ # img: numpy image with shape H*W*C
85
+
86
+ h, w, _ = img.shape
87
+ h1, w1 = (times - h % times) // 2, (times - w % times) // 2
88
+ h2, w2 = (times - h % times) - h1, (times - w % times) - w1
89
+ img = cv2.copyMakeBorder(img, h1, h2, w1, w2, cv2.BORDER_REFLECT)
90
+ return img, [h1, h2, w1, w2]
91
+
92
+
93
+ def main(path:str):
94
+ parser = argparse.ArgumentParser()
95
+ # parser.add_argument("--opt", default="./confs/LOL_smallNet.yml")
96
+ parser.add_argument("--opt", default="./models/llflow/LOL_smallNet.yml")
97
+ parser.add_argument("-n", "--name", default="unpaired")
98
+
99
+ # Namespace(opt="./models/llflow/LOL_smallNet.yml", name="unpaired")
100
+ # args = parser.parse_args()
101
+ args = parser.parse_args()
102
+
103
+ Console().log(f"🛠️\tLoading model from {args.opt}")
104
+
105
+ conf_path = args.opt
106
+ conf = conf_path.split('/')[-1].replace('.yml', '')
107
+ model, opt = load_model(conf_path)
108
+ model.netG = model.netG.cuda()
109
+
110
+ lr_dir = opt['dataroot_unpaired']
111
+ # lr_paths = fiFindByWildcard(os.path.join(lr_dir, '*.*'))
112
+ lr_paths = path
113
+
114
+ this_dir = os.path.dirname(os.path.realpath(__file__))
115
+ test_dir = os.path.join(this_dir, '..', 'results', conf, args.name)
116
+ print(f"Out dir: {test_dir}")
117
+
118
+ # for lr_path, idx_test in tqdm.tqdm(zip(lr_paths, range(len(lr_paths))), colour='green'):
119
+ lr_path = lr_paths
120
+ lr = imread(lr_path)
121
+ raw_shape = lr.shape
122
+ lr, padding_params = auto_padding(lr)
123
+ his = hiseq_color_cv2_img(lr)
124
+ if opt.get("histeq_as_input", False):
125
+ lr = his
126
+
127
+ lr_t = t(lr)
128
+ if opt["datasets"]["train"].get("log_low", False):
129
+ lr_t = torch.log(torch.clamp(lr_t + 1e-3, min=1e-3))
130
+ if opt.get("concat_histeq", False):
131
+ his = t(his)
132
+ lr_t = torch.cat([lr_t, his], dim=1)
133
+ heat = opt['heat']
134
+ with torch.cuda.amp.autocast():
135
+ sr_t = model.get_sr(lq=lr_t.cuda(), heat=None)
136
+
137
+ sr = rgb(torch.clamp(sr_t, 0, 1)[:, :, padding_params[0]:sr_t.shape[2] - padding_params[1],
138
+ padding_params[2]:sr_t.shape[3] - padding_params[3]])
139
+ assert raw_shape == sr.shape
140
+ path_out_sr = os.path.join(test_dir, os.path.basename(lr_path))
141
+ # imwrite(path_out_sr, sr)
142
+ # cv2.imwrite(path_out_sr, sr[:, :, [2, 1, 0]])
143
+
144
+ return sr[:, :, [2, 1, 0]]
145
+
146
+
147
+ def format_measurements(meas):
148
+ s_out = []
149
+ for k, v in meas.items():
150
+ v = f"{v:0.2f}" if isinstance(v, float) else v
151
+ s_out.append(f"{k}: {v}")
152
+ str_out = ", ".join(s_out)
153
+ return str_out
154
+
155
+
156
+ if __name__ == "__main__":
157
+ main()
models/llflow/models/LLFlow_model.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from collections import OrderedDict
3
+ # from models.llflow.util import get_resume_paths, opt_get
4
+ # from models.llflow import get_resume_paths, opt_get
5
+ import glob
6
+ import os
7
+ import natsort
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.nn.parallel import DataParallel, DistributedDataParallel
11
+ import models.networks as networks
12
+ import models.lr_scheduler as lr_scheduler
13
+ from .base_model import BaseModel
14
+ from torch.cuda.amp import GradScaler, autocast
15
+
16
+ logger = logging.getLogger('base')
17
+
18
+
19
+
20
+
21
+ def get_resume_paths(opt):
22
+ resume_state_path = None
23
+ resume_model_path = None
24
+ ts = opt_get(opt, ['path', 'training_state'])
25
+ if opt.get('path', {}).get('resume_state', None) == "auto" and ts is not None:
26
+ wildcard = os.path.join(ts, "*")
27
+ paths = natsort.natsorted(glob.glob(wildcard))
28
+ if len(paths) > 0:
29
+ resume_state_path = paths[-1]
30
+ resume_model_path = resume_state_path.replace(
31
+ 'training_state', 'models').replace('.state', '_G.pth')
32
+ else:
33
+ resume_state_path = opt.get('path', {}).get('resume_state')
34
+ return resume_state_path, resume_model_path
35
+
36
+
37
+ def opt_get(opt, keys, default=None):
38
+ if opt is None:
39
+ return default
40
+ ret = opt
41
+ for k in keys:
42
+ ret = ret.get(k, None)
43
+ if ret is None:
44
+ return default
45
+ return ret
46
+
47
+
48
+
49
+
50
+
51
+
52
+
53
+
54
+
55
+
56
+
57
+
58
+ class LLFlowModel(BaseModel):
59
+ def __init__(self, opt, step):
60
+ super(LLFlowModel, self).__init__(opt)
61
+ self.opt = opt
62
+
63
+ self.already_print_params_num = False
64
+
65
+ self.heats = opt['val']['heats']
66
+ self.n_sample = opt['val']['n_sample']
67
+ self.hr_size = opt['datasets']['train']['GT_size'] # opt_get(opt, ['datasets', 'train', 'center_crop_hr_size'])
68
+ # self.hr_size = 160 if self.hr_size is None else self.hr_size
69
+ self.lr_size = self.hr_size // opt['scale']
70
+
71
+ if opt['dist']:
72
+ self.rank = torch.distributed.get_rank()
73
+ else:
74
+ self.rank = -1 # non dist training
75
+ train_opt = opt['train']
76
+
77
+ # define network and load pretrained models
78
+ self.netG = networks.define_Flow(opt, step).to(self.device)
79
+ #
80
+ weight_l1 = opt_get(self.opt, ['train', 'weight_l1']) or 0
81
+ if weight_l1 and 1:
82
+ missing_keys, unexpected_keys = self.netG.load_state_dict(torch.load(
83
+ '/home/yufei/project/LowLightFlow/experiments/to_pretrain_netG/models/1000_G.pth'),
84
+ strict=False)
85
+ print('missing %d keys, unexpected %d keys' % (len(missing_keys), len(unexpected_keys)))
86
+ # if self.device.type != 'cpu':
87
+ if opt['gpu_ids'] is not None and len(opt['gpu_ids']) > 0:
88
+ if opt['dist']:
89
+ self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
90
+ elif len(opt['gpu_ids']) > 1:
91
+ self.netG = DataParallel(self.netG, opt['gpu_ids'])
92
+ else:
93
+ self.netG.cuda()
94
+ # print network
95
+ # self.print_network()
96
+
97
+ if opt_get(opt, ['path', 'resume_state'], 1) is not None:
98
+ self.load()
99
+ else:
100
+ print("WARNING: skipping initial loading, due to resume_state None")
101
+
102
+ if self.is_train:
103
+ self.netG.train()
104
+
105
+ self.init_optimizer_and_scheduler(train_opt)
106
+ self.log_dict = OrderedDict()
107
+
108
+ def to(self, device):
109
+ self.device = device
110
+ self.netG.to(device)
111
+
112
+ def init_optimizer_and_scheduler(self, train_opt):
113
+ # optimizers
114
+ self.optimizers = []
115
+ wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
116
+ if isinstance(wd_G, str): wd_G = eval(wd_G)
117
+ optim_params_RRDB = []
118
+ optim_params_other = []
119
+ for k, v in self.netG.named_parameters(): # can optimize for a part of the model
120
+ # print(k, v.requires_grad)
121
+ if v.requires_grad:
122
+ if '.RRDB.' in k:
123
+ optim_params_RRDB.append(v)
124
+ # print('opt', k)
125
+ else:
126
+ optim_params_other.append(v)
127
+ # if self.rank <= 0:
128
+ # logger.warning('Params [{:s}] will not optimize.'.format(k))
129
+
130
+ print('rrdb params', len(optim_params_RRDB))
131
+
132
+ self.optimizer_G = torch.optim.Adam(
133
+ [
134
+ {"params": optim_params_other, "lr": train_opt['lr_G'], 'beta1': train_opt['beta1'],
135
+ 'beta2': train_opt['beta2'], 'weight_decay': wd_G},
136
+ {"params": optim_params_RRDB, "lr": train_opt.get('lr_RRDB', train_opt['lr_G']),
137
+ 'beta1': train_opt['beta1'],
138
+ 'beta2': train_opt['beta2'], 'weight_decay': 1e-5}
139
+ ]
140
+ )
141
+
142
+ self.scaler = GradScaler()
143
+
144
+ self.optimizers.append(self.optimizer_G)
145
+ # schedulers
146
+ if train_opt['lr_scheme'] == 'MultiStepLR':
147
+ for optimizer in self.optimizers:
148
+ self.schedulers.append(
149
+ lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
150
+ restarts=train_opt['restarts'],
151
+ weights=train_opt['restart_weights'],
152
+ gamma=train_opt['lr_gamma'],
153
+ clear_state=train_opt['clear_state'],
154
+ lr_steps_invese=train_opt.get('lr_steps_inverse', [])))
155
+ elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
156
+ for optimizer in self.optimizers:
157
+ self.schedulers.append(
158
+ lr_scheduler.CosineAnnealingLR_Restart(
159
+ optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
160
+ restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
161
+ else:
162
+ raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
163
+
164
+ def add_optimizer_and_scheduler_RRDB(self, train_opt):
165
+ # optimizers
166
+ assert len(self.optimizers) == 1, self.optimizers
167
+ assert len(self.optimizer_G.param_groups[1]['params']) == 0, self.optimizer_G.param_groups[1]
168
+ for k, v in self.netG.named_parameters(): # can optimize for a part of the model
169
+ if v.requires_grad:
170
+ if '.RRDB.' in k:
171
+ self.optimizer_G.param_groups[1]['params'].append(v)
172
+ assert len(self.optimizer_G.param_groups[1]['params']) > 0
173
+
174
+ def feed_data(self, data, need_GT=True):
175
+ self.var_L = data['LQ'].to(self.device) # LQ
176
+ if need_GT:
177
+ self.real_H = data['GT'].to(self.device) # GT
178
+
179
+ def get_module(self, model):
180
+ if isinstance(model, nn.DataParallel):
181
+ return model.module
182
+ else:
183
+ return model
184
+
185
+ def optimize_color_encoder(self, step):
186
+ self.netG.train()
187
+ self.log_dict = OrderedDict()
188
+ self.optimizer_G.zero_grad()
189
+ color_lr, color_gt = self.fake_H[(0, 0)], logdet = self.netG(lr=self.var_L, gt=self.real_H,
190
+ get_color_map=True)
191
+ losses = {}
192
+ total_loss = (color_gt - color_lr).abs().mean()
193
+ # try:
194
+ total_loss.backward()
195
+ self.optimizer_G.step()
196
+ mean = total_loss.item()
197
+ return mean
198
+
199
+ def optimize_parameters(self, step):
200
+ train_RRDB_delay = opt_get(self.opt, ['network_G', 'train_RRDB_delay'])
201
+ if train_RRDB_delay is not None and step > int(train_RRDB_delay * self.opt['train']['niter']) \
202
+ and not self.get_module(self.netG).RRDB_training:
203
+ if self.get_module(self.netG).set_rrdb_training(True):
204
+ self.add_optimizer_and_scheduler_RRDB(self.opt['train'])
205
+
206
+ # self.print_rrdb_state()
207
+
208
+ self.netG.train()
209
+ self.log_dict = OrderedDict()
210
+ self.optimizer_G.zero_grad()
211
+ # with autocast():
212
+ losses = {}
213
+ weight_fl = opt_get(self.opt, ['train', 'weight_fl'])
214
+ weight_fl = 1 if weight_fl is None else weight_fl
215
+ weight_l1 = opt_get(self.opt, ['train', 'weight_l1']) or 0
216
+ flow_warm_up_iter = opt_get(self.opt, ['train', 'flow_warm_up_iter'])
217
+ # print(step, flow_warm_up_iter)
218
+ if flow_warm_up_iter is not None:
219
+ if step > flow_warm_up_iter:
220
+ weight_fl = 0
221
+ else:
222
+ weight_l1 = 0
223
+ # print(weight_fl, weight_l1)
224
+ if weight_fl > 0:
225
+ if self.opt['optimize_all_z']:
226
+ if self.opt['gpu_ids'] is not None and len(self.opt['gpu_ids']) > 0:
227
+ epses = [[] for _ in range(len(self.opt['gpu_ids']))]
228
+ else:
229
+ epses = []
230
+ else:
231
+ epses = None
232
+ z, nll, y_logits = self.netG(gt=self.real_H, lr=self.var_L, reverse=False, epses=epses,
233
+ align_condition_feature=opt_get(self.opt,
234
+ ['align_condition_feature']) or False)
235
+ nll_loss = torch.mean(nll)
236
+ losses['nll_loss'] = nll_loss * weight_fl
237
+
238
+ if weight_l1 > 0:
239
+ z = self.get_z(heat=0, seed=None, batch_size=self.var_L.shape[0], lr_shape=self.var_L.shape)
240
+ sr, logdet = self.netG(lr=self.var_L, z=z, eps_std=0, reverse=True, reverse_with_grad=True)
241
+ sr = sr.clamp(0, 1)
242
+ not_nan_mask = ~torch.isnan(sr)
243
+ sr[torch.isnan(sr)] = 0
244
+ l1_loss = ((sr - self.real_H) * not_nan_mask).abs().mean()
245
+ losses['l1_loss'] = l1_loss * weight_l1
246
+ if flow_warm_up_iter is not None:
247
+ print(l1_loss, not_nan_mask.float().mean())
248
+ total_loss = sum(losses.values())
249
+ # try:
250
+ self.scaler.scale(total_loss).backward()
251
+ if not self.already_print_params_num:
252
+ logger.info("Parameters of full network %.4f and encoder %.4f"%(sum([m.numel() for m in self.netG.parameters() if m.grad is not None])/1e6, sum([m.numel() for m in self.netG.RRDB.parameters() if m.grad is not None])/1e6))
253
+ self.already_print_params_num = True
254
+ self.scaler.step(self.optimizer_G)
255
+ self.scaler.update()
256
+ # except Exception as e:
257
+ # print(e)
258
+ # print(total_loss)
259
+
260
+ mean = total_loss.item()
261
+ return mean
262
+
263
+ def print_rrdb_state(self):
264
+ for name, param in self.get_module(self.netG).named_parameters():
265
+ if "RRDB.conv_first.weight" in name:
266
+ print(name, param.requires_grad, param.data.abs().sum())
267
+ print('params', [len(p['params']) for p in self.optimizer_G.param_groups])
268
+
269
+ def get_color_map(self):
270
+ self.netG.eval()
271
+ z = self.get_z(0, seed=None, batch_size=self.var_L.shape[0], lr_shape=self.var_L.shape)
272
+ with torch.no_grad():
273
+ color_lr, color_gt = self.fake_H[(0, 0)], logdet = self.netG(lr=self.var_L, gt=self.real_H,
274
+ get_color_map=True)
275
+ self.netG.train()
276
+ return color_lr, color_gt
277
+
278
+ def test(self):
279
+ self.netG.eval()
280
+ self.fake_H = {}
281
+ if self.heats is not None:
282
+ for heat in self.heats:
283
+ for i in range(self.n_sample):
284
+ z = self.get_z(heat, seed=None, batch_size=self.var_L.shape[0], lr_shape=self.var_L.shape)
285
+ with torch.no_grad():
286
+ self.fake_H[(heat, i)], logdet = self.netG(lr=self.var_L, z=z, eps_std=heat, reverse=True)
287
+ else:
288
+ z = self.get_z(0, seed=None, batch_size=self.var_L.shape[0], lr_shape=self.var_L.shape)
289
+ with torch.no_grad():
290
+ # torch.cuda.reset_peak_memory_stats()
291
+ self.fake_H[(0, 0)], logdet = self.netG(lr=self.var_L, z=z.to(self.var_L.device), eps_std=0, reverse=True)
292
+ # from thop import clever_format, profile
293
+ # print(clever_format(profile(self.netG, (None,self.var_L, z.to(self.var_L.device), 0 ,True))),"%.4")
294
+ # print(torch.cuda.max_memory_allocated()/1024/1024/1024)
295
+ # import time
296
+ # t = time.time()
297
+ # for i in range(15):
298
+ # with torch.no_grad():
299
+ # self.fake_H[(0, 0)], logdet = self.netG(lr=self.var_L, z=z.to(self.var_L.device), eps_std=0, reverse=True)
300
+ # print((time.time()-t)/15)
301
+ # with torch.no_grad():
302
+ # _, nll, _ = self.netG(gt=self.real_H, lr=self.var_L, reverse=False)
303
+ self.netG.train()
304
+ return None
305
+ # return nll.mean().item()
306
+
307
+ def get_encode_nll(self, lq, gt):
308
+ self.netG.eval()
309
+ with torch.no_grad():
310
+ _, nll, _ = self.netG(gt=gt, lr=lq, reverse=False)
311
+ self.netG.train()
312
+ return nll.mean().item()
313
+
314
+ def get_sr(self, lq, heat=None, seed=None, z=None, epses=None):
315
+ return self.get_sr_with_z(lq, heat, seed, z, epses)[0]
316
+
317
+ def get_encode_z(self, lq, gt, epses=None, add_gt_noise=True):
318
+ self.netG.eval()
319
+ with torch.no_grad():
320
+ z, _, _ = self.netG(gt=gt, lr=lq, reverse=False, epses=epses, add_gt_noise=add_gt_noise)
321
+ self.netG.train()
322
+ return z
323
+
324
+ def get_encode_z_and_nll(self, lq, gt, epses=None, add_gt_noise=True):
325
+ self.netG.eval()
326
+ with torch.no_grad():
327
+ z, nll, _ = self.netG(gt=gt, lr=lq, reverse=False, epses=epses, add_gt_noise=add_gt_noise)
328
+ self.netG.train()
329
+ return z, nll
330
+
331
+ def get_sr_with_z(self, lq, heat=None, seed=None, z=None, epses=None):
332
+ self.netG.eval()
333
+ if heat is None:
334
+ heat = 0
335
+ z = self.get_z(heat, seed, batch_size=lq.shape[0], lr_shape=lq.shape) if z is None and epses is None else z
336
+
337
+ with torch.no_grad():
338
+ sr, logdet = self.netG(lr=lq, z=z, eps_std=heat, reverse=True, epses=epses)
339
+ self.netG.train()
340
+ return sr, z
341
+
342
+ def get_z(self, heat, seed=None, batch_size=1, lr_shape=None):
343
+ if seed: torch.manual_seed(seed)
344
+ if opt_get(self.opt, ['network_G', 'flow', 'split', 'enable']):
345
+ C = self.get_module(self.netG).flowUpsamplerNet.C
346
+ H = int(self.opt['scale'] * lr_shape[2] // self.get_module(self.netG).flowUpsamplerNet.scaleH)
347
+ W = int(self.opt['scale'] * lr_shape[3] // self.get_module(self.netG).flowUpsamplerNet.scaleW)
348
+ z = torch.normal(mean=0, std=heat, size=(batch_size, C, H, W)) if heat > 0 else torch.zeros(
349
+ (batch_size, C, H, W))
350
+ else:
351
+ L = opt_get(self.opt, ['network_G', 'flow', 'L']) or 3
352
+ fac = 2 ** L
353
+ H = int(self.opt['scale'] * lr_shape[2] // self.get_module(self.netG).flowUpsamplerNet.scaleH)
354
+ W = int(self.opt['scale'] * lr_shape[3] // self.get_module(self.netG).flowUpsamplerNet.scaleW)
355
+ size = (batch_size, 3 * fac * fac, H, W)
356
+ z = torch.normal(mean=0, std=heat, size=size) if heat > 0 else torch.zeros(size)
357
+ return z
358
+
359
+ def get_current_log(self):
360
+ return self.log_dict
361
+
362
+ def get_current_visuals(self, need_GT=True):
363
+ out_dict = OrderedDict()
364
+ out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
365
+ if self.heats is not None:
366
+ for heat in self.heats:
367
+ for i in range(self.n_sample):
368
+ out_dict[('NORMAL', heat, i)] = self.fake_H[(heat, i)].detach()[0].float().cpu()
369
+ else:
370
+ out_dict['NORMAL'] = self.fake_H[(0, 0)].detach()[0].float().cpu()
371
+ if need_GT:
372
+ out_dict['GT'] = self.real_H.detach()[0].float().cpu()
373
+ return out_dict
374
+
375
+ def print_network(self):
376
+ s, n = self.get_network_description(self.netG)
377
+ if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
378
+ net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
379
+ self.netG.module.__class__.__name__)
380
+ else:
381
+ net_struc_str = '{}'.format(self.netG.__class__.__name__)
382
+ if self.rank <= 0:
383
+ logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
384
+ logger.info(s)
385
+
386
+ def load(self):
387
+ _, get_resume_model_path = get_resume_paths(self.opt)
388
+ if get_resume_model_path is not None:
389
+ self.load_network(get_resume_model_path, self.netG, strict=True, submodule=None)
390
+ return
391
+
392
+ load_path_G = self.opt['path']['pretrain_model_G']
393
+ load_submodule = self.opt['path']['load_submodule'] if 'load_submodule' in self.opt['path'].keys() else 'RRDB'
394
+ if load_path_G is not None:
395
+ logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
396
+ self.load_network(load_path_G, self.netG, self.opt['path'].get('strict_load', True),
397
+ submodule=load_submodule)
398
+
399
+ def save(self, iter_label):
400
+ self.save_network(self.netG, 'G', iter_label)
models/llflow/models/__init__.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import logging
3
+ import os
4
+
5
+ try:
6
+ import local_config
7
+ except:
8
+ local_config = None
9
+
10
+
11
+ logger = logging.getLogger('base')
12
+
13
+
14
+ def find_model_using_name(model_name):
15
+ # Given the option --model [modelname],
16
+ # the file "models/modelname_model.py"
17
+ # will be imported.
18
+ model_filename = "models." + model_name + "_model"
19
+ modellib = importlib.import_module(model_filename)
20
+
21
+ # In the file, the class called ModelNameModel() will
22
+ # be instantiated. It has to be a subclass of torch.nn.Module,
23
+ # and it is case-insensitive.
24
+ model = None
25
+ target_model_name = model_name.replace('_', '') + 'Model'
26
+ for name, cls in modellib.__dict__.items():
27
+ if name.lower() == target_model_name.lower():
28
+ model = cls
29
+
30
+ if model is None:
31
+ print(
32
+ "In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s." % (
33
+ model_filename, target_model_name))
34
+ exit(0)
35
+
36
+ return model
37
+
38
+
39
+ def create_model(opt, step=0, **opt_kwargs):
40
+ if local_config is not None:
41
+ opt['path']['pretrain_model_G'] = os.path.join(local_config.checkpoint_path, os.path.basename(opt['path']['results_root'] + '.pth'))
42
+
43
+ for k, v in opt_kwargs.items():
44
+ opt[k] = v
45
+
46
+ model = opt['model']
47
+
48
+ M = find_model_using_name(model)
49
+
50
+ m = M(opt, step)
51
+ logger.info('Model [{:s}] is created.'.format(m.__class__.__name__))
52
+ return m
models/llflow/models/__pycache__/LLFlow_model.cpython-310.pyc ADDED
Binary file (12.7 kB). View file
 
models/llflow/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.28 kB). View file
 
models/llflow/models/__pycache__/base_model.cpython-310.pyc ADDED
Binary file (6.54 kB). View file
 
models/llflow/models/__pycache__/lr_scheduler.cpython-310.pyc ADDED
Binary file (5.11 kB). View file
 
models/llflow/models/__pycache__/networks.cpython-310.pyc ADDED
Binary file (1.19 kB). View file
 
models/llflow/models/base_model.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ import os
5
+ from collections import OrderedDict
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.nn.parallel import DistributedDataParallel
9
+ import natsort
10
+ import glob
11
+
12
+
13
+ class BaseModel():
14
+ def __init__(self, opt):
15
+ self.opt = opt
16
+ self.device = torch.device('cuda' if opt.get('gpu_ids', None) is not None else 'cpu')
17
+ self.is_train = opt['is_train']
18
+ self.schedulers = []
19
+ self.optimizers = []
20
+ self.scaler = None
21
+
22
+ def feed_data(self, data):
23
+ pass
24
+
25
+ def optimize_parameters(self):
26
+ pass
27
+
28
+ def get_current_visuals(self):
29
+ pass
30
+
31
+ def get_current_losses(self):
32
+ pass
33
+
34
+ def print_network(self):
35
+ pass
36
+
37
+ def save(self, label):
38
+ pass
39
+
40
+ def load(self):
41
+ pass
42
+
43
+ def _set_lr(self, lr_groups_l):
44
+ ''' set learning rate for warmup,
45
+ lr_groups_l: list for lr_groups. each for a optimizer'''
46
+ for optimizer, lr_groups in zip(self.optimizers, lr_groups_l):
47
+ for param_group, lr in zip(optimizer.param_groups, lr_groups):
48
+ param_group['lr'] = lr
49
+
50
+ def _get_init_lr(self):
51
+ # get the initial lr, which is set by the scheduler
52
+ init_lr_groups_l = []
53
+ for optimizer in self.optimizers:
54
+ init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups])
55
+ return init_lr_groups_l
56
+
57
+ def update_learning_rate(self, cur_iter, warmup_iter=-1):
58
+ for scheduler in self.schedulers:
59
+ scheduler.step()
60
+ #### set up warm up learning rate
61
+ if cur_iter < warmup_iter:
62
+ # get initial lr for each group
63
+ init_lr_g_l = self._get_init_lr()
64
+ # modify warming-up learning rates
65
+ warm_up_lr_l = []
66
+ for init_lr_g in init_lr_g_l:
67
+ warm_up_lr_l.append([v / warmup_iter * cur_iter for v in init_lr_g])
68
+ # set learning rate
69
+ self._set_lr(warm_up_lr_l)
70
+
71
+ def get_current_learning_rate(self):
72
+ # return self.schedulers[0].get_lr()[0]
73
+ return self.optimizers[0].param_groups[0]['lr']
74
+
75
+ def get_network_description(self, network):
76
+ '''Get the string and total parameters of the network'''
77
+ if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
78
+ network = network.module
79
+ s = str(network)
80
+ n = sum(map(lambda x: x.numel(), network.parameters()))
81
+ return s, n
82
+
83
+ def save_network(self, network, network_label, iter_label):
84
+ paths = natsort.natsorted(glob.glob(os.path.join(self.opt['path']['models'], "*_{}.pth".format(network_label))),
85
+ reverse=True)
86
+ paths = [p for p in paths if
87
+ "latest_" not in p and not any([str(i * 10000) in p.split("/")[-1].split("_") for i in range(101)])]
88
+ if len(paths) > 2:
89
+ for path in paths[2:]:
90
+ os.remove(path)
91
+ save_filename = '{}_{}.pth'.format(iter_label, network_label)
92
+ save_path = os.path.join(self.opt['path']['models'], save_filename)
93
+ if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
94
+ network = network.module
95
+ state_dict = network.state_dict()
96
+ for key, param in state_dict.items():
97
+ state_dict[key] = param.cpu()
98
+ torch.save(state_dict, save_path)
99
+
100
+ def load_network(self, load_path, network, strict=True, submodule=None):
101
+ if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
102
+ network = network.module
103
+ if not (submodule is None or submodule.lower() == 'none'.lower()):
104
+ network = network.__getattr__(submodule)
105
+ load_net = torch.load(load_path)
106
+ load_net_clean = OrderedDict() # remove unnecessary 'module.'
107
+ for k, v in load_net.items():
108
+ if k.startswith('module.'):
109
+ load_net_clean[k[7:]] = v
110
+ else:
111
+ load_net_clean[k] = v
112
+ network.load_state_dict(load_net_clean, strict=strict)
113
+
114
+ def save_training_state(self, epoch, iter_step):
115
+ '''Saves training state during training, which will be used for resuming'''
116
+ state = {'epoch': epoch, 'iter': iter_step, 'schedulers': [], 'optimizers': [], 'scaler': None}
117
+ for s in self.schedulers:
118
+ state['schedulers'].append(s.state_dict())
119
+ for o in self.optimizers:
120
+ state['optimizers'].append(o.state_dict())
121
+ state['scaler'] = self.scaler.state_dict()
122
+ save_filename = '{}.state'.format(iter_step)
123
+ save_path = os.path.join(self.opt['path']['training_state'], save_filename)
124
+
125
+ paths = natsort.natsorted(glob.glob(os.path.join(self.opt['path']['training_state'], "*.state")),
126
+ reverse=True)
127
+ paths = [p for p in paths if "latest_" not in p]
128
+ if len(paths) > 2:
129
+ for path in paths[2:]:
130
+ os.remove(path)
131
+
132
+ torch.save(state, save_path)
133
+
134
+ def resume_training(self, resume_state):
135
+ '''Resume the optimizers and schedulers for training'''
136
+ resume_optimizers = resume_state['optimizers']
137
+ resume_schedulers = resume_state['schedulers']
138
+ resume_scaler = resume_state['scaler']
139
+ assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers'
140
+ assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers'
141
+ for i, o in enumerate(resume_optimizers):
142
+ self.optimizers[i].load_state_dict(o)
143
+ for i, s in enumerate(resume_schedulers):
144
+ self.schedulers[i].load_state_dict(s)
145
+ self.scaler.load_state_dict(resume_scaler)
models/llflow/models/lr_scheduler.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from collections import Counter
3
+ from collections import defaultdict
4
+ import torch
5
+ from torch.optim.lr_scheduler import _LRScheduler
6
+
7
+
8
+ class MultiStepLR_Restart(_LRScheduler):
9
+ def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1,
10
+ clear_state=False, last_epoch=-1, lr_steps_invese=None):
11
+ assert lr_steps_invese is not None, "Use empty list"
12
+ self.milestones = Counter(milestones)
13
+ self.lr_steps_inverse = Counter(lr_steps_invese)
14
+ self.gamma = gamma
15
+ self.clear_state = clear_state
16
+ self.restarts = restarts if restarts else [0]
17
+ self.restart_weights = weights if weights else [1]
18
+ assert len(self.restarts) == len(
19
+ self.restart_weights), 'restarts and their weights do not match.'
20
+ super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch)
21
+
22
+ def get_lr(self):
23
+ if self.last_epoch in self.restarts:
24
+ if self.clear_state:
25
+ self.optimizer.state = defaultdict(dict)
26
+ weight = self.restart_weights[self.restarts.index(self.last_epoch)]
27
+ return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
28
+ if self.last_epoch not in self.milestones and self.last_epoch not in self.lr_steps_inverse:
29
+ return [group['lr'] for group in self.optimizer.param_groups]
30
+ return [
31
+ group['lr'] * (self.gamma ** self.milestones[self.last_epoch]) *
32
+ (self.gamma ** (-self.lr_steps_inverse[self.last_epoch]))
33
+ for group in self.optimizer.param_groups
34
+ ]
35
+
36
+
37
+ class CosineAnnealingLR_Restart(_LRScheduler):
38
+ def __init__(self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1):
39
+ self.T_period = T_period
40
+ self.T_max = self.T_period[0] # current T period
41
+ self.eta_min = eta_min
42
+ self.restarts = restarts if restarts else [0]
43
+ self.restart_weights = weights if weights else [1]
44
+ self.last_restart = 0
45
+ assert len(self.restarts) == len(
46
+ self.restart_weights), 'restarts and their weights do not match.'
47
+ super(CosineAnnealingLR_Restart, self).__init__(optimizer, last_epoch)
48
+
49
+ def get_lr(self):
50
+ if self.last_epoch == 0:
51
+ return self.base_lrs
52
+ elif self.last_epoch in self.restarts:
53
+ self.last_restart = self.last_epoch
54
+ self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1]
55
+ weight = self.restart_weights[self.restarts.index(self.last_epoch)]
56
+ return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
57
+ elif (self.last_epoch - self.last_restart - 1 - self.T_max) % (2 * self.T_max) == 0:
58
+ return [
59
+ group['lr'] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2
60
+ for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
61
+ ]
62
+ return [(1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) /
63
+ (1 + math.cos(math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max)) *
64
+ (group['lr'] - self.eta_min) + self.eta_min
65
+ for group in self.optimizer.param_groups]
66
+
67
+
68
+ if __name__ == "__main__":
69
+ optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=2e-4, weight_decay=0,
70
+ betas=(0.9, 0.99))
71
+ ##############################
72
+ # MultiStepLR_Restart
73
+ ##############################
74
+ ## Original
75
+ lr_steps = [200000, 400000, 600000, 800000]
76
+ restarts = None
77
+ restart_weights = None
78
+
79
+ ## two
80
+ lr_steps = [100000, 200000, 300000, 400000, 490000, 600000, 700000, 800000, 900000, 990000]
81
+ restarts = [500000]
82
+ restart_weights = [1]
83
+
84
+ ## four
85
+ lr_steps = [
86
+ 50000, 100000, 150000, 200000, 240000, 300000, 350000, 400000, 450000, 490000, 550000,
87
+ 600000, 650000, 700000, 740000, 800000, 850000, 900000, 950000, 990000
88
+ ]
89
+ restarts = [250000, 500000, 750000]
90
+ restart_weights = [1, 1, 1]
91
+
92
+ scheduler = MultiStepLR_Restart(optimizer, lr_steps, restarts, restart_weights, gamma=0.5,
93
+ clear_state=False)
94
+
95
+ ##############################
96
+ # Cosine Annealing Restart
97
+ ##############################
98
+ ## two
99
+ T_period = [500000, 500000]
100
+ restarts = [500000]
101
+ restart_weights = [1]
102
+
103
+ ## four
104
+ T_period = [250000, 250000, 250000, 250000]
105
+ restarts = [250000, 500000, 750000]
106
+ restart_weights = [1, 1, 1]
107
+
108
+ scheduler = CosineAnnealingLR_Restart(optimizer, T_period, eta_min=1e-7, restarts=restarts,
109
+ weights=restart_weights)
110
+
111
+ ##############################
112
+ # Draw figure
113
+ ##############################
114
+ N_iter = 1000000
115
+ lr_l = list(range(N_iter))
116
+ for i in range(N_iter):
117
+ scheduler.step()
118
+ current_lr = optimizer.param_groups[0]['lr']
119
+ lr_l[i] = current_lr
120
+
121
+ import matplotlib as mpl
122
+ from matplotlib import pyplot as plt
123
+ import matplotlib.ticker as mtick
124
+
125
+ mpl.style.use('default')
126
+ import seaborn
127
+
128
+ seaborn.set(style='whitegrid')
129
+ seaborn.set_context('paper')
130
+
131
+ plt.figure(1)
132
+ plt.subplot(111)
133
+ plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0))
134
+ plt.title('Title', fontsize=16, color='k')
135
+ plt.plot(list(range(N_iter)), lr_l, linewidth=1.5, label='learning rate scheme')
136
+ legend = plt.legend(loc='upper right', shadow=False)
137
+ ax = plt.gca()
138
+ labels = ax.get_xticks().tolist()
139
+ for k, v in enumerate(labels):
140
+ labels[k] = str(int(v / 1000)) + 'K'
141
+ ax.set_xticklabels(labels)
142
+ ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e'))
143
+
144
+ ax.set_ylabel('Learning rate')
145
+ ax.set_xlabel('Iteration')
146
+ fig = plt.gcf()
147
+ plt.show()
models/llflow/models/modules/ConditionEncoder.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ from torchvision.utils import save_image
4
+ import functools
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import models.modules.module_util as mutil
9
+ # from utils.util import opt_get
10
+ from models.modules.flow import Conv2dZeros
11
+
12
+
13
+ def opt_get(opt, keys, default=None):
14
+ if opt is None:
15
+ return default
16
+ ret = opt
17
+ for k in keys:
18
+ ret = ret.get(k, None)
19
+ if ret is None:
20
+ return default
21
+ return ret
22
+
23
+
24
+
25
+
26
+
27
+
28
+ class ResidualDenseBlock_5C(nn.Module):
29
+ def __init__(self, nf=64, gc=32, bias=True):
30
+ super(ResidualDenseBlock_5C, self).__init__()
31
+ # gc: growth channel, i.e. intermediate channels
32
+ self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
33
+ self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
34
+ self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
35
+ self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
36
+ self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
37
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
38
+
39
+ # initialization
40
+ mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
41
+
42
+ def forward(self, x):
43
+ x1 = self.lrelu(self.conv1(x))
44
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
45
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
46
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
47
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
48
+ return x5 * 0.2 + x
49
+ # gamma = torch.sigmoid(self.conv5(torch.cat((x, x1, x2, x3, x4), 1)))
50
+ # x = torch.sigmoid(x)
51
+ # return x + gamma * x * (1 - x)
52
+
53
+
54
+ class RRDB(nn.Module):
55
+ '''Residual in Residual Dense Block'''
56
+
57
+ def __init__(self, nf, gc=32):
58
+ super(RRDB, self).__init__()
59
+ self.RDB1 = ResidualDenseBlock_5C(nf, gc)
60
+ self.RDB2 = ResidualDenseBlock_5C(nf, gc)
61
+ self.RDB3 = ResidualDenseBlock_5C(nf, gc)
62
+
63
+ def forward(self, x):
64
+ out = self.RDB1(x)
65
+ out = self.RDB2(out)
66
+ out = self.RDB3(out)
67
+ return out * 0.2 + x
68
+
69
+
70
+ class ConEncoder1(nn.Module):
71
+ def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, opt=None):
72
+ self.opt = opt
73
+ self.gray_map_bool = False
74
+ self.concat_color_map = False
75
+ if opt['concat_histeq']:
76
+ in_nc = in_nc + 3
77
+ if opt['concat_color_map']:
78
+ in_nc = in_nc + 3
79
+ self.concat_color_map = True
80
+ if opt['gray_map']:
81
+ in_nc = in_nc + 1
82
+ self.gray_map_bool = True
83
+ in_nc = in_nc + 6
84
+ super(ConEncoder1, self).__init__()
85
+ RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
86
+ self.scale = scale
87
+
88
+ self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
89
+ self.conv_second = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
90
+ self.RRDB_trunk = mutil.make_layer(RRDB_block_f, nb)
91
+ self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
92
+ #### downsampling
93
+ self.downconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
94
+ self.downconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
95
+ self.downconv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
96
+ # self.downconv4 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
97
+
98
+ self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
99
+ self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
100
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
101
+
102
+ self.awb_para = nn.Linear(nf, 3)
103
+ self.fine_tune_color_map = nn.Sequential(nn.Conv2d(nf, 3, 1, 1),nn.Sigmoid())
104
+
105
+ def forward(self, x, get_steps=False):
106
+ if self.gray_map_bool:
107
+ x = torch.cat([x, 1 - x.mean(dim=1, keepdim=True)], dim=1)
108
+ if self.concat_color_map:
109
+ x = torch.cat([x, x / (x.sum(dim=1, keepdim=True) + 1e-4)], dim=1)
110
+
111
+ raw_low_input = x[:, 0:3].exp()
112
+ # fea_for_awb = F.adaptive_avg_pool2d(fea_down8, 1).view(-1, 64)
113
+ awb_weight = 1 # (1 + self.awb_para(fea_for_awb).unsqueeze(2).unsqueeze(3))
114
+ low_after_awb = raw_low_input * awb_weight
115
+ # import pdb
116
+ # pdb.set_trace()
117
+ color_map = low_after_awb / (low_after_awb.sum(dim=1, keepdims=True) + 1e-4)
118
+ dx, dy = self.gradient(color_map)
119
+ noise_map = torch.max(torch.stack([dx.abs(), dy.abs()], dim=0), dim=0)[0]
120
+ # color_map = self.fine_tune_color_map(torch.cat([color_map, noise_map], dim=1))
121
+
122
+ fea = self.conv_first(torch.cat([x, color_map, noise_map], dim=1))
123
+ fea = self.lrelu(fea)
124
+ fea = self.conv_second(fea)
125
+ fea_head = F.max_pool2d(fea, 2)
126
+
127
+ block_idxs = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or []
128
+ block_results = {}
129
+ fea = fea_head
130
+ for idx, m in enumerate(self.RRDB_trunk.children()):
131
+ fea = m(fea)
132
+ for b in block_idxs:
133
+ if b == idx:
134
+ block_results["block_{}".format(idx)] = fea
135
+ trunk = self.trunk_conv(fea)
136
+ # fea = F.max_pool2d(fea, 2)
137
+ fea_down2 = fea_head + trunk
138
+
139
+ fea_down4 = self.downconv1(F.interpolate(fea_down2, scale_factor=1 / 2, mode='bilinear', align_corners=False,
140
+ recompute_scale_factor=True))
141
+ fea = self.lrelu(fea_down4)
142
+
143
+ fea_down8 = self.downconv2(
144
+ F.interpolate(fea, scale_factor=1 / 2, mode='bilinear', align_corners=False, recompute_scale_factor=True))
145
+ # fea = self.lrelu(fea_down8)
146
+
147
+ # fea_down16 = self.downconv3(
148
+ # F.interpolate(fea, scale_factor=1 / 2, mode='bilinear', align_corners=False, recompute_scale_factor=True))
149
+ # fea = self.lrelu(fea_down16)
150
+
151
+ results = {'fea_up0': fea_down8,
152
+ 'fea_up1': fea_down4,
153
+ 'fea_up2': fea_down2,
154
+ 'fea_up4': fea_head,
155
+ 'last_lr_fea': fea_down4,
156
+ 'color_map': self.fine_tune_color_map(F.interpolate(fea_down2, scale_factor=2))
157
+ }
158
+
159
+ # 'color_map': color_map} # raw
160
+
161
+ if get_steps:
162
+ for k, v in block_results.items():
163
+ results[k] = v
164
+ return results
165
+ else:
166
+ return None
167
+
168
+ def gradient(self, x):
169
+ def sub_gradient(x):
170
+ left_shift_x, right_shift_x, grad = torch.zeros_like(
171
+ x), torch.zeros_like(x), torch.zeros_like(x)
172
+ left_shift_x[:, :, 0:-1] = x[:, :, 1:]
173
+ right_shift_x[:, :, 1:] = x[:, :, 0:-1]
174
+ grad = 0.5 * (left_shift_x - right_shift_x)
175
+ return grad
176
+
177
+ return sub_gradient(x), sub_gradient(torch.transpose(x, 2, 3)).transpose(2, 3)
178
+
179
+
180
+ class NoEncoder(nn.Module):
181
+ def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, opt=None):
182
+ self.opt = opt
183
+ self.gray_map_bool = False
184
+ self.concat_color_map = False
185
+ if opt['concat_histeq']:
186
+ in_nc = in_nc + 3
187
+ if opt['concat_color_map']:
188
+ in_nc = in_nc + 3
189
+ self.concat_color_map = True
190
+ if opt['gray_map']:
191
+ in_nc = in_nc + 1
192
+ self.gray_map_bool = True
193
+ in_nc = in_nc + 6
194
+ super(NoEncoder, self).__init__()
195
+ RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
196
+ self.scale = scale
197
+
198
+ self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
199
+ self.conv_second = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
200
+ self.RRDB_trunk = mutil.make_layer(RRDB_block_f, nb)
201
+ self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
202
+ #### downsampling
203
+ self.downconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
204
+ self.downconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
205
+ self.downconv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
206
+ # self.downconv4 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
207
+
208
+ self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
209
+ self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
210
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
211
+
212
+ self.awb_para = nn.Linear(nf, 3)
213
+ self.fine_tune_color_map = nn.Sequential(nn.Conv2d(nf, 3, 1, 1),nn.Sigmoid())
214
+
215
+ def forward(self, x, get_steps=False):
216
+ if self.gray_map_bool:
217
+ x = torch.cat([x, 1 - x.mean(dim=1, keepdim=True)], dim=1)
218
+ if self.concat_color_map:
219
+ x = torch.cat([x, x / (x.sum(dim=1, keepdim=True) + 1e-4)], dim=1)
220
+
221
+ raw_low_input = x[:, 0:3].exp()
222
+ # fea_for_awb = F.adaptive_avg_pool2d(fea_down8, 1).view(-1, 64)
223
+ awb_weight = 1 # (1 + self.awb_para(fea_for_awb).unsqueeze(2).unsqueeze(3))
224
+ low_after_awb = raw_low_input * awb_weight
225
+ # import pdb
226
+ # pdb.set_trace()
227
+ color_map = low_after_awb / (low_after_awb.sum(dim=1, keepdims=True) + 1e-4)
228
+ dx, dy = self.gradient(color_map)
229
+ noise_map = torch.max(torch.stack([dx.abs(), dy.abs()], dim=0), dim=0)[0]
230
+ # color_map = self.fine_tune_color_map(torch.cat([color_map, noise_map], dim=1))
231
+
232
+ fea = self.conv_first(torch.cat([x, color_map, noise_map], dim=1))
233
+ fea = self.lrelu(fea)
234
+ fea = self.conv_second(fea)
235
+ fea_head = F.max_pool2d(fea, 2)
236
+
237
+ block_idxs = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or []
238
+ block_results = {}
239
+ fea = fea_head
240
+ for idx, m in enumerate(self.RRDB_trunk.children()):
241
+ fea = m(fea)
242
+ for b in block_idxs:
243
+ if b == idx:
244
+ block_results["block_{}".format(idx)] = fea
245
+ trunk = self.trunk_conv(fea)
246
+ # fea = F.max_pool2d(fea, 2)
247
+ fea_down2 = fea_head + trunk
248
+
249
+ fea_down4 = self.downconv1(F.interpolate(fea_down2, scale_factor=1 / 2, mode='bilinear', align_corners=False,
250
+ recompute_scale_factor=True))
251
+ fea = self.lrelu(fea_down4)
252
+
253
+ fea_down8 = self.downconv2(
254
+ F.interpolate(fea, scale_factor=1 / 2, mode='bilinear', align_corners=False, recompute_scale_factor=True))
255
+ # fea = self.lrelu(fea_down8)
256
+
257
+ # fea_down16 = self.downconv3(
258
+ # F.interpolate(fea, scale_factor=1 / 2, mode='bilinear', align_corners=False, recompute_scale_factor=True))
259
+ # fea = self.lrelu(fea_down16)
260
+
261
+ results = {'fea_up0': fea_down8*0,
262
+ 'fea_up1': fea_down4*0,
263
+ 'fea_up2': fea_down2*0,
264
+ 'fea_up4': fea_head*0,
265
+ 'last_lr_fea': fea_down4*0,
266
+ 'color_map': self.fine_tune_color_map(F.interpolate(fea_down2, scale_factor=2))*0
267
+ }
268
+
269
+ # 'color_map': color_map} # raw
270
+
271
+ if get_steps:
272
+ for k, v in block_results.items():
273
+ results[k] = v
274
+ return results
275
+ else:
276
+ return None
277
+
278
+ def gradient(self, x):
279
+ def sub_gradient(x):
280
+ left_shift_x, right_shift_x, grad = torch.zeros_like(
281
+ x), torch.zeros_like(x), torch.zeros_like(x)
282
+ left_shift_x[:, :, 0:-1] = x[:, :, 1:]
283
+ right_shift_x[:, :, 1:] = x[:, :, 0:-1]
284
+ grad = 0.5 * (left_shift_x - right_shift_x)
285
+ return grad
286
+
287
+ return sub_gradient(x), sub_gradient(torch.transpose(x, 2, 3)).transpose(2, 3)
models/llflow/models/modules/FlowActNorms.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ import torch
5
+ from torch import nn as nn
6
+
7
+ from models.modules import thops
8
+
9
+
10
+ class _ActNorm(nn.Module):
11
+ """
12
+ Activation Normalization
13
+ Initialize the bias and scale with a given minibatch,
14
+ so that the output per-channel have zero mean and unit variance for that.
15
+
16
+ After initialization, `bias` and `logs` will be trained as parameters.
17
+ """
18
+
19
+ def __init__(self, num_features, scale=1.):
20
+ super().__init__()
21
+ # register mean and scale
22
+ size = [1, num_features, 1, 1]
23
+ self.register_parameter("bias", nn.Parameter(torch.zeros(*size)))
24
+ self.register_parameter("logs", nn.Parameter(torch.zeros(*size)))
25
+ self.num_features = num_features
26
+ self.scale = float(scale)
27
+ self.inited = False
28
+
29
+ def _check_input_dim(self, input):
30
+ return NotImplemented
31
+
32
+ def initialize_parameters(self, input):
33
+ self._check_input_dim(input)
34
+ if not self.training:
35
+ return
36
+ if (self.bias != 0).any():
37
+ self.inited = True
38
+ return
39
+ assert input.device == self.bias.device, (input.device, self.bias.device)
40
+ with torch.no_grad():
41
+ bias = thops.mean(input.clone(), dim=[0, 2, 3], keepdim=True) * -1.0
42
+ vars = thops.mean((input.clone() + bias) ** 2, dim=[0, 2, 3], keepdim=True)
43
+ logs = torch.log(self.scale / (torch.sqrt(vars) + 1e-6))
44
+ self.bias.data.copy_(bias.data)
45
+ self.logs.data.copy_(logs.data)
46
+ self.inited = True
47
+
48
+ def _center(self, input, reverse=False, offset=None):
49
+ bias = self.bias
50
+
51
+ if offset is not None:
52
+ bias = bias + offset
53
+
54
+ if not reverse:
55
+ return input + bias
56
+ else:
57
+ return input - bias
58
+
59
+ def _scale(self, input, logdet=None, reverse=False, offset=None):
60
+ logs = self.logs
61
+
62
+ if offset is not None:
63
+ logs = logs + offset
64
+
65
+ if not reverse:
66
+ input = input * torch.exp(logs) # should have shape batchsize, n_channels, 1, 1
67
+ # input = input * torch.exp(logs+logs_offset)
68
+ else:
69
+ input = input * torch.exp(-logs)
70
+ if logdet is not None:
71
+ """
72
+ logs is log_std of `mean of channels`
73
+ so we need to multiply pixels
74
+ """
75
+ dlogdet = thops.sum(logs) * thops.pixels(input)
76
+ if reverse:
77
+ dlogdet *= -1
78
+ logdet = logdet + dlogdet
79
+ return input, logdet
80
+
81
+ def forward(self, input, logdet=None, reverse=False, offset_mask=None, logs_offset=None, bias_offset=None):
82
+ if not self.inited:
83
+ self.initialize_parameters(input)
84
+ self._check_input_dim(input)
85
+
86
+ if offset_mask is not None:
87
+ logs_offset *= offset_mask
88
+ bias_offset *= offset_mask
89
+ # no need to permute dims as old version
90
+ if not reverse:
91
+ # center and scale
92
+
93
+ # self.input = input
94
+ input = self._center(input, reverse, bias_offset)
95
+ input, logdet = self._scale(input, logdet, reverse, logs_offset)
96
+ else:
97
+ # scale and center
98
+ input, logdet = self._scale(input, logdet, reverse, logs_offset)
99
+ input = self._center(input, reverse, bias_offset)
100
+ return input, logdet
101
+
102
+
103
+ class ActNorm2d(_ActNorm):
104
+ def __init__(self, num_features, scale=1.):
105
+ super().__init__(num_features, scale)
106
+
107
+ def _check_input_dim(self, input):
108
+ assert len(input.size()) == 4
109
+ assert input.size(1) == self.num_features, (
110
+ "[ActNorm]: input should be in shape as `BCHW`,"
111
+ " channels should be {} rather than {}".format(
112
+ self.num_features, input.size()))
113
+
114
+
115
+ class MaskedActNorm2d(ActNorm2d):
116
+ def __init__(self, num_features, scale=1.):
117
+ super().__init__(num_features, scale)
118
+
119
+ def forward(self, input, mask, logdet=None, reverse=False):
120
+
121
+ assert mask.dtype == torch.bool
122
+ output, logdet_out = super().forward(input, logdet, reverse)
123
+
124
+ input[mask] = output[mask]
125
+ logdet[mask] = logdet_out[mask]
126
+
127
+ return input, logdet
128
+
models/llflow/models/modules/FlowAffineCouplingsAblation.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from torch import nn as nn
4
+
5
+ from models.modules import thops
6
+ from models.modules.flow import Conv2d, Conv2dZeros
7
+ # from utils.util import opt_get
8
+
9
+
10
+ def opt_get(opt, keys, default=None):
11
+ if opt is None:
12
+ return default
13
+ ret = opt
14
+ for k in keys:
15
+ ret = ret.get(k, None)
16
+ if ret is None:
17
+ return default
18
+ return ret
19
+
20
+
21
+
22
+ class CondAffineSeparatedAndCond(nn.Module):
23
+ def __init__(self, in_channels, opt):
24
+ super().__init__()
25
+ self.need_features = True
26
+ self.in_channels = in_channels
27
+ self.in_channels_rrdb = opt_get(opt, ['network_G', 'flow', 'conditionInFeaDim'], 320)
28
+ self.kernel_hidden = 1
29
+ self.affine_eps = 0.0001
30
+ self.n_hidden_layers = 1
31
+ hidden_channels = opt_get(opt, ['network_G', 'flow', 'CondAffineSeparatedAndCond', 'hidden_channels'])
32
+ self.hidden_channels = 64 if hidden_channels is None else hidden_channels
33
+
34
+ self.affine_eps = opt_get(opt, ['network_G', 'flow', 'CondAffineSeparatedAndCond', 'eps'], 0.0001)
35
+
36
+ self.channels_for_nn = self.in_channels // 2
37
+ self.channels_for_co = self.in_channels - self.channels_for_nn
38
+
39
+ if self.channels_for_nn is None:
40
+ self.channels_for_nn = self.in_channels // 2
41
+
42
+ self.fAffine = self.F(in_channels=self.channels_for_nn + self.in_channels_rrdb,
43
+ out_channels=self.channels_for_co * 2,
44
+ hidden_channels=self.hidden_channels,
45
+ kernel_hidden=self.kernel_hidden,
46
+ n_hidden_layers=self.n_hidden_layers)
47
+
48
+ self.fFeatures = self.F(in_channels=self.in_channels_rrdb,
49
+ out_channels=self.in_channels * 2,
50
+ hidden_channels=self.hidden_channels,
51
+ kernel_hidden=self.kernel_hidden,
52
+ n_hidden_layers=self.n_hidden_layers)
53
+ self.opt = opt
54
+ self.le_curve = opt['le_curve'] if opt['le_curve'] is not None else False
55
+ if self.le_curve:
56
+ self.fCurve = self.F(in_channels=self.in_channels_rrdb,
57
+ out_channels=self.in_channels,
58
+ hidden_channels=self.hidden_channels,
59
+ kernel_hidden=self.kernel_hidden,
60
+ n_hidden_layers=self.n_hidden_layers)
61
+
62
+ def forward(self, input: torch.Tensor, logdet=None, reverse=False, ft=None):
63
+ if not reverse:
64
+ z = input
65
+ assert z.shape[1] == self.in_channels, (z.shape[1], self.in_channels)
66
+
67
+ # Feature Conditional
68
+ scaleFt, shiftFt = self.feature_extract(ft, self.fFeatures)
69
+ z = z + shiftFt
70
+ z = z * scaleFt
71
+ logdet = logdet + self.get_logdet(scaleFt)
72
+
73
+ # Curve conditional
74
+ if self.le_curve:
75
+ # logdet = logdet + thops.sum(torch.log(torch.sigmoid(z) * (1 - torch.sigmoid(z))), dim=[1, 2, 3])
76
+ # z = torch.sigmoid(z)
77
+ # alpha = self.fCurve(ft)
78
+ # alpha = (torch.tanh(alpha + 2.) + self.affine_eps)
79
+ # logdet = logdet + thops.sum(torch.log((1 + alpha - 2 * z * alpha).abs()), dim=[1, 2, 3])
80
+ # z = z + alpha * z * (1 - z)
81
+
82
+ alpha = self.fCurve(ft)
83
+ # alpha = (torch.sigmoid(alpha + 2.) + self.affine_eps)
84
+ alpha = torch.relu(alpha) + self.affine_eps
85
+ logdet = logdet + thops.sum(torch.log(alpha * torch.pow(z.abs(), alpha - 1)) + self.affine_eps)
86
+ z = torch.pow(z.abs(), alpha) * z.sign()
87
+
88
+ # Self Conditional
89
+ z1, z2 = self.split(z)
90
+ scale, shift = self.feature_extract_aff(z1, ft, self.fAffine)
91
+ self.asserts(scale, shift, z1, z2)
92
+ z2 = z2 + shift
93
+ z2 = z2 * scale
94
+
95
+ logdet = logdet + self.get_logdet(scale)
96
+ z = thops.cat_feature(z1, z2)
97
+ output = z
98
+ else:
99
+ z = input
100
+
101
+ # Self Conditional
102
+ z1, z2 = self.split(z)
103
+ scale, shift = self.feature_extract_aff(z1, ft, self.fAffine)
104
+ self.asserts(scale, shift, z1, z2)
105
+ z2 = z2 / scale
106
+ z2 = z2 - shift
107
+ z = thops.cat_feature(z1, z2)
108
+ logdet = logdet - self.get_logdet(scale)
109
+
110
+ # Curve conditional
111
+ if self.le_curve:
112
+ # alpha = self.fCurve(ft)
113
+ # alpha = (torch.sigmoid(alpha + 2.) + self.affine_eps)
114
+ # z = (1 + alpha) / alpha - (
115
+ # alpha + torch.pow(2 * alpha - 4 * alpha * z + torch.pow(alpha, 2) + 1, 0.5) + 1) / (
116
+ # 2 * alpha)
117
+ # z = torch.log((z / (1 - z)).clamp(1 / 1000, 1000))
118
+
119
+ alpha = self.fCurve(ft)
120
+ alpha = torch.relu(alpha) + self.affine_eps
121
+ # alpha = (torch.sigmoid(alpha + 2.) + self.affine_eps)
122
+ z = torch.pow(z.abs(), 1 / alpha) * z.sign()
123
+
124
+ # Feature Conditional
125
+ scaleFt, shiftFt = self.feature_extract(ft, self.fFeatures)
126
+ z = z / scaleFt
127
+ z = z - shiftFt
128
+ logdet = logdet - self.get_logdet(scaleFt)
129
+
130
+ output = z
131
+ return output, logdet
132
+
133
+ def asserts(self, scale, shift, z1, z2):
134
+ assert z1.shape[1] == self.channels_for_nn, (z1.shape[1], self.channels_for_nn)
135
+ assert z2.shape[1] == self.channels_for_co, (z2.shape[1], self.channels_for_co)
136
+ assert scale.shape[1] == shift.shape[1], (scale.shape[1], shift.shape[1])
137
+ assert scale.shape[1] == z2.shape[1], (scale.shape[1], z1.shape[1], z2.shape[1])
138
+
139
+ def get_logdet(self, scale):
140
+ return thops.sum(torch.log(scale), dim=[1, 2, 3])
141
+
142
+ def feature_extract(self, z, f):
143
+ h = f(z)
144
+ shift, scale = thops.split_feature(h, "cross")
145
+ scale = (torch.sigmoid(scale + 2.) + self.affine_eps)
146
+ return scale, shift
147
+
148
+ def feature_extract_aff(self, z1, ft, f):
149
+ z = torch.cat([z1, ft], dim=1)
150
+ h = f(z)
151
+ shift, scale = thops.split_feature(h, "cross")
152
+ scale = (torch.sigmoid(scale + 2.) + self.affine_eps)
153
+ return scale, shift
154
+
155
+ def split(self, z):
156
+ z1 = z[:, :self.channels_for_nn]
157
+ z2 = z[:, self.channels_for_nn:]
158
+ assert z1.shape[1] + z2.shape[1] == z.shape[1], (z1.shape[1], z2.shape[1], z.shape[1])
159
+ return z1, z2
160
+
161
+ def F(self, in_channels, out_channels, hidden_channels, kernel_hidden=1, n_hidden_layers=1):
162
+ layers = [Conv2d(in_channels, hidden_channels), nn.ReLU(inplace=False)]
163
+
164
+ for _ in range(n_hidden_layers):
165
+ layers.append(Conv2d(hidden_channels, hidden_channels, kernel_size=[kernel_hidden, kernel_hidden]))
166
+ layers.append(nn.ReLU(inplace=False))
167
+ layers.append(Conv2dZeros(hidden_channels, out_channels))
168
+
169
+ return nn.Sequential(*layers)
models/llflow/models/modules/FlowStep.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ import torch
5
+ from torch import nn as nn
6
+
7
+ import models.modules
8
+ import models.modules.Permutations
9
+ from models.modules import flow, thops, FlowAffineCouplingsAblation
10
+ # from utils.util import opt_get
11
+
12
+
13
+ def opt_get(opt, keys, default=None):
14
+ if opt is None:
15
+ return default
16
+ ret = opt
17
+ for k in keys:
18
+ ret = ret.get(k, None)
19
+ if ret is None:
20
+ return default
21
+ return ret
22
+
23
+
24
+
25
+ def getConditional(rrdbResults, position):
26
+ img_ft = rrdbResults if isinstance(rrdbResults, torch.Tensor) else rrdbResults[position]
27
+ return img_ft
28
+
29
+
30
+ class FlowStep(nn.Module):
31
+ FlowPermutation = {
32
+ "reverse": lambda obj, z, logdet, rev: (obj.reverse(z, rev), logdet),
33
+ "shuffle": lambda obj, z, logdet, rev: (obj.shuffle(z, rev), logdet),
34
+ "invconv": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
35
+ "squeeze_invconv": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
36
+ "resqueeze_invconv_alternating_2_3": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
37
+ "resqueeze_invconv_3": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
38
+ "InvertibleConv1x1GridAlign": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
39
+ "InvertibleConv1x1SubblocksShuf": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
40
+ "InvertibleConv1x1GridAlignIndepBorder": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
41
+ "InvertibleConv1x1GridAlignIndepBorder4": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
42
+ }
43
+
44
+ def __init__(self, in_channels, hidden_channels,
45
+ actnorm_scale=1.0, flow_permutation="invconv", flow_coupling="additive",
46
+ LU_decomposed=False, opt=None, image_injector=None, idx=None, acOpt=None, normOpt=None, in_shape=None,
47
+ position=None):
48
+ # check configures
49
+ assert flow_permutation in FlowStep.FlowPermutation, \
50
+ "float_permutation should be in `{}`".format(
51
+ FlowStep.FlowPermutation.keys())
52
+ super().__init__()
53
+ self.flow_permutation = flow_permutation
54
+ self.flow_coupling = flow_coupling
55
+ self.image_injector = image_injector
56
+
57
+ self.norm_type = normOpt['type'] if normOpt else 'ActNorm2d'
58
+ self.position = normOpt['position'] if normOpt else None
59
+
60
+ self.in_shape = in_shape
61
+ self.position = position
62
+ self.acOpt = acOpt
63
+
64
+ # 1. actnorm
65
+ self.actnorm = models.modules.FlowActNorms.ActNorm2d(in_channels, actnorm_scale)
66
+
67
+ # 2. permute
68
+ if flow_permutation == "invconv":
69
+ self.invconv = models.modules.Permutations.InvertibleConv1x1(
70
+ in_channels, LU_decomposed=LU_decomposed)
71
+
72
+ # 3. coupling
73
+ if flow_coupling == "CondAffineSeparatedAndCond":
74
+ self.affine = models.modules.FlowAffineCouplingsAblation.CondAffineSeparatedAndCond(in_channels=in_channels,
75
+ opt=opt)
76
+ elif flow_coupling == "noCoupling":
77
+ pass
78
+ else:
79
+ raise RuntimeError("coupling not Found:", flow_coupling)
80
+
81
+ def forward(self, input, logdet=None, reverse=False, rrdbResults=None):
82
+ if not reverse:
83
+ return self.normal_flow(input, logdet, rrdbResults)
84
+ else:
85
+ return self.reverse_flow(input, logdet, rrdbResults)
86
+
87
+ def normal_flow(self, z, logdet, rrdbResults=None):
88
+ if self.flow_coupling == "bentIdentityPreAct":
89
+ z, logdet = self.bentIdentPar(z, logdet, reverse=False)
90
+
91
+ # 1. actnorm
92
+ if self.norm_type == "ConditionalActNormImageInjector":
93
+ img_ft = getConditional(rrdbResults, self.position)
94
+ z, logdet = self.actnorm(z, img_ft=img_ft, logdet=logdet, reverse=False)
95
+ elif self.norm_type == "noNorm":
96
+ pass
97
+ else:
98
+ z, logdet = self.actnorm(z, logdet=logdet, reverse=False)
99
+
100
+ # 2. permute
101
+ z, logdet = FlowStep.FlowPermutation[self.flow_permutation](
102
+ self, z, logdet, False)
103
+
104
+ need_features = self.affine_need_features()
105
+
106
+ # 3. coupling
107
+ if need_features or self.flow_coupling in ["condAffine", "condFtAffine", "condNormAffine"]:
108
+ img_ft = getConditional(rrdbResults, self.position)
109
+ z, logdet = self.affine(input=z, logdet=logdet, reverse=False, ft=img_ft)
110
+ return z, logdet
111
+
112
+ def reverse_flow(self, z, logdet, rrdbResults=None):
113
+
114
+ need_features = self.affine_need_features()
115
+
116
+ # 1.coupling
117
+ if need_features or self.flow_coupling in ["condAffine", "condFtAffine", "condNormAffine"]:
118
+ img_ft = getConditional(rrdbResults, self.position)
119
+ z, logdet = self.affine(input=z, logdet=logdet, reverse=True, ft=img_ft)
120
+
121
+ # 2. permute
122
+ z, logdet = FlowStep.FlowPermutation[self.flow_permutation](
123
+ self, z, logdet, True)
124
+
125
+ # 3. actnorm
126
+ z, logdet = self.actnorm(z, logdet=logdet, reverse=True)
127
+
128
+ return z, logdet
129
+
130
+ def affine_need_features(self):
131
+ need_features = False
132
+ try:
133
+ need_features = self.affine.need_features
134
+ except:
135
+ pass
136
+ return need_features
models/llflow/models/modules/FlowUpsamplerNet.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ import numpy as np
5
+ import torch
6
+ from torch import nn as nn
7
+
8
+ import models.modules.Split
9
+ from models.modules import flow, thops
10
+ from models.modules.Split import Split2d
11
+ from models.modules.glow_arch import f_conv2d_bias
12
+ from models.modules.FlowStep import FlowStep
13
+ # from utils.util import opt_get
14
+
15
+
16
+
17
+ def opt_get(opt, keys, default=None):
18
+ if opt is None:
19
+ return default
20
+ ret = opt
21
+ for k in keys:
22
+ ret = ret.get(k, None)
23
+ if ret is None:
24
+ return default
25
+ return ret
26
+
27
+
28
+
29
+
30
+
31
+ class FlowUpsamplerNet(nn.Module):
32
+ def __init__(self, image_shape, hidden_channels, K, L=None,
33
+ actnorm_scale=1.0,
34
+ flow_permutation=None,
35
+ flow_coupling="affine",
36
+ LU_decomposed=False, opt=None):
37
+
38
+ super().__init__()
39
+ self.hr_size = opt['datasets']['train']['GT_size']
40
+ self.layers = nn.ModuleList()
41
+ self.output_shapes = []
42
+ self.sigmoid_output = opt['sigmoid_output'] if opt['sigmoid_output'] is not None else False
43
+ self.L = opt_get(opt, ['network_G', 'flow', 'L'])
44
+ self.K = opt_get(opt, ['network_G', 'flow', 'K'])
45
+ if isinstance(self.K, int):
46
+ self.K = [K for K in [K, ] * (self.L + 1)]
47
+
48
+ self.opt = opt
49
+ H, W, self.C = image_shape
50
+ self.check_image_shape()
51
+
52
+ if opt['scale'] == 16:
53
+ self.levelToName = {
54
+ 0: 'fea_up16',
55
+ 1: 'fea_up8',
56
+ 2: 'fea_up4',
57
+ 3: 'fea_up2',
58
+ 4: 'fea_up1',
59
+ }
60
+
61
+ if opt['scale'] == 8:
62
+ self.levelToName = {
63
+ 0: 'fea_up8',
64
+ 1: 'fea_up4',
65
+ 2: 'fea_up2',
66
+ 3: 'fea_up1',
67
+ 4: 'fea_up0'
68
+ }
69
+
70
+ elif opt['scale'] == 4:
71
+ self.levelToName = {
72
+ 0: 'fea_up4',
73
+ 1: 'fea_up2',
74
+ 2: 'fea_up1',
75
+ 3: 'fea_up0',
76
+ 4: 'fea_up-1'
77
+ }
78
+ elif opt['scale'] == 1:
79
+ self.levelToName = {
80
+ # 0: 'fea_up4',
81
+ 1: 'fea_up2',
82
+ 2: 'fea_up1',
83
+ 3: 'fea_up0',
84
+ # 4: 'fea_up-1'
85
+ }
86
+
87
+ affineInCh = self.get_affineInCh(opt_get)
88
+ flow_permutation = self.get_flow_permutation(flow_permutation, opt)
89
+
90
+ normOpt = opt_get(opt, ['network_G', 'flow', 'norm'])
91
+
92
+ conditional_channels = {}
93
+ n_rrdb = self.get_n_rrdb_channels(opt, opt_get)
94
+ n_bypass_channels = opt_get(opt, ['network_G', 'flow', 'levelConditional', 'n_channels'])
95
+ conditional_channels[0] = n_rrdb
96
+ for level in range(1, self.L + 1):
97
+ # Level 1 gets conditionals from 2, 3, 4 => L - level
98
+ # Level 2 gets conditionals from 3, 4
99
+ # Level 3 gets conditionals from 4
100
+ # Level 4 gets conditionals from None
101
+ n_bypass = 0 if n_bypass_channels is None else (self.L - level) * n_bypass_channels
102
+ conditional_channels[level] = n_rrdb + n_bypass
103
+
104
+ # Upsampler
105
+ for level in range(1, self.L + 1):
106
+ # 1. Squeeze
107
+ H, W = self.arch_squeeze(H, W)
108
+
109
+ # 2. K FlowStep
110
+ self.arch_additionalFlowAffine(H, LU_decomposed, W, actnorm_scale, hidden_channels, opt)
111
+ self.arch_FlowStep(H, self.K[level], LU_decomposed, W, actnorm_scale, affineInCh, flow_coupling,
112
+ flow_permutation,
113
+ hidden_channels, normOpt, opt, opt_get,
114
+ n_conditinal_channels=conditional_channels[level])
115
+ # Split
116
+ self.arch_split(H, W, level, self.L, opt, opt_get)
117
+
118
+ if opt_get(opt, ['network_G', 'flow', 'split', 'enable']):
119
+ self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64 // 2 // 2)
120
+ else:
121
+ self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64)
122
+
123
+ self.H = H
124
+ self.W = W
125
+ self.scaleH = opt['datasets']['train']['GT_size'] / H
126
+ self.scaleW = opt['datasets']['train']['GT_size'] / W
127
+
128
+ def get_n_rrdb_channels(self, opt, opt_get):
129
+ blocks = opt_get(opt, ['network_G', 'flow', 'stackRRDB', 'blocks'])
130
+ n_rrdb = 64 if blocks is None else (len(blocks) + 1) * 64
131
+ return n_rrdb
132
+
133
+ def arch_FlowStep(self, H, K, LU_decomposed, W, actnorm_scale, affineInCh, flow_coupling, flow_permutation,
134
+ hidden_channels, normOpt, opt, opt_get, n_conditinal_channels=None):
135
+ condAff = self.get_condAffSetting(opt, opt_get)
136
+ if condAff is not None:
137
+ condAff['in_channels_rrdb'] = n_conditinal_channels
138
+
139
+ for k in range(K):
140
+ position_name = get_position_name(H, self.opt['scale'], opt)
141
+ if normOpt: normOpt['position'] = position_name
142
+
143
+ self.layers.append(
144
+ FlowStep(in_channels=self.C,
145
+ hidden_channels=hidden_channels,
146
+ actnorm_scale=actnorm_scale,
147
+ flow_permutation=flow_permutation,
148
+ flow_coupling=flow_coupling,
149
+ acOpt=condAff,
150
+ position=position_name,
151
+ LU_decomposed=LU_decomposed, opt=opt, idx=k, normOpt=normOpt))
152
+ self.output_shapes.append(
153
+ [-1, self.C, H, W])
154
+
155
+ def get_condAffSetting(self, opt, opt_get):
156
+ condAff = opt_get(opt, ['network_G', 'flow', 'condAff']) or None
157
+ condAff = opt_get(opt, ['network_G', 'flow', 'condFtAffine']) or condAff
158
+ return condAff
159
+
160
+ def arch_split(self, H, W, L, levels, opt, opt_get):
161
+ correct_splits = opt_get(opt, ['network_G', 'flow', 'split', 'correct_splits'], False)
162
+ correction = 0 if correct_splits else 1
163
+ if opt_get(opt, ['network_G', 'flow', 'split', 'enable']) and L < levels - correction:
164
+ logs_eps = opt_get(opt, ['network_G', 'flow', 'split', 'logs_eps']) or 0
165
+ consume_ratio = opt_get(opt, ['network_G', 'flow', 'split', 'consume_ratio']) or 0.5
166
+ position_name = get_position_name(H, self.opt['scale'], opt)
167
+ position = position_name if opt_get(opt, ['network_G', 'flow', 'split', 'conditional']) else None
168
+ cond_channels = opt_get(opt, ['network_G', 'flow', 'split', 'cond_channels'])
169
+ cond_channels = 0 if cond_channels is None else cond_channels
170
+
171
+ t = opt_get(opt, ['network_G', 'flow', 'split', 'type'], 'Split2d')
172
+
173
+ if t == 'Split2d':
174
+ split = models.modules.Split.Split2d(num_channels=self.C, logs_eps=logs_eps, position=position,
175
+ cond_channels=cond_channels, consume_ratio=consume_ratio, opt=opt)
176
+ self.layers.append(split)
177
+ self.output_shapes.append([-1, split.num_channels_pass, H, W])
178
+ self.C = split.num_channels_pass
179
+
180
+ def arch_additionalFlowAffine(self, H, LU_decomposed, W, actnorm_scale, hidden_channels, opt):
181
+ if 'additionalFlowNoAffine' in opt['network_G']['flow']:
182
+ n_additionalFlowNoAffine = int(opt['network_G']['flow']['additionalFlowNoAffine'])
183
+ for _ in range(n_additionalFlowNoAffine):
184
+ self.layers.append(
185
+ FlowStep(in_channels=self.C,
186
+ hidden_channels=hidden_channels,
187
+ actnorm_scale=actnorm_scale,
188
+ flow_permutation='invconv',
189
+ flow_coupling='noCoupling',
190
+ LU_decomposed=LU_decomposed, opt=opt))
191
+ self.output_shapes.append(
192
+ [-1, self.C, H, W])
193
+
194
+ def arch_squeeze(self, H, W):
195
+ self.C, H, W = self.C * 4, H // 2, W // 2
196
+ self.layers.append(flow.SqueezeLayer(factor=2))
197
+ self.output_shapes.append([-1, self.C, H, W])
198
+ return H, W
199
+
200
+ def get_flow_permutation(self, flow_permutation, opt):
201
+ flow_permutation = opt['network_G']['flow'].get('flow_permutation', 'invconv')
202
+ return flow_permutation
203
+
204
+ def get_affineInCh(self, opt_get):
205
+ affineInCh = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or []
206
+ affineInCh = (len(affineInCh) + 1) * 64
207
+ return affineInCh
208
+
209
+ def check_image_shape(self):
210
+ assert self.C == 1 or self.C == 3, ("image_shape should be HWC, like (64, 64, 3)"
211
+ "self.C == 1 or self.C == 3")
212
+
213
+ def forward(self, gt=None, rrdbResults=None, z=None, epses=None, logdet=0., reverse=False, eps_std=None,
214
+ y_onehot=None):
215
+
216
+ if reverse:
217
+ epses_copy = [eps for eps in epses] if isinstance(epses, list) else epses
218
+
219
+ sr, logdet = self.decode(rrdbResults, z, eps_std, epses=epses_copy, logdet=logdet, y_onehot=y_onehot)
220
+ if self.sigmoid_output:
221
+ sr = torch.sigmoid(sr)
222
+ return sr, logdet
223
+ else:
224
+ assert gt is not None
225
+ # assert rrdbResults is not None
226
+ if self.sigmoid_output:
227
+ gt = torch.log((gt / (1 - gt)).clamp(1 / 1000, 1000))
228
+ z, logdet = self.encode(gt, rrdbResults, logdet=logdet, epses=epses, y_onehot=y_onehot)
229
+
230
+ return z, logdet
231
+
232
+ def encode(self, gt, rrdbResults, logdet=0.0, epses=None, y_onehot=None):
233
+ fl_fea = gt
234
+ reverse = False
235
+ level_conditionals = {}
236
+ bypasses = {}
237
+
238
+ L = opt_get(self.opt, ['network_G', 'flow', 'L'])
239
+
240
+ for level in range(1, L + 1):
241
+ bypasses[level] = torch.nn.functional.interpolate(gt, scale_factor=2 ** -level, mode='bilinear',
242
+ align_corners=False)
243
+
244
+ for layer, shape in zip(self.layers, self.output_shapes):
245
+ size = shape[2]
246
+ level = int(np.log(self.hr_size / size) / np.log(2))
247
+ if level > 0 and level not in level_conditionals.keys():
248
+ if rrdbResults is None:
249
+ level_conditionals[level] = None
250
+ else:
251
+ level_conditionals[level] = rrdbResults[self.levelToName[level]]
252
+
253
+ if isinstance(layer, FlowStep):
254
+ fl_fea, logdet = layer(fl_fea, logdet, reverse=reverse, rrdbResults=level_conditionals[level])
255
+ elif isinstance(layer, Split2d):
256
+ fl_fea, logdet = self.forward_split2d(epses, fl_fea, layer, logdet, reverse, level_conditionals[level],
257
+ y_onehot=y_onehot)
258
+ else:
259
+ fl_fea, logdet = layer(fl_fea, logdet, reverse=reverse)
260
+
261
+ z = fl_fea
262
+
263
+ if not isinstance(epses, list):
264
+ return z, logdet
265
+
266
+ epses.append(z)
267
+ return epses, logdet
268
+
269
+ def forward_preFlow(self, fl_fea, logdet, reverse):
270
+ if hasattr(self, 'preFlow'):
271
+ for l in self.preFlow:
272
+ fl_fea, logdet = l(fl_fea, logdet, reverse=reverse)
273
+ return fl_fea, logdet
274
+
275
+ def forward_split2d(self, epses, fl_fea, layer, logdet, reverse, rrdbResults, y_onehot=None):
276
+ ft = None if layer.position is None else rrdbResults[layer.position]
277
+ fl_fea, logdet, eps = layer(fl_fea, logdet, reverse=reverse, eps=epses, ft=ft, y_onehot=y_onehot)
278
+
279
+ if isinstance(epses, list):
280
+ epses.append(eps)
281
+ return fl_fea, logdet
282
+
283
+ def decode(self, rrdbResults, z, eps_std=None, epses=None, logdet=0.0, y_onehot=None):
284
+ z = epses.pop() if isinstance(epses, list) else z
285
+
286
+ fl_fea = z
287
+ # debug.imwrite("fl_fea", fl_fea)
288
+ bypasses = {}
289
+ level_conditionals = {}
290
+ if not opt_get(self.opt, ['network_G', 'flow', 'levelConditional', 'conditional']) == True:
291
+ for level in range(self.L + 1):
292
+ if level not in self.levelToName.keys():
293
+ level_conditionals[level] = None
294
+ else:
295
+ level_conditionals[level] = rrdbResults[self.levelToName[level]] if rrdbResults else None
296
+
297
+ for layer, shape in zip(reversed(self.layers), reversed(self.output_shapes)):
298
+ size = shape[2]
299
+ level = int(np.log(self.hr_size / size) / np.log(2))
300
+ # size = fl_fea.shape[2]
301
+ # level = int(np.log(160 / size) / np.log(2))
302
+
303
+ if isinstance(layer, Split2d):
304
+ fl_fea, logdet = self.forward_split2d_reverse(eps_std, epses, fl_fea, layer,
305
+ rrdbResults[self.levelToName[level]], logdet=logdet,
306
+ y_onehot=y_onehot)
307
+ elif isinstance(layer, FlowStep):
308
+ fl_fea, logdet = layer(fl_fea, logdet=logdet, reverse=True, rrdbResults=level_conditionals[level])
309
+ else:
310
+ fl_fea, logdet = layer(fl_fea, logdet=logdet, reverse=True)
311
+
312
+ sr = fl_fea
313
+
314
+ assert sr.shape[1] == 3
315
+ return sr, logdet
316
+
317
+ def forward_split2d_reverse(self, eps_std, epses, fl_fea, layer, rrdbResults, logdet, y_onehot=None):
318
+ ft = None if layer.position is None else rrdbResults[layer.position]
319
+ fl_fea, logdet = layer(fl_fea, logdet=logdet, reverse=True,
320
+ eps=epses.pop() if isinstance(epses, list) else None,
321
+ eps_std=eps_std, ft=ft, y_onehot=y_onehot)
322
+ return fl_fea, logdet
323
+
324
+
325
+ def get_position_name(H, scale, opt):
326
+ downscale_factor = opt['datasets']['train']['GT_size'] // H
327
+ position_name = 'fea_up{}'.format(scale / downscale_factor)
328
+ return position_name
models/llflow/models/modules/LLFlow_arch.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ import math
5
+ import random
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import numpy as np
11
+ from models.modules.RRDBNet_arch import RRDBNet
12
+ from models.modules.ConditionEncoder import ConEncoder1, NoEncoder
13
+ from models.modules.FlowUpsamplerNet import FlowUpsamplerNet
14
+ import models.modules.thops as thops
15
+ import models.modules.flow as flow
16
+ from models.modules.color_encoder import ColorEncoder
17
+ # from utils.util import opt_get
18
+ from models.modules.flow import unsqueeze2d, squeeze2d
19
+ from torch.cuda.amp import autocast
20
+
21
+
22
+
23
+ def opt_get(opt, keys, default=None):
24
+ if opt is None:
25
+ return default
26
+ ret = opt
27
+ for k in keys:
28
+ ret = ret.get(k, None)
29
+ if ret is None:
30
+ return default
31
+ return ret
32
+
33
+
34
+
35
+
36
+
37
+ class LLFlow(nn.Module):
38
+ def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, K=None, opt=None, step=None):
39
+ super(LLFlow, self).__init__()
40
+ self.crop_size = opt['datasets']['train']['GT_size']
41
+ self.opt = opt
42
+ self.quant = 255 if opt_get(opt, ['datasets', 'train', 'quant']) is \
43
+ None else opt_get(opt, ['datasets', 'train', 'quant'])
44
+ if opt['cond_encoder'] == 'ConEncoder1':
45
+ self.RRDB = ConEncoder1(in_nc, out_nc, nf, nb, gc, scale, opt)
46
+ elif opt['cond_encoder'] == 'NoEncoder':
47
+ self.RRDB = None # NoEncoder(in_nc, out_nc, nf, nb, gc, scale, opt)
48
+ elif opt['cond_encoder'] == 'RRDBNet':
49
+ # if self.opt['encode_color_map']: print('Warning: ''encode_color_map'' is not implemented in RRDBNet')
50
+ self.RRDB = RRDBNet(in_nc, out_nc, nf, nb, gc, scale, opt)
51
+ else:
52
+ print('WARNING: Cannot find the conditional encoder %s, select RRDBNet by default.' % opt['cond_encoder'])
53
+ # if self.opt['encode_color_map']: print('Warning: ''encode_color_map'' is not implemented in RRDBNet')
54
+ opt['cond_encoder'] = 'RRDBNet'
55
+ self.RRDB = RRDBNet(in_nc, out_nc, nf, nb, gc, scale, opt)
56
+
57
+ if self.opt['encode_color_map']:
58
+ self.color_map_encoder = ColorEncoder(nf=nf, opt=opt)
59
+
60
+ hidden_channels = opt_get(opt, ['network_G', 'flow', 'hidden_channels'])
61
+ hidden_channels = hidden_channels or 64
62
+ self.RRDB_training = True # Default is true
63
+
64
+ train_RRDB_delay = opt_get(self.opt, ['network_G', 'train_RRDB_delay'])
65
+ set_RRDB_to_train = False
66
+ if set_RRDB_to_train and self.RRDB:
67
+ self.set_rrdb_training(True)
68
+
69
+ self.flowUpsamplerNet = \
70
+ FlowUpsamplerNet((self.crop_size, self.crop_size, 3), hidden_channels, K,
71
+ flow_coupling=opt['network_G']['flow']['coupling'], opt=opt)
72
+ self.i = 0
73
+ if self.opt['to_yuv']:
74
+ self.A_rgb2yuv = torch.nn.Parameter(torch.tensor([[0.299, -0.14714119, 0.61497538],
75
+ [0.587, -0.28886916, -0.51496512],
76
+ [0.114, 0.43601035, -0.10001026]]), requires_grad=False)
77
+ self.A_yuv2rgb = torch.nn.Parameter(torch.tensor([[1., 1., 1.],
78
+ [0., -0.39465, 2.03211],
79
+ [1.13983, -0.58060, 0]]), requires_grad=False)
80
+ if self.opt['align_maxpool']:
81
+ self.max_pool = torch.nn.MaxPool2d(3)
82
+
83
+ def set_rrdb_training(self, trainable):
84
+ if self.RRDB_training != trainable:
85
+ for p in self.RRDB.parameters():
86
+ p.requires_grad = trainable
87
+ self.RRDB_training = trainable
88
+ return True
89
+ return False
90
+
91
+ def rgb2yuv(self, rgb):
92
+ rgb_ = rgb.transpose(1, 3) # input is 3*n*n default
93
+ yuv = torch.tensordot(rgb_, self.A_rgb2yuv, 1).transpose(1, 3)
94
+ return yuv
95
+
96
+ def yuv2rgb(self, yuv):
97
+ yuv_ = yuv.transpose(1, 3) # input is 3*n*n default
98
+ rgb = torch.tensordot(yuv_, self.A_yuv2rgb, 1).transpose(1, 3)
99
+ return rgb
100
+
101
+ @autocast()
102
+ def forward(self, gt=None, lr=None, z=None, eps_std=None, reverse=False, epses=None, reverse_with_grad=False,
103
+ lr_enc=None,
104
+ add_gt_noise=False, step=None, y_label=None, align_condition_feature=False, get_color_map=False):
105
+ if get_color_map:
106
+ color_lr = self.color_map_encoder(lr)
107
+ color_gt = nn.functional.avg_pool2d(gt, 11, 1, 5)
108
+ color_gt = color_gt / torch.sum(color_gt, 1, keepdim=True)
109
+ return color_lr, color_gt
110
+ if not reverse:
111
+ if epses is not None and gt.device.index is not None:
112
+ epses = epses[gt.device.index]
113
+ return self.normal_flow(gt, lr, epses=epses, lr_enc=lr_enc, add_gt_noise=add_gt_noise, step=step,
114
+ y_onehot=y_label, align_condition_feature=align_condition_feature)
115
+ else:
116
+ # assert lr.shape[0] == 1
117
+ assert lr.shape[1] == 3 or lr.shape[1] == 6
118
+ # assert lr.shape[2] == 20
119
+ # assert lr.shape[3] == 20
120
+ # assert z.shape[0] == 1
121
+ # assert z.shape[1] == 3 * 8 * 8
122
+ # assert z.shape[2] == 20
123
+ # assert z.shape[3] == 20
124
+ if reverse_with_grad:
125
+ return self.reverse_flow(lr, z, y_onehot=y_label, eps_std=eps_std, epses=epses, lr_enc=lr_enc,
126
+ add_gt_noise=add_gt_noise)
127
+ else:
128
+ with torch.no_grad():
129
+ return self.reverse_flow(lr, z, y_onehot=y_label, eps_std=eps_std, epses=epses, lr_enc=lr_enc,
130
+ add_gt_noise=add_gt_noise)
131
+
132
+ def normal_flow(self, gt, lr, y_onehot=None, epses=None, lr_enc=None, add_gt_noise=True, step=None,
133
+ align_condition_feature=False):
134
+ if self.opt['to_yuv']:
135
+ gt = self.rgb2yuv(gt)
136
+ if lr_enc is None and self.RRDB:
137
+ lr_enc = self.rrdbPreprocessing(lr)
138
+
139
+ logdet = torch.zeros_like(gt[:, 0, 0, 0])
140
+ pixels = thops.pixels(gt)
141
+
142
+ z = gt
143
+
144
+ if add_gt_noise:
145
+ # Setup
146
+ noiseQuant = opt_get(self.opt, ['network_G', 'flow', 'augmentation', 'noiseQuant'], True)
147
+ if noiseQuant:
148
+ z = z + ((torch.rand(z.shape, device=z.device) - 0.5) / self.quant)
149
+ logdet = logdet + float(-np.log(self.quant) * pixels)
150
+
151
+ # Encode
152
+ epses, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, gt=z, logdet=logdet, reverse=False, epses=epses,
153
+ y_onehot=y_onehot)
154
+
155
+ objective = logdet.clone()
156
+
157
+ # if isinstance(epses, (list, tuple)):
158
+ # z = epses[-1]
159
+ # else:
160
+ # z = epses
161
+ z = epses
162
+ if 'avg_color_map' in self.opt.keys() and self.opt['avg_color_map']:
163
+ if 'avg_pool_color_map' in self.opt.keys() and self.opt['avg_pool_color_map']:
164
+ mean = squeeze2d(F.avg_pool2d(lr_enc['color_map'], 7, 1, 3), 8) if random.random() > self.opt[
165
+ 'train_gt_ratio'] else squeeze2d(F.avg_pool2d(
166
+ gt / (gt.sum(dim=1, keepdims=True) + 1e-4), 7, 1, 3), 8)
167
+ else:
168
+ if self.RRDB is not None:
169
+ mean = squeeze2d(lr_enc['color_map'], 8) if random.random() > self.opt['train_gt_ratio'] else squeeze2d(
170
+ gt/(gt.sum(dim=1, keepdims=True) + 1e-4), 8)
171
+ else:
172
+ mean = squeeze2d(lr[:,:3],8)
173
+ objective = objective + flow.GaussianDiag.logp(mean, torch.tensor(0.).to(z.device), z)
174
+
175
+ nll = (-objective) / float(np.log(2.) * pixels)
176
+ if self.opt['encode_color_map']:
177
+ color_map = self.color_map_encoder(lr)
178
+ color_gt = nn.functional.avg_pool2d(gt, 11, 1, 5)
179
+ color_gt = color_gt / torch.sum(color_gt, 1, keepdim=True)
180
+ color_loss = (color_gt - color_map).abs().mean()
181
+ nll = nll + color_loss
182
+ if align_condition_feature:
183
+ with torch.no_grad():
184
+ gt_enc = self.rrdbPreprocessing(gt)
185
+ for k, v in gt_enc.items():
186
+ if k in ['fea_up-1']: # ['fea_up2','fea_up1','fea_up0','fea_up-1']:
187
+ if self.opt['align_maxpool']:
188
+ nll = nll + (self.max_pool(gt_enc[k]) - self.max_pool(lr_enc[k])).abs().mean() * (
189
+ self.opt['align_weight'] if self.opt['align_weight'] is not None else 1)
190
+ else:
191
+ nll = nll + (gt_enc[k] - lr_enc[k]).abs().mean() * (
192
+ self.opt['align_weight'] if self.opt['align_weight'] is not None else 1)
193
+ if isinstance(epses, list):
194
+ return epses, nll, logdet
195
+ return z, nll, logdet
196
+
197
+ def rrdbPreprocessing(self, lr):
198
+ rrdbResults = self.RRDB(lr, get_steps=True)
199
+ block_idxs = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or []
200
+ if len(block_idxs) > 0:
201
+ low_level_features = [rrdbResults["block_{}".format(idx)] for idx in block_idxs]
202
+ # low_level_features.append(rrdbResults['color_map'])
203
+ concat = torch.cat(low_level_features, dim=1)
204
+
205
+ if opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'concat']) or False:
206
+ keys = ['last_lr_fea', 'fea_up1', 'fea_up2', 'fea_up4']
207
+ if 'fea_up0' in rrdbResults.keys():
208
+ keys.append('fea_up0')
209
+ if 'fea_up-1' in rrdbResults.keys():
210
+ keys.append('fea_up-1')
211
+ for k in keys:
212
+ h = rrdbResults[k].shape[2]
213
+ w = rrdbResults[k].shape[3]
214
+ rrdbResults[k] = torch.cat([rrdbResults[k], F.interpolate(concat, (h, w))], dim=1)
215
+ return rrdbResults
216
+
217
+ def get_score(self, disc_loss_sigma, z):
218
+ score_real = 0.5 * (1 - 1 / (disc_loss_sigma ** 2)) * thops.sum(z ** 2, dim=[1, 2, 3]) - \
219
+ z.shape[1] * z.shape[2] * z.shape[3] * math.log(disc_loss_sigma)
220
+ return -score_real
221
+
222
+ def reverse_flow(self, lr, z, y_onehot, eps_std, epses=None, lr_enc=None, add_gt_noise=True):
223
+
224
+ logdet = torch.zeros_like(lr[:, 0, 0, 0])
225
+ pixels = thops.pixels(lr) * self.opt['scale'] ** 2
226
+
227
+ if add_gt_noise:
228
+ logdet = logdet - float(-np.log(self.quant) * pixels)
229
+
230
+ if lr_enc is None and self.RRDB:
231
+ lr_enc = self.rrdbPreprocessing(lr)
232
+ if self.opt['cond_encoder'] == "NoEncoder":
233
+ z = squeeze2d(lr[:,:3],8)
234
+ else:
235
+ if 'avg_color_map' in self.opt.keys() and self.opt['avg_color_map']:
236
+ z = squeeze2d(F.avg_pool2d(lr_enc['color_map'], 7, 1, 3), 8)
237
+ else:
238
+ z = squeeze2d(lr_enc['color_map'], 8)
239
+ x, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, z=z, eps_std=eps_std, reverse=True, epses=epses,
240
+ logdet=logdet)
241
+ if self.opt['encode_color_map']:
242
+ color_map = self.color_map_encoder(lr)
243
+ color_out = nn.functional.avg_pool2d(x, 11, 1, 5)
244
+ color_out = color_out / torch.sum(color_out, 1, keepdim=True)
245
+ x = x * (color_map / color_out)
246
+ if self.opt['to_yuv']:
247
+ x = self.yuv2rgb(x)
248
+ return x, logdet
models/llflow/models/modules/Permutations.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ import numpy as np
5
+ import torch
6
+ from torch import nn as nn
7
+ from torch.nn import functional as F
8
+
9
+ from models.modules import thops
10
+
11
+
12
+ class InvertibleConv1x1(nn.Module):
13
+ def __init__(self, num_channels, LU_decomposed=False):
14
+ super().__init__()
15
+ w_shape = [num_channels, num_channels]
16
+ w_init = np.linalg.qr(np.random.randn(*w_shape))[0].astype(np.float32)
17
+ self.register_parameter("weight", nn.Parameter(torch.Tensor(w_init)))
18
+ self.w_shape = w_shape
19
+ self.LU = LU_decomposed
20
+
21
+ def get_weight(self, input, reverse):
22
+ w_shape = self.w_shape
23
+ pixels = thops.pixels(input)
24
+ dlogdet = torch.tensor(float('inf'))
25
+ while torch.isinf(dlogdet):
26
+ try:
27
+ dlogdet = torch.slogdet(self.weight)[1] * pixels
28
+ except Exception as e:
29
+ print(e)
30
+ dlogdet = \
31
+ torch.slogdet(
32
+ self.weight + (self.weight.mean() * torch.randn(*self.w_shape).to(input.device) * 0.001))[
33
+ 1] * pixels
34
+ if not reverse:
35
+ weight = self.weight.view(w_shape[0], w_shape[1], 1, 1)
36
+ else:
37
+ try:
38
+ weight = torch.inverse(self.weight.double()).float() \
39
+ .view(w_shape[0], w_shape[1], 1, 1)
40
+ except:
41
+ weight = torch.inverse(self.weight.double()+ (self.weight.mean() * torch.randn(*self.w_shape).to(input.device) * 0.001).float() \
42
+ .view(w_shape[0], w_shape[1], 1, 1))
43
+ return weight, dlogdet
44
+
45
+ def forward(self, input, logdet=None, reverse=False):
46
+ """
47
+ log-det = log|abs(|W|)| * pixels
48
+ """
49
+ weight, dlogdet = self.get_weight(input, reverse)
50
+ if not reverse:
51
+ z = F.conv2d(input, weight)
52
+ if logdet is not None:
53
+ logdet = logdet + dlogdet
54
+ return z, logdet
55
+ else:
56
+ z = F.conv2d(input, weight)
57
+ if logdet is not None:
58
+ logdet = logdet - dlogdet
59
+ return z, logdet
models/llflow/models/modules/RRDBNet_arch.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ import functools
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import models.modules.module_util as mutil
9
+ # from utils.util import opt_get
10
+
11
+
12
+ def opt_get(opt, keys, default=None):
13
+ if opt is None:
14
+ return default
15
+ ret = opt
16
+ for k in keys:
17
+ ret = ret.get(k, None)
18
+ if ret is None:
19
+ return default
20
+ return ret
21
+
22
+
23
+
24
+
25
+ class ResidualDenseBlock_5C(nn.Module):
26
+ def __init__(self, nf=64, gc=32, bias=True):
27
+ super(ResidualDenseBlock_5C, self).__init__()
28
+ # gc: growth channel, i.e. intermediate channels
29
+ self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
30
+ self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
31
+ self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
32
+ self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
33
+ self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
34
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
35
+
36
+ # initialization
37
+ mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
38
+
39
+ def forward(self, x):
40
+ x1 = self.lrelu(self.conv1(x))
41
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
42
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
43
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
44
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
45
+ return x5 * 0.2 + x
46
+
47
+
48
+ class RRDB(nn.Module):
49
+ '''Residual in Residual Dense Block'''
50
+
51
+ def __init__(self, nf, gc=32):
52
+ super(RRDB, self).__init__()
53
+ self.RDB1 = ResidualDenseBlock_5C(nf, gc)
54
+ self.RDB2 = ResidualDenseBlock_5C(nf, gc)
55
+ self.RDB3 = ResidualDenseBlock_5C(nf, gc)
56
+
57
+ def forward(self, x):
58
+ out = self.RDB1(x)
59
+ out = self.RDB2(out)
60
+ out = self.RDB3(out)
61
+ return out * 0.2 + x
62
+
63
+
64
+ class RRDBNet(nn.Module):
65
+ def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, opt=None):
66
+ self.opt = opt
67
+ super(RRDBNet, self).__init__()
68
+ RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
69
+ self.scale = scale
70
+
71
+ self.conv_first = nn.Conv2d(in_nc, nf, 3, 2, 1, bias=True)
72
+ self.RRDB_trunk = mutil.make_layer(RRDB_block_f, nb)
73
+ self.trunk_conv = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
74
+ #### upsampling
75
+ self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
76
+ self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
77
+ if self.scale >= 8:
78
+ self.upconv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
79
+ if self.scale >= 16:
80
+ self.upconv4 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
81
+ if self.scale >= 32:
82
+ self.upconv5 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
83
+
84
+ self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
85
+ self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
86
+
87
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
88
+
89
+ def forward(self, x, get_steps=False):
90
+ fea = self.conv_first(x)
91
+
92
+ block_idxs = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or []
93
+ block_results = {}
94
+
95
+ for idx, m in enumerate(self.RRDB_trunk.children()):
96
+ fea = m(fea)
97
+ for b in block_idxs:
98
+ if b == idx:
99
+ block_results["block_{}".format(idx)] = fea
100
+ trunk = self.trunk_conv(fea)
101
+ fea = F.max_pool2d(fea, 2)
102
+ last_lr_fea = fea + trunk
103
+
104
+ fea_up2 = self.upconv1(F.interpolate(last_lr_fea, scale_factor=2, mode='nearest'))
105
+ fea = self.lrelu(fea_up2)
106
+
107
+ fea_up4 = self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))
108
+ fea = self.lrelu(fea_up4)
109
+
110
+ fea_up8 = None
111
+ fea_up16 = None
112
+ fea_up32 = None
113
+
114
+ if self.scale >= 8:
115
+ fea_up8 = self.upconv3(fea)
116
+ fea = self.lrelu(fea_up8)
117
+ if self.scale >= 16:
118
+ fea_up16 = self.upconv4(fea)
119
+ fea = self.lrelu(fea_up16)
120
+ if self.scale >= 32:
121
+ fea_up32 = self.upconv5(fea)
122
+ fea = self.lrelu(fea_up32)
123
+
124
+ out = self.conv_last(self.lrelu(self.HRconv(fea)))
125
+
126
+ results = {'last_lr_fea': last_lr_fea,
127
+ 'fea_up1': last_lr_fea,
128
+ 'fea_up2': fea_up2,
129
+ 'fea_up4': fea_up4, # raw
130
+ 'fea_up8': fea_up8,
131
+ 'fea_up16': fea_up16,
132
+ 'fea_up32': fea_up32,
133
+ 'out': out} # raw
134
+
135
+ fea_up0_en = opt_get(self.opt, ['network_G', 'flow', 'fea_up0']) or False
136
+ if fea_up0_en:
137
+ results['fea_up0'] = F.interpolate(last_lr_fea, scale_factor=1/2, mode='bilinear', align_corners=False, recompute_scale_factor=True)
138
+ fea_upn1_en = True # opt_get(self.opt, ['network_G', 'flow', 'fea_up-1']) or False
139
+ if fea_upn1_en:
140
+ results['fea_up-1'] = F.interpolate(last_lr_fea, scale_factor=1/4, mode='bilinear', align_corners=False, recompute_scale_factor=True)
141
+
142
+ if get_steps:
143
+ for k, v in block_results.items():
144
+ results[k] = v
145
+ return results
146
+ else:
147
+ return out
models/llflow/models/modules/Split.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ import torch
5
+ from torch import nn as nn
6
+
7
+ from models.modules import thops
8
+ from models.modules.FlowStep import FlowStep
9
+ from models.modules.flow import Conv2dZeros, GaussianDiag
10
+ # from utils.util import opt_get
11
+
12
+
13
+
14
+ def opt_get(opt, keys, default=None):
15
+ if opt is None:
16
+ return default
17
+ ret = opt
18
+ for k in keys:
19
+ ret = ret.get(k, None)
20
+ if ret is None:
21
+ return default
22
+ return ret
23
+
24
+
25
+
26
+
27
+
28
+ class Split2d(nn.Module):
29
+ def __init__(self, num_channels, logs_eps=0, cond_channels=0, position=None, consume_ratio=0.5, opt=None):
30
+ super().__init__()
31
+
32
+ self.num_channels_consume = int(round(num_channels * consume_ratio))
33
+ self.num_channels_pass = num_channels - self.num_channels_consume
34
+
35
+ self.conv = Conv2dZeros(in_channels=self.num_channels_pass + cond_channels,
36
+ out_channels=self.num_channels_consume * 2)
37
+ self.logs_eps = logs_eps
38
+ self.position = position
39
+ self.opt = opt
40
+
41
+ def split2d_prior(self, z, ft):
42
+ if ft is not None:
43
+ z = torch.cat([z, ft], dim=1)
44
+ h = self.conv(z)
45
+ return thops.split_feature(h, "cross")
46
+
47
+ def exp_eps(self, logs):
48
+ return torch.exp(logs) + self.logs_eps
49
+
50
+ def forward(self, input, logdet=0., reverse=False, eps_std=None, eps=None, ft=None, y_onehot=None):
51
+ if not reverse:
52
+ # self.input = input
53
+ z1, z2 = self.split_ratio(input)
54
+ mean, logs = self.split2d_prior(z1, ft)
55
+
56
+ eps = (z2 - mean) / self.exp_eps(logs)
57
+
58
+ logdet = logdet + self.get_logdet(logs, mean, z2)
59
+
60
+ # print(logs.shape, mean.shape, z2.shape)
61
+ # self.eps = eps
62
+ # print('split, enc eps:', eps)
63
+ return z1, logdet, eps
64
+ else:
65
+ z1 = input
66
+ mean, logs = self.split2d_prior(z1, ft)
67
+
68
+ if eps is None:
69
+ #print("WARNING: eps is None, generating eps untested functionality!")
70
+ eps = GaussianDiag.sample_eps(mean.shape, eps_std)
71
+
72
+ eps = eps.to(mean.device)
73
+ z2 = mean + self.exp_eps(logs) * eps
74
+
75
+ z = thops.cat_feature(z1, z2)
76
+ logdet = logdet - self.get_logdet(logs, mean, z2)
77
+
78
+ return z, logdet
79
+ # return z, logdet, eps
80
+
81
+ def get_logdet(self, logs, mean, z2):
82
+ logdet_diff = GaussianDiag.logp(mean, logs, z2)
83
+ # print("Split2D: logdet diff", logdet_diff.item())
84
+ return logdet_diff
85
+
86
+ def split_ratio(self, input):
87
+ z1, z2 = input[:, :self.num_channels_pass, ...], input[:, self.num_channels_pass:, ...]
88
+ return z1, z2
models/llflow/models/modules/__init__.py ADDED
File without changes
models/llflow/models/modules/__pycache__/ConditionEncoder.cpython-310.pyc ADDED
Binary file (6.75 kB). View file
 
models/llflow/models/modules/__pycache__/FlowActNorms.cpython-310.pyc ADDED
Binary file (3.91 kB). View file
 
models/llflow/models/modules/__pycache__/FlowAffineCouplingsAblation.cpython-310.pyc ADDED
Binary file (4.39 kB). View file
 
models/llflow/models/modules/__pycache__/FlowStep.cpython-310.pyc ADDED
Binary file (4.38 kB). View file
 
models/llflow/models/modules/__pycache__/FlowUpsamplerNet.cpython-310.pyc ADDED
Binary file (9.45 kB). View file