gokaygokay commited on
Commit
ff6da4a
1 Parent(s): 8d6a3e4

Florence2Flux

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +73 -160
  2. configs/inference/controlnet_c_3b_canny.yaml +0 -14
  3. configs/inference/controlnet_c_3b_identity.yaml +0 -17
  4. configs/inference/controlnet_c_3b_inpainting.yaml +0 -15
  5. configs/inference/controlnet_c_3b_sr.yaml +0 -15
  6. configs/inference/lora_c_3b.yaml +0 -15
  7. configs/inference/stage_b_1b.yaml +0 -13
  8. configs/inference/stage_b_3b.yaml +0 -13
  9. configs/inference/stage_c_1b.yaml +0 -7
  10. configs/inference/stage_c_3b.yaml +0 -7
  11. configs/training/cfg_control_lr.yaml +0 -47
  12. configs/training/lora_personalization.yaml +0 -37
  13. configs/training/t2i.yaml +0 -29
  14. core/__init__.py +0 -372
  15. core/data/__init__.py +0 -69
  16. core/data/bucketeer.py +0 -88
  17. core/data/bucketeer_deg.py +0 -91
  18. core/data/deg_kair_utils/utils_alignfaces.py +0 -263
  19. core/data/deg_kair_utils/utils_blindsr.py +0 -631
  20. core/data/deg_kair_utils/utils_bnorm.py +0 -91
  21. core/data/deg_kair_utils/utils_deblur.py +0 -655
  22. core/data/deg_kair_utils/utils_dist.py +0 -201
  23. core/data/deg_kair_utils/utils_googledownload.py +0 -93
  24. core/data/deg_kair_utils/utils_image.py +0 -1016
  25. core/data/deg_kair_utils/utils_lmdb.py +0 -205
  26. core/data/deg_kair_utils/utils_logger.py +0 -66
  27. core/data/deg_kair_utils/utils_mat.py +0 -88
  28. core/data/deg_kair_utils/utils_matconvnet.py +0 -197
  29. core/data/deg_kair_utils/utils_model.py +0 -330
  30. core/data/deg_kair_utils/utils_modelsummary.py +0 -485
  31. core/data/deg_kair_utils/utils_option.py +0 -255
  32. core/data/deg_kair_utils/utils_params.py +0 -135
  33. core/data/deg_kair_utils/utils_receptivefield.py +0 -62
  34. core/data/deg_kair_utils/utils_regularizers.py +0 -104
  35. core/data/deg_kair_utils/utils_sisr.py +0 -848
  36. core/data/deg_kair_utils/utils_video.py +0 -493
  37. core/data/deg_kair_utils/utils_videoio.py +0 -555
  38. core/scripts/__init__.py +0 -0
  39. core/scripts/cli.py +0 -41
  40. core/templates/__init__.py +0 -1
  41. core/templates/diffusion.py +0 -236
  42. core/utils/__init__.py +0 -9
  43. core/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  44. core/utils/__pycache__/__init__.cpython-39.pyc +0 -0
  45. core/utils/__pycache__/base_dto.cpython-310.pyc +0 -0
  46. core/utils/__pycache__/base_dto.cpython-39.pyc +0 -0
  47. core/utils/__pycache__/save_and_load.cpython-310.pyc +0 -0
  48. core/utils/__pycache__/save_and_load.cpython-39.pyc +0 -0
  49. core/utils/base_dto.py +0 -56
  50. core/utils/save_and_load.py +0 -59
app.py CHANGED
@@ -1,161 +1,74 @@
1
- import spaces
2
- import os
3
- import requests
4
- import yaml
5
- import torch
6
  import gradio as gr
7
- from PIL import Image
8
- import sys
9
- sys.path.append(os.path.abspath('./'))
10
- from inference.utils import *
11
- from core.utils import load_or_fail
12
- from train import WurstCoreB
13
- from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
14
- from train import WurstCore_t2i as WurstCoreC
15
- import torch.nn.functional as F
16
- from core.utils import load_or_fail
17
- import numpy as np
18
- import random
19
- import math
20
- from einops import rearrange
21
-
22
- def download_file(url, folder_path, filename):
23
- if not os.path.exists(folder_path):
24
- os.makedirs(folder_path)
25
- file_path = os.path.join(folder_path, filename)
26
-
27
- if os.path.isfile(file_path):
28
- print(f"File already exists: {file_path}")
29
- else:
30
- response = requests.get(url, stream=True)
31
- if response.status_code == 200:
32
- with open(file_path, 'wb') as file:
33
- for chunk in response.iter_content(chunk_size=1024):
34
- file.write(chunk)
35
- print(f"File successfully downloaded and saved: {file_path}")
36
- else:
37
- print(f"Error downloading the file. Status code: {response.status_code}")
38
-
39
- def download_models():
40
- models = {
41
- "STABLEWURST_A": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_a.safetensors?download=true", "models/", "stage_a.safetensors"),
42
- "STABLEWURST_PREVIEWER": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/previewer.safetensors?download=true", "models/", "previewer.safetensors"),
43
- "STABLEWURST_EFFNET": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/effnet_encoder.safetensors?download=true", "models/", "effnet_encoder.safetensors"),
44
- "STABLEWURST_B_LITE": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_b_lite_bf16.safetensors?download=true", "models/", "stage_b_lite_bf16.safetensors"),
45
- "STABLEWURST_C": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_c_bf16.safetensors?download=true", "models/", "stage_c_bf16.safetensors"),
46
- "ULTRAPIXEL_T2I": ("https://huggingface.co/roubaofeipi/UltraPixel/resolve/main/ultrapixel_t2i.safetensors?download=true", "models/", "ultrapixel_t2i.safetensors"),
47
- "ULTRAPIXEL_LORA_CAT": ("https://huggingface.co/roubaofeipi/UltraPixel/resolve/main/lora_cat.safetensors?download=true", "models/", "lora_cat.safetensors"),
48
- }
49
-
50
- for model, (url, folder, filename) in models.items():
51
- download_file(url, folder, filename)
52
-
53
- download_models()
54
-
55
- # Global variables
56
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
- dtype = torch.bfloat16
58
-
59
- # Load configs and setup models
60
- with open("configs/training/t2i.yaml", "r", encoding="utf-8") as file:
61
- config_c = yaml.safe_load(file)
62
-
63
- with open("configs/inference/stage_b_1b.yaml", "r", encoding="utf-8") as file:
64
- config_b = yaml.safe_load(file)
65
-
66
- core = WurstCoreC(config_dict=config_c, device=device, training=False)
67
- core_b = WurstCoreB(config_dict=config_b, device=device, training=False)
68
-
69
- extras = core.setup_extras_pre()
70
- models = core.setup_models(extras)
71
- models.generator.eval().requires_grad_(False)
72
-
73
- extras_b = core_b.setup_extras_pre()
74
- models_b = core_b.setup_models(extras_b, skip_clip=True)
75
- models_b = WurstCoreB.Models(
76
- **{**models_b.to_dict(), 'tokenizer': models.tokenizer, 'text_model': models.text_model}
77
- )
78
- models_b.generator.bfloat16().eval().requires_grad_(False)
79
-
80
- # Load pretrained model
81
- pretrained_path = "models/ultrapixel_t2i.safetensors"
82
- sdd = torch.load(pretrained_path, map_location='cpu')
83
- collect_sd = {k[7:]: v for k, v in sdd.items()}
84
- models.train_norm.load_state_dict(collect_sd)
85
- models.generator.eval()
86
- models.train_norm.eval()
87
-
88
- # Set up sampling configurations
89
- extras.sampling_configs.update({
90
- 'cfg': 4,
91
- 'shift': 1,
92
- 'timesteps': 20,
93
- 't_start': 1.0,
94
- 'sampler': DDPMSampler(extras.gdf)
95
- })
96
-
97
- extras_b.sampling_configs.update({
98
- 'cfg': 1.1,
99
- 'shift': 1,
100
- 'timesteps': 10,
101
- 't_start': 1.0
102
- })
103
-
104
- @spaces.GPU(duration=180)
105
- def generate_images(prompt, height, width, seed, num_images):
106
- torch.manual_seed(seed)
107
- random.seed(seed)
108
- np.random.seed(seed)
109
-
110
- batch_size = num_images
111
- height_lr, width_lr = get_target_lr_size(height / width, std_size=32)
112
- stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)
113
- stage_c_latent_shape_lr, stage_b_latent_shape_lr = calculate_latent_sizes(height_lr, width_lr, batch_size=batch_size)
114
-
115
- batch = {'captions': [prompt] * batch_size}
116
- conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False)
117
- unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
118
-
119
- conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
120
- unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
121
-
122
- with torch.no_grad():
123
- models.generator.cuda()
124
- with torch.cuda.amp.autocast(dtype=dtype):
125
- sampled_c = generation_c(batch, models, extras, core, stage_c_latent_shape, stage_c_latent_shape_lr, device)
126
-
127
- models.generator.cpu()
128
- torch.cuda.empty_cache()
129
-
130
- conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
131
- unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
132
- conditions_b['effnet'] = sampled_c
133
- unconditions_b['effnet'] = torch.zeros_like(sampled_c)
134
-
135
- with torch.cuda.amp.autocast(dtype=dtype):
136
- sampled = decode_b(conditions_b, unconditions_b, models_b, stage_b_latent_shape, extras_b, device, stage_a_tiled=True)
137
-
138
- torch.cuda.empty_cache()
139
- imgs = show_images(sampled)
140
- return imgs
141
-
142
- iface = gr.Interface(
143
- fn=generate_images,
144
- inputs=[
145
- gr.Textbox(label="Prompt"),
146
- gr.Slider(minimum=256, maximum=2560, step=32, label="Height", value=1024),
147
- gr.Slider(minimum=256, maximum=5120, step=32, label="Width", value=1024),
148
- gr.Number(label="Seed", value=42),
149
- gr.Slider(minimum=1, maximum=10, step=1, label="Number of Images", value=1)
150
- ],
151
- outputs=gr.Gallery(label="Generated Images", columns=5, rows=2),
152
- title="UltraPixel Image Generation",
153
- description="Generate high-resolution images using UltraPixel model.",
154
- theme='bethecloud/storj_theme',
155
- examples=[
156
- ["The image features a snow-covered mountain range with a large, snow-covered mountain in the background. The mountain is surrounded by a forest of trees, and the sky is filled with clouds. The scene is set during the winter season, with snow covering the ground and the trees.", 1024, 1024, 42, 1]
157
- ],
158
- cache_examples=True
159
- )
160
-
161
- iface.launch()
 
 
 
 
 
 
1
  import gradio as gr
