nolanaatama commited on
Commit
f2a3c57
1 Parent(s): 33b550a

Initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .github/ISSUE_TEMPLATE/bug_report.yml +83 -0
  2. .github/ISSUE_TEMPLATE/config.yml +5 -0
  3. .github/ISSUE_TEMPLATE/feature_request.yml +40 -0
  4. .github/PULL_REQUEST_TEMPLATE/pull_request_template.md +28 -0
  5. .github/workflows/on_pull_request.yaml +42 -0
  6. .github/workflows/run_tests.yaml +31 -0
  7. .gitignore +34 -0
  8. .pylintrc +3 -0
  9. CODEOWNERS +12 -0
  10. README.md +151 -3
  11. artists.csv +0 -0
  12. configs/alt-diffusion-inference.yaml +72 -0
  13. configs/v1-inference.yaml +70 -0
  14. environment-wsl2.yaml +11 -0
  15. extensions-builtin/LDSR/ldsr_model_arch.py +256 -0
  16. extensions-builtin/LDSR/preload.py +6 -0
  17. extensions-builtin/LDSR/scripts/ldsr_model.py +69 -0
  18. extensions-builtin/LDSR/sd_hijack_autoencoder.py +286 -0
  19. extensions-builtin/LDSR/sd_hijack_ddpm_v1.py +1449 -0
  20. extensions-builtin/ScuNET/preload.py +6 -0
  21. extensions-builtin/ScuNET/scripts/scunet_model.py +87 -0
  22. extensions-builtin/ScuNET/scunet_model_arch.py +265 -0
  23. extensions-builtin/SwinIR/preload.py +6 -0
  24. extensions-builtin/SwinIR/scripts/swinir_model.py +172 -0
  25. extensions-builtin/SwinIR/swinir_model_arch.py +867 -0
  26. extensions-builtin/SwinIR/swinir_model_arch_v2.py +1017 -0
  27. extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js +107 -0
  28. javascript/aspectRatioOverlay.js +108 -0
  29. javascript/contextMenus.js +177 -0
  30. javascript/dragdrop.js +89 -0
  31. javascript/edit-attention.js +75 -0
  32. javascript/extensions.js +35 -0
  33. javascript/generationParams.js +33 -0
  34. javascript/hints.js +136 -0
  35. javascript/imageMaskFix.js +45 -0
  36. javascript/imageParams.js +19 -0
  37. javascript/imageviewer.js +276 -0
  38. javascript/localization.js +167 -0
  39. javascript/notification.js +49 -0
  40. javascript/progressbar.js +149 -0
  41. javascript/textualInversion.js +8 -0
  42. javascript/ui.js +222 -0
  43. launch.py +295 -0
  44. localizations/Put localization files here.txt +0 -0
  45. modules/api/api.py +418 -0
  46. modules/api/models.py +251 -0
  47. modules/artists.py +25 -0
  48. modules/call_queue.py +98 -0
  49. modules/codeformer/codeformer_arch.py +278 -0
  50. modules/codeformer/vqgan_arch.py +437 -0
