Spaces:
Runtime error
Runtime error
DeepCoreB4
commited on
Commit
•
6c10d0d
1
Parent(s):
0943f5a
Upload 1369 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +56 -0
- __pycache__/app.cpython-310.pyc +0 -0
- __pycache__/launch.cpython-310.pyc +0 -0
- __pycache__/webui.cpython-310.pyc +0 -0
- configs/alt-diffusion-inference.yaml +72 -0
- configs/instruct-pix2pix.yaml +98 -0
- configs/v1-inference.yaml +70 -0
- configs/v1-inpainting-inference.yaml +70 -0
- embeddings/Place Textual Inversion embeddings here.txt +0 -0
- extensions-builtin/LDSR/__pycache__/ldsr_model_arch.cpython-310.pyc +0 -0
- extensions-builtin/LDSR/__pycache__/preload.cpython-310.pyc +0 -0
- extensions-builtin/LDSR/__pycache__/sd_hijack_autoencoder.cpython-310.pyc +0 -0
- extensions-builtin/LDSR/__pycache__/sd_hijack_ddpm_v1.cpython-310.pyc +0 -0
- extensions-builtin/LDSR/ldsr_model_arch.py +253 -0
- extensions-builtin/LDSR/preload.py +6 -0
- extensions-builtin/LDSR/scripts/__pycache__/ldsr_model.cpython-310.pyc +0 -0
- extensions-builtin/LDSR/scripts/ldsr_model.py +69 -0
- extensions-builtin/LDSR/sd_hijack_autoencoder.py +286 -0
- extensions-builtin/LDSR/sd_hijack_ddpm_v1.py +1449 -0
- extensions-builtin/Lora/__pycache__/extra_networks_lora.cpython-310.pyc +0 -0
- extensions-builtin/Lora/__pycache__/lora.cpython-310.pyc +0 -0
- extensions-builtin/Lora/__pycache__/preload.cpython-310.pyc +0 -0
- extensions-builtin/Lora/__pycache__/ui_extra_networks_lora.cpython-310.pyc +0 -0
- extensions-builtin/Lora/extra_networks_lora.py +26 -0
- extensions-builtin/Lora/lora.py +226 -0
- extensions-builtin/Lora/preload.py +6 -0
- extensions-builtin/Lora/scripts/__pycache__/lora_script.cpython-310.pyc +0 -0
- extensions-builtin/Lora/scripts/lora_script.py +38 -0
- extensions-builtin/Lora/ui_extra_networks_lora.py +31 -0
- extensions-builtin/ScuNET/__pycache__/preload.cpython-310.pyc +0 -0
- extensions-builtin/ScuNET/__pycache__/scunet_model_arch.cpython-310.pyc +0 -0
- extensions-builtin/ScuNET/preload.py +6 -0
- extensions-builtin/ScuNET/scripts/__pycache__/scunet_model.cpython-310.pyc +0 -0
- extensions-builtin/ScuNET/scripts/scunet_model.py +87 -0
- extensions-builtin/ScuNET/scunet_model_arch.py +265 -0
- extensions-builtin/SwinIR/__pycache__/preload.cpython-310.pyc +0 -0
- extensions-builtin/SwinIR/__pycache__/swinir_model_arch.cpython-310.pyc +0 -0
- extensions-builtin/SwinIR/__pycache__/swinir_model_arch_v2.cpython-310.pyc +0 -0
- extensions-builtin/SwinIR/preload.py +6 -0
- extensions-builtin/SwinIR/scripts/__pycache__/swinir_model.cpython-310.pyc +0 -0
- extensions-builtin/SwinIR/scripts/swinir_model.py +178 -0
- extensions-builtin/SwinIR/swinir_model_arch.py +867 -0
- extensions-builtin/SwinIR/swinir_model_arch_v2.py +1017 -0
- extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js +110 -0
- extensions/gif2gif/README.md +47 -0
- extensions/gif2gif/instructions.txt +19 -0
- extensions/gif2gif/javascript/gif2gif_hints.js +40 -0
- extensions/gif2gif/scripts/__pycache__/gif2gif.cpython-310.pyc +0 -0
- extensions/gif2gif/scripts/gif2gif.py +355 -0
- extensions/put extensions here.txt +0 -0
.gitattributes
CHANGED
@@ -32,3 +32,59 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
html/diffusion_banner[[:space:]]-[[:space:]]Kopie.gif filter=lfs diff=lfs merge=lfs -text
|
36 |
+
html/diffusion_banner.gif filter=lfs diff=lfs merge=lfs -text
|
37 |
+
outputs/img2img-images/2023-03-22/00000-2069306538.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
outputs/txt2img-grids/2023-03-21/grid-0000.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
outputs/txt2img-grids/2023-03-22/grid-0000.png filter=lfs diff=lfs merge=lfs -text
|
40 |
+
outputs/txt2img-grids/2023-03-22/prompt_matrix-0000-4155696721.png filter=lfs diff=lfs merge=lfs -text
|
41 |
+
outputs/txt2img-grids/2023-03-22/prompt_matrix-0001-2769304212.png filter=lfs diff=lfs merge=lfs -text
|
42 |
+
outputs/txt2img-grids/2023-03-22/prompt_matrix-0002-3709775125.png filter=lfs diff=lfs merge=lfs -text
|
43 |
+
outputs/txt2img-grids/2023-03-22/prompt_matrix-0003-3558263618.png filter=lfs diff=lfs merge=lfs -text
|
44 |
+
outputs/txt2img-grids/2023-03-22/prompt_matrix-0004-1174613831.png filter=lfs diff=lfs merge=lfs -text
|
45 |
+
outputs/txt2img-grids/2023-03-25/grid-0000.png filter=lfs diff=lfs merge=lfs -text
|
46 |
+
outputs/txt2img-grids/2023-03-25/grid-0001.png filter=lfs diff=lfs merge=lfs -text
|
47 |
+
outputs/txt2img-grids/2023-03-25/grid-0002.png filter=lfs diff=lfs merge=lfs -text
|
48 |
+
outputs/txt2img-grids/2023-03-25/grid-0003.png filter=lfs diff=lfs merge=lfs -text
|
49 |
+
outputs/txt2img-grids/2023-03-25/grid-0004.png filter=lfs diff=lfs merge=lfs -text
|
50 |
+
outputs/txt2img-grids/2023-03-25/grid-0005.png filter=lfs diff=lfs merge=lfs -text
|
51 |
+
outputs/txt2img-grids/2023-03-25/grid-0006.png filter=lfs diff=lfs merge=lfs -text
|
52 |
+
outputs/txt2img-grids/2023-03-25/grid-0007.png filter=lfs diff=lfs merge=lfs -text
|
53 |
+
outputs/txt2img-grids/2023-03-25/grid-0008.png filter=lfs diff=lfs merge=lfs -text
|
54 |
+
outputs/txt2img-grids/2023-03-25/grid-0009.png filter=lfs diff=lfs merge=lfs -text
|
55 |
+
outputs/txt2img-images/2023-03-25/00015-457796679.png filter=lfs diff=lfs merge=lfs -text
|
56 |
+
outputs/txt2img-images/2023-03-25/00016-457796680.png filter=lfs diff=lfs merge=lfs -text
|
57 |
+
outputs/txt2img-images/2023-03-25/00017-457796681.png filter=lfs diff=lfs merge=lfs -text
|
58 |
+
outputs/txt2img-images/2023-03-25/00018-457796682.png filter=lfs diff=lfs merge=lfs -text
|
59 |
+
outputs/txt2img-images/2023-03-25/00070-4269172638.png filter=lfs diff=lfs merge=lfs -text
|
60 |
+
repositories/BLIP/BLIP.gif filter=lfs diff=lfs merge=lfs -text
|
61 |
+
repositories/stable-diffusion-stability-ai/assets/stable-inpainting/merged-leopards.png filter=lfs diff=lfs merge=lfs -text
|
62 |
+
repositories/stable-diffusion-stability-ai/assets/stable-samples/depth2img/d2i.gif filter=lfs diff=lfs merge=lfs -text
|
63 |
+
repositories/stable-diffusion-stability-ai/assets/stable-samples/depth2img/depth2img01.png filter=lfs diff=lfs merge=lfs -text
|
64 |
+
repositories/stable-diffusion-stability-ai/assets/stable-samples/depth2img/depth2img02.png filter=lfs diff=lfs merge=lfs -text
|
65 |
+
repositories/stable-diffusion-stability-ai/assets/stable-samples/depth2img/merged-0000.png filter=lfs diff=lfs merge=lfs -text
|
66 |
+
repositories/stable-diffusion-stability-ai/assets/stable-samples/depth2img/merged-0004.png filter=lfs diff=lfs merge=lfs -text
|
67 |
+
repositories/stable-diffusion-stability-ai/assets/stable-samples/depth2img/merged-0005.png filter=lfs diff=lfs merge=lfs -text
|
68 |
+
repositories/stable-diffusion-stability-ai/assets/stable-samples/img2img/upscaling-in.png filter=lfs diff=lfs merge=lfs -text
|
69 |
+
repositories/stable-diffusion-stability-ai/assets/stable-samples/img2img/upscaling-out.png filter=lfs diff=lfs merge=lfs -text
|
70 |
+
repositories/stable-diffusion-stability-ai/assets/stable-samples/txt2img/768/merged-0001.png filter=lfs diff=lfs merge=lfs -text
|
71 |
+
repositories/stable-diffusion-stability-ai/assets/stable-samples/txt2img/768/merged-0002.png filter=lfs diff=lfs merge=lfs -text
|
72 |
+
repositories/stable-diffusion-stability-ai/assets/stable-samples/txt2img/768/merged-0003.png filter=lfs diff=lfs merge=lfs -text
|
73 |
+
repositories/stable-diffusion-stability-ai/assets/stable-samples/txt2img/768/merged-0004.png filter=lfs diff=lfs merge=lfs -text
|
74 |
+
repositories/stable-diffusion-stability-ai/assets/stable-samples/txt2img/768/merged-0005.png filter=lfs diff=lfs merge=lfs -text
|
75 |
+
repositories/stable-diffusion-stability-ai/assets/stable-samples/txt2img/768/merged-0006.png filter=lfs diff=lfs merge=lfs -text
|
76 |
+
repositories/stable-diffusion-stability-ai/assets/stable-samples/txt2img/merged-0001.png filter=lfs diff=lfs merge=lfs -text
|
77 |
+
repositories/stable-diffusion-stability-ai/assets/stable-samples/txt2img/merged-0003.png filter=lfs diff=lfs merge=lfs -text
|
78 |
+
repositories/stable-diffusion-stability-ai/assets/stable-samples/txt2img/merged-0005.png filter=lfs diff=lfs merge=lfs -text
|
79 |
+
repositories/stable-diffusion-stability-ai/assets/stable-samples/txt2img/merged-0006.png filter=lfs diff=lfs merge=lfs -text
|
80 |
+
repositories/stable-diffusion-stability-ai/assets/stable-samples/txt2img/merged-0007.png filter=lfs diff=lfs merge=lfs -text
|
81 |
+
repositories/stable-diffusion-stability-ai/assets/stable-samples/upscaling/merged-dog.png filter=lfs diff=lfs merge=lfs -text
|
82 |
+
repositories/stable-diffusion-stability-ai/assets/stable-samples/upscaling/sampled-bear-x4.png filter=lfs diff=lfs merge=lfs -text
|
83 |
+
repositories/stable-diffusion-stability-ai/assets/stable-samples/upscaling/snow-leopard-x4.png filter=lfs diff=lfs merge=lfs -text
|
84 |
+
repositories/taming-transformers/assets/birddrawnbyachild.png filter=lfs diff=lfs merge=lfs -text
|
85 |
+
repositories/taming-transformers/assets/first_stage_mushrooms.png filter=lfs diff=lfs merge=lfs -text
|
86 |
+
repositories/taming-transformers/assets/first_stage_squirrels.png filter=lfs diff=lfs merge=lfs -text
|
87 |
+
repositories/taming-transformers/assets/imagenet.png filter=lfs diff=lfs merge=lfs -text
|
88 |
+
repositories/taming-transformers/data/open_images_annotations_100/train/000b1b3b85edd850.jpg filter=lfs diff=lfs merge=lfs -text
|
89 |
+
repositories/taming-transformers/data/open_images_annotations_100/validation/0a600f1148d1023c.jpg filter=lfs diff=lfs merge=lfs -text
|
90 |
+
repositories/taming-transformers/scripts/reconstruction_usage.ipynb filter=lfs diff=lfs merge=lfs -text
|
__pycache__/app.cpython-310.pyc
ADDED
Binary file (14.2 kB). View file
|
|
__pycache__/launch.cpython-310.pyc
ADDED
Binary file (14.3 kB). View file
|
|
__pycache__/webui.cpython-310.pyc
ADDED
Binary file (10.4 kB). View file
|
|
configs/alt-diffusion-inference.yaml
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 1.0e-04
|
3 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
4 |
+
params:
|
5 |
+
linear_start: 0.00085
|
6 |
+
linear_end: 0.0120
|
7 |
+
num_timesteps_cond: 1
|
8 |
+
log_every_t: 200
|
9 |
+
timesteps: 1000
|
10 |
+
first_stage_key: "jpg"
|
11 |
+
cond_stage_key: "txt"
|
12 |
+
image_size: 64
|
13 |
+
channels: 4
|
14 |
+
cond_stage_trainable: false # Note: different from the one we trained before
|
15 |
+
conditioning_key: crossattn
|
16 |
+
monitor: val/loss_simple_ema
|
17 |
+
scale_factor: 0.18215
|
18 |
+
use_ema: False
|
19 |
+
|
20 |
+
scheduler_config: # 10000 warmup steps
|
21 |
+
target: ldm.lr_scheduler.LambdaLinearScheduler
|
22 |
+
params:
|
23 |
+
warm_up_steps: [ 10000 ]
|
24 |
+
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
25 |
+
f_start: [ 1.e-6 ]
|
26 |
+
f_max: [ 1. ]
|
27 |
+
f_min: [ 1. ]
|
28 |
+
|
29 |
+
unet_config:
|
30 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
31 |
+
params:
|
32 |
+
image_size: 32 # unused
|
33 |
+
in_channels: 4
|
34 |
+
out_channels: 4
|
35 |
+
model_channels: 320
|
36 |
+
attention_resolutions: [ 4, 2, 1 ]
|
37 |
+
num_res_blocks: 2
|
38 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
39 |
+
num_heads: 8
|
40 |
+
use_spatial_transformer: True
|
41 |
+
transformer_depth: 1
|
42 |
+
context_dim: 768
|
43 |
+
use_checkpoint: True
|
44 |
+
legacy: False
|
45 |
+
|
46 |
+
first_stage_config:
|
47 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
48 |
+
params:
|
49 |
+
embed_dim: 4
|
50 |
+
monitor: val/rec_loss
|
51 |
+
ddconfig:
|
52 |
+
double_z: true
|
53 |
+
z_channels: 4
|
54 |
+
resolution: 256
|
55 |
+
in_channels: 3
|
56 |
+
out_ch: 3
|
57 |
+
ch: 128
|
58 |
+
ch_mult:
|
59 |
+
- 1
|
60 |
+
- 2
|
61 |
+
- 4
|
62 |
+
- 4
|
63 |
+
num_res_blocks: 2
|
64 |
+
attn_resolutions: []
|
65 |
+
dropout: 0.0
|
66 |
+
lossconfig:
|
67 |
+
target: torch.nn.Identity
|
68 |
+
|
69 |
+
cond_stage_config:
|
70 |
+
target: modules.xlmr.BertSeriesModelWithTransformation
|
71 |
+
params:
|
72 |
+
name: "XLMR-Large"
|
configs/instruct-pix2pix.yaml
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
|
2 |
+
# See more details in LICENSE.
|
3 |
+
|
4 |
+
model:
|
5 |
+
base_learning_rate: 1.0e-04
|
6 |
+
target: modules.models.diffusion.ddpm_edit.LatentDiffusion
|
7 |
+
params:
|
8 |
+
linear_start: 0.00085
|
9 |
+
linear_end: 0.0120
|
10 |
+
num_timesteps_cond: 1
|
11 |
+
log_every_t: 200
|
12 |
+
timesteps: 1000
|
13 |
+
first_stage_key: edited
|
14 |
+
cond_stage_key: edit
|
15 |
+
# image_size: 64
|
16 |
+
# image_size: 32
|
17 |
+
image_size: 16
|
18 |
+
channels: 4
|
19 |
+
cond_stage_trainable: false # Note: different from the one we trained before
|
20 |
+
conditioning_key: hybrid
|
21 |
+
monitor: val/loss_simple_ema
|
22 |
+
scale_factor: 0.18215
|
23 |
+
use_ema: false
|
24 |
+
|
25 |
+
scheduler_config: # 10000 warmup steps
|
26 |
+
target: ldm.lr_scheduler.LambdaLinearScheduler
|
27 |
+
params:
|
28 |
+
warm_up_steps: [ 0 ]
|
29 |
+
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
30 |
+
f_start: [ 1.e-6 ]
|
31 |
+
f_max: [ 1. ]
|
32 |
+
f_min: [ 1. ]
|
33 |
+
|
34 |
+
unet_config:
|
35 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
36 |
+
params:
|
37 |
+
image_size: 32 # unused
|
38 |
+
in_channels: 8
|
39 |
+
out_channels: 4
|
40 |
+
model_channels: 320
|
41 |
+
attention_resolutions: [ 4, 2, 1 ]
|
42 |
+
num_res_blocks: 2
|
43 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
44 |
+
num_heads: 8
|
45 |
+
use_spatial_transformer: True
|
46 |
+
transformer_depth: 1
|
47 |
+
context_dim: 768
|
48 |
+
use_checkpoint: True
|
49 |
+
legacy: False
|
50 |
+
|
51 |
+
first_stage_config:
|
52 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
53 |
+
params:
|
54 |
+
embed_dim: 4
|
55 |
+
monitor: val/rec_loss
|
56 |
+
ddconfig:
|
57 |
+
double_z: true
|
58 |
+
z_channels: 4
|
59 |
+
resolution: 256
|
60 |
+
in_channels: 3
|
61 |
+
out_ch: 3
|
62 |
+
ch: 128
|
63 |
+
ch_mult:
|
64 |
+
- 1
|
65 |
+
- 2
|
66 |
+
- 4
|
67 |
+
- 4
|
68 |
+
num_res_blocks: 2
|
69 |
+
attn_resolutions: []
|
70 |
+
dropout: 0.0
|
71 |
+
lossconfig:
|
72 |
+
target: torch.nn.Identity
|
73 |
+
|
74 |
+
cond_stage_config:
|
75 |
+
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
76 |
+
|
77 |
+
data:
|
78 |
+
target: main.DataModuleFromConfig
|
79 |
+
params:
|
80 |
+
batch_size: 128
|
81 |
+
num_workers: 1
|
82 |
+
wrap: false
|
83 |
+
validation:
|
84 |
+
target: edit_dataset.EditDataset
|
85 |
+
params:
|
86 |
+
path: data/clip-filtered-dataset
|
87 |
+
cache_dir: data/
|
88 |
+
cache_name: data_10k
|
89 |
+
split: val
|
90 |
+
min_text_sim: 0.2
|
91 |
+
min_image_sim: 0.75
|
92 |
+
min_direction_sim: 0.2
|
93 |
+
max_samples_per_prompt: 1
|
94 |
+
min_resize_res: 512
|
95 |
+
max_resize_res: 512
|
96 |
+
crop_res: 512
|
97 |
+
output_as_edit: False
|
98 |
+
real_input: True
|
configs/v1-inference.yaml
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 1.0e-04
|
3 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
4 |
+
params:
|
5 |
+
linear_start: 0.00085
|
6 |
+
linear_end: 0.0120
|
7 |
+
num_timesteps_cond: 1
|
8 |
+
log_every_t: 200
|
9 |
+
timesteps: 1000
|
10 |
+
first_stage_key: "jpg"
|
11 |
+
cond_stage_key: "txt"
|
12 |
+
image_size: 64
|
13 |
+
channels: 4
|
14 |
+
cond_stage_trainable: false # Note: different from the one we trained before
|
15 |
+
conditioning_key: crossattn
|
16 |
+
monitor: val/loss_simple_ema
|
17 |
+
scale_factor: 0.18215
|
18 |
+
use_ema: False
|
19 |
+
|
20 |
+
scheduler_config: # 10000 warmup steps
|
21 |
+
target: ldm.lr_scheduler.LambdaLinearScheduler
|
22 |
+
params:
|
23 |
+
warm_up_steps: [ 10000 ]
|
24 |
+
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
25 |
+
f_start: [ 1.e-6 ]
|
26 |
+
f_max: [ 1. ]
|
27 |
+
f_min: [ 1. ]
|
28 |
+
|
29 |
+
unet_config:
|
30 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
31 |
+
params:
|
32 |
+
image_size: 32 # unused
|
33 |
+
in_channels: 4
|
34 |
+
out_channels: 4
|
35 |
+
model_channels: 320
|
36 |
+
attention_resolutions: [ 4, 2, 1 ]
|
37 |
+
num_res_blocks: 2
|
38 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
39 |
+
num_heads: 8
|
40 |
+
use_spatial_transformer: True
|
41 |
+
transformer_depth: 1
|
42 |
+
context_dim: 768
|
43 |
+
use_checkpoint: True
|
44 |
+
legacy: False
|
45 |
+
|
46 |
+
first_stage_config:
|
47 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
48 |
+
params:
|
49 |
+
embed_dim: 4
|
50 |
+
monitor: val/rec_loss
|
51 |
+
ddconfig:
|
52 |
+
double_z: true
|
53 |
+
z_channels: 4
|
54 |
+
resolution: 256
|
55 |
+
in_channels: 3
|
56 |
+
out_ch: 3
|
57 |
+
ch: 128
|
58 |
+
ch_mult:
|
59 |
+
- 1
|
60 |
+
- 2
|
61 |
+
- 4
|
62 |
+
- 4
|
63 |
+
num_res_blocks: 2
|
64 |
+
attn_resolutions: []
|
65 |
+
dropout: 0.0
|
66 |
+
lossconfig:
|
67 |
+
target: torch.nn.Identity
|
68 |
+
|
69 |
+
cond_stage_config:
|
70 |
+
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
configs/v1-inpainting-inference.yaml
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 7.5e-05
|
3 |
+
target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
|
4 |
+
params:
|
5 |
+
linear_start: 0.00085
|
6 |
+
linear_end: 0.0120
|
7 |
+
num_timesteps_cond: 1
|
8 |
+
log_every_t: 200
|
9 |
+
timesteps: 1000
|
10 |
+
first_stage_key: "jpg"
|
11 |
+
cond_stage_key: "txt"
|
12 |
+
image_size: 64
|
13 |
+
channels: 4
|
14 |
+
cond_stage_trainable: false # Note: different from the one we trained before
|
15 |
+
conditioning_key: hybrid # important
|
16 |
+
monitor: val/loss_simple_ema
|
17 |
+
scale_factor: 0.18215
|
18 |
+
finetune_keys: null
|
19 |
+
|
20 |
+
scheduler_config: # 10000 warmup steps
|
21 |
+
target: ldm.lr_scheduler.LambdaLinearScheduler
|
22 |
+
params:
|
23 |
+
warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch
|
24 |
+
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
25 |
+
f_start: [ 1.e-6 ]
|
26 |
+
f_max: [ 1. ]
|
27 |
+
f_min: [ 1. ]
|
28 |
+
|
29 |
+
unet_config:
|
30 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
31 |
+
params:
|
32 |
+
image_size: 32 # unused
|
33 |
+
in_channels: 9 # 4 data + 4 downscaled image + 1 mask
|
34 |
+
out_channels: 4
|
35 |
+
model_channels: 320
|
36 |
+
attention_resolutions: [ 4, 2, 1 ]
|
37 |
+
num_res_blocks: 2
|
38 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
39 |
+
num_heads: 8
|
40 |
+
use_spatial_transformer: True
|
41 |
+
transformer_depth: 1
|
42 |
+
context_dim: 768
|
43 |
+
use_checkpoint: True
|
44 |
+
legacy: False
|
45 |
+
|
46 |
+
first_stage_config:
|
47 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
48 |
+
params:
|
49 |
+
embed_dim: 4
|
50 |
+
monitor: val/rec_loss
|
51 |
+
ddconfig:
|
52 |
+
double_z: true
|
53 |
+
z_channels: 4
|
54 |
+
resolution: 256
|
55 |
+
in_channels: 3
|
56 |
+
out_ch: 3
|
57 |
+
ch: 128
|
58 |
+
ch_mult:
|
59 |
+
- 1
|
60 |
+
- 2
|
61 |
+
- 4
|
62 |
+
- 4
|
63 |
+
num_res_blocks: 2
|
64 |
+
attn_resolutions: []
|
65 |
+
dropout: 0.0
|
66 |
+
lossconfig:
|
67 |
+
target: torch.nn.Identity
|
68 |
+
|
69 |
+
cond_stage_config:
|
70 |
+
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
embeddings/Place Textual Inversion embeddings here.txt
ADDED
File without changes
|
extensions-builtin/LDSR/__pycache__/ldsr_model_arch.cpython-310.pyc
ADDED
Binary file (6.76 kB). View file
|
|
extensions-builtin/LDSR/__pycache__/preload.cpython-310.pyc
ADDED
Binary file (514 Bytes). View file
|
|
extensions-builtin/LDSR/__pycache__/sd_hijack_autoencoder.cpython-310.pyc
ADDED
Binary file (8.84 kB). View file
|
|
extensions-builtin/LDSR/__pycache__/sd_hijack_ddpm_v1.cpython-310.pyc
ADDED
Binary file (43.1 kB). View file
|
|
extensions-builtin/LDSR/ldsr_model_arch.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gc
|
3 |
+
import time
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torchvision
|
8 |
+
from PIL import Image
|
9 |
+
from einops import rearrange, repeat
|
10 |
+
from omegaconf import OmegaConf
|
11 |
+
import safetensors.torch
|
12 |
+
|
13 |
+
from ldm.models.diffusion.ddim import DDIMSampler
|
14 |
+
from ldm.util import instantiate_from_config, ismap
|
15 |
+
from modules import shared, sd_hijack
|
16 |
+
|
17 |
+
cached_ldsr_model: torch.nn.Module = None
|
18 |
+
|
19 |
+
|
20 |
+
# Create LDSR Class
|
21 |
+
class LDSR:
|
22 |
+
def load_model_from_config(self, half_attention):
|
23 |
+
global cached_ldsr_model
|
24 |
+
|
25 |
+
if shared.opts.ldsr_cached and cached_ldsr_model is not None:
|
26 |
+
print("Loading model from cache")
|
27 |
+
model: torch.nn.Module = cached_ldsr_model
|
28 |
+
else:
|
29 |
+
print(f"Loading model from {self.modelPath}")
|
30 |
+
_, extension = os.path.splitext(self.modelPath)
|
31 |
+
if extension.lower() == ".safetensors":
|
32 |
+
pl_sd = safetensors.torch.load_file(self.modelPath, device="cpu")
|
33 |
+
else:
|
34 |
+
pl_sd = torch.load(self.modelPath, map_location="cpu")
|
35 |
+
sd = pl_sd["state_dict"] if "state_dict" in pl_sd else pl_sd
|
36 |
+
config = OmegaConf.load(self.yamlPath)
|
37 |
+
config.model.target = "ldm.models.diffusion.ddpm.LatentDiffusionV1"
|
38 |
+
model: torch.nn.Module = instantiate_from_config(config.model)
|
39 |
+
model.load_state_dict(sd, strict=False)
|
40 |
+
model = model.to(shared.device)
|
41 |
+
if half_attention:
|
42 |
+
model = model.half()
|
43 |
+
if shared.cmd_opts.opt_channelslast:
|
44 |
+
model = model.to(memory_format=torch.channels_last)
|
45 |
+
|
46 |
+
sd_hijack.model_hijack.hijack(model) # apply optimization
|
47 |
+
model.eval()
|
48 |
+
|
49 |
+
if shared.opts.ldsr_cached:
|
50 |
+
cached_ldsr_model = model
|
51 |
+
|
52 |
+
return {"model": model}
|
53 |
+
|
54 |
+
def __init__(self, model_path, yaml_path):
|
55 |
+
self.modelPath = model_path
|
56 |
+
self.yamlPath = yaml_path
|
57 |
+
|
58 |
+
@staticmethod
|
59 |
+
def run(model, selected_path, custom_steps, eta):
|
60 |
+
example = get_cond(selected_path)
|
61 |
+
|
62 |
+
n_runs = 1
|
63 |
+
guider = None
|
64 |
+
ckwargs = None
|
65 |
+
ddim_use_x0_pred = False
|
66 |
+
temperature = 1.
|
67 |
+
eta = eta
|
68 |
+
custom_shape = None
|
69 |
+
|
70 |
+
height, width = example["image"].shape[1:3]
|
71 |
+
split_input = height >= 128 and width >= 128
|
72 |
+
|
73 |
+
if split_input:
|
74 |
+
ks = 128
|
75 |
+
stride = 64
|
76 |
+
vqf = 4 #
|
77 |
+
model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride),
|
78 |
+
"vqf": vqf,
|
79 |
+
"patch_distributed_vq": True,
|
80 |
+
"tie_braker": False,
|
81 |
+
"clip_max_weight": 0.5,
|
82 |
+
"clip_min_weight": 0.01,
|
83 |
+
"clip_max_tie_weight": 0.5,
|
84 |
+
"clip_min_tie_weight": 0.01}
|
85 |
+
else:
|
86 |
+
if hasattr(model, "split_input_params"):
|
87 |
+
delattr(model, "split_input_params")
|
88 |
+
|
89 |
+
x_t = None
|
90 |
+
logs = None
|
91 |
+
for n in range(n_runs):
|
92 |
+
if custom_shape is not None:
|
93 |
+
x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
|
94 |
+
x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0])
|
95 |
+
|
96 |
+
logs = make_convolutional_sample(example, model,
|
97 |
+
custom_steps=custom_steps,
|
98 |
+
eta=eta, quantize_x0=False,
|
99 |
+
custom_shape=custom_shape,
|
100 |
+
temperature=temperature, noise_dropout=0.,
|
101 |
+
corrector=guider, corrector_kwargs=ckwargs, x_T=x_t,
|
102 |
+
ddim_use_x0_pred=ddim_use_x0_pred
|
103 |
+
)
|
104 |
+
return logs
|
105 |
+
|
106 |
+
def super_resolution(self, image, steps=100, target_scale=2, half_attention=False):
|
107 |
+
model = self.load_model_from_config(half_attention)
|
108 |
+
|
109 |
+
# Run settings
|
110 |
+
diffusion_steps = int(steps)
|
111 |
+
eta = 1.0
|
112 |
+
|
113 |
+
down_sample_method = 'Lanczos'
|
114 |
+
|
115 |
+
gc.collect()
|
116 |
+
if torch.cuda.is_available:
|
117 |
+
torch.cuda.empty_cache()
|
118 |
+
|
119 |
+
im_og = image
|
120 |
+
width_og, height_og = im_og.size
|
121 |
+
# If we can adjust the max upscale size, then the 4 below should be our variable
|
122 |
+
down_sample_rate = target_scale / 4
|
123 |
+
wd = width_og * down_sample_rate
|
124 |
+
hd = height_og * down_sample_rate
|
125 |
+
width_downsampled_pre = int(np.ceil(wd))
|
126 |
+
height_downsampled_pre = int(np.ceil(hd))
|
127 |
+
|
128 |
+
if down_sample_rate != 1:
|
129 |
+
print(
|
130 |
+
f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]')
|
131 |
+
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
|
132 |
+
else:
|
133 |
+
print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
|
134 |
+
|
135 |
+
# pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts
|
136 |
+
pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size
|
137 |
+
im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
|
138 |
+
|
139 |
+
logs = self.run(model["model"], im_padded, diffusion_steps, eta)
|
140 |
+
|
141 |
+
sample = logs["sample"]
|
142 |
+
sample = sample.detach().cpu()
|
143 |
+
sample = torch.clamp(sample, -1., 1.)
|
144 |
+
sample = (sample + 1.) / 2. * 255
|
145 |
+
sample = sample.numpy().astype(np.uint8)
|
146 |
+
sample = np.transpose(sample, (0, 2, 3, 1))
|
147 |
+
a = Image.fromarray(sample[0])
|
148 |
+
|
149 |
+
# remove padding
|
150 |
+
a = a.crop((0, 0) + tuple(np.array(im_og.size) * 4))
|
151 |
+
|
152 |
+
del model
|
153 |
+
gc.collect()
|
154 |
+
if torch.cuda.is_available:
|
155 |
+
torch.cuda.empty_cache()
|
156 |
+
|
157 |
+
return a
|
158 |
+
|
159 |
+
|
160 |
+
def get_cond(selected_path):
|
161 |
+
example = dict()
|
162 |
+
up_f = 4
|
163 |
+
c = selected_path.convert('RGB')
|
164 |
+
c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
|
165 |
+
c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]],
|
166 |
+
antialias=True)
|
167 |
+
c_up = rearrange(c_up, '1 c h w -> 1 h w c')
|
168 |
+
c = rearrange(c, '1 c h w -> 1 h w c')
|
169 |
+
c = 2. * c - 1.
|
170 |
+
|
171 |
+
c = c.to(shared.device)
|
172 |
+
example["LR_image"] = c
|
173 |
+
example["image"] = c_up
|
174 |
+
|
175 |
+
return example
|
176 |
+
|
177 |
+
|
178 |
+
@torch.no_grad()
|
179 |
+
def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None,
|
180 |
+
mask=None, x0=None, quantize_x0=False, temperature=1., score_corrector=None,
|
181 |
+
corrector_kwargs=None, x_t=None
|
182 |
+
):
|
183 |
+
ddim = DDIMSampler(model)
|
184 |
+
bs = shape[0]
|
185 |
+
shape = shape[1:]
|
186 |
+
print(f"Sampling with eta = {eta}; steps: {steps}")
|
187 |
+
samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback,
|
188 |
+
normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta,
|
189 |
+
mask=mask, x0=x0, temperature=temperature, verbose=False,
|
190 |
+
score_corrector=score_corrector,
|
191 |
+
corrector_kwargs=corrector_kwargs, x_t=x_t)
|
192 |
+
|
193 |
+
return samples, intermediates
|
194 |
+
|
195 |
+
|
196 |
+
@torch.no_grad()
|
197 |
+
def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
|
198 |
+
corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False):
|
199 |
+
log = dict()
|
200 |
+
|
201 |
+
z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
|
202 |
+
return_first_stage_outputs=True,
|
203 |
+
force_c_encode=not (hasattr(model, 'split_input_params')
|
204 |
+
and model.cond_stage_key == 'coordinates_bbox'),
|
205 |
+
return_original_cond=True)
|
206 |
+
|
207 |
+
if custom_shape is not None:
|
208 |
+
z = torch.randn(custom_shape)
|
209 |
+
print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}")
|
210 |
+
|
211 |
+
z0 = None
|
212 |
+
|
213 |
+
log["input"] = x
|
214 |
+
log["reconstruction"] = xrec
|
215 |
+
|
216 |
+
if ismap(xc):
|
217 |
+
log["original_conditioning"] = model.to_rgb(xc)
|
218 |
+
if hasattr(model, 'cond_stage_key'):
|
219 |
+
log[model.cond_stage_key] = model.to_rgb(xc)
|
220 |
+
|
221 |
+
else:
|
222 |
+
log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x)
|
223 |
+
if model.cond_stage_model:
|
224 |
+
log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x)
|
225 |
+
if model.cond_stage_key == 'class_label':
|
226 |
+
log[model.cond_stage_key] = xc[model.cond_stage_key]
|
227 |
+
|
228 |
+
with model.ema_scope("Plotting"):
|
229 |
+
t0 = time.time()
|
230 |
+
|
231 |
+
sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape,
|
232 |
+
eta=eta,
|
233 |
+
quantize_x0=quantize_x0, mask=None, x0=z0,
|
234 |
+
temperature=temperature, score_corrector=corrector, corrector_kwargs=corrector_kwargs,
|
235 |
+
x_t=x_T)
|
236 |
+
t1 = time.time()
|
237 |
+
|
238 |
+
if ddim_use_x0_pred:
|
239 |
+
sample = intermediates['pred_x0'][-1]
|
240 |
+
|
241 |
+
x_sample = model.decode_first_stage(sample)
|
242 |
+
|
243 |
+
try:
|
244 |
+
x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
|
245 |
+
log["sample_noquant"] = x_sample_noquant
|
246 |
+
log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
|
247 |
+
except:
|
248 |
+
pass
|
249 |
+
|
250 |
+
log["sample"] = x_sample
|
251 |
+
log["time"] = t1 - t0
|
252 |
+
|
253 |
+
return log
|
extensions-builtin/LDSR/preload.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from modules import paths
|
3 |
+
|
4 |
+
|
5 |
+
def preload(parser):
|
6 |
+
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(paths.models_path, 'LDSR'))
|
extensions-builtin/LDSR/scripts/__pycache__/ldsr_model.cpython-310.pyc
ADDED
Binary file (2.77 kB). View file
|
|
extensions-builtin/LDSR/scripts/ldsr_model.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import traceback
|
4 |
+
|
5 |
+
from basicsr.utils.download_util import load_file_from_url
|
6 |
+
|
7 |
+
from modules.upscaler import Upscaler, UpscalerData
|
8 |
+
from ldsr_model_arch import LDSR
|
9 |
+
from modules import shared, script_callbacks
|
10 |
+
import sd_hijack_autoencoder, sd_hijack_ddpm_v1
|
11 |
+
|
12 |
+
|
13 |
+
class UpscalerLDSR(Upscaler):
|
14 |
+
def __init__(self, user_path):
|
15 |
+
self.name = "LDSR"
|
16 |
+
self.user_path = user_path
|
17 |
+
self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
|
18 |
+
self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
|
19 |
+
super().__init__()
|
20 |
+
scaler_data = UpscalerData("LDSR", None, self)
|
21 |
+
self.scalers = [scaler_data]
|
22 |
+
|
23 |
+
def load_model(self, path: str):
|
24 |
+
# Remove incorrect project.yaml file if too big
|
25 |
+
yaml_path = os.path.join(self.model_path, "project.yaml")
|
26 |
+
old_model_path = os.path.join(self.model_path, "model.pth")
|
27 |
+
new_model_path = os.path.join(self.model_path, "model.ckpt")
|
28 |
+
safetensors_model_path = os.path.join(self.model_path, "model.safetensors")
|
29 |
+
if os.path.exists(yaml_path):
|
30 |
+
statinfo = os.stat(yaml_path)
|
31 |
+
if statinfo.st_size >= 10485760:
|
32 |
+
print("Removing invalid LDSR YAML file.")
|
33 |
+
os.remove(yaml_path)
|
34 |
+
if os.path.exists(old_model_path):
|
35 |
+
print("Renaming model from model.pth to model.ckpt")
|
36 |
+
os.rename(old_model_path, new_model_path)
|
37 |
+
if os.path.exists(safetensors_model_path):
|
38 |
+
model = safetensors_model_path
|
39 |
+
else:
|
40 |
+
model = load_file_from_url(url=self.model_url, model_dir=self.model_path,
|
41 |
+
file_name="model.ckpt", progress=True)
|
42 |
+
yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path,
|
43 |
+
file_name="project.yaml", progress=True)
|
44 |
+
|
45 |
+
try:
|
46 |
+
return LDSR(model, yaml)
|
47 |
+
|
48 |
+
except Exception:
|
49 |
+
print("Error importing LDSR:", file=sys.stderr)
|
50 |
+
print(traceback.format_exc(), file=sys.stderr)
|
51 |
+
return None
|
52 |
+
|
53 |
+
def do_upscale(self, img, path):
|
54 |
+
ldsr = self.load_model(path)
|
55 |
+
if ldsr is None:
|
56 |
+
print("NO LDSR!")
|
57 |
+
return img
|
58 |
+
ddim_steps = shared.opts.ldsr_steps
|
59 |
+
return ldsr.super_resolution(img, ddim_steps, self.scale)
|
60 |
+
|
61 |
+
|
62 |
+
def on_ui_settings():
|
63 |
+
import gradio as gr
|
64 |
+
|
65 |
+
shared.opts.add_option("ldsr_steps", shared.OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}, section=('upscaling', "Upscaling")))
|
66 |
+
shared.opts.add_option("ldsr_cached", shared.OptionInfo(False, "Cache LDSR model in memory", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")))
|
67 |
+
|
68 |
+
|
69 |
+
script_callbacks.on_ui_settings(on_ui_settings)
|
extensions-builtin/LDSR/sd_hijack_autoencoder.py
ADDED
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo
|
2 |
+
# The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo
|
3 |
+
# As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import pytorch_lightning as pl
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from contextlib import contextmanager
|
9 |
+
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
10 |
+
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
11 |
+
from ldm.util import instantiate_from_config
|
12 |
+
|
13 |
+
import ldm.models.autoencoder
|
14 |
+
|
15 |
+
class VQModel(pl.LightningModule):
|
16 |
+
def __init__(self,
|
17 |
+
ddconfig,
|
18 |
+
lossconfig,
|
19 |
+
n_embed,
|
20 |
+
embed_dim,
|
21 |
+
ckpt_path=None,
|
22 |
+
ignore_keys=[],
|
23 |
+
image_key="image",
|
24 |
+
colorize_nlabels=None,
|
25 |
+
monitor=None,
|
26 |
+
batch_resize_range=None,
|
27 |
+
scheduler_config=None,
|
28 |
+
lr_g_factor=1.0,
|
29 |
+
remap=None,
|
30 |
+
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
31 |
+
use_ema=False
|
32 |
+
):
|
33 |
+
super().__init__()
|
34 |
+
self.embed_dim = embed_dim
|
35 |
+
self.n_embed = n_embed
|
36 |
+
self.image_key = image_key
|
37 |
+
self.encoder = Encoder(**ddconfig)
|
38 |
+
self.decoder = Decoder(**ddconfig)
|
39 |
+
self.loss = instantiate_from_config(lossconfig)
|
40 |
+
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
|
41 |
+
remap=remap,
|
42 |
+
sane_index_shape=sane_index_shape)
|
43 |
+
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
|
44 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
45 |
+
if colorize_nlabels is not None:
|
46 |
+
assert type(colorize_nlabels)==int
|
47 |
+
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
48 |
+
if monitor is not None:
|
49 |
+
self.monitor = monitor
|
50 |
+
self.batch_resize_range = batch_resize_range
|
51 |
+
if self.batch_resize_range is not None:
|
52 |
+
print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
|
53 |
+
|
54 |
+
self.use_ema = use_ema
|
55 |
+
if self.use_ema:
|
56 |
+
self.model_ema = LitEma(self)
|
57 |
+
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
58 |
+
|
59 |
+
if ckpt_path is not None:
|
60 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
61 |
+
self.scheduler_config = scheduler_config
|
62 |
+
self.lr_g_factor = lr_g_factor
|
63 |
+
|
64 |
+
@contextmanager
|
65 |
+
def ema_scope(self, context=None):
|
66 |
+
if self.use_ema:
|
67 |
+
self.model_ema.store(self.parameters())
|
68 |
+
self.model_ema.copy_to(self)
|
69 |
+
if context is not None:
|
70 |
+
print(f"{context}: Switched to EMA weights")
|
71 |
+
try:
|
72 |
+
yield None
|
73 |
+
finally:
|
74 |
+
if self.use_ema:
|
75 |
+
self.model_ema.restore(self.parameters())
|
76 |
+
if context is not None:
|
77 |
+
print(f"{context}: Restored training weights")
|
78 |
+
|
79 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
80 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
81 |
+
keys = list(sd.keys())
|
82 |
+
for k in keys:
|
83 |
+
for ik in ignore_keys:
|
84 |
+
if k.startswith(ik):
|
85 |
+
print("Deleting key {} from state_dict.".format(k))
|
86 |
+
del sd[k]
|
87 |
+
missing, unexpected = self.load_state_dict(sd, strict=False)
|
88 |
+
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
89 |
+
if len(missing) > 0:
|
90 |
+
print(f"Missing Keys: {missing}")
|
91 |
+
print(f"Unexpected Keys: {unexpected}")
|
92 |
+
|
93 |
+
def on_train_batch_end(self, *args, **kwargs):
|
94 |
+
if self.use_ema:
|
95 |
+
self.model_ema(self)
|
96 |
+
|
97 |
+
def encode(self, x):
|
98 |
+
h = self.encoder(x)
|
99 |
+
h = self.quant_conv(h)
|
100 |
+
quant, emb_loss, info = self.quantize(h)
|
101 |
+
return quant, emb_loss, info
|
102 |
+
|
103 |
+
def encode_to_prequant(self, x):
|
104 |
+
h = self.encoder(x)
|
105 |
+
h = self.quant_conv(h)
|
106 |
+
return h
|
107 |
+
|
108 |
+
def decode(self, quant):
|
109 |
+
quant = self.post_quant_conv(quant)
|
110 |
+
dec = self.decoder(quant)
|
111 |
+
return dec
|
112 |
+
|
113 |
+
def decode_code(self, code_b):
|
114 |
+
quant_b = self.quantize.embed_code(code_b)
|
115 |
+
dec = self.decode(quant_b)
|
116 |
+
return dec
|
117 |
+
|
118 |
+
def forward(self, input, return_pred_indices=False):
|
119 |
+
quant, diff, (_,_,ind) = self.encode(input)
|
120 |
+
dec = self.decode(quant)
|
121 |
+
if return_pred_indices:
|
122 |
+
return dec, diff, ind
|
123 |
+
return dec, diff
|
124 |
+
|
125 |
+
def get_input(self, batch, k):
|
126 |
+
x = batch[k]
|
127 |
+
if len(x.shape) == 3:
|
128 |
+
x = x[..., None]
|
129 |
+
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
130 |
+
if self.batch_resize_range is not None:
|
131 |
+
lower_size = self.batch_resize_range[0]
|
132 |
+
upper_size = self.batch_resize_range[1]
|
133 |
+
if self.global_step <= 4:
|
134 |
+
# do the first few batches with max size to avoid later oom
|
135 |
+
new_resize = upper_size
|
136 |
+
else:
|
137 |
+
new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
|
138 |
+
if new_resize != x.shape[2]:
|
139 |
+
x = F.interpolate(x, size=new_resize, mode="bicubic")
|
140 |
+
x = x.detach()
|
141 |
+
return x
|
142 |
+
|
143 |
+
def training_step(self, batch, batch_idx, optimizer_idx):
|
144 |
+
# https://github.com/pytorch/pytorch/issues/37142
|
145 |
+
# try not to fool the heuristics
|
146 |
+
x = self.get_input(batch, self.image_key)
|
147 |
+
xrec, qloss, ind = self(x, return_pred_indices=True)
|
148 |
+
|
149 |
+
if optimizer_idx == 0:
|
150 |
+
# autoencode
|
151 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
152 |
+
last_layer=self.get_last_layer(), split="train",
|
153 |
+
predicted_indices=ind)
|
154 |
+
|
155 |
+
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
156 |
+
return aeloss
|
157 |
+
|
158 |
+
if optimizer_idx == 1:
|
159 |
+
# discriminator
|
160 |
+
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
161 |
+
last_layer=self.get_last_layer(), split="train")
|
162 |
+
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
163 |
+
return discloss
|
164 |
+
|
165 |
+
def validation_step(self, batch, batch_idx):
|
166 |
+
log_dict = self._validation_step(batch, batch_idx)
|
167 |
+
with self.ema_scope():
|
168 |
+
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
|
169 |
+
return log_dict
|
170 |
+
|
171 |
+
def _validation_step(self, batch, batch_idx, suffix=""):
|
172 |
+
x = self.get_input(batch, self.image_key)
|
173 |
+
xrec, qloss, ind = self(x, return_pred_indices=True)
|
174 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
|
175 |
+
self.global_step,
|
176 |
+
last_layer=self.get_last_layer(),
|
177 |
+
split="val"+suffix,
|
178 |
+
predicted_indices=ind
|
179 |
+
)
|
180 |
+
|
181 |
+
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
|
182 |
+
self.global_step,
|
183 |
+
last_layer=self.get_last_layer(),
|
184 |
+
split="val"+suffix,
|
185 |
+
predicted_indices=ind
|
186 |
+
)
|
187 |
+
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
|
188 |
+
self.log(f"val{suffix}/rec_loss", rec_loss,
|
189 |
+
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
190 |
+
self.log(f"val{suffix}/aeloss", aeloss,
|
191 |
+
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
192 |
+
if version.parse(pl.__version__) >= version.parse('1.4.0'):
|
193 |
+
del log_dict_ae[f"val{suffix}/rec_loss"]
|
194 |
+
self.log_dict(log_dict_ae)
|
195 |
+
self.log_dict(log_dict_disc)
|
196 |
+
return self.log_dict
|
197 |
+
|
198 |
+
def configure_optimizers(self):
|
199 |
+
lr_d = self.learning_rate
|
200 |
+
lr_g = self.lr_g_factor*self.learning_rate
|
201 |
+
print("lr_d", lr_d)
|
202 |
+
print("lr_g", lr_g)
|
203 |
+
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
204 |
+
list(self.decoder.parameters())+
|
205 |
+
list(self.quantize.parameters())+
|
206 |
+
list(self.quant_conv.parameters())+
|
207 |
+
list(self.post_quant_conv.parameters()),
|
208 |
+
lr=lr_g, betas=(0.5, 0.9))
|
209 |
+
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
210 |
+
lr=lr_d, betas=(0.5, 0.9))
|
211 |
+
|
212 |
+
if self.scheduler_config is not None:
|
213 |
+
scheduler = instantiate_from_config(self.scheduler_config)
|
214 |
+
|
215 |
+
print("Setting up LambdaLR scheduler...")
|
216 |
+
scheduler = [
|
217 |
+
{
|
218 |
+
'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
|
219 |
+
'interval': 'step',
|
220 |
+
'frequency': 1
|
221 |
+
},
|
222 |
+
{
|
223 |
+
'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
|
224 |
+
'interval': 'step',
|
225 |
+
'frequency': 1
|
226 |
+
},
|
227 |
+
]
|
228 |
+
return [opt_ae, opt_disc], scheduler
|
229 |
+
return [opt_ae, opt_disc], []
|
230 |
+
|
231 |
+
def get_last_layer(self):
|
232 |
+
return self.decoder.conv_out.weight
|
233 |
+
|
234 |
+
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
|
235 |
+
log = dict()
|
236 |
+
x = self.get_input(batch, self.image_key)
|
237 |
+
x = x.to(self.device)
|
238 |
+
if only_inputs:
|
239 |
+
log["inputs"] = x
|
240 |
+
return log
|
241 |
+
xrec, _ = self(x)
|
242 |
+
if x.shape[1] > 3:
|
243 |
+
# colorize with random projection
|
244 |
+
assert xrec.shape[1] > 3
|
245 |
+
x = self.to_rgb(x)
|
246 |
+
xrec = self.to_rgb(xrec)
|
247 |
+
log["inputs"] = x
|
248 |
+
log["reconstructions"] = xrec
|
249 |
+
if plot_ema:
|
250 |
+
with self.ema_scope():
|
251 |
+
xrec_ema, _ = self(x)
|
252 |
+
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
|
253 |
+
log["reconstructions_ema"] = xrec_ema
|
254 |
+
return log
|
255 |
+
|
256 |
+
def to_rgb(self, x):
|
257 |
+
assert self.image_key == "segmentation"
|
258 |
+
if not hasattr(self, "colorize"):
|
259 |
+
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
260 |
+
x = F.conv2d(x, weight=self.colorize)
|
261 |
+
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
262 |
+
return x
|
263 |
+
|
264 |
+
|
265 |
+
class VQModelInterface(VQModel):
|
266 |
+
def __init__(self, embed_dim, *args, **kwargs):
|
267 |
+
super().__init__(embed_dim=embed_dim, *args, **kwargs)
|
268 |
+
self.embed_dim = embed_dim
|
269 |
+
|
270 |
+
def encode(self, x):
|
271 |
+
h = self.encoder(x)
|
272 |
+
h = self.quant_conv(h)
|
273 |
+
return h
|
274 |
+
|
275 |
+
def decode(self, h, force_not_quantize=False):
|
276 |
+
# also go through quantization layer
|
277 |
+
if not force_not_quantize:
|
278 |
+
quant, emb_loss, info = self.quantize(h)
|
279 |
+
else:
|
280 |
+
quant = h
|
281 |
+
quant = self.post_quant_conv(quant)
|
282 |
+
dec = self.decoder(quant)
|
283 |
+
return dec
|
284 |
+
|
285 |
+
setattr(ldm.models.autoencoder, "VQModel", VQModel)
|
286 |
+
setattr(ldm.models.autoencoder, "VQModelInterface", VQModelInterface)
|
extensions-builtin/LDSR/sd_hijack_ddpm_v1.py
ADDED
@@ -0,0 +1,1449 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This script is copied from the compvis/stable-diffusion repo (aka the SD V1 repo)
|
2 |
+
# Original filename: ldm/models/diffusion/ddpm.py
|
3 |
+
# The purpose to reinstate the old DDPM logic which works with VQ, whereas the V2 one doesn't
|
4 |
+
# Some models such as LDSR require VQ to work correctly
|
5 |
+
# The classes are suffixed with "V1" and added back to the "ldm.models.diffusion.ddpm" module
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import numpy as np
|
10 |
+
import pytorch_lightning as pl
|
11 |
+
from torch.optim.lr_scheduler import LambdaLR
|
12 |
+
from einops import rearrange, repeat
|
13 |
+
from contextlib import contextmanager
|
14 |
+
from functools import partial
|
15 |
+
from tqdm import tqdm
|
16 |
+
from torchvision.utils import make_grid
|
17 |
+
from pytorch_lightning.utilities.distributed import rank_zero_only
|
18 |
+
|
19 |
+
from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
|
20 |
+
from ldm.modules.ema import LitEma
|
21 |
+
from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
|
22 |
+
from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
|
23 |
+
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
|
24 |
+
from ldm.models.diffusion.ddim import DDIMSampler
|
25 |
+
|
26 |
+
import ldm.models.diffusion.ddpm
|
27 |
+
|
28 |
+
__conditioning_keys__ = {'concat': 'c_concat',
|
29 |
+
'crossattn': 'c_crossattn',
|
30 |
+
'adm': 'y'}
|
31 |
+
|
32 |
+
|
33 |
+
def disabled_train(self, mode=True):
|
34 |
+
"""Overwrite model.train with this function to make sure train/eval mode
|
35 |
+
does not change anymore."""
|
36 |
+
return self
|
37 |
+
|
38 |
+
|
39 |
+
def uniform_on_device(r1, r2, shape, device):
|
40 |
+
return (r1 - r2) * torch.rand(*shape, device=device) + r2
|
41 |
+
|
42 |
+
|
43 |
+
class DDPMV1(pl.LightningModule):
|
44 |
+
# classic DDPM with Gaussian diffusion, in image space
|
45 |
+
def __init__(self,
|
46 |
+
unet_config,
|
47 |
+
timesteps=1000,
|
48 |
+
beta_schedule="linear",
|
49 |
+
loss_type="l2",
|
50 |
+
ckpt_path=None,
|
51 |
+
ignore_keys=[],
|
52 |
+
load_only_unet=False,
|
53 |
+
monitor="val/loss",
|
54 |
+
use_ema=True,
|
55 |
+
first_stage_key="image",
|
56 |
+
image_size=256,
|
57 |
+
channels=3,
|
58 |
+
log_every_t=100,
|
59 |
+
clip_denoised=True,
|
60 |
+
linear_start=1e-4,
|
61 |
+
linear_end=2e-2,
|
62 |
+
cosine_s=8e-3,
|
63 |
+
given_betas=None,
|
64 |
+
original_elbo_weight=0.,
|
65 |
+
v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
|
66 |
+
l_simple_weight=1.,
|
67 |
+
conditioning_key=None,
|
68 |
+
parameterization="eps", # all assuming fixed variance schedules
|
69 |
+
scheduler_config=None,
|
70 |
+
use_positional_encodings=False,
|
71 |
+
learn_logvar=False,
|
72 |
+
logvar_init=0.,
|
73 |
+
):
|
74 |
+
super().__init__()
|
75 |
+
assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
|
76 |
+
self.parameterization = parameterization
|
77 |
+
print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
|
78 |
+
self.cond_stage_model = None
|
79 |
+
self.clip_denoised = clip_denoised
|
80 |
+
self.log_every_t = log_every_t
|
81 |
+
self.first_stage_key = first_stage_key
|
82 |
+
self.image_size = image_size # try conv?
|
83 |
+
self.channels = channels
|
84 |
+
self.use_positional_encodings = use_positional_encodings
|
85 |
+
self.model = DiffusionWrapperV1(unet_config, conditioning_key)
|
86 |
+
count_params(self.model, verbose=True)
|
87 |
+
self.use_ema = use_ema
|
88 |
+
if self.use_ema:
|
89 |
+
self.model_ema = LitEma(self.model)
|
90 |
+
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
91 |
+
|
92 |
+
self.use_scheduler = scheduler_config is not None
|
93 |
+
if self.use_scheduler:
|
94 |
+
self.scheduler_config = scheduler_config
|
95 |
+
|
96 |
+
self.v_posterior = v_posterior
|
97 |
+
self.original_elbo_weight = original_elbo_weight
|
98 |
+
self.l_simple_weight = l_simple_weight
|
99 |
+
|
100 |
+
if monitor is not None:
|
101 |
+
self.monitor = monitor
|
102 |
+
if ckpt_path is not None:
|
103 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
|
104 |
+
|
105 |
+
self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
|
106 |
+
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
|
107 |
+
|
108 |
+
self.loss_type = loss_type
|
109 |
+
|
110 |
+
self.learn_logvar = learn_logvar
|
111 |
+
self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
|
112 |
+
if self.learn_logvar:
|
113 |
+
self.logvar = nn.Parameter(self.logvar, requires_grad=True)
|
114 |
+
|
115 |
+
|
116 |
+
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
117 |
+
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
118 |
+
if exists(given_betas):
|
119 |
+
betas = given_betas
|
120 |
+
else:
|
121 |
+
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
|
122 |
+
cosine_s=cosine_s)
|
123 |
+
alphas = 1. - betas
|
124 |
+
alphas_cumprod = np.cumprod(alphas, axis=0)
|
125 |
+
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
126 |
+
|
127 |
+
timesteps, = betas.shape
|
128 |
+
self.num_timesteps = int(timesteps)
|
129 |
+
self.linear_start = linear_start
|
130 |
+
self.linear_end = linear_end
|
131 |
+
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
|
132 |
+
|
133 |
+
to_torch = partial(torch.tensor, dtype=torch.float32)
|
134 |
+
|
135 |
+
self.register_buffer('betas', to_torch(betas))
|
136 |
+
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
137 |
+
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
|
138 |
+
|
139 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
140 |
+
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
|
141 |
+
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
|
142 |
+
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
|
143 |
+
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
|
144 |
+
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
|
145 |
+
|
146 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
147 |
+
posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
|
148 |
+
1. - alphas_cumprod) + self.v_posterior * betas
|
149 |
+
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
150 |
+
self.register_buffer('posterior_variance', to_torch(posterior_variance))
|
151 |
+
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
152 |
+
self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
|
153 |
+
self.register_buffer('posterior_mean_coef1', to_torch(
|
154 |
+
betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
|
155 |
+
self.register_buffer('posterior_mean_coef2', to_torch(
|
156 |
+
(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
|
157 |
+
|
158 |
+
if self.parameterization == "eps":
|
159 |
+
lvlb_weights = self.betas ** 2 / (
|
160 |
+
2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
|
161 |
+
elif self.parameterization == "x0":
|
162 |
+
lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
|
163 |
+
else:
|
164 |
+
raise NotImplementedError("mu not supported")
|
165 |
+
# TODO how to choose this term
|
166 |
+
lvlb_weights[0] = lvlb_weights[1]
|
167 |
+
self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
|
168 |
+
assert not torch.isnan(self.lvlb_weights).all()
|
169 |
+
|
170 |
+
@contextmanager
|
171 |
+
def ema_scope(self, context=None):
|
172 |
+
if self.use_ema:
|
173 |
+
self.model_ema.store(self.model.parameters())
|
174 |
+
self.model_ema.copy_to(self.model)
|
175 |
+
if context is not None:
|
176 |
+
print(f"{context}: Switched to EMA weights")
|
177 |
+
try:
|
178 |
+
yield None
|
179 |
+
finally:
|
180 |
+
if self.use_ema:
|
181 |
+
self.model_ema.restore(self.model.parameters())
|
182 |
+
if context is not None:
|
183 |
+
print(f"{context}: Restored training weights")
|
184 |
+
|
185 |
+
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
|
186 |
+
sd = torch.load(path, map_location="cpu")
|
187 |
+
if "state_dict" in list(sd.keys()):
|
188 |
+
sd = sd["state_dict"]
|
189 |
+
keys = list(sd.keys())
|
190 |
+
for k in keys:
|
191 |
+
for ik in ignore_keys:
|
192 |
+
if k.startswith(ik):
|
193 |
+
print("Deleting key {} from state_dict.".format(k))
|
194 |
+
del sd[k]
|
195 |
+
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
|
196 |
+
sd, strict=False)
|
197 |
+
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
198 |
+
if len(missing) > 0:
|
199 |
+
print(f"Missing Keys: {missing}")
|
200 |
+
if len(unexpected) > 0:
|
201 |
+
print(f"Unexpected Keys: {unexpected}")
|
202 |
+
|
203 |
+
def q_mean_variance(self, x_start, t):
|
204 |
+
"""
|
205 |
+
Get the distribution q(x_t | x_0).
|
206 |
+
:param x_start: the [N x C x ...] tensor of noiseless inputs.
|
207 |
+
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
208 |
+
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
|
209 |
+
"""
|
210 |
+
mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
|
211 |
+
variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
|
212 |
+
log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
213 |
+
return mean, variance, log_variance
|
214 |
+
|
215 |
+
def predict_start_from_noise(self, x_t, t, noise):
|
216 |
+
return (
|
217 |
+
extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
218 |
+
extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
219 |
+
)
|
220 |
+
|
221 |
+
def q_posterior(self, x_start, x_t, t):
|
222 |
+
posterior_mean = (
|
223 |
+
extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
224 |
+
extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
225 |
+
)
|
226 |
+
posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
|
227 |
+
posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
|
228 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
229 |
+
|
230 |
+
def p_mean_variance(self, x, t, clip_denoised: bool):
|
231 |
+
model_out = self.model(x, t)
|
232 |
+
if self.parameterization == "eps":
|
233 |
+
x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
|
234 |
+
elif self.parameterization == "x0":
|
235 |
+
x_recon = model_out
|
236 |
+
if clip_denoised:
|
237 |
+
x_recon.clamp_(-1., 1.)
|
238 |
+
|
239 |
+
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
240 |
+
return model_mean, posterior_variance, posterior_log_variance
|
241 |
+
|
242 |
+
@torch.no_grad()
|
243 |
+
def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
|
244 |
+
b, *_, device = *x.shape, x.device
|
245 |
+
model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
|
246 |
+
noise = noise_like(x.shape, device, repeat_noise)
|
247 |
+
# no noise when t == 0
|
248 |
+
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
249 |
+
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
250 |
+
|
251 |
+
@torch.no_grad()
|
252 |
+
def p_sample_loop(self, shape, return_intermediates=False):
|
253 |
+
device = self.betas.device
|
254 |
+
b = shape[0]
|
255 |
+
img = torch.randn(shape, device=device)
|
256 |
+
intermediates = [img]
|
257 |
+
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
|
258 |
+
img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
|
259 |
+
clip_denoised=self.clip_denoised)
|
260 |
+
if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
|
261 |
+
intermediates.append(img)
|
262 |
+
if return_intermediates:
|
263 |
+
return img, intermediates
|
264 |
+
return img
|
265 |
+
|
266 |
+
@torch.no_grad()
|
267 |
+
def sample(self, batch_size=16, return_intermediates=False):
|
268 |
+
image_size = self.image_size
|
269 |
+
channels = self.channels
|
270 |
+
return self.p_sample_loop((batch_size, channels, image_size, image_size),
|
271 |
+
return_intermediates=return_intermediates)
|
272 |
+
|
273 |
+
def q_sample(self, x_start, t, noise=None):
|
274 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
275 |
+
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
276 |
+
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
|
277 |
+
|
278 |
+
def get_loss(self, pred, target, mean=True):
|
279 |
+
if self.loss_type == 'l1':
|
280 |
+
loss = (target - pred).abs()
|
281 |
+
if mean:
|
282 |
+
loss = loss.mean()
|
283 |
+
elif self.loss_type == 'l2':
|
284 |
+
if mean:
|
285 |
+
loss = torch.nn.functional.mse_loss(target, pred)
|
286 |
+
else:
|
287 |
+
loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
|
288 |
+
else:
|
289 |
+
raise NotImplementedError("unknown loss type '{loss_type}'")
|
290 |
+
|
291 |
+
return loss
|
292 |
+
|
293 |
+
def p_losses(self, x_start, t, noise=None):
|
294 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
295 |
+
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
|
296 |
+
model_out = self.model(x_noisy, t)
|
297 |
+
|
298 |
+
loss_dict = {}
|
299 |
+
if self.parameterization == "eps":
|
300 |
+
target = noise
|
301 |
+
elif self.parameterization == "x0":
|
302 |
+
target = x_start
|
303 |
+
else:
|
304 |
+
raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")
|
305 |
+
|
306 |
+
loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
|
307 |
+
|
308 |
+
log_prefix = 'train' if self.training else 'val'
|
309 |
+
|
310 |
+
loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
|
311 |
+
loss_simple = loss.mean() * self.l_simple_weight
|
312 |
+
|
313 |
+
loss_vlb = (self.lvlb_weights[t] * loss).mean()
|
314 |
+
loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
|
315 |
+
|
316 |
+
loss = loss_simple + self.original_elbo_weight * loss_vlb
|
317 |
+
|
318 |
+
loss_dict.update({f'{log_prefix}/loss': loss})
|
319 |
+
|
320 |
+
return loss, loss_dict
|
321 |
+
|
322 |
+
def forward(self, x, *args, **kwargs):
|
323 |
+
# b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
|
324 |
+
# assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
|
325 |
+
t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
|
326 |
+
return self.p_losses(x, t, *args, **kwargs)
|
327 |
+
|
328 |
+
def get_input(self, batch, k):
|
329 |
+
x = batch[k]
|
330 |
+
if len(x.shape) == 3:
|
331 |
+
x = x[..., None]
|
332 |
+
x = rearrange(x, 'b h w c -> b c h w')
|
333 |
+
x = x.to(memory_format=torch.contiguous_format).float()
|
334 |
+
return x
|
335 |
+
|
336 |
+
def shared_step(self, batch):
|
337 |
+
x = self.get_input(batch, self.first_stage_key)
|
338 |
+
loss, loss_dict = self(x)
|
339 |
+
return loss, loss_dict
|
340 |
+
|
341 |
+
def training_step(self, batch, batch_idx):
|
342 |
+
loss, loss_dict = self.shared_step(batch)
|
343 |
+
|
344 |
+
self.log_dict(loss_dict, prog_bar=True,
|
345 |
+
logger=True, on_step=True, on_epoch=True)
|
346 |
+
|
347 |
+
self.log("global_step", self.global_step,
|
348 |
+
prog_bar=True, logger=True, on_step=True, on_epoch=False)
|
349 |
+
|
350 |
+
if self.use_scheduler:
|
351 |
+
lr = self.optimizers().param_groups[0]['lr']
|
352 |
+
self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
|
353 |
+
|
354 |
+
return loss
|
355 |
+
|
356 |
+
@torch.no_grad()
|
357 |
+
def validation_step(self, batch, batch_idx):
|
358 |
+
_, loss_dict_no_ema = self.shared_step(batch)
|
359 |
+
with self.ema_scope():
|
360 |
+
_, loss_dict_ema = self.shared_step(batch)
|
361 |
+
loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
|
362 |
+
self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
|
363 |
+
self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
|
364 |
+
|
365 |
+
def on_train_batch_end(self, *args, **kwargs):
|
366 |
+
if self.use_ema:
|
367 |
+
self.model_ema(self.model)
|
368 |
+
|
369 |
+
def _get_rows_from_list(self, samples):
|
370 |
+
n_imgs_per_row = len(samples)
|
371 |
+
denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
|
372 |
+
denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
|
373 |
+
denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
|
374 |
+
return denoise_grid
|
375 |
+
|
376 |
+
@torch.no_grad()
|
377 |
+
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
|
378 |
+
log = dict()
|
379 |
+
x = self.get_input(batch, self.first_stage_key)
|
380 |
+
N = min(x.shape[0], N)
|
381 |
+
n_row = min(x.shape[0], n_row)
|
382 |
+
x = x.to(self.device)[:N]
|
383 |
+
log["inputs"] = x
|
384 |
+
|
385 |
+
# get diffusion row
|
386 |
+
diffusion_row = list()
|
387 |
+
x_start = x[:n_row]
|
388 |
+
|
389 |
+
for t in range(self.num_timesteps):
|
390 |
+
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
391 |
+
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
|
392 |
+
t = t.to(self.device).long()
|
393 |
+
noise = torch.randn_like(x_start)
|
394 |
+
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
|
395 |
+
diffusion_row.append(x_noisy)
|
396 |
+
|
397 |
+
log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
|
398 |
+
|
399 |
+
if sample:
|
400 |
+
# get denoise row
|
401 |
+
with self.ema_scope("Plotting"):
|
402 |
+
samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
|
403 |
+
|
404 |
+
log["samples"] = samples
|
405 |
+
log["denoise_row"] = self._get_rows_from_list(denoise_row)
|
406 |
+
|
407 |
+
if return_keys:
|
408 |
+
if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
|
409 |
+
return log
|
410 |
+
else:
|
411 |
+
return {key: log[key] for key in return_keys}
|
412 |
+
return log
|
413 |
+
|
414 |
+
def configure_optimizers(self):
|
415 |
+
lr = self.learning_rate
|
416 |
+
params = list(self.model.parameters())
|
417 |
+
if self.learn_logvar:
|
418 |
+
params = params + [self.logvar]
|
419 |
+
opt = torch.optim.AdamW(params, lr=lr)
|
420 |
+
return opt
|
421 |
+
|
422 |
+
|
423 |
+
class LatentDiffusionV1(DDPMV1):
|
424 |
+
"""main class"""
|
425 |
+
def __init__(self,
|
426 |
+
first_stage_config,
|
427 |
+
cond_stage_config,
|
428 |
+
num_timesteps_cond=None,
|
429 |
+
cond_stage_key="image",
|
430 |
+
cond_stage_trainable=False,
|
431 |
+
concat_mode=True,
|
432 |
+
cond_stage_forward=None,
|
433 |
+
conditioning_key=None,
|
434 |
+
scale_factor=1.0,
|
435 |
+
scale_by_std=False,
|
436 |
+
*args, **kwargs):
|
437 |
+
self.num_timesteps_cond = default(num_timesteps_cond, 1)
|
438 |
+
self.scale_by_std = scale_by_std
|
439 |
+
assert self.num_timesteps_cond <= kwargs['timesteps']
|
440 |
+
# for backwards compatibility after implementation of DiffusionWrapper
|
441 |
+
if conditioning_key is None:
|
442 |
+
conditioning_key = 'concat' if concat_mode else 'crossattn'
|
443 |
+
if cond_stage_config == '__is_unconditional__':
|
444 |
+
conditioning_key = None
|
445 |
+
ckpt_path = kwargs.pop("ckpt_path", None)
|
446 |
+
ignore_keys = kwargs.pop("ignore_keys", [])
|
447 |
+
super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
|
448 |
+
self.concat_mode = concat_mode
|
449 |
+
self.cond_stage_trainable = cond_stage_trainable
|
450 |
+
self.cond_stage_key = cond_stage_key
|
451 |
+
try:
|
452 |
+
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
|
453 |
+
except:
|
454 |
+
self.num_downs = 0
|
455 |
+
if not scale_by_std:
|
456 |
+
self.scale_factor = scale_factor
|
457 |
+
else:
|
458 |
+
self.register_buffer('scale_factor', torch.tensor(scale_factor))
|
459 |
+
self.instantiate_first_stage(first_stage_config)
|
460 |
+
self.instantiate_cond_stage(cond_stage_config)
|
461 |
+
self.cond_stage_forward = cond_stage_forward
|
462 |
+
self.clip_denoised = False
|
463 |
+
self.bbox_tokenizer = None
|
464 |
+
|
465 |
+
self.restarted_from_ckpt = False
|
466 |
+
if ckpt_path is not None:
|
467 |
+
self.init_from_ckpt(ckpt_path, ignore_keys)
|
468 |
+
self.restarted_from_ckpt = True
|
469 |
+
|
470 |
+
def make_cond_schedule(self, ):
|
471 |
+
self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
|
472 |
+
ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
|
473 |
+
self.cond_ids[:self.num_timesteps_cond] = ids
|
474 |
+
|
475 |
+
@rank_zero_only
|
476 |
+
@torch.no_grad()
|
477 |
+
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
|
478 |
+
# only for very first batch
|
479 |
+
if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
|
480 |
+
assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
|
481 |
+
# set rescale weight to 1./std of encodings
|
482 |
+
print("### USING STD-RESCALING ###")
|
483 |
+
x = super().get_input(batch, self.first_stage_key)
|
484 |
+
x = x.to(self.device)
|
485 |
+
encoder_posterior = self.encode_first_stage(x)
|
486 |
+
z = self.get_first_stage_encoding(encoder_posterior).detach()
|
487 |
+
del self.scale_factor
|
488 |
+
self.register_buffer('scale_factor', 1. / z.flatten().std())
|
489 |
+
print(f"setting self.scale_factor to {self.scale_factor}")
|
490 |
+
print("### USING STD-RESCALING ###")
|
491 |
+
|
492 |
+
def register_schedule(self,
|
493 |
+
given_betas=None, beta_schedule="linear", timesteps=1000,
|
494 |
+
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
495 |
+
super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
|
496 |
+
|
497 |
+
self.shorten_cond_schedule = self.num_timesteps_cond > 1
|
498 |
+
if self.shorten_cond_schedule:
|
499 |
+
self.make_cond_schedule()
|
500 |
+
|
501 |
+
def instantiate_first_stage(self, config):
|
502 |
+
model = instantiate_from_config(config)
|
503 |
+
self.first_stage_model = model.eval()
|
504 |
+
self.first_stage_model.train = disabled_train
|
505 |
+
for param in self.first_stage_model.parameters():
|
506 |
+
param.requires_grad = False
|
507 |
+
|
508 |
+
def instantiate_cond_stage(self, config):
|
509 |
+
if not self.cond_stage_trainable:
|
510 |
+
if config == "__is_first_stage__":
|
511 |
+
print("Using first stage also as cond stage.")
|
512 |
+
self.cond_stage_model = self.first_stage_model
|
513 |
+
elif config == "__is_unconditional__":
|
514 |
+
print(f"Training {self.__class__.__name__} as an unconditional model.")
|
515 |
+
self.cond_stage_model = None
|
516 |
+
# self.be_unconditional = True
|
517 |
+
else:
|
518 |
+
model = instantiate_from_config(config)
|
519 |
+
self.cond_stage_model = model.eval()
|
520 |
+
self.cond_stage_model.train = disabled_train
|
521 |
+
for param in self.cond_stage_model.parameters():
|
522 |
+
param.requires_grad = False
|
523 |
+
else:
|
524 |
+
assert config != '__is_first_stage__'
|
525 |
+
assert config != '__is_unconditional__'
|
526 |
+
model = instantiate_from_config(config)
|
527 |
+
self.cond_stage_model = model
|
528 |
+
|
529 |
+
def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
|
530 |
+
denoise_row = []
|
531 |
+
for zd in tqdm(samples, desc=desc):
|
532 |
+
denoise_row.append(self.decode_first_stage(zd.to(self.device),
|
533 |
+
force_not_quantize=force_no_decoder_quantization))
|
534 |
+
n_imgs_per_row = len(denoise_row)
|
535 |
+
denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
|
536 |
+
denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
|
537 |
+
denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
|
538 |
+
denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
|
539 |
+
return denoise_grid
|
540 |
+
|
541 |
+
def get_first_stage_encoding(self, encoder_posterior):
|
542 |
+
if isinstance(encoder_posterior, DiagonalGaussianDistribution):
|
543 |
+
z = encoder_posterior.sample()
|
544 |
+
elif isinstance(encoder_posterior, torch.Tensor):
|
545 |
+
z = encoder_posterior
|
546 |
+
else:
|
547 |
+
raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
|
548 |
+
return self.scale_factor * z
|
549 |
+
|
550 |
+
def get_learned_conditioning(self, c):
|
551 |
+
if self.cond_stage_forward is None:
|
552 |
+
if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
|
553 |
+
c = self.cond_stage_model.encode(c)
|
554 |
+
if isinstance(c, DiagonalGaussianDistribution):
|
555 |
+
c = c.mode()
|
556 |
+
else:
|
557 |
+
c = self.cond_stage_model(c)
|
558 |
+
else:
|
559 |
+
assert hasattr(self.cond_stage_model, self.cond_stage_forward)
|
560 |
+
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
|
561 |
+
return c
|
562 |
+
|
563 |
+
def meshgrid(self, h, w):
|
564 |
+
y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
|
565 |
+
x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
|
566 |
+
|
567 |
+
arr = torch.cat([y, x], dim=-1)
|
568 |
+
return arr
|
569 |
+
|
570 |
+
def delta_border(self, h, w):
|
571 |
+
"""
|
572 |
+
:param h: height
|
573 |
+
:param w: width
|
574 |
+
:return: normalized distance to image border,
|
575 |
+
wtith min distance = 0 at border and max dist = 0.5 at image center
|
576 |
+
"""
|
577 |
+
lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
|
578 |
+
arr = self.meshgrid(h, w) / lower_right_corner
|
579 |
+
dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
|
580 |
+
dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
|
581 |
+
edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
|
582 |
+
return edge_dist
|
583 |
+
|
584 |
+
def get_weighting(self, h, w, Ly, Lx, device):
|
585 |
+
weighting = self.delta_border(h, w)
|
586 |
+
weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
|
587 |
+
self.split_input_params["clip_max_weight"], )
|
588 |
+
weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
|
589 |
+
|
590 |
+
if self.split_input_params["tie_braker"]:
|
591 |
+
L_weighting = self.delta_border(Ly, Lx)
|
592 |
+
L_weighting = torch.clip(L_weighting,
|
593 |
+
self.split_input_params["clip_min_tie_weight"],
|
594 |
+
self.split_input_params["clip_max_tie_weight"])
|
595 |
+
|
596 |
+
L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
|
597 |
+
weighting = weighting * L_weighting
|
598 |
+
return weighting
|
599 |
+
|
600 |
+
def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
|
601 |
+
"""
|
602 |
+
:param x: img of size (bs, c, h, w)
|
603 |
+
:return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
|
604 |
+
"""
|
605 |
+
bs, nc, h, w = x.shape
|
606 |
+
|
607 |
+
# number of crops in image
|
608 |
+
Ly = (h - kernel_size[0]) // stride[0] + 1
|
609 |
+
Lx = (w - kernel_size[1]) // stride[1] + 1
|
610 |
+
|
611 |
+
if uf == 1 and df == 1:
|
612 |
+
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
|
613 |
+
unfold = torch.nn.Unfold(**fold_params)
|
614 |
+
|
615 |
+
fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
|
616 |
+
|
617 |
+
weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
|
618 |
+
normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
|
619 |
+
weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
|
620 |
+
|
621 |
+
elif uf > 1 and df == 1:
|
622 |
+
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
|
623 |
+
unfold = torch.nn.Unfold(**fold_params)
|
624 |
+
|
625 |
+
fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
|
626 |
+
dilation=1, padding=0,
|
627 |
+
stride=(stride[0] * uf, stride[1] * uf))
|
628 |
+
fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
|
629 |
+
|
630 |
+
weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
|
631 |
+
normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
|
632 |
+
weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
|
633 |
+
|
634 |
+
elif df > 1 and uf == 1:
|
635 |
+
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
|
636 |
+
unfold = torch.nn.Unfold(**fold_params)
|
637 |
+
|
638 |
+
fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
|
639 |
+
dilation=1, padding=0,
|
640 |
+
stride=(stride[0] // df, stride[1] // df))
|
641 |
+
fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
|
642 |
+
|
643 |
+
weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
|
644 |
+
normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
|
645 |
+
weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
|
646 |
+
|
647 |
+
else:
|
648 |
+
raise NotImplementedError
|
649 |
+
|
650 |
+
return fold, unfold, normalization, weighting
|
651 |
+
|
652 |
+
@torch.no_grad()
|
653 |
+
def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
|
654 |
+
cond_key=None, return_original_cond=False, bs=None):
|
655 |
+
x = super().get_input(batch, k)
|
656 |
+
if bs is not None:
|
657 |
+
x = x[:bs]
|
658 |
+
x = x.to(self.device)
|
659 |
+
encoder_posterior = self.encode_first_stage(x)
|
660 |
+
z = self.get_first_stage_encoding(encoder_posterior).detach()
|
661 |
+
|
662 |
+
if self.model.conditioning_key is not None:
|
663 |
+
if cond_key is None:
|
664 |
+
cond_key = self.cond_stage_key
|
665 |
+
if cond_key != self.first_stage_key:
|
666 |
+
if cond_key in ['caption', 'coordinates_bbox']:
|
667 |
+
xc = batch[cond_key]
|
668 |
+
elif cond_key == 'class_label':
|
669 |
+
xc = batch
|
670 |
+
else:
|
671 |
+
xc = super().get_input(batch, cond_key).to(self.device)
|
672 |
+
else:
|
673 |
+
xc = x
|
674 |
+
if not self.cond_stage_trainable or force_c_encode:
|
675 |
+
if isinstance(xc, dict) or isinstance(xc, list):
|
676 |
+
# import pudb; pudb.set_trace()
|
677 |
+
c = self.get_learned_conditioning(xc)
|
678 |
+
else:
|
679 |
+
c = self.get_learned_conditioning(xc.to(self.device))
|
680 |
+
else:
|
681 |
+
c = xc
|
682 |
+
if bs is not None:
|
683 |
+
c = c[:bs]
|
684 |
+
|
685 |
+
if self.use_positional_encodings:
|
686 |
+
pos_x, pos_y = self.compute_latent_shifts(batch)
|
687 |
+
ckey = __conditioning_keys__[self.model.conditioning_key]
|
688 |
+
c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}
|
689 |
+
|
690 |
+
else:
|
691 |
+
c = None
|
692 |
+
xc = None
|
693 |
+
if self.use_positional_encodings:
|
694 |
+
pos_x, pos_y = self.compute_latent_shifts(batch)
|
695 |
+
c = {'pos_x': pos_x, 'pos_y': pos_y}
|
696 |
+
out = [z, c]
|
697 |
+
if return_first_stage_outputs:
|
698 |
+
xrec = self.decode_first_stage(z)
|
699 |
+
out.extend([x, xrec])
|
700 |
+
if return_original_cond:
|
701 |
+
out.append(xc)
|
702 |
+
return out
|
703 |
+
|
704 |
+
@torch.no_grad()
|
705 |
+
def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
|
706 |
+
if predict_cids:
|
707 |
+
if z.dim() == 4:
|
708 |
+
z = torch.argmax(z.exp(), dim=1).long()
|
709 |
+
z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
|
710 |
+
z = rearrange(z, 'b h w c -> b c h w').contiguous()
|
711 |
+
|
712 |
+
z = 1. / self.scale_factor * z
|
713 |
+
|
714 |
+
if hasattr(self, "split_input_params"):
|
715 |
+
if self.split_input_params["patch_distributed_vq"]:
|
716 |
+
ks = self.split_input_params["ks"] # eg. (128, 128)
|
717 |
+
stride = self.split_input_params["stride"] # eg. (64, 64)
|
718 |
+
uf = self.split_input_params["vqf"]
|
719 |
+
bs, nc, h, w = z.shape
|
720 |
+
if ks[0] > h or ks[1] > w:
|
721 |
+
ks = (min(ks[0], h), min(ks[1], w))
|
722 |
+
print("reducing Kernel")
|
723 |
+
|
724 |
+
if stride[0] > h or stride[1] > w:
|
725 |
+
stride = (min(stride[0], h), min(stride[1], w))
|
726 |
+
print("reducing stride")
|
727 |
+
|
728 |
+
fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
|
729 |
+
|
730 |
+
z = unfold(z) # (bn, nc * prod(**ks), L)
|
731 |
+
# 1. Reshape to img shape
|
732 |
+
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
733 |
+
|
734 |
+
# 2. apply model loop over last dim
|
735 |
+
if isinstance(self.first_stage_model, VQModelInterface):
|
736 |
+
output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
|
737 |
+
force_not_quantize=predict_cids or force_not_quantize)
|
738 |
+
for i in range(z.shape[-1])]
|
739 |
+
else:
|
740 |
+
|
741 |
+
output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
|
742 |
+
for i in range(z.shape[-1])]
|
743 |
+
|
744 |
+
o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
|
745 |
+
o = o * weighting
|
746 |
+
# Reverse 1. reshape to img shape
|
747 |
+
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
|
748 |
+
# stitch crops together
|
749 |
+
decoded = fold(o)
|
750 |
+
decoded = decoded / normalization # norm is shape (1, 1, h, w)
|
751 |
+
return decoded
|
752 |
+
else:
|
753 |
+
if isinstance(self.first_stage_model, VQModelInterface):
|
754 |
+
return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
|
755 |
+
else:
|
756 |
+
return self.first_stage_model.decode(z)
|
757 |
+
|
758 |
+
else:
|
759 |
+
if isinstance(self.first_stage_model, VQModelInterface):
|
760 |
+
return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
|
761 |
+
else:
|
762 |
+
return self.first_stage_model.decode(z)
|
763 |
+
|
764 |
+
# same as above but without decorator
|
765 |
+
def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
|
766 |
+
if predict_cids:
|
767 |
+
if z.dim() == 4:
|
768 |
+
z = torch.argmax(z.exp(), dim=1).long()
|
769 |
+
z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
|
770 |
+
z = rearrange(z, 'b h w c -> b c h w').contiguous()
|
771 |
+
|
772 |
+
z = 1. / self.scale_factor * z
|
773 |
+
|
774 |
+
if hasattr(self, "split_input_params"):
|
775 |
+
if self.split_input_params["patch_distributed_vq"]:
|
776 |
+
ks = self.split_input_params["ks"] # eg. (128, 128)
|
777 |
+
stride = self.split_input_params["stride"] # eg. (64, 64)
|
778 |
+
uf = self.split_input_params["vqf"]
|
779 |
+
bs, nc, h, w = z.shape
|
780 |
+
if ks[0] > h or ks[1] > w:
|
781 |
+
ks = (min(ks[0], h), min(ks[1], w))
|
782 |
+
print("reducing Kernel")
|
783 |
+
|
784 |
+
if stride[0] > h or stride[1] > w:
|
785 |
+
stride = (min(stride[0], h), min(stride[1], w))
|
786 |
+
print("reducing stride")
|
787 |
+
|
788 |
+
fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
|
789 |
+
|
790 |
+
z = unfold(z) # (bn, nc * prod(**ks), L)
|
791 |
+
# 1. Reshape to img shape
|
792 |
+
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
793 |
+
|
794 |
+
# 2. apply model loop over last dim
|
795 |
+
if isinstance(self.first_stage_model, VQModelInterface):
|
796 |
+
output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
|
797 |
+
force_not_quantize=predict_cids or force_not_quantize)
|
798 |
+
for i in range(z.shape[-1])]
|
799 |
+
else:
|
800 |
+
|
801 |
+
output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
|
802 |
+
for i in range(z.shape[-1])]
|
803 |
+
|
804 |
+
o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
|
805 |
+
o = o * weighting
|
806 |
+
# Reverse 1. reshape to img shape
|
807 |
+
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
|
808 |
+
# stitch crops together
|
809 |
+
decoded = fold(o)
|
810 |
+
decoded = decoded / normalization # norm is shape (1, 1, h, w)
|
811 |
+
return decoded
|
812 |
+
else:
|
813 |
+
if isinstance(self.first_stage_model, VQModelInterface):
|
814 |
+
return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
|
815 |
+
else:
|
816 |
+
return self.first_stage_model.decode(z)
|
817 |
+
|
818 |
+
else:
|
819 |
+
if isinstance(self.first_stage_model, VQModelInterface):
|
820 |
+
return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
|
821 |
+
else:
|
822 |
+
return self.first_stage_model.decode(z)
|
823 |
+
|
824 |
+
@torch.no_grad()
|
825 |
+
def encode_first_stage(self, x):
|
826 |
+
if hasattr(self, "split_input_params"):
|
827 |
+
if self.split_input_params["patch_distributed_vq"]:
|
828 |
+
ks = self.split_input_params["ks"] # eg. (128, 128)
|
829 |
+
stride = self.split_input_params["stride"] # eg. (64, 64)
|
830 |
+
df = self.split_input_params["vqf"]
|
831 |
+
self.split_input_params['original_image_size'] = x.shape[-2:]
|
832 |
+
bs, nc, h, w = x.shape
|
833 |
+
if ks[0] > h or ks[1] > w:
|
834 |
+
ks = (min(ks[0], h), min(ks[1], w))
|
835 |
+
print("reducing Kernel")
|
836 |
+
|
837 |
+
if stride[0] > h or stride[1] > w:
|
838 |
+
stride = (min(stride[0], h), min(stride[1], w))
|
839 |
+
print("reducing stride")
|
840 |
+
|
841 |
+
fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)
|
842 |
+
z = unfold(x) # (bn, nc * prod(**ks), L)
|
843 |
+
# Reshape to img shape
|
844 |
+
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
845 |
+
|
846 |
+
output_list = [self.first_stage_model.encode(z[:, :, :, :, i])
|
847 |
+
for i in range(z.shape[-1])]
|
848 |
+
|
849 |
+
o = torch.stack(output_list, axis=-1)
|
850 |
+
o = o * weighting
|
851 |
+
|
852 |
+
# Reverse reshape to img shape
|
853 |
+
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
|
854 |
+
# stitch crops together
|
855 |
+
decoded = fold(o)
|
856 |
+
decoded = decoded / normalization
|
857 |
+
return decoded
|
858 |
+
|
859 |
+
else:
|
860 |
+
return self.first_stage_model.encode(x)
|
861 |
+
else:
|
862 |
+
return self.first_stage_model.encode(x)
|
863 |
+
|
864 |
+
def shared_step(self, batch, **kwargs):
|
865 |
+
x, c = self.get_input(batch, self.first_stage_key)
|
866 |
+
loss = self(x, c)
|
867 |
+
return loss
|
868 |
+
|
869 |
+
def forward(self, x, c, *args, **kwargs):
|
870 |
+
t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
|
871 |
+
if self.model.conditioning_key is not None:
|
872 |
+
assert c is not None
|
873 |
+
if self.cond_stage_trainable:
|
874 |
+
c = self.get_learned_conditioning(c)
|
875 |
+
if self.shorten_cond_schedule: # TODO: drop this option
|
876 |
+
tc = self.cond_ids[t].to(self.device)
|
877 |
+
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
|
878 |
+
return self.p_losses(x, c, t, *args, **kwargs)
|
879 |
+
|
880 |
+
def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
|
881 |
+
def rescale_bbox(bbox):
|
882 |
+
x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
|
883 |
+
y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
|
884 |
+
w = min(bbox[2] / crop_coordinates[2], 1 - x0)
|
885 |
+
h = min(bbox[3] / crop_coordinates[3], 1 - y0)
|
886 |
+
return x0, y0, w, h
|
887 |
+
|
888 |
+
return [rescale_bbox(b) for b in bboxes]
|
889 |
+
|
890 |
+
def apply_model(self, x_noisy, t, cond, return_ids=False):
|
891 |
+
|
892 |
+
if isinstance(cond, dict):
|
893 |
+
# hybrid case, cond is exptected to be a dict
|
894 |
+
pass
|
895 |
+
else:
|
896 |
+
if not isinstance(cond, list):
|
897 |
+
cond = [cond]
|
898 |
+
key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
|
899 |
+
cond = {key: cond}
|
900 |
+
|
901 |
+
if hasattr(self, "split_input_params"):
|
902 |
+
assert len(cond) == 1 # todo can only deal with one conditioning atm
|
903 |
+
assert not return_ids
|
904 |
+
ks = self.split_input_params["ks"] # eg. (128, 128)
|
905 |
+
stride = self.split_input_params["stride"] # eg. (64, 64)
|
906 |
+
|
907 |
+
h, w = x_noisy.shape[-2:]
|
908 |
+
|
909 |
+
fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)
|
910 |
+
|
911 |
+
z = unfold(x_noisy) # (bn, nc * prod(**ks), L)
|
912 |
+
# Reshape to img shape
|
913 |
+
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
914 |
+
z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
|
915 |
+
|
916 |
+
if self.cond_stage_key in ["image", "LR_image", "segmentation",
|
917 |
+
'bbox_img'] and self.model.conditioning_key: # todo check for completeness
|
918 |
+
c_key = next(iter(cond.keys())) # get key
|
919 |
+
c = next(iter(cond.values())) # get value
|
920 |
+
assert (len(c) == 1) # todo extend to list with more than one elem
|
921 |
+
c = c[0] # get element
|
922 |
+
|
923 |
+
c = unfold(c)
|
924 |
+
c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
925 |
+
|
926 |
+
cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
|
927 |
+
|
928 |
+
elif self.cond_stage_key == 'coordinates_bbox':
|
929 |
+
assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size'
|
930 |
+
|
931 |
+
# assuming padding of unfold is always 0 and its dilation is always 1
|
932 |
+
n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
|
933 |
+
full_img_h, full_img_w = self.split_input_params['original_image_size']
|
934 |
+
# as we are operating on latents, we need the factor from the original image size to the
|
935 |
+
# spatial latent size to properly rescale the crops for regenerating the bbox annotations
|
936 |
+
num_downs = self.first_stage_model.encoder.num_resolutions - 1
|
937 |
+
rescale_latent = 2 ** (num_downs)
|
938 |
+
|
939 |
+
# get top left postions of patches as conforming for the bbbox tokenizer, therefore we
|
940 |
+
# need to rescale the tl patch coordinates to be in between (0,1)
|
941 |
+
tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
|
942 |
+
rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
|
943 |
+
for patch_nr in range(z.shape[-1])]
|
944 |
+
|
945 |
+
# patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)
|
946 |
+
patch_limits = [(x_tl, y_tl,
|
947 |
+
rescale_latent * ks[0] / full_img_w,
|
948 |
+
rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]
|
949 |
+
# patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]
|
950 |
+
|
951 |
+
# tokenize crop coordinates for the bounding boxes of the respective patches
|
952 |
+
patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)
|
953 |
+
for bbox in patch_limits] # list of length l with tensors of shape (1, 2)
|
954 |
+
print(patch_limits_tknzd[0].shape)
|
955 |
+
# cut tknzd crop position from conditioning
|
956 |
+
assert isinstance(cond, dict), 'cond must be dict to be fed into model'
|
957 |
+
cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)
|
958 |
+
print(cut_cond.shape)
|
959 |
+
|
960 |
+
adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])
|
961 |
+
adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')
|
962 |
+
print(adapted_cond.shape)
|
963 |
+
adapted_cond = self.get_learned_conditioning(adapted_cond)
|
964 |
+
print(adapted_cond.shape)
|
965 |
+
adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])
|
966 |
+
print(adapted_cond.shape)
|
967 |
+
|
968 |
+
cond_list = [{'c_crossattn': [e]} for e in adapted_cond]
|
969 |
+
|
970 |
+
else:
|
971 |
+
cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient
|
972 |
+
|
973 |
+
# apply model by loop over crops
|
974 |
+
output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]
|
975 |
+
assert not isinstance(output_list[0],
|
976 |
+
tuple) # todo cant deal with multiple model outputs check this never happens
|
977 |
+
|
978 |
+
o = torch.stack(output_list, axis=-1)
|
979 |
+
o = o * weighting
|
980 |
+
# Reverse reshape to img shape
|
981 |
+
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
|
982 |
+
# stitch crops together
|
983 |
+
x_recon = fold(o) / normalization
|
984 |
+
|
985 |
+
else:
|
986 |
+
x_recon = self.model(x_noisy, t, **cond)
|
987 |
+
|
988 |
+
if isinstance(x_recon, tuple) and not return_ids:
|
989 |
+
return x_recon[0]
|
990 |
+
else:
|
991 |
+
return x_recon
|
992 |
+
|
993 |
+
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
|
994 |
+
return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
|
995 |
+
extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
996 |
+
|
997 |
+
def _prior_bpd(self, x_start):
|
998 |
+
"""
|
999 |
+
Get the prior KL term for the variational lower-bound, measured in
|
1000 |
+
bits-per-dim.
|
1001 |
+
This term can't be optimized, as it only depends on the encoder.
|
1002 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
1003 |
+
:return: a batch of [N] KL values (in bits), one per batch element.
|
1004 |
+
"""
|
1005 |
+
batch_size = x_start.shape[0]
|
1006 |
+
t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
|
1007 |
+
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
|
1008 |
+
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
|
1009 |
+
return mean_flat(kl_prior) / np.log(2.0)
|
1010 |
+
|
1011 |
+
def p_losses(self, x_start, cond, t, noise=None):
|
1012 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
1013 |
+
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
|
1014 |
+
model_output = self.apply_model(x_noisy, t, cond)
|
1015 |
+
|
1016 |
+
loss_dict = {}
|
1017 |
+
prefix = 'train' if self.training else 'val'
|
1018 |
+
|
1019 |
+
if self.parameterization == "x0":
|
1020 |
+
target = x_start
|
1021 |
+
elif self.parameterization == "eps":
|
1022 |
+
target = noise
|
1023 |
+
else:
|
1024 |
+
raise NotImplementedError()
|
1025 |
+
|
1026 |
+
loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
|
1027 |
+
loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
|
1028 |
+
|
1029 |
+
logvar_t = self.logvar[t].to(self.device)
|
1030 |
+
loss = loss_simple / torch.exp(logvar_t) + logvar_t
|
1031 |
+
# loss = loss_simple / torch.exp(self.logvar) + self.logvar
|
1032 |
+
if self.learn_logvar:
|
1033 |
+
loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
|
1034 |
+
loss_dict.update({'logvar': self.logvar.data.mean()})
|
1035 |
+
|
1036 |
+
loss = self.l_simple_weight * loss.mean()
|
1037 |
+
|
1038 |
+
loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
|
1039 |
+
loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
|
1040 |
+
loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
|
1041 |
+
loss += (self.original_elbo_weight * loss_vlb)
|
1042 |
+
loss_dict.update({f'{prefix}/loss': loss})
|
1043 |
+
|
1044 |
+
return loss, loss_dict
|
1045 |
+
|
1046 |
+
def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
|
1047 |
+
return_x0=False, score_corrector=None, corrector_kwargs=None):
|
1048 |
+
t_in = t
|
1049 |
+
model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
|
1050 |
+
|
1051 |
+
if score_corrector is not None:
|
1052 |
+
assert self.parameterization == "eps"
|
1053 |
+
model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
|
1054 |
+
|
1055 |
+
if return_codebook_ids:
|
1056 |
+
model_out, logits = model_out
|
1057 |
+
|
1058 |
+
if self.parameterization == "eps":
|
1059 |
+
x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
|
1060 |
+
elif self.parameterization == "x0":
|
1061 |
+
x_recon = model_out
|
1062 |
+
else:
|
1063 |
+
raise NotImplementedError()
|
1064 |
+
|
1065 |
+
if clip_denoised:
|
1066 |
+
x_recon.clamp_(-1., 1.)
|
1067 |
+
if quantize_denoised:
|
1068 |
+
x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
|
1069 |
+
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
1070 |
+
if return_codebook_ids:
|
1071 |
+
return model_mean, posterior_variance, posterior_log_variance, logits
|
1072 |
+
elif return_x0:
|
1073 |
+
return model_mean, posterior_variance, posterior_log_variance, x_recon
|
1074 |
+
else:
|
1075 |
+
return model_mean, posterior_variance, posterior_log_variance
|
1076 |
+
|
1077 |
+
@torch.no_grad()
|
1078 |
+
def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
|
1079 |
+
return_codebook_ids=False, quantize_denoised=False, return_x0=False,
|
1080 |
+
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
|
1081 |
+
b, *_, device = *x.shape, x.device
|
1082 |
+
outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
|
1083 |
+
return_codebook_ids=return_codebook_ids,
|
1084 |
+
quantize_denoised=quantize_denoised,
|
1085 |
+
return_x0=return_x0,
|
1086 |
+
score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
|
1087 |
+
if return_codebook_ids:
|
1088 |
+
raise DeprecationWarning("Support dropped.")
|
1089 |
+
model_mean, _, model_log_variance, logits = outputs
|
1090 |
+
elif return_x0:
|
1091 |
+
model_mean, _, model_log_variance, x0 = outputs
|
1092 |
+
else:
|
1093 |
+
model_mean, _, model_log_variance = outputs
|
1094 |
+
|
1095 |
+
noise = noise_like(x.shape, device, repeat_noise) * temperature
|
1096 |
+
if noise_dropout > 0.:
|
1097 |
+
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
1098 |
+
# no noise when t == 0
|
1099 |
+
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
1100 |
+
|
1101 |
+
if return_codebook_ids:
|
1102 |
+
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
|
1103 |
+
if return_x0:
|
1104 |
+
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
|
1105 |
+
else:
|
1106 |
+
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
1107 |
+
|
1108 |
+
@torch.no_grad()
|
1109 |
+
def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
|
1110 |
+
img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
|
1111 |
+
score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
|
1112 |
+
log_every_t=None):
|
1113 |
+
if not log_every_t:
|
1114 |
+
log_every_t = self.log_every_t
|
1115 |
+
timesteps = self.num_timesteps
|
1116 |
+
if batch_size is not None:
|
1117 |
+
b = batch_size if batch_size is not None else shape[0]
|
1118 |
+
shape = [batch_size] + list(shape)
|
1119 |
+
else:
|
1120 |
+
b = batch_size = shape[0]
|
1121 |
+
if x_T is None:
|
1122 |
+
img = torch.randn(shape, device=self.device)
|
1123 |
+
else:
|
1124 |
+
img = x_T
|
1125 |
+
intermediates = []
|
1126 |
+
if cond is not None:
|
1127 |
+
if isinstance(cond, dict):
|
1128 |
+
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
1129 |
+
list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
|
1130 |
+
else:
|
1131 |
+
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
1132 |
+
|
1133 |
+
if start_T is not None:
|
1134 |
+
timesteps = min(timesteps, start_T)
|
1135 |
+
iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
|
1136 |
+
total=timesteps) if verbose else reversed(
|
1137 |
+
range(0, timesteps))
|
1138 |
+
if type(temperature) == float:
|
1139 |
+
temperature = [temperature] * timesteps
|
1140 |
+
|
1141 |
+
for i in iterator:
|
1142 |
+
ts = torch.full((b,), i, device=self.device, dtype=torch.long)
|
1143 |
+
if self.shorten_cond_schedule:
|
1144 |
+
assert self.model.conditioning_key != 'hybrid'
|
1145 |
+
tc = self.cond_ids[ts].to(cond.device)
|
1146 |
+
cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
|
1147 |
+
|
1148 |
+
img, x0_partial = self.p_sample(img, cond, ts,
|
1149 |
+
clip_denoised=self.clip_denoised,
|
1150 |
+
quantize_denoised=quantize_denoised, return_x0=True,
|
1151 |
+
temperature=temperature[i], noise_dropout=noise_dropout,
|
1152 |
+
score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
|
1153 |
+
if mask is not None:
|
1154 |
+
assert x0 is not None
|
1155 |
+
img_orig = self.q_sample(x0, ts)
|
1156 |
+
img = img_orig * mask + (1. - mask) * img
|
1157 |
+
|
1158 |
+
if i % log_every_t == 0 or i == timesteps - 1:
|
1159 |
+
intermediates.append(x0_partial)
|
1160 |
+
if callback: callback(i)
|
1161 |
+
if img_callback: img_callback(img, i)
|
1162 |
+
return img, intermediates
|
1163 |
+
|
1164 |
+
@torch.no_grad()
|
1165 |
+
def p_sample_loop(self, cond, shape, return_intermediates=False,
|
1166 |
+
x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
|
1167 |
+
mask=None, x0=None, img_callback=None, start_T=None,
|
1168 |
+
log_every_t=None):
|
1169 |
+
|
1170 |
+
if not log_every_t:
|
1171 |
+
log_every_t = self.log_every_t
|
1172 |
+
device = self.betas.device
|
1173 |
+
b = shape[0]
|
1174 |
+
if x_T is None:
|
1175 |
+
img = torch.randn(shape, device=device)
|
1176 |
+
else:
|
1177 |
+
img = x_T
|
1178 |
+
|
1179 |
+
intermediates = [img]
|
1180 |
+
if timesteps is None:
|
1181 |
+
timesteps = self.num_timesteps
|
1182 |
+
|
1183 |
+
if start_T is not None:
|
1184 |
+
timesteps = min(timesteps, start_T)
|
1185 |
+
iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
|
1186 |
+
range(0, timesteps))
|
1187 |
+
|
1188 |
+
if mask is not None:
|
1189 |
+
assert x0 is not None
|
1190 |
+
assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
|
1191 |
+
|
1192 |
+
for i in iterator:
|
1193 |
+
ts = torch.full((b,), i, device=device, dtype=torch.long)
|
1194 |
+
if self.shorten_cond_schedule:
|
1195 |
+
assert self.model.conditioning_key != 'hybrid'
|
1196 |
+
tc = self.cond_ids[ts].to(cond.device)
|
1197 |
+
cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
|
1198 |
+
|
1199 |
+
img = self.p_sample(img, cond, ts,
|
1200 |
+
clip_denoised=self.clip_denoised,
|
1201 |
+
quantize_denoised=quantize_denoised)
|
1202 |
+
if mask is not None:
|
1203 |
+
img_orig = self.q_sample(x0, ts)
|
1204 |
+
img = img_orig * mask + (1. - mask) * img
|
1205 |
+
|
1206 |
+
if i % log_every_t == 0 or i == timesteps - 1:
|
1207 |
+
intermediates.append(img)
|
1208 |
+
if callback: callback(i)
|
1209 |
+
if img_callback: img_callback(img, i)
|
1210 |
+
|
1211 |
+
if return_intermediates:
|
1212 |
+
return img, intermediates
|
1213 |
+
return img
|
1214 |
+
|
1215 |
+
@torch.no_grad()
|
1216 |
+
def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
|
1217 |
+
verbose=True, timesteps=None, quantize_denoised=False,
|
1218 |
+
mask=None, x0=None, shape=None,**kwargs):
|
1219 |
+
if shape is None:
|
1220 |
+
shape = (batch_size, self.channels, self.image_size, self.image_size)
|
1221 |
+
if cond is not None:
|
1222 |
+
if isinstance(cond, dict):
|
1223 |
+
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
1224 |
+
list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
|
1225 |
+
else:
|
1226 |
+
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
1227 |
+
return self.p_sample_loop(cond,
|
1228 |
+
shape,
|
1229 |
+
return_intermediates=return_intermediates, x_T=x_T,
|
1230 |
+
verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
|
1231 |
+
mask=mask, x0=x0)
|
1232 |
+
|
1233 |
+
@torch.no_grad()
|
1234 |
+
def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
|
1235 |
+
|
1236 |
+
if ddim:
|
1237 |
+
ddim_sampler = DDIMSampler(self)
|
1238 |
+
shape = (self.channels, self.image_size, self.image_size)
|
1239 |
+
samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
|
1240 |
+
shape,cond,verbose=False,**kwargs)
|
1241 |
+
|
1242 |
+
else:
|
1243 |
+
samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
|
1244 |
+
return_intermediates=True,**kwargs)
|
1245 |
+
|
1246 |
+
return samples, intermediates
|
1247 |
+
|
1248 |
+
|
1249 |
+
@torch.no_grad()
|
1250 |
+
def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
|
1251 |
+
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
|
1252 |
+
plot_diffusion_rows=True, **kwargs):
|
1253 |
+
|
1254 |
+
use_ddim = ddim_steps is not None
|
1255 |
+
|
1256 |
+
log = dict()
|
1257 |
+
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
|
1258 |
+
return_first_stage_outputs=True,
|
1259 |
+
force_c_encode=True,
|
1260 |
+
return_original_cond=True,
|
1261 |
+
bs=N)
|
1262 |
+
N = min(x.shape[0], N)
|
1263 |
+
n_row = min(x.shape[0], n_row)
|
1264 |
+
log["inputs"] = x
|
1265 |
+
log["reconstruction"] = xrec
|
1266 |
+
if self.model.conditioning_key is not None:
|
1267 |
+
if hasattr(self.cond_stage_model, "decode"):
|
1268 |
+
xc = self.cond_stage_model.decode(c)
|
1269 |
+
log["conditioning"] = xc
|
1270 |
+
elif self.cond_stage_key in ["caption"]:
|
1271 |
+
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"])
|
1272 |
+
log["conditioning"] = xc
|
1273 |
+
elif self.cond_stage_key == 'class_label':
|
1274 |
+
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
|
1275 |
+
log['conditioning'] = xc
|
1276 |
+
elif isimage(xc):
|
1277 |
+
log["conditioning"] = xc
|
1278 |
+
if ismap(xc):
|
1279 |
+
log["original_conditioning"] = self.to_rgb(xc)
|
1280 |
+
|
1281 |
+
if plot_diffusion_rows:
|
1282 |
+
# get diffusion row
|
1283 |
+
diffusion_row = list()
|
1284 |
+
z_start = z[:n_row]
|
1285 |
+
for t in range(self.num_timesteps):
|
1286 |
+
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
1287 |
+
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
|
1288 |
+
t = t.to(self.device).long()
|
1289 |
+
noise = torch.randn_like(z_start)
|
1290 |
+
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
|
1291 |
+
diffusion_row.append(self.decode_first_stage(z_noisy))
|
1292 |
+
|
1293 |
+
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
|
1294 |
+
diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
|
1295 |
+
diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
|
1296 |
+
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
|
1297 |
+
log["diffusion_row"] = diffusion_grid
|
1298 |
+
|
1299 |
+
if sample:
|
1300 |
+
# get denoise row
|
1301 |
+
with self.ema_scope("Plotting"):
|
1302 |
+
samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
|
1303 |
+
ddim_steps=ddim_steps,eta=ddim_eta)
|
1304 |
+
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
|
1305 |
+
x_samples = self.decode_first_stage(samples)
|
1306 |
+
log["samples"] = x_samples
|
1307 |
+
if plot_denoise_rows:
|
1308 |
+
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
|
1309 |
+
log["denoise_row"] = denoise_grid
|
1310 |
+
|
1311 |
+
if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
|
1312 |
+
self.first_stage_model, IdentityFirstStage):
|
1313 |
+
# also display when quantizing x0 while sampling
|
1314 |
+
with self.ema_scope("Plotting Quantized Denoised"):
|
1315 |
+
samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
|
1316 |
+
ddim_steps=ddim_steps,eta=ddim_eta,
|
1317 |
+
quantize_denoised=True)
|
1318 |
+
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
|
1319 |
+
# quantize_denoised=True)
|
1320 |
+
x_samples = self.decode_first_stage(samples.to(self.device))
|
1321 |
+
log["samples_x0_quantized"] = x_samples
|
1322 |
+
|
1323 |
+
if inpaint:
|
1324 |
+
# make a simple center square
|
1325 |
+
b, h, w = z.shape[0], z.shape[2], z.shape[3]
|
1326 |
+
mask = torch.ones(N, h, w).to(self.device)
|
1327 |
+
# zeros will be filled in
|
1328 |
+
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
|
1329 |
+
mask = mask[:, None, ...]
|
1330 |
+
with self.ema_scope("Plotting Inpaint"):
|
1331 |
+
|
1332 |
+
samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
|
1333 |
+
ddim_steps=ddim_steps, x0=z[:N], mask=mask)
|
1334 |
+
x_samples = self.decode_first_stage(samples.to(self.device))
|
1335 |
+
log["samples_inpainting"] = x_samples
|
1336 |
+
log["mask"] = mask
|
1337 |
+
|
1338 |
+
# outpaint
|
1339 |
+
with self.ema_scope("Plotting Outpaint"):
|
1340 |
+
samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
|
1341 |
+
ddim_steps=ddim_steps, x0=z[:N], mask=mask)
|
1342 |
+
x_samples = self.decode_first_stage(samples.to(self.device))
|
1343 |
+
log["samples_outpainting"] = x_samples
|
1344 |
+
|
1345 |
+
if plot_progressive_rows:
|
1346 |
+
with self.ema_scope("Plotting Progressives"):
|
1347 |
+
img, progressives = self.progressive_denoising(c,
|
1348 |
+
shape=(self.channels, self.image_size, self.image_size),
|
1349 |
+
batch_size=N)
|
1350 |
+
prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
|
1351 |
+
log["progressive_row"] = prog_row
|
1352 |
+
|
1353 |
+
if return_keys:
|
1354 |
+
if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
|
1355 |
+
return log
|
1356 |
+
else:
|
1357 |
+
return {key: log[key] for key in return_keys}
|
1358 |
+
return log
|
1359 |
+
|
1360 |
+
def configure_optimizers(self):
|
1361 |
+
lr = self.learning_rate
|
1362 |
+
params = list(self.model.parameters())
|
1363 |
+
if self.cond_stage_trainable:
|
1364 |
+
print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
|
1365 |
+
params = params + list(self.cond_stage_model.parameters())
|
1366 |
+
if self.learn_logvar:
|
1367 |
+
print('Diffusion model optimizing logvar')
|
1368 |
+
params.append(self.logvar)
|
1369 |
+
opt = torch.optim.AdamW(params, lr=lr)
|
1370 |
+
if self.use_scheduler:
|
1371 |
+
assert 'target' in self.scheduler_config
|
1372 |
+
scheduler = instantiate_from_config(self.scheduler_config)
|
1373 |
+
|
1374 |
+
print("Setting up LambdaLR scheduler...")
|
1375 |
+
scheduler = [
|
1376 |
+
{
|
1377 |
+
'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
|
1378 |
+
'interval': 'step',
|
1379 |
+
'frequency': 1
|
1380 |
+
}]
|
1381 |
+
return [opt], scheduler
|
1382 |
+
return opt
|
1383 |
+
|
1384 |
+
@torch.no_grad()
|
1385 |
+
def to_rgb(self, x):
|
1386 |
+
x = x.float()
|
1387 |
+
if not hasattr(self, "colorize"):
|
1388 |
+
self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
|
1389 |
+
x = nn.functional.conv2d(x, weight=self.colorize)
|
1390 |
+
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
|
1391 |
+
return x
|
1392 |
+
|
1393 |
+
|
1394 |
+
class DiffusionWrapperV1(pl.LightningModule):
|
1395 |
+
def __init__(self, diff_model_config, conditioning_key):
|
1396 |
+
super().__init__()
|
1397 |
+
self.diffusion_model = instantiate_from_config(diff_model_config)
|
1398 |
+
self.conditioning_key = conditioning_key
|
1399 |
+
assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm']
|
1400 |
+
|
1401 |
+
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
|
1402 |
+
if self.conditioning_key is None:
|
1403 |
+
out = self.diffusion_model(x, t)
|
1404 |
+
elif self.conditioning_key == 'concat':
|
1405 |
+
xc = torch.cat([x] + c_concat, dim=1)
|
1406 |
+
out = self.diffusion_model(xc, t)
|
1407 |
+
elif self.conditioning_key == 'crossattn':
|
1408 |
+
cc = torch.cat(c_crossattn, 1)
|
1409 |
+
out = self.diffusion_model(x, t, context=cc)
|
1410 |
+
elif self.conditioning_key == 'hybrid':
|
1411 |
+
xc = torch.cat([x] + c_concat, dim=1)
|
1412 |
+
cc = torch.cat(c_crossattn, 1)
|
1413 |
+
out = self.diffusion_model(xc, t, context=cc)
|
1414 |
+
elif self.conditioning_key == 'adm':
|
1415 |
+
cc = c_crossattn[0]
|
1416 |
+
out = self.diffusion_model(x, t, y=cc)
|
1417 |
+
else:
|
1418 |
+
raise NotImplementedError()
|
1419 |
+
|
1420 |
+
return out
|
1421 |
+
|
1422 |
+
|
1423 |
+
class Layout2ImgDiffusionV1(LatentDiffusionV1):
|
1424 |
+
# TODO: move all layout-specific hacks to this class
|
1425 |
+
def __init__(self, cond_stage_key, *args, **kwargs):
|
1426 |
+
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
|
1427 |
+
super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
|
1428 |
+
|
1429 |
+
def log_images(self, batch, N=8, *args, **kwargs):
|
1430 |
+
logs = super().log_images(batch=batch, N=N, *args, **kwargs)
|
1431 |
+
|
1432 |
+
key = 'train' if self.training else 'validation'
|
1433 |
+
dset = self.trainer.datamodule.datasets[key]
|
1434 |
+
mapper = dset.conditional_builders[self.cond_stage_key]
|
1435 |
+
|
1436 |
+
bbox_imgs = []
|
1437 |
+
map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno))
|
1438 |
+
for tknzd_bbox in batch[self.cond_stage_key][:N]:
|
1439 |
+
bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256))
|
1440 |
+
bbox_imgs.append(bboximg)
|
1441 |
+
|
1442 |
+
cond_img = torch.stack(bbox_imgs, dim=0)
|
1443 |
+
logs['bbox_image'] = cond_img
|
1444 |
+
return logs
|
1445 |
+
|
1446 |
+
setattr(ldm.models.diffusion.ddpm, "DDPMV1", DDPMV1)
|
1447 |
+
setattr(ldm.models.diffusion.ddpm, "LatentDiffusionV1", LatentDiffusionV1)
|
1448 |
+
setattr(ldm.models.diffusion.ddpm, "DiffusionWrapperV1", DiffusionWrapperV1)
|
1449 |
+
setattr(ldm.models.diffusion.ddpm, "Layout2ImgDiffusionV1", Layout2ImgDiffusionV1)
|
extensions-builtin/Lora/__pycache__/extra_networks_lora.cpython-310.pyc
ADDED
Binary file (1.66 kB). View file
|
|
extensions-builtin/Lora/__pycache__/lora.cpython-310.pyc
ADDED
Binary file (7.08 kB). View file
|
|
extensions-builtin/Lora/__pycache__/preload.cpython-310.pyc
ADDED
Binary file (501 Bytes). View file
|
|
extensions-builtin/Lora/__pycache__/ui_extra_networks_lora.cpython-310.pyc
ADDED
Binary file (1.71 kB). View file
|
|
extensions-builtin/Lora/extra_networks_lora.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules import extra_networks, shared
|
2 |
+
import lora
|
3 |
+
|
4 |
+
class ExtraNetworkLora(extra_networks.ExtraNetwork):
|
5 |
+
def __init__(self):
|
6 |
+
super().__init__('lora')
|
7 |
+
|
8 |
+
def activate(self, p, params_list):
|
9 |
+
additional = shared.opts.sd_lora
|
10 |
+
|
11 |
+
if additional != "" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0:
|
12 |
+
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
|
13 |
+
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
|
14 |
+
|
15 |
+
names = []
|
16 |
+
multipliers = []
|
17 |
+
for params in params_list:
|
18 |
+
assert len(params.items) > 0
|
19 |
+
|
20 |
+
names.append(params.items[0])
|
21 |
+
multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
|
22 |
+
|
23 |
+
lora.load_loras(names, multipliers)
|
24 |
+
|
25 |
+
def deactivate(self, p):
|
26 |
+
pass
|
extensions-builtin/Lora/lora.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from modules import shared, devices, sd_models, errors
|
7 |
+
|
8 |
+
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
|
9 |
+
|
10 |
+
re_digits = re.compile(r"\d+")
|
11 |
+
re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)")
|
12 |
+
re_unet_mid_blocks = re.compile(r"lora_unet_mid_block_attentions_(\d+)_(.+)")
|
13 |
+
re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)")
|
14 |
+
re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)")
|
15 |
+
|
16 |
+
|
17 |
+
def convert_diffusers_name_to_compvis(key):
|
18 |
+
def match(match_list, regex):
|
19 |
+
r = re.match(regex, key)
|
20 |
+
if not r:
|
21 |
+
return False
|
22 |
+
|
23 |
+
match_list.clear()
|
24 |
+
match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
|
25 |
+
return True
|
26 |
+
|
27 |
+
m = []
|
28 |
+
|
29 |
+
if match(m, re_unet_down_blocks):
|
30 |
+
return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_1_{m[2]}"
|
31 |
+
|
32 |
+
if match(m, re_unet_mid_blocks):
|
33 |
+
return f"diffusion_model_middle_block_1_{m[1]}"
|
34 |
+
|
35 |
+
if match(m, re_unet_up_blocks):
|
36 |
+
return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}"
|
37 |
+
|
38 |
+
if match(m, re_text_block):
|
39 |
+
return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
|
40 |
+
|
41 |
+
return key
|
42 |
+
|
43 |
+
|
44 |
+
class LoraOnDisk:
|
45 |
+
def __init__(self, name, filename):
|
46 |
+
self.name = name
|
47 |
+
self.filename = filename
|
48 |
+
self.metadata = {}
|
49 |
+
|
50 |
+
_, ext = os.path.splitext(filename)
|
51 |
+
if ext.lower() == ".safetensors":
|
52 |
+
try:
|
53 |
+
self.metadata = sd_models.read_metadata_from_safetensors(filename)
|
54 |
+
except Exception as e:
|
55 |
+
errors.display(e, f"reading lora {filename}")
|
56 |
+
|
57 |
+
if self.metadata:
|
58 |
+
m = {}
|
59 |
+
for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)):
|
60 |
+
m[k] = v
|
61 |
+
|
62 |
+
self.metadata = m
|
63 |
+
|
64 |
+
self.ssmd_cover_images = self.metadata.pop('ssmd_cover_images', None) # those are cover images and they are too big to display in UI as text
|
65 |
+
|
66 |
+
|
67 |
+
class LoraModule:
|
68 |
+
def __init__(self, name):
|
69 |
+
self.name = name
|
70 |
+
self.multiplier = 1.0
|
71 |
+
self.modules = {}
|
72 |
+
self.mtime = None
|
73 |
+
|
74 |
+
|
75 |
+
class LoraUpDownModule:
|
76 |
+
def __init__(self):
|
77 |
+
self.up = None
|
78 |
+
self.down = None
|
79 |
+
self.alpha = None
|
80 |
+
|
81 |
+
|
82 |
+
def assign_lora_names_to_compvis_modules(sd_model):
|
83 |
+
lora_layer_mapping = {}
|
84 |
+
|
85 |
+
for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
|
86 |
+
lora_name = name.replace(".", "_")
|
87 |
+
lora_layer_mapping[lora_name] = module
|
88 |
+
module.lora_layer_name = lora_name
|
89 |
+
|
90 |
+
for name, module in shared.sd_model.model.named_modules():
|
91 |
+
lora_name = name.replace(".", "_")
|
92 |
+
lora_layer_mapping[lora_name] = module
|
93 |
+
module.lora_layer_name = lora_name
|
94 |
+
|
95 |
+
sd_model.lora_layer_mapping = lora_layer_mapping
|
96 |
+
|
97 |
+
|
98 |
+
def load_lora(name, filename):
|
99 |
+
lora = LoraModule(name)
|
100 |
+
lora.mtime = os.path.getmtime(filename)
|
101 |
+
|
102 |
+
sd = sd_models.read_state_dict(filename)
|
103 |
+
|
104 |
+
keys_failed_to_match = []
|
105 |
+
|
106 |
+
for key_diffusers, weight in sd.items():
|
107 |
+
fullkey = convert_diffusers_name_to_compvis(key_diffusers)
|
108 |
+
key, lora_key = fullkey.split(".", 1)
|
109 |
+
|
110 |
+
sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
|
111 |
+
if sd_module is None:
|
112 |
+
keys_failed_to_match.append(key_diffusers)
|
113 |
+
continue
|
114 |
+
|
115 |
+
lora_module = lora.modules.get(key, None)
|
116 |
+
if lora_module is None:
|
117 |
+
lora_module = LoraUpDownModule()
|
118 |
+
lora.modules[key] = lora_module
|
119 |
+
|
120 |
+
if lora_key == "alpha":
|
121 |
+
lora_module.alpha = weight.item()
|
122 |
+
continue
|
123 |
+
|
124 |
+
if type(sd_module) == torch.nn.Linear:
|
125 |
+
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
126 |
+
elif type(sd_module) == torch.nn.Conv2d:
|
127 |
+
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
|
128 |
+
else:
|
129 |
+
assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'
|
130 |
+
|
131 |
+
with torch.no_grad():
|
132 |
+
module.weight.copy_(weight)
|
133 |
+
|
134 |
+
module.to(device=devices.device, dtype=devices.dtype)
|
135 |
+
|
136 |
+
if lora_key == "lora_up.weight":
|
137 |
+
lora_module.up = module
|
138 |
+
elif lora_key == "lora_down.weight":
|
139 |
+
lora_module.down = module
|
140 |
+
else:
|
141 |
+
assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha'
|
142 |
+
|
143 |
+
if len(keys_failed_to_match) > 0:
|
144 |
+
print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
|
145 |
+
|
146 |
+
return lora
|
147 |
+
|
148 |
+
|
149 |
+
def load_loras(names, multipliers=None):
|
150 |
+
already_loaded = {}
|
151 |
+
|
152 |
+
for lora in loaded_loras:
|
153 |
+
if lora.name in names:
|
154 |
+
already_loaded[lora.name] = lora
|
155 |
+
|
156 |
+
loaded_loras.clear()
|
157 |
+
|
158 |
+
loras_on_disk = [available_loras.get(name, None) for name in names]
|
159 |
+
if any([x is None for x in loras_on_disk]):
|
160 |
+
list_available_loras()
|
161 |
+
|
162 |
+
loras_on_disk = [available_loras.get(name, None) for name in names]
|
163 |
+
|
164 |
+
for i, name in enumerate(names):
|
165 |
+
lora = already_loaded.get(name, None)
|
166 |
+
|
167 |
+
lora_on_disk = loras_on_disk[i]
|
168 |
+
if lora_on_disk is not None:
|
169 |
+
if lora is None or os.path.getmtime(lora_on_disk.filename) > lora.mtime:
|
170 |
+
lora = load_lora(name, lora_on_disk.filename)
|
171 |
+
|
172 |
+
if lora is None:
|
173 |
+
print(f"Couldn't find Lora with name {name}")
|
174 |
+
continue
|
175 |
+
|
176 |
+
lora.multiplier = multipliers[i] if multipliers else 1.0
|
177 |
+
loaded_loras.append(lora)
|
178 |
+
|
179 |
+
|
180 |
+
def lora_forward(module, input, res):
|
181 |
+
if len(loaded_loras) == 0:
|
182 |
+
return res
|
183 |
+
|
184 |
+
lora_layer_name = getattr(module, 'lora_layer_name', None)
|
185 |
+
for lora in loaded_loras:
|
186 |
+
module = lora.modules.get(lora_layer_name, None)
|
187 |
+
if module is not None:
|
188 |
+
if shared.opts.lora_apply_to_outputs and res.shape == input.shape:
|
189 |
+
res = res + module.up(module.down(res)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
|
190 |
+
else:
|
191 |
+
res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
|
192 |
+
|
193 |
+
return res
|
194 |
+
|
195 |
+
|
196 |
+
def lora_Linear_forward(self, input):
|
197 |
+
return lora_forward(self, input, torch.nn.Linear_forward_before_lora(self, input))
|
198 |
+
|
199 |
+
|
200 |
+
def lora_Conv2d_forward(self, input):
|
201 |
+
return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input))
|
202 |
+
|
203 |
+
|
204 |
+
def list_available_loras():
|
205 |
+
available_loras.clear()
|
206 |
+
|
207 |
+
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
|
208 |
+
|
209 |
+
candidates = \
|
210 |
+
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.pt'), recursive=True) + \
|
211 |
+
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.safetensors'), recursive=True) + \
|
212 |
+
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.ckpt'), recursive=True)
|
213 |
+
|
214 |
+
for filename in sorted(candidates):
|
215 |
+
if os.path.isdir(filename):
|
216 |
+
continue
|
217 |
+
|
218 |
+
name = os.path.splitext(os.path.basename(filename))[0]
|
219 |
+
|
220 |
+
available_loras[name] = LoraOnDisk(name, filename)
|
221 |
+
|
222 |
+
|
223 |
+
available_loras = {}
|
224 |
+
loaded_loras = []
|
225 |
+
|
226 |
+
list_available_loras()
|
extensions-builtin/Lora/preload.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from modules import paths
|
3 |
+
|
4 |
+
|
5 |
+
def preload(parser):
|
6 |
+
parser.add_argument("--lora-dir", type=str, help="Path to directory with Lora networks.", default=os.path.join(paths.models_path, 'Lora'))
|
extensions-builtin/Lora/scripts/__pycache__/lora_script.cpython-310.pyc
ADDED
Binary file (1.78 kB). View file
|
|
extensions-builtin/Lora/scripts/lora_script.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import gradio as gr
|
3 |
+
|
4 |
+
import lora
|
5 |
+
import extra_networks_lora
|
6 |
+
import ui_extra_networks_lora
|
7 |
+
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
|
8 |
+
|
9 |
+
|
10 |
+
def unload():
|
11 |
+
torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
|
12 |
+
torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora
|
13 |
+
|
14 |
+
|
15 |
+
def before_ui():
|
16 |
+
ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())
|
17 |
+
extra_networks.register_extra_network(extra_networks_lora.ExtraNetworkLora())
|
18 |
+
|
19 |
+
|
20 |
+
if not hasattr(torch.nn, 'Linear_forward_before_lora'):
|
21 |
+
torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward
|
22 |
+
|
23 |
+
if not hasattr(torch.nn, 'Conv2d_forward_before_lora'):
|
24 |
+
torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward
|
25 |
+
|
26 |
+
torch.nn.Linear.forward = lora.lora_Linear_forward
|
27 |
+
torch.nn.Conv2d.forward = lora.lora_Conv2d_forward
|
28 |
+
|
29 |
+
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
|
30 |
+
script_callbacks.on_script_unloaded(unload)
|
31 |
+
script_callbacks.on_before_ui(before_ui)
|
32 |
+
|
33 |
+
|
34 |
+
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
|
35 |
+
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
|
36 |
+
"lora_apply_to_outputs": shared.OptionInfo(False, "Apply Lora to outputs rather than inputs when possible (experimental)"),
|
37 |
+
|
38 |
+
}))
|
extensions-builtin/Lora/ui_extra_networks_lora.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import lora
|
4 |
+
|
5 |
+
from modules import shared, ui_extra_networks
|
6 |
+
|
7 |
+
|
8 |
+
class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
9 |
+
def __init__(self):
|
10 |
+
super().__init__('Lora')
|
11 |
+
|
12 |
+
def refresh(self):
|
13 |
+
lora.list_available_loras()
|
14 |
+
|
15 |
+
def list_items(self):
|
16 |
+
for name, lora_on_disk in lora.available_loras.items():
|
17 |
+
path, ext = os.path.splitext(lora_on_disk.filename)
|
18 |
+
yield {
|
19 |
+
"name": name,
|
20 |
+
"filename": path,
|
21 |
+
"preview": self.find_preview(path),
|
22 |
+
"description": self.find_description(path),
|
23 |
+
"search_term": self.search_terms_from_path(lora_on_disk.filename),
|
24 |
+
"prompt": json.dumps(f"<lora:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
|
25 |
+
"local_preview": f"{path}.{shared.opts.samples_format}",
|
26 |
+
"metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None,
|
27 |
+
}
|
28 |
+
|
29 |
+
def allowed_directories_for_previews(self):
|
30 |
+
return [shared.cmd_opts.lora_dir]
|
31 |
+
|
extensions-builtin/ScuNET/__pycache__/preload.cpython-310.pyc
ADDED
Binary file (522 Bytes). View file
|
|
extensions-builtin/ScuNET/__pycache__/scunet_model_arch.cpython-310.pyc
ADDED
Binary file (9.3 kB). View file
|
|
extensions-builtin/ScuNET/preload.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from modules import paths
|
3 |
+
|
4 |
+
|
5 |
+
def preload(parser):
|
6 |
+
parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(paths.models_path, 'ScuNET'))
|
extensions-builtin/ScuNET/scripts/__pycache__/scunet_model.cpython-310.pyc
ADDED
Binary file (3.18 kB). View file
|
|
extensions-builtin/ScuNET/scripts/scunet_model.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
+
import sys
|
3 |
+
import traceback
|
4 |
+
|
5 |
+
import PIL.Image
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from basicsr.utils.download_util import load_file_from_url
|
9 |
+
|
10 |
+
import modules.upscaler
|
11 |
+
from modules import devices, modelloader
|
12 |
+
from scunet_model_arch import SCUNet as net
|
13 |
+
|
14 |
+
|
15 |
+
class UpscalerScuNET(modules.upscaler.Upscaler):
|
16 |
+
def __init__(self, dirname):
|
17 |
+
self.name = "ScuNET"
|
18 |
+
self.model_name = "ScuNET GAN"
|
19 |
+
self.model_name2 = "ScuNET PSNR"
|
20 |
+
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_gan.pth"
|
21 |
+
self.model_url2 = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_psnr.pth"
|
22 |
+
self.user_path = dirname
|
23 |
+
super().__init__()
|
24 |
+
model_paths = self.find_models(ext_filter=[".pth"])
|
25 |
+
scalers = []
|
26 |
+
add_model2 = True
|
27 |
+
for file in model_paths:
|
28 |
+
if "http" in file:
|
29 |
+
name = self.model_name
|
30 |
+
else:
|
31 |
+
name = modelloader.friendly_name(file)
|
32 |
+
if name == self.model_name2 or file == self.model_url2:
|
33 |
+
add_model2 = False
|
34 |
+
try:
|
35 |
+
scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
|
36 |
+
scalers.append(scaler_data)
|
37 |
+
except Exception:
|
38 |
+
print(f"Error loading ScuNET model: {file}", file=sys.stderr)
|
39 |
+
print(traceback.format_exc(), file=sys.stderr)
|
40 |
+
if add_model2:
|
41 |
+
scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self)
|
42 |
+
scalers.append(scaler_data2)
|
43 |
+
self.scalers = scalers
|
44 |
+
|
45 |
+
def do_upscale(self, img: PIL.Image, selected_file):
|
46 |
+
torch.cuda.empty_cache()
|
47 |
+
|
48 |
+
model = self.load_model(selected_file)
|
49 |
+
if model is None:
|
50 |
+
return img
|
51 |
+
|
52 |
+
device = devices.get_device_for('scunet')
|
53 |
+
img = np.array(img)
|
54 |
+
img = img[:, :, ::-1]
|
55 |
+
img = np.moveaxis(img, 2, 0) / 255
|
56 |
+
img = torch.from_numpy(img).float()
|
57 |
+
img = img.unsqueeze(0).to(device)
|
58 |
+
|
59 |
+
with torch.no_grad():
|
60 |
+
output = model(img)
|
61 |
+
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
62 |
+
output = 255. * np.moveaxis(output, 0, 2)
|
63 |
+
output = output.astype(np.uint8)
|
64 |
+
output = output[:, :, ::-1]
|
65 |
+
torch.cuda.empty_cache()
|
66 |
+
return PIL.Image.fromarray(output, 'RGB')
|
67 |
+
|
68 |
+
def load_model(self, path: str):
|
69 |
+
device = devices.get_device_for('scunet')
|
70 |
+
if "http" in path:
|
71 |
+
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
|
72 |
+
progress=True)
|
73 |
+
else:
|
74 |
+
filename = path
|
75 |
+
if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None:
|
76 |
+
print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr)
|
77 |
+
return None
|
78 |
+
|
79 |
+
model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
|
80 |
+
model.load_state_dict(torch.load(filename), strict=True)
|
81 |
+
model.eval()
|
82 |
+
for k, v in model.named_parameters():
|
83 |
+
v.requires_grad = False
|
84 |
+
model = model.to(device)
|
85 |
+
|
86 |
+
return model
|
87 |
+
|
extensions-builtin/ScuNET/scunet_model_arch.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from einops import rearrange
|
6 |
+
from einops.layers.torch import Rearrange
|
7 |
+
from timm.models.layers import trunc_normal_, DropPath
|
8 |
+
|
9 |
+
|
10 |
+
class WMSA(nn.Module):
|
11 |
+
""" Self-attention module in Swin Transformer
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self, input_dim, output_dim, head_dim, window_size, type):
|
15 |
+
super(WMSA, self).__init__()
|
16 |
+
self.input_dim = input_dim
|
17 |
+
self.output_dim = output_dim
|
18 |
+
self.head_dim = head_dim
|
19 |
+
self.scale = self.head_dim ** -0.5
|
20 |
+
self.n_heads = input_dim // head_dim
|
21 |
+
self.window_size = window_size
|
22 |
+
self.type = type
|
23 |
+
self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True)
|
24 |
+
|
25 |
+
self.relative_position_params = nn.Parameter(
|
26 |
+
torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads))
|
27 |
+
|
28 |
+
self.linear = nn.Linear(self.input_dim, self.output_dim)
|
29 |
+
|
30 |
+
trunc_normal_(self.relative_position_params, std=.02)
|
31 |
+
self.relative_position_params = torch.nn.Parameter(
|
32 |
+
self.relative_position_params.view(2 * window_size - 1, 2 * window_size - 1, self.n_heads).transpose(1,
|
33 |
+
2).transpose(
|
34 |
+
0, 1))
|
35 |
+
|
36 |
+
def generate_mask(self, h, w, p, shift):
|
37 |
+
""" generating the mask of SW-MSA
|
38 |
+
Args:
|
39 |
+
shift: shift parameters in CyclicShift.
|
40 |
+
Returns:
|
41 |
+
attn_mask: should be (1 1 w p p),
|
42 |
+
"""
|
43 |
+
# supporting square.
|
44 |
+
attn_mask = torch.zeros(h, w, p, p, p, p, dtype=torch.bool, device=self.relative_position_params.device)
|
45 |
+
if self.type == 'W':
|
46 |
+
return attn_mask
|
47 |
+
|
48 |
+
s = p - shift
|
49 |
+
attn_mask[-1, :, :s, :, s:, :] = True
|
50 |
+
attn_mask[-1, :, s:, :, :s, :] = True
|
51 |
+
attn_mask[:, -1, :, :s, :, s:] = True
|
52 |
+
attn_mask[:, -1, :, s:, :, :s] = True
|
53 |
+
attn_mask = rearrange(attn_mask, 'w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)')
|
54 |
+
return attn_mask
|
55 |
+
|
56 |
+
def forward(self, x):
|
57 |
+
""" Forward pass of Window Multi-head Self-attention module.
|
58 |
+
Args:
|
59 |
+
x: input tensor with shape of [b h w c];
|
60 |
+
attn_mask: attention mask, fill -inf where the value is True;
|
61 |
+
Returns:
|
62 |
+
output: tensor shape [b h w c]
|
63 |
+
"""
|
64 |
+
if self.type != 'W': x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2))
|
65 |
+
x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size)
|
66 |
+
h_windows = x.size(1)
|
67 |
+
w_windows = x.size(2)
|
68 |
+
# square validation
|
69 |
+
# assert h_windows == w_windows
|
70 |
+
|
71 |
+
x = rearrange(x, 'b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c', p1=self.window_size, p2=self.window_size)
|
72 |
+
qkv = self.embedding_layer(x)
|
73 |
+
q, k, v = rearrange(qkv, 'b nw np (threeh c) -> threeh b nw np c', c=self.head_dim).chunk(3, dim=0)
|
74 |
+
sim = torch.einsum('hbwpc,hbwqc->hbwpq', q, k) * self.scale
|
75 |
+
# Adding learnable relative embedding
|
76 |
+
sim = sim + rearrange(self.relative_embedding(), 'h p q -> h 1 1 p q')
|
77 |
+
# Using Attn Mask to distinguish different subwindows.
|
78 |
+
if self.type != 'W':
|
79 |
+
attn_mask = self.generate_mask(h_windows, w_windows, self.window_size, shift=self.window_size // 2)
|
80 |
+
sim = sim.masked_fill_(attn_mask, float("-inf"))
|
81 |
+
|
82 |
+
probs = nn.functional.softmax(sim, dim=-1)
|
83 |
+
output = torch.einsum('hbwij,hbwjc->hbwic', probs, v)
|
84 |
+
output = rearrange(output, 'h b w p c -> b w p (h c)')
|
85 |
+
output = self.linear(output)
|
86 |
+
output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size)
|
87 |
+
|
88 |
+
if self.type != 'W': output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2),
|
89 |
+
dims=(1, 2))
|
90 |
+
return output
|
91 |
+
|
92 |
+
def relative_embedding(self):
|
93 |
+
cord = torch.tensor(np.array([[i, j] for i in range(self.window_size) for j in range(self.window_size)]))
|
94 |
+
relation = cord[:, None, :] - cord[None, :, :] + self.window_size - 1
|
95 |
+
# negative is allowed
|
96 |
+
return self.relative_position_params[:, relation[:, :, 0].long(), relation[:, :, 1].long()]
|
97 |
+
|
98 |
+
|
99 |
+
class Block(nn.Module):
|
100 |
+
def __init__(self, input_dim, output_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
|
101 |
+
""" SwinTransformer Block
|
102 |
+
"""
|
103 |
+
super(Block, self).__init__()
|
104 |
+
self.input_dim = input_dim
|
105 |
+
self.output_dim = output_dim
|
106 |
+
assert type in ['W', 'SW']
|
107 |
+
self.type = type
|
108 |
+
if input_resolution <= window_size:
|
109 |
+
self.type = 'W'
|
110 |
+
|
111 |
+
self.ln1 = nn.LayerNorm(input_dim)
|
112 |
+
self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type)
|
113 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
114 |
+
self.ln2 = nn.LayerNorm(input_dim)
|
115 |
+
self.mlp = nn.Sequential(
|
116 |
+
nn.Linear(input_dim, 4 * input_dim),
|
117 |
+
nn.GELU(),
|
118 |
+
nn.Linear(4 * input_dim, output_dim),
|
119 |
+
)
|
120 |
+
|
121 |
+
def forward(self, x):
|
122 |
+
x = x + self.drop_path(self.msa(self.ln1(x)))
|
123 |
+
x = x + self.drop_path(self.mlp(self.ln2(x)))
|
124 |
+
return x
|
125 |
+
|
126 |
+
|
127 |
+
class ConvTransBlock(nn.Module):
|
128 |
+
def __init__(self, conv_dim, trans_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
|
129 |
+
""" SwinTransformer and Conv Block
|
130 |
+
"""
|
131 |
+
super(ConvTransBlock, self).__init__()
|
132 |
+
self.conv_dim = conv_dim
|
133 |
+
self.trans_dim = trans_dim
|
134 |
+
self.head_dim = head_dim
|
135 |
+
self.window_size = window_size
|
136 |
+
self.drop_path = drop_path
|
137 |
+
self.type = type
|
138 |
+
self.input_resolution = input_resolution
|
139 |
+
|
140 |
+
assert self.type in ['W', 'SW']
|
141 |
+
if self.input_resolution <= self.window_size:
|
142 |
+
self.type = 'W'
|
143 |
+
|
144 |
+
self.trans_block = Block(self.trans_dim, self.trans_dim, self.head_dim, self.window_size, self.drop_path,
|
145 |
+
self.type, self.input_resolution)
|
146 |
+
self.conv1_1 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True)
|
147 |
+
self.conv1_2 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True)
|
148 |
+
|
149 |
+
self.conv_block = nn.Sequential(
|
150 |
+
nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False),
|
151 |
+
nn.ReLU(True),
|
152 |
+
nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False)
|
153 |
+
)
|
154 |
+
|
155 |
+
def forward(self, x):
|
156 |
+
conv_x, trans_x = torch.split(self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1)
|
157 |
+
conv_x = self.conv_block(conv_x) + conv_x
|
158 |
+
trans_x = Rearrange('b c h w -> b h w c')(trans_x)
|
159 |
+
trans_x = self.trans_block(trans_x)
|
160 |
+
trans_x = Rearrange('b h w c -> b c h w')(trans_x)
|
161 |
+
res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1))
|
162 |
+
x = x + res
|
163 |
+
|
164 |
+
return x
|
165 |
+
|
166 |
+
|
167 |
+
class SCUNet(nn.Module):
|
168 |
+
# def __init__(self, in_nc=3, config=[2, 2, 2, 2, 2, 2, 2], dim=64, drop_path_rate=0.0, input_resolution=256):
|
169 |
+
def __init__(self, in_nc=3, config=None, dim=64, drop_path_rate=0.0, input_resolution=256):
|
170 |
+
super(SCUNet, self).__init__()
|
171 |
+
if config is None:
|
172 |
+
config = [2, 2, 2, 2, 2, 2, 2]
|
173 |
+
self.config = config
|
174 |
+
self.dim = dim
|
175 |
+
self.head_dim = 32
|
176 |
+
self.window_size = 8
|
177 |
+
|
178 |
+
# drop path rate for each layer
|
179 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))]
|
180 |
+
|
181 |
+
self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)]
|
182 |
+
|
183 |
+
begin = 0
|
184 |
+
self.m_down1 = [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin],
|
185 |
+
'W' if not i % 2 else 'SW', input_resolution)
|
186 |
+
for i in range(config[0])] + \
|
187 |
+
[nn.Conv2d(dim, 2 * dim, 2, 2, 0, bias=False)]
|
188 |
+
|
189 |
+
begin += config[0]
|
190 |
+
self.m_down2 = [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin],
|
191 |
+
'W' if not i % 2 else 'SW', input_resolution // 2)
|
192 |
+
for i in range(config[1])] + \
|
193 |
+
[nn.Conv2d(2 * dim, 4 * dim, 2, 2, 0, bias=False)]
|
194 |
+
|
195 |
+
begin += config[1]
|
196 |
+
self.m_down3 = [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin],
|
197 |
+
'W' if not i % 2 else 'SW', input_resolution // 4)
|
198 |
+
for i in range(config[2])] + \
|
199 |
+
[nn.Conv2d(4 * dim, 8 * dim, 2, 2, 0, bias=False)]
|
200 |
+
|
201 |
+
begin += config[2]
|
202 |
+
self.m_body = [ConvTransBlock(4 * dim, 4 * dim, self.head_dim, self.window_size, dpr[i + begin],
|
203 |
+
'W' if not i % 2 else 'SW', input_resolution // 8)
|
204 |
+
for i in range(config[3])]
|
205 |
+
|
206 |
+
begin += config[3]
|
207 |
+
self.m_up3 = [nn.ConvTranspose2d(8 * dim, 4 * dim, 2, 2, 0, bias=False), ] + \
|
208 |
+
[ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin],
|
209 |
+
'W' if not i % 2 else 'SW', input_resolution // 4)
|
210 |
+
for i in range(config[4])]
|
211 |
+
|
212 |
+
begin += config[4]
|
213 |
+
self.m_up2 = [nn.ConvTranspose2d(4 * dim, 2 * dim, 2, 2, 0, bias=False), ] + \
|
214 |
+
[ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin],
|
215 |
+
'W' if not i % 2 else 'SW', input_resolution // 2)
|
216 |
+
for i in range(config[5])]
|
217 |
+
|
218 |
+
begin += config[5]
|
219 |
+
self.m_up1 = [nn.ConvTranspose2d(2 * dim, dim, 2, 2, 0, bias=False), ] + \
|
220 |
+
[ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin],
|
221 |
+
'W' if not i % 2 else 'SW', input_resolution)
|
222 |
+
for i in range(config[6])]
|
223 |
+
|
224 |
+
self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)]
|
225 |
+
|
226 |
+
self.m_head = nn.Sequential(*self.m_head)
|
227 |
+
self.m_down1 = nn.Sequential(*self.m_down1)
|
228 |
+
self.m_down2 = nn.Sequential(*self.m_down2)
|
229 |
+
self.m_down3 = nn.Sequential(*self.m_down3)
|
230 |
+
self.m_body = nn.Sequential(*self.m_body)
|
231 |
+
self.m_up3 = nn.Sequential(*self.m_up3)
|
232 |
+
self.m_up2 = nn.Sequential(*self.m_up2)
|
233 |
+
self.m_up1 = nn.Sequential(*self.m_up1)
|
234 |
+
self.m_tail = nn.Sequential(*self.m_tail)
|
235 |
+
# self.apply(self._init_weights)
|
236 |
+
|
237 |
+
def forward(self, x0):
|
238 |
+
|
239 |
+
h, w = x0.size()[-2:]
|
240 |
+
paddingBottom = int(np.ceil(h / 64) * 64 - h)
|
241 |
+
paddingRight = int(np.ceil(w / 64) * 64 - w)
|
242 |
+
x0 = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x0)
|
243 |
+
|
244 |
+
x1 = self.m_head(x0)
|
245 |
+
x2 = self.m_down1(x1)
|
246 |
+
x3 = self.m_down2(x2)
|
247 |
+
x4 = self.m_down3(x3)
|
248 |
+
x = self.m_body(x4)
|
249 |
+
x = self.m_up3(x + x4)
|
250 |
+
x = self.m_up2(x + x3)
|
251 |
+
x = self.m_up1(x + x2)
|
252 |
+
x = self.m_tail(x + x1)
|
253 |
+
|
254 |
+
x = x[..., :h, :w]
|
255 |
+
|
256 |
+
return x
|
257 |
+
|
258 |
+
def _init_weights(self, m):
|
259 |
+
if isinstance(m, nn.Linear):
|
260 |
+
trunc_normal_(m.weight, std=.02)
|
261 |
+
if m.bias is not None:
|
262 |
+
nn.init.constant_(m.bias, 0)
|
263 |
+
elif isinstance(m, nn.LayerNorm):
|
264 |
+
nn.init.constant_(m.bias, 0)
|
265 |
+
nn.init.constant_(m.weight, 1.0)
|
extensions-builtin/SwinIR/__pycache__/preload.cpython-310.pyc
ADDED
Binary file (522 Bytes). View file
|
|
extensions-builtin/SwinIR/__pycache__/swinir_model_arch.cpython-310.pyc
ADDED
Binary file (27.8 kB). View file
|
|
extensions-builtin/SwinIR/__pycache__/swinir_model_arch_v2.cpython-310.pyc
ADDED
Binary file (31.4 kB). View file
|
|
extensions-builtin/SwinIR/preload.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from modules import paths
|
3 |
+
|
4 |
+
|
5 |
+
def preload(parser):
|
6 |
+
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(paths.models_path, 'SwinIR'))
|
extensions-builtin/SwinIR/scripts/__pycache__/swinir_model.cpython-310.pyc
ADDED
Binary file (5.5 kB). View file
|
|
extensions-builtin/SwinIR/scripts/swinir_model.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import contextlib
|
2 |
+
import os
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
from basicsr.utils.download_util import load_file_from_url
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
from modules import modelloader, devices, script_callbacks, shared
|
11 |
+
from modules.shared import cmd_opts, opts, state
|
12 |
+
from swinir_model_arch import SwinIR as net
|
13 |
+
from swinir_model_arch_v2 import Swin2SR as net2
|
14 |
+
from modules.upscaler import Upscaler, UpscalerData
|
15 |
+
|
16 |
+
|
17 |
+
device_swinir = devices.get_device_for('swinir')
|
18 |
+
|
19 |
+
|
20 |
+
class UpscalerSwinIR(Upscaler):
|
21 |
+
def __init__(self, dirname):
|
22 |
+
self.name = "SwinIR"
|
23 |
+
self.model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0" \
|
24 |
+
"/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \
|
25 |
+
"-L_x4_GAN.pth "
|
26 |
+
self.model_name = "SwinIR 4x"
|
27 |
+
self.user_path = dirname
|
28 |
+
super().__init__()
|
29 |
+
scalers = []
|
30 |
+
model_files = self.find_models(ext_filter=[".pt", ".pth"])
|
31 |
+
for model in model_files:
|
32 |
+
if "http" in model:
|
33 |
+
name = self.model_name
|
34 |
+
else:
|
35 |
+
name = modelloader.friendly_name(model)
|
36 |
+
model_data = UpscalerData(name, model, self)
|
37 |
+
scalers.append(model_data)
|
38 |
+
self.scalers = scalers
|
39 |
+
|
40 |
+
def do_upscale(self, img, model_file):
|
41 |
+
model = self.load_model(model_file)
|
42 |
+
if model is None:
|
43 |
+
return img
|
44 |
+
model = model.to(device_swinir, dtype=devices.dtype)
|
45 |
+
img = upscale(img, model)
|
46 |
+
try:
|
47 |
+
torch.cuda.empty_cache()
|
48 |
+
except:
|
49 |
+
pass
|
50 |
+
return img
|
51 |
+
|
52 |
+
def load_model(self, path, scale=4):
|
53 |
+
if "http" in path:
|
54 |
+
dl_name = "%s%s" % (self.model_name.replace(" ", "_"), ".pth")
|
55 |
+
filename = load_file_from_url(url=path, model_dir=self.model_path, file_name=dl_name, progress=True)
|
56 |
+
else:
|
57 |
+
filename = path
|
58 |
+
if filename is None or not os.path.exists(filename):
|
59 |
+
return None
|
60 |
+
if filename.endswith(".v2.pth"):
|
61 |
+
model = net2(
|
62 |
+
upscale=scale,
|
63 |
+
in_chans=3,
|
64 |
+
img_size=64,
|
65 |
+
window_size=8,
|
66 |
+
img_range=1.0,
|
67 |
+
depths=[6, 6, 6, 6, 6, 6],
|
68 |
+
embed_dim=180,
|
69 |
+
num_heads=[6, 6, 6, 6, 6, 6],
|
70 |
+
mlp_ratio=2,
|
71 |
+
upsampler="nearest+conv",
|
72 |
+
resi_connection="1conv",
|
73 |
+
)
|
74 |
+
params = None
|
75 |
+
else:
|
76 |
+
model = net(
|
77 |
+
upscale=scale,
|
78 |
+
in_chans=3,
|
79 |
+
img_size=64,
|
80 |
+
window_size=8,
|
81 |
+
img_range=1.0,
|
82 |
+
depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
|
83 |
+
embed_dim=240,
|
84 |
+
num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
|
85 |
+
mlp_ratio=2,
|
86 |
+
upsampler="nearest+conv",
|
87 |
+
resi_connection="3conv",
|
88 |
+
)
|
89 |
+
params = "params_ema"
|
90 |
+
|
91 |
+
pretrained_model = torch.load(filename)
|
92 |
+
if params is not None:
|
93 |
+
model.load_state_dict(pretrained_model[params], strict=True)
|
94 |
+
else:
|
95 |
+
model.load_state_dict(pretrained_model, strict=True)
|
96 |
+
return model
|
97 |
+
|
98 |
+
|
99 |
+
def upscale(
|
100 |
+
img,
|
101 |
+
model,
|
102 |
+
tile=None,
|
103 |
+
tile_overlap=None,
|
104 |
+
window_size=8,
|
105 |
+
scale=4,
|
106 |
+
):
|
107 |
+
tile = tile or opts.SWIN_tile
|
108 |
+
tile_overlap = tile_overlap or opts.SWIN_tile_overlap
|
109 |
+
|
110 |
+
|
111 |
+
img = np.array(img)
|
112 |
+
img = img[:, :, ::-1]
|
113 |
+
img = np.moveaxis(img, 2, 0) / 255
|
114 |
+
img = torch.from_numpy(img).float()
|
115 |
+
img = img.unsqueeze(0).to(device_swinir, dtype=devices.dtype)
|
116 |
+
with torch.no_grad(), devices.autocast():
|
117 |
+
_, _, h_old, w_old = img.size()
|
118 |
+
h_pad = (h_old // window_size + 1) * window_size - h_old
|
119 |
+
w_pad = (w_old // window_size + 1) * window_size - w_old
|
120 |
+
img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :]
|
121 |
+
img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad]
|
122 |
+
output = inference(img, model, tile, tile_overlap, window_size, scale)
|
123 |
+
output = output[..., : h_old * scale, : w_old * scale]
|
124 |
+
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
125 |
+
if output.ndim == 3:
|
126 |
+
output = np.transpose(
|
127 |
+
output[[2, 1, 0], :, :], (1, 2, 0)
|
128 |
+
) # CHW-RGB to HCW-BGR
|
129 |
+
output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
|
130 |
+
return Image.fromarray(output, "RGB")
|
131 |
+
|
132 |
+
|
133 |
+
def inference(img, model, tile, tile_overlap, window_size, scale):
|
134 |
+
# test the image tile by tile
|
135 |
+
b, c, h, w = img.size()
|
136 |
+
tile = min(tile, h, w)
|
137 |
+
assert tile % window_size == 0, "tile size should be a multiple of window_size"
|
138 |
+
sf = scale
|
139 |
+
|
140 |
+
stride = tile - tile_overlap
|
141 |
+
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
|
142 |
+
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
|
143 |
+
E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device_swinir).type_as(img)
|
144 |
+
W = torch.zeros_like(E, dtype=devices.dtype, device=device_swinir)
|
145 |
+
|
146 |
+
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
|
147 |
+
for h_idx in h_idx_list:
|
148 |
+
if state.interrupted or state.skipped:
|
149 |
+
break
|
150 |
+
|
151 |
+
for w_idx in w_idx_list:
|
152 |
+
if state.interrupted or state.skipped:
|
153 |
+
break
|
154 |
+
|
155 |
+
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
|
156 |
+
out_patch = model(in_patch)
|
157 |
+
out_patch_mask = torch.ones_like(out_patch)
|
158 |
+
|
159 |
+
E[
|
160 |
+
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
|
161 |
+
].add_(out_patch)
|
162 |
+
W[
|
163 |
+
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
|
164 |
+
].add_(out_patch_mask)
|
165 |
+
pbar.update(1)
|
166 |
+
output = E.div_(W)
|
167 |
+
|
168 |
+
return output
|
169 |
+
|
170 |
+
|
171 |
+
def on_ui_settings():
|
172 |
+
import gradio as gr
|
173 |
+
|
174 |
+
shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")))
|
175 |
+
shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling")))
|
176 |
+
|
177 |
+
|
178 |
+
script_callbacks.on_ui_settings(on_ui_settings)
|
extensions-builtin/SwinIR/swinir_model_arch.py
ADDED
@@ -0,0 +1,867 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -----------------------------------------------------------------------------------
|
2 |
+
# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
|
3 |
+
# Originally Written by Ze Liu, Modified by Jingyun Liang.
|
4 |
+
# -----------------------------------------------------------------------------------
|
5 |
+
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torch.utils.checkpoint as checkpoint
|
11 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
12 |
+
|
13 |
+
|
14 |
+
class Mlp(nn.Module):
|
15 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
16 |
+
super().__init__()
|
17 |
+
out_features = out_features or in_features
|
18 |
+
hidden_features = hidden_features or in_features
|
19 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
20 |
+
self.act = act_layer()
|
21 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
22 |
+
self.drop = nn.Dropout(drop)
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
x = self.fc1(x)
|
26 |
+
x = self.act(x)
|
27 |
+
x = self.drop(x)
|
28 |
+
x = self.fc2(x)
|
29 |
+
x = self.drop(x)
|
30 |
+
return x
|
31 |
+
|
32 |
+
|
33 |
+
def window_partition(x, window_size):
|
34 |
+
"""
|
35 |
+
Args:
|
36 |
+
x: (B, H, W, C)
|
37 |
+
window_size (int): window size
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
windows: (num_windows*B, window_size, window_size, C)
|
41 |
+
"""
|
42 |
+
B, H, W, C = x.shape
|
43 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
44 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
45 |
+
return windows
|
46 |
+
|
47 |
+
|
48 |
+
def window_reverse(windows, window_size, H, W):
|
49 |
+
"""
|
50 |
+
Args:
|
51 |
+
windows: (num_windows*B, window_size, window_size, C)
|
52 |
+
window_size (int): Window size
|
53 |
+
H (int): Height of image
|
54 |
+
W (int): Width of image
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
x: (B, H, W, C)
|
58 |
+
"""
|
59 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
60 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
61 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
62 |
+
return x
|
63 |
+
|
64 |
+
|
65 |
+
class WindowAttention(nn.Module):
|
66 |
+
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
67 |
+
It supports both of shifted and non-shifted window.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
dim (int): Number of input channels.
|
71 |
+
window_size (tuple[int]): The height and width of the window.
|
72 |
+
num_heads (int): Number of attention heads.
|
73 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
74 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
75 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
76 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
77 |
+
"""
|
78 |
+
|
79 |
+
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
|
80 |
+
|
81 |
+
super().__init__()
|
82 |
+
self.dim = dim
|
83 |
+
self.window_size = window_size # Wh, Ww
|
84 |
+
self.num_heads = num_heads
|
85 |
+
head_dim = dim // num_heads
|
86 |
+
self.scale = qk_scale or head_dim ** -0.5
|
87 |
+
|
88 |
+
# define a parameter table of relative position bias
|
89 |
+
self.relative_position_bias_table = nn.Parameter(
|
90 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
91 |
+
|
92 |
+
# get pair-wise relative position index for each token inside the window
|
93 |
+
coords_h = torch.arange(self.window_size[0])
|
94 |
+
coords_w = torch.arange(self.window_size[1])
|
95 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
96 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
97 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
98 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
99 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
100 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
101 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
102 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
103 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
104 |
+
|
105 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
106 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
107 |
+
self.proj = nn.Linear(dim, dim)
|
108 |
+
|
109 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
110 |
+
|
111 |
+
trunc_normal_(self.relative_position_bias_table, std=.02)
|
112 |
+
self.softmax = nn.Softmax(dim=-1)
|
113 |
+
|
114 |
+
def forward(self, x, mask=None):
|
115 |
+
"""
|
116 |
+
Args:
|
117 |
+
x: input features with shape of (num_windows*B, N, C)
|
118 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
119 |
+
"""
|
120 |
+
B_, N, C = x.shape
|
121 |
+
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
122 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
123 |
+
|
124 |
+
q = q * self.scale
|
125 |
+
attn = (q @ k.transpose(-2, -1))
|
126 |
+
|
127 |
+
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
128 |
+
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
129 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
130 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
131 |
+
|
132 |
+
if mask is not None:
|
133 |
+
nW = mask.shape[0]
|
134 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
135 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
136 |
+
attn = self.softmax(attn)
|
137 |
+
else:
|
138 |
+
attn = self.softmax(attn)
|
139 |
+
|
140 |
+
attn = self.attn_drop(attn)
|
141 |
+
|
142 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
143 |
+
x = self.proj(x)
|
144 |
+
x = self.proj_drop(x)
|
145 |
+
return x
|
146 |
+
|
147 |
+
def extra_repr(self) -> str:
|
148 |
+
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
|
149 |
+
|
150 |
+
def flops(self, N):
|
151 |
+
# calculate flops for 1 window with token length of N
|
152 |
+
flops = 0
|
153 |
+
# qkv = self.qkv(x)
|
154 |
+
flops += N * self.dim * 3 * self.dim
|
155 |
+
# attn = (q @ k.transpose(-2, -1))
|
156 |
+
flops += self.num_heads * N * (self.dim // self.num_heads) * N
|
157 |
+
# x = (attn @ v)
|
158 |
+
flops += self.num_heads * N * N * (self.dim // self.num_heads)
|
159 |
+
# x = self.proj(x)
|
160 |
+
flops += N * self.dim * self.dim
|
161 |
+
return flops
|
162 |
+
|
163 |
+
|
164 |
+
class SwinTransformerBlock(nn.Module):
|
165 |
+
r""" Swin Transformer Block.
|
166 |
+
|
167 |
+
Args:
|
168 |
+
dim (int): Number of input channels.
|
169 |
+
input_resolution (tuple[int]): Input resolution.
|
170 |
+
num_heads (int): Number of attention heads.
|
171 |
+
window_size (int): Window size.
|
172 |
+
shift_size (int): Shift size for SW-MSA.
|
173 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
174 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
175 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
176 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
177 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
178 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
179 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
180 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
181 |
+
"""
|
182 |
+
|
183 |
+
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
|
184 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
|
185 |
+
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
186 |
+
super().__init__()
|
187 |
+
self.dim = dim
|
188 |
+
self.input_resolution = input_resolution
|
189 |
+
self.num_heads = num_heads
|
190 |
+
self.window_size = window_size
|
191 |
+
self.shift_size = shift_size
|
192 |
+
self.mlp_ratio = mlp_ratio
|
193 |
+
if min(self.input_resolution) <= self.window_size:
|
194 |
+
# if window size is larger than input resolution, we don't partition windows
|
195 |
+
self.shift_size = 0
|
196 |
+
self.window_size = min(self.input_resolution)
|
197 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
198 |
+
|
199 |
+
self.norm1 = norm_layer(dim)
|
200 |
+
self.attn = WindowAttention(
|
201 |
+
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
|
202 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
203 |
+
|
204 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
205 |
+
self.norm2 = norm_layer(dim)
|
206 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
207 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
208 |
+
|
209 |
+
if self.shift_size > 0:
|
210 |
+
attn_mask = self.calculate_mask(self.input_resolution)
|
211 |
+
else:
|
212 |
+
attn_mask = None
|
213 |
+
|
214 |
+
self.register_buffer("attn_mask", attn_mask)
|
215 |
+
|
216 |
+
def calculate_mask(self, x_size):
|
217 |
+
# calculate attention mask for SW-MSA
|
218 |
+
H, W = x_size
|
219 |
+
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
220 |
+
h_slices = (slice(0, -self.window_size),
|
221 |
+
slice(-self.window_size, -self.shift_size),
|
222 |
+
slice(-self.shift_size, None))
|
223 |
+
w_slices = (slice(0, -self.window_size),
|
224 |
+
slice(-self.window_size, -self.shift_size),
|
225 |
+
slice(-self.shift_size, None))
|
226 |
+
cnt = 0
|
227 |
+
for h in h_slices:
|
228 |
+
for w in w_slices:
|
229 |
+
img_mask[:, h, w, :] = cnt
|
230 |
+
cnt += 1
|
231 |
+
|
232 |
+
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
233 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
234 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
235 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
236 |
+
|
237 |
+
return attn_mask
|
238 |
+
|
239 |
+
def forward(self, x, x_size):
|
240 |
+
H, W = x_size
|
241 |
+
B, L, C = x.shape
|
242 |
+
# assert L == H * W, "input feature has wrong size"
|
243 |
+
|
244 |
+
shortcut = x
|
245 |
+
x = self.norm1(x)
|
246 |
+
x = x.view(B, H, W, C)
|
247 |
+
|
248 |
+
# cyclic shift
|
249 |
+
if self.shift_size > 0:
|
250 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
251 |
+
else:
|
252 |
+
shifted_x = x
|
253 |
+
|
254 |
+
# partition windows
|
255 |
+
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
256 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
257 |
+
|
258 |
+
# W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
|
259 |
+
if self.input_resolution == x_size:
|
260 |
+
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
|
261 |
+
else:
|
262 |
+
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
|
263 |
+
|
264 |
+
# merge windows
|
265 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
266 |
+
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
267 |
+
|
268 |
+
# reverse cyclic shift
|
269 |
+
if self.shift_size > 0:
|
270 |
+
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
271 |
+
else:
|
272 |
+
x = shifted_x
|
273 |
+
x = x.view(B, H * W, C)
|
274 |
+
|
275 |
+
# FFN
|
276 |
+
x = shortcut + self.drop_path(x)
|
277 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
278 |
+
|
279 |
+
return x
|
280 |
+
|
281 |
+
def extra_repr(self) -> str:
|
282 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
|
283 |
+
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
|
284 |
+
|
285 |
+
def flops(self):
|
286 |
+
flops = 0
|
287 |
+
H, W = self.input_resolution
|
288 |
+
# norm1
|
289 |
+
flops += self.dim * H * W
|
290 |
+
# W-MSA/SW-MSA
|
291 |
+
nW = H * W / self.window_size / self.window_size
|
292 |
+
flops += nW * self.attn.flops(self.window_size * self.window_size)
|
293 |
+
# mlp
|
294 |
+
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
|
295 |
+
# norm2
|
296 |
+
flops += self.dim * H * W
|
297 |
+
return flops
|
298 |
+
|
299 |
+
|
300 |
+
class PatchMerging(nn.Module):
|
301 |
+
r""" Patch Merging Layer.
|
302 |
+
|
303 |
+
Args:
|
304 |
+
input_resolution (tuple[int]): Resolution of input feature.
|
305 |
+
dim (int): Number of input channels.
|
306 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
307 |
+
"""
|
308 |
+
|
309 |
+
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
|
310 |
+
super().__init__()
|
311 |
+
self.input_resolution = input_resolution
|
312 |
+
self.dim = dim
|
313 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
314 |
+
self.norm = norm_layer(4 * dim)
|
315 |
+
|
316 |
+
def forward(self, x):
|
317 |
+
"""
|
318 |
+
x: B, H*W, C
|
319 |
+
"""
|
320 |
+
H, W = self.input_resolution
|
321 |
+
B, L, C = x.shape
|
322 |
+
assert L == H * W, "input feature has wrong size"
|
323 |
+
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
|
324 |
+
|
325 |
+
x = x.view(B, H, W, C)
|
326 |
+
|
327 |
+
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
328 |
+
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
329 |
+
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
330 |
+
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
331 |
+
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
332 |
+
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
333 |
+
|
334 |
+
x = self.norm(x)
|
335 |
+
x = self.reduction(x)
|
336 |
+
|
337 |
+
return x
|
338 |
+
|
339 |
+
def extra_repr(self) -> str:
|
340 |
+
return f"input_resolution={self.input_resolution}, dim={self.dim}"
|
341 |
+
|
342 |
+
def flops(self):
|
343 |
+
H, W = self.input_resolution
|
344 |
+
flops = H * W * self.dim
|
345 |
+
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
|
346 |
+
return flops
|
347 |
+
|
348 |
+
|
349 |
+
class BasicLayer(nn.Module):
|
350 |
+
""" A basic Swin Transformer layer for one stage.
|
351 |
+
|
352 |
+
Args:
|
353 |
+
dim (int): Number of input channels.
|
354 |
+
input_resolution (tuple[int]): Input resolution.
|
355 |
+
depth (int): Number of blocks.
|
356 |
+
num_heads (int): Number of attention heads.
|
357 |
+
window_size (int): Local window size.
|
358 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
359 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
360 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
361 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
362 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
363 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
364 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
365 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
366 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
367 |
+
"""
|
368 |
+
|
369 |
+
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
370 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
|
371 |
+
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
|
372 |
+
|
373 |
+
super().__init__()
|
374 |
+
self.dim = dim
|
375 |
+
self.input_resolution = input_resolution
|
376 |
+
self.depth = depth
|
377 |
+
self.use_checkpoint = use_checkpoint
|
378 |
+
|
379 |
+
# build blocks
|
380 |
+
self.blocks = nn.ModuleList([
|
381 |
+
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
|
382 |
+
num_heads=num_heads, window_size=window_size,
|
383 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
384 |
+
mlp_ratio=mlp_ratio,
|
385 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
386 |
+
drop=drop, attn_drop=attn_drop,
|
387 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
388 |
+
norm_layer=norm_layer)
|
389 |
+
for i in range(depth)])
|
390 |
+
|
391 |
+
# patch merging layer
|
392 |
+
if downsample is not None:
|
393 |
+
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
|
394 |
+
else:
|
395 |
+
self.downsample = None
|
396 |
+
|
397 |
+
def forward(self, x, x_size):
|
398 |
+
for blk in self.blocks:
|
399 |
+
if self.use_checkpoint:
|
400 |
+
x = checkpoint.checkpoint(blk, x, x_size)
|
401 |
+
else:
|
402 |
+
x = blk(x, x_size)
|
403 |
+
if self.downsample is not None:
|
404 |
+
x = self.downsample(x)
|
405 |
+
return x
|
406 |
+
|
407 |
+
def extra_repr(self) -> str:
|
408 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
409 |
+
|
410 |
+
def flops(self):
|
411 |
+
flops = 0
|
412 |
+
for blk in self.blocks:
|
413 |
+
flops += blk.flops()
|
414 |
+
if self.downsample is not None:
|
415 |
+
flops += self.downsample.flops()
|
416 |
+
return flops
|
417 |
+
|
418 |
+
|
419 |
+
class RSTB(nn.Module):
|
420 |
+
"""Residual Swin Transformer Block (RSTB).
|
421 |
+
|
422 |
+
Args:
|
423 |
+
dim (int): Number of input channels.
|
424 |
+
input_resolution (tuple[int]): Input resolution.
|
425 |
+
depth (int): Number of blocks.
|
426 |
+
num_heads (int): Number of attention heads.
|
427 |
+
window_size (int): Local window size.
|
428 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
429 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
430 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
431 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
432 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
433 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
434 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
435 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
436 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
437 |
+
img_size: Input image size.
|
438 |
+
patch_size: Patch size.
|
439 |
+
resi_connection: The convolutional block before residual connection.
|
440 |
+
"""
|
441 |
+
|
442 |
+
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
443 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
|
444 |
+
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
|
445 |
+
img_size=224, patch_size=4, resi_connection='1conv'):
|
446 |
+
super(RSTB, self).__init__()
|
447 |
+
|
448 |
+
self.dim = dim
|
449 |
+
self.input_resolution = input_resolution
|
450 |
+
|
451 |
+
self.residual_group = BasicLayer(dim=dim,
|
452 |
+
input_resolution=input_resolution,
|
453 |
+
depth=depth,
|
454 |
+
num_heads=num_heads,
|
455 |
+
window_size=window_size,
|
456 |
+
mlp_ratio=mlp_ratio,
|
457 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
458 |
+
drop=drop, attn_drop=attn_drop,
|
459 |
+
drop_path=drop_path,
|
460 |
+
norm_layer=norm_layer,
|
461 |
+
downsample=downsample,
|
462 |
+
use_checkpoint=use_checkpoint)
|
463 |
+
|
464 |
+
if resi_connection == '1conv':
|
465 |
+
self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
|
466 |
+
elif resi_connection == '3conv':
|
467 |
+
# to save parameters and memory
|
468 |
+
self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
469 |
+
nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
|
470 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
471 |
+
nn.Conv2d(dim // 4, dim, 3, 1, 1))
|
472 |
+
|
473 |
+
self.patch_embed = PatchEmbed(
|
474 |
+
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
|
475 |
+
norm_layer=None)
|
476 |
+
|
477 |
+
self.patch_unembed = PatchUnEmbed(
|
478 |
+
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
|
479 |
+
norm_layer=None)
|
480 |
+
|
481 |
+
def forward(self, x, x_size):
|
482 |
+
return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
|
483 |
+
|
484 |
+
def flops(self):
|
485 |
+
flops = 0
|
486 |
+
flops += self.residual_group.flops()
|
487 |
+
H, W = self.input_resolution
|
488 |
+
flops += H * W * self.dim * self.dim * 9
|
489 |
+
flops += self.patch_embed.flops()
|
490 |
+
flops += self.patch_unembed.flops()
|
491 |
+
|
492 |
+
return flops
|
493 |
+
|
494 |
+
|
495 |
+
class PatchEmbed(nn.Module):
|
496 |
+
r""" Image to Patch Embedding
|
497 |
+
|
498 |
+
Args:
|
499 |
+
img_size (int): Image size. Default: 224.
|
500 |
+
patch_size (int): Patch token size. Default: 4.
|
501 |
+
in_chans (int): Number of input image channels. Default: 3.
|
502 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
503 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
504 |
+
"""
|
505 |
+
|
506 |
+
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
507 |
+
super().__init__()
|
508 |
+
img_size = to_2tuple(img_size)
|
509 |
+
patch_size = to_2tuple(patch_size)
|
510 |
+
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
511 |
+
self.img_size = img_size
|
512 |
+
self.patch_size = patch_size
|
513 |
+
self.patches_resolution = patches_resolution
|
514 |
+
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
515 |
+
|
516 |
+
self.in_chans = in_chans
|
517 |
+
self.embed_dim = embed_dim
|
518 |
+
|
519 |
+
if norm_layer is not None:
|
520 |
+
self.norm = norm_layer(embed_dim)
|
521 |
+
else:
|
522 |
+
self.norm = None
|
523 |
+
|
524 |
+
def forward(self, x):
|
525 |
+
x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
|
526 |
+
if self.norm is not None:
|
527 |
+
x = self.norm(x)
|
528 |
+
return x
|
529 |
+
|
530 |
+
def flops(self):
|
531 |
+
flops = 0
|
532 |
+
H, W = self.img_size
|
533 |
+
if self.norm is not None:
|
534 |
+
flops += H * W * self.embed_dim
|
535 |
+
return flops
|
536 |
+
|
537 |
+
|
538 |
+
class PatchUnEmbed(nn.Module):
|
539 |
+
r""" Image to Patch Unembedding
|
540 |
+
|
541 |
+
Args:
|
542 |
+
img_size (int): Image size. Default: 224.
|
543 |
+
patch_size (int): Patch token size. Default: 4.
|
544 |
+
in_chans (int): Number of input image channels. Default: 3.
|
545 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
546 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
547 |
+
"""
|
548 |
+
|
549 |
+
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
550 |
+
super().__init__()
|
551 |
+
img_size = to_2tuple(img_size)
|
552 |
+
patch_size = to_2tuple(patch_size)
|
553 |
+
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
554 |
+
self.img_size = img_size
|
555 |
+
self.patch_size = patch_size
|
556 |
+
self.patches_resolution = patches_resolution
|
557 |
+
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
558 |
+
|
559 |
+
self.in_chans = in_chans
|
560 |
+
self.embed_dim = embed_dim
|
561 |
+
|
562 |
+
def forward(self, x, x_size):
|
563 |
+
B, HW, C = x.shape
|
564 |
+
x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
|
565 |
+
return x
|
566 |
+
|
567 |
+
def flops(self):
|
568 |
+
flops = 0
|
569 |
+
return flops
|
570 |
+
|
571 |
+
|
572 |
+
class Upsample(nn.Sequential):
|
573 |
+
"""Upsample module.
|
574 |
+
|
575 |
+
Args:
|
576 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
577 |
+
num_feat (int): Channel number of intermediate features.
|
578 |
+
"""
|
579 |
+
|
580 |
+
def __init__(self, scale, num_feat):
|
581 |
+
m = []
|
582 |
+
if (scale & (scale - 1)) == 0: # scale = 2^n
|
583 |
+
for _ in range(int(math.log(scale, 2))):
|
584 |
+
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
|
585 |
+
m.append(nn.PixelShuffle(2))
|
586 |
+
elif scale == 3:
|
587 |
+
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
|
588 |
+
m.append(nn.PixelShuffle(3))
|
589 |
+
else:
|
590 |
+
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
|
591 |
+
super(Upsample, self).__init__(*m)
|
592 |
+
|
593 |
+
|
594 |
+
class UpsampleOneStep(nn.Sequential):
|
595 |
+
"""UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
|
596 |
+
Used in lightweight SR to save parameters.
|
597 |
+
|
598 |
+
Args:
|
599 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
600 |
+
num_feat (int): Channel number of intermediate features.
|
601 |
+
|
602 |
+
"""
|
603 |
+
|
604 |
+
def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
|
605 |
+
self.num_feat = num_feat
|
606 |
+
self.input_resolution = input_resolution
|
607 |
+
m = []
|
608 |
+
m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
|
609 |
+
m.append(nn.PixelShuffle(scale))
|
610 |
+
super(UpsampleOneStep, self).__init__(*m)
|
611 |
+
|
612 |
+
def flops(self):
|
613 |
+
H, W = self.input_resolution
|
614 |
+
flops = H * W * self.num_feat * 3 * 9
|
615 |
+
return flops
|
616 |
+
|
617 |
+
|
618 |
+
class SwinIR(nn.Module):
|
619 |
+
r""" SwinIR
|
620 |
+
A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
|
621 |
+
|
622 |
+
Args:
|
623 |
+
img_size (int | tuple(int)): Input image size. Default 64
|
624 |
+
patch_size (int | tuple(int)): Patch size. Default: 1
|
625 |
+
in_chans (int): Number of input image channels. Default: 3
|
626 |
+
embed_dim (int): Patch embedding dimension. Default: 96
|
627 |
+
depths (tuple(int)): Depth of each Swin Transformer layer.
|
628 |
+
num_heads (tuple(int)): Number of attention heads in different layers.
|
629 |
+
window_size (int): Window size. Default: 7
|
630 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
631 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
632 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
|
633 |
+
drop_rate (float): Dropout rate. Default: 0
|
634 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0
|
635 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
636 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
637 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
|
638 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True
|
639 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
|
640 |
+
upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
|
641 |
+
img_range: Image range. 1. or 255.
|
642 |
+
upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
|
643 |
+
resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
|
644 |
+
"""
|
645 |
+
|
646 |
+
def __init__(self, img_size=64, patch_size=1, in_chans=3,
|
647 |
+
embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
|
648 |
+
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
|
649 |
+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
650 |
+
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
651 |
+
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
|
652 |
+
**kwargs):
|
653 |
+
super(SwinIR, self).__init__()
|
654 |
+
num_in_ch = in_chans
|
655 |
+
num_out_ch = in_chans
|
656 |
+
num_feat = 64
|
657 |
+
self.img_range = img_range
|
658 |
+
if in_chans == 3:
|
659 |
+
rgb_mean = (0.4488, 0.4371, 0.4040)
|
660 |
+
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
|
661 |
+
else:
|
662 |
+
self.mean = torch.zeros(1, 1, 1, 1)
|
663 |
+
self.upscale = upscale
|
664 |
+
self.upsampler = upsampler
|
665 |
+
self.window_size = window_size
|
666 |
+
|
667 |
+
#####################################################################################################
|
668 |
+
################################### 1, shallow feature extraction ###################################
|
669 |
+
self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
|
670 |
+
|
671 |
+
#####################################################################################################
|
672 |
+
################################### 2, deep feature extraction ######################################
|
673 |
+
self.num_layers = len(depths)
|
674 |
+
self.embed_dim = embed_dim
|
675 |
+
self.ape = ape
|
676 |
+
self.patch_norm = patch_norm
|
677 |
+
self.num_features = embed_dim
|
678 |
+
self.mlp_ratio = mlp_ratio
|
679 |
+
|
680 |
+
# split image into non-overlapping patches
|
681 |
+
self.patch_embed = PatchEmbed(
|
682 |
+
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
|
683 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
684 |
+
num_patches = self.patch_embed.num_patches
|
685 |
+
patches_resolution = self.patch_embed.patches_resolution
|
686 |
+
self.patches_resolution = patches_resolution
|
687 |
+
|
688 |
+
# merge non-overlapping patches into image
|
689 |
+
self.patch_unembed = PatchUnEmbed(
|
690 |
+
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
|
691 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
692 |
+
|
693 |
+
# absolute position embedding
|
694 |
+
if self.ape:
|
695 |
+
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
696 |
+
trunc_normal_(self.absolute_pos_embed, std=.02)
|
697 |
+
|
698 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
699 |
+
|
700 |
+
# stochastic depth
|
701 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
702 |
+
|
703 |
+
# build Residual Swin Transformer blocks (RSTB)
|
704 |
+
self.layers = nn.ModuleList()
|
705 |
+
for i_layer in range(self.num_layers):
|
706 |
+
layer = RSTB(dim=embed_dim,
|
707 |
+
input_resolution=(patches_resolution[0],
|
708 |
+
patches_resolution[1]),
|
709 |
+
depth=depths[i_layer],
|
710 |
+
num_heads=num_heads[i_layer],
|
711 |
+
window_size=window_size,
|
712 |
+
mlp_ratio=self.mlp_ratio,
|
713 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
714 |
+
drop=drop_rate, attn_drop=attn_drop_rate,
|
715 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
|
716 |
+
norm_layer=norm_layer,
|
717 |
+
downsample=None,
|
718 |
+
use_checkpoint=use_checkpoint,
|
719 |
+
img_size=img_size,
|
720 |
+
patch_size=patch_size,
|
721 |
+
resi_connection=resi_connection
|
722 |
+
|
723 |
+
)
|
724 |
+
self.layers.append(layer)
|
725 |
+
self.norm = norm_layer(self.num_features)
|
726 |
+
|
727 |
+
# build the last conv layer in deep feature extraction
|
728 |
+
if resi_connection == '1conv':
|
729 |
+
self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
|
730 |
+
elif resi_connection == '3conv':
|
731 |
+
# to save parameters and memory
|
732 |
+
self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
|
733 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
734 |
+
nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
|
735 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
736 |
+
nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
|
737 |
+
|
738 |
+
#####################################################################################################
|
739 |
+
################################ 3, high quality image reconstruction ################################
|
740 |
+
if self.upsampler == 'pixelshuffle':
|
741 |
+
# for classical SR
|
742 |
+
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
743 |
+
nn.LeakyReLU(inplace=True))
|
744 |
+
self.upsample = Upsample(upscale, num_feat)
|
745 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
746 |
+
elif self.upsampler == 'pixelshuffledirect':
|
747 |
+
# for lightweight SR (to save parameters)
|
748 |
+
self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
|
749 |
+
(patches_resolution[0], patches_resolution[1]))
|
750 |
+
elif self.upsampler == 'nearest+conv':
|
751 |
+
# for real-world SR (less artifacts)
|
752 |
+
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
753 |
+
nn.LeakyReLU(inplace=True))
|
754 |
+
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
755 |
+
if self.upscale == 4:
|
756 |
+
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
757 |
+
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
758 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
759 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
760 |
+
else:
|
761 |
+
# for image denoising and JPEG compression artifact reduction
|
762 |
+
self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
|
763 |
+
|
764 |
+
self.apply(self._init_weights)
|
765 |
+
|
766 |
+
def _init_weights(self, m):
|
767 |
+
if isinstance(m, nn.Linear):
|
768 |
+
trunc_normal_(m.weight, std=.02)
|
769 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
770 |
+
nn.init.constant_(m.bias, 0)
|
771 |
+
elif isinstance(m, nn.LayerNorm):
|
772 |
+
nn.init.constant_(m.bias, 0)
|
773 |
+
nn.init.constant_(m.weight, 1.0)
|
774 |
+
|
775 |
+
@torch.jit.ignore
|
776 |
+
def no_weight_decay(self):
|
777 |
+
return {'absolute_pos_embed'}
|
778 |
+
|
779 |
+
@torch.jit.ignore
|
780 |
+
def no_weight_decay_keywords(self):
|
781 |
+
return {'relative_position_bias_table'}
|
782 |
+
|
783 |
+
def check_image_size(self, x):
|
784 |
+
_, _, h, w = x.size()
|
785 |
+
mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
|
786 |
+
mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
|
787 |
+
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
|
788 |
+
return x
|
789 |
+
|
790 |
+
def forward_features(self, x):
|
791 |
+
x_size = (x.shape[2], x.shape[3])
|
792 |
+
x = self.patch_embed(x)
|
793 |
+
if self.ape:
|
794 |
+
x = x + self.absolute_pos_embed
|
795 |
+
x = self.pos_drop(x)
|
796 |
+
|
797 |
+
for layer in self.layers:
|
798 |
+
x = layer(x, x_size)
|
799 |
+
|
800 |
+
x = self.norm(x) # B L C
|
801 |
+
x = self.patch_unembed(x, x_size)
|
802 |
+
|
803 |
+
return x
|
804 |
+
|
805 |
+
def forward(self, x):
|
806 |
+
H, W = x.shape[2:]
|
807 |
+
x = self.check_image_size(x)
|
808 |
+
|
809 |
+
self.mean = self.mean.type_as(x)
|
810 |
+
x = (x - self.mean) * self.img_range
|
811 |
+
|
812 |
+
if self.upsampler == 'pixelshuffle':
|
813 |
+
# for classical SR
|
814 |
+
x = self.conv_first(x)
|
815 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
816 |
+
x = self.conv_before_upsample(x)
|
817 |
+
x = self.conv_last(self.upsample(x))
|
818 |
+
elif self.upsampler == 'pixelshuffledirect':
|
819 |
+
# for lightweight SR
|
820 |
+
x = self.conv_first(x)
|
821 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
822 |
+
x = self.upsample(x)
|
823 |
+
elif self.upsampler == 'nearest+conv':
|
824 |
+
# for real-world SR
|
825 |
+
x = self.conv_first(x)
|
826 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
827 |
+
x = self.conv_before_upsample(x)
|
828 |
+
x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
|
829 |
+
if self.upscale == 4:
|
830 |
+
x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
|
831 |
+
x = self.conv_last(self.lrelu(self.conv_hr(x)))
|
832 |
+
else:
|
833 |
+
# for image denoising and JPEG compression artifact reduction
|
834 |
+
x_first = self.conv_first(x)
|
835 |
+
res = self.conv_after_body(self.forward_features(x_first)) + x_first
|
836 |
+
x = x + self.conv_last(res)
|
837 |
+
|
838 |
+
x = x / self.img_range + self.mean
|
839 |
+
|
840 |
+
return x[:, :, :H*self.upscale, :W*self.upscale]
|
841 |
+
|
842 |
+
def flops(self):
|
843 |
+
flops = 0
|
844 |
+
H, W = self.patches_resolution
|
845 |
+
flops += H * W * 3 * self.embed_dim * 9
|
846 |
+
flops += self.patch_embed.flops()
|
847 |
+
for i, layer in enumerate(self.layers):
|
848 |
+
flops += layer.flops()
|
849 |
+
flops += H * W * 3 * self.embed_dim * self.embed_dim
|
850 |
+
flops += self.upsample.flops()
|
851 |
+
return flops
|
852 |
+
|
853 |
+
|
854 |
+
if __name__ == '__main__':
|
855 |
+
upscale = 4
|
856 |
+
window_size = 8
|
857 |
+
height = (1024 // upscale // window_size + 1) * window_size
|
858 |
+
width = (720 // upscale // window_size + 1) * window_size
|
859 |
+
model = SwinIR(upscale=2, img_size=(height, width),
|
860 |
+
window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
|
861 |
+
embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
|
862 |
+
print(model)
|
863 |
+
print(height, width, model.flops() / 1e9)
|
864 |
+
|
865 |
+
x = torch.randn((1, 3, height, width))
|
866 |
+
x = model(x)
|
867 |
+
print(x.shape)
|
extensions-builtin/SwinIR/swinir_model_arch_v2.py
ADDED
@@ -0,0 +1,1017 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -----------------------------------------------------------------------------------
|
2 |
+
# Swin2SR: Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration, https://arxiv.org/abs/
|
3 |
+
# Written by Conde and Choi et al.
|
4 |
+
# -----------------------------------------------------------------------------------
|
5 |
+
|
6 |
+
import math
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import torch.utils.checkpoint as checkpoint
|
12 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
13 |
+
|
14 |
+
|
15 |
+
class Mlp(nn.Module):
|
16 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
17 |
+
super().__init__()
|
18 |
+
out_features = out_features or in_features
|
19 |
+
hidden_features = hidden_features or in_features
|
20 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
21 |
+
self.act = act_layer()
|
22 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
23 |
+
self.drop = nn.Dropout(drop)
|
24 |
+
|
25 |
+
def forward(self, x):
|
26 |
+
x = self.fc1(x)
|
27 |
+
x = self.act(x)
|
28 |
+
x = self.drop(x)
|
29 |
+
x = self.fc2(x)
|
30 |
+
x = self.drop(x)
|
31 |
+
return x
|
32 |
+
|
33 |
+
|
34 |
+
def window_partition(x, window_size):
|
35 |
+
"""
|
36 |
+
Args:
|
37 |
+
x: (B, H, W, C)
|
38 |
+
window_size (int): window size
|
39 |
+
Returns:
|
40 |
+
windows: (num_windows*B, window_size, window_size, C)
|
41 |
+
"""
|
42 |
+
B, H, W, C = x.shape
|
43 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
44 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
45 |
+
return windows
|
46 |
+
|
47 |
+
|
48 |
+
def window_reverse(windows, window_size, H, W):
|
49 |
+
"""
|
50 |
+
Args:
|
51 |
+
windows: (num_windows*B, window_size, window_size, C)
|
52 |
+
window_size (int): Window size
|
53 |
+
H (int): Height of image
|
54 |
+
W (int): Width of image
|
55 |
+
Returns:
|
56 |
+
x: (B, H, W, C)
|
57 |
+
"""
|
58 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
59 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
60 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
61 |
+
return x
|
62 |
+
|
63 |
+
class WindowAttention(nn.Module):
|
64 |
+
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
65 |
+
It supports both of shifted and non-shifted window.
|
66 |
+
Args:
|
67 |
+
dim (int): Number of input channels.
|
68 |
+
window_size (tuple[int]): The height and width of the window.
|
69 |
+
num_heads (int): Number of attention heads.
|
70 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
71 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
72 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
73 |
+
pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
|
74 |
+
"""
|
75 |
+
|
76 |
+
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
|
77 |
+
pretrained_window_size=[0, 0]):
|
78 |
+
|
79 |
+
super().__init__()
|
80 |
+
self.dim = dim
|
81 |
+
self.window_size = window_size # Wh, Ww
|
82 |
+
self.pretrained_window_size = pretrained_window_size
|
83 |
+
self.num_heads = num_heads
|
84 |
+
|
85 |
+
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
|
86 |
+
|
87 |
+
# mlp to generate continuous relative position bias
|
88 |
+
self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
|
89 |
+
nn.ReLU(inplace=True),
|
90 |
+
nn.Linear(512, num_heads, bias=False))
|
91 |
+
|
92 |
+
# get relative_coords_table
|
93 |
+
relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
|
94 |
+
relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
|
95 |
+
relative_coords_table = torch.stack(
|
96 |
+
torch.meshgrid([relative_coords_h,
|
97 |
+
relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
|
98 |
+
if pretrained_window_size[0] > 0:
|
99 |
+
relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
|
100 |
+
relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
|
101 |
+
else:
|
102 |
+
relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
|
103 |
+
relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
|
104 |
+
relative_coords_table *= 8 # normalize to -8, 8
|
105 |
+
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
|
106 |
+
torch.abs(relative_coords_table) + 1.0) / np.log2(8)
|
107 |
+
|
108 |
+
self.register_buffer("relative_coords_table", relative_coords_table)
|
109 |
+
|
110 |
+
# get pair-wise relative position index for each token inside the window
|
111 |
+
coords_h = torch.arange(self.window_size[0])
|
112 |
+
coords_w = torch.arange(self.window_size[1])
|
113 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
114 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
115 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
116 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
117 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
118 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
119 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
120 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
121 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
122 |
+
|
123 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=False)
|
124 |
+
if qkv_bias:
|
125 |
+
self.q_bias = nn.Parameter(torch.zeros(dim))
|
126 |
+
self.v_bias = nn.Parameter(torch.zeros(dim))
|
127 |
+
else:
|
128 |
+
self.q_bias = None
|
129 |
+
self.v_bias = None
|
130 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
131 |
+
self.proj = nn.Linear(dim, dim)
|
132 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
133 |
+
self.softmax = nn.Softmax(dim=-1)
|
134 |
+
|
135 |
+
def forward(self, x, mask=None):
|
136 |
+
"""
|
137 |
+
Args:
|
138 |
+
x: input features with shape of (num_windows*B, N, C)
|
139 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
140 |
+
"""
|
141 |
+
B_, N, C = x.shape
|
142 |
+
qkv_bias = None
|
143 |
+
if self.q_bias is not None:
|
144 |
+
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
|
145 |
+
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
146 |
+
qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
147 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
148 |
+
|
149 |
+
# cosine attention
|
150 |
+
attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
|
151 |
+
logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).to(self.logit_scale.device)).exp()
|
152 |
+
attn = attn * logit_scale
|
153 |
+
|
154 |
+
relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
|
155 |
+
relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
156 |
+
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
157 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
158 |
+
relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
|
159 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
160 |
+
|
161 |
+
if mask is not None:
|
162 |
+
nW = mask.shape[0]
|
163 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
164 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
165 |
+
attn = self.softmax(attn)
|
166 |
+
else:
|
167 |
+
attn = self.softmax(attn)
|
168 |
+
|
169 |
+
attn = self.attn_drop(attn)
|
170 |
+
|
171 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
172 |
+
x = self.proj(x)
|
173 |
+
x = self.proj_drop(x)
|
174 |
+
return x
|
175 |
+
|
176 |
+
def extra_repr(self) -> str:
|
177 |
+
return f'dim={self.dim}, window_size={self.window_size}, ' \
|
178 |
+
f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'
|
179 |
+
|
180 |
+
def flops(self, N):
|
181 |
+
# calculate flops for 1 window with token length of N
|
182 |
+
flops = 0
|
183 |
+
# qkv = self.qkv(x)
|
184 |
+
flops += N * self.dim * 3 * self.dim
|
185 |
+
# attn = (q @ k.transpose(-2, -1))
|
186 |
+
flops += self.num_heads * N * (self.dim // self.num_heads) * N
|
187 |
+
# x = (attn @ v)
|
188 |
+
flops += self.num_heads * N * N * (self.dim // self.num_heads)
|
189 |
+
# x = self.proj(x)
|
190 |
+
flops += N * self.dim * self.dim
|
191 |
+
return flops
|
192 |
+
|
193 |
+
class SwinTransformerBlock(nn.Module):
|
194 |
+
r""" Swin Transformer Block.
|
195 |
+
Args:
|
196 |
+
dim (int): Number of input channels.
|
197 |
+
input_resolution (tuple[int]): Input resulotion.
|
198 |
+
num_heads (int): Number of attention heads.
|
199 |
+
window_size (int): Window size.
|
200 |
+
shift_size (int): Shift size for SW-MSA.
|
201 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
202 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
203 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
204 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
205 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
206 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
207 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
208 |
+
pretrained_window_size (int): Window size in pre-training.
|
209 |
+
"""
|
210 |
+
|
211 |
+
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
|
212 |
+
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
|
213 |
+
act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0):
|
214 |
+
super().__init__()
|
215 |
+
self.dim = dim
|
216 |
+
self.input_resolution = input_resolution
|
217 |
+
self.num_heads = num_heads
|
218 |
+
self.window_size = window_size
|
219 |
+
self.shift_size = shift_size
|
220 |
+
self.mlp_ratio = mlp_ratio
|
221 |
+
if min(self.input_resolution) <= self.window_size:
|
222 |
+
# if window size is larger than input resolution, we don't partition windows
|
223 |
+
self.shift_size = 0
|
224 |
+
self.window_size = min(self.input_resolution)
|
225 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
226 |
+
|
227 |
+
self.norm1 = norm_layer(dim)
|
228 |
+
self.attn = WindowAttention(
|
229 |
+
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
|
230 |
+
qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
|
231 |
+
pretrained_window_size=to_2tuple(pretrained_window_size))
|
232 |
+
|
233 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
234 |
+
self.norm2 = norm_layer(dim)
|
235 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
236 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
237 |
+
|
238 |
+
if self.shift_size > 0:
|
239 |
+
attn_mask = self.calculate_mask(self.input_resolution)
|
240 |
+
else:
|
241 |
+
attn_mask = None
|
242 |
+
|
243 |
+
self.register_buffer("attn_mask", attn_mask)
|
244 |
+
|
245 |
+
def calculate_mask(self, x_size):
|
246 |
+
# calculate attention mask for SW-MSA
|
247 |
+
H, W = x_size
|
248 |
+
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
249 |
+
h_slices = (slice(0, -self.window_size),
|
250 |
+
slice(-self.window_size, -self.shift_size),
|
251 |
+
slice(-self.shift_size, None))
|
252 |
+
w_slices = (slice(0, -self.window_size),
|
253 |
+
slice(-self.window_size, -self.shift_size),
|
254 |
+
slice(-self.shift_size, None))
|
255 |
+
cnt = 0
|
256 |
+
for h in h_slices:
|
257 |
+
for w in w_slices:
|
258 |
+
img_mask[:, h, w, :] = cnt
|
259 |
+
cnt += 1
|
260 |
+
|
261 |
+
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
262 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
263 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
264 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
265 |
+
|
266 |
+
return attn_mask
|
267 |
+
|
268 |
+
def forward(self, x, x_size):
|
269 |
+
H, W = x_size
|
270 |
+
B, L, C = x.shape
|
271 |
+
#assert L == H * W, "input feature has wrong size"
|
272 |
+
|
273 |
+
shortcut = x
|
274 |
+
x = x.view(B, H, W, C)
|
275 |
+
|
276 |
+
# cyclic shift
|
277 |
+
if self.shift_size > 0:
|
278 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
279 |
+
else:
|
280 |
+
shifted_x = x
|
281 |
+
|
282 |
+
# partition windows
|
283 |
+
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
284 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
285 |
+
|
286 |
+
# W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
|
287 |
+
if self.input_resolution == x_size:
|
288 |
+
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
|
289 |
+
else:
|
290 |
+
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
|
291 |
+
|
292 |
+
# merge windows
|
293 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
294 |
+
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
295 |
+
|
296 |
+
# reverse cyclic shift
|
297 |
+
if self.shift_size > 0:
|
298 |
+
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
299 |
+
else:
|
300 |
+
x = shifted_x
|
301 |
+
x = x.view(B, H * W, C)
|
302 |
+
x = shortcut + self.drop_path(self.norm1(x))
|
303 |
+
|
304 |
+
# FFN
|
305 |
+
x = x + self.drop_path(self.norm2(self.mlp(x)))
|
306 |
+
|
307 |
+
return x
|
308 |
+
|
309 |
+
def extra_repr(self) -> str:
|
310 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
|
311 |
+
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
|
312 |
+
|
313 |
+
def flops(self):
|
314 |
+
flops = 0
|
315 |
+
H, W = self.input_resolution
|
316 |
+
# norm1
|
317 |
+
flops += self.dim * H * W
|
318 |
+
# W-MSA/SW-MSA
|
319 |
+
nW = H * W / self.window_size / self.window_size
|
320 |
+
flops += nW * self.attn.flops(self.window_size * self.window_size)
|
321 |
+
# mlp
|
322 |
+
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
|
323 |
+
# norm2
|
324 |
+
flops += self.dim * H * W
|
325 |
+
return flops
|
326 |
+
|
327 |
+
class PatchMerging(nn.Module):
|
328 |
+
r""" Patch Merging Layer.
|
329 |
+
Args:
|
330 |
+
input_resolution (tuple[int]): Resolution of input feature.
|
331 |
+
dim (int): Number of input channels.
|
332 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
333 |
+
"""
|
334 |
+
|
335 |
+
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
|
336 |
+
super().__init__()
|
337 |
+
self.input_resolution = input_resolution
|
338 |
+
self.dim = dim
|
339 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
340 |
+
self.norm = norm_layer(2 * dim)
|
341 |
+
|
342 |
+
def forward(self, x):
|
343 |
+
"""
|
344 |
+
x: B, H*W, C
|
345 |
+
"""
|
346 |
+
H, W = self.input_resolution
|
347 |
+
B, L, C = x.shape
|
348 |
+
assert L == H * W, "input feature has wrong size"
|
349 |
+
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
|
350 |
+
|
351 |
+
x = x.view(B, H, W, C)
|
352 |
+
|
353 |
+
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
354 |
+
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
355 |
+
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
356 |
+
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
357 |
+
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
358 |
+
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
359 |
+
|
360 |
+
x = self.reduction(x)
|
361 |
+
x = self.norm(x)
|
362 |
+
|
363 |
+
return x
|
364 |
+
|
365 |
+
def extra_repr(self) -> str:
|
366 |
+
return f"input_resolution={self.input_resolution}, dim={self.dim}"
|
367 |
+
|
368 |
+
def flops(self):
|
369 |
+
H, W = self.input_resolution
|
370 |
+
flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
|
371 |
+
flops += H * W * self.dim // 2
|
372 |
+
return flops
|
373 |
+
|
374 |
+
class BasicLayer(nn.Module):
|
375 |
+
""" A basic Swin Transformer layer for one stage.
|
376 |
+
Args:
|
377 |
+
dim (int): Number of input channels.
|
378 |
+
input_resolution (tuple[int]): Input resolution.
|
379 |
+
depth (int): Number of blocks.
|
380 |
+
num_heads (int): Number of attention heads.
|
381 |
+
window_size (int): Local window size.
|
382 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
383 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
384 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
385 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
386 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
387 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
388 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
389 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
390 |
+
pretrained_window_size (int): Local window size in pre-training.
|
391 |
+
"""
|
392 |
+
|
393 |
+
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
394 |
+
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
|
395 |
+
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
|
396 |
+
pretrained_window_size=0):
|
397 |
+
|
398 |
+
super().__init__()
|
399 |
+
self.dim = dim
|
400 |
+
self.input_resolution = input_resolution
|
401 |
+
self.depth = depth
|
402 |
+
self.use_checkpoint = use_checkpoint
|
403 |
+
|
404 |
+
# build blocks
|
405 |
+
self.blocks = nn.ModuleList([
|
406 |
+
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
|
407 |
+
num_heads=num_heads, window_size=window_size,
|
408 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
409 |
+
mlp_ratio=mlp_ratio,
|
410 |
+
qkv_bias=qkv_bias,
|
411 |
+
drop=drop, attn_drop=attn_drop,
|
412 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
413 |
+
norm_layer=norm_layer,
|
414 |
+
pretrained_window_size=pretrained_window_size)
|
415 |
+
for i in range(depth)])
|
416 |
+
|
417 |
+
# patch merging layer
|
418 |
+
if downsample is not None:
|
419 |
+
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
|
420 |
+
else:
|
421 |
+
self.downsample = None
|
422 |
+
|
423 |
+
def forward(self, x, x_size):
|
424 |
+
for blk in self.blocks:
|
425 |
+
if self.use_checkpoint:
|
426 |
+
x = checkpoint.checkpoint(blk, x, x_size)
|
427 |
+
else:
|
428 |
+
x = blk(x, x_size)
|
429 |
+
if self.downsample is not None:
|
430 |
+
x = self.downsample(x)
|
431 |
+
return x
|
432 |
+
|
433 |
+
def extra_repr(self) -> str:
|
434 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
435 |
+
|
436 |
+
def flops(self):
|
437 |
+
flops = 0
|
438 |
+
for blk in self.blocks:
|
439 |
+
flops += blk.flops()
|
440 |
+
if self.downsample is not None:
|
441 |
+
flops += self.downsample.flops()
|
442 |
+
return flops
|
443 |
+
|
444 |
+
def _init_respostnorm(self):
|
445 |
+
for blk in self.blocks:
|
446 |
+
nn.init.constant_(blk.norm1.bias, 0)
|
447 |
+
nn.init.constant_(blk.norm1.weight, 0)
|
448 |
+
nn.init.constant_(blk.norm2.bias, 0)
|
449 |
+
nn.init.constant_(blk.norm2.weight, 0)
|
450 |
+
|
451 |
+
class PatchEmbed(nn.Module):
|
452 |
+
r""" Image to Patch Embedding
|
453 |
+
Args:
|
454 |
+
img_size (int): Image size. Default: 224.
|
455 |
+
patch_size (int): Patch token size. Default: 4.
|
456 |
+
in_chans (int): Number of input image channels. Default: 3.
|
457 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
458 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
459 |
+
"""
|
460 |
+
|
461 |
+
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
462 |
+
super().__init__()
|
463 |
+
img_size = to_2tuple(img_size)
|
464 |
+
patch_size = to_2tuple(patch_size)
|
465 |
+
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
466 |
+
self.img_size = img_size
|
467 |
+
self.patch_size = patch_size
|
468 |
+
self.patches_resolution = patches_resolution
|
469 |
+
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
470 |
+
|
471 |
+
self.in_chans = in_chans
|
472 |
+
self.embed_dim = embed_dim
|
473 |
+
|
474 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
475 |
+
if norm_layer is not None:
|
476 |
+
self.norm = norm_layer(embed_dim)
|
477 |
+
else:
|
478 |
+
self.norm = None
|
479 |
+
|
480 |
+
def forward(self, x):
|
481 |
+
B, C, H, W = x.shape
|
482 |
+
# FIXME look at relaxing size constraints
|
483 |
+
# assert H == self.img_size[0] and W == self.img_size[1],
|
484 |
+
# f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
485 |
+
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
|
486 |
+
if self.norm is not None:
|
487 |
+
x = self.norm(x)
|
488 |
+
return x
|
489 |
+
|
490 |
+
def flops(self):
|
491 |
+
Ho, Wo = self.patches_resolution
|
492 |
+
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
493 |
+
if self.norm is not None:
|
494 |
+
flops += Ho * Wo * self.embed_dim
|
495 |
+
return flops
|
496 |
+
|
497 |
+
class RSTB(nn.Module):
|
498 |
+
"""Residual Swin Transformer Block (RSTB).
|
499 |
+
|
500 |
+
Args:
|
501 |
+
dim (int): Number of input channels.
|
502 |
+
input_resolution (tuple[int]): Input resolution.
|
503 |
+
depth (int): Number of blocks.
|
504 |
+
num_heads (int): Number of attention heads.
|
505 |
+
window_size (int): Local window size.
|
506 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
507 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
508 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
509 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
510 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
511 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
512 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
513 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
514 |
+
img_size: Input image size.
|
515 |
+
patch_size: Patch size.
|
516 |
+
resi_connection: The convolutional block before residual connection.
|
517 |
+
"""
|
518 |
+
|
519 |
+
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
520 |
+
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
|
521 |
+
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
|
522 |
+
img_size=224, patch_size=4, resi_connection='1conv'):
|
523 |
+
super(RSTB, self).__init__()
|
524 |
+
|
525 |
+
self.dim = dim
|
526 |
+
self.input_resolution = input_resolution
|
527 |
+
|
528 |
+
self.residual_group = BasicLayer(dim=dim,
|
529 |
+
input_resolution=input_resolution,
|
530 |
+
depth=depth,
|
531 |
+
num_heads=num_heads,
|
532 |
+
window_size=window_size,
|
533 |
+
mlp_ratio=mlp_ratio,
|
534 |
+
qkv_bias=qkv_bias,
|
535 |
+
drop=drop, attn_drop=attn_drop,
|
536 |
+
drop_path=drop_path,
|
537 |
+
norm_layer=norm_layer,
|
538 |
+
downsample=downsample,
|
539 |
+
use_checkpoint=use_checkpoint)
|
540 |
+
|
541 |
+
if resi_connection == '1conv':
|
542 |
+
self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
|
543 |
+
elif resi_connection == '3conv':
|
544 |
+
# to save parameters and memory
|
545 |
+
self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
546 |
+
nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
|
547 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
548 |
+
nn.Conv2d(dim // 4, dim, 3, 1, 1))
|
549 |
+
|
550 |
+
self.patch_embed = PatchEmbed(
|
551 |
+
img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim,
|
552 |
+
norm_layer=None)
|
553 |
+
|
554 |
+
self.patch_unembed = PatchUnEmbed(
|
555 |
+
img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim,
|
556 |
+
norm_layer=None)
|
557 |
+
|
558 |
+
def forward(self, x, x_size):
|
559 |
+
return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
|
560 |
+
|
561 |
+
def flops(self):
|
562 |
+
flops = 0
|
563 |
+
flops += self.residual_group.flops()
|
564 |
+
H, W = self.input_resolution
|
565 |
+
flops += H * W * self.dim * self.dim * 9
|
566 |
+
flops += self.patch_embed.flops()
|
567 |
+
flops += self.patch_unembed.flops()
|
568 |
+
|
569 |
+
return flops
|
570 |
+
|
571 |
+
class PatchUnEmbed(nn.Module):
|
572 |
+
r""" Image to Patch Unembedding
|
573 |
+
|
574 |
+
Args:
|
575 |
+
img_size (int): Image size. Default: 224.
|
576 |
+
patch_size (int): Patch token size. Default: 4.
|
577 |
+
in_chans (int): Number of input image channels. Default: 3.
|
578 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
579 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
580 |
+
"""
|
581 |
+
|
582 |
+
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
583 |
+
super().__init__()
|
584 |
+
img_size = to_2tuple(img_size)
|
585 |
+
patch_size = to_2tuple(patch_size)
|
586 |
+
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
587 |
+
self.img_size = img_size
|
588 |
+
self.patch_size = patch_size
|
589 |
+
self.patches_resolution = patches_resolution
|
590 |
+
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
591 |
+
|
592 |
+
self.in_chans = in_chans
|
593 |
+
self.embed_dim = embed_dim
|
594 |
+
|
595 |
+
def forward(self, x, x_size):
|
596 |
+
B, HW, C = x.shape
|
597 |
+
x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
|
598 |
+
return x
|
599 |
+
|
600 |
+
def flops(self):
|
601 |
+
flops = 0
|
602 |
+
return flops
|
603 |
+
|
604 |
+
|
605 |
+
class Upsample(nn.Sequential):
|
606 |
+
"""Upsample module.
|
607 |
+
|
608 |
+
Args:
|
609 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
610 |
+
num_feat (int): Channel number of intermediate features.
|
611 |
+
"""
|
612 |
+
|
613 |
+
def __init__(self, scale, num_feat):
|
614 |
+
m = []
|
615 |
+
if (scale & (scale - 1)) == 0: # scale = 2^n
|
616 |
+
for _ in range(int(math.log(scale, 2))):
|
617 |
+
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
|
618 |
+
m.append(nn.PixelShuffle(2))
|
619 |
+
elif scale == 3:
|
620 |
+
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
|
621 |
+
m.append(nn.PixelShuffle(3))
|
622 |
+
else:
|
623 |
+
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
|
624 |
+
super(Upsample, self).__init__(*m)
|
625 |
+
|
626 |
+
class Upsample_hf(nn.Sequential):
|
627 |
+
"""Upsample module.
|
628 |
+
|
629 |
+
Args:
|
630 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
631 |
+
num_feat (int): Channel number of intermediate features.
|
632 |
+
"""
|
633 |
+
|
634 |
+
def __init__(self, scale, num_feat):
|
635 |
+
m = []
|
636 |
+
if (scale & (scale - 1)) == 0: # scale = 2^n
|
637 |
+
for _ in range(int(math.log(scale, 2))):
|
638 |
+
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
|
639 |
+
m.append(nn.PixelShuffle(2))
|
640 |
+
elif scale == 3:
|
641 |
+
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
|
642 |
+
m.append(nn.PixelShuffle(3))
|
643 |
+
else:
|
644 |
+
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
|
645 |
+
super(Upsample_hf, self).__init__(*m)
|
646 |
+
|
647 |
+
|
648 |
+
class UpsampleOneStep(nn.Sequential):
|
649 |
+
"""UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
|
650 |
+
Used in lightweight SR to save parameters.
|
651 |
+
|
652 |
+
Args:
|
653 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
654 |
+
num_feat (int): Channel number of intermediate features.
|
655 |
+
|
656 |
+
"""
|
657 |
+
|
658 |
+
def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
|
659 |
+
self.num_feat = num_feat
|
660 |
+
self.input_resolution = input_resolution
|
661 |
+
m = []
|
662 |
+
m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
|
663 |
+
m.append(nn.PixelShuffle(scale))
|
664 |
+
super(UpsampleOneStep, self).__init__(*m)
|
665 |
+
|
666 |
+
def flops(self):
|
667 |
+
H, W = self.input_resolution
|
668 |
+
flops = H * W * self.num_feat * 3 * 9
|
669 |
+
return flops
|
670 |
+
|
671 |
+
|
672 |
+
|
673 |
+
class Swin2SR(nn.Module):
|
674 |
+
r""" Swin2SR
|
675 |
+
A PyTorch impl of : `Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration`.
|
676 |
+
|
677 |
+
Args:
|
678 |
+
img_size (int | tuple(int)): Input image size. Default 64
|
679 |
+
patch_size (int | tuple(int)): Patch size. Default: 1
|
680 |
+
in_chans (int): Number of input image channels. Default: 3
|
681 |
+
embed_dim (int): Patch embedding dimension. Default: 96
|
682 |
+
depths (tuple(int)): Depth of each Swin Transformer layer.
|
683 |
+
num_heads (tuple(int)): Number of attention heads in different layers.
|
684 |
+
window_size (int): Window size. Default: 7
|
685 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
686 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
687 |
+
drop_rate (float): Dropout rate. Default: 0
|
688 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0
|
689 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
690 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
691 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
|
692 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True
|
693 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
|
694 |
+
upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
|
695 |
+
img_range: Image range. 1. or 255.
|
696 |
+
upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
|
697 |
+
resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
|
698 |
+
"""
|
699 |
+
|
700 |
+
def __init__(self, img_size=64, patch_size=1, in_chans=3,
|
701 |
+
embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
|
702 |
+
window_size=7, mlp_ratio=4., qkv_bias=True,
|
703 |
+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
704 |
+
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
705 |
+
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
|
706 |
+
**kwargs):
|
707 |
+
super(Swin2SR, self).__init__()
|
708 |
+
num_in_ch = in_chans
|
709 |
+
num_out_ch = in_chans
|
710 |
+
num_feat = 64
|
711 |
+
self.img_range = img_range
|
712 |
+
if in_chans == 3:
|
713 |
+
rgb_mean = (0.4488, 0.4371, 0.4040)
|
714 |
+
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
|
715 |
+
else:
|
716 |
+
self.mean = torch.zeros(1, 1, 1, 1)
|
717 |
+
self.upscale = upscale
|
718 |
+
self.upsampler = upsampler
|
719 |
+
self.window_size = window_size
|
720 |
+
|
721 |
+
#####################################################################################################
|
722 |
+
################################### 1, shallow feature extraction ###################################
|
723 |
+
self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
|
724 |
+
|
725 |
+
#####################################################################################################
|
726 |
+
################################### 2, deep feature extraction ######################################
|
727 |
+
self.num_layers = len(depths)
|
728 |
+
self.embed_dim = embed_dim
|
729 |
+
self.ape = ape
|
730 |
+
self.patch_norm = patch_norm
|
731 |
+
self.num_features = embed_dim
|
732 |
+
self.mlp_ratio = mlp_ratio
|
733 |
+
|
734 |
+
# split image into non-overlapping patches
|
735 |
+
self.patch_embed = PatchEmbed(
|
736 |
+
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
|
737 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
738 |
+
num_patches = self.patch_embed.num_patches
|
739 |
+
patches_resolution = self.patch_embed.patches_resolution
|
740 |
+
self.patches_resolution = patches_resolution
|
741 |
+
|
742 |
+
# merge non-overlapping patches into image
|
743 |
+
self.patch_unembed = PatchUnEmbed(
|
744 |
+
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
|
745 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
746 |
+
|
747 |
+
# absolute position embedding
|
748 |
+
if self.ape:
|
749 |
+
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
750 |
+
trunc_normal_(self.absolute_pos_embed, std=.02)
|
751 |
+
|
752 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
753 |
+
|
754 |
+
# stochastic depth
|
755 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
756 |
+
|
757 |
+
# build Residual Swin Transformer blocks (RSTB)
|
758 |
+
self.layers = nn.ModuleList()
|
759 |
+
for i_layer in range(self.num_layers):
|
760 |
+
layer = RSTB(dim=embed_dim,
|
761 |
+
input_resolution=(patches_resolution[0],
|
762 |
+
patches_resolution[1]),
|
763 |
+
depth=depths[i_layer],
|
764 |
+
num_heads=num_heads[i_layer],
|
765 |
+
window_size=window_size,
|
766 |
+
mlp_ratio=self.mlp_ratio,
|
767 |
+
qkv_bias=qkv_bias,
|
768 |
+
drop=drop_rate, attn_drop=attn_drop_rate,
|
769 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
|
770 |
+
norm_layer=norm_layer,
|
771 |
+
downsample=None,
|
772 |
+
use_checkpoint=use_checkpoint,
|
773 |
+
img_size=img_size,
|
774 |
+
patch_size=patch_size,
|
775 |
+
resi_connection=resi_connection
|
776 |
+
|
777 |
+
)
|
778 |
+
self.layers.append(layer)
|
779 |
+
|
780 |
+
if self.upsampler == 'pixelshuffle_hf':
|
781 |
+
self.layers_hf = nn.ModuleList()
|
782 |
+
for i_layer in range(self.num_layers):
|
783 |
+
layer = RSTB(dim=embed_dim,
|
784 |
+
input_resolution=(patches_resolution[0],
|
785 |
+
patches_resolution[1]),
|
786 |
+
depth=depths[i_layer],
|
787 |
+
num_heads=num_heads[i_layer],
|
788 |
+
window_size=window_size,
|
789 |
+
mlp_ratio=self.mlp_ratio,
|
790 |
+
qkv_bias=qkv_bias,
|
791 |
+
drop=drop_rate, attn_drop=attn_drop_rate,
|
792 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
|
793 |
+
norm_layer=norm_layer,
|
794 |
+
downsample=None,
|
795 |
+
use_checkpoint=use_checkpoint,
|
796 |
+
img_size=img_size,
|
797 |
+
patch_size=patch_size,
|
798 |
+
resi_connection=resi_connection
|
799 |
+
|
800 |
+
)
|
801 |
+
self.layers_hf.append(layer)
|
802 |
+
|
803 |
+
self.norm = norm_layer(self.num_features)
|
804 |
+
|
805 |
+
# build the last conv layer in deep feature extraction
|
806 |
+
if resi_connection == '1conv':
|
807 |
+
self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
|
808 |
+
elif resi_connection == '3conv':
|
809 |
+
# to save parameters and memory
|
810 |
+
self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
|
811 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
812 |
+
nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
|
813 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
814 |
+
nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
|
815 |
+
|
816 |
+
#####################################################################################################
|
817 |
+
################################ 3, high quality image reconstruction ################################
|
818 |
+
if self.upsampler == 'pixelshuffle':
|
819 |
+
# for classical SR
|
820 |
+
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
821 |
+
nn.LeakyReLU(inplace=True))
|
822 |
+
self.upsample = Upsample(upscale, num_feat)
|
823 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
824 |
+
elif self.upsampler == 'pixelshuffle_aux':
|
825 |
+
self.conv_bicubic = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
826 |
+
self.conv_before_upsample = nn.Sequential(
|
827 |
+
nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
828 |
+
nn.LeakyReLU(inplace=True))
|
829 |
+
self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
830 |
+
self.conv_after_aux = nn.Sequential(
|
831 |
+
nn.Conv2d(3, num_feat, 3, 1, 1),
|
832 |
+
nn.LeakyReLU(inplace=True))
|
833 |
+
self.upsample = Upsample(upscale, num_feat)
|
834 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
835 |
+
|
836 |
+
elif self.upsampler == 'pixelshuffle_hf':
|
837 |
+
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
838 |
+
nn.LeakyReLU(inplace=True))
|
839 |
+
self.upsample = Upsample(upscale, num_feat)
|
840 |
+
self.upsample_hf = Upsample_hf(upscale, num_feat)
|
841 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
842 |
+
self.conv_first_hf = nn.Sequential(nn.Conv2d(num_feat, embed_dim, 3, 1, 1),
|
843 |
+
nn.LeakyReLU(inplace=True))
|
844 |
+
self.conv_after_body_hf = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
|
845 |
+
self.conv_before_upsample_hf = nn.Sequential(
|
846 |
+
nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
847 |
+
nn.LeakyReLU(inplace=True))
|
848 |
+
self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
849 |
+
|
850 |
+
elif self.upsampler == 'pixelshuffledirect':
|
851 |
+
# for lightweight SR (to save parameters)
|
852 |
+
self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
|
853 |
+
(patches_resolution[0], patches_resolution[1]))
|
854 |
+
elif self.upsampler == 'nearest+conv':
|
855 |
+
# for real-world SR (less artifacts)
|
856 |
+
assert self.upscale == 4, 'only support x4 now.'
|
857 |
+
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
858 |
+
nn.LeakyReLU(inplace=True))
|
859 |
+
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
860 |
+
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
861 |
+
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
862 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
863 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
864 |
+
else:
|
865 |
+
# for image denoising and JPEG compression artifact reduction
|
866 |
+
self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
|
867 |
+
|
868 |
+
self.apply(self._init_weights)
|
869 |
+
|
870 |
+
def _init_weights(self, m):
|
871 |
+
if isinstance(m, nn.Linear):
|
872 |
+
trunc_normal_(m.weight, std=.02)
|
873 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
874 |
+
nn.init.constant_(m.bias, 0)
|
875 |
+
elif isinstance(m, nn.LayerNorm):
|
876 |
+
nn.init.constant_(m.bias, 0)
|
877 |
+
nn.init.constant_(m.weight, 1.0)
|
878 |
+
|
879 |
+
@torch.jit.ignore
|
880 |
+
def no_weight_decay(self):
|
881 |
+
return {'absolute_pos_embed'}
|
882 |
+
|
883 |
+
@torch.jit.ignore
|
884 |
+
def no_weight_decay_keywords(self):
|
885 |
+
return {'relative_position_bias_table'}
|
886 |
+
|
887 |
+
def check_image_size(self, x):
|
888 |
+
_, _, h, w = x.size()
|
889 |
+
mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
|
890 |
+
mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
|
891 |
+
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
|
892 |
+
return x
|
893 |
+
|
894 |
+
def forward_features(self, x):
|
895 |
+
x_size = (x.shape[2], x.shape[3])
|
896 |
+
x = self.patch_embed(x)
|
897 |
+
if self.ape:
|
898 |
+
x = x + self.absolute_pos_embed
|
899 |
+
x = self.pos_drop(x)
|
900 |
+
|
901 |
+
for layer in self.layers:
|
902 |
+
x = layer(x, x_size)
|
903 |
+
|
904 |
+
x = self.norm(x) # B L C
|
905 |
+
x = self.patch_unembed(x, x_size)
|
906 |
+
|
907 |
+
return x
|
908 |
+
|
909 |
+
def forward_features_hf(self, x):
|
910 |
+
x_size = (x.shape[2], x.shape[3])
|
911 |
+
x = self.patch_embed(x)
|
912 |
+
if self.ape:
|
913 |
+
x = x + self.absolute_pos_embed
|
914 |
+
x = self.pos_drop(x)
|
915 |
+
|
916 |
+
for layer in self.layers_hf:
|
917 |
+
x = layer(x, x_size)
|
918 |
+
|
919 |
+
x = self.norm(x) # B L C
|
920 |
+
x = self.patch_unembed(x, x_size)
|
921 |
+
|
922 |
+
return x
|
923 |
+
|
924 |
+
def forward(self, x):
|
925 |
+
H, W = x.shape[2:]
|
926 |
+
x = self.check_image_size(x)
|
927 |
+
|
928 |
+
self.mean = self.mean.type_as(x)
|
929 |
+
x = (x - self.mean) * self.img_range
|
930 |
+
|
931 |
+
if self.upsampler == 'pixelshuffle':
|
932 |
+
# for classical SR
|
933 |
+
x = self.conv_first(x)
|
934 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
935 |
+
x = self.conv_before_upsample(x)
|
936 |
+
x = self.conv_last(self.upsample(x))
|
937 |
+
elif self.upsampler == 'pixelshuffle_aux':
|
938 |
+
bicubic = F.interpolate(x, size=(H * self.upscale, W * self.upscale), mode='bicubic', align_corners=False)
|
939 |
+
bicubic = self.conv_bicubic(bicubic)
|
940 |
+
x = self.conv_first(x)
|
941 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
942 |
+
x = self.conv_before_upsample(x)
|
943 |
+
aux = self.conv_aux(x) # b, 3, LR_H, LR_W
|
944 |
+
x = self.conv_after_aux(aux)
|
945 |
+
x = self.upsample(x)[:, :, :H * self.upscale, :W * self.upscale] + bicubic[:, :, :H * self.upscale, :W * self.upscale]
|
946 |
+
x = self.conv_last(x)
|
947 |
+
aux = aux / self.img_range + self.mean
|
948 |
+
elif self.upsampler == 'pixelshuffle_hf':
|
949 |
+
# for classical SR with HF
|
950 |
+
x = self.conv_first(x)
|
951 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
952 |
+
x_before = self.conv_before_upsample(x)
|
953 |
+
x_out = self.conv_last(self.upsample(x_before))
|
954 |
+
|
955 |
+
x_hf = self.conv_first_hf(x_before)
|
956 |
+
x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf
|
957 |
+
x_hf = self.conv_before_upsample_hf(x_hf)
|
958 |
+
x_hf = self.conv_last_hf(self.upsample_hf(x_hf))
|
959 |
+
x = x_out + x_hf
|
960 |
+
x_hf = x_hf / self.img_range + self.mean
|
961 |
+
|
962 |
+
elif self.upsampler == 'pixelshuffledirect':
|
963 |
+
# for lightweight SR
|
964 |
+
x = self.conv_first(x)
|
965 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
966 |
+
x = self.upsample(x)
|
967 |
+
elif self.upsampler == 'nearest+conv':
|
968 |
+
# for real-world SR
|
969 |
+
x = self.conv_first(x)
|
970 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
971 |
+
x = self.conv_before_upsample(x)
|
972 |
+
x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
|
973 |
+
x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
|
974 |
+
x = self.conv_last(self.lrelu(self.conv_hr(x)))
|
975 |
+
else:
|
976 |
+
# for image denoising and JPEG compression artifact reduction
|
977 |
+
x_first = self.conv_first(x)
|
978 |
+
res = self.conv_after_body(self.forward_features(x_first)) + x_first
|
979 |
+
x = x + self.conv_last(res)
|
980 |
+
|
981 |
+
x = x / self.img_range + self.mean
|
982 |
+
if self.upsampler == "pixelshuffle_aux":
|
983 |
+
return x[:, :, :H*self.upscale, :W*self.upscale], aux
|
984 |
+
|
985 |
+
elif self.upsampler == "pixelshuffle_hf":
|
986 |
+
x_out = x_out / self.img_range + self.mean
|
987 |
+
return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale]
|
988 |
+
|
989 |
+
else:
|
990 |
+
return x[:, :, :H*self.upscale, :W*self.upscale]
|
991 |
+
|
992 |
+
def flops(self):
|
993 |
+
flops = 0
|
994 |
+
H, W = self.patches_resolution
|
995 |
+
flops += H * W * 3 * self.embed_dim * 9
|
996 |
+
flops += self.patch_embed.flops()
|
997 |
+
for i, layer in enumerate(self.layers):
|
998 |
+
flops += layer.flops()
|
999 |
+
flops += H * W * 3 * self.embed_dim * self.embed_dim
|
1000 |
+
flops += self.upsample.flops()
|
1001 |
+
return flops
|
1002 |
+
|
1003 |
+
|
1004 |
+
if __name__ == '__main__':
|
1005 |
+
upscale = 4
|
1006 |
+
window_size = 8
|
1007 |
+
height = (1024 // upscale // window_size + 1) * window_size
|
1008 |
+
width = (720 // upscale // window_size + 1) * window_size
|
1009 |
+
model = Swin2SR(upscale=2, img_size=(height, width),
|
1010 |
+
window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
|
1011 |
+
embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
|
1012 |
+
print(model)
|
1013 |
+
print(height, width, model.flops() / 1e9)
|
1014 |
+
|
1015 |
+
x = torch.randn((1, 3, height, width))
|
1016 |
+
x = model(x)
|
1017 |
+
print(x.shape)
|
extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Stable Diffusion WebUI - Bracket checker
|
2 |
+
// Version 1.0
|
3 |
+
// By Hingashi no Florin/Bwin4L
|
4 |
+
// Counts open and closed brackets (round, square, curly) in the prompt and negative prompt text boxes in the txt2img and img2img tabs.
|
5 |
+
// If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong.
|
6 |
+
|
7 |
+
function checkBrackets(evt, textArea, counterElt) {
|
8 |
+
errorStringParen = '(...) - Different number of opening and closing parentheses detected.\n';
|
9 |
+
errorStringSquare = '[...] - Different number of opening and closing square brackets detected.\n';
|
10 |
+
errorStringCurly = '{...} - Different number of opening and closing curly brackets detected.\n';
|
11 |
+
|
12 |
+
openBracketRegExp = /\(/g;
|
13 |
+
closeBracketRegExp = /\)/g;
|
14 |
+
|
15 |
+
openSquareBracketRegExp = /\[/g;
|
16 |
+
closeSquareBracketRegExp = /\]/g;
|
17 |
+
|
18 |
+
openCurlyBracketRegExp = /\{/g;
|
19 |
+
closeCurlyBracketRegExp = /\}/g;
|
20 |
+
|
21 |
+
totalOpenBracketMatches = 0;
|
22 |
+
totalCloseBracketMatches = 0;
|
23 |
+
totalOpenSquareBracketMatches = 0;
|
24 |
+
totalCloseSquareBracketMatches = 0;
|
25 |
+
totalOpenCurlyBracketMatches = 0;
|
26 |
+
totalCloseCurlyBracketMatches = 0;
|
27 |
+
|
28 |
+
openBracketMatches = textArea.value.match(openBracketRegExp);
|
29 |
+
if(openBracketMatches) {
|
30 |
+
totalOpenBracketMatches = openBracketMatches.length;
|
31 |
+
}
|
32 |
+
|
33 |
+
closeBracketMatches = textArea.value.match(closeBracketRegExp);
|
34 |
+
if(closeBracketMatches) {
|
35 |
+
totalCloseBracketMatches = closeBracketMatches.length;
|
36 |
+
}
|
37 |
+
|
38 |
+
openSquareBracketMatches = textArea.value.match(openSquareBracketRegExp);
|
39 |
+
if(openSquareBracketMatches) {
|
40 |
+
totalOpenSquareBracketMatches = openSquareBracketMatches.length;
|
41 |
+
}
|
42 |
+
|
43 |
+
closeSquareBracketMatches = textArea.value.match(closeSquareBracketRegExp);
|
44 |
+
if(closeSquareBracketMatches) {
|
45 |
+
totalCloseSquareBracketMatches = closeSquareBracketMatches.length;
|
46 |
+
}
|
47 |
+
|
48 |
+
openCurlyBracketMatches = textArea.value.match(openCurlyBracketRegExp);
|
49 |
+
if(openCurlyBracketMatches) {
|
50 |
+
totalOpenCurlyBracketMatches = openCurlyBracketMatches.length;
|
51 |
+
}
|
52 |
+
|
53 |
+
closeCurlyBracketMatches = textArea.value.match(closeCurlyBracketRegExp);
|
54 |
+
if(closeCurlyBracketMatches) {
|
55 |
+
totalCloseCurlyBracketMatches = closeCurlyBracketMatches.length;
|
56 |
+
}
|
57 |
+
|
58 |
+
if(totalOpenBracketMatches != totalCloseBracketMatches) {
|
59 |
+
if(!counterElt.title.includes(errorStringParen)) {
|
60 |
+
counterElt.title += errorStringParen;
|
61 |
+
}
|
62 |
+
} else {
|
63 |
+
counterElt.title = counterElt.title.replace(errorStringParen, '');
|
64 |
+
}
|
65 |
+
|
66 |
+
if(totalOpenSquareBracketMatches != totalCloseSquareBracketMatches) {
|
67 |
+
if(!counterElt.title.includes(errorStringSquare)) {
|
68 |
+
counterElt.title += errorStringSquare;
|
69 |
+
}
|
70 |
+
} else {
|
71 |
+
counterElt.title = counterElt.title.replace(errorStringSquare, '');
|
72 |
+
}
|
73 |
+
|
74 |
+
if(totalOpenCurlyBracketMatches != totalCloseCurlyBracketMatches) {
|
75 |
+
if(!counterElt.title.includes(errorStringCurly)) {
|
76 |
+
counterElt.title += errorStringCurly;
|
77 |
+
}
|
78 |
+
} else {
|
79 |
+
counterElt.title = counterElt.title.replace(errorStringCurly, '');
|
80 |
+
}
|
81 |
+
|
82 |
+
if(counterElt.title != '') {
|
83 |
+
counterElt.classList.add('error');
|
84 |
+
} else {
|
85 |
+
counterElt.classList.remove('error');
|
86 |
+
}
|
87 |
+
}
|
88 |
+
|
89 |
+
function setupBracketChecking(id_prompt, id_counter){
|
90 |
+
var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea");
|
91 |
+
var counter = gradioApp().getElementById(id_counter)
|
92 |
+
textarea.addEventListener("input", function(evt){
|
93 |
+
checkBrackets(evt, textarea, counter)
|
94 |
+
});
|
95 |
+
}
|
96 |
+
|
97 |
+
var shadowRootLoaded = setInterval(function() {
|
98 |
+
var shadowRoot = document.querySelector('gradio-app').shadowRoot;
|
99 |
+
if(! shadowRoot) return false;
|
100 |
+
|
101 |
+
var shadowTextArea = shadowRoot.querySelectorAll('#txt2img_prompt > label > textarea');
|
102 |
+
if(shadowTextArea.length < 1) return false;
|
103 |
+
|
104 |
+
clearInterval(shadowRootLoaded);
|
105 |
+
|
106 |
+
setupBracketChecking('txt2img_prompt', 'txt2img_token_counter')
|
107 |
+
setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter')
|
108 |
+
setupBracketChecking('img2img_prompt', 'imgimg_token_counter')
|
109 |
+
setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter')
|
110 |
+
}, 1000);
|
extensions/gif2gif/README.md
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# gif2gif
|
2 |
+
Automatic1111 Stable Diffusion WebUI GIF Extension
|
3 |
+
|
4 |
+
### gif2gif script extension
|
5 |
+
|
6 |
+
The purpose of this script is to accept an animated gif as input, process frames as img2img typically would, and recombine them back into an animated gif. Intended to provide a fun, fast, gif-to-gif workflow that supports new models and methods such as Controlnet and InstructPix2Pix. Drop in a gif and go. Referenced code from prompts_from_file.
|
7 |
+
|
8 |
+
![combined](https://user-images.githubusercontent.com/93007558/224235828-f4d0be70-67da-41fc-b225-558576b4b5d4.gif)
|
9 |
+
|
10 |
+
Experimental/WIP similar repos:
|
11 |
+
- [keyframer](https://github.com/LonicaMewinsky/sd-webui-keyframer) - Multiple images in same latent space. Good for keyframes.
|
12 |
+
- [frame2frame](https://github.com/LonicaMewinsky/frame2frame) - Handles video files (and gifs).
|
13 |
+
|
14 |
+
**Instructions:**
|
15 |
+
- For ControlNet support, make sure to enable "Allow other script to control this extension" in settings.
|
16 |
+
- img2img batch *count* represents completed GIFs, not individual images.
|
17 |
+
- All images in a single batch will be blended together. May help with consistency between frames.
|
18 |
+
- Drop or select gif in the script's box; a preview should appear if it is a valid animated gif.
|
19 |
+
- Inpainting works, but currently limited to one mask applied to all frames equally.
|
20 |
+
- Optionally blend all frames together for more predictable inpaint coverage.
|
21 |
+
- Adjust desired FPS if needed/wanted. Default slider position is original FPS.
|
22 |
+
- Add interpolation frames if wanted. Preview should render.
|
23 |
+
- Count of interp frames represent the number of blend steps between keyframes.
|
24 |
+
- This is a very simple dynamic interp function; the keyframes are left as-is.
|
25 |
+
- When *actual FPS* reaches 50, the maximum, the resultant gif will slow and extend to accomodate interp.
|
26 |
+
- Results are dropped into outputs/img2img/gif2gif, and displayed in output gallery on right side
|
27 |
+
- [ControlNet](https://github.com/Mikubill/sd-webui-controlnet) extension handling improved:
|
28 |
+
- Script will no longer overwrite existing ControlNet input images.
|
29 |
+
- Script will only target ControlNet models with no input image specified.
|
30 |
+
- Allows, for example, a static depth background while animation feeds openpose.
|
31 |
+
|
32 |
+
![ControlNetInst](https://user-images.githubusercontent.com/93007558/224233623-88abcf87-3e01-4bf3-8209-6ee691b1f749.jpg)
|
33 |
+
|
34 |
+
**Tips:**
|
35 |
+
- Configure and process the gif in img2img (it'll use the first frame) before running the script. Find a good seed!
|
36 |
+
- If you add an image into ControlNet image window, it will default to that image for guidance for ALL frames.
|
37 |
+
- Interpolation is not always necessary nor helpful.
|
38 |
+
|
39 |
+
**Installation:**
|
40 |
+
- Install from the Automatic1111 WebUI extensions list, restart UI or
|
41 |
+
- Clone this repo into your Automatic1111 WebUI /extensions folder, restart UI
|
42 |
+
|
43 |
+
**Changelog:**
|
44 |
+
- 3/09/23: ControlNet extension handling completely re-worked.
|
45 |
+
- 3/06/23: GIFs are now sent to results gallery(!) and "re-use last seed" works more reliably.
|
46 |
+
- 3/03/23: Added support for embedding generation into into GIF.
|
47 |
+
- 3/03/23: Blended inpainting picture had major performance issues on some systems; made optional.
|
extensions/gif2gif/instructions.txt
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# gif2gif
|
2 |
+
|
3 |
+
The purpose of this script is to accept an animated gif as input, process frames as img2img typically would, and recombine them back into an animated gif. Not intended to have extensive functionality. Referenced code from prompts_from_file.
|
4 |
+
|
5 |
+
**Instructions:**
|
6 |
+
- img2img batch *count* represents completed GIFs, not individual images.
|
7 |
+
- All images in a single batch will be blended together
|
8 |
+
- Drop or select gif in the script's box; a preview should appear if it is a valid animated gif.
|
9 |
+
- Inpainting works, but currently limited to one mask applied to all frames equally.
|
10 |
+
- Adjust desired FPS if needed/wanted. Default slider position is original FPS.
|
11 |
+
- Add interpolation frames if wanted. Preview should render.
|
12 |
+
- Count of interp frames represent the number of blend steps between keyframes.
|
13 |
+
- This is a very simple dynamic interp function; the keyframes are left as-is.
|
14 |
+
- When *actual FPS* reaches 50, the maximum, the resultant gif will slow and extend to accomodate interp.
|
15 |
+
- Results are dropped into outputs/img2img/gif2gif.
|
16 |
+
|
17 |
+
**Tips:**
|
18 |
+
- Configure and process the gif in img2img (it'll use the first frame) before running the script. Find a good seed!
|
19 |
+
- Interpolation is not always necessary nor helpful.
|
extensions/gif2gif/javascript/gif2gif_hints.js
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// mouseover tooltips for various UI elements in the form of "UI element label"="Tooltip text".
|
2 |
+
|
3 |
+
gif2gif_titles = {
|
4 |
+
"Upload GIF": "Click here to upload your GIF",
|
5 |
+
"Desired FPS": "Target FPS; defaults to original FPS",
|
6 |
+
"Interpolation frames": "Number of transition frames between key frames",
|
7 |
+
"Loopback decay": "Factor change for every loop generation. <1 for noise falloff, >1 for noise rampup"
|
8 |
+
}
|
9 |
+
|
10 |
+
|
11 |
+
onUiUpdate(function(){
|
12 |
+
gradioApp().querySelectorAll('span, button, select, p').forEach(function(span){
|
13 |
+
tooltip = gif2gif_titles[span.textContent];
|
14 |
+
|
15 |
+
if(!tooltip){
|
16 |
+
tooltip = gif2gif_titles[span.value];
|
17 |
+
}
|
18 |
+
|
19 |
+
if(!tooltip){
|
20 |
+
for (const c of span.classList) {
|
21 |
+
if (c in gif2gif_titles) {
|
22 |
+
tooltip = gif2gif_titles[c];
|
23 |
+
break;
|
24 |
+
}
|
25 |
+
}
|
26 |
+
}
|
27 |
+
|
28 |
+
if(tooltip){
|
29 |
+
span.title = tooltip;
|
30 |
+
}
|
31 |
+
})
|
32 |
+
|
33 |
+
gradioApp().querySelectorAll('select').forEach(function(select){
|
34 |
+
if (select.onchange != null) return;
|
35 |
+
|
36 |
+
select.onchange = function(){
|
37 |
+
select.title = gif2gif_titles[select.value] || "";
|
38 |
+
}
|
39 |
+
})
|
40 |
+
})
|
extensions/gif2gif/scripts/__pycache__/gif2gif.cpython-310.pyc
ADDED
Binary file (13.6 kB). View file
|
|
extensions/gif2gif/scripts/gif2gif.py
ADDED
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import os
|
3 |
+
import modules.scripts as scripts
|
4 |
+
import modules.images
|
5 |
+
import gradio as gr
|
6 |
+
import numpy as np
|
7 |
+
import tempfile
|
8 |
+
import importlib
|
9 |
+
from PIL import Image, ImageSequence
|
10 |
+
from modules.processing import Processed, process_images
|
11 |
+
from modules.shared import opts, state, sd_upscalers
|
12 |
+
|
13 |
+
with open(os.path.join(scripts.basedir(), "instructions.txt"), 'r') as file:
|
14 |
+
mkd_inst = file.read()
|
15 |
+
|
16 |
+
#Rudimentary interpolation
|
17 |
+
def interp(gif, iframes, dur):
|
18 |
+
try:
|
19 |
+
working_images, resframes = [], []
|
20 |
+
pilgif = Image.open(gif)
|
21 |
+
for frame in ImageSequence.Iterator(pilgif):
|
22 |
+
converted = frame.convert('RGBA')
|
23 |
+
working_images.append(converted)
|
24 |
+
resframes.append(working_images[0]) #Seed the first frame
|
25 |
+
alphas = np.linspace(0, 1, iframes+2)[1:]
|
26 |
+
for i in range(1, len(working_images), 1):
|
27 |
+
for a in range(len(alphas)):
|
28 |
+
intermediate_image = Image.blend(working_images[i-1],working_images[i],alphas[a])
|
29 |
+
resframes.append(intermediate_image)
|
30 |
+
resframes[0].save(gif,
|
31 |
+
save_all = True, append_images = resframes[1:], loop = 0,
|
32 |
+
optimize = False, duration = dur, format='GIF')
|
33 |
+
return gif
|
34 |
+
except:
|
35 |
+
return False
|
36 |
+
|
37 |
+
#Get num closest to 8
|
38 |
+
def cl8(num):
|
39 |
+
rem = num % 8
|
40 |
+
if rem <= 4:
|
41 |
+
return round(num - rem)
|
42 |
+
else:
|
43 |
+
return round(num + (8 - rem))
|
44 |
+
|
45 |
+
def upscale(image, upscaler_name, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop):
|
46 |
+
if upscale_mode == 1:
|
47 |
+
upscale_by = max(upscale_to_width/image.width, upscale_to_height/image.height)
|
48 |
+
|
49 |
+
upscaler = next(iter([x for x in sd_upscalers if x.name == upscaler_name]), None)
|
50 |
+
assert upscaler or (upscaler_name is None), f'could not find upscaler named {upscaler_name}'
|
51 |
+
|
52 |
+
image = upscaler.scaler.upscale(image, upscale_by, upscaler.data_path)
|
53 |
+
if upscale_mode == 1 and upscale_crop:
|
54 |
+
cropped = Image.new("RGB", (upscale_to_width, upscale_to_height))
|
55 |
+
cropped.paste(image, box=(upscale_to_width // 2 - image.width // 2, upscale_to_height // 2 - image.height // 2))
|
56 |
+
image = cropped
|
57 |
+
|
58 |
+
return image
|
59 |
+
|
60 |
+
def blend_images(images):
|
61 |
+
sizes = [img.size for img in images]
|
62 |
+
min_width, min_height = min(sizes, key=lambda s: s[0]*s[1])
|
63 |
+
blended_img = Image.new('RGB', (min_width, min_height))
|
64 |
+
|
65 |
+
for x in range(min_width):
|
66 |
+
for y in range(min_height):
|
67 |
+
colors = [img.getpixel((x, y)) for img in images]
|
68 |
+
avg_color = tuple(int(sum(c[i] for c in colors) / len(colors)) for i in range(3))
|
69 |
+
blended_img.putpixel((x, y), avg_color)
|
70 |
+
|
71 |
+
return blended_img
|
72 |
+
|
73 |
+
class Script(scripts.Script):
|
74 |
+
def __init__(self):
|
75 |
+
self.gif_name = str()
|
76 |
+
self.gif_frames = []
|
77 |
+
self.orig_fps = 0
|
78 |
+
self.orig_duration = 0
|
79 |
+
self.orig_total_seconds = 0
|
80 |
+
self.orig_n_frames = 0
|
81 |
+
self.orig_dimensions = (0,0)
|
82 |
+
self.ready = False
|
83 |
+
self.desired_fps = 0
|
84 |
+
self.desired_interp = 0
|
85 |
+
self.desired_duration = 0
|
86 |
+
self.desired_total_seconds = 0
|
87 |
+
self.slowmo = False
|
88 |
+
self.gif2gifdir = tempfile.TemporaryDirectory()
|
89 |
+
self.img2img_component = gr.Image()
|
90 |
+
self.img2img_inpaint_component = gr.Image()
|
91 |
+
self.img2img_gallery = gr.Gallery()
|
92 |
+
self.img2img_w_slider = gr.Slider()
|
93 |
+
self.img2img_h_slider = gr.Slider()
|
94 |
+
return None
|
95 |
+
|
96 |
+
def title(self):
|
97 |
+
return "gif2gif"
|
98 |
+
|
99 |
+
def show(self, is_img2img):
|
100 |
+
return is_img2img
|
101 |
+
|
102 |
+
def ui(self, is_img2img):
|
103 |
+
#Controls
|
104 |
+
with gr.Column():
|
105 |
+
upload_gif = gr.File(label="Upload GIF", visible=True, file_types = ['.gif','.webp','.plc'], file_count = "single")
|
106 |
+
with gr.Tabs():
|
107 |
+
with gr.Tab("Settings"):
|
108 |
+
with gr.Column():
|
109 |
+
with gr.Row():
|
110 |
+
with gr.Column():
|
111 |
+
with gr.Box():
|
112 |
+
fps_slider = gr.Slider(1, 50, step = 1, label = "Desired FPS", elem_id="harbl")
|
113 |
+
interp_slider = gr.Slider(label = "Interpolation frames", value = 0)
|
114 |
+
gif_resize = gr.Checkbox(value = True, label="Resize result back to original dimensions")
|
115 |
+
gif_clear_frames = gr.Checkbox(value = True, label="Delete intermediate frames after GIF generation")
|
116 |
+
gif_common_seed = gr.Checkbox(value = True, label="For -1 seed, all frames in a GIF have common seed")
|
117 |
+
with gr.Column():
|
118 |
+
with gr.Row():
|
119 |
+
with gr.Box():
|
120 |
+
with gr.Column():
|
121 |
+
fps_actual = gr.Textbox(value="", interactive = False, label = "Actual FPS")
|
122 |
+
seconds_actual = gr.Textbox(value="", interactive = False, label = "Actual total duration")
|
123 |
+
frames_actual = gr.Textbox(value="", interactive = False, label = "Actual total frames")
|
124 |
+
with gr.Box():
|
125 |
+
with gr.Column():
|
126 |
+
fps_original = gr.Textbox(value="", interactive = False, label = "Original FPS")
|
127 |
+
seconds_original = gr.Textbox(value="", interactive = False, label = "Original total duration")
|
128 |
+
frames_original = gr.Textbox(value="", interactive = False, label = "Original total frames")
|
129 |
+
with gr.Tab("Loopback"):
|
130 |
+
loop_backs = gr.Slider(0, 50, step = 1, label = "Generation loopbacks", value = 0)
|
131 |
+
loop_denoise = gr.Slider(0.01, 1, step = 0.01, value=0.10, interactive = True, label = "Loopback denoise strength")
|
132 |
+
loop_decay = gr.Slider(0, 2, step = 0.05, value=1.0, interactive = True, label = "Loopback decay")
|
133 |
+
with gr.Tab("Upscaling"):
|
134 |
+
with gr.Row():
|
135 |
+
with gr.Column():
|
136 |
+
with gr.Box():
|
137 |
+
ups_upscaler = gr.Dropdown(value = "None", interactive = True, choices = [x.name for x in sd_upscalers], label = "Upscaler")
|
138 |
+
ups_only_upscale = gr.Checkbox(value = False, label = "Skip generation, only upscale")
|
139 |
+
with gr.Column():
|
140 |
+
with gr.Tabs():
|
141 |
+
with gr.Tab("Scale by") as tab_scale_by:
|
142 |
+
with gr.Box():
|
143 |
+
ups_scale_by = gr.Slider(1, 8, step = 0.1, value=2, interactive = True, label = "Factor")
|
144 |
+
with gr.Tab("Scale to") as tab_scale_to:
|
145 |
+
with gr.Box():
|
146 |
+
with gr.Column():
|
147 |
+
ups_scale_to_w = gr.Slider(0, 8000, step = 8, value=512, interactive = True, label = "Target width")
|
148 |
+
ups_scale_to_h = gr.Slider(0, 8000, step = 8, value=512, interactive = True, label = "Target height")
|
149 |
+
ups_scale_to_crop = gr.Checkbox(value = False, label = "Crop to fit")
|
150 |
+
with gr.Tab("Inpainting", open = False):
|
151 |
+
with gr.Column():
|
152 |
+
make_blend = gr.Button("Send blended image to img2img Inpainting tab")
|
153 |
+
with gr.Tab("Readme", open = False):
|
154 |
+
gr.Markdown(mkd_inst)
|
155 |
+
with gr.Column():
|
156 |
+
display_gif = gr.Image(label = "Preview GIF", Source="Upload", visible=False, interactive=True, type="filepath")
|
157 |
+
|
158 |
+
def processgif(file):
|
159 |
+
try:
|
160 |
+
pimg = ImageSequence.Iterator(Image.open(file.name))[0]
|
161 |
+
except:
|
162 |
+
print("Could not load GIF.") #Make no changes
|
163 |
+
return gr.Image.update(), gr.Image.update(), gr.Image.update(), gr.Slider.update(), gr.Textbox.update(), gr.Textbox.update(), gr.Textbox.update()
|
164 |
+
init_gif = Image.open(file.name)
|
165 |
+
self.gif_name = file.name
|
166 |
+
self.orig_dimensions = init_gif.size
|
167 |
+
self.orig_duration = init_gif.info["duration"]
|
168 |
+
self.orig_n_frames = init_gif.n_frames
|
169 |
+
self.orig_total_seconds = round((self.orig_duration * self.orig_n_frames)/1000, 2)
|
170 |
+
self.orig_fps = round(1000 / int(init_gif.info["duration"]), 2)
|
171 |
+
#Need to also put images in img2img/inpainting windows (ui will not run without)
|
172 |
+
#Gradio painting tools act weird with smaller images.. resize to 480 if smaller
|
173 |
+
self.gif_frames = []
|
174 |
+
for frame in ImageSequence.Iterator(init_gif):
|
175 |
+
if frame.height < 480:
|
176 |
+
converted = frame.resize((round(480*frame.width/frame.height), 480), Image.Resampling.LANCZOS).convert('RGBA')
|
177 |
+
else:
|
178 |
+
converted = frame.convert('RGBA')
|
179 |
+
self.gif_frames.append(converted)
|
180 |
+
self.ready = True
|
181 |
+
self.img2img_gallery.update([file.name])
|
182 |
+
return self.gif_frames[0], self.gif_frames[0], cl8(pimg.width), cl8(pimg.height), gr.File.update(visible=False), gr.Image.update(value=file.name, visible=True), self.orig_fps, self.orig_fps, (f"{self.orig_total_seconds} seconds"), self.orig_n_frames
|
183 |
+
|
184 |
+
def clear_gif(gif):
|
185 |
+
if gif == None:
|
186 |
+
return None, None, gr.File.update(value=None, visible=True), gr.Image.update(visible=False)
|
187 |
+
else:
|
188 |
+
return gr.Image.update(), gr.Image.update(), gr.File.update(), gr.Image.update()
|
189 |
+
|
190 |
+
def fpsupdate(fps, interp_frames):
|
191 |
+
if (self.ready and fps and (interp_frames != None)):
|
192 |
+
self.desired_fps = fps
|
193 |
+
self.desired_interp = interp_frames
|
194 |
+
total_n_frames = self.orig_n_frames + ((self.orig_n_frames -1) * self.desired_interp)
|
195 |
+
calcdur = (1000 / fps) / (total_n_frames/self.orig_n_frames)
|
196 |
+
if calcdur < 20:
|
197 |
+
calcdur = 20
|
198 |
+
self.slowmo = True
|
199 |
+
self.desired_duration = calcdur
|
200 |
+
self.desired_total_seconds = round((self.desired_duration * total_n_frames)/1000, 2)
|
201 |
+
gifbuffer = (f"{self.gif2gifdir.name}/previewgif.gif")
|
202 |
+
self.gif_frames[0].save(gifbuffer,
|
203 |
+
save_all = True, append_images = self.gif_frames[1:], loop = 0,
|
204 |
+
optimize = False, duration = self.desired_duration)
|
205 |
+
return gifbuffer, round(1000/self.desired_duration, 2), f"{self.desired_total_seconds} seconds", total_n_frames
|
206 |
+
|
207 |
+
def send_blend():
|
208 |
+
if self.gif_frames == None:
|
209 |
+
print("No loaded; cannot blend")
|
210 |
+
return gr.Image.update()
|
211 |
+
blend = blend_images(self.gif_frames)
|
212 |
+
return blend
|
213 |
+
|
214 |
+
#Control change events
|
215 |
+
fps_slider.change(fn=fpsupdate, inputs = [fps_slider, interp_slider], outputs = [display_gif, fps_actual, seconds_actual, frames_actual])
|
216 |
+
interp_slider.change(fn=fpsupdate, inputs = [fps_slider, interp_slider], outputs = [display_gif, fps_actual, seconds_actual, frames_actual])
|
217 |
+
ups_scale_mode = gr.State(value = 0)
|
218 |
+
tab_scale_by.select(fn=lambda: 0, inputs=[], outputs=[ups_scale_mode])
|
219 |
+
tab_scale_to.select(fn=lambda: 1, inputs=[], outputs=[ups_scale_mode])
|
220 |
+
upload_gif.upload(fn=processgif, inputs = upload_gif, outputs = [self.img2img_component, self.img2img_inpaint_component, self.img2img_w_slider, self.img2img_h_slider, upload_gif, display_gif, fps_slider, fps_original, seconds_original, frames_original])
|
221 |
+
display_gif.change(fn=clear_gif, inputs=display_gif, outputs=[self.img2img_component, self.img2img_inpaint_component, upload_gif, display_gif])
|
222 |
+
make_blend.click(fn=send_blend, inputs=None, outputs=[self.img2img_inpaint_component])
|
223 |
+
|
224 |
+
return [gif_resize, gif_clear_frames, gif_common_seed, loop_backs, loop_denoise, loop_decay, ups_upscaler, ups_only_upscale, ups_scale_mode, ups_scale_by, ups_scale_to_w, ups_scale_to_h, ups_scale_to_crop]
|
225 |
+
|
226 |
+
#Grab the img2img image components for update later
|
227 |
+
#Maybe there's a better way to do this?
|
228 |
+
def after_component(self, component, **kwargs):
|
229 |
+
if component.elem_id == "img2img_image":
|
230 |
+
self.img2img_component = component
|
231 |
+
return self.img2img_component
|
232 |
+
if component.elem_id == "img2maskimg":
|
233 |
+
self.img2img_inpaint_component = component
|
234 |
+
return self.img2img_inpaint_component
|
235 |
+
if component.elem_id == "img2img_width":
|
236 |
+
self.img2img_w_slider = component
|
237 |
+
return self.img2img_w_slider
|
238 |
+
if component.elem_id == "img2img_height":
|
239 |
+
self.img2img_h_slider = component
|
240 |
+
return self.img2img_h_slider
|
241 |
+
|
242 |
+
#Main run
|
243 |
+
def run(self, p, gif_resize, gif_clear_frames, gif_common_seed, loop_backs, loop_denoise, loop_decay, ups_upscaler, ups_only_upscale, ups_scale_mode, ups_scale_by, ups_scale_to_w, ups_scale_to_h, ups_scale_to_crop):
|
244 |
+
cnet_present = False
|
245 |
+
try:
|
246 |
+
cnet = importlib.import_module('extensions.sd-webui-controlnet.scripts.external_code', 'external_code')
|
247 |
+
cn_layers = cnet.get_all_units_in_processing(p)
|
248 |
+
target_layer_indices = []
|
249 |
+
for i in range(len(cn_layers)):
|
250 |
+
if (cn_layers[i].image == None) and (cn_layers[i].enabled == True):
|
251 |
+
target_layer_indices.append(i)
|
252 |
+
if len(target_layer_indices) >0:
|
253 |
+
cnet_present = True
|
254 |
+
except:
|
255 |
+
pass
|
256 |
+
orig_p = copy.copy(p)
|
257 |
+
try:
|
258 |
+
inc_frames = self.gif_frames
|
259 |
+
except:
|
260 |
+
print("Something went wrong with GIF. Processing still from img2img.")
|
261 |
+
proc = process_images(p)
|
262 |
+
return proc
|
263 |
+
outpath = os.path.join(p.outpath_samples, "gif2gif")
|
264 |
+
|
265 |
+
#Handle upscaling
|
266 |
+
if (ups_upscaler != "None"):
|
267 |
+
inc_frames = [upscale(frame, ups_upscaler, ups_scale_mode, ups_scale_by, ups_scale_to_w, ups_scale_to_h, ups_scale_to_crop) for frame in inc_frames]
|
268 |
+
if ups_only_upscale:
|
269 |
+
gif_filename = (modules.images.save_image(inc_frames[0], outpath, "gif2gif", extension = 'gif')[0])
|
270 |
+
print(f"gif2gif: Generating GIF to {gif_filename}..")
|
271 |
+
inc_frames[0].save(gif_filename,
|
272 |
+
save_all = True, append_images = inc_frames[1:], loop = 0,
|
273 |
+
optimize = False, duration = self.desired_duration)
|
274 |
+
return Processed(p, inc_frames)
|
275 |
+
|
276 |
+
#Fix/setup vars
|
277 |
+
return_images, all_prompts, infotexts, inter_images = [], [], [], []
|
278 |
+
state.job_count = self.orig_n_frames * p.n_iter * (loop_backs+1)
|
279 |
+
p.do_not_save_grid = True
|
280 |
+
p.do_not_save_samples = gif_clear_frames
|
281 |
+
gif_n_iter = p.n_iter
|
282 |
+
p.n_iter = 1
|
283 |
+
|
284 |
+
#Iterate batch count
|
285 |
+
print(f"Will process {gif_n_iter} GIF(s) with {state.job_count * p.batch_size} total generations.")
|
286 |
+
for x in range(gif_n_iter):
|
287 |
+
if state.skipped: state.skipped = False
|
288 |
+
if state.interrupted: break
|
289 |
+
color_correction = [modules.processing.setup_color_correction(p.init_images[0])]
|
290 |
+
if(gif_common_seed and (p.seed == -1)):
|
291 |
+
modules.processing.fix_seed(p)
|
292 |
+
|
293 |
+
#Iterate frames
|
294 |
+
for frame in inc_frames:
|
295 |
+
if state.skipped: state.skipped = False
|
296 |
+
if state.interrupted: break
|
297 |
+
p.denoising_strength = orig_p.denoising_strength #reset denoise
|
298 |
+
frame_loop_denoise = loop_denoise
|
299 |
+
state.job = f"{state.job_no + 1} out of {state.job_count}"
|
300 |
+
p.init_images = [frame] * p.batch_size #inject current frame
|
301 |
+
#Handle controlnets
|
302 |
+
if cnet_present:
|
303 |
+
new_layers = []
|
304 |
+
for i in range(len(cn_layers)):
|
305 |
+
if i in target_layer_indices:
|
306 |
+
nimg = np.array(frame.convert("RGB"))
|
307 |
+
bimg = np.zeros((frame.width, frame.height, 3), dtype = np.uint8)
|
308 |
+
cn_layers[i].image = [{"image" : nimg, "mask" : bimg}]
|
309 |
+
new_layers.append(cn_layers[i])
|
310 |
+
cnet.update_cn_script_in_processing(p, new_layers)
|
311 |
+
#Process
|
312 |
+
|
313 |
+
proc = process_images(p) #process
|
314 |
+
#Do loopbacks
|
315 |
+
for _ in range(loop_backs):
|
316 |
+
p.init_images = [proc.images[0].convert("RGB")] * p.batch_size
|
317 |
+
p.color_corrections = color_correction
|
318 |
+
p.denoising_strength = frame_loop_denoise
|
319 |
+
proc = process_images(p)
|
320 |
+
frame_loop_denoise = round(frame_loop_denoise*loop_decay, 2)
|
321 |
+
#Handle batches
|
322 |
+
proc_batch = []
|
323 |
+
for pi in proc.images:
|
324 |
+
if type(pi) is Image.Image:
|
325 |
+
proc_batch.append(pi)
|
326 |
+
if len(proc_batch) > 1 and p.batch_size > 1:
|
327 |
+
inter_images.append(blend_images(proc_batch))
|
328 |
+
else:
|
329 |
+
inter_images.append(proc_batch[0])
|
330 |
+
all_prompts += proc.all_prompts
|
331 |
+
infotexts += proc.infotexts
|
332 |
+
#Resize and make gif
|
333 |
+
if(gif_resize):
|
334 |
+
for i in range(len(inter_images)):
|
335 |
+
inter_images[i] = inter_images[i].resize(self.orig_dimensions)
|
336 |
+
#First make temporary file via save_images, then save actual gif over it. First index returns path.
|
337 |
+
gif_filename = (modules.images.save_image(inc_frames[0], outpath, "gif2gif", extension = 'gif')[0])
|
338 |
+
#Handle infotext embedding
|
339 |
+
gif_info=""
|
340 |
+
if opts.enable_pnginfo and infotexts[0] is not None:
|
341 |
+
gif_info = infotexts[0].replace('\n', ', ')
|
342 |
+
#Generate animation
|
343 |
+
print(f"gif2gif: Generating GIF to {gif_filename}..")
|
344 |
+
inter_images[0].save(gif_filename,
|
345 |
+
save_all = True, append_images = inter_images[1:], loop = 0,
|
346 |
+
optimize = False, duration = self.desired_duration, comment=gif_info)
|
347 |
+
if(self.desired_interp > 0):
|
348 |
+
print(f"gif2gif: Interpolating {gif_filename}..")
|
349 |
+
interp(gif_filename, self.desired_interp, self.desired_duration)
|
350 |
+
#Returns
|
351 |
+
return_images.append(gif_filename)
|
352 |
+
if not gif_clear_frames:
|
353 |
+
return_images.extend(inter_images)
|
354 |
+
|
355 |
+
return Processed(p, return_images, p.seed, "", all_prompts=all_prompts, infotexts=infotexts)
|
extensions/put extensions here.txt
ADDED
File without changes
|