Commit
•
f2a3c57
1
Parent(s):
33b550a
Initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .github/ISSUE_TEMPLATE/bug_report.yml +83 -0
- .github/ISSUE_TEMPLATE/config.yml +5 -0
- .github/ISSUE_TEMPLATE/feature_request.yml +40 -0
- .github/PULL_REQUEST_TEMPLATE/pull_request_template.md +28 -0
- .github/workflows/on_pull_request.yaml +42 -0
- .github/workflows/run_tests.yaml +31 -0
- .gitignore +34 -0
- .pylintrc +3 -0
- CODEOWNERS +12 -0
- README.md +151 -3
- artists.csv +0 -0
- configs/alt-diffusion-inference.yaml +72 -0
- configs/v1-inference.yaml +70 -0
- environment-wsl2.yaml +11 -0
- extensions-builtin/LDSR/ldsr_model_arch.py +256 -0
- extensions-builtin/LDSR/preload.py +6 -0
- extensions-builtin/LDSR/scripts/ldsr_model.py +69 -0
- extensions-builtin/LDSR/sd_hijack_autoencoder.py +286 -0
- extensions-builtin/LDSR/sd_hijack_ddpm_v1.py +1449 -0
- extensions-builtin/ScuNET/preload.py +6 -0
- extensions-builtin/ScuNET/scripts/scunet_model.py +87 -0
- extensions-builtin/ScuNET/scunet_model_arch.py +265 -0
- extensions-builtin/SwinIR/preload.py +6 -0
- extensions-builtin/SwinIR/scripts/swinir_model.py +172 -0
- extensions-builtin/SwinIR/swinir_model_arch.py +867 -0
- extensions-builtin/SwinIR/swinir_model_arch_v2.py +1017 -0
- extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js +107 -0
- javascript/aspectRatioOverlay.js +108 -0
- javascript/contextMenus.js +177 -0
- javascript/dragdrop.js +89 -0
- javascript/edit-attention.js +75 -0
- javascript/extensions.js +35 -0
- javascript/generationParams.js +33 -0
- javascript/hints.js +136 -0
- javascript/imageMaskFix.js +45 -0
- javascript/imageParams.js +19 -0
- javascript/imageviewer.js +276 -0
- javascript/localization.js +167 -0
- javascript/notification.js +49 -0
- javascript/progressbar.js +149 -0
- javascript/textualInversion.js +8 -0
- javascript/ui.js +222 -0
- launch.py +295 -0
- localizations/Put localization files here.txt +0 -0
- modules/api/api.py +418 -0
- modules/api/models.py +251 -0
- modules/artists.py +25 -0
- modules/call_queue.py +98 -0
- modules/codeformer/codeformer_arch.py +278 -0
- modules/codeformer/vqgan_arch.py +437 -0
.github/ISSUE_TEMPLATE/bug_report.yml
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Bug Report
|
2 |
+
description: You think somethings is broken in the UI
|
3 |
+
title: "[Bug]: "
|
4 |
+
labels: ["bug-report"]
|
5 |
+
|
6 |
+
body:
|
7 |
+
- type: checkboxes
|
8 |
+
attributes:
|
9 |
+
label: Is there an existing issue for this?
|
10 |
+
description: Please search to see if an issue already exists for the bug you encountered, and that it hasn't been fixed in a recent build/commit.
|
11 |
+
options:
|
12 |
+
- label: I have searched the existing issues and checked the recent builds/commits
|
13 |
+
required: true
|
14 |
+
- type: markdown
|
15 |
+
attributes:
|
16 |
+
value: |
|
17 |
+
*Please fill this form with as much information as possible, don't forget to fill "What OS..." and "What browsers" and *provide screenshots if possible**
|
18 |
+
- type: textarea
|
19 |
+
id: what-did
|
20 |
+
attributes:
|
21 |
+
label: What happened?
|
22 |
+
description: Tell us what happened in a very clear and simple way
|
23 |
+
validations:
|
24 |
+
required: true
|
25 |
+
- type: textarea
|
26 |
+
id: steps
|
27 |
+
attributes:
|
28 |
+
label: Steps to reproduce the problem
|
29 |
+
description: Please provide us with precise step by step information on how to reproduce the bug
|
30 |
+
value: |
|
31 |
+
1. Go to ....
|
32 |
+
2. Press ....
|
33 |
+
3. ...
|
34 |
+
validations:
|
35 |
+
required: true
|
36 |
+
- type: textarea
|
37 |
+
id: what-should
|
38 |
+
attributes:
|
39 |
+
label: What should have happened?
|
40 |
+
description: tell what you think the normal behavior should be
|
41 |
+
validations:
|
42 |
+
required: true
|
43 |
+
- type: input
|
44 |
+
id: commit
|
45 |
+
attributes:
|
46 |
+
label: Commit where the problem happens
|
47 |
+
description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit hash** shown in the cmd/terminal when you launch the UI)
|
48 |
+
validations:
|
49 |
+
required: true
|
50 |
+
- type: dropdown
|
51 |
+
id: platforms
|
52 |
+
attributes:
|
53 |
+
label: What platforms do you use to access UI ?
|
54 |
+
multiple: true
|
55 |
+
options:
|
56 |
+
- Windows
|
57 |
+
- Linux
|
58 |
+
- MacOS
|
59 |
+
- iOS
|
60 |
+
- Android
|
61 |
+
- Other/Cloud
|
62 |
+
- type: dropdown
|
63 |
+
id: browsers
|
64 |
+
attributes:
|
65 |
+
label: What browsers do you use to access the UI ?
|
66 |
+
multiple: true
|
67 |
+
options:
|
68 |
+
- Mozilla Firefox
|
69 |
+
- Google Chrome
|
70 |
+
- Brave
|
71 |
+
- Apple Safari
|
72 |
+
- Microsoft Edge
|
73 |
+
- type: textarea
|
74 |
+
id: cmdargs
|
75 |
+
attributes:
|
76 |
+
label: Command Line Arguments
|
77 |
+
description: Are you using any launching parameters/command line arguments (modified webui-user.py) ? If yes, please write them below
|
78 |
+
render: Shell
|
79 |
+
- type: textarea
|
80 |
+
id: misc
|
81 |
+
attributes:
|
82 |
+
label: Additional information, context and logs
|
83 |
+
description: Please provide us with any relevant additional info, context or log output.
|
.github/ISSUE_TEMPLATE/config.yml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
blank_issues_enabled: false
|
2 |
+
contact_links:
|
3 |
+
- name: WebUI Community Support
|
4 |
+
url: https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions
|
5 |
+
about: Please ask and answer questions here.
|
.github/ISSUE_TEMPLATE/feature_request.yml
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Feature request
|
2 |
+
description: Suggest an idea for this project
|
3 |
+
title: "[Feature Request]: "
|
4 |
+
labels: ["suggestion"]
|
5 |
+
|
6 |
+
body:
|
7 |
+
- type: checkboxes
|
8 |
+
attributes:
|
9 |
+
label: Is there an existing issue for this?
|
10 |
+
description: Please search to see if an issue already exists for the feature you want, and that it's not implemented in a recent build/commit.
|
11 |
+
options:
|
12 |
+
- label: I have searched the existing issues and checked the recent builds/commits
|
13 |
+
required: true
|
14 |
+
- type: markdown
|
15 |
+
attributes:
|
16 |
+
value: |
|
17 |
+
*Please fill this form with as much information as possible, provide screenshots and/or illustrations of the feature if possible*
|
18 |
+
- type: textarea
|
19 |
+
id: feature
|
20 |
+
attributes:
|
21 |
+
label: What would your feature do ?
|
22 |
+
description: Tell us about your feature in a very clear and simple way, and what problem it would solve
|
23 |
+
validations:
|
24 |
+
required: true
|
25 |
+
- type: textarea
|
26 |
+
id: workflow
|
27 |
+
attributes:
|
28 |
+
label: Proposed workflow
|
29 |
+
description: Please provide us with step by step information on how you'd like the feature to be accessed and used
|
30 |
+
value: |
|
31 |
+
1. Go to ....
|
32 |
+
2. Press ....
|
33 |
+
3. ...
|
34 |
+
validations:
|
35 |
+
required: true
|
36 |
+
- type: textarea
|
37 |
+
id: misc
|
38 |
+
attributes:
|
39 |
+
label: Additional information
|
40 |
+
description: Add any other context or screenshots about the feature request here.
|
.github/PULL_REQUEST_TEMPLATE/pull_request_template.md
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Please read the [contributing wiki page](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing) before submitting a pull request!
|
2 |
+
|
3 |
+
If you have a large change, pay special attention to this paragraph:
|
4 |
+
|
5 |
+
> Before making changes, if you think that your feature will result in more than 100 lines changing, find me and talk to me about the feature you are proposing. It pains me to reject the hard work someone else did, but I won't add everything to the repo, and it's better if the rejection happens before you have to waste time working on the feature.
|
6 |
+
|
7 |
+
Otherwise, after making sure you're following the rules described in wiki page, remove this section and continue on.
|
8 |
+
|
9 |
+
**Describe what this pull request is trying to achieve.**
|
10 |
+
|
11 |
+
A clear and concise description of what you're trying to accomplish with this, so your intent doesn't have to be extracted from your code.
|
12 |
+
|
13 |
+
**Additional notes and description of your changes**
|
14 |
+
|
15 |
+
More technical discussion about your changes go here, plus anything that a maintainer might have to specifically take a look at, or be wary of.
|
16 |
+
|
17 |
+
**Environment this was tested in**
|
18 |
+
|
19 |
+
List the environment you have developed / tested this on. As per the contributing page, changes should be able to work on Windows out of the box.
|
20 |
+
- OS: [e.g. Windows, Linux]
|
21 |
+
- Browser [e.g. chrome, safari]
|
22 |
+
- Graphics card [e.g. NVIDIA RTX 2080 8GB, AMD RX 6600 8GB]
|
23 |
+
|
24 |
+
**Screenshots or videos of your changes**
|
25 |
+
|
26 |
+
If applicable, screenshots or a video showing off your changes. If it edits an existing UI, it should ideally contain a comparison of what used to be there, before your changes were made.
|
27 |
+
|
28 |
+
This is **required** for anything that touches the user interface.
|
.github/workflows/on_pull_request.yaml
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# See https://github.com/actions/starter-workflows/blob/1067f16ad8a1eac328834e4b0ae24f7d206f810d/ci/pylint.yml for original reference file
|
2 |
+
name: Run Linting/Formatting on Pull Requests
|
3 |
+
|
4 |
+
on:
|
5 |
+
- push
|
6 |
+
- pull_request
|
7 |
+
# See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#onpull_requestpull_request_targetbranchesbranches-ignore for syntax docs
|
8 |
+
# if you want to filter out branches, delete the `- pull_request` and uncomment these lines :
|
9 |
+
# pull_request:
|
10 |
+
# branches:
|
11 |
+
# - master
|
12 |
+
# branches-ignore:
|
13 |
+
# - development
|
14 |
+
|
15 |
+
jobs:
|
16 |
+
lint:
|
17 |
+
runs-on: ubuntu-latest
|
18 |
+
steps:
|
19 |
+
- name: Checkout Code
|
20 |
+
uses: actions/checkout@v3
|
21 |
+
- name: Set up Python 3.10
|
22 |
+
uses: actions/setup-python@v3
|
23 |
+
with:
|
24 |
+
python-version: 3.10.6
|
25 |
+
- uses: actions/cache@v2
|
26 |
+
with:
|
27 |
+
path: ~/.cache/pip
|
28 |
+
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
|
29 |
+
restore-keys: |
|
30 |
+
${{ runner.os }}-pip-
|
31 |
+
- name: Install PyLint
|
32 |
+
run: |
|
33 |
+
python -m pip install --upgrade pip
|
34 |
+
pip install pylint
|
35 |
+
# This lets PyLint check to see if it can resolve imports
|
36 |
+
- name: Install dependencies
|
37 |
+
run : |
|
38 |
+
export COMMANDLINE_ARGS="--skip-torch-cuda-test --exit"
|
39 |
+
python launch.py
|
40 |
+
- name: Analysing the code with pylint
|
41 |
+
run: |
|
42 |
+
pylint $(git ls-files '*.py')
|
.github/workflows/run_tests.yaml
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Run basic features tests on CPU with empty SD model
|
2 |
+
|
3 |
+
on:
|
4 |
+
- push
|
5 |
+
- pull_request
|
6 |
+
|
7 |
+
jobs:
|
8 |
+
test:
|
9 |
+
runs-on: ubuntu-latest
|
10 |
+
steps:
|
11 |
+
- name: Checkout Code
|
12 |
+
uses: actions/checkout@v3
|
13 |
+
- name: Set up Python 3.10
|
14 |
+
uses: actions/setup-python@v4
|
15 |
+
with:
|
16 |
+
python-version: 3.10.6
|
17 |
+
- uses: actions/cache@v3
|
18 |
+
with:
|
19 |
+
path: ~/.cache/pip
|
20 |
+
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
|
21 |
+
restore-keys: ${{ runner.os }}-pip-
|
22 |
+
- name: Run tests
|
23 |
+
run: python launch.py --tests basic_features --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test
|
24 |
+
- name: Upload main app stdout-stderr
|
25 |
+
uses: actions/upload-artifact@v3
|
26 |
+
if: always()
|
27 |
+
with:
|
28 |
+
name: stdout-stderr
|
29 |
+
path: |
|
30 |
+
test/stdout.txt
|
31 |
+
test/stderr.txt
|
.gitignore
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
*.ckpt
|
3 |
+
*.safetensors
|
4 |
+
*.pth
|
5 |
+
/ESRGAN/*
|
6 |
+
/SwinIR/*
|
7 |
+
/repositories
|
8 |
+
/venv
|
9 |
+
/tmp
|
10 |
+
/model.ckpt
|
11 |
+
/models/**/*
|
12 |
+
/GFPGANv1.3.pth
|
13 |
+
/gfpgan/weights/*.pth
|
14 |
+
/ui-config.json
|
15 |
+
/outputs
|
16 |
+
/config.json
|
17 |
+
/log
|
18 |
+
/webui.settings.bat
|
19 |
+
/embeddings
|
20 |
+
/styles.csv
|
21 |
+
/params.txt
|
22 |
+
/styles.csv.bak
|
23 |
+
/webui-user.bat
|
24 |
+
/webui-user.sh
|
25 |
+
/interrogate
|
26 |
+
/user.css
|
27 |
+
/.idea
|
28 |
+
notification.mp3
|
29 |
+
/SwinIR
|
30 |
+
/textual_inversion
|
31 |
+
.vscode
|
32 |
+
/extensions
|
33 |
+
/test/stdout.txt
|
34 |
+
/test/stderr.txt
|
.pylintrc
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# See https://pylint.pycqa.org/en/latest/user_guide/messages/message_control.html
|
2 |
+
[MESSAGES CONTROL]
|
3 |
+
disable=C,R,W,E,I
|
CODEOWNERS
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
* @AUTOMATIC1111
|
2 |
+
|
3 |
+
# if you were managing a localization and were removed from this file, this is because
|
4 |
+
# the intended way to do localizations now is via extensions. See:
|
5 |
+
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Developing-extensions
|
6 |
+
# Make a repo with your localization and since you are still listed as a collaborator
|
7 |
+
# you can add it to the wiki page yourself. This change is because some people complained
|
8 |
+
# the git commit log is cluttered with things unrelated to almost everyone and
|
9 |
+
# because I believe this is the best overall for the project to handle localizations almost
|
10 |
+
# entirely without my oversight.
|
11 |
+
|
12 |
+
|
README.md
CHANGED
@@ -1,3 +1,151 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Stable Diffusion web UI
|
2 |
+
A browser interface based on Gradio library for Stable Diffusion.
|
3 |
+
|
4 |
+
![](txt2img_Screenshot.png)
|
5 |
+
|
6 |
+
Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Scripts) wiki page for extra scripts developed by users.
|
7 |
+
|
8 |
+
## Features
|
9 |
+
[Detailed feature showcase with images](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features):
|
10 |
+
- Original txt2img and img2img modes
|
11 |
+
- One click install and run script (but you still must install python and git)
|
12 |
+
- Outpainting
|
13 |
+
- Inpainting
|
14 |
+
- Color Sketch
|
15 |
+
- Prompt Matrix
|
16 |
+
- Stable Diffusion Upscale
|
17 |
+
- Attention, specify parts of text that the model should pay more attention to
|
18 |
+
- a man in a ((tuxedo)) - will pay more attention to tuxedo
|
19 |
+
- a man in a (tuxedo:1.21) - alternative syntax
|
20 |
+
- select text and press ctrl+up or ctrl+down to automatically adjust attention to selected text (code contributed by anonymous user)
|
21 |
+
- Loopback, run img2img processing multiple times
|
22 |
+
- X/Y plot, a way to draw a 2 dimensional plot of images with different parameters
|
23 |
+
- Textual Inversion
|
24 |
+
- have as many embeddings as you want and use any names you like for them
|
25 |
+
- use multiple embeddings with different numbers of vectors per token
|
26 |
+
- works with half precision floating point numbers
|
27 |
+
- train embeddings on 8GB (also reports of 6GB working)
|
28 |
+
- Extras tab with:
|
29 |
+
- GFPGAN, neural network that fixes faces
|
30 |
+
- CodeFormer, face restoration tool as an alternative to GFPGAN
|
31 |
+
- RealESRGAN, neural network upscaler
|
32 |
+
- ESRGAN, neural network upscaler with a lot of third party models
|
33 |
+
- SwinIR and Swin2SR([see here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/2092)), neural network upscalers
|
34 |
+
- LDSR, Latent diffusion super resolution upscaling
|
35 |
+
- Resizing aspect ratio options
|
36 |
+
- Sampling method selection
|
37 |
+
- Adjust sampler eta values (noise multiplier)
|
38 |
+
- More advanced noise setting options
|
39 |
+
- Interrupt processing at any time
|
40 |
+
- 4GB video card support (also reports of 2GB working)
|
41 |
+
- Correct seeds for batches
|
42 |
+
- Live prompt token length validation
|
43 |
+
- Generation parameters
|
44 |
+
- parameters you used to generate images are saved with that image
|
45 |
+
- in PNG chunks for PNG, in EXIF for JPEG
|
46 |
+
- can drag the image to PNG info tab to restore generation parameters and automatically copy them into UI
|
47 |
+
- can be disabled in settings
|
48 |
+
- drag and drop an image/text-parameters to promptbox
|
49 |
+
- Read Generation Parameters Button, loads parameters in promptbox to UI
|
50 |
+
- Settings page
|
51 |
+
- Running arbitrary python code from UI (must run with --allow-code to enable)
|
52 |
+
- Mouseover hints for most UI elements
|
53 |
+
- Possible to change defaults/mix/max/step values for UI elements via text config
|
54 |
+
- Random artist button
|
55 |
+
- Tiling support, a checkbox to create images that can be tiled like textures
|
56 |
+
- Progress bar and live image generation preview
|
57 |
+
- Negative prompt, an extra text field that allows you to list what you don't want to see in generated image
|
58 |
+
- Styles, a way to save part of prompt and easily apply them via dropdown later
|
59 |
+
- Variations, a way to generate same image but with tiny differences
|
60 |
+
- Seed resizing, a way to generate same image but at slightly different resolution
|
61 |
+
- CLIP interrogator, a button that tries to guess prompt from an image
|
62 |
+
- Prompt Editing, a way to change prompt mid-generation, say to start making a watermelon and switch to anime girl midway
|
63 |
+
- Batch Processing, process a group of files using img2img
|
64 |
+
- Img2img Alternative, reverse Euler method of cross attention control
|
65 |
+
- Highres Fix, a convenience option to produce high resolution pictures in one click without usual distortions
|
66 |
+
- Reloading checkpoints on the fly
|
67 |
+
- Checkpoint Merger, a tab that allows you to merge up to 3 checkpoints into one
|
68 |
+
- [Custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Scripts) with many extensions from community
|
69 |
+
- [Composable-Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/), a way to use multiple prompts at once
|
70 |
+
- separate prompts using uppercase `AND`
|
71 |
+
- also supports weights for prompts: `a cat :1.2 AND a dog AND a penguin :2.2`
|
72 |
+
- No token limit for prompts (original stable diffusion lets you use up to 75 tokens)
|
73 |
+
- DeepDanbooru integration, creates danbooru style tags for anime prompts
|
74 |
+
- [xformers](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers), major speed increase for select cards: (add --xformers to commandline args)
|
75 |
+
- via extension: [History tab](https://github.com/yfszzx/stable-diffusion-webui-images-browser): view, direct and delete images conveniently within the UI
|
76 |
+
- Generate forever option
|
77 |
+
- Training tab
|
78 |
+
- hypernetworks and embeddings options
|
79 |
+
- Preprocessing images: cropping, mirroring, autotagging using BLIP or deepdanbooru (for anime)
|
80 |
+
- Clip skip
|
81 |
+
- Use Hypernetworks
|
82 |
+
- Use VAEs
|
83 |
+
- Estimated completion time in progress bar
|
84 |
+
- API
|
85 |
+
- Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML.
|
86 |
+
- via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embeds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients))
|
87 |
+
- [Stable Diffusion 2.0](https://github.com/Stability-AI/stablediffusion) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20) for instructions
|
88 |
+
|
89 |
+
## Installation and Running
|
90 |
+
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
|
91 |
+
|
92 |
+
Alternatively, use online services (like Google Colab):
|
93 |
+
|
94 |
+
- [List of Online Services](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Online-Services)
|
95 |
+
|
96 |
+
### Automatic Installation on Windows
|
97 |
+
1. Install [Python 3.10.6](https://www.python.org/downloads/windows/), checking "Add Python to PATH"
|
98 |
+
2. Install [git](https://git-scm.com/download/win).
|
99 |
+
3. Download the stable-diffusion-webui repository, for example by running `git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git`.
|
100 |
+
4. Place `model.ckpt` in the `models` directory (see [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) for where to get it).
|
101 |
+
5. _*(Optional)*_ Place `GFPGANv1.4.pth` in the base directory, alongside `webui.py` (see [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) for where to get it).
|
102 |
+
6. Run `webui-user.bat` from Windows Explorer as normal, non-administrator, user.
|
103 |
+
|
104 |
+
### Automatic Installation on Linux
|
105 |
+
1. Install the dependencies:
|
106 |
+
```bash
|
107 |
+
# Debian-based:
|
108 |
+
sudo apt install wget git python3 python3-venv
|
109 |
+
# Red Hat-based:
|
110 |
+
sudo dnf install wget git python3
|
111 |
+
# Arch-based:
|
112 |
+
sudo pacman -S wget git python3
|
113 |
+
```
|
114 |
+
2. To install in `/home/$(whoami)/stable-diffusion-webui/`, run:
|
115 |
+
```bash
|
116 |
+
bash <(wget -qO- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/master/webui.sh)
|
117 |
+
```
|
118 |
+
|
119 |
+
### Installation on Apple Silicon
|
120 |
+
|
121 |
+
Find the instructions [here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Installation-on-Apple-Silicon).
|
122 |
+
|
123 |
+
## Contributing
|
124 |
+
Here's how to add code to this repo: [Contributing](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing)
|
125 |
+
|
126 |
+
## Documentation
|
127 |
+
The documentation was moved from this README over to the project's [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki).
|
128 |
+
|
129 |
+
## Credits
|
130 |
+
- Stable Diffusion - https://github.com/CompVis/stable-diffusion, https://github.com/CompVis/taming-transformers
|
131 |
+
- k-diffusion - https://github.com/crowsonkb/k-diffusion.git
|
132 |
+
- GFPGAN - https://github.com/TencentARC/GFPGAN.git
|
133 |
+
- CodeFormer - https://github.com/sczhou/CodeFormer
|
134 |
+
- ESRGAN - https://github.com/xinntao/ESRGAN
|
135 |
+
- SwinIR - https://github.com/JingyunLiang/SwinIR
|
136 |
+
- Swin2SR - https://github.com/mv-lab/swin2sr
|
137 |
+
- LDSR - https://github.com/Hafiidz/latent-diffusion
|
138 |
+
- MiDaS - https://github.com/isl-org/MiDaS
|
139 |
+
- Ideas for optimizations - https://github.com/basujindal/stable-diffusion
|
140 |
+
- Cross Attention layer optimization - Doggettx - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing.
|
141 |
+
- Cross Attention layer optimization - InvokeAI, lstein - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion)
|
142 |
+
- Textual Inversion - Rinon Gal - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas).
|
143 |
+
- Idea for SD upscale - https://github.com/jquesnelle/txt2imghd
|
144 |
+
- Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot
|
145 |
+
- CLIP interrogator idea and borrowing some code - https://github.com/pharmapsychotic/clip-interrogator
|
146 |
+
- Idea for Composable Diffusion - https://github.com/energy-based-model/Compositional-Visual-Generation-with-Composable-Diffusion-Models-PyTorch
|
147 |
+
- xformers - https://github.com/facebookresearch/xformers
|
148 |
+
- DeepDanbooru - interrogator for anime diffusers https://github.com/KichangKim/DeepDanbooru
|
149 |
+
- Security advice - RyotaK
|
150 |
+
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
|
151 |
+
- (You)
|
artists.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
configs/alt-diffusion-inference.yaml
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 1.0e-04
|
3 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
4 |
+
params:
|
5 |
+
linear_start: 0.00085
|
6 |
+
linear_end: 0.0120
|
7 |
+
num_timesteps_cond: 1
|
8 |
+
log_every_t: 200
|
9 |
+
timesteps: 1000
|
10 |
+
first_stage_key: "jpg"
|
11 |
+
cond_stage_key: "txt"
|
12 |
+
image_size: 64
|
13 |
+
channels: 4
|
14 |
+
cond_stage_trainable: false # Note: different from the one we trained before
|
15 |
+
conditioning_key: crossattn
|
16 |
+
monitor: val/loss_simple_ema
|
17 |
+
scale_factor: 0.18215
|
18 |
+
use_ema: False
|
19 |
+
|
20 |
+
scheduler_config: # 10000 warmup steps
|
21 |
+
target: ldm.lr_scheduler.LambdaLinearScheduler
|
22 |
+
params:
|
23 |
+
warm_up_steps: [ 10000 ]
|
24 |
+
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
25 |
+
f_start: [ 1.e-6 ]
|
26 |
+
f_max: [ 1. ]
|
27 |
+
f_min: [ 1. ]
|
28 |
+
|
29 |
+
unet_config:
|
30 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
31 |
+
params:
|
32 |
+
image_size: 32 # unused
|
33 |
+
in_channels: 4
|
34 |
+
out_channels: 4
|
35 |
+
model_channels: 320
|
36 |
+
attention_resolutions: [ 4, 2, 1 ]
|
37 |
+
num_res_blocks: 2
|
38 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
39 |
+
num_heads: 8
|
40 |
+
use_spatial_transformer: True
|
41 |
+
transformer_depth: 1
|
42 |
+
context_dim: 768
|
43 |
+
use_checkpoint: True
|
44 |
+
legacy: False
|
45 |
+
|
46 |
+
first_stage_config:
|
47 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
48 |
+
params:
|
49 |
+
embed_dim: 4
|
50 |
+
monitor: val/rec_loss
|
51 |
+
ddconfig:
|
52 |
+
double_z: true
|
53 |
+
z_channels: 4
|
54 |
+
resolution: 256
|
55 |
+
in_channels: 3
|
56 |
+
out_ch: 3
|
57 |
+
ch: 128
|
58 |
+
ch_mult:
|
59 |
+
- 1
|
60 |
+
- 2
|
61 |
+
- 4
|
62 |
+
- 4
|
63 |
+
num_res_blocks: 2
|
64 |
+
attn_resolutions: []
|
65 |
+
dropout: 0.0
|
66 |
+
lossconfig:
|
67 |
+
target: torch.nn.Identity
|
68 |
+
|
69 |
+
cond_stage_config:
|
70 |
+
target: modules.xlmr.BertSeriesModelWithTransformation
|
71 |
+
params:
|
72 |
+
name: "XLMR-Large"
|
configs/v1-inference.yaml
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 1.0e-04
|
3 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
4 |
+
params:
|
5 |
+
linear_start: 0.00085
|
6 |
+
linear_end: 0.0120
|
7 |
+
num_timesteps_cond: 1
|
8 |
+
log_every_t: 200
|
9 |
+
timesteps: 1000
|
10 |
+
first_stage_key: "jpg"
|
11 |
+
cond_stage_key: "txt"
|
12 |
+
image_size: 64
|
13 |
+
channels: 4
|
14 |
+
cond_stage_trainable: false # Note: different from the one we trained before
|
15 |
+
conditioning_key: crossattn
|
16 |
+
monitor: val/loss_simple_ema
|
17 |
+
scale_factor: 0.18215
|
18 |
+
use_ema: False
|
19 |
+
|
20 |
+
scheduler_config: # 10000 warmup steps
|
21 |
+
target: ldm.lr_scheduler.LambdaLinearScheduler
|
22 |
+
params:
|
23 |
+
warm_up_steps: [ 10000 ]
|
24 |
+
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
25 |
+
f_start: [ 1.e-6 ]
|
26 |
+
f_max: [ 1. ]
|
27 |
+
f_min: [ 1. ]
|
28 |
+
|
29 |
+
unet_config:
|
30 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
31 |
+
params:
|
32 |
+
image_size: 32 # unused
|
33 |
+
in_channels: 4
|
34 |
+
out_channels: 4
|
35 |
+
model_channels: 320
|
36 |
+
attention_resolutions: [ 4, 2, 1 ]
|
37 |
+
num_res_blocks: 2
|
38 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
39 |
+
num_heads: 8
|
40 |
+
use_spatial_transformer: True
|
41 |
+
transformer_depth: 1
|
42 |
+
context_dim: 768
|
43 |
+
use_checkpoint: True
|
44 |
+
legacy: False
|
45 |
+
|
46 |
+
first_stage_config:
|
47 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
48 |
+
params:
|
49 |
+
embed_dim: 4
|
50 |
+
monitor: val/rec_loss
|
51 |
+
ddconfig:
|
52 |
+
double_z: true
|
53 |
+
z_channels: 4
|
54 |
+
resolution: 256
|
55 |
+
in_channels: 3
|
56 |
+
out_ch: 3
|
57 |
+
ch: 128
|
58 |
+
ch_mult:
|
59 |
+
- 1
|
60 |
+
- 2
|
61 |
+
- 4
|
62 |
+
- 4
|
63 |
+
num_res_blocks: 2
|
64 |
+
attn_resolutions: []
|
65 |
+
dropout: 0.0
|
66 |
+
lossconfig:
|
67 |
+
target: torch.nn.Identity
|
68 |
+
|
69 |
+
cond_stage_config:
|
70 |
+
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
environment-wsl2.yaml
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: automatic
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- defaults
|
5 |
+
dependencies:
|
6 |
+
- python=3.10
|
7 |
+
- pip=22.2.2
|
8 |
+
- cudatoolkit=11.3
|
9 |
+
- pytorch=1.12.1
|
10 |
+
- torchvision=0.13.1
|
11 |
+
- numpy=1.23.1
|
extensions-builtin/LDSR/ldsr_model_arch.py
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gc
|
3 |
+
import time
|
4 |
+
import warnings
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torchvision
|
9 |
+
from PIL import Image
|
10 |
+
from einops import rearrange, repeat
|
11 |
+
from omegaconf import OmegaConf
|
12 |
+
import safetensors.torch
|
13 |
+
|
14 |
+
from ldm.models.diffusion.ddim import DDIMSampler
|
15 |
+
from ldm.util import instantiate_from_config, ismap
|
16 |
+
from modules import shared, sd_hijack
|
17 |
+
|
18 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
19 |
+
|
20 |
+
cached_ldsr_model: torch.nn.Module = None
|
21 |
+
|
22 |
+
|
23 |
+
# Create LDSR Class
|
24 |
+
class LDSR:
|
25 |
+
def load_model_from_config(self, half_attention):
|
26 |
+
global cached_ldsr_model
|
27 |
+
|
28 |
+
if shared.opts.ldsr_cached and cached_ldsr_model is not None:
|
29 |
+
print("Loading model from cache")
|
30 |
+
model: torch.nn.Module = cached_ldsr_model
|
31 |
+
else:
|
32 |
+
print(f"Loading model from {self.modelPath}")
|
33 |
+
_, extension = os.path.splitext(self.modelPath)
|
34 |
+
if extension.lower() == ".safetensors":
|
35 |
+
pl_sd = safetensors.torch.load_file(self.modelPath, device="cpu")
|
36 |
+
else:
|
37 |
+
pl_sd = torch.load(self.modelPath, map_location="cpu")
|
38 |
+
sd = pl_sd["state_dict"] if "state_dict" in pl_sd else pl_sd
|
39 |
+
config = OmegaConf.load(self.yamlPath)
|
40 |
+
config.model.target = "ldm.models.diffusion.ddpm.LatentDiffusionV1"
|
41 |
+
model: torch.nn.Module = instantiate_from_config(config.model)
|
42 |
+
model.load_state_dict(sd, strict=False)
|
43 |
+
model = model.to(shared.device)
|
44 |
+
if half_attention:
|
45 |
+
model = model.half()
|
46 |
+
if shared.cmd_opts.opt_channelslast:
|
47 |
+
model = model.to(memory_format=torch.channels_last)
|
48 |
+
|
49 |
+
sd_hijack.model_hijack.hijack(model) # apply optimization
|
50 |
+
model.eval()
|
51 |
+
|
52 |
+
if shared.opts.ldsr_cached:
|
53 |
+
cached_ldsr_model = model
|
54 |
+
|
55 |
+
return {"model": model}
|
56 |
+
|
57 |
+
def __init__(self, model_path, yaml_path):
|
58 |
+
self.modelPath = model_path
|
59 |
+
self.yamlPath = yaml_path
|
60 |
+
|
61 |
+
@staticmethod
|
62 |
+
def run(model, selected_path, custom_steps, eta):
|
63 |
+
example = get_cond(selected_path)
|
64 |
+
|
65 |
+
n_runs = 1
|
66 |
+
guider = None
|
67 |
+
ckwargs = None
|
68 |
+
ddim_use_x0_pred = False
|
69 |
+
temperature = 1.
|
70 |
+
eta = eta
|
71 |
+
custom_shape = None
|
72 |
+
|
73 |
+
height, width = example["image"].shape[1:3]
|
74 |
+
split_input = height >= 128 and width >= 128
|
75 |
+
|
76 |
+
if split_input:
|
77 |
+
ks = 128
|
78 |
+
stride = 64
|
79 |
+
vqf = 4 #
|
80 |
+
model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride),
|
81 |
+
"vqf": vqf,
|
82 |
+
"patch_distributed_vq": True,
|
83 |
+
"tie_braker": False,
|
84 |
+
"clip_max_weight": 0.5,
|
85 |
+
"clip_min_weight": 0.01,
|
86 |
+
"clip_max_tie_weight": 0.5,
|
87 |
+
"clip_min_tie_weight": 0.01}
|
88 |
+
else:
|
89 |
+
if hasattr(model, "split_input_params"):
|
90 |
+
delattr(model, "split_input_params")
|
91 |
+
|
92 |
+
x_t = None
|
93 |
+
logs = None
|
94 |
+
for n in range(n_runs):
|
95 |
+
if custom_shape is not None:
|
96 |
+
x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
|
97 |
+
x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0])
|
98 |
+
|
99 |
+
logs = make_convolutional_sample(example, model,
|
100 |
+
custom_steps=custom_steps,
|
101 |
+
eta=eta, quantize_x0=False,
|
102 |
+
custom_shape=custom_shape,
|
103 |
+
temperature=temperature, noise_dropout=0.,
|
104 |
+
corrector=guider, corrector_kwargs=ckwargs, x_T=x_t,
|
105 |
+
ddim_use_x0_pred=ddim_use_x0_pred
|
106 |
+
)
|
107 |
+
return logs
|
108 |
+
|
109 |
+
def super_resolution(self, image, steps=100, target_scale=2, half_attention=False):
|
110 |
+
model = self.load_model_from_config(half_attention)
|
111 |
+
|
112 |
+
# Run settings
|
113 |
+
diffusion_steps = int(steps)
|
114 |
+
eta = 1.0
|
115 |
+
|
116 |
+
down_sample_method = 'Lanczos'
|
117 |
+
|
118 |
+
gc.collect()
|
119 |
+
if torch.cuda.is_available:
|
120 |
+
torch.cuda.empty_cache()
|
121 |
+
|
122 |
+
im_og = image
|
123 |
+
width_og, height_og = im_og.size
|
124 |
+
# If we can adjust the max upscale size, then the 4 below should be our variable
|
125 |
+
down_sample_rate = target_scale / 4
|
126 |
+
wd = width_og * down_sample_rate
|
127 |
+
hd = height_og * down_sample_rate
|
128 |
+
width_downsampled_pre = int(np.ceil(wd))
|
129 |
+
height_downsampled_pre = int(np.ceil(hd))
|
130 |
+
|
131 |
+
if down_sample_rate != 1:
|
132 |
+
print(
|
133 |
+
f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]')
|
134 |
+
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
|
135 |
+
else:
|
136 |
+
print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
|
137 |
+
|
138 |
+
# pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts
|
139 |
+
pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size
|
140 |
+
im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
|
141 |
+
|
142 |
+
logs = self.run(model["model"], im_padded, diffusion_steps, eta)
|
143 |
+
|
144 |
+
sample = logs["sample"]
|
145 |
+
sample = sample.detach().cpu()
|
146 |
+
sample = torch.clamp(sample, -1., 1.)
|
147 |
+
sample = (sample + 1.) / 2. * 255
|
148 |
+
sample = sample.numpy().astype(np.uint8)
|
149 |
+
sample = np.transpose(sample, (0, 2, 3, 1))
|
150 |
+
a = Image.fromarray(sample[0])
|
151 |
+
|
152 |
+
# remove padding
|
153 |
+
a = a.crop((0, 0) + tuple(np.array(im_og.size) * 4))
|
154 |
+
|
155 |
+
del model
|
156 |
+
gc.collect()
|
157 |
+
if torch.cuda.is_available:
|
158 |
+
torch.cuda.empty_cache()
|
159 |
+
|
160 |
+
return a
|
161 |
+
|
162 |
+
|
163 |
+
def get_cond(selected_path):
|
164 |
+
example = dict()
|
165 |
+
up_f = 4
|
166 |
+
c = selected_path.convert('RGB')
|
167 |
+
c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
|
168 |
+
c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]],
|
169 |
+
antialias=True)
|
170 |
+
c_up = rearrange(c_up, '1 c h w -> 1 h w c')
|
171 |
+
c = rearrange(c, '1 c h w -> 1 h w c')
|
172 |
+
c = 2. * c - 1.
|
173 |
+
|
174 |
+
c = c.to(shared.device)
|
175 |
+
example["LR_image"] = c
|
176 |
+
example["image"] = c_up
|
177 |
+
|
178 |
+
return example
|
179 |
+
|
180 |
+
|
181 |
+
@torch.no_grad()
|
182 |
+
def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None,
|
183 |
+
mask=None, x0=None, quantize_x0=False, temperature=1., score_corrector=None,
|
184 |
+
corrector_kwargs=None, x_t=None
|
185 |
+
):
|
186 |
+
ddim = DDIMSampler(model)
|
187 |
+
bs = shape[0]
|
188 |
+
shape = shape[1:]
|
189 |
+
print(f"Sampling with eta = {eta}; steps: {steps}")
|
190 |
+
samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback,
|
191 |
+
normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta,
|
192 |
+
mask=mask, x0=x0, temperature=temperature, verbose=False,
|
193 |
+
score_corrector=score_corrector,
|
194 |
+
corrector_kwargs=corrector_kwargs, x_t=x_t)
|
195 |
+
|
196 |
+
return samples, intermediates
|
197 |
+
|
198 |
+
|
199 |
+
@torch.no_grad()
|
200 |
+
def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
|
201 |
+
corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False):
|
202 |
+
log = dict()
|
203 |
+
|
204 |
+
z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
|
205 |
+
return_first_stage_outputs=True,
|
206 |
+
force_c_encode=not (hasattr(model, 'split_input_params')
|
207 |
+
and model.cond_stage_key == 'coordinates_bbox'),
|
208 |
+
return_original_cond=True)
|
209 |
+
|
210 |
+
if custom_shape is not None:
|
211 |
+
z = torch.randn(custom_shape)
|
212 |
+
print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}")
|
213 |
+
|
214 |
+
z0 = None
|
215 |
+
|
216 |
+
log["input"] = x
|
217 |
+
log["reconstruction"] = xrec
|
218 |
+
|
219 |
+
if ismap(xc):
|
220 |
+
log["original_conditioning"] = model.to_rgb(xc)
|
221 |
+
if hasattr(model, 'cond_stage_key'):
|
222 |
+
log[model.cond_stage_key] = model.to_rgb(xc)
|
223 |
+
|
224 |
+
else:
|
225 |
+
log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x)
|
226 |
+
if model.cond_stage_model:
|
227 |
+
log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x)
|
228 |
+
if model.cond_stage_key == 'class_label':
|
229 |
+
log[model.cond_stage_key] = xc[model.cond_stage_key]
|
230 |
+
|
231 |
+
with model.ema_scope("Plotting"):
|
232 |
+
t0 = time.time()
|
233 |
+
|
234 |
+
sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape,
|
235 |
+
eta=eta,
|
236 |
+
quantize_x0=quantize_x0, mask=None, x0=z0,
|
237 |
+
temperature=temperature, score_corrector=corrector, corrector_kwargs=corrector_kwargs,
|
238 |
+
x_t=x_T)
|
239 |
+
t1 = time.time()
|
240 |
+
|
241 |
+
if ddim_use_x0_pred:
|
242 |
+
sample = intermediates['pred_x0'][-1]
|
243 |
+
|
244 |
+
x_sample = model.decode_first_stage(sample)
|
245 |
+
|
246 |
+
try:
|
247 |
+
x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
|
248 |
+
log["sample_noquant"] = x_sample_noquant
|
249 |
+
log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
|
250 |
+
except:
|
251 |
+
pass
|
252 |
+
|
253 |
+
log["sample"] = x_sample
|
254 |
+
log["time"] = t1 - t0
|
255 |
+
|
256 |
+
return log
|
extensions-builtin/LDSR/preload.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from modules import paths
|
3 |
+
|
4 |
+
|
5 |
+
def preload(parser):
|
6 |
+
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(paths.models_path, 'LDSR'))
|
extensions-builtin/LDSR/scripts/ldsr_model.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import traceback
|
4 |
+
|
5 |
+
from basicsr.utils.download_util import load_file_from_url
|
6 |
+
|
7 |
+
from modules.upscaler import Upscaler, UpscalerData
|
8 |
+
from ldsr_model_arch import LDSR
|
9 |
+
from modules import shared, script_callbacks
|
10 |
+
import sd_hijack_autoencoder, sd_hijack_ddpm_v1
|
11 |
+
|
12 |
+
|
13 |
+
class UpscalerLDSR(Upscaler):
|
14 |
+
def __init__(self, user_path):
|
15 |
+
self.name = "LDSR"
|
16 |
+
self.user_path = user_path
|
17 |
+
self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
|
18 |
+
self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
|
19 |
+
super().__init__()
|
20 |
+
scaler_data = UpscalerData("LDSR", None, self)
|
21 |
+
self.scalers = [scaler_data]
|
22 |
+
|
23 |
+
def load_model(self, path: str):
|
24 |
+
# Remove incorrect project.yaml file if too big
|
25 |
+
yaml_path = os.path.join(self.model_path, "project.yaml")
|
26 |
+
old_model_path = os.path.join(self.model_path, "model.pth")
|
27 |
+
new_model_path = os.path.join(self.model_path, "model.ckpt")
|
28 |
+
safetensors_model_path = os.path.join(self.model_path, "model.safetensors")
|
29 |
+
if os.path.exists(yaml_path):
|
30 |
+
statinfo = os.stat(yaml_path)
|
31 |
+
if statinfo.st_size >= 10485760:
|
32 |
+
print("Removing invalid LDSR YAML file.")
|
33 |
+
os.remove(yaml_path)
|
34 |
+
if os.path.exists(old_model_path):
|
35 |
+
print("Renaming model from model.pth to model.ckpt")
|
36 |
+
os.rename(old_model_path, new_model_path)
|
37 |
+
if os.path.exists(safetensors_model_path):
|
38 |
+
model = safetensors_model_path
|
39 |
+
else:
|
40 |
+
model = load_file_from_url(url=self.model_url, model_dir=self.model_path,
|
41 |
+
file_name="model.ckpt", progress=True)
|
42 |
+
yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path,
|
43 |
+
file_name="project.yaml", progress=True)
|
44 |
+
|
45 |
+
try:
|
46 |
+
return LDSR(model, yaml)
|
47 |
+
|
48 |
+
except Exception:
|
49 |
+
print("Error importing LDSR:", file=sys.stderr)
|
50 |
+
print(traceback.format_exc(), file=sys.stderr)
|
51 |
+
return None
|
52 |
+
|
53 |
+
def do_upscale(self, img, path):
|
54 |
+
ldsr = self.load_model(path)
|
55 |
+
if ldsr is None:
|
56 |
+
print("NO LDSR!")
|
57 |
+
return img
|
58 |
+
ddim_steps = shared.opts.ldsr_steps
|
59 |
+
return ldsr.super_resolution(img, ddim_steps, self.scale)
|
60 |
+
|
61 |
+
|
62 |
+
def on_ui_settings():
|
63 |
+
import gradio as gr
|
64 |
+
|
65 |
+
shared.opts.add_option("ldsr_steps", shared.OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}, section=('upscaling', "Upscaling")))
|
66 |
+
shared.opts.add_option("ldsr_cached", shared.OptionInfo(False, "Cache LDSR model in memory", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")))
|
67 |
+
|
68 |
+
|
69 |
+
script_callbacks.on_ui_settings(on_ui_settings)
|
extensions-builtin/LDSR/sd_hijack_autoencoder.py
ADDED
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo
|
2 |
+
# The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo
|
3 |
+
# As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import pytorch_lightning as pl
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from contextlib import contextmanager
|
9 |
+
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
10 |
+
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
11 |
+
from ldm.util import instantiate_from_config
|
12 |
+
|
13 |
+
import ldm.models.autoencoder
|
14 |
+
|
15 |
+
class VQModel(pl.LightningModule):
|
16 |
+
def __init__(self,
|
17 |
+
ddconfig,
|
18 |
+
lossconfig,
|
19 |
+
n_embed,
|
20 |
+
embed_dim,
|
21 |
+
ckpt_path=None,
|
22 |
+
ignore_keys=[],
|
23 |
+
image_key="image",
|
24 |
+
colorize_nlabels=None,
|
25 |
+
monitor=None,
|
26 |
+
batch_resize_range=None,
|
27 |
+
scheduler_config=None,
|
28 |
+
lr_g_factor=1.0,
|
29 |
+
remap=None,
|
30 |
+
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
31 |
+
use_ema=False
|
32 |
+
):
|
33 |
+
super().__init__()
|
34 |
+
self.embed_dim = embed_dim
|
35 |
+
self.n_embed = n_embed
|
36 |
+
self.image_key = image_key
|
37 |
+
self.encoder = Encoder(**ddconfig)
|
38 |
+
self.decoder = Decoder(**ddconfig)
|
39 |
+
self.loss = instantiate_from_config(lossconfig)
|
40 |
+
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
|
41 |
+
remap=remap,
|
42 |
+
sane_index_shape=sane_index_shape)
|
43 |
+
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
|
44 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
45 |
+
if colorize_nlabels is not None:
|
46 |
+
assert type(colorize_nlabels)==int
|
47 |
+
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
48 |
+
if monitor is not None:
|
49 |
+
self.monitor = monitor
|
50 |
+
self.batch_resize_range = batch_resize_range
|
51 |
+
if self.batch_resize_range is not None:
|
52 |
+
print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
|
53 |
+
|
54 |
+
self.use_ema = use_ema
|
55 |
+
if self.use_ema:
|
56 |
+
self.model_ema = LitEma(self)
|
57 |
+
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
58 |
+
|
59 |
+
if ckpt_path is not None:
|
60 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
61 |
+
self.scheduler_config = scheduler_config
|
62 |
+
self.lr_g_factor = lr_g_factor
|
63 |
+
|
64 |
+
@contextmanager
|
65 |
+
def ema_scope(self, context=None):
|
66 |
+
if self.use_ema:
|
67 |
+
self.model_ema.store(self.parameters())
|
68 |
+
self.model_ema.copy_to(self)
|
69 |
+
if context is not None:
|
70 |
+
print(f"{context}: Switched to EMA weights")
|
71 |
+
try:
|
72 |
+
yield None
|
73 |
+
finally:
|
74 |
+
if self.use_ema:
|
75 |
+
self.model_ema.restore(self.parameters())
|
76 |
+
if context is not None:
|
77 |
+
print(f"{context}: Restored training weights")
|
78 |
+
|
79 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
80 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
81 |
+
keys = list(sd.keys())
|
82 |
+
for k in keys:
|
83 |
+
for ik in ignore_keys:
|
84 |
+
if k.startswith(ik):
|
85 |
+
print("Deleting key {} from state_dict.".format(k))
|
86 |
+
del sd[k]
|
87 |
+
missing, unexpected = self.load_state_dict(sd, strict=False)
|
88 |
+
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
89 |
+
if len(missing) > 0:
|
90 |
+
print(f"Missing Keys: {missing}")
|
91 |
+
print(f"Unexpected Keys: {unexpected}")
|
92 |
+
|
93 |
+
def on_train_batch_end(self, *args, **kwargs):
|
94 |
+
if self.use_ema:
|
95 |
+
self.model_ema(self)
|
96 |
+
|
97 |
+
def encode(self, x):
|
98 |
+
h = self.encoder(x)
|
99 |
+
h = self.quant_conv(h)
|
100 |
+
quant, emb_loss, info = self.quantize(h)
|
101 |
+
return quant, emb_loss, info
|
102 |
+
|
103 |
+
def encode_to_prequant(self, x):
|
104 |
+
h = self.encoder(x)
|
105 |
+
h = self.quant_conv(h)
|
106 |
+
return h
|
107 |
+
|
108 |
+
def decode(self, quant):
|
109 |
+
quant = self.post_quant_conv(quant)
|
110 |
+
dec = self.decoder(quant)
|
111 |
+
return dec
|
112 |
+
|
113 |
+
def decode_code(self, code_b):
|
114 |
+
quant_b = self.quantize.embed_code(code_b)
|
115 |
+
dec = self.decode(quant_b)
|
116 |
+
return dec
|
117 |
+
|
118 |
+
def forward(self, input, return_pred_indices=False):
|
119 |
+
quant, diff, (_,_,ind) = self.encode(input)
|
120 |
+
dec = self.decode(quant)
|
121 |
+
if return_pred_indices:
|
122 |
+
return dec, diff, ind
|
123 |
+
return dec, diff
|
124 |
+
|
125 |
+
def get_input(self, batch, k):
|
126 |
+
x = batch[k]
|
127 |
+
if len(x.shape) == 3:
|
128 |
+
x = x[..., None]
|
129 |
+
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
130 |
+
if self.batch_resize_range is not None:
|
131 |
+
lower_size = self.batch_resize_range[0]
|
132 |
+
upper_size = self.batch_resize_range[1]
|
133 |
+
if self.global_step <= 4:
|
134 |
+
# do the first few batches with max size to avoid later oom
|
135 |
+
new_resize = upper_size
|
136 |
+
else:
|
137 |
+
new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
|
138 |
+
if new_resize != x.shape[2]:
|
139 |
+
x = F.interpolate(x, size=new_resize, mode="bicubic")
|
140 |
+
x = x.detach()
|
141 |
+
return x
|
142 |
+
|
143 |
+
def training_step(self, batch, batch_idx, optimizer_idx):
|
144 |
+
# https://github.com/pytorch/pytorch/issues/37142
|
145 |
+
# try not to fool the heuristics
|
146 |
+
x = self.get_input(batch, self.image_key)
|
147 |
+
xrec, qloss, ind = self(x, return_pred_indices=True)
|
148 |
+
|
149 |
+
if optimizer_idx == 0:
|
150 |
+
# autoencode
|
151 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
152 |
+
last_layer=self.get_last_layer(), split="train",
|
153 |
+
predicted_indices=ind)
|
154 |
+
|
155 |
+
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
156 |
+
return aeloss
|
157 |
+
|
158 |
+
if optimizer_idx == 1:
|
159 |
+
# discriminator
|
160 |
+
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
161 |
+
last_layer=self.get_last_layer(), split="train")
|
162 |
+
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
163 |
+
return discloss
|
164 |
+
|
165 |
+
def validation_step(self, batch, batch_idx):
|
166 |
+
log_dict = self._validation_step(batch, batch_idx)
|
167 |
+
with self.ema_scope():
|
168 |
+
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
|
169 |
+
return log_dict
|
170 |
+
|
171 |
+
def _validation_step(self, batch, batch_idx, suffix=""):
|
172 |
+
x = self.get_input(batch, self.image_key)
|
173 |
+
xrec, qloss, ind = self(x, return_pred_indices=True)
|
174 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
|
175 |
+
self.global_step,
|
176 |
+
last_layer=self.get_last_layer(),
|
177 |
+
split="val"+suffix,
|
178 |
+
predicted_indices=ind
|
179 |
+
)
|
180 |
+
|
181 |
+
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
|
182 |
+
self.global_step,
|
183 |
+
last_layer=self.get_last_layer(),
|
184 |
+
split="val"+suffix,
|
185 |
+
predicted_indices=ind
|
186 |
+
)
|
187 |
+
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
|
188 |
+
self.log(f"val{suffix}/rec_loss", rec_loss,
|
189 |
+
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
190 |
+
self.log(f"val{suffix}/aeloss", aeloss,
|
191 |
+
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
192 |
+
if version.parse(pl.__version__) >= version.parse('1.4.0'):
|
193 |
+
del log_dict_ae[f"val{suffix}/rec_loss"]
|
194 |
+
self.log_dict(log_dict_ae)
|
195 |
+
self.log_dict(log_dict_disc)
|
196 |
+
return self.log_dict
|
197 |
+
|
198 |
+
def configure_optimizers(self):
|
199 |
+
lr_d = self.learning_rate
|
200 |
+
lr_g = self.lr_g_factor*self.learning_rate
|
201 |
+
print("lr_d", lr_d)
|
202 |
+
print("lr_g", lr_g)
|
203 |
+
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
204 |
+
list(self.decoder.parameters())+
|
205 |
+
list(self.quantize.parameters())+
|
206 |
+
list(self.quant_conv.parameters())+
|
207 |
+
list(self.post_quant_conv.parameters()),
|
208 |
+
lr=lr_g, betas=(0.5, 0.9))
|
209 |
+
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
210 |
+
lr=lr_d, betas=(0.5, 0.9))
|
211 |
+
|
212 |
+
if self.scheduler_config is not None:
|
213 |
+
scheduler = instantiate_from_config(self.scheduler_config)
|
214 |
+
|
215 |
+
print("Setting up LambdaLR scheduler...")
|
216 |
+
scheduler = [
|
217 |
+
{
|
218 |
+
'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
|
219 |
+
'interval': 'step',
|
220 |
+
'frequency': 1
|
221 |
+
},
|
222 |
+
{
|
223 |
+
'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
|
224 |
+
'interval': 'step',
|
225 |
+
'frequency': 1
|
226 |
+
},
|
227 |
+
]
|
228 |
+
return [opt_ae, opt_disc], scheduler
|
229 |
+
return [opt_ae, opt_disc], []
|
230 |
+
|
231 |
+
def get_last_layer(self):
|
232 |
+
return self.decoder.conv_out.weight
|
233 |
+
|
234 |
+
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
|
235 |
+
log = dict()
|
236 |
+
x = self.get_input(batch, self.image_key)
|
237 |
+
x = x.to(self.device)
|
238 |
+
if only_inputs:
|
239 |
+
log["inputs"] = x
|
240 |
+
return log
|
241 |
+
xrec, _ = self(x)
|
242 |
+
if x.shape[1] > 3:
|
243 |
+
# colorize with random projection
|
244 |
+
assert xrec.shape[1] > 3
|
245 |
+
x = self.to_rgb(x)
|
246 |
+
xrec = self.to_rgb(xrec)
|
247 |
+
log["inputs"] = x
|
248 |
+
log["reconstructions"] = xrec
|
249 |
+
if plot_ema:
|
250 |
+
with self.ema_scope():
|
251 |
+
xrec_ema, _ = self(x)
|
252 |
+
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
|
253 |
+
log["reconstructions_ema"] = xrec_ema
|
254 |
+
return log
|
255 |
+
|
256 |
+
def to_rgb(self, x):
|
257 |
+
assert self.image_key == "segmentation"
|
258 |
+
if not hasattr(self, "colorize"):
|
259 |
+
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
260 |
+
x = F.conv2d(x, weight=self.colorize)
|
261 |
+
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
262 |
+
return x
|
263 |
+
|
264 |
+
|
265 |
+
class VQModelInterface(VQModel):
|
266 |
+
def __init__(self, embed_dim, *args, **kwargs):
|
267 |
+
super().__init__(embed_dim=embed_dim, *args, **kwargs)
|
268 |
+
self.embed_dim = embed_dim
|
269 |
+
|
270 |
+
def encode(self, x):
|
271 |
+
h = self.encoder(x)
|
272 |
+
h = self.quant_conv(h)
|
273 |
+
return h
|
274 |
+
|
275 |
+
def decode(self, h, force_not_quantize=False):
|
276 |
+
# also go through quantization layer
|
277 |
+
if not force_not_quantize:
|
278 |
+
quant, emb_loss, info = self.quantize(h)
|
279 |
+
else:
|
280 |
+
quant = h
|
281 |
+
quant = self.post_quant_conv(quant)
|
282 |
+
dec = self.decoder(quant)
|
283 |
+
return dec
|
284 |
+
|
285 |
+
setattr(ldm.models.autoencoder, "VQModel", VQModel)
|
286 |
+
setattr(ldm.models.autoencoder, "VQModelInterface", VQModelInterface)
|
extensions-builtin/LDSR/sd_hijack_ddpm_v1.py
ADDED
@@ -0,0 +1,1449 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This script is copied from the compvis/stable-diffusion repo (aka the SD V1 repo)
|
2 |
+
# Original filename: ldm/models/diffusion/ddpm.py
|
3 |
+
# The purpose to reinstate the old DDPM logic which works with VQ, whereas the V2 one doesn't
|
4 |
+
# Some models such as LDSR require VQ to work correctly
|
5 |
+
# The classes are suffixed with "V1" and added back to the "ldm.models.diffusion.ddpm" module
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import numpy as np
|
10 |
+
import pytorch_lightning as pl
|
11 |
+
from torch.optim.lr_scheduler import LambdaLR
|
12 |
+
from einops import rearrange, repeat
|
13 |
+
from contextlib import contextmanager
|
14 |
+
from functools import partial
|
15 |
+
from tqdm import tqdm
|
16 |
+
from torchvision.utils import make_grid
|
17 |
+
from pytorch_lightning.utilities.distributed import rank_zero_only
|
18 |
+
|
19 |
+
from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
|
20 |
+
from ldm.modules.ema import LitEma
|
21 |
+
from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
|
22 |
+
from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
|
23 |
+
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
|
24 |
+
from ldm.models.diffusion.ddim import DDIMSampler
|
25 |
+
|
26 |
+
import ldm.models.diffusion.ddpm
|
27 |
+
|
28 |
+
__conditioning_keys__ = {'concat': 'c_concat',
|
29 |
+
'crossattn': 'c_crossattn',
|
30 |
+
'adm': 'y'}
|
31 |
+
|
32 |
+
|
33 |
+
def disabled_train(self, mode=True):
|
34 |
+
"""Overwrite model.train with this function to make sure train/eval mode
|
35 |
+
does not change anymore."""
|
36 |
+
return self
|
37 |
+
|
38 |
+
|
39 |
+
def uniform_on_device(r1, r2, shape, device):
|
40 |
+
return (r1 - r2) * torch.rand(*shape, device=device) + r2
|
41 |
+
|
42 |
+
|
43 |
+
class DDPMV1(pl.LightningModule):
|
44 |
+
# classic DDPM with Gaussian diffusion, in image space
|
45 |
+
def __init__(self,
|
46 |
+
unet_config,
|
47 |
+
timesteps=1000,
|
48 |
+
beta_schedule="linear",
|
49 |
+
loss_type="l2",
|
50 |
+
ckpt_path=None,
|
51 |
+
ignore_keys=[],
|
52 |
+
load_only_unet=False,
|
53 |
+
monitor="val/loss",
|
54 |
+
use_ema=True,
|
55 |
+
first_stage_key="image",
|
56 |
+
image_size=256,
|
57 |
+
channels=3,
|
58 |
+
log_every_t=100,
|
59 |
+
clip_denoised=True,
|
60 |
+
linear_start=1e-4,
|
61 |
+
linear_end=2e-2,
|
62 |
+
cosine_s=8e-3,
|
63 |
+
given_betas=None,
|
64 |
+
original_elbo_weight=0.,
|
65 |
+
v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
|
66 |
+
l_simple_weight=1.,
|
67 |
+
conditioning_key=None,
|
68 |
+
parameterization="eps", # all assuming fixed variance schedules
|
69 |
+
scheduler_config=None,
|
70 |
+
use_positional_encodings=False,
|
71 |
+
learn_logvar=False,
|
72 |
+
logvar_init=0.,
|
73 |
+
):
|
74 |
+
super().__init__()
|
75 |
+
assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
|
76 |
+
self.parameterization = parameterization
|
77 |
+
print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
|
78 |
+
self.cond_stage_model = None
|
79 |
+
self.clip_denoised = clip_denoised
|
80 |
+
self.log_every_t = log_every_t
|
81 |
+
self.first_stage_key = first_stage_key
|
82 |
+
self.image_size = image_size # try conv?
|
83 |
+
self.channels = channels
|
84 |
+
self.use_positional_encodings = use_positional_encodings
|
85 |
+
self.model = DiffusionWrapperV1(unet_config, conditioning_key)
|
86 |
+
count_params(self.model, verbose=True)
|
87 |
+
self.use_ema = use_ema
|
88 |
+
if self.use_ema:
|
89 |
+
self.model_ema = LitEma(self.model)
|
90 |
+
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
91 |
+
|
92 |
+
self.use_scheduler = scheduler_config is not None
|
93 |
+
if self.use_scheduler:
|
94 |
+
self.scheduler_config = scheduler_config
|
95 |
+
|
96 |
+
self.v_posterior = v_posterior
|
97 |
+
self.original_elbo_weight = original_elbo_weight
|
98 |
+
self.l_simple_weight = l_simple_weight
|
99 |
+
|
100 |
+
if monitor is not None:
|
101 |
+
self.monitor = monitor
|
102 |
+
if ckpt_path is not None:
|
103 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
|
104 |
+
|
105 |
+
self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
|
106 |
+
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
|
107 |
+
|
108 |
+
self.loss_type = loss_type
|
109 |
+
|
110 |
+
self.learn_logvar = learn_logvar
|
111 |
+
self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
|
112 |
+
if self.learn_logvar:
|
113 |
+
self.logvar = nn.Parameter(self.logvar, requires_grad=True)
|
114 |
+
|
115 |
+
|
116 |
+
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
117 |
+
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
118 |
+
if exists(given_betas):
|
119 |
+
betas = given_betas
|
120 |
+
else:
|
121 |
+
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
|
122 |
+
cosine_s=cosine_s)
|
123 |
+
alphas = 1. - betas
|
124 |
+
alphas_cumprod = np.cumprod(alphas, axis=0)
|
125 |
+
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
126 |
+
|
127 |
+
timesteps, = betas.shape
|
128 |
+
self.num_timesteps = int(timesteps)
|
129 |
+
self.linear_start = linear_start
|
130 |
+
self.linear_end = linear_end
|
131 |
+
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
|
132 |
+
|
133 |
+
to_torch = partial(torch.tensor, dtype=torch.float32)
|
134 |
+
|
135 |
+
self.register_buffer('betas', to_torch(betas))
|
136 |
+
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
137 |
+
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
|
138 |
+
|
139 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
140 |
+
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
|
141 |
+
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
|
142 |
+
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
|
143 |
+
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
|
144 |
+
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
|
145 |
+
|
146 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
147 |
+
posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
|
148 |
+
1. - alphas_cumprod) + self.v_posterior * betas
|
149 |
+
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
150 |
+
self.register_buffer('posterior_variance', to_torch(posterior_variance))
|
151 |
+
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
152 |
+
self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
|
153 |
+
self.register_buffer('posterior_mean_coef1', to_torch(
|
154 |
+
betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
|
155 |
+
self.register_buffer('posterior_mean_coef2', to_torch(
|
156 |
+
(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
|
157 |
+
|
158 |
+
if self.parameterization == "eps":
|
159 |
+
lvlb_weights = self.betas ** 2 / (
|
160 |
+
2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
|
161 |
+
elif self.parameterization == "x0":
|
162 |
+
lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
|
163 |
+
else:
|
164 |
+
raise NotImplementedError("mu not supported")
|
165 |
+
# TODO how to choose this term
|
166 |
+
lvlb_weights[0] = lvlb_weights[1]
|
167 |
+
self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
|
168 |
+
assert not torch.isnan(self.lvlb_weights).all()
|
169 |
+
|
170 |
+
@contextmanager
|
171 |
+
def ema_scope(self, context=None):
|
172 |
+
if self.use_ema:
|
173 |
+
self.model_ema.store(self.model.parameters())
|
174 |
+
self.model_ema.copy_to(self.model)
|
175 |
+
if context is not None:
|
176 |
+
print(f"{context}: Switched to EMA weights")
|
177 |
+
try:
|
178 |
+
yield None
|
179 |
+
finally:
|
180 |
+
if self.use_ema:
|
181 |
+
self.model_ema.restore(self.model.parameters())
|
182 |
+
if context is not None:
|
183 |
+
print(f"{context}: Restored training weights")
|
184 |
+
|
185 |
+
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
|
186 |
+
sd = torch.load(path, map_location="cpu")
|
187 |
+
if "state_dict" in list(sd.keys()):
|
188 |
+
sd = sd["state_dict"]
|
189 |
+
keys = list(sd.keys())
|
190 |
+
for k in keys:
|
191 |
+
for ik in ignore_keys:
|
192 |
+
if k.startswith(ik):
|
193 |
+
print("Deleting key {} from state_dict.".format(k))
|
194 |
+
del sd[k]
|
195 |
+
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
|
196 |
+
sd, strict=False)
|
197 |
+
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
198 |
+
if len(missing) > 0:
|
199 |
+
print(f"Missing Keys: {missing}")
|
200 |
+
if len(unexpected) > 0:
|
201 |
+
print(f"Unexpected Keys: {unexpected}")
|
202 |
+
|
203 |
+
def q_mean_variance(self, x_start, t):
|
204 |
+
"""
|
205 |
+
Get the distribution q(x_t | x_0).
|
206 |
+
:param x_start: the [N x C x ...] tensor of noiseless inputs.
|
207 |
+
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
208 |
+
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
|
209 |
+
"""
|
210 |
+
mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
|
211 |
+
variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
|
212 |
+
log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
213 |
+
return mean, variance, log_variance
|
214 |
+
|
215 |
+
def predict_start_from_noise(self, x_t, t, noise):
|
216 |
+
return (
|
217 |
+
extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
218 |
+
extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
219 |
+
)
|
220 |
+
|
221 |
+
def q_posterior(self, x_start, x_t, t):
|
222 |
+
posterior_mean = (
|
223 |
+
extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
224 |
+
extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
225 |
+
)
|
226 |
+
posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
|
227 |
+
posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
|
228 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
229 |
+
|
230 |
+
def p_mean_variance(self, x, t, clip_denoised: bool):
|
231 |
+
model_out = self.model(x, t)
|
232 |
+
if self.parameterization == "eps":
|
233 |
+
x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
|
234 |
+
elif self.parameterization == "x0":
|
235 |
+
x_recon = model_out
|
236 |
+
if clip_denoised:
|
237 |
+
x_recon.clamp_(-1., 1.)
|
238 |
+
|
239 |
+
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
240 |
+
return model_mean, posterior_variance, posterior_log_variance
|
241 |
+
|
242 |
+
@torch.no_grad()
|
243 |
+
def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
|
244 |
+
b, *_, device = *x.shape, x.device
|
245 |
+
model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
|
246 |
+
noise = noise_like(x.shape, device, repeat_noise)
|
247 |
+
# no noise when t == 0
|
248 |
+
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
249 |
+
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
250 |
+
|
251 |
+
@torch.no_grad()
|
252 |
+
def p_sample_loop(self, shape, return_intermediates=False):
|
253 |
+
device = self.betas.device
|
254 |
+
b = shape[0]
|
255 |
+
img = torch.randn(shape, device=device)
|
256 |
+
intermediates = [img]
|
257 |
+
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
|
258 |
+
img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
|
259 |
+
clip_denoised=self.clip_denoised)
|
260 |
+
if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
|
261 |
+
intermediates.append(img)
|
262 |
+
if return_intermediates:
|
263 |
+
return img, intermediates
|
264 |
+
return img
|
265 |
+
|
266 |
+
@torch.no_grad()
|
267 |
+
def sample(self, batch_size=16, return_intermediates=False):
|
268 |
+
image_size = self.image_size
|
269 |
+
channels = self.channels
|
270 |
+
return self.p_sample_loop((batch_size, channels, image_size, image_size),
|
271 |
+
return_intermediates=return_intermediates)
|
272 |
+
|
273 |
+
def q_sample(self, x_start, t, noise=None):
|
274 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
275 |
+
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
276 |
+
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
|
277 |
+
|
278 |
+
def get_loss(self, pred, target, mean=True):
|
279 |
+
if self.loss_type == 'l1':
|
280 |
+
loss = (target - pred).abs()
|
281 |
+
if mean:
|
282 |
+
loss = loss.mean()
|
283 |
+
elif self.loss_type == 'l2':
|
284 |
+
if mean:
|
285 |
+
loss = torch.nn.functional.mse_loss(target, pred)
|
286 |
+
else:
|
287 |
+
loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
|
288 |
+
else:
|
289 |
+
raise NotImplementedError("unknown loss type '{loss_type}'")
|
290 |
+
|
291 |
+
return loss
|
292 |
+
|
293 |
+
def p_losses(self, x_start, t, noise=None):
|
294 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
295 |
+
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
|
296 |
+
model_out = self.model(x_noisy, t)
|
297 |
+
|
298 |
+
loss_dict = {}
|
299 |
+
if self.parameterization == "eps":
|
300 |
+
target = noise
|
301 |
+
elif self.parameterization == "x0":
|
302 |
+
target = x_start
|
303 |
+
else:
|
304 |
+
raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")
|
305 |
+
|
306 |
+
loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
|
307 |
+
|
308 |
+
log_prefix = 'train' if self.training else 'val'
|
309 |
+
|
310 |
+
loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
|
311 |
+
loss_simple = loss.mean() * self.l_simple_weight
|
312 |
+
|
313 |
+
loss_vlb = (self.lvlb_weights[t] * loss).mean()
|
314 |
+
loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
|
315 |
+
|
316 |
+
loss = loss_simple + self.original_elbo_weight * loss_vlb
|
317 |
+
|
318 |
+
loss_dict.update({f'{log_prefix}/loss': loss})
|
319 |
+
|
320 |
+
return loss, loss_dict
|
321 |
+
|
322 |
+
def forward(self, x, *args, **kwargs):
|
323 |
+
# b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
|
324 |
+
# assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
|
325 |
+
t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
|
326 |
+
return self.p_losses(x, t, *args, **kwargs)
|
327 |
+
|
328 |
+
def get_input(self, batch, k):
|
329 |
+
x = batch[k]
|
330 |
+
if len(x.shape) == 3:
|
331 |
+
x = x[..., None]
|
332 |
+
x = rearrange(x, 'b h w c -> b c h w')
|
333 |
+
x = x.to(memory_format=torch.contiguous_format).float()
|
334 |
+
return x
|
335 |
+
|
336 |
+
def shared_step(self, batch):
|
337 |
+
x = self.get_input(batch, self.first_stage_key)
|
338 |
+
loss, loss_dict = self(x)
|
339 |
+
return loss, loss_dict
|
340 |
+
|
341 |
+
def training_step(self, batch, batch_idx):
|
342 |
+
loss, loss_dict = self.shared_step(batch)
|
343 |
+
|
344 |
+
self.log_dict(loss_dict, prog_bar=True,
|
345 |
+
logger=True, on_step=True, on_epoch=True)
|
346 |
+
|
347 |
+
self.log("global_step", self.global_step,
|
348 |
+
prog_bar=True, logger=True, on_step=True, on_epoch=False)
|
349 |
+
|
350 |
+
if self.use_scheduler:
|
351 |
+
lr = self.optimizers().param_groups[0]['lr']
|
352 |
+
self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
|
353 |
+
|
354 |
+
return loss
|
355 |
+
|
356 |
+
@torch.no_grad()
|
357 |
+
def validation_step(self, batch, batch_idx):
|
358 |
+
_, loss_dict_no_ema = self.shared_step(batch)
|
359 |
+
with self.ema_scope():
|
360 |
+
_, loss_dict_ema = self.shared_step(batch)
|
361 |
+
loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
|
362 |
+
self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
|
363 |
+
self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
|
364 |
+
|
365 |
+
def on_train_batch_end(self, *args, **kwargs):
|
366 |
+
if self.use_ema:
|
367 |
+
self.model_ema(self.model)
|
368 |
+
|
369 |
+
def _get_rows_from_list(self, samples):
|
370 |
+
n_imgs_per_row = len(samples)
|
371 |
+
denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
|
372 |
+
denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
|
373 |
+
denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
|
374 |
+
return denoise_grid
|
375 |
+
|
376 |
+
@torch.no_grad()
|
377 |
+
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
|
378 |
+
log = dict()
|
379 |
+
x = self.get_input(batch, self.first_stage_key)
|
380 |
+
N = min(x.shape[0], N)
|
381 |
+
n_row = min(x.shape[0], n_row)
|
382 |
+
x = x.to(self.device)[:N]
|
383 |
+
log["inputs"] = x
|
384 |
+
|
385 |
+
# get diffusion row
|
386 |
+
diffusion_row = list()
|
387 |
+
x_start = x[:n_row]
|
388 |
+
|
389 |
+
for t in range(self.num_timesteps):
|
390 |
+
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
391 |
+
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
|
392 |
+
t = t.to(self.device).long()
|
393 |
+
noise = torch.randn_like(x_start)
|
394 |
+
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
|
395 |
+
diffusion_row.append(x_noisy)
|
396 |
+
|
397 |
+
log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
|
398 |
+
|
399 |
+
if sample:
|
400 |
+
# get denoise row
|
401 |
+
with self.ema_scope("Plotting"):
|
402 |
+
samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
|
403 |
+
|
404 |
+
log["samples"] = samples
|
405 |
+
log["denoise_row"] = self._get_rows_from_list(denoise_row)
|
406 |
+
|
407 |
+
if return_keys:
|
408 |
+
if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
|
409 |
+
return log
|
410 |
+
else:
|
411 |
+
return {key: log[key] for key in return_keys}
|
412 |
+
return log
|
413 |
+
|
414 |
+
def configure_optimizers(self):
|
415 |
+
lr = self.learning_rate
|
416 |
+
params = list(self.model.parameters())
|
417 |
+
if self.learn_logvar:
|
418 |
+
params = params + [self.logvar]
|
419 |
+
opt = torch.optim.AdamW(params, lr=lr)
|
420 |
+
return opt
|
421 |
+
|
422 |
+
|
423 |
+
class LatentDiffusionV1(DDPMV1):
|
424 |
+
"""main class"""
|
425 |
+
def __init__(self,
|
426 |
+
first_stage_config,
|
427 |
+
cond_stage_config,
|
428 |
+
num_timesteps_cond=None,
|
429 |
+
cond_stage_key="image",
|
430 |
+
cond_stage_trainable=False,
|
431 |
+
concat_mode=True,
|
432 |
+
cond_stage_forward=None,
|
433 |
+
conditioning_key=None,
|
434 |
+
scale_factor=1.0,
|
435 |
+
scale_by_std=False,
|
436 |
+
*args, **kwargs):
|
437 |
+
self.num_timesteps_cond = default(num_timesteps_cond, 1)
|
438 |
+
self.scale_by_std = scale_by_std
|
439 |
+
assert self.num_timesteps_cond <= kwargs['timesteps']
|
440 |
+
# for backwards compatibility after implementation of DiffusionWrapper
|
441 |
+
if conditioning_key is None:
|
442 |
+
conditioning_key = 'concat' if concat_mode else 'crossattn'
|
443 |
+
if cond_stage_config == '__is_unconditional__':
|
444 |
+
conditioning_key = None
|
445 |
+
ckpt_path = kwargs.pop("ckpt_path", None)
|
446 |
+
ignore_keys = kwargs.pop("ignore_keys", [])
|
447 |
+
super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
|
448 |
+
self.concat_mode = concat_mode
|
449 |
+
self.cond_stage_trainable = cond_stage_trainable
|
450 |
+
self.cond_stage_key = cond_stage_key
|
451 |
+
try:
|
452 |
+
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
|
453 |
+
except:
|
454 |
+
self.num_downs = 0
|
455 |
+
if not scale_by_std:
|
456 |
+
self.scale_factor = scale_factor
|
457 |
+
else:
|
458 |
+
self.register_buffer('scale_factor', torch.tensor(scale_factor))
|
459 |
+
self.instantiate_first_stage(first_stage_config)
|
460 |
+
self.instantiate_cond_stage(cond_stage_config)
|
461 |
+
self.cond_stage_forward = cond_stage_forward
|
462 |
+
self.clip_denoised = False
|
463 |
+
self.bbox_tokenizer = None
|
464 |
+
|
465 |
+
self.restarted_from_ckpt = False
|
466 |
+
if ckpt_path is not None:
|
467 |
+
self.init_from_ckpt(ckpt_path, ignore_keys)
|
468 |
+
self.restarted_from_ckpt = True
|
469 |
+
|
470 |
+
def make_cond_schedule(self, ):
|
471 |
+
self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
|
472 |
+
ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
|
473 |
+
self.cond_ids[:self.num_timesteps_cond] = ids
|
474 |
+
|
475 |
+
@rank_zero_only
|
476 |
+
@torch.no_grad()
|
477 |
+
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
|
478 |
+
# only for very first batch
|
479 |
+
if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
|
480 |
+
assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
|
481 |
+
# set rescale weight to 1./std of encodings
|
482 |
+
print("### USING STD-RESCALING ###")
|
483 |
+
x = super().get_input(batch, self.first_stage_key)
|
484 |
+
x = x.to(self.device)
|
485 |
+
encoder_posterior = self.encode_first_stage(x)
|
486 |
+
z = self.get_first_stage_encoding(encoder_posterior).detach()
|
487 |
+
del self.scale_factor
|
488 |
+
self.register_buffer('scale_factor', 1. / z.flatten().std())
|
489 |
+
print(f"setting self.scale_factor to {self.scale_factor}")
|
490 |
+
print("### USING STD-RESCALING ###")
|
491 |
+
|
492 |
+
def register_schedule(self,
|
493 |
+
given_betas=None, beta_schedule="linear", timesteps=1000,
|
494 |
+
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
495 |
+
super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
|
496 |
+
|
497 |
+
self.shorten_cond_schedule = self.num_timesteps_cond > 1
|
498 |
+
if self.shorten_cond_schedule:
|
499 |
+
self.make_cond_schedule()
|
500 |
+
|
501 |
+
def instantiate_first_stage(self, config):
|
502 |
+
model = instantiate_from_config(config)
|
503 |
+
self.first_stage_model = model.eval()
|
504 |
+
self.first_stage_model.train = disabled_train
|
505 |
+
for param in self.first_stage_model.parameters():
|
506 |
+
param.requires_grad = False
|
507 |
+
|
508 |
+
def instantiate_cond_stage(self, config):
|
509 |
+
if not self.cond_stage_trainable:
|
510 |
+
if config == "__is_first_stage__":
|
511 |
+
print("Using first stage also as cond stage.")
|
512 |
+
self.cond_stage_model = self.first_stage_model
|
513 |
+
elif config == "__is_unconditional__":
|
514 |
+
print(f"Training {self.__class__.__name__} as an unconditional model.")
|
515 |
+
self.cond_stage_model = None
|
516 |
+
# self.be_unconditional = True
|
517 |
+
else:
|
518 |
+
model = instantiate_from_config(config)
|
519 |
+
self.cond_stage_model = model.eval()
|
520 |
+
self.cond_stage_model.train = disabled_train
|
521 |
+
for param in self.cond_stage_model.parameters():
|
522 |
+
param.requires_grad = False
|
523 |
+
else:
|
524 |
+
assert config != '__is_first_stage__'
|
525 |
+
assert config != '__is_unconditional__'
|
526 |
+
model = instantiate_from_config(config)
|
527 |
+
self.cond_stage_model = model
|
528 |
+
|
529 |
+
def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
|
530 |
+
denoise_row = []
|
531 |
+
for zd in tqdm(samples, desc=desc):
|
532 |
+
denoise_row.append(self.decode_first_stage(zd.to(self.device),
|
533 |
+
force_not_quantize=force_no_decoder_quantization))
|
534 |
+
n_imgs_per_row = len(denoise_row)
|
535 |
+
denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
|
536 |
+
denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
|
537 |
+
denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
|
538 |
+
denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
|
539 |
+
return denoise_grid
|
540 |
+
|
541 |
+
def get_first_stage_encoding(self, encoder_posterior):
|
542 |
+
if isinstance(encoder_posterior, DiagonalGaussianDistribution):
|
543 |
+
z = encoder_posterior.sample()
|
544 |
+
elif isinstance(encoder_posterior, torch.Tensor):
|
545 |
+
z = encoder_posterior
|
546 |
+
else:
|
547 |
+
raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
|
548 |
+
return self.scale_factor * z
|
549 |
+
|
550 |
+
def get_learned_conditioning(self, c):
|
551 |
+
if self.cond_stage_forward is None:
|
552 |
+
if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
|
553 |
+
c = self.cond_stage_model.encode(c)
|
554 |
+
if isinstance(c, DiagonalGaussianDistribution):
|
555 |
+
c = c.mode()
|
556 |
+
else:
|
557 |
+
c = self.cond_stage_model(c)
|
558 |
+
else:
|
559 |
+
assert hasattr(self.cond_stage_model, self.cond_stage_forward)
|
560 |
+
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
|
561 |
+
return c
|
562 |
+
|
563 |
+
def meshgrid(self, h, w):
|
564 |
+
y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
|
565 |
+
x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
|
566 |
+
|
567 |
+
arr = torch.cat([y, x], dim=-1)
|
568 |
+
return arr
|
569 |
+
|
570 |
+
def delta_border(self, h, w):
|
571 |
+
"""
|
572 |
+
:param h: height
|
573 |
+
:param w: width
|
574 |
+
:return: normalized distance to image border,
|
575 |
+
wtith min distance = 0 at border and max dist = 0.5 at image center
|
576 |
+
"""
|
577 |
+
lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
|
578 |
+
arr = self.meshgrid(h, w) / lower_right_corner
|
579 |
+
dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
|
580 |
+
dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
|
581 |
+
edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
|
582 |
+
return edge_dist
|
583 |
+
|
584 |
+
def get_weighting(self, h, w, Ly, Lx, device):
|
585 |
+
weighting = self.delta_border(h, w)
|
586 |
+
weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
|
587 |
+
self.split_input_params["clip_max_weight"], )
|
588 |
+
weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
|
589 |
+
|
590 |
+
if self.split_input_params["tie_braker"]:
|
591 |
+
L_weighting = self.delta_border(Ly, Lx)
|
592 |
+
L_weighting = torch.clip(L_weighting,
|
593 |
+
self.split_input_params["clip_min_tie_weight"],
|
594 |
+
self.split_input_params["clip_max_tie_weight"])
|
595 |
+
|
596 |
+
L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
|
597 |
+
weighting = weighting * L_weighting
|
598 |
+
return weighting
|
599 |
+
|
600 |
+
def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
|
601 |
+
"""
|
602 |
+
:param x: img of size (bs, c, h, w)
|
603 |
+
:return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
|
604 |
+
"""
|
605 |
+
bs, nc, h, w = x.shape
|
606 |
+
|
607 |
+
# number of crops in image
|
608 |
+
Ly = (h - kernel_size[0]) // stride[0] + 1
|
609 |
+
Lx = (w - kernel_size[1]) // stride[1] + 1
|
610 |
+
|
611 |
+
if uf == 1 and df == 1:
|
612 |
+
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
|
613 |
+
unfold = torch.nn.Unfold(**fold_params)
|
614 |
+
|
615 |
+
fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
|
616 |
+
|
617 |
+
weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
|
618 |
+
normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
|
619 |
+
weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
|
620 |
+
|
621 |
+
elif uf > 1 and df == 1:
|
622 |
+
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
|
623 |
+
unfold = torch.nn.Unfold(**fold_params)
|
624 |
+
|
625 |
+
fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
|
626 |
+
dilation=1, padding=0,
|
627 |
+
stride=(stride[0] * uf, stride[1] * uf))
|
628 |
+
fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
|
629 |
+
|
630 |
+
weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
|
631 |
+
normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
|
632 |
+
weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
|
633 |
+
|
634 |
+
elif df > 1 and uf == 1:
|
635 |
+
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
|
636 |
+
unfold = torch.nn.Unfold(**fold_params)
|
637 |
+
|
638 |
+
fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
|
639 |
+
dilation=1, padding=0,
|
640 |
+
stride=(stride[0] // df, stride[1] // df))
|
641 |
+
fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
|
642 |
+
|
643 |
+
weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
|
644 |
+
normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
|
645 |
+
weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
|
646 |
+
|
647 |
+
else:
|
648 |
+
raise NotImplementedError
|
649 |
+
|
650 |
+
return fold, unfold, normalization, weighting
|
651 |
+
|
652 |
+
@torch.no_grad()
|
653 |
+
def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
|
654 |
+
cond_key=None, return_original_cond=False, bs=None):
|
655 |
+
x = super().get_input(batch, k)
|
656 |
+
if bs is not None:
|
657 |
+
x = x[:bs]
|
658 |
+
x = x.to(self.device)
|
659 |
+
encoder_posterior = self.encode_first_stage(x)
|
660 |
+
z = self.get_first_stage_encoding(encoder_posterior).detach()
|
661 |
+
|
662 |
+
if self.model.conditioning_key is not None:
|
663 |
+
if cond_key is None:
|
664 |
+
cond_key = self.cond_stage_key
|
665 |
+
if cond_key != self.first_stage_key:
|
666 |
+
if cond_key in ['caption', 'coordinates_bbox']:
|
667 |
+
xc = batch[cond_key]
|
668 |
+
elif cond_key == 'class_label':
|
669 |
+
xc = batch
|
670 |
+
else:
|
671 |
+
xc = super().get_input(batch, cond_key).to(self.device)
|
672 |
+
else:
|
673 |
+
xc = x
|
674 |
+
if not self.cond_stage_trainable or force_c_encode:
|
675 |
+
if isinstance(xc, dict) or isinstance(xc, list):
|
676 |
+
# import pudb; pudb.set_trace()
|
677 |
+
c = self.get_learned_conditioning(xc)
|
678 |
+
else:
|
679 |
+
c = self.get_learned_conditioning(xc.to(self.device))
|
680 |
+
else:
|
681 |
+
c = xc
|
682 |
+
if bs is not None:
|
683 |
+
c = c[:bs]
|
684 |
+
|
685 |
+
if self.use_positional_encodings:
|
686 |
+
pos_x, pos_y = self.compute_latent_shifts(batch)
|
687 |
+
ckey = __conditioning_keys__[self.model.conditioning_key]
|
688 |
+
c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}
|
689 |
+
|
690 |
+
else:
|
691 |
+
c = None
|
692 |
+
xc = None
|
693 |
+
if self.use_positional_encodings:
|
694 |
+
pos_x, pos_y = self.compute_latent_shifts(batch)
|
695 |
+
c = {'pos_x': pos_x, 'pos_y': pos_y}
|
696 |
+
out = [z, c]
|
697 |
+
if return_first_stage_outputs:
|
698 |
+
xrec = self.decode_first_stage(z)
|
699 |
+
out.extend([x, xrec])
|
700 |
+
if return_original_cond:
|
701 |
+
out.append(xc)
|
702 |
+
return out
|
703 |
+
|
704 |
+
@torch.no_grad()
|
705 |
+
def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
|
706 |
+
if predict_cids:
|
707 |
+
if z.dim() == 4:
|
708 |
+
z = torch.argmax(z.exp(), dim=1).long()
|
709 |
+
z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
|
710 |
+
z = rearrange(z, 'b h w c -> b c h w').contiguous()
|
711 |
+
|
712 |
+
z = 1. / self.scale_factor * z
|
713 |
+
|
714 |
+
if hasattr(self, "split_input_params"):
|
715 |
+
if self.split_input_params["patch_distributed_vq"]:
|
716 |
+
ks = self.split_input_params["ks"] # eg. (128, 128)
|
717 |
+
stride = self.split_input_params["stride"] # eg. (64, 64)
|
718 |
+
uf = self.split_input_params["vqf"]
|
719 |
+
bs, nc, h, w = z.shape
|
720 |
+
if ks[0] > h or ks[1] > w:
|
721 |
+
ks = (min(ks[0], h), min(ks[1], w))
|
722 |
+
print("reducing Kernel")
|
723 |
+
|
724 |
+
if stride[0] > h or stride[1] > w:
|
725 |
+
stride = (min(stride[0], h), min(stride[1], w))
|
726 |
+
print("reducing stride")
|
727 |
+
|
728 |
+
fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
|
729 |
+
|
730 |
+
z = unfold(z) # (bn, nc * prod(**ks), L)
|
731 |
+
# 1. Reshape to img shape
|
732 |
+
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
733 |
+
|
734 |
+
# 2. apply model loop over last dim
|
735 |
+
if isinstance(self.first_stage_model, VQModelInterface):
|
736 |
+
output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
|
737 |
+
force_not_quantize=predict_cids or force_not_quantize)
|
738 |
+
for i in range(z.shape[-1])]
|
739 |
+
else:
|
740 |
+
|
741 |
+
output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
|
742 |
+
for i in range(z.shape[-1])]
|
743 |
+
|
744 |
+
o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
|
745 |
+
o = o * weighting
|
746 |
+
# Reverse 1. reshape to img shape
|
747 |
+
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
|
748 |
+
# stitch crops together
|
749 |
+
decoded = fold(o)
|
750 |
+
decoded = decoded / normalization # norm is shape (1, 1, h, w)
|
751 |
+
return decoded
|
752 |
+
else:
|
753 |
+
if isinstance(self.first_stage_model, VQModelInterface):
|
754 |
+
return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
|
755 |
+
else:
|
756 |
+
return self.first_stage_model.decode(z)
|
757 |
+
|
758 |
+
else:
|
759 |
+
if isinstance(self.first_stage_model, VQModelInterface):
|
760 |
+
return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
|
761 |
+
else:
|
762 |
+
return self.first_stage_model.decode(z)
|
763 |
+
|
764 |
+
# same as above but without decorator
|
765 |
+
def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
|
766 |
+
if predict_cids:
|
767 |
+
if z.dim() == 4:
|
768 |
+
z = torch.argmax(z.exp(), dim=1).long()
|
769 |
+
z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
|
770 |
+
z = rearrange(z, 'b h w c -> b c h w').contiguous()
|
771 |
+
|
772 |
+
z = 1. / self.scale_factor * z
|
773 |
+
|
774 |
+
if hasattr(self, "split_input_params"):
|
775 |
+
if self.split_input_params["patch_distributed_vq"]:
|
776 |
+
ks = self.split_input_params["ks"] # eg. (128, 128)
|
777 |
+
stride = self.split_input_params["stride"] # eg. (64, 64)
|
778 |
+
uf = self.split_input_params["vqf"]
|
779 |
+
bs, nc, h, w = z.shape
|
780 |
+
if ks[0] > h or ks[1] > w:
|
781 |
+
ks = (min(ks[0], h), min(ks[1], w))
|
782 |
+
print("reducing Kernel")
|
783 |
+
|
784 |
+
if stride[0] > h or stride[1] > w:
|
785 |
+
stride = (min(stride[0], h), min(stride[1], w))
|
786 |
+
print("reducing stride")
|
787 |
+
|
788 |
+
fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
|
789 |
+
|
790 |
+
z = unfold(z) # (bn, nc * prod(**ks), L)
|
791 |
+
# 1. Reshape to img shape
|
792 |
+
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
793 |
+
|
794 |
+
# 2. apply model loop over last dim
|
795 |
+
if isinstance(self.first_stage_model, VQModelInterface):
|
796 |
+
output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
|
797 |
+
force_not_quantize=predict_cids or force_not_quantize)
|
798 |
+
for i in range(z.shape[-1])]
|
799 |
+
else:
|
800 |
+
|
801 |
+
output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
|
802 |
+
for i in range(z.shape[-1])]
|
803 |
+
|
804 |
+
o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
|
805 |
+
o = o * weighting
|
806 |
+
# Reverse 1. reshape to img shape
|
807 |
+
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
|
808 |
+
# stitch crops together
|
809 |
+
decoded = fold(o)
|
810 |
+
decoded = decoded / normalization # norm is shape (1, 1, h, w)
|
811 |
+
return decoded
|
812 |
+
else:
|
813 |
+
if isinstance(self.first_stage_model, VQModelInterface):
|
814 |
+
return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
|
815 |
+
else:
|
816 |
+
return self.first_stage_model.decode(z)
|
817 |
+
|
818 |
+
else:
|
819 |
+
if isinstance(self.first_stage_model, VQModelInterface):
|
820 |
+
return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
|
821 |
+
else:
|
822 |
+
return self.first_stage_model.decode(z)
|
823 |
+
|
824 |
+
@torch.no_grad()
|
825 |
+
def encode_first_stage(self, x):
|
826 |
+
if hasattr(self, "split_input_params"):
|
827 |
+
if self.split_input_params["patch_distributed_vq"]:
|
828 |
+
ks = self.split_input_params["ks"] # eg. (128, 128)
|
829 |
+
stride = self.split_input_params["stride"] # eg. (64, 64)
|
830 |
+
df = self.split_input_params["vqf"]
|
831 |
+
self.split_input_params['original_image_size'] = x.shape[-2:]
|
832 |
+
bs, nc, h, w = x.shape
|
833 |
+
if ks[0] > h or ks[1] > w:
|
834 |
+
ks = (min(ks[0], h), min(ks[1], w))
|
835 |
+
print("reducing Kernel")
|
836 |
+
|
837 |
+
if stride[0] > h or stride[1] > w:
|
838 |
+
stride = (min(stride[0], h), min(stride[1], w))
|
839 |
+
print("reducing stride")
|
840 |
+
|
841 |
+
fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)
|
842 |
+
z = unfold(x) # (bn, nc * prod(**ks), L)
|
843 |
+
# Reshape to img shape
|
844 |
+
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
845 |
+
|
846 |
+
output_list = [self.first_stage_model.encode(z[:, :, :, :, i])
|
847 |
+
for i in range(z.shape[-1])]
|
848 |
+
|
849 |
+
o = torch.stack(output_list, axis=-1)
|
850 |
+
o = o * weighting
|
851 |
+
|
852 |
+
# Reverse reshape to img shape
|
853 |
+
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
|
854 |
+
# stitch crops together
|
855 |
+
decoded = fold(o)
|
856 |
+
decoded = decoded / normalization
|
857 |
+
return decoded
|
858 |
+
|
859 |
+
else:
|
860 |
+
return self.first_stage_model.encode(x)
|
861 |
+
else:
|
862 |
+
return self.first_stage_model.encode(x)
|
863 |
+
|
864 |
+
def shared_step(self, batch, **kwargs):
|
865 |
+
x, c = self.get_input(batch, self.first_stage_key)
|
866 |
+
loss = self(x, c)
|
867 |
+
return loss
|
868 |
+
|
869 |
+
def forward(self, x, c, *args, **kwargs):
|
870 |
+
t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
|
871 |
+
if self.model.conditioning_key is not None:
|
872 |
+
assert c is not None
|
873 |
+
if self.cond_stage_trainable:
|
874 |
+
c = self.get_learned_conditioning(c)
|
875 |
+
if self.shorten_cond_schedule: # TODO: drop this option
|
876 |
+
tc = self.cond_ids[t].to(self.device)
|
877 |
+
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
|
878 |
+
return self.p_losses(x, c, t, *args, **kwargs)
|
879 |
+
|
880 |
+
def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
|
881 |
+
def rescale_bbox(bbox):
|
882 |
+
x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
|
883 |
+
y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
|
884 |
+
w = min(bbox[2] / crop_coordinates[2], 1 - x0)
|
885 |
+
h = min(bbox[3] / crop_coordinates[3], 1 - y0)
|
886 |
+
return x0, y0, w, h
|
887 |
+
|
888 |
+
return [rescale_bbox(b) for b in bboxes]
|
889 |
+
|
890 |
+
def apply_model(self, x_noisy, t, cond, return_ids=False):
|
891 |
+
|
892 |
+
if isinstance(cond, dict):
|
893 |
+
# hybrid case, cond is exptected to be a dict
|
894 |
+
pass
|
895 |
+
else:
|
896 |
+
if not isinstance(cond, list):
|
897 |
+
cond = [cond]
|
898 |
+
key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
|
899 |
+
cond = {key: cond}
|
900 |
+
|
901 |
+
if hasattr(self, "split_input_params"):
|
902 |
+
assert len(cond) == 1 # todo can only deal with one conditioning atm
|
903 |
+
assert not return_ids
|
904 |
+
ks = self.split_input_params["ks"] # eg. (128, 128)
|
905 |
+
stride = self.split_input_params["stride"] # eg. (64, 64)
|
906 |
+
|
907 |
+
h, w = x_noisy.shape[-2:]
|
908 |
+
|
909 |
+
fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)
|
910 |
+
|
911 |
+
z = unfold(x_noisy) # (bn, nc * prod(**ks), L)
|
912 |
+
# Reshape to img shape
|
913 |
+
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
914 |
+
z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
|
915 |
+
|
916 |
+
if self.cond_stage_key in ["image", "LR_image", "segmentation",
|
917 |
+
'bbox_img'] and self.model.conditioning_key: # todo check for completeness
|
918 |
+
c_key = next(iter(cond.keys())) # get key
|
919 |
+
c = next(iter(cond.values())) # get value
|
920 |
+
assert (len(c) == 1) # todo extend to list with more than one elem
|
921 |
+
c = c[0] # get element
|
922 |
+
|
923 |
+
c = unfold(c)
|
924 |
+
c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
925 |
+
|
926 |
+
cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
|
927 |
+
|
928 |
+
elif self.cond_stage_key == 'coordinates_bbox':
|
929 |
+
assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size'
|
930 |
+
|
931 |
+
# assuming padding of unfold is always 0 and its dilation is always 1
|
932 |
+
n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
|
933 |
+
full_img_h, full_img_w = self.split_input_params['original_image_size']
|
934 |
+
# as we are operating on latents, we need the factor from the original image size to the
|
935 |
+
# spatial latent size to properly rescale the crops for regenerating the bbox annotations
|
936 |
+
num_downs = self.first_stage_model.encoder.num_resolutions - 1
|
937 |
+
rescale_latent = 2 ** (num_downs)
|
938 |
+
|
939 |
+
# get top left postions of patches as conforming for the bbbox tokenizer, therefore we
|
940 |
+
# need to rescale the tl patch coordinates to be in between (0,1)
|
941 |
+
tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
|
942 |
+
rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
|
943 |
+
for patch_nr in range(z.shape[-1])]
|
944 |
+
|
945 |
+
# patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)
|
946 |
+
patch_limits = [(x_tl, y_tl,
|
947 |
+
rescale_latent * ks[0] / full_img_w,
|
948 |
+
rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]
|
949 |
+
# patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]
|
950 |
+
|
951 |
+
# tokenize crop coordinates for the bounding boxes of the respective patches
|
952 |
+
patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)
|
953 |
+
for bbox in patch_limits] # list of length l with tensors of shape (1, 2)
|
954 |
+
print(patch_limits_tknzd[0].shape)
|
955 |
+
# cut tknzd crop position from conditioning
|
956 |
+
assert isinstance(cond, dict), 'cond must be dict to be fed into model'
|
957 |
+
cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)
|
958 |
+
print(cut_cond.shape)
|
959 |
+
|
960 |
+
adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])
|
961 |
+
adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')
|
962 |
+
print(adapted_cond.shape)
|
963 |
+
adapted_cond = self.get_learned_conditioning(adapted_cond)
|
964 |
+
print(adapted_cond.shape)
|
965 |
+
adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])
|
966 |
+
print(adapted_cond.shape)
|
967 |
+
|
968 |
+
cond_list = [{'c_crossattn': [e]} for e in adapted_cond]
|
969 |
+
|
970 |
+
else:
|
971 |
+
cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient
|
972 |
+
|
973 |
+
# apply model by loop over crops
|
974 |
+
output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]
|
975 |
+
assert not isinstance(output_list[0],
|
976 |
+
tuple) # todo cant deal with multiple model outputs check this never happens
|
977 |
+
|
978 |
+
o = torch.stack(output_list, axis=-1)
|
979 |
+
o = o * weighting
|
980 |
+
# Reverse reshape to img shape
|
981 |
+
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
|
982 |
+
# stitch crops together
|
983 |
+
x_recon = fold(o) / normalization
|
984 |
+
|
985 |
+
else:
|
986 |
+
x_recon = self.model(x_noisy, t, **cond)
|
987 |
+
|
988 |
+
if isinstance(x_recon, tuple) and not return_ids:
|
989 |
+
return x_recon[0]
|
990 |
+
else:
|
991 |
+
return x_recon
|
992 |
+
|
993 |
+
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
|
994 |
+
return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
|
995 |
+
extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
996 |
+
|
997 |
+
def _prior_bpd(self, x_start):
|
998 |
+
"""
|
999 |
+
Get the prior KL term for the variational lower-bound, measured in
|
1000 |
+
bits-per-dim.
|
1001 |
+
This term can't be optimized, as it only depends on the encoder.
|
1002 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
1003 |
+
:return: a batch of [N] KL values (in bits), one per batch element.
|
1004 |
+
"""
|
1005 |
+
batch_size = x_start.shape[0]
|
1006 |
+
t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
|
1007 |
+
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
|
1008 |
+
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
|
1009 |
+
return mean_flat(kl_prior) / np.log(2.0)
|
1010 |
+
|
1011 |
+
def p_losses(self, x_start, cond, t, noise=None):
|
1012 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
1013 |
+
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
|
1014 |
+
model_output = self.apply_model(x_noisy, t, cond)
|
1015 |
+
|
1016 |
+
loss_dict = {}
|
1017 |
+
prefix = 'train' if self.training else 'val'
|
1018 |
+
|
1019 |
+
if self.parameterization == "x0":
|
1020 |
+
target = x_start
|
1021 |
+
elif self.parameterization == "eps":
|
1022 |
+
target = noise
|
1023 |
+
else:
|
1024 |
+
raise NotImplementedError()
|
1025 |
+
|
1026 |
+
loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
|
1027 |
+
loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
|
1028 |
+
|
1029 |
+
logvar_t = self.logvar[t].to(self.device)
|
1030 |
+
loss = loss_simple / torch.exp(logvar_t) + logvar_t
|
1031 |
+
# loss = loss_simple / torch.exp(self.logvar) + self.logvar
|
1032 |
+
if self.learn_logvar:
|
1033 |
+
loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
|
1034 |
+
loss_dict.update({'logvar': self.logvar.data.mean()})
|
1035 |
+
|
1036 |
+
loss = self.l_simple_weight * loss.mean()
|
1037 |
+
|
1038 |
+
loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
|
1039 |
+
loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
|
1040 |
+
loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
|
1041 |
+
loss += (self.original_elbo_weight * loss_vlb)
|
1042 |
+
loss_dict.update({f'{prefix}/loss': loss})
|
1043 |
+
|
1044 |
+
return loss, loss_dict
|
1045 |
+
|
1046 |
+
def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
|
1047 |
+
return_x0=False, score_corrector=None, corrector_kwargs=None):
|
1048 |
+
t_in = t
|
1049 |
+
model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
|
1050 |
+
|
1051 |
+
if score_corrector is not None:
|
1052 |
+
assert self.parameterization == "eps"
|
1053 |
+
model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
|
1054 |
+
|
1055 |
+
if return_codebook_ids:
|
1056 |
+
model_out, logits = model_out
|
1057 |
+
|
1058 |
+
if self.parameterization == "eps":
|
1059 |
+
x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
|
1060 |
+
elif self.parameterization == "x0":
|
1061 |
+
x_recon = model_out
|
1062 |
+
else:
|
1063 |
+
raise NotImplementedError()
|
1064 |
+
|
1065 |
+
if clip_denoised:
|
1066 |
+
x_recon.clamp_(-1., 1.)
|
1067 |
+
if quantize_denoised:
|
1068 |
+
x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
|
1069 |
+
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
1070 |
+
if return_codebook_ids:
|
1071 |
+
return model_mean, posterior_variance, posterior_log_variance, logits
|
1072 |
+
elif return_x0:
|
1073 |
+
return model_mean, posterior_variance, posterior_log_variance, x_recon
|
1074 |
+
else:
|
1075 |
+
return model_mean, posterior_variance, posterior_log_variance
|
1076 |
+
|
1077 |
+
@torch.no_grad()
|
1078 |
+
def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
|
1079 |
+
return_codebook_ids=False, quantize_denoised=False, return_x0=False,
|
1080 |
+
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
|
1081 |
+
b, *_, device = *x.shape, x.device
|
1082 |
+
outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
|
1083 |
+
return_codebook_ids=return_codebook_ids,
|
1084 |
+
quantize_denoised=quantize_denoised,
|
1085 |
+
return_x0=return_x0,
|
1086 |
+
score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
|
1087 |
+
if return_codebook_ids:
|
1088 |
+
raise DeprecationWarning("Support dropped.")
|
1089 |
+
model_mean, _, model_log_variance, logits = outputs
|
1090 |
+
elif return_x0:
|
1091 |
+
model_mean, _, model_log_variance, x0 = outputs
|
1092 |
+
else:
|
1093 |
+
model_mean, _, model_log_variance = outputs
|
1094 |
+
|
1095 |
+
noise = noise_like(x.shape, device, repeat_noise) * temperature
|
1096 |
+
if noise_dropout > 0.:
|
1097 |
+
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
1098 |
+
# no noise when t == 0
|
1099 |
+
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
1100 |
+
|
1101 |
+
if return_codebook_ids:
|
1102 |
+
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
|
1103 |
+
if return_x0:
|
1104 |
+
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
|
1105 |
+
else:
|
1106 |
+
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
1107 |
+
|
1108 |
+
@torch.no_grad()
|
1109 |
+
def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
|
1110 |
+
img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
|
1111 |
+
score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
|
1112 |
+
log_every_t=None):
|
1113 |
+
if not log_every_t:
|
1114 |
+
log_every_t = self.log_every_t
|
1115 |
+
timesteps = self.num_timesteps
|
1116 |
+
if batch_size is not None:
|
1117 |
+
b = batch_size if batch_size is not None else shape[0]
|
1118 |
+
shape = [batch_size] + list(shape)
|
1119 |
+
else:
|
1120 |
+
b = batch_size = shape[0]
|
1121 |
+
if x_T is None:
|
1122 |
+
img = torch.randn(shape, device=self.device)
|
1123 |
+
else:
|
1124 |
+
img = x_T
|
1125 |
+
intermediates = []
|
1126 |
+
if cond is not None:
|
1127 |
+
if isinstance(cond, dict):
|
1128 |
+
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
1129 |
+
list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
|
1130 |
+
else:
|
1131 |
+
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
1132 |
+
|
1133 |
+
if start_T is not None:
|
1134 |
+
timesteps = min(timesteps, start_T)
|
1135 |
+
iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
|
1136 |
+
total=timesteps) if verbose else reversed(
|
1137 |
+
range(0, timesteps))
|
1138 |
+
if type(temperature) == float:
|
1139 |
+
temperature = [temperature] * timesteps
|
1140 |
+
|
1141 |
+
for i in iterator:
|
1142 |
+
ts = torch.full((b,), i, device=self.device, dtype=torch.long)
|
1143 |
+
if self.shorten_cond_schedule:
|
1144 |
+
assert self.model.conditioning_key != 'hybrid'
|
1145 |
+
tc = self.cond_ids[ts].to(cond.device)
|
1146 |
+
cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
|
1147 |
+
|
1148 |
+
img, x0_partial = self.p_sample(img, cond, ts,
|
1149 |
+
clip_denoised=self.clip_denoised,
|
1150 |
+
quantize_denoised=quantize_denoised, return_x0=True,
|
1151 |
+
temperature=temperature[i], noise_dropout=noise_dropout,
|
1152 |
+
score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
|
1153 |
+
if mask is not None:
|
1154 |
+
assert x0 is not None
|
1155 |
+
img_orig = self.q_sample(x0, ts)
|
1156 |
+
img = img_orig * mask + (1. - mask) * img
|
1157 |
+
|
1158 |
+
if i % log_every_t == 0 or i == timesteps - 1:
|
1159 |
+
intermediates.append(x0_partial)
|
1160 |
+
if callback: callback(i)
|
1161 |
+
if img_callback: img_callback(img, i)
|
1162 |
+
return img, intermediates
|
1163 |
+
|
1164 |
+
@torch.no_grad()
|
1165 |
+
def p_sample_loop(self, cond, shape, return_intermediates=False,
|
1166 |
+
x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
|
1167 |
+
mask=None, x0=None, img_callback=None, start_T=None,
|
1168 |
+
log_every_t=None):
|
1169 |
+
|
1170 |
+
if not log_every_t:
|
1171 |
+
log_every_t = self.log_every_t
|
1172 |
+
device = self.betas.device
|
1173 |
+
b = shape[0]
|
1174 |
+
if x_T is None:
|
1175 |
+
img = torch.randn(shape, device=device)
|
1176 |
+
else:
|
1177 |
+
img = x_T
|
1178 |
+
|
1179 |
+
intermediates = [img]
|
1180 |
+
if timesteps is None:
|
1181 |
+
timesteps = self.num_timesteps
|
1182 |
+
|
1183 |
+
if start_T is not None:
|
1184 |
+
timesteps = min(timesteps, start_T)
|
1185 |
+
iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
|
1186 |
+
range(0, timesteps))
|
1187 |
+
|
1188 |
+
if mask is not None:
|
1189 |
+
assert x0 is not None
|
1190 |
+
assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
|
1191 |
+
|
1192 |
+
for i in iterator:
|
1193 |
+
ts = torch.full((b,), i, device=device, dtype=torch.long)
|
1194 |
+
if self.shorten_cond_schedule:
|
1195 |
+
assert self.model.conditioning_key != 'hybrid'
|
1196 |
+
tc = self.cond_ids[ts].to(cond.device)
|
1197 |
+
cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
|
1198 |
+
|
1199 |
+
img = self.p_sample(img, cond, ts,
|
1200 |
+
clip_denoised=self.clip_denoised,
|
1201 |
+
quantize_denoised=quantize_denoised)
|
1202 |
+
if mask is not None:
|
1203 |
+
img_orig = self.q_sample(x0, ts)
|
1204 |
+
img = img_orig * mask + (1. - mask) * img
|
1205 |
+
|
1206 |
+
if i % log_every_t == 0 or i == timesteps - 1:
|
1207 |
+
intermediates.append(img)
|
1208 |
+
if callback: callback(i)
|
1209 |
+
if img_callback: img_callback(img, i)
|
1210 |
+
|
1211 |
+
if return_intermediates:
|
1212 |
+
return img, intermediates
|
1213 |
+
return img
|
1214 |
+
|
1215 |
+
@torch.no_grad()
|
1216 |
+
def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
|
1217 |
+
verbose=True, timesteps=None, quantize_denoised=False,
|
1218 |
+
mask=None, x0=None, shape=None,**kwargs):
|
1219 |
+
if shape is None:
|
1220 |
+
shape = (batch_size, self.channels, self.image_size, self.image_size)
|
1221 |
+
if cond is not None:
|
1222 |
+
if isinstance(cond, dict):
|
1223 |
+
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
1224 |
+
list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
|
1225 |
+
else:
|
1226 |
+
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
1227 |
+
return self.p_sample_loop(cond,
|
1228 |
+
shape,
|
1229 |
+
return_intermediates=return_intermediates, x_T=x_T,
|
1230 |
+
verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
|
1231 |
+
mask=mask, x0=x0)
|
1232 |
+
|
1233 |
+
@torch.no_grad()
|
1234 |
+
def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
|
1235 |
+
|
1236 |
+
if ddim:
|
1237 |
+
ddim_sampler = DDIMSampler(self)
|
1238 |
+
shape = (self.channels, self.image_size, self.image_size)
|
1239 |
+
samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
|
1240 |
+
shape,cond,verbose=False,**kwargs)
|
1241 |
+
|
1242 |
+
else:
|
1243 |
+
samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
|
1244 |
+
return_intermediates=True,**kwargs)
|
1245 |
+
|
1246 |
+
return samples, intermediates
|
1247 |
+
|
1248 |
+
|
1249 |
+
@torch.no_grad()
|
1250 |
+
def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
|
1251 |
+
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
|
1252 |
+
plot_diffusion_rows=True, **kwargs):
|
1253 |
+
|
1254 |
+
use_ddim = ddim_steps is not None
|
1255 |
+
|
1256 |
+
log = dict()
|
1257 |
+
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
|
1258 |
+
return_first_stage_outputs=True,
|
1259 |
+
force_c_encode=True,
|
1260 |
+
return_original_cond=True,
|
1261 |
+
bs=N)
|
1262 |
+
N = min(x.shape[0], N)
|
1263 |
+
n_row = min(x.shape[0], n_row)
|
1264 |
+
log["inputs"] = x
|
1265 |
+
log["reconstruction"] = xrec
|
1266 |
+
if self.model.conditioning_key is not None:
|
1267 |
+
if hasattr(self.cond_stage_model, "decode"):
|
1268 |
+
xc = self.cond_stage_model.decode(c)
|
1269 |
+
log["conditioning"] = xc
|
1270 |
+
elif self.cond_stage_key in ["caption"]:
|
1271 |
+
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"])
|
1272 |
+
log["conditioning"] = xc
|
1273 |
+
elif self.cond_stage_key == 'class_label':
|
1274 |
+
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
|
1275 |
+
log['conditioning'] = xc
|
1276 |
+
elif isimage(xc):
|
1277 |
+
log["conditioning"] = xc
|
1278 |
+
if ismap(xc):
|
1279 |
+
log["original_conditioning"] = self.to_rgb(xc)
|
1280 |
+
|
1281 |
+
if plot_diffusion_rows:
|
1282 |
+
# get diffusion row
|
1283 |
+
diffusion_row = list()
|
1284 |
+
z_start = z[:n_row]
|
1285 |
+
for t in range(self.num_timesteps):
|
1286 |
+
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
1287 |
+
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
|
1288 |
+
t = t.to(self.device).long()
|
1289 |
+
noise = torch.randn_like(z_start)
|
1290 |
+
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
|
1291 |
+
diffusion_row.append(self.decode_first_stage(z_noisy))
|
1292 |
+
|
1293 |
+
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
|
1294 |
+
diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
|
1295 |
+
diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
|
1296 |
+
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
|
1297 |
+
log["diffusion_row"] = diffusion_grid
|
1298 |
+
|
1299 |
+
if sample:
|
1300 |
+
# get denoise row
|
1301 |
+
with self.ema_scope("Plotting"):
|
1302 |
+
samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
|
1303 |
+
ddim_steps=ddim_steps,eta=ddim_eta)
|
1304 |
+
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
|
1305 |
+
x_samples = self.decode_first_stage(samples)
|
1306 |
+
log["samples"] = x_samples
|
1307 |
+
if plot_denoise_rows:
|
1308 |
+
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
|
1309 |
+
log["denoise_row"] = denoise_grid
|
1310 |
+
|
1311 |
+
if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
|
1312 |
+
self.first_stage_model, IdentityFirstStage):
|
1313 |
+
# also display when quantizing x0 while sampling
|
1314 |
+
with self.ema_scope("Plotting Quantized Denoised"):
|
1315 |
+
samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
|
1316 |
+
ddim_steps=ddim_steps,eta=ddim_eta,
|
1317 |
+
quantize_denoised=True)
|
1318 |
+
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
|
1319 |
+
# quantize_denoised=True)
|
1320 |
+
x_samples = self.decode_first_stage(samples.to(self.device))
|
1321 |
+
log["samples_x0_quantized"] = x_samples
|
1322 |
+
|
1323 |
+
if inpaint:
|
1324 |
+
# make a simple center square
|
1325 |
+
b, h, w = z.shape[0], z.shape[2], z.shape[3]
|
1326 |
+
mask = torch.ones(N, h, w).to(self.device)
|
1327 |
+
# zeros will be filled in
|
1328 |
+
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
|
1329 |
+
mask = mask[:, None, ...]
|
1330 |
+
with self.ema_scope("Plotting Inpaint"):
|
1331 |
+
|
1332 |
+
samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
|
1333 |
+
ddim_steps=ddim_steps, x0=z[:N], mask=mask)
|
1334 |
+
x_samples = self.decode_first_stage(samples.to(self.device))
|
1335 |
+
log["samples_inpainting"] = x_samples
|
1336 |
+
log["mask"] = mask
|
1337 |
+
|
1338 |
+
# outpaint
|
1339 |
+
with self.ema_scope("Plotting Outpaint"):
|
1340 |
+
samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
|
1341 |
+
ddim_steps=ddim_steps, x0=z[:N], mask=mask)
|
1342 |
+
x_samples = self.decode_first_stage(samples.to(self.device))
|
1343 |
+
log["samples_outpainting"] = x_samples
|
1344 |
+
|
1345 |
+
if plot_progressive_rows:
|
1346 |
+
with self.ema_scope("Plotting Progressives"):
|
1347 |
+
img, progressives = self.progressive_denoising(c,
|
1348 |
+
shape=(self.channels, self.image_size, self.image_size),
|
1349 |
+
batch_size=N)
|
1350 |
+
prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
|
1351 |
+
log["progressive_row"] = prog_row
|
1352 |
+
|
1353 |
+
if return_keys:
|
1354 |
+
if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
|
1355 |
+
return log
|
1356 |
+
else:
|
1357 |
+
return {key: log[key] for key in return_keys}
|
1358 |
+
return log
|
1359 |
+
|
1360 |
+
def configure_optimizers(self):
|
1361 |
+
lr = self.learning_rate
|
1362 |
+
params = list(self.model.parameters())
|
1363 |
+
if self.cond_stage_trainable:
|
1364 |
+
print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
|
1365 |
+
params = params + list(self.cond_stage_model.parameters())
|
1366 |
+
if self.learn_logvar:
|
1367 |
+
print('Diffusion model optimizing logvar')
|
1368 |
+
params.append(self.logvar)
|
1369 |
+
opt = torch.optim.AdamW(params, lr=lr)
|
1370 |
+
if self.use_scheduler:
|
1371 |
+
assert 'target' in self.scheduler_config
|
1372 |
+
scheduler = instantiate_from_config(self.scheduler_config)
|
1373 |
+
|
1374 |
+
print("Setting up LambdaLR scheduler...")
|
1375 |
+
scheduler = [
|
1376 |
+
{
|
1377 |
+
'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
|
1378 |
+
'interval': 'step',
|
1379 |
+
'frequency': 1
|
1380 |
+
}]
|
1381 |
+
return [opt], scheduler
|
1382 |
+
return opt
|
1383 |
+
|
1384 |
+
@torch.no_grad()
|
1385 |
+
def to_rgb(self, x):
|
1386 |
+
x = x.float()
|
1387 |
+
if not hasattr(self, "colorize"):
|
1388 |
+
self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
|
1389 |
+
x = nn.functional.conv2d(x, weight=self.colorize)
|
1390 |
+
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
|
1391 |
+
return x
|
1392 |
+
|
1393 |
+
|
1394 |
+
class DiffusionWrapperV1(pl.LightningModule):
|
1395 |
+
def __init__(self, diff_model_config, conditioning_key):
|
1396 |
+
super().__init__()
|
1397 |
+
self.diffusion_model = instantiate_from_config(diff_model_config)
|
1398 |
+
self.conditioning_key = conditioning_key
|
1399 |
+
assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm']
|
1400 |
+
|
1401 |
+
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
|
1402 |
+
if self.conditioning_key is None:
|
1403 |
+
out = self.diffusion_model(x, t)
|
1404 |
+
elif self.conditioning_key == 'concat':
|
1405 |
+
xc = torch.cat([x] + c_concat, dim=1)
|
1406 |
+
out = self.diffusion_model(xc, t)
|
1407 |
+
elif self.conditioning_key == 'crossattn':
|
1408 |
+
cc = torch.cat(c_crossattn, 1)
|
1409 |
+
out = self.diffusion_model(x, t, context=cc)
|
1410 |
+
elif self.conditioning_key == 'hybrid':
|
1411 |
+
xc = torch.cat([x] + c_concat, dim=1)
|
1412 |
+
cc = torch.cat(c_crossattn, 1)
|
1413 |
+
out = self.diffusion_model(xc, t, context=cc)
|
1414 |
+
elif self.conditioning_key == 'adm':
|
1415 |
+
cc = c_crossattn[0]
|
1416 |
+
out = self.diffusion_model(x, t, y=cc)
|
1417 |
+
else:
|
1418 |
+
raise NotImplementedError()
|
1419 |
+
|
1420 |
+
return out
|
1421 |
+
|
1422 |
+
|
1423 |
+
class Layout2ImgDiffusionV1(LatentDiffusionV1):
|
1424 |
+
# TODO: move all layout-specific hacks to this class
|
1425 |
+
def __init__(self, cond_stage_key, *args, **kwargs):
|
1426 |
+
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
|
1427 |
+
super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
|
1428 |
+
|
1429 |
+
def log_images(self, batch, N=8, *args, **kwargs):
|
1430 |
+
logs = super().log_images(batch=batch, N=N, *args, **kwargs)
|
1431 |
+
|
1432 |
+
key = 'train' if self.training else 'validation'
|
1433 |
+
dset = self.trainer.datamodule.datasets[key]
|
1434 |
+
mapper = dset.conditional_builders[self.cond_stage_key]
|
1435 |
+
|
1436 |
+
bbox_imgs = []
|
1437 |
+
map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno))
|
1438 |
+
for tknzd_bbox in batch[self.cond_stage_key][:N]:
|
1439 |
+
bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256))
|
1440 |
+
bbox_imgs.append(bboximg)
|
1441 |
+
|
1442 |
+
cond_img = torch.stack(bbox_imgs, dim=0)
|
1443 |
+
logs['bbox_image'] = cond_img
|
1444 |
+
return logs
|
1445 |
+
|
1446 |
+
setattr(ldm.models.diffusion.ddpm, "DDPMV1", DDPMV1)
|
1447 |
+
setattr(ldm.models.diffusion.ddpm, "LatentDiffusionV1", LatentDiffusionV1)
|
1448 |
+
setattr(ldm.models.diffusion.ddpm, "DiffusionWrapperV1", DiffusionWrapperV1)
|
1449 |
+
setattr(ldm.models.diffusion.ddpm, "Layout2ImgDiffusionV1", Layout2ImgDiffusionV1)
|
extensions-builtin/ScuNET/preload.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from modules import paths
|
3 |
+
|
4 |
+
|
5 |
+
def preload(parser):
|
6 |
+
parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(paths.models_path, 'ScuNET'))
|
extensions-builtin/ScuNET/scripts/scunet_model.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
+
import sys
|
3 |
+
import traceback
|
4 |
+
|
5 |
+
import PIL.Image
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from basicsr.utils.download_util import load_file_from_url
|
9 |
+
|
10 |
+
import modules.upscaler
|
11 |
+
from modules import devices, modelloader
|
12 |
+
from scunet_model_arch import SCUNet as net
|
13 |
+
|
14 |
+
|
15 |
+
class UpscalerScuNET(modules.upscaler.Upscaler):
|
16 |
+
def __init__(self, dirname):
|
17 |
+
self.name = "ScuNET"
|
18 |
+
self.model_name = "ScuNET GAN"
|
19 |
+
self.model_name2 = "ScuNET PSNR"
|
20 |
+
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_gan.pth"
|
21 |
+
self.model_url2 = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_psnr.pth"
|
22 |
+
self.user_path = dirname
|
23 |
+
super().__init__()
|
24 |
+
model_paths = self.find_models(ext_filter=[".pth"])
|
25 |
+
scalers = []
|
26 |
+
add_model2 = True
|
27 |
+
for file in model_paths:
|
28 |
+
if "http" in file:
|
29 |
+
name = self.model_name
|
30 |
+
else:
|
31 |
+
name = modelloader.friendly_name(file)
|
32 |
+
if name == self.model_name2 or file == self.model_url2:
|
33 |
+
add_model2 = False
|
34 |
+
try:
|
35 |
+
scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
|
36 |
+
scalers.append(scaler_data)
|
37 |
+
except Exception:
|
38 |
+
print(f"Error loading ScuNET model: {file}", file=sys.stderr)
|
39 |
+
print(traceback.format_exc(), file=sys.stderr)
|
40 |
+
if add_model2:
|
41 |
+
scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self)
|
42 |
+
scalers.append(scaler_data2)
|
43 |
+
self.scalers = scalers
|
44 |
+
|
45 |
+
def do_upscale(self, img: PIL.Image, selected_file):
|
46 |
+
torch.cuda.empty_cache()
|
47 |
+
|
48 |
+
model = self.load_model(selected_file)
|
49 |
+
if model is None:
|
50 |
+
return img
|
51 |
+
|
52 |
+
device = devices.get_device_for('scunet')
|
53 |
+
img = np.array(img)
|
54 |
+
img = img[:, :, ::-1]
|
55 |
+
img = np.moveaxis(img, 2, 0) / 255
|
56 |
+
img = torch.from_numpy(img).float()
|
57 |
+
img = img.unsqueeze(0).to(device)
|
58 |
+
|
59 |
+
with torch.no_grad():
|
60 |
+
output = model(img)
|
61 |
+
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
62 |
+
output = 255. * np.moveaxis(output, 0, 2)
|
63 |
+
output = output.astype(np.uint8)
|
64 |
+
output = output[:, :, ::-1]
|
65 |
+
torch.cuda.empty_cache()
|
66 |
+
return PIL.Image.fromarray(output, 'RGB')
|
67 |
+
|
68 |
+
def load_model(self, path: str):
|
69 |
+
device = devices.get_device_for('scunet')
|
70 |
+
if "http" in path:
|
71 |
+
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
|
72 |
+
progress=True)
|
73 |
+
else:
|
74 |
+
filename = path
|
75 |
+
if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None:
|
76 |
+
print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr)
|
77 |
+
return None
|
78 |
+
|
79 |
+
model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
|
80 |
+
model.load_state_dict(torch.load(filename), strict=True)
|
81 |
+
model.eval()
|
82 |
+
for k, v in model.named_parameters():
|
83 |
+
v.requires_grad = False
|
84 |
+
model = model.to(device)
|
85 |
+
|
86 |
+
return model
|
87 |
+
|
extensions-builtin/ScuNET/scunet_model_arch.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from einops import rearrange
|
6 |
+
from einops.layers.torch import Rearrange
|
7 |
+
from timm.models.layers import trunc_normal_, DropPath
|
8 |
+
|
9 |
+
|
10 |
+
class WMSA(nn.Module):
|
11 |
+
""" Self-attention module in Swin Transformer
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self, input_dim, output_dim, head_dim, window_size, type):
|
15 |
+
super(WMSA, self).__init__()
|
16 |
+
self.input_dim = input_dim
|
17 |
+
self.output_dim = output_dim
|
18 |
+
self.head_dim = head_dim
|
19 |
+
self.scale = self.head_dim ** -0.5
|
20 |
+
self.n_heads = input_dim // head_dim
|
21 |
+
self.window_size = window_size
|
22 |
+
self.type = type
|
23 |
+
self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True)
|
24 |
+
|
25 |
+
self.relative_position_params = nn.Parameter(
|
26 |
+
torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads))
|
27 |
+
|
28 |
+
self.linear = nn.Linear(self.input_dim, self.output_dim)
|
29 |
+
|
30 |
+
trunc_normal_(self.relative_position_params, std=.02)
|
31 |
+
self.relative_position_params = torch.nn.Parameter(
|
32 |
+
self.relative_position_params.view(2 * window_size - 1, 2 * window_size - 1, self.n_heads).transpose(1,
|
33 |
+
2).transpose(
|
34 |
+
0, 1))
|
35 |
+
|
36 |
+
def generate_mask(self, h, w, p, shift):
|
37 |
+
""" generating the mask of SW-MSA
|
38 |
+
Args:
|
39 |
+
shift: shift parameters in CyclicShift.
|
40 |
+
Returns:
|
41 |
+
attn_mask: should be (1 1 w p p),
|
42 |
+
"""
|
43 |
+
# supporting square.
|
44 |
+
attn_mask = torch.zeros(h, w, p, p, p, p, dtype=torch.bool, device=self.relative_position_params.device)
|
45 |
+
if self.type == 'W':
|
46 |
+
return attn_mask
|
47 |
+
|
48 |
+
s = p - shift
|
49 |
+
attn_mask[-1, :, :s, :, s:, :] = True
|
50 |
+
attn_mask[-1, :, s:, :, :s, :] = True
|
51 |
+
attn_mask[:, -1, :, :s, :, s:] = True
|
52 |
+
attn_mask[:, -1, :, s:, :, :s] = True
|
53 |
+
attn_mask = rearrange(attn_mask, 'w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)')
|
54 |
+
return attn_mask
|
55 |
+
|
56 |
+
def forward(self, x):
|
57 |
+
""" Forward pass of Window Multi-head Self-attention module.
|
58 |
+
Args:
|
59 |
+
x: input tensor with shape of [b h w c];
|
60 |
+
attn_mask: attention mask, fill -inf where the value is True;
|
61 |
+
Returns:
|
62 |
+
output: tensor shape [b h w c]
|
63 |
+
"""
|
64 |
+
if self.type != 'W': x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2))
|
65 |
+
x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size)
|
66 |
+
h_windows = x.size(1)
|
67 |
+
w_windows = x.size(2)
|
68 |
+
# square validation
|
69 |
+
# assert h_windows == w_windows
|
70 |
+
|
71 |
+
x = rearrange(x, 'b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c', p1=self.window_size, p2=self.window_size)
|
72 |
+
qkv = self.embedding_layer(x)
|
73 |
+
q, k, v = rearrange(qkv, 'b nw np (threeh c) -> threeh b nw np c', c=self.head_dim).chunk(3, dim=0)
|
74 |
+
sim = torch.einsum('hbwpc,hbwqc->hbwpq', q, k) * self.scale
|
75 |
+
# Adding learnable relative embedding
|
76 |
+
sim = sim + rearrange(self.relative_embedding(), 'h p q -> h 1 1 p q')
|
77 |
+
# Using Attn Mask to distinguish different subwindows.
|
78 |
+
if self.type != 'W':
|
79 |
+
attn_mask = self.generate_mask(h_windows, w_windows, self.window_size, shift=self.window_size // 2)
|
80 |
+
sim = sim.masked_fill_(attn_mask, float("-inf"))
|
81 |
+
|
82 |
+
probs = nn.functional.softmax(sim, dim=-1)
|
83 |
+
output = torch.einsum('hbwij,hbwjc->hbwic', probs, v)
|
84 |
+
output = rearrange(output, 'h b w p c -> b w p (h c)')
|
85 |
+
output = self.linear(output)
|
86 |
+
output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size)
|
87 |
+
|
88 |
+
if self.type != 'W': output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2),
|
89 |
+
dims=(1, 2))
|
90 |
+
return output
|
91 |
+
|
92 |
+
def relative_embedding(self):
|
93 |
+
cord = torch.tensor(np.array([[i, j] for i in range(self.window_size) for j in range(self.window_size)]))
|
94 |
+
relation = cord[:, None, :] - cord[None, :, :] + self.window_size - 1
|
95 |
+
# negative is allowed
|
96 |
+
return self.relative_position_params[:, relation[:, :, 0].long(), relation[:, :, 1].long()]
|
97 |
+
|
98 |
+
|
99 |
+
class Block(nn.Module):
|
100 |
+
def __init__(self, input_dim, output_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
|
101 |
+
""" SwinTransformer Block
|
102 |
+
"""
|
103 |
+
super(Block, self).__init__()
|
104 |
+
self.input_dim = input_dim
|
105 |
+
self.output_dim = output_dim
|
106 |
+
assert type in ['W', 'SW']
|
107 |
+
self.type = type
|
108 |
+
if input_resolution <= window_size:
|
109 |
+
self.type = 'W'
|
110 |
+
|
111 |
+
self.ln1 = nn.LayerNorm(input_dim)
|
112 |
+
self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type)
|
113 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
114 |
+
self.ln2 = nn.LayerNorm(input_dim)
|
115 |
+
self.mlp = nn.Sequential(
|
116 |
+
nn.Linear(input_dim, 4 * input_dim),
|
117 |
+
nn.GELU(),
|
118 |
+
nn.Linear(4 * input_dim, output_dim),
|
119 |
+
)
|
120 |
+
|
121 |
+
def forward(self, x):
|
122 |
+
x = x + self.drop_path(self.msa(self.ln1(x)))
|
123 |
+
x = x + self.drop_path(self.mlp(self.ln2(x)))
|
124 |
+
return x
|
125 |
+
|
126 |
+
|
127 |
+
class ConvTransBlock(nn.Module):
|
128 |
+
def __init__(self, conv_dim, trans_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
|
129 |
+
""" SwinTransformer and Conv Block
|
130 |
+
"""
|
131 |
+
super(ConvTransBlock, self).__init__()
|
132 |
+
self.conv_dim = conv_dim
|
133 |
+
self.trans_dim = trans_dim
|
134 |
+
self.head_dim = head_dim
|
135 |
+
self.window_size = window_size
|
136 |
+
self.drop_path = drop_path
|
137 |
+
self.type = type
|
138 |
+
self.input_resolution = input_resolution
|
139 |
+
|
140 |
+
assert self.type in ['W', 'SW']
|
141 |
+
if self.input_resolution <= self.window_size:
|
142 |
+
self.type = 'W'
|
143 |
+
|
144 |
+
self.trans_block = Block(self.trans_dim, self.trans_dim, self.head_dim, self.window_size, self.drop_path,
|
145 |
+
self.type, self.input_resolution)
|
146 |
+
self.conv1_1 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True)
|
147 |
+
self.conv1_2 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True)
|
148 |
+
|
149 |
+
self.conv_block = nn.Sequential(
|
150 |
+
nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False),
|
151 |
+
nn.ReLU(True),
|
152 |
+
nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False)
|
153 |
+
)
|
154 |
+
|
155 |
+
def forward(self, x):
|
156 |
+
conv_x, trans_x = torch.split(self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1)
|
157 |
+
conv_x = self.conv_block(conv_x) + conv_x
|
158 |
+
trans_x = Rearrange('b c h w -> b h w c')(trans_x)
|
159 |
+
trans_x = self.trans_block(trans_x)
|
160 |
+
trans_x = Rearrange('b h w c -> b c h w')(trans_x)
|
161 |
+
res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1))
|
162 |
+
x = x + res
|
163 |
+
|
164 |
+
return x
|
165 |
+
|
166 |
+
|
167 |
+
class SCUNet(nn.Module):
|
168 |
+
# def __init__(self, in_nc=3, config=[2, 2, 2, 2, 2, 2, 2], dim=64, drop_path_rate=0.0, input_resolution=256):
|
169 |
+
def __init__(self, in_nc=3, config=None, dim=64, drop_path_rate=0.0, input_resolution=256):
|
170 |
+
super(SCUNet, self).__init__()
|
171 |
+
if config is None:
|
172 |
+
config = [2, 2, 2, 2, 2, 2, 2]
|
173 |
+
self.config = config
|
174 |
+
self.dim = dim
|
175 |
+
self.head_dim = 32
|
176 |
+
self.window_size = 8
|
177 |
+
|
178 |
+
# drop path rate for each layer
|
179 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))]
|
180 |
+
|
181 |
+
self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)]
|
182 |
+
|
183 |
+
begin = 0
|
184 |
+
self.m_down1 = [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin],
|
185 |
+
'W' if not i % 2 else 'SW', input_resolution)
|
186 |
+
for i in range(config[0])] + \
|
187 |
+
[nn.Conv2d(dim, 2 * dim, 2, 2, 0, bias=False)]
|
188 |
+
|
189 |
+
begin += config[0]
|
190 |
+
self.m_down2 = [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin],
|
191 |
+
'W' if not i % 2 else 'SW', input_resolution // 2)
|
192 |
+
for i in range(config[1])] + \
|
193 |
+
[nn.Conv2d(2 * dim, 4 * dim, 2, 2, 0, bias=False)]
|
194 |
+
|
195 |
+
begin += config[1]
|
196 |
+
self.m_down3 = [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin],
|
197 |
+
'W' if not i % 2 else 'SW', input_resolution // 4)
|
198 |
+
for i in range(config[2])] + \
|
199 |
+
[nn.Conv2d(4 * dim, 8 * dim, 2, 2, 0, bias=False)]
|
200 |
+
|
201 |
+
begin += config[2]
|
202 |
+
self.m_body = [ConvTransBlock(4 * dim, 4 * dim, self.head_dim, self.window_size, dpr[i + begin],
|
203 |
+
'W' if not i % 2 else 'SW', input_resolution // 8)
|
204 |
+
for i in range(config[3])]
|
205 |
+
|
206 |
+
begin += config[3]
|
207 |
+
self.m_up3 = [nn.ConvTranspose2d(8 * dim, 4 * dim, 2, 2, 0, bias=False), ] + \
|
208 |
+
[ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin],
|
209 |
+
'W' if not i % 2 else 'SW', input_resolution // 4)
|
210 |
+
for i in range(config[4])]
|
211 |
+
|
212 |
+
begin += config[4]
|
213 |
+
self.m_up2 = [nn.ConvTranspose2d(4 * dim, 2 * dim, 2, 2, 0, bias=False), ] + \
|
214 |
+
[ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin],
|
215 |
+
'W' if not i % 2 else 'SW', input_resolution // 2)
|
216 |
+
for i in range(config[5])]
|
217 |
+
|
218 |
+
begin += config[5]
|
219 |
+
self.m_up1 = [nn.ConvTranspose2d(2 * dim, dim, 2, 2, 0, bias=False), ] + \
|
220 |
+
[ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin],
|
221 |
+
'W' if not i % 2 else 'SW', input_resolution)
|
222 |
+
for i in range(config[6])]
|
223 |
+
|
224 |
+
self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)]
|
225 |
+
|
226 |
+
self.m_head = nn.Sequential(*self.m_head)
|
227 |
+
self.m_down1 = nn.Sequential(*self.m_down1)
|
228 |
+
self.m_down2 = nn.Sequential(*self.m_down2)
|
229 |
+
self.m_down3 = nn.Sequential(*self.m_down3)
|
230 |
+
self.m_body = nn.Sequential(*self.m_body)
|
231 |
+
self.m_up3 = nn.Sequential(*self.m_up3)
|
232 |
+
self.m_up2 = nn.Sequential(*self.m_up2)
|
233 |
+
self.m_up1 = nn.Sequential(*self.m_up1)
|
234 |
+
self.m_tail = nn.Sequential(*self.m_tail)
|
235 |
+
# self.apply(self._init_weights)
|
236 |
+
|
237 |
+
def forward(self, x0):
|
238 |
+
|
239 |
+
h, w = x0.size()[-2:]
|
240 |
+
paddingBottom = int(np.ceil(h / 64) * 64 - h)
|
241 |
+
paddingRight = int(np.ceil(w / 64) * 64 - w)
|
242 |
+
x0 = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x0)
|
243 |
+
|
244 |
+
x1 = self.m_head(x0)
|
245 |
+
x2 = self.m_down1(x1)
|
246 |
+
x3 = self.m_down2(x2)
|
247 |
+
x4 = self.m_down3(x3)
|
248 |
+
x = self.m_body(x4)
|
249 |
+
x = self.m_up3(x + x4)
|
250 |
+
x = self.m_up2(x + x3)
|
251 |
+
x = self.m_up1(x + x2)
|
252 |
+
x = self.m_tail(x + x1)
|
253 |
+
|
254 |
+
x = x[..., :h, :w]
|
255 |
+
|
256 |
+
return x
|
257 |
+
|
258 |
+
def _init_weights(self, m):
|
259 |
+
if isinstance(m, nn.Linear):
|
260 |
+
trunc_normal_(m.weight, std=.02)
|
261 |
+
if m.bias is not None:
|
262 |
+
nn.init.constant_(m.bias, 0)
|
263 |
+
elif isinstance(m, nn.LayerNorm):
|
264 |
+
nn.init.constant_(m.bias, 0)
|
265 |
+
nn.init.constant_(m.weight, 1.0)
|
extensions-builtin/SwinIR/preload.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from modules import paths
|
3 |
+
|
4 |
+
|
5 |
+
def preload(parser):
|
6 |
+
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(paths.models_path, 'SwinIR'))
|
extensions-builtin/SwinIR/scripts/swinir_model.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import contextlib
|
2 |
+
import os
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
from basicsr.utils.download_util import load_file_from_url
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
from modules import modelloader, devices, script_callbacks, shared
|
11 |
+
from modules.shared import cmd_opts, opts
|
12 |
+
from swinir_model_arch import SwinIR as net
|
13 |
+
from swinir_model_arch_v2 import Swin2SR as net2
|
14 |
+
from modules.upscaler import Upscaler, UpscalerData
|
15 |
+
|
16 |
+
|
17 |
+
device_swinir = devices.get_device_for('swinir')
|
18 |
+
|
19 |
+
|
20 |
+
class UpscalerSwinIR(Upscaler):
|
21 |
+
def __init__(self, dirname):
|
22 |
+
self.name = "SwinIR"
|
23 |
+
self.model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0" \
|
24 |
+
"/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \
|
25 |
+
"-L_x4_GAN.pth "
|
26 |
+
self.model_name = "SwinIR 4x"
|
27 |
+
self.user_path = dirname
|
28 |
+
super().__init__()
|
29 |
+
scalers = []
|
30 |
+
model_files = self.find_models(ext_filter=[".pt", ".pth"])
|
31 |
+
for model in model_files:
|
32 |
+
if "http" in model:
|
33 |
+
name = self.model_name
|
34 |
+
else:
|
35 |
+
name = modelloader.friendly_name(model)
|
36 |
+
model_data = UpscalerData(name, model, self)
|
37 |
+
scalers.append(model_data)
|
38 |
+
self.scalers = scalers
|
39 |
+
|
40 |
+
def do_upscale(self, img, model_file):
|
41 |
+
model = self.load_model(model_file)
|
42 |
+
if model is None:
|
43 |
+
return img
|
44 |
+
model = model.to(device_swinir, dtype=devices.dtype)
|
45 |
+
img = upscale(img, model)
|
46 |
+
try:
|
47 |
+
torch.cuda.empty_cache()
|
48 |
+
except:
|
49 |
+
pass
|
50 |
+
return img
|
51 |
+
|
52 |
+
def load_model(self, path, scale=4):
|
53 |
+
if "http" in path:
|
54 |
+
dl_name = "%s%s" % (self.model_name.replace(" ", "_"), ".pth")
|
55 |
+
filename = load_file_from_url(url=path, model_dir=self.model_path, file_name=dl_name, progress=True)
|
56 |
+
else:
|
57 |
+
filename = path
|
58 |
+
if filename is None or not os.path.exists(filename):
|
59 |
+
return None
|
60 |
+
if filename.endswith(".v2.pth"):
|
61 |
+
model = net2(
|
62 |
+
upscale=scale,
|
63 |
+
in_chans=3,
|
64 |
+
img_size=64,
|
65 |
+
window_size=8,
|
66 |
+
img_range=1.0,
|
67 |
+
depths=[6, 6, 6, 6, 6, 6],
|
68 |
+
embed_dim=180,
|
69 |
+
num_heads=[6, 6, 6, 6, 6, 6],
|
70 |
+
mlp_ratio=2,
|
71 |
+
upsampler="nearest+conv",
|
72 |
+
resi_connection="1conv",
|
73 |
+
)
|
74 |
+
params = None
|
75 |
+
else:
|
76 |
+
model = net(
|
77 |
+
upscale=scale,
|
78 |
+
in_chans=3,
|
79 |
+
img_size=64,
|
80 |
+
window_size=8,
|
81 |
+
img_range=1.0,
|
82 |
+
depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
|
83 |
+
embed_dim=240,
|
84 |
+
num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
|
85 |
+
mlp_ratio=2,
|
86 |
+
upsampler="nearest+conv",
|
87 |
+
resi_connection="3conv",
|
88 |
+
)
|
89 |
+
params = "params_ema"
|
90 |
+
|
91 |
+
pretrained_model = torch.load(filename)
|
92 |
+
if params is not None:
|
93 |
+
model.load_state_dict(pretrained_model[params], strict=True)
|
94 |
+
else:
|
95 |
+
model.load_state_dict(pretrained_model, strict=True)
|
96 |
+
return model
|
97 |
+
|
98 |
+
|
99 |
+
def upscale(
|
100 |
+
img,
|
101 |
+
model,
|
102 |
+
tile=None,
|
103 |
+
tile_overlap=None,
|
104 |
+
window_size=8,
|
105 |
+
scale=4,
|
106 |
+
):
|
107 |
+
tile = tile or opts.SWIN_tile
|
108 |
+
tile_overlap = tile_overlap or opts.SWIN_tile_overlap
|
109 |
+
|
110 |
+
|
111 |
+
img = np.array(img)
|
112 |
+
img = img[:, :, ::-1]
|
113 |
+
img = np.moveaxis(img, 2, 0) / 255
|
114 |
+
img = torch.from_numpy(img).float()
|
115 |
+
img = img.unsqueeze(0).to(device_swinir, dtype=devices.dtype)
|
116 |
+
with torch.no_grad(), devices.autocast():
|
117 |
+
_, _, h_old, w_old = img.size()
|
118 |
+
h_pad = (h_old // window_size + 1) * window_size - h_old
|
119 |
+
w_pad = (w_old // window_size + 1) * window_size - w_old
|
120 |
+
img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :]
|
121 |
+
img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad]
|
122 |
+
output = inference(img, model, tile, tile_overlap, window_size, scale)
|
123 |
+
output = output[..., : h_old * scale, : w_old * scale]
|
124 |
+
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
125 |
+
if output.ndim == 3:
|
126 |
+
output = np.transpose(
|
127 |
+
output[[2, 1, 0], :, :], (1, 2, 0)
|
128 |
+
) # CHW-RGB to HCW-BGR
|
129 |
+
output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
|
130 |
+
return Image.fromarray(output, "RGB")
|
131 |
+
|
132 |
+
|
133 |
+
def inference(img, model, tile, tile_overlap, window_size, scale):
|
134 |
+
# test the image tile by tile
|
135 |
+
b, c, h, w = img.size()
|
136 |
+
tile = min(tile, h, w)
|
137 |
+
assert tile % window_size == 0, "tile size should be a multiple of window_size"
|
138 |
+
sf = scale
|
139 |
+
|
140 |
+
stride = tile - tile_overlap
|
141 |
+
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
|
142 |
+
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
|
143 |
+
E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device_swinir).type_as(img)
|
144 |
+
W = torch.zeros_like(E, dtype=devices.dtype, device=device_swinir)
|
145 |
+
|
146 |
+
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
|
147 |
+
for h_idx in h_idx_list:
|
148 |
+
for w_idx in w_idx_list:
|
149 |
+
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
|
150 |
+
out_patch = model(in_patch)
|
151 |
+
out_patch_mask = torch.ones_like(out_patch)
|
152 |
+
|
153 |
+
E[
|
154 |
+
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
|
155 |
+
].add_(out_patch)
|
156 |
+
W[
|
157 |
+
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
|
158 |
+
].add_(out_patch_mask)
|
159 |
+
pbar.update(1)
|
160 |
+
output = E.div_(W)
|
161 |
+
|
162 |
+
return output
|
163 |
+
|
164 |
+
|
165 |
+
def on_ui_settings():
|
166 |
+
import gradio as gr
|
167 |
+
|
168 |
+
shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")))
|
169 |
+
shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling")))
|
170 |
+
|
171 |
+
|
172 |
+
script_callbacks.on_ui_settings(on_ui_settings)
|
extensions-builtin/SwinIR/swinir_model_arch.py
ADDED
@@ -0,0 +1,867 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -----------------------------------------------------------------------------------
|
2 |
+
# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
|
3 |
+
# Originally Written by Ze Liu, Modified by Jingyun Liang.
|
4 |
+
# -----------------------------------------------------------------------------------
|
5 |
+
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torch.utils.checkpoint as checkpoint
|
11 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
12 |
+
|
13 |
+
|
14 |
+
class Mlp(nn.Module):
|
15 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
16 |
+
super().__init__()
|
17 |
+
out_features = out_features or in_features
|
18 |
+
hidden_features = hidden_features or in_features
|
19 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
20 |
+
self.act = act_layer()
|
21 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
22 |
+
self.drop = nn.Dropout(drop)
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
x = self.fc1(x)
|
26 |
+
x = self.act(x)
|
27 |
+
x = self.drop(x)
|
28 |
+
x = self.fc2(x)
|
29 |
+
x = self.drop(x)
|
30 |
+
return x
|
31 |
+
|
32 |
+
|
33 |
+
def window_partition(x, window_size):
|
34 |
+
"""
|
35 |
+
Args:
|
36 |
+
x: (B, H, W, C)
|
37 |
+
window_size (int): window size
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
windows: (num_windows*B, window_size, window_size, C)
|
41 |
+
"""
|
42 |
+
B, H, W, C = x.shape
|
43 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
44 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
45 |
+
return windows
|
46 |
+
|
47 |
+
|
48 |
+
def window_reverse(windows, window_size, H, W):
|
49 |
+
"""
|
50 |
+
Args:
|
51 |
+
windows: (num_windows*B, window_size, window_size, C)
|
52 |
+
window_size (int): Window size
|
53 |
+
H (int): Height of image
|
54 |
+
W (int): Width of image
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
x: (B, H, W, C)
|
58 |
+
"""
|
59 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
60 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
61 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
62 |
+
return x
|
63 |
+
|
64 |
+
|
65 |
+
class WindowAttention(nn.Module):
|
66 |
+
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
67 |
+
It supports both of shifted and non-shifted window.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
dim (int): Number of input channels.
|
71 |
+
window_size (tuple[int]): The height and width of the window.
|
72 |
+
num_heads (int): Number of attention heads.
|
73 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
74 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
75 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
76 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
77 |
+
"""
|
78 |
+
|
79 |
+
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
|
80 |
+
|
81 |
+
super().__init__()
|
82 |
+
self.dim = dim
|
83 |
+
self.window_size = window_size # Wh, Ww
|
84 |
+
self.num_heads = num_heads
|
85 |
+
head_dim = dim // num_heads
|
86 |
+
self.scale = qk_scale or head_dim ** -0.5
|
87 |
+
|
88 |
+
# define a parameter table of relative position bias
|
89 |
+
self.relative_position_bias_table = nn.Parameter(
|
90 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
91 |
+
|
92 |
+
# get pair-wise relative position index for each token inside the window
|
93 |
+
coords_h = torch.arange(self.window_size[0])
|
94 |
+
coords_w = torch.arange(self.window_size[1])
|
95 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
96 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
97 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
98 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
99 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
100 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
101 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
102 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
103 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
104 |
+
|
105 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
106 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
107 |
+
self.proj = nn.Linear(dim, dim)
|
108 |
+
|
109 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
110 |
+
|
111 |
+
trunc_normal_(self.relative_position_bias_table, std=.02)
|
112 |
+
self.softmax = nn.Softmax(dim=-1)
|
113 |
+
|
114 |
+
def forward(self, x, mask=None):
|
115 |
+
"""
|
116 |
+
Args:
|
117 |
+
x: input features with shape of (num_windows*B, N, C)
|
118 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
119 |
+
"""
|
120 |
+
B_, N, C = x.shape
|
121 |
+
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
122 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
123 |
+
|
124 |
+
q = q * self.scale
|
125 |
+
attn = (q @ k.transpose(-2, -1))
|
126 |
+
|
127 |
+
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
128 |
+
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
129 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
130 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
131 |
+
|
132 |
+
if mask is not None:
|
133 |
+
nW = mask.shape[0]
|
134 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
135 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
136 |
+
attn = self.softmax(attn)
|
137 |
+
else:
|
138 |
+
attn = self.softmax(attn)
|
139 |
+
|
140 |
+
attn = self.attn_drop(attn)
|
141 |
+
|
142 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
143 |
+
x = self.proj(x)
|
144 |
+
x = self.proj_drop(x)
|
145 |
+
return x
|
146 |
+
|
147 |
+
def extra_repr(self) -> str:
|
148 |
+
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
|
149 |
+
|
150 |
+
def flops(self, N):
|
151 |
+
# calculate flops for 1 window with token length of N
|
152 |
+
flops = 0
|
153 |
+
# qkv = self.qkv(x)
|
154 |
+
flops += N * self.dim * 3 * self.dim
|
155 |
+
# attn = (q @ k.transpose(-2, -1))
|
156 |
+
flops += self.num_heads * N * (self.dim // self.num_heads) * N
|
157 |
+
# x = (attn @ v)
|
158 |
+
flops += self.num_heads * N * N * (self.dim // self.num_heads)
|
159 |
+
# x = self.proj(x)
|
160 |
+
flops += N * self.dim * self.dim
|
161 |
+
return flops
|
162 |
+
|
163 |
+
|
164 |
+
class SwinTransformerBlock(nn.Module):
|
165 |
+
r""" Swin Transformer Block.
|
166 |
+
|
167 |
+
Args:
|
168 |
+
dim (int): Number of input channels.
|
169 |
+
input_resolution (tuple[int]): Input resolution.
|
170 |
+
num_heads (int): Number of attention heads.
|
171 |
+
window_size (int): Window size.
|
172 |
+
shift_size (int): Shift size for SW-MSA.
|
173 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
174 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
175 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
176 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
177 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
178 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
179 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
180 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
181 |
+
"""
|
182 |
+
|
183 |
+
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
|
184 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
|
185 |
+
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
186 |
+
super().__init__()
|
187 |
+
self.dim = dim
|
188 |
+
self.input_resolution = input_resolution
|
189 |
+
self.num_heads = num_heads
|
190 |
+
self.window_size = window_size
|
191 |
+
self.shift_size = shift_size
|
192 |
+
self.mlp_ratio = mlp_ratio
|
193 |
+
if min(self.input_resolution) <= self.window_size:
|
194 |
+
# if window size is larger than input resolution, we don't partition windows
|
195 |
+
self.shift_size = 0
|
196 |
+
self.window_size = min(self.input_resolution)
|
197 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
198 |
+
|
199 |
+
self.norm1 = norm_layer(dim)
|
200 |
+
self.attn = WindowAttention(
|
201 |
+
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
|
202 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
203 |
+
|
204 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
205 |
+
self.norm2 = norm_layer(dim)
|
206 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
207 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
208 |
+
|
209 |
+
if self.shift_size > 0:
|
210 |
+
attn_mask = self.calculate_mask(self.input_resolution)
|
211 |
+
else:
|
212 |
+
attn_mask = None
|
213 |
+
|
214 |
+
self.register_buffer("attn_mask", attn_mask)
|
215 |
+
|
216 |
+
def calculate_mask(self, x_size):
|
217 |
+
# calculate attention mask for SW-MSA
|
218 |
+
H, W = x_size
|
219 |
+
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
220 |
+
h_slices = (slice(0, -self.window_size),
|
221 |
+
slice(-self.window_size, -self.shift_size),
|
222 |
+
slice(-self.shift_size, None))
|
223 |
+
w_slices = (slice(0, -self.window_size),
|
224 |
+
slice(-self.window_size, -self.shift_size),
|
225 |
+
slice(-self.shift_size, None))
|
226 |
+
cnt = 0
|
227 |
+
for h in h_slices:
|
228 |
+
for w in w_slices:
|
229 |
+
img_mask[:, h, w, :] = cnt
|
230 |
+
cnt += 1
|
231 |
+
|
232 |
+
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
233 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
234 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
235 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
236 |
+
|
237 |
+
return attn_mask
|
238 |
+
|
239 |
+
def forward(self, x, x_size):
|
240 |
+
H, W = x_size
|
241 |
+
B, L, C = x.shape
|
242 |
+
# assert L == H * W, "input feature has wrong size"
|
243 |
+
|
244 |
+
shortcut = x
|
245 |
+
x = self.norm1(x)
|
246 |
+
x = x.view(B, H, W, C)
|
247 |
+
|
248 |
+
# cyclic shift
|
249 |
+
if self.shift_size > 0:
|
250 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
251 |
+
else:
|
252 |
+
shifted_x = x
|
253 |
+
|
254 |
+
# partition windows
|
255 |
+
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
256 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
257 |
+
|
258 |
+
# W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
|
259 |
+
if self.input_resolution == x_size:
|
260 |
+
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
|
261 |
+
else:
|
262 |
+
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
|
263 |
+
|
264 |
+
# merge windows
|
265 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
266 |
+
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
267 |
+
|
268 |
+
# reverse cyclic shift
|
269 |
+
if self.shift_size > 0:
|
270 |
+
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
271 |
+
else:
|
272 |
+
x = shifted_x
|
273 |
+
x = x.view(B, H * W, C)
|
274 |
+
|
275 |
+
# FFN
|
276 |
+
x = shortcut + self.drop_path(x)
|
277 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
278 |
+
|
279 |
+
return x
|
280 |
+
|
281 |
+
def extra_repr(self) -> str:
|
282 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
|
283 |
+
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
|
284 |
+
|
285 |
+
def flops(self):
|
286 |
+
flops = 0
|
287 |
+
H, W = self.input_resolution
|
288 |
+
# norm1
|
289 |
+
flops += self.dim * H * W
|
290 |
+
# W-MSA/SW-MSA
|
291 |
+
nW = H * W / self.window_size / self.window_size
|
292 |
+
flops += nW * self.attn.flops(self.window_size * self.window_size)
|
293 |
+
# mlp
|
294 |
+
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
|
295 |
+
# norm2
|
296 |
+
flops += self.dim * H * W
|
297 |
+
return flops
|
298 |
+
|
299 |
+
|
300 |
+
class PatchMerging(nn.Module):
|
301 |
+
r""" Patch Merging Layer.
|
302 |
+
|
303 |
+
Args:
|
304 |
+
input_resolution (tuple[int]): Resolution of input feature.
|
305 |
+
dim (int): Number of input channels.
|
306 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
307 |
+
"""
|
308 |
+
|
309 |
+
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
|
310 |
+
super().__init__()
|
311 |
+
self.input_resolution = input_resolution
|
312 |
+
self.dim = dim
|
313 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
314 |
+
self.norm = norm_layer(4 * dim)
|
315 |
+
|
316 |
+
def forward(self, x):
|
317 |
+
"""
|
318 |
+
x: B, H*W, C
|
319 |
+
"""
|
320 |
+
H, W = self.input_resolution
|
321 |
+
B, L, C = x.shape
|
322 |
+
assert L == H * W, "input feature has wrong size"
|
323 |
+
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
|
324 |
+
|
325 |
+
x = x.view(B, H, W, C)
|
326 |
+
|
327 |
+
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
328 |
+
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
329 |
+
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
330 |
+
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
331 |
+
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
332 |
+
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
333 |
+
|
334 |
+
x = self.norm(x)
|
335 |
+
x = self.reduction(x)
|
336 |
+
|
337 |
+
return x
|
338 |
+
|
339 |
+
def extra_repr(self) -> str:
|
340 |
+
return f"input_resolution={self.input_resolution}, dim={self.dim}"
|
341 |
+
|
342 |
+
def flops(self):
|
343 |
+
H, W = self.input_resolution
|
344 |
+
flops = H * W * self.dim
|
345 |
+
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
|
346 |
+
return flops
|
347 |
+
|
348 |
+
|
349 |
+
class BasicLayer(nn.Module):
|
350 |
+
""" A basic Swin Transformer layer for one stage.
|
351 |
+
|
352 |
+
Args:
|
353 |
+
dim (int): Number of input channels.
|
354 |
+
input_resolution (tuple[int]): Input resolution.
|
355 |
+
depth (int): Number of blocks.
|
356 |
+
num_heads (int): Number of attention heads.
|
357 |
+
window_size (int): Local window size.
|
358 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
359 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
360 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
361 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
362 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
363 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
364 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
365 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
366 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
367 |
+
"""
|
368 |
+
|
369 |
+
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
370 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
|
371 |
+
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
|
372 |
+
|
373 |
+
super().__init__()
|
374 |
+
self.dim = dim
|
375 |
+
self.input_resolution = input_resolution
|
376 |
+
self.depth = depth
|
377 |
+
self.use_checkpoint = use_checkpoint
|
378 |
+
|
379 |
+
# build blocks
|
380 |
+
self.blocks = nn.ModuleList([
|
381 |
+
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
|
382 |
+
num_heads=num_heads, window_size=window_size,
|
383 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
384 |
+
mlp_ratio=mlp_ratio,
|
385 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
386 |
+
drop=drop, attn_drop=attn_drop,
|
387 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
388 |
+
norm_layer=norm_layer)
|
389 |
+
for i in range(depth)])
|
390 |
+
|
391 |
+
# patch merging layer
|
392 |
+
if downsample is not None:
|
393 |
+
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
|
394 |
+
else:
|
395 |
+
self.downsample = None
|
396 |
+
|
397 |
+
def forward(self, x, x_size):
|
398 |
+
for blk in self.blocks:
|
399 |
+
if self.use_checkpoint:
|
400 |
+
x = checkpoint.checkpoint(blk, x, x_size)
|
401 |
+
else:
|
402 |
+
x = blk(x, x_size)
|
403 |
+
if self.downsample is not None:
|
404 |
+
x = self.downsample(x)
|
405 |
+
return x
|
406 |
+
|
407 |
+
def extra_repr(self) -> str:
|
408 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
409 |
+
|
410 |
+
def flops(self):
|
411 |
+
flops = 0
|
412 |
+
for blk in self.blocks:
|
413 |
+
flops += blk.flops()
|
414 |
+
if self.downsample is not None:
|
415 |
+
flops += self.downsample.flops()
|
416 |
+
return flops
|
417 |
+
|
418 |
+
|
419 |
+
class RSTB(nn.Module):
|
420 |
+
"""Residual Swin Transformer Block (RSTB).
|
421 |
+
|
422 |
+
Args:
|
423 |
+
dim (int): Number of input channels.
|
424 |
+
input_resolution (tuple[int]): Input resolution.
|
425 |
+
depth (int): Number of blocks.
|
426 |
+
num_heads (int): Number of attention heads.
|
427 |
+
window_size (int): Local window size.
|
428 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
429 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
430 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
431 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
432 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
433 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
434 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
435 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
436 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
437 |
+
img_size: Input image size.
|
438 |
+
patch_size: Patch size.
|
439 |
+
resi_connection: The convolutional block before residual connection.
|
440 |
+
"""
|
441 |
+
|
442 |
+
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
443 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
|
444 |
+
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
|
445 |
+
img_size=224, patch_size=4, resi_connection='1conv'):
|
446 |
+
super(RSTB, self).__init__()
|
447 |
+
|
448 |
+
self.dim = dim
|
449 |
+
self.input_resolution = input_resolution
|
450 |
+
|
451 |
+
self.residual_group = BasicLayer(dim=dim,
|
452 |
+
input_resolution=input_resolution,
|
453 |
+
depth=depth,
|
454 |
+
num_heads=num_heads,
|
455 |
+
window_size=window_size,
|
456 |
+
mlp_ratio=mlp_ratio,
|
457 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
458 |
+
drop=drop, attn_drop=attn_drop,
|
459 |
+
drop_path=drop_path,
|
460 |
+
norm_layer=norm_layer,
|
461 |
+
downsample=downsample,
|
462 |
+
use_checkpoint=use_checkpoint)
|
463 |
+
|
464 |
+
if resi_connection == '1conv':
|
465 |
+
self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
|
466 |
+
elif resi_connection == '3conv':
|
467 |
+
# to save parameters and memory
|
468 |
+
self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
469 |
+
nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
|
470 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
471 |
+
nn.Conv2d(dim // 4, dim, 3, 1, 1))
|
472 |
+
|
473 |
+
self.patch_embed = PatchEmbed(
|
474 |
+
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
|
475 |
+
norm_layer=None)
|
476 |
+
|
477 |
+
self.patch_unembed = PatchUnEmbed(
|
478 |
+
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
|
479 |
+
norm_layer=None)
|
480 |
+
|
481 |
+
def forward(self, x, x_size):
|
482 |
+
return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
|
483 |
+
|
484 |
+
def flops(self):
|
485 |
+
flops = 0
|
486 |
+
flops += self.residual_group.flops()
|
487 |
+
H, W = self.input_resolution
|
488 |
+
flops += H * W * self.dim * self.dim * 9
|
489 |
+
flops += self.patch_embed.flops()
|
490 |
+
flops += self.patch_unembed.flops()
|
491 |
+
|
492 |
+
return flops
|
493 |
+
|
494 |
+
|
495 |
+
class PatchEmbed(nn.Module):
|
496 |
+
r""" Image to Patch Embedding
|
497 |
+
|
498 |
+
Args:
|
499 |
+
img_size (int): Image size. Default: 224.
|
500 |
+
patch_size (int): Patch token size. Default: 4.
|
501 |
+
in_chans (int): Number of input image channels. Default: 3.
|
502 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
503 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
504 |
+
"""
|
505 |
+
|
506 |
+
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
507 |
+
super().__init__()
|
508 |
+
img_size = to_2tuple(img_size)
|
509 |
+
patch_size = to_2tuple(patch_size)
|
510 |
+
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
511 |
+
self.img_size = img_size
|
512 |
+
self.patch_size = patch_size
|
513 |
+
self.patches_resolution = patches_resolution
|
514 |
+
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
515 |
+
|
516 |
+
self.in_chans = in_chans
|
517 |
+
self.embed_dim = embed_dim
|
518 |
+
|
519 |
+
if norm_layer is not None:
|
520 |
+
self.norm = norm_layer(embed_dim)
|
521 |
+
else:
|
522 |
+
self.norm = None
|
523 |
+
|
524 |
+
def forward(self, x):
|
525 |
+
x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
|
526 |
+
if self.norm is not None:
|
527 |
+
x = self.norm(x)
|
528 |
+
return x
|
529 |
+
|
530 |
+
def flops(self):
|
531 |
+
flops = 0
|
532 |
+
H, W = self.img_size
|
533 |
+
if self.norm is not None:
|
534 |
+
flops += H * W * self.embed_dim
|
535 |
+
return flops
|
536 |
+
|
537 |
+
|
538 |
+
class PatchUnEmbed(nn.Module):
|
539 |
+
r""" Image to Patch Unembedding
|
540 |
+
|
541 |
+
Args:
|
542 |
+
img_size (int): Image size. Default: 224.
|
543 |
+
patch_size (int): Patch token size. Default: 4.
|
544 |
+
in_chans (int): Number of input image channels. Default: 3.
|
545 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
546 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
547 |
+
"""
|
548 |
+
|
549 |
+
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
550 |
+
super().__init__()
|
551 |
+
img_size = to_2tuple(img_size)
|
552 |
+
patch_size = to_2tuple(patch_size)
|
553 |
+
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
554 |
+
self.img_size = img_size
|
555 |
+
self.patch_size = patch_size
|
556 |
+
self.patches_resolution = patches_resolution
|
557 |
+
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
558 |
+
|
559 |
+
self.in_chans = in_chans
|
560 |
+
self.embed_dim = embed_dim
|
561 |
+
|
562 |
+
def forward(self, x, x_size):
|
563 |
+
B, HW, C = x.shape
|
564 |
+
x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
|
565 |
+
return x
|
566 |
+
|
567 |
+
def flops(self):
|
568 |
+
flops = 0
|
569 |
+
return flops
|
570 |
+
|
571 |
+
|
572 |
+
class Upsample(nn.Sequential):
|
573 |
+
"""Upsample module.
|
574 |
+
|
575 |
+
Args:
|
576 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
577 |
+
num_feat (int): Channel number of intermediate features.
|
578 |
+
"""
|
579 |
+
|
580 |
+
def __init__(self, scale, num_feat):
|
581 |
+
m = []
|
582 |
+
if (scale & (scale - 1)) == 0: # scale = 2^n
|
583 |
+
for _ in range(int(math.log(scale, 2))):
|
584 |
+
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
|
585 |
+
m.append(nn.PixelShuffle(2))
|
586 |
+
elif scale == 3:
|
587 |
+
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
|
588 |
+
m.append(nn.PixelShuffle(3))
|
589 |
+
else:
|
590 |
+
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
|
591 |
+
super(Upsample, self).__init__(*m)
|
592 |
+
|
593 |
+
|
594 |
+
class UpsampleOneStep(nn.Sequential):
|
595 |
+
"""UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
|
596 |
+
Used in lightweight SR to save parameters.
|
597 |
+
|
598 |
+
Args:
|
599 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
600 |
+
num_feat (int): Channel number of intermediate features.
|
601 |
+
|
602 |
+
"""
|
603 |
+
|
604 |
+
def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
|
605 |
+
self.num_feat = num_feat
|
606 |
+
self.input_resolution = input_resolution
|
607 |
+
m = []
|
608 |
+
m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
|
609 |
+
m.append(nn.PixelShuffle(scale))
|
610 |
+
super(UpsampleOneStep, self).__init__(*m)
|
611 |
+
|
612 |
+
def flops(self):
|
613 |
+
H, W = self.input_resolution
|
614 |
+
flops = H * W * self.num_feat * 3 * 9
|
615 |
+
return flops
|
616 |
+
|
617 |
+
|
618 |
+
class SwinIR(nn.Module):
|
619 |
+
r""" SwinIR
|
620 |
+
A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
|
621 |
+
|
622 |
+
Args:
|
623 |
+
img_size (int | tuple(int)): Input image size. Default 64
|
624 |
+
patch_size (int | tuple(int)): Patch size. Default: 1
|
625 |
+
in_chans (int): Number of input image channels. Default: 3
|
626 |
+
embed_dim (int): Patch embedding dimension. Default: 96
|
627 |
+
depths (tuple(int)): Depth of each Swin Transformer layer.
|
628 |
+
num_heads (tuple(int)): Number of attention heads in different layers.
|
629 |
+
window_size (int): Window size. Default: 7
|
630 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
631 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
632 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
|
633 |
+
drop_rate (float): Dropout rate. Default: 0
|
634 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0
|
635 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
636 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
637 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
|
638 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True
|
639 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
|
640 |
+
upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
|
641 |
+
img_range: Image range. 1. or 255.
|
642 |
+
upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
|
643 |
+
resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
|
644 |
+
"""
|
645 |
+
|
646 |
+
def __init__(self, img_size=64, patch_size=1, in_chans=3,
|
647 |
+
embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
|
648 |
+
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
|
649 |
+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
650 |
+
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
651 |
+
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
|
652 |
+
**kwargs):
|
653 |
+
super(SwinIR, self).__init__()
|
654 |
+
num_in_ch = in_chans
|
655 |
+
num_out_ch = in_chans
|
656 |
+
num_feat = 64
|
657 |
+
self.img_range = img_range
|
658 |
+
if in_chans == 3:
|
659 |
+
rgb_mean = (0.4488, 0.4371, 0.4040)
|
660 |
+
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
|
661 |
+
else:
|
662 |
+
self.mean = torch.zeros(1, 1, 1, 1)
|
663 |
+
self.upscale = upscale
|
664 |
+
self.upsampler = upsampler
|
665 |
+
self.window_size = window_size
|
666 |
+
|
667 |
+
#####################################################################################################
|
668 |
+
################################### 1, shallow feature extraction ###################################
|
669 |
+
self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
|
670 |
+
|
671 |
+
#####################################################################################################
|
672 |
+
################################### 2, deep feature extraction ######################################
|
673 |
+
self.num_layers = len(depths)
|
674 |
+
self.embed_dim = embed_dim
|
675 |
+
self.ape = ape
|
676 |
+
self.patch_norm = patch_norm
|
677 |
+
self.num_features = embed_dim
|
678 |
+
self.mlp_ratio = mlp_ratio
|
679 |
+
|
680 |
+
# split image into non-overlapping patches
|
681 |
+
self.patch_embed = PatchEmbed(
|
682 |
+
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
|
683 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
684 |
+
num_patches = self.patch_embed.num_patches
|
685 |
+
patches_resolution = self.patch_embed.patches_resolution
|
686 |
+
self.patches_resolution = patches_resolution
|
687 |
+
|
688 |
+
# merge non-overlapping patches into image
|
689 |
+
self.patch_unembed = PatchUnEmbed(
|
690 |
+
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
|
691 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
692 |
+
|
693 |
+
# absolute position embedding
|
694 |
+
if self.ape:
|
695 |
+
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
696 |
+
trunc_normal_(self.absolute_pos_embed, std=.02)
|
697 |
+
|
698 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
699 |
+
|
700 |
+
# stochastic depth
|
701 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
702 |
+
|
703 |
+
# build Residual Swin Transformer blocks (RSTB)
|
704 |
+
self.layers = nn.ModuleList()
|
705 |
+
for i_layer in range(self.num_layers):
|
706 |
+
layer = RSTB(dim=embed_dim,
|
707 |
+
input_resolution=(patches_resolution[0],
|
708 |
+
patches_resolution[1]),
|
709 |
+
depth=depths[i_layer],
|
710 |
+
num_heads=num_heads[i_layer],
|
711 |
+
window_size=window_size,
|
712 |
+
mlp_ratio=self.mlp_ratio,
|
713 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
714 |
+
drop=drop_rate, attn_drop=attn_drop_rate,
|
715 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
|
716 |
+
norm_layer=norm_layer,
|
717 |
+
downsample=None,
|
718 |
+
use_checkpoint=use_checkpoint,
|
719 |
+
img_size=img_size,
|
720 |
+
patch_size=patch_size,
|
721 |
+
resi_connection=resi_connection
|
722 |
+
|
723 |
+
)
|
724 |
+
self.layers.append(layer)
|
725 |
+
self.norm = norm_layer(self.num_features)
|
726 |
+
|
727 |
+
# build the last conv layer in deep feature extraction
|
728 |
+
if resi_connection == '1conv':
|
729 |
+
self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
|
730 |
+
elif resi_connection == '3conv':
|
731 |
+
# to save parameters and memory
|
732 |
+
self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
|
733 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
734 |
+
nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
|
735 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
736 |
+
nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
|
737 |
+
|
738 |
+
#####################################################################################################
|
739 |
+
################################ 3, high quality image reconstruction ################################
|
740 |
+
if self.upsampler == 'pixelshuffle':
|
741 |
+
# for classical SR
|
742 |
+
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
743 |
+
nn.LeakyReLU(inplace=True))
|
744 |
+
self.upsample = Upsample(upscale, num_feat)
|
745 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
746 |
+
elif self.upsampler == 'pixelshuffledirect':
|
747 |
+
# for lightweight SR (to save parameters)
|
748 |
+
self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
|
749 |
+
(patches_resolution[0], patches_resolution[1]))
|
750 |
+
elif self.upsampler == 'nearest+conv':
|
751 |
+
# for real-world SR (less artifacts)
|
752 |
+
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
753 |
+
nn.LeakyReLU(inplace=True))
|
754 |
+
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
755 |
+
if self.upscale == 4:
|
756 |
+
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
757 |
+
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
758 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
759 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
760 |
+
else:
|
761 |
+
# for image denoising and JPEG compression artifact reduction
|
762 |
+
self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
|
763 |
+
|
764 |
+
self.apply(self._init_weights)
|
765 |
+
|
766 |
+
def _init_weights(self, m):
|
767 |
+
if isinstance(m, nn.Linear):
|
768 |
+
trunc_normal_(m.weight, std=.02)
|
769 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
770 |
+
nn.init.constant_(m.bias, 0)
|
771 |
+
elif isinstance(m, nn.LayerNorm):
|
772 |
+
nn.init.constant_(m.bias, 0)
|
773 |
+
nn.init.constant_(m.weight, 1.0)
|
774 |
+
|
775 |
+
@torch.jit.ignore
|
776 |
+
def no_weight_decay(self):
|
777 |
+
return {'absolute_pos_embed'}
|
778 |
+
|
779 |
+
@torch.jit.ignore
|
780 |
+
def no_weight_decay_keywords(self):
|
781 |
+
return {'relative_position_bias_table'}
|
782 |
+
|
783 |
+
def check_image_size(self, x):
|
784 |
+
_, _, h, w = x.size()
|
785 |
+
mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
|
786 |
+
mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
|
787 |
+
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
|
788 |
+
return x
|
789 |
+
|
790 |
+
def forward_features(self, x):
|
791 |
+
x_size = (x.shape[2], x.shape[3])
|
792 |
+
x = self.patch_embed(x)
|
793 |
+
if self.ape:
|
794 |
+
x = x + self.absolute_pos_embed
|
795 |
+
x = self.pos_drop(x)
|
796 |
+
|
797 |
+
for layer in self.layers:
|
798 |
+
x = layer(x, x_size)
|
799 |
+
|
800 |
+
x = self.norm(x) # B L C
|
801 |
+
x = self.patch_unembed(x, x_size)
|
802 |
+
|
803 |
+
return x
|
804 |
+
|
805 |
+
def forward(self, x):
|
806 |
+
H, W = x.shape[2:]
|
807 |
+
x = self.check_image_size(x)
|
808 |
+
|
809 |
+
self.mean = self.mean.type_as(x)
|
810 |
+
x = (x - self.mean) * self.img_range
|
811 |
+
|
812 |
+
if self.upsampler == 'pixelshuffle':
|
813 |
+
# for classical SR
|
814 |
+
x = self.conv_first(x)
|
815 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
816 |
+
x = self.conv_before_upsample(x)
|
817 |
+
x = self.conv_last(self.upsample(x))
|
818 |
+
elif self.upsampler == 'pixelshuffledirect':
|
819 |
+
# for lightweight SR
|
820 |
+
x = self.conv_first(x)
|
821 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
822 |
+
x = self.upsample(x)
|
823 |
+
elif self.upsampler == 'nearest+conv':
|
824 |
+
# for real-world SR
|
825 |
+
x = self.conv_first(x)
|
826 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
827 |
+
x = self.conv_before_upsample(x)
|
828 |
+
x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
|
829 |
+
if self.upscale == 4:
|
830 |
+
x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
|
831 |
+
x = self.conv_last(self.lrelu(self.conv_hr(x)))
|
832 |
+
else:
|
833 |
+
# for image denoising and JPEG compression artifact reduction
|
834 |
+
x_first = self.conv_first(x)
|
835 |
+
res = self.conv_after_body(self.forward_features(x_first)) + x_first
|
836 |
+
x = x + self.conv_last(res)
|
837 |
+
|
838 |
+
x = x / self.img_range + self.mean
|
839 |
+
|
840 |
+
return x[:, :, :H*self.upscale, :W*self.upscale]
|
841 |
+
|
842 |
+
def flops(self):
|
843 |
+
flops = 0
|
844 |
+
H, W = self.patches_resolution
|
845 |
+
flops += H * W * 3 * self.embed_dim * 9
|
846 |
+
flops += self.patch_embed.flops()
|
847 |
+
for i, layer in enumerate(self.layers):
|
848 |
+
flops += layer.flops()
|
849 |
+
flops += H * W * 3 * self.embed_dim * self.embed_dim
|
850 |
+
flops += self.upsample.flops()
|
851 |
+
return flops
|
852 |
+
|
853 |
+
|
854 |
+
if __name__ == '__main__':
|
855 |
+
upscale = 4
|
856 |
+
window_size = 8
|
857 |
+
height = (1024 // upscale // window_size + 1) * window_size
|
858 |
+
width = (720 // upscale // window_size + 1) * window_size
|
859 |
+
model = SwinIR(upscale=2, img_size=(height, width),
|
860 |
+
window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
|
861 |
+
embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
|
862 |
+
print(model)
|
863 |
+
print(height, width, model.flops() / 1e9)
|
864 |
+
|
865 |
+
x = torch.randn((1, 3, height, width))
|
866 |
+
x = model(x)
|
867 |
+
print(x.shape)
|
extensions-builtin/SwinIR/swinir_model_arch_v2.py
ADDED
@@ -0,0 +1,1017 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -----------------------------------------------------------------------------------
|
2 |
+
# Swin2SR: Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration, https://arxiv.org/abs/
|
3 |
+
# Written by Conde and Choi et al.
|
4 |
+
# -----------------------------------------------------------------------------------
|
5 |
+
|
6 |
+
import math
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import torch.utils.checkpoint as checkpoint
|
12 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
13 |
+
|
14 |
+
|
15 |
+
class Mlp(nn.Module):
|
16 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
17 |
+
super().__init__()
|
18 |
+
out_features = out_features or in_features
|
19 |
+
hidden_features = hidden_features or in_features
|
20 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
21 |
+
self.act = act_layer()
|
22 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
23 |
+
self.drop = nn.Dropout(drop)
|
24 |
+
|
25 |
+
def forward(self, x):
|
26 |
+
x = self.fc1(x)
|
27 |
+
x = self.act(x)
|
28 |
+
x = self.drop(x)
|
29 |
+
x = self.fc2(x)
|
30 |
+
x = self.drop(x)
|
31 |
+
return x
|
32 |
+
|
33 |
+
|
34 |
+
def window_partition(x, window_size):
|
35 |
+
"""
|
36 |
+
Args:
|
37 |
+
x: (B, H, W, C)
|
38 |
+
window_size (int): window size
|
39 |
+
Returns:
|
40 |
+
windows: (num_windows*B, window_size, window_size, C)
|
41 |
+
"""
|
42 |
+
B, H, W, C = x.shape
|
43 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
44 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
45 |
+
return windows
|
46 |
+
|
47 |
+
|
48 |
+
def window_reverse(windows, window_size, H, W):
|
49 |
+
"""
|
50 |
+
Args:
|
51 |
+
windows: (num_windows*B, window_size, window_size, C)
|
52 |
+
window_size (int): Window size
|
53 |
+
H (int): Height of image
|
54 |
+
W (int): Width of image
|
55 |
+
Returns:
|
56 |
+
x: (B, H, W, C)
|
57 |
+
"""
|
58 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
59 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
60 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
61 |
+
return x
|
62 |
+
|
63 |
+
class WindowAttention(nn.Module):
|
64 |
+
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
65 |
+
It supports both of shifted and non-shifted window.
|
66 |
+
Args:
|
67 |
+
dim (int): Number of input channels.
|
68 |
+
window_size (tuple[int]): The height and width of the window.
|
69 |
+
num_heads (int): Number of attention heads.
|
70 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
71 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
72 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
73 |
+
pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
|
74 |
+
"""
|
75 |
+
|
76 |
+
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
|
77 |
+
pretrained_window_size=[0, 0]):
|
78 |
+
|
79 |
+
super().__init__()
|
80 |
+
self.dim = dim
|
81 |
+
self.window_size = window_size # Wh, Ww
|
82 |
+
self.pretrained_window_size = pretrained_window_size
|
83 |
+
self.num_heads = num_heads
|
84 |
+
|
85 |
+
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
|
86 |
+
|
87 |
+
# mlp to generate continuous relative position bias
|
88 |
+
self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
|
89 |
+
nn.ReLU(inplace=True),
|
90 |
+
nn.Linear(512, num_heads, bias=False))
|
91 |
+
|
92 |
+
# get relative_coords_table
|
93 |
+
relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
|
94 |
+
relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
|
95 |
+
relative_coords_table = torch.stack(
|
96 |
+
torch.meshgrid([relative_coords_h,
|
97 |
+
relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
|
98 |
+
if pretrained_window_size[0] > 0:
|
99 |
+
relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
|
100 |
+
relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
|
101 |
+
else:
|
102 |
+
relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
|
103 |
+
relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
|
104 |
+
relative_coords_table *= 8 # normalize to -8, 8
|
105 |
+
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
|
106 |
+
torch.abs(relative_coords_table) + 1.0) / np.log2(8)
|
107 |
+
|
108 |
+
self.register_buffer("relative_coords_table", relative_coords_table)
|
109 |
+
|
110 |
+
# get pair-wise relative position index for each token inside the window
|
111 |
+
coords_h = torch.arange(self.window_size[0])
|
112 |
+
coords_w = torch.arange(self.window_size[1])
|
113 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
114 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
115 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
116 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
117 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
118 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
119 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
120 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
121 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
122 |
+
|
123 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=False)
|
124 |
+
if qkv_bias:
|
125 |
+
self.q_bias = nn.Parameter(torch.zeros(dim))
|
126 |
+
self.v_bias = nn.Parameter(torch.zeros(dim))
|
127 |
+
else:
|
128 |
+
self.q_bias = None
|
129 |
+
self.v_bias = None
|
130 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
131 |
+
self.proj = nn.Linear(dim, dim)
|
132 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
133 |
+
self.softmax = nn.Softmax(dim=-1)
|
134 |
+
|
135 |
+
def forward(self, x, mask=None):
|
136 |
+
"""
|
137 |
+
Args:
|
138 |
+
x: input features with shape of (num_windows*B, N, C)
|
139 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
140 |
+
"""
|
141 |
+
B_, N, C = x.shape
|
142 |
+
qkv_bias = None
|
143 |
+
if self.q_bias is not None:
|
144 |
+
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
|
145 |
+
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
146 |
+
qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
147 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
148 |
+
|
149 |
+
# cosine attention
|
150 |
+
attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
|
151 |
+
logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).to(self.logit_scale.device)).exp()
|
152 |
+
attn = attn * logit_scale
|
153 |
+
|
154 |
+
relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
|
155 |
+
relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
156 |
+
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
157 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
158 |
+
relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
|
159 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
160 |
+
|
161 |
+
if mask is not None:
|
162 |
+
nW = mask.shape[0]
|
163 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
164 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
165 |
+
attn = self.softmax(attn)
|
166 |
+
else:
|
167 |
+
attn = self.softmax(attn)
|
168 |
+
|
169 |
+
attn = self.attn_drop(attn)
|
170 |
+
|
171 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
172 |
+
x = self.proj(x)
|
173 |
+
x = self.proj_drop(x)
|
174 |
+
return x
|
175 |
+
|
176 |
+
def extra_repr(self) -> str:
|
177 |
+
return f'dim={self.dim}, window_size={self.window_size}, ' \
|
178 |
+
f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'
|
179 |
+
|
180 |
+
def flops(self, N):
|
181 |
+
# calculate flops for 1 window with token length of N
|
182 |
+
flops = 0
|
183 |
+
# qkv = self.qkv(x)
|
184 |
+
flops += N * self.dim * 3 * self.dim
|
185 |
+
# attn = (q @ k.transpose(-2, -1))
|
186 |
+
flops += self.num_heads * N * (self.dim // self.num_heads) * N
|
187 |
+
# x = (attn @ v)
|
188 |
+
flops += self.num_heads * N * N * (self.dim // self.num_heads)
|
189 |
+
# x = self.proj(x)
|
190 |
+
flops += N * self.dim * self.dim
|
191 |
+
return flops
|
192 |
+
|
193 |
+
class SwinTransformerBlock(nn.Module):
|
194 |
+
r""" Swin Transformer Block.
|
195 |
+
Args:
|
196 |
+
dim (int): Number of input channels.
|
197 |
+
input_resolution (tuple[int]): Input resulotion.
|
198 |
+
num_heads (int): Number of attention heads.
|
199 |
+
window_size (int): Window size.
|
200 |
+
shift_size (int): Shift size for SW-MSA.
|
201 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
202 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
203 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
204 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
205 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
206 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
207 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
208 |
+
pretrained_window_size (int): Window size in pre-training.
|
209 |
+
"""
|
210 |
+
|
211 |
+
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
|
212 |
+
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
|
213 |
+
act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0):
|
214 |
+
super().__init__()
|
215 |
+
self.dim = dim
|
216 |
+
self.input_resolution = input_resolution
|
217 |
+
self.num_heads = num_heads
|
218 |
+
self.window_size = window_size
|
219 |
+
self.shift_size = shift_size
|
220 |
+
self.mlp_ratio = mlp_ratio
|
221 |
+
if min(self.input_resolution) <= self.window_size:
|
222 |
+
# if window size is larger than input resolution, we don't partition windows
|
223 |
+
self.shift_size = 0
|
224 |
+
self.window_size = min(self.input_resolution)
|
225 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
226 |
+
|
227 |
+
self.norm1 = norm_layer(dim)
|
228 |
+
self.attn = WindowAttention(
|
229 |
+
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
|
230 |
+
qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
|
231 |
+
pretrained_window_size=to_2tuple(pretrained_window_size))
|
232 |
+
|
233 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
234 |
+
self.norm2 = norm_layer(dim)
|
235 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
236 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
237 |
+
|
238 |
+
if self.shift_size > 0:
|
239 |
+
attn_mask = self.calculate_mask(self.input_resolution)
|
240 |
+
else:
|
241 |
+
attn_mask = None
|
242 |
+
|
243 |
+
self.register_buffer("attn_mask", attn_mask)
|
244 |
+
|
245 |
+
def calculate_mask(self, x_size):
|
246 |
+
# calculate attention mask for SW-MSA
|
247 |
+
H, W = x_size
|
248 |
+
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
249 |
+
h_slices = (slice(0, -self.window_size),
|
250 |
+
slice(-self.window_size, -self.shift_size),
|
251 |
+
slice(-self.shift_size, None))
|
252 |
+
w_slices = (slice(0, -self.window_size),
|
253 |
+
slice(-self.window_size, -self.shift_size),
|
254 |
+
slice(-self.shift_size, None))
|
255 |
+
cnt = 0
|
256 |
+
for h in h_slices:
|
257 |
+
for w in w_slices:
|
258 |
+
img_mask[:, h, w, :] = cnt
|
259 |
+
cnt += 1
|
260 |
+
|
261 |
+
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
262 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
263 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
264 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
265 |
+
|
266 |
+
return attn_mask
|
267 |
+
|
268 |
+
def forward(self, x, x_size):
|
269 |
+
H, W = x_size
|
270 |
+
B, L, C = x.shape
|
271 |
+
#assert L == H * W, "input feature has wrong size"
|
272 |
+
|
273 |
+
shortcut = x
|
274 |
+
x = x.view(B, H, W, C)
|
275 |
+
|
276 |
+
# cyclic shift
|
277 |
+
if self.shift_size > 0:
|
278 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
279 |
+
else:
|
280 |
+
shifted_x = x
|
281 |
+
|
282 |
+
# partition windows
|
283 |
+
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
284 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
285 |
+
|
286 |
+
# W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
|
287 |
+
if self.input_resolution == x_size:
|
288 |
+
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
|
289 |
+
else:
|
290 |
+
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
|
291 |
+
|
292 |
+
# merge windows
|
293 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
294 |
+
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
295 |
+
|
296 |
+
# reverse cyclic shift
|
297 |
+
if self.shift_size > 0:
|
298 |
+
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
299 |
+
else:
|
300 |
+
x = shifted_x
|
301 |
+
x = x.view(B, H * W, C)
|
302 |
+
x = shortcut + self.drop_path(self.norm1(x))
|
303 |
+
|
304 |
+
# FFN
|
305 |
+
x = x + self.drop_path(self.norm2(self.mlp(x)))
|
306 |
+
|
307 |
+
return x
|
308 |
+
|
309 |
+
def extra_repr(self) -> str:
|
310 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
|
311 |
+
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
|
312 |
+
|
313 |
+
def flops(self):
|
314 |
+
flops = 0
|
315 |
+
H, W = self.input_resolution
|
316 |
+
# norm1
|
317 |
+
flops += self.dim * H * W
|
318 |
+
# W-MSA/SW-MSA
|
319 |
+
nW = H * W / self.window_size / self.window_size
|
320 |
+
flops += nW * self.attn.flops(self.window_size * self.window_size)
|
321 |
+
# mlp
|
322 |
+
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
|
323 |
+
# norm2
|
324 |
+
flops += self.dim * H * W
|
325 |
+
return flops
|
326 |
+
|
327 |
+
class PatchMerging(nn.Module):
|
328 |
+
r""" Patch Merging Layer.
|
329 |
+
Args:
|
330 |
+
input_resolution (tuple[int]): Resolution of input feature.
|
331 |
+
dim (int): Number of input channels.
|
332 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
333 |
+
"""
|
334 |
+
|
335 |
+
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
|
336 |
+
super().__init__()
|
337 |
+
self.input_resolution = input_resolution
|
338 |
+
self.dim = dim
|
339 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
340 |
+
self.norm = norm_layer(2 * dim)
|
341 |
+
|
342 |
+
def forward(self, x):
|
343 |
+
"""
|
344 |
+
x: B, H*W, C
|
345 |
+
"""
|
346 |
+
H, W = self.input_resolution
|
347 |
+
B, L, C = x.shape
|
348 |
+
assert L == H * W, "input feature has wrong size"
|
349 |
+
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
|
350 |
+
|
351 |
+
x = x.view(B, H, W, C)
|
352 |
+
|
353 |
+
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
354 |
+
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
355 |
+
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
356 |
+
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
357 |
+
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
358 |
+
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
359 |
+
|
360 |
+
x = self.reduction(x)
|
361 |
+
x = self.norm(x)
|
362 |
+
|
363 |
+
return x
|
364 |
+
|
365 |
+
def extra_repr(self) -> str:
|
366 |
+
return f"input_resolution={self.input_resolution}, dim={self.dim}"
|
367 |
+
|
368 |
+
def flops(self):
|
369 |
+
H, W = self.input_resolution
|
370 |
+
flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
|
371 |
+
flops += H * W * self.dim // 2
|
372 |
+
return flops
|
373 |
+
|
374 |
+
class BasicLayer(nn.Module):
|
375 |
+
""" A basic Swin Transformer layer for one stage.
|
376 |
+
Args:
|
377 |
+
dim (int): Number of input channels.
|
378 |
+
input_resolution (tuple[int]): Input resolution.
|
379 |
+
depth (int): Number of blocks.
|
380 |
+
num_heads (int): Number of attention heads.
|
381 |
+
window_size (int): Local window size.
|
382 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
383 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
384 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
385 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
386 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
387 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
388 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
389 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
390 |
+
pretrained_window_size (int): Local window size in pre-training.
|
391 |
+
"""
|
392 |
+
|
393 |
+
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
394 |
+
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
|
395 |
+
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
|
396 |
+
pretrained_window_size=0):
|
397 |
+
|
398 |
+
super().__init__()
|
399 |
+
self.dim = dim
|
400 |
+
self.input_resolution = input_resolution
|
401 |
+
self.depth = depth
|
402 |
+
self.use_checkpoint = use_checkpoint
|
403 |
+
|
404 |
+
# build blocks
|
405 |
+
self.blocks = nn.ModuleList([
|
406 |
+
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
|
407 |
+
num_heads=num_heads, window_size=window_size,
|
408 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
409 |
+
mlp_ratio=mlp_ratio,
|
410 |
+
qkv_bias=qkv_bias,
|
411 |
+
drop=drop, attn_drop=attn_drop,
|
412 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
413 |
+
norm_layer=norm_layer,
|
414 |
+
pretrained_window_size=pretrained_window_size)
|
415 |
+
for i in range(depth)])
|
416 |
+
|
417 |
+
# patch merging layer
|
418 |
+
if downsample is not None:
|
419 |
+
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
|
420 |
+
else:
|
421 |
+
self.downsample = None
|
422 |
+
|
423 |
+
def forward(self, x, x_size):
|
424 |
+
for blk in self.blocks:
|
425 |
+
if self.use_checkpoint:
|
426 |
+
x = checkpoint.checkpoint(blk, x, x_size)
|
427 |
+
else:
|
428 |
+
x = blk(x, x_size)
|
429 |
+
if self.downsample is not None:
|
430 |
+
x = self.downsample(x)
|
431 |
+
return x
|
432 |
+
|
433 |
+
def extra_repr(self) -> str:
|
434 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
435 |
+
|
436 |
+
def flops(self):
|
437 |
+
flops = 0
|
438 |
+
for blk in self.blocks:
|
439 |
+
flops += blk.flops()
|
440 |
+
if self.downsample is not None:
|
441 |
+
flops += self.downsample.flops()
|
442 |
+
return flops
|
443 |
+
|
444 |
+
def _init_respostnorm(self):
|
445 |
+
for blk in self.blocks:
|
446 |
+
nn.init.constant_(blk.norm1.bias, 0)
|
447 |
+
nn.init.constant_(blk.norm1.weight, 0)
|
448 |
+
nn.init.constant_(blk.norm2.bias, 0)
|
449 |
+
nn.init.constant_(blk.norm2.weight, 0)
|
450 |
+
|
451 |
+
class PatchEmbed(nn.Module):
|
452 |
+
r""" Image to Patch Embedding
|
453 |
+
Args:
|
454 |
+
img_size (int): Image size. Default: 224.
|
455 |
+
patch_size (int): Patch token size. Default: 4.
|
456 |
+
in_chans (int): Number of input image channels. Default: 3.
|
457 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
458 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
459 |
+
"""
|
460 |
+
|
461 |
+
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
462 |
+
super().__init__()
|
463 |
+
img_size = to_2tuple(img_size)
|
464 |
+
patch_size = to_2tuple(patch_size)
|
465 |
+
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
466 |
+
self.img_size = img_size
|
467 |
+
self.patch_size = patch_size
|
468 |
+
self.patches_resolution = patches_resolution
|
469 |
+
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
470 |
+
|
471 |
+
self.in_chans = in_chans
|
472 |
+
self.embed_dim = embed_dim
|
473 |
+
|
474 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
475 |
+
if norm_layer is not None:
|
476 |
+
self.norm = norm_layer(embed_dim)
|
477 |
+
else:
|
478 |
+
self.norm = None
|
479 |
+
|
480 |
+
def forward(self, x):
|
481 |
+
B, C, H, W = x.shape
|
482 |
+
# FIXME look at relaxing size constraints
|
483 |
+
# assert H == self.img_size[0] and W == self.img_size[1],
|
484 |
+
# f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
485 |
+
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
|
486 |
+
if self.norm is not None:
|
487 |
+
x = self.norm(x)
|
488 |
+
return x
|
489 |
+
|
490 |
+
def flops(self):
|
491 |
+
Ho, Wo = self.patches_resolution
|
492 |
+
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
493 |
+
if self.norm is not None:
|
494 |
+
flops += Ho * Wo * self.embed_dim
|
495 |
+
return flops
|
496 |
+
|
497 |
+
class RSTB(nn.Module):
|
498 |
+
"""Residual Swin Transformer Block (RSTB).
|
499 |
+
|
500 |
+
Args:
|
501 |
+
dim (int): Number of input channels.
|
502 |
+
input_resolution (tuple[int]): Input resolution.
|
503 |
+
depth (int): Number of blocks.
|
504 |
+
num_heads (int): Number of attention heads.
|
505 |
+
window_size (int): Local window size.
|
506 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
507 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
508 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
509 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
510 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
511 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
512 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
513 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
514 |
+
img_size: Input image size.
|
515 |
+
patch_size: Patch size.
|
516 |
+
resi_connection: The convolutional block before residual connection.
|
517 |
+
"""
|
518 |
+
|
519 |
+
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
520 |
+
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
|
521 |
+
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
|
522 |
+
img_size=224, patch_size=4, resi_connection='1conv'):
|
523 |
+
super(RSTB, self).__init__()
|
524 |
+
|
525 |
+
self.dim = dim
|
526 |
+
self.input_resolution = input_resolution
|
527 |
+
|
528 |
+
self.residual_group = BasicLayer(dim=dim,
|
529 |
+
input_resolution=input_resolution,
|
530 |
+
depth=depth,
|
531 |
+
num_heads=num_heads,
|
532 |
+
window_size=window_size,
|
533 |
+
mlp_ratio=mlp_ratio,
|
534 |
+
qkv_bias=qkv_bias,
|
535 |
+
drop=drop, attn_drop=attn_drop,
|
536 |
+
drop_path=drop_path,
|
537 |
+
norm_layer=norm_layer,
|
538 |
+
downsample=downsample,
|
539 |
+
use_checkpoint=use_checkpoint)
|
540 |
+
|
541 |
+
if resi_connection == '1conv':
|
542 |
+
self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
|
543 |
+
elif resi_connection == '3conv':
|
544 |
+
# to save parameters and memory
|
545 |
+
self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
546 |
+
nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
|
547 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
548 |
+
nn.Conv2d(dim // 4, dim, 3, 1, 1))
|
549 |
+
|
550 |
+
self.patch_embed = PatchEmbed(
|
551 |
+
img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim,
|
552 |
+
norm_layer=None)
|
553 |
+
|
554 |
+
self.patch_unembed = PatchUnEmbed(
|
555 |
+
img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim,
|
556 |
+
norm_layer=None)
|
557 |
+
|
558 |
+
def forward(self, x, x_size):
|
559 |
+
return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
|
560 |
+
|
561 |
+
def flops(self):
|
562 |
+
flops = 0
|
563 |
+
flops += self.residual_group.flops()
|
564 |
+
H, W = self.input_resolution
|
565 |
+
flops += H * W * self.dim * self.dim * 9
|
566 |
+
flops += self.patch_embed.flops()
|
567 |
+
flops += self.patch_unembed.flops()
|
568 |
+
|
569 |
+
return flops
|
570 |
+
|
571 |
+
class PatchUnEmbed(nn.Module):
|
572 |
+
r""" Image to Patch Unembedding
|
573 |
+
|
574 |
+
Args:
|
575 |
+
img_size (int): Image size. Default: 224.
|
576 |
+
patch_size (int): Patch token size. Default: 4.
|
577 |
+
in_chans (int): Number of input image channels. Default: 3.
|
578 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
579 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
580 |
+
"""
|
581 |
+
|
582 |
+
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
583 |
+
super().__init__()
|
584 |
+
img_size = to_2tuple(img_size)
|
585 |
+
patch_size = to_2tuple(patch_size)
|
586 |
+
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
587 |
+
self.img_size = img_size
|
588 |
+
self.patch_size = patch_size
|
589 |
+
self.patches_resolution = patches_resolution
|
590 |
+
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
591 |
+
|
592 |
+
self.in_chans = in_chans
|
593 |
+
self.embed_dim = embed_dim
|
594 |
+
|
595 |
+
def forward(self, x, x_size):
|
596 |
+
B, HW, C = x.shape
|
597 |
+
x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
|
598 |
+
return x
|
599 |
+
|
600 |
+
def flops(self):
|
601 |
+
flops = 0
|
602 |
+
return flops
|
603 |
+
|
604 |
+
|
605 |
+
class Upsample(nn.Sequential):
|
606 |
+
"""Upsample module.
|
607 |
+
|
608 |
+
Args:
|
609 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
610 |
+
num_feat (int): Channel number of intermediate features.
|
611 |
+
"""
|
612 |
+
|
613 |
+
def __init__(self, scale, num_feat):
|
614 |
+
m = []
|
615 |
+
if (scale & (scale - 1)) == 0: # scale = 2^n
|
616 |
+
for _ in range(int(math.log(scale, 2))):
|
617 |
+
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
|
618 |
+
m.append(nn.PixelShuffle(2))
|
619 |
+
elif scale == 3:
|
620 |
+
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
|
621 |
+
m.append(nn.PixelShuffle(3))
|
622 |
+
else:
|
623 |
+
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
|
624 |
+
super(Upsample, self).__init__(*m)
|
625 |
+
|
626 |
+
class Upsample_hf(nn.Sequential):
|
627 |
+
"""Upsample module.
|
628 |
+
|
629 |
+
Args:
|
630 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
631 |
+
num_feat (int): Channel number of intermediate features.
|
632 |
+
"""
|
633 |
+
|
634 |
+
def __init__(self, scale, num_feat):
|
635 |
+
m = []
|
636 |
+
if (scale & (scale - 1)) == 0: # scale = 2^n
|
637 |
+
for _ in range(int(math.log(scale, 2))):
|
638 |
+
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
|
639 |
+
m.append(nn.PixelShuffle(2))
|
640 |
+
elif scale == 3:
|
641 |
+
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
|
642 |
+
m.append(nn.PixelShuffle(3))
|
643 |
+
else:
|
644 |
+
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
|
645 |
+
super(Upsample_hf, self).__init__(*m)
|
646 |
+
|
647 |
+
|
648 |
+
class UpsampleOneStep(nn.Sequential):
|
649 |
+
"""UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
|
650 |
+
Used in lightweight SR to save parameters.
|
651 |
+
|
652 |
+
Args:
|
653 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
654 |
+
num_feat (int): Channel number of intermediate features.
|
655 |
+
|
656 |
+
"""
|
657 |
+
|
658 |
+
def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
|
659 |
+
self.num_feat = num_feat
|
660 |
+
self.input_resolution = input_resolution
|
661 |
+
m = []
|
662 |
+
m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
|
663 |
+
m.append(nn.PixelShuffle(scale))
|
664 |
+
super(UpsampleOneStep, self).__init__(*m)
|
665 |
+
|
666 |
+
def flops(self):
|
667 |
+
H, W = self.input_resolution
|
668 |
+
flops = H * W * self.num_feat * 3 * 9
|
669 |
+
return flops
|
670 |
+
|
671 |
+
|
672 |
+
|
673 |
+
class Swin2SR(nn.Module):
|
674 |
+
r""" Swin2SR
|
675 |
+
A PyTorch impl of : `Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration`.
|
676 |
+
|
677 |
+
Args:
|
678 |
+
img_size (int | tuple(int)): Input image size. Default 64
|
679 |
+
patch_size (int | tuple(int)): Patch size. Default: 1
|
680 |
+
in_chans (int): Number of input image channels. Default: 3
|
681 |
+
embed_dim (int): Patch embedding dimension. Default: 96
|
682 |
+
depths (tuple(int)): Depth of each Swin Transformer layer.
|
683 |
+
num_heads (tuple(int)): Number of attention heads in different layers.
|
684 |
+
window_size (int): Window size. Default: 7
|
685 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
686 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
687 |
+
drop_rate (float): Dropout rate. Default: 0
|
688 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0
|
689 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
690 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
691 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
|
692 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True
|
693 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
|
694 |
+
upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
|
695 |
+
img_range: Image range. 1. or 255.
|
696 |
+
upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
|
697 |
+
resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
|
698 |
+
"""
|
699 |
+
|
700 |
+
def __init__(self, img_size=64, patch_size=1, in_chans=3,
|
701 |
+
embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
|
702 |
+
window_size=7, mlp_ratio=4., qkv_bias=True,
|
703 |
+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
704 |
+
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
705 |
+
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
|
706 |
+
**kwargs):
|
707 |
+
super(Swin2SR, self).__init__()
|
708 |
+
num_in_ch = in_chans
|
709 |
+
num_out_ch = in_chans
|
710 |
+
num_feat = 64
|
711 |
+
self.img_range = img_range
|
712 |
+
if in_chans == 3:
|
713 |
+
rgb_mean = (0.4488, 0.4371, 0.4040)
|
714 |
+
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
|
715 |
+
else:
|
716 |
+
self.mean = torch.zeros(1, 1, 1, 1)
|
717 |
+
self.upscale = upscale
|
718 |
+
self.upsampler = upsampler
|
719 |
+
self.window_size = window_size
|
720 |
+
|
721 |
+
#####################################################################################################
|
722 |
+
################################### 1, shallow feature extraction ###################################
|
723 |
+
self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
|
724 |
+
|
725 |
+
#####################################################################################################
|
726 |
+
################################### 2, deep feature extraction ######################################
|
727 |
+
self.num_layers = len(depths)
|
728 |
+
self.embed_dim = embed_dim
|
729 |
+
self.ape = ape
|
730 |
+
self.patch_norm = patch_norm
|
731 |
+
self.num_features = embed_dim
|
732 |
+
self.mlp_ratio = mlp_ratio
|
733 |
+
|
734 |
+
# split image into non-overlapping patches
|
735 |
+
self.patch_embed = PatchEmbed(
|
736 |
+
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
|
737 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
738 |
+
num_patches = self.patch_embed.num_patches
|
739 |
+
patches_resolution = self.patch_embed.patches_resolution
|
740 |
+
self.patches_resolution = patches_resolution
|
741 |
+
|
742 |
+
# merge non-overlapping patches into image
|
743 |
+
self.patch_unembed = PatchUnEmbed(
|
744 |
+
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
|
745 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
746 |
+
|
747 |
+
# absolute position embedding
|
748 |
+
if self.ape:
|
749 |
+
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
750 |
+
trunc_normal_(self.absolute_pos_embed, std=.02)
|
751 |
+
|
752 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
753 |
+
|
754 |
+
# stochastic depth
|
755 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
756 |
+
|
757 |
+
# build Residual Swin Transformer blocks (RSTB)
|
758 |
+
self.layers = nn.ModuleList()
|
759 |
+
for i_layer in range(self.num_layers):
|
760 |
+
layer = RSTB(dim=embed_dim,
|
761 |
+
input_resolution=(patches_resolution[0],
|
762 |
+
patches_resolution[1]),
|
763 |
+
depth=depths[i_layer],
|
764 |
+
num_heads=num_heads[i_layer],
|
765 |
+
window_size=window_size,
|
766 |
+
mlp_ratio=self.mlp_ratio,
|
767 |
+
qkv_bias=qkv_bias,
|
768 |
+
drop=drop_rate, attn_drop=attn_drop_rate,
|
769 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
|
770 |
+
norm_layer=norm_layer,
|
771 |
+
downsample=None,
|
772 |
+
use_checkpoint=use_checkpoint,
|
773 |
+
img_size=img_size,
|
774 |
+
patch_size=patch_size,
|
775 |
+
resi_connection=resi_connection
|
776 |
+
|
777 |
+
)
|
778 |
+
self.layers.append(layer)
|
779 |
+
|
780 |
+
if self.upsampler == 'pixelshuffle_hf':
|
781 |
+
self.layers_hf = nn.ModuleList()
|
782 |
+
for i_layer in range(self.num_layers):
|
783 |
+
layer = RSTB(dim=embed_dim,
|
784 |
+
input_resolution=(patches_resolution[0],
|
785 |
+
patches_resolution[1]),
|
786 |
+
depth=depths[i_layer],
|
787 |
+
num_heads=num_heads[i_layer],
|
788 |
+
window_size=window_size,
|
789 |
+
mlp_ratio=self.mlp_ratio,
|
790 |
+
qkv_bias=qkv_bias,
|
791 |
+
drop=drop_rate, attn_drop=attn_drop_rate,
|
792 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
|
793 |
+
norm_layer=norm_layer,
|
794 |
+
downsample=None,
|
795 |
+
use_checkpoint=use_checkpoint,
|
796 |
+
img_size=img_size,
|
797 |
+
patch_size=patch_size,
|
798 |
+
resi_connection=resi_connection
|
799 |
+
|
800 |
+
)
|
801 |
+
self.layers_hf.append(layer)
|
802 |
+
|
803 |
+
self.norm = norm_layer(self.num_features)
|
804 |
+
|
805 |
+
# build the last conv layer in deep feature extraction
|
806 |
+
if resi_connection == '1conv':
|
807 |
+
self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
|
808 |
+
elif resi_connection == '3conv':
|
809 |
+
# to save parameters and memory
|
810 |
+
self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
|
811 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
812 |
+
nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
|
813 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
814 |
+
nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
|
815 |
+
|
816 |
+
#####################################################################################################
|
817 |
+
################################ 3, high quality image reconstruction ################################
|
818 |
+
if self.upsampler == 'pixelshuffle':
|
819 |
+
# for classical SR
|
820 |
+
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
821 |
+
nn.LeakyReLU(inplace=True))
|
822 |
+
self.upsample = Upsample(upscale, num_feat)
|
823 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
824 |
+
elif self.upsampler == 'pixelshuffle_aux':
|
825 |
+
self.conv_bicubic = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
826 |
+
self.conv_before_upsample = nn.Sequential(
|
827 |
+
nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
828 |
+
nn.LeakyReLU(inplace=True))
|
829 |
+
self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
830 |
+
self.conv_after_aux = nn.Sequential(
|
831 |
+
nn.Conv2d(3, num_feat, 3, 1, 1),
|
832 |
+
nn.LeakyReLU(inplace=True))
|
833 |
+
self.upsample = Upsample(upscale, num_feat)
|
834 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
835 |
+
|
836 |
+
elif self.upsampler == 'pixelshuffle_hf':
|
837 |
+
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
838 |
+
nn.LeakyReLU(inplace=True))
|
839 |
+
self.upsample = Upsample(upscale, num_feat)
|
840 |
+
self.upsample_hf = Upsample_hf(upscale, num_feat)
|
841 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
842 |
+
self.conv_first_hf = nn.Sequential(nn.Conv2d(num_feat, embed_dim, 3, 1, 1),
|
843 |
+
nn.LeakyReLU(inplace=True))
|
844 |
+
self.conv_after_body_hf = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
|
845 |
+
self.conv_before_upsample_hf = nn.Sequential(
|
846 |
+
nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
847 |
+
nn.LeakyReLU(inplace=True))
|
848 |
+
self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
849 |
+
|
850 |
+
elif self.upsampler == 'pixelshuffledirect':
|
851 |
+
# for lightweight SR (to save parameters)
|
852 |
+
self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
|
853 |
+
(patches_resolution[0], patches_resolution[1]))
|
854 |
+
elif self.upsampler == 'nearest+conv':
|
855 |
+
# for real-world SR (less artifacts)
|
856 |
+
assert self.upscale == 4, 'only support x4 now.'
|
857 |
+
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
858 |
+
nn.LeakyReLU(inplace=True))
|
859 |
+
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
860 |
+
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
861 |
+
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
862 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
863 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
864 |
+
else:
|
865 |
+
# for image denoising and JPEG compression artifact reduction
|
866 |
+
self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
|
867 |
+
|
868 |
+
self.apply(self._init_weights)
|
869 |
+
|
870 |
+
def _init_weights(self, m):
|
871 |
+
if isinstance(m, nn.Linear):
|
872 |
+
trunc_normal_(m.weight, std=.02)
|
873 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
874 |
+
nn.init.constant_(m.bias, 0)
|
875 |
+
elif isinstance(m, nn.LayerNorm):
|
876 |
+
nn.init.constant_(m.bias, 0)
|
877 |
+
nn.init.constant_(m.weight, 1.0)
|
878 |
+
|
879 |
+
@torch.jit.ignore
|
880 |
+
def no_weight_decay(self):
|
881 |
+
return {'absolute_pos_embed'}
|
882 |
+
|
883 |
+
@torch.jit.ignore
|
884 |
+
def no_weight_decay_keywords(self):
|
885 |
+
return {'relative_position_bias_table'}
|
886 |
+
|
887 |
+
def check_image_size(self, x):
|
888 |
+
_, _, h, w = x.size()
|
889 |
+
mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
|
890 |
+
mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
|
891 |
+
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
|
892 |
+
return x
|
893 |
+
|
894 |
+
def forward_features(self, x):
|
895 |
+
x_size = (x.shape[2], x.shape[3])
|
896 |
+
x = self.patch_embed(x)
|
897 |
+
if self.ape:
|
898 |
+
x = x + self.absolute_pos_embed
|
899 |
+
x = self.pos_drop(x)
|
900 |
+
|
901 |
+
for layer in self.layers:
|
902 |
+
x = layer(x, x_size)
|
903 |
+
|
904 |
+
x = self.norm(x) # B L C
|
905 |
+
x = self.patch_unembed(x, x_size)
|
906 |
+
|
907 |
+
return x
|
908 |
+
|
909 |
+
def forward_features_hf(self, x):
|
910 |
+
x_size = (x.shape[2], x.shape[3])
|
911 |
+
x = self.patch_embed(x)
|
912 |
+
if self.ape:
|
913 |
+
x = x + self.absolute_pos_embed
|
914 |
+
x = self.pos_drop(x)
|
915 |
+
|
916 |
+
for layer in self.layers_hf:
|
917 |
+
x = layer(x, x_size)
|
918 |
+
|
919 |
+
x = self.norm(x) # B L C
|
920 |
+
x = self.patch_unembed(x, x_size)
|
921 |
+
|
922 |
+
return x
|
923 |
+
|
924 |
+
def forward(self, x):
|
925 |
+
H, W = x.shape[2:]
|
926 |
+
x = self.check_image_size(x)
|
927 |
+
|
928 |
+
self.mean = self.mean.type_as(x)
|
929 |
+
x = (x - self.mean) * self.img_range
|
930 |
+
|
931 |
+
if self.upsampler == 'pixelshuffle':
|
932 |
+
# for classical SR
|
933 |
+
x = self.conv_first(x)
|
934 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
935 |
+
x = self.conv_before_upsample(x)
|
936 |
+
x = self.conv_last(self.upsample(x))
|
937 |
+
elif self.upsampler == 'pixelshuffle_aux':
|
938 |
+
bicubic = F.interpolate(x, size=(H * self.upscale, W * self.upscale), mode='bicubic', align_corners=False)
|
939 |
+
bicubic = self.conv_bicubic(bicubic)
|
940 |
+
x = self.conv_first(x)
|
941 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
942 |
+
x = self.conv_before_upsample(x)
|
943 |
+
aux = self.conv_aux(x) # b, 3, LR_H, LR_W
|
944 |
+
x = self.conv_after_aux(aux)
|
945 |
+
x = self.upsample(x)[:, :, :H * self.upscale, :W * self.upscale] + bicubic[:, :, :H * self.upscale, :W * self.upscale]
|
946 |
+
x = self.conv_last(x)
|
947 |
+
aux = aux / self.img_range + self.mean
|
948 |
+
elif self.upsampler == 'pixelshuffle_hf':
|
949 |
+
# for classical SR with HF
|
950 |
+
x = self.conv_first(x)
|
951 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
952 |
+
x_before = self.conv_before_upsample(x)
|
953 |
+
x_out = self.conv_last(self.upsample(x_before))
|
954 |
+
|
955 |
+
x_hf = self.conv_first_hf(x_before)
|
956 |
+
x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf
|
957 |
+
x_hf = self.conv_before_upsample_hf(x_hf)
|
958 |
+
x_hf = self.conv_last_hf(self.upsample_hf(x_hf))
|
959 |
+
x = x_out + x_hf
|
960 |
+
x_hf = x_hf / self.img_range + self.mean
|
961 |
+
|
962 |
+
elif self.upsampler == 'pixelshuffledirect':
|
963 |
+
# for lightweight SR
|
964 |
+
x = self.conv_first(x)
|
965 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
966 |
+
x = self.upsample(x)
|
967 |
+
elif self.upsampler == 'nearest+conv':
|
968 |
+
# for real-world SR
|
969 |
+
x = self.conv_first(x)
|
970 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
971 |
+
x = self.conv_before_upsample(x)
|
972 |
+
x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
|
973 |
+
x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
|
974 |
+
x = self.conv_last(self.lrelu(self.conv_hr(x)))
|
975 |
+
else:
|
976 |
+
# for image denoising and JPEG compression artifact reduction
|
977 |
+
x_first = self.conv_first(x)
|
978 |
+
res = self.conv_after_body(self.forward_features(x_first)) + x_first
|
979 |
+
x = x + self.conv_last(res)
|
980 |
+
|
981 |
+
x = x / self.img_range + self.mean
|
982 |
+
if self.upsampler == "pixelshuffle_aux":
|
983 |
+
return x[:, :, :H*self.upscale, :W*self.upscale], aux
|
984 |
+
|
985 |
+
elif self.upsampler == "pixelshuffle_hf":
|
986 |
+
x_out = x_out / self.img_range + self.mean
|
987 |
+
return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale]
|
988 |
+
|
989 |
+
else:
|
990 |
+
return x[:, :, :H*self.upscale, :W*self.upscale]
|
991 |
+
|
992 |
+
def flops(self):
|
993 |
+
flops = 0
|
994 |
+
H, W = self.patches_resolution
|
995 |
+
flops += H * W * 3 * self.embed_dim * 9
|
996 |
+
flops += self.patch_embed.flops()
|
997 |
+
for i, layer in enumerate(self.layers):
|
998 |
+
flops += layer.flops()
|
999 |
+
flops += H * W * 3 * self.embed_dim * self.embed_dim
|
1000 |
+
flops += self.upsample.flops()
|
1001 |
+
return flops
|
1002 |
+
|
1003 |
+
|
1004 |
+
if __name__ == '__main__':
|
1005 |
+
upscale = 4
|
1006 |
+
window_size = 8
|
1007 |
+
height = (1024 // upscale // window_size + 1) * window_size
|
1008 |
+
width = (720 // upscale // window_size + 1) * window_size
|
1009 |
+
model = Swin2SR(upscale=2, img_size=(height, width),
|
1010 |
+
window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
|
1011 |
+
embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
|
1012 |
+
print(model)
|
1013 |
+
print(height, width, model.flops() / 1e9)
|
1014 |
+
|
1015 |
+
x = torch.randn((1, 3, height, width))
|
1016 |
+
x = model(x)
|
1017 |
+
print(x.shape)
|
extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Stable Diffusion WebUI - Bracket checker
|
2 |
+
// Version 1.0
|
3 |
+
// By Hingashi no Florin/Bwin4L
|
4 |
+
// Counts open and closed brackets (round, square, curly) in the prompt and negative prompt text boxes in the txt2img and img2img tabs.
|
5 |
+
// If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong.
|
6 |
+
|
7 |
+
function checkBrackets(evt) {
|
8 |
+
textArea = evt.target;
|
9 |
+
tabName = evt.target.parentElement.parentElement.id.split("_")[0];
|
10 |
+
counterElt = document.querySelector('gradio-app').shadowRoot.querySelector('#' + tabName + '_token_counter');
|
11 |
+
|
12 |
+
promptName = evt.target.parentElement.parentElement.id.includes('neg') ? ' negative' : '';
|
13 |
+
|
14 |
+
errorStringParen = '(' + tabName + promptName + ' prompt) - Different number of opening and closing parentheses detected.\n';
|
15 |
+
errorStringSquare = '[' + tabName + promptName + ' prompt] - Different number of opening and closing square brackets detected.\n';
|
16 |
+
errorStringCurly = '{' + tabName + promptName + ' prompt} - Different number of opening and closing curly brackets detected.\n';
|
17 |
+
|
18 |
+
openBracketRegExp = /\(/g;
|
19 |
+
closeBracketRegExp = /\)/g;
|
20 |
+
|
21 |
+
openSquareBracketRegExp = /\[/g;
|
22 |
+
closeSquareBracketRegExp = /\]/g;
|
23 |
+
|
24 |
+
openCurlyBracketRegExp = /\{/g;
|
25 |
+
closeCurlyBracketRegExp = /\}/g;
|
26 |
+
|
27 |
+
totalOpenBracketMatches = 0;
|
28 |
+
totalCloseBracketMatches = 0;
|
29 |
+
totalOpenSquareBracketMatches = 0;
|
30 |
+
totalCloseSquareBracketMatches = 0;
|
31 |
+
totalOpenCurlyBracketMatches = 0;
|
32 |
+
totalCloseCurlyBracketMatches = 0;
|
33 |
+
|
34 |
+
openBracketMatches = textArea.value.match(openBracketRegExp);
|
35 |
+
if(openBracketMatches) {
|
36 |
+
totalOpenBracketMatches = openBracketMatches.length;
|
37 |
+
}
|
38 |
+
|
39 |
+
closeBracketMatches = textArea.value.match(closeBracketRegExp);
|
40 |
+
if(closeBracketMatches) {
|
41 |
+
totalCloseBracketMatches = closeBracketMatches.length;
|
42 |
+
}
|
43 |
+
|
44 |
+
openSquareBracketMatches = textArea.value.match(openSquareBracketRegExp);
|
45 |
+
if(openSquareBracketMatches) {
|
46 |
+
totalOpenSquareBracketMatches = openSquareBracketMatches.length;
|
47 |
+
}
|
48 |
+
|
49 |
+
closeSquareBracketMatches = textArea.value.match(closeSquareBracketRegExp);
|
50 |
+
if(closeSquareBracketMatches) {
|
51 |
+
totalCloseSquareBracketMatches = closeSquareBracketMatches.length;
|
52 |
+
}
|
53 |
+
|
54 |
+
openCurlyBracketMatches = textArea.value.match(openCurlyBracketRegExp);
|
55 |
+
if(openCurlyBracketMatches) {
|
56 |
+
totalOpenCurlyBracketMatches = openCurlyBracketMatches.length;
|
57 |
+
}
|
58 |
+
|
59 |
+
closeCurlyBracketMatches = textArea.value.match(closeCurlyBracketRegExp);
|
60 |
+
if(closeCurlyBracketMatches) {
|
61 |
+
totalCloseCurlyBracketMatches = closeCurlyBracketMatches.length;
|
62 |
+
}
|
63 |
+
|
64 |
+
if(totalOpenBracketMatches != totalCloseBracketMatches) {
|
65 |
+
if(!counterElt.title.includes(errorStringParen)) {
|
66 |
+
counterElt.title += errorStringParen;
|
67 |
+
}
|
68 |
+
} else {
|
69 |
+
counterElt.title = counterElt.title.replace(errorStringParen, '');
|
70 |
+
}
|
71 |
+
|
72 |
+
if(totalOpenSquareBracketMatches != totalCloseSquareBracketMatches) {
|
73 |
+
if(!counterElt.title.includes(errorStringSquare)) {
|
74 |
+
counterElt.title += errorStringSquare;
|
75 |
+
}
|
76 |
+
} else {
|
77 |
+
counterElt.title = counterElt.title.replace(errorStringSquare, '');
|
78 |
+
}
|
79 |
+
|
80 |
+
if(totalOpenCurlyBracketMatches != totalCloseCurlyBracketMatches) {
|
81 |
+
if(!counterElt.title.includes(errorStringCurly)) {
|
82 |
+
counterElt.title += errorStringCurly;
|
83 |
+
}
|
84 |
+
} else {
|
85 |
+
counterElt.title = counterElt.title.replace(errorStringCurly, '');
|
86 |
+
}
|
87 |
+
|
88 |
+
if(counterElt.title != '') {
|
89 |
+
counterElt.style = 'color: #FF5555;';
|
90 |
+
} else {
|
91 |
+
counterElt.style = '';
|
92 |
+
}
|
93 |
+
}
|
94 |
+
|
95 |
+
var shadowRootLoaded = setInterval(function() {
|
96 |
+
var shadowTextArea = document.querySelector('gradio-app').shadowRoot.querySelectorAll('#txt2img_prompt > label > textarea');
|
97 |
+
if(shadowTextArea.length < 1) {
|
98 |
+
return false;
|
99 |
+
}
|
100 |
+
|
101 |
+
clearInterval(shadowRootLoaded);
|
102 |
+
|
103 |
+
document.querySelector('gradio-app').shadowRoot.querySelector('#txt2img_prompt').onkeyup = checkBrackets;
|
104 |
+
document.querySelector('gradio-app').shadowRoot.querySelector('#txt2img_neg_prompt').onkeyup = checkBrackets;
|
105 |
+
document.querySelector('gradio-app').shadowRoot.querySelector('#img2img_prompt').onkeyup = checkBrackets;
|
106 |
+
document.querySelector('gradio-app').shadowRoot.querySelector('#img2img_neg_prompt').onkeyup = checkBrackets;
|
107 |
+
}, 1000);
|
javascript/aspectRatioOverlay.js
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
let currentWidth = null;
|
3 |
+
let currentHeight = null;
|
4 |
+
let arFrameTimeout = setTimeout(function(){},0);
|
5 |
+
|
6 |
+
function dimensionChange(e, is_width, is_height){
|
7 |
+
|
8 |
+
if(is_width){
|
9 |
+
currentWidth = e.target.value*1.0
|
10 |
+
}
|
11 |
+
if(is_height){
|
12 |
+
currentHeight = e.target.value*1.0
|
13 |
+
}
|
14 |
+
|
15 |
+
var inImg2img = Boolean(gradioApp().querySelector("button.rounded-t-lg.border-gray-200"))
|
16 |
+
|
17 |
+
if(!inImg2img){
|
18 |
+
return;
|
19 |
+
}
|
20 |
+
|
21 |
+
var targetElement = null;
|
22 |
+
|
23 |
+
var tabIndex = get_tab_index('mode_img2img')
|
24 |
+
if(tabIndex == 0){
|
25 |
+
targetElement = gradioApp().querySelector('div[data-testid=image] img');
|
26 |
+
} else if(tabIndex == 1){
|
27 |
+
targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img');
|
28 |
+
}
|
29 |
+
|
30 |
+
if(targetElement){
|
31 |
+
|
32 |
+
var arPreviewRect = gradioApp().querySelector('#imageARPreview');
|
33 |
+
if(!arPreviewRect){
|
34 |
+
arPreviewRect = document.createElement('div')
|
35 |
+
arPreviewRect.id = "imageARPreview";
|
36 |
+
gradioApp().getRootNode().appendChild(arPreviewRect)
|
37 |
+
}
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
var viewportOffset = targetElement.getBoundingClientRect();
|
42 |
+
|
43 |
+
viewportscale = Math.min( targetElement.clientWidth/targetElement.naturalWidth, targetElement.clientHeight/targetElement.naturalHeight )
|
44 |
+
|
45 |
+
scaledx = targetElement.naturalWidth*viewportscale
|
46 |
+
scaledy = targetElement.naturalHeight*viewportscale
|
47 |
+
|
48 |
+
cleintRectTop = (viewportOffset.top+window.scrollY)
|
49 |
+
cleintRectLeft = (viewportOffset.left+window.scrollX)
|
50 |
+
cleintRectCentreY = cleintRectTop + (targetElement.clientHeight/2)
|
51 |
+
cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth/2)
|
52 |
+
|
53 |
+
viewRectTop = cleintRectCentreY-(scaledy/2)
|
54 |
+
viewRectLeft = cleintRectCentreX-(scaledx/2)
|
55 |
+
arRectWidth = scaledx
|
56 |
+
arRectHeight = scaledy
|
57 |
+
|
58 |
+
arscale = Math.min( arRectWidth/currentWidth, arRectHeight/currentHeight )
|
59 |
+
arscaledx = currentWidth*arscale
|
60 |
+
arscaledy = currentHeight*arscale
|
61 |
+
|
62 |
+
arRectTop = cleintRectCentreY-(arscaledy/2)
|
63 |
+
arRectLeft = cleintRectCentreX-(arscaledx/2)
|
64 |
+
arRectWidth = arscaledx
|
65 |
+
arRectHeight = arscaledy
|
66 |
+
|
67 |
+
arPreviewRect.style.top = arRectTop+'px';
|
68 |
+
arPreviewRect.style.left = arRectLeft+'px';
|
69 |
+
arPreviewRect.style.width = arRectWidth+'px';
|
70 |
+
arPreviewRect.style.height = arRectHeight+'px';
|
71 |
+
|
72 |
+
clearTimeout(arFrameTimeout);
|
73 |
+
arFrameTimeout = setTimeout(function(){
|
74 |
+
arPreviewRect.style.display = 'none';
|
75 |
+
},2000);
|
76 |
+
|
77 |
+
arPreviewRect.style.display = 'block';
|
78 |
+
|
79 |
+
}
|
80 |
+
|
81 |
+
}
|
82 |
+
|
83 |
+
|
84 |
+
onUiUpdate(function(){
|
85 |
+
var arPreviewRect = gradioApp().querySelector('#imageARPreview');
|
86 |
+
if(arPreviewRect){
|
87 |
+
arPreviewRect.style.display = 'none';
|
88 |
+
}
|
89 |
+
var inImg2img = Boolean(gradioApp().querySelector("button.rounded-t-lg.border-gray-200"))
|
90 |
+
if(inImg2img){
|
91 |
+
let inputs = gradioApp().querySelectorAll('input');
|
92 |
+
inputs.forEach(function(e){
|
93 |
+
var is_width = e.parentElement.id == "img2img_width"
|
94 |
+
var is_height = e.parentElement.id == "img2img_height"
|
95 |
+
|
96 |
+
if((is_width || is_height) && !e.classList.contains('scrollwatch')){
|
97 |
+
e.addEventListener('input', function(e){dimensionChange(e, is_width, is_height)} )
|
98 |
+
e.classList.add('scrollwatch')
|
99 |
+
}
|
100 |
+
if(is_width){
|
101 |
+
currentWidth = e.value*1.0
|
102 |
+
}
|
103 |
+
if(is_height){
|
104 |
+
currentHeight = e.value*1.0
|
105 |
+
}
|
106 |
+
})
|
107 |
+
}
|
108 |
+
});
|
javascript/contextMenus.js
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
contextMenuInit = function(){
|
3 |
+
let eventListenerApplied=false;
|
4 |
+
let menuSpecs = new Map();
|
5 |
+
|
6 |
+
const uid = function(){
|
7 |
+
return Date.now().toString(36) + Math.random().toString(36).substr(2);
|
8 |
+
}
|
9 |
+
|
10 |
+
function showContextMenu(event,element,menuEntries){
|
11 |
+
let posx = event.clientX + document.body.scrollLeft + document.documentElement.scrollLeft;
|
12 |
+
let posy = event.clientY + document.body.scrollTop + document.documentElement.scrollTop;
|
13 |
+
|
14 |
+
let oldMenu = gradioApp().querySelector('#context-menu')
|
15 |
+
if(oldMenu){
|
16 |
+
oldMenu.remove()
|
17 |
+
}
|
18 |
+
|
19 |
+
let tabButton = uiCurrentTab
|
20 |
+
let baseStyle = window.getComputedStyle(tabButton)
|
21 |
+
|
22 |
+
const contextMenu = document.createElement('nav')
|
23 |
+
contextMenu.id = "context-menu"
|
24 |
+
contextMenu.style.background = baseStyle.background
|
25 |
+
contextMenu.style.color = baseStyle.color
|
26 |
+
contextMenu.style.fontFamily = baseStyle.fontFamily
|
27 |
+
contextMenu.style.top = posy+'px'
|
28 |
+
contextMenu.style.left = posx+'px'
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
const contextMenuList = document.createElement('ul')
|
33 |
+
contextMenuList.className = 'context-menu-items';
|
34 |
+
contextMenu.append(contextMenuList);
|
35 |
+
|
36 |
+
menuEntries.forEach(function(entry){
|
37 |
+
let contextMenuEntry = document.createElement('a')
|
38 |
+
contextMenuEntry.innerHTML = entry['name']
|
39 |
+
contextMenuEntry.addEventListener("click", function(e) {
|
40 |
+
entry['func']();
|
41 |
+
})
|
42 |
+
contextMenuList.append(contextMenuEntry);
|
43 |
+
|
44 |
+
})
|
45 |
+
|
46 |
+
gradioApp().getRootNode().appendChild(contextMenu)
|
47 |
+
|
48 |
+
let menuWidth = contextMenu.offsetWidth + 4;
|
49 |
+
let menuHeight = contextMenu.offsetHeight + 4;
|
50 |
+
|
51 |
+
let windowWidth = window.innerWidth;
|
52 |
+
let windowHeight = window.innerHeight;
|
53 |
+
|
54 |
+
if ( (windowWidth - posx) < menuWidth ) {
|
55 |
+
contextMenu.style.left = windowWidth - menuWidth + "px";
|
56 |
+
}
|
57 |
+
|
58 |
+
if ( (windowHeight - posy) < menuHeight ) {
|
59 |
+
contextMenu.style.top = windowHeight - menuHeight + "px";
|
60 |
+
}
|
61 |
+
|
62 |
+
}
|
63 |
+
|
64 |
+
function appendContextMenuOption(targetElementSelector,entryName,entryFunction){
|
65 |
+
|
66 |
+
currentItems = menuSpecs.get(targetElementSelector)
|
67 |
+
|
68 |
+
if(!currentItems){
|
69 |
+
currentItems = []
|
70 |
+
menuSpecs.set(targetElementSelector,currentItems);
|
71 |
+
}
|
72 |
+
let newItem = {'id':targetElementSelector+'_'+uid(),
|
73 |
+
'name':entryName,
|
74 |
+
'func':entryFunction,
|
75 |
+
'isNew':true}
|
76 |
+
|
77 |
+
currentItems.push(newItem)
|
78 |
+
return newItem['id']
|
79 |
+
}
|
80 |
+
|
81 |
+
function removeContextMenuOption(uid){
|
82 |
+
menuSpecs.forEach(function(v,k) {
|
83 |
+
let index = -1
|
84 |
+
v.forEach(function(e,ei){if(e['id']==uid){index=ei}})
|
85 |
+
if(index>=0){
|
86 |
+
v.splice(index, 1);
|
87 |
+
}
|
88 |
+
})
|
89 |
+
}
|
90 |
+
|
91 |
+
function addContextMenuEventListener(){
|
92 |
+
if(eventListenerApplied){
|
93 |
+
return;
|
94 |
+
}
|
95 |
+
gradioApp().addEventListener("click", function(e) {
|
96 |
+
let source = e.composedPath()[0]
|
97 |
+
if(source.id && source.id.indexOf('check_progress')>-1){
|
98 |
+
return
|
99 |
+
}
|
100 |
+
|
101 |
+
let oldMenu = gradioApp().querySelector('#context-menu')
|
102 |
+
if(oldMenu){
|
103 |
+
oldMenu.remove()
|
104 |
+
}
|
105 |
+
});
|
106 |
+
gradioApp().addEventListener("contextmenu", function(e) {
|
107 |
+
let oldMenu = gradioApp().querySelector('#context-menu')
|
108 |
+
if(oldMenu){
|
109 |
+
oldMenu.remove()
|
110 |
+
}
|
111 |
+
menuSpecs.forEach(function(v,k) {
|
112 |
+
if(e.composedPath()[0].matches(k)){
|
113 |
+
showContextMenu(e,e.composedPath()[0],v)
|
114 |
+
e.preventDefault()
|
115 |
+
return
|
116 |
+
}
|
117 |
+
})
|
118 |
+
});
|
119 |
+
eventListenerApplied=true
|
120 |
+
|
121 |
+
}
|
122 |
+
|
123 |
+
return [appendContextMenuOption, removeContextMenuOption, addContextMenuEventListener]
|
124 |
+
}
|
125 |
+
|
126 |
+
initResponse = contextMenuInit();
|
127 |
+
appendContextMenuOption = initResponse[0];
|
128 |
+
removeContextMenuOption = initResponse[1];
|
129 |
+
addContextMenuEventListener = initResponse[2];
|
130 |
+
|
131 |
+
(function(){
|
132 |
+
//Start example Context Menu Items
|
133 |
+
let generateOnRepeat = function(genbuttonid,interruptbuttonid){
|
134 |
+
let genbutton = gradioApp().querySelector(genbuttonid);
|
135 |
+
let interruptbutton = gradioApp().querySelector(interruptbuttonid);
|
136 |
+
if(!interruptbutton.offsetParent){
|
137 |
+
genbutton.click();
|
138 |
+
}
|
139 |
+
clearInterval(window.generateOnRepeatInterval)
|
140 |
+
window.generateOnRepeatInterval = setInterval(function(){
|
141 |
+
if(!interruptbutton.offsetParent){
|
142 |
+
genbutton.click();
|
143 |
+
}
|
144 |
+
},
|
145 |
+
500)
|
146 |
+
}
|
147 |
+
|
148 |
+
appendContextMenuOption('#txt2img_generate','Generate forever',function(){
|
149 |
+
generateOnRepeat('#txt2img_generate','#txt2img_interrupt');
|
150 |
+
})
|
151 |
+
appendContextMenuOption('#img2img_generate','Generate forever',function(){
|
152 |
+
generateOnRepeat('#img2img_generate','#img2img_interrupt');
|
153 |
+
})
|
154 |
+
|
155 |
+
let cancelGenerateForever = function(){
|
156 |
+
clearInterval(window.generateOnRepeatInterval)
|
157 |
+
}
|
158 |
+
|
159 |
+
appendContextMenuOption('#txt2img_interrupt','Cancel generate forever',cancelGenerateForever)
|
160 |
+
appendContextMenuOption('#txt2img_generate', 'Cancel generate forever',cancelGenerateForever)
|
161 |
+
appendContextMenuOption('#img2img_interrupt','Cancel generate forever',cancelGenerateForever)
|
162 |
+
appendContextMenuOption('#img2img_generate', 'Cancel generate forever',cancelGenerateForever)
|
163 |
+
|
164 |
+
appendContextMenuOption('#roll','Roll three',
|
165 |
+
function(){
|
166 |
+
let rollbutton = get_uiCurrentTabContent().querySelector('#roll');
|
167 |
+
setTimeout(function(){rollbutton.click()},100)
|
168 |
+
setTimeout(function(){rollbutton.click()},200)
|
169 |
+
setTimeout(function(){rollbutton.click()},300)
|
170 |
+
}
|
171 |
+
)
|
172 |
+
})();
|
173 |
+
//End example Context Menu Items
|
174 |
+
|
175 |
+
onUiUpdate(function(){
|
176 |
+
addContextMenuEventListener()
|
177 |
+
});
|
javascript/dragdrop.js
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// allows drag-dropping files into gradio image elements, and also pasting images from clipboard
|
2 |
+
|
3 |
+
function isValidImageList( files ) {
|
4 |
+
return files && files?.length === 1 && ['image/png', 'image/gif', 'image/jpeg'].includes(files[0].type);
|
5 |
+
}
|
6 |
+
|
7 |
+
function dropReplaceImage( imgWrap, files ) {
|
8 |
+
if ( ! isValidImageList( files ) ) {
|
9 |
+
return;
|
10 |
+
}
|
11 |
+
|
12 |
+
imgWrap.querySelector('.modify-upload button + button, .touch-none + div button + button')?.click();
|
13 |
+
const callback = () => {
|
14 |
+
const fileInput = imgWrap.querySelector('input[type="file"]');
|
15 |
+
if ( fileInput ) {
|
16 |
+
fileInput.files = files;
|
17 |
+
fileInput.dispatchEvent(new Event('change'));
|
18 |
+
}
|
19 |
+
};
|
20 |
+
|
21 |
+
if ( imgWrap.closest('#pnginfo_image') ) {
|
22 |
+
// special treatment for PNG Info tab, wait for fetch request to finish
|
23 |
+
const oldFetch = window.fetch;
|
24 |
+
window.fetch = async (input, options) => {
|
25 |
+
const response = await oldFetch(input, options);
|
26 |
+
if ( 'api/predict/' === input ) {
|
27 |
+
const content = await response.text();
|
28 |
+
window.fetch = oldFetch;
|
29 |
+
window.requestAnimationFrame( () => callback() );
|
30 |
+
return new Response(content, {
|
31 |
+
status: response.status,
|
32 |
+
statusText: response.statusText,
|
33 |
+
headers: response.headers
|
34 |
+
})
|
35 |
+
}
|
36 |
+
return response;
|
37 |
+
};
|
38 |
+
} else {
|
39 |
+
window.requestAnimationFrame( () => callback() );
|
40 |
+
}
|
41 |
+
}
|
42 |
+
|
43 |
+
window.document.addEventListener('dragover', e => {
|
44 |
+
const target = e.composedPath()[0];
|
45 |
+
const imgWrap = target.closest('[data-testid="image"]');
|
46 |
+
if ( !imgWrap && target.placeholder && target.placeholder.indexOf("Prompt") == -1) {
|
47 |
+
return;
|
48 |
+
}
|
49 |
+
e.stopPropagation();
|
50 |
+
e.preventDefault();
|
51 |
+
e.dataTransfer.dropEffect = 'copy';
|
52 |
+
});
|
53 |
+
|
54 |
+
window.document.addEventListener('drop', e => {
|
55 |
+
const target = e.composedPath()[0];
|
56 |
+
if (target.placeholder.indexOf("Prompt") == -1) {
|
57 |
+
return;
|
58 |
+
}
|
59 |
+
const imgWrap = target.closest('[data-testid="image"]');
|
60 |
+
if ( !imgWrap ) {
|
61 |
+
return;
|
62 |
+
}
|
63 |
+
e.stopPropagation();
|
64 |
+
e.preventDefault();
|
65 |
+
const files = e.dataTransfer.files;
|
66 |
+
dropReplaceImage( imgWrap, files );
|
67 |
+
});
|
68 |
+
|
69 |
+
window.addEventListener('paste', e => {
|
70 |
+
const files = e.clipboardData.files;
|
71 |
+
if ( ! isValidImageList( files ) ) {
|
72 |
+
return;
|
73 |
+
}
|
74 |
+
|
75 |
+
const visibleImageFields = [...gradioApp().querySelectorAll('[data-testid="image"]')]
|
76 |
+
.filter(el => uiElementIsVisible(el));
|
77 |
+
if ( ! visibleImageFields.length ) {
|
78 |
+
return;
|
79 |
+
}
|
80 |
+
|
81 |
+
const firstFreeImageField = visibleImageFields
|
82 |
+
.filter(el => el.querySelector('input[type=file]'))?.[0];
|
83 |
+
|
84 |
+
dropReplaceImage(
|
85 |
+
firstFreeImageField ?
|
86 |
+
firstFreeImageField :
|
87 |
+
visibleImageFields[visibleImageFields.length - 1]
|
88 |
+
, files );
|
89 |
+
});
|
javascript/edit-attention.js
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
addEventListener('keydown', (event) => {
|
2 |
+
let target = event.originalTarget || event.composedPath()[0];
|
3 |
+
if (!target.matches("#toprow textarea.gr-text-input[placeholder]")) return;
|
4 |
+
if (! (event.metaKey || event.ctrlKey)) return;
|
5 |
+
|
6 |
+
|
7 |
+
let plus = "ArrowUp"
|
8 |
+
let minus = "ArrowDown"
|
9 |
+
if (event.key != plus && event.key != minus) return;
|
10 |
+
|
11 |
+
let selectionStart = target.selectionStart;
|
12 |
+
let selectionEnd = target.selectionEnd;
|
13 |
+
// If the user hasn't selected anything, let's select their current parenthesis block
|
14 |
+
if (selectionStart === selectionEnd) {
|
15 |
+
// Find opening parenthesis around current cursor
|
16 |
+
const before = target.value.substring(0, selectionStart);
|
17 |
+
let beforeParen = before.lastIndexOf("(");
|
18 |
+
if (beforeParen == -1) return;
|
19 |
+
let beforeParenClose = before.lastIndexOf(")");
|
20 |
+
while (beforeParenClose !== -1 && beforeParenClose > beforeParen) {
|
21 |
+
beforeParen = before.lastIndexOf("(", beforeParen - 1);
|
22 |
+
beforeParenClose = before.lastIndexOf(")", beforeParenClose - 1);
|
23 |
+
}
|
24 |
+
|
25 |
+
// Find closing parenthesis around current cursor
|
26 |
+
const after = target.value.substring(selectionStart);
|
27 |
+
let afterParen = after.indexOf(")");
|
28 |
+
if (afterParen == -1) return;
|
29 |
+
let afterParenOpen = after.indexOf("(");
|
30 |
+
while (afterParenOpen !== -1 && afterParen > afterParenOpen) {
|
31 |
+
afterParen = after.indexOf(")", afterParen + 1);
|
32 |
+
afterParenOpen = after.indexOf("(", afterParenOpen + 1);
|
33 |
+
}
|
34 |
+
if (beforeParen === -1 || afterParen === -1) return;
|
35 |
+
|
36 |
+
// Set the selection to the text between the parenthesis
|
37 |
+
const parenContent = target.value.substring(beforeParen + 1, selectionStart + afterParen);
|
38 |
+
const lastColon = parenContent.lastIndexOf(":");
|
39 |
+
selectionStart = beforeParen + 1;
|
40 |
+
selectionEnd = selectionStart + lastColon;
|
41 |
+
target.setSelectionRange(selectionStart, selectionEnd);
|
42 |
+
}
|
43 |
+
|
44 |
+
event.preventDefault();
|
45 |
+
|
46 |
+
if (selectionStart == 0 || target.value[selectionStart - 1] != "(") {
|
47 |
+
target.value = target.value.slice(0, selectionStart) +
|
48 |
+
"(" + target.value.slice(selectionStart, selectionEnd) + ":1.0)" +
|
49 |
+
target.value.slice(selectionEnd);
|
50 |
+
|
51 |
+
target.focus();
|
52 |
+
target.selectionStart = selectionStart + 1;
|
53 |
+
target.selectionEnd = selectionEnd + 1;
|
54 |
+
|
55 |
+
} else {
|
56 |
+
end = target.value.slice(selectionEnd + 1).indexOf(")") + 1;
|
57 |
+
weight = parseFloat(target.value.slice(selectionEnd + 1, selectionEnd + 1 + end));
|
58 |
+
if (isNaN(weight)) return;
|
59 |
+
if (event.key == minus) weight -= 0.1;
|
60 |
+
if (event.key == plus) weight += 0.1;
|
61 |
+
|
62 |
+
weight = parseFloat(weight.toPrecision(12));
|
63 |
+
|
64 |
+
target.value = target.value.slice(0, selectionEnd + 1) +
|
65 |
+
weight +
|
66 |
+
target.value.slice(selectionEnd + 1 + end - 1);
|
67 |
+
|
68 |
+
target.focus();
|
69 |
+
target.selectionStart = selectionStart;
|
70 |
+
target.selectionEnd = selectionEnd;
|
71 |
+
}
|
72 |
+
// Since we've modified a Gradio Textbox component manually, we need to simulate an `input` DOM event to ensure its
|
73 |
+
// internal Svelte data binding remains in sync.
|
74 |
+
target.dispatchEvent(new Event("input", { bubbles: true }));
|
75 |
+
});
|
javascript/extensions.js
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
function extensions_apply(_, _){
|
3 |
+
disable = []
|
4 |
+
update = []
|
5 |
+
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
|
6 |
+
if(x.name.startsWith("enable_") && ! x.checked)
|
7 |
+
disable.push(x.name.substr(7))
|
8 |
+
|
9 |
+
if(x.name.startsWith("update_") && x.checked)
|
10 |
+
update.push(x.name.substr(7))
|
11 |
+
})
|
12 |
+
|
13 |
+
restart_reload()
|
14 |
+
|
15 |
+
return [JSON.stringify(disable), JSON.stringify(update)]
|
16 |
+
}
|
17 |
+
|
18 |
+
function extensions_check(){
|
19 |
+
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
|
20 |
+
x.innerHTML = "Loading..."
|
21 |
+
})
|
22 |
+
|
23 |
+
return []
|
24 |
+
}
|
25 |
+
|
26 |
+
function install_extension_from_index(button, url){
|
27 |
+
button.disabled = "disabled"
|
28 |
+
button.value = "Installing..."
|
29 |
+
|
30 |
+
textarea = gradioApp().querySelector('#extension_to_install textarea')
|
31 |
+
textarea.value = url
|
32 |
+
textarea.dispatchEvent(new Event("input", { bubbles: true }))
|
33 |
+
|
34 |
+
gradioApp().querySelector('#install_extension_button').click()
|
35 |
+
}
|
javascript/generationParams.js
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// attaches listeners to the txt2img and img2img galleries to update displayed generation param text when the image changes
|
2 |
+
|
3 |
+
let txt2img_gallery, img2img_gallery, modal = undefined;
|
4 |
+
onUiUpdate(function(){
|
5 |
+
if (!txt2img_gallery) {
|
6 |
+
txt2img_gallery = attachGalleryListeners("txt2img")
|
7 |
+
}
|
8 |
+
if (!img2img_gallery) {
|
9 |
+
img2img_gallery = attachGalleryListeners("img2img")
|
10 |
+
}
|
11 |
+
if (!modal) {
|
12 |
+
modal = gradioApp().getElementById('lightboxModal')
|
13 |
+
modalObserver.observe(modal, { attributes : true, attributeFilter : ['style'] });
|
14 |
+
}
|
15 |
+
});
|
16 |
+
|
17 |
+
let modalObserver = new MutationObserver(function(mutations) {
|
18 |
+
mutations.forEach(function(mutationRecord) {
|
19 |
+
let selectedTab = gradioApp().querySelector('#tabs div button.bg-white')?.innerText
|
20 |
+
if (mutationRecord.target.style.display === 'none' && selectedTab === 'txt2img' || selectedTab === 'img2img')
|
21 |
+
gradioApp().getElementById(selectedTab+"_generation_info_button").click()
|
22 |
+
});
|
23 |
+
});
|
24 |
+
|
25 |
+
function attachGalleryListeners(tab_name) {
|
26 |
+
gallery = gradioApp().querySelector('#'+tab_name+'_gallery')
|
27 |
+
gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name+"_generation_info_button").click());
|
28 |
+
gallery?.addEventListener('keydown', (e) => {
|
29 |
+
if (e.keyCode == 37 || e.keyCode == 39) // left or right arrow
|
30 |
+
gradioApp().getElementById(tab_name+"_generation_info_button").click()
|
31 |
+
});
|
32 |
+
return gallery;
|
33 |
+
}
|
javascript/hints.js
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// mouseover tooltips for various UI elements
|
2 |
+
|
3 |
+
titles = {
|
4 |
+
"Sampling steps": "How many times to improve the generated image iteratively; higher values take longer; very low values can produce bad results",
|
5 |
+
"Sampling method": "Which algorithm to use to produce the image",
|
6 |
+
"GFPGAN": "Restore low quality faces using GFPGAN neural network",
|
7 |
+
"Euler a": "Euler Ancestral - very creative, each can get a completely different picture depending on step count, setting steps to higher than 30-40 does not help",
|
8 |
+
"DDIM": "Denoising Diffusion Implicit Models - best at inpainting",
|
9 |
+
"DPM adaptive": "Ignores step count - uses a number of steps determined by the CFG and resolution",
|
10 |
+
|
11 |
+
"Batch count": "How many batches of images to create",
|
12 |
+
"Batch size": "How many image to create in a single batch",
|
13 |
+
"CFG Scale": "Classifier Free Guidance Scale - how strongly the image should conform to prompt - lower values produce more creative results",
|
14 |
+
"Seed": "A value that determines the output of random number generator - if you create an image with same parameters and seed as another image, you'll get the same result",
|
15 |
+
"\u{1f3b2}\ufe0f": "Set seed to -1, which will cause a new random number to be used every time",
|
16 |
+
"\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed",
|
17 |
+
"\u{1f3a8}": "Add a random artist to the prompt.",
|
18 |
+
"\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.",
|
19 |
+
"\u{1f4c2}": "Open images output directory",
|
20 |
+
"\u{1f4be}": "Save style",
|
21 |
+
"\U0001F5D1": "Clear prompt",
|
22 |
+
"\u{1f4cb}": "Apply selected styles to current prompt",
|
23 |
+
|
24 |
+
"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
|
25 |
+
"SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",
|
26 |
+
|
27 |
+
"Just resize": "Resize image to target resolution. Unless height and width match, you will get incorrect aspect ratio.",
|
28 |
+
"Crop and resize": "Resize the image so that entirety of target resolution is filled with the image. Crop parts that stick out.",
|
29 |
+
"Resize and fill": "Resize the image so that entirety of image is inside target resolution. Fill empty space with image's colors.",
|
30 |
+
|
31 |
+
"Mask blur": "How much to blur the mask before processing, in pixels.",
|
32 |
+
"Masked content": "What to put inside the masked area before processing it with Stable Diffusion.",
|
33 |
+
"fill": "fill it with colors of the image",
|
34 |
+
"original": "keep whatever was there originally",
|
35 |
+
"latent noise": "fill it with latent space noise",
|
36 |
+
"latent nothing": "fill it with latent space zeroes",
|
37 |
+
"Inpaint at full resolution": "Upscale masked region to target resolution, do inpainting, downscale back and paste into original image",
|
38 |
+
|
39 |
+
"Denoising strength": "Determines how little respect the algorithm should have for image's content. At 0, nothing will change, and at 1 you'll get an unrelated image. With values below 1.0, processing will take less steps than the Sampling Steps slider specifies.",
|
40 |
+
"Denoising strength change factor": "In loopback mode, on each loop the denoising strength is multiplied by this value. <1 means decreasing variety so your sequence will converge on a fixed picture. >1 means increasing variety so your sequence will become more and more chaotic.",
|
41 |
+
|
42 |
+
"Skip": "Stop processing current image and continue processing.",
|
43 |
+
"Interrupt": "Stop processing images and return any results accumulated so far.",
|
44 |
+
"Save": "Write image to a directory (default - log/images) and generation parameters into csv file.",
|
45 |
+
|
46 |
+
"X values": "Separate values for X axis using commas.",
|
47 |
+
"Y values": "Separate values for Y axis using commas.",
|
48 |
+
|
49 |
+
"None": "Do not do anything special",
|
50 |
+
"Prompt matrix": "Separate prompts into parts using vertical pipe character (|) and the script will create a picture for every combination of them (except for the first part, which will be present in all combinations)",
|
51 |
+
"X/Y plot": "Create a grid where images will have different parameters. Use inputs below to specify which parameters will be shared by columns and rows",
|
52 |
+
"Custom code": "Run Python code. Advanced user only. Must run program with --allow-code for this to work",
|
53 |
+
|
54 |
+
"Prompt S/R": "Separate a list of words with commas, and the first word will be used as a keyword: script will search for this word in the prompt, and replace it with others",
|
55 |
+
"Prompt order": "Separate a list of words with commas, and the script will make a variation of prompt with those words for their every possible order",
|
56 |
+
|
57 |
+
"Tiling": "Produce an image that can be tiled.",
|
58 |
+
"Tile overlap": "For SD upscale, how much overlap in pixels should there be between tiles. Tiles overlap so that when they are merged back into one picture, there is no clearly visible seam.",
|
59 |
+
|
60 |
+
"Variation seed": "Seed of a different picture to be mixed into the generation.",
|
61 |
+
"Variation strength": "How strong of a variation to produce. At 0, there will be no effect. At 1, you will get the complete picture with variation seed (except for ancestral samplers, where you will just get something).",
|
62 |
+
"Resize seed from height": "Make an attempt to produce a picture similar to what would have been produced with same seed at specified resolution",
|
63 |
+
"Resize seed from width": "Make an attempt to produce a picture similar to what would have been produced with same seed at specified resolution",
|
64 |
+
|
65 |
+
"Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.",
|
66 |
+
|
67 |
+
"Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.",
|
68 |
+
"Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.",
|
69 |
+
"Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle",
|
70 |
+
|
71 |
+
"Loopback": "Process an image, use it as an input, repeat.",
|
72 |
+
"Loops": "How many times to repeat processing an image and using it as input for the next iteration",
|
73 |
+
|
74 |
+
"Style 1": "Style to apply; styles have components for both positive and negative prompts and apply to both",
|
75 |
+
"Style 2": "Style to apply; styles have components for both positive and negative prompts and apply to both",
|
76 |
+
"Apply style": "Insert selected styles into prompt fields",
|
77 |
+
"Create style": "Save current prompts as a style. If you add the token {prompt} to the text, the style use that as placeholder for your prompt when you use the style in the future.",
|
78 |
+
|
79 |
+
"Checkpoint name": "Loads weights from checkpoint before making images. You can either use hash or a part of filename (as seen in settings) for checkpoint name. Recommended to use with Y axis for less switching.",
|
80 |
+
"Inpainting conditioning mask strength": "Only applies to inpainting models. Determines how strongly to mask off the original image for inpainting and img2img. 1.0 means fully masked, which is the default behaviour. 0.0 means a fully unmasked conditioning. Lower values will help preserve the overall composition of the image, but will struggle with large changes.",
|
81 |
+
|
82 |
+
"vram": "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data.\nTorch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data.\nSys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%).",
|
83 |
+
|
84 |
+
"Highres. fix": "Use a two step process to partially create an image at smaller resolution, upscale, and then improve details in it without changing composition",
|
85 |
+
"Scale latent": "Uscale the image in latent space. Alternative is to produce the full image from latent representation, upscale that, and then move it back to latent space.",
|
86 |
+
|
87 |
+
"Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.",
|
88 |
+
"Do not add watermark to images": "If this option is enabled, watermark will not be added to created images. Warning: if you do not add watermark, you may be behaving in an unethical manner.",
|
89 |
+
|
90 |
+
"Filename word regex": "This regular expression will be used extract words from filename, and they will be joined using the option below into label text used for training. Leave empty to keep filename text as it is.",
|
91 |
+
"Filename join string": "This string will be used to join split words into a single line if the option above is enabled.",
|
92 |
+
|
93 |
+
"Quicksettings list": "List of setting names, separated by commas, for settings that should go to the quick access bar at the top, rather than the usual setting tab. See modules/shared.py for setting names. Requires restarting to apply.",
|
94 |
+
|
95 |
+
"Weighted sum": "Result = A * (1 - M) + B * M",
|
96 |
+
"Add difference": "Result = A + (B - C) * M",
|
97 |
+
|
98 |
+
"Learning rate": "how fast should the training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.",
|
99 |
+
|
100 |
+
"Clip skip": "Early stopping parameter for CLIP model; 1 is stop at last layer as usual, 2 is stop at penultimate layer, etc.",
|
101 |
+
|
102 |
+
"Approx NN": "Cheap neural network approximation. Very fast compared to VAE, but produces pictures with 4 times smaller horizontal/vertical resoluton and lower quality.",
|
103 |
+
"Approx cheap": "Very cheap approximation. Very fast compared to VAE, but produces pictures with 8 times smaller horizontal/vertical resoluton and extremely low quality."
|
104 |
+
}
|
105 |
+
|
106 |
+
|
107 |
+
onUiUpdate(function(){
|
108 |
+
gradioApp().querySelectorAll('span, button, select, p').forEach(function(span){
|
109 |
+
tooltip = titles[span.textContent];
|
110 |
+
|
111 |
+
if(!tooltip){
|
112 |
+
tooltip = titles[span.value];
|
113 |
+
}
|
114 |
+
|
115 |
+
if(!tooltip){
|
116 |
+
for (const c of span.classList) {
|
117 |
+
if (c in titles) {
|
118 |
+
tooltip = titles[c];
|
119 |
+
break;
|
120 |
+
}
|
121 |
+
}
|
122 |
+
}
|
123 |
+
|
124 |
+
if(tooltip){
|
125 |
+
span.title = tooltip;
|
126 |
+
}
|
127 |
+
})
|
128 |
+
|
129 |
+
gradioApp().querySelectorAll('select').forEach(function(select){
|
130 |
+
if (select.onchange != null) return;
|
131 |
+
|
132 |
+
select.onchange = function(){
|
133 |
+
select.title = titles[select.value] || "";
|
134 |
+
}
|
135 |
+
})
|
136 |
+
})
|
javascript/imageMaskFix.js
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/**
|
2 |
+
* temporary fix for https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/668
|
3 |
+
* @see https://github.com/gradio-app/gradio/issues/1721
|
4 |
+
*/
|
5 |
+
window.addEventListener( 'resize', () => imageMaskResize());
|
6 |
+
function imageMaskResize() {
|
7 |
+
const canvases = gradioApp().querySelectorAll('#img2maskimg .touch-none canvas');
|
8 |
+
if ( ! canvases.length ) {
|
9 |
+
canvases_fixed = false;
|
10 |
+
window.removeEventListener( 'resize', imageMaskResize );
|
11 |
+
return;
|
12 |
+
}
|
13 |
+
|
14 |
+
const wrapper = canvases[0].closest('.touch-none');
|
15 |
+
const previewImage = wrapper.previousElementSibling;
|
16 |
+
|
17 |
+
if ( ! previewImage.complete ) {
|
18 |
+
previewImage.addEventListener( 'load', () => imageMaskResize());
|
19 |
+
return;
|
20 |
+
}
|
21 |
+
|
22 |
+
const w = previewImage.width;
|
23 |
+
const h = previewImage.height;
|
24 |
+
const nw = previewImage.naturalWidth;
|
25 |
+
const nh = previewImage.naturalHeight;
|
26 |
+
const portrait = nh > nw;
|
27 |
+
const factor = portrait;
|
28 |
+
|
29 |
+
const wW = Math.min(w, portrait ? h/nh*nw : w/nw*nw);
|
30 |
+
const wH = Math.min(h, portrait ? h/nh*nh : w/nw*nh);
|
31 |
+
|
32 |
+
wrapper.style.width = `${wW}px`;
|
33 |
+
wrapper.style.height = `${wH}px`;
|
34 |
+
wrapper.style.left = `0px`;
|
35 |
+
wrapper.style.top = `0px`;
|
36 |
+
|
37 |
+
canvases.forEach( c => {
|
38 |
+
c.style.width = c.style.height = '';
|
39 |
+
c.style.maxWidth = '100%';
|
40 |
+
c.style.maxHeight = '100%';
|
41 |
+
c.style.objectFit = 'contain';
|
42 |
+
});
|
43 |
+
}
|
44 |
+
|
45 |
+
onUiUpdate(() => imageMaskResize());
|
javascript/imageParams.js
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
window.onload = (function(){
|
2 |
+
window.addEventListener('drop', e => {
|
3 |
+
const target = e.composedPath()[0];
|
4 |
+
const idx = selected_gallery_index();
|
5 |
+
if (target.placeholder.indexOf("Prompt") == -1) return;
|
6 |
+
|
7 |
+
let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image";
|
8 |
+
|
9 |
+
e.stopPropagation();
|
10 |
+
e.preventDefault();
|
11 |
+
const imgParent = gradioApp().getElementById(prompt_target);
|
12 |
+
const files = e.dataTransfer.files;
|
13 |
+
const fileInput = imgParent.querySelector('input[type="file"]');
|
14 |
+
if ( fileInput ) {
|
15 |
+
fileInput.files = files;
|
16 |
+
fileInput.dispatchEvent(new Event('change'));
|
17 |
+
}
|
18 |
+
});
|
19 |
+
});
|
javascript/imageviewer.js
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// A full size 'lightbox' preview modal shown when left clicking on gallery previews
|
2 |
+
function closeModal() {
|
3 |
+
gradioApp().getElementById("lightboxModal").style.display = "none";
|
4 |
+
}
|
5 |
+
|
6 |
+
function showModal(event) {
|
7 |
+
const source = event.target || event.srcElement;
|
8 |
+
const modalImage = gradioApp().getElementById("modalImage")
|
9 |
+
const lb = gradioApp().getElementById("lightboxModal")
|
10 |
+
modalImage.src = source.src
|
11 |
+
if (modalImage.style.display === 'none') {
|
12 |
+
lb.style.setProperty('background-image', 'url(' + source.src + ')');
|
13 |
+
}
|
14 |
+
lb.style.display = "block";
|
15 |
+
lb.focus()
|
16 |
+
|
17 |
+
const tabTxt2Img = gradioApp().getElementById("tab_txt2img")
|
18 |
+
const tabImg2Img = gradioApp().getElementById("tab_img2img")
|
19 |
+
// show the save button in modal only on txt2img or img2img tabs
|
20 |
+
if (tabTxt2Img.style.display != "none" || tabImg2Img.style.display != "none") {
|
21 |
+
gradioApp().getElementById("modal_save").style.display = "inline"
|
22 |
+
} else {
|
23 |
+
gradioApp().getElementById("modal_save").style.display = "none"
|
24 |
+
}
|
25 |
+
event.stopPropagation()
|
26 |
+
}
|
27 |
+
|
28 |
+
function negmod(n, m) {
|
29 |
+
return ((n % m) + m) % m;
|
30 |
+
}
|
31 |
+
|
32 |
+
function updateOnBackgroundChange() {
|
33 |
+
const modalImage = gradioApp().getElementById("modalImage")
|
34 |
+
if (modalImage && modalImage.offsetParent) {
|
35 |
+
let allcurrentButtons = gradioApp().querySelectorAll(".gallery-item.transition-all.\\!ring-2")
|
36 |
+
let currentButton = null
|
37 |
+
allcurrentButtons.forEach(function(elem) {
|
38 |
+
if (elem.parentElement.offsetParent) {
|
39 |
+
currentButton = elem;
|
40 |
+
}
|
41 |
+
})
|
42 |
+
|
43 |
+
if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) {
|
44 |
+
modalImage.src = currentButton.children[0].src;
|
45 |
+
if (modalImage.style.display === 'none') {
|
46 |
+
modal.style.setProperty('background-image', `url(${modalImage.src})`)
|
47 |
+
}
|
48 |
+
}
|
49 |
+
}
|
50 |
+
}
|
51 |
+
|
52 |
+
function modalImageSwitch(offset) {
|
53 |
+
var allgalleryButtons = gradioApp().querySelectorAll(".gallery-item.transition-all")
|
54 |
+
var galleryButtons = []
|
55 |
+
allgalleryButtons.forEach(function(elem) {
|
56 |
+
if (elem.parentElement.offsetParent) {
|
57 |
+
galleryButtons.push(elem);
|
58 |
+
}
|
59 |
+
})
|
60 |
+
|
61 |
+
if (galleryButtons.length > 1) {
|
62 |
+
var allcurrentButtons = gradioApp().querySelectorAll(".gallery-item.transition-all.\\!ring-2")
|
63 |
+
var currentButton = null
|
64 |
+
allcurrentButtons.forEach(function(elem) {
|
65 |
+
if (elem.parentElement.offsetParent) {
|
66 |
+
currentButton = elem;
|
67 |
+
}
|
68 |
+
})
|
69 |
+
|
70 |
+
var result = -1
|
71 |
+
galleryButtons.forEach(function(v, i) {
|
72 |
+
if (v == currentButton) {
|
73 |
+
result = i
|
74 |
+
}
|
75 |
+
})
|
76 |
+
|
77 |
+
if (result != -1) {
|
78 |
+
nextButton = galleryButtons[negmod((result + offset), galleryButtons.length)]
|
79 |
+
nextButton.click()
|
80 |
+
const modalImage = gradioApp().getElementById("modalImage");
|
81 |
+
const modal = gradioApp().getElementById("lightboxModal");
|
82 |
+
modalImage.src = nextButton.children[0].src;
|
83 |
+
if (modalImage.style.display === 'none') {
|
84 |
+
modal.style.setProperty('background-image', `url(${modalImage.src})`)
|
85 |
+
}
|
86 |
+
setTimeout(function() {
|
87 |
+
modal.focus()
|
88 |
+
}, 10)
|
89 |
+
}
|
90 |
+
}
|
91 |
+
}
|
92 |
+
|
93 |
+
function saveImage(){
|
94 |
+
const tabTxt2Img = gradioApp().getElementById("tab_txt2img")
|
95 |
+
const tabImg2Img = gradioApp().getElementById("tab_img2img")
|
96 |
+
const saveTxt2Img = "save_txt2img"
|
97 |
+
const saveImg2Img = "save_img2img"
|
98 |
+
if (tabTxt2Img.style.display != "none") {
|
99 |
+
gradioApp().getElementById(saveTxt2Img).click()
|
100 |
+
} else if (tabImg2Img.style.display != "none") {
|
101 |
+
gradioApp().getElementById(saveImg2Img).click()
|
102 |
+
} else {
|
103 |
+
console.error("missing implementation for saving modal of this type")
|
104 |
+
}
|
105 |
+
}
|
106 |
+
|
107 |
+
function modalSaveImage(event) {
|
108 |
+
saveImage()
|
109 |
+
event.stopPropagation()
|
110 |
+
}
|
111 |
+
|
112 |
+
function modalNextImage(event) {
|
113 |
+
modalImageSwitch(1)
|
114 |
+
event.stopPropagation()
|
115 |
+
}
|
116 |
+
|
117 |
+
function modalPrevImage(event) {
|
118 |
+
modalImageSwitch(-1)
|
119 |
+
event.stopPropagation()
|
120 |
+
}
|
121 |
+
|
122 |
+
function modalKeyHandler(event) {
|
123 |
+
switch (event.key) {
|
124 |
+
case "s":
|
125 |
+
saveImage()
|
126 |
+
break;
|
127 |
+
case "ArrowLeft":
|
128 |
+
modalPrevImage(event)
|
129 |
+
break;
|
130 |
+
case "ArrowRight":
|
131 |
+
modalNextImage(event)
|
132 |
+
break;
|
133 |
+
case "Escape":
|
134 |
+
closeModal();
|
135 |
+
break;
|
136 |
+
}
|
137 |
+
}
|
138 |
+
|
139 |
+
function showGalleryImage() {
|
140 |
+
setTimeout(function() {
|
141 |
+
fullImg_preview = gradioApp().querySelectorAll('img.w-full.object-contain')
|
142 |
+
|
143 |
+
if (fullImg_preview != null) {
|
144 |
+
fullImg_preview.forEach(function function_name(e) {
|
145 |
+
if (e.dataset.modded)
|
146 |
+
return;
|
147 |
+
e.dataset.modded = true;
|
148 |
+
if(e && e.parentElement.tagName == 'DIV'){
|
149 |
+
e.style.cursor='pointer'
|
150 |
+
e.style.userSelect='none'
|
151 |
+
e.addEventListener('click', function (evt) {
|
152 |
+
if(!opts.js_modal_lightbox) return;
|
153 |
+
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed)
|
154 |
+
showModal(evt)
|
155 |
+
}, true);
|
156 |
+
}
|
157 |
+
});
|
158 |
+
}
|
159 |
+
|
160 |
+
}, 100);
|
161 |
+
}
|
162 |
+
|
163 |
+
function modalZoomSet(modalImage, enable) {
|
164 |
+
if (enable) {
|
165 |
+
modalImage.classList.add('modalImageFullscreen');
|
166 |
+
} else {
|
167 |
+
modalImage.classList.remove('modalImageFullscreen');
|
168 |
+
}
|
169 |
+
}
|
170 |
+
|
171 |
+
function modalZoomToggle(event) {
|
172 |
+
modalImage = gradioApp().getElementById("modalImage");
|
173 |
+
modalZoomSet(modalImage, !modalImage.classList.contains('modalImageFullscreen'))
|
174 |
+
event.stopPropagation()
|
175 |
+
}
|
176 |
+
|
177 |
+
function modalTileImageToggle(event) {
|
178 |
+
const modalImage = gradioApp().getElementById("modalImage");
|
179 |
+
const modal = gradioApp().getElementById("lightboxModal");
|
180 |
+
const isTiling = modalImage.style.display === 'none';
|
181 |
+
if (isTiling) {
|
182 |
+
modalImage.style.display = 'block';
|
183 |
+
modal.style.setProperty('background-image', 'none')
|
184 |
+
} else {
|
185 |
+
modalImage.style.display = 'none';
|
186 |
+
modal.style.setProperty('background-image', `url(${modalImage.src})`)
|
187 |
+
}
|
188 |
+
|
189 |
+
event.stopPropagation()
|
190 |
+
}
|
191 |
+
|
192 |
+
function galleryImageHandler(e) {
|
193 |
+
if (e && e.parentElement.tagName == 'BUTTON') {
|
194 |
+
e.onclick = showGalleryImage;
|
195 |
+
}
|
196 |
+
}
|
197 |
+
|
198 |
+
onUiUpdate(function() {
|
199 |
+
fullImg_preview = gradioApp().querySelectorAll('img.w-full')
|
200 |
+
if (fullImg_preview != null) {
|
201 |
+
fullImg_preview.forEach(galleryImageHandler);
|
202 |
+
}
|
203 |
+
updateOnBackgroundChange();
|
204 |
+
})
|
205 |
+
|
206 |
+
document.addEventListener("DOMContentLoaded", function() {
|
207 |
+
const modalFragment = document.createDocumentFragment();
|
208 |
+
const modal = document.createElement('div')
|
209 |
+
modal.onclick = closeModal;
|
210 |
+
modal.id = "lightboxModal";
|
211 |
+
modal.tabIndex = 0
|
212 |
+
modal.addEventListener('keydown', modalKeyHandler, true)
|
213 |
+
|
214 |
+
const modalControls = document.createElement('div')
|
215 |
+
modalControls.className = 'modalControls gradio-container';
|
216 |
+
modal.append(modalControls);
|
217 |
+
|
218 |
+
const modalZoom = document.createElement('span')
|
219 |
+
modalZoom.className = 'modalZoom cursor';
|
220 |
+
modalZoom.innerHTML = '⤡'
|
221 |
+
modalZoom.addEventListener('click', modalZoomToggle, true)
|
222 |
+
modalZoom.title = "Toggle zoomed view";
|
223 |
+
modalControls.appendChild(modalZoom)
|
224 |
+
|
225 |
+
const modalTileImage = document.createElement('span')
|
226 |
+
modalTileImage.className = 'modalTileImage cursor';
|
227 |
+
modalTileImage.innerHTML = '⊞'
|
228 |
+
modalTileImage.addEventListener('click', modalTileImageToggle, true)
|
229 |
+
modalTileImage.title = "Preview tiling";
|
230 |
+
modalControls.appendChild(modalTileImage)
|
231 |
+
|
232 |
+
const modalSave = document.createElement("span")
|
233 |
+
modalSave.className = "modalSave cursor"
|
234 |
+
modalSave.id = "modal_save"
|
235 |
+
modalSave.innerHTML = "🖫"
|
236 |
+
modalSave.addEventListener("click", modalSaveImage, true)
|
237 |
+
modalSave.title = "Save Image(s)"
|
238 |
+
modalControls.appendChild(modalSave)
|
239 |
+
|
240 |
+
const modalClose = document.createElement('span')
|
241 |
+
modalClose.className = 'modalClose cursor';
|
242 |
+
modalClose.innerHTML = '×'
|
243 |
+
modalClose.onclick = closeModal;
|
244 |
+
modalClose.title = "Close image viewer";
|
245 |
+
modalControls.appendChild(modalClose)
|
246 |
+
|
247 |
+
const modalImage = document.createElement('img')
|
248 |
+
modalImage.id = 'modalImage';
|
249 |
+
modalImage.onclick = closeModal;
|
250 |
+
modalImage.tabIndex = 0
|
251 |
+
modalImage.addEventListener('keydown', modalKeyHandler, true)
|
252 |
+
modal.appendChild(modalImage)
|
253 |
+
|
254 |
+
const modalPrev = document.createElement('a')
|
255 |
+
modalPrev.className = 'modalPrev';
|
256 |
+
modalPrev.innerHTML = '❮'
|
257 |
+
modalPrev.tabIndex = 0
|
258 |
+
modalPrev.addEventListener('click', modalPrevImage, true);
|
259 |
+
modalPrev.addEventListener('keydown', modalKeyHandler, true)
|
260 |
+
modal.appendChild(modalPrev)
|
261 |
+
|
262 |
+
const modalNext = document.createElement('a')
|
263 |
+
modalNext.className = 'modalNext';
|
264 |
+
modalNext.innerHTML = '❯'
|
265 |
+
modalNext.tabIndex = 0
|
266 |
+
modalNext.addEventListener('click', modalNextImage, true);
|
267 |
+
modalNext.addEventListener('keydown', modalKeyHandler, true)
|
268 |
+
|
269 |
+
modal.appendChild(modalNext)
|
270 |
+
|
271 |
+
|
272 |
+
gradioApp().getRootNode().appendChild(modal)
|
273 |
+
|
274 |
+
document.body.appendChild(modalFragment);
|
275 |
+
|
276 |
+
});
|
javascript/localization.js
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
// localization = {} -- the dict with translations is created by the backend
|
3 |
+
|
4 |
+
ignore_ids_for_localization={
|
5 |
+
setting_sd_hypernetwork: 'OPTION',
|
6 |
+
setting_sd_model_checkpoint: 'OPTION',
|
7 |
+
setting_realesrgan_enabled_models: 'OPTION',
|
8 |
+
modelmerger_primary_model_name: 'OPTION',
|
9 |
+
modelmerger_secondary_model_name: 'OPTION',
|
10 |
+
modelmerger_tertiary_model_name: 'OPTION',
|
11 |
+
train_embedding: 'OPTION',
|
12 |
+
train_hypernetwork: 'OPTION',
|
13 |
+
txt2img_style_index: 'OPTION',
|
14 |
+
txt2img_style2_index: 'OPTION',
|
15 |
+
img2img_style_index: 'OPTION',
|
16 |
+
img2img_style2_index: 'OPTION',
|
17 |
+
setting_random_artist_categories: 'SPAN',
|
18 |
+
setting_face_restoration_model: 'SPAN',
|
19 |
+
setting_realesrgan_enabled_models: 'SPAN',
|
20 |
+
extras_upscaler_1: 'SPAN',
|
21 |
+
extras_upscaler_2: 'SPAN',
|
22 |
+
}
|
23 |
+
|
24 |
+
re_num = /^[\.\d]+$/
|
25 |
+
re_emoji = /[\p{Extended_Pictographic}\u{1F3FB}-\u{1F3FF}\u{1F9B0}-\u{1F9B3}]/u
|
26 |
+
|
27 |
+
original_lines = {}
|
28 |
+
translated_lines = {}
|
29 |
+
|
30 |
+
function textNodesUnder(el){
|
31 |
+
var n, a=[], walk=document.createTreeWalker(el,NodeFilter.SHOW_TEXT,null,false);
|
32 |
+
while(n=walk.nextNode()) a.push(n);
|
33 |
+
return a;
|
34 |
+
}
|
35 |
+
|
36 |
+
function canBeTranslated(node, text){
|
37 |
+
if(! text) return false;
|
38 |
+
if(! node.parentElement) return false;
|
39 |
+
|
40 |
+
parentType = node.parentElement.nodeName
|
41 |
+
if(parentType=='SCRIPT' || parentType=='STYLE' || parentType=='TEXTAREA') return false;
|
42 |
+
|
43 |
+
if (parentType=='OPTION' || parentType=='SPAN'){
|
44 |
+
pnode = node
|
45 |
+
for(var level=0; level<4; level++){
|
46 |
+
pnode = pnode.parentElement
|
47 |
+
if(! pnode) break;
|
48 |
+
|
49 |
+
if(ignore_ids_for_localization[pnode.id] == parentType) return false;
|
50 |
+
}
|
51 |
+
}
|
52 |
+
|
53 |
+
if(re_num.test(text)) return false;
|
54 |
+
if(re_emoji.test(text)) return false;
|
55 |
+
return true
|
56 |
+
}
|
57 |
+
|
58 |
+
function getTranslation(text){
|
59 |
+
if(! text) return undefined
|
60 |
+
|
61 |
+
if(translated_lines[text] === undefined){
|
62 |
+
original_lines[text] = 1
|
63 |
+
}
|
64 |
+
|
65 |
+
tl = localization[text]
|
66 |
+
if(tl !== undefined){
|
67 |
+
translated_lines[tl] = 1
|
68 |
+
}
|
69 |
+
|
70 |
+
return tl
|
71 |
+
}
|
72 |
+
|
73 |
+
function processTextNode(node){
|
74 |
+
text = node.textContent.trim()
|
75 |
+
|
76 |
+
if(! canBeTranslated(node, text)) return
|
77 |
+
|
78 |
+
tl = getTranslation(text)
|
79 |
+
if(tl !== undefined){
|
80 |
+
node.textContent = tl
|
81 |
+
}
|
82 |
+
}
|
83 |
+
|
84 |
+
function processNode(node){
|
85 |
+
if(node.nodeType == 3){
|
86 |
+
processTextNode(node)
|
87 |
+
return
|
88 |
+
}
|
89 |
+
|
90 |
+
if(node.title){
|
91 |
+
tl = getTranslation(node.title)
|
92 |
+
if(tl !== undefined){
|
93 |
+
node.title = tl
|
94 |
+
}
|
95 |
+
}
|
96 |
+
|
97 |
+
if(node.placeholder){
|
98 |
+
tl = getTranslation(node.placeholder)
|
99 |
+
if(tl !== undefined){
|
100 |
+
node.placeholder = tl
|
101 |
+
}
|
102 |
+
}
|
103 |
+
|
104 |
+
textNodesUnder(node).forEach(function(node){
|
105 |
+
processTextNode(node)
|
106 |
+
})
|
107 |
+
}
|
108 |
+
|
109 |
+
function dumpTranslations(){
|
110 |
+
dumped = {}
|
111 |
+
if (localization.rtl) {
|
112 |
+
dumped.rtl = true
|
113 |
+
}
|
114 |
+
|
115 |
+
Object.keys(original_lines).forEach(function(text){
|
116 |
+
if(dumped[text] !== undefined) return
|
117 |
+
|
118 |
+
dumped[text] = localization[text] || text
|
119 |
+
})
|
120 |
+
|
121 |
+
return dumped
|
122 |
+
}
|
123 |
+
|
124 |
+
onUiUpdate(function(m){
|
125 |
+
m.forEach(function(mutation){
|
126 |
+
mutation.addedNodes.forEach(function(node){
|
127 |
+
processNode(node)
|
128 |
+
})
|
129 |
+
});
|
130 |
+
})
|
131 |
+
|
132 |
+
|
133 |
+
document.addEventListener("DOMContentLoaded", function() {
|
134 |
+
processNode(gradioApp())
|
135 |
+
|
136 |
+
if (localization.rtl) { // if the language is from right to left,
|
137 |
+
(new MutationObserver((mutations, observer) => { // wait for the style to load
|
138 |
+
mutations.forEach(mutation => {
|
139 |
+
mutation.addedNodes.forEach(node => {
|
140 |
+
if (node.tagName === 'STYLE') {
|
141 |
+
observer.disconnect();
|
142 |
+
|
143 |
+
for (const x of node.sheet.rules) { // find all rtl media rules
|
144 |
+
if (Array.from(x.media || []).includes('rtl')) {
|
145 |
+
x.media.appendMedium('all'); // enable them
|
146 |
+
}
|
147 |
+
}
|
148 |
+
}
|
149 |
+
})
|
150 |
+
});
|
151 |
+
})).observe(gradioApp(), { childList: true });
|
152 |
+
}
|
153 |
+
})
|
154 |
+
|
155 |
+
function download_localization() {
|
156 |
+
text = JSON.stringify(dumpTranslations(), null, 4)
|
157 |
+
|
158 |
+
var element = document.createElement('a');
|
159 |
+
element.setAttribute('href', 'data:text/plain;charset=utf-8,' + encodeURIComponent(text));
|
160 |
+
element.setAttribute('download', "localization.json");
|
161 |
+
element.style.display = 'none';
|
162 |
+
document.body.appendChild(element);
|
163 |
+
|
164 |
+
element.click();
|
165 |
+
|
166 |
+
document.body.removeChild(element);
|
167 |
+
}
|
javascript/notification.js
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Monitors the gallery and sends a browser notification when the leading image is new.
|
2 |
+
|
3 |
+
let lastHeadImg = null;
|
4 |
+
|
5 |
+
notificationButton = null
|
6 |
+
|
7 |
+
onUiUpdate(function(){
|
8 |
+
if(notificationButton == null){
|
9 |
+
notificationButton = gradioApp().getElementById('request_notifications')
|
10 |
+
|
11 |
+
if(notificationButton != null){
|
12 |
+
notificationButton.addEventListener('click', function (evt) {
|
13 |
+
Notification.requestPermission();
|
14 |
+
},true);
|
15 |
+
}
|
16 |
+
}
|
17 |
+
|
18 |
+
const galleryPreviews = gradioApp().querySelectorAll('div[id^="tab_"][style*="display: block"] img.h-full.w-full.overflow-hidden');
|
19 |
+
|
20 |
+
if (galleryPreviews == null) return;
|
21 |
+
|
22 |
+
const headImg = galleryPreviews[0]?.src;
|
23 |
+
|
24 |
+
if (headImg == null || headImg == lastHeadImg) return;
|
25 |
+
|
26 |
+
lastHeadImg = headImg;
|
27 |
+
|
28 |
+
// play notification sound if available
|
29 |
+
gradioApp().querySelector('#audio_notification audio')?.play();
|
30 |
+
|
31 |
+
if (document.hasFocus()) return;
|
32 |
+
|
33 |
+
// Multiple copies of the images are in the DOM when one is selected. Dedup with a Set to get the real number generated.
|
34 |
+
const imgs = new Set(Array.from(galleryPreviews).map(img => img.src));
|
35 |
+
|
36 |
+
const notification = new Notification(
|
37 |
+
'Stable Diffusion',
|
38 |
+
{
|
39 |
+
body: `Generated ${imgs.size > 1 ? imgs.size - opts.return_grid : 1} image${imgs.size > 1 ? 's' : ''}`,
|
40 |
+
icon: headImg,
|
41 |
+
image: headImg,
|
42 |
+
}
|
43 |
+
);
|
44 |
+
|
45 |
+
notification.onclick = function(_){
|
46 |
+
parent.focus();
|
47 |
+
this.close();
|
48 |
+
};
|
49 |
+
});
|
javascript/progressbar.js
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// code related to showing and updating progressbar shown as the image is being made
|
2 |
+
global_progressbars = {}
|
3 |
+
galleries = {}
|
4 |
+
galleryObservers = {}
|
5 |
+
|
6 |
+
// this tracks launches of window.setTimeout for progressbar to prevent starting a new timeout when the previous is still running
|
7 |
+
timeoutIds = {}
|
8 |
+
|
9 |
+
function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){
|
10 |
+
// gradio 3.8's enlightened approach allows them to create two nested div elements inside each other with same id
|
11 |
+
// every time you use gr.HTML(elem_id='xxx'), so we handle this here
|
12 |
+
var progressbar = gradioApp().querySelector("#"+id_progressbar+" #"+id_progressbar)
|
13 |
+
var progressbarParent
|
14 |
+
if(progressbar){
|
15 |
+
progressbarParent = gradioApp().querySelector("#"+id_progressbar)
|
16 |
+
} else{
|
17 |
+
progressbar = gradioApp().getElementById(id_progressbar)
|
18 |
+
progressbarParent = null
|
19 |
+
}
|
20 |
+
|
21 |
+
var skip = id_skip ? gradioApp().getElementById(id_skip) : null
|
22 |
+
var interrupt = gradioApp().getElementById(id_interrupt)
|
23 |
+
|
24 |
+
if(opts.show_progress_in_title && progressbar && progressbar.offsetParent){
|
25 |
+
if(progressbar.innerText){
|
26 |
+
let newtitle = '[' + progressbar.innerText.trim() + '] Stable Diffusion';
|
27 |
+
if(document.title != newtitle){
|
28 |
+
document.title = newtitle;
|
29 |
+
}
|
30 |
+
}else{
|
31 |
+
let newtitle = 'Stable Diffusion'
|
32 |
+
if(document.title != newtitle){
|
33 |
+
document.title = newtitle;
|
34 |
+
}
|
35 |
+
}
|
36 |
+
}
|
37 |
+
|
38 |
+
if(progressbar!= null && progressbar != global_progressbars[id_progressbar]){
|
39 |
+
global_progressbars[id_progressbar] = progressbar
|
40 |
+
|
41 |
+
var mutationObserver = new MutationObserver(function(m){
|
42 |
+
if(timeoutIds[id_part]) return;
|
43 |
+
|
44 |
+
preview = gradioApp().getElementById(id_preview)
|
45 |
+
gallery = gradioApp().getElementById(id_gallery)
|
46 |
+
|
47 |
+
if(preview != null && gallery != null){
|
48 |
+
preview.style.width = gallery.clientWidth + "px"
|
49 |
+
preview.style.height = gallery.clientHeight + "px"
|
50 |
+
if(progressbarParent) progressbar.style.width = progressbarParent.clientWidth + "px"
|
51 |
+
|
52 |
+
//only watch gallery if there is a generation process going on
|
53 |
+
check_gallery(id_gallery);
|
54 |
+
|
55 |
+
var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0;
|
56 |
+
if(progressDiv){
|
57 |
+
timeoutIds[id_part] = window.setTimeout(function() {
|
58 |
+
timeoutIds[id_part] = null
|
59 |
+
requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt)
|
60 |
+
}, 500)
|
61 |
+
} else{
|
62 |
+
if (skip) {
|
63 |
+
skip.style.display = "none"
|
64 |
+
}
|
65 |
+
interrupt.style.display = "none"
|
66 |
+
|
67 |
+
//disconnect observer once generation finished, so user can close selected image if they want
|
68 |
+
if (galleryObservers[id_gallery]) {
|
69 |
+
galleryObservers[id_gallery].disconnect();
|
70 |
+
galleries[id_gallery] = null;
|
71 |
+
}
|
72 |
+
}
|
73 |
+
}
|
74 |
+
|
75 |
+
});
|
76 |
+
mutationObserver.observe( progressbar, { childList:true, subtree:true })
|
77 |
+
}
|
78 |
+
}
|
79 |
+
|
80 |
+
function check_gallery(id_gallery){
|
81 |
+
let gallery = gradioApp().getElementById(id_gallery)
|
82 |
+
// if gallery has no change, no need to setting up observer again.
|
83 |
+
if (gallery && galleries[id_gallery] !== gallery){
|
84 |
+
galleries[id_gallery] = gallery;
|
85 |
+
if(galleryObservers[id_gallery]){
|
86 |
+
galleryObservers[id_gallery].disconnect();
|
87 |
+
}
|
88 |
+
let prevSelectedIndex = selected_gallery_index();
|
89 |
+
galleryObservers[id_gallery] = new MutationObserver(function (){
|
90 |
+
let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item')
|
91 |
+
let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2')
|
92 |
+
if (prevSelectedIndex !== -1 && galleryButtons.length>prevSelectedIndex && !galleryBtnSelected) {
|
93 |
+
// automatically re-open previously selected index (if exists)
|
94 |
+
activeElement = gradioApp().activeElement;
|
95 |
+
let scrollX = window.scrollX;
|
96 |
+
let scrollY = window.scrollY;
|
97 |
+
|
98 |
+
galleryButtons[prevSelectedIndex].click();
|
99 |
+
showGalleryImage();
|
100 |
+
|
101 |
+
// When the gallery button is clicked, it gains focus and scrolls itself into view
|
102 |
+
// We need to scroll back to the previous position
|
103 |
+
setTimeout(function (){
|
104 |
+
window.scrollTo(scrollX, scrollY);
|
105 |
+
}, 50);
|
106 |
+
|
107 |
+
if(activeElement){
|
108 |
+
// i fought this for about an hour; i don't know why the focus is lost or why this helps recover it
|
109 |
+
// if someone has a better solution please by all means
|
110 |
+
setTimeout(function (){
|
111 |
+
activeElement.focus({
|
112 |
+
preventScroll: true // Refocus the element that was focused before the gallery was opened without scrolling to it
|
113 |
+
})
|
114 |
+
}, 1);
|
115 |
+
}
|
116 |
+
}
|
117 |
+
})
|
118 |
+
galleryObservers[id_gallery].observe( gallery, { childList:true, subtree:false })
|
119 |
+
}
|
120 |
+
}
|
121 |
+
|
122 |
+
onUiUpdate(function(){
|
123 |
+
check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_skip', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery')
|
124 |
+
check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_skip', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery')
|
125 |
+
check_progressbar('ti', 'ti_progressbar', 'ti_progress_span', '', 'ti_interrupt', 'ti_preview', 'ti_gallery')
|
126 |
+
})
|
127 |
+
|
128 |
+
function requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt){
|
129 |
+
btn = gradioApp().getElementById(id_part+"_check_progress");
|
130 |
+
if(btn==null) return;
|
131 |
+
|
132 |
+
btn.click();
|
133 |
+
var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0;
|
134 |
+
var skip = id_skip ? gradioApp().getElementById(id_skip) : null
|
135 |
+
var interrupt = gradioApp().getElementById(id_interrupt)
|
136 |
+
if(progressDiv && interrupt){
|
137 |
+
if (skip) {
|
138 |
+
skip.style.display = "block"
|
139 |
+
}
|
140 |
+
interrupt.style.display = "block"
|
141 |
+
}
|
142 |
+
}
|
143 |
+
|
144 |
+
function requestProgress(id_part){
|
145 |
+
btn = gradioApp().getElementById(id_part+"_check_progress_initial");
|
146 |
+
if(btn==null) return;
|
147 |
+
|
148 |
+
btn.click();
|
149 |
+
}
|
javascript/textualInversion.js
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
function start_training_textual_inversion(){
|
4 |
+
requestProgress('ti')
|
5 |
+
gradioApp().querySelector('#ti_error').innerHTML=''
|
6 |
+
|
7 |
+
return args_to_array(arguments)
|
8 |
+
}
|
javascript/ui.js
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// various functions for interation with ui.py not large enough to warrant putting them in separate files
|
2 |
+
|
3 |
+
function set_theme(theme){
|
4 |
+
gradioURL = window.location.href
|
5 |
+
if (!gradioURL.includes('?__theme=')) {
|
6 |
+
window.location.replace(gradioURL + '?__theme=' + theme);
|
7 |
+
}
|
8 |
+
}
|
9 |
+
|
10 |
+
function selected_gallery_index(){
|
11 |
+
var buttons = gradioApp().querySelectorAll('[style="display: block;"].tabitem div[id$=_gallery] .gallery-item')
|
12 |
+
var button = gradioApp().querySelector('[style="display: block;"].tabitem div[id$=_gallery] .gallery-item.\\!ring-2')
|
13 |
+
|
14 |
+
var result = -1
|
15 |
+
buttons.forEach(function(v, i){ if(v==button) { result = i } })
|
16 |
+
|
17 |
+
return result
|
18 |
+
}
|
19 |
+
|
20 |
+
function extract_image_from_gallery(gallery){
|
21 |
+
if(gallery.length == 1){
|
22 |
+
return gallery[0]
|
23 |
+
}
|
24 |
+
|
25 |
+
index = selected_gallery_index()
|
26 |
+
|
27 |
+
if (index < 0 || index >= gallery.length){
|
28 |
+
return [null]
|
29 |
+
}
|
30 |
+
|
31 |
+
return gallery[index];
|
32 |
+
}
|
33 |
+
|
34 |
+
function args_to_array(args){
|
35 |
+
res = []
|
36 |
+
for(var i=0;i<args.length;i++){
|
37 |
+
res.push(args[i])
|
38 |
+
}
|
39 |
+
return res
|
40 |
+
}
|
41 |
+
|
42 |
+
function switch_to_txt2img(){
|
43 |
+
gradioApp().querySelector('#tabs').querySelectorAll('button')[0].click();
|
44 |
+
|
45 |
+
return args_to_array(arguments);
|
46 |
+
}
|
47 |
+
|
48 |
+
function switch_to_img2img(){
|
49 |
+
gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click();
|
50 |
+
gradioApp().getElementById('mode_img2img').querySelectorAll('button')[0].click();
|
51 |
+
|
52 |
+
return args_to_array(arguments);
|
53 |
+
}
|
54 |
+
|
55 |
+
function switch_to_inpaint(){
|
56 |
+
gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click();
|
57 |
+
gradioApp().getElementById('mode_img2img').querySelectorAll('button')[1].click();
|
58 |
+
|
59 |
+
return args_to_array(arguments);
|
60 |
+
}
|
61 |
+
|
62 |
+
function switch_to_extras(){
|
63 |
+
gradioApp().querySelector('#tabs').querySelectorAll('button')[2].click();
|
64 |
+
|
65 |
+
return args_to_array(arguments);
|
66 |
+
}
|
67 |
+
|
68 |
+
function get_tab_index(tabId){
|
69 |
+
var res = 0
|
70 |
+
|
71 |
+
gradioApp().getElementById(tabId).querySelector('div').querySelectorAll('button').forEach(function(button, i){
|
72 |
+
if(button.className.indexOf('bg-white') != -1)
|
73 |
+
res = i
|
74 |
+
})
|
75 |
+
|
76 |
+
return res
|
77 |
+
}
|
78 |
+
|
79 |
+
function create_tab_index_args(tabId, args){
|
80 |
+
var res = []
|
81 |
+
for(var i=0; i<args.length; i++){
|
82 |
+
res.push(args[i])
|
83 |
+
}
|
84 |
+
|
85 |
+
res[0] = get_tab_index(tabId)
|
86 |
+
|
87 |
+
return res
|
88 |
+
}
|
89 |
+
|
90 |
+
function get_extras_tab_index(){
|
91 |
+
const [,,...args] = [...arguments]
|
92 |
+
return [get_tab_index('mode_extras'), get_tab_index('extras_resize_mode'), ...args]
|
93 |
+
}
|
94 |
+
|
95 |
+
function create_submit_args(args){
|
96 |
+
res = []
|
97 |
+
for(var i=0;i<args.length;i++){
|
98 |
+
res.push(args[i])
|
99 |
+
}
|
100 |
+
|
101 |
+
// As it is currently, txt2img and img2img send back the previous output args (txt2img_gallery, generation_info, html_info) whenever you generate a new image.
|
102 |
+
// This can lead to uploading a huge gallery of previously generated images, which leads to an unnecessary delay between submitting and beginning to generate.
|
103 |
+
// I don't know why gradio is sending outputs along with inputs, but we can prevent sending the image gallery here, which seems to be an issue for some.
|
104 |
+
// If gradio at some point stops sending outputs, this may break something
|
105 |
+
if(Array.isArray(res[res.length - 3])){
|
106 |
+
res[res.length - 3] = null
|
107 |
+
}
|
108 |
+
|
109 |
+
return res
|
110 |
+
}
|
111 |
+
|
112 |
+
function submit(){
|
113 |
+
requestProgress('txt2img')
|
114 |
+
|
115 |
+
return create_submit_args(arguments)
|
116 |
+
}
|
117 |
+
|
118 |
+
function submit_img2img(){
|
119 |
+
requestProgress('img2img')
|
120 |
+
|
121 |
+
res = create_submit_args(arguments)
|
122 |
+
|
123 |
+
res[0] = get_tab_index('mode_img2img')
|
124 |
+
|
125 |
+
return res
|
126 |
+
}
|
127 |
+
|
128 |
+
|
129 |
+
function ask_for_style_name(_, prompt_text, negative_prompt_text) {
|
130 |
+
name_ = prompt('Style name:')
|
131 |
+
return [name_, prompt_text, negative_prompt_text]
|
132 |
+
}
|
133 |
+
|
134 |
+
function confirm_clear_prompt(prompt, negative_prompt) {
|
135 |
+
if(confirm("Delete prompt?")) {
|
136 |
+
prompt = ""
|
137 |
+
negative_prompt = ""
|
138 |
+
}
|
139 |
+
|
140 |
+
return [prompt, negative_prompt]
|
141 |
+
}
|
142 |
+
|
143 |
+
|
144 |
+
|
145 |
+
opts = {}
|
146 |
+
function apply_settings(jsdata){
|
147 |
+
console.log(jsdata)
|
148 |
+
|
149 |
+
opts = JSON.parse(jsdata)
|
150 |
+
|
151 |
+
return jsdata
|
152 |
+
}
|
153 |
+
|
154 |
+
onUiUpdate(function(){
|
155 |
+
if(Object.keys(opts).length != 0) return;
|
156 |
+
|
157 |
+
json_elem = gradioApp().getElementById('settings_json')
|
158 |
+
if(json_elem == null) return;
|
159 |
+
|
160 |
+
textarea = json_elem.querySelector('textarea')
|
161 |
+
jsdata = textarea.value
|
162 |
+
opts = JSON.parse(jsdata)
|
163 |
+
|
164 |
+
|
165 |
+
Object.defineProperty(textarea, 'value', {
|
166 |
+
set: function(newValue) {
|
167 |
+
var valueProp = Object.getOwnPropertyDescriptor(HTMLTextAreaElement.prototype, 'value');
|
168 |
+
var oldValue = valueProp.get.call(textarea);
|
169 |
+
valueProp.set.call(textarea, newValue);
|
170 |
+
|
171 |
+
if (oldValue != newValue) {
|
172 |
+
opts = JSON.parse(textarea.value)
|
173 |
+
}
|
174 |
+
},
|
175 |
+
get: function() {
|
176 |
+
var valueProp = Object.getOwnPropertyDescriptor(HTMLTextAreaElement.prototype, 'value');
|
177 |
+
return valueProp.get.call(textarea);
|
178 |
+
}
|
179 |
+
});
|
180 |
+
|
181 |
+
json_elem.parentElement.style.display="none"
|
182 |
+
|
183 |
+
if (!txt2img_textarea) {
|
184 |
+
txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea");
|
185 |
+
txt2img_textarea?.addEventListener("input", () => update_token_counter("txt2img_token_button"));
|
186 |
+
}
|
187 |
+
if (!img2img_textarea) {
|
188 |
+
img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea");
|
189 |
+
img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button"));
|
190 |
+
}
|
191 |
+
})
|
192 |
+
|
193 |
+
let txt2img_textarea, img2img_textarea = undefined;
|
194 |
+
let wait_time = 800
|
195 |
+
let token_timeout;
|
196 |
+
|
197 |
+
function update_txt2img_tokens(...args) {
|
198 |
+
update_token_counter("txt2img_token_button")
|
199 |
+
if (args.length == 2)
|
200 |
+
return args[0]
|
201 |
+
return args;
|
202 |
+
}
|
203 |
+
|
204 |
+
function update_img2img_tokens(...args) {
|
205 |
+
update_token_counter("img2img_token_button")
|
206 |
+
if (args.length == 2)
|
207 |
+
return args[0]
|
208 |
+
return args;
|
209 |
+
}
|
210 |
+
|
211 |
+
function update_token_counter(button_id) {
|
212 |
+
if (token_timeout)
|
213 |
+
clearTimeout(token_timeout);
|
214 |
+
token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
|
215 |
+
}
|
216 |
+
|
217 |
+
function restart_reload(){
|
218 |
+
document.body.innerHTML='<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>';
|
219 |
+
setTimeout(function(){location.reload()},2000)
|
220 |
+
|
221 |
+
return []
|
222 |
+
}
|
launch.py
ADDED
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# this scripts installs necessary requirements and launches main program in webui.py
|
2 |
+
import subprocess
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
import importlib.util
|
6 |
+
import shlex
|
7 |
+
import platform
|
8 |
+
import argparse
|
9 |
+
import json
|
10 |
+
|
11 |
+
dir_repos = "repositories"
|
12 |
+
dir_extensions = "extensions"
|
13 |
+
python = sys.executable
|
14 |
+
git = os.environ.get('GIT', "git")
|
15 |
+
index_url = os.environ.get('INDEX_URL', "")
|
16 |
+
|
17 |
+
|
18 |
+
def extract_arg(args, name):
|
19 |
+
return [x for x in args if x != name], name in args
|
20 |
+
|
21 |
+
|
22 |
+
def extract_opt(args, name):
|
23 |
+
opt = None
|
24 |
+
is_present = False
|
25 |
+
if name in args:
|
26 |
+
is_present = True
|
27 |
+
idx = args.index(name)
|
28 |
+
del args[idx]
|
29 |
+
if idx < len(args) and args[idx][0] != "-":
|
30 |
+
opt = args[idx]
|
31 |
+
del args[idx]
|
32 |
+
return args, is_present, opt
|
33 |
+
|
34 |
+
|
35 |
+
def run(command, desc=None, errdesc=None, custom_env=None):
|
36 |
+
if desc is not None:
|
37 |
+
print(desc)
|
38 |
+
|
39 |
+
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env)
|
40 |
+
|
41 |
+
if result.returncode != 0:
|
42 |
+
|
43 |
+
message = f"""{errdesc or 'Error running command'}.
|
44 |
+
Command: {command}
|
45 |
+
Error code: {result.returncode}
|
46 |
+
stdout: {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout)>0 else '<empty>'}
|
47 |
+
stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.stderr)>0 else '<empty>'}
|
48 |
+
"""
|
49 |
+
raise RuntimeError(message)
|
50 |
+
|
51 |
+
return result.stdout.decode(encoding="utf8", errors="ignore")
|
52 |
+
|
53 |
+
|
54 |
+
def check_run(command):
|
55 |
+
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
|
56 |
+
return result.returncode == 0
|
57 |
+
|
58 |
+
|
59 |
+
def is_installed(package):
|
60 |
+
try:
|
61 |
+
spec = importlib.util.find_spec(package)
|
62 |
+
except ModuleNotFoundError:
|
63 |
+
return False
|
64 |
+
|
65 |
+
return spec is not None
|
66 |
+
|
67 |
+
|
68 |
+
def repo_dir(name):
|
69 |
+
return os.path.join(dir_repos, name)
|
70 |
+
|
71 |
+
|
72 |
+
def run_python(code, desc=None, errdesc=None):
|
73 |
+
return run(f'"{python}" -c "{code}"', desc, errdesc)
|
74 |
+
|
75 |
+
|
76 |
+
def run_pip(args, desc=None):
|
77 |
+
index_url_line = f' --index-url {index_url}' if index_url != '' else ''
|
78 |
+
return run(f'"{python}" -m pip {args} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}")
|
79 |
+
|
80 |
+
|
81 |
+
def check_run_python(code):
|
82 |
+
return check_run(f'"{python}" -c "{code}"')
|
83 |
+
|
84 |
+
|
85 |
+
def git_clone(url, dir, name, commithash=None):
|
86 |
+
# TODO clone into temporary dir and move if successful
|
87 |
+
|
88 |
+
if os.path.exists(dir):
|
89 |
+
if commithash is None:
|
90 |
+
return
|
91 |
+
|
92 |
+
current_hash = run(f'"{git}" -C {dir} rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip()
|
93 |
+
if current_hash == commithash:
|
94 |
+
return
|
95 |
+
|
96 |
+
run(f'"{git}" -C {dir} fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
|
97 |
+
run(f'"{git}" -C {dir} checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}")
|
98 |
+
return
|
99 |
+
|
100 |
+
run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}")
|
101 |
+
|
102 |
+
if commithash is not None:
|
103 |
+
run(f'"{git}" -C {dir} checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
|
104 |
+
|
105 |
+
|
106 |
+
def version_check(commit):
|
107 |
+
try:
|
108 |
+
import requests
|
109 |
+
commits = requests.get('https://api.github.com/repos/AUTOMATIC1111/stable-diffusion-webui/branches/master').json()
|
110 |
+
if commit != "<none>" and commits['commit']['sha'] != commit:
|
111 |
+
print("--------------------------------------------------------")
|
112 |
+
print("| You are not up to date with the most recent release. |")
|
113 |
+
print("| Consider running `git pull` to update. |")
|
114 |
+
print("--------------------------------------------------------")
|
115 |
+
elif commits['commit']['sha'] == commit:
|
116 |
+
print("You are up to date with the most recent release.")
|
117 |
+
else:
|
118 |
+
print("Not a git clone, can't perform version check.")
|
119 |
+
except Exception as e:
|
120 |
+
print("version check failed", e)
|
121 |
+
|
122 |
+
|
123 |
+
def run_extension_installer(extension_dir):
|
124 |
+
path_installer = os.path.join(extension_dir, "install.py")
|
125 |
+
if not os.path.isfile(path_installer):
|
126 |
+
return
|
127 |
+
|
128 |
+
try:
|
129 |
+
env = os.environ.copy()
|
130 |
+
env['PYTHONPATH'] = os.path.abspath(".")
|
131 |
+
|
132 |
+
print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env))
|
133 |
+
except Exception as e:
|
134 |
+
print(e, file=sys.stderr)
|
135 |
+
|
136 |
+
|
137 |
+
def list_extensions(settings_file):
|
138 |
+
settings = {}
|
139 |
+
|
140 |
+
try:
|
141 |
+
if os.path.isfile(settings_file):
|
142 |
+
with open(settings_file, "r", encoding="utf8") as file:
|
143 |
+
settings = json.load(file)
|
144 |
+
except Exception as e:
|
145 |
+
print(e, file=sys.stderr)
|
146 |
+
|
147 |
+
disabled_extensions = set(settings.get('disabled_extensions', []))
|
148 |
+
|
149 |
+
return [x for x in os.listdir(dir_extensions) if x not in disabled_extensions]
|
150 |
+
|
151 |
+
|
152 |
+
def run_extensions_installers(settings_file):
|
153 |
+
if not os.path.isdir(dir_extensions):
|
154 |
+
return
|
155 |
+
|
156 |
+
for dirname_extension in list_extensions(settings_file):
|
157 |
+
run_extension_installer(os.path.join(dir_extensions, dirname_extension))
|
158 |
+
|
159 |
+
|
160 |
+
def prepare_environment():
|
161 |
+
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113")
|
162 |
+
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
163 |
+
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
|
164 |
+
|
165 |
+
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
|
166 |
+
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
|
167 |
+
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b")
|
168 |
+
|
169 |
+
xformers_windows_package = os.environ.get('XFORMERS_WINDOWS_PACKAGE', 'https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl')
|
170 |
+
|
171 |
+
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
|
172 |
+
taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
|
173 |
+
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
|
174 |
+
codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
|
175 |
+
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
|
176 |
+
|
177 |
+
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "47b6b607fdd31875c9279cd2f4f16b92e4ea958e")
|
178 |
+
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
|
179 |
+
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "5b3af030dd83e0297272d861c19477735d0317ec")
|
180 |
+
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
|
181 |
+
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
182 |
+
|
183 |
+
sys.argv += shlex.split(commandline_args)
|
184 |
+
|
185 |
+
parser = argparse.ArgumentParser()
|
186 |
+
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default='config.json')
|
187 |
+
args, _ = parser.parse_known_args(sys.argv)
|
188 |
+
|
189 |
+
sys.argv, _ = extract_arg(sys.argv, '-f')
|
190 |
+
sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test')
|
191 |
+
sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers')
|
192 |
+
sys.argv, update_check = extract_arg(sys.argv, '--update-check')
|
193 |
+
sys.argv, run_tests, test_dir = extract_opt(sys.argv, '--tests')
|
194 |
+
xformers = '--xformers' in sys.argv
|
195 |
+
ngrok = '--ngrok' in sys.argv
|
196 |
+
|
197 |
+
try:
|
198 |
+
commit = run(f"{git} rev-parse HEAD").strip()
|
199 |
+
except Exception:
|
200 |
+
commit = "<none>"
|
201 |
+
|
202 |
+
print(f"Python {sys.version}")
|
203 |
+
print(f"Commit hash: {commit}")
|
204 |
+
|
205 |
+
if not is_installed("torch") or not is_installed("torchvision"):
|
206 |
+
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch")
|
207 |
+
|
208 |
+
if not skip_torch_cuda_test:
|
209 |
+
run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'")
|
210 |
+
|
211 |
+
if not is_installed("gfpgan"):
|
212 |
+
run_pip(f"install {gfpgan_package}", "gfpgan")
|
213 |
+
|
214 |
+
if not is_installed("clip"):
|
215 |
+
run_pip(f"install {clip_package}", "clip")
|
216 |
+
|
217 |
+
if not is_installed("open_clip"):
|
218 |
+
run_pip(f"install {openclip_package}", "open_clip")
|
219 |
+
|
220 |
+
if (not is_installed("xformers") or reinstall_xformers) and xformers:
|
221 |
+
if platform.system() == "Windows":
|
222 |
+
if platform.python_version().startswith("3.10"):
|
223 |
+
run_pip(f"install -U -I --no-deps {xformers_windows_package}", "xformers")
|
224 |
+
else:
|
225 |
+
print("Installation of xformers is not supported in this version of Python.")
|
226 |
+
print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")
|
227 |
+
if not is_installed("xformers"):
|
228 |
+
exit(0)
|
229 |
+
elif platform.system() == "Linux":
|
230 |
+
run_pip("install xformers", "xformers")
|
231 |
+
|
232 |
+
if not is_installed("pyngrok") and ngrok:
|
233 |
+
run_pip("install pyngrok", "ngrok")
|
234 |
+
|
235 |
+
os.makedirs(dir_repos, exist_ok=True)
|
236 |
+
|
237 |
+
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
|
238 |
+
git_clone(taming_transformers_repo, repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash)
|
239 |
+
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
|
240 |
+
git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
|
241 |
+
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
|
242 |
+
|
243 |
+
if not is_installed("lpips"):
|
244 |
+
run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer")
|
245 |
+
|
246 |
+
run_pip(f"install -r {requirements_file}", "requirements for Web UI")
|
247 |
+
|
248 |
+
run_extensions_installers(settings_file=args.ui_settings_file)
|
249 |
+
|
250 |
+
if update_check:
|
251 |
+
version_check(commit)
|
252 |
+
|
253 |
+
if "--exit" in sys.argv:
|
254 |
+
print("Exiting because of --exit argument")
|
255 |
+
exit(0)
|
256 |
+
|
257 |
+
if run_tests:
|
258 |
+
exitcode = tests(test_dir)
|
259 |
+
exit(exitcode)
|
260 |
+
|
261 |
+
|
262 |
+
def tests(test_dir):
|
263 |
+
if "--api" not in sys.argv:
|
264 |
+
sys.argv.append("--api")
|
265 |
+
if "--ckpt" not in sys.argv:
|
266 |
+
sys.argv.append("--ckpt")
|
267 |
+
sys.argv.append("./test/test_files/empty.pt")
|
268 |
+
if "--skip-torch-cuda-test" not in sys.argv:
|
269 |
+
sys.argv.append("--skip-torch-cuda-test")
|
270 |
+
|
271 |
+
print(f"Launching Web UI in another process for testing with arguments: {' '.join(sys.argv[1:])}")
|
272 |
+
|
273 |
+
with open('test/stdout.txt', "w", encoding="utf8") as stdout, open('test/stderr.txt', "w", encoding="utf8") as stderr:
|
274 |
+
proc = subprocess.Popen([sys.executable, *sys.argv], stdout=stdout, stderr=stderr)
|
275 |
+
|
276 |
+
import test.server_poll
|
277 |
+
exitcode = test.server_poll.run_tests(proc, test_dir)
|
278 |
+
|
279 |
+
print(f"Stopping Web UI process with id {proc.pid}")
|
280 |
+
proc.kill()
|
281 |
+
return exitcode
|
282 |
+
|
283 |
+
|
284 |
+
def start():
|
285 |
+
print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}")
|
286 |
+
import webui
|
287 |
+
if '--nowebui' in sys.argv:
|
288 |
+
webui.api_only()
|
289 |
+
else:
|
290 |
+
webui.webui()
|
291 |
+
|
292 |
+
|
293 |
+
if __name__ == "__main__":
|
294 |
+
prepare_environment()
|
295 |
+
start()
|
localizations/Put localization files here.txt
ADDED
File without changes
|
modules/api/api.py
ADDED
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import io
|
3 |
+
import time
|
4 |
+
import uvicorn
|
5 |
+
from threading import Lock
|
6 |
+
from io import BytesIO
|
7 |
+
from gradio.processing_utils import decode_base64_to_file
|
8 |
+
from fastapi import APIRouter, Depends, FastAPI, HTTPException
|
9 |
+
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
10 |
+
from secrets import compare_digest
|
11 |
+
|
12 |
+
import modules.shared as shared
|
13 |
+
from modules import sd_samplers, deepbooru, sd_hijack
|
14 |
+
from modules.api.models import *
|
15 |
+
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
16 |
+
from modules.extras import run_extras, run_pnginfo
|
17 |
+
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
18 |
+
from modules.textual_inversion.preprocess import preprocess
|
19 |
+
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
20 |
+
from PIL import PngImagePlugin,Image
|
21 |
+
from modules.sd_models import checkpoints_list
|
22 |
+
from modules.realesrgan_model import get_realesrgan_models
|
23 |
+
from modules import devices
|
24 |
+
from typing import List
|
25 |
+
|
26 |
+
def upscaler_to_index(name: str):
|
27 |
+
try:
|
28 |
+
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
|
29 |
+
except:
|
30 |
+
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")
|
31 |
+
|
32 |
+
|
33 |
+
def validate_sampler_name(name):
|
34 |
+
config = sd_samplers.all_samplers_map.get(name, None)
|
35 |
+
if config is None:
|
36 |
+
raise HTTPException(status_code=404, detail="Sampler not found")
|
37 |
+
|
38 |
+
return name
|
39 |
+
|
40 |
+
def setUpscalers(req: dict):
|
41 |
+
reqDict = vars(req)
|
42 |
+
reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1)
|
43 |
+
reqDict['extras_upscaler_2'] = upscaler_to_index(req.upscaler_2)
|
44 |
+
reqDict.pop('upscaler_1')
|
45 |
+
reqDict.pop('upscaler_2')
|
46 |
+
return reqDict
|
47 |
+
|
48 |
+
def decode_base64_to_image(encoding):
|
49 |
+
if encoding.startswith("data:image/"):
|
50 |
+
encoding = encoding.split(";")[1].split(",")[1]
|
51 |
+
return Image.open(BytesIO(base64.b64decode(encoding)))
|
52 |
+
|
53 |
+
def encode_pil_to_base64(image):
|
54 |
+
with io.BytesIO() as output_bytes:
|
55 |
+
|
56 |
+
# Copy any text-only metadata
|
57 |
+
use_metadata = False
|
58 |
+
metadata = PngImagePlugin.PngInfo()
|
59 |
+
for key, value in image.info.items():
|
60 |
+
if isinstance(key, str) and isinstance(value, str):
|
61 |
+
metadata.add_text(key, value)
|
62 |
+
use_metadata = True
|
63 |
+
|
64 |
+
image.save(
|
65 |
+
output_bytes, "PNG", pnginfo=(metadata if use_metadata else None)
|
66 |
+
)
|
67 |
+
bytes_data = output_bytes.getvalue()
|
68 |
+
return base64.b64encode(bytes_data)
|
69 |
+
|
70 |
+
|
71 |
+
class Api:
|
72 |
+
def __init__(self, app: FastAPI, queue_lock: Lock):
|
73 |
+
if shared.cmd_opts.api_auth:
|
74 |
+
self.credentials = dict()
|
75 |
+
for auth in shared.cmd_opts.api_auth.split(","):
|
76 |
+
user, password = auth.split(":")
|
77 |
+
self.credentials[user] = password
|
78 |
+
|
79 |
+
self.router = APIRouter()
|
80 |
+
self.app = app
|
81 |
+
self.queue_lock = queue_lock
|
82 |
+
self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
|
83 |
+
self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
|
84 |
+
self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
|
85 |
+
self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
|
86 |
+
self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
|
87 |
+
self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
|
88 |
+
self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
|
89 |
+
self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
|
90 |
+
self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
|
91 |
+
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
|
92 |
+
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
|
93 |
+
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
|
94 |
+
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem])
|
95 |
+
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem])
|
96 |
+
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem])
|
97 |
+
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
|
98 |
+
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
|
99 |
+
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
|
100 |
+
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
|
101 |
+
self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
|
102 |
+
self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
|
103 |
+
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
104 |
+
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
|
105 |
+
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse)
|
106 |
+
self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=PreprocessResponse)
|
107 |
+
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
|
108 |
+
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
|
109 |
+
|
110 |
+
def add_api_route(self, path: str, endpoint, **kwargs):
|
111 |
+
if shared.cmd_opts.api_auth:
|
112 |
+
return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
|
113 |
+
return self.app.add_api_route(path, endpoint, **kwargs)
|
114 |
+
|
115 |
+
def auth(self, credentials: HTTPBasicCredentials = Depends(HTTPBasic())):
|
116 |
+
if credentials.username in self.credentials:
|
117 |
+
if compare_digest(credentials.password, self.credentials[credentials.username]):
|
118 |
+
return True
|
119 |
+
|
120 |
+
raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})
|
121 |
+
|
122 |
+
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
|
123 |
+
populate = txt2imgreq.copy(update={ # Override __init__ params
|
124 |
+
"sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
|
125 |
+
"do_not_save_samples": True,
|
126 |
+
"do_not_save_grid": True
|
127 |
+
}
|
128 |
+
)
|
129 |
+
if populate.sampler_name:
|
130 |
+
populate.sampler_index = None # prevent a warning later on
|
131 |
+
|
132 |
+
with self.queue_lock:
|
133 |
+
p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **vars(populate))
|
134 |
+
|
135 |
+
shared.state.begin()
|
136 |
+
processed = process_images(p)
|
137 |
+
shared.state.end()
|
138 |
+
|
139 |
+
|
140 |
+
b64images = list(map(encode_pil_to_base64, processed.images))
|
141 |
+
|
142 |
+
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
|
143 |
+
|
144 |
+
def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
|
145 |
+
init_images = img2imgreq.init_images
|
146 |
+
if init_images is None:
|
147 |
+
raise HTTPException(status_code=404, detail="Init image not found")
|
148 |
+
|
149 |
+
mask = img2imgreq.mask
|
150 |
+
if mask:
|
151 |
+
mask = decode_base64_to_image(mask)
|
152 |
+
|
153 |
+
populate = img2imgreq.copy(update={ # Override __init__ params
|
154 |
+
"sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
|
155 |
+
"do_not_save_samples": True,
|
156 |
+
"do_not_save_grid": True,
|
157 |
+
"mask": mask
|
158 |
+
}
|
159 |
+
)
|
160 |
+
if populate.sampler_name:
|
161 |
+
populate.sampler_index = None # prevent a warning later on
|
162 |
+
|
163 |
+
args = vars(populate)
|
164 |
+
args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
|
165 |
+
|
166 |
+
with self.queue_lock:
|
167 |
+
p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
|
168 |
+
p.init_images = [decode_base64_to_image(x) for x in init_images]
|
169 |
+
|
170 |
+
shared.state.begin()
|
171 |
+
processed = process_images(p)
|
172 |
+
shared.state.end()
|
173 |
+
|
174 |
+
b64images = list(map(encode_pil_to_base64, processed.images))
|
175 |
+
|
176 |
+
if not img2imgreq.include_init_images:
|
177 |
+
img2imgreq.init_images = None
|
178 |
+
img2imgreq.mask = None
|
179 |
+
|
180 |
+
return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
|
181 |
+
|
182 |
+
def extras_single_image_api(self, req: ExtrasSingleImageRequest):
|
183 |
+
reqDict = setUpscalers(req)
|
184 |
+
|
185 |
+
reqDict['image'] = decode_base64_to_image(reqDict['image'])
|
186 |
+
|
187 |
+
with self.queue_lock:
|
188 |
+
result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
|
189 |
+
|
190 |
+
return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
|
191 |
+
|
192 |
+
def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
|
193 |
+
reqDict = setUpscalers(req)
|
194 |
+
|
195 |
+
def prepareFiles(file):
|
196 |
+
file = decode_base64_to_file(file.data, file_path=file.name)
|
197 |
+
file.orig_name = file.name
|
198 |
+
return file
|
199 |
+
|
200 |
+
reqDict['image_folder'] = list(map(prepareFiles, reqDict['imageList']))
|
201 |
+
reqDict.pop('imageList')
|
202 |
+
|
203 |
+
with self.queue_lock:
|
204 |
+
result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict)
|
205 |
+
|
206 |
+
return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
|
207 |
+
|
208 |
+
def pnginfoapi(self, req: PNGInfoRequest):
|
209 |
+
if(not req.image.strip()):
|
210 |
+
return PNGInfoResponse(info="")
|
211 |
+
|
212 |
+
result = run_pnginfo(decode_base64_to_image(req.image.strip()))
|
213 |
+
|
214 |
+
return PNGInfoResponse(info=result[1])
|
215 |
+
|
216 |
+
def progressapi(self, req: ProgressRequest = Depends()):
|
217 |
+
# copy from check_progress_call of ui.py
|
218 |
+
|
219 |
+
if shared.state.job_count == 0:
|
220 |
+
return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict())
|
221 |
+
|
222 |
+
# avoid dividing zero
|
223 |
+
progress = 0.01
|
224 |
+
|
225 |
+
if shared.state.job_count > 0:
|
226 |
+
progress += shared.state.job_no / shared.state.job_count
|
227 |
+
if shared.state.sampling_steps > 0:
|
228 |
+
progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
|
229 |
+
|
230 |
+
time_since_start = time.time() - shared.state.time_start
|
231 |
+
eta = (time_since_start/progress)
|
232 |
+
eta_relative = eta-time_since_start
|
233 |
+
|
234 |
+
progress = min(progress, 1)
|
235 |
+
|
236 |
+
shared.state.set_current_image()
|
237 |
+
|
238 |
+
current_image = None
|
239 |
+
if shared.state.current_image and not req.skip_current_image:
|
240 |
+
current_image = encode_pil_to_base64(shared.state.current_image)
|
241 |
+
|
242 |
+
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image)
|
243 |
+
|
244 |
+
def interrogateapi(self, interrogatereq: InterrogateRequest):
|
245 |
+
image_b64 = interrogatereq.image
|
246 |
+
if image_b64 is None:
|
247 |
+
raise HTTPException(status_code=404, detail="Image not found")
|
248 |
+
|
249 |
+
img = decode_base64_to_image(image_b64)
|
250 |
+
img = img.convert('RGB')
|
251 |
+
|
252 |
+
# Override object param
|
253 |
+
with self.queue_lock:
|
254 |
+
if interrogatereq.model == "clip":
|
255 |
+
processed = shared.interrogator.interrogate(img)
|
256 |
+
elif interrogatereq.model == "deepdanbooru":
|
257 |
+
processed = deepbooru.model.tag(img)
|
258 |
+
else:
|
259 |
+
raise HTTPException(status_code=404, detail="Model not found")
|
260 |
+
|
261 |
+
return InterrogateResponse(caption=processed)
|
262 |
+
|
263 |
+
def interruptapi(self):
|
264 |
+
shared.state.interrupt()
|
265 |
+
|
266 |
+
return {}
|
267 |
+
|
268 |
+
def skip(self):
|
269 |
+
shared.state.skip()
|
270 |
+
|
271 |
+
def get_config(self):
|
272 |
+
options = {}
|
273 |
+
for key in shared.opts.data.keys():
|
274 |
+
metadata = shared.opts.data_labels.get(key)
|
275 |
+
if(metadata is not None):
|
276 |
+
options.update({key: shared.opts.data.get(key, shared.opts.data_labels.get(key).default)})
|
277 |
+
else:
|
278 |
+
options.update({key: shared.opts.data.get(key, None)})
|
279 |
+
|
280 |
+
return options
|
281 |
+
|
282 |
+
def set_config(self, req: Dict[str, Any]):
|
283 |
+
for k, v in req.items():
|
284 |
+
shared.opts.set(k, v)
|
285 |
+
|
286 |
+
shared.opts.save(shared.config_filename)
|
287 |
+
return
|
288 |
+
|
289 |
+
def get_cmd_flags(self):
|
290 |
+
return vars(shared.cmd_opts)
|
291 |
+
|
292 |
+
def get_samplers(self):
|
293 |
+
return [{"name": sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]
|
294 |
+
|
295 |
+
def get_upscalers(self):
|
296 |
+
upscalers = []
|
297 |
+
|
298 |
+
for upscaler in shared.sd_upscalers:
|
299 |
+
u = upscaler.scaler
|
300 |
+
upscalers.append({"name":u.name, "model_name":u.model_name, "model_path":u.model_path, "model_url":u.model_url})
|
301 |
+
|
302 |
+
return upscalers
|
303 |
+
|
304 |
+
def get_sd_models(self):
|
305 |
+
return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": x.config} for x in checkpoints_list.values()]
|
306 |
+
|
307 |
+
def get_hypernetworks(self):
|
308 |
+
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
|
309 |
+
|
310 |
+
def get_face_restorers(self):
|
311 |
+
return [{"name":x.name(), "cmd_dir": getattr(x, "cmd_dir", None)} for x in shared.face_restorers]
|
312 |
+
|
313 |
+
def get_realesrgan_models(self):
|
314 |
+
return [{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)]
|
315 |
+
|
316 |
+
def get_prompt_styles(self):
|
317 |
+
styleList = []
|
318 |
+
for k in shared.prompt_styles.styles:
|
319 |
+
style = shared.prompt_styles.styles[k]
|
320 |
+
styleList.append({"name":style[0], "prompt": style[1], "negative_prompt": style[2]})
|
321 |
+
|
322 |
+
return styleList
|
323 |
+
|
324 |
+
def get_artists_categories(self):
|
325 |
+
return shared.artist_db.cats
|
326 |
+
|
327 |
+
def get_artists(self):
|
328 |
+
return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists]
|
329 |
+
|
330 |
+
def refresh_checkpoints(self):
|
331 |
+
shared.refresh_checkpoints()
|
332 |
+
|
333 |
+
def create_embedding(self, args: dict):
|
334 |
+
try:
|
335 |
+
shared.state.begin()
|
336 |
+
filename = create_embedding(**args) # create empty embedding
|
337 |
+
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
|
338 |
+
shared.state.end()
|
339 |
+
return CreateResponse(info = "create embedding filename: {filename}".format(filename = filename))
|
340 |
+
except AssertionError as e:
|
341 |
+
shared.state.end()
|
342 |
+
return TrainResponse(info = "create embedding error: {error}".format(error = e))
|
343 |
+
|
344 |
+
def create_hypernetwork(self, args: dict):
|
345 |
+
try:
|
346 |
+
shared.state.begin()
|
347 |
+
filename = create_hypernetwork(**args) # create empty embedding
|
348 |
+
shared.state.end()
|
349 |
+
return CreateResponse(info = "create hypernetwork filename: {filename}".format(filename = filename))
|
350 |
+
except AssertionError as e:
|
351 |
+
shared.state.end()
|
352 |
+
return TrainResponse(info = "create hypernetwork error: {error}".format(error = e))
|
353 |
+
|
354 |
+
def preprocess(self, args: dict):
|
355 |
+
try:
|
356 |
+
shared.state.begin()
|
357 |
+
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
|
358 |
+
shared.state.end()
|
359 |
+
return PreprocessResponse(info = 'preprocess complete')
|
360 |
+
except KeyError as e:
|
361 |
+
shared.state.end()
|
362 |
+
return PreprocessResponse(info = "preprocess error: invalid token: {error}".format(error = e))
|
363 |
+
except AssertionError as e:
|
364 |
+
shared.state.end()
|
365 |
+
return PreprocessResponse(info = "preprocess error: {error}".format(error = e))
|
366 |
+
except FileNotFoundError as e:
|
367 |
+
shared.state.end()
|
368 |
+
return PreprocessResponse(info = 'preprocess error: {error}'.format(error = e))
|
369 |
+
|
370 |
+
def train_embedding(self, args: dict):
|
371 |
+
try:
|
372 |
+
shared.state.begin()
|
373 |
+
apply_optimizations = shared.opts.training_xattention_optimizations
|
374 |
+
error = None
|
375 |
+
filename = ''
|
376 |
+
if not apply_optimizations:
|
377 |
+
sd_hijack.undo_optimizations()
|
378 |
+
try:
|
379 |
+
embedding, filename = train_embedding(**args) # can take a long time to complete
|
380 |
+
except Exception as e:
|
381 |
+
error = e
|
382 |
+
finally:
|
383 |
+
if not apply_optimizations:
|
384 |
+
sd_hijack.apply_optimizations()
|
385 |
+
shared.state.end()
|
386 |
+
return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error))
|
387 |
+
except AssertionError as msg:
|
388 |
+
shared.state.end()
|
389 |
+
return TrainResponse(info = "train embedding error: {msg}".format(msg = msg))
|
390 |
+
|
391 |
+
def train_hypernetwork(self, args: dict):
|
392 |
+
try:
|
393 |
+
shared.state.begin()
|
394 |
+
initial_hypernetwork = shared.loaded_hypernetwork
|
395 |
+
apply_optimizations = shared.opts.training_xattention_optimizations
|
396 |
+
error = None
|
397 |
+
filename = ''
|
398 |
+
if not apply_optimizations:
|
399 |
+
sd_hijack.undo_optimizations()
|
400 |
+
try:
|
401 |
+
hypernetwork, filename = train_hypernetwork(*args)
|
402 |
+
except Exception as e:
|
403 |
+
error = e
|
404 |
+
finally:
|
405 |
+
shared.loaded_hypernetwork = initial_hypernetwork
|
406 |
+
shared.sd_model.cond_stage_model.to(devices.device)
|
407 |
+
shared.sd_model.first_stage_model.to(devices.device)
|
408 |
+
if not apply_optimizations:
|
409 |
+
sd_hijack.apply_optimizations()
|
410 |
+
shared.state.end()
|
411 |
+
return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error))
|
412 |
+
except AssertionError as msg:
|
413 |
+
shared.state.end()
|
414 |
+
return TrainResponse(info = "train embedding error: {error}".format(error = error))
|
415 |
+
|
416 |
+
def launch(self, server_name, port):
|
417 |
+
self.app.include_router(self.router)
|
418 |
+
uvicorn.run(self.app, host=server_name, port=port)
|
modules/api/models.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
from pydantic import BaseModel, Field, create_model
|
3 |
+
from typing import Any, Optional
|
4 |
+
from typing_extensions import Literal
|
5 |
+
from inflection import underscore
|
6 |
+
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
|
7 |
+
from modules.shared import sd_upscalers, opts, parser
|
8 |
+
from typing import Dict, List
|
9 |
+
|
10 |
+
API_NOT_ALLOWED = [
|
11 |
+
"self",
|
12 |
+
"kwargs",
|
13 |
+
"sd_model",
|
14 |
+
"outpath_samples",
|
15 |
+
"outpath_grids",
|
16 |
+
"sampler_index",
|
17 |
+
"do_not_save_samples",
|
18 |
+
"do_not_save_grid",
|
19 |
+
"extra_generation_params",
|
20 |
+
"overlay_images",
|
21 |
+
"do_not_reload_embeddings",
|
22 |
+
"seed_enable_extras",
|
23 |
+
"prompt_for_display",
|
24 |
+
"sampler_noise_scheduler_override",
|
25 |
+
"ddim_discretize"
|
26 |
+
]
|
27 |
+
|
28 |
+
class ModelDef(BaseModel):
|
29 |
+
"""Assistance Class for Pydantic Dynamic Model Generation"""
|
30 |
+
|
31 |
+
field: str
|
32 |
+
field_alias: str
|
33 |
+
field_type: Any
|
34 |
+
field_value: Any
|
35 |
+
field_exclude: bool = False
|
36 |
+
|
37 |
+
|
38 |
+
class PydanticModelGenerator:
|
39 |
+
"""
|
40 |
+
Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about:
|
41 |
+
source_data is a snapshot of the default values produced by the class
|
42 |
+
params are the names of the actual keys required by __init__
|
43 |
+
"""
|
44 |
+
|
45 |
+
def __init__(
|
46 |
+
self,
|
47 |
+
model_name: str = None,
|
48 |
+
class_instance = None,
|
49 |
+
additional_fields = None,
|
50 |
+
):
|
51 |
+
def field_type_generator(k, v):
|
52 |
+
# field_type = str if not overrides.get(k) else overrides[k]["type"]
|
53 |
+
# print(k, v.annotation, v.default)
|
54 |
+
field_type = v.annotation
|
55 |
+
|
56 |
+
return Optional[field_type]
|
57 |
+
|
58 |
+
def merge_class_params(class_):
|
59 |
+
all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
|
60 |
+
parameters = {}
|
61 |
+
for classes in all_classes:
|
62 |
+
parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
|
63 |
+
return parameters
|
64 |
+
|
65 |
+
|
66 |
+
self._model_name = model_name
|
67 |
+
self._class_data = merge_class_params(class_instance)
|
68 |
+
|
69 |
+
self._model_def = [
|
70 |
+
ModelDef(
|
71 |
+
field=underscore(k),
|
72 |
+
field_alias=k,
|
73 |
+
field_type=field_type_generator(k, v),
|
74 |
+
field_value=v.default
|
75 |
+
)
|
76 |
+
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
|
77 |
+
]
|
78 |
+
|
79 |
+
for fields in additional_fields:
|
80 |
+
self._model_def.append(ModelDef(
|
81 |
+
field=underscore(fields["key"]),
|
82 |
+
field_alias=fields["key"],
|
83 |
+
field_type=fields["type"],
|
84 |
+
field_value=fields["default"],
|
85 |
+
field_exclude=fields["exclude"] if "exclude" in fields else False))
|
86 |
+
|
87 |
+
def generate_model(self):
|
88 |
+
"""
|
89 |
+
Creates a pydantic BaseModel
|
90 |
+
from the json and overrides provided at initialization
|
91 |
+
"""
|
92 |
+
fields = {
|
93 |
+
d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias, exclude=d.field_exclude)) for d in self._model_def
|
94 |
+
}
|
95 |
+
DynamicModel = create_model(self._model_name, **fields)
|
96 |
+
DynamicModel.__config__.allow_population_by_field_name = True
|
97 |
+
DynamicModel.__config__.allow_mutation = True
|
98 |
+
return DynamicModel
|
99 |
+
|
100 |
+
StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
|
101 |
+
"StableDiffusionProcessingTxt2Img",
|
102 |
+
StableDiffusionProcessingTxt2Img,
|
103 |
+
[{"key": "sampler_index", "type": str, "default": "Euler"}]
|
104 |
+
).generate_model()
|
105 |
+
|
106 |
+
StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
|
107 |
+
"StableDiffusionProcessingImg2Img",
|
108 |
+
StableDiffusionProcessingImg2Img,
|
109 |
+
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}]
|
110 |
+
).generate_model()
|
111 |
+
|
112 |
+
class TextToImageResponse(BaseModel):
|
113 |
+
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
114 |
+
parameters: dict
|
115 |
+
info: str
|
116 |
+
|
117 |
+
class ImageToImageResponse(BaseModel):
|
118 |
+
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
119 |
+
parameters: dict
|
120 |
+
info: str
|
121 |
+
|
122 |
+
class ExtrasBaseRequest(BaseModel):
|
123 |
+
resize_mode: Literal[0, 1] = Field(default=0, title="Resize Mode", description="Sets the resize mode: 0 to upscale by upscaling_resize amount, 1 to upscale up to upscaling_resize_h x upscaling_resize_w.")
|
124 |
+
show_extras_results: bool = Field(default=True, title="Show results", description="Should the backend return the generated image?")
|
125 |
+
gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.")
|
126 |
+
codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.")
|
127 |
+
codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.")
|
128 |
+
upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=4, description="By how much to upscale the image, only used when resize_mode=0.")
|
129 |
+
upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.")
|
130 |
+
upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.")
|
131 |
+
upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the chosen size?")
|
132 |
+
upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
|
133 |
+
upscaler_2: str = Field(default="None", title="Secondary upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
|
134 |
+
extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.")
|
135 |
+
upscale_first: bool = Field(default=False, title="Upscale first", description="Should the upscaler run before restoring faces?")
|
136 |
+
|
137 |
+
class ExtraBaseResponse(BaseModel):
|
138 |
+
html_info: str = Field(title="HTML info", description="A series of HTML tags containing the process info.")
|
139 |
+
|
140 |
+
class ExtrasSingleImageRequest(ExtrasBaseRequest):
|
141 |
+
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
|
142 |
+
|
143 |
+
class ExtrasSingleImageResponse(ExtraBaseResponse):
|
144 |
+
image: str = Field(default=None, title="Image", description="The generated image in base64 format.")
|
145 |
+
|
146 |
+
class FileData(BaseModel):
|
147 |
+
data: str = Field(title="File data", description="Base64 representation of the file")
|
148 |
+
name: str = Field(title="File name")
|
149 |
+
|
150 |
+
class ExtrasBatchImagesRequest(ExtrasBaseRequest):
|
151 |
+
imageList: List[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
|
152 |
+
|
153 |
+
class ExtrasBatchImagesResponse(ExtraBaseResponse):
|
154 |
+
images: List[str] = Field(title="Images", description="The generated images in base64 format.")
|
155 |
+
|
156 |
+
class PNGInfoRequest(BaseModel):
|
157 |
+
image: str = Field(title="Image", description="The base64 encoded PNG image")
|
158 |
+
|
159 |
+
class PNGInfoResponse(BaseModel):
|
160 |
+
info: str = Field(title="Image info", description="A string with all the info the image had")
|
161 |
+
|
162 |
+
class ProgressRequest(BaseModel):
|
163 |
+
skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization")
|
164 |
+
|
165 |
+
class ProgressResponse(BaseModel):
|
166 |
+
progress: float = Field(title="Progress", description="The progress with a range of 0 to 1")
|
167 |
+
eta_relative: float = Field(title="ETA in secs")
|
168 |
+
state: dict = Field(title="State", description="The current state snapshot")
|
169 |
+
current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
|
170 |
+
|
171 |
+
class InterrogateRequest(BaseModel):
|
172 |
+
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
|
173 |
+
model: str = Field(default="clip", title="Model", description="The interrogate model used.")
|
174 |
+
|
175 |
+
class InterrogateResponse(BaseModel):
|
176 |
+
caption: str = Field(default=None, title="Caption", description="The generated caption for the image.")
|
177 |
+
|
178 |
+
class TrainResponse(BaseModel):
|
179 |
+
info: str = Field(title="Train info", description="Response string from train embedding or hypernetwork task.")
|
180 |
+
|
181 |
+
class CreateResponse(BaseModel):
|
182 |
+
info: str = Field(title="Create info", description="Response string from create embedding or hypernetwork task.")
|
183 |
+
|
184 |
+
class PreprocessResponse(BaseModel):
|
185 |
+
info: str = Field(title="Preprocess info", description="Response string from preprocessing task.")
|
186 |
+
|
187 |
+
fields = {}
|
188 |
+
for key, metadata in opts.data_labels.items():
|
189 |
+
value = opts.data.get(key)
|
190 |
+
optType = opts.typemap.get(type(metadata.default), type(value))
|
191 |
+
|
192 |
+
if (metadata is not None):
|
193 |
+
fields.update({key: (Optional[optType], Field(
|
194 |
+
default=metadata.default ,description=metadata.label))})
|
195 |
+
else:
|
196 |
+
fields.update({key: (Optional[optType], Field())})
|
197 |
+
|
198 |
+
OptionsModel = create_model("Options", **fields)
|
199 |
+
|
200 |
+
flags = {}
|
201 |
+
_options = vars(parser)['_option_string_actions']
|
202 |
+
for key in _options:
|
203 |
+
if(_options[key].dest != 'help'):
|
204 |
+
flag = _options[key]
|
205 |
+
_type = str
|
206 |
+
if _options[key].default is not None: _type = type(_options[key].default)
|
207 |
+
flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))})
|
208 |
+
|
209 |
+
FlagsModel = create_model("Flags", **flags)
|
210 |
+
|
211 |
+
class SamplerItem(BaseModel):
|
212 |
+
name: str = Field(title="Name")
|
213 |
+
aliases: List[str] = Field(title="Aliases")
|
214 |
+
options: Dict[str, str] = Field(title="Options")
|
215 |
+
|
216 |
+
class UpscalerItem(BaseModel):
|
217 |
+
name: str = Field(title="Name")
|
218 |
+
model_name: Optional[str] = Field(title="Model Name")
|
219 |
+
model_path: Optional[str] = Field(title="Path")
|
220 |
+
model_url: Optional[str] = Field(title="URL")
|
221 |
+
|
222 |
+
class SDModelItem(BaseModel):
|
223 |
+
title: str = Field(title="Title")
|
224 |
+
model_name: str = Field(title="Model Name")
|
225 |
+
hash: str = Field(title="Hash")
|
226 |
+
filename: str = Field(title="Filename")
|
227 |
+
config: str = Field(title="Config file")
|
228 |
+
|
229 |
+
class HypernetworkItem(BaseModel):
|
230 |
+
name: str = Field(title="Name")
|
231 |
+
path: Optional[str] = Field(title="Path")
|
232 |
+
|
233 |
+
class FaceRestorerItem(BaseModel):
|
234 |
+
name: str = Field(title="Name")
|
235 |
+
cmd_dir: Optional[str] = Field(title="Path")
|
236 |
+
|
237 |
+
class RealesrganItem(BaseModel):
|
238 |
+
name: str = Field(title="Name")
|
239 |
+
path: Optional[str] = Field(title="Path")
|
240 |
+
scale: Optional[int] = Field(title="Scale")
|
241 |
+
|
242 |
+
class PromptStyleItem(BaseModel):
|
243 |
+
name: str = Field(title="Name")
|
244 |
+
prompt: Optional[str] = Field(title="Prompt")
|
245 |
+
negative_prompt: Optional[str] = Field(title="Negative Prompt")
|
246 |
+
|
247 |
+
class ArtistItem(BaseModel):
|
248 |
+
name: str = Field(title="Name")
|
249 |
+
score: float = Field(title="Score")
|
250 |
+
category: str = Field(title="Category")
|
251 |
+
|
modules/artists.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
+
import csv
|
3 |
+
from collections import namedtuple
|
4 |
+
|
5 |
+
Artist = namedtuple("Artist", ['name', 'weight', 'category'])
|
6 |
+
|
7 |
+
|
8 |
+
class ArtistsDatabase:
|
9 |
+
def __init__(self, filename):
|
10 |
+
self.cats = set()
|
11 |
+
self.artists = []
|
12 |
+
|
13 |
+
if not os.path.exists(filename):
|
14 |
+
return
|
15 |
+
|
16 |
+
with open(filename, "r", newline='', encoding="utf8") as file:
|
17 |
+
reader = csv.DictReader(file)
|
18 |
+
|
19 |
+
for row in reader:
|
20 |
+
artist = Artist(row["artist"], float(row["score"]), row["category"])
|
21 |
+
self.artists.append(artist)
|
22 |
+
self.cats.add(artist.category)
|
23 |
+
|
24 |
+
def categories(self):
|
25 |
+
return sorted(self.cats)
|
modules/call_queue.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import html
|
2 |
+
import sys
|
3 |
+
import threading
|
4 |
+
import traceback
|
5 |
+
import time
|
6 |
+
|
7 |
+
from modules import shared
|
8 |
+
|
9 |
+
queue_lock = threading.Lock()
|
10 |
+
|
11 |
+
|
12 |
+
def wrap_queued_call(func):
|
13 |
+
def f(*args, **kwargs):
|
14 |
+
with queue_lock:
|
15 |
+
res = func(*args, **kwargs)
|
16 |
+
|
17 |
+
return res
|
18 |
+
|
19 |
+
return f
|
20 |
+
|
21 |
+
|
22 |
+
def wrap_gradio_gpu_call(func, extra_outputs=None):
|
23 |
+
def f(*args, **kwargs):
|
24 |
+
|
25 |
+
shared.state.begin()
|
26 |
+
|
27 |
+
with queue_lock:
|
28 |
+
res = func(*args, **kwargs)
|
29 |
+
|
30 |
+
shared.state.end()
|
31 |
+
|
32 |
+
return res
|
33 |
+
|
34 |
+
return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True)
|
35 |
+
|
36 |
+
|
37 |
+
def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
|
38 |
+
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
|
39 |
+
run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
|
40 |
+
if run_memmon:
|
41 |
+
shared.mem_mon.monitor()
|
42 |
+
t = time.perf_counter()
|
43 |
+
|
44 |
+
try:
|
45 |
+
res = list(func(*args, **kwargs))
|
46 |
+
except Exception as e:
|
47 |
+
# When printing out our debug argument list, do not print out more than a MB of text
|
48 |
+
max_debug_str_len = 131072 # (1024*1024)/8
|
49 |
+
|
50 |
+
print("Error completing request", file=sys.stderr)
|
51 |
+
argStr = f"Arguments: {str(args)} {str(kwargs)}"
|
52 |
+
print(argStr[:max_debug_str_len], file=sys.stderr)
|
53 |
+
if len(argStr) > max_debug_str_len:
|
54 |
+
print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
|
55 |
+
|
56 |
+
print(traceback.format_exc(), file=sys.stderr)
|
57 |
+
|
58 |
+
shared.state.job = ""
|
59 |
+
shared.state.job_count = 0
|
60 |
+
|
61 |
+
if extra_outputs_array is None:
|
62 |
+
extra_outputs_array = [None, '']
|
63 |
+
|
64 |
+
res = extra_outputs_array + [f"<div class='error'>{html.escape(type(e).__name__+': '+str(e))}</div>"]
|
65 |
+
|
66 |
+
shared.state.skipped = False
|
67 |
+
shared.state.interrupted = False
|
68 |
+
shared.state.job_count = 0
|
69 |
+
|
70 |
+
if not add_stats:
|
71 |
+
return tuple(res)
|
72 |
+
|
73 |
+
elapsed = time.perf_counter() - t
|
74 |
+
elapsed_m = int(elapsed // 60)
|
75 |
+
elapsed_s = elapsed % 60
|
76 |
+
elapsed_text = f"{elapsed_s:.2f}s"
|
77 |
+
if elapsed_m > 0:
|
78 |
+
elapsed_text = f"{elapsed_m}m "+elapsed_text
|
79 |
+
|
80 |
+
if run_memmon:
|
81 |
+
mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
|
82 |
+
active_peak = mem_stats['active_peak']
|
83 |
+
reserved_peak = mem_stats['reserved_peak']
|
84 |
+
sys_peak = mem_stats['system_peak']
|
85 |
+
sys_total = mem_stats['total']
|
86 |
+
sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2)
|
87 |
+
|
88 |
+
vram_html = f"<p class='vram'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p>"
|
89 |
+
else:
|
90 |
+
vram_html = ''
|
91 |
+
|
92 |
+
# last item is always HTML
|
93 |
+
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
|
94 |
+
|
95 |
+
return tuple(res)
|
96 |
+
|
97 |
+
return f
|
98 |
+
|
modules/codeformer/codeformer_arch.py
ADDED
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
|
2 |
+
|
3 |
+
import math
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from torch import nn, Tensor
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from typing import Optional, List
|
9 |
+
|
10 |
+
from modules.codeformer.vqgan_arch import *
|
11 |
+
from basicsr.utils import get_root_logger
|
12 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
13 |
+
|
14 |
+
def calc_mean_std(feat, eps=1e-5):
|
15 |
+
"""Calculate mean and std for adaptive_instance_normalization.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
feat (Tensor): 4D tensor.
|
19 |
+
eps (float): A small value added to the variance to avoid
|
20 |
+
divide-by-zero. Default: 1e-5.
|
21 |
+
"""
|
22 |
+
size = feat.size()
|
23 |
+
assert len(size) == 4, 'The input feature should be 4D tensor.'
|
24 |
+
b, c = size[:2]
|
25 |
+
feat_var = feat.view(b, c, -1).var(dim=2) + eps
|
26 |
+
feat_std = feat_var.sqrt().view(b, c, 1, 1)
|
27 |
+
feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
|
28 |
+
return feat_mean, feat_std
|
29 |
+
|
30 |
+
|
31 |
+
def adaptive_instance_normalization(content_feat, style_feat):
|
32 |
+
"""Adaptive instance normalization.
|
33 |
+
|
34 |
+
Adjust the reference features to have the similar color and illuminations
|
35 |
+
as those in the degradate features.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
content_feat (Tensor): The reference feature.
|
39 |
+
style_feat (Tensor): The degradate features.
|
40 |
+
"""
|
41 |
+
size = content_feat.size()
|
42 |
+
style_mean, style_std = calc_mean_std(style_feat)
|
43 |
+
content_mean, content_std = calc_mean_std(content_feat)
|
44 |
+
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
45 |
+
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
46 |
+
|
47 |
+
|
48 |
+
class PositionEmbeddingSine(nn.Module):
|
49 |
+
"""
|
50 |
+
This is a more standard version of the position embedding, very similar to the one
|
51 |
+
used by the Attention is all you need paper, generalized to work on images.
|
52 |
+
"""
|
53 |
+
|
54 |
+
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
55 |
+
super().__init__()
|
56 |
+
self.num_pos_feats = num_pos_feats
|
57 |
+
self.temperature = temperature
|
58 |
+
self.normalize = normalize
|
59 |
+
if scale is not None and normalize is False:
|
60 |
+
raise ValueError("normalize should be True if scale is passed")
|
61 |
+
if scale is None:
|
62 |
+
scale = 2 * math.pi
|
63 |
+
self.scale = scale
|
64 |
+
|
65 |
+
def forward(self, x, mask=None):
|
66 |
+
if mask is None:
|
67 |
+
mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
|
68 |
+
not_mask = ~mask
|
69 |
+
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
70 |
+
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
71 |
+
if self.normalize:
|
72 |
+
eps = 1e-6
|
73 |
+
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
74 |
+
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
75 |
+
|
76 |
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
77 |
+
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
78 |
+
|
79 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
80 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
81 |
+
pos_x = torch.stack(
|
82 |
+
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
|
83 |
+
).flatten(3)
|
84 |
+
pos_y = torch.stack(
|
85 |
+
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
|
86 |
+
).flatten(3)
|
87 |
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
88 |
+
return pos
|
89 |
+
|
90 |
+
def _get_activation_fn(activation):
|
91 |
+
"""Return an activation function given a string"""
|
92 |
+
if activation == "relu":
|
93 |
+
return F.relu
|
94 |
+
if activation == "gelu":
|
95 |
+
return F.gelu
|
96 |
+
if activation == "glu":
|
97 |
+
return F.glu
|
98 |
+
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
99 |
+
|
100 |
+
|
101 |
+
class TransformerSALayer(nn.Module):
|
102 |
+
def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
|
103 |
+
super().__init__()
|
104 |
+
self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
|
105 |
+
# Implementation of Feedforward model - MLP
|
106 |
+
self.linear1 = nn.Linear(embed_dim, dim_mlp)
|
107 |
+
self.dropout = nn.Dropout(dropout)
|
108 |
+
self.linear2 = nn.Linear(dim_mlp, embed_dim)
|
109 |
+
|
110 |
+
self.norm1 = nn.LayerNorm(embed_dim)
|
111 |
+
self.norm2 = nn.LayerNorm(embed_dim)
|
112 |
+
self.dropout1 = nn.Dropout(dropout)
|
113 |
+
self.dropout2 = nn.Dropout(dropout)
|
114 |
+
|
115 |
+
self.activation = _get_activation_fn(activation)
|
116 |
+
|
117 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
118 |
+
return tensor if pos is None else tensor + pos
|
119 |
+
|
120 |
+
def forward(self, tgt,
|
121 |
+
tgt_mask: Optional[Tensor] = None,
|
122 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
123 |
+
query_pos: Optional[Tensor] = None):
|
124 |
+
|
125 |
+
# self attention
|
126 |
+
tgt2 = self.norm1(tgt)
|
127 |
+
q = k = self.with_pos_embed(tgt2, query_pos)
|
128 |
+
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
|
129 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
130 |
+
tgt = tgt + self.dropout1(tgt2)
|
131 |
+
|
132 |
+
# ffn
|
133 |
+
tgt2 = self.norm2(tgt)
|
134 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
135 |
+
tgt = tgt + self.dropout2(tgt2)
|
136 |
+
return tgt
|
137 |
+
|
138 |
+
class Fuse_sft_block(nn.Module):
|
139 |
+
def __init__(self, in_ch, out_ch):
|
140 |
+
super().__init__()
|
141 |
+
self.encode_enc = ResBlock(2*in_ch, out_ch)
|
142 |
+
|
143 |
+
self.scale = nn.Sequential(
|
144 |
+
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
145 |
+
nn.LeakyReLU(0.2, True),
|
146 |
+
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
|
147 |
+
|
148 |
+
self.shift = nn.Sequential(
|
149 |
+
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
150 |
+
nn.LeakyReLU(0.2, True),
|
151 |
+
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
|
152 |
+
|
153 |
+
def forward(self, enc_feat, dec_feat, w=1):
|
154 |
+
enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
|
155 |
+
scale = self.scale(enc_feat)
|
156 |
+
shift = self.shift(enc_feat)
|
157 |
+
residual = w * (dec_feat * scale + shift)
|
158 |
+
out = dec_feat + residual
|
159 |
+
return out
|
160 |
+
|
161 |
+
|
162 |
+
@ARCH_REGISTRY.register()
|
163 |
+
class CodeFormer(VQAutoEncoder):
|
164 |
+
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
|
165 |
+
codebook_size=1024, latent_size=256,
|
166 |
+
connect_list=['32', '64', '128', '256'],
|
167 |
+
fix_modules=['quantize','generator']):
|
168 |
+
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
|
169 |
+
|
170 |
+
if fix_modules is not None:
|
171 |
+
for module in fix_modules:
|
172 |
+
for param in getattr(self, module).parameters():
|
173 |
+
param.requires_grad = False
|
174 |
+
|
175 |
+
self.connect_list = connect_list
|
176 |
+
self.n_layers = n_layers
|
177 |
+
self.dim_embd = dim_embd
|
178 |
+
self.dim_mlp = dim_embd*2
|
179 |
+
|
180 |
+
self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
|
181 |
+
self.feat_emb = nn.Linear(256, self.dim_embd)
|
182 |
+
|
183 |
+
# transformer
|
184 |
+
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
|
185 |
+
for _ in range(self.n_layers)])
|
186 |
+
|
187 |
+
# logits_predict head
|
188 |
+
self.idx_pred_layer = nn.Sequential(
|
189 |
+
nn.LayerNorm(dim_embd),
|
190 |
+
nn.Linear(dim_embd, codebook_size, bias=False))
|
191 |
+
|
192 |
+
self.channels = {
|
193 |
+
'16': 512,
|
194 |
+
'32': 256,
|
195 |
+
'64': 256,
|
196 |
+
'128': 128,
|
197 |
+
'256': 128,
|
198 |
+
'512': 64,
|
199 |
+
}
|
200 |
+
|
201 |
+
# after second residual block for > 16, before attn layer for ==16
|
202 |
+
self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
|
203 |
+
# after first residual block for > 16, before attn layer for ==16
|
204 |
+
self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
|
205 |
+
|
206 |
+
# fuse_convs_dict
|
207 |
+
self.fuse_convs_dict = nn.ModuleDict()
|
208 |
+
for f_size in self.connect_list:
|
209 |
+
in_ch = self.channels[f_size]
|
210 |
+
self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
|
211 |
+
|
212 |
+
def _init_weights(self, module):
|
213 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
214 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
215 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
216 |
+
module.bias.data.zero_()
|
217 |
+
elif isinstance(module, nn.LayerNorm):
|
218 |
+
module.bias.data.zero_()
|
219 |
+
module.weight.data.fill_(1.0)
|
220 |
+
|
221 |
+
def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
|
222 |
+
# ################### Encoder #####################
|
223 |
+
enc_feat_dict = {}
|
224 |
+
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
|
225 |
+
for i, block in enumerate(self.encoder.blocks):
|
226 |
+
x = block(x)
|
227 |
+
if i in out_list:
|
228 |
+
enc_feat_dict[str(x.shape[-1])] = x.clone()
|
229 |
+
|
230 |
+
lq_feat = x
|
231 |
+
# ################# Transformer ###################
|
232 |
+
# quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
|
233 |
+
pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
|
234 |
+
# BCHW -> BC(HW) -> (HW)BC
|
235 |
+
feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
|
236 |
+
query_emb = feat_emb
|
237 |
+
# Transformer encoder
|
238 |
+
for layer in self.ft_layers:
|
239 |
+
query_emb = layer(query_emb, query_pos=pos_emb)
|
240 |
+
|
241 |
+
# output logits
|
242 |
+
logits = self.idx_pred_layer(query_emb) # (hw)bn
|
243 |
+
logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
|
244 |
+
|
245 |
+
if code_only: # for training stage II
|
246 |
+
# logits doesn't need softmax before cross_entropy loss
|
247 |
+
return logits, lq_feat
|
248 |
+
|
249 |
+
# ################# Quantization ###################
|
250 |
+
# if self.training:
|
251 |
+
# quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
|
252 |
+
# # b(hw)c -> bc(hw) -> bchw
|
253 |
+
# quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
|
254 |
+
# ------------
|
255 |
+
soft_one_hot = F.softmax(logits, dim=2)
|
256 |
+
_, top_idx = torch.topk(soft_one_hot, 1, dim=2)
|
257 |
+
quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
|
258 |
+
# preserve gradients
|
259 |
+
# quant_feat = lq_feat + (quant_feat - lq_feat).detach()
|
260 |
+
|
261 |
+
if detach_16:
|
262 |
+
quant_feat = quant_feat.detach() # for training stage III
|
263 |
+
if adain:
|
264 |
+
quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
|
265 |
+
|
266 |
+
# ################## Generator ####################
|
267 |
+
x = quant_feat
|
268 |
+
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
|
269 |
+
|
270 |
+
for i, block in enumerate(self.generator.blocks):
|
271 |
+
x = block(x)
|
272 |
+
if i in fuse_list: # fuse after i-th block
|
273 |
+
f_size = str(x.shape[-1])
|
274 |
+
if w>0:
|
275 |
+
x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
|
276 |
+
out = x
|
277 |
+
# logits doesn't need softmax before cross_entropy loss
|
278 |
+
return out, logits, lq_feat
|
modules/codeformer/vqgan_arch.py
ADDED
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
|
2 |
+
|
3 |
+
'''
|
4 |
+
VQGAN code, adapted from the original created by the Unleashing Transformers authors:
|
5 |
+
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
|
6 |
+
|
7 |
+
'''
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
import copy
|
13 |
+
from basicsr.utils import get_root_logger
|
14 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
15 |
+
|
16 |
+
def normalize(in_channels):
|
17 |
+
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
18 |
+
|
19 |
+
|
20 |
+
@torch.jit.script
|
21 |
+
def swish(x):
|
22 |
+
return x*torch.sigmoid(x)
|
23 |
+
|
24 |
+
|
25 |
+
# Define VQVAE classes
|
26 |
+
class VectorQuantizer(nn.Module):
|
27 |
+
def __init__(self, codebook_size, emb_dim, beta):
|
28 |
+
super(VectorQuantizer, self).__init__()
|
29 |
+
self.codebook_size = codebook_size # number of embeddings
|
30 |
+
self.emb_dim = emb_dim # dimension of embedding
|
31 |
+
self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
|
32 |
+
self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
|
33 |
+
self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
|
34 |
+
|
35 |
+
def forward(self, z):
|
36 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
37 |
+
z = z.permute(0, 2, 3, 1).contiguous()
|
38 |
+
z_flattened = z.view(-1, self.emb_dim)
|
39 |
+
|
40 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
41 |
+
d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \
|
42 |
+
2 * torch.matmul(z_flattened, self.embedding.weight.t())
|
43 |
+
|
44 |
+
mean_distance = torch.mean(d)
|
45 |
+
# find closest encodings
|
46 |
+
# min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
|
47 |
+
min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
|
48 |
+
# [0-1], higher score, higher confidence
|
49 |
+
min_encoding_scores = torch.exp(-min_encoding_scores/10)
|
50 |
+
|
51 |
+
min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z)
|
52 |
+
min_encodings.scatter_(1, min_encoding_indices, 1)
|
53 |
+
|
54 |
+
# get quantized latent vectors
|
55 |
+
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
|
56 |
+
# compute loss for embedding
|
57 |
+
loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
|
58 |
+
# preserve gradients
|
59 |
+
z_q = z + (z_q - z).detach()
|
60 |
+
|
61 |
+
# perplexity
|
62 |
+
e_mean = torch.mean(min_encodings, dim=0)
|
63 |
+
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
|
64 |
+
# reshape back to match original input shape
|
65 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
66 |
+
|
67 |
+
return z_q, loss, {
|
68 |
+
"perplexity": perplexity,
|
69 |
+
"min_encodings": min_encodings,
|
70 |
+
"min_encoding_indices": min_encoding_indices,
|
71 |
+
"min_encoding_scores": min_encoding_scores,
|
72 |
+
"mean_distance": mean_distance
|
73 |
+
}
|
74 |
+
|
75 |
+
def get_codebook_feat(self, indices, shape):
|
76 |
+
# input indices: batch*token_num -> (batch*token_num)*1
|
77 |
+
# shape: batch, height, width, channel
|
78 |
+
indices = indices.view(-1,1)
|
79 |
+
min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
|
80 |
+
min_encodings.scatter_(1, indices, 1)
|
81 |
+
# get quantized latent vectors
|
82 |
+
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
|
83 |
+
|
84 |
+
if shape is not None: # reshape back to match original input shape
|
85 |
+
z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
|
86 |
+
|
87 |
+
return z_q
|
88 |
+
|
89 |
+
|
90 |
+
class GumbelQuantizer(nn.Module):
|
91 |
+
def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
|
92 |
+
super().__init__()
|
93 |
+
self.codebook_size = codebook_size # number of embeddings
|
94 |
+
self.emb_dim = emb_dim # dimension of embedding
|
95 |
+
self.straight_through = straight_through
|
96 |
+
self.temperature = temp_init
|
97 |
+
self.kl_weight = kl_weight
|
98 |
+
self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits
|
99 |
+
self.embed = nn.Embedding(codebook_size, emb_dim)
|
100 |
+
|
101 |
+
def forward(self, z):
|
102 |
+
hard = self.straight_through if self.training else True
|
103 |
+
|
104 |
+
logits = self.proj(z)
|
105 |
+
|
106 |
+
soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
|
107 |
+
|
108 |
+
z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
|
109 |
+
|
110 |
+
# + kl divergence to the prior loss
|
111 |
+
qy = F.softmax(logits, dim=1)
|
112 |
+
diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
|
113 |
+
min_encoding_indices = soft_one_hot.argmax(dim=1)
|
114 |
+
|
115 |
+
return z_q, diff, {
|
116 |
+
"min_encoding_indices": min_encoding_indices
|
117 |
+
}
|
118 |
+
|
119 |
+
|
120 |
+
class Downsample(nn.Module):
|
121 |
+
def __init__(self, in_channels):
|
122 |
+
super().__init__()
|
123 |
+
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
124 |
+
|
125 |
+
def forward(self, x):
|
126 |
+
pad = (0, 1, 0, 1)
|
127 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
128 |
+
x = self.conv(x)
|
129 |
+
return x
|
130 |
+
|
131 |
+
|
132 |
+
class Upsample(nn.Module):
|
133 |
+
def __init__(self, in_channels):
|
134 |
+
super().__init__()
|
135 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
136 |
+
|
137 |
+
def forward(self, x):
|
138 |
+
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
139 |
+
x = self.conv(x)
|
140 |
+
|
141 |
+
return x
|
142 |
+
|
143 |
+
|
144 |
+
class ResBlock(nn.Module):
|
145 |
+
def __init__(self, in_channels, out_channels=None):
|
146 |
+
super(ResBlock, self).__init__()
|
147 |
+
self.in_channels = in_channels
|
148 |
+
self.out_channels = in_channels if out_channels is None else out_channels
|
149 |
+
self.norm1 = normalize(in_channels)
|
150 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
151 |
+
self.norm2 = normalize(out_channels)
|
152 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
153 |
+
if self.in_channels != self.out_channels:
|
154 |
+
self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
155 |
+
|
156 |
+
def forward(self, x_in):
|
157 |
+
x = x_in
|
158 |
+
x = self.norm1(x)
|
159 |
+
x = swish(x)
|
160 |
+
x = self.conv1(x)
|
161 |
+
x = self.norm2(x)
|
162 |
+
x = swish(x)
|
163 |
+
x = self.conv2(x)
|
164 |
+
if self.in_channels != self.out_channels:
|
165 |
+
x_in = self.conv_out(x_in)
|
166 |
+
|
167 |
+
return x + x_in
|
168 |
+
|
169 |
+
|
170 |
+
class AttnBlock(nn.Module):
|
171 |
+
def __init__(self, in_channels):
|
172 |
+
super().__init__()
|
173 |
+
self.in_channels = in_channels
|
174 |
+
|
175 |
+
self.norm = normalize(in_channels)
|
176 |
+
self.q = torch.nn.Conv2d(
|
177 |
+
in_channels,
|
178 |
+
in_channels,
|
179 |
+
kernel_size=1,
|
180 |
+
stride=1,
|
181 |
+
padding=0
|
182 |
+
)
|
183 |
+
self.k = torch.nn.Conv2d(
|
184 |
+
in_channels,
|
185 |
+
in_channels,
|
186 |
+
kernel_size=1,
|
187 |
+
stride=1,
|
188 |
+
padding=0
|
189 |
+
)
|
190 |
+
self.v = torch.nn.Conv2d(
|
191 |
+
in_channels,
|
192 |
+
in_channels,
|
193 |
+
kernel_size=1,
|
194 |
+
stride=1,
|
195 |
+
padding=0
|
196 |
+
)
|
197 |
+
self.proj_out = torch.nn.Conv2d(
|
198 |
+
in_channels,
|
199 |
+
in_channels,
|
200 |
+
kernel_size=1,
|
201 |
+
stride=1,
|
202 |
+
padding=0
|
203 |
+
)
|
204 |
+
|
205 |
+
def forward(self, x):
|
206 |
+
h_ = x
|
207 |
+
h_ = self.norm(h_)
|
208 |
+
q = self.q(h_)
|
209 |
+
k = self.k(h_)
|
210 |
+
v = self.v(h_)
|
211 |
+
|
212 |
+
# compute attention
|
213 |
+
b, c, h, w = q.shape
|
214 |
+
q = q.reshape(b, c, h*w)
|
215 |
+
q = q.permute(0, 2, 1)
|
216 |
+
k = k.reshape(b, c, h*w)
|
217 |
+
w_ = torch.bmm(q, k)
|
218 |
+
w_ = w_ * (int(c)**(-0.5))
|
219 |
+
w_ = F.softmax(w_, dim=2)
|
220 |
+
|
221 |
+
# attend to values
|
222 |
+
v = v.reshape(b, c, h*w)
|
223 |
+
w_ = w_.permute(0, 2, 1)
|
224 |
+
h_ = torch.bmm(v, w_)
|
225 |
+
h_ = h_.reshape(b, c, h, w)
|
226 |
+
|
227 |
+
h_ = self.proj_out(h_)
|
228 |
+
|
229 |
+
return x+h_
|
230 |
+
|
231 |
+
|
232 |
+
class Encoder(nn.Module):
|
233 |
+
def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
|
234 |
+
super().__init__()
|
235 |
+
self.nf = nf
|
236 |
+
self.num_resolutions = len(ch_mult)
|
237 |
+
self.num_res_blocks = num_res_blocks
|
238 |
+
self.resolution = resolution
|
239 |
+
self.attn_resolutions = attn_resolutions
|
240 |
+
|
241 |
+
curr_res = self.resolution
|
242 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
243 |
+
|
244 |
+
blocks = []
|
245 |
+
# initial convultion
|
246 |
+
blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
|
247 |
+
|
248 |
+
# residual and downsampling blocks, with attention on smaller res (16x16)
|
249 |
+
for i in range(self.num_resolutions):
|
250 |
+
block_in_ch = nf * in_ch_mult[i]
|
251 |
+
block_out_ch = nf * ch_mult[i]
|
252 |
+
for _ in range(self.num_res_blocks):
|
253 |
+
blocks.append(ResBlock(block_in_ch, block_out_ch))
|
254 |
+
block_in_ch = block_out_ch
|
255 |
+
if curr_res in attn_resolutions:
|
256 |
+
blocks.append(AttnBlock(block_in_ch))
|
257 |
+
|
258 |
+
if i != self.num_resolutions - 1:
|
259 |
+
blocks.append(Downsample(block_in_ch))
|
260 |
+
curr_res = curr_res // 2
|
261 |
+
|
262 |
+
# non-local attention block
|
263 |
+
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
264 |
+
blocks.append(AttnBlock(block_in_ch))
|
265 |
+
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
266 |
+
|
267 |
+
# normalise and convert to latent size
|
268 |
+
blocks.append(normalize(block_in_ch))
|
269 |
+
blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
|
270 |
+
self.blocks = nn.ModuleList(blocks)
|
271 |
+
|
272 |
+
def forward(self, x):
|
273 |
+
for block in self.blocks:
|
274 |
+
x = block(x)
|
275 |
+
|
276 |
+
return x
|
277 |
+
|
278 |
+
|
279 |
+
class Generator(nn.Module):
|
280 |
+
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
|
281 |
+
super().__init__()
|
282 |
+
self.nf = nf
|
283 |
+
self.ch_mult = ch_mult
|
284 |
+
self.num_resolutions = len(self.ch_mult)
|
285 |
+
self.num_res_blocks = res_blocks
|
286 |
+
self.resolution = img_size
|
287 |
+
self.attn_resolutions = attn_resolutions
|
288 |
+
self.in_channels = emb_dim
|
289 |
+
self.out_channels = 3
|
290 |
+
block_in_ch = self.nf * self.ch_mult[-1]
|
291 |
+
curr_res = self.resolution // 2 ** (self.num_resolutions-1)
|
292 |
+
|
293 |
+
blocks = []
|
294 |
+
# initial conv
|
295 |
+
blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
|
296 |
+
|
297 |
+
# non-local attention block
|
298 |
+
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
299 |
+
blocks.append(AttnBlock(block_in_ch))
|
300 |
+
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
301 |
+
|
302 |
+
for i in reversed(range(self.num_resolutions)):
|
303 |
+
block_out_ch = self.nf * self.ch_mult[i]
|
304 |
+
|
305 |
+
for _ in range(self.num_res_blocks):
|
306 |
+
blocks.append(ResBlock(block_in_ch, block_out_ch))
|
307 |
+
block_in_ch = block_out_ch
|
308 |
+
|
309 |
+
if curr_res in self.attn_resolutions:
|
310 |
+
blocks.append(AttnBlock(block_in_ch))
|
311 |
+
|
312 |
+
if i != 0:
|
313 |
+
blocks.append(Upsample(block_in_ch))
|
314 |
+
curr_res = curr_res * 2
|
315 |
+
|
316 |
+
blocks.append(normalize(block_in_ch))
|
317 |
+
blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
|
318 |
+
|
319 |
+
self.blocks = nn.ModuleList(blocks)
|
320 |
+
|
321 |
+
|
322 |
+
def forward(self, x):
|
323 |
+
for block in self.blocks:
|
324 |
+
x = block(x)
|
325 |
+
|
326 |
+
return x
|
327 |
+
|
328 |
+
|
329 |
+
@ARCH_REGISTRY.register()
|
330 |
+
class VQAutoEncoder(nn.Module):
|
331 |
+
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
|
332 |
+
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
|
333 |
+
super().__init__()
|
334 |
+
logger = get_root_logger()
|
335 |
+
self.in_channels = 3
|
336 |
+
self.nf = nf
|
337 |
+
self.n_blocks = res_blocks
|
338 |
+
self.codebook_size = codebook_size
|
339 |
+
self.embed_dim = emb_dim
|
340 |
+
self.ch_mult = ch_mult
|
341 |
+
self.resolution = img_size
|
342 |
+
self.attn_resolutions = attn_resolutions
|
343 |
+
self.quantizer_type = quantizer
|
344 |
+
self.encoder = Encoder(
|
345 |
+
self.in_channels,
|
346 |
+
self.nf,
|
347 |
+
self.embed_dim,
|
348 |
+
self.ch_mult,
|
349 |
+
self.n_blocks,
|
350 |
+
self.resolution,
|
351 |
+
self.attn_resolutions
|
352 |
+
)
|
353 |
+
if self.quantizer_type == "nearest":
|
354 |
+
self.beta = beta #0.25
|
355 |
+
self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta)
|
356 |
+
elif self.quantizer_type == "gumbel":
|
357 |
+
self.gumbel_num_hiddens = emb_dim
|
358 |
+
self.straight_through = gumbel_straight_through
|
359 |
+
self.kl_weight = gumbel_kl_weight
|
360 |
+
self.quantize = GumbelQuantizer(
|
361 |
+
self.codebook_size,
|
362 |
+
self.embed_dim,
|
363 |
+
self.gumbel_num_hiddens,
|
364 |
+
self.straight_through,
|
365 |
+
self.kl_weight
|
366 |
+
)
|
367 |
+
self.generator = Generator(
|
368 |
+
self.nf,
|
369 |
+
self.embed_dim,
|
370 |
+
self.ch_mult,
|
371 |
+
self.n_blocks,
|
372 |
+
self.resolution,
|
373 |
+
self.attn_resolutions
|
374 |
+
)
|
375 |
+
|
376 |
+
if model_path is not None:
|
377 |
+
chkpt = torch.load(model_path, map_location='cpu')
|
378 |
+
if 'params_ema' in chkpt:
|
379 |
+
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema'])
|
380 |
+
logger.info(f'vqgan is loaded from: {model_path} [params_ema]')
|
381 |
+
elif 'params' in chkpt:
|
382 |
+
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
|
383 |
+
logger.info(f'vqgan is loaded from: {model_path} [params]')
|
384 |
+
else:
|
385 |
+
raise ValueError('Wrong params!')
|
386 |
+
|
387 |
+
|
388 |
+
def forward(self, x):
|
389 |
+
x = self.encoder(x)
|
390 |
+
quant, codebook_loss, quant_stats = self.quantize(x)
|
391 |
+
x = self.generator(quant)
|
392 |
+
return x, codebook_loss, quant_stats
|
393 |
+
|
394 |
+
|
395 |
+
|
396 |
+
# patch based discriminator
|
397 |
+
@ARCH_REGISTRY.register()
|
398 |
+
class VQGANDiscriminator(nn.Module):
|
399 |
+
def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
|
400 |
+
super().__init__()
|
401 |
+
|
402 |
+
layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
|
403 |
+
ndf_mult = 1
|
404 |
+
ndf_mult_prev = 1
|
405 |
+
for n in range(1, n_layers): # gradually increase the number of filters
|
406 |
+
ndf_mult_prev = ndf_mult
|
407 |
+
ndf_mult = min(2 ** n, 8)
|
408 |
+
layers += [
|
409 |
+
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
|
410 |
+
nn.BatchNorm2d(ndf * ndf_mult),
|
411 |
+
nn.LeakyReLU(0.2, True)
|
412 |
+
]
|
413 |
+
|
414 |
+
ndf_mult_prev = ndf_mult
|
415 |
+
ndf_mult = min(2 ** n_layers, 8)
|
416 |
+
|
417 |
+
layers += [
|
418 |
+
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
|
419 |
+
nn.BatchNorm2d(ndf * ndf_mult),
|
420 |
+
nn.LeakyReLU(0.2, True)
|
421 |
+
]
|
422 |
+
|
423 |
+
layers += [
|
424 |
+
nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map
|
425 |
+
self.main = nn.Sequential(*layers)
|
426 |
+
|
427 |
+
if model_path is not None:
|
428 |
+
chkpt = torch.load(model_path, map_location='cpu')
|
429 |
+
if 'params_d' in chkpt:
|
430 |
+
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
|
431 |
+
elif 'params' in chkpt:
|
432 |
+
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
|
433 |
+
else:
|
434 |
+
raise ValueError('Wrong params!')
|
435 |
+
|
436 |
+
def forward(self, x):
|
437 |
+
return self.main(x)
|