.github/ISSUE_TEMPLATE/bug_report.yml ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Bug Report
2
+ description: You think somethings is broken in the UI
3
+ title: "[Bug]: "
4
+ labels: ["bug-report"]
5
+
6
+ body:
7
+ - type: checkboxes
8
+ attributes:
9
+ label: Is there an existing issue for this?
10
+ description: Please search to see if an issue already exists for the bug you encountered, and that it hasn't been fixed in a recent build/commit.
11
+ options:
12
+ - label: I have searched the existing issues and checked the recent builds/commits
13
+ required: true
14
+ - type: markdown
15
+ attributes:
16
+ value: |
17
+ *Please fill this form with as much information as possible, don't forget to fill "What OS..." and "What browsers" and *provide screenshots if possible**
18
+ - type: textarea
19
+ id: what-did
20
+ attributes:
21
+ label: What happened?
22
+ description: Tell us what happened in a very clear and simple way
23
+ validations:
24
+ required: true
25
+ - type: textarea
26
+ id: steps
27
+ attributes:
28
+ label: Steps to reproduce the problem
29
+ description: Please provide us with precise step by step information on how to reproduce the bug
30
+ value: |
31
+ 1. Go to ....
32
+ 2. Press ....
33
+ 3. ...
34
+ validations:
35
+ required: true
36
+ - type: textarea
37
+ id: what-should
38
+ attributes:
39
+ label: What should have happened?
40
+ description: tell what you think the normal behavior should be
41
+ validations:
42
+ required: true
43
+ - type: input
44
+ id: commit
45
+ attributes:
46
+ label: Commit where the problem happens
47
+ description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit hash** shown in the cmd/terminal when you launch the UI)
48
+ validations:
49
+ required: true
50
+ - type: dropdown
51
+ id: platforms
52
+ attributes:
53
+ label: What platforms do you use to access UI ?
54
+ multiple: true
55
+ options:
56
+ - Windows
57
+ - Linux
58
+ - MacOS
59
+ - iOS
60
+ - Android
61
+ - Other/Cloud
62
+ - type: dropdown
63
+ id: browsers
64
+ attributes:
65
+ label: What browsers do you use to access the UI ?
66
+ multiple: true
67
+ options:
68
+ - Mozilla Firefox
69
+ - Google Chrome
70
+ - Brave
71
+ - Apple Safari
72
+ - Microsoft Edge
73
+ - type: textarea
74
+ id: cmdargs
75
+ attributes:
76
+ label: Command Line Arguments
77
+ description: Are you using any launching parameters/command line arguments (modified webui-user.py) ? If yes, please write them below
78
+ render: Shell
79
+ - type: textarea
80
+ id: misc
81
+ attributes:
82
+ label: Additional information, context and logs
83
+ description: Please provide us with any relevant additional info, context or log output.
.github/ISSUE_TEMPLATE/config.yml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ blank_issues_enabled: false
2
+ contact_links:
3
+ - name: WebUI Community Support
4
+ url: https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions
5
+ about: Please ask and answer questions here.
.github/ISSUE_TEMPLATE/feature_request.yml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Feature request
2
+ description: Suggest an idea for this project
3
+ title: "[Feature Request]: "
4
+ labels: ["suggestion"]
5
+
6
+ body:
7
+ - type: checkboxes
8
+ attributes:
9
+ label: Is there an existing issue for this?
10
+ description: Please search to see if an issue already exists for the feature you want, and that it's not implemented in a recent build/commit.
11
+ options:
12
+ - label: I have searched the existing issues and checked the recent builds/commits
13
+ required: true
14
+ - type: markdown
15
+ attributes:
16
+ value: |
17
+ *Please fill this form with as much information as possible, provide screenshots and/or illustrations of the feature if possible*
18
+ - type: textarea
19
+ id: feature
20
+ attributes:
21
+ label: What would your feature do ?
22
+ description: Tell us about your feature in a very clear and simple way, and what problem it would solve
23
+ validations:
24
+ required: true
25
+ - type: textarea
26
+ id: workflow
27
+ attributes:
28
+ label: Proposed workflow
29
+ description: Please provide us with step by step information on how you'd like the feature to be accessed and used
30
+ value: |
31
+ 1. Go to ....
32
+ 2. Press ....
33
+ 3. ...
34
+ validations:
35
+ required: true
36
+ - type: textarea
37
+ id: misc
38
+ attributes:
39
+ label: Additional information
40
+ description: Add any other context or screenshots about the feature request here.
.github/PULL_REQUEST_TEMPLATE/pull_request_template.md ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Please read the [contributing wiki page](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing) before submitting a pull request!
2
+
3
+ If you have a large change, pay special attention to this paragraph:
4
+
5
+ > Before making changes, if you think that your feature will result in more than 100 lines changing, find me and talk to me about the feature you are proposing. It pains me to reject the hard work someone else did, but I won't add everything to the repo, and it's better if the rejection happens before you have to waste time working on the feature.
6
+
7
+ Otherwise, after making sure you're following the rules described in wiki page, remove this section and continue on.
8
+
9
+ **Describe what this pull request is trying to achieve.**
10
+
11
+ A clear and concise description of what you're trying to accomplish with this, so your intent doesn't have to be extracted from your code.
12
+
13
+ **Additional notes and description of your changes**
14
+
15
+ More technical discussion about your changes go here, plus anything that a maintainer might have to specifically take a look at, or be wary of.
16
+
17
+ **Environment this was tested in**
18
+
19
+ List the environment you have developed / tested this on. As per the contributing page, changes should be able to work on Windows out of the box.
20
+ - OS: [e.g. Windows, Linux]
21
+ - Browser [e.g. chrome, safari]
22
+ - Graphics card [e.g. NVIDIA RTX 2080 8GB, AMD RX 6600 8GB]
23
+
24
+ **Screenshots or videos of your changes**
25
+
26
+ If applicable, screenshots or a video showing off your changes. If it edits an existing UI, it should ideally contain a comparison of what used to be there, before your changes were made.
27
+
28
+ This is **required** for anything that touches the user interface.
.github/workflows/on_pull_request.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # See https://github.com/actions/starter-workflows/blob/1067f16ad8a1eac328834e4b0ae24f7d206f810d/ci/pylint.yml for original reference file
2
+ name: Run Linting/Formatting on Pull Requests
3
+
4
+ on:
5
+ - push
6
+ - pull_request
7
+ # See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#onpull_requestpull_request_targetbranchesbranches-ignore for syntax docs
8
+ # if you want to filter out branches, delete the `- pull_request` and uncomment these lines :
9
+ # pull_request:
10
+ # branches:
11
+ # - master
12
+ # branches-ignore:
13
+ # - development
14
+
15
+ jobs:
16
+ lint:
17
+ runs-on: ubuntu-latest
18
+ steps:
19
+ - name: Checkout Code
20
+ uses: actions/checkout@v3
21
+ - name: Set up Python 3.10
22
+ uses: actions/setup-python@v3
23
+ with:
24
+ python-version: 3.10.6
25
+ - uses: actions/cache@v2
26
+ with:
27
+ path: ~/.cache/pip
28
+ key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
29
+ restore-keys: |
30
+ ${{ runner.os }}-pip-
31
+ - name: Install PyLint
32
+ run: |
33
+ python -m pip install --upgrade pip
34
+ pip install pylint
35
+ # This lets PyLint check to see if it can resolve imports
36
+ - name: Install dependencies
37
+ run : |
38
+ export COMMANDLINE_ARGS="--skip-torch-cuda-test --exit"
39
+ python launch.py
40
+ - name: Analysing the code with pylint
41
+ run: |
42
+ pylint $(git ls-files '*.py')
.github/workflows/run_tests.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Run basic features tests on CPU with empty SD model
2
+
3
+ on:
4
+ - push
5
+ - pull_request
6
+
7
+ jobs:
8
+ test:
9
+ runs-on: ubuntu-latest
10
+ steps:
11
+ - name: Checkout Code
12
+ uses: actions/checkout@v3
13
+ - name: Set up Python 3.10
14
+ uses: actions/setup-python@v4
15
+ with:
16
+ python-version: 3.10.6
17
+ - uses: actions/cache@v3
18
+ with:
19
+ path: ~/.cache/pip
20
+ key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
21
+ restore-keys: ${{ runner.os }}-pip-
22
+ - name: Run tests
23
+ run: python launch.py --tests basic_features --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test
24
+ - name: Upload main app stdout-stderr
25
+ uses: actions/upload-artifact@v3
26
+ if: always()
27
+ with:
28
+ name: stdout-stderr
29
+ path: |
30
+ test/stdout.txt
31
+ test/stderr.txt
.gitignore ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+ *.ckpt
3
+ *.safetensors
4
+ *.pth
5
+ /ESRGAN/*
6
+ /SwinIR/*
7
+ /repositories
8
+ /venv
9
+ /tmp
10
+ /model.ckpt
11
+ /models/**/*
12
+ /GFPGANv1.3.pth
13
+ /gfpgan/weights/*.pth
14
+ /ui-config.json
15
+ /outputs
16
+ /config.json
17
+ /log
18
+ /webui.settings.bat
19
+ /embeddings
20
+ /styles.csv
21
+ /params.txt
22
+ /styles.csv.bak
23
+ /webui-user.bat
24
+ /webui-user.sh
25
+ /interrogate
26
+ /user.css
27
+ /.idea
28
+ notification.mp3
29
+ /SwinIR
30
+ /textual_inversion
31
+ .vscode
32
+ /extensions
33
+ /test/stdout.txt
34
+ /test/stderr.txt
.pylintrc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # See https://pylint.pycqa.org/en/latest/user_guide/messages/message_control.html
2
+ [MESSAGES CONTROL]
3
+ disable=C,R,W,E,I
CODEOWNERS ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ * @AUTOMATIC1111
2
+
3
+ # if you were managing a localization and were removed from this file, this is because
4
+ # the intended way to do localizations now is via extensions. See:
5
+ # https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Developing-extensions
6
+ # Make a repo with your localization and since you are still listed as a collaborator
7
+ # you can add it to the wiki page yourself. This change is because some people complained
8
+ # the git commit log is cluttered with things unrelated to almost everyone and
9
+ # because I believe this is the best overall for the project to handle localizations almost
10
+ # entirely without my oversight.
11
+
12
+
README.md CHANGED
@@ -1,3 +1,151 @@
1
- ---
2
- license: creativeml-openrail-m
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Stable Diffusion web UI
2
+ A browser interface based on Gradio library for Stable Diffusion.
3
+
4
+ ![](txt2img_Screenshot.png)
5
+
6
+ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Scripts) wiki page for extra scripts developed by users.
7
+
8
+ ## Features
9
+ [Detailed feature showcase with images](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features):
10
+ - Original txt2img and img2img modes
11
+ - One click install and run script (but you still must install python and git)
12
+ - Outpainting
13
+ - Inpainting
14
+ - Color Sketch
15
+ - Prompt Matrix
16
+ - Stable Diffusion Upscale
17
+ - Attention, specify parts of text that the model should pay more attention to
18
+ - a man in a ((tuxedo)) - will pay more attention to tuxedo
19
+ - a man in a (tuxedo:1.21) - alternative syntax
20
+ - select text and press ctrl+up or ctrl+down to automatically adjust attention to selected text (code contributed by anonymous user)
21
+ - Loopback, run img2img processing multiple times
22
+ - X/Y plot, a way to draw a 2 dimensional plot of images with different parameters
23
+ - Textual Inversion
24
+ - have as many embeddings as you want and use any names you like for them
25
+ - use multiple embeddings with different numbers of vectors per token
26
+ - works with half precision floating point numbers
27
+ - train embeddings on 8GB (also reports of 6GB working)
28
+ - Extras tab with:
29
+ - GFPGAN, neural network that fixes faces
30
+ - CodeFormer, face restoration tool as an alternative to GFPGAN
31
+ - RealESRGAN, neural network upscaler
32
+ - ESRGAN, neural network upscaler with a lot of third party models
33
+ - SwinIR and Swin2SR([see here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/2092)), neural network upscalers
34
+ - LDSR, Latent diffusion super resolution upscaling
35
+ - Resizing aspect ratio options
36
+ - Sampling method selection
37
+ - Adjust sampler eta values (noise multiplier)
38
+ - More advanced noise setting options
39
+ - Interrupt processing at any time
40
+ - 4GB video card support (also reports of 2GB working)
41
+ - Correct seeds for batches
42
+ - Live prompt token length validation
43
+ - Generation parameters
44
+ - parameters you used to generate images are saved with that image
45
+ - in PNG chunks for PNG, in EXIF for JPEG
46
+ - can drag the image to PNG info tab to restore generation parameters and automatically copy them into UI
47
+ - can be disabled in settings
48
+ - drag and drop an image/text-parameters to promptbox
49
+ - Read Generation Parameters Button, loads parameters in promptbox to UI
50
+ - Settings page
51
+ - Running arbitrary python code from UI (must run with --allow-code to enable)
52
+ - Mouseover hints for most UI elements
53
+ - Possible to change defaults/mix/max/step values for UI elements via text config
54
+ - Random artist button
55
+ - Tiling support, a checkbox to create images that can be tiled like textures
56
+ - Progress bar and live image generation preview
57
+ - Negative prompt, an extra text field that allows you to list what you don't want to see in generated image
58
+ - Styles, a way to save part of prompt and easily apply them via dropdown later
59
+ - Variations, a way to generate same image but with tiny differences
60
+ - Seed resizing, a way to generate same image but at slightly different resolution
61
+ - CLIP interrogator, a button that tries to guess prompt from an image
62
+ - Prompt Editing, a way to change prompt mid-generation, say to start making a watermelon and switch to anime girl midway
63
+ - Batch Processing, process a group of files using img2img
64
+ - Img2img Alternative, reverse Euler method of cross attention control
65
+ - Highres Fix, a convenience option to produce high resolution pictures in one click without usual distortions
66
+ - Reloading checkpoints on the fly
67
+ - Checkpoint Merger, a tab that allows you to merge up to 3 checkpoints into one
68
+ - [Custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Scripts) with many extensions from community
69
+ - [Composable-Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/), a way to use multiple prompts at once
70
+ - separate prompts using uppercase `AND`
71
+ - also supports weights for prompts: `a cat :1.2 AND a dog AND a penguin :2.2`
72
+ - No token limit for prompts (original stable diffusion lets you use up to 75 tokens)
73
+ - DeepDanbooru integration, creates danbooru style tags for anime prompts
74
+ - [xformers](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers), major speed increase for select cards: (add --xformers to commandline args)
75
+ - via extension: [History tab](https://github.com/yfszzx/stable-diffusion-webui-images-browser): view, direct and delete images conveniently within the UI
76
+ - Generate forever option
77
+ - Training tab
78
+ - hypernetworks and embeddings options
79
+ - Preprocessing images: cropping, mirroring, autotagging using BLIP or deepdanbooru (for anime)
80
+ - Clip skip
81
+ - Use Hypernetworks
82
+ - Use VAEs
83
+ - Estimated completion time in progress bar
84
+ - API
85
+ - Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML.
86
+ - via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embeds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients))
87
+ - [Stable Diffusion 2.0](https://github.com/Stability-AI/stablediffusion) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20) for instructions
88
+
89
+ ## Installation and Running
90
+ Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
91
+
92
+ Alternatively, use online services (like Google Colab):
93
+
94
+ - [List of Online Services](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Online-Services)
95
+
96
+ ### Automatic Installation on Windows
97
+ 1. Install [Python 3.10.6](https://www.python.org/downloads/windows/), checking "Add Python to PATH"
98
+ 2. Install [git](https://git-scm.com/download/win).
99
+ 3. Download the stable-diffusion-webui repository, for example by running `git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git`.
100
+ 4. Place `model.ckpt` in the `models` directory (see [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) for where to get it).
101
+ 5. _*(Optional)*_ Place `GFPGANv1.4.pth` in the base directory, alongside `webui.py` (see [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) for where to get it).
102
+ 6. Run `webui-user.bat` from Windows Explorer as normal, non-administrator, user.
103
+
104
+ ### Automatic Installation on Linux
105
+ 1. Install the dependencies:
106
+ ```bash
107
+ # Debian-based:
108
+ sudo apt install wget git python3 python3-venv
109
+ # Red Hat-based:
110
+ sudo dnf install wget git python3
111
+ # Arch-based:
112
+ sudo pacman -S wget git python3
113
+ ```
114
+ 2. To install in `/home/$(whoami)/stable-diffusion-webui/`, run:
115
+ ```bash
116
+ bash <(wget -qO- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/master/webui.sh)
117
+ ```
118
+
119
+ ### Installation on Apple Silicon
120
+
121
+ Find the instructions [here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Installation-on-Apple-Silicon).
122
+
123
+ ## Contributing
124
+ Here's how to add code to this repo: [Contributing](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing)
125
+
126
+ ## Documentation
127
+ The documentation was moved from this README over to the project's [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki).
128
+
129
+ ## Credits
130
+ - Stable Diffusion - https://github.com/CompVis/stable-diffusion, https://github.com/CompVis/taming-transformers
131
+ - k-diffusion - https://github.com/crowsonkb/k-diffusion.git
132
+ - GFPGAN - https://github.com/TencentARC/GFPGAN.git
133
+ - CodeFormer - https://github.com/sczhou/CodeFormer
134
+ - ESRGAN - https://github.com/xinntao/ESRGAN
135
+ - SwinIR - https://github.com/JingyunLiang/SwinIR
136
+ - Swin2SR - https://github.com/mv-lab/swin2sr
137
+ - LDSR - https://github.com/Hafiidz/latent-diffusion
138
+ - MiDaS - https://github.com/isl-org/MiDaS
139
+ - Ideas for optimizations - https://github.com/basujindal/stable-diffusion
140
+ - Cross Attention layer optimization - Doggettx - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing.
141
+ - Cross Attention layer optimization - InvokeAI, lstein - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion)
142
+ - Textual Inversion - Rinon Gal - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas).
143
+ - Idea for SD upscale - https://github.com/jquesnelle/txt2imghd
144
+ - Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot
145
+ - CLIP interrogator idea and borrowing some code - https://github.com/pharmapsychotic/clip-interrogator
146
+ - Idea for Composable Diffusion - https://github.com/energy-based-model/Compositional-Visual-Generation-with-Composable-Diffusion-Models-PyTorch
147
+ - xformers - https://github.com/facebookresearch/xformers
148
+ - DeepDanbooru - interrogator for anime diffusers https://github.com/KichangKim/DeepDanbooru
149
+ - Security advice - RyotaK
150
+ - Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
151
+ - (You)
artists.csv ADDED
The diff for this file is too large to render. See raw diff
 
configs/alt-diffusion-inference.yaml ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-04
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.00085
6
+ linear_end: 0.0120
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: "jpg"
11
+ cond_stage_key: "txt"
12
+ image_size: 64
13
+ channels: 4
14
+ cond_stage_trainable: false # Note: different from the one we trained before
15
+ conditioning_key: crossattn
16
+ monitor: val/loss_simple_ema
17
+ scale_factor: 0.18215
18
+ use_ema: False
19
+
20
+ scheduler_config: # 10000 warmup steps
21
+ target: ldm.lr_scheduler.LambdaLinearScheduler
22
+ params:
23
+ warm_up_steps: [ 10000 ]
24
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25
+ f_start: [ 1.e-6 ]
26
+ f_max: [ 1. ]
27
+ f_min: [ 1. ]
28
+
29
+ unet_config:
30
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31
+ params:
32
+ image_size: 32 # unused
33
+ in_channels: 4
34
+ out_channels: 4
35
+ model_channels: 320
36
+ attention_resolutions: [ 4, 2, 1 ]
37
+ num_res_blocks: 2
38
+ channel_mult: [ 1, 2, 4, 4 ]
39
+ num_heads: 8
40
+ use_spatial_transformer: True
41
+ transformer_depth: 1
42
+ context_dim: 768
43
+ use_checkpoint: True
44
+ legacy: False
45
+
46
+ first_stage_config:
47
+ target: ldm.models.autoencoder.AutoencoderKL
48
+ params:
49
+ embed_dim: 4
50
+ monitor: val/rec_loss
51
+ ddconfig:
52
+ double_z: true
53
+ z_channels: 4
54
+ resolution: 256
55
+ in_channels: 3
56
+ out_ch: 3
57
+ ch: 128
58
+ ch_mult:
59
+ - 1
60
+ - 2
61
+ - 4
62
+ - 4
63
+ num_res_blocks: 2
64
+ attn_resolutions: []
65
+ dropout: 0.0
66
+ lossconfig:
67
+ target: torch.nn.Identity
68
+
69
+ cond_stage_config:
70
+ target: modules.xlmr.BertSeriesModelWithTransformation
71
+ params:
72
+ name: "XLMR-Large"
configs/v1-inference.yaml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-04
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.00085
6
+ linear_end: 0.0120
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: "jpg"
11
+ cond_stage_key: "txt"
12
+ image_size: 64
13
+ channels: 4
14
+ cond_stage_trainable: false # Note: different from the one we trained before
15
+ conditioning_key: crossattn
16
+ monitor: val/loss_simple_ema
17
+ scale_factor: 0.18215
18
+ use_ema: False
19
+
20
+ scheduler_config: # 10000 warmup steps
21
+ target: ldm.lr_scheduler.LambdaLinearScheduler
22
+ params:
23
+ warm_up_steps: [ 10000 ]
24
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25
+ f_start: [ 1.e-6 ]
26
+ f_max: [ 1. ]
27
+ f_min: [ 1. ]
28
+
29
+ unet_config:
30
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31
+ params:
32
+ image_size: 32 # unused
33
+ in_channels: 4
34
+ out_channels: 4
35
+ model_channels: 320
36
+ attention_resolutions: [ 4, 2, 1 ]
37
+ num_res_blocks: 2
38
+ channel_mult: [ 1, 2, 4, 4 ]
39
+ num_heads: 8
40
+ use_spatial_transformer: True
41
+ transformer_depth: 1
42
+ context_dim: 768
43
+ use_checkpoint: True
44
+ legacy: False
45
+
46
+ first_stage_config:
47
+ target: ldm.models.autoencoder.AutoencoderKL
48
+ params:
49
+ embed_dim: 4
50
+ monitor: val/rec_loss
51
+ ddconfig:
52
+ double_z: true
53
+ z_channels: 4
54
+ resolution: 256
55
+ in_channels: 3
56
+ out_ch: 3
57
+ ch: 128
58
+ ch_mult:
59
+ - 1
60
+ - 2
61
+ - 4
62
+ - 4
63
+ num_res_blocks: 2
64
+ attn_resolutions: []
65
+ dropout: 0.0
66
+ lossconfig:
67
+ target: torch.nn.Identity
68
+
69
+ cond_stage_config:
70
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
environment-wsl2.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: automatic
2
+ channels:
3
+ - pytorch
4
+ - defaults
5
+ dependencies:
6
+ - python=3.10
7
+ - pip=22.2.2
8
+ - cudatoolkit=11.3
9
+ - pytorch=1.12.1
10
+ - torchvision=0.13.1
11
+ - numpy=1.23.1
extensions-builtin/LDSR/ldsr_model_arch.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import time
4
+ import warnings
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torchvision
9
+ from PIL import Image
10
+ from einops import rearrange, repeat
11
+ from omegaconf import OmegaConf
12
+ import safetensors.torch
13
+
14
+ from ldm.models.diffusion.ddim import DDIMSampler
15
+ from ldm.util import instantiate_from_config, ismap
16
+ from modules import shared, sd_hijack
17
+
18
+ warnings.filterwarnings("ignore", category=UserWarning)
19
+
20
+ cached_ldsr_model: torch.nn.Module = None
21
+
22
+
23
+ # Create LDSR Class
24
+ class LDSR:
25
+ def load_model_from_config(self, half_attention):
26
+ global cached_ldsr_model
27
+
28
+ if shared.opts.ldsr_cached and cached_ldsr_model is not None:
29
+ print("Loading model from cache")
30
+ model: torch.nn.Module = cached_ldsr_model
31
+ else:
32
+ print(f"Loading model from {self.modelPath}")
33
+ _, extension = os.path.splitext(self.modelPath)
34
+ if extension.lower() == ".safetensors":
35
+ pl_sd = safetensors.torch.load_file(self.modelPath, device="cpu")
36
+ else:
37
+ pl_sd = torch.load(self.modelPath, map_location="cpu")
38
+ sd = pl_sd["state_dict"] if "state_dict" in pl_sd else pl_sd
39
+ config = OmegaConf.load(self.yamlPath)
40
+ config.model.target = "ldm.models.diffusion.ddpm.LatentDiffusionV1"
41
+ model: torch.nn.Module = instantiate_from_config(config.model)
42
+ model.load_state_dict(sd, strict=False)
43
+ model = model.to(shared.device)
44
+ if half_attention:
45
+ model = model.half()
46
+ if shared.cmd_opts.opt_channelslast:
47
+ model = model.to(memory_format=torch.channels_last)
48
+
49
+ sd_hijack.model_hijack.hijack(model) # apply optimization
50
+ model.eval()
51
+
52
+ if shared.opts.ldsr_cached:
53
+ cached_ldsr_model = model
54
+
55
+ return {"model": model}
56
+
57
+ def __init__(self, model_path, yaml_path):
58
+ self.modelPath = model_path
59
+ self.yamlPath = yaml_path
60
+
61
+ @staticmethod
62
+ def run(model, selected_path, custom_steps, eta):
63
+ example = get_cond(selected_path)
64
+
65
+ n_runs = 1
66
+ guider = None
67
+ ckwargs = None
68
+ ddim_use_x0_pred = False
69
+ temperature = 1.
70
+ eta = eta
71
+ custom_shape = None
72
+
73
+ height, width = example["image"].shape[1:3]
74
+ split_input = height >= 128 and width >= 128
75
+
76
+ if split_input:
77
+ ks = 128
78
+ stride = 64
79
+ vqf = 4 #
80
+ model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride),
81
+ "vqf": vqf,
82
+ "patch_distributed_vq": True,
83
+ "tie_braker": False,
84
+ "clip_max_weight": 0.5,
85
+ "clip_min_weight": 0.01,
86
+ "clip_max_tie_weight": 0.5,
87
+ "clip_min_tie_weight": 0.01}
88
+ else:
89
+ if hasattr(model, "split_input_params"):
90
+ delattr(model, "split_input_params")
91
+
92
+ x_t = None
93
+ logs = None
94
+ for n in range(n_runs):
95
+ if custom_shape is not None:
96
+ x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
97
+ x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0])
98
+
99
+ logs = make_convolutional_sample(example, model,
100
+ custom_steps=custom_steps,
101
+ eta=eta, quantize_x0=False,
102
+ custom_shape=custom_shape,
103
+ temperature=temperature, noise_dropout=0.,
104
+ corrector=guider, corrector_kwargs=ckwargs, x_T=x_t,
105
+ ddim_use_x0_pred=ddim_use_x0_pred
106
+ )
107
+ return logs
108
+
109
+ def super_resolution(self, image, steps=100, target_scale=2, half_attention=False):
110
+ model = self.load_model_from_config(half_attention)
111
+
112
+ # Run settings
113
+ diffusion_steps = int(steps)
114
+ eta = 1.0
115
+
116
+ down_sample_method = 'Lanczos'
117
+
118
+ gc.collect()
119
+ if torch.cuda.is_available:
120
+ torch.cuda.empty_cache()
121
+
122
+ im_og = image
123
+ width_og, height_og = im_og.size
124
+ # If we can adjust the max upscale size, then the 4 below should be our variable
125
+ down_sample_rate = target_scale / 4
126
+ wd = width_og * down_sample_rate
127
+ hd = height_og * down_sample_rate
128
+ width_downsampled_pre = int(np.ceil(wd))
129
+ height_downsampled_pre = int(np.ceil(hd))
130
+
131
+ if down_sample_rate != 1:
132
+ print(
133
+ f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]')
134
+ im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
135
+ else:
136
+ print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
137
+
138
+ # pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts
139
+ pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size
140
+ im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
141
+
142
+ logs = self.run(model["model"], im_padded, diffusion_steps, eta)
143
+
144
+ sample = logs["sample"]
145
+ sample = sample.detach().cpu()
146
+ sample = torch.clamp(sample, -1., 1.)
147
+ sample = (sample + 1.) / 2. * 255
148
+ sample = sample.numpy().astype(np.uint8)
149
+ sample = np.transpose(sample, (0, 2, 3, 1))
150
+ a = Image.fromarray(sample[0])
151
+
152
+ # remove padding
153
+ a = a.crop((0, 0) + tuple(np.array(im_og.size) * 4))
154
+
155
+ del model
156
+ gc.collect()
157
+ if torch.cuda.is_available:
158
+ torch.cuda.empty_cache()
159
+
160
+ return a
161
+
162
+
163
+ def get_cond(selected_path):
164
+ example = dict()
165
+ up_f = 4
166
+ c = selected_path.convert('RGB')
167
+ c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
168
+ c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]],
169
+ antialias=True)
170
+ c_up = rearrange(c_up, '1 c h w -> 1 h w c')
171
+ c = rearrange(c, '1 c h w -> 1 h w c')
172
+ c = 2. * c - 1.
173
+
174
+ c = c.to(shared.device)
175
+ example["LR_image"] = c
176
+ example["image"] = c_up
177
+
178
+ return example
179
+
180
+
181
+ @torch.no_grad()
182
+ def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None,
183
+ mask=None, x0=None, quantize_x0=False, temperature=1., score_corrector=None,
184
+ corrector_kwargs=None, x_t=None
185
+ ):
186
+ ddim = DDIMSampler(model)
187
+ bs = shape[0]
188
+ shape = shape[1:]
189
+ print(f"Sampling with eta = {eta}; steps: {steps}")
190
+ samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback,
191
+ normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta,
192
+ mask=mask, x0=x0, temperature=temperature, verbose=False,
193
+ score_corrector=score_corrector,
194
+ corrector_kwargs=corrector_kwargs, x_t=x_t)
195
+
196
+ return samples, intermediates
197
+
198
+
199
+ @torch.no_grad()
200
+ def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
201
+ corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False):
202
+ log = dict()
203
+
204
+ z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
205
+ return_first_stage_outputs=True,
206
+ force_c_encode=not (hasattr(model, 'split_input_params')
207
+ and model.cond_stage_key == 'coordinates_bbox'),
208
+ return_original_cond=True)
209
+
210
+ if custom_shape is not None:
211
+ z = torch.randn(custom_shape)
212
+ print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}")
213
+
214
+ z0 = None
215
+
216
+ log["input"] = x
217
+ log["reconstruction"] = xrec
218
+
219
+ if ismap(xc):
220
+ log["original_conditioning"] = model.to_rgb(xc)
221
+ if hasattr(model, 'cond_stage_key'):
222
+ log[model.cond_stage_key] = model.to_rgb(xc)
223
+
224
+ else:
225
+ log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x)
226
+ if model.cond_stage_model:
227
+ log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x)
228
+ if model.cond_stage_key == 'class_label':
229
+ log[model.cond_stage_key] = xc[model.cond_stage_key]
230
+
231
+ with model.ema_scope("Plotting"):
232
+ t0 = time.time()
233
+
234
+ sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape,
235
+ eta=eta,
236
+ quantize_x0=quantize_x0, mask=None, x0=z0,
237
+ temperature=temperature, score_corrector=corrector, corrector_kwargs=corrector_kwargs,
238
+ x_t=x_T)
239
+ t1 = time.time()
240
+
241
+ if ddim_use_x0_pred:
242
+ sample = intermediates['pred_x0'][-1]
243
+
244
+ x_sample = model.decode_first_stage(sample)
245
+
246
+ try:
247
+ x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
248
+ log["sample_noquant"] = x_sample_noquant
249
+ log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
250
+ except:
251
+ pass
252
+
253
+ log["sample"] = x_sample
254
+ log["time"] = t1 - t0
255
+
256
+ return log
extensions-builtin/LDSR/preload.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import os
2
+ from modules import paths
3
+
4
+
5
+ def preload(parser):
6
+ parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(paths.models_path, 'LDSR'))
extensions-builtin/LDSR/scripts/ldsr_model.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import traceback
4
+
5
+ from basicsr.utils.download_util import load_file_from_url
6
+
7
+ from modules.upscaler import Upscaler, UpscalerData
8
+ from ldsr_model_arch import LDSR
9
+ from modules import shared, script_callbacks
10
+ import sd_hijack_autoencoder, sd_hijack_ddpm_v1
11
+
12
+
13
+ class UpscalerLDSR(Upscaler):
14
+ def __init__(self, user_path):
15
+ self.name = "LDSR"
16
+ self.user_path = user_path
17
+ self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
18
+ self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
19
+ super().__init__()
20
+ scaler_data = UpscalerData("LDSR", None, self)
21
+ self.scalers = [scaler_data]
22
+
23
+ def load_model(self, path: str):
24
+ # Remove incorrect project.yaml file if too big
25
+ yaml_path = os.path.join(self.model_path, "project.yaml")
26
+ old_model_path = os.path.join(self.model_path, "model.pth")
27
+ new_model_path = os.path.join(self.model_path, "model.ckpt")
28
+ safetensors_model_path = os.path.join(self.model_path, "model.safetensors")
29
+ if os.path.exists(yaml_path):
30
+ statinfo = os.stat(yaml_path)
31
+ if statinfo.st_size >= 10485760:
32
+ print("Removing invalid LDSR YAML file.")
33
+ os.remove(yaml_path)
34
+ if os.path.exists(old_model_path):
35
+ print("Renaming model from model.pth to model.ckpt")
36
+ os.rename(old_model_path, new_model_path)
37
+ if os.path.exists(safetensors_model_path):
38
+ model = safetensors_model_path
39
+ else:
40
+ model = load_file_from_url(url=self.model_url, model_dir=self.model_path,
41
+ file_name="model.ckpt", progress=True)
42
+ yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path,
43
+ file_name="project.yaml", progress=True)
44
+
45
+ try:
46
+ return LDSR(model, yaml)
47
+
48
+ except Exception:
49
+ print("Error importing LDSR:", file=sys.stderr)
50
+ print(traceback.format_exc(), file=sys.stderr)
51
+ return None
52
+
53
+ def do_upscale(self, img, path):
54
+ ldsr = self.load_model(path)
55
+ if ldsr is None:
56
+ print("NO LDSR!")
57
+ return img
58
+ ddim_steps = shared.opts.ldsr_steps
59
+ return ldsr.super_resolution(img, ddim_steps, self.scale)
60
+
61
+
62
+ def on_ui_settings():
63
+ import gradio as gr
64
+
65
+ shared.opts.add_option("ldsr_steps", shared.OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}, section=('upscaling', "Upscaling")))
66
+ shared.opts.add_option("ldsr_cached", shared.OptionInfo(False, "Cache LDSR model in memory", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")))
67
+
68
+
69
+ script_callbacks.on_ui_settings(on_ui_settings)
extensions-builtin/LDSR/sd_hijack_autoencoder.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo
2
+ # The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo
3
+ # As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder
4
+
5
+ import torch
6
+ import pytorch_lightning as pl
7
+ import torch.nn.functional as F
8
+ from contextlib import contextmanager
9
+ from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
10
+ from ldm.modules.diffusionmodules.model import Encoder, Decoder
11
+ from ldm.util import instantiate_from_config
12
+
13
+ import ldm.models.autoencoder
14
+
15
+ class VQModel(pl.LightningModule):
16
+ def __init__(self,
17
+ ddconfig,
18
+ lossconfig,
19
+ n_embed,
20
+ embed_dim,
21
+ ckpt_path=None,
22
+ ignore_keys=[],
23
+ image_key="image",
24
+ colorize_nlabels=None,
25
+ monitor=None,
26
+ batch_resize_range=None,
27
+ scheduler_config=None,
28
+ lr_g_factor=1.0,
29
+ remap=None,
30
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
31
+ use_ema=False
32
+ ):
33
+ super().__init__()
34
+ self.embed_dim = embed_dim
35
+ self.n_embed = n_embed
36
+ self.image_key = image_key
37
+ self.encoder = Encoder(**ddconfig)
38
+ self.decoder = Decoder(**ddconfig)
39
+ self.loss = instantiate_from_config(lossconfig)
40
+ self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
41
+ remap=remap,
42
+ sane_index_shape=sane_index_shape)
43
+ self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
44
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
45
+ if colorize_nlabels is not None:
46
+ assert type(colorize_nlabels)==int
47
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
48
+ if monitor is not None:
49
+ self.monitor = monitor
50
+ self.batch_resize_range = batch_resize_range
51
+ if self.batch_resize_range is not None:
52
+ print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
53
+
54
+ self.use_ema = use_ema
55
+ if self.use_ema:
56
+ self.model_ema = LitEma(self)
57
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
58
+
59
+ if ckpt_path is not None:
60
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
61
+ self.scheduler_config = scheduler_config
62
+ self.lr_g_factor = lr_g_factor
63
+
64
+ @contextmanager
65
+ def ema_scope(self, context=None):
66
+ if self.use_ema:
67
+ self.model_ema.store(self.parameters())
68
+ self.model_ema.copy_to(self)
69
+ if context is not None:
70
+ print(f"{context}: Switched to EMA weights")
71
+ try:
72
+ yield None
73
+ finally:
74
+ if self.use_ema:
75
+ self.model_ema.restore(self.parameters())
76
+ if context is not None:
77
+ print(f"{context}: Restored training weights")
78
+
79
+ def init_from_ckpt(self, path, ignore_keys=list()):
80
+ sd = torch.load(path, map_location="cpu")["state_dict"]
81
+ keys = list(sd.keys())
82
+ for k in keys:
83
+ for ik in ignore_keys:
84
+ if k.startswith(ik):
85
+ print("Deleting key {} from state_dict.".format(k))
86
+ del sd[k]
87
+ missing, unexpected = self.load_state_dict(sd, strict=False)
88
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
89
+ if len(missing) > 0:
90
+ print(f"Missing Keys: {missing}")
91
+ print(f"Unexpected Keys: {unexpected}")
92
+
93
+ def on_train_batch_end(self, *args, **kwargs):
94
+ if self.use_ema:
95
+ self.model_ema(self)
96
+
97
+ def encode(self, x):
98
+ h = self.encoder(x)
99
+ h = self.quant_conv(h)
100
+ quant, emb_loss, info = self.quantize(h)
101
+ return quant, emb_loss, info
102
+
103
+ def encode_to_prequant(self, x):
104
+ h = self.encoder(x)
105
+ h = self.quant_conv(h)
106
+ return h
107
+
108
+ def decode(self, quant):
109
+ quant = self.post_quant_conv(quant)
110
+ dec = self.decoder(quant)
111
+ return dec
112
+
113
+ def decode_code(self, code_b):
114
+ quant_b = self.quantize.embed_code(code_b)
115
+ dec = self.decode(quant_b)
116
+ return dec
117
+
118
+ def forward(self, input, return_pred_indices=False):
119
+ quant, diff, (_,_,ind) = self.encode(input)
120
+ dec = self.decode(quant)
121
+ if return_pred_indices:
122
+ return dec, diff, ind
123
+ return dec, diff
124
+
125
+ def get_input(self, batch, k):
126
+ x = batch[k]
127
+ if len(x.shape) == 3:
128
+ x = x[..., None]
129
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
130
+ if self.batch_resize_range is not None:
131
+ lower_size = self.batch_resize_range[0]
132
+ upper_size = self.batch_resize_range[1]
133
+ if self.global_step <= 4:
134
+ # do the first few batches with max size to avoid later oom
135
+ new_resize = upper_size
136
+ else:
137
+ new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
138
+ if new_resize != x.shape[2]:
139
+ x = F.interpolate(x, size=new_resize, mode="bicubic")
140
+ x = x.detach()
141
+ return x
142
+
143
+ def training_step(self, batch, batch_idx, optimizer_idx):
144
+ # https://github.com/pytorch/pytorch/issues/37142
145
+ # try not to fool the heuristics
146
+ x = self.get_input(batch, self.image_key)
147
+ xrec, qloss, ind = self(x, return_pred_indices=True)
148
+
149
+ if optimizer_idx == 0:
150
+ # autoencode
151
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
152
+ last_layer=self.get_last_layer(), split="train",
153
+ predicted_indices=ind)
154
+
155
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
156
+ return aeloss
157
+
158
+ if optimizer_idx == 1:
159
+ # discriminator
160
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
161
+ last_layer=self.get_last_layer(), split="train")
162
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
163
+ return discloss
164
+
165
+ def validation_step(self, batch, batch_idx):
166
+ log_dict = self._validation_step(batch, batch_idx)
167
+ with self.ema_scope():
168
+ log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
169
+ return log_dict
170
+
171
+ def _validation_step(self, batch, batch_idx, suffix=""):
172
+ x = self.get_input(batch, self.image_key)
173
+ xrec, qloss, ind = self(x, return_pred_indices=True)
174
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
175
+ self.global_step,
176
+ last_layer=self.get_last_layer(),
177
+ split="val"+suffix,
178
+ predicted_indices=ind
179
+ )
180
+
181
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
182
+ self.global_step,
183
+ last_layer=self.get_last_layer(),
184
+ split="val"+suffix,
185
+ predicted_indices=ind
186
+ )
187
+ rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
188
+ self.log(f"val{suffix}/rec_loss", rec_loss,
189
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
190
+ self.log(f"val{suffix}/aeloss", aeloss,
191
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
192
+ if version.parse(pl.__version__) >= version.parse('1.4.0'):
193
+ del log_dict_ae[f"val{suffix}/rec_loss"]
194
+ self.log_dict(log_dict_ae)
195
+ self.log_dict(log_dict_disc)
196
+ return self.log_dict
197
+
198
+ def configure_optimizers(self):
199
+ lr_d = self.learning_rate
200
+ lr_g = self.lr_g_factor*self.learning_rate
201
+ print("lr_d", lr_d)
202
+ print("lr_g", lr_g)
203
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
204
+ list(self.decoder.parameters())+
205
+ list(self.quantize.parameters())+
206
+ list(self.quant_conv.parameters())+
207
+ list(self.post_quant_conv.parameters()),
208
+ lr=lr_g, betas=(0.5, 0.9))
209
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
210
+ lr=lr_d, betas=(0.5, 0.9))
211
+
212
+ if self.scheduler_config is not None:
213
+ scheduler = instantiate_from_config(self.scheduler_config)
214
+
215
+ print("Setting up LambdaLR scheduler...")
216
+ scheduler = [
217
+ {
218
+ 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
219
+ 'interval': 'step',
220
+ 'frequency': 1
221
+ },
222
+ {
223
+ 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
224
+ 'interval': 'step',
225
+ 'frequency': 1
226
+ },
227
+ ]
228
+ return [opt_ae, opt_disc], scheduler
229
+ return [opt_ae, opt_disc], []
230
+
231
+ def get_last_layer(self):
232
+ return self.decoder.conv_out.weight
233
+
234
+ def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
235
+ log = dict()
236
+ x = self.get_input(batch, self.image_key)
237
+ x = x.to(self.device)
238
+ if only_inputs:
239
+ log["inputs"] = x
240
+ return log
241
+ xrec, _ = self(x)
242
+ if x.shape[1] > 3:
243
+ # colorize with random projection
244
+ assert xrec.shape[1] > 3
245
+ x = self.to_rgb(x)
246
+ xrec = self.to_rgb(xrec)
247
+ log["inputs"] = x
248
+ log["reconstructions"] = xrec
249
+ if plot_ema:
250
+ with self.ema_scope():
251
+ xrec_ema, _ = self(x)
252
+ if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
253
+ log["reconstructions_ema"] = xrec_ema
254
+ return log
255
+
256
+ def to_rgb(self, x):
257
+ assert self.image_key == "segmentation"
258
+ if not hasattr(self, "colorize"):
259
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
260
+ x = F.conv2d(x, weight=self.colorize)
261
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
262
+ return x
263
+
264
+
265
+ class VQModelInterface(VQModel):
266
+ def __init__(self, embed_dim, *args, **kwargs):
267
+ super().__init__(embed_dim=embed_dim, *args, **kwargs)
268
+ self.embed_dim = embed_dim
269
+
270
+ def encode(self, x):
271
+ h = self.encoder(x)
272
+ h = self.quant_conv(h)
273
+ return h
274
+
275
+ def decode(self, h, force_not_quantize=False):
276
+ # also go through quantization layer
277
+ if not force_not_quantize:
278
+ quant, emb_loss, info = self.quantize(h)
279
+ else:
280
+ quant = h
281
+ quant = self.post_quant_conv(quant)
282
+ dec = self.decoder(quant)
283
+ return dec
284
+
285
+ setattr(ldm.models.autoencoder, "VQModel", VQModel)
286
+ setattr(ldm.models.autoencoder, "VQModelInterface", VQModelInterface)
extensions-builtin/LDSR/sd_hijack_ddpm_v1.py ADDED
@@ -0,0 +1,1449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This script is copied from the compvis/stable-diffusion repo (aka the SD V1 repo)
2
+ # Original filename: ldm/models/diffusion/ddpm.py
3
+ # The purpose to reinstate the old DDPM logic which works with VQ, whereas the V2 one doesn't
4
+ # Some models such as LDSR require VQ to work correctly
5
+ # The classes are suffixed with "V1" and added back to the "ldm.models.diffusion.ddpm" module
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import numpy as np
10
+ import pytorch_lightning as pl
11
+ from torch.optim.lr_scheduler import LambdaLR
12
+ from einops import rearrange, repeat
13
+ from contextlib import contextmanager
14
+ from functools import partial
15
+ from tqdm import tqdm
16
+ from torchvision.utils import make_grid
17
+ from pytorch_lightning.utilities.distributed import rank_zero_only
18
+
19
+ from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
20
+ from ldm.modules.ema import LitEma
21
+ from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
22
+ from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
23
+ from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
24
+ from ldm.models.diffusion.ddim import DDIMSampler
25
+
26
+ import ldm.models.diffusion.ddpm
27
+
28
+ __conditioning_keys__ = {'concat': 'c_concat',
29
+ 'crossattn': 'c_crossattn',
30
+ 'adm': 'y'}
31
+
32
+
33
+ def disabled_train(self, mode=True):
34
+ """Overwrite model.train with this function to make sure train/eval mode
35
+ does not change anymore."""
36
+ return self
37
+
38
+
39
+ def uniform_on_device(r1, r2, shape, device):
40
+ return (r1 - r2) * torch.rand(*shape, device=device) + r2
41
+
42
+
43
+ class DDPMV1(pl.LightningModule):
44
+ # classic DDPM with Gaussian diffusion, in image space
45
+ def __init__(self,
46
+ unet_config,
47
+ timesteps=1000,
48
+ beta_schedule="linear",
49
+ loss_type="l2",
50
+ ckpt_path=None,
51
+ ignore_keys=[],
52
+ load_only_unet=False,
53
+ monitor="val/loss",
54
+ use_ema=True,
55
+ first_stage_key="image",
56
+ image_size=256,
57
+ channels=3,
58
+ log_every_t=100,
59
+ clip_denoised=True,
60
+ linear_start=1e-4,
61
+ linear_end=2e-2,
62
+ cosine_s=8e-3,
63
+ given_betas=None,
64
+ original_elbo_weight=0.,
65
+ v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
66
+ l_simple_weight=1.,
67
+ conditioning_key=None,
68
+ parameterization="eps", # all assuming fixed variance schedules
69
+ scheduler_config=None,
70
+ use_positional_encodings=False,
71
+ learn_logvar=False,
72
+ logvar_init=0.,
73
+ ):
74
+ super().__init__()
75
+ assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
76
+ self.parameterization = parameterization
77
+ print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
78
+ self.cond_stage_model = None
79
+ self.clip_denoised = clip_denoised
80
+ self.log_every_t = log_every_t
81
+ self.first_stage_key = first_stage_key
82
+ self.image_size = image_size # try conv?
83
+ self.channels = channels
84
+ self.use_positional_encodings = use_positional_encodings
85
+ self.model = DiffusionWrapperV1(unet_config, conditioning_key)
86
+ count_params(self.model, verbose=True)
87
+ self.use_ema = use_ema
88
+ if self.use_ema:
89
+ self.model_ema = LitEma(self.model)
90
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
91
+
92
+ self.use_scheduler = scheduler_config is not None
93
+ if self.use_scheduler:
94
+ self.scheduler_config = scheduler_config
95
+
96
+ self.v_posterior = v_posterior
97
+ self.original_elbo_weight = original_elbo_weight
98
+ self.l_simple_weight = l_simple_weight
99
+
100
+ if monitor is not None:
101
+ self.monitor = monitor
102
+ if ckpt_path is not None:
103
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
104
+
105
+ self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
106
+ linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
107
+
108
+ self.loss_type = loss_type
109
+
110
+ self.learn_logvar = learn_logvar
111
+ self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
112
+ if self.learn_logvar:
113
+ self.logvar = nn.Parameter(self.logvar, requires_grad=True)
114
+
115
+
116
+ def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
117
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
118
+ if exists(given_betas):
119
+ betas = given_betas
120
+ else:
121
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
122
+ cosine_s=cosine_s)
123
+ alphas = 1. - betas
124
+ alphas_cumprod = np.cumprod(alphas, axis=0)
125
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
126
+
127
+ timesteps, = betas.shape
128
+ self.num_timesteps = int(timesteps)
129
+ self.linear_start = linear_start
130
+ self.linear_end = linear_end
131
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
132
+
133
+ to_torch = partial(torch.tensor, dtype=torch.float32)
134
+
135
+ self.register_buffer('betas', to_torch(betas))
136
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
137
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
138
+
139
+ # calculations for diffusion q(x_t | x_{t-1}) and others
140
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
141
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
142
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
143
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
144
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
145
+
146
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
147
+ posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
148
+ 1. - alphas_cumprod) + self.v_posterior * betas
149
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
150
+ self.register_buffer('posterior_variance', to_torch(posterior_variance))
151
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
152
+ self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
153
+ self.register_buffer('posterior_mean_coef1', to_torch(
154
+ betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
155
+ self.register_buffer('posterior_mean_coef2', to_torch(
156
+ (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
157
+
158
+ if self.parameterization == "eps":
159
+ lvlb_weights = self.betas ** 2 / (
160
+ 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
161
+ elif self.parameterization == "x0":
162
+ lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
163
+ else:
164
+ raise NotImplementedError("mu not supported")
165
+ # TODO how to choose this term
166
+ lvlb_weights[0] = lvlb_weights[1]
167
+ self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
168
+ assert not torch.isnan(self.lvlb_weights).all()
169
+
170
+ @contextmanager
171
+ def ema_scope(self, context=None):
172
+ if self.use_ema:
173
+ self.model_ema.store(self.model.parameters())
174
+ self.model_ema.copy_to(self.model)
175
+ if context is not None:
176
+ print(f"{context}: Switched to EMA weights")
177
+ try:
178
+ yield None
179
+ finally:
180
+ if self.use_ema:
181
+ self.model_ema.restore(self.model.parameters())
182
+ if context is not None:
183
+ print(f"{context}: Restored training weights")
184
+
185
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
186
+ sd = torch.load(path, map_location="cpu")
187
+ if "state_dict" in list(sd.keys()):
188
+ sd = sd["state_dict"]
189
+ keys = list(sd.keys())
190
+ for k in keys:
191
+ for ik in ignore_keys:
192
+ if k.startswith(ik):
193
+ print("Deleting key {} from state_dict.".format(k))
194
+ del sd[k]
195
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
196
+ sd, strict=False)
197
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
198
+ if len(missing) > 0:
199
+ print(f"Missing Keys: {missing}")
200
+ if len(unexpected) > 0:
201
+ print(f"Unexpected Keys: {unexpected}")
202
+
203
+ def q_mean_variance(self, x_start, t):
204
+ """
205
+ Get the distribution q(x_t | x_0).
206
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
207
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
208
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
209
+ """
210
+ mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
211
+ variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
212
+ log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
213
+ return mean, variance, log_variance
214
+
215
+ def predict_start_from_noise(self, x_t, t, noise):
216
+ return (
217
+ extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
218
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
219
+ )
220
+
221
+ def q_posterior(self, x_start, x_t, t):
222
+ posterior_mean = (
223
+ extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
224
+ extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
225
+ )
226
+ posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
227
+ posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
228
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
229
+
230
+ def p_mean_variance(self, x, t, clip_denoised: bool):
231
+ model_out = self.model(x, t)
232
+ if self.parameterization == "eps":
233
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
234
+ elif self.parameterization == "x0":
235
+ x_recon = model_out
236
+ if clip_denoised:
237
+ x_recon.clamp_(-1., 1.)
238
+
239
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
240
+ return model_mean, posterior_variance, posterior_log_variance
241
+
242
+ @torch.no_grad()
243
+ def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
244
+ b, *_, device = *x.shape, x.device
245
+ model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
246
+ noise = noise_like(x.shape, device, repeat_noise)
247
+ # no noise when t == 0
248
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
249
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
250
+
251
+ @torch.no_grad()
252
+ def p_sample_loop(self, shape, return_intermediates=False):
253
+ device = self.betas.device
254
+ b = shape[0]
255
+ img = torch.randn(shape, device=device)
256
+ intermediates = [img]
257
+ for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
258
+ img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
259
+ clip_denoised=self.clip_denoised)
260
+ if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
261
+ intermediates.append(img)
262
+ if return_intermediates:
263
+ return img, intermediates
264
+ return img
265
+
266
+ @torch.no_grad()
267
+ def sample(self, batch_size=16, return_intermediates=False):
268
+ image_size = self.image_size
269
+ channels = self.channels
270
+ return self.p_sample_loop((batch_size, channels, image_size, image_size),
271
+ return_intermediates=return_intermediates)
272
+
273
+ def q_sample(self, x_start, t, noise=None):
274
+ noise = default(noise, lambda: torch.randn_like(x_start))
275
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
276
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
277
+
278
+ def get_loss(self, pred, target, mean=True):
279
+ if self.loss_type == 'l1':
280
+ loss = (target - pred).abs()
281
+ if mean:
282
+ loss = loss.mean()
283
+ elif self.loss_type == 'l2':
284
+ if mean:
285
+ loss = torch.nn.functional.mse_loss(target, pred)
286
+ else:
287
+ loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
288
+ else:
289
+ raise NotImplementedError("unknown loss type '{loss_type}'")
290
+
291
+ return loss
292
+
293
+ def p_losses(self, x_start, t, noise=None):
294
+ noise = default(noise, lambda: torch.randn_like(x_start))
295
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
296
+ model_out = self.model(x_noisy, t)
297
+
298
+ loss_dict = {}
299
+ if self.parameterization == "eps":
300
+ target = noise
301
+ elif self.parameterization == "x0":
302
+ target = x_start
303
+ else:
304
+ raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")
305
+
306
+ loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
307
+
308
+ log_prefix = 'train' if self.training else 'val'
309
+
310
+ loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
311
+ loss_simple = loss.mean() * self.l_simple_weight
312
+
313
+ loss_vlb = (self.lvlb_weights[t] * loss).mean()
314
+ loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
315
+
316
+ loss = loss_simple + self.original_elbo_weight * loss_vlb
317
+
318
+ loss_dict.update({f'{log_prefix}/loss': loss})
319
+
320
+ return loss, loss_dict
321
+
322
+ def forward(self, x, *args, **kwargs):
323
+ # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
324
+ # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
325
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
326
+ return self.p_losses(x, t, *args, **kwargs)
327
+
328
+ def get_input(self, batch, k):
329
+ x = batch[k]
330
+ if len(x.shape) == 3:
331
+ x = x[..., None]
332
+ x = rearrange(x, 'b h w c -> b c h w')
333
+ x = x.to(memory_format=torch.contiguous_format).float()
334
+ return x
335
+
336
+ def shared_step(self, batch):
337
+ x = self.get_input(batch, self.first_stage_key)
338
+ loss, loss_dict = self(x)
339
+ return loss, loss_dict
340
+
341
+ def training_step(self, batch, batch_idx):
342
+ loss, loss_dict = self.shared_step(batch)
343
+
344
+ self.log_dict(loss_dict, prog_bar=True,
345
+ logger=True, on_step=True, on_epoch=True)
346
+
347
+ self.log("global_step", self.global_step,
348
+ prog_bar=True, logger=True, on_step=True, on_epoch=False)
349
+
350
+ if self.use_scheduler:
351
+ lr = self.optimizers().param_groups[0]['lr']
352
+ self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
353
+
354
+ return loss
355
+
356
+ @torch.no_grad()
357
+ def validation_step(self, batch, batch_idx):
358
+ _, loss_dict_no_ema = self.shared_step(batch)
359
+ with self.ema_scope():
360
+ _, loss_dict_ema = self.shared_step(batch)
361
+ loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
362
+ self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
363
+ self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
364
+
365
+ def on_train_batch_end(self, *args, **kwargs):
366
+ if self.use_ema:
367
+ self.model_ema(self.model)
368
+
369
+ def _get_rows_from_list(self, samples):
370
+ n_imgs_per_row = len(samples)
371
+ denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
372
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
373
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
374
+ return denoise_grid
375
+
376
+ @torch.no_grad()
377
+ def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
378
+ log = dict()
379
+ x = self.get_input(batch, self.first_stage_key)
380
+ N = min(x.shape[0], N)
381
+ n_row = min(x.shape[0], n_row)
382
+ x = x.to(self.device)[:N]
383
+ log["inputs"] = x
384
+
385
+ # get diffusion row
386
+ diffusion_row = list()
387
+ x_start = x[:n_row]
388
+
389
+ for t in range(self.num_timesteps):
390
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
391
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
392
+ t = t.to(self.device).long()
393
+ noise = torch.randn_like(x_start)
394
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
395
+ diffusion_row.append(x_noisy)
396
+
397
+ log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
398
+
399
+ if sample:
400
+ # get denoise row
401
+ with self.ema_scope("Plotting"):
402
+ samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
403
+
404
+ log["samples"] = samples
405
+ log["denoise_row"] = self._get_rows_from_list(denoise_row)
406
+
407
+ if return_keys:
408
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
409
+ return log
410
+ else:
411
+ return {key: log[key] for key in return_keys}
412
+ return log
413
+
414
+ def configure_optimizers(self):
415
+ lr = self.learning_rate
416
+ params = list(self.model.parameters())
417
+ if self.learn_logvar:
418
+ params = params + [self.logvar]
419
+ opt = torch.optim.AdamW(params, lr=lr)
420
+ return opt
421
+
422
+
423
+ class LatentDiffusionV1(DDPMV1):
424
+ """main class"""
425
+ def __init__(self,
426
+ first_stage_config,
427
+ cond_stage_config,
428
+ num_timesteps_cond=None,
429
+ cond_stage_key="image",
430
+ cond_stage_trainable=False,
431
+ concat_mode=True,
432
+ cond_stage_forward=None,
433
+ conditioning_key=None,
434
+ scale_factor=1.0,
435
+ scale_by_std=False,
436
+ *args, **kwargs):
437
+ self.num_timesteps_cond = default(num_timesteps_cond, 1)
438
+ self.scale_by_std = scale_by_std
439
+ assert self.num_timesteps_cond <= kwargs['timesteps']
440
+ # for backwards compatibility after implementation of DiffusionWrapper
441
+ if conditioning_key is None:
442
+ conditioning_key = 'concat' if concat_mode else 'crossattn'
443
+ if cond_stage_config == '__is_unconditional__':
444
+ conditioning_key = None
445
+ ckpt_path = kwargs.pop("ckpt_path", None)
446
+ ignore_keys = kwargs.pop("ignore_keys", [])
447
+ super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
448
+ self.concat_mode = concat_mode
449
+ self.cond_stage_trainable = cond_stage_trainable
450
+ self.cond_stage_key = cond_stage_key
451
+ try:
452
+ self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
453
+ except:
454
+ self.num_downs = 0
455
+ if not scale_by_std:
456
+ self.scale_factor = scale_factor
457
+ else:
458
+ self.register_buffer('scale_factor', torch.tensor(scale_factor))
459
+ self.instantiate_first_stage(first_stage_config)
460
+ self.instantiate_cond_stage(cond_stage_config)
461
+ self.cond_stage_forward = cond_stage_forward
462
+ self.clip_denoised = False
463
+ self.bbox_tokenizer = None
464
+
465
+ self.restarted_from_ckpt = False
466
+ if ckpt_path is not None:
467
+ self.init_from_ckpt(ckpt_path, ignore_keys)
468
+ self.restarted_from_ckpt = True
469
+
470
+ def make_cond_schedule(self, ):
471
+ self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
472
+ ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
473
+ self.cond_ids[:self.num_timesteps_cond] = ids
474
+
475
+ @rank_zero_only
476
+ @torch.no_grad()
477
+ def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
478
+ # only for very first batch
479
+ if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
480
+ assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
481
+ # set rescale weight to 1./std of encodings
482
+ print("### USING STD-RESCALING ###")
483
+ x = super().get_input(batch, self.first_stage_key)
484
+ x = x.to(self.device)
485
+ encoder_posterior = self.encode_first_stage(x)
486
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
487
+ del self.scale_factor
488
+ self.register_buffer('scale_factor', 1. / z.flatten().std())
489
+ print(f"setting self.scale_factor to {self.scale_factor}")
490
+ print("### USING STD-RESCALING ###")
491
+
492
+ def register_schedule(self,
493
+ given_betas=None, beta_schedule="linear", timesteps=1000,
494
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
495
+ super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
496
+
497
+ self.shorten_cond_schedule = self.num_timesteps_cond > 1
498
+ if self.shorten_cond_schedule:
499
+ self.make_cond_schedule()
500
+
501
+ def instantiate_first_stage(self, config):
502
+ model = instantiate_from_config(config)
503
+ self.first_stage_model = model.eval()
504
+ self.first_stage_model.train = disabled_train
505
+ for param in self.first_stage_model.parameters():
506
+ param.requires_grad = False
507
+
508
+ def instantiate_cond_stage(self, config):
509
+ if not self.cond_stage_trainable:
510
+ if config == "__is_first_stage__":
511
+ print("Using first stage also as cond stage.")
512
+ self.cond_stage_model = self.first_stage_model
513
+ elif config == "__is_unconditional__":
514
+ print(f"Training {self.__class__.__name__} as an unconditional model.")
515
+ self.cond_stage_model = None
516
+ # self.be_unconditional = True
517
+ else:
518
+ model = instantiate_from_config(config)
519
+ self.cond_stage_model = model.eval()
520
+ self.cond_stage_model.train = disabled_train
521
+ for param in self.cond_stage_model.parameters():
522
+ param.requires_grad = False
523
+ else:
524
+ assert config != '__is_first_stage__'
525
+ assert config != '__is_unconditional__'
526
+ model = instantiate_from_config(config)
527
+ self.cond_stage_model = model
528
+
529
+ def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
530
+ denoise_row = []
531
+ for zd in tqdm(samples, desc=desc):
532
+ denoise_row.append(self.decode_first_stage(zd.to(self.device),
533
+ force_not_quantize=force_no_decoder_quantization))
534
+ n_imgs_per_row = len(denoise_row)
535
+ denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
536
+ denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
537
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
538
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
539
+ return denoise_grid
540
+
541
+ def get_first_stage_encoding(self, encoder_posterior):
542
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
543
+ z = encoder_posterior.sample()
544
+ elif isinstance(encoder_posterior, torch.Tensor):
545
+ z = encoder_posterior
546
+ else:
547
+ raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
548
+ return self.scale_factor * z
549
+
550
+ def get_learned_conditioning(self, c):
551
+ if self.cond_stage_forward is None:
552
+ if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
553
+ c = self.cond_stage_model.encode(c)
554
+ if isinstance(c, DiagonalGaussianDistribution):
555
+ c = c.mode()
556
+ else:
557
+ c = self.cond_stage_model(c)
558
+ else:
559
+ assert hasattr(self.cond_stage_model, self.cond_stage_forward)
560
+ c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
561
+ return c
562
+
563
+ def meshgrid(self, h, w):
564
+ y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
565
+ x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
566
+
567
+ arr = torch.cat([y, x], dim=-1)
568
+ return arr
569
+
570
+ def delta_border(self, h, w):
571
+ """
572
+ :param h: height
573
+ :param w: width
574
+ :return: normalized distance to image border,
575
+ wtith min distance = 0 at border and max dist = 0.5 at image center
576
+ """
577
+ lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
578
+ arr = self.meshgrid(h, w) / lower_right_corner
579
+ dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
580
+ dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
581
+ edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
582
+ return edge_dist
583
+
584
+ def get_weighting(self, h, w, Ly, Lx, device):
585
+ weighting = self.delta_border(h, w)
586
+ weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
587
+ self.split_input_params["clip_max_weight"], )
588
+ weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
589
+
590
+ if self.split_input_params["tie_braker"]:
591
+ L_weighting = self.delta_border(Ly, Lx)
592
+ L_weighting = torch.clip(L_weighting,
593
+ self.split_input_params["clip_min_tie_weight"],
594
+ self.split_input_params["clip_max_tie_weight"])
595
+
596
+ L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
597
+ weighting = weighting * L_weighting
598
+ return weighting
599
+
600
+ def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
601
+ """
602
+ :param x: img of size (bs, c, h, w)
603
+ :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
604
+ """
605
+ bs, nc, h, w = x.shape
606
+
607
+ # number of crops in image
608
+ Ly = (h - kernel_size[0]) // stride[0] + 1
609
+ Lx = (w - kernel_size[1]) // stride[1] + 1
610
+
611
+ if uf == 1 and df == 1:
612
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
613
+ unfold = torch.nn.Unfold(**fold_params)
614
+
615
+ fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
616
+
617
+ weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
618
+ normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
619
+ weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
620
+
621
+ elif uf > 1 and df == 1:
622
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
623
+ unfold = torch.nn.Unfold(**fold_params)
624
+
625
+ fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
626
+ dilation=1, padding=0,
627
+ stride=(stride[0] * uf, stride[1] * uf))
628
+ fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
629
+
630
+ weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
631
+ normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
632
+ weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
633
+
634
+ elif df > 1 and uf == 1:
635
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
636
+ unfold = torch.nn.Unfold(**fold_params)
637
+
638
+ fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
639
+ dilation=1, padding=0,
640
+ stride=(stride[0] // df, stride[1] // df))
641
+ fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
642
+
643
+ weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
644
+ normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
645
+ weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
646
+
647
+ else:
648
+ raise NotImplementedError
649
+
650
+ return fold, unfold, normalization, weighting
651
+
652
+ @torch.no_grad()
653
+ def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
654
+ cond_key=None, return_original_cond=False, bs=None):
655
+ x = super().get_input(batch, k)
656
+ if bs is not None:
657
+ x = x[:bs]
658
+ x = x.to(self.device)
659
+ encoder_posterior = self.encode_first_stage(x)
660
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
661
+
662
+ if self.model.conditioning_key is not None:
663
+ if cond_key is None:
664
+ cond_key = self.cond_stage_key
665
+ if cond_key != self.first_stage_key:
666
+ if cond_key in ['caption', 'coordinates_bbox']:
667
+ xc = batch[cond_key]
668
+ elif cond_key == 'class_label':
669
+ xc = batch
670
+ else:
671
+ xc = super().get_input(batch, cond_key).to(self.device)
672
+ else:
673
+ xc = x
674
+ if not self.cond_stage_trainable or force_c_encode:
675
+ if isinstance(xc, dict) or isinstance(xc, list):
676
+ # import pudb; pudb.set_trace()
677
+ c = self.get_learned_conditioning(xc)
678
+ else:
679
+ c = self.get_learned_conditioning(xc.to(self.device))
680
+ else:
681
+ c = xc
682
+ if bs is not None:
683
+ c = c[:bs]
684
+
685
+ if self.use_positional_encodings:
686
+ pos_x, pos_y = self.compute_latent_shifts(batch)
687
+ ckey = __conditioning_keys__[self.model.conditioning_key]
688
+ c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}
689
+
690
+ else:
691
+ c = None
692
+ xc = None
693
+ if self.use_positional_encodings:
694
+ pos_x, pos_y = self.compute_latent_shifts(batch)
695
+ c = {'pos_x': pos_x, 'pos_y': pos_y}
696
+ out = [z, c]
697
+ if return_first_stage_outputs:
698
+ xrec = self.decode_first_stage(z)
699
+ out.extend([x, xrec])
700
+ if return_original_cond:
701
+ out.append(xc)
702
+ return out
703
+
704
+ @torch.no_grad()
705
+ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
706
+ if predict_cids:
707
+ if z.dim() == 4:
708
+ z = torch.argmax(z.exp(), dim=1).long()
709
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
710
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
711
+
712
+ z = 1. / self.scale_factor * z
713
+
714
+ if hasattr(self, "split_input_params"):
715
+ if self.split_input_params["patch_distributed_vq"]:
716
+ ks = self.split_input_params["ks"] # eg. (128, 128)
717
+ stride = self.split_input_params["stride"] # eg. (64, 64)
718
+ uf = self.split_input_params["vqf"]
719
+ bs, nc, h, w = z.shape
720
+ if ks[0] > h or ks[1] > w:
721
+ ks = (min(ks[0], h), min(ks[1], w))
722
+ print("reducing Kernel")
723
+
724
+ if stride[0] > h or stride[1] > w:
725
+ stride = (min(stride[0], h), min(stride[1], w))
726
+ print("reducing stride")
727
+
728
+ fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
729
+
730
+ z = unfold(z) # (bn, nc * prod(**ks), L)
731
+ # 1. Reshape to img shape
732
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
733
+
734
+ # 2. apply model loop over last dim
735
+ if isinstance(self.first_stage_model, VQModelInterface):
736
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
737
+ force_not_quantize=predict_cids or force_not_quantize)
738
+ for i in range(z.shape[-1])]
739
+ else:
740
+
741
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
742
+ for i in range(z.shape[-1])]
743
+
744
+ o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
745
+ o = o * weighting
746
+ # Reverse 1. reshape to img shape
747
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
748
+ # stitch crops together
749
+ decoded = fold(o)
750
+ decoded = decoded / normalization # norm is shape (1, 1, h, w)
751
+ return decoded
752
+ else:
753
+ if isinstance(self.first_stage_model, VQModelInterface):
754
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
755
+ else:
756
+ return self.first_stage_model.decode(z)
757
+
758
+ else:
759
+ if isinstance(self.first_stage_model, VQModelInterface):
760
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
761
+ else:
762
+ return self.first_stage_model.decode(z)
763
+
764
+ # same as above but without decorator
765
+ def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
766
+ if predict_cids:
767
+ if z.dim() == 4:
768
+ z = torch.argmax(z.exp(), dim=1).long()
769
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
770
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
771
+
772
+ z = 1. / self.scale_factor * z
773
+
774
+ if hasattr(self, "split_input_params"):
775
+ if self.split_input_params["patch_distributed_vq"]:
776
+ ks = self.split_input_params["ks"] # eg. (128, 128)
777
+ stride = self.split_input_params["stride"] # eg. (64, 64)
778
+ uf = self.split_input_params["vqf"]
779
+ bs, nc, h, w = z.shape
780
+ if ks[0] > h or ks[1] > w:
781
+ ks = (min(ks[0], h), min(ks[1], w))
782
+ print("reducing Kernel")
783
+
784
+ if stride[0] > h or stride[1] > w:
785
+ stride = (min(stride[0], h), min(stride[1], w))
786
+ print("reducing stride")
787
+
788
+ fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
789
+
790
+ z = unfold(z) # (bn, nc * prod(**ks), L)
791
+ # 1. Reshape to img shape
792
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
793
+
794
+ # 2. apply model loop over last dim
795
+ if isinstance(self.first_stage_model, VQModelInterface):
796
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
797
+ force_not_quantize=predict_cids or force_not_quantize)
798
+ for i in range(z.shape[-1])]
799
+ else:
800
+
801
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
802
+ for i in range(z.shape[-1])]
803
+
804
+ o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
805
+ o = o * weighting
806
+ # Reverse 1. reshape to img shape
807
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
808
+ # stitch crops together
809
+ decoded = fold(o)
810
+ decoded = decoded / normalization # norm is shape (1, 1, h, w)
811
+ return decoded
812
+ else:
813
+ if isinstance(self.first_stage_model, VQModelInterface):
814
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
815
+ else:
816
+ return self.first_stage_model.decode(z)
817
+
818
+ else:
819
+ if isinstance(self.first_stage_model, VQModelInterface):
820
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
821
+ else:
822
+ return self.first_stage_model.decode(z)
823
+
824
+ @torch.no_grad()
825
+ def encode_first_stage(self, x):
826
+ if hasattr(self, "split_input_params"):
827
+ if self.split_input_params["patch_distributed_vq"]:
828
+ ks = self.split_input_params["ks"] # eg. (128, 128)
829
+ stride = self.split_input_params["stride"] # eg. (64, 64)
830
+ df = self.split_input_params["vqf"]
831
+ self.split_input_params['original_image_size'] = x.shape[-2:]
832
+ bs, nc, h, w = x.shape
833
+ if ks[0] > h or ks[1] > w:
834
+ ks = (min(ks[0], h), min(ks[1], w))
835
+ print("reducing Kernel")
836
+
837
+ if stride[0] > h or stride[1] > w:
838
+ stride = (min(stride[0], h), min(stride[1], w))
839
+ print("reducing stride")
840
+
841
+ fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)
842
+ z = unfold(x) # (bn, nc * prod(**ks), L)
843
+ # Reshape to img shape
844
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
845
+
846
+ output_list = [self.first_stage_model.encode(z[:, :, :, :, i])
847
+ for i in range(z.shape[-1])]
848
+
849
+ o = torch.stack(output_list, axis=-1)
850
+ o = o * weighting
851
+
852
+ # Reverse reshape to img shape
853
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
854
+ # stitch crops together
855
+ decoded = fold(o)
856
+ decoded = decoded / normalization
857
+ return decoded
858
+
859
+ else:
860
+ return self.first_stage_model.encode(x)
861
+ else:
862
+ return self.first_stage_model.encode(x)
863
+
864
+ def shared_step(self, batch, **kwargs):
865
+ x, c = self.get_input(batch, self.first_stage_key)
866
+ loss = self(x, c)
867
+ return loss
868
+
869
+ def forward(self, x, c, *args, **kwargs):
870
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
871
+ if self.model.conditioning_key is not None:
872
+ assert c is not None
873
+ if self.cond_stage_trainable:
874
+ c = self.get_learned_conditioning(c)
875
+ if self.shorten_cond_schedule: # TODO: drop this option
876
+ tc = self.cond_ids[t].to(self.device)
877
+ c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
878
+ return self.p_losses(x, c, t, *args, **kwargs)
879
+
880
+ def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
881
+ def rescale_bbox(bbox):
882
+ x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
883
+ y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
884
+ w = min(bbox[2] / crop_coordinates[2], 1 - x0)
885
+ h = min(bbox[3] / crop_coordinates[3], 1 - y0)
886
+ return x0, y0, w, h
887
+
888
+ return [rescale_bbox(b) for b in bboxes]
889
+
890
+ def apply_model(self, x_noisy, t, cond, return_ids=False):
891
+
892
+ if isinstance(cond, dict):
893
+ # hybrid case, cond is exptected to be a dict
894
+ pass
895
+ else:
896
+ if not isinstance(cond, list):
897
+ cond = [cond]
898
+ key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
899
+ cond = {key: cond}
900
+
901
+ if hasattr(self, "split_input_params"):
902
+ assert len(cond) == 1 # todo can only deal with one conditioning atm
903
+ assert not return_ids
904
+ ks = self.split_input_params["ks"] # eg. (128, 128)
905
+ stride = self.split_input_params["stride"] # eg. (64, 64)
906
+
907
+ h, w = x_noisy.shape[-2:]
908
+
909
+ fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)
910
+
911
+ z = unfold(x_noisy) # (bn, nc * prod(**ks), L)
912
+ # Reshape to img shape
913
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
914
+ z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
915
+
916
+ if self.cond_stage_key in ["image", "LR_image", "segmentation",
917
+ 'bbox_img'] and self.model.conditioning_key: # todo check for completeness
918
+ c_key = next(iter(cond.keys())) # get key
919
+ c = next(iter(cond.values())) # get value
920
+ assert (len(c) == 1) # todo extend to list with more than one elem
921
+ c = c[0] # get element
922
+
923
+ c = unfold(c)
924
+ c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L )
925
+
926
+ cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
927
+
928
+ elif self.cond_stage_key == 'coordinates_bbox':
929
+ assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size'
930
+
931
+ # assuming padding of unfold is always 0 and its dilation is always 1
932
+ n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
933
+ full_img_h, full_img_w = self.split_input_params['original_image_size']
934
+ # as we are operating on latents, we need the factor from the original image size to the
935
+ # spatial latent size to properly rescale the crops for regenerating the bbox annotations
936
+ num_downs = self.first_stage_model.encoder.num_resolutions - 1
937
+ rescale_latent = 2 ** (num_downs)
938
+
939
+ # get top left postions of patches as conforming for the bbbox tokenizer, therefore we
940
+ # need to rescale the tl patch coordinates to be in between (0,1)
941
+ tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
942
+ rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
943
+ for patch_nr in range(z.shape[-1])]
944
+
945
+ # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)
946
+ patch_limits = [(x_tl, y_tl,
947
+ rescale_latent * ks[0] / full_img_w,
948
+ rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]
949
+ # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]
950
+
951
+ # tokenize crop coordinates for the bounding boxes of the respective patches
952
+ patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)
953
+ for bbox in patch_limits] # list of length l with tensors of shape (1, 2)
954
+ print(patch_limits_tknzd[0].shape)
955
+ # cut tknzd crop position from conditioning
956
+ assert isinstance(cond, dict), 'cond must be dict to be fed into model'
957
+ cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)
958
+ print(cut_cond.shape)
959
+
960
+ adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])
961
+ adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')
962
+ print(adapted_cond.shape)
963
+ adapted_cond = self.get_learned_conditioning(adapted_cond)
964
+ print(adapted_cond.shape)
965
+ adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])
966
+ print(adapted_cond.shape)
967
+
968
+ cond_list = [{'c_crossattn': [e]} for e in adapted_cond]
969
+
970
+ else:
971
+ cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient
972
+
973
+ # apply model by loop over crops
974
+ output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]
975
+ assert not isinstance(output_list[0],
976
+ tuple) # todo cant deal with multiple model outputs check this never happens
977
+
978
+ o = torch.stack(output_list, axis=-1)
979
+ o = o * weighting
980
+ # Reverse reshape to img shape
981
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
982
+ # stitch crops together
983
+ x_recon = fold(o) / normalization
984
+
985
+ else:
986
+ x_recon = self.model(x_noisy, t, **cond)
987
+
988
+ if isinstance(x_recon, tuple) and not return_ids:
989
+ return x_recon[0]
990
+ else:
991
+ return x_recon
992
+
993
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
994
+ return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
995
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
996
+
997
+ def _prior_bpd(self, x_start):
998
+ """
999
+ Get the prior KL term for the variational lower-bound, measured in
1000
+ bits-per-dim.
1001
+ This term can't be optimized, as it only depends on the encoder.
1002
+ :param x_start: the [N x C x ...] tensor of inputs.
1003
+ :return: a batch of [N] KL values (in bits), one per batch element.
1004
+ """
1005
+ batch_size = x_start.shape[0]
1006
+ t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
1007
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
1008
+ kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
1009
+ return mean_flat(kl_prior) / np.log(2.0)
1010
+
1011
+ def p_losses(self, x_start, cond, t, noise=None):
1012
+ noise = default(noise, lambda: torch.randn_like(x_start))
1013
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
1014
+ model_output = self.apply_model(x_noisy, t, cond)
1015
+
1016
+ loss_dict = {}
1017
+ prefix = 'train' if self.training else 'val'
1018
+
1019
+ if self.parameterization == "x0":
1020
+ target = x_start
1021
+ elif self.parameterization == "eps":
1022
+ target = noise
1023
+ else:
1024
+ raise NotImplementedError()
1025
+
1026
+ loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
1027
+ loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
1028
+
1029
+ logvar_t = self.logvar[t].to(self.device)
1030
+ loss = loss_simple / torch.exp(logvar_t) + logvar_t
1031
+ # loss = loss_simple / torch.exp(self.logvar) + self.logvar
1032
+ if self.learn_logvar:
1033
+ loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
1034
+ loss_dict.update({'logvar': self.logvar.data.mean()})
1035
+
1036
+ loss = self.l_simple_weight * loss.mean()
1037
+
1038
+ loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
1039
+ loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
1040
+ loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
1041
+ loss += (self.original_elbo_weight * loss_vlb)
1042
+ loss_dict.update({f'{prefix}/loss': loss})
1043
+
1044
+ return loss, loss_dict
1045
+
1046
+ def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
1047
+ return_x0=False, score_corrector=None, corrector_kwargs=None):
1048
+ t_in = t
1049
+ model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
1050
+
1051
+ if score_corrector is not None:
1052
+ assert self.parameterization == "eps"
1053
+ model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
1054
+
1055
+ if return_codebook_ids:
1056
+ model_out, logits = model_out
1057
+
1058
+ if self.parameterization == "eps":
1059
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
1060
+ elif self.parameterization == "x0":
1061
+ x_recon = model_out
1062
+ else:
1063
+ raise NotImplementedError()
1064
+
1065
+ if clip_denoised:
1066
+ x_recon.clamp_(-1., 1.)
1067
+ if quantize_denoised:
1068
+ x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
1069
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
1070
+ if return_codebook_ids:
1071
+ return model_mean, posterior_variance, posterior_log_variance, logits
1072
+ elif return_x0:
1073
+ return model_mean, posterior_variance, posterior_log_variance, x_recon
1074
+ else:
1075
+ return model_mean, posterior_variance, posterior_log_variance
1076
+
1077
+ @torch.no_grad()
1078
+ def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
1079
+ return_codebook_ids=False, quantize_denoised=False, return_x0=False,
1080
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
1081
+ b, *_, device = *x.shape, x.device
1082
+ outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
1083
+ return_codebook_ids=return_codebook_ids,
1084
+ quantize_denoised=quantize_denoised,
1085
+ return_x0=return_x0,
1086
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
1087
+ if return_codebook_ids:
1088
+ raise DeprecationWarning("Support dropped.")
1089
+ model_mean, _, model_log_variance, logits = outputs
1090
+ elif return_x0:
1091
+ model_mean, _, model_log_variance, x0 = outputs
1092
+ else:
1093
+ model_mean, _, model_log_variance = outputs
1094
+
1095
+ noise = noise_like(x.shape, device, repeat_noise) * temperature
1096
+ if noise_dropout > 0.:
1097
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
1098
+ # no noise when t == 0
1099
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
1100
+
1101
+ if return_codebook_ids:
1102
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
1103
+ if return_x0:
1104
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
1105
+ else:
1106
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
1107
+
1108
+ @torch.no_grad()
1109
+ def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
1110
+ img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
1111
+ score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
1112
+ log_every_t=None):
1113
+ if not log_every_t:
1114
+ log_every_t = self.log_every_t
1115
+ timesteps = self.num_timesteps
1116
+ if batch_size is not None:
1117
+ b = batch_size if batch_size is not None else shape[0]
1118
+ shape = [batch_size] + list(shape)
1119
+ else:
1120
+ b = batch_size = shape[0]
1121
+ if x_T is None:
1122
+ img = torch.randn(shape, device=self.device)
1123
+ else:
1124
+ img = x_T
1125
+ intermediates = []
1126
+ if cond is not None:
1127
+ if isinstance(cond, dict):
1128
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
1129
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
1130
+ else:
1131
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1132
+
1133
+ if start_T is not None:
1134
+ timesteps = min(timesteps, start_T)
1135
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
1136
+ total=timesteps) if verbose else reversed(
1137
+ range(0, timesteps))
1138
+ if type(temperature) == float:
1139
+ temperature = [temperature] * timesteps
1140
+
1141
+ for i in iterator:
1142
+ ts = torch.full((b,), i, device=self.device, dtype=torch.long)
1143
+ if self.shorten_cond_schedule:
1144
+ assert self.model.conditioning_key != 'hybrid'
1145
+ tc = self.cond_ids[ts].to(cond.device)
1146
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1147
+
1148
+ img, x0_partial = self.p_sample(img, cond, ts,
1149
+ clip_denoised=self.clip_denoised,
1150
+ quantize_denoised=quantize_denoised, return_x0=True,
1151
+ temperature=temperature[i], noise_dropout=noise_dropout,
1152
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
1153
+ if mask is not None:
1154
+ assert x0 is not None
1155
+ img_orig = self.q_sample(x0, ts)
1156
+ img = img_orig * mask + (1. - mask) * img
1157
+
1158
+ if i % log_every_t == 0 or i == timesteps - 1:
1159
+ intermediates.append(x0_partial)
1160
+ if callback: callback(i)
1161
+ if img_callback: img_callback(img, i)
1162
+ return img, intermediates
1163
+
1164
+ @torch.no_grad()
1165
+ def p_sample_loop(self, cond, shape, return_intermediates=False,
1166
+ x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
1167
+ mask=None, x0=None, img_callback=None, start_T=None,
1168
+ log_every_t=None):
1169
+
1170
+ if not log_every_t:
1171
+ log_every_t = self.log_every_t
1172
+ device = self.betas.device
1173
+ b = shape[0]
1174
+ if x_T is None:
1175
+ img = torch.randn(shape, device=device)
1176
+ else:
1177
+ img = x_T
1178
+
1179
+ intermediates = [img]
1180
+ if timesteps is None:
1181
+ timesteps = self.num_timesteps
1182
+
1183
+ if start_T is not None:
1184
+ timesteps = min(timesteps, start_T)
1185
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
1186
+ range(0, timesteps))
1187
+
1188
+ if mask is not None:
1189
+ assert x0 is not None
1190
+ assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
1191
+
1192
+ for i in iterator:
1193
+ ts = torch.full((b,), i, device=device, dtype=torch.long)
1194
+ if self.shorten_cond_schedule:
1195
+ assert self.model.conditioning_key != 'hybrid'
1196
+ tc = self.cond_ids[ts].to(cond.device)
1197
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1198
+
1199
+ img = self.p_sample(img, cond, ts,
1200
+ clip_denoised=self.clip_denoised,
1201
+ quantize_denoised=quantize_denoised)
1202
+ if mask is not None:
1203
+ img_orig = self.q_sample(x0, ts)
1204
+ img = img_orig * mask + (1. - mask) * img
1205
+
1206
+ if i % log_every_t == 0 or i == timesteps - 1:
1207
+ intermediates.append(img)
1208
+ if callback: callback(i)
1209
+ if img_callback: img_callback(img, i)
1210
+
1211
+ if return_intermediates:
1212
+ return img, intermediates
1213
+ return img
1214
+
1215
+ @torch.no_grad()
1216
+ def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
1217
+ verbose=True, timesteps=None, quantize_denoised=False,
1218
+ mask=None, x0=None, shape=None,**kwargs):
1219
+ if shape is None:
1220
+ shape = (batch_size, self.channels, self.image_size, self.image_size)
1221
+ if cond is not None:
1222
+ if isinstance(cond, dict):
1223
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
1224
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
1225
+ else:
1226
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1227
+ return self.p_sample_loop(cond,
1228
+ shape,
1229
+ return_intermediates=return_intermediates, x_T=x_T,
1230
+ verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
1231
+ mask=mask, x0=x0)
1232
+
1233
+ @torch.no_grad()
1234
+ def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
1235
+
1236
+ if ddim:
1237
+ ddim_sampler = DDIMSampler(self)
1238
+ shape = (self.channels, self.image_size, self.image_size)
1239
+ samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
1240
+ shape,cond,verbose=False,**kwargs)
1241
+
1242
+ else:
1243
+ samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
1244
+ return_intermediates=True,**kwargs)
1245
+
1246
+ return samples, intermediates
1247
+
1248
+
1249
+ @torch.no_grad()
1250
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
1251
+ quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
1252
+ plot_diffusion_rows=True, **kwargs):
1253
+
1254
+ use_ddim = ddim_steps is not None
1255
+
1256
+ log = dict()
1257
+ z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
1258
+ return_first_stage_outputs=True,
1259
+ force_c_encode=True,
1260
+ return_original_cond=True,
1261
+ bs=N)
1262
+ N = min(x.shape[0], N)
1263
+ n_row = min(x.shape[0], n_row)
1264
+ log["inputs"] = x
1265
+ log["reconstruction"] = xrec
1266
+ if self.model.conditioning_key is not None:
1267
+ if hasattr(self.cond_stage_model, "decode"):
1268
+ xc = self.cond_stage_model.decode(c)
1269
+ log["conditioning"] = xc
1270
+ elif self.cond_stage_key in ["caption"]:
1271
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"])
1272
+ log["conditioning"] = xc
1273
+ elif self.cond_stage_key == 'class_label':
1274
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
1275
+ log['conditioning'] = xc
1276
+ elif isimage(xc):
1277
+ log["conditioning"] = xc
1278
+ if ismap(xc):
1279
+ log["original_conditioning"] = self.to_rgb(xc)
1280
+
1281
+ if plot_diffusion_rows:
1282
+ # get diffusion row
1283
+ diffusion_row = list()
1284
+ z_start = z[:n_row]
1285
+ for t in range(self.num_timesteps):
1286
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
1287
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
1288
+ t = t.to(self.device).long()
1289
+ noise = torch.randn_like(z_start)
1290
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
1291
+ diffusion_row.append(self.decode_first_stage(z_noisy))
1292
+
1293
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
1294
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
1295
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
1296
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
1297
+ log["diffusion_row"] = diffusion_grid
1298
+
1299
+ if sample:
1300
+ # get denoise row
1301
+ with self.ema_scope("Plotting"):
1302
+ samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
1303
+ ddim_steps=ddim_steps,eta=ddim_eta)
1304
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
1305
+ x_samples = self.decode_first_stage(samples)
1306
+ log["samples"] = x_samples
1307
+ if plot_denoise_rows:
1308
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
1309
+ log["denoise_row"] = denoise_grid
1310
+
1311
+ if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
1312
+ self.first_stage_model, IdentityFirstStage):
1313
+ # also display when quantizing x0 while sampling
1314
+ with self.ema_scope("Plotting Quantized Denoised"):
1315
+ samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
1316
+ ddim_steps=ddim_steps,eta=ddim_eta,
1317
+ quantize_denoised=True)
1318
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
1319
+ # quantize_denoised=True)
1320
+ x_samples = self.decode_first_stage(samples.to(self.device))
1321
+ log["samples_x0_quantized"] = x_samples
1322
+
1323
+ if inpaint:
1324
+ # make a simple center square
1325
+ b, h, w = z.shape[0], z.shape[2], z.shape[3]
1326
+ mask = torch.ones(N, h, w).to(self.device)
1327
+ # zeros will be filled in
1328
+ mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
1329
+ mask = mask[:, None, ...]
1330
+ with self.ema_scope("Plotting Inpaint"):
1331
+
1332
+ samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
1333
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1334
+ x_samples = self.decode_first_stage(samples.to(self.device))
1335
+ log["samples_inpainting"] = x_samples
1336
+ log["mask"] = mask
1337
+
1338
+ # outpaint
1339
+ with self.ema_scope("Plotting Outpaint"):
1340
+ samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
1341
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1342
+ x_samples = self.decode_first_stage(samples.to(self.device))
1343
+ log["samples_outpainting"] = x_samples
1344
+
1345
+ if plot_progressive_rows:
1346
+ with self.ema_scope("Plotting Progressives"):
1347
+ img, progressives = self.progressive_denoising(c,
1348
+ shape=(self.channels, self.image_size, self.image_size),
1349
+ batch_size=N)
1350
+ prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
1351
+ log["progressive_row"] = prog_row
1352
+
1353
+ if return_keys:
1354
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
1355
+ return log
1356
+ else:
1357
+ return {key: log[key] for key in return_keys}
1358
+ return log
1359
+
1360
+ def configure_optimizers(self):
1361
+ lr = self.learning_rate
1362
+ params = list(self.model.parameters())
1363
+ if self.cond_stage_trainable:
1364
+ print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
1365
+ params = params + list(self.cond_stage_model.parameters())
1366
+ if self.learn_logvar:
1367
+ print('Diffusion model optimizing logvar')
1368
+ params.append(self.logvar)
1369
+ opt = torch.optim.AdamW(params, lr=lr)
1370
+ if self.use_scheduler:
1371
+ assert 'target' in self.scheduler_config
1372
+ scheduler = instantiate_from_config(self.scheduler_config)
1373
+
1374
+ print("Setting up LambdaLR scheduler...")
1375
+ scheduler = [
1376
+ {
1377
+ 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
1378
+ 'interval': 'step',
1379
+ 'frequency': 1
1380
+ }]
1381
+ return [opt], scheduler
1382
+ return opt
1383
+
1384
+ @torch.no_grad()
1385
+ def to_rgb(self, x):
1386
+ x = x.float()
1387
+ if not hasattr(self, "colorize"):
1388
+ self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
1389
+ x = nn.functional.conv2d(x, weight=self.colorize)
1390
+ x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
1391
+ return x
1392
+
1393
+
1394
+ class DiffusionWrapperV1(pl.LightningModule):
1395
+ def __init__(self, diff_model_config, conditioning_key):
1396
+ super().__init__()
1397
+ self.diffusion_model = instantiate_from_config(diff_model_config)
1398
+ self.conditioning_key = conditioning_key
1399
+ assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm']
1400
+
1401
+ def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
1402
+ if self.conditioning_key is None:
1403
+ out = self.diffusion_model(x, t)
1404
+ elif self.conditioning_key == 'concat':
1405
+ xc = torch.cat([x] + c_concat, dim=1)
1406
+ out = self.diffusion_model(xc, t)
1407
+ elif self.conditioning_key == 'crossattn':
1408
+ cc = torch.cat(c_crossattn, 1)
1409
+ out = self.diffusion_model(x, t, context=cc)
1410
+ elif self.conditioning_key == 'hybrid':
1411
+ xc = torch.cat([x] + c_concat, dim=1)
1412
+ cc = torch.cat(c_crossattn, 1)
1413
+ out = self.diffusion_model(xc, t, context=cc)
1414
+ elif self.conditioning_key == 'adm':
1415
+ cc = c_crossattn[0]
1416
+ out = self.diffusion_model(x, t, y=cc)
1417
+ else:
1418
+ raise NotImplementedError()
1419
+
1420
+ return out
1421
+
1422
+
1423
+ class Layout2ImgDiffusionV1(LatentDiffusionV1):
1424
+ # TODO: move all layout-specific hacks to this class
1425
+ def __init__(self, cond_stage_key, *args, **kwargs):
1426
+ assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
1427
+ super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
1428
+
1429
+ def log_images(self, batch, N=8, *args, **kwargs):
1430
+ logs = super().log_images(batch=batch, N=N, *args, **kwargs)
1431
+
1432
+ key = 'train' if self.training else 'validation'
1433
+ dset = self.trainer.datamodule.datasets[key]
1434
+ mapper = dset.conditional_builders[self.cond_stage_key]
1435
+
1436
+ bbox_imgs = []
1437
+ map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno))
1438
+ for tknzd_bbox in batch[self.cond_stage_key][:N]:
1439
+ bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256))
1440
+ bbox_imgs.append(bboximg)
1441
+
1442
+ cond_img = torch.stack(bbox_imgs, dim=0)
1443
+ logs['bbox_image'] = cond_img
1444
+ return logs
1445
+
1446
+ setattr(ldm.models.diffusion.ddpm, "DDPMV1", DDPMV1)
1447
+ setattr(ldm.models.diffusion.ddpm, "LatentDiffusionV1", LatentDiffusionV1)
1448
+ setattr(ldm.models.diffusion.ddpm, "DiffusionWrapperV1", DiffusionWrapperV1)
1449
+ setattr(ldm.models.diffusion.ddpm, "Layout2ImgDiffusionV1", Layout2ImgDiffusionV1)
extensions-builtin/ScuNET/preload.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import os
2
+ from modules import paths
3
+
4
+
5
+ def preload(parser):
6
+ parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(paths.models_path, 'ScuNET'))
extensions-builtin/ScuNET/scripts/scunet_model.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ import sys
3
+ import traceback
4
+
5
+ import PIL.Image
6
+ import numpy as np
7
+ import torch
8
+ from basicsr.utils.download_util import load_file_from_url
9
+
10
+ import modules.upscaler
11
+ from modules import devices, modelloader
12
+ from scunet_model_arch import SCUNet as net
13
+
14
+
15
+ class UpscalerScuNET(modules.upscaler.Upscaler):
16
+ def __init__(self, dirname):
17
+ self.name = "ScuNET"
18
+ self.model_name = "ScuNET GAN"
19
+ self.model_name2 = "ScuNET PSNR"
20
+ self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_gan.pth"
21
+ self.model_url2 = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_psnr.pth"
22
+ self.user_path = dirname
23
+ super().__init__()
24
+ model_paths = self.find_models(ext_filter=[".pth"])
25
+ scalers = []
26
+ add_model2 = True
27
+ for file in model_paths:
28
+ if "http" in file:
29
+ name = self.model_name
30
+ else:
31
+ name = modelloader.friendly_name(file)
32
+ if name == self.model_name2 or file == self.model_url2:
33
+ add_model2 = False
34
+ try:
35
+ scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
36
+ scalers.append(scaler_data)
37
+ except Exception:
38
+ print(f"Error loading ScuNET model: {file}", file=sys.stderr)
39
+ print(traceback.format_exc(), file=sys.stderr)
40
+ if add_model2:
41
+ scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self)
42
+ scalers.append(scaler_data2)
43
+ self.scalers = scalers
44
+
45
+ def do_upscale(self, img: PIL.Image, selected_file):
46
+ torch.cuda.empty_cache()
47
+
48
+ model = self.load_model(selected_file)
49
+ if model is None:
50
+ return img
51
+
52
+ device = devices.get_device_for('scunet')
53
+ img = np.array(img)
54
+ img = img[:, :, ::-1]
55
+ img = np.moveaxis(img, 2, 0) / 255
56
+ img = torch.from_numpy(img).float()
57
+ img = img.unsqueeze(0).to(device)
58
+
59
+ with torch.no_grad():
60
+ output = model(img)
61
+ output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
62
+ output = 255. * np.moveaxis(output, 0, 2)
63
+ output = output.astype(np.uint8)
64
+ output = output[:, :, ::-1]
65
+ torch.cuda.empty_cache()
66
+ return PIL.Image.fromarray(output, 'RGB')
67
+
68
+ def load_model(self, path: str):
69
+ device = devices.get_device_for('scunet')
70
+ if "http" in path:
71
+ filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
72
+ progress=True)
73
+ else:
74
+ filename = path
75
+ if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None:
76
+ print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr)
77
+ return None
78
+
79
+ model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
80
+ model.load_state_dict(torch.load(filename), strict=True)
81
+ model.eval()
82
+ for k, v in model.named_parameters():
83
+ v.requires_grad = False
84
+ model = model.to(device)
85
+
86
+ return model
87
+
extensions-builtin/ScuNET/scunet_model_arch.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ from einops import rearrange
6
+ from einops.layers.torch import Rearrange
7
+ from timm.models.layers import trunc_normal_, DropPath
8
+
9
+
10
+ class WMSA(nn.Module):
11
+ """ Self-attention module in Swin Transformer
12
+ """
13
+
14
+ def __init__(self, input_dim, output_dim, head_dim, window_size, type):
15
+ super(WMSA, self).__init__()
16
+ self.input_dim = input_dim
17
+ self.output_dim = output_dim
18
+ self.head_dim = head_dim
19
+ self.scale = self.head_dim ** -0.5
20
+ self.n_heads = input_dim // head_dim
21
+ self.window_size = window_size
22
+ self.type = type
23
+ self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True)
24
+
25
+ self.relative_position_params = nn.Parameter(
26
+ torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads))
27
+
28
+ self.linear = nn.Linear(self.input_dim, self.output_dim)
29
+
30
+ trunc_normal_(self.relative_position_params, std=.02)
31
+ self.relative_position_params = torch.nn.Parameter(
32
+ self.relative_position_params.view(2 * window_size - 1, 2 * window_size - 1, self.n_heads).transpose(1,
33
+ 2).transpose(
34
+ 0, 1))
35
+
36
+ def generate_mask(self, h, w, p, shift):
37
+ """ generating the mask of SW-MSA
38
+ Args:
39
+ shift: shift parameters in CyclicShift.
40
+ Returns:
41
+ attn_mask: should be (1 1 w p p),
42
+ """
43
+ # supporting square.
44
+ attn_mask = torch.zeros(h, w, p, p, p, p, dtype=torch.bool, device=self.relative_position_params.device)
45
+ if self.type == 'W':
46
+ return attn_mask
47
+
48
+ s = p - shift
49
+ attn_mask[-1, :, :s, :, s:, :] = True
50
+ attn_mask[-1, :, s:, :, :s, :] = True
51
+ attn_mask[:, -1, :, :s, :, s:] = True
52
+ attn_mask[:, -1, :, s:, :, :s] = True
53
+ attn_mask = rearrange(attn_mask, 'w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)')
54
+ return attn_mask
55
+
56
+ def forward(self, x):
57
+ """ Forward pass of Window Multi-head Self-attention module.
58
+ Args:
59
+ x: input tensor with shape of [b h w c];
60
+ attn_mask: attention mask, fill -inf where the value is True;
61
+ Returns:
62
+ output: tensor shape [b h w c]
63
+ """
64
+ if self.type != 'W': x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2))
65
+ x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size)
66
+ h_windows = x.size(1)
67
+ w_windows = x.size(2)
68
+ # square validation
69
+ # assert h_windows == w_windows
70
+
71
+ x = rearrange(x, 'b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c', p1=self.window_size, p2=self.window_size)
72
+ qkv = self.embedding_layer(x)
73
+ q, k, v = rearrange(qkv, 'b nw np (threeh c) -> threeh b nw np c', c=self.head_dim).chunk(3, dim=0)
74
+ sim = torch.einsum('hbwpc,hbwqc->hbwpq', q, k) * self.scale
75
+ # Adding learnable relative embedding
76
+ sim = sim + rearrange(self.relative_embedding(), 'h p q -> h 1 1 p q')
77
+ # Using Attn Mask to distinguish different subwindows.
78
+ if self.type != 'W':
79
+ attn_mask = self.generate_mask(h_windows, w_windows, self.window_size, shift=self.window_size // 2)
80
+ sim = sim.masked_fill_(attn_mask, float("-inf"))
81
+
82
+ probs = nn.functional.softmax(sim, dim=-1)
83
+ output = torch.einsum('hbwij,hbwjc->hbwic', probs, v)
84
+ output = rearrange(output, 'h b w p c -> b w p (h c)')
85
+ output = self.linear(output)
86
+ output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size)
87
+
88
+ if self.type != 'W': output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2),
89
+ dims=(1, 2))
90
+ return output
91
+
92
+ def relative_embedding(self):
93
+ cord = torch.tensor(np.array([[i, j] for i in range(self.window_size) for j in range(self.window_size)]))
94
+ relation = cord[:, None, :] - cord[None, :, :] + self.window_size - 1
95
+ # negative is allowed
96
+ return self.relative_position_params[:, relation[:, :, 0].long(), relation[:, :, 1].long()]
97
+
98
+
99
+ class Block(nn.Module):
100
+ def __init__(self, input_dim, output_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
101
+ """ SwinTransformer Block
102
+ """
103
+ super(Block, self).__init__()
104
+ self.input_dim = input_dim
105
+ self.output_dim = output_dim
106
+ assert type in ['W', 'SW']
107
+ self.type = type
108
+ if input_resolution <= window_size:
109
+ self.type = 'W'
110
+
111
+ self.ln1 = nn.LayerNorm(input_dim)
112
+ self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type)
113
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
114
+ self.ln2 = nn.LayerNorm(input_dim)
115
+ self.mlp = nn.Sequential(
116
+ nn.Linear(input_dim, 4 * input_dim),
117
+ nn.GELU(),
118
+ nn.Linear(4 * input_dim, output_dim),
119
+ )
120
+
121
+ def forward(self, x):
122
+ x = x + self.drop_path(self.msa(self.ln1(x)))
123
+ x = x + self.drop_path(self.mlp(self.ln2(x)))
124
+ return x
125
+
126
+
127
+ class ConvTransBlock(nn.Module):
128
+ def __init__(self, conv_dim, trans_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
129
+ """ SwinTransformer and Conv Block
130
+ """
131
+ super(ConvTransBlock, self).__init__()
132
+ self.conv_dim = conv_dim
133
+ self.trans_dim = trans_dim
134
+ self.head_dim = head_dim
135
+ self.window_size = window_size
136
+ self.drop_path = drop_path
137
+ self.type = type
138
+ self.input_resolution = input_resolution
139
+
140
+ assert self.type in ['W', 'SW']
141
+ if self.input_resolution <= self.window_size:
142
+ self.type = 'W'
143
+
144
+ self.trans_block = Block(self.trans_dim, self.trans_dim, self.head_dim, self.window_size, self.drop_path,
145
+ self.type, self.input_resolution)
146
+ self.conv1_1 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True)
147
+ self.conv1_2 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True)
148
+
149
+ self.conv_block = nn.Sequential(
150
+ nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False),
151
+ nn.ReLU(True),
152
+ nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False)
153
+ )
154
+
155
+ def forward(self, x):
156
+ conv_x, trans_x = torch.split(self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1)
157
+ conv_x = self.conv_block(conv_x) + conv_x
158
+ trans_x = Rearrange('b c h w -> b h w c')(trans_x)
159
+ trans_x = self.trans_block(trans_x)
160
+ trans_x = Rearrange('b h w c -> b c h w')(trans_x)
161
+ res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1))
162
+ x = x + res
163
+
164
+ return x
165
+
166
+
167
+ class SCUNet(nn.Module):
168
+ # def __init__(self, in_nc=3, config=[2, 2, 2, 2, 2, 2, 2], dim=64, drop_path_rate=0.0, input_resolution=256):
169
+ def __init__(self, in_nc=3, config=None, dim=64, drop_path_rate=0.0, input_resolution=256):
170
+ super(SCUNet, self).__init__()
171
+ if config is None:
172
+ config = [2, 2, 2, 2, 2, 2, 2]
173
+ self.config = config
174
+ self.dim = dim
175
+ self.head_dim = 32
176
+ self.window_size = 8
177
+
178
+ # drop path rate for each layer
179
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))]
180
+
181
+ self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)]
182
+
183
+ begin = 0
184
+ self.m_down1 = [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin],
185
+ 'W' if not i % 2 else 'SW', input_resolution)
186
+ for i in range(config[0])] + \
187
+ [nn.Conv2d(dim, 2 * dim, 2, 2, 0, bias=False)]
188
+
189
+ begin += config[0]
190
+ self.m_down2 = [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin],
191
+ 'W' if not i % 2 else 'SW', input_resolution // 2)
192
+ for i in range(config[1])] + \
193
+ [nn.Conv2d(2 * dim, 4 * dim, 2, 2, 0, bias=False)]
194
+
195
+ begin += config[1]
196
+ self.m_down3 = [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin],
197
+ 'W' if not i % 2 else 'SW', input_resolution // 4)
198
+ for i in range(config[2])] + \
199
+ [nn.Conv2d(4 * dim, 8 * dim, 2, 2, 0, bias=False)]
200
+
201
+ begin += config[2]
202
+ self.m_body = [ConvTransBlock(4 * dim, 4 * dim, self.head_dim, self.window_size, dpr[i + begin],
203
+ 'W' if not i % 2 else 'SW', input_resolution // 8)
204
+ for i in range(config[3])]
205
+
206
+ begin += config[3]
207
+ self.m_up3 = [nn.ConvTranspose2d(8 * dim, 4 * dim, 2, 2, 0, bias=False), ] + \
208
+ [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin],
209
+ 'W' if not i % 2 else 'SW', input_resolution // 4)
210
+ for i in range(config[4])]
211
+
212
+ begin += config[4]
213
+ self.m_up2 = [nn.ConvTranspose2d(4 * dim, 2 * dim, 2, 2, 0, bias=False), ] + \
214
+ [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin],
215
+ 'W' if not i % 2 else 'SW', input_resolution // 2)
216
+ for i in range(config[5])]
217
+
218
+ begin += config[5]
219
+ self.m_up1 = [nn.ConvTranspose2d(2 * dim, dim, 2, 2, 0, bias=False), ] + \
220
+ [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin],
221
+ 'W' if not i % 2 else 'SW', input_resolution)
222
+ for i in range(config[6])]
223
+
224
+ self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)]
225
+
226
+ self.m_head = nn.Sequential(*self.m_head)
227
+ self.m_down1 = nn.Sequential(*self.m_down1)
228
+ self.m_down2 = nn.Sequential(*self.m_down2)
229
+ self.m_down3 = nn.Sequential(*self.m_down3)
230
+ self.m_body = nn.Sequential(*self.m_body)
231
+ self.m_up3 = nn.Sequential(*self.m_up3)
232
+ self.m_up2 = nn.Sequential(*self.m_up2)
233
+ self.m_up1 = nn.Sequential(*self.m_up1)
234
+ self.m_tail = nn.Sequential(*self.m_tail)
235
+ # self.apply(self._init_weights)
236
+
237
+ def forward(self, x0):
238
+
239
+ h, w = x0.size()[-2:]
240
+ paddingBottom = int(np.ceil(h / 64) * 64 - h)
241
+ paddingRight = int(np.ceil(w / 64) * 64 - w)
242
+ x0 = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x0)
243
+
244
+ x1 = self.m_head(x0)
245
+ x2 = self.m_down1(x1)
246
+ x3 = self.m_down2(x2)
247
+ x4 = self.m_down3(x3)
248
+ x = self.m_body(x4)
249
+ x = self.m_up3(x + x4)
250
+ x = self.m_up2(x + x3)
251
+ x = self.m_up1(x + x2)
252
+ x = self.m_tail(x + x1)
253
+
254
+ x = x[..., :h, :w]
255
+
256
+ return x
257
+
258
+ def _init_weights(self, m):
259
+ if isinstance(m, nn.Linear):
260
+ trunc_normal_(m.weight, std=.02)
261
+ if m.bias is not None:
262
+ nn.init.constant_(m.bias, 0)
263
+ elif isinstance(m, nn.LayerNorm):
264
+ nn.init.constant_(m.bias, 0)
265
+ nn.init.constant_(m.weight, 1.0)
extensions-builtin/SwinIR/preload.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import os
2
+ from modules import paths
3
+
4
+
5
+ def preload(parser):
6
+ parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(paths.models_path, 'SwinIR'))
extensions-builtin/SwinIR/scripts/swinir_model.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import os
3
+
4
+ import numpy as np
5
+ import torch
6
+ from PIL import Image
7
+ from basicsr.utils.download_util import load_file_from_url
8
+ from tqdm import tqdm
9
+
10
+ from modules import modelloader, devices, script_callbacks, shared
11
+ from modules.shared import cmd_opts, opts
12
+ from swinir_model_arch import SwinIR as net
13
+ from swinir_model_arch_v2 import Swin2SR as net2
14
+ from modules.upscaler import Upscaler, UpscalerData
15
+
16
+
17
+ device_swinir = devices.get_device_for('swinir')
18
+
19
+
20
+ class UpscalerSwinIR(Upscaler):
21
+ def __init__(self, dirname):
22
+ self.name = "SwinIR"
23
+ self.model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0" \
24
+ "/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \
25
+ "-L_x4_GAN.pth "
26
+ self.model_name = "SwinIR 4x"
27
+ self.user_path = dirname
28
+ super().__init__()
29
+ scalers = []
30
+ model_files = self.find_models(ext_filter=[".pt", ".pth"])
31
+ for model in model_files:
32
+ if "http" in model:
33
+ name = self.model_name
34
+ else:
35
+ name = modelloader.friendly_name(model)
36
+ model_data = UpscalerData(name, model, self)
37
+ scalers.append(model_data)
38
+ self.scalers = scalers
39
+
40
+ def do_upscale(self, img, model_file):
41
+ model = self.load_model(model_file)
42
+ if model is None:
43
+ return img
44
+ model = model.to(device_swinir, dtype=devices.dtype)
45
+ img = upscale(img, model)
46
+ try:
47
+ torch.cuda.empty_cache()
48
+ except:
49
+ pass
50
+ return img
51
+
52
+ def load_model(self, path, scale=4):
53
+ if "http" in path:
54
+ dl_name = "%s%s" % (self.model_name.replace(" ", "_"), ".pth")
55
+ filename = load_file_from_url(url=path, model_dir=self.model_path, file_name=dl_name, progress=True)
56
+ else:
57
+ filename = path
58
+ if filename is None or not os.path.exists(filename):
59
+ return None
60
+ if filename.endswith(".v2.pth"):
61
+ model = net2(
62
+ upscale=scale,
63
+ in_chans=3,
64
+ img_size=64,
65
+ window_size=8,
66
+ img_range=1.0,
67
+ depths=[6, 6, 6, 6, 6, 6],
68
+ embed_dim=180,
69
+ num_heads=[6, 6, 6, 6, 6, 6],
70
+ mlp_ratio=2,
71
+ upsampler="nearest+conv",
72
+ resi_connection="1conv",
73
+ )
74
+ params = None
75
+ else:
76
+ model = net(
77
+ upscale=scale,
78
+ in_chans=3,
79
+ img_size=64,
80
+ window_size=8,
81
+ img_range=1.0,
82
+ depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
83
+ embed_dim=240,
84
+ num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
85
+ mlp_ratio=2,
86
+ upsampler="nearest+conv",
87
+ resi_connection="3conv",
88
+ )
89
+ params = "params_ema"
90
+
91
+ pretrained_model = torch.load(filename)
92
+ if params is not None:
93
+ model.load_state_dict(pretrained_model[params], strict=True)
94
+ else:
95
+ model.load_state_dict(pretrained_model, strict=True)
96
+ return model
97
+
98
+
99
+ def upscale(
100
+ img,
101
+ model,
102
+ tile=None,
103
+ tile_overlap=None,
104
+ window_size=8,
105
+ scale=4,
106
+ ):
107
+ tile = tile or opts.SWIN_tile
108
+ tile_overlap = tile_overlap or opts.SWIN_tile_overlap
109
+
110
+
111
+ img = np.array(img)
112
+ img = img[:, :, ::-1]
113
+ img = np.moveaxis(img, 2, 0) / 255
114
+ img = torch.from_numpy(img).float()
115
+ img = img.unsqueeze(0).to(device_swinir, dtype=devices.dtype)
116
+ with torch.no_grad(), devices.autocast():
117
+ _, _, h_old, w_old = img.size()
118
+ h_pad = (h_old // window_size + 1) * window_size - h_old
119
+ w_pad = (w_old // window_size + 1) * window_size - w_old
120
+ img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :]
121
+ img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad]
122
+ output = inference(img, model, tile, tile_overlap, window_size, scale)
123
+ output = output[..., : h_old * scale, : w_old * scale]
124
+ output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
125
+ if output.ndim == 3:
126
+ output = np.transpose(
127
+ output[[2, 1, 0], :, :], (1, 2, 0)
128
+ ) # CHW-RGB to HCW-BGR
129
+ output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
130
+ return Image.fromarray(output, "RGB")
131
+
132
+
133
+ def inference(img, model, tile, tile_overlap, window_size, scale):
134
+ # test the image tile by tile
135
+ b, c, h, w = img.size()
136
+ tile = min(tile, h, w)
137
+ assert tile % window_size == 0, "tile size should be a multiple of window_size"
138
+ sf = scale
139
+
140
+ stride = tile - tile_overlap
141
+ h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
142
+ w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
143
+ E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device_swinir).type_as(img)
144
+ W = torch.zeros_like(E, dtype=devices.dtype, device=device_swinir)
145
+
146
+ with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
147
+ for h_idx in h_idx_list:
148
+ for w_idx in w_idx_list:
149
+ in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
150
+ out_patch = model(in_patch)
151
+ out_patch_mask = torch.ones_like(out_patch)
152
+
153
+ E[
154
+ ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
155
+ ].add_(out_patch)
156
+ W[
157
+ ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
158
+ ].add_(out_patch_mask)
159
+ pbar.update(1)
160
+ output = E.div_(W)
161
+
162
+ return output
163
+
164
+
165
+ def on_ui_settings():
166
+ import gradio as gr
167
+
168
+ shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")))
169
+ shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling")))
170
+
171
+
172
+ script_callbacks.on_ui_settings(on_ui_settings)
extensions-builtin/SwinIR/swinir_model_arch.py ADDED
@@ -0,0 +1,867 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -----------------------------------------------------------------------------------
2
+ # SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
3
+ # Originally Written by Ze Liu, Modified by Jingyun Liang.
4
+ # -----------------------------------------------------------------------------------
5
+
6
+ import math
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import torch.utils.checkpoint as checkpoint
11
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
12
+
13
+
14
+ class Mlp(nn.Module):
15
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
16
+ super().__init__()
17
+ out_features = out_features or in_features
18
+ hidden_features = hidden_features or in_features
19
+ self.fc1 = nn.Linear(in_features, hidden_features)
20
+ self.act = act_layer()
21
+ self.fc2 = nn.Linear(hidden_features, out_features)
22
+ self.drop = nn.Dropout(drop)
23
+
24
+ def forward(self, x):
25
+ x = self.fc1(x)
26
+ x = self.act(x)
27
+ x = self.drop(x)
28
+ x = self.fc2(x)
29
+ x = self.drop(x)
30
+ return x
31
+
32
+
33
+ def window_partition(x, window_size):
34
+ """
35
+ Args:
36
+ x: (B, H, W, C)
37
+ window_size (int): window size
38
+
39
+ Returns:
40
+ windows: (num_windows*B, window_size, window_size, C)
41
+ """
42
+ B, H, W, C = x.shape
43
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
44
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
45
+ return windows
46
+
47
+
48
+ def window_reverse(windows, window_size, H, W):
49
+ """
50
+ Args:
51
+ windows: (num_windows*B, window_size, window_size, C)
52
+ window_size (int): Window size
53
+ H (int): Height of image
54
+ W (int): Width of image
55
+
56
+ Returns:
57
+ x: (B, H, W, C)
58
+ """
59
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
60
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
61
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
62
+ return x
63
+
64
+
65
+ class WindowAttention(nn.Module):
66
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
67
+ It supports both of shifted and non-shifted window.
68
+
69
+ Args:
70
+ dim (int): Number of input channels.
71
+ window_size (tuple[int]): The height and width of the window.
72
+ num_heads (int): Number of attention heads.
73
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
74
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
75
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
76
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
77
+ """
78
+
79
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
80
+
81
+ super().__init__()
82
+ self.dim = dim
83
+ self.window_size = window_size # Wh, Ww
84
+ self.num_heads = num_heads
85
+ head_dim = dim // num_heads
86
+ self.scale = qk_scale or head_dim ** -0.5
87
+
88
+ # define a parameter table of relative position bias
89
+ self.relative_position_bias_table = nn.Parameter(
90
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
91
+
92
+ # get pair-wise relative position index for each token inside the window
93
+ coords_h = torch.arange(self.window_size[0])
94
+ coords_w = torch.arange(self.window_size[1])
95
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
96
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
97
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
98
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
99
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
100
+ relative_coords[:, :, 1] += self.window_size[1] - 1
101
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
102
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
103
+ self.register_buffer("relative_position_index", relative_position_index)
104
+
105
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
106
+ self.attn_drop = nn.Dropout(attn_drop)
107
+ self.proj = nn.Linear(dim, dim)
108
+
109
+ self.proj_drop = nn.Dropout(proj_drop)
110
+
111
+ trunc_normal_(self.relative_position_bias_table, std=.02)
112
+ self.softmax = nn.Softmax(dim=-1)
113
+
114
+ def forward(self, x, mask=None):
115
+ """
116
+ Args:
117
+ x: input features with shape of (num_windows*B, N, C)
118
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
119
+ """
120
+ B_, N, C = x.shape
121
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
122
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
123
+
124
+ q = q * self.scale
125
+ attn = (q @ k.transpose(-2, -1))
126
+
127
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
128
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
129
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
130
+ attn = attn + relative_position_bias.unsqueeze(0)
131
+
132
+ if mask is not None:
133
+ nW = mask.shape[0]
134
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
135
+ attn = attn.view(-1, self.num_heads, N, N)
136
+ attn = self.softmax(attn)
137
+ else:
138
+ attn = self.softmax(attn)
139
+
140
+ attn = self.attn_drop(attn)
141
+
142
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
143
+ x = self.proj(x)
144
+ x = self.proj_drop(x)
145
+ return x
146
+
147
+ def extra_repr(self) -> str:
148
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
149
+
150
+ def flops(self, N):
151
+ # calculate flops for 1 window with token length of N
152
+ flops = 0
153
+ # qkv = self.qkv(x)
154
+ flops += N * self.dim * 3 * self.dim
155
+ # attn = (q @ k.transpose(-2, -1))
156
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
157
+ # x = (attn @ v)
158
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
159
+ # x = self.proj(x)
160
+ flops += N * self.dim * self.dim
161
+ return flops
162
+
163
+
164
+ class SwinTransformerBlock(nn.Module):
165
+ r""" Swin Transformer Block.
166
+
167
+ Args:
168
+ dim (int): Number of input channels.
169
+ input_resolution (tuple[int]): Input resolution.
170
+ num_heads (int): Number of attention heads.
171
+ window_size (int): Window size.
172
+ shift_size (int): Shift size for SW-MSA.
173
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
174
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
175
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
176
+ drop (float, optional): Dropout rate. Default: 0.0
177
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
178
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
179
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
180
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
181
+ """
182
+
183
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
184
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
185
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
186
+ super().__init__()
187
+ self.dim = dim
188
+ self.input_resolution = input_resolution
189
+ self.num_heads = num_heads
190
+ self.window_size = window_size
191
+ self.shift_size = shift_size
192
+ self.mlp_ratio = mlp_ratio
193
+ if min(self.input_resolution) <= self.window_size:
194
+ # if window size is larger than input resolution, we don't partition windows
195
+ self.shift_size = 0
196
+ self.window_size = min(self.input_resolution)
197
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
198
+
199
+ self.norm1 = norm_layer(dim)
200
+ self.attn = WindowAttention(
201
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
202
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
203
+
204
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
205
+ self.norm2 = norm_layer(dim)
206
+ mlp_hidden_dim = int(dim * mlp_ratio)
207
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
208
+
209
+ if self.shift_size > 0:
210
+ attn_mask = self.calculate_mask(self.input_resolution)
211
+ else:
212
+ attn_mask = None
213
+
214
+ self.register_buffer("attn_mask", attn_mask)
215
+
216
+ def calculate_mask(self, x_size):
217
+ # calculate attention mask for SW-MSA
218
+ H, W = x_size
219
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
220
+ h_slices = (slice(0, -self.window_size),
221
+ slice(-self.window_size, -self.shift_size),
222
+ slice(-self.shift_size, None))
223
+ w_slices = (slice(0, -self.window_size),
224
+ slice(-self.window_size, -self.shift_size),
225
+ slice(-self.shift_size, None))
226
+ cnt = 0
227
+ for h in h_slices:
228
+ for w in w_slices:
229
+ img_mask[:, h, w, :] = cnt
230
+ cnt += 1
231
+
232
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
233
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
234
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
235
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
236
+
237
+ return attn_mask
238
+
239
+ def forward(self, x, x_size):
240
+ H, W = x_size
241
+ B, L, C = x.shape
242
+ # assert L == H * W, "input feature has wrong size"
243
+
244
+ shortcut = x
245
+ x = self.norm1(x)
246
+ x = x.view(B, H, W, C)
247
+
248
+ # cyclic shift
249
+ if self.shift_size > 0:
250
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
251
+ else:
252
+ shifted_x = x
253
+
254
+ # partition windows
255
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
256
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
257
+
258
+ # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
259
+ if self.input_resolution == x_size:
260
+ attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
261
+ else:
262
+ attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
263
+
264
+ # merge windows
265
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
266
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
267
+
268
+ # reverse cyclic shift
269
+ if self.shift_size > 0:
270
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
271
+ else:
272
+ x = shifted_x
273
+ x = x.view(B, H * W, C)
274
+
275
+ # FFN
276
+ x = shortcut + self.drop_path(x)
277
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
278
+
279
+ return x
280
+
281
+ def extra_repr(self) -> str:
282
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
283
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
284
+
285
+ def flops(self):
286
+ flops = 0
287
+ H, W = self.input_resolution
288
+ # norm1
289
+ flops += self.dim * H * W
290
+ # W-MSA/SW-MSA
291
+ nW = H * W / self.window_size / self.window_size
292
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
293
+ # mlp
294
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
295
+ # norm2
296
+ flops += self.dim * H * W
297
+ return flops
298
+
299
+
300
+ class PatchMerging(nn.Module):
301
+ r""" Patch Merging Layer.
302
+
303
+ Args:
304
+ input_resolution (tuple[int]): Resolution of input feature.
305
+ dim (int): Number of input channels.
306
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
307
+ """
308
+
309
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
310
+ super().__init__()
311
+ self.input_resolution = input_resolution
312
+ self.dim = dim
313
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
314
+ self.norm = norm_layer(4 * dim)
315
+
316
+ def forward(self, x):
317
+ """
318
+ x: B, H*W, C
319
+ """
320
+ H, W = self.input_resolution
321
+ B, L, C = x.shape
322
+ assert L == H * W, "input feature has wrong size"
323
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
324
+
325
+ x = x.view(B, H, W, C)
326
+
327
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
328
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
329
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
330
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
331
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
332
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
333
+
334
+ x = self.norm(x)
335
+ x = self.reduction(x)
336
+
337
+ return x
338
+
339
+ def extra_repr(self) -> str:
340
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
341
+
342
+ def flops(self):
343
+ H, W = self.input_resolution
344
+ flops = H * W * self.dim
345
+ flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
346
+ return flops
347
+
348
+
349
+ class BasicLayer(nn.Module):
350
+ """ A basic Swin Transformer layer for one stage.
351
+
352
+ Args:
353
+ dim (int): Number of input channels.
354
+ input_resolution (tuple[int]): Input resolution.
355
+ depth (int): Number of blocks.
356
+ num_heads (int): Number of attention heads.
357
+ window_size (int): Local window size.
358
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
359
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
360
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
361
+ drop (float, optional): Dropout rate. Default: 0.0
362
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
363
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
364
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
365
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
366
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
367
+ """
368
+
369
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
370
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
371
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
372
+
373
+ super().__init__()
374
+ self.dim = dim
375
+ self.input_resolution = input_resolution
376
+ self.depth = depth
377
+ self.use_checkpoint = use_checkpoint
378
+
379
+ # build blocks
380
+ self.blocks = nn.ModuleList([
381
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
382
+ num_heads=num_heads, window_size=window_size,
383
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
384
+ mlp_ratio=mlp_ratio,
385
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
386
+ drop=drop, attn_drop=attn_drop,
387
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
388
+ norm_layer=norm_layer)
389
+ for i in range(depth)])
390
+
391
+ # patch merging layer
392
+ if downsample is not None:
393
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
394
+ else:
395
+ self.downsample = None
396
+
397
+ def forward(self, x, x_size):
398
+ for blk in self.blocks:
399
+ if self.use_checkpoint:
400
+ x = checkpoint.checkpoint(blk, x, x_size)
401
+ else:
402
+ x = blk(x, x_size)
403
+ if self.downsample is not None:
404
+ x = self.downsample(x)
405
+ return x
406
+
407
+ def extra_repr(self) -> str:
408
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
409
+
410
+ def flops(self):
411
+ flops = 0
412
+ for blk in self.blocks:
413
+ flops += blk.flops()
414
+ if self.downsample is not None:
415
+ flops += self.downsample.flops()
416
+ return flops
417
+
418
+
419
+ class RSTB(nn.Module):
420
+ """Residual Swin Transformer Block (RSTB).
421
+
422
+ Args:
423
+ dim (int): Number of input channels.
424
+ input_resolution (tuple[int]): Input resolution.
425
+ depth (int): Number of blocks.
426
+ num_heads (int): Number of attention heads.
427
+ window_size (int): Local window size.
428
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
429
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
430
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
431
+ drop (float, optional): Dropout rate. Default: 0.0
432
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
433
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
434
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
435
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
436
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
437
+ img_size: Input image size.
438
+ patch_size: Patch size.
439
+ resi_connection: The convolutional block before residual connection.
440
+ """
441
+
442
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
443
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
444
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
445
+ img_size=224, patch_size=4, resi_connection='1conv'):
446
+ super(RSTB, self).__init__()
447
+
448
+ self.dim = dim
449
+ self.input_resolution = input_resolution
450
+
451
+ self.residual_group = BasicLayer(dim=dim,
452
+ input_resolution=input_resolution,
453
+ depth=depth,
454
+ num_heads=num_heads,
455
+ window_size=window_size,
456
+ mlp_ratio=mlp_ratio,
457
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
458
+ drop=drop, attn_drop=attn_drop,
459
+ drop_path=drop_path,
460
+ norm_layer=norm_layer,
461
+ downsample=downsample,
462
+ use_checkpoint=use_checkpoint)
463
+
464
+ if resi_connection == '1conv':
465
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
466
+ elif resi_connection == '3conv':
467
+ # to save parameters and memory
468
+ self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
469
+ nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
470
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
471
+ nn.Conv2d(dim // 4, dim, 3, 1, 1))
472
+
473
+ self.patch_embed = PatchEmbed(
474
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
475
+ norm_layer=None)
476
+
477
+ self.patch_unembed = PatchUnEmbed(
478
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
479
+ norm_layer=None)
480
+
481
+ def forward(self, x, x_size):
482
+ return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
483
+
484
+ def flops(self):
485
+ flops = 0
486
+ flops += self.residual_group.flops()
487
+ H, W = self.input_resolution
488
+ flops += H * W * self.dim * self.dim * 9
489
+ flops += self.patch_embed.flops()
490
+ flops += self.patch_unembed.flops()
491
+
492
+ return flops
493
+
494
+
495
+ class PatchEmbed(nn.Module):
496
+ r""" Image to Patch Embedding
497
+
498
+ Args:
499
+ img_size (int): Image size. Default: 224.
500
+ patch_size (int): Patch token size. Default: 4.
501
+ in_chans (int): Number of input image channels. Default: 3.
502
+ embed_dim (int): Number of linear projection output channels. Default: 96.
503
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
504
+ """
505
+
506
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
507
+ super().__init__()
508
+ img_size = to_2tuple(img_size)
509
+ patch_size = to_2tuple(patch_size)
510
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
511
+ self.img_size = img_size
512
+ self.patch_size = patch_size
513
+ self.patches_resolution = patches_resolution
514
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
515
+
516
+ self.in_chans = in_chans
517
+ self.embed_dim = embed_dim
518
+
519
+ if norm_layer is not None:
520
+ self.norm = norm_layer(embed_dim)
521
+ else:
522
+ self.norm = None
523
+
524
+ def forward(self, x):
525
+ x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
526
+ if self.norm is not None:
527
+ x = self.norm(x)
528
+ return x
529
+
530
+ def flops(self):
531
+ flops = 0
532
+ H, W = self.img_size
533
+ if self.norm is not None:
534
+ flops += H * W * self.embed_dim
535
+ return flops
536
+
537
+
538
+ class PatchUnEmbed(nn.Module):
539
+ r""" Image to Patch Unembedding
540
+
541
+ Args:
542
+ img_size (int): Image size. Default: 224.
543
+ patch_size (int): Patch token size. Default: 4.
544
+ in_chans (int): Number of input image channels. Default: 3.
545
+ embed_dim (int): Number of linear projection output channels. Default: 96.
546
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
547
+ """
548
+
549
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
550
+ super().__init__()
551
+ img_size = to_2tuple(img_size)
552
+ patch_size = to_2tuple(patch_size)
553
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
554
+ self.img_size = img_size
555
+ self.patch_size = patch_size
556
+ self.patches_resolution = patches_resolution
557
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
558
+
559
+ self.in_chans = in_chans
560
+ self.embed_dim = embed_dim
561
+
562
+ def forward(self, x, x_size):
563
+ B, HW, C = x.shape
564
+ x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
565
+ return x
566
+
567
+ def flops(self):
568
+ flops = 0
569
+ return flops
570
+
571
+
572
+ class Upsample(nn.Sequential):
573
+ """Upsample module.
574
+
575
+ Args:
576
+ scale (int): Scale factor. Supported scales: 2^n and 3.
577
+ num_feat (int): Channel number of intermediate features.
578
+ """
579
+
580
+ def __init__(self, scale, num_feat):
581
+ m = []
582
+ if (scale & (scale - 1)) == 0: # scale = 2^n
583
+ for _ in range(int(math.log(scale, 2))):
584
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
585
+ m.append(nn.PixelShuffle(2))
586
+ elif scale == 3:
587
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
588
+ m.append(nn.PixelShuffle(3))
589
+ else:
590
+ raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
591
+ super(Upsample, self).__init__(*m)
592
+
593
+
594
+ class UpsampleOneStep(nn.Sequential):
595
+ """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
596
+ Used in lightweight SR to save parameters.
597
+
598
+ Args:
599
+ scale (int): Scale factor. Supported scales: 2^n and 3.
600
+ num_feat (int): Channel number of intermediate features.
601
+
602
+ """
603
+
604
+ def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
605
+ self.num_feat = num_feat
606
+ self.input_resolution = input_resolution
607
+ m = []
608
+ m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
609
+ m.append(nn.PixelShuffle(scale))
610
+ super(UpsampleOneStep, self).__init__(*m)
611
+
612
+ def flops(self):
613
+ H, W = self.input_resolution
614
+ flops = H * W * self.num_feat * 3 * 9
615
+ return flops
616
+
617
+
618
+ class SwinIR(nn.Module):
619
+ r""" SwinIR
620
+ A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
621
+
622
+ Args:
623
+ img_size (int | tuple(int)): Input image size. Default 64
624
+ patch_size (int | tuple(int)): Patch size. Default: 1
625
+ in_chans (int): Number of input image channels. Default: 3
626
+ embed_dim (int): Patch embedding dimension. Default: 96
627
+ depths (tuple(int)): Depth of each Swin Transformer layer.
628
+ num_heads (tuple(int)): Number of attention heads in different layers.
629
+ window_size (int): Window size. Default: 7
630
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
631
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
632
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
633
+ drop_rate (float): Dropout rate. Default: 0
634
+ attn_drop_rate (float): Attention dropout rate. Default: 0
635
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
636
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
637
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
638
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
639
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
640
+ upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
641
+ img_range: Image range. 1. or 255.
642
+ upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
643
+ resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
644
+ """
645
+
646
+ def __init__(self, img_size=64, patch_size=1, in_chans=3,
647
+ embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
648
+ window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
649
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
650
+ norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
651
+ use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
652
+ **kwargs):
653
+ super(SwinIR, self).__init__()
654
+ num_in_ch = in_chans
655
+ num_out_ch = in_chans
656
+ num_feat = 64
657
+ self.img_range = img_range
658
+ if in_chans == 3:
659
+ rgb_mean = (0.4488, 0.4371, 0.4040)
660
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
661
+ else:
662
+ self.mean = torch.zeros(1, 1, 1, 1)
663
+ self.upscale = upscale
664
+ self.upsampler = upsampler
665
+ self.window_size = window_size
666
+
667
+ #####################################################################################################
668
+ ################################### 1, shallow feature extraction ###################################
669
+ self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
670
+
671
+ #####################################################################################################
672
+ ################################### 2, deep feature extraction ######################################
673
+ self.num_layers = len(depths)
674
+ self.embed_dim = embed_dim
675
+ self.ape = ape
676
+ self.patch_norm = patch_norm
677
+ self.num_features = embed_dim
678
+ self.mlp_ratio = mlp_ratio
679
+
680
+ # split image into non-overlapping patches
681
+ self.patch_embed = PatchEmbed(
682
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
683
+ norm_layer=norm_layer if self.patch_norm else None)
684
+ num_patches = self.patch_embed.num_patches
685
+ patches_resolution = self.patch_embed.patches_resolution
686
+ self.patches_resolution = patches_resolution
687
+
688
+ # merge non-overlapping patches into image
689
+ self.patch_unembed = PatchUnEmbed(
690
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
691
+ norm_layer=norm_layer if self.patch_norm else None)
692
+
693
+ # absolute position embedding
694
+ if self.ape:
695
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
696
+ trunc_normal_(self.absolute_pos_embed, std=.02)
697
+
698
+ self.pos_drop = nn.Dropout(p=drop_rate)
699
+
700
+ # stochastic depth
701
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
702
+
703
+ # build Residual Swin Transformer blocks (RSTB)
704
+ self.layers = nn.ModuleList()
705
+ for i_layer in range(self.num_layers):
706
+ layer = RSTB(dim=embed_dim,
707
+ input_resolution=(patches_resolution[0],
708
+ patches_resolution[1]),
709
+ depth=depths[i_layer],
710
+ num_heads=num_heads[i_layer],
711
+ window_size=window_size,
712
+ mlp_ratio=self.mlp_ratio,
713
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
714
+ drop=drop_rate, attn_drop=attn_drop_rate,
715
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
716
+ norm_layer=norm_layer,
717
+ downsample=None,
718
+ use_checkpoint=use_checkpoint,
719
+ img_size=img_size,
720
+ patch_size=patch_size,
721
+ resi_connection=resi_connection
722
+
723
+ )
724
+ self.layers.append(layer)
725
+ self.norm = norm_layer(self.num_features)
726
+
727
+ # build the last conv layer in deep feature extraction
728
+ if resi_connection == '1conv':
729
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
730
+ elif resi_connection == '3conv':
731
+ # to save parameters and memory
732
+ self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
733
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
734
+ nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
735
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
736
+ nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
737
+
738
+ #####################################################################################################
739
+ ################################ 3, high quality image reconstruction ################################
740
+ if self.upsampler == 'pixelshuffle':
741
+ # for classical SR
742
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
743
+ nn.LeakyReLU(inplace=True))
744
+ self.upsample = Upsample(upscale, num_feat)
745
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
746
+ elif self.upsampler == 'pixelshuffledirect':
747
+ # for lightweight SR (to save parameters)
748
+ self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
749
+ (patches_resolution[0], patches_resolution[1]))
750
+ elif self.upsampler == 'nearest+conv':
751
+ # for real-world SR (less artifacts)
752
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
753
+ nn.LeakyReLU(inplace=True))
754
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
755
+ if self.upscale == 4:
756
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
757
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
758
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
759
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
760
+ else:
761
+ # for image denoising and JPEG compression artifact reduction
762
+ self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
763
+
764
+ self.apply(self._init_weights)
765
+
766
+ def _init_weights(self, m):
767
+ if isinstance(m, nn.Linear):
768
+ trunc_normal_(m.weight, std=.02)
769
+ if isinstance(m, nn.Linear) and m.bias is not None:
770
+ nn.init.constant_(m.bias, 0)
771
+ elif isinstance(m, nn.LayerNorm):
772
+ nn.init.constant_(m.bias, 0)
773
+ nn.init.constant_(m.weight, 1.0)
774
+
775
+ @torch.jit.ignore
776
+ def no_weight_decay(self):
777
+ return {'absolute_pos_embed'}
778
+
779
+ @torch.jit.ignore
780
+ def no_weight_decay_keywords(self):
781
+ return {'relative_position_bias_table'}
782
+
783
+ def check_image_size(self, x):
784
+ _, _, h, w = x.size()
785
+ mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
786
+ mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
787
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
788
+ return x
789
+
790
+ def forward_features(self, x):
791
+ x_size = (x.shape[2], x.shape[3])
792
+ x = self.patch_embed(x)
793
+ if self.ape:
794
+ x = x + self.absolute_pos_embed
795
+ x = self.pos_drop(x)
796
+
797
+ for layer in self.layers:
798
+ x = layer(x, x_size)
799
+
800
+ x = self.norm(x) # B L C
801
+ x = self.patch_unembed(x, x_size)
802
+
803
+ return x
804
+
805
+ def forward(self, x):
806
+ H, W = x.shape[2:]
807
+ x = self.check_image_size(x)
808
+
809
+ self.mean = self.mean.type_as(x)
810
+ x = (x - self.mean) * self.img_range
811
+
812
+ if self.upsampler == 'pixelshuffle':
813
+ # for classical SR
814
+ x = self.conv_first(x)
815
+ x = self.conv_after_body(self.forward_features(x)) + x
816
+ x = self.conv_before_upsample(x)
817
+ x = self.conv_last(self.upsample(x))
818
+ elif self.upsampler == 'pixelshuffledirect':
819
+ # for lightweight SR
820
+ x = self.conv_first(x)
821
+ x = self.conv_after_body(self.forward_features(x)) + x
822
+ x = self.upsample(x)
823
+ elif self.upsampler == 'nearest+conv':
824
+ # for real-world SR
825
+ x = self.conv_first(x)
826
+ x = self.conv_after_body(self.forward_features(x)) + x
827
+ x = self.conv_before_upsample(x)
828
+ x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
829
+ if self.upscale == 4:
830
+ x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
831
+ x = self.conv_last(self.lrelu(self.conv_hr(x)))
832
+ else:
833
+ # for image denoising and JPEG compression artifact reduction
834
+ x_first = self.conv_first(x)
835
+ res = self.conv_after_body(self.forward_features(x_first)) + x_first
836
+ x = x + self.conv_last(res)
837
+
838
+ x = x / self.img_range + self.mean
839
+
840
+ return x[:, :, :H*self.upscale, :W*self.upscale]
841
+
842
+ def flops(self):
843
+ flops = 0
844
+ H, W = self.patches_resolution
845
+ flops += H * W * 3 * self.embed_dim * 9
846
+ flops += self.patch_embed.flops()
847
+ for i, layer in enumerate(self.layers):
848
+ flops += layer.flops()
849
+ flops += H * W * 3 * self.embed_dim * self.embed_dim
850
+ flops += self.upsample.flops()
851
+ return flops
852
+
853
+
854
+ if __name__ == '__main__':
855
+ upscale = 4
856
+ window_size = 8
857
+ height = (1024 // upscale // window_size + 1) * window_size
858
+ width = (720 // upscale // window_size + 1) * window_size
859
+ model = SwinIR(upscale=2, img_size=(height, width),
860
+ window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
861
+ embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
862
+ print(model)
863
+ print(height, width, model.flops() / 1e9)
864
+
865
+ x = torch.randn((1, 3, height, width))
866
+ x = model(x)
867
+ print(x.shape)
extensions-builtin/SwinIR/swinir_model_arch_v2.py ADDED
@@ -0,0 +1,1017 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -----------------------------------------------------------------------------------
2
+ # Swin2SR: Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration, https://arxiv.org/abs/
3
+ # Written by Conde and Choi et al.
4
+ # -----------------------------------------------------------------------------------
5
+
6
+ import math
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import torch.utils.checkpoint as checkpoint
12
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
13
+
14
+
15
+ class Mlp(nn.Module):
16
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
17
+ super().__init__()
18
+ out_features = out_features or in_features
19
+ hidden_features = hidden_features or in_features
20
+ self.fc1 = nn.Linear(in_features, hidden_features)
21
+ self.act = act_layer()
22
+ self.fc2 = nn.Linear(hidden_features, out_features)
23
+ self.drop = nn.Dropout(drop)
24
+
25
+ def forward(self, x):
26
+ x = self.fc1(x)
27
+ x = self.act(x)
28
+ x = self.drop(x)
29
+ x = self.fc2(x)
30
+ x = self.drop(x)
31
+ return x
32
+
33
+
34
+ def window_partition(x, window_size):
35
+ """
36
+ Args:
37
+ x: (B, H, W, C)
38
+ window_size (int): window size
39
+ Returns:
40
+ windows: (num_windows*B, window_size, window_size, C)
41
+ """
42
+ B, H, W, C = x.shape
43
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
44
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
45
+ return windows
46
+
47
+
48
+ def window_reverse(windows, window_size, H, W):
49
+ """
50
+ Args:
51
+ windows: (num_windows*B, window_size, window_size, C)
52
+ window_size (int): Window size
53
+ H (int): Height of image
54
+ W (int): Width of image
55
+ Returns:
56
+ x: (B, H, W, C)
57
+ """
58
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
59
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
60
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
61
+ return x
62
+
63
+ class WindowAttention(nn.Module):
64
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
65
+ It supports both of shifted and non-shifted window.
66
+ Args:
67
+ dim (int): Number of input channels.
68
+ window_size (tuple[int]): The height and width of the window.
69
+ num_heads (int): Number of attention heads.
70
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
71
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
72
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
73
+ pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
74
+ """
75
+
76
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
77
+ pretrained_window_size=[0, 0]):
78
+
79
+ super().__init__()
80
+ self.dim = dim
81
+ self.window_size = window_size # Wh, Ww
82
+ self.pretrained_window_size = pretrained_window_size
83
+ self.num_heads = num_heads
84
+
85
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
86
+
87
+ # mlp to generate continuous relative position bias
88
+ self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
89
+ nn.ReLU(inplace=True),
90
+ nn.Linear(512, num_heads, bias=False))
91
+
92
+ # get relative_coords_table
93
+ relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
94
+ relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
95
+ relative_coords_table = torch.stack(
96
+ torch.meshgrid([relative_coords_h,
97
+ relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
98
+ if pretrained_window_size[0] > 0:
99
+ relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
100
+ relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
101
+ else:
102
+ relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
103
+ relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
104
+ relative_coords_table *= 8 # normalize to -8, 8
105
+ relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
106
+ torch.abs(relative_coords_table) + 1.0) / np.log2(8)
107
+
108
+ self.register_buffer("relative_coords_table", relative_coords_table)
109
+
110
+ # get pair-wise relative position index for each token inside the window
111
+ coords_h = torch.arange(self.window_size[0])
112
+ coords_w = torch.arange(self.window_size[1])
113
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
114
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
115
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
116
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
117
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
118
+ relative_coords[:, :, 1] += self.window_size[1] - 1
119
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
120
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
121
+ self.register_buffer("relative_position_index", relative_position_index)
122
+
123
+ self.qkv = nn.Linear(dim, dim * 3, bias=False)
124
+ if qkv_bias:
125
+ self.q_bias = nn.Parameter(torch.zeros(dim))
126
+ self.v_bias = nn.Parameter(torch.zeros(dim))
127
+ else:
128
+ self.q_bias = None
129
+ self.v_bias = None
130
+ self.attn_drop = nn.Dropout(attn_drop)
131
+ self.proj = nn.Linear(dim, dim)
132
+ self.proj_drop = nn.Dropout(proj_drop)
133
+ self.softmax = nn.Softmax(dim=-1)
134
+
135
+ def forward(self, x, mask=None):
136
+ """
137
+ Args:
138
+ x: input features with shape of (num_windows*B, N, C)
139
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
140
+ """
141
+ B_, N, C = x.shape
142
+ qkv_bias = None
143
+ if self.q_bias is not None:
144
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
145
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
146
+ qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
147
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
148
+
149
+ # cosine attention
150
+ attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
151
+ logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).to(self.logit_scale.device)).exp()
152
+ attn = attn * logit_scale
153
+
154
+ relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
155
+ relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
156
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
157
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
158
+ relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
159
+ attn = attn + relative_position_bias.unsqueeze(0)
160
+
161
+ if mask is not None:
162
+ nW = mask.shape[0]
163
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
164
+ attn = attn.view(-1, self.num_heads, N, N)
165
+ attn = self.softmax(attn)
166
+ else:
167
+ attn = self.softmax(attn)
168
+
169
+ attn = self.attn_drop(attn)
170
+
171
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
172
+ x = self.proj(x)
173
+ x = self.proj_drop(x)
174
+ return x
175
+
176
+ def extra_repr(self) -> str:
177
+ return f'dim={self.dim}, window_size={self.window_size}, ' \
178
+ f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'
179
+
180
+ def flops(self, N):
181
+ # calculate flops for 1 window with token length of N
182
+ flops = 0
183
+ # qkv = self.qkv(x)
184
+ flops += N * self.dim * 3 * self.dim
185
+ # attn = (q @ k.transpose(-2, -1))
186
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
187
+ # x = (attn @ v)
188
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
189
+ # x = self.proj(x)
190
+ flops += N * self.dim * self.dim
191
+ return flops
192
+
193
+ class SwinTransformerBlock(nn.Module):
194
+ r""" Swin Transformer Block.
195
+ Args:
196
+ dim (int): Number of input channels.
197
+ input_resolution (tuple[int]): Input resulotion.
198
+ num_heads (int): Number of attention heads.
199
+ window_size (int): Window size.
200
+ shift_size (int): Shift size for SW-MSA.
201
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
202
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
203
+ drop (float, optional): Dropout rate. Default: 0.0
204
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
205
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
206
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
207
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
208
+ pretrained_window_size (int): Window size in pre-training.
209
+ """
210
+
211
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
212
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
213
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0):
214
+ super().__init__()
215
+ self.dim = dim
216
+ self.input_resolution = input_resolution
217
+ self.num_heads = num_heads
218
+ self.window_size = window_size
219
+ self.shift_size = shift_size
220
+ self.mlp_ratio = mlp_ratio
221
+ if min(self.input_resolution) <= self.window_size:
222
+ # if window size is larger than input resolution, we don't partition windows
223
+ self.shift_size = 0
224
+ self.window_size = min(self.input_resolution)
225
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
226
+
227
+ self.norm1 = norm_layer(dim)
228
+ self.attn = WindowAttention(
229
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
230
+ qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
231
+ pretrained_window_size=to_2tuple(pretrained_window_size))
232
+
233
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
234
+ self.norm2 = norm_layer(dim)
235
+ mlp_hidden_dim = int(dim * mlp_ratio)
236
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
237
+
238
+ if self.shift_size > 0:
239
+ attn_mask = self.calculate_mask(self.input_resolution)
240
+ else:
241
+ attn_mask = None
242
+
243
+ self.register_buffer("attn_mask", attn_mask)
244
+
245
+ def calculate_mask(self, x_size):
246
+ # calculate attention mask for SW-MSA
247
+ H, W = x_size
248
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
249
+ h_slices = (slice(0, -self.window_size),
250
+ slice(-self.window_size, -self.shift_size),
251
+ slice(-self.shift_size, None))
252
+ w_slices = (slice(0, -self.window_size),
253
+ slice(-self.window_size, -self.shift_size),
254
+ slice(-self.shift_size, None))
255
+ cnt = 0
256
+ for h in h_slices:
257
+ for w in w_slices:
258
+ img_mask[:, h, w, :] = cnt
259
+ cnt += 1
260
+
261
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
262
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
263
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
264
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
265
+
266
+ return attn_mask
267
+
268
+ def forward(self, x, x_size):
269
+ H, W = x_size
270
+ B, L, C = x.shape
271
+ #assert L == H * W, "input feature has wrong size"
272
+
273
+ shortcut = x
274
+ x = x.view(B, H, W, C)
275
+
276
+ # cyclic shift
277
+ if self.shift_size > 0:
278
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
279
+ else:
280
+ shifted_x = x
281
+
282
+ # partition windows
283
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
284
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
285
+
286
+ # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
287
+ if self.input_resolution == x_size:
288
+ attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
289
+ else:
290
+ attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
291
+
292
+ # merge windows
293
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
294
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
295
+
296
+ # reverse cyclic shift
297
+ if self.shift_size > 0:
298
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
299
+ else:
300
+ x = shifted_x
301
+ x = x.view(B, H * W, C)
302
+ x = shortcut + self.drop_path(self.norm1(x))
303
+
304
+ # FFN
305
+ x = x + self.drop_path(self.norm2(self.mlp(x)))
306
+
307
+ return x
308
+
309
+ def extra_repr(self) -> str:
310
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
311
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
312
+
313
+ def flops(self):
314
+ flops = 0
315
+ H, W = self.input_resolution
316
+ # norm1
317
+ flops += self.dim * H * W
318
+ # W-MSA/SW-MSA
319
+ nW = H * W / self.window_size / self.window_size
320
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
321
+ # mlp
322
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
323
+ # norm2
324
+ flops += self.dim * H * W
325
+ return flops
326
+
327
+ class PatchMerging(nn.Module):
328
+ r""" Patch Merging Layer.
329
+ Args:
330
+ input_resolution (tuple[int]): Resolution of input feature.
331
+ dim (int): Number of input channels.
332
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
333
+ """
334
+
335
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
336
+ super().__init__()
337
+ self.input_resolution = input_resolution
338
+ self.dim = dim
339
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
340
+ self.norm = norm_layer(2 * dim)
341
+
342
+ def forward(self, x):
343
+ """
344
+ x: B, H*W, C
345
+ """
346
+ H, W = self.input_resolution
347
+ B, L, C = x.shape
348
+ assert L == H * W, "input feature has wrong size"
349
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
350
+
351
+ x = x.view(B, H, W, C)
352
+
353
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
354
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
355
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
356
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
357
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
358
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
359
+
360
+ x = self.reduction(x)
361
+ x = self.norm(x)
362
+
363
+ return x
364
+
365
+ def extra_repr(self) -> str:
366
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
367
+
368
+ def flops(self):
369
+ H, W = self.input_resolution
370
+ flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
371
+ flops += H * W * self.dim // 2
372
+ return flops
373
+
374
+ class BasicLayer(nn.Module):
375
+ """ A basic Swin Transformer layer for one stage.
376
+ Args:
377
+ dim (int): Number of input channels.
378
+ input_resolution (tuple[int]): Input resolution.
379
+ depth (int): Number of blocks.
380
+ num_heads (int): Number of attention heads.
381
+ window_size (int): Local window size.
382
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
383
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
384
+ drop (float, optional): Dropout rate. Default: 0.0
385
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
386
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
387
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
388
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
389
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
390
+ pretrained_window_size (int): Local window size in pre-training.
391
+ """
392
+
393
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
394
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
395
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
396
+ pretrained_window_size=0):
397
+
398
+ super().__init__()
399
+ self.dim = dim
400
+ self.input_resolution = input_resolution
401
+ self.depth = depth
402
+ self.use_checkpoint = use_checkpoint
403
+
404
+ # build blocks
405
+ self.blocks = nn.ModuleList([
406
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
407
+ num_heads=num_heads, window_size=window_size,
408
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
409
+ mlp_ratio=mlp_ratio,
410
+ qkv_bias=qkv_bias,
411
+ drop=drop, attn_drop=attn_drop,
412
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
413
+ norm_layer=norm_layer,
414
+ pretrained_window_size=pretrained_window_size)
415
+ for i in range(depth)])
416
+
417
+ # patch merging layer
418
+ if downsample is not None:
419
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
420
+ else:
421
+ self.downsample = None
422
+
423
+ def forward(self, x, x_size):
424
+ for blk in self.blocks:
425
+ if self.use_checkpoint:
426
+ x = checkpoint.checkpoint(blk, x, x_size)
427
+ else:
428
+ x = blk(x, x_size)
429
+ if self.downsample is not None:
430
+ x = self.downsample(x)
431
+ return x
432
+
433
+ def extra_repr(self) -> str:
434
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
435
+
436
+ def flops(self):
437
+ flops = 0
438
+ for blk in self.blocks:
439
+ flops += blk.flops()
440
+ if self.downsample is not None:
441
+ flops += self.downsample.flops()
442
+ return flops
443
+
444
+ def _init_respostnorm(self):
445
+ for blk in self.blocks:
446
+ nn.init.constant_(blk.norm1.bias, 0)
447
+ nn.init.constant_(blk.norm1.weight, 0)
448
+ nn.init.constant_(blk.norm2.bias, 0)
449
+ nn.init.constant_(blk.norm2.weight, 0)
450
+
451
+ class PatchEmbed(nn.Module):
452
+ r""" Image to Patch Embedding
453
+ Args:
454
+ img_size (int): Image size. Default: 224.
455
+ patch_size (int): Patch token size. Default: 4.
456
+ in_chans (int): Number of input image channels. Default: 3.
457
+ embed_dim (int): Number of linear projection output channels. Default: 96.
458
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
459
+ """
460
+
461
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
462
+ super().__init__()
463
+ img_size = to_2tuple(img_size)
464
+ patch_size = to_2tuple(patch_size)
465
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
466
+ self.img_size = img_size
467
+ self.patch_size = patch_size
468
+ self.patches_resolution = patches_resolution
469
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
470
+
471
+ self.in_chans = in_chans
472
+ self.embed_dim = embed_dim
473
+
474
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
475
+ if norm_layer is not None:
476
+ self.norm = norm_layer(embed_dim)
477
+ else:
478
+ self.norm = None
479
+
480
+ def forward(self, x):
481
+ B, C, H, W = x.shape
482
+ # FIXME look at relaxing size constraints
483
+ # assert H == self.img_size[0] and W == self.img_size[1],
484
+ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
485
+ x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
486
+ if self.norm is not None:
487
+ x = self.norm(x)
488
+ return x
489
+
490
+ def flops(self):
491
+ Ho, Wo = self.patches_resolution
492
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
493
+ if self.norm is not None:
494
+ flops += Ho * Wo * self.embed_dim
495
+ return flops
496
+
497
+ class RSTB(nn.Module):
498
+ """Residual Swin Transformer Block (RSTB).
499
+
500
+ Args:
501
+ dim (int): Number of input channels.
502
+ input_resolution (tuple[int]): Input resolution.
503
+ depth (int): Number of blocks.
504
+ num_heads (int): Number of attention heads.
505
+ window_size (int): Local window size.
506
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
507
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
508
+ drop (float, optional): Dropout rate. Default: 0.0
509
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
510
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
511
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
512
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
513
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
514
+ img_size: Input image size.
515
+ patch_size: Patch size.
516
+ resi_connection: The convolutional block before residual connection.
517
+ """
518
+
519
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
520
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
521
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
522
+ img_size=224, patch_size=4, resi_connection='1conv'):
523
+ super(RSTB, self).__init__()
524
+
525
+ self.dim = dim
526
+ self.input_resolution = input_resolution
527
+
528
+ self.residual_group = BasicLayer(dim=dim,
529
+ input_resolution=input_resolution,
530
+ depth=depth,
531
+ num_heads=num_heads,
532
+ window_size=window_size,
533
+ mlp_ratio=mlp_ratio,
534
+ qkv_bias=qkv_bias,
535
+ drop=drop, attn_drop=attn_drop,
536
+ drop_path=drop_path,
537
+ norm_layer=norm_layer,
538
+ downsample=downsample,
539
+ use_checkpoint=use_checkpoint)
540
+
541
+ if resi_connection == '1conv':
542
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
543
+ elif resi_connection == '3conv':
544
+ # to save parameters and memory
545
+ self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
546
+ nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
547
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
548
+ nn.Conv2d(dim // 4, dim, 3, 1, 1))
549
+
550
+ self.patch_embed = PatchEmbed(
551
+ img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim,
552
+ norm_layer=None)
553
+
554
+ self.patch_unembed = PatchUnEmbed(
555
+ img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim,
556
+ norm_layer=None)
557
+
558
+ def forward(self, x, x_size):
559
+ return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
560
+
561
+ def flops(self):
562
+ flops = 0
563
+ flops += self.residual_group.flops()
564
+ H, W = self.input_resolution
565
+ flops += H * W * self.dim * self.dim * 9
566
+ flops += self.patch_embed.flops()
567
+ flops += self.patch_unembed.flops()
568
+
569
+ return flops
570
+
571
+ class PatchUnEmbed(nn.Module):
572
+ r""" Image to Patch Unembedding
573
+
574
+ Args:
575
+ img_size (int): Image size. Default: 224.
576
+ patch_size (int): Patch token size. Default: 4.
577
+ in_chans (int): Number of input image channels. Default: 3.
578
+ embed_dim (int): Number of linear projection output channels. Default: 96.
579
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
580
+ """
581
+
582
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
583
+ super().__init__()
584
+ img_size = to_2tuple(img_size)
585
+ patch_size = to_2tuple(patch_size)
586
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
587
+ self.img_size = img_size
588
+ self.patch_size = patch_size
589
+ self.patches_resolution = patches_resolution
590
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
591
+
592
+ self.in_chans = in_chans
593
+ self.embed_dim = embed_dim
594
+
595
+ def forward(self, x, x_size):
596
+ B, HW, C = x.shape
597
+ x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
598
+ return x
599
+
600
+ def flops(self):
601
+ flops = 0
602
+ return flops
603
+
604
+
605
+ class Upsample(nn.Sequential):
606
+ """Upsample module.
607
+
608
+ Args:
609
+ scale (int): Scale factor. Supported scales: 2^n and 3.
610
+ num_feat (int): Channel number of intermediate features.
611
+ """
612
+
613
+ def __init__(self, scale, num_feat):
614
+ m = []
615
+ if (scale & (scale - 1)) == 0: # scale = 2^n
616
+ for _ in range(int(math.log(scale, 2))):
617
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
618
+ m.append(nn.PixelShuffle(2))
619
+ elif scale == 3:
620
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
621
+ m.append(nn.PixelShuffle(3))
622
+ else:
623
+ raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
624
+ super(Upsample, self).__init__(*m)
625
+
626
+ class Upsample_hf(nn.Sequential):
627
+ """Upsample module.
628
+
629
+ Args:
630
+ scale (int): Scale factor. Supported scales: 2^n and 3.
631
+ num_feat (int): Channel number of intermediate features.
632
+ """
633
+
634
+ def __init__(self, scale, num_feat):
635
+ m = []
636
+ if (scale & (scale - 1)) == 0: # scale = 2^n
637
+ for _ in range(int(math.log(scale, 2))):
638
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
639
+ m.append(nn.PixelShuffle(2))
640
+ elif scale == 3:
641
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
642
+ m.append(nn.PixelShuffle(3))
643
+ else:
644
+ raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
645
+ super(Upsample_hf, self).__init__(*m)
646
+
647
+
648
+ class UpsampleOneStep(nn.Sequential):
649
+ """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
650
+ Used in lightweight SR to save parameters.
651
+
652
+ Args:
653
+ scale (int): Scale factor. Supported scales: 2^n and 3.
654
+ num_feat (int): Channel number of intermediate features.
655
+
656
+ """
657
+
658
+ def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
659
+ self.num_feat = num_feat
660
+ self.input_resolution = input_resolution
661
+ m = []
662
+ m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
663
+ m.append(nn.PixelShuffle(scale))
664
+ super(UpsampleOneStep, self).__init__(*m)
665
+
666
+ def flops(self):
667
+ H, W = self.input_resolution
668
+ flops = H * W * self.num_feat * 3 * 9
669
+ return flops
670
+
671
+
672
+
673
+ class Swin2SR(nn.Module):
674
+ r""" Swin2SR
675
+ A PyTorch impl of : `Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration`.
676
+
677
+ Args:
678
+ img_size (int | tuple(int)): Input image size. Default 64
679
+ patch_size (int | tuple(int)): Patch size. Default: 1
680
+ in_chans (int): Number of input image channels. Default: 3
681
+ embed_dim (int): Patch embedding dimension. Default: 96
682
+ depths (tuple(int)): Depth of each Swin Transformer layer.
683
+ num_heads (tuple(int)): Number of attention heads in different layers.
684
+ window_size (int): Window size. Default: 7
685
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
686
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
687
+ drop_rate (float): Dropout rate. Default: 0
688
+ attn_drop_rate (float): Attention dropout rate. Default: 0
689
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
690
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
691
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
692
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
693
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
694
+ upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
695
+ img_range: Image range. 1. or 255.
696
+ upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
697
+ resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
698
+ """
699
+
700
+ def __init__(self, img_size=64, patch_size=1, in_chans=3,
701
+ embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
702
+ window_size=7, mlp_ratio=4., qkv_bias=True,
703
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
704
+ norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
705
+ use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
706
+ **kwargs):
707
+ super(Swin2SR, self).__init__()
708
+ num_in_ch = in_chans
709
+ num_out_ch = in_chans
710
+ num_feat = 64
711
+ self.img_range = img_range
712
+ if in_chans == 3:
713
+ rgb_mean = (0.4488, 0.4371, 0.4040)
714
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
715
+ else:
716
+ self.mean = torch.zeros(1, 1, 1, 1)
717
+ self.upscale = upscale
718
+ self.upsampler = upsampler
719
+ self.window_size = window_size
720
+
721
+ #####################################################################################################
722
+ ################################### 1, shallow feature extraction ###################################
723
+ self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
724
+
725
+ #####################################################################################################
726
+ ################################### 2, deep feature extraction ######################################
727
+ self.num_layers = len(depths)
728
+ self.embed_dim = embed_dim
729
+ self.ape = ape
730
+ self.patch_norm = patch_norm
731
+ self.num_features = embed_dim
732
+ self.mlp_ratio = mlp_ratio
733
+
734
+ # split image into non-overlapping patches
735
+ self.patch_embed = PatchEmbed(
736
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
737
+ norm_layer=norm_layer if self.patch_norm else None)
738
+ num_patches = self.patch_embed.num_patches
739
+ patches_resolution = self.patch_embed.patches_resolution
740
+ self.patches_resolution = patches_resolution
741
+
742
+ # merge non-overlapping patches into image
743
+ self.patch_unembed = PatchUnEmbed(
744
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
745
+ norm_layer=norm_layer if self.patch_norm else None)
746
+
747
+ # absolute position embedding
748
+ if self.ape:
749
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
750
+ trunc_normal_(self.absolute_pos_embed, std=.02)
751
+
752
+ self.pos_drop = nn.Dropout(p=drop_rate)
753
+
754
+ # stochastic depth
755
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
756
+
757
+ # build Residual Swin Transformer blocks (RSTB)
758
+ self.layers = nn.ModuleList()
759
+ for i_layer in range(self.num_layers):
760
+ layer = RSTB(dim=embed_dim,
761
+ input_resolution=(patches_resolution[0],
762
+ patches_resolution[1]),
763
+ depth=depths[i_layer],
764
+ num_heads=num_heads[i_layer],
765
+ window_size=window_size,
766
+ mlp_ratio=self.mlp_ratio,
767
+ qkv_bias=qkv_bias,
768
+ drop=drop_rate, attn_drop=attn_drop_rate,
769
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
770
+ norm_layer=norm_layer,
771
+ downsample=None,
772
+ use_checkpoint=use_checkpoint,
773
+ img_size=img_size,
774
+ patch_size=patch_size,
775
+ resi_connection=resi_connection
776
+
777
+ )
778
+ self.layers.append(layer)
779
+
780
+ if self.upsampler == 'pixelshuffle_hf':
781
+ self.layers_hf = nn.ModuleList()
782
+ for i_layer in range(self.num_layers):
783
+ layer = RSTB(dim=embed_dim,
784
+ input_resolution=(patches_resolution[0],
785
+ patches_resolution[1]),
786
+ depth=depths[i_layer],
787
+ num_heads=num_heads[i_layer],
788
+ window_size=window_size,
789
+ mlp_ratio=self.mlp_ratio,
790
+ qkv_bias=qkv_bias,
791
+ drop=drop_rate, attn_drop=attn_drop_rate,
792
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
793
+ norm_layer=norm_layer,
794
+ downsample=None,
795
+ use_checkpoint=use_checkpoint,
796
+ img_size=img_size,
797
+ patch_size=patch_size,
798
+ resi_connection=resi_connection
799
+
800
+ )
801
+ self.layers_hf.append(layer)
802
+
803
+ self.norm = norm_layer(self.num_features)
804
+
805
+ # build the last conv layer in deep feature extraction
806
+ if resi_connection == '1conv':
807
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
808
+ elif resi_connection == '3conv':
809
+ # to save parameters and memory
810
+ self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
811
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
812
+ nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
813
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
814
+ nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
815
+
816
+ #####################################################################################################
817
+ ################################ 3, high quality image reconstruction ################################
818
+ if self.upsampler == 'pixelshuffle':
819
+ # for classical SR
820
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
821
+ nn.LeakyReLU(inplace=True))
822
+ self.upsample = Upsample(upscale, num_feat)
823
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
824
+ elif self.upsampler == 'pixelshuffle_aux':
825
+ self.conv_bicubic = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
826
+ self.conv_before_upsample = nn.Sequential(
827
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
828
+ nn.LeakyReLU(inplace=True))
829
+ self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
830
+ self.conv_after_aux = nn.Sequential(
831
+ nn.Conv2d(3, num_feat, 3, 1, 1),
832
+ nn.LeakyReLU(inplace=True))
833
+ self.upsample = Upsample(upscale, num_feat)
834
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
835
+
836
+ elif self.upsampler == 'pixelshuffle_hf':
837
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
838
+ nn.LeakyReLU(inplace=True))
839
+ self.upsample = Upsample(upscale, num_feat)
840
+ self.upsample_hf = Upsample_hf(upscale, num_feat)
841
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
842
+ self.conv_first_hf = nn.Sequential(nn.Conv2d(num_feat, embed_dim, 3, 1, 1),
843
+ nn.LeakyReLU(inplace=True))
844
+ self.conv_after_body_hf = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
845
+ self.conv_before_upsample_hf = nn.Sequential(
846
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
847
+ nn.LeakyReLU(inplace=True))
848
+ self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
849
+
850
+ elif self.upsampler == 'pixelshuffledirect':
851
+ # for lightweight SR (to save parameters)
852
+ self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
853
+ (patches_resolution[0], patches_resolution[1]))
854
+ elif self.upsampler == 'nearest+conv':
855
+ # for real-world SR (less artifacts)
856
+ assert self.upscale == 4, 'only support x4 now.'
857
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
858
+ nn.LeakyReLU(inplace=True))
859
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
860
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
861
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
862
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
863
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
864
+ else:
865
+ # for image denoising and JPEG compression artifact reduction
866
+ self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
867
+
868
+ self.apply(self._init_weights)
869
+
870
+ def _init_weights(self, m):
871
+ if isinstance(m, nn.Linear):
872
+ trunc_normal_(m.weight, std=.02)
873
+ if isinstance(m, nn.Linear) and m.bias is not None:
874
+ nn.init.constant_(m.bias, 0)
875
+ elif isinstance(m, nn.LayerNorm):
876
+ nn.init.constant_(m.bias, 0)
877
+ nn.init.constant_(m.weight, 1.0)
878
+
879
+ @torch.jit.ignore
880
+ def no_weight_decay(self):
881
+ return {'absolute_pos_embed'}
882
+
883
+ @torch.jit.ignore
884
+ def no_weight_decay_keywords(self):
885
+ return {'relative_position_bias_table'}
886
+
887
+ def check_image_size(self, x):
888
+ _, _, h, w = x.size()
889
+ mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
890
+ mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
891
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
892
+ return x
893
+
894
+ def forward_features(self, x):
895
+ x_size = (x.shape[2], x.shape[3])
896
+ x = self.patch_embed(x)
897
+ if self.ape:
898
+ x = x + self.absolute_pos_embed
899
+ x = self.pos_drop(x)
900
+
901
+ for layer in self.layers:
902
+ x = layer(x, x_size)
903
+
904
+ x = self.norm(x) # B L C
905
+ x = self.patch_unembed(x, x_size)
906
+
907
+ return x
908
+
909
+ def forward_features_hf(self, x):
910
+ x_size = (x.shape[2], x.shape[3])
911
+ x = self.patch_embed(x)
912
+ if self.ape:
913
+ x = x + self.absolute_pos_embed
914
+ x = self.pos_drop(x)
915
+
916
+ for layer in self.layers_hf:
917
+ x = layer(x, x_size)
918
+
919
+ x = self.norm(x) # B L C
920
+ x = self.patch_unembed(x, x_size)
921
+
922
+ return x
923
+
924
+ def forward(self, x):
925
+ H, W = x.shape[2:]
926
+ x = self.check_image_size(x)
927
+
928
+ self.mean = self.mean.type_as(x)
929
+ x = (x - self.mean) * self.img_range
930
+
931
+ if self.upsampler == 'pixelshuffle':
932
+ # for classical SR
933
+ x = self.conv_first(x)
934
+ x = self.conv_after_body(self.forward_features(x)) + x
935
+ x = self.conv_before_upsample(x)
936
+ x = self.conv_last(self.upsample(x))
937
+ elif self.upsampler == 'pixelshuffle_aux':
938
+ bicubic = F.interpolate(x, size=(H * self.upscale, W * self.upscale), mode='bicubic', align_corners=False)
939
+ bicubic = self.conv_bicubic(bicubic)
940
+ x = self.conv_first(x)
941
+ x = self.conv_after_body(self.forward_features(x)) + x
942
+ x = self.conv_before_upsample(x)
943
+ aux = self.conv_aux(x) # b, 3, LR_H, LR_W
944
+ x = self.conv_after_aux(aux)
945
+ x = self.upsample(x)[:, :, :H * self.upscale, :W * self.upscale] + bicubic[:, :, :H * self.upscale, :W * self.upscale]
946
+ x = self.conv_last(x)
947
+ aux = aux / self.img_range + self.mean
948
+ elif self.upsampler == 'pixelshuffle_hf':
949
+ # for classical SR with HF
950
+ x = self.conv_first(x)
951
+ x = self.conv_after_body(self.forward_features(x)) + x
952
+ x_before = self.conv_before_upsample(x)
953
+ x_out = self.conv_last(self.upsample(x_before))
954
+
955
+ x_hf = self.conv_first_hf(x_before)
956
+ x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf
957
+ x_hf = self.conv_before_upsample_hf(x_hf)
958
+ x_hf = self.conv_last_hf(self.upsample_hf(x_hf))
959
+ x = x_out + x_hf
960
+ x_hf = x_hf / self.img_range + self.mean
961
+
962
+ elif self.upsampler == 'pixelshuffledirect':
963
+ # for lightweight SR
964
+ x = self.conv_first(x)
965
+ x = self.conv_after_body(self.forward_features(x)) + x
966
+ x = self.upsample(x)
967
+ elif self.upsampler == 'nearest+conv':
968
+ # for real-world SR
969
+ x = self.conv_first(x)
970
+ x = self.conv_after_body(self.forward_features(x)) + x
971
+ x = self.conv_before_upsample(x)
972
+ x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
973
+ x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
974
+ x = self.conv_last(self.lrelu(self.conv_hr(x)))
975
+ else:
976
+ # for image denoising and JPEG compression artifact reduction
977
+ x_first = self.conv_first(x)
978
+ res = self.conv_after_body(self.forward_features(x_first)) + x_first
979
+ x = x + self.conv_last(res)
980
+
981
+ x = x / self.img_range + self.mean
982
+ if self.upsampler == "pixelshuffle_aux":
983
+ return x[:, :, :H*self.upscale, :W*self.upscale], aux
984
+
985
+ elif self.upsampler == "pixelshuffle_hf":
986
+ x_out = x_out / self.img_range + self.mean
987
+ return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale]
988
+
989
+ else:
990
+ return x[:, :, :H*self.upscale, :W*self.upscale]
991
+
992
+ def flops(self):
993
+ flops = 0
994
+ H, W = self.patches_resolution
995
+ flops += H * W * 3 * self.embed_dim * 9
996
+ flops += self.patch_embed.flops()
997
+ for i, layer in enumerate(self.layers):
998
+ flops += layer.flops()
999
+ flops += H * W * 3 * self.embed_dim * self.embed_dim
1000
+ flops += self.upsample.flops()
1001
+ return flops
1002
+
1003
+
1004
+ if __name__ == '__main__':
1005
+ upscale = 4
1006
+ window_size = 8
1007
+ height = (1024 // upscale // window_size + 1) * window_size
1008
+ width = (720 // upscale // window_size + 1) * window_size
1009
+ model = Swin2SR(upscale=2, img_size=(height, width),
1010
+ window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
1011
+ embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
1012
+ print(model)
1013
+ print(height, width, model.flops() / 1e9)
1014
+
1015
+ x = torch.randn((1, 3, height, width))
1016
+ x = model(x)
1017
+ print(x.shape)
extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Stable Diffusion WebUI - Bracket checker
2
+ // Version 1.0
3
+ // By Hingashi no Florin/Bwin4L
4
+ // Counts open and closed brackets (round, square, curly) in the prompt and negative prompt text boxes in the txt2img and img2img tabs.
5
+ // If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong.
6
+
7
+ function checkBrackets(evt) {
8
+ textArea = evt.target;
9
+ tabName = evt.target.parentElement.parentElement.id.split("_")[0];
10
+ counterElt = document.querySelector('gradio-app').shadowRoot.querySelector('#' + tabName + '_token_counter');
11
+
12
+ promptName = evt.target.parentElement.parentElement.id.includes('neg') ? ' negative' : '';
13
+
14
+ errorStringParen = '(' + tabName + promptName + ' prompt) - Different number of opening and closing parentheses detected.\n';
15
+ errorStringSquare = '[' + tabName + promptName + ' prompt] - Different number of opening and closing square brackets detected.\n';
16
+ errorStringCurly = '{' + tabName + promptName + ' prompt} - Different number of opening and closing curly brackets detected.\n';
17
+
18
+ openBracketRegExp = /\(/g;
19
+ closeBracketRegExp = /\)/g;
20
+
21
+ openSquareBracketRegExp = /\[/g;
22
+ closeSquareBracketRegExp = /\]/g;
23
+
24
+ openCurlyBracketRegExp = /\{/g;
25
+ closeCurlyBracketRegExp = /\}/g;
26
+
27
+ totalOpenBracketMatches = 0;
28
+ totalCloseBracketMatches = 0;
29
+ totalOpenSquareBracketMatches = 0;
30
+ totalCloseSquareBracketMatches = 0;
31
+ totalOpenCurlyBracketMatches = 0;
32
+ totalCloseCurlyBracketMatches = 0;
33
+
34
+ openBracketMatches = textArea.value.match(openBracketRegExp);
35
+ if(openBracketMatches) {
36
+ totalOpenBracketMatches = openBracketMatches.length;
37
+ }
38
+
39
+ closeBracketMatches = textArea.value.match(closeBracketRegExp);
40
+ if(closeBracketMatches) {
41
+ totalCloseBracketMatches = closeBracketMatches.length;
42
+ }
43
+
44
+ openSquareBracketMatches = textArea.value.match(openSquareBracketRegExp);
45
+ if(openSquareBracketMatches) {
46
+ totalOpenSquareBracketMatches = openSquareBracketMatches.length;
47
+ }
48
+
49
+ closeSquareBracketMatches = textArea.value.match(closeSquareBracketRegExp);
50
+ if(closeSquareBracketMatches) {
51
+ totalCloseSquareBracketMatches = closeSquareBracketMatches.length;
52
+ }
53
+
54
+ openCurlyBracketMatches = textArea.value.match(openCurlyBracketRegExp);
55
+ if(openCurlyBracketMatches) {
56
+ totalOpenCurlyBracketMatches = openCurlyBracketMatches.length;
57
+ }
58
+
59
+ closeCurlyBracketMatches = textArea.value.match(closeCurlyBracketRegExp);
60
+ if(closeCurlyBracketMatches) {
61
+ totalCloseCurlyBracketMatches = closeCurlyBracketMatches.length;
62
+ }
63
+
64
+ if(totalOpenBracketMatches != totalCloseBracketMatches) {
65
+ if(!counterElt.title.includes(errorStringParen)) {
66
+ counterElt.title += errorStringParen;
67
+ }
68
+ } else {
69
+ counterElt.title = counterElt.title.replace(errorStringParen, '');
70
+ }
71
+
72
+ if(totalOpenSquareBracketMatches != totalCloseSquareBracketMatches) {
73
+ if(!counterElt.title.includes(errorStringSquare)) {
74
+ counterElt.title += errorStringSquare;
75
+ }
76
+ } else {
77
+ counterElt.title = counterElt.title.replace(errorStringSquare, '');
78
+ }
79
+
80
+ if(totalOpenCurlyBracketMatches != totalCloseCurlyBracketMatches) {
81
+ if(!counterElt.title.includes(errorStringCurly)) {
82
+ counterElt.title += errorStringCurly;
83
+ }
84
+ } else {
85
+ counterElt.title = counterElt.title.replace(errorStringCurly, '');
86
+ }
87
+
88
+ if(counterElt.title != '') {
89
+ counterElt.style = 'color: #FF5555;';
90
+ } else {
91
+ counterElt.style = '';
92
+ }
93
+ }
94
+
95
+ var shadowRootLoaded = setInterval(function() {
96
+ var shadowTextArea = document.querySelector('gradio-app').shadowRoot.querySelectorAll('#txt2img_prompt > label > textarea');
97
+ if(shadowTextArea.length < 1) {
98
+ return false;
99
+ }
100
+
101
+ clearInterval(shadowRootLoaded);
102
+
103
+ document.querySelector('gradio-app').shadowRoot.querySelector('#txt2img_prompt').onkeyup = checkBrackets;
104
+ document.querySelector('gradio-app').shadowRoot.querySelector('#txt2img_neg_prompt').onkeyup = checkBrackets;
105
+ document.querySelector('gradio-app').shadowRoot.querySelector('#img2img_prompt').onkeyup = checkBrackets;
106
+ document.querySelector('gradio-app').shadowRoot.querySelector('#img2img_neg_prompt').onkeyup = checkBrackets;
107
+ }, 1000);
javascript/aspectRatioOverlay.js ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ let currentWidth = null;
3
+ let currentHeight = null;
4
+ let arFrameTimeout = setTimeout(function(){},0);
5
+
6
+ function dimensionChange(e, is_width, is_height){
7
+
8
+ if(is_width){
9
+ currentWidth = e.target.value*1.0
10
+ }
11
+ if(is_height){
12
+ currentHeight = e.target.value*1.0
13
+ }
14
+
15
+ var inImg2img = Boolean(gradioApp().querySelector("button.rounded-t-lg.border-gray-200"))
16
+
17
+ if(!inImg2img){
18
+ return;
19
+ }
20
+
21
+ var targetElement = null;
22
+
23
+ var tabIndex = get_tab_index('mode_img2img')
24
+ if(tabIndex == 0){
25
+ targetElement = gradioApp().querySelector('div[data-testid=image] img');
26
+ } else if(tabIndex == 1){
27
+ targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img');
28
+ }
29
+
30
+ if(targetElement){
31
+
32
+ var arPreviewRect = gradioApp().querySelector('#imageARPreview');
33
+ if(!arPreviewRect){
34
+ arPreviewRect = document.createElement('div')
35
+ arPreviewRect.id = "imageARPreview";
36
+ gradioApp().getRootNode().appendChild(arPreviewRect)
37
+ }
38
+
39
+
40
+
41
+ var viewportOffset = targetElement.getBoundingClientRect();
42
+
43
+ viewportscale = Math.min( targetElement.clientWidth/targetElement.naturalWidth, targetElement.clientHeight/targetElement.naturalHeight )
44
+
45
+ scaledx = targetElement.naturalWidth*viewportscale
46
+ scaledy = targetElement.naturalHeight*viewportscale
47
+
48
+ cleintRectTop = (viewportOffset.top+window.scrollY)
49
+ cleintRectLeft = (viewportOffset.left+window.scrollX)
50
+ cleintRectCentreY = cleintRectTop + (targetElement.clientHeight/2)
51
+ cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth/2)
52
+
53
+ viewRectTop = cleintRectCentreY-(scaledy/2)
54
+ viewRectLeft = cleintRectCentreX-(scaledx/2)
55
+ arRectWidth = scaledx
56
+ arRectHeight = scaledy
57
+
58
+ arscale = Math.min( arRectWidth/currentWidth, arRectHeight/currentHeight )
59
+ arscaledx = currentWidth*arscale
60
+ arscaledy = currentHeight*arscale
61
+
62
+ arRectTop = cleintRectCentreY-(arscaledy/2)
63
+ arRectLeft = cleintRectCentreX-(arscaledx/2)
64
+ arRectWidth = arscaledx
65
+ arRectHeight = arscaledy
66
+
67
+ arPreviewRect.style.top = arRectTop+'px';
68
+ arPreviewRect.style.left = arRectLeft+'px';
69
+ arPreviewRect.style.width = arRectWidth+'px';
70
+ arPreviewRect.style.height = arRectHeight+'px';
71
+
72
+ clearTimeout(arFrameTimeout);
73
+ arFrameTimeout = setTimeout(function(){
74
+ arPreviewRect.style.display = 'none';
75
+ },2000);
76
+
77
+ arPreviewRect.style.display = 'block';
78
+
79
+ }
80
+
81
+ }
82
+
83
+
84
+ onUiUpdate(function(){
85
+ var arPreviewRect = gradioApp().querySelector('#imageARPreview');
86
+ if(arPreviewRect){
87
+ arPreviewRect.style.display = 'none';
88
+ }
89
+ var inImg2img = Boolean(gradioApp().querySelector("button.rounded-t-lg.border-gray-200"))
90
+ if(inImg2img){
91
+ let inputs = gradioApp().querySelectorAll('input');
92
+ inputs.forEach(function(e){
93
+ var is_width = e.parentElement.id == "img2img_width"
94
+ var is_height = e.parentElement.id == "img2img_height"
95
+
96
+ if((is_width || is_height) && !e.classList.contains('scrollwatch')){
97
+ e.addEventListener('input', function(e){dimensionChange(e, is_width, is_height)} )
98
+ e.classList.add('scrollwatch')
99
+ }
100
+ if(is_width){
101
+ currentWidth = e.value*1.0
102
+ }
103
+ if(is_height){
104
+ currentHeight = e.value*1.0
105
+ }
106
+ })
107
+ }
108
+ });
javascript/contextMenus.js ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ contextMenuInit = function(){
3
+ let eventListenerApplied=false;
4
+ let menuSpecs = new Map();
5
+
6
+ const uid = function(){
7
+ return Date.now().toString(36) + Math.random().toString(36).substr(2);
8
+ }
9
+
10
+ function showContextMenu(event,element,menuEntries){
11
+ let posx = event.clientX + document.body.scrollLeft + document.documentElement.scrollLeft;
12
+ let posy = event.clientY + document.body.scrollTop + document.documentElement.scrollTop;
13
+
14
+ let oldMenu = gradioApp().querySelector('#context-menu')
15
+ if(oldMenu){
16
+ oldMenu.remove()
17
+ }
18
+
19
+ let tabButton = uiCurrentTab
20
+ let baseStyle = window.getComputedStyle(tabButton)
21
+
22
+ const contextMenu = document.createElement('nav')
23
+ contextMenu.id = "context-menu"
24
+ contextMenu.style.background = baseStyle.background
25
+ contextMenu.style.color = baseStyle.color
26
+ contextMenu.style.fontFamily = baseStyle.fontFamily
27
+ contextMenu.style.top = posy+'px'
28
+ contextMenu.style.left = posx+'px'
29
+
30
+
31
+
32
+ const contextMenuList = document.createElement('ul')
33
+ contextMenuList.className = 'context-menu-items';
34
+ contextMenu.append(contextMenuList);
35
+
36
+ menuEntries.forEach(function(entry){
37
+ let contextMenuEntry = document.createElement('a')
38
+ contextMenuEntry.innerHTML = entry['name']
39
+ contextMenuEntry.addEventListener("click", function(e) {
40
+ entry['func']();
41
+ })
42
+ contextMenuList.append(contextMenuEntry);
43
+
44
+ })
45
+
46
+ gradioApp().getRootNode().appendChild(contextMenu)
47
+
48
+ let menuWidth = contextMenu.offsetWidth + 4;
49
+ let menuHeight = contextMenu.offsetHeight + 4;
50
+
51
+ let windowWidth = window.innerWidth;
52
+ let windowHeight = window.innerHeight;
53
+
54
+ if ( (windowWidth - posx) < menuWidth ) {
55
+ contextMenu.style.left = windowWidth - menuWidth + "px";
56
+ }
57
+
58
+ if ( (windowHeight - posy) < menuHeight ) {
59
+ contextMenu.style.top = windowHeight - menuHeight + "px";
60
+ }
61
+
62
+ }
63
+
64
+ function appendContextMenuOption(targetElementSelector,entryName,entryFunction){
65
+
66
+ currentItems = menuSpecs.get(targetElementSelector)
67
+
68
+ if(!currentItems){
69
+ currentItems = []
70
+ menuSpecs.set(targetElementSelector,currentItems);
71
+ }
72
+ let newItem = {'id':targetElementSelector+'_'+uid(),
73
+ 'name':entryName,
74
+ 'func':entryFunction,
75
+ 'isNew':true}
76
+
77
+ currentItems.push(newItem)
78
+ return newItem['id']
79
+ }
80
+
81
+ function removeContextMenuOption(uid){
82
+ menuSpecs.forEach(function(v,k) {
83
+ let index = -1
84
+ v.forEach(function(e,ei){if(e['id']==uid){index=ei}})
85
+ if(index>=0){
86
+ v.splice(index, 1);
87
+ }
88
+ })
89
+ }
90
+
91
+ function addContextMenuEventListener(){
92
+ if(eventListenerApplied){
93
+ return;
94
+ }
95
+ gradioApp().addEventListener("click", function(e) {
96
+ let source = e.composedPath()[0]
97
+ if(source.id && source.id.indexOf('check_progress')>-1){
98
+ return
99
+ }
100
+
101
+ let oldMenu = gradioApp().querySelector('#context-menu')
102
+ if(oldMenu){
103
+ oldMenu.remove()
104
+ }
105
+ });
106
+ gradioApp().addEventListener("contextmenu", function(e) {
107
+ let oldMenu = gradioApp().querySelector('#context-menu')
108
+ if(oldMenu){
109
+ oldMenu.remove()
110
+ }
111
+ menuSpecs.forEach(function(v,k) {
112
+ if(e.composedPath()[0].matches(k)){
113
+ showContextMenu(e,e.composedPath()[0],v)
114
+ e.preventDefault()
115
+ return
116
+ }
117
+ })
118
+ });
119
+ eventListenerApplied=true
120
+
121
+ }
122
+
123
+ return [appendContextMenuOption, removeContextMenuOption, addContextMenuEventListener]
124
+ }
125
+
126
+ initResponse = contextMenuInit();
127
+ appendContextMenuOption = initResponse[0];
128
+ removeContextMenuOption = initResponse[1];
129
+ addContextMenuEventListener = initResponse[2];
130
+
131
+ (function(){
132
+ //Start example Context Menu Items
133
+ let generateOnRepeat = function(genbuttonid,interruptbuttonid){
134
+ let genbutton = gradioApp().querySelector(genbuttonid);
135
+ let interruptbutton = gradioApp().querySelector(interruptbuttonid);
136
+ if(!interruptbutton.offsetParent){
137
+ genbutton.click();
138
+ }
139
+ clearInterval(window.generateOnRepeatInterval)
140
+ window.generateOnRepeatInterval = setInterval(function(){
141
+ if(!interruptbutton.offsetParent){
142
+ genbutton.click();
143
+ }
144
+ },
145
+ 500)
146
+ }
147
+
148
+ appendContextMenuOption('#txt2img_generate','Generate forever',function(){
149
+ generateOnRepeat('#txt2img_generate','#txt2img_interrupt');
150
+ })
151
+ appendContextMenuOption('#img2img_generate','Generate forever',function(){
152
+ generateOnRepeat('#img2img_generate','#img2img_interrupt');
153
+ })
154
+
155
+ let cancelGenerateForever = function(){
156
+ clearInterval(window.generateOnRepeatInterval)
157
+ }
158
+
159
+ appendContextMenuOption('#txt2img_interrupt','Cancel generate forever',cancelGenerateForever)
160
+ appendContextMenuOption('#txt2img_generate', 'Cancel generate forever',cancelGenerateForever)
161
+ appendContextMenuOption('#img2img_interrupt','Cancel generate forever',cancelGenerateForever)
162
+ appendContextMenuOption('#img2img_generate', 'Cancel generate forever',cancelGenerateForever)
163
+
164
+ appendContextMenuOption('#roll','Roll three',
165
+ function(){
166
+ let rollbutton = get_uiCurrentTabContent().querySelector('#roll');
167
+ setTimeout(function(){rollbutton.click()},100)
168
+ setTimeout(function(){rollbutton.click()},200)
169
+ setTimeout(function(){rollbutton.click()},300)
170
+ }
171
+ )
172
+ })();
173
+ //End example Context Menu Items
174
+
175
+ onUiUpdate(function(){
176
+ addContextMenuEventListener()
177
+ });
javascript/dragdrop.js ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // allows drag-dropping files into gradio image elements, and also pasting images from clipboard
2
+
3
+ function isValidImageList( files ) {
4
+ return files && files?.length === 1 && ['image/png', 'image/gif', 'image/jpeg'].includes(files[0].type);
5
+ }
6
+
7
+ function dropReplaceImage( imgWrap, files ) {
8
+ if ( ! isValidImageList( files ) ) {
9
+ return;
10
+ }
11
+
12
+ imgWrap.querySelector('.modify-upload button + button, .touch-none + div button + button')?.click();
13
+ const callback = () => {
14
+ const fileInput = imgWrap.querySelector('input[type="file"]');
15
+ if ( fileInput ) {
16
+ fileInput.files = files;
17
+ fileInput.dispatchEvent(new Event('change'));
18
+ }
19
+ };
20
+
21
+ if ( imgWrap.closest('#pnginfo_image') ) {
22
+ // special treatment for PNG Info tab, wait for fetch request to finish
23
+ const oldFetch = window.fetch;
24
+ window.fetch = async (input, options) => {
25
+ const response = await oldFetch(input, options);
26
+ if ( 'api/predict/' === input ) {
27
+ const content = await response.text();
28
+ window.fetch = oldFetch;
29
+ window.requestAnimationFrame( () => callback() );
30
+ return new Response(content, {
31
+ status: response.status,
32
+ statusText: response.statusText,
33
+ headers: response.headers
34
+ })
35
+ }
36
+ return response;
37
+ };
38
+ } else {
39
+ window.requestAnimationFrame( () => callback() );
40
+ }
41
+ }
42
+
43
+ window.document.addEventListener('dragover', e => {
44
+ const target = e.composedPath()[0];
45
+ const imgWrap = target.closest('[data-testid="image"]');
46
+ if ( !imgWrap && target.placeholder && target.placeholder.indexOf("Prompt") == -1) {
47
+ return;
48
+ }
49
+ e.stopPropagation();
50
+ e.preventDefault();
51
+ e.dataTransfer.dropEffect = 'copy';
52
+ });
53
+
54
+ window.document.addEventListener('drop', e => {
55
+ const target = e.composedPath()[0];
56
+ if (target.placeholder.indexOf("Prompt") == -1) {
57
+ return;
58
+ }
59
+ const imgWrap = target.closest('[data-testid="image"]');
60
+ if ( !imgWrap ) {
61
+ return;
62
+ }
63
+ e.stopPropagation();
64
+ e.preventDefault();
65
+ const files = e.dataTransfer.files;
66
+ dropReplaceImage( imgWrap, files );
67
+ });
68
+
69
+ window.addEventListener('paste', e => {
70
+ const files = e.clipboardData.files;
71
+ if ( ! isValidImageList( files ) ) {
72
+ return;
73
+ }
74
+
75
+ const visibleImageFields = [...gradioApp().querySelectorAll('[data-testid="image"]')]
76
+ .filter(el => uiElementIsVisible(el));
77
+ if ( ! visibleImageFields.length ) {
78
+ return;
79
+ }
80
+
81
+ const firstFreeImageField = visibleImageFields
82
+ .filter(el => el.querySelector('input[type=file]'))?.[0];
83
+
84
+ dropReplaceImage(
85
+ firstFreeImageField ?
86
+ firstFreeImageField :
87
+ visibleImageFields[visibleImageFields.length - 1]
88
+ , files );
89
+ });
javascript/edit-attention.js ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ addEventListener('keydown', (event) => {
2
+ let target = event.originalTarget || event.composedPath()[0];
3
+ if (!target.matches("#toprow textarea.gr-text-input[placeholder]")) return;
4
+ if (! (event.metaKey || event.ctrlKey)) return;
5
+
6
+
7
+ let plus = "ArrowUp"
8
+ let minus = "ArrowDown"
9
+ if (event.key != plus && event.key != minus) return;
10
+
11
+ let selectionStart = target.selectionStart;
12
+ let selectionEnd = target.selectionEnd;
13
+ // If the user hasn't selected anything, let's select their current parenthesis block
14
+ if (selectionStart === selectionEnd) {
15
+ // Find opening parenthesis around current cursor
16
+ const before = target.value.substring(0, selectionStart);
17
+ let beforeParen = before.lastIndexOf("(");
18
+ if (beforeParen == -1) return;
19
+ let beforeParenClose = before.lastIndexOf(")");
20
+ while (beforeParenClose !== -1 && beforeParenClose > beforeParen) {
21
+ beforeParen = before.lastIndexOf("(", beforeParen - 1);
22
+ beforeParenClose = before.lastIndexOf(")", beforeParenClose - 1);
23
+ }
24
+
25
+ // Find closing parenthesis around current cursor
26
+ const after = target.value.substring(selectionStart);
27
+ let afterParen = after.indexOf(")");
28
+ if (afterParen == -1) return;
29
+ let afterParenOpen = after.indexOf("(");
30
+ while (afterParenOpen !== -1 && afterParen > afterParenOpen) {
31
+ afterParen = after.indexOf(")", afterParen + 1);
32
+ afterParenOpen = after.indexOf("(", afterParenOpen + 1);
33
+ }
34
+ if (beforeParen === -1 || afterParen === -1) return;
35
+
36
+ // Set the selection to the text between the parenthesis
37
+ const parenContent = target.value.substring(beforeParen + 1, selectionStart + afterParen);
38
+ const lastColon = parenContent.lastIndexOf(":");
39
+ selectionStart = beforeParen + 1;
40
+ selectionEnd = selectionStart + lastColon;
41
+ target.setSelectionRange(selectionStart, selectionEnd);
42
+ }
43
+
44
+ event.preventDefault();
45
+
46
+ if (selectionStart == 0 || target.value[selectionStart - 1] != "(") {
47
+ target.value = target.value.slice(0, selectionStart) +
48
+ "(" + target.value.slice(selectionStart, selectionEnd) + ":1.0)" +
49
+ target.value.slice(selectionEnd);
50
+
51
+ target.focus();
52
+ target.selectionStart = selectionStart + 1;
53
+ target.selectionEnd = selectionEnd + 1;
54
+
55
+ } else {
56
+ end = target.value.slice(selectionEnd + 1).indexOf(")") + 1;
57
+ weight = parseFloat(target.value.slice(selectionEnd + 1, selectionEnd + 1 + end));
58
+ if (isNaN(weight)) return;
59
+ if (event.key == minus) weight -= 0.1;
60
+ if (event.key == plus) weight += 0.1;
61
+
62
+ weight = parseFloat(weight.toPrecision(12));
63
+
64
+ target.value = target.value.slice(0, selectionEnd + 1) +
65
+ weight +
66
+ target.value.slice(selectionEnd + 1 + end - 1);
67
+
68
+ target.focus();
69
+ target.selectionStart = selectionStart;
70
+ target.selectionEnd = selectionEnd;
71
+ }
72
+ // Since we've modified a Gradio Textbox component manually, we need to simulate an `input` DOM event to ensure its
73
+ // internal Svelte data binding remains in sync.
74
+ target.dispatchEvent(new Event("input", { bubbles: true }));
75
+ });
javascript/extensions.js ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ function extensions_apply(_, _){
3
+ disable = []
4
+ update = []
5
+ gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
6
+ if(x.name.startsWith("enable_") && ! x.checked)
7
+ disable.push(x.name.substr(7))
8
+
9
+ if(x.name.startsWith("update_") && x.checked)
10
+ update.push(x.name.substr(7))
11
+ })
12
+
13
+ restart_reload()
14
+
15
+ return [JSON.stringify(disable), JSON.stringify(update)]
16
+ }
17
+
18
+ function extensions_check(){
19
+ gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
20
+ x.innerHTML = "Loading..."
21
+ })
22
+
23
+ return []
24
+ }
25
+
26
+ function install_extension_from_index(button, url){
27
+ button.disabled = "disabled"
28
+ button.value = "Installing..."
29
+
30
+ textarea = gradioApp().querySelector('#extension_to_install textarea')
31
+ textarea.value = url
32
+ textarea.dispatchEvent(new Event("input", { bubbles: true }))
33
+
34
+ gradioApp().querySelector('#install_extension_button').click()
35
+ }
javascript/generationParams.js ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // attaches listeners to the txt2img and img2img galleries to update displayed generation param text when the image changes
2
+
3
+ let txt2img_gallery, img2img_gallery, modal = undefined;
4
+ onUiUpdate(function(){
5
+ if (!txt2img_gallery) {
6
+ txt2img_gallery = attachGalleryListeners("txt2img")
7
+ }
8
+ if (!img2img_gallery) {
9
+ img2img_gallery = attachGalleryListeners("img2img")
10
+ }
11
+ if (!modal) {
12
+ modal = gradioApp().getElementById('lightboxModal')
13
+ modalObserver.observe(modal, { attributes : true, attributeFilter : ['style'] });
14
+ }
15
+ });
16
+
17
+ let modalObserver = new MutationObserver(function(mutations) {
18
+ mutations.forEach(function(mutationRecord) {
19
+ let selectedTab = gradioApp().querySelector('#tabs div button.bg-white')?.innerText
20
+ if (mutationRecord.target.style.display === 'none' && selectedTab === 'txt2img' || selectedTab === 'img2img')
21
+ gradioApp().getElementById(selectedTab+"_generation_info_button").click()
22
+ });
23
+ });
24
+
25
+ function attachGalleryListeners(tab_name) {
26
+ gallery = gradioApp().querySelector('#'+tab_name+'_gallery')
27
+ gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name+"_generation_info_button").click());
28
+ gallery?.addEventListener('keydown', (e) => {
29
+ if (e.keyCode == 37 || e.keyCode == 39) // left or right arrow
30
+ gradioApp().getElementById(tab_name+"_generation_info_button").click()
31
+ });
32
+ return gallery;
33
+ }
javascript/hints.js ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // mouseover tooltips for various UI elements
2
+
3
+ titles = {
4
+ "Sampling steps": "How many times to improve the generated image iteratively; higher values take longer; very low values can produce bad results",
5
+ "Sampling method": "Which algorithm to use to produce the image",
6
+ "GFPGAN": "Restore low quality faces using GFPGAN neural network",
7
+ "Euler a": "Euler Ancestral - very creative, each can get a completely different picture depending on step count, setting steps to higher than 30-40 does not help",
8
+ "DDIM": "Denoising Diffusion Implicit Models - best at inpainting",
9
+ "DPM adaptive": "Ignores step count - uses a number of steps determined by the CFG and resolution",
10
+
11
+ "Batch count": "How many batches of images to create",
12
+ "Batch size": "How many image to create in a single batch",
13
+ "CFG Scale": "Classifier Free Guidance Scale - how strongly the image should conform to prompt - lower values produce more creative results",
14
+ "Seed": "A value that determines the output of random number generator - if you create an image with same parameters and seed as another image, you'll get the same result",
15
+ "\u{1f3b2}\ufe0f": "Set seed to -1, which will cause a new random number to be used every time",
16
+ "\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed",
17
+ "\u{1f3a8}": "Add a random artist to the prompt.",
18
+ "\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.",
19
+ "\u{1f4c2}": "Open images output directory",
20
+ "\u{1f4be}": "Save style",
21
+ "\U0001F5D1": "Clear prompt",
22
+ "\u{1f4cb}": "Apply selected styles to current prompt",
23
+
24
+ "Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
25
+ "SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",
26
+
27
+ "Just resize": "Resize image to target resolution. Unless height and width match, you will get incorrect aspect ratio.",
28
+ "Crop and resize": "Resize the image so that entirety of target resolution is filled with the image. Crop parts that stick out.",
29
+ "Resize and fill": "Resize the image so that entirety of image is inside target resolution. Fill empty space with image's colors.",
30
+
31
+ "Mask blur": "How much to blur the mask before processing, in pixels.",
32
+ "Masked content": "What to put inside the masked area before processing it with Stable Diffusion.",
33
+ "fill": "fill it with colors of the image",
34
+ "original": "keep whatever was there originally",
35
+ "latent noise": "fill it with latent space noise",
36
+ "latent nothing": "fill it with latent space zeroes",
37
+ "Inpaint at full resolution": "Upscale masked region to target resolution, do inpainting, downscale back and paste into original image",
38
+
39
+ "Denoising strength": "Determines how little respect the algorithm should have for image's content. At 0, nothing will change, and at 1 you'll get an unrelated image. With values below 1.0, processing will take less steps than the Sampling Steps slider specifies.",
40
+ "Denoising strength change factor": "In loopback mode, on each loop the denoising strength is multiplied by this value. <1 means decreasing variety so your sequence will converge on a fixed picture. >1 means increasing variety so your sequence will become more and more chaotic.",
41
+
42
+ "Skip": "Stop processing current image and continue processing.",
43
+ "Interrupt": "Stop processing images and return any results accumulated so far.",
44
+ "Save": "Write image to a directory (default - log/images) and generation parameters into csv file.",
45
+
46
+ "X values": "Separate values for X axis using commas.",
47
+ "Y values": "Separate values for Y axis using commas.",
48
+
49
+ "None": "Do not do anything special",
50
+ "Prompt matrix": "Separate prompts into parts using vertical pipe character (|) and the script will create a picture for every combination of them (except for the first part, which will be present in all combinations)",
51
+ "X/Y plot": "Create a grid where images will have different parameters. Use inputs below to specify which parameters will be shared by columns and rows",
52
+ "Custom code": "Run Python code. Advanced user only. Must run program with --allow-code for this to work",
53
+
54
+ "Prompt S/R": "Separate a list of words with commas, and the first word will be used as a keyword: script will search for this word in the prompt, and replace it with others",
55
+ "Prompt order": "Separate a list of words with commas, and the script will make a variation of prompt with those words for their every possible order",
56
+
57
+ "Tiling": "Produce an image that can be tiled.",
58
+ "Tile overlap": "For SD upscale, how much overlap in pixels should there be between tiles. Tiles overlap so that when they are merged back into one picture, there is no clearly visible seam.",
59
+
60
+ "Variation seed": "Seed of a different picture to be mixed into the generation.",
61
+ "Variation strength": "How strong of a variation to produce. At 0, there will be no effect. At 1, you will get the complete picture with variation seed (except for ancestral samplers, where you will just get something).",
62
+ "Resize seed from height": "Make an attempt to produce a picture similar to what would have been produced with same seed at specified resolution",
63
+ "Resize seed from width": "Make an attempt to produce a picture similar to what would have been produced with same seed at specified resolution",
64
+
65
+ "Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.",
66
+
67
+ "Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.",
68
+ "Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.",
69
+ "Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle",
70
+
71
+ "Loopback": "Process an image, use it as an input, repeat.",
72
+ "Loops": "How many times to repeat processing an image and using it as input for the next iteration",
73
+
74
+ "Style 1": "Style to apply; styles have components for both positive and negative prompts and apply to both",
75
+ "Style 2": "Style to apply; styles have components for both positive and negative prompts and apply to both",
76
+ "Apply style": "Insert selected styles into prompt fields",
77
+ "Create style": "Save current prompts as a style. If you add the token {prompt} to the text, the style use that as placeholder for your prompt when you use the style in the future.",
78
+
79
+ "Checkpoint name": "Loads weights from checkpoint before making images. You can either use hash or a part of filename (as seen in settings) for checkpoint name. Recommended to use with Y axis for less switching.",
80
+ "Inpainting conditioning mask strength": "Only applies to inpainting models. Determines how strongly to mask off the original image for inpainting and img2img. 1.0 means fully masked, which is the default behaviour. 0.0 means a fully unmasked conditioning. Lower values will help preserve the overall composition of the image, but will struggle with large changes.",
81
+
82
+ "vram": "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data.\nTorch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data.\nSys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%).",
83
+
84
+ "Highres. fix": "Use a two step process to partially create an image at smaller resolution, upscale, and then improve details in it without changing composition",
85
+ "Scale latent": "Uscale the image in latent space. Alternative is to produce the full image from latent representation, upscale that, and then move it back to latent space.",
86
+
87
+ "Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.",
88
+ "Do not add watermark to images": "If this option is enabled, watermark will not be added to created images. Warning: if you do not add watermark, you may be behaving in an unethical manner.",
89
+
90
+ "Filename word regex": "This regular expression will be used extract words from filename, and they will be joined using the option below into label text used for training. Leave empty to keep filename text as it is.",
91
+ "Filename join string": "This string will be used to join split words into a single line if the option above is enabled.",
92
+
93
+ "Quicksettings list": "List of setting names, separated by commas, for settings that should go to the quick access bar at the top, rather than the usual setting tab. See modules/shared.py for setting names. Requires restarting to apply.",
94
+
95
+ "Weighted sum": "Result = A * (1 - M) + B * M",
96
+ "Add difference": "Result = A + (B - C) * M",
97
+
98
+ "Learning rate": "how fast should the training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.",
99
+
100
+ "Clip skip": "Early stopping parameter for CLIP model; 1 is stop at last layer as usual, 2 is stop at penultimate layer, etc.",
101
+
102
+ "Approx NN": "Cheap neural network approximation. Very fast compared to VAE, but produces pictures with 4 times smaller horizontal/vertical resoluton and lower quality.",
103
+ "Approx cheap": "Very cheap approximation. Very fast compared to VAE, but produces pictures with 8 times smaller horizontal/vertical resoluton and extremely low quality."
104
+ }
105
+
106
+
107
+ onUiUpdate(function(){
108
+ gradioApp().querySelectorAll('span, button, select, p').forEach(function(span){
109
+ tooltip = titles[span.textContent];
110
+
111
+ if(!tooltip){
112
+ tooltip = titles[span.value];
113
+ }
114
+
115
+ if(!tooltip){
116
+ for (const c of span.classList) {
117
+ if (c in titles) {
118
+ tooltip = titles[c];
119
+ break;
120
+ }
121
+ }
122
+ }
123
+
124
+ if(tooltip){
125
+ span.title = tooltip;
126
+ }
127
+ })
128
+
129
+ gradioApp().querySelectorAll('select').forEach(function(select){
130
+ if (select.onchange != null) return;
131
+
132
+ select.onchange = function(){
133
+ select.title = titles[select.value] || "";
134
+ }
135
+ })
136
+ })
javascript/imageMaskFix.js ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * temporary fix for https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/668
3
+ * @see https://github.com/gradio-app/gradio/issues/1721
4
+ */
5
+ window.addEventListener( 'resize', () => imageMaskResize());
6
+ function imageMaskResize() {
7
+ const canvases = gradioApp().querySelectorAll('#img2maskimg .touch-none canvas');
8
+ if ( ! canvases.length ) {
9
+ canvases_fixed = false;
10
+ window.removeEventListener( 'resize', imageMaskResize );
11
+ return;
12
+ }
13
+
14
+ const wrapper = canvases[0].closest('.touch-none');
15
+ const previewImage = wrapper.previousElementSibling;
16
+
17
+ if ( ! previewImage.complete ) {
18
+ previewImage.addEventListener( 'load', () => imageMaskResize());
19
+ return;
20
+ }
21
+
22
+ const w = previewImage.width;
23
+ const h = previewImage.height;
24
+ const nw = previewImage.naturalWidth;
25
+ const nh = previewImage.naturalHeight;
26
+ const portrait = nh > nw;
27
+ const factor = portrait;
28
+
29
+ const wW = Math.min(w, portrait ? h/nh*nw : w/nw*nw);
30
+ const wH = Math.min(h, portrait ? h/nh*nh : w/nw*nh);
31
+
32
+ wrapper.style.width = `${wW}px`;
33
+ wrapper.style.height = `${wH}px`;
34
+ wrapper.style.left = `0px`;
35
+ wrapper.style.top = `0px`;
36
+
37
+ canvases.forEach( c => {
38
+ c.style.width = c.style.height = '';
39
+ c.style.maxWidth = '100%';
40
+ c.style.maxHeight = '100%';
41
+ c.style.objectFit = 'contain';
42
+ });
43
+ }
44
+
45
+ onUiUpdate(() => imageMaskResize());
javascript/imageParams.js ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ window.onload = (function(){
2
+ window.addEventListener('drop', e => {
3
+ const target = e.composedPath()[0];
4
+ const idx = selected_gallery_index();
5
+ if (target.placeholder.indexOf("Prompt") == -1) return;
6
+
7
+ let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image";
8
+
9
+ e.stopPropagation();
10
+ e.preventDefault();
11
+ const imgParent = gradioApp().getElementById(prompt_target);
12
+ const files = e.dataTransfer.files;
13
+ const fileInput = imgParent.querySelector('input[type="file"]');
14
+ if ( fileInput ) {
15
+ fileInput.files = files;
16
+ fileInput.dispatchEvent(new Event('change'));
17
+ }
18
+ });
19
+ });
javascript/imageviewer.js ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // A full size 'lightbox' preview modal shown when left clicking on gallery previews
2
+ function closeModal() {
3
+ gradioApp().getElementById("lightboxModal").style.display = "none";
4
+ }
5
+
6
+ function showModal(event) {
7
+ const source = event.target || event.srcElement;
8
+ const modalImage = gradioApp().getElementById("modalImage")
9
+ const lb = gradioApp().getElementById("lightboxModal")
10
+ modalImage.src = source.src
11
+ if (modalImage.style.display === 'none') {
12
+ lb.style.setProperty('background-image', 'url(' + source.src + ')');
13
+ }
14
+ lb.style.display = "block";
15
+ lb.focus()
16
+
17
+ const tabTxt2Img = gradioApp().getElementById("tab_txt2img")
18
+ const tabImg2Img = gradioApp().getElementById("tab_img2img")
19
+ // show the save button in modal only on txt2img or img2img tabs
20
+ if (tabTxt2Img.style.display != "none" || tabImg2Img.style.display != "none") {
21
+ gradioApp().getElementById("modal_save").style.display = "inline"
22
+ } else {
23
+ gradioApp().getElementById("modal_save").style.display = "none"
24
+ }
25
+ event.stopPropagation()
26
+ }
27
+
28
+ function negmod(n, m) {
29
+ return ((n % m) + m) % m;
30
+ }
31
+
32
+ function updateOnBackgroundChange() {
33
+ const modalImage = gradioApp().getElementById("modalImage")
34
+ if (modalImage && modalImage.offsetParent) {
35
+ let allcurrentButtons = gradioApp().querySelectorAll(".gallery-item.transition-all.\\!ring-2")
36
+ let currentButton = null
37
+ allcurrentButtons.forEach(function(elem) {
38
+ if (elem.parentElement.offsetParent) {
39
+ currentButton = elem;
40
+ }
41
+ })
42
+
43
+ if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) {
44
+ modalImage.src = currentButton.children[0].src;
45
+ if (modalImage.style.display === 'none') {
46
+ modal.style.setProperty('background-image', `url(${modalImage.src})`)
47
+ }
48
+ }
49
+ }
50
+ }
51
+
52
+ function modalImageSwitch(offset) {
53
+ var allgalleryButtons = gradioApp().querySelectorAll(".gallery-item.transition-all")
54
+ var galleryButtons = []
55
+ allgalleryButtons.forEach(function(elem) {
56
+ if (elem.parentElement.offsetParent) {
57
+ galleryButtons.push(elem);
58
+ }
59
+ })
60
+
61
+ if (galleryButtons.length > 1) {
62
+ var allcurrentButtons = gradioApp().querySelectorAll(".gallery-item.transition-all.\\!ring-2")
63
+ var currentButton = null
64
+ allcurrentButtons.forEach(function(elem) {
65
+ if (elem.parentElement.offsetParent) {
66
+ currentButton = elem;
67
+ }
68
+ })
69
+
70
+ var result = -1
71
+ galleryButtons.forEach(function(v, i) {
72
+ if (v == currentButton) {
73
+ result = i
74
+ }
75
+ })
76
+
77
+ if (result != -1) {
78
+ nextButton = galleryButtons[negmod((result + offset), galleryButtons.length)]
79
+ nextButton.click()
80
+ const modalImage = gradioApp().getElementById("modalImage");
81
+ const modal = gradioApp().getElementById("lightboxModal");
82
+ modalImage.src = nextButton.children[0].src;
83
+ if (modalImage.style.display === 'none') {
84
+ modal.style.setProperty('background-image', `url(${modalImage.src})`)
85
+ }
86
+ setTimeout(function() {
87
+ modal.focus()
88
+ }, 10)
89
+ }
90
+ }
91
+ }
92
+
93
+ function saveImage(){
94
+ const tabTxt2Img = gradioApp().getElementById("tab_txt2img")
95
+ const tabImg2Img = gradioApp().getElementById("tab_img2img")
96
+ const saveTxt2Img = "save_txt2img"
97
+ const saveImg2Img = "save_img2img"
98
+ if (tabTxt2Img.style.display != "none") {
99
+ gradioApp().getElementById(saveTxt2Img).click()
100
+ } else if (tabImg2Img.style.display != "none") {
101
+ gradioApp().getElementById(saveImg2Img).click()
102
+ } else {
103
+ console.error("missing implementation for saving modal of this type")
104
+ }
105
+ }
106
+
107
+ function modalSaveImage(event) {
108
+ saveImage()
109
+ event.stopPropagation()
110
+ }
111
+
112
+ function modalNextImage(event) {
113
+ modalImageSwitch(1)
114
+ event.stopPropagation()
115
+ }
116
+
117
+ function modalPrevImage(event) {
118
+ modalImageSwitch(-1)
119
+ event.stopPropagation()
120
+ }
121
+
122
+ function modalKeyHandler(event) {
123
+ switch (event.key) {
124
+ case "s":
125
+ saveImage()
126
+ break;
127
+ case "ArrowLeft":
128
+ modalPrevImage(event)
129
+ break;
130
+ case "ArrowRight":
131
+ modalNextImage(event)
132
+ break;
133
+ case "Escape":
134
+ closeModal();
135
+ break;
136
+ }
137
+ }
138
+
139
+ function showGalleryImage() {
140
+ setTimeout(function() {
141
+ fullImg_preview = gradioApp().querySelectorAll('img.w-full.object-contain')
142
+
143
+ if (fullImg_preview != null) {
144
+ fullImg_preview.forEach(function function_name(e) {
145
+ if (e.dataset.modded)
146
+ return;
147
+ e.dataset.modded = true;
148
+ if(e && e.parentElement.tagName == 'DIV'){
149
+ e.style.cursor='pointer'
150
+ e.style.userSelect='none'
151
+ e.addEventListener('click', function (evt) {
152
+ if(!opts.js_modal_lightbox) return;
153
+ modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed)
154
+ showModal(evt)
155
+ }, true);
156
+ }
157
+ });
158
+ }
159
+
160
+ }, 100);
161
+ }
162
+
163
+ function modalZoomSet(modalImage, enable) {
164
+ if (enable) {
165
+ modalImage.classList.add('modalImageFullscreen');
166
+ } else {
167
+ modalImage.classList.remove('modalImageFullscreen');
168
+ }
169
+ }
170
+
171
+ function modalZoomToggle(event) {
172
+ modalImage = gradioApp().getElementById("modalImage");
173
+ modalZoomSet(modalImage, !modalImage.classList.contains('modalImageFullscreen'))
174
+ event.stopPropagation()
175
+ }
176
+
177
+ function modalTileImageToggle(event) {
178
+ const modalImage = gradioApp().getElementById("modalImage");
179
+ const modal = gradioApp().getElementById("lightboxModal");
180
+ const isTiling = modalImage.style.display === 'none';
181
+ if (isTiling) {
182
+ modalImage.style.display = 'block';
183
+ modal.style.setProperty('background-image', 'none')
184
+ } else {
185
+ modalImage.style.display = 'none';
186
+ modal.style.setProperty('background-image', `url(${modalImage.src})`)
187
+ }
188
+
189
+ event.stopPropagation()
190
+ }
191
+
192
+ function galleryImageHandler(e) {
193
+ if (e && e.parentElement.tagName == 'BUTTON') {
194
+ e.onclick = showGalleryImage;
195
+ }
196
+ }
197
+
198
+ onUiUpdate(function() {
199
+ fullImg_preview = gradioApp().querySelectorAll('img.w-full')
200
+ if (fullImg_preview != null) {
201
+ fullImg_preview.forEach(galleryImageHandler);
202
+ }
203
+ updateOnBackgroundChange();
204
+ })
205
+
206
+ document.addEventListener("DOMContentLoaded", function() {
207
+ const modalFragment = document.createDocumentFragment();
208
+ const modal = document.createElement('div')
209
+ modal.onclick = closeModal;
210
+ modal.id = "lightboxModal";
211
+ modal.tabIndex = 0
212
+ modal.addEventListener('keydown', modalKeyHandler, true)
213
+
214
+ const modalControls = document.createElement('div')
215
+ modalControls.className = 'modalControls gradio-container';
216
+ modal.append(modalControls);
217
+
218
+ const modalZoom = document.createElement('span')
219
+ modalZoom.className = 'modalZoom cursor';
220
+ modalZoom.innerHTML = '&#10529;'
221
+ modalZoom.addEventListener('click', modalZoomToggle, true)
222
+ modalZoom.title = "Toggle zoomed view";
223
+ modalControls.appendChild(modalZoom)
224
+
225
+ const modalTileImage = document.createElement('span')
226
+ modalTileImage.className = 'modalTileImage cursor';
227
+ modalTileImage.innerHTML = '&#8862;'
228
+ modalTileImage.addEventListener('click', modalTileImageToggle, true)
229
+ modalTileImage.title = "Preview tiling";
230
+ modalControls.appendChild(modalTileImage)
231
+
232
+ const modalSave = document.createElement("span")
233
+ modalSave.className = "modalSave cursor"
234
+ modalSave.id = "modal_save"
235
+ modalSave.innerHTML = "&#x1F5AB;"
236
+ modalSave.addEventListener("click", modalSaveImage, true)
237
+ modalSave.title = "Save Image(s)"
238
+ modalControls.appendChild(modalSave)
239
+
240
+ const modalClose = document.createElement('span')
241
+ modalClose.className = 'modalClose cursor';
242
+ modalClose.innerHTML = '&times;'
243
+ modalClose.onclick = closeModal;
244
+ modalClose.title = "Close image viewer";
245
+ modalControls.appendChild(modalClose)
246
+
247
+ const modalImage = document.createElement('img')
248
+ modalImage.id = 'modalImage';
249
+ modalImage.onclick = closeModal;
250
+ modalImage.tabIndex = 0
251
+ modalImage.addEventListener('keydown', modalKeyHandler, true)
252
+ modal.appendChild(modalImage)
253
+
254
+ const modalPrev = document.createElement('a')
255
+ modalPrev.className = 'modalPrev';
256
+ modalPrev.innerHTML = '&#10094;'
257
+ modalPrev.tabIndex = 0
258
+ modalPrev.addEventListener('click', modalPrevImage, true);
259
+ modalPrev.addEventListener('keydown', modalKeyHandler, true)
260
+ modal.appendChild(modalPrev)
261
+
262
+ const modalNext = document.createElement('a')
263
+ modalNext.className = 'modalNext';
264
+ modalNext.innerHTML = '&#10095;'
265
+ modalNext.tabIndex = 0
266
+ modalNext.addEventListener('click', modalNextImage, true);
267
+ modalNext.addEventListener('keydown', modalKeyHandler, true)
268
+
269
+ modal.appendChild(modalNext)
270
+
271
+
272
+ gradioApp().getRootNode().appendChild(modal)
273
+
274
+ document.body.appendChild(modalFragment);
275
+
276
+ });
javascript/localization.js ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ // localization = {} -- the dict with translations is created by the backend
3
+
4
+ ignore_ids_for_localization={
5
+ setting_sd_hypernetwork: 'OPTION',
6
+ setting_sd_model_checkpoint: 'OPTION',
7
+ setting_realesrgan_enabled_models: 'OPTION',
8
+ modelmerger_primary_model_name: 'OPTION',
9
+ modelmerger_secondary_model_name: 'OPTION',
10
+ modelmerger_tertiary_model_name: 'OPTION',
11
+ train_embedding: 'OPTION',
12
+ train_hypernetwork: 'OPTION',
13
+ txt2img_style_index: 'OPTION',
14
+ txt2img_style2_index: 'OPTION',
15
+ img2img_style_index: 'OPTION',
16
+ img2img_style2_index: 'OPTION',
17
+ setting_random_artist_categories: 'SPAN',
18
+ setting_face_restoration_model: 'SPAN',
19
+ setting_realesrgan_enabled_models: 'SPAN',
20
+ extras_upscaler_1: 'SPAN',
21
+ extras_upscaler_2: 'SPAN',
22
+ }
23
+
24
+ re_num = /^[\.\d]+$/
25
+ re_emoji = /[\p{Extended_Pictographic}\u{1F3FB}-\u{1F3FF}\u{1F9B0}-\u{1F9B3}]/u
26
+
27
+ original_lines = {}
28
+ translated_lines = {}
29
+
30
+ function textNodesUnder(el){
31
+ var n, a=[], walk=document.createTreeWalker(el,NodeFilter.SHOW_TEXT,null,false);
32
+ while(n=walk.nextNode()) a.push(n);
33
+ return a;
34
+ }
35
+
36
+ function canBeTranslated(node, text){
37
+ if(! text) return false;
38
+ if(! node.parentElement) return false;
39
+
40
+ parentType = node.parentElement.nodeName
41
+ if(parentType=='SCRIPT' || parentType=='STYLE' || parentType=='TEXTAREA') return false;
42
+
43
+ if (parentType=='OPTION' || parentType=='SPAN'){
44
+ pnode = node
45
+ for(var level=0; level<4; level++){
46
+ pnode = pnode.parentElement
47
+ if(! pnode) break;
48
+
49
+ if(ignore_ids_for_localization[pnode.id] == parentType) return false;
50
+ }
51
+ }
52
+
53
+ if(re_num.test(text)) return false;
54
+ if(re_emoji.test(text)) return false;
55
+ return true
56
+ }
57
+
58
+ function getTranslation(text){
59
+ if(! text) return undefined
60
+
61
+ if(translated_lines[text] === undefined){
62
+ original_lines[text] = 1
63
+ }
64
+
65
+ tl = localization[text]
66
+ if(tl !== undefined){
67
+ translated_lines[tl] = 1
68
+ }
69
+
70
+ return tl
71
+ }
72
+
73
+ function processTextNode(node){
74
+ text = node.textContent.trim()
75
+
76
+ if(! canBeTranslated(node, text)) return
77
+
78
+ tl = getTranslation(text)
79
+ if(tl !== undefined){
80
+ node.textContent = tl
81
+ }
82
+ }
83
+
84
+ function processNode(node){
85
+ if(node.nodeType == 3){
86
+ processTextNode(node)
87
+ return
88
+ }
89
+
90
+ if(node.title){
91
+ tl = getTranslation(node.title)
92
+ if(tl !== undefined){
93
+ node.title = tl
94
+ }
95
+ }
96
+
97
+ if(node.placeholder){
98
+ tl = getTranslation(node.placeholder)
99
+ if(tl !== undefined){
100
+ node.placeholder = tl
101
+ }
102
+ }
103
+
104
+ textNodesUnder(node).forEach(function(node){
105
+ processTextNode(node)
106
+ })
107
+ }
108
+
109
+ function dumpTranslations(){
110
+ dumped = {}
111
+ if (localization.rtl) {
112
+ dumped.rtl = true
113
+ }
114
+
115
+ Object.keys(original_lines).forEach(function(text){
116
+ if(dumped[text] !== undefined) return
117
+
118
+ dumped[text] = localization[text] || text
119
+ })
120
+
121
+ return dumped
122
+ }
123
+
124
+ onUiUpdate(function(m){
125
+ m.forEach(function(mutation){
126
+ mutation.addedNodes.forEach(function(node){
127
+ processNode(node)
128
+ })
129
+ });
130
+ })
131
+
132
+
133
+ document.addEventListener("DOMContentLoaded", function() {
134
+ processNode(gradioApp())
135
+
136
+ if (localization.rtl) { // if the language is from right to left,
137
+ (new MutationObserver((mutations, observer) => { // wait for the style to load
138
+ mutations.forEach(mutation => {
139
+ mutation.addedNodes.forEach(node => {
140
+ if (node.tagName === 'STYLE') {
141
+ observer.disconnect();
142
+
143
+ for (const x of node.sheet.rules) { // find all rtl media rules
144
+ if (Array.from(x.media || []).includes('rtl')) {
145
+ x.media.appendMedium('all'); // enable them
146
+ }
147
+ }
148
+ }
149
+ })
150
+ });
151
+ })).observe(gradioApp(), { childList: true });
152
+ }
153
+ })
154
+
155
+ function download_localization() {
156
+ text = JSON.stringify(dumpTranslations(), null, 4)
157
+
158
+ var element = document.createElement('a');
159
+ element.setAttribute('href', 'data:text/plain;charset=utf-8,' + encodeURIComponent(text));
160
+ element.setAttribute('download', "localization.json");
161
+ element.style.display = 'none';
162
+ document.body.appendChild(element);
163
+
164
+ element.click();
165
+
166
+ document.body.removeChild(element);
167
+ }
javascript/notification.js ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Monitors the gallery and sends a browser notification when the leading image is new.
2
+
3
+ let lastHeadImg = null;
4
+
5
+ notificationButton = null
6
+
7
+ onUiUpdate(function(){
8
+ if(notificationButton == null){
9
+ notificationButton = gradioApp().getElementById('request_notifications')
10
+
11
+ if(notificationButton != null){
12
+ notificationButton.addEventListener('click', function (evt) {
13
+ Notification.requestPermission();
14
+ },true);
15
+ }
16
+ }
17
+
18
+ const galleryPreviews = gradioApp().querySelectorAll('div[id^="tab_"][style*="display: block"] img.h-full.w-full.overflow-hidden');
19
+
20
+ if (galleryPreviews == null) return;
21
+
22
+ const headImg = galleryPreviews[0]?.src;
23
+
24
+ if (headImg == null || headImg == lastHeadImg) return;
25
+
26
+ lastHeadImg = headImg;
27
+
28
+ // play notification sound if available
29
+ gradioApp().querySelector('#audio_notification audio')?.play();
30
+
31
+ if (document.hasFocus()) return;
32
+
33
+ // Multiple copies of the images are in the DOM when one is selected. Dedup with a Set to get the real number generated.
34
+ const imgs = new Set(Array.from(galleryPreviews).map(img => img.src));
35
+
36
+ const notification = new Notification(
37
+ 'Stable Diffusion',
38
+ {
39
+ body: `Generated ${imgs.size > 1 ? imgs.size - opts.return_grid : 1} image${imgs.size > 1 ? 's' : ''}`,
40
+ icon: headImg,
41
+ image: headImg,
42
+ }
43
+ );
44
+
45
+ notification.onclick = function(_){
46
+ parent.focus();
47
+ this.close();
48
+ };
49
+ });
javascript/progressbar.js ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // code related to showing and updating progressbar shown as the image is being made
2
+ global_progressbars = {}
3
+ galleries = {}
4
+ galleryObservers = {}
5
+
6
+ // this tracks launches of window.setTimeout for progressbar to prevent starting a new timeout when the previous is still running
7
+ timeoutIds = {}
8
+
9
+ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){
10
+ // gradio 3.8's enlightened approach allows them to create two nested div elements inside each other with same id
11
+ // every time you use gr.HTML(elem_id='xxx'), so we handle this here
12
+ var progressbar = gradioApp().querySelector("#"+id_progressbar+" #"+id_progressbar)
13
+ var progressbarParent
14
+ if(progressbar){
15
+ progressbarParent = gradioApp().querySelector("#"+id_progressbar)
16
+ } else{
17
+ progressbar = gradioApp().getElementById(id_progressbar)
18
+ progressbarParent = null
19
+ }
20
+
21
+ var skip = id_skip ? gradioApp().getElementById(id_skip) : null
22
+ var interrupt = gradioApp().getElementById(id_interrupt)
23
+
24
+ if(opts.show_progress_in_title && progressbar && progressbar.offsetParent){
25
+ if(progressbar.innerText){
26
+ let newtitle = '[' + progressbar.innerText.trim() + '] Stable Diffusion';
27
+ if(document.title != newtitle){
28
+ document.title = newtitle;
29
+ }
30
+ }else{
31
+ let newtitle = 'Stable Diffusion'
32
+ if(document.title != newtitle){
33
+ document.title = newtitle;
34
+ }
35
+ }
36
+ }
37
+
38
+ if(progressbar!= null && progressbar != global_progressbars[id_progressbar]){
39
+ global_progressbars[id_progressbar] = progressbar
40
+
41
+ var mutationObserver = new MutationObserver(function(m){
42
+ if(timeoutIds[id_part]) return;
43
+
44
+ preview = gradioApp().getElementById(id_preview)
45
+ gallery = gradioApp().getElementById(id_gallery)
46
+
47
+ if(preview != null && gallery != null){
48
+ preview.style.width = gallery.clientWidth + "px"
49
+ preview.style.height = gallery.clientHeight + "px"
50
+ if(progressbarParent) progressbar.style.width = progressbarParent.clientWidth + "px"
51
+
52
+ //only watch gallery if there is a generation process going on
53
+ check_gallery(id_gallery);
54
+
55
+ var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0;
56
+ if(progressDiv){
57
+ timeoutIds[id_part] = window.setTimeout(function() {
58
+ timeoutIds[id_part] = null
59
+ requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt)
60
+ }, 500)
61
+ } else{
62
+ if (skip) {
63
+ skip.style.display = "none"
64
+ }
65
+ interrupt.style.display = "none"
66
+
67
+ //disconnect observer once generation finished, so user can close selected image if they want
68
+ if (galleryObservers[id_gallery]) {
69
+ galleryObservers[id_gallery].disconnect();
70
+ galleries[id_gallery] = null;
71
+ }
72
+ }
73
+ }
74
+
75
+ });
76
+ mutationObserver.observe( progressbar, { childList:true, subtree:true })
77
+ }
78
+ }
79
+
80
+ function check_gallery(id_gallery){
81
+ let gallery = gradioApp().getElementById(id_gallery)
82
+ // if gallery has no change, no need to setting up observer again.
83
+ if (gallery && galleries[id_gallery] !== gallery){
84
+ galleries[id_gallery] = gallery;
85
+ if(galleryObservers[id_gallery]){
86
+ galleryObservers[id_gallery].disconnect();
87
+ }
88
+ let prevSelectedIndex = selected_gallery_index();
89
+ galleryObservers[id_gallery] = new MutationObserver(function (){
90
+ let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item')
91
+ let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2')
92
+ if (prevSelectedIndex !== -1 && galleryButtons.length>prevSelectedIndex && !galleryBtnSelected) {
93
+ // automatically re-open previously selected index (if exists)
94
+ activeElement = gradioApp().activeElement;
95
+ let scrollX = window.scrollX;
96
+ let scrollY = window.scrollY;
97
+
98
+ galleryButtons[prevSelectedIndex].click();
99
+ showGalleryImage();
100
+
101
+ // When the gallery button is clicked, it gains focus and scrolls itself into view
102
+ // We need to scroll back to the previous position
103
+ setTimeout(function (){
104
+ window.scrollTo(scrollX, scrollY);
105
+ }, 50);
106
+
107
+ if(activeElement){
108
+ // i fought this for about an hour; i don't know why the focus is lost or why this helps recover it
109
+ // if someone has a better solution please by all means
110
+ setTimeout(function (){
111
+ activeElement.focus({
112
+ preventScroll: true // Refocus the element that was focused before the gallery was opened without scrolling to it
113
+ })
114
+ }, 1);
115
+ }
116
+ }
117
+ })
118
+ galleryObservers[id_gallery].observe( gallery, { childList:true, subtree:false })
119
+ }
120
+ }
121
+
122
+ onUiUpdate(function(){
123
+ check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_skip', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery')
124
+ check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_skip', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery')
125
+ check_progressbar('ti', 'ti_progressbar', 'ti_progress_span', '', 'ti_interrupt', 'ti_preview', 'ti_gallery')
126
+ })
127
+
128
+ function requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt){
129
+ btn = gradioApp().getElementById(id_part+"_check_progress");
130
+ if(btn==null) return;
131
+
132
+ btn.click();
133
+ var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0;
134
+ var skip = id_skip ? gradioApp().getElementById(id_skip) : null
135
+ var interrupt = gradioApp().getElementById(id_interrupt)
136
+ if(progressDiv && interrupt){
137
+ if (skip) {
138
+ skip.style.display = "block"
139
+ }
140
+ interrupt.style.display = "block"
141
+ }
142
+ }
143
+
144
+ function requestProgress(id_part){
145
+ btn = gradioApp().getElementById(id_part+"_check_progress_initial");
146
+ if(btn==null) return;
147
+
148
+ btn.click();
149
+ }
javascript/textualInversion.js ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ function start_training_textual_inversion(){
4
+ requestProgress('ti')
5
+ gradioApp().querySelector('#ti_error').innerHTML=''
6
+
7
+ return args_to_array(arguments)
8
+ }
javascript/ui.js ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // various functions for interation with ui.py not large enough to warrant putting them in separate files
2
+
3
+ function set_theme(theme){
4
+ gradioURL = window.location.href
5
+ if (!gradioURL.includes('?__theme=')) {
6
+ window.location.replace(gradioURL + '?__theme=' + theme);
7
+ }
8
+ }
9
+
10
+ function selected_gallery_index(){
11
+ var buttons = gradioApp().querySelectorAll('[style="display: block;"].tabitem div[id$=_gallery] .gallery-item')
12
+ var button = gradioApp().querySelector('[style="display: block;"].tabitem div[id$=_gallery] .gallery-item.\\!ring-2')
13
+
14
+ var result = -1
15
+ buttons.forEach(function(v, i){ if(v==button) { result = i } })
16
+
17
+ return result
18
+ }
19
+
20
+ function extract_image_from_gallery(gallery){
21
+ if(gallery.length == 1){
22
+ return gallery[0]
23
+ }
24
+
25
+ index = selected_gallery_index()
26
+
27
+ if (index < 0 || index >= gallery.length){
28
+ return [null]
29
+ }
30
+
31
+ return gallery[index];
32
+ }
33
+
34
+ function args_to_array(args){
35
+ res = []
36
+ for(var i=0;i<args.length;i++){
37
+ res.push(args[i])
38
+ }
39
+ return res
40
+ }
41
+
42
+ function switch_to_txt2img(){
43
+ gradioApp().querySelector('#tabs').querySelectorAll('button')[0].click();
44
+
45
+ return args_to_array(arguments);
46
+ }
47
+
48
+ function switch_to_img2img(){
49
+ gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click();
50
+ gradioApp().getElementById('mode_img2img').querySelectorAll('button')[0].click();
51
+
52
+ return args_to_array(arguments);
53
+ }
54
+
55
+ function switch_to_inpaint(){
56
+ gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click();
57
+ gradioApp().getElementById('mode_img2img').querySelectorAll('button')[1].click();
58
+
59
+ return args_to_array(arguments);
60
+ }
61
+
62
+ function switch_to_extras(){
63
+ gradioApp().querySelector('#tabs').querySelectorAll('button')[2].click();
64
+
65
+ return args_to_array(arguments);
66
+ }
67
+
68
+ function get_tab_index(tabId){
69
+ var res = 0
70
+
71
+ gradioApp().getElementById(tabId).querySelector('div').querySelectorAll('button').forEach(function(button, i){
72
+ if(button.className.indexOf('bg-white') != -1)
73
+ res = i
74
+ })
75
+
76
+ return res
77
+ }
78
+
79
+ function create_tab_index_args(tabId, args){
80
+ var res = []
81
+ for(var i=0; i<args.length; i++){
82
+ res.push(args[i])
83
+ }
84
+
85
+ res[0] = get_tab_index(tabId)
86
+
87
+ return res
88
+ }
89
+
90
+ function get_extras_tab_index(){
91
+ const [,,...args] = [...arguments]
92
+ return [get_tab_index('mode_extras'), get_tab_index('extras_resize_mode'), ...args]
93
+ }
94
+
95
+ function create_submit_args(args){
96
+ res = []
97
+ for(var i=0;i<args.length;i++){
98
+ res.push(args[i])
99
+ }
100
+
101
+ // As it is currently, txt2img and img2img send back the previous output args (txt2img_gallery, generation_info, html_info) whenever you generate a new image.
102
+ // This can lead to uploading a huge gallery of previously generated images, which leads to an unnecessary delay between submitting and beginning to generate.
103
+ // I don't know why gradio is sending outputs along with inputs, but we can prevent sending the image gallery here, which seems to be an issue for some.
104
+ // If gradio at some point stops sending outputs, this may break something
105
+ if(Array.isArray(res[res.length - 3])){
106
+ res[res.length - 3] = null
107
+ }
108
+
109
+ return res
110
+ }
111
+
112
+ function submit(){
113
+ requestProgress('txt2img')
114
+
115
+ return create_submit_args(arguments)
116
+ }
117
+
118
+ function submit_img2img(){
119
+ requestProgress('img2img')
120
+
121
+ res = create_submit_args(arguments)
122
+
123
+ res[0] = get_tab_index('mode_img2img')
124
+
125
+ return res
126
+ }
127
+
128
+
129
+ function ask_for_style_name(_, prompt_text, negative_prompt_text) {
130
+ name_ = prompt('Style name:')
131
+ return [name_, prompt_text, negative_prompt_text]
132
+ }
133
+
134
+ function confirm_clear_prompt(prompt, negative_prompt) {
135
+ if(confirm("Delete prompt?")) {
136
+ prompt = ""
137
+ negative_prompt = ""
138
+ }
139
+
140
+ return [prompt, negative_prompt]
141
+ }
142
+
143
+
144
+
145
+ opts = {}
146
+ function apply_settings(jsdata){
147
+ console.log(jsdata)
148
+
149
+ opts = JSON.parse(jsdata)
150
+
151
+ return jsdata
152
+ }
153
+
154
+ onUiUpdate(function(){
155
+ if(Object.keys(opts).length != 0) return;
156
+
157
+ json_elem = gradioApp().getElementById('settings_json')
158
+ if(json_elem == null) return;
159
+
160
+ textarea = json_elem.querySelector('textarea')
161
+ jsdata = textarea.value
162
+ opts = JSON.parse(jsdata)
163
+
164
+
165
+ Object.defineProperty(textarea, 'value', {
166
+ set: function(newValue) {
167
+ var valueProp = Object.getOwnPropertyDescriptor(HTMLTextAreaElement.prototype, 'value');
168
+ var oldValue = valueProp.get.call(textarea);
169
+ valueProp.set.call(textarea, newValue);
170
+
171
+ if (oldValue != newValue) {
172
+ opts = JSON.parse(textarea.value)
173
+ }
174
+ },
175
+ get: function() {
176
+ var valueProp = Object.getOwnPropertyDescriptor(HTMLTextAreaElement.prototype, 'value');
177
+ return valueProp.get.call(textarea);
178
+ }
179
+ });
180
+
181
+ json_elem.parentElement.style.display="none"
182
+
183
+ if (!txt2img_textarea) {
184
+ txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea");
185
+ txt2img_textarea?.addEventListener("input", () => update_token_counter("txt2img_token_button"));
186
+ }
187
+ if (!img2img_textarea) {
188
+ img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea");
189
+ img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button"));
190
+ }
191
+ })
192
+
193
+ let txt2img_textarea, img2img_textarea = undefined;
194
+ let wait_time = 800
195
+ let token_timeout;
196
+
197
+ function update_txt2img_tokens(...args) {
198
+ update_token_counter("txt2img_token_button")
199
+ if (args.length == 2)
200
+ return args[0]
201
+ return args;
202
+ }
203
+
204
+ function update_img2img_tokens(...args) {
205
+ update_token_counter("img2img_token_button")
206
+ if (args.length == 2)
207
+ return args[0]
208
+ return args;
209
+ }
210
+
211
+ function update_token_counter(button_id) {
212
+ if (token_timeout)
213
+ clearTimeout(token_timeout);
214
+ token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
215
+ }
216
+
217
+ function restart_reload(){
218
+ document.body.innerHTML='<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>';
219
+ setTimeout(function(){location.reload()},2000)
220
+
221
+ return []
222
+ }
launch.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # this scripts installs necessary requirements and launches main program in webui.py
2
+ import subprocess
3
+ import os
4
+ import sys
5
+ import importlib.util
6
+ import shlex
7
+ import platform
8
+ import argparse
9
+ import json
10
+
11
+ dir_repos = "repositories"
12
+ dir_extensions = "extensions"
13
+ python = sys.executable
14
+ git = os.environ.get('GIT', "git")
15
+ index_url = os.environ.get('INDEX_URL', "")
16
+
17
+
18
+ def extract_arg(args, name):
19
+ return [x for x in args if x != name], name in args
20
+
21
+
22
+ def extract_opt(args, name):
23
+ opt = None
24
+ is_present = False
25
+ if name in args:
26
+ is_present = True
27
+ idx = args.index(name)
28
+ del args[idx]
29
+ if idx < len(args) and args[idx][0] != "-":
30
+ opt = args[idx]
31
+ del args[idx]
32
+ return args, is_present, opt
33
+
34
+
35
+ def run(command, desc=None, errdesc=None, custom_env=None):
36
+ if desc is not None:
37
+ print(desc)
38
+
39
+ result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env)
40
+
41
+ if result.returncode != 0:
42
+
43
+ message = f"""{errdesc or 'Error running command'}.
44
+ Command: {command}
45
+ Error code: {result.returncode}
46
+ stdout: {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout)>0 else '<empty>'}
47
+ stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.stderr)>0 else '<empty>'}
48
+ """
49
+ raise RuntimeError(message)
50
+
51
+ return result.stdout.decode(encoding="utf8", errors="ignore")
52
+
53
+
54
+ def check_run(command):
55
+ result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
56
+ return result.returncode == 0
57
+
58
+
59
+ def is_installed(package):
60
+ try:
61
+ spec = importlib.util.find_spec(package)
62
+ except ModuleNotFoundError:
63
+ return False
64
+
65
+ return spec is not None
66
+
67
+
68
+ def repo_dir(name):
69
+ return os.path.join(dir_repos, name)
70
+
71
+
72
+ def run_python(code, desc=None, errdesc=None):
73
+ return run(f'"{python}" -c "{code}"', desc, errdesc)
74
+
75
+
76
+ def run_pip(args, desc=None):
77
+ index_url_line = f' --index-url {index_url}' if index_url != '' else ''
78
+ return run(f'"{python}" -m pip {args} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}")
79
+
80
+
81
+ def check_run_python(code):
82
+ return check_run(f'"{python}" -c "{code}"')
83
+
84
+
85
+ def git_clone(url, dir, name, commithash=None):
86
+ # TODO clone into temporary dir and move if successful
87
+
88
+ if os.path.exists(dir):
89
+ if commithash is None:
90
+ return
91
+
92
+ current_hash = run(f'"{git}" -C {dir} rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip()
93
+ if current_hash == commithash:
94
+ return
95
+
96
+ run(f'"{git}" -C {dir} fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
97
+ run(f'"{git}" -C {dir} checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}")
98
+ return
99
+
100
+ run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}")
101
+
102
+ if commithash is not None:
103
+ run(f'"{git}" -C {dir} checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
104
+
105
+
106
+ def version_check(commit):
107
+ try:
108
+ import requests
109
+ commits = requests.get('https://api.github.com/repos/AUTOMATIC1111/stable-diffusion-webui/branches/master').json()
110
+ if commit != "<none>" and commits['commit']['sha'] != commit:
111
+ print("--------------------------------------------------------")
112
+ print("| You are not up to date with the most recent release. |")
113
+ print("| Consider running `git pull` to update. |")
114
+ print("--------------------------------------------------------")
115
+ elif commits['commit']['sha'] == commit:
116
+ print("You are up to date with the most recent release.")
117
+ else:
118
+ print("Not a git clone, can't perform version check.")
119
+ except Exception as e:
120
+ print("version check failed", e)
121
+
122
+
123
+ def run_extension_installer(extension_dir):
124
+ path_installer = os.path.join(extension_dir, "install.py")
125
+ if not os.path.isfile(path_installer):
126
+ return
127
+
128
+ try:
129
+ env = os.environ.copy()
130
+ env['PYTHONPATH'] = os.path.abspath(".")
131
+
132
+ print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env))
133
+ except Exception as e:
134
+ print(e, file=sys.stderr)
135
+
136
+
137
+ def list_extensions(settings_file):
138
+ settings = {}
139
+
140
+ try:
141
+ if os.path.isfile(settings_file):
142
+ with open(settings_file, "r", encoding="utf8") as file:
143
+ settings = json.load(file)
144
+ except Exception as e:
145
+ print(e, file=sys.stderr)
146
+
147
+ disabled_extensions = set(settings.get('disabled_extensions', []))
148
+
149
+ return [x for x in os.listdir(dir_extensions) if x not in disabled_extensions]
150
+
151
+
152
+ def run_extensions_installers(settings_file):
153
+ if not os.path.isdir(dir_extensions):
154
+ return
155
+
156
+ for dirname_extension in list_extensions(settings_file):
157
+ run_extension_installer(os.path.join(dir_extensions, dirname_extension))
158
+
159
+
160
+ def prepare_environment():
161
+ torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113")
162
+ requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
163
+ commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
164
+
165
+ gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
166
+ clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
167
+ openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b")
168
+
169
+ xformers_windows_package = os.environ.get('XFORMERS_WINDOWS_PACKAGE', 'https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl')
170
+
171
+ stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
172
+ taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
173
+ k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
174
+ codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
175
+ blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
176
+
177
+ stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "47b6b607fdd31875c9279cd2f4f16b92e4ea958e")
178
+ taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
179
+ k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "5b3af030dd83e0297272d861c19477735d0317ec")
180
+ codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
181
+ blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
182
+
183
+ sys.argv += shlex.split(commandline_args)
184
+
185
+ parser = argparse.ArgumentParser()
186
+ parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default='config.json')
187
+ args, _ = parser.parse_known_args(sys.argv)
188
+
189
+ sys.argv, _ = extract_arg(sys.argv, '-f')
190
+ sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test')
191
+ sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers')
192
+ sys.argv, update_check = extract_arg(sys.argv, '--update-check')
193
+ sys.argv, run_tests, test_dir = extract_opt(sys.argv, '--tests')
194
+ xformers = '--xformers' in sys.argv
195
+ ngrok = '--ngrok' in sys.argv
196
+
197
+ try:
198
+ commit = run(f"{git} rev-parse HEAD").strip()
199
+ except Exception:
200
+ commit = "<none>"
201
+
202
+ print(f"Python {sys.version}")
203
+ print(f"Commit hash: {commit}")
204
+
205
+ if not is_installed("torch") or not is_installed("torchvision"):
206
+ run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch")
207
+
208
+ if not skip_torch_cuda_test:
209
+ run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'")
210
+
211
+ if not is_installed("gfpgan"):
212
+ run_pip(f"install {gfpgan_package}", "gfpgan")
213
+
214
+ if not is_installed("clip"):
215
+ run_pip(f"install {clip_package}", "clip")
216
+
217
+ if not is_installed("open_clip"):
218
+ run_pip(f"install {openclip_package}", "open_clip")
219
+
220
+ if (not is_installed("xformers") or reinstall_xformers) and xformers:
221
+ if platform.system() == "Windows":
222
+ if platform.python_version().startswith("3.10"):
223
+ run_pip(f"install -U -I --no-deps {xformers_windows_package}", "xformers")
224
+ else:
225
+ print("Installation of xformers is not supported in this version of Python.")
226
+ print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")
227
+ if not is_installed("xformers"):
228
+ exit(0)
229
+ elif platform.system() == "Linux":
230
+ run_pip("install xformers", "xformers")
231
+
232
+ if not is_installed("pyngrok") and ngrok:
233
+ run_pip("install pyngrok", "ngrok")
234
+
235
+ os.makedirs(dir_repos, exist_ok=True)
236
+
237
+ git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
238
+ git_clone(taming_transformers_repo, repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash)
239
+ git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
240
+ git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
241
+ git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
242
+
243
+ if not is_installed("lpips"):
244
+ run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer")
245
+
246
+ run_pip(f"install -r {requirements_file}", "requirements for Web UI")
247
+
248
+ run_extensions_installers(settings_file=args.ui_settings_file)
249
+
250
+ if update_check:
251
+ version_check(commit)
252
+
253
+ if "--exit" in sys.argv:
254
+ print("Exiting because of --exit argument")
255
+ exit(0)
256
+
257
+ if run_tests:
258
+ exitcode = tests(test_dir)
259
+ exit(exitcode)
260
+
261
+
262
+ def tests(test_dir):
263
+ if "--api" not in sys.argv:
264
+ sys.argv.append("--api")
265
+ if "--ckpt" not in sys.argv:
266
+ sys.argv.append("--ckpt")
267
+ sys.argv.append("./test/test_files/empty.pt")
268
+ if "--skip-torch-cuda-test" not in sys.argv:
269
+ sys.argv.append("--skip-torch-cuda-test")
270
+
271
+ print(f"Launching Web UI in another process for testing with arguments: {' '.join(sys.argv[1:])}")
272
+
273
+ with open('test/stdout.txt', "w", encoding="utf8") as stdout, open('test/stderr.txt', "w", encoding="utf8") as stderr:
274
+ proc = subprocess.Popen([sys.executable, *sys.argv], stdout=stdout, stderr=stderr)
275
+
276
+ import test.server_poll
277
+ exitcode = test.server_poll.run_tests(proc, test_dir)
278
+
279
+ print(f"Stopping Web UI process with id {proc.pid}")
280
+ proc.kill()
281
+ return exitcode
282
+
283
+
284
+ def start():
285
+ print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}")
286
+ import webui
287
+ if '--nowebui' in sys.argv:
288
+ webui.api_only()
289
+ else:
290
+ webui.webui()
291
+
292
+
293
+ if __name__ == "__main__":
294
+ prepare_environment()
295
+ start()
localizations/Put localization files here.txt ADDED
File without changes
modules/api/api.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import time
4
+ import uvicorn
5
+ from threading import Lock
6
+ from io import BytesIO
7
+ from gradio.processing_utils import decode_base64_to_file
8
+ from fastapi import APIRouter, Depends, FastAPI, HTTPException
9
+ from fastapi.security import HTTPBasic, HTTPBasicCredentials
10
+ from secrets import compare_digest
11
+
12
+ import modules.shared as shared
13
+ from modules import sd_samplers, deepbooru, sd_hijack
14
+ from modules.api.models import *
15
+ from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
16
+ from modules.extras import run_extras, run_pnginfo
17
+ from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
18
+ from modules.textual_inversion.preprocess import preprocess
19
+ from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
20
+ from PIL import PngImagePlugin,Image
21
+ from modules.sd_models import checkpoints_list
22
+ from modules.realesrgan_model import get_realesrgan_models
23
+ from modules import devices
24
+ from typing import List
25
+
26
+ def upscaler_to_index(name: str):
27
+ try:
28
+ return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
29
+ except:
30
+ raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")
31
+
32
+
33
+ def validate_sampler_name(name):
34
+ config = sd_samplers.all_samplers_map.get(name, None)
35
+ if config is None:
36
+ raise HTTPException(status_code=404, detail="Sampler not found")
37
+
38
+ return name
39
+
40
+ def setUpscalers(req: dict):
41
+ reqDict = vars(req)
42
+ reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1)
43
+ reqDict['extras_upscaler_2'] = upscaler_to_index(req.upscaler_2)
44
+ reqDict.pop('upscaler_1')
45
+ reqDict.pop('upscaler_2')
46
+ return reqDict
47
+
48
+ def decode_base64_to_image(encoding):
49
+ if encoding.startswith("data:image/"):
50
+ encoding = encoding.split(";")[1].split(",")[1]
51
+ return Image.open(BytesIO(base64.b64decode(encoding)))
52
+
53
+ def encode_pil_to_base64(image):
54
+ with io.BytesIO() as output_bytes:
55
+
56
+ # Copy any text-only metadata
57
+ use_metadata = False
58
+ metadata = PngImagePlugin.PngInfo()
59
+ for key, value in image.info.items():
60
+ if isinstance(key, str) and isinstance(value, str):
61
+ metadata.add_text(key, value)
62
+ use_metadata = True
63
+
64
+ image.save(
65
+ output_bytes, "PNG", pnginfo=(metadata if use_metadata else None)
66
+ )
67
+ bytes_data = output_bytes.getvalue()
68
+ return base64.b64encode(bytes_data)
69
+
70
+
71
+ class Api:
72
+ def __init__(self, app: FastAPI, queue_lock: Lock):
73
+ if shared.cmd_opts.api_auth:
74
+ self.credentials = dict()
75
+ for auth in shared.cmd_opts.api_auth.split(","):
76
+ user, password = auth.split(":")
77
+ self.credentials[user] = password
78
+
79
+ self.router = APIRouter()
80
+ self.app = app
81
+ self.queue_lock = queue_lock
82
+ self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
83
+ self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
84
+ self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
85
+ self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
86
+ self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
87
+ self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
88
+ self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
89
+ self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
90
+ self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
91
+ self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
92
+ self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
93
+ self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
94
+ self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem])
95
+ self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem])
96
+ self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem])
97
+ self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
98
+ self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
99
+ self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
100
+ self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
101
+ self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
102
+ self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
103
+ self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
104
+ self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
105
+ self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse)
106
+ self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=PreprocessResponse)
107
+ self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
108
+ self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
109
+
110
+ def add_api_route(self, path: str, endpoint, **kwargs):
111
+ if shared.cmd_opts.api_auth:
112
+ return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
113
+ return self.app.add_api_route(path, endpoint, **kwargs)
114
+
115
+ def auth(self, credentials: HTTPBasicCredentials = Depends(HTTPBasic())):
116
+ if credentials.username in self.credentials:
117
+ if compare_digest(credentials.password, self.credentials[credentials.username]):
118
+ return True
119
+
120
+ raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})
121
+
122
+ def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
123
+ populate = txt2imgreq.copy(update={ # Override __init__ params
124
+ "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
125
+ "do_not_save_samples": True,
126
+ "do_not_save_grid": True
127
+ }
128
+ )
129
+ if populate.sampler_name:
130
+ populate.sampler_index = None # prevent a warning later on
131
+
132
+ with self.queue_lock:
133
+ p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **vars(populate))
134
+
135
+ shared.state.begin()
136
+ processed = process_images(p)
137
+ shared.state.end()
138
+
139
+
140
+ b64images = list(map(encode_pil_to_base64, processed.images))
141
+
142
+ return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
143
+
144
+ def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
145
+ init_images = img2imgreq.init_images
146
+ if init_images is None:
147
+ raise HTTPException(status_code=404, detail="Init image not found")
148
+
149
+ mask = img2imgreq.mask
150
+ if mask:
151
+ mask = decode_base64_to_image(mask)
152
+
153
+ populate = img2imgreq.copy(update={ # Override __init__ params
154
+ "sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
155
+ "do_not_save_samples": True,
156
+ "do_not_save_grid": True,
157
+ "mask": mask
158
+ }
159
+ )
160
+ if populate.sampler_name:
161
+ populate.sampler_index = None # prevent a warning later on
162
+
163
+ args = vars(populate)
164
+ args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
165
+
166
+ with self.queue_lock:
167
+ p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
168
+ p.init_images = [decode_base64_to_image(x) for x in init_images]
169
+
170
+ shared.state.begin()
171
+ processed = process_images(p)
172
+ shared.state.end()
173
+
174
+ b64images = list(map(encode_pil_to_base64, processed.images))
175
+
176
+ if not img2imgreq.include_init_images:
177
+ img2imgreq.init_images = None
178
+ img2imgreq.mask = None
179
+
180
+ return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
181
+
182
+ def extras_single_image_api(self, req: ExtrasSingleImageRequest):
183
+ reqDict = setUpscalers(req)
184
+
185
+ reqDict['image'] = decode_base64_to_image(reqDict['image'])
186
+
187
+ with self.queue_lock:
188
+ result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
189
+
190
+ return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
191
+
192
+ def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
193
+ reqDict = setUpscalers(req)
194
+
195
+ def prepareFiles(file):
196
+ file = decode_base64_to_file(file.data, file_path=file.name)
197
+ file.orig_name = file.name
198
+ return file
199
+
200
+ reqDict['image_folder'] = list(map(prepareFiles, reqDict['imageList']))
201
+ reqDict.pop('imageList')
202
+
203
+ with self.queue_lock:
204
+ result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict)
205
+
206
+ return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
207
+
208
+ def pnginfoapi(self, req: PNGInfoRequest):
209
+ if(not req.image.strip()):
210
+ return PNGInfoResponse(info="")
211
+
212
+ result = run_pnginfo(decode_base64_to_image(req.image.strip()))
213
+
214
+ return PNGInfoResponse(info=result[1])
215
+
216
+ def progressapi(self, req: ProgressRequest = Depends()):
217
+ # copy from check_progress_call of ui.py
218
+
219
+ if shared.state.job_count == 0:
220
+ return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict())
221
+
222
+ # avoid dividing zero
223
+ progress = 0.01
224
+
225
+ if shared.state.job_count > 0:
226
+ progress += shared.state.job_no / shared.state.job_count
227
+ if shared.state.sampling_steps > 0:
228
+ progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
229
+
230
+ time_since_start = time.time() - shared.state.time_start
231
+ eta = (time_since_start/progress)
232
+ eta_relative = eta-time_since_start
233
+
234
+ progress = min(progress, 1)
235
+
236
+ shared.state.set_current_image()
237
+
238
+ current_image = None
239
+ if shared.state.current_image and not req.skip_current_image:
240
+ current_image = encode_pil_to_base64(shared.state.current_image)
241
+
242
+ return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image)
243
+
244
+ def interrogateapi(self, interrogatereq: InterrogateRequest):
245
+ image_b64 = interrogatereq.image
246
+ if image_b64 is None:
247
+ raise HTTPException(status_code=404, detail="Image not found")
248
+
249
+ img = decode_base64_to_image(image_b64)
250
+ img = img.convert('RGB')
251
+
252
+ # Override object param
253
+ with self.queue_lock:
254
+ if interrogatereq.model == "clip":
255
+ processed = shared.interrogator.interrogate(img)
256
+ elif interrogatereq.model == "deepdanbooru":
257
+ processed = deepbooru.model.tag(img)
258
+ else:
259
+ raise HTTPException(status_code=404, detail="Model not found")
260
+
261
+ return InterrogateResponse(caption=processed)
262
+
263
+ def interruptapi(self):
264
+ shared.state.interrupt()
265
+
266
+ return {}
267
+
268
+ def skip(self):
269
+ shared.state.skip()
270
+
271
+ def get_config(self):
272
+ options = {}
273
+ for key in shared.opts.data.keys():
274
+ metadata = shared.opts.data_labels.get(key)
275
+ if(metadata is not None):
276
+ options.update({key: shared.opts.data.get(key, shared.opts.data_labels.get(key).default)})
277
+ else:
278
+ options.update({key: shared.opts.data.get(key, None)})
279
+
280
+ return options
281
+
282
+ def set_config(self, req: Dict[str, Any]):
283
+ for k, v in req.items():
284
+ shared.opts.set(k, v)
285
+
286
+ shared.opts.save(shared.config_filename)
287
+ return
288
+
289
+ def get_cmd_flags(self):
290
+ return vars(shared.cmd_opts)
291
+
292
+ def get_samplers(self):
293
+ return [{"name": sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]
294
+
295
+ def get_upscalers(self):
296
+ upscalers = []
297
+
298
+ for upscaler in shared.sd_upscalers:
299
+ u = upscaler.scaler
300
+ upscalers.append({"name":u.name, "model_name":u.model_name, "model_path":u.model_path, "model_url":u.model_url})
301
+
302
+ return upscalers
303
+
304
+ def get_sd_models(self):
305
+ return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": x.config} for x in checkpoints_list.values()]
306
+
307
+ def get_hypernetworks(self):
308
+ return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
309
+
310
+ def get_face_restorers(self):
311
+ return [{"name":x.name(), "cmd_dir": getattr(x, "cmd_dir", None)} for x in shared.face_restorers]
312
+
313
+ def get_realesrgan_models(self):
314
+ return [{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)]
315
+
316
+ def get_prompt_styles(self):
317
+ styleList = []
318
+ for k in shared.prompt_styles.styles:
319
+ style = shared.prompt_styles.styles[k]
320
+ styleList.append({"name":style[0], "prompt": style[1], "negative_prompt": style[2]})
321
+
322
+ return styleList
323
+
324
+ def get_artists_categories(self):
325
+ return shared.artist_db.cats
326
+
327
+ def get_artists(self):
328
+ return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists]
329
+
330
+ def refresh_checkpoints(self):
331
+ shared.refresh_checkpoints()
332
+
333
+ def create_embedding(self, args: dict):
334
+ try:
335
+ shared.state.begin()
336
+ filename = create_embedding(**args) # create empty embedding
337
+ sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
338
+ shared.state.end()
339
+ return CreateResponse(info = "create embedding filename: {filename}".format(filename = filename))
340
+ except AssertionError as e:
341
+ shared.state.end()
342
+ return TrainResponse(info = "create embedding error: {error}".format(error = e))
343
+
344
+ def create_hypernetwork(self, args: dict):
345
+ try:
346
+ shared.state.begin()
347
+ filename = create_hypernetwork(**args) # create empty embedding
348
+ shared.state.end()
349
+ return CreateResponse(info = "create hypernetwork filename: {filename}".format(filename = filename))
350
+ except AssertionError as e:
351
+ shared.state.end()
352
+ return TrainResponse(info = "create hypernetwork error: {error}".format(error = e))
353
+
354
+ def preprocess(self, args: dict):
355
+ try:
356
+ shared.state.begin()
357
+ preprocess(**args) # quick operation unless blip/booru interrogation is enabled
358
+ shared.state.end()
359
+ return PreprocessResponse(info = 'preprocess complete')
360
+ except KeyError as e:
361
+ shared.state.end()
362
+ return PreprocessResponse(info = "preprocess error: invalid token: {error}".format(error = e))
363
+ except AssertionError as e:
364
+ shared.state.end()
365
+ return PreprocessResponse(info = "preprocess error: {error}".format(error = e))
366
+ except FileNotFoundError as e:
367
+ shared.state.end()
368
+ return PreprocessResponse(info = 'preprocess error: {error}'.format(error = e))
369
+
370
+ def train_embedding(self, args: dict):
371
+ try:
372
+ shared.state.begin()
373
+ apply_optimizations = shared.opts.training_xattention_optimizations
374
+ error = None
375
+ filename = ''
376
+ if not apply_optimizations:
377
+ sd_hijack.undo_optimizations()
378
+ try:
379
+ embedding, filename = train_embedding(**args) # can take a long time to complete
380
+ except Exception as e:
381
+ error = e
382
+ finally:
383
+ if not apply_optimizations:
384
+ sd_hijack.apply_optimizations()
385
+ shared.state.end()
386
+ return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error))
387
+ except AssertionError as msg:
388
+ shared.state.end()
389
+ return TrainResponse(info = "train embedding error: {msg}".format(msg = msg))
390
+
391
+ def train_hypernetwork(self, args: dict):
392
+ try:
393
+ shared.state.begin()
394
+ initial_hypernetwork = shared.loaded_hypernetwork
395
+ apply_optimizations = shared.opts.training_xattention_optimizations
396
+ error = None
397
+ filename = ''
398
+ if not apply_optimizations:
399
+ sd_hijack.undo_optimizations()
400
+ try:
401
+ hypernetwork, filename = train_hypernetwork(*args)
402
+ except Exception as e:
403
+ error = e
404
+ finally:
405
+ shared.loaded_hypernetwork = initial_hypernetwork
406
+ shared.sd_model.cond_stage_model.to(devices.device)
407
+ shared.sd_model.first_stage_model.to(devices.device)
408
+ if not apply_optimizations:
409
+ sd_hijack.apply_optimizations()
410
+ shared.state.end()
411
+ return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error))
412
+ except AssertionError as msg:
413
+ shared.state.end()
414
+ return TrainResponse(info = "train embedding error: {error}".format(error = error))
415
+
416
+ def launch(self, server_name, port):
417
+ self.app.include_router(self.router)
418
+ uvicorn.run(self.app, host=server_name, port=port)
modules/api/models.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from pydantic import BaseModel, Field, create_model
3
+ from typing import Any, Optional
4
+ from typing_extensions import Literal
5
+ from inflection import underscore
6
+ from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
7
+ from modules.shared import sd_upscalers, opts, parser
8
+ from typing import Dict, List
9
+
10
+ API_NOT_ALLOWED = [
11
+ "self",
12
+ "kwargs",
13
+ "sd_model",
14
+ "outpath_samples",
15
+ "outpath_grids",
16
+ "sampler_index",
17
+ "do_not_save_samples",
18
+ "do_not_save_grid",
19
+ "extra_generation_params",
20
+ "overlay_images",
21
+ "do_not_reload_embeddings",
22
+ "seed_enable_extras",
23
+ "prompt_for_display",
24
+ "sampler_noise_scheduler_override",
25
+ "ddim_discretize"
26
+ ]
27
+
28
+ class ModelDef(BaseModel):
29
+ """Assistance Class for Pydantic Dynamic Model Generation"""
30
+
31
+ field: str
32
+ field_alias: str
33
+ field_type: Any
34
+ field_value: Any
35
+ field_exclude: bool = False
36
+
37
+
38
+ class PydanticModelGenerator:
39
+ """
40
+ Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about:
41
+ source_data is a snapshot of the default values produced by the class
42
+ params are the names of the actual keys required by __init__
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ model_name: str = None,
48
+ class_instance = None,
49
+ additional_fields = None,
50
+ ):
51
+ def field_type_generator(k, v):
52
+ # field_type = str if not overrides.get(k) else overrides[k]["type"]
53
+ # print(k, v.annotation, v.default)
54
+ field_type = v.annotation
55
+
56
+ return Optional[field_type]
57
+
58
+ def merge_class_params(class_):
59
+ all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
60
+ parameters = {}
61
+ for classes in all_classes:
62
+ parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
63
+ return parameters
64
+
65
+
66
+ self._model_name = model_name
67
+ self._class_data = merge_class_params(class_instance)
68
+
69
+ self._model_def = [
70
+ ModelDef(
71
+ field=underscore(k),
72
+ field_alias=k,
73
+ field_type=field_type_generator(k, v),
74
+ field_value=v.default
75
+ )
76
+ for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
77
+ ]
78
+
79
+ for fields in additional_fields:
80
+ self._model_def.append(ModelDef(
81
+ field=underscore(fields["key"]),
82
+ field_alias=fields["key"],
83
+ field_type=fields["type"],
84
+ field_value=fields["default"],
85
+ field_exclude=fields["exclude"] if "exclude" in fields else False))
86
+
87
+ def generate_model(self):
88
+ """
89
+ Creates a pydantic BaseModel
90
+ from the json and overrides provided at initialization
91
+ """
92
+ fields = {
93
+ d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias, exclude=d.field_exclude)) for d in self._model_def
94
+ }
95
+ DynamicModel = create_model(self._model_name, **fields)
96
+ DynamicModel.__config__.allow_population_by_field_name = True
97
+ DynamicModel.__config__.allow_mutation = True
98
+ return DynamicModel
99
+
100
+ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
101
+ "StableDiffusionProcessingTxt2Img",
102
+ StableDiffusionProcessingTxt2Img,
103
+ [{"key": "sampler_index", "type": str, "default": "Euler"}]
104
+ ).generate_model()
105
+
106
+ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
107
+ "StableDiffusionProcessingImg2Img",
108
+ StableDiffusionProcessingImg2Img,
109
+ [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}]
110
+ ).generate_model()
111
+
112
+ class TextToImageResponse(BaseModel):
113
+ images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
114
+ parameters: dict
115
+ info: str
116
+
117
+ class ImageToImageResponse(BaseModel):
118
+ images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
119
+ parameters: dict
120
+ info: str
121
+
122
+ class ExtrasBaseRequest(BaseModel):
123
+ resize_mode: Literal[0, 1] = Field(default=0, title="Resize Mode", description="Sets the resize mode: 0 to upscale by upscaling_resize amount, 1 to upscale up to upscaling_resize_h x upscaling_resize_w.")
124
+ show_extras_results: bool = Field(default=True, title="Show results", description="Should the backend return the generated image?")
125
+ gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.")
126
+ codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.")
127
+ codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.")
128
+ upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=4, description="By how much to upscale the image, only used when resize_mode=0.")
129
+ upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.")
130
+ upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.")
131
+ upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the chosen size?")
132
+ upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
133
+ upscaler_2: str = Field(default="None", title="Secondary upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
134
+ extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.")
135
+ upscale_first: bool = Field(default=False, title="Upscale first", description="Should the upscaler run before restoring faces?")
136
+
137
+ class ExtraBaseResponse(BaseModel):
138
+ html_info: str = Field(title="HTML info", description="A series of HTML tags containing the process info.")
139
+
140
+ class ExtrasSingleImageRequest(ExtrasBaseRequest):
141
+ image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
142
+
143
+ class ExtrasSingleImageResponse(ExtraBaseResponse):
144
+ image: str = Field(default=None, title="Image", description="The generated image in base64 format.")
145
+
146
+ class FileData(BaseModel):
147
+ data: str = Field(title="File data", description="Base64 representation of the file")
148
+ name: str = Field(title="File name")
149
+
150
+ class ExtrasBatchImagesRequest(ExtrasBaseRequest):
151
+ imageList: List[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
152
+
153
+ class ExtrasBatchImagesResponse(ExtraBaseResponse):
154
+ images: List[str] = Field(title="Images", description="The generated images in base64 format.")
155
+
156
+ class PNGInfoRequest(BaseModel):
157
+ image: str = Field(title="Image", description="The base64 encoded PNG image")
158
+
159
+ class PNGInfoResponse(BaseModel):
160
+ info: str = Field(title="Image info", description="A string with all the info the image had")
161
+
162
+ class ProgressRequest(BaseModel):
163
+ skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization")
164
+
165
+ class ProgressResponse(BaseModel):
166
+ progress: float = Field(title="Progress", description="The progress with a range of 0 to 1")
167
+ eta_relative: float = Field(title="ETA in secs")
168
+ state: dict = Field(title="State", description="The current state snapshot")
169
+ current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
170
+
171
+ class InterrogateRequest(BaseModel):
172
+ image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
173
+ model: str = Field(default="clip", title="Model", description="The interrogate model used.")
174
+
175
+ class InterrogateResponse(BaseModel):
176
+ caption: str = Field(default=None, title="Caption", description="The generated caption for the image.")
177
+
178
+ class TrainResponse(BaseModel):
179
+ info: str = Field(title="Train info", description="Response string from train embedding or hypernetwork task.")
180
+
181
+ class CreateResponse(BaseModel):
182
+ info: str = Field(title="Create info", description="Response string from create embedding or hypernetwork task.")
183
+
184
+ class PreprocessResponse(BaseModel):
185
+ info: str = Field(title="Preprocess info", description="Response string from preprocessing task.")
186
+
187
+ fields = {}
188
+ for key, metadata in opts.data_labels.items():
189
+ value = opts.data.get(key)
190
+ optType = opts.typemap.get(type(metadata.default), type(value))
191
+
192
+ if (metadata is not None):
193
+ fields.update({key: (Optional[optType], Field(
194
+ default=metadata.default ,description=metadata.label))})
195
+ else:
196
+ fields.update({key: (Optional[optType], Field())})
197
+
198
+ OptionsModel = create_model("Options", **fields)
199
+
200
+ flags = {}
201
+ _options = vars(parser)['_option_string_actions']
202
+ for key in _options:
203
+ if(_options[key].dest != 'help'):
204
+ flag = _options[key]
205
+ _type = str
206
+ if _options[key].default is not None: _type = type(_options[key].default)
207
+ flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))})
208
+
209
+ FlagsModel = create_model("Flags", **flags)
210
+
211
+ class SamplerItem(BaseModel):
212
+ name: str = Field(title="Name")
213
+ aliases: List[str] = Field(title="Aliases")
214
+ options: Dict[str, str] = Field(title="Options")
215
+
216
+ class UpscalerItem(BaseModel):
217
+ name: str = Field(title="Name")
218
+ model_name: Optional[str] = Field(title="Model Name")
219
+ model_path: Optional[str] = Field(title="Path")
220
+ model_url: Optional[str] = Field(title="URL")
221
+
222
+ class SDModelItem(BaseModel):
223
+ title: str = Field(title="Title")
224
+ model_name: str = Field(title="Model Name")
225
+ hash: str = Field(title="Hash")
226
+ filename: str = Field(title="Filename")
227
+ config: str = Field(title="Config file")
228
+
229
+ class HypernetworkItem(BaseModel):
230
+ name: str = Field(title="Name")
231
+ path: Optional[str] = Field(title="Path")
232
+
233
+ class FaceRestorerItem(BaseModel):
234
+ name: str = Field(title="Name")
235
+ cmd_dir: Optional[str] = Field(title="Path")
236
+
237
+ class RealesrganItem(BaseModel):
238
+ name: str = Field(title="Name")
239
+ path: Optional[str] = Field(title="Path")
240
+ scale: Optional[int] = Field(title="Scale")
241
+
242
+ class PromptStyleItem(BaseModel):
243
+ name: str = Field(title="Name")
244
+ prompt: Optional[str] = Field(title="Prompt")
245
+ negative_prompt: Optional[str] = Field(title="Negative Prompt")
246
+
247
+ class ArtistItem(BaseModel):
248
+ name: str = Field(title="Name")
249
+ score: float = Field(title="Score")
250
+ category: str = Field(title="Category")
251
+
modules/artists.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ import csv
3
+ from collections import namedtuple
4
+
5
+ Artist = namedtuple("Artist", ['name', 'weight', 'category'])
6
+
7
+
8
+ class ArtistsDatabase:
9
+ def __init__(self, filename):
10
+ self.cats = set()
11
+ self.artists = []
12
+
13
+ if not os.path.exists(filename):
14
+ return
15
+
16
+ with open(filename, "r", newline='', encoding="utf8") as file:
17
+ reader = csv.DictReader(file)
18
+
19
+ for row in reader:
20
+ artist = Artist(row["artist"], float(row["score"]), row["category"])
21
+ self.artists.append(artist)
22
+ self.cats.add(artist.category)
23
+
24
+ def categories(self):
25
+ return sorted(self.cats)
modules/call_queue.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import html
2
+ import sys
3
+ import threading
4
+ import traceback
5
+ import time
6
+
7
+ from modules import shared
8
+
9
+ queue_lock = threading.Lock()
10
+
11
+
12
+ def wrap_queued_call(func):
13
+ def f(*args, **kwargs):
14
+ with queue_lock:
15
+ res = func(*args, **kwargs)
16
+
17
+ return res
18
+
19
+ return f
20
+
21
+
22
+ def wrap_gradio_gpu_call(func, extra_outputs=None):
23
+ def f(*args, **kwargs):
24
+
25
+ shared.state.begin()
26
+
27
+ with queue_lock:
28
+ res = func(*args, **kwargs)
29
+
30
+ shared.state.end()
31
+
32
+ return res
33
+
34
+ return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True)
35
+
36
+
37
+ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
38
+ def f(*args, extra_outputs_array=extra_outputs, **kwargs):
39
+ run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
40
+ if run_memmon:
41
+ shared.mem_mon.monitor()
42
+ t = time.perf_counter()
43
+
44
+ try:
45
+ res = list(func(*args, **kwargs))
46
+ except Exception as e:
47
+ # When printing out our debug argument list, do not print out more than a MB of text
48
+ max_debug_str_len = 131072 # (1024*1024)/8
49
+
50
+ print("Error completing request", file=sys.stderr)
51
+ argStr = f"Arguments: {str(args)} {str(kwargs)}"
52
+ print(argStr[:max_debug_str_len], file=sys.stderr)
53
+ if len(argStr) > max_debug_str_len:
54
+ print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
55
+
56
+ print(traceback.format_exc(), file=sys.stderr)
57
+
58
+ shared.state.job = ""
59
+ shared.state.job_count = 0
60
+
61
+ if extra_outputs_array is None:
62
+ extra_outputs_array = [None, '']
63
+
64
+ res = extra_outputs_array + [f"<div class='error'>{html.escape(type(e).__name__+': '+str(e))}</div>"]
65
+
66
+ shared.state.skipped = False
67
+ shared.state.interrupted = False
68
+ shared.state.job_count = 0
69
+
70
+ if not add_stats:
71
+ return tuple(res)
72
+
73
+ elapsed = time.perf_counter() - t
74
+ elapsed_m = int(elapsed // 60)
75
+ elapsed_s = elapsed % 60
76
+ elapsed_text = f"{elapsed_s:.2f}s"
77
+ if elapsed_m > 0:
78
+ elapsed_text = f"{elapsed_m}m "+elapsed_text
79
+
80
+ if run_memmon:
81
+ mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
82
+ active_peak = mem_stats['active_peak']
83
+ reserved_peak = mem_stats['reserved_peak']
84
+ sys_peak = mem_stats['system_peak']
85
+ sys_total = mem_stats['total']
86
+ sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2)
87
+
88
+ vram_html = f"<p class='vram'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p>"
89
+ else:
90
+ vram_html = ''
91
+
92
+ # last item is always HTML
93
+ res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
94
+
95
+ return tuple(res)
96
+
97
+ return f
98
+
modules/codeformer/codeformer_arch.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
2
+
3
+ import math
4
+ import numpy as np
5
+ import torch
6
+ from torch import nn, Tensor
7
+ import torch.nn.functional as F
8
+ from typing import Optional, List
9
+
10
+ from modules.codeformer.vqgan_arch import *
11
+ from basicsr.utils import get_root_logger
12
+ from basicsr.utils.registry import ARCH_REGISTRY
13
+
14
+ def calc_mean_std(feat, eps=1e-5):
15
+ """Calculate mean and std for adaptive_instance_normalization.
16
+
17
+ Args:
18
+ feat (Tensor): 4D tensor.
19
+ eps (float): A small value added to the variance to avoid
20
+ divide-by-zero. Default: 1e-5.
21
+ """
22
+ size = feat.size()
23
+ assert len(size) == 4, 'The input feature should be 4D tensor.'
24
+ b, c = size[:2]
25
+ feat_var = feat.view(b, c, -1).var(dim=2) + eps
26
+ feat_std = feat_var.sqrt().view(b, c, 1, 1)
27
+ feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
28
+ return feat_mean, feat_std
29
+
30
+
31
+ def adaptive_instance_normalization(content_feat, style_feat):
32
+ """Adaptive instance normalization.
33
+
34
+ Adjust the reference features to have the similar color and illuminations
35
+ as those in the degradate features.
36
+
37
+ Args:
38
+ content_feat (Tensor): The reference feature.
39
+ style_feat (Tensor): The degradate features.
40
+ """
41
+ size = content_feat.size()
42
+ style_mean, style_std = calc_mean_std(style_feat)
43
+ content_mean, content_std = calc_mean_std(content_feat)
44
+ normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
45
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
46
+
47
+
48
+ class PositionEmbeddingSine(nn.Module):
49
+ """
50
+ This is a more standard version of the position embedding, very similar to the one
51
+ used by the Attention is all you need paper, generalized to work on images.
52
+ """
53
+
54
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
55
+ super().__init__()
56
+ self.num_pos_feats = num_pos_feats
57
+ self.temperature = temperature
58
+ self.normalize = normalize
59
+ if scale is not None and normalize is False:
60
+ raise ValueError("normalize should be True if scale is passed")
61
+ if scale is None:
62
+ scale = 2 * math.pi
63
+ self.scale = scale
64
+
65
+ def forward(self, x, mask=None):
66
+ if mask is None:
67
+ mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
68
+ not_mask = ~mask
69
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
70
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
71
+ if self.normalize:
72
+ eps = 1e-6
73
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
74
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
75
+
76
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
77
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
78
+
79
+ pos_x = x_embed[:, :, :, None] / dim_t
80
+ pos_y = y_embed[:, :, :, None] / dim_t
81
+ pos_x = torch.stack(
82
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
83
+ ).flatten(3)
84
+ pos_y = torch.stack(
85
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
86
+ ).flatten(3)
87
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
88
+ return pos
89
+
90
+ def _get_activation_fn(activation):
91
+ """Return an activation function given a string"""
92
+ if activation == "relu":
93
+ return F.relu
94
+ if activation == "gelu":
95
+ return F.gelu
96
+ if activation == "glu":
97
+ return F.glu
98
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
99
+
100
+
101
+ class TransformerSALayer(nn.Module):
102
+ def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
103
+ super().__init__()
104
+ self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
105
+ # Implementation of Feedforward model - MLP
106
+ self.linear1 = nn.Linear(embed_dim, dim_mlp)
107
+ self.dropout = nn.Dropout(dropout)
108
+ self.linear2 = nn.Linear(dim_mlp, embed_dim)
109
+
110
+ self.norm1 = nn.LayerNorm(embed_dim)
111
+ self.norm2 = nn.LayerNorm(embed_dim)
112
+ self.dropout1 = nn.Dropout(dropout)
113
+ self.dropout2 = nn.Dropout(dropout)
114
+
115
+ self.activation = _get_activation_fn(activation)
116
+
117
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
118
+ return tensor if pos is None else tensor + pos
119
+
120
+ def forward(self, tgt,
121
+ tgt_mask: Optional[Tensor] = None,
122
+ tgt_key_padding_mask: Optional[Tensor] = None,
123
+ query_pos: Optional[Tensor] = None):
124
+
125
+ # self attention
126
+ tgt2 = self.norm1(tgt)
127
+ q = k = self.with_pos_embed(tgt2, query_pos)
128
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
129
+ key_padding_mask=tgt_key_padding_mask)[0]
130
+ tgt = tgt + self.dropout1(tgt2)
131
+
132
+ # ffn
133
+ tgt2 = self.norm2(tgt)
134
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
135
+ tgt = tgt + self.dropout2(tgt2)
136
+ return tgt
137
+
138
+ class Fuse_sft_block(nn.Module):
139
+ def __init__(self, in_ch, out_ch):
140
+ super().__init__()
141
+ self.encode_enc = ResBlock(2*in_ch, out_ch)
142
+
143
+ self.scale = nn.Sequential(
144
+ nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
145
+ nn.LeakyReLU(0.2, True),
146
+ nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
147
+
148
+ self.shift = nn.Sequential(
149
+ nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
150
+ nn.LeakyReLU(0.2, True),
151
+ nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
152
+
153
+ def forward(self, enc_feat, dec_feat, w=1):
154
+ enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
155
+ scale = self.scale(enc_feat)
156
+ shift = self.shift(enc_feat)
157
+ residual = w * (dec_feat * scale + shift)
158
+ out = dec_feat + residual
159
+ return out
160
+
161
+
162
+ @ARCH_REGISTRY.register()
163
+ class CodeFormer(VQAutoEncoder):
164
+ def __init__(self, dim_embd=512, n_head=8, n_layers=9,
165
+ codebook_size=1024, latent_size=256,
166
+ connect_list=['32', '64', '128', '256'],
167
+ fix_modules=['quantize','generator']):
168
+ super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
169
+
170
+ if fix_modules is not None:
171
+ for module in fix_modules:
172
+ for param in getattr(self, module).parameters():
173
+ param.requires_grad = False
174
+
175
+ self.connect_list = connect_list
176
+ self.n_layers = n_layers
177
+ self.dim_embd = dim_embd
178
+ self.dim_mlp = dim_embd*2
179
+
180
+ self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
181
+ self.feat_emb = nn.Linear(256, self.dim_embd)
182
+
183
+ # transformer
184
+ self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
185
+ for _ in range(self.n_layers)])
186
+
187
+ # logits_predict head
188
+ self.idx_pred_layer = nn.Sequential(
189
+ nn.LayerNorm(dim_embd),
190
+ nn.Linear(dim_embd, codebook_size, bias=False))
191
+
192
+ self.channels = {
193
+ '16': 512,
194
+ '32': 256,
195
+ '64': 256,
196
+ '128': 128,
197
+ '256': 128,
198
+ '512': 64,
199
+ }
200
+
201
+ # after second residual block for > 16, before attn layer for ==16
202
+ self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
203
+ # after first residual block for > 16, before attn layer for ==16
204
+ self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
205
+
206
+ # fuse_convs_dict
207
+ self.fuse_convs_dict = nn.ModuleDict()
208
+ for f_size in self.connect_list:
209
+ in_ch = self.channels[f_size]
210
+ self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
211
+
212
+ def _init_weights(self, module):
213
+ if isinstance(module, (nn.Linear, nn.Embedding)):
214
+ module.weight.data.normal_(mean=0.0, std=0.02)
215
+ if isinstance(module, nn.Linear) and module.bias is not None:
216
+ module.bias.data.zero_()
217
+ elif isinstance(module, nn.LayerNorm):
218
+ module.bias.data.zero_()
219
+ module.weight.data.fill_(1.0)
220
+
221
+ def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
222
+ # ################### Encoder #####################
223
+ enc_feat_dict = {}
224
+ out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
225
+ for i, block in enumerate(self.encoder.blocks):
226
+ x = block(x)
227
+ if i in out_list:
228
+ enc_feat_dict[str(x.shape[-1])] = x.clone()
229
+
230
+ lq_feat = x
231
+ # ################# Transformer ###################
232
+ # quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
233
+ pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
234
+ # BCHW -> BC(HW) -> (HW)BC
235
+ feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
236
+ query_emb = feat_emb
237
+ # Transformer encoder
238
+ for layer in self.ft_layers:
239
+ query_emb = layer(query_emb, query_pos=pos_emb)
240
+
241
+ # output logits
242
+ logits = self.idx_pred_layer(query_emb) # (hw)bn
243
+ logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
244
+
245
+ if code_only: # for training stage II
246
+ # logits doesn't need softmax before cross_entropy loss
247
+ return logits, lq_feat
248
+
249
+ # ################# Quantization ###################
250
+ # if self.training:
251
+ # quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
252
+ # # b(hw)c -> bc(hw) -> bchw
253
+ # quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
254
+ # ------------
255
+ soft_one_hot = F.softmax(logits, dim=2)
256
+ _, top_idx = torch.topk(soft_one_hot, 1, dim=2)
257
+ quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
258
+ # preserve gradients
259
+ # quant_feat = lq_feat + (quant_feat - lq_feat).detach()
260
+
261
+ if detach_16:
262
+ quant_feat = quant_feat.detach() # for training stage III
263
+ if adain:
264
+ quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
265
+
266
+ # ################## Generator ####################
267
+ x = quant_feat
268
+ fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
269
+
270
+ for i, block in enumerate(self.generator.blocks):
271
+ x = block(x)
272
+ if i in fuse_list: # fuse after i-th block
273
+ f_size = str(x.shape[-1])
274
+ if w>0:
275
+ x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
276
+ out = x
277
+ # logits doesn't need softmax before cross_entropy loss
278
+ return out, logits, lq_feat
modules/codeformer/vqgan_arch.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
2
+
3
+ '''
4
+ VQGAN code, adapted from the original created by the Unleashing Transformers authors:
5
+ https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
6
+
7
+ '''
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import copy
13
+ from basicsr.utils import get_root_logger
14
+ from basicsr.utils.registry import ARCH_REGISTRY
15
+
16
+ def normalize(in_channels):
17
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
18
+
19
+
20
+ @torch.jit.script
21
+ def swish(x):
22
+ return x*torch.sigmoid(x)
23
+
24
+
25
+ # Define VQVAE classes
26
+ class VectorQuantizer(nn.Module):
27
+ def __init__(self, codebook_size, emb_dim, beta):
28
+ super(VectorQuantizer, self).__init__()
29
+ self.codebook_size = codebook_size # number of embeddings
30
+ self.emb_dim = emb_dim # dimension of embedding
31
+ self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
32
+ self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
33
+ self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
34
+
35
+ def forward(self, z):
36
+ # reshape z -> (batch, height, width, channel) and flatten
37
+ z = z.permute(0, 2, 3, 1).contiguous()
38
+ z_flattened = z.view(-1, self.emb_dim)
39
+
40
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
41
+ d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \
42
+ 2 * torch.matmul(z_flattened, self.embedding.weight.t())
43
+
44
+ mean_distance = torch.mean(d)
45
+ # find closest encodings
46
+ # min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
47
+ min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
48
+ # [0-1], higher score, higher confidence
49
+ min_encoding_scores = torch.exp(-min_encoding_scores/10)
50
+
51
+ min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z)
52
+ min_encodings.scatter_(1, min_encoding_indices, 1)
53
+
54
+ # get quantized latent vectors
55
+ z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
56
+ # compute loss for embedding
57
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
58
+ # preserve gradients
59
+ z_q = z + (z_q - z).detach()
60
+
61
+ # perplexity
62
+ e_mean = torch.mean(min_encodings, dim=0)
63
+ perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
64
+ # reshape back to match original input shape
65
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
66
+
67
+ return z_q, loss, {
68
+ "perplexity": perplexity,
69
+ "min_encodings": min_encodings,
70
+ "min_encoding_indices": min_encoding_indices,
71
+ "min_encoding_scores": min_encoding_scores,
72
+ "mean_distance": mean_distance
73
+ }
74
+
75
+ def get_codebook_feat(self, indices, shape):
76
+ # input indices: batch*token_num -> (batch*token_num)*1
77
+ # shape: batch, height, width, channel
78
+ indices = indices.view(-1,1)
79
+ min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
80
+ min_encodings.scatter_(1, indices, 1)
81
+ # get quantized latent vectors
82
+ z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
83
+
84
+ if shape is not None: # reshape back to match original input shape
85
+ z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
86
+
87
+ return z_q
88
+
89
+
90
+ class GumbelQuantizer(nn.Module):
91
+ def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
92
+ super().__init__()
93
+ self.codebook_size = codebook_size # number of embeddings
94
+ self.emb_dim = emb_dim # dimension of embedding
95
+ self.straight_through = straight_through
96
+ self.temperature = temp_init
97
+ self.kl_weight = kl_weight
98
+ self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits
99
+ self.embed = nn.Embedding(codebook_size, emb_dim)
100
+
101
+ def forward(self, z):
102
+ hard = self.straight_through if self.training else True
103
+
104
+ logits = self.proj(z)
105
+
106
+ soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
107
+
108
+ z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
109
+
110
+ # + kl divergence to the prior loss
111
+ qy = F.softmax(logits, dim=1)
112
+ diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
113
+ min_encoding_indices = soft_one_hot.argmax(dim=1)
114
+
115
+ return z_q, diff, {
116
+ "min_encoding_indices": min_encoding_indices
117
+ }
118
+
119
+
120
+ class Downsample(nn.Module):
121
+ def __init__(self, in_channels):
122
+ super().__init__()
123
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
124
+
125
+ def forward(self, x):
126
+ pad = (0, 1, 0, 1)
127
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
128
+ x = self.conv(x)
129
+ return x
130
+
131
+
132
+ class Upsample(nn.Module):
133
+ def __init__(self, in_channels):
134
+ super().__init__()
135
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
136
+
137
+ def forward(self, x):
138
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
139
+ x = self.conv(x)
140
+
141
+ return x
142
+
143
+
144
+ class ResBlock(nn.Module):
145
+ def __init__(self, in_channels, out_channels=None):
146
+ super(ResBlock, self).__init__()
147
+ self.in_channels = in_channels
148
+ self.out_channels = in_channels if out_channels is None else out_channels
149
+ self.norm1 = normalize(in_channels)
150
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
151
+ self.norm2 = normalize(out_channels)
152
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
153
+ if self.in_channels != self.out_channels:
154
+ self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
155
+
156
+ def forward(self, x_in):
157
+ x = x_in
158
+ x = self.norm1(x)
159
+ x = swish(x)
160
+ x = self.conv1(x)
161
+ x = self.norm2(x)
162
+ x = swish(x)
163
+ x = self.conv2(x)
164
+ if self.in_channels != self.out_channels:
165
+ x_in = self.conv_out(x_in)
166
+
167
+ return x + x_in
168
+
169
+
170
+ class AttnBlock(nn.Module):
171
+ def __init__(self, in_channels):
172
+ super().__init__()
173
+ self.in_channels = in_channels
174
+
175
+ self.norm = normalize(in_channels)
176
+ self.q = torch.nn.Conv2d(
177
+ in_channels,
178
+ in_channels,
179
+ kernel_size=1,
180
+ stride=1,
181
+ padding=0
182
+ )
183
+ self.k = torch.nn.Conv2d(
184
+ in_channels,
185
+ in_channels,
186
+ kernel_size=1,
187
+ stride=1,
188
+ padding=0
189
+ )
190
+ self.v = torch.nn.Conv2d(
191
+ in_channels,
192
+ in_channels,
193
+ kernel_size=1,
194
+ stride=1,
195
+ padding=0
196
+ )
197
+ self.proj_out = torch.nn.Conv2d(
198
+ in_channels,
199
+ in_channels,
200
+ kernel_size=1,
201
+ stride=1,
202
+ padding=0
203
+ )
204
+
205
+ def forward(self, x):
206
+ h_ = x
207
+ h_ = self.norm(h_)
208
+ q = self.q(h_)
209
+ k = self.k(h_)
210
+ v = self.v(h_)
211
+
212
+ # compute attention
213
+ b, c, h, w = q.shape
214
+ q = q.reshape(b, c, h*w)
215
+ q = q.permute(0, 2, 1)
216
+ k = k.reshape(b, c, h*w)
217
+ w_ = torch.bmm(q, k)
218
+ w_ = w_ * (int(c)**(-0.5))
219
+ w_ = F.softmax(w_, dim=2)
220
+
221
+ # attend to values
222
+ v = v.reshape(b, c, h*w)
223
+ w_ = w_.permute(0, 2, 1)
224
+ h_ = torch.bmm(v, w_)
225
+ h_ = h_.reshape(b, c, h, w)
226
+
227
+ h_ = self.proj_out(h_)
228
+
229
+ return x+h_
230
+
231
+
232
+ class Encoder(nn.Module):
233
+ def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
234
+ super().__init__()
235
+ self.nf = nf
236
+ self.num_resolutions = len(ch_mult)
237
+ self.num_res_blocks = num_res_blocks
238
+ self.resolution = resolution
239
+ self.attn_resolutions = attn_resolutions
240
+
241
+ curr_res = self.resolution
242
+ in_ch_mult = (1,)+tuple(ch_mult)
243
+
244
+ blocks = []
245
+ # initial convultion
246
+ blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
247
+
248
+ # residual and downsampling blocks, with attention on smaller res (16x16)
249
+ for i in range(self.num_resolutions):
250
+ block_in_ch = nf * in_ch_mult[i]
251
+ block_out_ch = nf * ch_mult[i]
252
+ for _ in range(self.num_res_blocks):
253
+ blocks.append(ResBlock(block_in_ch, block_out_ch))
254
+ block_in_ch = block_out_ch
255
+ if curr_res in attn_resolutions:
256
+ blocks.append(AttnBlock(block_in_ch))
257
+
258
+ if i != self.num_resolutions - 1:
259
+ blocks.append(Downsample(block_in_ch))
260
+ curr_res = curr_res // 2
261
+
262
+ # non-local attention block
263
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
264
+ blocks.append(AttnBlock(block_in_ch))
265
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
266
+
267
+ # normalise and convert to latent size
268
+ blocks.append(normalize(block_in_ch))
269
+ blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
270
+ self.blocks = nn.ModuleList(blocks)
271
+
272
+ def forward(self, x):
273
+ for block in self.blocks:
274
+ x = block(x)
275
+
276
+ return x
277
+
278
+
279
+ class Generator(nn.Module):
280
+ def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
281
+ super().__init__()
282
+ self.nf = nf
283
+ self.ch_mult = ch_mult
284
+ self.num_resolutions = len(self.ch_mult)
285
+ self.num_res_blocks = res_blocks
286
+ self.resolution = img_size
287
+ self.attn_resolutions = attn_resolutions
288
+ self.in_channels = emb_dim
289
+ self.out_channels = 3
290
+ block_in_ch = self.nf * self.ch_mult[-1]
291
+ curr_res = self.resolution // 2 ** (self.num_resolutions-1)
292
+
293
+ blocks = []
294
+ # initial conv
295
+ blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
296
+
297
+ # non-local attention block
298
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
299
+ blocks.append(AttnBlock(block_in_ch))
300
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
301
+
302
+ for i in reversed(range(self.num_resolutions)):
303
+ block_out_ch = self.nf * self.ch_mult[i]
304
+
305
+ for _ in range(self.num_res_blocks):
306
+ blocks.append(ResBlock(block_in_ch, block_out_ch))
307
+ block_in_ch = block_out_ch
308
+
309
+ if curr_res in self.attn_resolutions:
310
+ blocks.append(AttnBlock(block_in_ch))
311
+
312
+ if i != 0:
313
+ blocks.append(Upsample(block_in_ch))
314
+ curr_res = curr_res * 2
315
+
316
+ blocks.append(normalize(block_in_ch))
317
+ blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
318
+
319
+ self.blocks = nn.ModuleList(blocks)
320
+
321
+
322
+ def forward(self, x):
323
+ for block in self.blocks:
324
+ x = block(x)
325
+
326
+ return x
327
+
328
+
329
+ @ARCH_REGISTRY.register()
330
+ class VQAutoEncoder(nn.Module):
331
+ def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
332
+ beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
333
+ super().__init__()
334
+ logger = get_root_logger()
335
+ self.in_channels = 3
336
+ self.nf = nf
337
+ self.n_blocks = res_blocks
338
+ self.codebook_size = codebook_size
339
+ self.embed_dim = emb_dim
340
+ self.ch_mult = ch_mult
341
+ self.resolution = img_size
342
+ self.attn_resolutions = attn_resolutions
343
+ self.quantizer_type = quantizer
344
+ self.encoder = Encoder(
345
+ self.in_channels,
346
+ self.nf,
347
+ self.embed_dim,
348
+ self.ch_mult,
349
+ self.n_blocks,
350
+ self.resolution,
351
+ self.attn_resolutions
352
+ )
353
+ if self.quantizer_type == "nearest":
354
+ self.beta = beta #0.25
355
+ self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta)
356
+ elif self.quantizer_type == "gumbel":
357
+ self.gumbel_num_hiddens = emb_dim
358
+ self.straight_through = gumbel_straight_through
359
+ self.kl_weight = gumbel_kl_weight
360
+ self.quantize = GumbelQuantizer(
361
+ self.codebook_size,
362
+ self.embed_dim,
363
+ self.gumbel_num_hiddens,
364
+ self.straight_through,
365
+ self.kl_weight
366
+ )
367
+ self.generator = Generator(
368
+ self.nf,
369
+ self.embed_dim,
370
+ self.ch_mult,
371
+ self.n_blocks,
372
+ self.resolution,
373
+ self.attn_resolutions
374
+ )
375
+
376
+ if model_path is not None:
377
+ chkpt = torch.load(model_path, map_location='cpu')
378
+ if 'params_ema' in chkpt:
379
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema'])
380
+ logger.info(f'vqgan is loaded from: {model_path} [params_ema]')
381
+ elif 'params' in chkpt:
382
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
383
+ logger.info(f'vqgan is loaded from: {model_path} [params]')
384
+ else:
385
+ raise ValueError('Wrong params!')
386
+
387
+
388
+ def forward(self, x):
389
+ x = self.encoder(x)
390
+ quant, codebook_loss, quant_stats = self.quantize(x)
391
+ x = self.generator(quant)
392
+ return x, codebook_loss, quant_stats
393
+
394
+
395
+
396
+ # patch based discriminator
397
+ @ARCH_REGISTRY.register()
398
+ class VQGANDiscriminator(nn.Module):
399
+ def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
400
+ super().__init__()
401
+
402
+ layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
403
+ ndf_mult = 1
404
+ ndf_mult_prev = 1
405
+ for n in range(1, n_layers): # gradually increase the number of filters
406
+ ndf_mult_prev = ndf_mult
407
+ ndf_mult = min(2 ** n, 8)
408
+ layers += [
409
+ nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
410
+ nn.BatchNorm2d(ndf * ndf_mult),
411
+ nn.LeakyReLU(0.2, True)
412
+ ]
413
+
414
+ ndf_mult_prev = ndf_mult
415
+ ndf_mult = min(2 ** n_layers, 8)
416
+
417
+ layers += [
418
+ nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
419
+ nn.BatchNorm2d(ndf * ndf_mult),
420
+ nn.LeakyReLU(0.2, True)
421
+ ]
422
+
423
+ layers += [
424
+ nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map
425
+ self.main = nn.Sequential(*layers)
426
+
427
+ if model_path is not None:
428
+ chkpt = torch.load(model_path, map_location='cpu')
429
+ if 'params_d' in chkpt:
430
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
431
+ elif 'params' in chkpt:
432
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
433
+ else:
434
+ raise ValueError('Wrong params!')
435
+
436
+ def forward(self, x):
437
+ return self.main(x)