2
+ from transformers import AutoProcessor, AutoModelForCausalLM
3
+ import spaces
4
+ from PIL import Image
5
+
6
+ import subprocess
7
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
8
+
9
+ models = {
10
+ 'gokaygokay/Florence-2-Flux-Large': AutoModelForCausalLM.from_pretrained('gokaygokay/Florence-2-Flux-Large', trust_remote_code=True).eval(),
11
+ 'gokaygokay/Florence-2-Flux': AutoModelForCausalLM.from_pretrained('gokaygokay/Florence-2-Flux', trust_remote_code=True).eval(),
12
+ }
13
+
14
+ processors = {
15
+ 'gokaygokay/Florence-2-Flux-Large': AutoProcessor.from_pretrained('gokaygokay/Florence-2-Flux-Large', trust_remote_code=True),
16
+ 'gokaygokay/Florence-2-Flux': AutoProcessor.from_pretrained('gokaygokay/Florence-2-Flux', trust_remote_code=True),
17
+ }
18
+
19
+ title = """<h1 align="center">Florence-2 Captioner for Flux Prompts</h1>
20
+ <p><center>
21
+ <a href="https://huggingface.co/gokaygokay/Florence-2-Flux-Large" target="_blank">[Florence-2 Flux Large]</a>
22
+ <a href="https://huggingface.co/gokaygokay/Florence-2-Flux" target="_blank">[Florence-2 Flux Base]</a>
23
+ </center></p>
24
+ """
25
+
26
+ @spaces.GPU
27
+ def run_example(image, model_name='gokaygokay/Florence-2-Flux-Large'):
28
+ image = Image.fromarray(image)
29
+ task_prompt = "<DESCRIPTION>"
30
+ prompt = task_prompt + "Describe this image in great detail."
31
+
32
+ if image.mode != "RGB":
33
+ image = image.convert("RGB")
34
+
35
+ model = models[model_name]
36
+ processor = processors[model_name]
37
+
38
+ inputs = processor(text=prompt, images=image, return_tensors="pt")
39
+ generated_ids = model.generate(
40
+ input_ids=inputs["input_ids"],
41
+ pixel_values=inputs["pixel_values"],
42
+ max_new_tokens=1024,
43
+ num_beams=3,
44
+ repetition_penalty=1.10,
45
+ )
46
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
47
+ parsed_answer = processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height))
48
+ return parsed_answer["<DESCRIPTION>"]
49
+
50
+ with gr.Blocks(theme='bethecloud/storj_theme') as demo:
51
+ gr.HTML(title)
52
+
53
+ with gr.Row():
54
+ with gr.Column():
55
+ input_img = gr.Image(label="Input Picture")
56
+ model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value='gokaygokay/Florence-2-Flux-Large')
57
+ submit_btn = gr.Button(value="Submit")
58
+ with gr.Column():
59
+ output_text = gr.Textbox(label="Output Text")
60
+
61
+ gr.Examples(
62
+ [["image1.jpg"],
63
+ ["image2.jpg"],
64
+ ["image3.png"],
65
+ ["image5.jpg"]],
66
+ inputs=[input_img, model_selector],
67
+ outputs=[output_text],
68
+ fn=run_example,
69
+ label='Try captioning on below examples'
70
+ )
71
+
72
+ submit_btn.click(run_example, [input_img, model_selector], [output_text])
73
+
74
+ demo.launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/inference/controlnet_c_3b_canny.yaml DELETED
@@ -1,14 +0,0 @@
1
- # GLOBAL STUFF
2
- model_version: 3.6B
3
- dtype: bfloat16
4
-
5
- # ControlNet specific
6
- controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63]
7
- controlnet_filter: CannyFilter
8
- controlnet_filter_params:
9
- resize: 224
10
-
11
- effnet_checkpoint_path: models/effnet_encoder.safetensors
12
- previewer_checkpoint_path: models/previewer.safetensors
13
- generator_checkpoint_path: models/stage_c_bf16.safetensors
14
- controlnet_checkpoint_path: models/canny.safetensors
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/inference/controlnet_c_3b_identity.yaml DELETED
@@ -1,17 +0,0 @@
1
- # GLOBAL STUFF
2
- model_version: 3.6B
3
- dtype: bfloat16
4
-
5
- # ControlNet specific
6
- controlnet_bottleneck_mode: 'simple'
7
- controlnet_blocks: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]
8
- controlnet_filter: IdentityFilter
9
- controlnet_filter_params:
10
- max_faces: 4
11
- p_drop: 0.00
12
- p_full: 0.0
13
-
14
- effnet_checkpoint_path: models/effnet_encoder.safetensors
15
- previewer_checkpoint_path: models/previewer.safetensors
16
- generator_checkpoint_path: models/stage_c_bf16.safetensors
17
- controlnet_checkpoint_path:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/inference/controlnet_c_3b_inpainting.yaml DELETED
@@ -1,15 +0,0 @@
1
- # GLOBAL STUFF
2
- model_version: 3.6B
3
- dtype: bfloat16
4
-
5
- # ControlNet specific
6
- controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63]
7
- controlnet_filter: InpaintFilter
8
- controlnet_filter_params:
9
- thresold: [0.04, 0.4]
10
- p_outpaint: 0.4
11
-
12
- effnet_checkpoint_path: models/effnet_encoder.safetensors
13
- previewer_checkpoint_path: models/previewer.safetensors
14
- generator_checkpoint_path: models/stage_c_bf16.safetensors
15
- controlnet_checkpoint_path: models/inpainting.safetensors
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/inference/controlnet_c_3b_sr.yaml DELETED
@@ -1,15 +0,0 @@
1
- # GLOBAL STUFF
2
- model_version: 3.6B
3
- dtype: bfloat16
4
-
5
- # ControlNet specific
6
- controlnet_bottleneck_mode: 'large'
7
- controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63]
8
- controlnet_filter: SREffnetFilter
9
- controlnet_filter_params:
10
- scale_factor: 0.5
11
-
12
- effnet_checkpoint_path: models/effnet_encoder.safetensors
13
- previewer_checkpoint_path: models/previewer.safetensors
14
- generator_checkpoint_path: models/stage_c_bf16.safetensors
15
- controlnet_checkpoint_path: models/super_resolution.safetensors
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/inference/lora_c_3b.yaml DELETED
@@ -1,15 +0,0 @@
1
- # GLOBAL STUFF
2
- model_version: 3.6B
3
- dtype: bfloat16
4
-
5
- # LoRA specific
6
- module_filters: ['.attn']
7
- rank: 4
8
- train_tokens:
9
- # - ['^snail', null] # token starts with "snail" -> "snail" & "snails", don't need to be reinitialized
10
- - ['[fernando]', '^dog</w>'] # custom token [snail], initialize as avg of snail & snails
11
-
12
- effnet_checkpoint_path: models/effnet_encoder.safetensors
13
- previewer_checkpoint_path: models/previewer.safetensors
14
- generator_checkpoint_path: models/stage_c_bf16.safetensors
15
- lora_checkpoint_path: models/lora_fernando_10k.safetensors
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/inference/stage_b_1b.yaml DELETED
@@ -1,13 +0,0 @@
1
- # GLOBAL STUFF
2
- model_version: 700M
3
- dtype: bfloat16
4
-
5
- # For demonstration purposes in reconstruct_images.ipynb
6
- webdataset_path: path to your dataset
7
- batch_size: 1
8
- image_size: 2048
9
- grad_accum_steps: 1
10
-
11
- effnet_checkpoint_path: models/effnet_encoder.safetensors
12
- stage_a_checkpoint_path: models/stage_a.safetensors
13
- generator_checkpoint_path: models/stage_b_lite_bf16.safetensors
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/inference/stage_b_3b.yaml DELETED
@@ -1,13 +0,0 @@
1
- # GLOBAL STUFF
2
- model_version: 3B
3
- dtype: bfloat16
4
-
5
- # For demonstration purposes in reconstruct_images.ipynb
6
- webdataset_path: path to your dataset
7
- batch_size: 4
8
- image_size: 1024
9
- grad_accum_steps: 1
10
-
11
- effnet_checkpoint_path: models/effnet_encoder.safetensors
12
- stage_a_checkpoint_path: models/stage_a.safetensors
13
- generator_checkpoint_path: models/stage_b_lite_bf16.safetensors
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/inference/stage_c_1b.yaml DELETED
@@ -1,7 +0,0 @@
1
- # GLOBAL STUFF
2
- model_version: 1B
3
- dtype: bfloat16
4
-
5
- effnet_checkpoint_path: models/effnet_encoder.safetensors
6
- previewer_checkpoint_path: models/previewer.safetensors
7
- generator_checkpoint_path: models/stage_c_lite_bf16.safetensors
 
 
 
 
 
 
 
 
configs/inference/stage_c_3b.yaml DELETED
@@ -1,7 +0,0 @@
1
- # GLOBAL STUFF
2
- model_version: 3.6B
3
- dtype: bfloat16
4
-
5
- effnet_checkpoint_path: models/effnet_encoder.safetensors
6
- previewer_checkpoint_path: models/previewer.safetensors
7
- generator_checkpoint_path: models/stage_c_bf16.safetensors
 
 
 
 
 
 
 
 
configs/training/cfg_control_lr.yaml DELETED
@@ -1,47 +0,0 @@
1
- # GLOBAL STUFF
2
- experiment_id: Ultrapixel_controlnet
3
-
4
- checkpoint_path: checkpoint output path
5
- output_path: visual results output path
6
- model_version: 3.6B
7
- dtype: float32
8
- # # WandB
9
- # wandb_project: StableCascade
10
- # wandb_entity: wandb_username
11
- #module_filters: ['.depthwise', '.mapper', '.attn', '.channelwise' ]
12
- #rank: 32
13
- # TRAINING PARAMS
14
- lr: 1.0e-4
15
- batch_size: 12
16
- #image_size: [1536, 2048, 2560, 3072, 4096]
17
- image_size: [1024, 2048, 2560, 3072, 3584, 3840, 4096, 4608]
18
- #image_size: [ 1024, 1536, 2048, 2560, 3072, 3584, 3840, 4096, 4608]
19
- #image_size: [ 1024, 1280]
20
- multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16]
21
- grad_accum_steps: 2
22
- updates: 40000
23
- backup_every: 5000
24
- save_every: 256
25
- warmup_updates: 1
26
- use_fsdp: True
27
-
28
- # ControlNet specific
29
- controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63]
30
- controlnet_filter: CannyFilter
31
- controlnet_filter_params:
32
- resize: 224
33
- # offset_noise: 0.1
34
-
35
- # GDF
36
- adaptive_loss_weight: True
37
-
38
- ema_start_iters: 10
39
- ema_iters: 50
40
- ema_beta: 0.9
41
-
42
- webdataset_path: path to your training dataset
43
- effnet_checkpoint_path: models/effnet_encoder.safetensors
44
- previewer_checkpoint_path: models/previewer.safetensors
45
- generator_checkpoint_path: models/stage_c_bf16.safetensors
46
- controlnet_checkpoint_path: pretrained controlnet path
47
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/training/lora_personalization.yaml DELETED
@@ -1,37 +0,0 @@
1
- # GLOBAL STUFF
2
- experiment_id: roubao_cat_personalized
3
-
4
- checkpoint_path: checkpoint output path
5
- output_path: visual results output path
6
- model_version: 3.6B
7
- dtype: float32
8
-
9
- module_filters: [ '.attn']
10
- rank: 4
11
- train_tokens:
12
- # - ['^snail', null] # token starts with "snail" -> "snail" & "snails", don't need to be reinitialized
13
- - ['[roubaobao]', '^cat</w>'] # custom token [snail], initialize as avg of snail & snails
14
- # TRAINING PARAMS
15
- lr: 1.0e-4
16
- batch_size: 4
17
-
18
- image_size: [1024, 2048, 2560, 3072, 3584, 3840, 4096, 4608]
19
- multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16]
20
- grad_accum_steps: 2
21
- updates: 40000
22
- backup_every: 5000
23
- save_every: 512
24
- warmup_updates: 1
25
- use_ddp: True
26
-
27
- # GDF
28
- adaptive_loss_weight: True
29
-
30
-
31
- tmp_prompt: a photo of a cat [roubaobao]
32
- webdataset_path: path to your personalized training dataset
33
- effnet_checkpoint_path: models/effnet_encoder.safetensors
34
- previewer_checkpoint_path: models/previewer.safetensors
35
- generator_checkpoint_path: models/stage_c_bf16.safetensors
36
- ultrapixel_path: models/ultrapixel_t2i.safetensors
37
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/training/t2i.yaml DELETED
@@ -1,29 +0,0 @@
1
- # GLOBAL STUFF
2
- experiment_id: ultrapixel_t2i
3
- #strc_fixlrt_norm3_lite_1024_hrft_newdata
4
- checkpoint_path: checkpoint output path #output model directory
5
- output_path: visual results output path #experiment output directory
6
- model_version: 3.6B # finetune large stage c model of stablecascade
7
- dtype: float32
8
-
9
-
10
- # TRAINING PARAMS
11
- lr: 1.0e-4
12
- batch_size: 4 # gpu_number * num_per_gpu * grad_accum_steps
13
- image_size: [1024, 2048, 2560, 3072, 3584, 3840, 4096, 4608] # possible image resolution
14
- multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16]
15
- grad_accum_steps: 2
16
- updates: 40000
17
- backup_every: 5000
18
- save_every: 256
19
- warmup_updates: 1
20
- use_ddp: True
21
-
22
- # GDF
23
- adaptive_loss_weight: True
24
-
25
-
26
- webdataset_path: path to your personalized training dataset
27
- effnet_checkpoint_path: models/effnet_encoder.safetensors
28
- previewer_checkpoint_path: models/previewer.safetensors
29
- generator_checkpoint_path: models/stage_c_bf16.safetensors
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
core/__init__.py DELETED
@@ -1,372 +0,0 @@
1
- import os
2
- import yaml
3
- import torch
4
- from torch import nn
5
- import wandb
6
- import json
7
- from abc import ABC, abstractmethod
8
- from dataclasses import dataclass
9
- from torch.utils.data import Dataset, DataLoader
10
-
11
- from torch.distributed import init_process_group, destroy_process_group, barrier
12
- from torch.distributed.fsdp import (
13
- FullyShardedDataParallel as FSDP,
14
- FullStateDictConfig,
15
- MixedPrecision,
16
- ShardingStrategy,
17
- StateDictType
18
- )
19
-
20
- from .utils import Base, EXPECTED, EXPECTED_TRAIN
21
- from .utils import create_folder_if_necessary, safe_save, load_or_fail
22
-
23
- # pylint: disable=unused-argument
24
- class WarpCore(ABC):
25
- @dataclass(frozen=True)
26
- class Config(Base):
27
- experiment_id: str = EXPECTED_TRAIN
28
- checkpoint_path: str = EXPECTED_TRAIN
29
- output_path: str = EXPECTED_TRAIN
30
- checkpoint_extension: str = "safetensors"
31
- dist_file_subfolder: str = ""
32
- allow_tf32: bool = True
33
-
34
- wandb_project: str = None
35
- wandb_entity: str = None
36
-
37
- @dataclass() # not frozen, means that fields are mutable
38
- class Info(): # not inheriting from Base, because we don't want to enforce the default fields
39
- wandb_run_id: str = None
40
- total_steps: int = 0
41
- iter: int = 0
42
-
43
- @dataclass(frozen=True)
44
- class Data(Base):
45
- dataset: Dataset = EXPECTED
46
- dataloader: DataLoader = EXPECTED
47
- iterator: any = EXPECTED
48
-
49
- @dataclass(frozen=True)
50
- class Models(Base):
51
- pass
52
-
53
- @dataclass(frozen=True)
54
- class Optimizers(Base):
55
- pass
56
-
57
- @dataclass(frozen=True)
58
- class Schedulers(Base):
59
- pass
60
-
61
- @dataclass(frozen=True)
62
- class Extras(Base):
63
- pass
64
- # ---------------------------------------
65
- info: Info
66
- config: Config
67
-
68
- # FSDP stuff
69
- fsdp_defaults = {
70
- "sharding_strategy": ShardingStrategy.SHARD_GRAD_OP,
71
- "cpu_offload": None,
72
- "mixed_precision": MixedPrecision(
73
- param_dtype=torch.bfloat16,
74
- reduce_dtype=torch.bfloat16,
75
- buffer_dtype=torch.bfloat16,
76
- ),
77
- "limit_all_gathers": True,
78
- }
79
- fsdp_fullstate_save_policy = FullStateDictConfig(
80
- offload_to_cpu=True, rank0_only=True
81
- )
82
- # ------------
83
-
84
- # OVERRIDEABLE METHODS
85
-
86
- # [optionally] setup extra stuff, will be called BEFORE the models & optimizers are setup
87
- def setup_extras_pre(self) -> Extras:
88
- return self.Extras()
89
-
90
- # setup dataset & dataloader, return a dict contained dataser, dataloader and/or iterator
91
- @abstractmethod
92
- def setup_data(self, extras: Extras) -> Data:
93
- raise NotImplementedError("This method needs to be overriden")
94
-
95
- # return a dict with all models that are going to be used in the training
96
- @abstractmethod
97
- def setup_models(self, extras: Extras) -> Models:
98
- raise NotImplementedError("This method needs to be overriden")
99
-
100
- # return a dict with all optimizers that are going to be used in the training
101
- @abstractmethod
102
- def setup_optimizers(self, extras: Extras, models: Models) -> Optimizers:
103
- raise NotImplementedError("This method needs to be overriden")
104
-
105
- # [optionally] return a dict with all schedulers that are going to be used in the training
106
- def setup_schedulers(self, extras: Extras, models: Models, optimizers: Optimizers) -> Schedulers:
107
- return self.Schedulers()
108
-
109
- # [optionally] setup extra stuff, will be called AFTER the models & optimizers are setup
110
- def setup_extras_post(self, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers) -> Extras:
111
- return self.Extras.from_dict(extras.to_dict())
112
-
113
- # perform the training here
114
- @abstractmethod
115
- def train(self, data: Data, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers):
116
- raise NotImplementedError("This method needs to be overriden")
117
- # ------------
118
-
119
- def setup_info(self, full_path=None) -> Info:
120
- if full_path is None:
121
- full_path = (f"{self.config.checkpoint_path}/{self.config.experiment_id}/info.json")
122
- info_dict = load_or_fail(full_path, wandb_run_id=None) or {}
123
- info_dto = self.Info(**info_dict)
124
- if info_dto.total_steps > 0 and self.is_main_node:
125
- print(">>> RESUMING TRAINING FROM ITER ", info_dto.total_steps)
126
- return info_dto
127
-
128
- def setup_config(self, config_file_path=None, config_dict=None, training=True) -> Config:
129
- if config_file_path is not None:
130
- if config_file_path.endswith(".yml") or config_file_path.endswith(".yaml"):
131
- with open(config_file_path, "r", encoding="utf-8") as file:
132
- loaded_config = yaml.safe_load(file)
133
- elif config_file_path.endswith(".json"):
134
- with open(config_file_path, "r", encoding="utf-8") as file:
135
- loaded_config = json.load(file)
136
- else:
137
- raise ValueError("Config file must be either a .yml|.yaml or .json file")
138
- return self.Config.from_dict({**loaded_config, 'training': training})
139
- if config_dict is not None:
140
- return self.Config.from_dict({**config_dict, 'training': training})
141
- return self.Config(training=training)
142
-
143
- def setup_ddp(self, experiment_id, single_gpu=False):
144
- if not single_gpu:
145
- local_rank = int(os.environ.get("SLURM_LOCALID"))
146
- process_id = int(os.environ.get("SLURM_PROCID"))
147
- world_size = int(os.environ.get("SLURM_NNODES")) * torch.cuda.device_count()
148
-
149
- self.process_id = process_id
150
- self.is_main_node = process_id == 0
151
- self.device = torch.device(local_rank)
152
- self.world_size = world_size
153
-
154
- dist_file_path = f"{os.getcwd()}/{self.config.dist_file_subfolder}dist_file_{experiment_id}"
155
- # if os.path.exists(dist_file_path) and self.is_main_node:
156
- # os.remove(dist_file_path)
157
-
158
- torch.cuda.set_device(local_rank)
159
- init_process_group(
160
- backend="nccl",
161
- rank=process_id,
162
- world_size=world_size,
163
- init_method=f"file://{dist_file_path}",
164
- )
165
- print(f"[GPU {process_id}] READY")
166
- else:
167
- print("Running in single thread, DDP not enabled.")
168
-
169
- def setup_wandb(self):
170
- if self.is_main_node and self.config.wandb_project is not None:
171
- self.info.wandb_run_id = self.info.wandb_run_id or wandb.util.generate_id()
172
- wandb.init(project=self.config.wandb_project, entity=self.config.wandb_entity, name=self.config.experiment_id, id=self.info.wandb_run_id, resume="allow", config=self.config.to_dict())
173
-
174
- if self.info.total_steps > 0:
175
- wandb.alert(title=f"Training {self.info.wandb_run_id} resumed", text=f"Training {self.info.wandb_run_id} resumed from step {self.info.total_steps}")
176
- else:
177
- wandb.alert(title=f"Training {self.info.wandb_run_id} started", text=f"Training {self.info.wandb_run_id} started")
178
-
179
- # LOAD UTILITIES ----------
180
- def load_model(self, model, model_id=None, full_path=None, strict=True):
181
- print('in line 181 load model', type(model), model_id, full_path, strict)
182
- if model_id is not None and full_path is None:
183
- full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{model_id}.{self.config.checkpoint_extension}"
184
- elif full_path is None and model_id is None:
185
- raise ValueError(
186
- "This method expects either 'model_id' or 'full_path' to be defined"
187
- )
188
-
189
- checkpoint = load_or_fail(full_path, wandb_run_id=self.info.wandb_run_id if self.is_main_node else None)
190
- if checkpoint is not None:
191
- model.load_state_dict(checkpoint, strict=strict)
192
- del checkpoint
193
-
194
- return model
195
-
196
- def load_optimizer(self, optim, optim_id=None, full_path=None, fsdp_model=None):
197
- if optim_id is not None and full_path is None:
198
- full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{optim_id}.pt"
199
- elif full_path is None and optim_id is None:
200
- raise ValueError(
201
- "This method expects either 'optim_id' or 'full_path' to be defined"
202
- )
203
-
204
- checkpoint = load_or_fail(full_path, wandb_run_id=self.info.wandb_run_id if self.is_main_node else None)
205
- if checkpoint is not None:
206
- try:
207
- if fsdp_model is not None:
208
- sharded_optimizer_state_dict = (
209
- FSDP.scatter_full_optim_state_dict( # <---- FSDP
210
- checkpoint
211
- if (
212
- self.is_main_node
213
- or self.fsdp_defaults["sharding_strategy"]
214
- == ShardingStrategy.NO_SHARD
215
- )
216
- else None,
217
- fsdp_model,
218
- )
219
- )
220
- optim.load_state_dict(sharded_optimizer_state_dict)
221
- del checkpoint, sharded_optimizer_state_dict
222
- else:
223
- optim.load_state_dict(checkpoint)
224
- # pylint: disable=broad-except
225
- except Exception as e:
226
- print("!!! Failed loading optimizer, skipping... Exception:", e)
227
-
228
- return optim
229
-
230
- # SAVE UTILITIES ----------
231
- def save_info(self, info, suffix=""):
232
- full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/info{suffix}.json"
233
- create_folder_if_necessary(full_path)
234
- if self.is_main_node:
235
- safe_save(vars(self.info), full_path)
236
-
237
- def save_model(self, model, model_id=None, full_path=None, is_fsdp=False):
238
- if model_id is not None and full_path is None:
239
- full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{model_id}.{self.config.checkpoint_extension}"
240
- elif full_path is None and model_id is None:
241
- raise ValueError(
242
- "This method expects either 'model_id' or 'full_path' to be defined"
243
- )
244
- create_folder_if_necessary(full_path)
245
- if is_fsdp:
246
- with FSDP.summon_full_params(model):
247
- pass
248
- with FSDP.state_dict_type(
249
- model, StateDictType.FULL_STATE_DICT, self.fsdp_fullstate_save_policy
250
- ):
251
- checkpoint = model.state_dict()
252
- if self.is_main_node:
253
- safe_save(checkpoint, full_path)
254
- del checkpoint
255
- else:
256
- if self.is_main_node:
257
- checkpoint = model.state_dict()
258
- safe_save(checkpoint, full_path)
259
- del checkpoint
260
-
261
- def save_optimizer(self, optim, optim_id=None, full_path=None, fsdp_model=None):
262
- if optim_id is not None and full_path is None:
263
- full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{optim_id}.pt"
264
- elif full_path is None and optim_id is None:
265
- raise ValueError(
266
- "This method expects either 'optim_id' or 'full_path' to be defined"
267
- )
268
- create_folder_if_necessary(full_path)
269
- if fsdp_model is not None:
270
- optim_statedict = FSDP.full_optim_state_dict(fsdp_model, optim)
271
- if self.is_main_node:
272
- safe_save(optim_statedict, full_path)
273
- del optim_statedict
274
- else:
275
- if self.is_main_node:
276
- checkpoint = optim.state_dict()
277
- safe_save(checkpoint, full_path)
278
- del checkpoint
279
- # -----
280
-
281
- def __init__(self, config_file_path=None, config_dict=None, device="cpu", training=True):
282
- # Temporary setup, will be overriden by setup_ddp if required
283
- self.device = device
284
- self.process_id = 0
285
- self.is_main_node = True
286
- self.world_size = 1
287
- # ----
288
-
289
- self.config: self.Config = self.setup_config(config_file_path, config_dict, training)
290
- self.info: self.Info = self.setup_info()
291
-
292
- def __call__(self, single_gpu=False):
293
- self.setup_ddp(self.config.experiment_id, single_gpu=single_gpu) # this will change the device to the CUDA rank
294
- self.setup_wandb()
295
- if self.config.allow_tf32:
296
- torch.backends.cuda.matmul.allow_tf32 = True
297
- torch.backends.cudnn.allow_tf32 = True
298
-
299
- if self.is_main_node:
300
- print()
301
- print("**STARTIG JOB WITH CONFIG:**")
302
- print(yaml.dump(self.config.to_dict(), default_flow_style=False))
303
- print("------------------------------------")
304
- print()
305
- print("**INFO:**")
306
- print(yaml.dump(vars(self.info), default_flow_style=False))
307
- print("------------------------------------")
308
- print()
309
-
310
- # SETUP STUFF
311
- extras = self.setup_extras_pre()
312
- assert extras is not None, "setup_extras_pre() must return a DTO"
313
-
314
- data = self.setup_data(extras)
315
- assert data is not None, "setup_data() must return a DTO"
316
- if self.is_main_node:
317
- print("**DATA:**")
318
- print(yaml.dump({k:type(v).__name__ for k, v in data.to_dict().items()}, default_flow_style=False))
319
- print("------------------------------------")
320
- print()
321
-
322
- models = self.setup_models(extras)
323
- assert models is not None, "setup_models() must return a DTO"
324
- if self.is_main_node:
325
- print("**MODELS:**")
326
- print(yaml.dump({
327
- k:f"{type(v).__name__} - {f'trainable params {sum(p.numel() for p in v.parameters() if p.requires_grad)}' if isinstance(v, nn.Module) else 'Not a nn.Module'}" for k, v in models.to_dict().items()
328
- }, default_flow_style=False))
329
- print("------------------------------------")
330
- print()
331
-
332
- optimizers = self.setup_optimizers(extras, models)
333
- assert optimizers is not None, "setup_optimizers() must return a DTO"
334
- if self.is_main_node:
335
- print("**OPTIMIZERS:**")
336
- print(yaml.dump({k:type(v).__name__ for k, v in optimizers.to_dict().items()}, default_flow_style=False))
337
- print("------------------------------------")
338
- print()
339
-
340
- schedulers = self.setup_schedulers(extras, models, optimizers)
341
- assert schedulers is not None, "setup_schedulers() must return a DTO"
342
- if self.is_main_node:
343
- print("**SCHEDULERS:**")
344
- print(yaml.dump({k:type(v).__name__ for k, v in schedulers.to_dict().items()}, default_flow_style=False))
345
- print("------------------------------------")
346
- print()
347
-
348
- post_extras =self.setup_extras_post(extras, models, optimizers, schedulers)
349
- assert post_extras is not None, "setup_extras_post() must return a DTO"
350
- extras = self.Extras.from_dict({ **extras.to_dict(),**post_extras.to_dict() })
351
- if self.is_main_node:
352
- print("**EXTRAS:**")
353
- print(yaml.dump({k:f"{v}" for k, v in extras.to_dict().items()}, default_flow_style=False))
354
- print("------------------------------------")
355
- print()
356
- # -------
357
-
358
- # TRAIN
359
- if self.is_main_node:
360
- print("**TRAINING STARTING...**")
361
- self.train(data, extras, models, optimizers, schedulers)
362
-
363
- if single_gpu is False:
364
- barrier()
365
- destroy_process_group()
366
- if self.is_main_node:
367
- print()
368
- print("------------------------------------")
369
- print()
370
- print("**TRAINING COMPLETE**")
371
- if self.config.wandb_project is not None:
372
- wandb.alert(title=f"Training {self.info.wandb_run_id} finished", text=f"Training {self.info.wandb_run_id} finished")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
core/data/__init__.py DELETED
@@ -1,69 +0,0 @@
1
- import json
2
- import subprocess
3
- import yaml
4
- import os
5
- from .bucketeer import Bucketeer
6
-
7
- class MultiFilter():
8
- def __init__(self, rules, default=False):
9
- self.rules = rules
10
- self.default = default
11
-
12
- def __call__(self, x):
13
- try:
14
- x_json = x['json']
15
- if isinstance(x_json, bytes):
16
- x_json = json.loads(x_json)
17
- validations = []
18
- for k, r in self.rules.items():
19
- if isinstance(k, tuple):
20
- v = r(*[x_json[kv] for kv in k])
21
- else:
22
- v = r(x_json[k])
23
- validations.append(v)
24
- return all(validations)
25
- except Exception:
26
- return False
27
-
28
- class MultiGetter():
29
- def __init__(self, rules):
30
- self.rules = rules
31
-
32
- def __call__(self, x_json):
33
- if isinstance(x_json, bytes):
34
- x_json = json.loads(x_json)
35
- outputs = []
36
- for k, r in self.rules.items():
37
- if isinstance(k, tuple):
38
- v = r(*[x_json[kv] for kv in k])
39
- else:
40
- v = r(x_json[k])
41
- outputs.append(v)
42
- if len(outputs) == 1:
43
- outputs = outputs[0]
44
- return outputs
45
-
46
- def setup_webdataset_path(paths, cache_path=None):
47
- if cache_path is None or not os.path.exists(cache_path):
48
- tar_paths = []
49
- if isinstance(paths, str):
50
- paths = [paths]
51
- for path in paths:
52
- if path.strip().endswith(".tar"):
53
- # Avoid looking up s3 if we already have a tar file
54
- tar_paths.append(path)
55
- continue
56
- bucket = "/".join(path.split("/")[:3])
57
- result = subprocess.run([f"aws s3 ls {path} --recursive | awk '{{print $4}}'"], stdout=subprocess.PIPE, shell=True, check=True)
58
- files = result.stdout.decode('utf-8').split()
59
- files = [f"{bucket}/{f}" for f in files if f.endswith(".tar")]
60
- tar_paths += files
61
-
62
- with open(cache_path, 'w', encoding='utf-8') as outfile:
63
- yaml.dump(tar_paths, outfile, default_flow_style=False)
64
- else:
65
- with open(cache_path, 'r', encoding='utf-8') as file:
66
- tar_paths = yaml.safe_load(file)
67
-
68
- tar_paths_str = ",".join([f"{p}" for p in tar_paths])
69
- return f"pipe:aws s3 cp {{ {tar_paths_str} }} -"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
core/data/bucketeer.py DELETED
@@ -1,88 +0,0 @@
1
- import torch
2
- import torchvision
3
- import numpy as np
4
- from torchtools.transforms import SmartCrop
5
- import math
6
-
7
- class Bucketeer():
8
- def __init__(self, dataloader, density=256*256, factor=8, ratios=[1/1, 1/2, 3/4, 3/5, 4/5, 6/9, 9/16], reverse_list=True, randomize_p=0.3, randomize_q=0.2, crop_mode='random', p_random_ratio=0.0, interpolate_nearest=False):
9
- assert crop_mode in ['center', 'random', 'smart']
10
- self.crop_mode = crop_mode
11
- self.ratios = ratios
12
- if reverse_list:
13
- for r in list(ratios):
14
- if 1/r not in self.ratios:
15
- self.ratios.append(1/r)
16
- self.sizes = {}
17
- for dd in density:
18
- self.sizes[dd]= [(int(((dd/r)**0.5//factor)*factor), int(((dd*r)**0.5//factor)*factor)) for r in ratios]
19
-
20
- self.batch_size = dataloader.batch_size
21
- self.iterator = iter(dataloader)
22
- all_sizes = []
23
- for k, vs in self.sizes.items():
24
- all_sizes += vs
25
- self.buckets = {s: [] for s in all_sizes}
26
- self.smartcrop = SmartCrop(int(density**0.5), randomize_p, randomize_q) if self.crop_mode=='smart' else None
27
- self.p_random_ratio = p_random_ratio
28
- self.interpolate_nearest = interpolate_nearest
29
-
30
- def get_available_batch(self):
31
- for b in self.buckets:
32
- if len(self.buckets[b]) >= self.batch_size:
33
- batch = self.buckets[b][:self.batch_size]
34
- self.buckets[b] = self.buckets[b][self.batch_size:]
35
- return batch
36
- return None
37
-
38
- def get_closest_size(self, x):
39
- w, h = x.size(-1), x.size(-2)
40
-
41
-
42
- best_size_idx = np.argmin([abs(w/h-r) for r in self.ratios])
43
- find_dict = {dd : abs(w*h - self.sizes[dd][best_size_idx][0]*self.sizes[dd][best_size_idx][1]) for dd, vv in self.sizes.items()}
44
- min_ = find_dict[list(find_dict.keys())[0]]
45
- find_size = self.sizes[list(find_dict.keys())[0]][best_size_idx]
46
- for dd, val in find_dict.items():
47
- if val < min_:
48
- min_ = val
49
- find_size = self.sizes[dd][best_size_idx]
50
-
51
- return find_size
52
-
53
- def get_resize_size(self, orig_size, tgt_size):
54
- if (tgt_size[1]/tgt_size[0] - 1) * (orig_size[1]/orig_size[0] - 1) >= 0:
55
- alt_min = int(math.ceil(max(tgt_size)*min(orig_size)/max(orig_size)))
56
- resize_size = max(alt_min, min(tgt_size))
57
- else:
58
- alt_max = int(math.ceil(min(tgt_size)*max(orig_size)/min(orig_size)))
59
- resize_size = max(alt_max, max(tgt_size))
60
-
61
- return resize_size
62
-
63
- def __next__(self):
64
- batch = self.get_available_batch()
65
- while batch is None:
66
- elements = next(self.iterator)
67
- for dct in elements:
68
- img = dct['images']
69
- size = self.get_closest_size(img)
70
- resize_size = self.get_resize_size(img.shape[-2:], size)
71
-
72
- if self.interpolate_nearest:
73
- img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.NEAREST)
74
- else:
75
- img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True)
76
- if self.crop_mode == 'center':
77
- img = torchvision.transforms.functional.center_crop(img, size)
78
- elif self.crop_mode == 'random':
79
- img = torchvision.transforms.RandomCrop(size)(img)
80
- elif self.crop_mode == 'smart':
81
- self.smartcrop.output_size = size
82
- img = self.smartcrop(img)
83
-
84
- self.buckets[size].append({**{'images': img}, **{k:dct[k] for k in dct if k != 'images'}})
85
- batch = self.get_available_batch()
86
-
87
- out = {k:[batch[i][k] for i in range(len(batch))] for k in batch[0]}
88
- return {k: torch.stack(o, dim=0) if isinstance(o[0], torch.Tensor) else o for k, o in out.items()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
core/data/bucketeer_deg.py DELETED
@@ -1,91 +0,0 @@
1
- import torch
2
- import torchvision
3
- import numpy as np
4
- from torchtools.transforms import SmartCrop
5
- import math
6
-
7
- class Bucketeer():
8
- def __init__(self, dataloader, density=256*256, factor=8, ratios=[1/1, 1/2, 3/4, 3/5, 4/5, 6/9, 9/16], reverse_list=True, randomize_p=0.3, randomize_q=0.2, crop_mode='random', p_random_ratio=0.0, interpolate_nearest=False):
9
- assert crop_mode in ['center', 'random', 'smart']
10
- self.crop_mode = crop_mode
11
- self.ratios = ratios
12
- if reverse_list:
13
- for r in list(ratios):
14
- if 1/r not in self.ratios:
15
- self.ratios.append(1/r)
16
- self.sizes = {}
17
- for dd in density:
18
- self.sizes[dd]= [(int(((dd/r)**0.5//factor)*factor), int(((dd*r)**0.5//factor)*factor)) for r in ratios]
19
- print('in line 17 buckteer', self.sizes)
20
- self.batch_size = dataloader.batch_size
21
- self.iterator = iter(dataloader)
22
- all_sizes = []
23
- for k, vs in self.sizes.items():
24
- all_sizes += vs
25
- self.buckets = {s: [] for s in all_sizes}
26
- self.smartcrop = SmartCrop(int(density**0.5), randomize_p, randomize_q) if self.crop_mode=='smart' else None
27
- self.p_random_ratio = p_random_ratio
28
- self.interpolate_nearest = interpolate_nearest
29
-
30
- def get_available_batch(self):
31
- for b in self.buckets:
32
- if len(self.buckets[b]) >= self.batch_size:
33
- batch = self.buckets[b][:self.batch_size]
34
- self.buckets[b] = self.buckets[b][self.batch_size:]
35
- return batch
36
- return None
37
-
38
- def get_closest_size(self, x):
39
- w, h = x.size(-1), x.size(-2)
40
- #if self.p_random_ratio > 0 and np.random.rand() < self.p_random_ratio:
41
- # best_size_idx = np.random.randint(len(self.ratios))
42
- #print('in line 41 get closes size', best_size_idx, x.shape, self.p_random_ratio)
43
- #else:
44
-
45
- best_size_idx = np.argmin([abs(w/h-r) for r in self.ratios])
46
- find_dict = {dd : abs(w*h - self.sizes[dd][best_size_idx][0]*self.sizes[dd][best_size_idx][1]) for dd, vv in self.sizes.items()}
47
- min_ = find_dict[list(find_dict.keys())[0]]
48
- find_size = self.sizes[list(find_dict.keys())[0]][best_size_idx]
49
- for dd, val in find_dict.items():
50
- if val < min_:
51
- min_ = val
52
- find_size = self.sizes[dd][best_size_idx]
53
-
54
- return find_size
55
-
56
- def get_resize_size(self, orig_size, tgt_size):
57
- if (tgt_size[1]/tgt_size[0] - 1) * (orig_size[1]/orig_size[0] - 1) >= 0:
58
- alt_min = int(math.ceil(max(tgt_size)*min(orig_size)/max(orig_size)))
59
- resize_size = max(alt_min, min(tgt_size))
60
- else:
61
- alt_max = int(math.ceil(min(tgt_size)*max(orig_size)/min(orig_size)))
62
- resize_size = max(alt_max, max(tgt_size))
63
- #print('in line 50', orig_size, tgt_size, resize_size)
64
- return resize_size
65
-
66
- def __next__(self):
67
- batch = self.get_available_batch()
68
- while batch is None:
69
- elements = next(self.iterator)
70
- for dct in elements:
71
- img = dct['images']
72
- size = self.get_closest_size(img)
73
- resize_size = self.get_resize_size(img.shape[-2:], size)
74
- #print('in line 74', img.size(), resize_size)
75
- if self.interpolate_nearest:
76
- img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.NEAREST)
77
- else:
78
- img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True)
79
- if self.crop_mode == 'center':
80
- img = torchvision.transforms.functional.center_crop(img, size)
81
- elif self.crop_mode == 'random':
82
- img = torchvision.transforms.RandomCrop(size)(img)
83
- elif self.crop_mode == 'smart':
84
- self.smartcrop.output_size = size
85
- img = self.smartcrop(img)
86
- print('in line 86 bucketeer', type(img), img.shape, torch.max(img), torch.min(img))
87
- self.buckets[size].append({**{'images': img}, **{k:dct[k] for k in dct if k != 'images'}})
88
- batch = self.get_available_batch()
89
-
90
- out = {k:[batch[i][k] for i in range(len(batch))] for k in batch[0]}
91
- return {k: torch.stack(o, dim=0) if isinstance(o[0], torch.Tensor) else o for k, o in out.items()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
core/data/deg_kair_utils/utils_alignfaces.py DELETED
@@ -1,263 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- """
3
- Created on Mon Apr 24 15:43:29 2017
4
- @author: zhaoy
5
- """
6
- import cv2
7
- import numpy as np
8
- from skimage import transform as trans
9
-
10
- # reference facial points, a list of coordinates (x,y)
11
- REFERENCE_FACIAL_POINTS = [
12
- [30.29459953, 51.69630051],
13
- [65.53179932, 51.50139999],
14
- [48.02519989, 71.73660278],
15
- [33.54930115, 92.3655014],
16
- [62.72990036, 92.20410156]
17
- ]
18
-
19
- DEFAULT_CROP_SIZE = (96, 112)
20
-
21
-
22
- def _umeyama(src, dst, estimate_scale=True, scale=1.0):
23
- """Estimate N-D similarity transformation with or without scaling.
24
- Parameters
25
- ----------
26
- src : (M, N) array
27
- Source coordinates.
28
- dst : (M, N) array
29
- Destination coordinates.
30
- estimate_scale : bool
31
- Whether to estimate scaling factor.
32
- Returns
33
- -------
34
- T : (N + 1, N + 1)
35
- The homogeneous similarity transformation matrix. The matrix contains
36
- NaN values only if the problem is not well-conditioned.
37
- References
38
- ----------
39
- .. [1] "Least-squares estimation of transformation parameters between two
40
- point patterns", Shinji Umeyama, PAMI 1991, :DOI:`10.1109/34.88573`
41
- """
42
-
43
- num = src.shape[0]
44
- dim = src.shape[1]
45
-
46
- # Compute mean of src and dst.
47
- src_mean = src.mean(axis=0)
48
- dst_mean = dst.mean(axis=0)
49
-
50
- # Subtract mean from src and dst.
51
- src_demean = src - src_mean
52
- dst_demean = dst - dst_mean
53
-
54
- # Eq. (38).
55
- A = dst_demean.T @ src_demean / num
56
-
57
- # Eq. (39).
58
- d = np.ones((dim,), dtype=np.double)
59
- if np.linalg.det(A) < 0:
60
- d[dim - 1] = -1
61
-
62
- T = np.eye(dim + 1, dtype=np.double)
63
-
64
- U, S, V = np.linalg.svd(A)
65
-
66
- # Eq. (40) and (43).
67
- rank = np.linalg.matrix_rank(A)
68
- if rank == 0:
69
- return np.nan * T
70
- elif rank == dim - 1:
71
- if np.linalg.det(U) * np.linalg.det(V) > 0:
72
- T[:dim, :dim] = U @ V
73
- else:
74
- s = d[dim - 1]
75
- d[dim - 1] = -1
76
- T[:dim, :dim] = U @ np.diag(d) @ V
77
- d[dim - 1] = s
78
- else:
79
- T[:dim, :dim] = U @ np.diag(d) @ V
80
-
81
- if estimate_scale:
82
- # Eq. (41) and (42).
83
- scale = 1.0 / src_demean.var(axis=0).sum() * (S @ d)
84
- else:
85
- scale = scale
86
-
87
- T[:dim, dim] = dst_mean - scale * (T[:dim, :dim] @ src_mean.T)
88
- T[:dim, :dim] *= scale
89
-
90
- return T, scale
91
-
92
-
93
- class FaceWarpException(Exception):
94
- def __str__(self):
95
- return 'In File {}:{}'.format(
96
- __file__, super.__str__(self))
97
-
98
-
99
- def get_reference_facial_points(output_size=None,
100
- inner_padding_factor=0.0,
101
- outer_padding=(0, 0),
102
- default_square=False):
103
- tmp_5pts = np.array(REFERENCE_FACIAL_POINTS)
104
- tmp_crop_size = np.array(DEFAULT_CROP_SIZE)
105
-
106
- # 0) make the inner region a square
107
- if default_square:
108
- size_diff = max(tmp_crop_size) - tmp_crop_size
109
- tmp_5pts += size_diff / 2
110
- tmp_crop_size += size_diff
111
-
112
- if (output_size and
113
- output_size[0] == tmp_crop_size[0] and
114
- output_size[1] == tmp_crop_size[1]):
115
- print('output_size == DEFAULT_CROP_SIZE {}: return default reference points'.format(tmp_crop_size))
116
- return tmp_5pts
117
-
118
- if (inner_padding_factor == 0 and
119
- outer_padding == (0, 0)):
120
- if output_size is None:
121
- print('No paddings to do: return default reference points')
122
- return tmp_5pts
123
- else:
124
- raise FaceWarpException(
125
- 'No paddings to do, output_size must be None or {}'.format(tmp_crop_size))
126
-
127
- # check output size
128
- if not (0 <= inner_padding_factor <= 1.0):
129
- raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)')
130
-
131
- if ((inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0)
132
- and output_size is None):
133
- output_size = tmp_crop_size * \
134
- (1 + inner_padding_factor * 2).astype(np.int32)
135
- output_size += np.array(outer_padding)
136
- print(' deduced from paddings, output_size = ', output_size)
137
-
138
- if not (outer_padding[0] < output_size[0]
139
- and outer_padding[1] < output_size[1]):
140
- raise FaceWarpException('Not (outer_padding[0] < output_size[0]'
141
- 'and outer_padding[1] < output_size[1])')
142
-
143
- # 1) pad the inner region according inner_padding_factor
144
- # print('---> STEP1: pad the inner region according inner_padding_factor')
145
- if inner_padding_factor > 0:
146
- size_diff = tmp_crop_size * inner_padding_factor * 2
147
- tmp_5pts += size_diff / 2
148
- tmp_crop_size += np.round(size_diff).astype(np.int32)
149
-
150
- # print(' crop_size = ', tmp_crop_size)
151
- # print(' reference_5pts = ', tmp_5pts)
152
-
153
- # 2) resize the padded inner region
154
- # print('---> STEP2: resize the padded inner region')
155
- size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2
156
- # print(' crop_size = ', tmp_crop_size)
157
- # print(' size_bf_outer_pad = ', size_bf_outer_pad)
158
-
159
- if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]:
160
- raise FaceWarpException('Must have (output_size - outer_padding)'
161
- '= some_scale * (crop_size * (1.0 + inner_padding_factor)')
162
-
163
- scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0]
164
- # print(' resize scale_factor = ', scale_factor)
165
- tmp_5pts = tmp_5pts * scale_factor
166
- # size_diff = tmp_crop_size * (scale_factor - min(scale_factor))
167
- # tmp_5pts = tmp_5pts + size_diff / 2
168
- tmp_crop_size = size_bf_outer_pad
169
- # print(' crop_size = ', tmp_crop_size)
170
- # print(' reference_5pts = ', tmp_5pts)
171
-
172
- # 3) add outer_padding to make output_size
173
- reference_5point = tmp_5pts + np.array(outer_padding)
174
- tmp_crop_size = output_size
175
- # print('---> STEP3: add outer_padding to make output_size')
176
- # print(' crop_size = ', tmp_crop_size)
177
- # print(' reference_5pts = ', tmp_5pts)
178
- #
179
- # print('===> end get_reference_facial_points\n')
180
-
181
- return reference_5point
182
-
183
-
184
- def get_affine_transform_matrix(src_pts, dst_pts):
185
- tfm = np.float32([[1, 0, 0], [0, 1, 0]])
186
- n_pts = src_pts.shape[0]
187
- ones = np.ones((n_pts, 1), src_pts.dtype)
188
- src_pts_ = np.hstack([src_pts, ones])
189
- dst_pts_ = np.hstack([dst_pts, ones])
190
-
191
- A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_)
192
-
193
- if rank == 3:
194
- tfm = np.float32([
195
- [A[0, 0], A[1, 0], A[2, 0]],
196
- [A[0, 1], A[1, 1], A[2, 1]]
197
- ])
198
- elif rank == 2:
199
- tfm = np.float32([
200
- [A[0, 0], A[1, 0], 0],
201
- [A[0, 1], A[1, 1], 0]
202
- ])
203
-
204
- return tfm
205
-
206
-
207
- def warp_and_crop_face(src_img,
208
- facial_pts,
209
- reference_pts=None,
210
- crop_size=(96, 112),
211
- align_type='smilarity'): #smilarity cv2_affine affine
212
- if reference_pts is None:
213
- if crop_size[0] == 96 and crop_size[1] == 112:
214
- reference_pts = REFERENCE_FACIAL_POINTS
215
- else:
216
- default_square = False
217
- inner_padding_factor = 0
218
- outer_padding = (0, 0)
219
- output_size = crop_size
220
-
221
- reference_pts = get_reference_facial_points(output_size,
222
- inner_padding_factor,
223
- outer_padding,
224
- default_square)
225
-
226
- ref_pts = np.float32(reference_pts)
227
- ref_pts_shp = ref_pts.shape
228
- if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2:
229
- raise FaceWarpException(
230
- 'reference_pts.shape must be (K,2) or (2,K) and K>2')
231
-
232
- if ref_pts_shp[0] == 2:
233
- ref_pts = ref_pts.T
234
-
235
- src_pts = np.float32(facial_pts)
236
- src_pts_shp = src_pts.shape
237
- if max(src_pts_shp) < 3 or min(src_pts_shp) != 2:
238
- raise FaceWarpException(
239
- 'facial_pts.shape must be (K,2) or (2,K) and K>2')
240
-
241
- if src_pts_shp[0] == 2:
242
- src_pts = src_pts.T
243
-
244
- if src_pts.shape != ref_pts.shape:
245
- raise FaceWarpException(
246
- 'facial_pts and reference_pts must have the same shape')
247
-
248
- if align_type is 'cv2_affine':
249
- tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3])
250
- tfm_inv = cv2.getAffineTransform(ref_pts[0:3], src_pts[0:3])
251
- elif align_type is 'affine':
252
- tfm = get_affine_transform_matrix(src_pts, ref_pts)
253
- tfm_inv = get_affine_transform_matrix(ref_pts, src_pts)
254
- else:
255
- params, scale = _umeyama(src_pts, ref_pts)
256
- tfm = params[:2, :]
257
-
258
- params, _ = _umeyama(ref_pts, src_pts, False, scale=1.0/scale)
259
- tfm_inv = params[:2, :]
260
-
261
- face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1]), flags=3)
262
-
263
- return face_img, tfm_inv
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
core/data/deg_kair_utils/utils_blindsr.py DELETED
@@ -1,631 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- import numpy as np
3
- import cv2
4
- import torch
5
-
6
- from core.data.deg_kair_utils import utils_image as util
7
-
8
- import random
9
- from scipy import ndimage
10
- import scipy
11
- import scipy.stats as ss
12
- from scipy.interpolate import interp2d
13
- from scipy.linalg import orth
14
-
15
-
16
-
17
-
18
- """
19
- # --------------------------------------------
20
- # Super-Resolution
21
- # --------------------------------------------
22
- #
23
- # Kai Zhang (cskaizhang@gmail.com)
24
- # https://github.com/cszn
25
- # From 2019/03--2021/08
26
- # --------------------------------------------
27
- """
28
-
29
- def modcrop_np(img, sf):
30
- '''
31
- Args:
32
- img: numpy image, WxH or WxHxC
33
- sf: scale factor
34
-
35
- Return:
36
- cropped image
37
- '''
38
- w, h = img.shape[:2]
39
- im = np.copy(img)
40
- return im[:w - w % sf, :h - h % sf, ...]
41
-
42
-
43
- """
44
- # --------------------------------------------
45
- # anisotropic Gaussian kernels
46
- # --------------------------------------------
47
- """
48
- def analytic_kernel(k):
49
- """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
50
- k_size = k.shape[0]
51
- # Calculate the big kernels size
52
- big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
53
- # Loop over the small kernel to fill the big one
54
- for r in range(k_size):
55
- for c in range(k_size):
56
- big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
57
- # Crop the edges of the big kernel to ignore very small values and increase run time of SR
58
- crop = k_size // 2
59
- cropped_big_k = big_k[crop:-crop, crop:-crop]
60
- # Normalize to 1
61
- return cropped_big_k / cropped_big_k.sum()
62
-
63
-
64
- def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
65
- """ generate an anisotropic Gaussian kernel
66
- Args:
67
- ksize : e.g., 15, kernel size
68
- theta : [0, pi], rotation angle range
69
- l1 : [0.1,50], scaling of eigenvalues
70
- l2 : [0.1,l1], scaling of eigenvalues
71
- If l1 = l2, will get an isotropic Gaussian kernel.
72
-
73
- Returns:
74
- k : kernel
75
- """
76
-
77
- v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
78
- V = np.array([[v[0], v[1]], [v[1], -v[0]]])
79
- D = np.array([[l1, 0], [0, l2]])
80
- Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
81
- k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
82
-
83
- return k
84
-
85
-
86
- def gm_blur_kernel(mean, cov, size=15):
87
- center = size / 2.0 + 0.5
88
- k = np.zeros([size, size])
89
- for y in range(size):
90
- for x in range(size):
91
- cy = y - center + 1
92
- cx = x - center + 1
93
- k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
94
-
95
- k = k / np.sum(k)
96
- return k
97
-
98
-
99
- def shift_pixel(x, sf, upper_left=True):
100
- """shift pixel for super-resolution with different scale factors
101
- Args:
102
- x: WxHxC or WxH
103
- sf: scale factor
104
- upper_left: shift direction
105
- """
106
- h, w = x.shape[:2]
107
- shift = (sf-1)*0.5
108
- xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
109
- if upper_left:
110
- x1 = xv + shift
111
- y1 = yv + shift
112
- else:
113
- x1 = xv - shift
114
- y1 = yv - shift
115
-
116
- x1 = np.clip(x1, 0, w-1)
117
- y1 = np.clip(y1, 0, h-1)
118
-
119
- if x.ndim == 2:
120
- x = interp2d(xv, yv, x)(x1, y1)
121
- if x.ndim == 3:
122
- for i in range(x.shape[-1]):
123
- x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
124
-
125
- return x
126
-
127
-
128
- def blur(x, k):
129
- '''
130
- x: image, NxcxHxW
131
- k: kernel, Nx1xhxw
132
- '''
133
- n, c = x.shape[:2]
134
- p1, p2 = (k.shape[-2]-1)//2, (k.shape[-1]-1)//2
135
- x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
136
- k = k.repeat(1,c,1,1)
137
- k = k.view(-1, 1, k.shape[2], k.shape[3])
138
- x = x.view(1, -1, x.shape[2], x.shape[3])
139
- x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n*c)
140
- x = x.view(n, c, x.shape[2], x.shape[3])
141
-
142
- return x
143
-
144
-
145
-
146
- def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
147
- """"
148
- # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
149
- # Kai Zhang
150
- # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
151
- # max_var = 2.5 * sf
152
- """
153
- # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
154
- lambda_1 = min_var + np.random.rand() * (max_var - min_var)
155
- lambda_2 = min_var + np.random.rand() * (max_var - min_var)
156
- theta = np.random.rand() * np.pi # random theta
157
- noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
158
-
159
- # Set COV matrix using Lambdas and Theta
160
- LAMBDA = np.diag([lambda_1, lambda_2])
161
- Q = np.array([[np.cos(theta), -np.sin(theta)],
162
- [np.sin(theta), np.cos(theta)]])
163
- SIGMA = Q @ LAMBDA @ Q.T
164
- INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
165
-
166
- # Set expectation position (shifting kernel for aligned image)
167
- MU = k_size // 2 - 0.5*(scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
168
- MU = MU[None, None, :, None]
169
-
170
- # Create meshgrid for Gaussian
171
- [X,Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
172
- Z = np.stack([X, Y], 2)[:, :, :, None]
173
-
174
- # Calcualte Gaussian for every pixel of the kernel
175
- ZZ = Z-MU
176
- ZZ_t = ZZ.transpose(0,1,3,2)
177
- raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
178
-
179
- # shift the kernel so it will be centered
180
- #raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
181
-
182
- # Normalize the kernel and return
183
- #kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
184
- kernel = raw_kernel / np.sum(raw_kernel)
185
- return kernel
186
-
187
-
188
- def fspecial_gaussian(hsize, sigma):
189
- hsize = [hsize, hsize]
190
- siz = [(hsize[0]-1.0)/2.0, (hsize[1]-1.0)/2.0]
191
- std = sigma
192
- [x, y] = np.meshgrid(np.arange(-siz[1], siz[1]+1), np.arange(-siz[0], siz[0]+1))
193
- arg = -(x*x + y*y)/(2*std*std)
194
- h = np.exp(arg)
195
- h[h < scipy.finfo(float).eps * h.max()] = 0
196
- sumh = h.sum()
197
- if sumh != 0:
198
- h = h/sumh
199
- return h
200
-
201
-
202
- def fspecial_laplacian(alpha):
203
- alpha = max([0, min([alpha,1])])
204
- h1 = alpha/(alpha+1)
205
- h2 = (1-alpha)/(alpha+1)
206
- h = [[h1, h2, h1], [h2, -4/(alpha+1), h2], [h1, h2, h1]]
207
- h = np.array(h)
208
- return h
209
-
210
-
211
- def fspecial(filter_type, *args, **kwargs):
212
- '''
213
- python code from:
214
- https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
215
- '''
216
- if filter_type == 'gaussian':
217
- return fspecial_gaussian(*args, **kwargs)
218
- if filter_type == 'laplacian':
219
- return fspecial_laplacian(*args, **kwargs)
220
-
221
- """
222
- # --------------------------------------------
223
- # degradation models
224
- # --------------------------------------------
225
- """
226
-
227
-
228
- def bicubic_degradation(x, sf=3):
229
- '''
230
- Args:
231
- x: HxWxC image, [0, 1]
232
- sf: down-scale factor
233
-
234
- Return:
235
- bicubicly downsampled LR image
236
- '''
237
- x = util.imresize_np(x, scale=1/sf)
238
- return x
239
-
240
-
241
- def srmd_degradation(x, k, sf=3):
242
- ''' blur + bicubic downsampling
243
-
244
- Args:
245
- x: HxWxC image, [0, 1]
246
- k: hxw, double
247
- sf: down-scale factor
248
-
249
- Return:
250
- downsampled LR image
251
-
252
- Reference:
253
- @inproceedings{zhang2018learning,
254
- title={Learning a single convolutional super-resolution network for multiple degradations},
255
- author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
256
- booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
257
- pages={3262--3271},
258
- year={2018}
259
- }
260
- '''
261
- x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
262
- x = bicubic_degradation(x, sf=sf)
263
- return x
264
-
265
-
266
- def dpsr_degradation(x, k, sf=3):
267
-
268
- ''' bicubic downsampling + blur
269
-
270
- Args:
271
- x: HxWxC image, [0, 1]
272
- k: hxw, double
273
- sf: down-scale factor
274
-
275
- Return:
276
- downsampled LR image
277
-
278
- Reference:
279
- @inproceedings{zhang2019deep,
280
- title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
281
- author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
282
- booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
283
- pages={1671--1681},
284
- year={2019}
285
- }
286
- '''
287
- x = bicubic_degradation(x, sf=sf)
288
- x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
289
- return x
290
-
291
-
292
- def classical_degradation(x, k, sf=3):
293
- ''' blur + downsampling
294
-
295
- Args:
296
- x: HxWxC image, [0, 1]/[0, 255]
297
- k: hxw, double
298
- sf: down-scale factor
299
-
300
- Return:
301
- downsampled LR image
302
- '''
303
- x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
304
- #x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
305
- st = 0
306
- return x[st::sf, st::sf, ...]
307
-
308
-
309
- def add_sharpening(img, weight=0.5, radius=50, threshold=10):
310
- """USM sharpening. borrowed from real-ESRGAN
311
- Input image: I; Blurry image: B.
312
- 1. K = I + weight * (I - B)
313
- 2. Mask = 1 if abs(I - B) > threshold, else: 0
314
- 3. Blur mask:
315
- 4. Out = Mask * K + (1 - Mask) * I
316
- Args:
317
- img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
318
- weight (float): Sharp weight. Default: 1.
319
- radius (float): Kernel size of Gaussian blur. Default: 50.
320
- threshold (int):
321
- """
322
- if radius % 2 == 0:
323
- radius += 1
324
- blur = cv2.GaussianBlur(img, (radius, radius), 0)
325
- residual = img - blur
326
- mask = np.abs(residual) * 255 > threshold
327
- mask = mask.astype('float32')
328
- soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
329
-
330
- K = img + weight * residual
331
- K = np.clip(K, 0, 1)
332
- return soft_mask * K + (1 - soft_mask) * img
333
-
334
-
335
- def add_blur(img, sf=4):
336
- wd2 = 4.0 + sf
337
- wd = 2.0 + 0.2*sf
338
- if random.random() < 0.5:
339
- l1 = wd2*random.random()
340
- l2 = wd2*random.random()
341
- k = anisotropic_Gaussian(ksize=2*random.randint(2,11)+3, theta=random.random()*np.pi, l1=l1, l2=l2)
342
- else:
343
- k = fspecial('gaussian', 2*random.randint(2,11)+3, wd*random.random())
344
- img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
345
-
346
- return img
347
-
348
-
349
- def add_resize(img, sf=4):
350
- rnum = np.random.rand()
351
- if rnum > 0.8: # up
352
- sf1 = random.uniform(1, 2)
353
- elif rnum < 0.7: # down
354
- sf1 = random.uniform(0.5/sf, 1)
355
- else:
356
- sf1 = 1.0
357
- img = cv2.resize(img, (int(sf1*img.shape[1]), int(sf1*img.shape[0])), interpolation=random.choice([1, 2, 3]))
358
- img = np.clip(img, 0.0, 1.0)
359
-
360
- return img
361
-
362
-
363
- def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
364
- noise_level = random.randint(noise_level1, noise_level2)
365
- rnum = np.random.rand()
366
- if rnum > 0.6: # add color Gaussian noise
367
- img += np.random.normal(0, noise_level/255.0, img.shape).astype(np.float32)
368
- elif rnum < 0.4: # add grayscale Gaussian noise
369
- img += np.random.normal(0, noise_level/255.0, (*img.shape[:2], 1)).astype(np.float32)
370
- else: # add noise
371
- L = noise_level2/255.
372
- D = np.diag(np.random.rand(3))
373
- U = orth(np.random.rand(3,3))
374
- conv = np.dot(np.dot(np.transpose(U), D), U)
375
- img += np.random.multivariate_normal([0,0,0], np.abs(L**2*conv), img.shape[:2]).astype(np.float32)
376
- img = np.clip(img, 0.0, 1.0)
377
- return img
378
-
379
-
380
- def add_speckle_noise(img, noise_level1=2, noise_level2=25):
381
- noise_level = random.randint(noise_level1, noise_level2)
382
- img = np.clip(img, 0.0, 1.0)
383
- rnum = random.random()
384
- if rnum > 0.6:
385
- img += img*np.random.normal(0, noise_level/255.0, img.shape).astype(np.float32)
386
- elif rnum < 0.4:
387
- img += img*np.random.normal(0, noise_level/255.0, (*img.shape[:2], 1)).astype(np.float32)
388
- else:
389
- L = noise_level2/255.
390
- D = np.diag(np.random.rand(3))
391
- U = orth(np.random.rand(3,3))
392
- conv = np.dot(np.dot(np.transpose(U), D), U)
393
- img += img*np.random.multivariate_normal([0,0,0], np.abs(L**2*conv), img.shape[:2]).astype(np.float32)
394
- img = np.clip(img, 0.0, 1.0)
395
- return img
396
-
397
-
398
- def add_Poisson_noise(img):
399
- img = np.clip((img * 255.0).round(), 0, 255) / 255.
400
- vals = 10**(2*random.random()+2.0) # [2, 4]
401
- if random.random() < 0.5:
402
- img = np.random.poisson(img * vals).astype(np.float32) / vals
403
- else:
404
- img_gray = np.dot(img[...,:3], [0.299, 0.587, 0.114])
405
- img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
406
- noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
407
- img += noise_gray[:, :, np.newaxis]
408
- img = np.clip(img, 0.0, 1.0)
409
- return img
410
-
411
-
412
- def add_JPEG_noise(img):
413
- quality_factor = random.randint(30, 95)
414
- img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
415
- result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
416
- img = cv2.imdecode(encimg, 1)
417
- img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
418
- return img
419
-
420
-
421
- def random_crop(lq, hq, sf=4, lq_patchsize=64):
422
- h, w = lq.shape[:2]
423
- rnd_h = random.randint(0, h-lq_patchsize)
424
- rnd_w = random.randint(0, w-lq_patchsize)
425
- lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
426
-
427
- rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
428
- hq = hq[rnd_h_H:rnd_h_H + lq_patchsize*sf, rnd_w_H:rnd_w_H + lq_patchsize*sf, :]
429
- return lq, hq
430
-
431
-
432
- def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
433
- """
434
- This is the degradation model of BSRGAN from the paper
435
- "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
436
- ----------
437
- img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
438
- sf: scale factor
439
- isp_model: camera ISP model
440
-
441
- Returns
442
- -------
443
- img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
444
- hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
445
- """
446
- isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
447
- sf_ori = sf
448
-
449
- h1, w1 = img.shape[:2]
450
- img = img.copy()[:h1 - h1 % sf, :w1 - w1 % sf, ...] # mod crop
451
- h, w = img.shape[:2]
452
-
453
- if h < lq_patchsize*sf or w < lq_patchsize*sf:
454
- raise ValueError(f'img size ({h1}X{w1}) is too small!')
455
-
456
- hq = img.copy()
457
-
458
- if sf == 4 and random.random() < scale2_prob: # downsample1
459
- if np.random.rand() < 0.5:
460
- img = cv2.resize(img, (int(1/2*img.shape[1]), int(1/2*img.shape[0])), interpolation=random.choice([1,2,3]))
461
- else:
462
- img = util.imresize_np(img, 1/2, True)
463
- img = np.clip(img, 0.0, 1.0)
464
- sf = 2
465
-
466
- shuffle_order = random.sample(range(7), 7)
467
- idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
468
- if idx1 > idx2: # keep downsample3 last
469
- shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
470
-
471
- for i in shuffle_order:
472
-
473
- if i == 0:
474
- img = add_blur(img, sf=sf)
475
-
476
- elif i == 1:
477
- img = add_blur(img, sf=sf)
478
-
479
- elif i == 2:
480
- a, b = img.shape[1], img.shape[0]
481
- # downsample2
482
- if random.random() < 0.75:
483
- sf1 = random.uniform(1,2*sf)
484
- img = cv2.resize(img, (int(1/sf1*img.shape[1]), int(1/sf1*img.shape[0])), interpolation=random.choice([1,2,3]))
485
- else:
486
- k = fspecial('gaussian', 25, random.uniform(0.1, 0.6*sf))
487
- k_shifted = shift_pixel(k, sf)
488
- k_shifted = k_shifted/k_shifted.sum() # blur with shifted kernel
489
- img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
490
- img = img[0::sf, 0::sf, ...] # nearest downsampling
491
- img = np.clip(img, 0.0, 1.0)
492
-
493
- elif i == 3:
494
- # downsample3
495
- img = cv2.resize(img, (int(1/sf*a), int(1/sf*b)), interpolation=random.choice([1,2,3]))
496
- img = np.clip(img, 0.0, 1.0)
497
-
498
- elif i == 4:
499
- # add Gaussian noise
500
- img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
501
-
502
- elif i == 5:
503
- # add JPEG noise
504
- if random.random() < jpeg_prob:
505
- img = add_JPEG_noise(img)
506
-
507
- elif i == 6:
508
- # add processed camera sensor noise
509
- if random.random() < isp_prob and isp_model is not None:
510
- with torch.no_grad():
511
- img, hq = isp_model.forward(img.copy(), hq)
512
-
513
- # add final JPEG compression noise
514
- img = add_JPEG_noise(img)
515
-
516
- # random crop
517
- img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
518
-
519
- return img, hq
520
-
521
-
522
-
523
-
524
- def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=False, lq_patchsize=64, isp_model=None):
525
- """
526
- This is an extended degradation model by combining
527
- the degradation models of BSRGAN and Real-ESRGAN
528
- ----------
529
- img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
530
- sf: scale factor
531
- use_shuffle: the degradation shuffle
532
- use_sharp: sharpening the img
533
-
534
- Returns
535
- -------
536
- img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
537
- hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
538
- """
539
-
540
- h1, w1 = img.shape[:2]
541
- img = img.copy()[:h1 - h1 % sf, :w1 - w1 % sf, ...] # mod crop
542
- h, w = img.shape[:2]
543
-
544
- if h < lq_patchsize*sf or w < lq_patchsize*sf:
545
- raise ValueError(f'img size ({h1}X{w1}) is too small!')
546
-
547
- if use_sharp:
548
- img = add_sharpening(img)
549
- hq = img.copy()
550
-
551
- if random.random() < shuffle_prob:
552
- shuffle_order = random.sample(range(13), 13)
553
- else:
554
- shuffle_order = list(range(13))
555
- # local shuffle for noise, JPEG is always the last one
556
- shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6)))
557
- shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13)))
558
-
559
- poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1
560
-
561
- for i in shuffle_order:
562
- if i == 0:
563
- img = add_blur(img, sf=sf)
564
- elif i == 1:
565
- img = add_resize(img, sf=sf)
566
- elif i == 2:
567
- img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
568
- elif i == 3:
569
- if random.random() < poisson_prob:
570
- img = add_Poisson_noise(img)
571
- elif i == 4:
572
- if random.random() < speckle_prob:
573
- img = add_speckle_noise(img)
574
- elif i == 5:
575
- if random.random() < isp_prob and isp_model is not None:
576
- with torch.no_grad():
577
- img, hq = isp_model.forward(img.copy(), hq)
578
- elif i == 6:
579
- img = add_JPEG_noise(img)
580
- elif i == 7:
581
- img = add_blur(img, sf=sf)
582
- elif i == 8:
583
- img = add_resize(img, sf=sf)
584
- elif i == 9:
585
- img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
586
- elif i == 10:
587
- if random.random() < poisson_prob:
588
- img = add_Poisson_noise(img)
589
- elif i == 11:
590
- if random.random() < speckle_prob:
591
- img = add_speckle_noise(img)
592
- elif i == 12:
593
- if random.random() < isp_prob and isp_model is not None:
594
- with torch.no_grad():
595
- img, hq = isp_model.forward(img.copy(), hq)
596
- else:
597
- print('check the shuffle!')
598
-
599
- # resize to desired size
600
- img = cv2.resize(img, (int(1/sf*hq.shape[1]), int(1/sf*hq.shape[0])), interpolation=random.choice([1, 2, 3]))
601
-
602
- # add final JPEG compression noise
603
- img = add_JPEG_noise(img)
604
-
605
- # random crop
606
- img, hq = random_crop(img, hq, sf, lq_patchsize)
607
-
608
- return img, hq
609
-
610
-
611
-
612
- if __name__ == '__main__':
613
- img = util.imread_uint('utils/test.png', 3)
614
- img = util.uint2single(img)
615
- sf = 4
616
-
617
- for i in range(20):
618
- img_lq, img_hq = degradation_bsrgan(img, sf=sf, lq_patchsize=72)
619
- print(i)
620
- lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf*img_lq.shape[1]), int(sf*img_lq.shape[0])), interpolation=0)
621
- img_concat = np.concatenate([lq_nearest, util.single2uint(img_hq)], axis=1)
622
- util.imsave(img_concat, str(i)+'.png')
623
-
624
- # for i in range(10):
625
- # img_lq, img_hq = degradation_bsrgan_plus(img, sf=sf, shuffle_prob=0.1, use_sharp=True, lq_patchsize=64)
626
- # print(i)
627
- # lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf*img_lq.shape[1]), int(sf*img_lq.shape[0])), interpolation=0)
628
- # img_concat = np.concatenate([lq_nearest, util.single2uint(img_hq)], axis=1)
629
- # util.imsave(img_concat, str(i)+'.png')
630
-
631
- # run utils/utils_blindsr.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
core/data/deg_kair_utils/utils_bnorm.py DELETED
@@ -1,91 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
-
4
-
5
- """
6
- # --------------------------------------------
7
- # Batch Normalization
8
- # --------------------------------------------
9
-
10
- # Kai Zhang (cskaizhang@gmail.com)
11
- # https://github.com/cszn
12
- # 01/Jan/2019
13
- # --------------------------------------------
14
- """
15
-
16
-
17
- # --------------------------------------------
18
- # remove/delete specified layer
19
- # --------------------------------------------
20
- def deleteLayer(model, layer_type=nn.BatchNorm2d):
21
- ''' Kai Zhang, 11/Jan/2019.
22
- '''
23
- for k, m in list(model.named_children()):
24
- if isinstance(m, layer_type):
25
- del model._modules[k]
26
- deleteLayer(m, layer_type)
27
-
28
-
29
- # --------------------------------------------
30
- # merge bn, "conv+bn" --> "conv"
31
- # --------------------------------------------
32
- def merge_bn(model):
33
- ''' Kai Zhang, 11/Jan/2019.
34
- merge all 'Conv+BN' (or 'TConv+BN') into 'Conv' (or 'TConv')
35
- based on https://github.com/pytorch/pytorch/pull/901
36
- '''
37
- prev_m = None
38
- for k, m in list(model.named_children()):
39
- if (isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d)) and (isinstance(prev_m, nn.Conv2d) or isinstance(prev_m, nn.Linear) or isinstance(prev_m, nn.ConvTranspose2d)):
40
-
41
- w = prev_m.weight.data
42
-
43
- if prev_m.bias is None:
44
- zeros = torch.Tensor(prev_m.out_channels).zero_().type(w.type())
45
- prev_m.bias = nn.Parameter(zeros)
46
- b = prev_m.bias.data
47
-
48
- invstd = m.running_var.clone().add_(m.eps).pow_(-0.5)
49
- if isinstance(prev_m, nn.ConvTranspose2d):
50
- w.mul_(invstd.view(1, w.size(1), 1, 1).expand_as(w))
51
- else:
52
- w.mul_(invstd.view(w.size(0), 1, 1, 1).expand_as(w))
53
- b.add_(-m.running_mean).mul_(invstd)
54
- if m.affine:
55
- if isinstance(prev_m, nn.ConvTranspose2d):
56
- w.mul_(m.weight.data.view(1, w.size(1), 1, 1).expand_as(w))
57
- else:
58
- w.mul_(m.weight.data.view(w.size(0), 1, 1, 1).expand_as(w))
59
- b.mul_(m.weight.data).add_(m.bias.data)
60
-
61
- del model._modules[k]
62
- prev_m = m
63
- merge_bn(m)
64
-
65
-
66
- # --------------------------------------------
67
- # add bn, "conv" --> "conv+bn"
68
- # --------------------------------------------
69
- def add_bn(model):
70
- ''' Kai Zhang, 11/Jan/2019.
71
- '''
72
- for k, m in list(model.named_children()):
73
- if (isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose2d)):
74
- b = nn.BatchNorm2d(m.out_channels, momentum=0.1, affine=True)
75
- b.weight.data.fill_(1)
76
- new_m = nn.Sequential(model._modules[k], b)
77
- model._modules[k] = new_m
78
- add_bn(m)
79
-
80
-
81
- # --------------------------------------------
82
- # tidy model after removing bn
83
- # --------------------------------------------
84
- def tidy_sequential(model):
85
- ''' Kai Zhang, 11/Jan/2019.
86
- '''
87
- for k, m in list(model.named_children()):
88
- if isinstance(m, nn.Sequential):
89
- if m.__len__() == 1:
90
- model._modules[k] = m.__getitem__(0)
91
- tidy_sequential(m)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
core/data/deg_kair_utils/utils_deblur.py DELETED
@@ -1,655 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- import numpy as np
3
- import scipy
4
- from scipy import fftpack
5
- import torch
6
-
7
- from math import cos, sin
8
- from numpy import zeros, ones, prod, array, pi, log, min, mod, arange, sum, mgrid, exp, pad, round
9
- from numpy.random import randn, rand
10
- from scipy.signal import convolve2d
11
- import cv2
12
- import random
13
- # import utils_image as util
14
-
15
- '''
16
- modified by Kai Zhang (github: https://github.com/cszn)
17
- 03/03/2019
18
- '''
19
-
20
-
21
- def get_uperleft_denominator(img, kernel):
22
- '''
23
- img: HxWxC
24
- kernel: hxw
25
- denominator: HxWx1
26
- upperleft: HxWxC
27
- '''
28
- V = psf2otf(kernel, img.shape[:2])
29
- denominator = np.expand_dims(np.abs(V)**2, axis=2)
30
- upperleft = np.expand_dims(np.conj(V), axis=2) * np.fft.fft2(img, axes=[0, 1])
31
- return upperleft, denominator
32
-
33
-
34
- def get_uperleft_denominator_pytorch(img, kernel):
35
- '''
36
- img: NxCxHxW
37
- kernel: Nx1xhxw
38
- denominator: Nx1xHxW
39
- upperleft: NxCxHxWx2
40
- '''
41
- V = p2o(kernel, img.shape[-2:]) # Nx1xHxWx2
42
- denominator = V[..., 0]**2+V[..., 1]**2 # Nx1xHxW
43
- upperleft = cmul(cconj(V), rfft(img)) # Nx1xHxWx2 * NxCxHxWx2
44
- return upperleft, denominator
45
-
46
-
47
- def c2c(x):
48
- return torch.from_numpy(np.stack([np.float32(x.real), np.float32(x.imag)], axis=-1))
49
-
50
-
51
- def r2c(x):
52
- return torch.stack([x, torch.zeros_like(x)], -1)
53
-
54
-
55
- def cdiv(x, y):
56
- a, b = x[..., 0], x[..., 1]
57
- c, d = y[..., 0], y[..., 1]
58
- cd2 = c**2 + d**2
59
- return torch.stack([(a*c+b*d)/cd2, (b*c-a*d)/cd2], -1)
60
-
61
-
62
- def cabs(x):
63
- return torch.pow(x[..., 0]**2+x[..., 1]**2, 0.5)
64
-
65
-
66
- def cmul(t1, t2):
67
- '''
68
- complex multiplication
69
- t1: NxCxHxWx2
70
- output: NxCxHxWx2
71
- '''
72
- real1, imag1 = t1[..., 0], t1[..., 1]
73
- real2, imag2 = t2[..., 0], t2[..., 1]
74
- return torch.stack([real1 * real2 - imag1 * imag2, real1 * imag2 + imag1 * real2], dim=-1)
75
-
76
-
77
- def cconj(t, inplace=False):
78
- '''
79
- # complex's conjugation
80
- t: NxCxHxWx2
81
- output: NxCxHxWx2
82
- '''
83
- c = t.clone() if not inplace else t
84
- c[..., 1] *= -1
85
- return c
86
-
87
-
88
- def rfft(t):
89
- return torch.rfft(t, 2, onesided=False)
90
-
91
-
92
- def irfft(t):
93
- return torch.irfft(t, 2, onesided=False)
94
-
95
-
96
- def fft(t):
97
- return torch.fft(t, 2)
98
-
99
-
100
- def ifft(t):
101
- return torch.ifft(t, 2)
102
-
103
-
104
- def p2o(psf, shape):
105
- '''
106
- # psf: NxCxhxw
107
- # shape: [H,W]
108
- # otf: NxCxHxWx2
109
- '''
110
- otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf)
111
- otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf)
112
- for axis, axis_size in enumerate(psf.shape[2:]):
113
- otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2)
114
- otf = torch.rfft(otf, 2, onesided=False)
115
- n_ops = torch.sum(torch.tensor(psf.shape).type_as(psf) * torch.log2(torch.tensor(psf.shape).type_as(psf)))
116
- otf[...,1][torch.abs(otf[...,1])<n_ops*2.22e-16] = torch.tensor(0).type_as(psf)
117
- return otf
118
-
119
-
120
-
121
- # otf2psf: not sure where I got this one from. Maybe translated from Octave source code or whatever. It's just math.
122
- def otf2psf(otf, outsize=None):
123
- insize = np.array(otf.shape)
124
- psf = np.fft.ifftn(otf, axes=(0, 1))
125
- for axis, axis_size in enumerate(insize):
126
- psf = np.roll(psf, np.floor(axis_size / 2).astype(int), axis=axis)
127
- if type(outsize) != type(None):
128
- insize = np.array(otf.shape)
129
- outsize = np.array(outsize)
130
- n = max(np.size(outsize), np.size(insize))
131
- # outsize = postpad(outsize(:), n, 1);
132
- # insize = postpad(insize(:) , n, 1);
133
- colvec_out = outsize.flatten().reshape((np.size(outsize), 1))
134
- colvec_in = insize.flatten().reshape((np.size(insize), 1))
135
- outsize = np.pad(colvec_out, ((0, max(0, n - np.size(colvec_out))), (0, 0)), mode="constant")
136
- insize = np.pad(colvec_in, ((0, max(0, n - np.size(colvec_in))), (0, 0)), mode="constant")
137
-
138
- pad = (insize - outsize) / 2
139
- if np.any(pad < 0):
140
- print("otf2psf error: OUTSIZE must be smaller than or equal than OTF size")
141
- prepad = np.floor(pad)
142
- postpad = np.ceil(pad)
143
- dims_start = prepad.astype(int)
144
- dims_end = (insize - postpad).astype(int)
145
- for i in range(len(dims_start.shape)):
146
- psf = np.take(psf, range(dims_start[i][0], dims_end[i][0]), axis=i)
147
- n_ops = np.sum(otf.size * np.log2(otf.shape))
148
- psf = np.real_if_close(psf, tol=n_ops)
149
- return psf
150
-
151
-
152
- # psf2otf copied/modified from https://github.com/aboucaud/pypher/blob/master/pypher/pypher.py
153
- def psf2otf(psf, shape=None):
154
- """
155
- Convert point-spread function to optical transfer function.
156
- Compute the Fast Fourier Transform (FFT) of the point-spread
157
- function (PSF) array and creates the optical transfer function (OTF)
158
- array that is not influenced by the PSF off-centering.
159
- By default, the OTF array is the same size as the PSF array.
160
- To ensure that the OTF is not altered due to PSF off-centering, PSF2OTF
161
- post-pads the PSF array (down or to the right) with zeros to match
162
- dimensions specified in OUTSIZE, then circularly shifts the values of
163
- the PSF array up (or to the left) until the central pixel reaches (1,1)
164
- position.
165
- Parameters
166
- ----------
167
- psf : `numpy.ndarray`
168
- PSF array
169
- shape : int
170
- Output shape of the OTF array
171
- Returns
172
- -------
173
- otf : `numpy.ndarray`
174
- OTF array
175
- Notes
176
- -----
177
- Adapted from MATLAB psf2otf function
178
- """
179
- if type(shape) == type(None):
180
- shape = psf.shape
181
- shape = np.array(shape)
182
- if np.all(psf == 0):
183
- # return np.zeros_like(psf)
184
- return np.zeros(shape)
185
- if len(psf.shape) == 1:
186
- psf = psf.reshape((1, psf.shape[0]))
187
- inshape = psf.shape
188
- psf = zero_pad(psf, shape, position='corner')
189
- for axis, axis_size in enumerate(inshape):
190
- psf = np.roll(psf, -int(axis_size / 2), axis=axis)
191
- # Compute the OTF
192
- otf = np.fft.fft2(psf, axes=(0, 1))
193
- # Estimate the rough number of operations involved in the FFT
194
- # and discard the PSF imaginary part if within roundoff error
195
- # roundoff error = machine epsilon = sys.float_info.epsilon
196
- # or np.finfo().eps
197
- n_ops = np.sum(psf.size * np.log2(psf.shape))
198
- otf = np.real_if_close(otf, tol=n_ops)
199
- return otf
200
-
201
-
202
- def zero_pad(image, shape, position='corner'):
203
- """
204
- Extends image to a certain size with zeros
205
- Parameters
206
- ----------
207
- image: real 2d `numpy.ndarray`
208
- Input image
209
- shape: tuple of int
210
- Desired output shape of the image
211
- position : str, optional
212
- The position of the input image in the output one:
213
- * 'corner'
214
- top-left corner (default)
215
- * 'center'
216
- centered
217
- Returns
218
- -------
219
- padded_img: real `numpy.ndarray`
220
- The zero-padded image
221
- """
222
- shape = np.asarray(shape, dtype=int)
223
- imshape = np.asarray(image.shape, dtype=int)
224
- if np.alltrue(imshape == shape):
225
- return image
226
- if np.any(shape <= 0):
227
- raise ValueError("ZERO_PAD: null or negative shape given")
228
- dshape = shape - imshape
229
- if np.any(dshape < 0):
230
- raise ValueError("ZERO_PAD: target size smaller than source one")
231
- pad_img = np.zeros(shape, dtype=image.dtype)
232
- idx, idy = np.indices(imshape)
233
- if position == 'center':
234
- if np.any(dshape % 2 != 0):
235
- raise ValueError("ZERO_PAD: source and target shapes "
236
- "have different parity.")
237
- offx, offy = dshape // 2
238
- else:
239
- offx, offy = (0, 0)
240
- pad_img[idx + offx, idy + offy] = image
241
- return pad_img
242
-
243
-
244
- '''
245
- Reducing boundary artifacts
246
- '''
247
-
248
-
249
- def opt_fft_size(n):
250
- '''
251
- Kai Zhang (github: https://github.com/cszn)
252
- 03/03/2019
253
- # opt_fft_size.m
254
- # compute an optimal data length for Fourier transforms
255
- # written by Sunghyun Cho (sodomau@postech.ac.kr)
256
- # persistent opt_fft_size_LUT;
257
- '''
258
-
259
- LUT_size = 2048
260
- # print("generate opt_fft_size_LUT")
261
- opt_fft_size_LUT = np.zeros(LUT_size)
262
-
263
- e2 = 1
264
- while e2 <= LUT_size:
265
- e3 = e2
266
- while e3 <= LUT_size:
267
- e5 = e3
268
- while e5 <= LUT_size:
269
- e7 = e5
270
- while e7 <= LUT_size:
271
- if e7 <= LUT_size:
272
- opt_fft_size_LUT[e7-1] = e7
273
- if e7*11 <= LUT_size:
274
- opt_fft_size_LUT[e7*11-1] = e7*11
275
- if e7*13 <= LUT_size:
276
- opt_fft_size_LUT[e7*13-1] = e7*13
277
- e7 = e7 * 7
278
- e5 = e5 * 5
279
- e3 = e3 * 3
280
- e2 = e2 * 2
281
-
282
- nn = 0
283
- for i in range(LUT_size, 0, -1):
284
- if opt_fft_size_LUT[i-1] != 0:
285
- nn = i-1
286
- else:
287
- opt_fft_size_LUT[i-1] = nn+1
288
-
289
- m = np.zeros(len(n))
290
- for c in range(len(n)):
291
- nn = n[c]
292
- if nn <= LUT_size:
293
- m[c] = opt_fft_size_LUT[nn-1]
294
- else:
295
- m[c] = -1
296
- return m
297
-
298
-
299
- def wrap_boundary_liu(img, img_size):
300
-
301
- """
302
- Reducing boundary artifacts in image deconvolution
303
- Renting Liu, Jiaya Jia
304
- ICIP 2008
305
- """
306
- if img.ndim == 2:
307
- ret = wrap_boundary(img, img_size)
308
- elif img.ndim == 3:
309
- ret = [wrap_boundary(img[:, :, i], img_size) for i in range(3)]
310
- ret = np.stack(ret, 2)
311
- return ret
312
-
313
-
314
- def wrap_boundary(img, img_size):
315
-
316
- """
317
- python code from:
318
- https://github.com/ys-koshelev/nla_deblur/blob/90fe0ab98c26c791dcbdf231fe6f938fca80e2a0/boundaries.py
319
- Reducing boundary artifacts in image deconvolution
320
- Renting Liu, Jiaya Jia
321
- ICIP 2008
322
- """
323
- (H, W) = np.shape(img)
324
- H_w = int(img_size[0]) - H
325
- W_w = int(img_size[1]) - W
326
-
327
- # ret = np.zeros((img_size[0], img_size[1]));
328
- alpha = 1
329
- HG = img[:, :]
330
-
331
- r_A = np.zeros((alpha*2+H_w, W))
332
- r_A[:alpha, :] = HG[-alpha:, :]
333
- r_A[-alpha:, :] = HG[:alpha, :]
334
- a = np.arange(H_w)/(H_w-1)
335
- # r_A(alpha+1:end-alpha, 1) = (1-a)*r_A(alpha,1) + a*r_A(end-alpha+1,1)
336
- r_A[alpha:-alpha, 0] = (1-a)*r_A[alpha-1, 0] + a*r_A[-alpha, 0]
337
- # r_A(alpha+1:end-alpha, end) = (1-a)*r_A(alpha,end) + a*r_A(end-alpha+1,end)
338
- r_A[alpha:-alpha, -1] = (1-a)*r_A[alpha-1, -1] + a*r_A[-alpha, -1]
339
-
340
- r_B = np.zeros((H, alpha*2+W_w))
341
- r_B[:, :alpha] = HG[:, -alpha:]
342
- r_B[:, -alpha:] = HG[:, :alpha]
343
- a = np.arange(W_w)/(W_w-1)
344
- r_B[0, alpha:-alpha] = (1-a)*r_B[0, alpha-1] + a*r_B[0, -alpha]
345
- r_B[-1, alpha:-alpha] = (1-a)*r_B[-1, alpha-1] + a*r_B[-1, -alpha]
346
-
347
- if alpha == 1:
348
- A2 = solve_min_laplacian(r_A[alpha-1:, :])
349
- B2 = solve_min_laplacian(r_B[:, alpha-1:])
350
- r_A[alpha-1:, :] = A2
351
- r_B[:, alpha-1:] = B2
352
- else:
353
- A2 = solve_min_laplacian(r_A[alpha-1:-alpha+1, :])
354
- r_A[alpha-1:-alpha+1, :] = A2
355
- B2 = solve_min_laplacian(r_B[:, alpha-1:-alpha+1])
356
- r_B[:, alpha-1:-alpha+1] = B2
357
- A = r_A
358
- B = r_B
359
-
360
- r_C = np.zeros((alpha*2+H_w, alpha*2+W_w))
361
- r_C[:alpha, :] = B[-alpha:, :]
362
- r_C[-alpha:, :] = B[:alpha, :]
363
- r_C[:, :alpha] = A[:, -alpha:]
364
- r_C[:, -alpha:] = A[:, :alpha]
365
-
366
- if alpha == 1:
367
- C2 = C2 = solve_min_laplacian(r_C[alpha-1:, alpha-1:])
368
- r_C[alpha-1:, alpha-1:] = C2
369
- else:
370
- C2 = solve_min_laplacian(r_C[alpha-1:-alpha+1, alpha-1:-alpha+1])
371
- r_C[alpha-1:-alpha+1, alpha-1:-alpha+1] = C2
372
- C = r_C
373
- # return C
374
- A = A[alpha-1:-alpha-1, :]
375
- B = B[:, alpha:-alpha]
376
- C = C[alpha:-alpha, alpha:-alpha]
377
- ret = np.vstack((np.hstack((img, B)), np.hstack((A, C))))
378
- return ret
379
-
380
-
381
- def solve_min_laplacian(boundary_image):
382
- (H, W) = np.shape(boundary_image)
383
-
384
- # Laplacian
385
- f = np.zeros((H, W))
386
- # boundary image contains image intensities at boundaries
387
- boundary_image[1:-1, 1:-1] = 0
388
- j = np.arange(2, H)-1
389
- k = np.arange(2, W)-1
390
- f_bp = np.zeros((H, W))
391
- f_bp[np.ix_(j, k)] = -4*boundary_image[np.ix_(j, k)] + boundary_image[np.ix_(j, k+1)] + boundary_image[np.ix_(j, k-1)] + boundary_image[np.ix_(j-1, k)] + boundary_image[np.ix_(j+1, k)]
392
-
393
- del(j, k)
394
- f1 = f - f_bp # subtract boundary points contribution
395
- del(f_bp, f)
396
-
397
- # DST Sine Transform algo starts here
398
- f2 = f1[1:-1,1:-1]
399
- del(f1)
400
-
401
- # compute sine tranform
402
- if f2.shape[1] == 1:
403
- tt = fftpack.dst(f2, type=1, axis=0)/2
404
- else:
405
- tt = fftpack.dst(f2, type=1)/2
406
-
407
- if tt.shape[0] == 1:
408
- f2sin = np.transpose(fftpack.dst(np.transpose(tt), type=1, axis=0)/2)
409
- else:
410
- f2sin = np.transpose(fftpack.dst(np.transpose(tt), type=1)/2)
411
- del(f2)
412
-
413
- # compute Eigen Values
414
- [x, y] = np.meshgrid(np.arange(1, W-1), np.arange(1, H-1))
415
- denom = (2*np.cos(np.pi*x/(W-1))-2) + (2*np.cos(np.pi*y/(H-1)) - 2)
416
-
417
- # divide
418
- f3 = f2sin/denom
419
- del(f2sin, x, y)
420
-
421
- # compute Inverse Sine Transform
422
- if f3.shape[0] == 1:
423
- tt = fftpack.idst(f3*2, type=1, axis=1)/(2*(f3.shape[1]+1))
424
- else:
425
- tt = fftpack.idst(f3*2, type=1, axis=0)/(2*(f3.shape[0]+1))
426
- del(f3)
427
- if tt.shape[1] == 1:
428
- img_tt = np.transpose(fftpack.idst(np.transpose(tt)*2, type=1)/(2*(tt.shape[0]+1)))
429
- else:
430
- img_tt = np.transpose(fftpack.idst(np.transpose(tt)*2, type=1, axis=0)/(2*(tt.shape[1]+1)))
431
- del(tt)
432
-
433
- # put solution in inner points; outer points obtained from boundary image
434
- img_direct = boundary_image
435
- img_direct[1:-1, 1:-1] = 0
436
- img_direct[1:-1, 1:-1] = img_tt
437
- return img_direct
438
-
439
-
440
- """
441
- Created on Thu Jan 18 15:36:32 2018
442
- @author: italo
443
- https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
444
- """
445
-
446
- """
447
- Syntax
448
- h = fspecial(type)
449
- h = fspecial('average',hsize)
450
- h = fspecial('disk',radius)
451
- h = fspecial('gaussian',hsize,sigma)
452
- h = fspecial('laplacian',alpha)
453
- h = fspecial('log',hsize,sigma)
454
- h = fspecial('motion',len,theta)
455
- h = fspecial('prewitt')
456
- h = fspecial('sobel')
457
- """
458
-
459
-
460
- def fspecial_average(hsize=3):
461
- """Smoothing filter"""
462
- return np.ones((hsize, hsize))/hsize**2
463
-
464
-
465
- def fspecial_disk(radius):
466
- """Disk filter"""
467
- raise(NotImplemented)
468
- rad = 0.6
469
- crad = np.ceil(rad-0.5)
470
- [x, y] = np.meshgrid(np.arange(-crad, crad+1), np.arange(-crad, crad+1))
471
- maxxy = np.zeros(x.shape)
472
- maxxy[abs(x) >= abs(y)] = abs(x)[abs(x) >= abs(y)]
473
- maxxy[abs(y) >= abs(x)] = abs(y)[abs(y) >= abs(x)]
474
- minxy = np.zeros(x.shape)
475
- minxy[abs(x) <= abs(y)] = abs(x)[abs(x) <= abs(y)]
476
- minxy[abs(y) <= abs(x)] = abs(y)[abs(y) <= abs(x)]
477
- m1 = (rad**2 < (maxxy+0.5)**2 + (minxy-0.5)**2)*(minxy-0.5) +\
478
- (rad**2 >= (maxxy+0.5)**2 + (minxy-0.5)**2)*\
479
- np.sqrt((rad**2 + 0j) - (maxxy + 0.5)**2)
480
- m2 = (rad**2 > (maxxy-0.5)**2 + (minxy+0.5)**2)*(minxy+0.5) +\
481
- (rad**2 <= (maxxy-0.5)**2 + (minxy+0.5)**2)*\
482
- np.sqrt((rad**2 + 0j) - (maxxy - 0.5)**2)
483
- h = None
484
- return h
485
-
486
-
487
- def fspecial_gaussian(hsize, sigma):
488
- hsize = [hsize, hsize]
489
- siz = [(hsize[0]-1.0)/2.0, (hsize[1]-1.0)/2.0]
490
- std = sigma
491
- [x, y] = np.meshgrid(np.arange(-siz[1], siz[1]+1), np.arange(-siz[0], siz[0]+1))
492
- arg = -(x*x + y*y)/(2*std*std)
493
- h = np.exp(arg)
494
- h[h < scipy.finfo(float).eps * h.max()] = 0
495
- sumh = h.sum()
496
- if sumh != 0:
497
- h = h/sumh
498
- return h
499
-
500
-
501
- def fspecial_laplacian(alpha):
502
- alpha = max([0, min([alpha,1])])
503
- h1 = alpha/(alpha+1)
504
- h2 = (1-alpha)/(alpha+1)
505
- h = [[h1, h2, h1], [h2, -4/(alpha+1), h2], [h1, h2, h1]]
506
- h = np.array(h)
507
- return h
508
-
509
-
510
- def fspecial_log(hsize, sigma):
511
- raise(NotImplemented)
512
-
513
-
514
- def fspecial_motion(motion_len, theta):
515
- raise(NotImplemented)
516
-
517
-
518
- def fspecial_prewitt():
519
- return np.array([[1, 1, 1], [0, 0, 0], [-1, -1, -1]])
520
-
521
-
522
- def fspecial_sobel():
523
- return np.array([[1, 2, 1], [0, 0, 0], [-1, -2, -1]])
524
-
525
-
526
- def fspecial(filter_type, *args, **kwargs):
527
- '''
528
- python code from:
529
- https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
530
- '''
531
- if filter_type == 'average':
532
- return fspecial_average(*args, **kwargs)
533
- if filter_type == 'disk':
534
- return fspecial_disk(*args, **kwargs)
535
- if filter_type == 'gaussian':
536
- return fspecial_gaussian(*args, **kwargs)
537
- if filter_type == 'laplacian':
538
- return fspecial_laplacian(*args, **kwargs)
539
- if filter_type == 'log':
540
- return fspecial_log(*args, **kwargs)
541
- if filter_type == 'motion':
542
- return fspecial_motion(*args, **kwargs)
543
- if filter_type == 'prewitt':
544
- return fspecial_prewitt(*args, **kwargs)
545
- if filter_type == 'sobel':
546
- return fspecial_sobel(*args, **kwargs)
547
-
548
-
549
- def fspecial_gauss(size, sigma):
550
- x, y = mgrid[-size // 2 + 1 : size // 2 + 1, -size // 2 + 1 : size // 2 + 1]
551
- g = exp(-((x ** 2 + y ** 2) / (2.0 * sigma ** 2)))
552
- return g / g.sum()
553
-
554
-
555
- def blurkernel_synthesis(h=37, w=None):
556
- # https://github.com/tkkcc/prior/blob/879a0b6c117c810776d8cc6b63720bf29f7d0cc4/util/gen_kernel.py
557
- w = h if w is None else w
558
- kdims = [h, w]
559
- x = randomTrajectory(250)
560
- k = None
561
- while k is None:
562
- k = kernelFromTrajectory(x)
563
-
564
- # center pad to kdims
565
- pad_width = ((kdims[0] - k.shape[0]) // 2, (kdims[1] - k.shape[1]) // 2)
566
- pad_width = [(pad_width[0],), (pad_width[1],)]
567
-
568
- if pad_width[0][0]<0 or pad_width[1][0]<0:
569
- k = k[0:h, 0:h]
570
- else:
571
- k = pad(k, pad_width, "constant")
572
- x1,x2 = k.shape
573
- if np.random.randint(0, 4) == 1:
574
- k = cv2.resize(k, (random.randint(x1, 5*x1), random.randint(x2, 5*x2)), interpolation=cv2.INTER_LINEAR)
575
- y1, y2 = k.shape
576
- k = k[(y1-x1)//2: (y1-x1)//2+x1, (y2-x2)//2: (y2-x2)//2+x2]
577
-
578
- if sum(k)<0.1:
579
- k = fspecial_gaussian(h, 0.1+6*np.random.rand(1))
580
- k = k / sum(k)
581
- # import matplotlib.pyplot as plt
582
- # plt.imshow(k, interpolation="nearest", cmap="gray")
583
- # plt.show()
584
- return k
585
-
586
-
587
- def kernelFromTrajectory(x):
588
- h = 5 - log(rand()) / 0.15
589
- h = round(min([h, 27])).astype(int)
590
- h = h + 1 - h % 2
591
- w = h
592
- k = zeros((h, w))
593
-
594
- xmin = min(x[0])
595
- xmax = max(x[0])
596
- ymin = min(x[1])
597
- ymax = max(x[1])
598
- xthr = arange(xmin, xmax, (xmax - xmin) / w)
599
- ythr = arange(ymin, ymax, (ymax - ymin) / h)
600
-
601
- for i in range(1, xthr.size):
602
- for j in range(1, ythr.size):
603
- idx = (
604
- (x[0, :] >= xthr[i - 1])
605
- & (x[0, :] < xthr[i])
606
- & (x[1, :] >= ythr[j - 1])
607
- & (x[1, :] < ythr[j])
608
- )
609
- k[i - 1, j - 1] = sum(idx)
610
- if sum(k) == 0:
611
- return
612
- k = k / sum(k)
613
- k = convolve2d(k, fspecial_gauss(3, 1), "same")
614
- k = k / sum(k)
615
- return k
616
-
617
-
618
- def randomTrajectory(T):
619
- x = zeros((3, T))
620
- v = randn(3, T)
621
- r = zeros((3, T))
622
- trv = 1 / 1
623
- trr = 2 * pi / T
624
- for t in range(1, T):
625
- F_rot = randn(3) / (t + 1) + r[:, t - 1]
626
- F_trans = randn(3) / (t + 1)
627
- r[:, t] = r[:, t - 1] + trr * F_rot
628
- v[:, t] = v[:, t - 1] + trv * F_trans
629
- st = v[:, t]
630
- st = rot3D(st, r[:, t])
631
- x[:, t] = x[:, t - 1] + st
632
- return x
633
-
634
-
635
- def rot3D(x, r):
636
- Rx = array([[1, 0, 0], [0, cos(r[0]), -sin(r[0])], [0, sin(r[0]), cos(r[0])]])
637
- Ry = array([[cos(r[1]), 0, sin(r[1])], [0, 1, 0], [-sin(r[1]), 0, cos(r[1])]])
638
- Rz = array([[cos(r[2]), -sin(r[2]), 0], [sin(r[2]), cos(r[2]), 0], [0, 0, 1]])
639
- R = Rz @ Ry @ Rx
640
- x = R @ x
641
- return x
642
-
643
-
644
- if __name__ == '__main__':
645
- a = opt_fft_size([111])
646
- print(a)
647
-
648
- print(fspecial('gaussian', 5, 1))
649
-
650
- print(p2o(torch.zeros(1,1,4,4).float(),(14,14)).shape)
651
-
652
- k = blurkernel_synthesis(11)
653
- import matplotlib.pyplot as plt
654
- plt.imshow(k, interpolation="nearest", cmap="gray")
655
- plt.show()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
core/data/deg_kair_utils/utils_dist.py DELETED
@@ -1,201 +0,0 @@
1
- # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501
2
- import functools
3
- import os
4
- import subprocess
5
- import torch
6
- import torch.distributed as dist
7
- import torch.multiprocessing as mp
8
-
9
-
10
- # ----------------------------------
11
- # init
12
- # ----------------------------------
13
- def init_dist(launcher, backend='nccl', **kwargs):
14
- if mp.get_start_method(allow_none=True) is None:
15
- mp.set_start_method('spawn')
16
- if launcher == 'pytorch':
17
- _init_dist_pytorch(backend, **kwargs)
18
- elif launcher == 'slurm':
19
- _init_dist_slurm(backend, **kwargs)
20
- else:
21
- raise ValueError(f'Invalid launcher type: {launcher}')
22
-
23
-
24
- def _init_dist_pytorch(backend, **kwargs):
25
- rank = int(os.environ['RANK'])
26
- num_gpus = torch.cuda.device_count()
27
- torch.cuda.set_device(rank % num_gpus)
28
- dist.init_process_group(backend=backend, **kwargs)
29
-
30
-
31
- def _init_dist_slurm(backend, port=None):
32
- """Initialize slurm distributed training environment.
33
- If argument ``port`` is not specified, then the master port will be system
34
- environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
35
- environment variable, then a default port ``29500`` will be used.
36
- Args:
37
- backend (str): Backend of torch.distributed.
38
- port (int, optional): Master port. Defaults to None.
39
- """
40
- proc_id = int(os.environ['SLURM_PROCID'])
41
- ntasks = int(os.environ['SLURM_NTASKS'])
42
- node_list = os.environ['SLURM_NODELIST']
43
- num_gpus = torch.cuda.device_count()
44
- torch.cuda.set_device(proc_id % num_gpus)
45
- addr = subprocess.getoutput(
46
- f'scontrol show hostname {node_list} | head -n1')
47
- # specify master port
48
- if port is not None:
49
- os.environ['MASTER_PORT'] = str(port)
50
- elif 'MASTER_PORT' in os.environ:
51
- pass # use MASTER_PORT in the environment variable
52
- else:
53
- # 29500 is torch.distributed default port
54
- os.environ['MASTER_PORT'] = '29500'
55
- os.environ['MASTER_ADDR'] = addr
56
- os.environ['WORLD_SIZE'] = str(ntasks)
57
- os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
58
- os.environ['RANK'] = str(proc_id)
59
- dist.init_process_group(backend=backend)
60
-
61
-
62
-
63
- # ----------------------------------
64
- # get rank and world_size
65
- # ----------------------------------
66
- def get_dist_info():
67
- if dist.is_available():
68
- initialized = dist.is_initialized()
69
- else:
70
- initialized = False
71
- if initialized:
72
- rank = dist.get_rank()
73
- world_size = dist.get_world_size()
74
- else:
75
- rank = 0
76
- world_size = 1
77
- return rank, world_size
78
-
79
-
80
- def get_rank():
81
- if not dist.is_available():
82
- return 0
83
-
84
- if not dist.is_initialized():
85
- return 0
86
-
87
- return dist.get_rank()
88
-
89
-
90
- def get_world_size():
91
- if not dist.is_available():
92
- return 1
93
-
94
- if not dist.is_initialized():
95
- return 1
96
-
97
- return dist.get_world_size()
98
-
99
-
100
- def master_only(func):
101
-
102
- @functools.wraps(func)
103
- def wrapper(*args, **kwargs):
104
- rank, _ = get_dist_info()
105
- if rank == 0:
106
- return func(*args, **kwargs)
107
-
108
- return wrapper
109
-
110
-
111
-
112
-
113
-
114
-
115
- # ----------------------------------
116
- # operation across ranks
117
- # ----------------------------------
118
- def reduce_sum(tensor):
119
- if not dist.is_available():
120
- return tensor
121
-
122
- if not dist.is_initialized():
123
- return tensor
124
-
125
- tensor = tensor.clone()
126
- dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
127
-
128
- return tensor
129
-
130
-
131
- def gather_grad(params):
132
- world_size = get_world_size()
133
-
134
- if world_size == 1:
135
- return
136
-
137
- for param in params:
138
- if param.grad is not None:
139
- dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
140
- param.grad.data.div_(world_size)
141
-
142
-
143
- def all_gather(data):
144
- world_size = get_world_size()
145
-
146
- if world_size == 1:
147
- return [data]
148
-
149
- buffer = pickle.dumps(data)
150
- storage = torch.ByteStorage.from_buffer(buffer)
151
- tensor = torch.ByteTensor(storage).to('cuda')
152
-
153
- local_size = torch.IntTensor([tensor.numel()]).to('cuda')
154
- size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)]
155
- dist.all_gather(size_list, local_size)
156
- size_list = [int(size.item()) for size in size_list]
157
- max_size = max(size_list)
158
-
159
- tensor_list = []
160
- for _ in size_list:
161
- tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda'))
162
-
163
- if local_size != max_size:
164
- padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda')
165
- tensor = torch.cat((tensor, padding), 0)
166
-
167
- dist.all_gather(tensor_list, tensor)
168
-
169
- data_list = []
170
-
171
- for size, tensor in zip(size_list, tensor_list):
172
- buffer = tensor.cpu().numpy().tobytes()[:size]
173
- data_list.append(pickle.loads(buffer))
174
-
175
- return data_list
176
-
177
-
178
- def reduce_loss_dict(loss_dict):
179
- world_size = get_world_size()
180
-
181
- if world_size < 2:
182
- return loss_dict
183
-
184
- with torch.no_grad():
185
- keys = []
186
- losses = []
187
-
188
- for k in sorted(loss_dict.keys()):
189
- keys.append(k)
190
- losses.append(loss_dict[k])
191
-
192
- losses = torch.stack(losses, 0)
193
- dist.reduce(losses, dst=0)
194
-
195
- if dist.get_rank() == 0:
196
- losses /= world_size
197
-
198
- reduced_losses = {k: v for k, v in zip(keys, losses)}
199
-
200
- return reduced_losses
201
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
core/data/deg_kair_utils/utils_googledownload.py DELETED
@@ -1,93 +0,0 @@
1
- import math
2
- import requests
3
- from tqdm import tqdm
4
-
5
-
6
- '''
7
- borrowed from
8
- https://github.com/xinntao/BasicSR/blob/28883e15eedc3381d23235ff3cf7c454c4be87e6/basicsr/utils/download_util.py
9
- '''
10
-
11
-
12
- def sizeof_fmt(size, suffix='B'):
13
- """Get human readable file size.
14
- Args:
15
- size (int): File size.
16
- suffix (str): Suffix. Default: 'B'.
17
- Return:
18
- str: Formated file siz.
19
- """
20
- for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
21
- if abs(size) < 1024.0:
22
- return f'{size:3.1f} {unit}{suffix}'
23
- size /= 1024.0
24
- return f'{size:3.1f} Y{suffix}'
25
-
26
-
27
- def download_file_from_google_drive(file_id, save_path):
28
- """Download files from google drive.
29
- Ref:
30
- https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501
31
- Args:
32
- file_id (str): File id.
33
- save_path (str): Save path.
34
- """
35
-
36
- session = requests.Session()
37
- URL = 'https://docs.google.com/uc?export=download'
38
- params = {'id': file_id}
39
-
40
- response = session.get(URL, params=params, stream=True)
41
- token = get_confirm_token(response)
42
- if token:
43
- params['confirm'] = token
44
- response = session.get(URL, params=params, stream=True)
45
-
46
- # get file size
47
- response_file_size = session.get(
48
- URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
49
- if 'Content-Range' in response_file_size.headers:
50
- file_size = int(
51
- response_file_size.headers['Content-Range'].split('/')[1])
52
- else:
53
- file_size = None
54
-
55
- save_response_content(response, save_path, file_size)
56
-
57
-
58
- def get_confirm_token(response):
59
- for key, value in response.cookies.items():
60
- if key.startswith('download_warning'):
61
- return value
62
- return None
63
-
64
-
65
- def save_response_content(response,
66
- destination,
67
- file_size=None,
68
- chunk_size=32768):
69
- if file_size is not None:
70
- pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
71
-
72
- readable_file_size = sizeof_fmt(file_size)
73
- else:
74
- pbar = None
75
-
76
- with open(destination, 'wb') as f:
77
- downloaded_size = 0
78
- for chunk in response.iter_content(chunk_size):
79
- downloaded_size += chunk_size
80
- if pbar is not None:
81
- pbar.update(1)
82
- pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} '
83
- f'/ {readable_file_size}')
84
- if chunk: # filter out keep-alive new chunks
85
- f.write(chunk)
86
- if pbar is not None:
87
- pbar.close()
88
-
89
-
90
- if __name__ == "__main__":
91
- file_id = '1WNULM1e8gRNvsngVscsQ8tpaOqJ4mYtv'
92
- save_path = 'BSRGAN.pth'
93
- download_file_from_google_drive(file_id, save_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
core/data/deg_kair_utils/utils_image.py DELETED
@@ -1,1016 +0,0 @@
1
- import os
2
- import math
3
- import random
4
- import numpy as np
5
- import torch
6
- import cv2
7
- from torchvision.utils import make_grid
8
- from datetime import datetime
9
- # import torchvision.transforms as transforms
10
- import matplotlib.pyplot as plt
11
- from mpl_toolkits.mplot3d import Axes3D
12
- os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
13
-
14
-
15
- '''
16
- # --------------------------------------------
17
- # Kai Zhang (github: https://github.com/cszn)
18
- # 03/Mar/2019
19
- # --------------------------------------------
20
- # https://github.com/twhui/SRGAN-pyTorch
21
- # https://github.com/xinntao/BasicSR
22
- # --------------------------------------------
23
- '''
24
-
25
-
26
- IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif']
27
-
28
-
29
- def is_image_file(filename):
30
- return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
31
-
32
-
33
- def get_timestamp():
34
- return datetime.now().strftime('%y%m%d-%H%M%S')
35
-
36
-
37
- def imshow(x, title=None, cbar=False, figsize=None):
38
- plt.figure(figsize=figsize)
39
- plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray')
40
- if title:
41
- plt.title(title)
42
- if cbar:
43
- plt.colorbar()
44
- plt.show()
45
-
46
-
47
- def surf(Z, cmap='rainbow', figsize=None):
48
- plt.figure(figsize=figsize)
49
- ax3 = plt.axes(projection='3d')
50
-
51
- w, h = Z.shape[:2]
52
- xx = np.arange(0,w,1)
53
- yy = np.arange(0,h,1)
54
- X, Y = np.meshgrid(xx, yy)
55
- ax3.plot_surface(X,Y,Z,cmap=cmap)
56
- #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap)
57
- plt.show()
58
-
59
-
60
- '''
61
- # --------------------------------------------
62
- # get image pathes
63
- # --------------------------------------------
64
- '''
65
-
66
-
67
- def get_image_paths(dataroot):
68
- paths = None # return None if dataroot is None
69
- if isinstance(dataroot, str):
70
- paths = sorted(_get_paths_from_images(dataroot))
71
- elif isinstance(dataroot, list):
72
- paths = []
73
- for i in dataroot:
74
- paths += sorted(_get_paths_from_images(i))
75
- return paths
76
-
77
-
78
- def _get_paths_from_images(path):
79
- assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
80
- images = []
81
- for dirpath, _, fnames in sorted(os.walk(path)):
82
- for fname in sorted(fnames):
83
- if is_image_file(fname):
84
- img_path = os.path.join(dirpath, fname)
85
- images.append(img_path)
86
- assert images, '{:s} has no valid image file'.format(path)
87
- return images
88
-
89
-
90
- '''
91
- # --------------------------------------------
92
- # split large images into small images
93
- # --------------------------------------------
94
- '''
95
-
96
-
97
- def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
98
- w, h = img.shape[:2]
99
- patches = []
100
- if w > p_max and h > p_max:
101
- w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int))
102
- h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int))
103
- w1.append(w-p_size)
104
- h1.append(h-p_size)
105
- # print(w1)
106
- # print(h1)
107
- for i in w1:
108
- for j in h1:
109
- patches.append(img[i:i+p_size, j:j+p_size,:])
110
- else:
111
- patches.append(img)
112
-
113
- return patches
114
-
115
-
116
- def imssave(imgs, img_path):
117
- """
118
- imgs: list, N images of size WxHxC
119
- """
120
- img_name, ext = os.path.splitext(os.path.basename(img_path))
121
- for i, img in enumerate(imgs):
122
- if img.ndim == 3:
123
- img = img[:, :, [2, 1, 0]]
124
- new_path = os.path.join(os.path.dirname(img_path), img_name+str('_{:04d}'.format(i))+'.png')
125
- cv2.imwrite(new_path, img)
126
-
127
-
128
- def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=512, p_overlap=96, p_max=800):
129
- """
130
- split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),
131
- and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
132
- will be splitted.
133
-
134
- Args:
135
- original_dataroot:
136
- taget_dataroot:
137
- p_size: size of small images
138
- p_overlap: patch size in training is a good choice
139
- p_max: images with smaller size than (p_max)x(p_max) keep unchanged.
140
- """
141
- paths = get_image_paths(original_dataroot)
142
- for img_path in paths:
143
- # img_name, ext = os.path.splitext(os.path.basename(img_path))
144
- img = imread_uint(img_path, n_channels=n_channels)
145
- patches = patches_from_image(img, p_size, p_overlap, p_max)
146
- imssave(patches, os.path.join(taget_dataroot, os.path.basename(img_path)))
147
- #if original_dataroot == taget_dataroot:
148
- #del img_path
149
-
150
- '''
151
- # --------------------------------------------
152
- # makedir
153
- # --------------------------------------------
154
- '''
155
-
156
-
157
- def mkdir(path):
158
- if not os.path.exists(path):
159
- os.makedirs(path)
160
-
161
-
162
- def mkdirs(paths):
163
- if isinstance(paths, str):
164
- mkdir(paths)
165
- else:
166
- for path in paths:
167
- mkdir(path)
168
-
169
-
170
- def mkdir_and_rename(path):
171
- if os.path.exists(path):
172
- new_name = path + '_archived_' + get_timestamp()
173
- print('Path already exists. Rename it to [{:s}]'.format(new_name))
174
- os.rename(path, new_name)
175
- os.makedirs(path)
176
-
177
-
178
- '''
179
- # --------------------------------------------
180
- # read image from path
181
- # opencv is fast, but read BGR numpy image
182
- # --------------------------------------------
183
- '''
184
-
185
-
186
- # --------------------------------------------
187
- # get uint8 image of size HxWxn_channles (RGB)
188
- # --------------------------------------------
189
- def imread_uint(path, n_channels=3):
190
- # input: path
191
- # output: HxWx3(RGB or GGG), or HxWx1 (G)
192
- if n_channels == 1:
193
- img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE
194
- img = np.expand_dims(img, axis=2) # HxWx1
195
- elif n_channels == 3:
196
- img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G
197
- if img.ndim == 2:
198
- img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG
199
- else:
200
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB
201
- return img
202
-
203
-
204
- # --------------------------------------------
205
- # matlab's imwrite
206
- # --------------------------------------------
207
- def imsave(img, img_path):
208
- img = np.squeeze(img)
209
- if img.ndim == 3:
210
- img = img[:, :, [2, 1, 0]]
211
- cv2.imwrite(img_path, img)
212
-
213
- def imwrite(img, img_path):
214
- img = np.squeeze(img)
215
- if img.ndim == 3:
216
- img = img[:, :, [2, 1, 0]]
217
- cv2.imwrite(img_path, img)
218
-
219
-
220
-
221
- # --------------------------------------------
222
- # get single image of size HxWxn_channles (BGR)
223
- # --------------------------------------------
224
- def read_img(path):
225
- # read image by cv2
226
- # return: Numpy float32, HWC, BGR, [0,1]
227
- img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE
228
- img = img.astype(np.float32) / 255.
229
- if img.ndim == 2:
230
- img = np.expand_dims(img, axis=2)
231
- # some images have 4 channels
232
- if img.shape[2] > 3:
233
- img = img[:, :, :3]
234
- return img
235
-
236
-
237
- '''
238
- # --------------------------------------------
239
- # image format conversion
240
- # --------------------------------------------
241
- # numpy(single) <---> numpy(uint)
242
- # numpy(single) <---> tensor
243
- # numpy(uint) <---> tensor
244
- # --------------------------------------------
245
- '''
246
-
247
-
248
- # --------------------------------------------
249
- # numpy(single) [0, 1] <---> numpy(uint)
250
- # --------------------------------------------
251
-
252
-
253
- def uint2single(img):
254
-
255
- return np.float32(img/255.)
256
-
257
-
258
- def single2uint(img):
259
-
260
- return np.uint8((img.clip(0, 1)*255.).round())
261
-
262
-
263
- def uint162single(img):
264
-
265
- return np.float32(img/65535.)
266
-
267
-
268
- def single2uint16(img):
269
-
270
- return np.uint16((img.clip(0, 1)*65535.).round())
271
-
272
-
273
- # --------------------------------------------
274
- # numpy(uint) (HxWxC or HxW) <---> tensor
275
- # --------------------------------------------
276
-
277
-
278
- # convert uint to 4-dimensional torch tensor
279
- def uint2tensor4(img):
280
- if img.ndim == 2:
281
- img = np.expand_dims(img, axis=2)
282
- return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0)
283
-
284
-
285
- # convert uint to 3-dimensional torch tensor
286
- def uint2tensor3(img):
287
- if img.ndim == 2:
288
- img = np.expand_dims(img, axis=2)
289
- return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.)
290
-
291
-
292
- # convert 2/3/4-dimensional torch tensor to uint
293
- def tensor2uint(img):
294
- img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
295
- if img.ndim == 3:
296
- img = np.transpose(img, (1, 2, 0))
297
- return np.uint8((img*255.0).round())
298
-
299
-
300
- # --------------------------------------------
301
- # numpy(single) (HxWxC) <---> tensor
302
- # --------------------------------------------
303
-
304
-
305
- # convert single (HxWxC) to 3-dimensional torch tensor
306
- def single2tensor3(img):
307
- return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float()
308
-
309
-
310
- # convert single (HxWxC) to 4-dimensional torch tensor
311
- def single2tensor4(img):
312
- return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0)
313
-
314
-
315
- # convert torch tensor to single
316
- def tensor2single(img):
317
- img = img.data.squeeze().float().cpu().numpy()
318
- if img.ndim == 3:
319
- img = np.transpose(img, (1, 2, 0))
320
-
321
- return img
322
-
323
- # convert torch tensor to single
324
- def tensor2single3(img):
325
- img = img.data.squeeze().float().cpu().numpy()
326
- if img.ndim == 3:
327
- img = np.transpose(img, (1, 2, 0))
328
- elif img.ndim == 2:
329
- img = np.expand_dims(img, axis=2)
330
- return img
331
-
332
-
333
- def single2tensor5(img):
334
- return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0)
335
-
336
-
337
- def single32tensor5(img):
338
- return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0)
339
-
340
-
341
- def single42tensor4(img):
342
- return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()
343
-
344
-
345
- # from skimage.io import imread, imsave
346
- def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
347
- '''
348
- Converts a torch Tensor into an image Numpy array of BGR channel order
349
- Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
350
- Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
351
- '''
352
- tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp
353
- tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
354
- n_dim = tensor.dim()
355
- if n_dim == 4:
356
- n_img = len(tensor)
357
- img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
358
- img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
359
- elif n_dim == 3:
360
- img_np = tensor.numpy()
361
- img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
362
- elif n_dim == 2:
363
- img_np = tensor.numpy()
364
- else:
365
- raise TypeError(
366
- 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
367
- if out_type == np.uint8:
368
- img_np = (img_np * 255.0).round()
369
- # Important. Unlike matlab, numpy.uint8() WILL NOT round by default.
370
- return img_np.astype(out_type)
371
-
372
-
373
- '''
374
- # --------------------------------------------
375
- # Augmentation, flipe and/or rotate
376
- # --------------------------------------------
377
- # The following two are enough.
378
- # (1) augmet_img: numpy image of WxHxC or WxH
379
- # (2) augment_img_tensor4: tensor image 1xCxWxH
380
- # --------------------------------------------
381
- '''
382
-
383
-
384
- def augment_img(img, mode=0):
385
- '''Kai Zhang (github: https://github.com/cszn)
386
- '''
387
- if mode == 0:
388
- return img
389
- elif mode == 1:
390
- return np.flipud(np.rot90(img))
391
- elif mode == 2:
392
- return np.flipud(img)
393
- elif mode == 3:
394
- return np.rot90(img, k=3)
395
- elif mode == 4:
396
- return np.flipud(np.rot90(img, k=2))
397
- elif mode == 5:
398
- return np.rot90(img)
399
- elif mode == 6:
400
- return np.rot90(img, k=2)
401
- elif mode == 7:
402
- return np.flipud(np.rot90(img, k=3))
403
-
404
-
405
- def augment_img_tensor4(img, mode=0):
406
- '''Kai Zhang (github: https://github.com/cszn)
407
- '''
408
- if mode == 0:
409
- return img
410
- elif mode == 1:
411
- return img.rot90(1, [2, 3]).flip([2])
412
- elif mode == 2:
413
- return img.flip([2])
414
- elif mode == 3:
415
- return img.rot90(3, [2, 3])
416
- elif mode == 4:
417
- return img.rot90(2, [2, 3]).flip([2])
418
- elif mode == 5:
419
- return img.rot90(1, [2, 3])
420
- elif mode == 6:
421
- return img.rot90(2, [2, 3])
422
- elif mode == 7:
423
- return img.rot90(3, [2, 3]).flip([2])
424
-
425
-
426
- def augment_img_tensor(img, mode=0):
427
- '''Kai Zhang (github: https://github.com/cszn)
428
- '''
429
- img_size = img.size()
430
- img_np = img.data.cpu().numpy()
431
- if len(img_size) == 3:
432
- img_np = np.transpose(img_np, (1, 2, 0))
433
- elif len(img_size) == 4:
434
- img_np = np.transpose(img_np, (2, 3, 1, 0))
435
- img_np = augment_img(img_np, mode=mode)
436
- img_tensor = torch.from_numpy(np.ascontiguousarray(img_np))
437
- if len(img_size) == 3:
438
- img_tensor = img_tensor.permute(2, 0, 1)
439
- elif len(img_size) == 4:
440
- img_tensor = img_tensor.permute(3, 2, 0, 1)
441
-
442
- return img_tensor.type_as(img)
443
-
444
-
445
- def augment_img_np3(img, mode=0):
446
- if mode == 0:
447
- return img
448
- elif mode == 1:
449
- return img.transpose(1, 0, 2)
450
- elif mode == 2:
451
- return img[::-1, :, :]
452
- elif mode == 3:
453
- img = img[::-1, :, :]
454
- img = img.transpose(1, 0, 2)
455
- return img
456
- elif mode == 4:
457
- return img[:, ::-1, :]
458
- elif mode == 5:
459
- img = img[:, ::-1, :]
460
- img = img.transpose(1, 0, 2)
461
- return img
462
- elif mode == 6:
463
- img = img[:, ::-1, :]
464
- img = img[::-1, :, :]
465
- return img
466
- elif mode == 7:
467
- img = img[:, ::-1, :]
468
- img = img[::-1, :, :]
469
- img = img.transpose(1, 0, 2)
470
- return img
471
-
472
-
473
- def augment_imgs(img_list, hflip=True, rot=True):
474
- # horizontal flip OR rotate
475
- hflip = hflip and random.random() < 0.5
476
- vflip = rot and random.random() < 0.5
477
- rot90 = rot and random.random() < 0.5
478
-
479
- def _augment(img):
480
- if hflip:
481
- img = img[:, ::-1, :]
482
- if vflip:
483
- img = img[::-1, :, :]
484
- if rot90:
485
- img = img.transpose(1, 0, 2)
486
- return img
487
-
488
- return [_augment(img) for img in img_list]
489
-
490
-
491
- '''
492
- # --------------------------------------------
493
- # modcrop and shave
494
- # --------------------------------------------
495
- '''
496
-
497
-
498
- def modcrop(img_in, scale):
499
- # img_in: Numpy, HWC or HW
500
- img = np.copy(img_in)
501
- if img.ndim == 2:
502
- H, W = img.shape
503
- H_r, W_r = H % scale, W % scale
504
- img = img[:H - H_r, :W - W_r]
505
- elif img.ndim == 3:
506
- H, W, C = img.shape
507
- H_r, W_r = H % scale, W % scale
508
- img = img[:H - H_r, :W - W_r, :]
509
- else:
510
- raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
511
- return img
512
-
513
-
514
- def shave(img_in, border=0):
515
- # img_in: Numpy, HWC or HW
516
- img = np.copy(img_in)
517
- h, w = img.shape[:2]
518
- img = img[border:h-border, border:w-border]
519
- return img
520
-
521
-
522
- '''
523
- # --------------------------------------------
524
- # image processing process on numpy image
525
- # channel_convert(in_c, tar_type, img_list):
526
- # rgb2ycbcr(img, only_y=True):
527
- # bgr2ycbcr(img, only_y=True):
528
- # ycbcr2rgb(img):
529
- # --------------------------------------------
530
- '''
531
-
532
-
533
- def rgb2ycbcr(img, only_y=True):
534
- '''same as matlab rgb2ycbcr
535
- only_y: only return Y channel
536
- Input:
537
- uint8, [0, 255]
538
- float, [0, 1]
539
- '''
540
- in_img_type = img.dtype
541
- img.astype(np.float32)
542
- if in_img_type != np.uint8:
543
- img *= 255.
544
- # convert
545
- if only_y:
546
- rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
547
- else:
548
- rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
549
- [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
550
- if in_img_type == np.uint8:
551
- rlt = rlt.round()
552
- else:
553
- rlt /= 255.
554
- return rlt.astype(in_img_type)
555
-
556
-
557
- def ycbcr2rgb(img):
558
- '''same as matlab ycbcr2rgb
559
- Input:
560
- uint8, [0, 255]
561
- float, [0, 1]
562
- '''
563
- in_img_type = img.dtype
564
- img.astype(np.float32)
565
- if in_img_type != np.uint8:
566
- img *= 255.
567
- # convert
568
- rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
569
- [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
570
- rlt = np.clip(rlt, 0, 255)
571
- if in_img_type == np.uint8:
572
- rlt = rlt.round()
573
- else:
574
- rlt /= 255.
575
- return rlt.astype(in_img_type)
576
-
577
-
578
- def bgr2ycbcr(img, only_y=True):
579
- '''bgr version of rgb2ycbcr
580
- only_y: only return Y channel
581
- Input:
582
- uint8, [0, 255]
583
- float, [0, 1]
584
- '''
585
- in_img_type = img.dtype
586
- img.astype(np.float32)
587
- if in_img_type != np.uint8:
588
- img *= 255.
589
- # convert
590
- if only_y:
591
- rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
592
- else:
593
- rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
594
- [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
595
- if in_img_type == np.uint8:
596
- rlt = rlt.round()
597
- else:
598
- rlt /= 255.
599
- return rlt.astype(in_img_type)
600
-
601
-
602
- def channel_convert(in_c, tar_type, img_list):
603
- # conversion among BGR, gray and y
604
- if in_c == 3 and tar_type == 'gray': # BGR to gray
605
- gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
606
- return [np.expand_dims(img, axis=2) for img in gray_list]
607
- elif in_c == 3 and tar_type == 'y': # BGR to y
608
- y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
609
- return [np.expand_dims(img, axis=2) for img in y_list]
610
- elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR
611
- return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
612
- else:
613
- return img_list
614
-
615
-
616
- '''
617
- # --------------------------------------------
618
- # metric, PSNR, SSIM and PSNRB
619
- # --------------------------------------------
620
- '''
621
-
622
-
623
- # --------------------------------------------
624
- # PSNR
625
- # --------------------------------------------
626
- def calculate_psnr(img1, img2, border=0):
627
- # img1 and img2 have range [0, 255]
628
- #img1 = img1.squeeze()
629
- #img2 = img2.squeeze()
630
- if not img1.shape == img2.shape:
631
- raise ValueError('Input images must have the same dimensions.')
632
- h, w = img1.shape[:2]
633
- img1 = img1[border:h-border, border:w-border]
634
- img2 = img2[border:h-border, border:w-border]
635
-
636
- img1 = img1.astype(np.float64)
637
- img2 = img2.astype(np.float64)
638
- mse = np.mean((img1 - img2)**2)
639
- if mse == 0:
640
- return float('inf')
641
- return 20 * math.log10(255.0 / math.sqrt(mse))
642
-
643
-
644
- # --------------------------------------------
645
- # SSIM
646
- # --------------------------------------------
647
- def calculate_ssim(img1, img2, border=0):
648
- '''calculate SSIM
649
- the same outputs as MATLAB's
650
- img1, img2: [0, 255]
651
- '''
652
- #img1 = img1.squeeze()
653
- #img2 = img2.squeeze()
654
- if not img1.shape == img2.shape:
655
- raise ValueError('Input images must have the same dimensions.')
656
- h, w = img1.shape[:2]
657
- img1 = img1[border:h-border, border:w-border]
658
- img2 = img2[border:h-border, border:w-border]
659
-
660
- if img1.ndim == 2:
661
- return ssim(img1, img2)
662
- elif img1.ndim == 3:
663
- if img1.shape[2] == 3:
664
- ssims = []
665
- for i in range(3):
666
- ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
667
- return np.array(ssims).mean()
668
- elif img1.shape[2] == 1:
669
- return ssim(np.squeeze(img1), np.squeeze(img2))
670
- else:
671
- raise ValueError('Wrong input image dimensions.')
672
-
673
-
674
- def ssim(img1, img2):
675
- C1 = (0.01 * 255)**2
676
- C2 = (0.03 * 255)**2
677
-
678
- img1 = img1.astype(np.float64)
679
- img2 = img2.astype(np.float64)
680
- kernel = cv2.getGaussianKernel(11, 1.5)
681
- window = np.outer(kernel, kernel.transpose())
682
-
683
- mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
684
- mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
685
- mu1_sq = mu1**2
686
- mu2_sq = mu2**2
687
- mu1_mu2 = mu1 * mu2
688
- sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
689
- sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
690
- sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
691
-
692
- ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
693
- (sigma1_sq + sigma2_sq + C2))
694
- return ssim_map.mean()
695
-
696
-
697
- def _blocking_effect_factor(im):
698
- block_size = 8
699
-
700
- block_horizontal_positions = torch.arange(7, im.shape[3] - 1, 8)
701
- block_vertical_positions = torch.arange(7, im.shape[2] - 1, 8)
702
-
703
- horizontal_block_difference = (
704
- (im[:, :, :, block_horizontal_positions] - im[:, :, :, block_horizontal_positions + 1]) ** 2).sum(
705
- 3).sum(2).sum(1)
706
- vertical_block_difference = (
707
- (im[:, :, block_vertical_positions, :] - im[:, :, block_vertical_positions + 1, :]) ** 2).sum(3).sum(
708
- 2).sum(1)
709
-
710
- nonblock_horizontal_positions = np.setdiff1d(torch.arange(0, im.shape[3] - 1), block_horizontal_positions)
711
- nonblock_vertical_positions = np.setdiff1d(torch.arange(0, im.shape[2] - 1), block_vertical_positions)
712
-
713
- horizontal_nonblock_difference = (
714
- (im[:, :, :, nonblock_horizontal_positions] - im[:, :, :, nonblock_horizontal_positions + 1]) ** 2).sum(
715
- 3).sum(2).sum(1)
716
- vertical_nonblock_difference = (
717
- (im[:, :, nonblock_vertical_positions, :] - im[:, :, nonblock_vertical_positions + 1, :]) ** 2).sum(
718
- 3).sum(2).sum(1)
719
-
720
- n_boundary_horiz = im.shape[2] * (im.shape[3] // block_size - 1)
721
- n_boundary_vert = im.shape[3] * (im.shape[2] // block_size - 1)
722
- boundary_difference = (horizontal_block_difference + vertical_block_difference) / (
723
- n_boundary_horiz + n_boundary_vert)
724
-
725
- n_nonboundary_horiz = im.shape[2] * (im.shape[3] - 1) - n_boundary_horiz
726
- n_nonboundary_vert = im.shape[3] * (im.shape[2] - 1) - n_boundary_vert
727
- nonboundary_difference = (horizontal_nonblock_difference + vertical_nonblock_difference) / (
728
- n_nonboundary_horiz + n_nonboundary_vert)
729
-
730
- scaler = np.log2(block_size) / np.log2(min([im.shape[2], im.shape[3]]))
731
- bef = scaler * (boundary_difference - nonboundary_difference)
732
-
733
- bef[boundary_difference <= nonboundary_difference] = 0
734
- return bef
735
-
736
-
737
- def calculate_psnrb(img1, img2, border=0):
738
- """Calculate PSNR-B (Peak Signal-to-Noise Ratio).
739
- Ref: Quality assessment of deblocked images, for JPEG image deblocking evaluation
740
- # https://gitlab.com/Queuecumber/quantization-guided-ac/-/blob/master/metrics/psnrb.py
741
- Args:
742
- img1 (ndarray): Images with range [0, 255].
743
- img2 (ndarray): Images with range [0, 255].
744
- border (int): Cropped pixels in each edge of an image. These
745
- pixels are not involved in the PSNR calculation.
746
- test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
747
- Returns:
748
- float: psnr result.
749
- """
750
-
751
- if not img1.shape == img2.shape:
752
- raise ValueError('Input images must have the same dimensions.')
753
-
754
- if img1.ndim == 2:
755
- img1, img2 = np.expand_dims(img1, 2), np.expand_dims(img2, 2)
756
-
757
- h, w = img1.shape[:2]
758
- img1 = img1[border:h-border, border:w-border]
759
- img2 = img2[border:h-border, border:w-border]
760
-
761
- img1 = img1.astype(np.float64)
762
- img2 = img2.astype(np.float64)
763
-
764
- # follow https://gitlab.com/Queuecumber/quantization-guided-ac/-/blob/master/metrics/psnrb.py
765
- img1 = torch.from_numpy(img1).permute(2, 0, 1).unsqueeze(0) / 255.
766
- img2 = torch.from_numpy(img2).permute(2, 0, 1).unsqueeze(0) / 255.
767
-
768
- total = 0
769
- for c in range(img1.shape[1]):
770
- mse = torch.nn.functional.mse_loss(img1[:, c:c + 1, :, :], img2[:, c:c + 1, :, :], reduction='none')
771
- bef = _blocking_effect_factor(img1[:, c:c + 1, :, :])
772
-
773
- mse = mse.view(mse.shape[0], -1).mean(1)
774
- total += 10 * torch.log10(1 / (mse + bef))
775
-
776
- return float(total) / img1.shape[1]
777
-
778
- '''
779
- # --------------------------------------------
780
- # matlab's bicubic imresize (numpy and torch) [0, 1]
781
- # --------------------------------------------
782
- '''
783
-
784
-
785
- # matlab 'imresize' function, now only support 'bicubic'
786
- def cubic(x):
787
- absx = torch.abs(x)
788
- absx2 = absx**2
789
- absx3 = absx**3
790
- return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \
791
- (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx))
792
-
793
-
794
- def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
795
- if (scale < 1) and (antialiasing):
796
- # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
797
- kernel_width = kernel_width / scale
798
-
799
- # Output-space coordinates
800
- x = torch.linspace(1, out_length, out_length)
801
-
802
- # Input-space coordinates. Calculate the inverse mapping such that 0.5
803
- # in output space maps to 0.5 in input space, and 0.5+scale in output
804
- # space maps to 1.5 in input space.
805
- u = x / scale + 0.5 * (1 - 1 / scale)
806
-
807
- # What is the left-most pixel that can be involved in the computation?
808
- left = torch.floor(u - kernel_width / 2)
809
-
810
- # What is the maximum number of pixels that can be involved in the
811
- # computation? Note: it's OK to use an extra pixel here; if the
812
- # corresponding weights are all zero, it will be eliminated at the end
813
- # of this function.
814
- P = math.ceil(kernel_width) + 2
815
-
816
- # The indices of the input pixels involved in computing the k-th output
817
- # pixel are in row k of the indices matrix.
818
- indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
819
- 1, P).expand(out_length, P)
820
-
821
- # The weights used to compute the k-th output pixel are in row k of the
822
- # weights matrix.
823
- distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
824
- # apply cubic kernel
825
- if (scale < 1) and (antialiasing):
826
- weights = scale * cubic(distance_to_center * scale)
827
- else:
828
- weights = cubic(distance_to_center)
829
- # Normalize the weights matrix so that each row sums to 1.
830
- weights_sum = torch.sum(weights, 1).view(out_length, 1)
831
- weights = weights / weights_sum.expand(out_length, P)
832
-
833
- # If a column in weights is all zero, get rid of it. only consider the first and last column.
834
- weights_zero_tmp = torch.sum((weights == 0), 0)
835
- if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
836
- indices = indices.narrow(1, 1, P - 2)
837
- weights = weights.narrow(1, 1, P - 2)
838
- if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
839
- indices = indices.narrow(1, 0, P - 2)
840
- weights = weights.narrow(1, 0, P - 2)
841
- weights = weights.contiguous()
842
- indices = indices.contiguous()
843
- sym_len_s = -indices.min() + 1
844
- sym_len_e = indices.max() - in_length
845
- indices = indices + sym_len_s - 1
846
- return weights, indices, int(sym_len_s), int(sym_len_e)
847
-
848
-
849
- # --------------------------------------------
850
- # imresize for tensor image [0, 1]
851
- # --------------------------------------------
852
- def imresize(img, scale, antialiasing=True):
853
- # Now the scale should be the same for H and W
854
- # input: img: pytorch tensor, CHW or HW [0,1]
855
- # output: CHW or HW [0,1] w/o round
856
- need_squeeze = True if img.dim() == 2 else False
857
- if need_squeeze:
858
- img.unsqueeze_(0)
859
- in_C, in_H, in_W = img.size()
860
- out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
861
- kernel_width = 4
862
- kernel = 'cubic'
863
-
864
- # Return the desired dimension order for performing the resize. The
865
- # strategy is to perform the resize first along the dimension with the
866
- # smallest scale factor.
867
- # Now we do not support this.
868
-
869
- # get weights and indices
870
- weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
871
- in_H, out_H, scale, kernel, kernel_width, antialiasing)
872
- weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
873
- in_W, out_W, scale, kernel, kernel_width, antialiasing)
874
- # process H dimension
875
- # symmetric copying
876
- img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
877
- img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
878
-
879
- sym_patch = img[:, :sym_len_Hs, :]
880
- inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
881
- sym_patch_inv = sym_patch.index_select(1, inv_idx)
882
- img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
883
-
884
- sym_patch = img[:, -sym_len_He:, :]
885
- inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
886
- sym_patch_inv = sym_patch.index_select(1, inv_idx)
887
- img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
888
-
889
- out_1 = torch.FloatTensor(in_C, out_H, in_W)
890
- kernel_width = weights_H.size(1)
891
- for i in range(out_H):
892
- idx = int(indices_H[i][0])
893
- for j in range(out_C):
894
- out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
895
-
896
- # process W dimension
897
- # symmetric copying
898
- out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
899
- out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
900
-
901
- sym_patch = out_1[:, :, :sym_len_Ws]
902
- inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
903
- sym_patch_inv = sym_patch.index_select(2, inv_idx)
904
- out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
905
-
906
- sym_patch = out_1[:, :, -sym_len_We:]
907
- inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
908
- sym_patch_inv = sym_patch.index_select(2, inv_idx)
909
- out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
910
-
911
- out_2 = torch.FloatTensor(in_C, out_H, out_W)
912
- kernel_width = weights_W.size(1)
913
- for i in range(out_W):
914
- idx = int(indices_W[i][0])
915
- for j in range(out_C):
916
- out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i])
917
- if need_squeeze:
918
- out_2.squeeze_()
919
- return out_2
920
-
921
-
922
- # --------------------------------------------
923
- # imresize for numpy image [0, 1]
924
- # --------------------------------------------
925
- def imresize_np(img, scale, antialiasing=True):
926
- # Now the scale should be the same for H and W
927
- # input: img: Numpy, HWC or HW [0,1]
928
- # output: HWC or HW [0,1] w/o round
929
- img = torch.from_numpy(img)
930
- need_squeeze = True if img.dim() == 2 else False
931
- if need_squeeze:
932
- img.unsqueeze_(2)
933
-
934
- in_H, in_W, in_C = img.size()
935
- out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
936
- kernel_width = 4
937
- kernel = 'cubic'
938
-
939
- # Return the desired dimension order for performing the resize. The
940
- # strategy is to perform the resize first along the dimension with the
941
- # smallest scale factor.
942
- # Now we do not support this.
943
-
944
- # get weights and indices
945
- weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
946
- in_H, out_H, scale, kernel, kernel_width, antialiasing)
947
- weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
948
- in_W, out_W, scale, kernel, kernel_width, antialiasing)
949
- # process H dimension
950
- # symmetric copying
951
- img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
952
- img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
953
-
954
- sym_patch = img[:sym_len_Hs, :, :]
955
- inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
956
- sym_patch_inv = sym_patch.index_select(0, inv_idx)
957
- img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
958
-
959
- sym_patch = img[-sym_len_He:, :, :]
960
- inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
961
- sym_patch_inv = sym_patch.index_select(0, inv_idx)
962
- img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
963
-
964
- out_1 = torch.FloatTensor(out_H, in_W, in_C)
965
- kernel_width = weights_H.size(1)
966
- for i in range(out_H):
967
- idx = int(indices_H[i][0])
968
- for j in range(out_C):
969
- out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
970
-
971
- # process W dimension
972
- # symmetric copying
973
- out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
974
- out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
975
-
976
- sym_patch = out_1[:, :sym_len_Ws, :]
977
- inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
978
- sym_patch_inv = sym_patch.index_select(1, inv_idx)
979
- out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
980
-
981
- sym_patch = out_1[:, -sym_len_We:, :]
982
- inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
983
- sym_patch_inv = sym_patch.index_select(1, inv_idx)
984
- out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
985
-
986
- out_2 = torch.FloatTensor(out_H, out_W, in_C)
987
- kernel_width = weights_W.size(1)
988
- for i in range(out_W):
989
- idx = int(indices_W[i][0])
990
- for j in range(out_C):
991
- out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i])
992
- if need_squeeze:
993
- out_2.squeeze_()
994
-
995
- return out_2.numpy()
996
-
997
-
998
- if __name__ == '__main__':
999
- img = imread_uint('test.bmp', 3)
1000
- # img = uint2single(img)
1001
- # img_bicubic = imresize_np(img, 1/4)
1002
- # imshow(single2uint(img_bicubic))
1003
- #
1004
- # img_tensor = single2tensor4(img)
1005
- # for i in range(8):
1006
- # imshow(np.concatenate((augment_img(img, i), tensor2single(augment_img_tensor4(img_tensor, i))), 1))
1007
-
1008
- # patches = patches_from_image(img, p_size=128, p_overlap=0, p_max=200)
1009
- # imssave(patches,'a.png')
1010
-
1011
-
1012
-
1013
-
1014
-
1015
-
1016
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
core/data/deg_kair_utils/utils_lmdb.py DELETED
@@ -1,205 +0,0 @@
1
- import cv2
2
- import lmdb
3
- import sys
4
- from multiprocessing import Pool
5
- from os import path as osp
6
- from tqdm import tqdm
7
-
8
-
9
- def make_lmdb_from_imgs(data_path,
10
- lmdb_path,
11
- img_path_list,
12
- keys,
13
- batch=5000,
14
- compress_level=1,
15
- multiprocessing_read=False,
16
- n_thread=40,
17
- map_size=None):
18
- """Make lmdb from images.
19
-
20
- Contents of lmdb. The file structure is:
21
- example.lmdb
22
- ├── data.mdb
23
- ├── lock.mdb
24
- ├── meta_info.txt
25
-
26
- The data.mdb and lock.mdb are standard lmdb files and you can refer to
27
- https://lmdb.readthedocs.io/en/release/ for more details.
28
-
29
- The meta_info.txt is a specified txt file to record the meta information
30
- of our datasets. It will be automatically created when preparing
31
- datasets by our provided dataset tools.
32
- Each line in the txt file records 1)image name (with extension),
33
- 2)image shape, and 3)compression level, separated by a white space.
34
-
35
- For example, the meta information could be:
36
- `000_00000000.png (720,1280,3) 1`, which means:
37
- 1) image name (with extension): 000_00000000.png;
38
- 2) image shape: (720,1280,3);
39
- 3) compression level: 1
40
-
41
- We use the image name without extension as the lmdb key.
42
-
43
- If `multiprocessing_read` is True, it will read all the images to memory
44
- using multiprocessing. Thus, your server needs to have enough memory.
45
-
46
- Args:
47
- data_path (str): Data path for reading images.
48
- lmdb_path (str): Lmdb save path.
49
- img_path_list (str): Image path list.
50
- keys (str): Used for lmdb keys.
51
- batch (int): After processing batch images, lmdb commits.
52
- Default: 5000.
53
- compress_level (int): Compress level when encoding images. Default: 1.
54
- multiprocessing_read (bool): Whether use multiprocessing to read all
55
- the images to memory. Default: False.
56
- n_thread (int): For multiprocessing.
57
- map_size (int | None): Map size for lmdb env. If None, use the
58
- estimated size from images. Default: None
59
- """
60
-
61
- assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, '
62
- f'but got {len(img_path_list)} and {len(keys)}')
63
- print(f'Create lmdb for {data_path}, save to {lmdb_path}...')
64
- print(f'Totoal images: {len(img_path_list)}')
65
- if not lmdb_path.endswith('.lmdb'):
66
- raise ValueError("lmdb_path must end with '.lmdb'.")
67
- if osp.exists(lmdb_path):
68
- print(f'Folder {lmdb_path} already exists. Exit.')
69
- sys.exit(1)
70
-
71
- if multiprocessing_read:
72
- # read all the images to memory (multiprocessing)
73
- dataset = {} # use dict to keep the order for multiprocessing
74
- shapes = {}
75
- print(f'Read images with multiprocessing, #thread: {n_thread} ...')
76
- pbar = tqdm(total=len(img_path_list), unit='image')
77
-
78
- def callback(arg):
79
- """get the image data and update pbar."""
80
- key, dataset[key], shapes[key] = arg
81
- pbar.update(1)
82
- pbar.set_description(f'Read {key}')
83
-
84
- pool = Pool(n_thread)
85
- for path, key in zip(img_path_list, keys):
86
- pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback)
87
- pool.close()
88
- pool.join()
89
- pbar.close()
90
- print(f'Finish reading {len(img_path_list)} images.')
91
-
92
- # create lmdb environment
93
- if map_size is None:
94
- # obtain data size for one image
95
- img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED)
96
- _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
97
- data_size_per_img = img_byte.nbytes
98
- print('Data size per image is: ', data_size_per_img)
99
- data_size = data_size_per_img * len(img_path_list)
100
- map_size = data_size * 10
101
-
102
- env = lmdb.open(lmdb_path, map_size=map_size)
103
-
104
- # write data to lmdb
105
- pbar = tqdm(total=len(img_path_list), unit='chunk')
106
- txn = env.begin(write=True)
107
- txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
108
- for idx, (path, key) in enumerate(zip(img_path_list, keys)):
109
- pbar.update(1)
110
- pbar.set_description(f'Write {key}')
111
- key_byte = key.encode('ascii')
112
- if multiprocessing_read:
113
- img_byte = dataset[key]
114
- h, w, c = shapes[key]
115
- else:
116
- _, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level)
117
- h, w, c = img_shape
118
-
119
- txn.put(key_byte, img_byte)
120
- # write meta information
121
- txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n')
122
- if idx % batch == 0:
123
- txn.commit()
124
- txn = env.begin(write=True)
125
- pbar.close()
126
- txn.commit()
127
- env.close()
128
- txt_file.close()
129
- print('\nFinish writing lmdb.')
130
-
131
-
132
- def read_img_worker(path, key, compress_level):
133
- """Read image worker.
134
-
135
- Args:
136
- path (str): Image path.
137
- key (str): Image key.
138
- compress_level (int): Compress level when encoding images.
139
-
140
- Returns:
141
- str: Image key.
142
- byte: Image byte.
143
- tuple[int]: Image shape.
144
- """
145
-
146
- img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
147
- # deal with `libpng error: Read Error`
148
- if img is None:
149
- print(f'To deal with `libpng error: Read Error`, use PIL to load {path}')
150
- from PIL import Image
151
- import numpy as np
152
- img = Image.open(path)
153
- img = np.asanyarray(img)
154
- img = img[:, :, [2, 1, 0]]
155
-
156
- if img.ndim == 2:
157
- h, w = img.shape
158
- c = 1
159
- else:
160
- h, w, c = img.shape
161
- _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
162
- return (key, img_byte, (h, w, c))
163
-
164
-
165
- class LmdbMaker():
166
- """LMDB Maker.
167
-
168
- Args:
169
- lmdb_path (str): Lmdb save path.
170
- map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB.
171
- batch (int): After processing batch images, lmdb commits.
172
- Default: 5000.
173
- compress_level (int): Compress level when encoding images. Default: 1.
174
- """
175
-
176
- def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1):
177
- if not lmdb_path.endswith('.lmdb'):
178
- raise ValueError("lmdb_path must end with '.lmdb'.")
179
- if osp.exists(lmdb_path):
180
- print(f'Folder {lmdb_path} already exists. Exit.')
181
- sys.exit(1)
182
-
183
- self.lmdb_path = lmdb_path
184
- self.batch = batch
185
- self.compress_level = compress_level
186
- self.env = lmdb.open(lmdb_path, map_size=map_size)
187
- self.txn = self.env.begin(write=True)
188
- self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
189
- self.counter = 0
190
-
191
- def put(self, img_byte, key, img_shape):
192
- self.counter += 1
193
- key_byte = key.encode('ascii')
194
- self.txn.put(key_byte, img_byte)
195
- # write meta information
196
- h, w, c = img_shape
197
- self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n')
198
- if self.counter % self.batch == 0:
199
- self.txn.commit()
200
- self.txn = self.env.begin(write=True)
201
-
202
- def close(self):
203
- self.txn.commit()
204
- self.env.close()
205
- self.txt_file.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
core/data/deg_kair_utils/utils_logger.py DELETED
@@ -1,66 +0,0 @@
1
- import sys
2
- import datetime
3
- import logging
4
-
5
-
6
- '''
7
- # --------------------------------------------
8
- # Kai Zhang (github: https://github.com/cszn)
9
- # 03/Mar/2019
10
- # --------------------------------------------
11
- # https://github.com/xinntao/BasicSR
12
- # --------------------------------------------
13
- '''
14
-
15
-
16
- def log(*args, **kwargs):
17
- print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"), *args, **kwargs)
18
-
19
-
20
- '''
21
- # --------------------------------------------
22
- # logger
23
- # --------------------------------------------
24
- '''
25
-
26
-
27
- def logger_info(logger_name, log_path='default_logger.log'):
28
- ''' set up logger
29
- modified by Kai Zhang (github: https://github.com/cszn)
30
- '''
31
- log = logging.getLogger(logger_name)
32
- if log.hasHandlers():
33
- print('LogHandlers exist!')
34
- else:
35
- print('LogHandlers setup!')
36
- level = logging.INFO
37
- formatter = logging.Formatter('%(asctime)s.%(msecs)03d : %(message)s', datefmt='%y-%m-%d %H:%M:%S')
38
- fh = logging.FileHandler(log_path, mode='a')
39
- fh.setFormatter(formatter)
40
- log.setLevel(level)
41
- log.addHandler(fh)
42
- # print(len(log.handlers))
43
-
44
- sh = logging.StreamHandler()
45
- sh.setFormatter(formatter)
46
- log.addHandler(sh)
47
-
48
-
49
- '''
50
- # --------------------------------------------
51
- # print to file and std_out simultaneously
52
- # --------------------------------------------
53
- '''
54
-
55
-
56
- class logger_print(object):
57
- def __init__(self, log_path="default.log"):
58
- self.terminal = sys.stdout
59
- self.log = open(log_path, 'a')
60
-
61
- def write(self, message):
62
- self.terminal.write(message)
63
- self.log.write(message) # write the message
64
-
65
- def flush(self):
66
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
core/data/deg_kair_utils/utils_mat.py DELETED
@@ -1,88 +0,0 @@
1
- import os
2
- import json
3
- import scipy.io as spio
4
- import pandas as pd
5
-
6
-
7
- def loadmat(filename):
8
- '''
9
- this function should be called instead of direct spio.loadmat
10
- as it cures the problem of not properly recovering python dictionaries
11
- from mat files. It calls the function check keys to cure all entries
12
- which are still mat-objects
13
- '''
14
- data = spio.loadmat(filename, struct_as_record=False, squeeze_me=True)
15
- return dict_to_nonedict(_check_keys(data))
16
-
17
- def _check_keys(dict):
18
- '''
19
- checks if entries in dictionary are mat-objects. If yes
20
- todict is called to change them to nested dictionaries
21
- '''
22
- for key in dict:
23
- if isinstance(dict[key], spio.matlab.mio5_params.mat_struct):
24
- dict[key] = _todict(dict[key])
25
- return dict
26
-
27
- def _todict(matobj):
28
- '''
29
- A recursive function which constructs from matobjects nested dictionaries
30
- '''
31
- dict = {}
32
- for strg in matobj._fieldnames:
33
- elem = matobj.__dict__[strg]
34
- if isinstance(elem, spio.matlab.mio5_params.mat_struct):
35
- dict[strg] = _todict(elem)
36
- else:
37
- dict[strg] = elem
38
- return dict
39
-
40
-
41
- def dict_to_nonedict(opt):
42
- if isinstance(opt, dict):
43
- new_opt = dict()
44
- for key, sub_opt in opt.items():
45
- new_opt[key] = dict_to_nonedict(sub_opt)
46
- return NoneDict(**new_opt)
47
- elif isinstance(opt, list):
48
- return [dict_to_nonedict(sub_opt) for sub_opt in opt]
49
- else:
50
- return opt
51
-
52
-
53
- class NoneDict(dict):
54
- def __missing__(self, key):
55
- return None
56
-
57
-
58
- def mat2json(mat_path=None, filepath = None):
59
- """
60
- Converts .mat file to .json and writes new file
61
- Parameters
62
- ----------
63
- mat_path: Str
64
- path/filename .mat存放路径
65
- filepath: Str
66
- 如果需要保存成json, 添加这一路径. 否则不保存
67
- Returns
68
- 返回转化的字典
69
- -------
70
- None
71
- Examples
72
- --------
73
- >>> mat2json(blah blah)
74
- """
75
-
76
- matlabFile = loadmat(mat_path)
77
- #pop all those dumb fields that don't let you jsonize file
78
- matlabFile.pop('__header__')
79
- matlabFile.pop('__version__')
80
- matlabFile.pop('__globals__')
81
- #jsonize the file - orientation is 'index'
82
- matlabFile = pd.Series(matlabFile).to_json()
83
-
84
- if filepath:
85
- json_path = os.path.splitext(os.path.split(mat_path)[1])[0] + '.json'
86
- with open(json_path, 'w') as f:
87
- f.write(matlabFile)
88
- return matlabFile
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
core/data/deg_kair_utils/utils_matconvnet.py DELETED
@@ -1,197 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- import numpy as np
3
- import torch
4
- from collections import OrderedDict
5
-
6
- # import scipy.io as io
7
- import hdf5storage
8
-
9
- """
10
- # --------------------------------------------
11
- # Convert matconvnet SimpleNN model into pytorch model
12
- # --------------------------------------------
13
- # Kai Zhang (cskaizhang@gmail.com)
14
- # https://github.com/cszn
15
- # 28/Nov/2019
16
- # --------------------------------------------
17
- """
18
-
19
-
20
- def weights2tensor(x, squeeze=False, in_features=None, out_features=None):
21
- """Modified version of https://github.com/albanie/pytorch-mcn
22
- Adjust memory layout and load weights as torch tensor
23
- Args:
24
- x (ndaray): a numpy array, corresponding to a set of network weights
25
- stored in column major order
26
- squeeze (bool) [False]: whether to squeeze the tensor (i.e. remove
27
- singletons from the trailing dimensions. So after converting to
28
- pytorch layout (C_out, C_in, H, W), if the shape is (A, B, 1, 1)
29
- it will be reshaped to a matrix with shape (A,B).
30
- in_features (int :: None): used to reshape weights for a linear block.
31
- out_features (int :: None): used to reshape weights for a linear block.
32
- Returns:
33
- torch.tensor: a permuted sets of weights, matching the pytorch layout
34
- convention
35
- """
36
- if x.ndim == 4:
37
- x = x.transpose((3, 2, 0, 1))
38
- # for FFDNet, pixel-shuffle layer
39
- # if x.shape[1]==13:
40
- # x=x[:,[0,2,1,3, 4,6,5,7, 8,10,9,11, 12],:,:]
41
- # if x.shape[0]==12:
42
- # x=x[[0,2,1,3, 4,6,5,7, 8,10,9,11],:,:,:]
43
- # if x.shape[1]==5:
44
- # x=x[:,[0,2,1,3, 4],:,:]
45
- # if x.shape[0]==4:
46
- # x=x[[0,2,1,3],:,:,:]
47
- ## for SRMD, pixel-shuffle layer
48
- # if x.shape[0]==12:
49
- # x=x[[0,2,1,3, 4,6,5,7, 8,10,9,11],:,:,:]
50
- # if x.shape[0]==27:
51
- # x=x[[0,3,6,1,4,7,2,5,8, 0+9,3+9,6+9,1+9,4+9,7+9,2+9,5+9,8+9, 0+18,3+18,6+18,1+18,4+18,7+18,2+18,5+18,8+18],:,:,:]
52
- # if x.shape[0]==48:
53
- # x=x[[0,4,8,12,1,5,9,13,2,6,10,14,3,7,11,15, 0+16,4+16,8+16,12+16,1+16,5+16,9+16,13+16,2+16,6+16,10+16,14+16,3+16,7+16,11+16,15+16, 0+32,4+32,8+32,12+32,1+32,5+32,9+32,13+32,2+32,6+32,10+32,14+32,3+32,7+32,11+32,15+32],:,:,:]
54
-
55
- elif x.ndim == 3: # add by Kai
56
- x = x[:,:,:,None]
57
- x = x.transpose((3, 2, 0, 1))
58
- elif x.ndim == 2:
59
- if x.shape[1] == 1:
60
- x = x.flatten()
61
- if squeeze:
62
- if in_features and out_features:
63
- x = x.reshape((out_features, in_features))
64
- x = np.squeeze(x)
65
- return torch.from_numpy(np.ascontiguousarray(x))
66
-
67
-
68
- def save_model(network, save_path):
69
- state_dict = network.state_dict()
70
- for key, param in state_dict.items():
71
- state_dict[key] = param.cpu()
72
- torch.save(state_dict, save_path)
73
-
74
-
75
- if __name__ == '__main__':
76
-
77
-
78
- # from utils import utils_logger
79
- # import logging
80
- # utils_logger.logger_info('a', 'a.log')
81
- # logger = logging.getLogger('a')
82
- #
83
- # mcn = hdf5storage.loadmat('/model_zoo/matfile/FFDNet_Clip_gray.mat')
84
- mcn = hdf5storage.loadmat('models/modelcolor.mat')
85
-
86
-
87
- #logger.info(mcn['CNNdenoiser'][0][0][0][1][0][0][0][0])
88
-
89
- mat_net = OrderedDict()
90
- for idx in range(25):
91
- mat_net[str(idx)] = OrderedDict()
92
- count = -1
93
-
94
- print(idx)
95
- for i in range(13):
96
-
97
- if mcn['CNNdenoiser'][0][idx][0][i][0][0][0][0] == 'conv':
98
-
99
- count += 1
100
- w = mcn['CNNdenoiser'][0][idx][0][i][0][1][0][0]
101
- # print(w.shape)
102
- w = weights2tensor(w)
103
- # print(w.shape)
104
-
105
- b = mcn['CNNdenoiser'][0][idx][0][i][0][1][0][1]
106
- b = weights2tensor(b)
107
- print(b.shape)
108
-
109
- mat_net[str(idx)]['model.{:d}.weight'.format(count*2)] = w
110
- mat_net[str(idx)]['model.{:d}.bias'.format(count*2)] = b
111
-
112
- torch.save(mat_net, 'model_zoo/modelcolor.pth')
113
-
114
-
115
-
116
- # from models.network_dncnn import IRCNN as net
117
- # network = net(in_nc=3, out_nc=3, nc=64)
118
- # state_dict = network.state_dict()
119
- #
120
- # #show_kv(state_dict)
121
- #
122
- # for i in range(len(mcn['net'][0][0][0])):
123
- # print(mcn['net'][0][0][0][i][0][0][0][0])
124
- #
125
- # count = -1
126
- # mat_net = OrderedDict()
127
- # for i in range(len(mcn['net'][0][0][0])):
128
- # if mcn['net'][0][0][0][i][0][0][0][0] == 'conv':
129
- #
130
- # count += 1
131
- # w = mcn['net'][0][0][0][i][0][1][0][0]
132
- # print(w.shape)
133
- # w = weights2tensor(w)
134
- # print(w.shape)
135
- #
136
- # b = mcn['net'][0][0][0][i][0][1][0][1]
137
- # b = weights2tensor(b)
138
- # print(b.shape)
139
- #
140
- # mat_net['model.{:d}.weight'.format(count*2)] = w
141
- # mat_net['model.{:d}.bias'.format(count*2)] = b
142
- #
143
- # torch.save(mat_net, 'E:/pytorch/KAIR_ongoing/model_zoo/ffdnet_gray_clip.pth')
144
- #
145
- #
146
- #
147
- # crt_net = torch.load('E:/pytorch/KAIR_ongoing/model_zoo/imdn_x4.pth')
148
- # def show_kv(net):
149
- # for k, v in net.items():
150
- # print(k)
151
- #
152
- # show_kv(crt_net)
153
-
154
-
155
- # from models.network_dncnn import DnCNN as net
156
- # network = net(in_nc=2, out_nc=1, nc=64, nb=20, act_mode='R')
157
-
158
- # from models.network_srmd import SRMD as net
159
- # #network = net(in_nc=1, out_nc=1, nc=64, nb=15, act_mode='R')
160
- # network = net(in_nc=19, out_nc=3, nc=128, nb=12, upscale=4, act_mode='R', upsample_mode='pixelshuffle')
161
- #
162
- # from models.network_rrdb import RRDB as net
163
- # network = net(in_nc=3, out_nc=3, nc=64, nb=23, gc=32, upscale=4, act_mode='L', upsample_mode='upconv')
164
- #
165
- # state_dict = network.state_dict()
166
- # for key, param in state_dict.items():
167
- # print(key)
168
- # from models.network_imdn import IMDN as net
169
- # network = net(in_nc=3, out_nc=3, nc=64, nb=8, upscale=4, act_mode='L', upsample_mode='pixelshuffle')
170
- # state_dict = network.state_dict()
171
- # mat_net = OrderedDict()
172
- # for ((key, param),(key2, param2)) in zip(state_dict.items(), crt_net.items()):
173
- # mat_net[key] = param2
174
- # torch.save(mat_net, 'model_zoo/imdn_x4_1.pth')
175
- #
176
-
177
- # net_old = torch.load('net_old.pth')
178
- # def show_kv(net):
179
- # for k, v in net.items():
180
- # print(k)
181
- #
182
- # show_kv(net_old)
183
- # from models.network_dpsr import MSRResNet_prior as net
184
- # model = net(in_nc=4, out_nc=3, nc=96, nb=16, upscale=4, act_mode='R', upsample_mode='pixelshuffle')
185
- # state_dict = network.state_dict()
186
- # net_new = OrderedDict()
187
- # for ((key, param),(key_old, param_old)) in zip(state_dict.items(), net_old.items()):
188
- # net_new[key] = param_old
189
- # torch.save(net_new, 'net_new.pth')
190
-
191
-
192
- # print(key)
193
- # print(param.size())
194
-
195
-
196
-
197
- # run utils/utils_matconvnet.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
core/data/deg_kair_utils/utils_model.py DELETED
@@ -1,330 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- import numpy as np
3
- import torch
4
- from utils import utils_image as util
5
- import re
6
- import glob
7
- import os
8
-
9
-
10
- '''
11
- # --------------------------------------------
12
- # Model
13
- # --------------------------------------------
14
- # Kai Zhang (github: https://github.com/cszn)
15
- # 03/Mar/2019
16
- # --------------------------------------------
17
- '''
18
-
19
-
20
- def find_last_checkpoint(save_dir, net_type='G', pretrained_path=None):
21
- """
22
- # ---------------------------------------
23
- # Kai Zhang (github: https://github.com/cszn)
24
- # 03/Mar/2019
25
- # ---------------------------------------
26
- Args:
27
- save_dir: model folder
28
- net_type: 'G' or 'D' or 'optimizerG' or 'optimizerD'
29
- pretrained_path: pretrained model path. If save_dir does not have any model, load from pretrained_path
30
-
31
- Return:
32
- init_iter: iteration number
33
- init_path: model path
34
- # ---------------------------------------
35
- """
36
-
37
- file_list = glob.glob(os.path.join(save_dir, '*_{}.pth'.format(net_type)))
38
- if file_list:
39
- iter_exist = []
40
- for file_ in file_list:
41
- iter_current = re.findall(r"(\d+)_{}.pth".format(net_type), file_)
42
- iter_exist.append(int(iter_current[0]))
43
- init_iter = max(iter_exist)
44
- init_path = os.path.join(save_dir, '{}_{}.pth'.format(init_iter, net_type))
45
- else:
46
- init_iter = 0
47
- init_path = pretrained_path
48
- return init_iter, init_path
49
-
50
-
51
- def test_mode(model, L, mode=0, refield=32, min_size=256, sf=1, modulo=1):
52
- '''
53
- # ---------------------------------------
54
- # Kai Zhang (github: https://github.com/cszn)
55
- # 03/Mar/2019
56
- # ---------------------------------------
57
- Args:
58
- model: trained model
59
- L: input Low-quality image
60
- mode:
61
- (0) normal: test(model, L)
62
- (1) pad: test_pad(model, L, modulo=16)
63
- (2) split: test_split(model, L, refield=32, min_size=256, sf=1, modulo=1)
64
- (3) x8: test_x8(model, L, modulo=1) ^_^
65
- (4) split and x8: test_split_x8(model, L, refield=32, min_size=256, sf=1, modulo=1)
66
- refield: effective receptive filed of the network, 32 is enough
67
- useful when split, i.e., mode=2, 4
68
- min_size: min_sizeXmin_size image, e.g., 256X256 image
69
- useful when split, i.e., mode=2, 4
70
- sf: scale factor for super-resolution, otherwise 1
71
- modulo: 1 if split
72
- useful when pad, i.e., mode=1
73
-
74
- Returns:
75
- E: estimated image
76
- # ---------------------------------------
77
- '''
78
- if mode == 0:
79
- E = test(model, L)
80
- elif mode == 1:
81
- E = test_pad(model, L, modulo, sf)
82
- elif mode == 2:
83
- E = test_split(model, L, refield, min_size, sf, modulo)
84
- elif mode == 3:
85
- E = test_x8(model, L, modulo, sf)
86
- elif mode == 4:
87
- E = test_split_x8(model, L, refield, min_size, sf, modulo)
88
- return E
89
-
90
-
91
- '''
92
- # --------------------------------------------
93
- # normal (0)
94
- # --------------------------------------------
95
- '''
96
-
97
-
98
- def test(model, L):
99
- E = model(L)
100
- return E
101
-
102
-
103
- '''
104
- # --------------------------------------------
105
- # pad (1)
106
- # --------------------------------------------
107
- '''
108
-
109
-
110
- def test_pad(model, L, modulo=16, sf=1):
111
- h, w = L.size()[-2:]
112
- paddingBottom = int(np.ceil(h/modulo)*modulo-h)
113
- paddingRight = int(np.ceil(w/modulo)*modulo-w)
114
- L = torch.nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(L)
115
- E = model(L)
116
- E = E[..., :h*sf, :w*sf]
117
- return E
118
-
119
-
120
- '''
121
- # --------------------------------------------
122
- # split (function)
123
- # --------------------------------------------
124
- '''
125
-
126
-
127
- def test_split_fn(model, L, refield=32, min_size=256, sf=1, modulo=1):
128
- """
129
- Args:
130
- model: trained model
131
- L: input Low-quality image
132
- refield: effective receptive filed of the network, 32 is enough
133
- min_size: min_sizeXmin_size image, e.g., 256X256 image
134
- sf: scale factor for super-resolution, otherwise 1
135
- modulo: 1 if split
136
-
137
- Returns:
138
- E: estimated result
139
- """
140
- h, w = L.size()[-2:]
141
- if h*w <= min_size**2:
142
- L = torch.nn.ReplicationPad2d((0, int(np.ceil(w/modulo)*modulo-w), 0, int(np.ceil(h/modulo)*modulo-h)))(L)
143
- E = model(L)
144
- E = E[..., :h*sf, :w*sf]
145
- else:
146
- top = slice(0, (h//2//refield+1)*refield)
147
- bottom = slice(h - (h//2//refield+1)*refield, h)
148
- left = slice(0, (w//2//refield+1)*refield)
149
- right = slice(w - (w//2//refield+1)*refield, w)
150
- Ls = [L[..., top, left], L[..., top, right], L[..., bottom, left], L[..., bottom, right]]
151
-
152
- if h * w <= 4*(min_size**2):
153
- Es = [model(Ls[i]) for i in range(4)]
154
- else:
155
- Es = [test_split_fn(model, Ls[i], refield=refield, min_size=min_size, sf=sf, modulo=modulo) for i in range(4)]
156
-
157
- b, c = Es[0].size()[:2]
158
- E = torch.zeros(b, c, sf * h, sf * w).type_as(L)
159
-
160
- E[..., :h//2*sf, :w//2*sf] = Es[0][..., :h//2*sf, :w//2*sf]
161
- E[..., :h//2*sf, w//2*sf:w*sf] = Es[1][..., :h//2*sf, (-w + w//2)*sf:]
162
- E[..., h//2*sf:h*sf, :w//2*sf] = Es[2][..., (-h + h//2)*sf:, :w//2*sf]
163
- E[..., h//2*sf:h*sf, w//2*sf:w*sf] = Es[3][..., (-h + h//2)*sf:, (-w + w//2)*sf:]
164
- return E
165
-
166
-
167
- '''
168
- # --------------------------------------------
169
- # split (2)
170
- # --------------------------------------------
171
- '''
172
-
173
-
174
- def test_split(model, L, refield=32, min_size=256, sf=1, modulo=1):
175
- E = test_split_fn(model, L, refield=refield, min_size=min_size, sf=sf, modulo=modulo)
176
- return E
177
-
178
-
179
- '''
180
- # --------------------------------------------
181
- # x8 (3)
182
- # --------------------------------------------
183
- '''
184
-
185
-
186
- def test_x8(model, L, modulo=1, sf=1):
187
- E_list = [test_pad(model, util.augment_img_tensor4(L, mode=i), modulo=modulo, sf=sf) for i in range(8)]
188
- for i in range(len(E_list)):
189
- if i == 3 or i == 5:
190
- E_list[i] = util.augment_img_tensor4(E_list[i], mode=8 - i)
191
- else:
192
- E_list[i] = util.augment_img_tensor4(E_list[i], mode=i)
193
- output_cat = torch.stack(E_list, dim=0)
194
- E = output_cat.mean(dim=0, keepdim=False)
195
- return E
196
-
197
-
198
- '''
199
- # --------------------------------------------
200
- # split and x8 (4)
201
- # --------------------------------------------
202
- '''
203
-
204
-
205
- def test_split_x8(model, L, refield=32, min_size=256, sf=1, modulo=1):
206
- E_list = [test_split_fn(model, util.augment_img_tensor4(L, mode=i), refield=refield, min_size=min_size, sf=sf, modulo=modulo) for i in range(8)]
207
- for k, i in enumerate(range(len(E_list))):
208
- if i==3 or i==5:
209
- E_list[k] = util.augment_img_tensor4(E_list[k], mode=8-i)
210
- else:
211
- E_list[k] = util.augment_img_tensor4(E_list[k], mode=i)
212
- output_cat = torch.stack(E_list, dim=0)
213
- E = output_cat.mean(dim=0, keepdim=False)
214
- return E
215
-
216
-
217
- '''
218
- # ^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-
219
- # _^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^
220
- # ^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-
221
- '''
222
-
223
-
224
- '''
225
- # --------------------------------------------
226
- # print
227
- # --------------------------------------------
228
- '''
229
-
230
-
231
- # --------------------------------------------
232
- # print model
233
- # --------------------------------------------
234
- def print_model(model):
235
- msg = describe_model(model)
236
- print(msg)
237
-
238
-
239
- # --------------------------------------------
240
- # print params
241
- # --------------------------------------------
242
- def print_params(model):
243
- msg = describe_params(model)
244
- print(msg)
245
-
246
-
247
- '''
248
- # --------------------------------------------
249
- # information
250
- # --------------------------------------------
251
- '''
252
-
253
-
254
- # --------------------------------------------
255
- # model inforation
256
- # --------------------------------------------
257
- def info_model(model):
258
- msg = describe_model(model)
259
- return msg
260
-
261
-
262
- # --------------------------------------------
263
- # params inforation
264
- # --------------------------------------------
265
- def info_params(model):
266
- msg = describe_params(model)
267
- return msg
268
-
269
-
270
- '''
271
- # --------------------------------------------
272
- # description
273
- # --------------------------------------------
274
- '''
275
-
276
-
277
- # --------------------------------------------
278
- # model name and total number of parameters
279
- # --------------------------------------------
280
- def describe_model(model):
281
- if isinstance(model, torch.nn.DataParallel):
282
- model = model.module
283
- msg = '\n'
284
- msg += 'models name: {}'.format(model.__class__.__name__) + '\n'
285
- msg += 'Params number: {}'.format(sum(map(lambda x: x.numel(), model.parameters()))) + '\n'
286
- msg += 'Net structure:\n{}'.format(str(model)) + '\n'
287
- return msg
288
-
289
-
290
- # --------------------------------------------
291
- # parameters description
292
- # --------------------------------------------
293
- def describe_params(model):
294
- if isinstance(model, torch.nn.DataParallel):
295
- model = model.module
296
- msg = '\n'
297
- msg += ' | {:^6s} | {:^6s} | {:^6s} | {:^6s} || {:<20s}'.format('mean', 'min', 'max', 'std', 'shape', 'param_name') + '\n'
298
- for name, param in model.state_dict().items():
299
- if not 'num_batches_tracked' in name:
300
- v = param.data.clone().float()
301
- msg += ' | {:>6.3f} | {:>6.3f} | {:>6.3f} | {:>6.3f} | {} || {:s}'.format(v.mean(), v.min(), v.max(), v.std(), v.shape, name) + '\n'
302
- return msg
303
-
304
-
305
- if __name__ == '__main__':
306
-
307
- class Net(torch.nn.Module):
308
- def __init__(self, in_channels=3, out_channels=3):
309
- super(Net, self).__init__()
310
- self.conv = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1)
311
-
312
- def forward(self, x):
313
- x = self.conv(x)
314
- return x
315
-
316
- start = torch.cuda.Event(enable_timing=True)
317
- end = torch.cuda.Event(enable_timing=True)
318
-
319
- model = Net()
320
- model = model.eval()
321
- print_model(model)
322
- print_params(model)
323
- x = torch.randn((2,3,401,401))
324
- torch.cuda.empty_cache()
325
- with torch.no_grad():
326
- for mode in range(5):
327
- y = test_mode(model, x, mode, refield=32, min_size=256, sf=1, modulo=1)
328
- print(y.shape)
329
-
330
- # run utils/utils_model.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
core/data/deg_kair_utils/utils_modelsummary.py DELETED
@@ -1,485 +0,0 @@
1
- import torch.nn as nn
2
- import torch
3
- import numpy as np
4
-
5
- '''
6
- ---- 1) FLOPs: floating point operations
7
- ---- 2) #Activations: the number of elements of all ‘Conv2d’ outputs
8
- ---- 3) #Conv2d: the number of ‘Conv2d’ layers
9
- # --------------------------------------------
10
- # Kai Zhang (github: https://github.com/cszn)
11
- # 21/July/2020
12
- # --------------------------------------------
13
- # Reference
14
- https://github.com/sovrasov/flops-counter.pytorch.git
15
-
16
- # If you use this code, please consider the following citation:
17
-
18
- @inproceedings{zhang2020aim, %
19
- title={AIM 2020 Challenge on Efficient Super-Resolution: Methods and Results},
20
- author={Kai Zhang and Martin Danelljan and Yawei Li and Radu Timofte and others},
21
- booktitle={European Conference on Computer Vision Workshops},
22
- year={2020}
23
- }
24
- # --------------------------------------------
25
- '''
26
-
27
- def get_model_flops(model, input_res, print_per_layer_stat=True,
28
- input_constructor=None):
29
- assert type(input_res) is tuple, 'Please provide the size of the input image.'
30
- assert len(input_res) >= 3, 'Input image should have 3 dimensions.'
31
- flops_model = add_flops_counting_methods(model)
32
- flops_model.eval().start_flops_count()
33
- if input_constructor:
34
- input = input_constructor(input_res)
35
- _ = flops_model(**input)
36
- else:
37
- device = list(flops_model.parameters())[-1].device
38
- batch = torch.FloatTensor(1, *input_res).to(device)
39
- _ = flops_model(batch)
40
-
41
- if print_per_layer_stat:
42
- print_model_with_flops(flops_model)
43
- flops_count = flops_model.compute_average_flops_cost()
44
- flops_model.stop_flops_count()
45
-
46
- return flops_count
47
-
48
- def get_model_activation(model, input_res, input_constructor=None):
49
- assert type(input_res) is tuple, 'Please provide the size of the input image.'
50
- assert len(input_res) >= 3, 'Input image should have 3 dimensions.'
51
- activation_model = add_activation_counting_methods(model)
52
- activation_model.eval().start_activation_count()
53
- if input_constructor:
54
- input = input_constructor(input_res)
55
- _ = activation_model(**input)
56
- else:
57
- device = list(activation_model.parameters())[-1].device
58
- batch = torch.FloatTensor(1, *input_res).to(device)
59
- _ = activation_model(batch)
60
-
61
- activation_count, num_conv = activation_model.compute_average_activation_cost()
62
- activation_model.stop_activation_count()
63
-
64
- return activation_count, num_conv
65
-
66
-
67
- def get_model_complexity_info(model, input_res, print_per_layer_stat=True, as_strings=True,
68
- input_constructor=None):
69
- assert type(input_res) is tuple
70
- assert len(input_res) >= 3
71
- flops_model = add_flops_counting_methods(model)
72
- flops_model.eval().start_flops_count()
73
- if input_constructor:
74
- input = input_constructor(input_res)
75
- _ = flops_model(**input)
76
- else:
77
- batch = torch.FloatTensor(1, *input_res)
78
- _ = flops_model(batch)
79
-
80
- if print_per_layer_stat:
81
- print_model_with_flops(flops_model)
82
- flops_count = flops_model.compute_average_flops_cost()
83
- params_count = get_model_parameters_number(flops_model)
84
- flops_model.stop_flops_count()
85
-
86
- if as_strings:
87
- return flops_to_string(flops_count), params_to_string(params_count)
88
-
89
- return flops_count, params_count
90
-
91
-
92
- def flops_to_string(flops, units='GMac', precision=2):
93
- if units is None:
94
- if flops // 10**9 > 0:
95
- return str(round(flops / 10.**9, precision)) + ' GMac'
96
- elif flops // 10**6 > 0:
97
- return str(round(flops / 10.**6, precision)) + ' MMac'
98
- elif flops // 10**3 > 0:
99
- return str(round(flops / 10.**3, precision)) + ' KMac'
100
- else:
101
- return str(flops) + ' Mac'
102
- else:
103
- if units == 'GMac':
104
- return str(round(flops / 10.**9, precision)) + ' ' + units
105
- elif units == 'MMac':
106
- return str(round(flops / 10.**6, precision)) + ' ' + units
107
- elif units == 'KMac':
108
- return str(round(flops / 10.**3, precision)) + ' ' + units
109
- else:
110
- return str(flops) + ' Mac'
111
-
112
-
113
- def params_to_string(params_num):
114
- if params_num // 10 ** 6 > 0:
115
- return str(round(params_num / 10 ** 6, 2)) + ' M'
116
- elif params_num // 10 ** 3:
117
- return str(round(params_num / 10 ** 3, 2)) + ' k'
118
- else:
119
- return str(params_num)
120
-
121
-
122
- def print_model_with_flops(model, units='GMac', precision=3):
123
- total_flops = model.compute_average_flops_cost()
124
-
125
- def accumulate_flops(self):
126
- if is_supported_instance(self):
127
- return self.__flops__ / model.__batch_counter__
128
- else:
129
- sum = 0
130
- for m in self.children():
131
- sum += m.accumulate_flops()
132
- return sum
133
-
134
- def flops_repr(self):
135
- accumulated_flops_cost = self.accumulate_flops()
136
- return ', '.join([flops_to_string(accumulated_flops_cost, units=units, precision=precision),
137
- '{:.3%} MACs'.format(accumulated_flops_cost / total_flops),
138
- self.original_extra_repr()])
139
-
140
- def add_extra_repr(m):
141
- m.accumulate_flops = accumulate_flops.__get__(m)
142
- flops_extra_repr = flops_repr.__get__(m)
143
- if m.extra_repr != flops_extra_repr:
144
- m.original_extra_repr = m.extra_repr
145
- m.extra_repr = flops_extra_repr
146
- assert m.extra_repr != m.original_extra_repr
147
-
148
- def del_extra_repr(m):
149
- if hasattr(m, 'original_extra_repr'):
150
- m.extra_repr = m.original_extra_repr
151
- del m.original_extra_repr
152
- if hasattr(m, 'accumulate_flops'):
153
- del m.accumulate_flops
154
-
155
- model.apply(add_extra_repr)
156
- print(model)
157
- model.apply(del_extra_repr)
158
-
159
-
160
- def get_model_parameters_number(model):
161
- params_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
162
- return params_num
163
-
164
-
165
- def add_flops_counting_methods(net_main_module):
166
- # adding additional methods to the existing module object,
167
- # this is done this way so that each function has access to self object
168
- # embed()
169
- net_main_module.start_flops_count = start_flops_count.__get__(net_main_module)
170
- net_main_module.stop_flops_count = stop_flops_count.__get__(net_main_module)
171
- net_main_module.reset_flops_count = reset_flops_count.__get__(net_main_module)
172
- net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__(net_main_module)
173
-
174
- net_main_module.reset_flops_count()
175
- return net_main_module
176
-
177
-
178
- def compute_average_flops_cost(self):
179
- """
180
- A method that will be available after add_flops_counting_methods() is called
181
- on a desired net object.
182
-
183
- Returns current mean flops consumption per image.
184
-
185
- """
186
-
187
- flops_sum = 0
188
- for module in self.modules():
189
- if is_supported_instance(module):
190
- flops_sum += module.__flops__
191
-
192
- return flops_sum
193
-
194
-
195
- def start_flops_count(self):
196
- """
197
- A method that will be available after add_flops_counting_methods() is called
198
- on a desired net object.
199
-
200
- Activates the computation of mean flops consumption per image.
201
- Call it before you run the network.
202
-
203
- """
204
- self.apply(add_flops_counter_hook_function)
205
-
206
-
207
- def stop_flops_count(self):
208
- """
209
- A method that will be available after add_flops_counting_methods() is called
210
- on a desired net object.
211
-
212
- Stops computing the mean flops consumption per image.
213
- Call whenever you want to pause the computation.
214
-
215
- """
216
- self.apply(remove_flops_counter_hook_function)
217
-
218
-
219
- def reset_flops_count(self):
220
- """
221
- A method that will be available after add_flops_counting_methods() is called
222
- on a desired net object.
223
-
224
- Resets statistics computed so far.
225
-
226
- """
227
- self.apply(add_flops_counter_variable_or_reset)
228
-
229
-
230
- def add_flops_counter_hook_function(module):
231
- if is_supported_instance(module):
232
- if hasattr(module, '__flops_handle__'):
233
- return
234
-
235
- if isinstance(module, (nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d)):
236
- handle = module.register_forward_hook(conv_flops_counter_hook)
237
- elif isinstance(module, (nn.ReLU, nn.PReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6)):
238
- handle = module.register_forward_hook(relu_flops_counter_hook)
239
- elif isinstance(module, nn.Linear):
240
- handle = module.register_forward_hook(linear_flops_counter_hook)
241
- elif isinstance(module, (nn.BatchNorm2d)):
242
- handle = module.register_forward_hook(bn_flops_counter_hook)
243
- else:
244
- handle = module.register_forward_hook(empty_flops_counter_hook)
245
- module.__flops_handle__ = handle
246
-
247
-
248
- def remove_flops_counter_hook_function(module):
249
- if is_supported_instance(module):
250
- if hasattr(module, '__flops_handle__'):
251
- module.__flops_handle__.remove()
252
- del module.__flops_handle__
253
-
254
-
255
- def add_flops_counter_variable_or_reset(module):
256
- if is_supported_instance(module):
257
- module.__flops__ = 0
258
-
259
-
260
- # ---- Internal functions
261
- def is_supported_instance(module):
262
- if isinstance(module,
263
- (
264
- nn.Conv2d, nn.ConvTranspose2d,
265
- nn.BatchNorm2d,
266
- nn.Linear,
267
- nn.ReLU, nn.PReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6,
268
- )):
269
- return True
270
-
271
- return False
272
-
273
-
274
- def conv_flops_counter_hook(conv_module, input, output):
275
- # Can have multiple inputs, getting the first one
276
- # input = input[0]
277
-
278
- batch_size = output.shape[0]
279
- output_dims = list(output.shape[2:])
280
-
281
- kernel_dims = list(conv_module.kernel_size)
282
- in_channels = conv_module.in_channels
283
- out_channels = conv_module.out_channels
284
- groups = conv_module.groups
285
-
286
- filters_per_channel = out_channels // groups
287
- conv_per_position_flops = np.prod(kernel_dims) * in_channels * filters_per_channel
288
-
289
- active_elements_count = batch_size * np.prod(output_dims)
290
- overall_conv_flops = int(conv_per_position_flops) * int(active_elements_count)
291
-
292
- # overall_flops = overall_conv_flops
293
-
294
- conv_module.__flops__ += int(overall_conv_flops)
295
- # conv_module.__output_dims__ = output_dims
296
-
297
-
298
- def relu_flops_counter_hook(module, input, output):
299
- active_elements_count = output.numel()
300
- module.__flops__ += int(active_elements_count)
301
- # print(module.__flops__, id(module))
302
- # print(module)
303
-
304
-
305
- def linear_flops_counter_hook(module, input, output):
306
- input = input[0]
307
- if len(input.shape) == 1:
308
- batch_size = 1
309
- module.__flops__ += int(batch_size * input.shape[0] * output.shape[0])
310
- else:
311
- batch_size = input.shape[0]
312
- module.__flops__ += int(batch_size * input.shape[1] * output.shape[1])
313
-
314
-
315
- def bn_flops_counter_hook(module, input, output):
316
- # input = input[0]
317
- # TODO: need to check here
318
- # batch_flops = np.prod(input.shape)
319
- # if module.affine:
320
- # batch_flops *= 2
321
- # module.__flops__ += int(batch_flops)
322
- batch = output.shape[0]
323
- output_dims = output.shape[2:]
324
- channels = module.num_features
325
- batch_flops = batch * channels * np.prod(output_dims)
326
- if module.affine:
327
- batch_flops *= 2
328
- module.__flops__ += int(batch_flops)
329
-
330
-
331
- # ---- Count the number of convolutional layers and the activation
332
- def add_activation_counting_methods(net_main_module):
333
- # adding additional methods to the existing module object,
334
- # this is done this way so that each function has access to self object
335
- # embed()
336
- net_main_module.start_activation_count = start_activation_count.__get__(net_main_module)
337
- net_main_module.stop_activation_count = stop_activation_count.__get__(net_main_module)
338
- net_main_module.reset_activation_count = reset_activation_count.__get__(net_main_module)
339
- net_main_module.compute_average_activation_cost = compute_average_activation_cost.__get__(net_main_module)
340
-
341
- net_main_module.reset_activation_count()
342
- return net_main_module
343
-
344
-
345
- def compute_average_activation_cost(self):
346
- """
347
- A method that will be available after add_activation_counting_methods() is called
348
- on a desired net object.
349
-
350
- Returns current mean activation consumption per image.
351
-
352
- """
353
-
354
- activation_sum = 0
355
- num_conv = 0
356
- for module in self.modules():
357
- if is_supported_instance_for_activation(module):
358
- activation_sum += module.__activation__
359
- num_conv += module.__num_conv__
360
- return activation_sum, num_conv
361
-
362
-
363
- def start_activation_count(self):
364
- """
365
- A method that will be available after add_activation_counting_methods() is called
366
- on a desired net object.
367
-
368
- Activates the computation of mean activation consumption per image.
369
- Call it before you run the network.
370
-
371
- """
372
- self.apply(add_activation_counter_hook_function)
373
-
374
-
375
- def stop_activation_count(self):
376
- """
377
- A method that will be available after add_activation_counting_methods() is called
378
- on a desired net object.
379
-
380
- Stops computing the mean activation consumption per image.
381
- Call whenever you want to pause the computation.
382
-
383
- """
384
- self.apply(remove_activation_counter_hook_function)
385
-
386
-
387
- def reset_activation_count(self):
388
- """
389
- A method that will be available after add_activation_counting_methods() is called
390
- on a desired net object.
391
-
392
- Resets statistics computed so far.
393
-
394
- """
395
- self.apply(add_activation_counter_variable_or_reset)
396
-
397
-
398
- def add_activation_counter_hook_function(module):
399
- if is_supported_instance_for_activation(module):
400
- if hasattr(module, '__activation_handle__'):
401
- return
402
-
403
- if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)):
404
- handle = module.register_forward_hook(conv_activation_counter_hook)
405
- module.__activation_handle__ = handle
406
-
407
-
408
- def remove_activation_counter_hook_function(module):
409
- if is_supported_instance_for_activation(module):
410
- if hasattr(module, '__activation_handle__'):
411
- module.__activation_handle__.remove()
412
- del module.__activation_handle__
413
-
414
-
415
- def add_activation_counter_variable_or_reset(module):
416
- if is_supported_instance_for_activation(module):
417
- module.__activation__ = 0
418
- module.__num_conv__ = 0
419
-
420
-
421
- def is_supported_instance_for_activation(module):
422
- if isinstance(module,
423
- (
424
- nn.Conv2d, nn.ConvTranspose2d,
425
- )):
426
- return True
427
-
428
- return False
429
-
430
- def conv_activation_counter_hook(module, input, output):
431
- """
432
- Calculate the activations in the convolutional operation.
433
- Reference: Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr Dollár, Designing Network Design Spaces.
434
- :param module:
435
- :param input:
436
- :param output:
437
- :return:
438
- """
439
- module.__activation__ += output.numel()
440
- module.__num_conv__ += 1
441
-
442
-
443
- def empty_flops_counter_hook(module, input, output):
444
- module.__flops__ += 0
445
-
446
-
447
- def upsample_flops_counter_hook(module, input, output):
448
- output_size = output[0]
449
- batch_size = output_size.shape[0]
450
- output_elements_count = batch_size
451
- for val in output_size.shape[1:]:
452
- output_elements_count *= val
453
- module.__flops__ += int(output_elements_count)
454
-
455
-
456
- def pool_flops_counter_hook(module, input, output):
457
- input = input[0]
458
- module.__flops__ += int(np.prod(input.shape))
459
-
460
-
461
- def dconv_flops_counter_hook(dconv_module, input, output):
462
- input = input[0]
463
-
464
- batch_size = input.shape[0]
465
- output_dims = list(output.shape[2:])
466
-
467
- m_channels, in_channels, kernel_dim1, _, = dconv_module.weight.shape
468
- out_channels, _, kernel_dim2, _, = dconv_module.projection.shape
469
- # groups = dconv_module.groups
470
-
471
- # filters_per_channel = out_channels // groups
472
- conv_per_position_flops1 = kernel_dim1 ** 2 * in_channels * m_channels
473
- conv_per_position_flops2 = kernel_dim2 ** 2 * out_channels * m_channels
474
- active_elements_count = batch_size * np.prod(output_dims)
475
-
476
- overall_conv_flops = (conv_per_position_flops1 + conv_per_position_flops2) * active_elements_count
477
- overall_flops = overall_conv_flops
478
-
479
- dconv_module.__flops__ += int(overall_flops)
480
- # dconv_module.__output_dims__ = output_dims
481
-
482
-
483
-
484
-
485
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
core/data/deg_kair_utils/utils_option.py DELETED
@@ -1,255 +0,0 @@
1
- import os
2
- from collections import OrderedDict
3
- from datetime import datetime
4
- import json
5
- import re
6
- import glob
7
-
8
-
9
- '''
10
- # --------------------------------------------
11
- # Kai Zhang (github: https://github.com/cszn)
12
- # 03/Mar/2019
13
- # --------------------------------------------
14
- # https://github.com/xinntao/BasicSR
15
- # --------------------------------------------
16
- '''
17
-
18
-
19
- def get_timestamp():
20
- return datetime.now().strftime('_%y%m%d_%H%M%S')
21
-
22
-
23
- def parse(opt_path, is_train=True):
24
-
25
- # ----------------------------------------
26
- # remove comments starting with '//'
27
- # ----------------------------------------
28
- json_str = ''
29
- with open(opt_path, 'r') as f:
30
- for line in f:
31
- line = line.split('//')[0] + '\n'
32
- json_str += line
33
-
34
- # ----------------------------------------
35
- # initialize opt
36
- # ----------------------------------------
37
- opt = json.loads(json_str, object_pairs_hook=OrderedDict)
38
-
39
- opt['opt_path'] = opt_path
40
- opt['is_train'] = is_train
41
-
42
- # ----------------------------------------
43
- # set default
44
- # ----------------------------------------
45
- if 'merge_bn' not in opt:
46
- opt['merge_bn'] = False
47
- opt['merge_bn_startpoint'] = -1
48
-
49
- if 'scale' not in opt:
50
- opt['scale'] = 1
51
-
52
- # ----------------------------------------
53
- # datasets
54
- # ----------------------------------------
55
- for phase, dataset in opt['datasets'].items():
56
- phase = phase.split('_')[0]
57
- dataset['phase'] = phase
58
- dataset['scale'] = opt['scale'] # broadcast
59
- dataset['n_channels'] = opt['n_channels'] # broadcast
60
- if 'dataroot_H' in dataset and dataset['dataroot_H'] is not None:
61
- dataset['dataroot_H'] = os.path.expanduser(dataset['dataroot_H'])
62
- if 'dataroot_L' in dataset and dataset['dataroot_L'] is not None:
63
- dataset['dataroot_L'] = os.path.expanduser(dataset['dataroot_L'])
64
-
65
- # ----------------------------------------
66
- # path
67
- # ----------------------------------------
68
- for key, path in opt['path'].items():
69
- if path and key in opt['path']:
70
- opt['path'][key] = os.path.expanduser(path)
71
-
72
- path_task = os.path.join(opt['path']['root'], opt['task'])
73
- opt['path']['task'] = path_task
74
- opt['path']['log'] = path_task
75
- opt['path']['options'] = os.path.join(path_task, 'options')
76
-
77
- if is_train:
78
- opt['path']['models'] = os.path.join(path_task, 'models')
79
- opt['path']['images'] = os.path.join(path_task, 'images')
80
- else: # test
81
- opt['path']['images'] = os.path.join(path_task, 'test_images')
82
-
83
- # ----------------------------------------
84
- # network
85
- # ----------------------------------------
86
- opt['netG']['scale'] = opt['scale'] if 'scale' in opt else 1
87
-
88
- # ----------------------------------------
89
- # GPU devices
90
- # ----------------------------------------
91
- gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
92
- os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
93
- print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
94
-
95
- # ----------------------------------------
96
- # default setting for distributeddataparallel
97
- # ----------------------------------------
98
- if 'find_unused_parameters' not in opt:
99
- opt['find_unused_parameters'] = True
100
- if 'use_static_graph' not in opt:
101
- opt['use_static_graph'] = False
102
- if 'dist' not in opt:
103
- opt['dist'] = False
104
- opt['num_gpu'] = len(opt['gpu_ids'])
105
- print('number of GPUs is: ' + str(opt['num_gpu']))
106
-
107
- # ----------------------------------------
108
- # default setting for perceptual loss
109
- # ----------------------------------------
110
- if 'F_feature_layer' not in opt['train']:
111
- opt['train']['F_feature_layer'] = 34 # 25; [2,7,16,25,34]
112
- if 'F_weights' not in opt['train']:
113
- opt['train']['F_weights'] = 1.0 # 1.0; [0.1,0.1,1.0,1.0,1.0]
114
- if 'F_lossfn_type' not in opt['train']:
115
- opt['train']['F_lossfn_type'] = 'l1'
116
- if 'F_use_input_norm' not in opt['train']:
117
- opt['train']['F_use_input_norm'] = True
118
- if 'F_use_range_norm' not in opt['train']:
119
- opt['train']['F_use_range_norm'] = False
120
-
121
- # ----------------------------------------
122
- # default setting for optimizer
123
- # ----------------------------------------
124
- if 'G_optimizer_type' not in opt['train']:
125
- opt['train']['G_optimizer_type'] = "adam"
126
- if 'G_optimizer_betas' not in opt['train']:
127
- opt['train']['G_optimizer_betas'] = [0.9,0.999]
128
- if 'G_scheduler_restart_weights' not in opt['train']:
129
- opt['train']['G_scheduler_restart_weights'] = 1
130
- if 'G_optimizer_wd' not in opt['train']:
131
- opt['train']['G_optimizer_wd'] = 0
132
- if 'G_optimizer_reuse' not in opt['train']:
133
- opt['train']['G_optimizer_reuse'] = False
134
- if 'netD' in opt and 'D_optimizer_reuse' not in opt['train']:
135
- opt['train']['D_optimizer_reuse'] = False
136
-
137
- # ----------------------------------------
138
- # default setting of strict for model loading
139
- # ----------------------------------------
140
- if 'G_param_strict' not in opt['train']:
141
- opt['train']['G_param_strict'] = True
142
- if 'netD' in opt and 'D_param_strict' not in opt['path']:
143
- opt['train']['D_param_strict'] = True
144
- if 'E_param_strict' not in opt['path']:
145
- opt['train']['E_param_strict'] = True
146
-
147
- # ----------------------------------------
148
- # Exponential Moving Average
149
- # ----------------------------------------
150
- if 'E_decay' not in opt['train']:
151
- opt['train']['E_decay'] = 0
152
-
153
- # ----------------------------------------
154
- # default setting for discriminator
155
- # ----------------------------------------
156
- if 'netD' in opt:
157
- if 'net_type' not in opt['netD']:
158
- opt['netD']['net_type'] = 'discriminator_patchgan' # discriminator_unet
159
- if 'in_nc' not in opt['netD']:
160
- opt['netD']['in_nc'] = 3
161
- if 'base_nc' not in opt['netD']:
162
- opt['netD']['base_nc'] = 64
163
- if 'n_layers' not in opt['netD']:
164
- opt['netD']['n_layers'] = 3
165
- if 'norm_type' not in opt['netD']:
166
- opt['netD']['norm_type'] = 'spectral'
167
-
168
-
169
- return opt
170
-
171
-
172
- def find_last_checkpoint(save_dir, net_type='G', pretrained_path=None):
173
- """
174
- Args:
175
- save_dir: model folder
176
- net_type: 'G' or 'D' or 'optimizerG' or 'optimizerD'
177
- pretrained_path: pretrained model path. If save_dir does not have any model, load from pretrained_path
178
-
179
- Return:
180
- init_iter: iteration number
181
- init_path: model path
182
- """
183
- file_list = glob.glob(os.path.join(save_dir, '*_{}.pth'.format(net_type)))
184
- if file_list:
185
- iter_exist = []
186
- for file_ in file_list:
187
- iter_current = re.findall(r"(\d+)_{}.pth".format(net_type), file_)
188
- iter_exist.append(int(iter_current[0]))
189
- init_iter = max(iter_exist)
190
- init_path = os.path.join(save_dir, '{}_{}.pth'.format(init_iter, net_type))
191
- else:
192
- init_iter = 0
193
- init_path = pretrained_path
194
- return init_iter, init_path
195
-
196
-
197
- '''
198
- # --------------------------------------------
199
- # convert the opt into json file
200
- # --------------------------------------------
201
- '''
202
-
203
-
204
- def save(opt):
205
- opt_path = opt['opt_path']
206
- opt_path_copy = opt['path']['options']
207
- dirname, filename_ext = os.path.split(opt_path)
208
- filename, ext = os.path.splitext(filename_ext)
209
- dump_path = os.path.join(opt_path_copy, filename+get_timestamp()+ext)
210
- with open(dump_path, 'w') as dump_file:
211
- json.dump(opt, dump_file, indent=2)
212
-
213
-
214
- '''
215
- # --------------------------------------------
216
- # dict to string for logger
217
- # --------------------------------------------
218
- '''
219
-
220
-
221
- def dict2str(opt, indent_l=1):
222
- msg = ''
223
- for k, v in opt.items():
224
- if isinstance(v, dict):
225
- msg += ' ' * (indent_l * 2) + k + ':[\n'
226
- msg += dict2str(v, indent_l + 1)
227
- msg += ' ' * (indent_l * 2) + ']\n'
228
- else:
229
- msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n'
230
- return msg
231
-
232
-
233
- '''
234
- # --------------------------------------------
235
- # convert OrderedDict to NoneDict,
236
- # return None for missing key
237
- # --------------------------------------------
238
- '''
239
-
240
-
241
- def dict_to_nonedict(opt):
242
- if isinstance(opt, dict):
243
- new_opt = dict()
244
- for key, sub_opt in opt.items():
245
- new_opt[key] = dict_to_nonedict(sub_opt)
246
- return NoneDict(**new_opt)
247
- elif isinstance(opt, list):
248
- return [dict_to_nonedict(sub_opt) for sub_opt in opt]
249
- else:
250
- return opt
251
-
252
-
253
- class NoneDict(dict):
254
- def __missing__(self, key):
255
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
core/data/deg_kair_utils/utils_params.py DELETED
@@ -1,135 +0,0 @@
1
- import torch
2
-
3
- import torchvision
4
-
5
- from models import basicblock as B
6
-
7
- def show_kv(net):
8
- for k, v in net.items():
9
- print(k)
10
-
11
- # should run train debug mode first to get an initial model
12
- #crt_net = torch.load('../../experiments/debug_SRResNet_bicx4_in3nf64nb16/models/8_G.pth')
13
- #
14
- #for k, v in crt_net.items():
15
- # print(k)
16
- #for k, v in crt_net.items():
17
- # if k in pretrained_net:
18
- # crt_net[k] = pretrained_net[k]
19
- # print('replace ... ', k)
20
-
21
- # x2 -> x4
22
- #crt_net['model.5.weight'] = pretrained_net['model.2.weight']
23
- #crt_net['model.5.bias'] = pretrained_net['model.2.bias']
24
- #crt_net['model.8.weight'] = pretrained_net['model.5.weight']
25
- #crt_net['model.8.bias'] = pretrained_net['model.5.bias']
26
- #crt_net['model.10.weight'] = pretrained_net['model.7.weight']
27
- #crt_net['model.10.bias'] = pretrained_net['model.7.bias']
28
- #torch.save(crt_net, '../pretrained_tmp.pth')
29
-
30
- # x2 -> x3
31
- '''
32
- in_filter = pretrained_net['model.2.weight'] # 256, 64, 3, 3
33
- new_filter = torch.Tensor(576, 64, 3, 3)
34
- new_filter[0:256, :, :, :] = in_filter
35
- new_filter[256:512, :, :, :] = in_filter
36
- new_filter[512:, :, :, :] = in_filter[0:576-512, :, :, :]
37
- crt_net['model.2.weight'] = new_filter
38
-
39
- in_bias = pretrained_net['model.2.bias'] # 256, 64, 3, 3
40
- new_bias = torch.Tensor(576)
41
- new_bias[0:256] = in_bias
42
- new_bias[256:512] = in_bias
43
- new_bias[512:] = in_bias[0:576 - 512]
44
- crt_net['model.2.bias'] = new_bias
45
-
46
- torch.save(crt_net, '../pretrained_tmp.pth')
47
- '''
48
-
49
- # x2 -> x8
50
- '''
51
- crt_net['model.5.weight'] = pretrained_net['model.2.weight']
52
- crt_net['model.5.bias'] = pretrained_net['model.2.bias']
53
- crt_net['model.8.weight'] = pretrained_net['model.2.weight']
54
- crt_net['model.8.bias'] = pretrained_net['model.2.bias']
55
- crt_net['model.11.weight'] = pretrained_net['model.5.weight']
56
- crt_net['model.11.bias'] = pretrained_net['model.5.bias']
57
- crt_net['model.13.weight'] = pretrained_net['model.7.weight']
58
- crt_net['model.13.bias'] = pretrained_net['model.7.bias']
59
- torch.save(crt_net, '../pretrained_tmp.pth')
60
- '''
61
-
62
- # x3/4/8 RGB -> Y
63
-
64
- def rgb2gray_net(net, only_input=True):
65
-
66
- if only_input:
67
- in_filter = net['0.weight']
68
- in_new_filter = in_filter[:,0,:,:]*0.2989 + in_filter[:,1,:,:]*0.587 + in_filter[:,2,:,:]*0.114
69
- in_new_filter.unsqueeze_(1)
70
- net['0.weight'] = in_new_filter
71
-
72
- # out_filter = pretrained_net['model.13.weight']
73
- # out_new_filter = out_filter[0, :, :, :] * 0.2989 + out_filter[1, :, :, :] * 0.587 + \
74
- # out_filter[2, :, :, :] * 0.114
75
- # out_new_filter.unsqueeze_(0)
76
- # crt_net['model.13.weight'] = out_new_filter
77
- # out_bias = pretrained_net['model.13.bias']
78
- # out_new_bias = out_bias[0] * 0.2989 + out_bias[1] * 0.587 + out_bias[2] * 0.114
79
- # out_new_bias = torch.Tensor(1).fill_(out_new_bias)
80
- # crt_net['model.13.bias'] = out_new_bias
81
-
82
- # torch.save(crt_net, '../pretrained_tmp.pth')
83
-
84
- return net
85
-
86
-
87
-
88
- if __name__ == '__main__':
89
-
90
- net = torchvision.models.vgg19(pretrained=True)
91
- for k,v in net.features.named_parameters():
92
- if k=='0.weight':
93
- in_new_filter = v[:,0,:,:]*0.2989 + v[:,1,:,:]*0.587 + v[:,2,:,:]*0.114
94
- in_new_filter.unsqueeze_(1)
95
- v = in_new_filter
96
- print(v.shape)
97
- print(v[0,0,0,0])
98
- if k=='0.bias':
99
- in_new_bias = v
100
- print(v[0])
101
-
102
- print(net.features[0])
103
-
104
- net.features[0] = B.conv(1, 64, mode='C')
105
-
106
- print(net.features[0])
107
- net.features[0].weight.data=in_new_filter
108
- net.features[0].bias.data=in_new_bias
109
-
110
- for k,v in net.features.named_parameters():
111
- if k=='0.weight':
112
- print(v[0,0,0,0])
113
- if k=='0.bias':
114
- print(v[0])
115
-
116
- # transfer parameters of old model to new one
117
- model_old = torch.load(model_path)
118
- state_dict = model.state_dict()
119
- for ((key, param),(key2, param2)) in zip(model_old.items(), state_dict.items()):
120
- state_dict[key2] = param
121
- print([key, key2])
122
- # print([param.size(), param2.size()])
123
- torch.save(state_dict, 'model_new.pth')
124
-
125
-
126
- # rgb2gray_net(net)
127
-
128
-
129
-
130
-
131
-
132
-
133
-
134
-
135
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
core/data/deg_kair_utils/utils_receptivefield.py DELETED
@@ -1,62 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
-
3
- # online calculation: https://fomoro.com/research/article/receptive-field-calculator#
4
-
5
- # [filter size, stride, padding]
6
- #Assume the two dimensions are the same
7
- #Each kernel requires the following parameters:
8
- # - k_i: kernel size
9
- # - s_i: stride
10
- # - p_i: padding (if padding is uneven, right padding will higher than left padding; "SAME" option in tensorflow)
11
- #
12
- #Each layer i requires the following parameters to be fully represented:
13
- # - n_i: number of feature (data layer has n_1 = imagesize )
14
- # - j_i: distance (projected to image pixel distance) between center of two adjacent features
15
- # - r_i: receptive field of a feature in layer i
16
- # - start_i: position of the first feature's receptive field in layer i (idx start from 0, negative means the center fall into padding)
17
-
18
- import math
19
-
20
- def outFromIn(conv, layerIn):
21
- n_in = layerIn[0]
22
- j_in = layerIn[1]
23
- r_in = layerIn[2]
24
- start_in = layerIn[3]
25
- k = conv[0]
26
- s = conv[1]
27
- p = conv[2]
28
-
29
- n_out = math.floor((n_in - k + 2*p)/s) + 1
30
- actualP = (n_out-1)*s - n_in + k
31
- pR = math.ceil(actualP/2)
32
- pL = math.floor(actualP/2)
33
-
34
- j_out = j_in * s
35
- r_out = r_in + (k - 1)*j_in
36
- start_out = start_in + ((k-1)/2 - pL)*j_in
37
- return n_out, j_out, r_out, start_out
38
-
39
- def printLayer(layer, layer_name):
40
- print(layer_name + ":")
41
- print(" n features: %s jump: %s receptive size: %s start: %s " % (layer[0], layer[1], layer[2], layer[3]))
42
-
43
-
44
-
45
- layerInfos = []
46
- if __name__ == '__main__':
47
-
48
- convnet = [[3,1,1],[3,1,1],[3,1,1],[4,2,1],[2,2,0],[3,1,1]]
49
- layer_names = ['conv1','conv2','conv3','conv4','conv5','conv6','conv7','conv8','conv9','conv10','conv11','conv12']
50
- imsize = 128
51
-
52
- print ("-------Net summary------")
53
- currentLayer = [imsize, 1, 1, 0.5]
54
- printLayer(currentLayer, "input image")
55
- for i in range(len(convnet)):
56
- currentLayer = outFromIn(convnet[i], currentLayer)
57
- layerInfos.append(currentLayer)
58
- printLayer(currentLayer, layer_names[i])
59
-
60
-
61
- # run utils/utils_receptivefield.py
62
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
core/data/deg_kair_utils/utils_regularizers.py DELETED
@@ -1,104 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
-
4
-
5
- '''
6
- # --------------------------------------------
7
- # Kai Zhang (github: https://github.com/cszn)
8
- # 03/Mar/2019
9
- # --------------------------------------------
10
- '''
11
-
12
-
13
- # --------------------------------------------
14
- # SVD Orthogonal Regularization
15
- # --------------------------------------------
16
- def regularizer_orth(m):
17
- """
18
- # ----------------------------------------
19
- # SVD Orthogonal Regularization
20
- # ----------------------------------------
21
- # Applies regularization to the training by performing the
22
- # orthogonalization technique described in the paper
23
- # This function is to be called by the torch.nn.Module.apply() method,
24
- # which applies svd_orthogonalization() to every layer of the model.
25
- # usage: net.apply(regularizer_orth)
26
- # ----------------------------------------
27
- """
28
- classname = m.__class__.__name__
29
- if classname.find('Conv') != -1:
30
- w = m.weight.data.clone()
31
- c_out, c_in, f1, f2 = w.size()
32
- # dtype = m.weight.data.type()
33
- w = w.permute(2, 3, 1, 0).contiguous().view(f1*f2*c_in, c_out)
34
- # self.netG.apply(svd_orthogonalization)
35
- u, s, v = torch.svd(w)
36
- s[s > 1.5] = s[s > 1.5] - 1e-4
37
- s[s < 0.5] = s[s < 0.5] + 1e-4
38
- w = torch.mm(torch.mm(u, torch.diag(s)), v.t())
39
- m.weight.data = w.view(f1, f2, c_in, c_out).permute(3, 2, 0, 1) # .type(dtype)
40
- else:
41
- pass
42
-
43
-
44
- # --------------------------------------------
45
- # SVD Orthogonal Regularization
46
- # --------------------------------------------
47
- def regularizer_orth2(m):
48
- """
49
- # ----------------------------------------
50
- # Applies regularization to the training by performing the
51
- # orthogonalization technique described in the paper
52
- # This function is to be called by the torch.nn.Module.apply() method,
53
- # which applies svd_orthogonalization() to every layer of the model.
54
- # usage: net.apply(regularizer_orth2)
55
- # ----------------------------------------
56
- """
57
- classname = m.__class__.__name__
58
- if classname.find('Conv') != -1:
59
- w = m.weight.data.clone()
60
- c_out, c_in, f1, f2 = w.size()
61
- # dtype = m.weight.data.type()
62
- w = w.permute(2, 3, 1, 0).contiguous().view(f1*f2*c_in, c_out)
63
- u, s, v = torch.svd(w)
64
- s_mean = s.mean()
65
- s[s > 1.5*s_mean] = s[s > 1.5*s_mean] - 1e-4
66
- s[s < 0.5*s_mean] = s[s < 0.5*s_mean] + 1e-4
67
- w = torch.mm(torch.mm(u, torch.diag(s)), v.t())
68
- m.weight.data = w.view(f1, f2, c_in, c_out).permute(3, 2, 0, 1) # .type(dtype)
69
- else:
70
- pass
71
-
72
-
73
-
74
- def regularizer_clip(m):
75
- """
76
- # ----------------------------------------
77
- # usage: net.apply(regularizer_clip)
78
- # ----------------------------------------
79
- """
80
- eps = 1e-4
81
- c_min = -1.5
82
- c_max = 1.5
83
-
84
- classname = m.__class__.__name__
85
- if classname.find('Conv') != -1 or classname.find('Linear') != -1:
86
- w = m.weight.data.clone()
87
- w[w > c_max] -= eps
88
- w[w < c_min] += eps
89
- m.weight.data = w
90
-
91
- if m.bias is not None:
92
- b = m.bias.data.clone()
93
- b[b > c_max] -= eps
94
- b[b < c_min] += eps
95
- m.bias.data = b
96
-
97
- # elif classname.find('BatchNorm2d') != -1:
98
- #
99
- # rv = m.running_var.data.clone()
100
- # rm = m.running_mean.data.clone()
101
- #
102
- # if m.affine:
103
- # m.weight.data
104
- # m.bias.data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
core/data/deg_kair_utils/utils_sisr.py DELETED
@@ -1,848 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- from utils import utils_image as util
3
- import random
4
-
5
- import scipy
6
- import scipy.stats as ss
7
- import scipy.io as io
8
- from scipy import ndimage
9
- from scipy.interpolate import interp2d
10
-
11
- import numpy as np
12
- import torch
13
-
14
-
15
- """
16
- # --------------------------------------------
17
- # Super-Resolution
18
- # --------------------------------------------
19
- #
20
- # Kai Zhang (cskaizhang@gmail.com)
21
- # https://github.com/cszn
22
- # modified by Kai Zhang (github: https://github.com/cszn)
23
- # 03/03/2020
24
- # --------------------------------------------
25
- """
26
-
27
-
28
- """
29
- # --------------------------------------------
30
- # anisotropic Gaussian kernels
31
- # --------------------------------------------
32
- """
33
-
34
-
35
- def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
36
- """ generate an anisotropic Gaussian kernel
37
- Args:
38
- ksize : e.g., 15, kernel size
39
- theta : [0, pi], rotation angle range
40
- l1 : [0.1,50], scaling of eigenvalues
41
- l2 : [0.1,l1], scaling of eigenvalues
42
- If l1 = l2, will get an isotropic Gaussian kernel.
43
- Returns:
44
- k : kernel
45
- """
46
-
47
- v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
48
- V = np.array([[v[0], v[1]], [v[1], -v[0]]])
49
- D = np.array([[l1, 0], [0, l2]])
50
- Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
51
- k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
52
-
53
- return k
54
-
55
-
56
- def gm_blur_kernel(mean, cov, size=15):
57
- center = size / 2.0 + 0.5
58
- k = np.zeros([size, size])
59
- for y in range(size):
60
- for x in range(size):
61
- cy = y - center + 1
62
- cx = x - center + 1
63
- k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
64
-
65
- k = k / np.sum(k)
66
- return k
67
-
68
-
69
- """
70
- # --------------------------------------------
71
- # calculate PCA projection matrix
72
- # --------------------------------------------
73
- """
74
-
75
-
76
- def get_pca_matrix(x, dim_pca=15):
77
- """
78
- Args:
79
- x: 225x10000 matrix
80
- dim_pca: 15
81
- Returns:
82
- pca_matrix: 15x225
83
- """
84
- C = np.dot(x, x.T)
85
- w, v = scipy.linalg.eigh(C)
86
- pca_matrix = v[:, -dim_pca:].T
87
-
88
- return pca_matrix
89
-
90
-
91
- def show_pca(x):
92
- """
93
- x: PCA projection matrix, e.g., 15x225
94
- """
95
- for i in range(x.shape[0]):
96
- xc = np.reshape(x[i, :], (int(np.sqrt(x.shape[1])), -1), order="F")
97
- util.surf(xc)
98
-
99
-
100
- def cal_pca_matrix(path='PCA_matrix.mat', ksize=15, l_max=12.0, dim_pca=15, num_samples=500):
101
- kernels = np.zeros([ksize*ksize, num_samples], dtype=np.float32)
102
- for i in range(num_samples):
103
-
104
- theta = np.pi*np.random.rand(1)
105
- l1 = 0.1+l_max*np.random.rand(1)
106
- l2 = 0.1+(l1-0.1)*np.random.rand(1)
107
-
108
- k = anisotropic_Gaussian(ksize=ksize, theta=theta[0], l1=l1[0], l2=l2[0])
109
-
110
- # util.imshow(k)
111
-
112
- kernels[:, i] = np.reshape(k, (-1), order="F") # k.flatten(order='F')
113
-
114
- # io.savemat('k.mat', {'k': kernels})
115
-
116
- pca_matrix = get_pca_matrix(kernels, dim_pca=dim_pca)
117
-
118
- io.savemat(path, {'p': pca_matrix})
119
-
120
- return pca_matrix
121
-
122
-
123
- """
124
- # --------------------------------------------
125
- # shifted anisotropic Gaussian kernels
126
- # --------------------------------------------
127
- """
128
-
129
-
130
- def shifted_anisotropic_Gaussian(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
131
- """"
132
- # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
133
- # Kai Zhang
134
- # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
135
- # max_var = 2.5 * sf
136
- """
137
- # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
138
- lambda_1 = min_var + np.random.rand() * (max_var - min_var)
139
- lambda_2 = min_var + np.random.rand() * (max_var - min_var)
140
- theta = np.random.rand() * np.pi # random theta
141
- noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
142
-
143
- # Set COV matrix using Lambdas and Theta
144
- LAMBDA = np.diag([lambda_1, lambda_2])
145
- Q = np.array([[np.cos(theta), -np.sin(theta)],
146
- [np.sin(theta), np.cos(theta)]])
147
- SIGMA = Q @ LAMBDA @ Q.T
148
- INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
149
-
150
- # Set expectation position (shifting kernel for aligned image)
151
- MU = k_size // 2 - 0.5*(scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
152
- MU = MU[None, None, :, None]
153
-
154
- # Create meshgrid for Gaussian
155
- [X,Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
156
- Z = np.stack([X, Y], 2)[:, :, :, None]
157
-
158
- # Calcualte Gaussian for every pixel of the kernel
159
- ZZ = Z-MU
160
- ZZ_t = ZZ.transpose(0,1,3,2)
161
- raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
162
-
163
- # shift the kernel so it will be centered
164
- #raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
165
-
166
- # Normalize the kernel and return
167
- #kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
168
- kernel = raw_kernel / np.sum(raw_kernel)
169
- return kernel
170
-
171
-
172
- def gen_kernel(k_size=np.array([25, 25]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=12., noise_level=0):
173
- """"
174
- # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
175
- # Kai Zhang
176
- # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
177
- # max_var = 2.5 * sf
178
- """
179
- sf = random.choice([1, 2, 3, 4])
180
- scale_factor = np.array([sf, sf])
181
- # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
182
- lambda_1 = min_var + np.random.rand() * (max_var - min_var)
183
- lambda_2 = min_var + np.random.rand() * (max_var - min_var)
184
- theta = np.random.rand() * np.pi # random theta
185
- noise = 0#-noise_level + np.random.rand(*k_size) * noise_level * 2
186
-
187
- # Set COV matrix using Lambdas and Theta
188
- LAMBDA = np.diag([lambda_1, lambda_2])
189
- Q = np.array([[np.cos(theta), -np.sin(theta)],
190
- [np.sin(theta), np.cos(theta)]])
191
- SIGMA = Q @ LAMBDA @ Q.T
192
- INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
193
-
194
- # Set expectation position (shifting kernel for aligned image)
195
- MU = k_size // 2 - 0.5*(scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
196
- MU = MU[None, None, :, None]
197
-
198
- # Create meshgrid for Gaussian
199
- [X,Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
200
- Z = np.stack([X, Y], 2)[:, :, :, None]
201
-
202
- # Calcualte Gaussian for every pixel of the kernel
203
- ZZ = Z-MU
204
- ZZ_t = ZZ.transpose(0,1,3,2)
205
- raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
206
-
207
- # shift the kernel so it will be centered
208
- #raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
209
-
210
- # Normalize the kernel and return
211
- #kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
212
- kernel = raw_kernel / np.sum(raw_kernel)
213
- return kernel
214
-
215
-
216
- """
217
- # --------------------------------------------
218
- # degradation models
219
- # --------------------------------------------
220
- """
221
-
222
-
223
- def bicubic_degradation(x, sf=3):
224
- '''
225
- Args:
226
- x: HxWxC image, [0, 1]
227
- sf: down-scale factor
228
- Return:
229
- bicubicly downsampled LR image
230
- '''
231
- x = util.imresize_np(x, scale=1/sf)
232
- return x
233
-
234
-
235
- def srmd_degradation(x, k, sf=3):
236
- ''' blur + bicubic downsampling
237
- Args:
238
- x: HxWxC image, [0, 1]
239
- k: hxw, double
240
- sf: down-scale factor
241
- Return:
242
- downsampled LR image
243
- Reference:
244
- @inproceedings{zhang2018learning,
245
- title={Learning a single convolutional super-resolution network for multiple degradations},
246
- author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
247
- booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
248
- pages={3262--3271},
249
- year={2018}
250
- }
251
- '''
252
- x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
253
- x = bicubic_degradation(x, sf=sf)
254
- return x
255
-
256
-
257
- def dpsr_degradation(x, k, sf=3):
258
-
259
- ''' bicubic downsampling + blur
260
- Args:
261
- x: HxWxC image, [0, 1]
262
- k: hxw, double
263
- sf: down-scale factor
264
- Return:
265
- downsampled LR image
266
- Reference:
267
- @inproceedings{zhang2019deep,
268
- title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
269
- author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
270
- booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
271
- pages={1671--1681},
272
- year={2019}
273
- }
274
- '''
275
- x = bicubic_degradation(x, sf=sf)
276
- x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
277
- return x
278
-
279
-
280
- def classical_degradation(x, k, sf=3):
281
- ''' blur + downsampling
282
-
283
- Args:
284
- x: HxWxC image, [0, 1]/[0, 255]
285
- k: hxw, double
286
- sf: down-scale factor
287
-
288
- Return:
289
- downsampled LR image
290
- '''
291
- x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
292
- #x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
293
- st = 0
294
- return x[st::sf, st::sf, ...]
295
-
296
-
297
- def modcrop_np(img, sf):
298
- '''
299
- Args:
300
- img: numpy image, WxH or WxHxC
301
- sf: scale factor
302
- Return:
303
- cropped image
304
- '''
305
- w, h = img.shape[:2]
306
- im = np.copy(img)
307
- return im[:w - w % sf, :h - h % sf, ...]
308
-
309
-
310
- '''
311
- # =================
312
- # Numpy
313
- # =================
314
- '''
315
-
316
-
317
- def shift_pixel(x, sf, upper_left=True):
318
- """shift pixel for super-resolution with different scale factors
319
- Args:
320
- x: WxHxC or WxH, image or kernel
321
- sf: scale factor
322
- upper_left: shift direction
323
- """
324
- h, w = x.shape[:2]
325
- shift = (sf-1)*0.5
326
- xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
327
- if upper_left:
328
- x1 = xv + shift
329
- y1 = yv + shift
330
- else:
331
- x1 = xv - shift
332
- y1 = yv - shift
333
-
334
- x1 = np.clip(x1, 0, w-1)
335
- y1 = np.clip(y1, 0, h-1)
336
-
337
- if x.ndim == 2:
338
- x = interp2d(xv, yv, x)(x1, y1)
339
- if x.ndim == 3:
340
- for i in range(x.shape[-1]):
341
- x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
342
-
343
- return x
344
-
345
-
346
- '''
347
- # =================
348
- # pytorch
349
- # =================
350
- '''
351
-
352
-
353
- def splits(a, sf):
354
- '''
355
- a: tensor NxCxWxHx2
356
- sf: scale factor
357
- out: tensor NxCx(W/sf)x(H/sf)x2x(sf^2)
358
- '''
359
- b = torch.stack(torch.chunk(a, sf, dim=2), dim=5)
360
- b = torch.cat(torch.chunk(b, sf, dim=3), dim=5)
361
- return b
362
-
363
-
364
- def c2c(x):
365
- return torch.from_numpy(np.stack([np.float32(x.real), np.float32(x.imag)], axis=-1))
366
-
367
-
368
- def r2c(x):
369
- return torch.stack([x, torch.zeros_like(x)], -1)
370
-
371
-
372
- def cdiv(x, y):
373
- a, b = x[..., 0], x[..., 1]
374
- c, d = y[..., 0], y[..., 1]
375
- cd2 = c**2 + d**2
376
- return torch.stack([(a*c+b*d)/cd2, (b*c-a*d)/cd2], -1)
377
-
378
-
379
- def csum(x, y):
380
- return torch.stack([x[..., 0] + y, x[..., 1]], -1)
381
-
382
-
383
- def cabs(x):
384
- return torch.pow(x[..., 0]**2+x[..., 1]**2, 0.5)
385
-
386
-
387
- def cmul(t1, t2):
388
- '''
389
- complex multiplication
390
- t1: NxCxHxWx2
391
- output: NxCxHxWx2
392
- '''
393
- real1, imag1 = t1[..., 0], t1[..., 1]
394
- real2, imag2 = t2[..., 0], t2[..., 1]
395
- return torch.stack([real1 * real2 - imag1 * imag2, real1 * imag2 + imag1 * real2], dim=-1)
396
-
397
-
398
- def cconj(t, inplace=False):
399
- '''
400
- # complex's conjugation
401
- t: NxCxHxWx2
402
- output: NxCxHxWx2
403
- '''
404
- c = t.clone() if not inplace else t
405
- c[..., 1] *= -1
406
- return c
407
-
408
-
409
- def rfft(t):
410
- return torch.rfft(t, 2, onesided=False)
411
-
412
-
413
- def irfft(t):
414
- return torch.irfft(t, 2, onesided=False)
415
-
416
-
417
- def fft(t):
418
- return torch.fft(t, 2)
419
-
420
-
421
- def ifft(t):
422
- return torch.ifft(t, 2)
423
-
424
-
425
- def p2o(psf, shape):
426
- '''
427
- Args:
428
- psf: NxCxhxw
429
- shape: [H,W]
430
-
431
- Returns:
432
- otf: NxCxHxWx2
433
- '''
434
- otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf)
435
- otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf)
436
- for axis, axis_size in enumerate(psf.shape[2:]):
437
- otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2)
438
- otf = torch.rfft(otf, 2, onesided=False)
439
- n_ops = torch.sum(torch.tensor(psf.shape).type_as(psf) * torch.log2(torch.tensor(psf.shape).type_as(psf)))
440
- otf[...,1][torch.abs(otf[...,1])<n_ops*2.22e-16] = torch.tensor(0).type_as(psf)
441
- return otf
442
-
443
-
444
- '''
445
- # =================
446
- PyTorch
447
- # =================
448
- '''
449
-
450
- def INVLS_pytorch(FB, FBC, F2B, FR, tau, sf=2):
451
- '''
452
- FB: NxCxWxHx2
453
- F2B: NxCxWxHx2
454
-
455
- x1 = FB.*FR;
456
- FBR = BlockMM(nr,nc,Nb,m,x1);
457
- invW = BlockMM(nr,nc,Nb,m,F2B);
458
- invWBR = FBR./(invW + tau*Nb);
459
- fun = @(block_struct) block_struct.data.*invWBR;
460
- FCBinvWBR = blockproc(FBC,[nr,nc],fun);
461
- FX = (FR-FCBinvWBR)/tau;
462
- Xest = real(ifft2(FX));
463
- '''
464
- x1 = cmul(FB, FR)
465
- FBR = torch.mean(splits(x1, sf), dim=-1, keepdim=False)
466
- invW = torch.mean(splits(F2B, sf), dim=-1, keepdim=False)
467
- invWBR = cdiv(FBR, csum(invW, tau))
468
- FCBinvWBR = cmul(FBC, invWBR.repeat(1,1,sf,sf,1))
469
- FX = (FR-FCBinvWBR)/tau
470
- Xest = torch.irfft(FX, 2, onesided=False)
471
- return Xest
472
-
473
-
474
- def real2complex(x):
475
- return torch.stack([x, torch.zeros_like(x)], -1)
476
-
477
-
478
- def modcrop(img, sf):
479
- '''
480
- img: tensor image, NxCxWxH or CxWxH or WxH
481
- sf: scale factor
482
- '''
483
- w, h = img.shape[-2:]
484
- im = img.clone()
485
- return im[..., :w - w % sf, :h - h % sf]
486
-
487
-
488
- def upsample(x, sf=3, center=False):
489
- '''
490
- x: tensor image, NxCxWxH
491
- '''
492
- st = (sf-1)//2 if center else 0
493
- z = torch.zeros((x.shape[0], x.shape[1], x.shape[2]*sf, x.shape[3]*sf)).type_as(x)
494
- z[..., st::sf, st::sf].copy_(x)
495
- return z
496
-
497
-
498
- def downsample(x, sf=3, center=False):
499
- st = (sf-1)//2 if center else 0
500
- return x[..., st::sf, st::sf]
501
-
502
-
503
- def circular_pad(x, pad):
504
- '''
505
- # x[N, 1, W, H] -> x[N, 1, W + 2 pad, H + 2 pad] (pariodic padding)
506
- '''
507
- x = torch.cat([x, x[:, :, 0:pad, :]], dim=2)
508
- x = torch.cat([x, x[:, :, :, 0:pad]], dim=3)
509
- x = torch.cat([x[:, :, -2 * pad:-pad, :], x], dim=2)
510
- x = torch.cat([x[:, :, :, -2 * pad:-pad], x], dim=3)
511
- return x
512
-
513
-
514
- def pad_circular(input, padding):
515
- # type: (Tensor, List[int]) -> Tensor
516
- """
517
- Arguments
518
- :param input: tensor of shape :math:`(N, C_{\text{in}}, H, [W, D]))`
519
- :param padding: (tuple): m-elem tuple where m is the degree of convolution
520
- Returns
521
- :return: tensor of shape :math:`(N, C_{\text{in}}, [D + 2 * padding[0],
522
- H + 2 * padding[1]], W + 2 * padding[2]))`
523
- """
524
- offset = 3
525
- for dimension in range(input.dim() - offset + 1):
526
- input = dim_pad_circular(input, padding[dimension], dimension + offset)
527
- return input
528
-
529
-
530
- def dim_pad_circular(input, padding, dimension):
531
- # type: (Tensor, int, int) -> Tensor
532
- input = torch.cat([input, input[[slice(None)] * (dimension - 1) +
533
- [slice(0, padding)]]], dim=dimension - 1)
534
- input = torch.cat([input[[slice(None)] * (dimension - 1) +
535
- [slice(-2 * padding, -padding)]], input], dim=dimension - 1)
536
- return input
537
-
538
-
539
- def imfilter(x, k):
540
- '''
541
- x: image, NxcxHxW
542
- k: kernel, cx1xhxw
543
- '''
544
- x = pad_circular(x, padding=((k.shape[-2]-1)//2, (k.shape[-1]-1)//2))
545
- x = torch.nn.functional.conv2d(x, k, groups=x.shape[1])
546
- return x
547
-
548
-
549
- def G(x, k, sf=3, center=False):
550
- '''
551
- x: image, NxcxHxW
552
- k: kernel, cx1xhxw
553
- sf: scale factor
554
- center: the first one or the moddle one
555
-
556
- Matlab function:
557
- tmp = imfilter(x,h,'circular');
558
- y = downsample2(tmp,K);
559
- '''
560
- x = downsample(imfilter(x, k), sf=sf, center=center)
561
- return x
562
-
563
-
564
- def Gt(x, k, sf=3, center=False):
565
- '''
566
- x: image, NxcxHxW
567
- k: kernel, cx1xhxw
568
- sf: scale factor
569
- center: the first one or the moddle one
570
-
571
- Matlab function:
572
- tmp = upsample2(x,K);
573
- y = imfilter(tmp,h,'circular');
574
- '''
575
- x = imfilter(upsample(x, sf=sf, center=center), k)
576
- return x
577
-
578
-
579
- def interpolation_down(x, sf, center=False):
580
- mask = torch.zeros_like(x)
581
- if center:
582
- start = torch.tensor((sf-1)//2)
583
- mask[..., start::sf, start::sf] = torch.tensor(1).type_as(x)
584
- LR = x[..., start::sf, start::sf]
585
- else:
586
- mask[..., ::sf, ::sf] = torch.tensor(1).type_as(x)
587
- LR = x[..., ::sf, ::sf]
588
- y = x.mul(mask)
589
-
590
- return LR, y, mask
591
-
592
-
593
- '''
594
- # =================
595
- Numpy
596
- # =================
597
- '''
598
-
599
-
600
- def blockproc(im, blocksize, fun):
601
- xblocks = np.split(im, range(blocksize[0], im.shape[0], blocksize[0]), axis=0)
602
- xblocks_proc = []
603
- for xb in xblocks:
604
- yblocks = np.split(xb, range(blocksize[1], im.shape[1], blocksize[1]), axis=1)
605
- yblocks_proc = []
606
- for yb in yblocks:
607
- yb_proc = fun(yb)
608
- yblocks_proc.append(yb_proc)
609
- xblocks_proc.append(np.concatenate(yblocks_proc, axis=1))
610
-
611
- proc = np.concatenate(xblocks_proc, axis=0)
612
-
613
- return proc
614
-
615
-
616
- def fun_reshape(a):
617
- return np.reshape(a, (-1,1,a.shape[-1]), order='F')
618
-
619
-
620
- def fun_mul(a, b):
621
- return a*b
622
-
623
-
624
- def BlockMM(nr, nc, Nb, m, x1):
625
- '''
626
- myfun = @(block_struct) reshape(block_struct.data,m,1);
627
- x1 = blockproc(x1,[nr nc],myfun);
628
- x1 = reshape(x1,m,Nb);
629
- x1 = sum(x1,2);
630
- x = reshape(x1,nr,nc);
631
- '''
632
- fun = fun_reshape
633
- x1 = blockproc(x1, blocksize=(nr, nc), fun=fun)
634
- x1 = np.reshape(x1, (m, Nb, x1.shape[-1]), order='F')
635
- x1 = np.sum(x1, 1)
636
- x = np.reshape(x1, (nr, nc, x1.shape[-1]), order='F')
637
- return x
638
-
639
-
640
- def INVLS(FB, FBC, F2B, FR, tau, Nb, nr, nc, m):
641
- '''
642
- x1 = FB.*FR;
643
- FBR = BlockMM(nr,nc,Nb,m,x1);
644
- invW = BlockMM(nr,nc,Nb,m,F2B);
645
- invWBR = FBR./(invW + tau*Nb);
646
- fun = @(block_struct) block_struct.data.*invWBR;
647
- FCBinvWBR = blockproc(FBC,[nr,nc],fun);
648
- FX = (FR-FCBinvWBR)/tau;
649
- Xest = real(ifft2(FX));
650
- '''
651
- x1 = FB*FR
652
- FBR = BlockMM(nr, nc, Nb, m, x1)
653
- invW = BlockMM(nr, nc, Nb, m, F2B)
654
- invWBR = FBR/(invW + tau*Nb)
655
- FCBinvWBR = blockproc(FBC, [nr, nc], lambda im: fun_mul(im, invWBR))
656
- FX = (FR-FCBinvWBR)/tau
657
- Xest = np.real(np.fft.ifft2(FX, axes=(0, 1)))
658
- return Xest
659
-
660
-
661
- def psf2otf(psf, shape=None):
662
- """
663
- Convert point-spread function to optical transfer function.
664
- Compute the Fast Fourier Transform (FFT) of the point-spread
665
- function (PSF) array and creates the optical transfer function (OTF)
666
- array that is not influenced by the PSF off-centering.
667
- By default, the OTF array is the same size as the PSF array.
668
- To ensure that the OTF is not altered due to PSF off-centering, PSF2OTF
669
- post-pads the PSF array (down or to the right) with zeros to match
670
- dimensions specified in OUTSIZE, then circularly shifts the values of
671
- the PSF array up (or to the left) until the central pixel reaches (1,1)
672
- position.
673
- Parameters
674
- ----------
675
- psf : `numpy.ndarray`
676
- PSF array
677
- shape : int
678
- Output shape of the OTF array
679
- Returns
680
- -------
681
- otf : `numpy.ndarray`
682
- OTF array
683
- Notes
684
- -----
685
- Adapted from MATLAB psf2otf function
686
- """
687
- if type(shape) == type(None):
688
- shape = psf.shape
689
- shape = np.array(shape)
690
- if np.all(psf == 0):
691
- # return np.zeros_like(psf)
692
- return np.zeros(shape)
693
- if len(psf.shape) == 1:
694
- psf = psf.reshape((1, psf.shape[0]))
695
- inshape = psf.shape
696
- psf = zero_pad(psf, shape, position='corner')
697
- for axis, axis_size in enumerate(inshape):
698
- psf = np.roll(psf, -int(axis_size / 2), axis=axis)
699
- # Compute the OTF
700
- otf = np.fft.fft2(psf, axes=(0, 1))
701
- # Estimate the rough number of operations involved in the FFT
702
- # and discard the PSF imaginary part if within roundoff error
703
- # roundoff error = machine epsilon = sys.float_info.epsilon
704
- # or np.finfo().eps
705
- n_ops = np.sum(psf.size * np.log2(psf.shape))
706
- otf = np.real_if_close(otf, tol=n_ops)
707
- return otf
708
-
709
-
710
- def zero_pad(image, shape, position='corner'):
711
- """
712
- Extends image to a certain size with zeros
713
- Parameters
714
- ----------
715
- image: real 2d `numpy.ndarray`
716
- Input image
717
- shape: tuple of int
718
- Desired output shape of the image
719
- position : str, optional
720
- The position of the input image in the output one:
721
- * 'corner'
722
- top-left corner (default)
723
- * 'center'
724
- centered
725
- Returns
726
- -------
727
- padded_img: real `numpy.ndarray`
728
- The zero-padded image
729
- """
730
- shape = np.asarray(shape, dtype=int)
731
- imshape = np.asarray(image.shape, dtype=int)
732
- if np.alltrue(imshape == shape):
733
- return image
734
- if np.any(shape <= 0):
735
- raise ValueError("ZERO_PAD: null or negative shape given")
736
- dshape = shape - imshape
737
- if np.any(dshape < 0):
738
- raise ValueError("ZERO_PAD: target size smaller than source one")
739
- pad_img = np.zeros(shape, dtype=image.dtype)
740
- idx, idy = np.indices(imshape)
741
- if position == 'center':
742
- if np.any(dshape % 2 != 0):
743
- raise ValueError("ZERO_PAD: source and target shapes "
744
- "have different parity.")
745
- offx, offy = dshape // 2
746
- else:
747
- offx, offy = (0, 0)
748
- pad_img[idx + offx, idy + offy] = image
749
- return pad_img
750
-
751
-
752
- def upsample_np(x, sf=3, center=False):
753
- st = (sf-1)//2 if center else 0
754
- z = np.zeros((x.shape[0]*sf, x.shape[1]*sf, x.shape[2]))
755
- z[st::sf, st::sf, ...] = x
756
- return z
757
-
758
-
759
- def downsample_np(x, sf=3, center=False):
760
- st = (sf-1)//2 if center else 0
761
- return x[st::sf, st::sf, ...]
762
-
763
-
764
- def imfilter_np(x, k):
765
- '''
766
- x: image, NxcxHxW
767
- k: kernel, cx1xhxw
768
- '''
769
- x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
770
- return x
771
-
772
-
773
- def G_np(x, k, sf=3, center=False):
774
- '''
775
- x: image, NxcxHxW
776
- k: kernel, cx1xhxw
777
-
778
- Matlab function:
779
- tmp = imfilter(x,h,'circular');
780
- y = downsample2(tmp,K);
781
- '''
782
- x = downsample_np(imfilter_np(x, k), sf=sf, center=center)
783
- return x
784
-
785
-
786
- def Gt_np(x, k, sf=3, center=False):
787
- '''
788
- x: image, NxcxHxW
789
- k: kernel, cx1xhxw
790
-
791
- Matlab function:
792
- tmp = upsample2(x,K);
793
- y = imfilter(tmp,h,'circular');
794
- '''
795
- x = imfilter_np(upsample_np(x, sf=sf, center=center), k)
796
- return x
797
-
798
-
799
- if __name__ == '__main__':
800
- img = util.imread_uint('test.bmp', 3)
801
-
802
- img = util.uint2single(img)
803
- k = anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6)
804
- util.imshow(k*10)
805
-
806
-
807
- for sf in [2, 3, 4]:
808
-
809
- # modcrop
810
- img = modcrop_np(img, sf=sf)
811
-
812
- # 1) bicubic degradation
813
- img_b = bicubic_degradation(img, sf=sf)
814
- print(img_b.shape)
815
-
816
- # 2) srmd degradation
817
- img_s = srmd_degradation(img, k, sf=sf)
818
- print(img_s.shape)
819
-
820
- # 3) dpsr degradation
821
- img_d = dpsr_degradation(img, k, sf=sf)
822
- print(img_d.shape)
823
-
824
- # 4) classical degradation
825
- img_d = classical_degradation(img, k, sf=sf)
826
- print(img_d.shape)
827
-
828
- k = anisotropic_Gaussian(ksize=7, theta=0.25*np.pi, l1=0.01, l2=0.01)
829
- #print(k)
830
- # util.imshow(k*10)
831
-
832
- k = shifted_anisotropic_Gaussian(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.8, max_var=10.8, noise_level=0.0)
833
- # util.imshow(k*10)
834
-
835
-
836
- # PCA
837
- # pca_matrix = cal_pca_matrix(ksize=15, l_max=10.0, dim_pca=15, num_samples=12500)
838
- # print(pca_matrix.shape)
839
- # show_pca(pca_matrix)
840
- # run utils/utils_sisr.py
841
- # run utils_sisr.py
842
-
843
-
844
-
845
-
846
-
847
-
848
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
core/data/deg_kair_utils/utils_video.py DELETED
@@ -1,493 +0,0 @@
1
- import os
2
- import cv2
3
- import numpy as np
4
- import torch
5
- import random
6
- from os import path as osp
7
- from torch.nn import functional as F
8
- from abc import ABCMeta, abstractmethod
9
-
10
-
11
- def scandir(dir_path, suffix=None, recursive=False, full_path=False):
12
- """Scan a directory to find the interested files.
13
-
14
- Args:
15
- dir_path (str): Path of the directory.
16
- suffix (str | tuple(str), optional): File suffix that we are
17
- interested in. Default: None.
18
- recursive (bool, optional): If set to True, recursively scan the
19
- directory. Default: False.
20
- full_path (bool, optional): If set to True, include the dir_path.
21
- Default: False.
22
-
23
- Returns:
24
- A generator for all the interested files with relative paths.
25
- """
26
-
27
- if (suffix is not None) and not isinstance(suffix, (str, tuple)):
28
- raise TypeError('"suffix" must be a string or tuple of strings')
29
-
30
- root = dir_path
31
-
32
- def _scandir(dir_path, suffix, recursive):
33
- for entry in os.scandir(dir_path):
34
- if not entry.name.startswith('.') and entry.is_file():
35
- if full_path:
36
- return_path = entry.path
37
- else:
38
- return_path = osp.relpath(entry.path, root)
39
-
40
- if suffix is None:
41
- yield return_path
42
- elif return_path.endswith(suffix):
43
- yield return_path
44
- else:
45
- if recursive:
46
- yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
47
- else:
48
- continue
49
-
50
- return _scandir(dir_path, suffix=suffix, recursive=recursive)
51
-
52
-
53
- def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False):
54
- """Read a sequence of images from a given folder path.
55
-
56
- Args:
57
- path (list[str] | str): List of image paths or image folder path.
58
- require_mod_crop (bool): Require mod crop for each image.
59
- Default: False.
60
- scale (int): Scale factor for mod_crop. Default: 1.
61
- return_imgname(bool): Whether return image names. Default False.
62
-
63
- Returns:
64
- Tensor: size (t, c, h, w), RGB, [0, 1].
65
- list[str]: Returned image name list.
66
- """
67
- if isinstance(path, list):
68
- img_paths = path
69
- else:
70
- img_paths = sorted(list(scandir(path, full_path=True)))
71
- imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
72
-
73
- if require_mod_crop:
74
- imgs = [mod_crop(img, scale) for img in imgs]
75
- imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
76
- imgs = torch.stack(imgs, dim=0)
77
-
78
- if return_imgname:
79
- imgnames = [osp.splitext(osp.basename(path))[0] for path in img_paths]
80
- return imgs, imgnames
81
- else:
82
- return imgs
83
-
84
-
85
- def img2tensor(imgs, bgr2rgb=True, float32=True):
86
- """Numpy array to tensor.
87
-
88
- Args:
89
- imgs (list[ndarray] | ndarray): Input images.
90
- bgr2rgb (bool): Whether to change bgr to rgb.
91
- float32 (bool): Whether to change to float32.
92
-
93
- Returns:
94
- list[tensor] | tensor: Tensor images. If returned results only have
95
- one element, just return tensor.
96
- """
97
-
98
- def _totensor(img, bgr2rgb, float32):
99
- if img.shape[2] == 3 and bgr2rgb:
100
- if img.dtype == 'float64':
101
- img = img.astype('float32')
102
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
103
- img = torch.from_numpy(img.transpose(2, 0, 1))
104
- if float32:
105
- img = img.float()
106
- return img
107
-
108
- if isinstance(imgs, list):
109
- return [_totensor(img, bgr2rgb, float32) for img in imgs]
110
- else:
111
- return _totensor(imgs, bgr2rgb, float32)
112
-
113
-
114
- def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
115
- """Convert torch Tensors into image numpy arrays.
116
-
117
- After clamping to [min, max], values will be normalized to [0, 1].
118
-
119
- Args:
120
- tensor (Tensor or list[Tensor]): Accept shapes:
121
- 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
122
- 2) 3D Tensor of shape (3/1 x H x W);
123
- 3) 2D Tensor of shape (H x W).
124
- Tensor channel should be in RGB order.
125
- rgb2bgr (bool): Whether to change rgb to bgr.
126
- out_type (numpy type): output types. If ``np.uint8``, transform outputs
127
- to uint8 type with range [0, 255]; otherwise, float type with
128
- range [0, 1]. Default: ``np.uint8``.
129
- min_max (tuple[int]): min and max values for clamp.
130
-
131
- Returns:
132
- (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
133
- shape (H x W). The channel order is BGR.
134
- """
135
- if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
136
- raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
137
-
138
- if torch.is_tensor(tensor):
139
- tensor = [tensor]
140
- result = []
141
- for _tensor in tensor:
142
- _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
143
- _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
144
-
145
- n_dim = _tensor.dim()
146
- if n_dim == 4:
147
- img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
148
- img_np = img_np.transpose(1, 2, 0)
149
- if rgb2bgr:
150
- img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
151
- elif n_dim == 3:
152
- img_np = _tensor.numpy()
153
- img_np = img_np.transpose(1, 2, 0)
154
- if img_np.shape[2] == 1: # gray image
155
- img_np = np.squeeze(img_np, axis=2)
156
- else:
157
- if rgb2bgr:
158
- img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
159
- elif n_dim == 2:
160
- img_np = _tensor.numpy()
161
- else:
162
- raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}')
163
- if out_type == np.uint8:
164
- # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
165
- img_np = (img_np * 255.0).round()
166
- img_np = img_np.astype(out_type)
167
- result.append(img_np)
168
- if len(result) == 1:
169
- result = result[0]
170
- return result
171
-
172
-
173
- def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
174
- """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
175
-
176
- We use vertical flip and transpose for rotation implementation.
177
- All the images in the list use the same augmentation.
178
-
179
- Args:
180
- imgs (list[ndarray] | ndarray): Images to be augmented. If the input
181
- is an ndarray, it will be transformed to a list.
182
- hflip (bool): Horizontal flip. Default: True.
183
- rotation (bool): Ratotation. Default: True.
184
- flows (list[ndarray]: Flows to be augmented. If the input is an
185
- ndarray, it will be transformed to a list.
186
- Dimension is (h, w, 2). Default: None.
187
- return_status (bool): Return the status of flip and rotation.
188
- Default: False.
189
-
190
- Returns:
191
- list[ndarray] | ndarray: Augmented images and flows. If returned
192
- results only have one element, just return ndarray.
193
-
194
- """
195
- hflip = hflip and random.random() < 0.5
196
- vflip = rotation and random.random() < 0.5
197
- rot90 = rotation and random.random() < 0.5
198
-
199
- def _augment(img):
200
- if hflip: # horizontal
201
- cv2.flip(img, 1, img)
202
- if vflip: # vertical
203
- cv2.flip(img, 0, img)
204
- if rot90:
205
- img = img.transpose(1, 0, 2)
206
- return img
207
-
208
- def _augment_flow(flow):
209
- if hflip: # horizontal
210
- cv2.flip(flow, 1, flow)
211
- flow[:, :, 0] *= -1
212
- if vflip: # vertical
213
- cv2.flip(flow, 0, flow)
214
- flow[:, :, 1] *= -1
215
- if rot90:
216
- flow = flow.transpose(1, 0, 2)
217
- flow = flow[:, :, [1, 0]]
218
- return flow
219
-
220
- if not isinstance(imgs, list):
221
- imgs = [imgs]
222
- imgs = [_augment(img) for img in imgs]
223
- if len(imgs) == 1:
224
- imgs = imgs[0]
225
-
226
- if flows is not None:
227
- if not isinstance(flows, list):
228
- flows = [flows]
229
- flows = [_augment_flow(flow) for flow in flows]
230
- if len(flows) == 1:
231
- flows = flows[0]
232
- return imgs, flows
233
- else:
234
- if return_status:
235
- return imgs, (hflip, vflip, rot90)
236
- else:
237
- return imgs
238
-
239
-
240
- def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None):
241
- """Paired random crop. Support Numpy array and Tensor inputs.
242
-
243
- It crops lists of lq and gt images with corresponding locations.
244
-
245
- Args:
246
- img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images
247
- should have the same shape. If the input is an ndarray, it will
248
- be transformed to a list containing itself.
249
- img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
250
- should have the same shape. If the input is an ndarray, it will
251
- be transformed to a list containing itself.
252
- gt_patch_size (int): GT patch size.
253
- scale (int): Scale factor.
254
- gt_path (str): Path to ground-truth. Default: None.
255
-
256
- Returns:
257
- list[ndarray] | ndarray: GT images and LQ images. If returned results
258
- only have one element, just return ndarray.
259
- """
260
-
261
- if not isinstance(img_gts, list):
262
- img_gts = [img_gts]
263
- if not isinstance(img_lqs, list):
264
- img_lqs = [img_lqs]
265
-
266
- # determine input type: Numpy array or Tensor
267
- input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy'
268
-
269
- if input_type == 'Tensor':
270
- h_lq, w_lq = img_lqs[0].size()[-2:]
271
- h_gt, w_gt = img_gts[0].size()[-2:]
272
- else:
273
- h_lq, w_lq = img_lqs[0].shape[0:2]
274
- h_gt, w_gt = img_gts[0].shape[0:2]
275
- lq_patch_size = gt_patch_size // scale
276
-
277
- if h_gt != h_lq * scale or w_gt != w_lq * scale:
278
- raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
279
- f'multiplication of LQ ({h_lq}, {w_lq}).')
280
- if h_lq < lq_patch_size or w_lq < lq_patch_size:
281
- raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
282
- f'({lq_patch_size}, {lq_patch_size}). '
283
- f'Please remove {gt_path}.')
284
-
285
- # randomly choose top and left coordinates for lq patch
286
- top = random.randint(0, h_lq - lq_patch_size)
287
- left = random.randint(0, w_lq - lq_patch_size)
288
-
289
- # crop lq patch
290
- if input_type == 'Tensor':
291
- img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs]
292
- else:
293
- img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
294
-
295
- # crop corresponding gt patch
296
- top_gt, left_gt = int(top * scale), int(left * scale)
297
- if input_type == 'Tensor':
298
- img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts]
299
- else:
300
- img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
301
- if len(img_gts) == 1:
302
- img_gts = img_gts[0]
303
- if len(img_lqs) == 1:
304
- img_lqs = img_lqs[0]
305
- return img_gts, img_lqs
306
-
307
-
308
- # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501
309
- class BaseStorageBackend(metaclass=ABCMeta):
310
- """Abstract class of storage backends.
311
-
312
- All backends need to implement two apis: ``get()`` and ``get_text()``.
313
- ``get()`` reads the file as a byte stream and ``get_text()`` reads the file
314
- as texts.
315
- """
316
-
317
- @abstractmethod
318
- def get(self, filepath):
319
- pass
320
-
321
- @abstractmethod
322
- def get_text(self, filepath):
323
- pass
324
-
325
-
326
- class MemcachedBackend(BaseStorageBackend):
327
- """Memcached storage backend.
328
-
329
- Attributes:
330
- server_list_cfg (str): Config file for memcached server list.
331
- client_cfg (str): Config file for memcached client.
332
- sys_path (str | None): Additional path to be appended to `sys.path`.
333
- Default: None.
334
- """
335
-
336
- def __init__(self, server_list_cfg, client_cfg, sys_path=None):
337
- if sys_path is not None:
338
- import sys
339
- sys.path.append(sys_path)
340
- try:
341
- import mc
342
- except ImportError:
343
- raise ImportError('Please install memcached to enable MemcachedBackend.')
344
-
345
- self.server_list_cfg = server_list_cfg
346
- self.client_cfg = client_cfg
347
- self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg)
348
- # mc.pyvector servers as a point which points to a memory cache
349
- self._mc_buffer = mc.pyvector()
350
-
351
- def get(self, filepath):
352
- filepath = str(filepath)
353
- import mc
354
- self._client.Get(filepath, self._mc_buffer)
355
- value_buf = mc.ConvertBuffer(self._mc_buffer)
356
- return value_buf
357
-
358
- def get_text(self, filepath):
359
- raise NotImplementedError
360
-
361
-
362
- class HardDiskBackend(BaseStorageBackend):
363
- """Raw hard disks storage backend."""
364
-
365
- def get(self, filepath):
366
- filepath = str(filepath)
367
- with open(filepath, 'rb') as f:
368
- value_buf = f.read()
369
- return value_buf
370
-
371
- def get_text(self, filepath):
372
- filepath = str(filepath)
373
- with open(filepath, 'r') as f:
374
- value_buf = f.read()
375
- return value_buf
376
-
377
-
378
- class LmdbBackend(BaseStorageBackend):
379
- """Lmdb storage backend.
380
-
381
- Args:
382
- db_paths (str | list[str]): Lmdb database paths.
383
- client_keys (str | list[str]): Lmdb client keys. Default: 'default'.
384
- readonly (bool, optional): Lmdb environment parameter. If True,
385
- disallow any write operations. Default: True.
386
- lock (bool, optional): Lmdb environment parameter. If False, when
387
- concurrent access occurs, do not lock the database. Default: False.
388
- readahead (bool, optional): Lmdb environment parameter. If False,
389
- disable the OS filesystem readahead mechanism, which may improve
390
- random read performance when a database is larger than RAM.
391
- Default: False.
392
-
393
- Attributes:
394
- db_paths (list): Lmdb database path.
395
- _client (list): A list of several lmdb envs.
396
- """
397
-
398
- def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs):
399
- try:
400
- import lmdb
401
- except ImportError:
402
- raise ImportError('Please install lmdb to enable LmdbBackend.')
403
-
404
- if isinstance(client_keys, str):
405
- client_keys = [client_keys]
406
-
407
- if isinstance(db_paths, list):
408
- self.db_paths = [str(v) for v in db_paths]
409
- elif isinstance(db_paths, str):
410
- self.db_paths = [str(db_paths)]
411
- assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, '
412
- f'but received {len(client_keys)} and {len(self.db_paths)}.')
413
-
414
- self._client = {}
415
- for client, path in zip(client_keys, self.db_paths):
416
- self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs)
417
-
418
- def get(self, filepath, client_key):
419
- """Get values according to the filepath from one lmdb named client_key.
420
-
421
- Args:
422
- filepath (str | obj:`Path`): Here, filepath is the lmdb key.
423
- client_key (str): Used for distinguishing different lmdb envs.
424
- """
425
- filepath = str(filepath)
426
- assert client_key in self._client, (f'client_key {client_key} is not ' 'in lmdb clients.')
427
- client = self._client[client_key]
428
- with client.begin(write=False) as txn:
429
- value_buf = txn.get(filepath.encode('ascii'))
430
- return value_buf
431
-
432
- def get_text(self, filepath):
433
- raise NotImplementedError
434
-
435
-
436
- class FileClient(object):
437
- """A general file client to access files in different backend.
438
-
439
- The client loads a file or text in a specified backend from its path
440
- and return it as a binary file. it can also register other backend
441
- accessor with a given name and backend class.
442
-
443
- Attributes:
444
- backend (str): The storage backend type. Options are "disk",
445
- "memcached" and "lmdb".
446
- client (:obj:`BaseStorageBackend`): The backend object.
447
- """
448
-
449
- _backends = {
450
- 'disk': HardDiskBackend,
451
- 'memcached': MemcachedBackend,
452
- 'lmdb': LmdbBackend,
453
- }
454
-
455
- def __init__(self, backend='disk', **kwargs):
456
- if backend not in self._backends:
457
- raise ValueError(f'Backend {backend} is not supported. Currently supported ones'
458
- f' are {list(self._backends.keys())}')
459
- self.backend = backend
460
- self.client = self._backends[backend](**kwargs)
461
-
462
- def get(self, filepath, client_key='default'):
463
- # client_key is used only for lmdb, where different fileclients have
464
- # different lmdb environments.
465
- if self.backend == 'lmdb':
466
- return self.client.get(filepath, client_key)
467
- else:
468
- return self.client.get(filepath)
469
-
470
- def get_text(self, filepath):
471
- return self.client.get_text(filepath)
472
-
473
-
474
- def imfrombytes(content, flag='color', float32=False):
475
- """Read an image from bytes.
476
-
477
- Args:
478
- content (bytes): Image bytes got from files or other streams.
479
- flag (str): Flags specifying the color type of a loaded image,
480
- candidates are `color`, `grayscale` and `unchanged`.
481
- float32 (bool): Whether to change to float32., If True, will also norm
482
- to [0, 1]. Default: False.
483
-
484
- Returns:
485
- ndarray: Loaded image array.
486
- """
487
- img_np = np.frombuffer(content, np.uint8)
488
- imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED}
489
- img = cv2.imdecode(img_np, imread_flags[flag])
490
- if float32:
491
- img = img.astype(np.float32) / 255.
492
- return img
493
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
core/data/deg_kair_utils/utils_videoio.py DELETED
@@ -1,555 +0,0 @@
1
- import os
2
- import cv2
3
- import numpy as np
4
- import torch
5
- import random
6
- from os import path as osp
7
- from torchvision.utils import make_grid
8
- import sys
9
- from pathlib import Path
10
- import six
11
- from collections import OrderedDict
12
- import math
13
- import glob
14
- import av
15
- import io
16
- from cv2 import (CAP_PROP_FOURCC, CAP_PROP_FPS, CAP_PROP_FRAME_COUNT,
17
- CAP_PROP_FRAME_HEIGHT, CAP_PROP_FRAME_WIDTH,
18
- CAP_PROP_POS_FRAMES, VideoWriter_fourcc)
19
-
20
- if sys.version_info <= (3, 3):
21
- FileNotFoundError = IOError
22
- else:
23
- FileNotFoundError = FileNotFoundError
24
-
25
-
26
- def is_str(x):
27
- """Whether the input is an string instance."""
28
- return isinstance(x, six.string_types)
29
-
30
-
31
- def is_filepath(x):
32
- return is_str(x) or isinstance(x, Path)
33
-
34
-
35
- def fopen(filepath, *args, **kwargs):
36
- if is_str(filepath):
37
- return open(filepath, *args, **kwargs)
38
- elif isinstance(filepath, Path):
39
- return filepath.open(*args, **kwargs)
40
- raise ValueError('`filepath` should be a string or a Path')
41
-
42
-
43
- def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
44
- if not osp.isfile(filename):
45
- raise FileNotFoundError(msg_tmpl.format(filename))
46
-
47
-
48
- def mkdir_or_exist(dir_name, mode=0o777):
49
- if dir_name == '':
50
- return
51
- dir_name = osp.expanduser(dir_name)
52
- os.makedirs(dir_name, mode=mode, exist_ok=True)
53
-
54
-
55
- def symlink(src, dst, overwrite=True, **kwargs):
56
- if os.path.lexists(dst) and overwrite:
57
- os.remove(dst)
58
- os.symlink(src, dst, **kwargs)
59
-
60
-
61
- def scandir(dir_path, suffix=None, recursive=False, case_sensitive=True):
62
- """Scan a directory to find the interested files.
63
- Args:
64
- dir_path (str | :obj:`Path`): Path of the directory.
65
- suffix (str | tuple(str), optional): File suffix that we are
66
- interested in. Default: None.
67
- recursive (bool, optional): If set to True, recursively scan the
68
- directory. Default: False.
69
- case_sensitive (bool, optional) : If set to False, ignore the case of
70
- suffix. Default: True.
71
- Returns:
72
- A generator for all the interested files with relative paths.
73
- """
74
- if isinstance(dir_path, (str, Path)):
75
- dir_path = str(dir_path)
76
- else:
77
- raise TypeError('"dir_path" must be a string or Path object')
78
-
79
- if (suffix is not None) and not isinstance(suffix, (str, tuple)):
80
- raise TypeError('"suffix" must be a string or tuple of strings')
81
-
82
- if suffix is not None and not case_sensitive:
83
- suffix = suffix.lower() if isinstance(suffix, str) else tuple(
84
- item.lower() for item in suffix)
85
-
86
- root = dir_path
87
-
88
- def _scandir(dir_path, suffix, recursive, case_sensitive):
89
- for entry in os.scandir(dir_path):
90
- if not entry.name.startswith('.') and entry.is_file():
91
- rel_path = osp.relpath(entry.path, root)
92
- _rel_path = rel_path if case_sensitive else rel_path.lower()
93
- if suffix is None or _rel_path.endswith(suffix):
94
- yield rel_path
95
- elif recursive and os.path.isdir(entry.path):
96
- # scan recursively if entry.path is a directory
97
- yield from _scandir(entry.path, suffix, recursive,
98
- case_sensitive)
99
-
100
- return _scandir(dir_path, suffix, recursive, case_sensitive)
101
-
102
-
103
- class Cache:
104
-
105
- def __init__(self, capacity):
106
- self._cache = OrderedDict()
107
- self._capacity = int(capacity)
108
- if capacity <= 0:
109
- raise ValueError('capacity must be a positive integer')
110
-
111
- @property
112
- def capacity(self):
113
- return self._capacity
114
-
115
- @property
116
- def size(self):
117
- return len(self._cache)
118
-
119
- def put(self, key, val):
120
- if key in self._cache:
121
- return
122
- if len(self._cache) >= self.capacity:
123
- self._cache.popitem(last=False)
124
- self._cache[key] = val
125
-
126
- def get(self, key, default=None):
127
- val = self._cache[key] if key in self._cache else default
128
- return val
129
-
130
-
131
- class VideoReader:
132
- """Video class with similar usage to a list object.
133
-
134
- This video warpper class provides convenient apis to access frames.
135
- There exists an issue of OpenCV's VideoCapture class that jumping to a
136
- certain frame may be inaccurate. It is fixed in this class by checking
137
- the position after jumping each time.
138
- Cache is used when decoding videos. So if the same frame is visited for
139
- the second time, there is no need to decode again if it is stored in the
140
- cache.
141
-
142
- """
143
-
144
- def __init__(self, filename, cache_capacity=10):
145
- # Check whether the video path is a url
146
- if not filename.startswith(('https://', 'http://')):
147
- check_file_exist(filename, 'Video file not found: ' + filename)
148
- self._vcap = cv2.VideoCapture(filename)
149
- assert cache_capacity > 0
150
- self._cache = Cache(cache_capacity)
151
- self._position = 0
152
- # get basic info
153
- self._width = int(self._vcap.get(CAP_PROP_FRAME_WIDTH))
154
- self._height = int(self._vcap.get(CAP_PROP_FRAME_HEIGHT))
155
- self._fps = self._vcap.get(CAP_PROP_FPS)
156
- self._frame_cnt = int(self._vcap.get(CAP_PROP_FRAME_COUNT))
157
- self._fourcc = self._vcap.get(CAP_PROP_FOURCC)
158
-
159
- @property
160
- def vcap(self):
161
- """:obj:`cv2.VideoCapture`: The raw VideoCapture object."""
162
- return self._vcap
163
-
164
- @property
165
- def opened(self):
166
- """bool: Indicate whether the video is opened."""
167
- return self._vcap.isOpened()
168
-
169
- @property
170
- def width(self):
171
- """int: Width of video frames."""
172
- return self._width
173
-
174
- @property
175
- def height(self):
176
- """int: Height of video frames."""
177
- return self._height
178
-
179
- @property
180
- def resolution(self):
181
- """tuple: Video resolution (width, height)."""
182
- return (self._width, self._height)
183
-
184
- @property
185
- def fps(self):
186
- """float: FPS of the video."""
187
- return self._fps
188
-
189
- @property
190
- def frame_cnt(self):
191
- """int: Total frames of the video."""
192
- return self._frame_cnt
193
-
194
- @property
195
- def fourcc(self):
196
- """str: "Four character code" of the video."""
197
- return self._fourcc
198
-
199
- @property
200
- def position(self):
201
- """int: Current cursor position, indicating frame decoded."""
202
- return self._position
203
-
204
- def _get_real_position(self):
205
- return int(round(self._vcap.get(CAP_PROP_POS_FRAMES)))
206
-
207
- def _set_real_position(self, frame_id):
208
- self._vcap.set(CAP_PROP_POS_FRAMES, frame_id)
209
- pos = self._get_real_position()
210
- for _ in range(frame_id - pos):
211
- self._vcap.read()
212
- self._position = frame_id
213
-
214
- def read(self):
215
- """Read the next frame.
216
-
217
- If the next frame have been decoded before and in the cache, then
218
- return it directly, otherwise decode, cache and return it.
219
-
220
- Returns:
221
- ndarray or None: Return the frame if successful, otherwise None.
222
- """
223
- # pos = self._position
224
- if self._cache:
225
- img = self._cache.get(self._position)
226
- if img is not None:
227
- ret = True
228
- else:
229
- if self._position != self._get_real_position():
230
- self._set_real_position(self._position)
231
- ret, img = self._vcap.read()
232
- if ret:
233
- self._cache.put(self._position, img)
234
- else:
235
- ret, img = self._vcap.read()
236
- if ret:
237
- self._position += 1
238
- return img
239
-
240
- def get_frame(self, frame_id):
241
- """Get frame by index.
242
-
243
- Args:
244
- frame_id (int): Index of the expected frame, 0-based.
245
-
246
- Returns:
247
- ndarray or None: Return the frame if successful, otherwise None.
248
- """
249
- if frame_id < 0 or frame_id >= self._frame_cnt:
250
- raise IndexError(
251
- f'"frame_id" must be between 0 and {self._frame_cnt - 1}')
252
- if frame_id == self._position:
253
- return self.read()
254
- if self._cache:
255
- img = self._cache.get(frame_id)
256
- if img is not None:
257
- self._position = frame_id + 1
258
- return img
259
- self._set_real_position(frame_id)
260
- ret, img = self._vcap.read()
261
- if ret:
262
- if self._cache:
263
- self._cache.put(self._position, img)
264
- self._position += 1
265
- return img
266
-
267
- def current_frame(self):
268
- """Get the current frame (frame that is just visited).
269
-
270
- Returns:
271
- ndarray or None: If the video is fresh, return None, otherwise
272
- return the frame.
273
- """
274
- if self._position == 0:
275
- return None
276
- return self._cache.get(self._position - 1)
277
-
278
- def cvt2frames(self,
279
- frame_dir,
280
- file_start=0,
281
- filename_tmpl='{:06d}.jpg',
282
- start=0,
283
- max_num=0,
284
- show_progress=False):
285
- """Convert a video to frame images.
286
-
287
- Args:
288
- frame_dir (str): Output directory to store all the frame images.
289
- file_start (int): Filenames will start from the specified number.
290
- filename_tmpl (str): Filename template with the index as the
291
- placeholder.
292
- start (int): The starting frame index.
293
- max_num (int): Maximum number of frames to be written.
294
- show_progress (bool): Whether to show a progress bar.
295
- """
296
- mkdir_or_exist(frame_dir)
297
- if max_num == 0:
298
- task_num = self.frame_cnt - start
299
- else:
300
- task_num = min(self.frame_cnt - start, max_num)
301
- if task_num <= 0:
302
- raise ValueError('start must be less than total frame number')
303
- if start > 0:
304
- self._set_real_position(start)
305
-
306
- def write_frame(file_idx):
307
- img = self.read()
308
- if img is None:
309
- return
310
- filename = osp.join(frame_dir, filename_tmpl.format(file_idx))
311
- cv2.imwrite(filename, img)
312
-
313
- if show_progress:
314
- pass
315
- #track_progress(write_frame, range(file_start,file_start + task_num))
316
- else:
317
- for i in range(task_num):
318
- write_frame(file_start + i)
319
-
320
- def __len__(self):
321
- return self.frame_cnt
322
-
323
- def __getitem__(self, index):
324
- if isinstance(index, slice):
325
- return [
326
- self.get_frame(i)
327
- for i in range(*index.indices(self.frame_cnt))
328
- ]
329
- # support negative indexing
330
- if index < 0:
331
- index += self.frame_cnt
332
- if index < 0:
333
- raise IndexError('index out of range')
334
- return self.get_frame(index)
335
-
336
- def __iter__(self):
337
- self._set_real_position(0)
338
- return self
339
-
340
- def __next__(self):
341
- img = self.read()
342
- if img is not None:
343
- return img
344
- else:
345
- raise StopIteration
346
-
347
- next = __next__
348
-
349
- def __enter__(self):
350
- return self
351
-
352
- def __exit__(self, exc_type, exc_value, traceback):
353
- self._vcap.release()
354
-
355
-
356
- def frames2video(frame_dir,
357
- video_file,
358
- fps=30,
359
- fourcc='XVID',
360
- filename_tmpl='{:06d}.jpg',
361
- start=0,
362
- end=0,
363
- show_progress=False):
364
- """Read the frame images from a directory and join them as a video.
365
-
366
- Args:
367
- frame_dir (str): The directory containing video frames.
368
- video_file (str): Output filename.
369
- fps (float): FPS of the output video.
370
- fourcc (str): Fourcc of the output video, this should be compatible
371
- with the output file type.
372
- filename_tmpl (str): Filename template with the index as the variable.
373
- start (int): Starting frame index.
374
- end (int): Ending frame index.
375
- show_progress (bool): Whether to show a progress bar.
376
- """
377
- if end == 0:
378
- ext = filename_tmpl.split('.')[-1]
379
- end = len([name for name in scandir(frame_dir, ext)])
380
- first_file = osp.join(frame_dir, filename_tmpl.format(start))
381
- check_file_exist(first_file, 'The start frame not found: ' + first_file)
382
- img = cv2.imread(first_file)
383
- height, width = img.shape[:2]
384
- resolution = (width, height)
385
- vwriter = cv2.VideoWriter(video_file, VideoWriter_fourcc(*fourcc), fps,
386
- resolution)
387
-
388
- def write_frame(file_idx):
389
- filename = osp.join(frame_dir, filename_tmpl.format(file_idx))
390
- img = cv2.imread(filename)
391
- vwriter.write(img)
392
-
393
- if show_progress:
394
- pass
395
- # track_progress(write_frame, range(start, end))
396
- else:
397
- for i in range(start, end):
398
- write_frame(i)
399
- vwriter.release()
400
-
401
-
402
- def video2images(video_path, output_dir):
403
- vidcap = cv2.VideoCapture(video_path)
404
- in_fps = vidcap.get(cv2.CAP_PROP_FPS)
405
- print('video fps:', in_fps)
406
- if not os.path.isdir(output_dir):
407
- os.makedirs(output_dir)
408
- loaded, frame = vidcap.read()
409
- total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
410
- print(f'number of total frames is: {total_frames:06}')
411
- for i_frame in range(total_frames):
412
- if i_frame % 100 == 0:
413
- print(f'{i_frame:06} / {total_frames:06}')
414
- frame_name = os.path.join(output_dir, f'{i_frame:06}' + '.png')
415
- cv2.imwrite(frame_name, frame)
416
- loaded, frame = vidcap.read()
417
-
418
-
419
- def images2video(image_dir, video_path, fps=24, image_ext='png'):
420
- '''
421
- #codec = cv2.VideoWriter_fourcc(*'XVID')
422
- #codec = cv2.VideoWriter_fourcc('A','V','C','1')
423
- #codec = cv2.VideoWriter_fourcc('Y','U','V','1')
424
- #codec = cv2.VideoWriter_fourcc('P','I','M','1')
425
- #codec = cv2.VideoWriter_fourcc('M','J','P','G')
426
- codec = cv2.VideoWriter_fourcc('M','P','4','2')
427
- #codec = cv2.VideoWriter_fourcc('D','I','V','3')
428
- #codec = cv2.VideoWriter_fourcc('D','I','V','X')
429
- #codec = cv2.VideoWriter_fourcc('U','2','6','3')
430
- #codec = cv2.VideoWriter_fourcc('I','2','6','3')
431
- #codec = cv2.VideoWriter_fourcc('F','L','V','1')
432
- #codec = cv2.VideoWriter_fourcc('H','2','6','4')
433
- #codec = cv2.VideoWriter_fourcc('A','Y','U','V')
434
- #codec = cv2.VideoWriter_fourcc('I','U','Y','V')
435
- 编码器常用的几种:
436
- cv2.VideoWriter_fourcc("I", "4", "2", "0")
437
- 压缩的yuv颜色编码器,4:2:0色彩度子采样 兼容性好,产生很大的视频 avi
438
- cv2.VideoWriter_fourcc("P", I", "M", "1")
439
- 采用mpeg-1编码,文件为avi
440
- cv2.VideoWriter_fourcc("X", "V", "T", "D")
441
- 采用mpeg-4编码,得到视频大小平均 拓展名avi
442
- cv2.VideoWriter_fourcc("T", "H", "E", "O")
443
- Ogg Vorbis, 拓展名为ogv
444
- cv2.VideoWriter_fourcc("F", "L", "V", "1")
445
- FLASH视频,拓展名为.flv
446
- '''
447
- image_files = sorted(glob.glob(os.path.join(image_dir, '*.{}'.format(image_ext))))
448
- print(len(image_files))
449
- height, width, _ = cv2.imread(image_files[0]).shape
450
- out_fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G') # cv2.VideoWriter_fourcc(*'MP4V')
451
- out_video = cv2.VideoWriter(video_path, out_fourcc, fps, (width, height))
452
-
453
- for image_file in image_files:
454
- img = cv2.imread(image_file)
455
- img = cv2.resize(img, (width, height), interpolation=3)
456
- out_video.write(img)
457
- out_video.release()
458
-
459
-
460
- def add_video_compression(imgs):
461
- codec_type = ['libx264', 'h264', 'mpeg4']
462
- codec_prob = [1 / 3., 1 / 3., 1 / 3.]
463
- codec = random.choices(codec_type, codec_prob)[0]
464
- # codec = 'mpeg4'
465
- bitrate = [1e4, 1e5]
466
- bitrate = np.random.randint(bitrate[0], bitrate[1] + 1)
467
-
468
- buf = io.BytesIO()
469
- with av.open(buf, 'w', 'mp4') as container:
470
- stream = container.add_stream(codec, rate=1)
471
- stream.height = imgs[0].shape[0]
472
- stream.width = imgs[0].shape[1]
473
- stream.pix_fmt = 'yuv420p'
474
- stream.bit_rate = bitrate
475
-
476
- for img in imgs:
477
- img = np.uint8((img.clip(0, 1)*255.).round())
478
- frame = av.VideoFrame.from_ndarray(img, format='rgb24')
479
- frame.pict_type = 'NONE'
480
- # pdb.set_trace()
481
- for packet in stream.encode(frame):
482
- container.mux(packet)
483
-
484
- # Flush stream
485
- for packet in stream.encode():
486
- container.mux(packet)
487
-
488
- outputs = []
489
- with av.open(buf, 'r', 'mp4') as container:
490
- if container.streams.video:
491
- for frame in container.decode(**{'video': 0}):
492
- outputs.append(
493
- frame.to_rgb().to_ndarray().astype(np.float32) / 255.)
494
-
495
- #outputs = np.stack(outputs, axis=0)
496
- return outputs
497
-
498
-
499
- if __name__ == '__main__':
500
-
501
- # -----------------------------------
502
- # test VideoReader(filename, cache_capacity=10)
503
- # -----------------------------------
504
- # video_reader = VideoReader('utils/test.mp4')
505
- # from utils import utils_image as util
506
- # inputs = []
507
- # for frame in video_reader:
508
- # print(frame.dtype)
509
- # util.imshow(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
510
- # #util.imshow(np.flip(frame, axis=2))
511
-
512
- # -----------------------------------
513
- # test video2images(video_path, output_dir)
514
- # -----------------------------------
515
- # video2images('utils/test.mp4', 'frames')
516
-
517
- # -----------------------------------
518
- # test images2video(image_dir, video_path, fps=24, image_ext='png')
519
- # -----------------------------------
520
- # images2video('frames', 'video_02.mp4', fps=30, image_ext='png')
521
-
522
-
523
- # -----------------------------------
524
- # test frames2video(frame_dir, video_file, fps=30, fourcc='XVID', filename_tmpl='{:06d}.png')
525
- # -----------------------------------
526
- # frames2video('frames', 'video_01.mp4', filename_tmpl='{:06d}.png')
527
-
528
-
529
- # -----------------------------------
530
- # test add_video_compression(imgs)
531
- # -----------------------------------
532
- # imgs = []
533
- # image_ext = 'png'
534
- # frames = 'frames'
535
- # from utils import utils_image as util
536
- # image_files = sorted(glob.glob(os.path.join(frames, '*.{}'.format(image_ext))))
537
- # for i, image_file in enumerate(image_files):
538
- # if i < 7:
539
- # img = util.imread_uint(image_file, 3)
540
- # img = util.uint2single(img)
541
- # imgs.append(img)
542
- #
543
- # results = add_video_compression(imgs)
544
- # for i, img in enumerate(results):
545
- # util.imshow(util.single2uint(img))
546
- # util.imsave(util.single2uint(img),f'{i:05}.png')
547
-
548
- # run utils/utils_video.py
549
-
550
-
551
-
552
-
553
-
554
-
555
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
core/scripts/__init__.py DELETED
File without changes
core/scripts/cli.py DELETED
@@ -1,41 +0,0 @@
1
- import sys
2
- import argparse
3
- from .. import WarpCore
4
- from .. import templates
5
-
6
-
7
- def template_init(args):
8
- return ''''
9
-
10
-
11
- '''.strip()
12
-
13
-
14
- def init_template(args):
15
- parser = argparse.ArgumentParser(description='WarpCore template init tool')
16
- parser.add_argument('-t', '--template', type=str, default='WarpCore')
17
- args = parser.parse_args(args)
18
-
19
- if args.template == 'WarpCore':
20
- template_cls = WarpCore
21
- else:
22
- try:
23
- template_cls = __import__(args.template)
24
- except ModuleNotFoundError:
25
- template_cls = getattr(templates, args.template)
26
- print(template_cls)
27
-
28
-
29
- def main():
30
- if len(sys.argv) < 2:
31
- print('Usage: core <command>')
32
- sys.exit(1)
33
- if sys.argv[1] == 'init':
34
- init_template(sys.argv[2:])
35
- else:
36
- print('Unknown command')
37
- sys.exit(1)
38
-
39
-
40
- if __name__ == '__main__':
41
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
core/templates/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .diffusion import DiffusionCore
 
 
core/templates/diffusion.py DELETED
@@ -1,236 +0,0 @@
1
- from .. import WarpCore
2
- from ..utils import EXPECTED, EXPECTED_TRAIN, update_weights_ema, create_folder_if_necessary
3
- from abc import abstractmethod
4
- from dataclasses import dataclass
5
- import torch
6
- from torch import nn
7
- from torch.utils.data import DataLoader
8
- from gdf import GDF
9
- import numpy as np
10
- from tqdm import tqdm
11
- import wandb
12
-
13
- import webdataset as wds
14
- from webdataset.handlers import warn_and_continue
15
- from torch.distributed import barrier
16
- from enum import Enum
17
-
18
- class TargetReparametrization(Enum):
19
- EPSILON = 'epsilon'
20
- X0 = 'x0'
21
-
22
- class DiffusionCore(WarpCore):
23
- @dataclass(frozen=True)
24
- class Config(WarpCore.Config):
25
- # TRAINING PARAMS
26
- lr: float = EXPECTED_TRAIN
27
- grad_accum_steps: int = EXPECTED_TRAIN
28
- batch_size: int = EXPECTED_TRAIN
29
- updates: int = EXPECTED_TRAIN
30
- warmup_updates: int = EXPECTED_TRAIN
31
- save_every: int = 500
32
- backup_every: int = 20000
33
- use_fsdp: bool = True
34
-
35
- # EMA UPDATE
36
- ema_start_iters: int = None
37
- ema_iters: int = None
38
- ema_beta: float = None
39
-
40
- # GDF setting
41
- gdf_target_reparametrization: TargetReparametrization = None # epsilon or x0
42
-
43
- @dataclass() # not frozen, means that fields are mutable. Doesn't support EXPECTED
44
- class Info(WarpCore.Info):
45
- ema_loss: float = None
46
-
47
- @dataclass(frozen=True)
48
- class Models(WarpCore.Models):
49
- generator : nn.Module = EXPECTED
50
- generator_ema : nn.Module = None # optional
51
-
52
- @dataclass(frozen=True)
53
- class Optimizers(WarpCore.Optimizers):
54
- generator : any = EXPECTED
55
-
56
- @dataclass(frozen=True)
57
- class Schedulers(WarpCore.Schedulers):
58
- generator: any = None
59
-
60
- @dataclass(frozen=True)
61
- class Extras(WarpCore.Extras):
62
- gdf: GDF = EXPECTED
63
- sampling_configs: dict = EXPECTED
64
-
65
- # --------------------------------------------
66
- info: Info
67
- config: Config
68
-
69
- @abstractmethod
70
- def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
71
- raise NotImplementedError("This method needs to be overriden")
72
-
73
- @abstractmethod
74
- def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
75
- raise NotImplementedError("This method needs to be overriden")
76
-
77
- @abstractmethod
78
- def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False):
79
- raise NotImplementedError("This method needs to be overriden")
80
-
81
- @abstractmethod
82
- def webdataset_path(self, extras: Extras):
83
- raise NotImplementedError("This method needs to be overriden")
84
-
85
- @abstractmethod
86
- def webdataset_filters(self, extras: Extras):
87
- raise NotImplementedError("This method needs to be overriden")
88
-
89
- @abstractmethod
90
- def webdataset_preprocessors(self, extras: Extras):
91
- raise NotImplementedError("This method needs to be overriden")
92
-
93
- @abstractmethod
94
- def sample(self, models: Models, data: WarpCore.Data, extras: Extras):
95
- raise NotImplementedError("This method needs to be overriden")
96
- # -------------
97
-
98
- def setup_data(self, extras: Extras) -> WarpCore.Data:
99
- # SETUP DATASET
100
- dataset_path = self.webdataset_path(extras)
101
- preprocessors = self.webdataset_preprocessors(extras)
102
- filters = self.webdataset_filters(extras)
103
-
104
- handler = warn_and_continue # None
105
- # handler = None
106
- dataset = wds.WebDataset(
107
- dataset_path, resampled=True, handler=handler
108
- ).select(filters).shuffle(690, handler=handler).decode(
109
- "pilrgb", handler=handler
110
- ).to_tuple(
111
- *[p[0] for p in preprocessors], handler=handler
112
- ).map_tuple(
113
- *[p[1] for p in preprocessors], handler=handler
114
- ).map(lambda x: {p[2]:x[i] for i, p in enumerate(preprocessors)})
115
-
116
- # SETUP DATALOADER
117
- real_batch_size = self.config.batch_size//(self.world_size*self.config.grad_accum_steps)
118
- dataloader = DataLoader(
119
- dataset, batch_size=real_batch_size, num_workers=8, pin_memory=True
120
- )
121
-
122
- return self.Data(dataset=dataset, dataloader=dataloader, iterator=iter(dataloader))
123
-
124
- def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models):
125
- batch = next(data.iterator)
126
-
127
- with torch.no_grad():
128
- conditions = self.get_conditions(batch, models, extras)
129
- latents = self.encode_latents(batch, models, extras)
130
- noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1)
131
-
132
- # FORWARD PASS
133
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
134
- pred = models.generator(noised, noise_cond, **conditions)
135
- if self.config.gdf_target_reparametrization == TargetReparametrization.EPSILON:
136
- pred = extras.gdf.undiffuse(noised, logSNR, pred)[1] # transform whatever prediction to epsilon to use in the loss
137
- target = noise
138
- elif self.config.gdf_target_reparametrization == TargetReparametrization.X0:
139
- pred = extras.gdf.undiffuse(noised, logSNR, pred)[0] # transform whatever prediction to x0 to use in the loss
140
- target = latents
141
- loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3])
142
- loss_adjusted = (loss * loss_weight).mean() / self.config.grad_accum_steps
143
-
144
- return loss, loss_adjusted
145
-
146
- def train(self, data: WarpCore.Data, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers):
147
- start_iter = self.info.iter+1
148
- max_iters = self.config.updates * self.config.grad_accum_steps
149
- if self.is_main_node:
150
- print(f"STARTING AT STEP: {start_iter}/{max_iters}")
151
-
152
- pbar = tqdm(range(start_iter, max_iters+1)) if self.is_main_node else range(start_iter, max_iters+1) # <--- DDP
153
- models.generator.train()
154
- for i in pbar:
155
- # FORWARD PASS
156
- loss, loss_adjusted = self.forward_pass(data, extras, models)
157
-
158
- # BACKWARD PASS
159
- if i % self.config.grad_accum_steps == 0 or i == max_iters:
160
- loss_adjusted.backward()
161
- grad_norm = nn.utils.clip_grad_norm_(models.generator.parameters(), 1.0)
162
- optimizers_dict = optimizers.to_dict()
163
- for k in optimizers_dict:
164
- optimizers_dict[k].step()
165
- schedulers_dict = schedulers.to_dict()
166
- for k in schedulers_dict:
167
- schedulers_dict[k].step()
168
- models.generator.zero_grad(set_to_none=True)
169
- self.info.total_steps += 1
170
- else:
171
- with models.generator.no_sync():
172
- loss_adjusted.backward()
173
- self.info.iter = i
174
-
175
- # UPDATE EMA
176
- if models.generator_ema is not None and i % self.config.ema_iters == 0:
177
- update_weights_ema(
178
- models.generator_ema, models.generator,
179
- beta=(self.config.ema_beta if i > self.config.ema_start_iters else 0)
180
- )
181
-
182
- # UPDATE LOSS METRICS
183
- self.info.ema_loss = loss.mean().item() if self.info.ema_loss is None else self.info.ema_loss * 0.99 + loss.mean().item() * 0.01
184
-
185
- if self.is_main_node and self.config.wandb_project is not None and np.isnan(loss.mean().item()) or np.isnan(grad_norm.item()):
186
- wandb.alert(
187
- title=f"NaN value encountered in training run {self.info.wandb_run_id}",
188
- text=f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}",
189
- wait_duration=60*30
190
- )
191
-
192
- if self.is_main_node:
193
- logs = {
194
- 'loss': self.info.ema_loss,
195
- 'raw_loss': loss.mean().item(),
196
- 'grad_norm': grad_norm.item(),
197
- 'lr': optimizers.generator.param_groups[0]['lr'],
198
- 'total_steps': self.info.total_steps,
199
- }
200
-
201
- pbar.set_postfix(logs)
202
- if self.config.wandb_project is not None:
203
- wandb.log(logs)
204
-
205
- if i == 1 or i % (self.config.save_every*self.config.grad_accum_steps) == 0 or i == max_iters:
206
- # SAVE AND CHECKPOINT STUFF
207
- if np.isnan(loss.mean().item()):
208
- if self.is_main_node and self.config.wandb_project is not None:
209
- tqdm.write("Skipping sampling & checkpoint because the loss is NaN")
210
- wandb.alert(title=f"Skipping sampling & checkpoint for training run {self.config.run_id}", text=f"Skipping sampling & checkpoint at {self.info.total_steps} for training run {self.info.wandb_run_id} iters because loss is NaN")
211
- else:
212
- self.save_checkpoints(models, optimizers)
213
- if self.is_main_node:
214
- create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/')
215
- self.sample(models, data, extras)
216
-
217
- def models_to_save(self):
218
- return ['generator', 'generator_ema']
219
-
220
- def save_checkpoints(self, models: Models, optimizers: Optimizers, suffix=None):
221
- barrier()
222
- suffix = '' if suffix is None else suffix
223
- self.save_info(self.info, suffix=suffix)
224
- models_dict = models.to_dict()
225
- optimizers_dict = optimizers.to_dict()
226
- for key in self.models_to_save():
227
- model = models_dict[key]
228
- if model is not None:
229
- self.save_model(model, f"{key}{suffix}", is_fsdp=self.config.use_fsdp)
230
- for key in optimizers_dict:
231
- optimizer = optimizers_dict[key]
232
- if optimizer is not None:
233
- self.save_optimizer(optimizer, f'{key}_optim{suffix}', fsdp_model=models.generator if self.config.use_fsdp else None)
234
- if suffix == '' and self.info.total_steps > 1 and self.info.total_steps % self.config.backup_every == 0:
235
- self.save_checkpoints(models, optimizers, suffix=f"_{self.info.total_steps//1000}k")
236
- torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
core/utils/__init__.py DELETED
@@ -1,9 +0,0 @@
1
- from .base_dto import Base, nested_dto, EXPECTED, EXPECTED_TRAIN
2
- from .save_and_load import create_folder_if_necessary, safe_save, load_or_fail
3
-
4
- # MOVE IT SOMERWHERE ELSE
5
- def update_weights_ema(tgt_model, src_model, beta=0.999):
6
- for self_params, src_params in zip(tgt_model.parameters(), src_model.parameters()):
7
- self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1-beta)
8
- for self_buffers, src_buffers in zip(tgt_model.buffers(), src_model.buffers()):
9
- self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1-beta)
 
 
 
 
 
 
 
 
 
 
core/utils/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (763 Bytes)
 
core/utils/__pycache__/__init__.cpython-39.pyc DELETED
Binary file (804 Bytes)
 
core/utils/__pycache__/base_dto.cpython-310.pyc DELETED
Binary file (3.09 kB)
 
core/utils/__pycache__/base_dto.cpython-39.pyc DELETED
Binary file (3.11 kB)
 
core/utils/__pycache__/save_and_load.cpython-310.pyc DELETED
Binary file (2.19 kB)
 
core/utils/__pycache__/save_and_load.cpython-39.pyc DELETED
Binary file (2.2 kB)
 
core/utils/base_dto.py DELETED
@@ -1,56 +0,0 @@
1
- import dataclasses
2
- from dataclasses import dataclass, _MISSING_TYPE
3
- from munch import Munch
4
-
5
- EXPECTED = "___REQUIRED___"
6
- EXPECTED_TRAIN = "___REQUIRED_TRAIN___"
7
-
8
- # pylint: disable=invalid-field-call
9
- def nested_dto(x, raw=False):
10
- return dataclasses.field(default_factory=lambda: x if raw else Munch.fromDict(x))
11
-
12
- @dataclass(frozen=True)
13
- class Base:
14
- training: bool = None
15
- def __new__(cls, **kwargs):
16
- training = kwargs.get('training', True)
17
- setteable_fields = cls.setteable_fields(**kwargs)
18
- mandatory_fields = cls.mandatory_fields(**kwargs)
19
- invalid_kwargs = [
20
- {k: v} for k, v in kwargs.items() if k not in setteable_fields or v == EXPECTED or (v == EXPECTED_TRAIN and training is not False)
21
- ]
22
- print(mandatory_fields)
23
- assert (
24
- len(invalid_kwargs) == 0
25
- ), f"Invalid fields detected when initializing this DTO: {invalid_kwargs}.\nDeclare this field and set it to None or EXPECTED in order to make it setteable."
26
- missing_kwargs = [f for f in mandatory_fields if f not in kwargs]
27
- assert (
28
- len(missing_kwargs) == 0
29
- ), f"Required fields missing initializing this DTO: {missing_kwargs}."
30
- return object.__new__(cls)
31
-
32
-
33
- @classmethod
34
- def setteable_fields(cls, **kwargs):
35
- return [f.name for f in dataclasses.fields(cls) if f.default is None or isinstance(f.default, _MISSING_TYPE) or f.default == EXPECTED or f.default == EXPECTED_TRAIN]
36
-
37
- @classmethod
38
- def mandatory_fields(cls, **kwargs):
39
- training = kwargs.get('training', True)
40
- return [f.name for f in dataclasses.fields(cls) if isinstance(f.default, _MISSING_TYPE) and isinstance(f.default_factory, _MISSING_TYPE) or f.default == EXPECTED or (f.default == EXPECTED_TRAIN and training is not False)]
41
-
42
- @classmethod
43
- def from_dict(cls, kwargs):
44
- for k in kwargs:
45
- if isinstance(kwargs[k], (dict, list, tuple)):
46
- kwargs[k] = Munch.fromDict(kwargs[k])
47
- return cls(**kwargs)
48
-
49
- def to_dict(self):
50
- # selfdict = dataclasses.asdict(self) # needs to pickle stuff, doesn't support some more complex classes
51
- selfdict = {}
52
- for k in dataclasses.fields(self):
53
- selfdict[k.name] = getattr(self, k.name)
54
- if isinstance(selfdict[k.name], Munch):
55
- selfdict[k.name] = selfdict[k.name].toDict()
56
- return selfdict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
core/utils/save_and_load.py DELETED
@@ -1,59 +0,0 @@
1
- import os
2
- import torch
3
- import json
4
- from pathlib import Path
5
- import safetensors
6
- import wandb
7
-
8
-
9
- def create_folder_if_necessary(path):
10
- path = "/".join(path.split("/")[:-1])
11
- Path(path).mkdir(parents=True, exist_ok=True)
12
-
13
-
14
- def safe_save(ckpt, path):
15
- try:
16
- os.remove(f"{path}.bak")
17
- except OSError:
18
- pass
19
- try:
20
- os.rename(path, f"{path}.bak")
21
- except OSError:
22
- pass
23
- if path.endswith(".pt") or path.endswith(".ckpt"):
24
- torch.save(ckpt, path)
25
- elif path.endswith(".json"):
26
- with open(path, "w", encoding="utf-8") as f:
27
- json.dump(ckpt, f, indent=4)
28
- elif path.endswith(".safetensors"):
29
- safetensors.torch.save_file(ckpt, path)
30
- else:
31
- raise ValueError(f"File extension not supported: {path}")
32
-
33
-
34
- def load_or_fail(path, wandb_run_id=None):
35
- accepted_extensions = [".pt", ".ckpt", ".json", ".safetensors"]
36
- try:
37
- assert any(
38
- [path.endswith(ext) for ext in accepted_extensions]
39
- ), f"Automatic loading not supported for this extension: {path}"
40
- if not os.path.exists(path):
41
- checkpoint = None
42
- elif path.endswith(".pt") or path.endswith(".ckpt"):
43
- checkpoint = torch.load(path, map_location="cpu")
44
- elif path.endswith(".json"):
45
- with open(path, "r", encoding="utf-8") as f:
46
- checkpoint = json.load(f)
47
- elif path.endswith(".safetensors"):
48
- checkpoint = {}
49
- with safetensors.safe_open(path, framework="pt", device="cpu") as f:
50
- for key in f.keys():
51
- checkpoint[key] = f.get_tensor(key)
52
- return checkpoint
53
- except Exception as e:
54
- if wandb_run_id is not None:
55
- wandb.alert(
56
- title=f"Corrupt checkpoint for run {wandb_run_id}",
57
- text=f"Training {wandb_run_id} tried to load checkpoint {path} and failed",
58
- )
59
- raise e