bilegentile commited on
Commit
c19ca42
โ€ข
1 Parent(s): 31ce0ac

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes. ย  See raw diff
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ extensions-builtin/sd-webui-agent-scheduler/docs/images/walkthrough.png filter=lfs diff=lfs merge=lfs -text
37
+ extensions-builtin/stable-diffusion-webui-rembg/preview.png filter=lfs diff=lfs merge=lfs -text
38
+ javascript/notosans-nerdfont-regular.ttf filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # defaults
2
+ __pycache__
3
+ .ruff_cache
4
+ /cache.json
5
+ /*.json
6
+ /*.yaml
7
+ /params.txt
8
+ /styles.csv
9
+ /user.css
10
+ /webui-user.bat
11
+ /webui-user.sh
12
+ /html/extensions.json
13
+ /html/themes.json
14
+ node_modules
15
+ pnpm-lock.yaml
16
+ package-lock.json
17
+ venv
18
+ .history
19
+ cache
20
+ **/.DS_Store
21
+
22
+ # all models and temp files
23
+ *.log
24
+ *.log.*
25
+ *.bak
26
+ *.ckpt
27
+ *.safetensors
28
+ *.pth
29
+ *.pt
30
+ *.bin
31
+ *.optim
32
+ *.lock
33
+ *.zip
34
+ *.rar
35
+ *.7z
36
+ *.pyc
37
+ /*.bat
38
+ /*.sh
39
+ /*.txt
40
+ /*.mp3
41
+ /*.lnk
42
+ !webui.bat
43
+ !webui.sh
44
+ !package.json
45
+
46
+ # all dynamic stuff
47
+ /extensions/**/*
48
+ /outputs/**/*
49
+ /embeddings/**/*
50
+ /models/**/*
51
+ /interrogate/**/*
52
+ /train/log/**/*
53
+ /textual_inversion/**/*
54
+ /detected_maps/**/*
55
+ /tmp
56
+ /log
57
+ /cert
58
+ .vscode/
59
+ .idea/
60
+ /localizations
61
+
62
+ .*/
63
+
64
+ # force included
65
+ !/models/VAE-approx
66
+ !/models/VAE-approx/model.pt
67
+ !/models/Reference
68
+ !/models/Reference/**/*
.gitmodules ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [submodule "wiki"]
2
+ path = wiki
3
+ url = https://github.com/vladmandic/automatic.wiki
4
+ ignore = dirty
5
+ [submodule "modules/k-diffusion"]
6
+ path = modules/k-diffusion
7
+ url = https://github.com/crowsonkb/k-diffusion
8
+ ignore = dirty
9
+ [submodule "extensions-builtin/sd-extension-system-info"]
10
+ path = extensions-builtin/sd-extension-system-info
11
+ url = https://github.com/vladmandic/sd-extension-system-info
12
+ ignore = dirty
13
+ [submodule "extensions-builtin/sd-extension-chainner"]
14
+ path = extensions-builtin/sd-extension-chainner
15
+ url = https://github.com/vladmandic/sd-extension-chainner
16
+ ignore = dirty
17
+ [submodule "extensions-builtin/stable-diffusion-webui-rembg"]
18
+ path = extensions-builtin/stable-diffusion-webui-rembg
19
+ url = https://github.com/vladmandic/sd-extension-rembg
20
+ ignore = dirty
21
+ [submodule "extensions-builtin/stable-diffusion-webui-images-browser"]
22
+ path = extensions-builtin/stable-diffusion-webui-images-browser
23
+ url = https://github.com/AlUlkesh/stable-diffusion-webui-images-browser
24
+ ignore = dirty
25
+ [submodule "extensions-builtin/sd-webui-controlnet"]
26
+ path = extensions-builtin/sd-webui-controlnet
27
+ url = https://github.com/Mikubill/sd-webui-controlnet
28
+ ignore = dirty
29
+ [submodule "extensions-builtin/sd-webui-agent-scheduler"]
30
+ path = extensions-builtin/sd-webui-agent-scheduler
31
+ url = https://github.com/ArtVentureX/sd-webui-agent-scheduler
32
+ ignore = dirty
.pylintrc ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [MAIN]
2
+ analyse-fallback-blocks=no
3
+ clear-cache-post-run=no
4
+ #enable-all-extensions=
5
+ #errors-only=
6
+ #exit-zero=
7
+ extension-pkg-allow-list=
8
+ extension-pkg-whitelist=
9
+ fail-on=
10
+ fail-under=10
11
+ ignore=CVS
12
+ ignore-paths=/usr/lib/.*$,
13
+ ^repositories/.*$,
14
+ ^extensions/.*$,
15
+ ^extensions-builtin/.*$,
16
+ ^modules/dml/.*$,
17
+ ^modules/tcd/.*$,
18
+ ^modules/xadapters/.*$,
19
+ ignore-patterns=
20
+ ignored-modules=
21
+ jobs=0
22
+ limit-inference-results=100
23
+ load-plugins=
24
+ persistent=yes
25
+ py-version=3.10
26
+ recursive=no
27
+ source-roots=
28
+ suggestion-mode=yes
29
+ unsafe-load-any-extension=no
30
+ #verbose=
31
+
32
+ [BASIC]
33
+ argument-naming-style=snake_case
34
+ #argument-rgx=
35
+ attr-naming-style=snake_case
36
+ #attr-rgx=
37
+ bad-names=foo, bar, baz, toto, tutu, tata
38
+ bad-names-rgxs=
39
+ class-attribute-naming-style=any
40
+ class-const-naming-style=UPPER_CASE
41
+ #class-const-rgx=
42
+ class-naming-style=PascalCase
43
+ #class-rgx=
44
+ const-naming-style=snake_case
45
+ #const-rgx=
46
+ docstring-min-length=-1
47
+ function-naming-style=snake_case
48
+ #function-rgx=
49
+ # Good variable names which should always be accepted, separated by a comma.
50
+ good-names=i,j,k,e,ex,ok,p
51
+ good-names-rgxs=
52
+ include-naming-hint=no
53
+ inlinevar-naming-style=any
54
+ #inlinevar-rgx=
55
+ method-naming-style=snake_case
56
+ #method-rgx=
57
+ module-naming-style=snake_case
58
+ #module-rgx=
59
+ name-group=
60
+ no-docstring-rgx=^_
61
+ property-classes=abc.abstractproperty
62
+ #typealias-rgx=
63
+ #typevar-rgx=
64
+ variable-naming-style=snake_case
65
+ #variable-rgx=
66
+
67
+ [CLASSES]
68
+ check-protected-access-in-special-methods=no
69
+ defining-attr-methods=__init__,
70
+ __new__,
71
+ setUp,
72
+ asyncSetUp,
73
+ __post_init__
74
+ exclude-protected=_asdict,_fields,_replace,_source,_make,os._exit
75
+ valid-classmethod-first-arg=cls
76
+ valid-metaclass-classmethod-first-arg=mcs
77
+
78
+ [DESIGN]
79
+ exclude-too-few-public-methods=
80
+ ignored-parents=
81
+ max-args=99
82
+ max-attributes=99
83
+ max-bool-expr=99
84
+ max-branches=99
85
+ max-locals=99
86
+ max-parents=99
87
+ max-public-methods=99
88
+ max-returns=99
89
+ max-statements=199
90
+ min-public-methods=1
91
+
92
+ [EXCEPTIONS]
93
+ overgeneral-exceptions=builtins.BaseException,builtins.Exception
94
+
95
+ [FORMAT]
96
+ expected-line-ending-format=
97
+ ignore-long-lines=^\s*(# )?<?https?://\S+>?$
98
+ indent-after-paren=4
99
+ indent-string=' '
100
+ max-line-length=200
101
+ max-module-lines=9999
102
+ single-line-class-stmt=no
103
+ single-line-if-stmt=no
104
+
105
+ [IMPORTS]
106
+ allow-any-import-level=
107
+ allow-reexport-from-package=no
108
+ allow-wildcard-with-all=no
109
+ deprecated-modules=
110
+ ext-import-graph=
111
+ import-graph=
112
+ int-import-graph=
113
+ known-standard-library=
114
+ known-third-party=enchant
115
+ preferred-modules=
116
+
117
+ [LOGGING]
118
+ logging-format-style=new
119
+ logging-modules=logging
120
+
121
+ [MESSAGES CONTROL]
122
+ confidence=HIGH,
123
+ CONTROL_FLOW,
124
+ INFERENCE,
125
+ INFERENCE_FAILURE,
126
+ UNDEFINED
127
+ # disable=C,R,W
128
+ disable=bad-inline-option,
129
+ bare-except,
130
+ broad-exception-caught,
131
+ chained-comparison,
132
+ consider-iterating-dictionary,
133
+ consider-using-dict-items,
134
+ consider-using-generator,
135
+ consider-using-enumerate,
136
+ consider-using-sys-exit,
137
+ consider-using-from-import,
138
+ consider-using-get,
139
+ consider-using-in,
140
+ consider-using-min-builtin,
141
+ dangerous-default-value,
142
+ deprecated-pragma,
143
+ duplicate-code,
144
+ file-ignored,
145
+ import-error,
146
+ import-outside-toplevel,
147
+ invalid-name,
148
+ line-too-long,
149
+ locally-disabled,
150
+ logging-fstring-interpolation,
151
+ missing-class-docstring,
152
+ missing-function-docstring,
153
+ missing-module-docstring,
154
+ no-else-return,
155
+ not-callable,
156
+ pointless-string-statement,
157
+ raw-checker-failed,
158
+ simplifiable-if-expression,
159
+ suppressed-message,
160
+ too-many-nested-blocks,
161
+ too-few-public-methods,
162
+ too-many-statements,
163
+ too-many-locals,
164
+ too-many-instance-attributes,
165
+ unnecessary-dunder-call,
166
+ unnecessary-lambda,
167
+ use-dict-literal,
168
+ use-symbolic-message-instead,
169
+ useless-suppression,
170
+ unidiomatic-typecheck,
171
+ wrong-import-position
172
+ enable=c-extension-no-member
173
+
174
+ [METHOD_ARGS]
175
+ timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request
176
+
177
+ [MISCELLANEOUS]
178
+ notes=FIXME,
179
+ XXX,
180
+ TODO
181
+ notes-rgx=
182
+
183
+ [REFACTORING]
184
+ max-nested-blocks=5
185
+ never-returning-functions=sys.exit,argparse.parse_error
186
+
187
+ [REPORTS]
188
+ evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10))
189
+ msg-template=
190
+ #output-format=
191
+ reports=no
192
+ score=no
193
+
194
+ [SIMILARITIES]
195
+ ignore-comments=yes
196
+ ignore-docstrings=yes
197
+ ignore-imports=yes
198
+ ignore-signatures=yes
199
+ min-similarity-lines=4
200
+
201
+ [SPELLING]
202
+ max-spelling-suggestions=4
203
+ spelling-dict=
204
+ spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy:
205
+ spelling-ignore-words=
206
+ spelling-private-dict-file=
207
+ spelling-store-unknown-words=no
208
+
209
+ [STRING]
210
+ check-quote-consistency=no
211
+ check-str-concat-over-line-jumps=no
212
+
213
+ [TYPECHECK]
214
+ contextmanager-decorators=contextlib.contextmanager
215
+ generated-members=numpy.*,logging.*,torch.*,cv2.*
216
+ ignore-none=yes
217
+ ignore-on-opaque-inference=yes
218
+ ignored-checks-for-mixins=no-member,
219
+ not-async-context-manager,
220
+ not-context-manager,
221
+ attribute-defined-outside-init
222
+ ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace
223
+ missing-member-hint=yes
224
+ missing-member-hint-distance=1
225
+ missing-member-max-choices=1
226
+ mixin-class-rgx=.*[Mm]ixin
227
+ signature-mutators=
228
+
229
+ [VARIABLES]
230
+ additional-builtins=
231
+ allow-global-unused-variables=yes
232
+ allowed-redefined-builtins=
233
+ callbacks=cb_,
234
+ dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
235
+ ignored-argument-names=_.*|^ignored_|^unused_
236
+ init-import=no
237
+ redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io
CHANGELOG.md ADDED
The diff for this file is too large to render. See raw diff
 
CITATION.cff ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cff-version: 1.2.0
2
+ title: SD.Next
3
+ url: 'https://github.com/vladmandic/automatic'
4
+ message: >-
5
+ If you use this software, please cite it using the
6
+ metadata from this file
7
+ type: software
8
+ authors:
9
+ - given-names: Vladimir
10
+ name-particle: Vlado
11
+ family-names: Mandic
12
+ orcid: 'https://orcid.org/0009-0003-4592-5074'
13
+ identifiers:
14
+ - type: url
15
+ value: 'https://github.com/vladmandic'
16
+ description: GitHub
17
+ - type: url
18
+ value: 'https://www.linkedin.com/in/cyan051/'
19
+ description: LinkedIn
20
+ repository-code: 'https://github.com/vladmandic/automatic'
21
+ abstract: >-
22
+ SD.Next: Advanced Implementation of Stable Diffusion and
23
+ other diffusion models for text, image and video
24
+ generation
25
+ keywords:
26
+ - stablediffusion diffusers sdnext
27
+ license: AGPL-3.0
28
+ date-released: 2022-12-24
README.md CHANGED
@@ -1,12 +1,281 @@
1
  ---
2
- title: Test
3
- emoji: ๐Ÿ‘
4
- colorFrom: green
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 4.31.1
8
- app_file: app.py
9
- pinned: false
10
  ---
 
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: test
3
+ app_file: webui.py
 
 
4
  sdk: gradio
5
+ sdk_version: 4.29.0
 
 
6
  ---
7
+ <div align="center">
8
 
9
+ # SD.Next
10
+
11
+ **Stable Diffusion implementation with advanced features**
12
+
13
+ [![Sponsors](https://img.shields.io/static/v1?label=Sponsor&message=%E2%9D%A4&logo=GitHub&color=%23fe8e86)](https://github.com/sponsors/vladmandic)
14
+ ![Last Commit](https://img.shields.io/github/last-commit/vladmandic/automatic?svg=true)
15
+ ![License](https://img.shields.io/github/license/vladmandic/automatic?svg=true)
16
+ [![Discord](https://img.shields.io/discord/1101998836328697867?logo=Discord&svg=true)](https://discord.gg/VjvR2tabEX)
17
+
18
+ [Wiki](https://github.com/vladmandic/automatic/wiki) | [Discord](https://discord.gg/VjvR2tabEX) | [Changelog](CHANGELOG.md)
19
+
20
+ </div>
21
+ </br>
22
+
23
+ ## Notable features
24
+
25
+ All individual features are not listed here, instead check [ChangeLog](CHANGELOG.md) for full list of changes
26
+ - Multiple backends!
27
+ โ–น **Diffusers | Original**
28
+ - Multiple diffusion models!
29
+ โ–น **Stable Diffusion 1.5/2.1 | SD-XL | LCM | Segmind | Kandinsky | Pixart-ฮฑ | Stable Cascade | Wรผrstchen | aMUSEd | DeepFloyd IF | UniDiffusion | SD-Distilled | BLiP Diffusion | KOALA | etc.**
30
+ - Built-in Control for Text, Image, Batch and video processing!
31
+ โ–น **ControlNet | ControlNet XS | Control LLLite | T2I Adapters | IP Adapters**
32
+ - Multiplatform!
33
+ โ–น **Windows | Linux | MacOS with CPU | nVidia | AMD | IntelArc | DirectML | OpenVINO | ONNX+Olive | ZLUDA**
34
+ - Platform specific autodetection and tuning performed on install
35
+ - Optimized processing with latest `torch` developments with built-in support for `torch.compile`
36
+ and multiple compile backends: *Triton, ZLUDA, StableFast, DeepCache, OpenVINO, NNCF, IPEX*
37
+ - Improved prompt parser
38
+ - Enhanced *Lora*/*LoCon*/*Lyco* code supporting latest trends in training
39
+ - Built-in queue management
40
+ - Enterprise level logging and hardened API
41
+ - Built in installer with automatic updates and dependency management
42
+ - Modernized UI with theme support and number of built-in themes *(dark and light)*
43
+
44
+ <br>
45
+
46
+ *Main text2image interface*:
47
+ ![Screenshot-Dark](html/screenshot-text2image.jpg)
48
+
49
+ For screenshots and informations on other available themes, see [Themes Wiki](https://github.com/vladmandic/automatic/wiki/Themes)
50
+
51
+ <br>
52
+
53
+ ## Backend support
54
+
55
+ **SD.Next** supports two main backends: *Diffusers* and *Original*:
56
+
57
+ - **Diffusers**: Based on new [Huggingface Diffusers](https://huggingface.co/docs/diffusers/index) implementation
58
+ Supports *all* models listed below
59
+ This backend is set as default for new installations
60
+ See [wiki article](https://github.com/vladmandic/automatic/wiki/Diffusers) for more information
61
+ - **Original**: Based on [LDM](https://github.com/Stability-AI/stablediffusion) reference implementation and significantly expanded on by [A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui)
62
+ This backend and is fully compatible with most existing functionality and extensions written for *A1111 SDWebUI*
63
+ Supports **SD 1.x** and **SD 2.x** models
64
+ All other model types such as *SD-XL, LCM, PixArt, Segmind, Kandinsky, etc.* require backend **Diffusers**
65
+
66
+ ## Model support
67
+
68
+ Additional models will be added as they become available and there is public interest in them
69
+
70
+ - [RunwayML Stable Diffusion](https://github.com/Stability-AI/stablediffusion/) 1.x and 2.x *(all variants)*
71
+ - [StabilityAI Stable Diffusion XL](https://github.com/Stability-AI/generative-models)
72
+ - [StabilityAI Stable Video Diffusion](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid) Base, XT 1.0, XT 1.1
73
+ - [LCM: Latent Consistency Models](https://github.com/openai/consistency_models)
74
+ - [Playground](https://huggingface.co/playgroundai/playground-v2-256px-base) *v1, v2 256, v2 512, v2 1024 and latest v2.5*
75
+ - [Stable Cascade](https://github.com/Stability-AI/StableCascade) *Full* and *Lite*
76
+ - [aMUSEd 256](https://huggingface.co/amused/amused-256) 256 and 512
77
+ - [Segmind Vega](https://huggingface.co/segmind/Segmind-Vega)
78
+ - [Segmind SSD-1B](https://huggingface.co/segmind/SSD-1B)
79
+ - [Segmind SegMoE](https://github.com/segmind/segmoe) *SD and SD-XL*
80
+ - [Kandinsky](https://github.com/ai-forever/Kandinsky-2) *2.1 and 2.2 and latest 3.0*
81
+ - [PixArt-ฮฑ XL 2](https://github.com/PixArt-alpha/PixArt-alpha) *Medium and Large*
82
+ - [Warp Wuerstchen](https://huggingface.co/blog/wuertschen)
83
+ - [Tsinghua UniDiffusion](https://github.com/thu-ml/unidiffuser)
84
+ - [DeepFloyd IF](https://github.com/deep-floyd/IF) *Medium and Large*
85
+ - [ModelScope T2V](https://huggingface.co/damo-vilab/text-to-video-ms-1.7b)
86
+ - [Segmind SD Distilled](https://huggingface.co/blog/sd_distillation) *(all variants)*
87
+ - [BLIP-Diffusion](https://dxli94.github.io/BLIP-Diffusion-website/)
88
+ - [KOALA 700M](https://github.com/youngwanLEE/sdxl-koala)
89
+ - [VGen](https://huggingface.co/ali-vilab/i2vgen-xl)
90
+
91
+
92
+ Also supported are modifiers such as:
93
+ - **LCM** and **Turbo** (*adversarial diffusion distillation*) networks
94
+ - All **LoRA** types such as LoCon, LyCORIS, HADA, IA3, Lokr, OFT
95
+ - **IP-Adapters** for SD 1.5 and SD-XL
96
+ - **InstantID**, **FaceSwap**, **FaceID**, **PhotoMerge**
97
+ - **AnimateDiff** for SD 1.5
98
+
99
+ ## Examples
100
+
101
+ *IP Adapters*:
102
+ ![Screenshot-IPAdapter](html/screenshot-ipadapter.jpg)
103
+
104
+ *Color grading*:
105
+ ![Screenshot-Color](html/screenshot-color.jpg)
106
+
107
+ *InstantID*:
108
+ ![Screenshot-InstantID](html/screenshot-instantid.jpg)
109
+
110
+ > [!IMPORTANT]
111
+ > - Loading any model other than standard SD 1.x / SD 2.x requires use of backend **Diffusers**
112
+ > - Loading any other models using **Original** backend is not supported
113
+ > - Loading manually download model `.safetensors` files is supported for specified models only (typically SD 1.x / SD 2.x / SD-XL models only)
114
+ > - For all other model types, use backend **Diffusers** and use built in Model downloader or
115
+ select model from Networks -> Models -> Reference list in which case it will be auto-downloaded and loaded
116
+
117
+ ## Platform support
118
+
119
+ - *nVidia* GPUs using **CUDA** libraries on both *Windows and Linux*
120
+ - *AMD* GPUs using **ROCm** libraries on *Linux*
121
+ Support will be extended to *Windows* once AMD releases ROCm for Windows
122
+ - *Intel Arc* GPUs using **OneAPI** with *IPEX XPU* libraries on both *Windows and Linux*
123
+ - Any GPU compatible with *DirectX* on *Windows* using **DirectML** libraries
124
+ This includes support for AMD GPUs that are not supported by native ROCm libraries
125
+ - Any GPU or device compatible with **OpenVINO** libraries on both *Windows and Linux*
126
+ - *Apple M1/M2* on *OSX* using built-in support in Torch with **MPS** optimizations
127
+ - *ONNX/Olive*
128
+
129
+ ## Install
130
+
131
+ - [Step-by-step install guide](https://github.com/vladmandic/automatic/wiki/Installation)
132
+ - [Advanced install notes](https://github.com/vladmandic/automatic/wiki/Advanced-Install)
133
+ - [Common installation errors](https://github.com/vladmandic/automatic/discussions/1627)
134
+ - [FAQ](https://github.com/vladmandic/automatic/discussions/1011)
135
+ - If you can't run us locally, try our friends at [RunDuffusion!](https://rundiffusion.com?utm_source=github&utm_medium=referral&utm_campaign=SDNext)
136
+
137
+ > [!TIP]
138
+ > - Server can run with or without virtual environment,
139
+ Recommended to use `VENV` to avoid library version conflicts with other applications
140
+ > - **nVidia/CUDA** / **AMD/ROCm** / **Intel/OneAPI** are auto-detected if present and available,
141
+ For any other use case such as **DirectML**, **ONNX/Olive**, **OpenVINO** specify required parameter explicitly
142
+ or wrong packages may be installed as installer will assume CPU-only environment
143
+ > - Full startup sequence is logged in `sdnext.log`,
144
+ so if you encounter any issues, please check it first
145
+
146
+ ### Run
147
+
148
+ Once SD.Next is installed, simply run `webui.ps1` or `webui.bat` (*Windows*) or `webui.sh` (*Linux or MacOS*)
149
+
150
+ List of available parameters, run `webui --help` for the full & up-to-date list:
151
+
152
+ Server options:
153
+ --config CONFIG Use specific server configuration file, default: config.json
154
+ --ui-config UI_CONFIG Use specific UI configuration file, default: ui-config.json
155
+ --medvram Split model stages and keep only active part in VRAM, default: False
156
+ --lowvram Split model components and keep only active part in VRAM, default: False
157
+ --ckpt CKPT Path to model checkpoint to load immediately, default: None
158
+ --vae VAE Path to VAE checkpoint to load immediately, default: None
159
+ --data-dir DATA_DIR Base path where all user data is stored, default:
160
+ --models-dir MODELS_DIR Base path where all models are stored, default: models
161
+ --allow-code Allow custom script execution, default: False
162
+ --share Enable UI accessible through Gradio site, default: False
163
+ --insecure Enable extensions tab regardless of other options, default: False
164
+ --use-cpu USE_CPU [USE_CPU ...] Force use CPU for specified modules, default: []
165
+ --listen Launch web server using public IP address, default: False
166
+ --port PORT Launch web server with given server port, default: 7860
167
+ --freeze Disable editing settings
168
+ --auth AUTH Set access authentication like "user:pwd,user:pwd""
169
+ --auth-file AUTH_FILE Set access authentication using file, default: None
170
+ --autolaunch Open the UI URL in the system's default browser upon launch
171
+ --docs Mount API docs, default: False
172
+ --api-only Run in API only mode without starting UI
173
+ --api-log Enable logging of all API requests, default: False
174
+ --device-id DEVICE_ID Select the default CUDA device to use, default: None
175
+ --cors-origins CORS_ORIGINS Allowed CORS origins as comma-separated list, default: None
176
+ --cors-regex CORS_REGEX Allowed CORS origins as regular expression, default: None
177
+ --tls-keyfile TLS_KEYFILE Enable TLS and specify key file, default: None
178
+ --tls-certfile TLS_CERTFILE Enable TLS and specify cert file, default: None
179
+ --tls-selfsign Enable TLS with self-signed certificates, default: False
180
+ --server-name SERVER_NAME Sets hostname of server, default: None
181
+ --no-hashing Disable hashing of checkpoints, default: False
182
+ --no-metadata Disable reading of metadata from models, default: False
183
+ --disable-queue Disable queues, default: False
184
+ --subpath SUBPATH Customize the URL subpath for usage with reverse proxy
185
+ --backend {original,diffusers} force model pipeline type
186
+ --allowed-paths ALLOWED_PATHS [ALLOWED_PATHS ...] add additional paths to paths allowed for web access
187
+
188
+ Setup options:
189
+ --reset Reset main repository to latest version, default: False
190
+ --upgrade Upgrade main repository to latest version, default: False
191
+ --requirements Force re-check of requirements, default: False
192
+ --quick Bypass version checks, default: False
193
+ --use-directml Use DirectML if no compatible GPU is detected, default: False
194
+ --use-openvino Use Intel OpenVINO backend, default: False
195
+ --use-ipex Force use Intel OneAPI XPU backend, default: False
196
+ --use-cuda Force use nVidia CUDA backend, default: False
197
+ --use-rocm Force use AMD ROCm backend, default: False
198
+ --use-zluda Force use ZLUDA, AMD GPUs only, default: False
199
+ --use-xformers Force use xFormers cross-optimization, default: False
200
+ --skip-requirements Skips checking and installing requirements, default: False
201
+ --skip-extensions Skips running individual extension installers, default: False
202
+ --skip-git Skips running all GIT operations, default: False
203
+ --skip-torch Skips running Torch checks, default: False
204
+ --skip-all Skips running all checks, default: False
205
+ --skip-env Skips setting of env variables during startup, default: False
206
+ --experimental Allow unsupported versions of libraries, default: False
207
+ --reinstall Force reinstallation of all requirements, default: False
208
+ --test Run test only and exit
209
+ --version Print version information
210
+ --ignore Ignore any errors and attempt to continue
211
+ --safe Run in safe mode with no user extensions
212
+
213
+ Logging options:
214
+ --log LOG Set log file, default: None
215
+ --debug Run installer with debug logging, default: False
216
+ --profile Run profiler, default: False
217
+
218
+ ## Notes
219
+
220
+ ### Control
221
+
222
+ **SD.Next** comes with built-in control for all types of text2image, image2image, video2video and batch processing
223
+
224
+ *Control interface*:
225
+ ![Screenshot-Control](html/screenshot-control.jpg)
226
+
227
+ *Control processors*:
228
+ ![Screenshot-Process](html/screenshot-processors.jpg)
229
+
230
+ *Masking*:
231
+ ![Screenshot-Mask](html/screenshot-mask.jpg)
232
+
233
+ ### **Extensions**
234
+
235
+ SD.Next comes with several extensions pre-installed:
236
+
237
+ - [ControlNet](https://github.com/Mikubill/sd-webui-controlnet) (*active in backend: original only*)
238
+ - [Agent Scheduler](https://github.com/ArtVentureX/sd-webui-agent-scheduler)
239
+ - [Image Browser](https://github.com/AlUlkesh/stable-diffusion-webui-images-browser)
240
+
241
+ ### **Collab**
242
+
243
+ - We'd love to have additional maintainers (with comes with full repo rights). If you're interested, ping us!
244
+ - In addition to general cross-platform code, desire is to have a lead for each of the main platforms
245
+ This should be fully cross-platform, but we'd really love to have additional contributors and/or maintainers to join and help lead the efforts on different platforms
246
+
247
+ ### **Credits**
248
+
249
+ - Main credit goes to [Automatic1111 WebUI](https://github.com/AUTOMATIC1111/stable-diffusion-webui) for original codebase
250
+ - Additional credits are listed in [Credits](https://github.com/AUTOMATIC1111/stable-diffusion-webui/#credits)
251
+ - Licenses for modules are listed in [Licenses](html/licenses.html)
252
+
253
+ ### **Evolution**
254
+
255
+ <a href="https://star-history.com/#vladmandic/automatic&Date">
256
+ <picture width=640>
257
+ <source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=vladmandic/automatic&type=Date&theme=dark" />
258
+ <img src="https://api.star-history.com/svg?repos=vladmandic/automatic&type=Date" alt="starts" width="320">
259
+ </picture>
260
+ </a>
261
+
262
+ - [OSS Stats](https://ossinsight.io/analyze/vladmandic/automatic#overview)
263
+
264
+ ### **Docs**
265
+
266
+ If you're unsure how to use a feature, best place to start is [Wiki](https://github.com/vladmandic/automatic/wiki) and if its not there,
267
+ check [ChangeLog](CHANGELOG.md) for when feature was first introduced as it will always have a short note on how to use it
268
+
269
+ - [Wiki](https://github.com/vladmandic/automatic/wiki)
270
+ - [ReadMe](README.md)
271
+ - [ToDo](TODO.md)
272
+ - [ChangeLog](CHANGELOG.md)
273
+ - [CLI Tools](cli/README.md)
274
+
275
+ ### **Sponsors**
276
+
277
+ <div align="center">
278
+ <!-- sponsors --><a href="https://github.com/allangrant"><img src="https://github.com/allangrant.png" width="60px" alt="Allan Grant" /></a><a href="https://github.com/BrentOzar"><img src="https://github.com/BrentOzar.png" width="60px" alt="Brent Ozar" /></a><a href="https://github.com/inktomi"><img src="https://github.com/inktomi.png" width="60px" alt="Matthew Runo" /></a><a href="https://github.com/4joeknight4"><img src="https://github.com/4joeknight4.png" width="60px" alt="" /></a><a href="https://github.com/SaladTechnologies"><img src="https://github.com/SaladTechnologies.png" width="60px" alt="Salad Technologies" /></a><a href="https://github.com/mantzaris"><img src="https://github.com/mantzaris.png" width="60px" alt="a.v.mantzaris" /></a><a href="https://github.com/CurseWave"><img src="https://github.com/CurseWave.png" width="60px" alt="" /></a><!-- sponsors -->
279
+ </div>
280
+
281
+ <br>
SECURITY.md ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Security & Privacy Policy
2
+
3
+ <br>
4
+
5
+ ## Issues
6
+
7
+ All issues are tracked publicly on GitHub: <https://github.com/vladmandic/automatic/issues>
8
+
9
+ <br>
10
+
11
+ ## Vulnerabilities
12
+
13
+ `SD.Next` code base and included dependencies are automatically scanned against known security vulnerabilities
14
+
15
+ Any code commit is validated before merge
16
+
17
+ - [Dependencies](https://github.com/vladmandic/automatic/security/dependabot)
18
+ - [Scanning Alerts](https://github.com/vladmandic/automatic/security/code-scanning)
19
+
20
+ <br>
21
+
22
+ ## Privacy
23
+
24
+ `SD.Next` app:
25
+
26
+ - Is fully self-contained and does not send or share data of any kind with external targets
27
+ - Does not store any user or system data tracking, user provided inputs (images, video) or detection results
28
+ - Does not utilize any analytic services (such as Google Analytics)
29
+
30
+ `SD.Next` library can establish external connections *only* for following purposes and *only* when explicitly configured by user:
31
+
32
+ - Download extensions and themes indexes from automatically updated indexes
33
+ - Download required packages and repositories from GitHub during installation/upgrade
34
+ - Download installed/enabled extensions
35
+ - Download models from CivitAI and/or Huggingface when instructed by user
36
+ - Submit benchmark info upon user interaction
TODO.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TODO
2
+
3
+ Main ToDo list can be found at [GitHub projects](https://github.com/users/vladmandic/projects)
4
+
5
+ ## Candidates for next release
6
+
7
+ - defork
8
+ - stable diffusion 3.0
9
+ - ipadapter masking: <https://github.com/huggingface/diffusers/pull/6847>
10
+ - x-adapter: <https://github.com/showlab/X-Adapter>
11
+ - async lowvram: <https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14855>
12
+ - init latents: variations, img2img
13
+ - diffusers public callbacks
14
+ - remove builtin: controlnet
15
+ - remove builtin: image-browser
16
+
17
+ ## Control missing features
18
+
19
+ - second pass: <https://github.com/vladmandic/automatic/issues/2783>
20
+ - control api
cli/README.md ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Stable-Diffusion Productivity Scripts
2
+
3
+ Note: All scripts have built-in `--help` parameter that can be used to get more information
4
+
5
+ <br>
6
+
7
+ ## Main Scripts
8
+
9
+ ### Generate
10
+
11
+ Text-to-image with all of the possible parameters
12
+ Supports upsampling, face restoration and grid creation
13
+ > python generate.py
14
+
15
+ By default uses parameters from `generate.json`
16
+
17
+ Parameters that are not specified will be randomized:
18
+
19
+ - Prompt will be dynamically created from template of random samples: `random.json`
20
+ - Sampler/Scheduler will be randomly picked from available ones
21
+ - CFG Scale set to 5-10
22
+
23
+ ### Train
24
+
25
+ Combined pipeline for **embeddings**, **lora**, **lycoris**, **dreambooth** and **hypernetwork**
26
+ Optionally runs several image processing steps before training:
27
+
28
+ - keep original image
29
+ - detect and extract face
30
+ - detect and extract body
31
+ - detect blur
32
+ - detect dynamic range
33
+ - attempt to upscale low resolution images
34
+ - attempt to restore quality of low quality images
35
+ - automatically generate captions using interrogate
36
+ - resize image
37
+ - square image
38
+ - run image segmentation to remove background
39
+
40
+ > python train.py
41
+
42
+ <br>
43
+
44
+ ## Auxiliary Scripts
45
+
46
+ ### Benchmark
47
+
48
+ > python run-benchmark.py
49
+
50
+ ### Create Previews
51
+
52
+ Create previews for **embeddings**, **lora**, **lycoris**, **dreambooth** and **hypernetwork**
53
+
54
+ > python create-previews.py
55
+
56
+ ## Image Grid
57
+
58
+ > python image-grid.py
59
+
60
+ ### Image Watermark
61
+
62
+ Create invisible image watermark and remove existing EXIF tags
63
+
64
+ > python image-watermark.py
65
+
66
+ ### Image Interrogate
67
+
68
+ Runs CLiP and Booru image interrogation
69
+
70
+ > python image-interrogate.py
71
+
72
+ ### Palette Extract
73
+
74
+ Extract color palette from image(s)
75
+
76
+ > python image-palette.py
77
+
78
+ ### Prompt Ideas
79
+
80
+ Generate complex prompt ideas
81
+
82
+ > python prompt-ideas.py
83
+
84
+ ### Prompt Promptist
85
+
86
+ Attempts to beautify the provided prompt
87
+
88
+ > python prompt-promptist.py
89
+
90
+ ### Video Extract
91
+
92
+ Extract frames from video files
93
+
94
+ > python video-extract.py
95
+
96
+ <br>
97
+
98
+ ## Utility Scripts
99
+
100
+ ### SDAPI
101
+
102
+ Utility module that handles async communication to Automatic API endpoints
103
+ Note: Requires SD API
104
+
105
+ Can be used to manually execute specific commands:
106
+ > python sdapi.py progress
107
+ > python sdapi.py interrupt
108
+ > python sdapi.py shutdown
cli/clone.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ import os
3
+ import logging
4
+ import git
5
+ from rich import console, progress
6
+
7
+
8
+ class GitRemoteProgress(git.RemoteProgress):
9
+ OP_CODES = ["BEGIN", "CHECKING_OUT", "COMPRESSING", "COUNTING", "END", "FINDING_SOURCES", "RECEIVING", "RESOLVING", "WRITING"]
10
+ OP_CODE_MAP = { getattr(git.RemoteProgress, _op_code): _op_code for _op_code in OP_CODES }
11
+
12
+ def __init__(self, url, folder) -> None:
13
+ super().__init__()
14
+ self.url = url
15
+ self.folder = folder
16
+ self.progressbar = progress.Progress(
17
+ progress.SpinnerColumn(),
18
+ progress.TextColumn("[cyan][progress.description]{task.description}"),
19
+ progress.BarColumn(),
20
+ progress.TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
21
+ progress.TimeRemainingColumn(),
22
+ progress.TextColumn("[yellow]<{task.fields[url]}>"),
23
+ progress.TextColumn("{task.fields[message]}"),
24
+ console=console.Console(),
25
+ transient=False,
26
+ )
27
+ self.progressbar.start()
28
+ self.active_task = None
29
+
30
+ def __del__(self) -> None:
31
+ self.progressbar.stop()
32
+
33
+ @classmethod
34
+ def get_curr_op(cls, op_code: int) -> str:
35
+ op_code_masked = op_code & cls.OP_MASK
36
+ return cls.OP_CODE_MAP.get(op_code_masked, "?").title()
37
+
38
+ def update(self, op_code: int, cur_count: str | float, max_count: str | float | None = None, message: str | None = "") -> None:
39
+ if op_code & self.BEGIN:
40
+ self.curr_op = self.get_curr_op(op_code) # pylint: disable=attribute-defined-outside-init
41
+ self.active_task = self.progressbar.add_task(description=self.curr_op, total=max_count, message=message, url=self.url)
42
+ self.progressbar.update(task_id=self.active_task, completed=cur_count, message=message)
43
+ if op_code & self.END:
44
+ self.progressbar.update(task_id=self.active_task, message=f"[bright_black]{message}")
45
+
46
+
47
+ def clone(url: str, folder: str):
48
+ git.Repo.clone_from(
49
+ url=url,
50
+ to_path=folder,
51
+ progress=GitRemoteProgress(url=url, folder=folder),
52
+ multi_options=['--config core.compression=0', '--config core.loosecompression=0', '--config pack.window=0'],
53
+ allow_unsafe_options=True,
54
+ depth=1,
55
+ )
56
+
57
+
58
+ if __name__ == "__main__":
59
+ import argparse
60
+ parser = argparse.ArgumentParser(description = 'downloader')
61
+ parser.add_argument('--url', required=True, help="download url, required")
62
+ parser.add_argument('--folder', required=False, help="output folder, default: autodetect")
63
+ args = parser.parse_args()
64
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s: %(message)s")
65
+ log = logging.getLogger(__name__)
66
+ try:
67
+ if not args.url.startswith('http'):
68
+ raise ValueError(f'invalid url: {args.url}')
69
+ f = args.url.split('/')[-1].split('.')[0] if args.folder is None else args.folder
70
+ if os.path.exists(f):
71
+ raise FileExistsError(f'folder already exists: {f}')
72
+ log.info(f'Clone start: url={args.url} folder={f}')
73
+ clone(url=args.url, folder=f)
74
+ log.info(f'Clone complete: url={args.url} folder={f}')
75
+ except KeyboardInterrupt:
76
+ log.warning(f'Clone cancelled: url={args.url} folder={f}')
77
+ except Exception as e:
78
+ log.error(f'Clone: url={args.url} {e}')
cli/create-previews.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # pylint: disable=no-member
3
+ import os
4
+ import re
5
+ import json
6
+ import time
7
+ import logging
8
+ import importlib
9
+ import asyncio
10
+ import argparse
11
+ from pathlib import Path
12
+ from util import Map, log
13
+ from sdapi import get, post, close
14
+ from generate import generate # pylint: disable=import-error
15
+ grid = importlib.import_module('image-grid').grid
16
+
17
+
18
+ options = Map({
19
+ # used by extra networks
20
+ 'prompt': 'photo of <keyword> <embedding>, photograph, posing, pose, high detailed, intricate, elegant, sharp focus, skin texture, looking forward, facing camera, 135mm, shot on dslr, canon 5d, 4k, modelshoot style, cinematic lighting',
21
+ # used by models
22
+ 'prompts': [
23
+ ('photo citiscape', 'cityscape during night, photorealistic, high detailed, sharp focus, depth of field, 4k'),
24
+ ('photo car', 'photo of a sports car, high detailed, sharp focus, dslr, cinematic lighting, realistic'),
25
+ ('photo woman', 'portrait photo of beautiful woman, high detailed, dslr, 35mm'),
26
+ ('photo naked', 'full body photo of beautiful sexy naked woman, high detailed, dslr, 35mm'),
27
+
28
+ ('photo taylor', 'portrait photo of beautiful woman taylor swift, high detailed, sharp focus, depth of field, dslr, 35mm <lora:taylor-swift:1>'),
29
+ ('photo ti-mia', 'portrait photo of beautiful woman "ti-mia", naked, high detailed, dslr, 35mm'),
30
+ ('photo ti-vlado', 'portrait photo of man "ti-vlado", high detailed, dslr, 35mm'),
31
+ ('photo lora-vlado', 'portrait photo of man vlado, high detailed, dslr, 35mm <lora:vlado-original:1>'),
32
+
33
+ ('wlop', 'a stunning portrait of sexy teen girl in a wet t-shirt, vivid color palette, digital painting, octane render, highly detailed, particles, light effect, volumetric lighting, art by wlop'),
34
+ ('greg rutkowski', 'beautiful woman, high detailed, sharp focus, depth of field, 4k, art by greg rutkowski'),
35
+ ('carne griffiths', 'beautiful woman taylor swift, high detailed, sharp focus, depth of field, art by carne griffiths <lora:taylor-swift:1>'),
36
+ ('carne griffiths', 'man vlado, high detailed, sharp focus, depth of field, art by carne griffiths <lora:vlado-full:1>'),
37
+ ],
38
+ # save format
39
+ 'format': '.jpg',
40
+ # used by generate script
41
+ 'paths': {
42
+ "root": "/mnt/c/Users/mandi/OneDrive/Generative/Generate",
43
+ "generate": "image",
44
+ "upscale": "upscale",
45
+ "grid": "grid",
46
+ },
47
+ # generate params
48
+ 'generate': {
49
+ 'restore_faces': True,
50
+ 'prompt': '',
51
+ 'negative_prompt': 'foggy, blurry, blurred, duplicate, ugly, mutilated, mutation, mutated, out of frame, bad anatomy, disfigured, deformed, censored, low res, low resolution, watermark, text, poorly drawn face, poorly drawn hands, signature',
52
+ 'steps': 20,
53
+ 'batch_size': 2,
54
+ 'n_iter': 1,
55
+ 'seed': -1,
56
+ 'sampler_name': 'UniPC',
57
+ 'cfg_scale': 6,
58
+ 'width': 512,
59
+ 'height': 512,
60
+ },
61
+ 'lora': {
62
+ 'strength': 1.0,
63
+ },
64
+ 'hypernetwork': {
65
+ 'keyword': '',
66
+ 'strength': 1.0,
67
+ },
68
+ })
69
+
70
+
71
+ def preview_exists(folder, model):
72
+ model = os.path.splitext(model)[0]
73
+ for suffix in ['', '.preview']:
74
+ for ext in ['.jpg', '.png', '.webp']:
75
+ fn = os.path.join(folder, f'{model}{suffix}{ext}')
76
+ if os.path.exists(fn):
77
+ return True
78
+ return False
79
+
80
+
81
+ async def preview_models(params):
82
+ data = await get('/sdapi/v1/sd-models')
83
+ allmodels = [m['title'] for m in data]
84
+ models = []
85
+ excluded = []
86
+ for m in allmodels: # loop through all registered models
87
+ ok = True
88
+ for e in params.exclude: # check if model is excluded
89
+ if e in m:
90
+ excluded.append(m)
91
+ ok = False
92
+ break
93
+ if ok:
94
+ short = m.split(' [')[0]
95
+ short = short.replace('.ckpt', '').replace('.safetensors', '')
96
+ models.append(short)
97
+ if len(params.input) > 0: # check if model is included in cmd line
98
+ filtered = []
99
+ for m in params.input:
100
+ if m in models:
101
+ filtered.append(m)
102
+ else:
103
+ log.error({ 'model not found': m })
104
+ return
105
+ models = filtered
106
+ log.info({ 'models preview' })
107
+ log.info({ 'models': len(models), 'excluded': len(excluded) })
108
+ opt = await get('/sdapi/v1/options')
109
+ log.info({ 'total jobs': len(models) * options.generate.batch_size, 'per-model': options.generate.batch_size })
110
+ log.info(json.dumps(options, indent=2))
111
+ for model in models:
112
+ if preview_exists(opt['ckpt_dir'], model) and len(params.input) == 0: # if model preview exists and not manually included
113
+ log.info({ 'model preview exists': model })
114
+ continue
115
+ fn = os.path.join(opt['ckpt_dir'], os.path.splitext(model)[0] + options.format)
116
+ log.info({ 'model load': model })
117
+
118
+ opt['sd_model_checkpoint'] = model
119
+ del opt['sd_lora']
120
+ del opt['sd_lyco']
121
+ await post('/sdapi/v1/options', opt)
122
+ opt = await get('/sdapi/v1/options')
123
+ images = []
124
+ labels = []
125
+ t0 = time.time()
126
+ for label, p in options.prompts:
127
+ options.generate.prompt = p
128
+ log.info({ 'model generating': model, 'label': label, 'prompt': options.generate.prompt })
129
+ data = await generate(options = options, quiet=True)
130
+ if 'image' in data:
131
+ for img in data['image']:
132
+ images.append(img)
133
+ labels.append(label)
134
+ else:
135
+ log.error({ 'model': model, 'error': data })
136
+ t1 = time.time()
137
+ if len(images) == 0:
138
+ log.error({ 'model': model, 'error': 'no images generated' })
139
+ continue
140
+ image = grid(images = images, labels = labels, border = 8)
141
+ log.info({ 'saving preview': fn, 'images': len(images), 'size': [image.width, image.height] })
142
+ image.save(fn)
143
+ t = t1 - t0
144
+ its = 1.0 * options.generate.steps * len(images) / t
145
+ log.info({ 'model preview created': model, 'image': fn, 'images': len(images), 'grid': [image.width, image.height], 'time': round(t, 2), 'its': round(its, 2) })
146
+
147
+ opt = await get('/sdapi/v1/options')
148
+ if opt['sd_model_checkpoint'] != params.model:
149
+ log.info({ 'model set default': params.model })
150
+ opt['sd_model_checkpoint'] = params.model
151
+ del opt['sd_lora']
152
+ del opt['sd_lyco']
153
+ await post('/sdapi/v1/options', opt)
154
+
155
+
156
+ async def lora(params):
157
+ opt = await get('/sdapi/v1/options')
158
+ folder = opt['lora_dir']
159
+ if not os.path.exists(folder):
160
+ log.error({ 'lora directory not found': folder })
161
+ return
162
+ models1 = list(Path(folder).glob('**/*.safetensors'))
163
+ models2 = list(Path(folder).glob('**/*.ckpt'))
164
+ models = [os.path.splitext(f)[0] for f in models1 + models2]
165
+ log.info({ 'loras': len(models) })
166
+ for model in models:
167
+ if preview_exists('', model) and len(params.input) == 0: # if model preview exists and not manually included
168
+ log.info({ 'lora preview exists': model })
169
+ continue
170
+ fn = model + options.format
171
+ model = os.path.basename(model)
172
+ images = []
173
+ labels = []
174
+ t0 = time.time()
175
+ keywords = re.sub(r'\d', '', model)
176
+ keywords = keywords.replace('-v', ' ').replace('-', ' ').strip().split(' ')
177
+ keyword = '\"' + '\" \"'.join(keywords) + '\"'
178
+ options.generate.prompt = options.prompt.replace('<keyword>', keyword)
179
+ options.generate.prompt = options.generate.prompt.replace('<embedding>', '')
180
+ options.generate.prompt += f' <lora:{model}:{options.lora.strength}>'
181
+ log.info({ 'lora generating': model, 'keyword': keyword, 'prompt': options.generate.prompt })
182
+ data = await generate(options = options, quiet=True)
183
+ if 'image' in data:
184
+ for img in data['image']:
185
+ images.append(img)
186
+ labels.append(keyword)
187
+ else:
188
+ log.error({ 'lora': model, 'keyword': keyword, 'error': data })
189
+ t1 = time.time()
190
+ if len(images) == 0:
191
+ log.error({ 'model': model, 'error': 'no images generated' })
192
+ continue
193
+ image = grid(images = images, labels = labels, border = 8)
194
+ log.info({ 'saving preview': fn, 'images': len(images), 'size': [image.width, image.height] })
195
+ image.save(fn)
196
+ t = t1 - t0
197
+ its = 1.0 * options.generate.steps * len(images) / t
198
+ log.info({ 'lora preview created': model, 'image': fn, 'images': len(images), 'grid': [image.width, image.height], 'time': round(t, 2), 'its': round(its, 2) })
199
+
200
+
201
+ async def lyco(params):
202
+ opt = await get('/sdapi/v1/options')
203
+ folder = opt['lyco_dir']
204
+ if not os.path.exists(folder):
205
+ log.error({ 'lyco directory not found': folder })
206
+ return
207
+ models1 = list(Path(folder).glob('**/*.safetensors'))
208
+ models2 = list(Path(folder).glob('**/*.ckpt'))
209
+ models = [os.path.splitext(f)[0] for f in models1 + models2]
210
+ log.info({ 'lycos': len(models) })
211
+ for model in models:
212
+ if preview_exists('', model) and len(params.input) == 0: # if model preview exists and not manually included
213
+ log.info({ 'lyco preview exists': model })
214
+ continue
215
+ fn = model + options.format
216
+ model = os.path.basename(model)
217
+ images = []
218
+ labels = []
219
+ t0 = time.time()
220
+ keywords = re.sub(r'\d', '', model)
221
+ keywords = keywords.replace('-v', ' ').replace('-', ' ').strip().split(' ')
222
+ keyword = '\"' + '\" \"'.join(keywords) + '\"'
223
+ options.generate.prompt = options.prompt.replace('<keyword>', keyword)
224
+ options.generate.prompt = options.generate.prompt.replace('<embedding>', '')
225
+ options.generate.prompt += f' <lyco:{model}:{options.lora.strength}>'
226
+ log.info({ 'lyco generating': model, 'keyword': keyword, 'prompt': options.generate.prompt })
227
+ data = await generate(options = options, quiet=True)
228
+ if 'image' in data:
229
+ for img in data['image']:
230
+ images.append(img)
231
+ labels.append(keyword)
232
+ else:
233
+ log.error({ 'lyco': model, 'keyword': keyword, 'error': data })
234
+ t1 = time.time()
235
+ if len(images) == 0:
236
+ log.error({ 'model': model, 'error': 'no images generated' })
237
+ continue
238
+ image = grid(images = images, labels = labels, border = 8)
239
+ log.info({ 'saving preview': fn, 'images': len(images), 'size': [image.width, image.height] })
240
+ image.save(fn)
241
+ t = t1 - t0
242
+ its = 1.0 * options.generate.steps * len(images) / t
243
+ log.info({ 'lyco preview created': model, 'image': fn, 'images': len(images), 'grid': [image.width, image.height], 'time': round(t, 2), 'its': round(its, 2) })
244
+
245
+
246
+ async def hypernetwork(params):
247
+ opt = await get('/sdapi/v1/options')
248
+ folder = opt['hypernetwork_dir']
249
+ if not os.path.exists(folder):
250
+ log.error({ 'hypernetwork directory not found': folder })
251
+ return
252
+ models = [os.path.splitext(f)[0] for f in Path(folder).glob('**/*.pt')]
253
+ log.info({ 'hypernetworks': len(models) })
254
+ for model in models:
255
+ if preview_exists(folder, model) and len(params.input) == 0: # if model preview exists and not manually included
256
+ log.info({ 'hypernetwork preview exists': model })
257
+ continue
258
+ fn = os.path.join(folder, model + options.format)
259
+ images = []
260
+ labels = []
261
+ t0 = time.time()
262
+ keyword = options.hypernetwork.keyword
263
+ options.generate.prompt = options.prompt.replace('<keyword>', options.hypernetwork.keyword)
264
+ options.generate.prompt = options.generate.prompt.replace('<embedding>', '')
265
+ options.generate.prompt = f' <hypernet:{model}:{options.hypernetwork.strength}> ' + options.generate.prompt
266
+ log.info({ 'hypernetwork generating': model, 'keyword': keyword, 'prompt': options.generate.prompt })
267
+ data = await generate(options = options, quiet=True)
268
+ if 'image' in data:
269
+ for img in data['image']:
270
+ images.append(img)
271
+ labels.append(keyword)
272
+ else:
273
+ log.error({ 'hypernetwork': model, 'keyword': keyword, 'error': data })
274
+ t1 = time.time()
275
+ if len(images) == 0:
276
+ log.error({ 'model': model, 'error': 'no images generated' })
277
+ continue
278
+ image = grid(images = images, labels = labels, border = 8)
279
+ log.info({ 'saving preview': fn, 'images': len(images), 'size': [image.width, image.height] })
280
+ image.save(fn)
281
+ t = t1 - t0
282
+ its = 1.0 * options.generate.steps * len(images) / t
283
+ log.info({ 'hypernetwork preview created': model, 'image': fn, 'images': len(images), 'grid': [image.width, image.height], 'time': round(t, 2), 'its': round(its, 2) })
284
+
285
+
286
+ async def embedding(params):
287
+ opt = await get('/sdapi/v1/options')
288
+ folder = opt['embeddings_dir']
289
+ if not os.path.exists(folder):
290
+ log.error({ 'embeddings directory not found': folder })
291
+ return
292
+ models = [os.path.splitext(f)[0] for f in Path(folder).glob('**/*.pt')]
293
+ log.info({ 'embeddings': len(models) })
294
+ for model in models:
295
+ if preview_exists(folder, model) and len(params.input) == 0: # if model preview exists and not manually included
296
+ log.info({ 'embedding preview exists': model })
297
+ continue
298
+ fn = os.path.join(folder, model + '.preview' + options.format)
299
+ images = []
300
+ labels = []
301
+ t0 = time.time()
302
+ keyword = '\"' + re.sub(r'\d', '', model) + '\"'
303
+ options.generate.batch_size = 4
304
+ options.generate.prompt = options.prompt.replace('<keyword>', keyword)
305
+ options.generate.prompt = options.generate.prompt.replace('<embedding>', '')
306
+ log.info({ 'embedding generating': model, 'keyword': keyword, 'prompt': options.generate.prompt })
307
+ data = await generate(options = options, quiet=True)
308
+ if 'image' in data:
309
+ for img in data['image']:
310
+ images.append(img)
311
+ labels.append(keyword)
312
+ else:
313
+ log.error({ 'embeding': model, 'keyword': keyword, 'error': data })
314
+ t1 = time.time()
315
+ if len(images) == 0:
316
+ log.error({ 'model': model, 'error': 'no images generated' })
317
+ continue
318
+ image = grid(images = images, labels = labels, border = 8)
319
+ log.info({ 'saving preview': fn, 'images': len(images), 'size': [image.width, image.height] })
320
+ image.save(fn)
321
+ t = t1 - t0
322
+ its = 1.0 * options.generate.steps * len(images) / t
323
+ log.info({ 'embeding preview created': model, 'image': fn, 'images': len(images), 'grid': [image.width, image.height], 'time': round(t, 2), 'its': round(its, 2) })
324
+
325
+
326
+ async def create_previews(params):
327
+ await preview_models(params)
328
+ await lora(params)
329
+ await lyco(params)
330
+ await hypernetwork(params)
331
+ await embedding(params)
332
+ await close()
333
+
334
+
335
+ if __name__ == '__main__':
336
+ parser = argparse.ArgumentParser(description = 'generate model previews')
337
+ parser.add_argument('--model', default='best/icbinp-icantbelieveIts-final.safetensors [73f48afbdc]', help="model used to create extra network previews")
338
+ parser.add_argument('--exclude', default=['sd-v20', 'sd-v21', 'inpainting', 'pix2pix'], help="exclude models with keywords")
339
+ parser.add_argument('--debug', default = False, action='store_true', help = 'print extra debug information')
340
+ parser.add_argument('input', type = str, nargs = '*')
341
+ args = parser.parse_args()
342
+ if args.debug:
343
+ log.setLevel(logging.DEBUG)
344
+ log.debug({ 'debug': True })
345
+ log.debug({ 'args': args.__dict__ })
346
+ asyncio.run(create_previews(args))
cli/download.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ import os
3
+ import time
4
+ import argparse
5
+ import tempfile
6
+ import urllib
7
+ import requests
8
+ import urllib3
9
+ import rich.progress as p
10
+ from rich import print # pylint: disable=redefined-builtin
11
+
12
+
13
+ pbar = p.Progress(p.TextColumn('[cyan]{task.description}'), p.DownloadColumn(), p.BarColumn(), p.TaskProgressColumn(), p.TimeRemainingColumn(), p.TimeElapsedColumn(), p.TransferSpeedColumn())
14
+ headers = {
15
+ 'Content-type': 'application/json',
16
+ 'User-Agent': 'Mozilla/5.0',
17
+ }
18
+
19
+
20
+ def get_filename(args, res):
21
+ content_fn = (res.headers.get('content-disposition', '').split('filename=')[1]).strip().strip('\"') if 'filename=' in res.headers.get('content-disposition', '') else None
22
+ return args.file or content_fn or next(tempfile._get_candidate_names()) # pylint: disable=protected-access
23
+
24
+
25
+ def download_requests(args):
26
+ res = requests.get(args.url, timeout=30, headers=headers, verify=False, allow_redirects=True, stream=True)
27
+ content_length = int(res.headers.get('content-length', 0))
28
+ fn = get_filename(args, res)
29
+ print(f'downloading: url={args.url} file={fn} size={content_length if content_length > 0 else "unknown"} lib=requests block={args.block}')
30
+ with open(fn, 'wb') as f:
31
+ with pbar:
32
+ task = pbar.add_task(description="Download starting", total=content_length)
33
+ for data in res.iter_content(args.block):
34
+ f.write(data)
35
+ pbar.update(task, advance=args.block, description="Downloading")
36
+ return fn
37
+
38
+
39
+ def download_urllib(args):
40
+ fn = ''
41
+ req = urllib.request.Request(args.url, headers=headers)
42
+ res = urllib.request.urlopen(req)
43
+ res.getheader('content-length')
44
+ content_length = int(res.getheader('content-length') or 0)
45
+ fn = get_filename(args, res)
46
+ print(f'downloading: url={args.url} file={fn} size={content_length if content_length > 0 else "unknown"} lib=urllib block={args.block}')
47
+ with open(fn, 'wb') as f:
48
+ with pbar:
49
+ task = pbar.add_task(description="Download starting", total=content_length)
50
+ while True:
51
+ buf = res.read(args.block)
52
+ if not buf:
53
+ break
54
+ f.write(buf)
55
+ pbar.update(task, advance=args.block, description="Downloading")
56
+ return fn
57
+
58
+
59
+ def download_urllib3(args):
60
+ http_pool = urllib3.PoolManager()
61
+ res = http_pool.request('GET', args.url, preload_content=False, headers=headers)
62
+ fn = get_filename(args, res)
63
+ content_length = int(res.headers.get('content-length', 0))
64
+ print(f'downloading: url={args.url} file={fn} size={content_length if content_length > 0 else "unknown"} lib=urllib3 block={args.block}')
65
+ with open(fn, 'wb') as f:
66
+ with pbar:
67
+ task = pbar.add_task(description="Download starting", total=content_length)
68
+ while True:
69
+ buf = res.read(args.block)
70
+ if not buf:
71
+ break
72
+ f.write(buf)
73
+ pbar.update(task, advance=args.block, description="Downloading")
74
+ return fn
75
+
76
+
77
+ def download_httpx(args):
78
+ try:
79
+ import httpx
80
+ except ImportError:
81
+ print('httpx is not installed')
82
+ return None
83
+ with httpx.stream("GET", args.url, headers=headers, verify=False, follow_redirects=True) as res:
84
+ fn = get_filename(args, res)
85
+ content_length = int(res.headers.get('content-length', 0))
86
+ print(f'downloading: url={args.url} file={fn} size={content_length if content_length > 0 else "unknown"} lib=httpx block=internal')
87
+ with open(fn, 'wb') as f:
88
+ with pbar:
89
+ task = pbar.add_task(description="Download starting", total=content_length)
90
+ for buf in res.iter_bytes():
91
+ f.write(buf)
92
+ pbar.update(task, advance=args.block, description="Downloading")
93
+ return fn
94
+
95
+
96
+ if __name__ == "__main__":
97
+ parser = argparse.ArgumentParser(description = 'downloader')
98
+ parser.add_argument('--url', required=True, help="download url, required")
99
+ parser.add_argument('--file', required=False, help="output file, default: autodetect")
100
+ parser.add_argument('--lib', required=False, default='requests', choices=['urllib', 'urllib3', 'requests', 'httpx'], help="download mode, default: %(default)s")
101
+ parser.add_argument('--block', required=False, type=int, default=16384, help="download block size, default: %(default)s")
102
+ parsed = parser.parse_args()
103
+ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
104
+ try:
105
+ t0 = time.time()
106
+ if parsed.lib == 'requests':
107
+ filename = download_requests(parsed)
108
+ elif parsed.lib == 'urllib':
109
+ filename = download_urllib(parsed)
110
+ elif parsed.lib == 'urllib3':
111
+ filename = download_urllib3(parsed)
112
+ elif parsed.lib == 'httpx':
113
+ filename = download_httpx(parsed)
114
+ else:
115
+ print(f'unknown download library: {parsed.lib}')
116
+ exit(1)
117
+ t1 = time.time()
118
+ if filename is None:
119
+ print(f'download error: args={parsed}')
120
+ exit(1)
121
+ speed = round(os.path.getsize(filename) / (t1 - t0) / 1024 / 1024, 3)
122
+ print(f'download complete: url={parsed.url} file={filename} speed={speed} mb/s')
123
+ except KeyboardInterrupt:
124
+ print(f'download cancelled: args={parsed}')
125
+ except Exception as e:
126
+ print(f'download error: args={parsed} {e}')
cli/gen-styles.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/env python
2
+
3
+ import io
4
+ import json
5
+ import base64
6
+ import argparse
7
+ import requests
8
+ from PIL import Image
9
+
10
+
11
+ options = {
12
+ "negative_prompt": "",
13
+ "steps": 20,
14
+ "batch_size": 1,
15
+ "n_iter": 1,
16
+ "seed": -1,
17
+ "sampler_name": "UniPC",
18
+ "cfg_scale": 6,
19
+ "width": 512,
20
+ "height": 512,
21
+ "save_images": False,
22
+ "send_images": True,
23
+ }
24
+ styles = []
25
+
26
+
27
+ def pil_to_b64(img: Image, size: int, quality: int):
28
+ img = img.convert('RGB')
29
+ img = img.resize((size, size))
30
+ buffer = io.BytesIO()
31
+ img.save(buffer, format="JPEG", quality=quality)
32
+ b64encoded = base64.b64encode(buffer.getvalue()).decode("utf-8")
33
+ return f'data:image/jpeg;base64,{b64encoded}'
34
+
35
+
36
+ def post(endpoint: str, dct: dict = None):
37
+ req = requests.post(endpoint, json = dct, timeout=300, verify=False)
38
+ if req.status_code != 200:
39
+ return { 'error': req.status_code, 'reason': req.reason, 'url': req.url }
40
+ else:
41
+ return req.json()
42
+
43
+
44
+ if __name__ == '__main__':
45
+ parser = argparse.ArgumentParser(description = 'gen-styles.py')
46
+ parser.add_argument('--input', type=str, required=True, help="input text file with one line per prompt")
47
+ parser.add_argument('--output', type=str, required=True, help="output json file")
48
+ parser.add_argument('--nopreviews', default=False, action='store_true', help = 'generate previews')
49
+ parser.add_argument('--prompt', type=str, required=False, default='girl walking in a city', help="applied prompt when generating previews")
50
+ parser.add_argument('--size', type=int, default=128, help="image size for previews")
51
+ parser.add_argument('--quality', type=int, default=35, help="image quality for previews")
52
+ parser.add_argument('--url', type=str, required=False, default='http://127.0.0.1:7860', help="sd.next server url")
53
+ args = parser.parse_args()
54
+ with open(args.input, encoding='utf-8') as f:
55
+ lines = f.readlines()
56
+ for line in lines:
57
+ line = line.strip().replace('\n', '')
58
+ if len(line) == 0:
59
+ continue
60
+ print(f'processing: {line}')
61
+ if not args.nopreviews:
62
+ options['prompt'] = f'{line} {args.prompt}'
63
+ data = post(f'{args.url}/sdapi/v1/txt2img', options)
64
+ if 'error' in data:
65
+ print(f'error: {data}')
66
+ continue
67
+ b64str = data['images'][0].split(',',1)[0]
68
+ image = Image.open(io.BytesIO(base64.b64decode(b64str)))
69
+ else:
70
+ image = None
71
+ styles.append({
72
+ 'name': line,
73
+ 'prompt': line + ' {prompt}',
74
+ 'negative': '',
75
+ 'extra': '',
76
+ 'preview': pil_to_b64(image, args.size, args.quality) if image is not None else '',
77
+ })
78
+ with open(args.output, 'w', encoding='utf-8') as outfile:
79
+ json.dump(styles, outfile, indent=2)
cli/generate.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "paths":
3
+ {
4
+ "root": "/mnt/c/Users/mandi/OneDrive/Generative/Generate",
5
+ "generate": "image",
6
+ "upscale": "upscale",
7
+ "grid": "grid"
8
+ },
9
+ "generate":
10
+ {
11
+ "restore_faces": true,
12
+ "prompt": "dynamic",
13
+ "negative_prompt": "foggy, blurry, blurred, duplicate, ugly, mutilated, mutation, mutated, out of frame, bad anatomy, disfigured, deformed, censored, low res, watermark, text, poorly drawn face, signature",
14
+ "steps": 30,
15
+ "batch_size": 2,
16
+ "n_iter": 1,
17
+ "seed": -1,
18
+ "sampler_name": "DPM2 Karras",
19
+ "cfg_scale": 6,
20
+ "width": 512,
21
+ "height": 512
22
+ },
23
+ "upscale":
24
+ {
25
+ "upscaler_1": "SwinIR_4x",
26
+ "upscaler_2": "None",
27
+ "upscale_first": false,
28
+ "upscaling_resize": 0,
29
+ "gfpgan_visibility": 0,
30
+ "codeformer_visibility": 0,
31
+ "codeformer_weight": 0.5
32
+ },
33
+ "options":
34
+ {
35
+ "sd_model_checkpoint": "sd-v15-runwayml",
36
+ "sd_vae": "vae-ft-mse-840000-ema-pruned.ckpt"
37
+ }
38
+ }
cli/generate.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # pylint: disable=no-member
3
+ """generate batches of images from prompts and upscale them
4
+
5
+ params: run with `--help`
6
+
7
+ default workflow runs infinite loop and prints stats when interrupted:
8
+ 1. choose random scheduler lookup all available and pick one
9
+ 2. generate dynamic prompt based on styles, embeddings, places, artists, suffixes
10
+ 3. beautify prompt
11
+ 4. generate 3x3 images
12
+ 5. create image grid
13
+ 6. upscale images with face restoration
14
+ """
15
+
16
+ import argparse
17
+ import asyncio
18
+ import base64
19
+ import io
20
+ import json
21
+ import logging
22
+ import math
23
+ import os
24
+ import pathlib
25
+ import secrets
26
+ import time
27
+ import sys
28
+ import importlib
29
+
30
+ from random import randrange
31
+ from PIL import Image
32
+ from PIL.ExifTags import TAGS
33
+ from PIL.TiffImagePlugin import ImageFileDirectory_v2
34
+
35
+ from sdapi import close, get, interrupt, post, session
36
+ from util import Map, log, safestring
37
+
38
+
39
+ sd = {}
40
+ random = {}
41
+ stats = Map({ 'images': 0, 'wall': 0, 'generate': 0, 'upscale': 0 })
42
+ avg = {}
43
+
44
+
45
+ def grid(data):
46
+ if len(data.image) > 1:
47
+ w, h = data.image[0].size
48
+ rows = round(math.sqrt(len(data.image)))
49
+ cols = math.ceil(len(data.image) / rows)
50
+ image = Image.new('RGB', size = (cols * w, rows * h), color = 'black')
51
+ for i, img in enumerate(data.image):
52
+ image.paste(img, box=(i % cols * w, i // cols * h))
53
+ short = data.info.prompt[:min(len(data.info.prompt), 96)] # limit prompt part of filename to 96 chars
54
+ name = '{seed:0>9} {short}'.format(short = short, seed = data.info.all_seeds[0]) # pylint: disable=consider-using-f-string
55
+ name = safestring(name) + '.jpg'
56
+ f = os.path.join(sd.paths.root, sd.paths.grid, name)
57
+ log.info({ 'grid': { 'name': f, 'size': image.size, 'images': len(data.image) } })
58
+ image.save(f, 'JPEG', exif = exif(data.info, None, 'grid'), optimize = True, quality = 70)
59
+ return image
60
+ return data.image
61
+
62
+
63
+ def exif(info, i = None, op = 'generate'):
64
+ seed = [info.all_seeds[i]] if len(info.all_seeds) > 0 and i is not None else info.all_seeds # always returns list
65
+ seed = ', '.join([str(x) for x in seed]) # int list to str list to single str
66
+ template = '{prompt} | negative {negative_prompt} | seed {s} | steps {steps} | cfgscale {cfg_scale} | sampler {sampler_name} | batch {batch_size} | timestamp {job_timestamp} | model {model} | vae {vae}'.format(s = seed, model = sd.options['sd_model_checkpoint'], vae = sd.options['sd_vae'], **info) # pylint: disable=consider-using-f-string
67
+ if op == 'upscale':
68
+ template += ' | faces gfpgan' if sd.upscale.gfpgan_visibility > 0 else ''
69
+ template += ' | faces codeformer' if sd.upscale.codeformer_visibility > 0 else ''
70
+ template += ' | upscale {resize}x {upscaler}'.format(resize = sd.upscale.upscaling_resize, upscaler = sd.upscale.upscaler_1) if sd.upscale.upscaler_1 != 'None' else '' # pylint: disable=consider-using-f-string
71
+ template += ' | upscale {resize}x {upscaler}'.format(resize = sd.upscale.upscaling_resize, upscaler = sd.upscale.upscaler_2) if sd.upscale.upscaler_2 != 'None' else '' # pylint: disable=consider-using-f-string
72
+ if op == 'grid':
73
+ template += ' | grid {num}'.format(num = sd.generate.batch_size * sd.generate.n_iter) # pylint: disable=consider-using-f-string
74
+ ifd = ImageFileDirectory_v2()
75
+ exif_stream = io.BytesIO()
76
+ _TAGS = {v: k for k, v in TAGS.items()} # enumerate possible exif tags
77
+ ifd[_TAGS['ImageDescription']] = template
78
+ ifd.save(exif_stream)
79
+ val = b'Exif\x00\x00' + exif_stream.getvalue()
80
+ return val
81
+
82
+
83
+ def randomize(lst):
84
+ if len(lst) > 0:
85
+ return secrets.choice(lst)
86
+ else:
87
+ return ''
88
+
89
+
90
+ def prompt(params): # generate dynamic prompt or use one if provided
91
+ sd.generate.prompt = params.prompt if params.prompt != 'dynamic' else randomize(random.prompts)
92
+ sd.generate.negative_prompt = params.negative if params.negative != 'dynamic' else randomize(random.negative)
93
+ embedding = params.embedding if params.embedding != 'random' else randomize(random.embeddings)
94
+ sd.generate.prompt = sd.generate.prompt.replace('<embedding>', embedding)
95
+ artist = params.artist if params.artist != 'random' else randomize(random.artists)
96
+ sd.generate.prompt = sd.generate.prompt.replace('<artist>', artist)
97
+ style = params.style if params.style != 'random' else randomize(random.styles)
98
+ sd.generate.prompt = sd.generate.prompt.replace('<style>', style)
99
+ suffix = params.suffix if params.suffix != 'random' else randomize(random.suffixes)
100
+ sd.generate.prompt = sd.generate.prompt.replace('<suffix>', suffix)
101
+ place = params.suffix if params.suffix != 'random' else randomize(random.places)
102
+ sd.generate.prompt = sd.generate.prompt.replace('<place>', place)
103
+ if params.prompts or params.debug:
104
+ log.info({ 'random initializers': random })
105
+ if params.prompt == 'dynamic':
106
+ log.info({ 'dynamic prompt': sd.generate.prompt })
107
+ return sd.generate.prompt
108
+
109
+
110
+ def sampler(params, options): # find sampler
111
+ if params.sampler == 'random':
112
+ sd.generate.sampler_name = randomize(options.samplers)
113
+ log.info({ 'random sampler': sd.generate.sampler_name })
114
+ else:
115
+ found = [i for i in options.samplers if i.startswith(params.sampler)]
116
+ if len(found) == 0:
117
+ log.error({ 'sampler error': sd.generate.sampler_name, 'available': options.samplers})
118
+ exit()
119
+ sd.generate.sampler_name = found[0]
120
+ return sd.generate.sampler_name
121
+
122
+
123
+ async def generate(prompt = None, options = None, quiet = False): # pylint: disable=redefined-outer-name
124
+ global sd # pylint: disable=global-statement
125
+ if options:
126
+ sd = Map(options)
127
+ if prompt is not None:
128
+ sd.generate.prompt = prompt
129
+ if not quiet:
130
+ log.info({ 'generate': sd.generate })
131
+ if sd.get('options', None) is None:
132
+ sd['options'] = await get('/sdapi/v1/options')
133
+ names = []
134
+ b64s = []
135
+ images = []
136
+ info = Map({})
137
+ data = await post('/sdapi/v1/txt2img', sd.generate)
138
+ if 'error' in data:
139
+ log.error({ 'generate': data['error'], 'reason': data['reason'] })
140
+ return Map({})
141
+ info = Map(json.loads(data['info']))
142
+ log.debug({ 'info': info })
143
+ images = data['images']
144
+ short = info.prompt[:min(len(info.prompt), 96)] # limit prompt part of filename to 64 chars
145
+ for i in range(len(images)):
146
+ b64s.append(images[i])
147
+ images[i] = Image.open(io.BytesIO(base64.b64decode(images[i].split(',',1)[0])))
148
+ name = '{seed:0>9} {short}'.format(short = short, seed = info.all_seeds[i]) # pylint: disable=consider-using-f-string
149
+ name = safestring(name) + '.jpg'
150
+ f = os.path.join(sd.paths.root, sd.paths.generate, name)
151
+ names.append(f)
152
+ if not quiet:
153
+ log.info({ 'image': { 'name': f, 'size': images[i].size } })
154
+ images[i].save(f, 'JPEG', exif = exif(info, i), optimize = True, quality = 70)
155
+ return Map({ 'name': names, 'image': images, 'b64': b64s, 'info': info })
156
+
157
+
158
+ async def upscale(data):
159
+ data.upscaled = []
160
+ if sd.upscale.upscaling_resize <=1:
161
+ return data
162
+ sd.upscale.image = ''
163
+ log.info({ 'upscale': sd.upscale })
164
+ for i in range(len(data.image)):
165
+ f = data.name[i].replace(sd.paths.generate, sd.paths.upscale)
166
+ sd.upscale.image = data.b64[i]
167
+ res = await post('/sdapi/v1/extra-single-image', sd.upscale)
168
+ image = Image.open(io.BytesIO(base64.b64decode(res['image'].split(',',1)[0])))
169
+ data.upscaled.append(image)
170
+ log.info({ 'image': { 'name': f, 'size': image.size } })
171
+ image.save(f, 'JPEG', exif = exif(data.info, i, 'upscale'), optimize = True, quality = 70)
172
+ return data
173
+
174
+
175
+ async def init():
176
+ '''
177
+ import torch
178
+ log.info({ 'torch': torch.__version__, 'available': torch.cuda.is_available() })
179
+ current_device = torch.cuda.current_device()
180
+ mem_free, mem_total = torch.cuda.mem_get_info()
181
+ log.info({ 'cuda': torch.version.cuda, 'available': torch.cuda.is_available(), 'arch': torch.cuda.get_arch_list(), 'device': torch.cuda.get_device_name(current_device), 'memory': { 'free': round(mem_free / 1024 / 1024), 'total': (mem_total / 1024 / 1024) } })
182
+ '''
183
+ options = Map({})
184
+ options.flags = await get('/sdapi/v1/cmd-flags')
185
+ log.debug({ 'flags': options.flags })
186
+ data = await get('/sdapi/v1/sd-models')
187
+ options.models = [obj['title'] for obj in data]
188
+ log.debug({ 'registered models': options.models })
189
+ found = sd.options.sd_model_checkpoint if sd.options.sd_model_checkpoint in options.models else None
190
+ if found is None:
191
+ found = [i for i in options.models if i.startswith(sd.options.sd_model_checkpoint)]
192
+ if len(found) == 0:
193
+ log.error({ 'model error': sd.generate.sd_model_checkpoint, 'available': options.models})
194
+ exit()
195
+ sd.options.sd_model_checkpoint = found[0]
196
+ data = await get('/sdapi/v1/samplers')
197
+ options.samplers = [obj['name'] for obj in data]
198
+ log.debug({ 'registered samplers': options.samplers })
199
+ data = await get('/sdapi/v1/upscalers')
200
+ options.upscalers = [obj['name'] for obj in data]
201
+ log.debug({ 'registered upscalers': options.upscalers })
202
+ data = await get('/sdapi/v1/face-restorers')
203
+ options.restorers = [obj['name'] for obj in data]
204
+ log.debug({ 'registered face restorers': options.restorers })
205
+ await interrupt()
206
+ await post('/sdapi/v1/options', sd.options)
207
+ options.options = await get('/sdapi/v1/options')
208
+ log.info({ 'target models': { 'diffuser': options.options['sd_model_checkpoint'], 'vae': options.options['sd_vae'] } })
209
+ log.info({ 'paths': sd.paths })
210
+ options.queue = await get('/queue/status')
211
+ log.info({ 'queue': options.queue })
212
+ pathlib.Path(sd.paths.root).mkdir(parents = True, exist_ok = True)
213
+ pathlib.Path(os.path.join(sd.paths.root, sd.paths.generate)).mkdir(parents = True, exist_ok = True)
214
+ pathlib.Path(os.path.join(sd.paths.root, sd.paths.upscale)).mkdir(parents = True, exist_ok = True)
215
+ pathlib.Path(os.path.join(sd.paths.root, sd.paths.grid)).mkdir(parents = True, exist_ok = True)
216
+ return options
217
+
218
+
219
+ def args(): # parse cmd arguments
220
+ global sd # pylint: disable=global-statement
221
+ global random # pylint: disable=global-statement
222
+ parser = argparse.ArgumentParser(description = 'sd pipeline')
223
+ parser.add_argument('--config', type = str, default = 'generate.json', required = False, help = 'configuration file')
224
+ parser.add_argument('--random', type = str, default = 'random.json', required = False, help = 'prompt file with randomized sections')
225
+ parser.add_argument('--max', type = int, default = 1, required = False, help = 'maximum number of generated images')
226
+ parser.add_argument('--prompt', type = str, default = 'dynamic', required = False, help = 'prompt')
227
+ parser.add_argument('--negative', type = str, default = 'dynamic', required = False, help = 'negative prompt')
228
+ parser.add_argument('--artist', type = str, default = 'random', required = False, help = 'artist style, used to guide dynamic prompt when prompt is not provided')
229
+ parser.add_argument('--embedding', type = str, default = 'random', required = False, help = 'use embedding, used to guide dynamic prompt when prompt is not provided')
230
+ parser.add_argument('--style', type = str, default = 'random', required = False, help = 'image style, used to guide dynamic prompt when prompt is not provided')
231
+ parser.add_argument('--suffix', type = str, default = 'random', required = False, help = 'style suffix, used to guide dynamic prompt when prompt is not provided')
232
+ parser.add_argument('--place', type = str, default = 'random', required = False, help = 'place locator, used to guide dynamic prompt when prompt is not provided')
233
+ parser.add_argument('--faces', default = False, action='store_true', help = 'restore faces during upscaling')
234
+ parser.add_argument('--steps', type = int, default = 0, required = False, help = 'number of steps')
235
+ parser.add_argument('--batch', type = int, default = 0, required = False, help = 'batch size, limited by gpu vram')
236
+ parser.add_argument('--n', type = int, default = 0, required = False, help = 'number of iterations')
237
+ parser.add_argument('--cfg', type = int, default = 0, required = False, help = 'classifier free guidance scale')
238
+ parser.add_argument('--sampler', type = str, default = 'random', required = False, help = 'sampler')
239
+ parser.add_argument('--seed', type = int, default = 0, required = False, help = 'seed, default is random')
240
+ parser.add_argument('--upscale', type = int, default = 0, required = False, help = 'upscale factor, disabled if 0')
241
+ parser.add_argument('--model', type = str, default = '', required = False, help = 'diffusion model')
242
+ parser.add_argument('--vae', type = str, default = '', required = False, help = 'vae model')
243
+ parser.add_argument('--path', type = str, default = '', required = False, help = 'output path')
244
+ parser.add_argument('--width', type = int, default = 0, required = False, help = 'width')
245
+ parser.add_argument('--height', type = int, default = 0, required = False, help = 'height')
246
+ parser.add_argument('--beautify', default = False, action='store_true', help = 'beautify prompt')
247
+ parser.add_argument('--prompts', default = False, action='store_true', help = 'print dynamic prompt templates')
248
+ parser.add_argument('--debug', default = False, action='store_true', help = 'print extra debug information')
249
+ params = parser.parse_args()
250
+ if params.debug:
251
+ log.setLevel(logging.DEBUG)
252
+ log.debug({ 'debug': True })
253
+ log.debug({ 'args': params.__dict__ })
254
+ home = pathlib.Path(sys.argv[0]).parent
255
+ if os.path.isfile(params.config):
256
+ try:
257
+ with open(params.config, 'r', encoding='utf-8') as f:
258
+ data = json.load(f)
259
+ sd = Map(data)
260
+ log.debug({ 'config': sd })
261
+ except Exception as e:
262
+ log.error({ 'config error': params.config, 'exception': e })
263
+ exit()
264
+ elif os.path.isfile(os.path.join(home, params.config)):
265
+ try:
266
+ with open(os.path.join(home, params.config), 'r', encoding='utf-8') as f:
267
+ data = json.load(f)
268
+ sd = Map(data)
269
+ log.debug({ 'config': sd })
270
+ except Exception as e:
271
+ log.error({ 'config error': params.config, 'exception': e })
272
+ exit()
273
+ else:
274
+ log.error({ 'config file not found': params.config})
275
+ exit()
276
+ if params.prompt == 'dynamic':
277
+ log.info({ 'prompt template': params.random })
278
+ if os.path.isfile(params.random):
279
+ try:
280
+ with open(params.random, 'r', encoding='utf-8') as f:
281
+ data = json.load(f)
282
+ random = Map(data)
283
+ log.debug({ 'random template': sd })
284
+ except Exception:
285
+ log.error({ 'random template error': params.random})
286
+ exit()
287
+ elif os.path.isfile(os.path.join(home, params.random)):
288
+ try:
289
+ with open(os.path.join(home, params.random), 'r', encoding='utf-8') as f:
290
+ data = json.load(f)
291
+ random = Map(data)
292
+ log.debug({ 'random template': sd })
293
+ except Exception:
294
+ log.error({ 'random template error': params.random})
295
+ exit()
296
+ else:
297
+ log.error({ 'random template file not found': params.random})
298
+ exit()
299
+ _dynamic = prompt(params)
300
+
301
+ sd.paths.root = params.path if params.path != '' else sd.paths.root
302
+ sd.generate.restore_faces = params.faces if params.faces is not None else sd.generate.restore_faces
303
+ sd.generate.seed = params.seed if params.seed > 0 else sd.generate.seed
304
+ sd.generate.sampler_name = params.sampler if params.sampler != 'random' else sd.generate.sampler_name
305
+ sd.generate.batch_size = params.batch if params.batch > 0 else sd.generate.batch_size
306
+ sd.generate.cfg_scale = params.cfg if params.cfg > 0 else sd.generate.cfg_scale
307
+ sd.generate.n_iter = params.n if params.n > 0 else sd.generate.n_iter
308
+ sd.generate.width = params.width if params.width > 0 else sd.generate.width
309
+ sd.generate.height = params.height if params.height > 0 else sd.generate.height
310
+ sd.generate.steps = params.steps if params.steps > 0 else sd.generate.steps
311
+ sd.upscale.upscaling_resize = params.upscale if params.upscale > 0 else sd.upscale.upscaling_resize
312
+ sd.upscale.codeformer_visibility = 1 if params.faces else sd.upscale.codeformer_visibility
313
+ sd.options.sd_vae = params.vae if params.vae != '' else sd.options.sd_vae
314
+ sd.options.sd_model_checkpoint = params.model if params.model != '' else sd.options.sd_model_checkpoint
315
+ sd.upscale.upscaler_1 = 'SwinIR_4x' if params.upscale > 1 else sd.upscale.upscaler_1
316
+ if sd.generate.cfg_scale == 0:
317
+ sd.generate.cfg_scale = randrange(5, 10)
318
+ return params
319
+
320
+
321
+ async def main():
322
+ params = args()
323
+ sess = await session()
324
+ if sess is None:
325
+ await close()
326
+ exit()
327
+ options = await init()
328
+ iteration = 0
329
+ while True:
330
+ iteration += 1
331
+ log.info('')
332
+ log.info({ 'iteration': iteration, 'batch': sd.generate.batch_size, 'n': sd.generate.n_iter, 'total': sd.generate.n_iter * sd.generate.batch_size })
333
+ dynamic = prompt(params)
334
+ if params.beautify:
335
+ try:
336
+ promptist = importlib.import_module('modules.promptist')
337
+ sd.generate.prompt = promptist.beautify(dynamic)
338
+ except Exception as e:
339
+ log.error({ 'beautify': e })
340
+ scheduler = sampler(params, options)
341
+ t0 = time.perf_counter()
342
+ data = await generate() # generate returns list of images
343
+ if 'image' not in data:
344
+ break
345
+ stats.images += len(data.image)
346
+ t1 = time.perf_counter()
347
+ if len(data.image) > 0:
348
+ avg[scheduler] = (t1 - t0) / len(data.image)
349
+ stats.generate += t1 - t0
350
+ _image = grid(data)
351
+ data = await upscale(data)
352
+ t2 = time.perf_counter()
353
+ stats.upscale += t2 - t1
354
+ stats.wall += t2 - t0
355
+ its = sd.generate.steps / ((t1 - t0) / len(data.image)) if len(data.image) > 0 else 0
356
+ avg_time = round((t1 - t0) / len(data.image)) if len(data.image) > 0 else 0
357
+ log.info({ 'time' : { 'wall': round(t1 - t0), 'average': avg_time, 'upscale': round(t2 - t1), 'its': round(its, 2) } })
358
+ log.info({ 'generated': stats.images, 'max': params.max, 'progress': round(100 * stats.images / params.max, 1) })
359
+ if params.max != 0 and stats.images >= params.max:
360
+ break
361
+
362
+
363
+ if __name__ == '__main__':
364
+ try:
365
+ asyncio.run(main())
366
+ except KeyboardInterrupt:
367
+ asyncio.run(interrupt())
368
+ asyncio.run(close())
369
+ log.info({ 'interrupt': True })
370
+ finally:
371
+ log.info({ 'sampler performance': avg })
372
+ log.info({ 'stats' : stats })
373
+ asyncio.run(close())
cli/hf-convert.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+ import sys
5
+ import logging
6
+ import torch
7
+ import diffusers
8
+ import safetensors
9
+ import safetensors.torch as sf
10
+
11
+ log = logging.getLogger("sd")
12
+ logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s | %(message)s')
13
+
14
+
15
+ def convert(model_id, output_name):
16
+ if os.path.exists(output_name):
17
+ log.error(f'Output already exists: {output_name}')
18
+ return
19
+ pipe = diffusers.DiffusionPipeline.from_pretrained(model_id)
20
+ metadata = { 'model_id': model_id }
21
+ model = {}
22
+ model['state_dict'] = vars(pipe)['_internal_dict']
23
+ for k in model['state_dict'].keys():
24
+ # print(k, getattr(pipe, k))
25
+ model[k] = getattr(pipe, k)
26
+ sf.save_model(model, output_name, metadata=metadata)
27
+ # log.info(f'Saved model: {output_name}')
28
+
29
+ if __name__ == "__main__":
30
+ sys.argv.pop(0)
31
+ if len(sys.argv) < 2:
32
+ log.info('Usage: hf-convert.py <model_id> <output_name>')
33
+ sys.exit(1)
34
+ log.debug(f'Packages: torch={torch.__version__} diffusers={diffusers.__version__} safetensors={safetensors.__version__}')
35
+ convert(sys.argv[0], sys.argv[1])
cli/hf-search.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import sys
4
+ import huggingface_hub as hf
5
+ from rich import print # pylint: disable=redefined-builtin
6
+
7
+ if __name__ == "__main__":
8
+ sys.argv.pop(0)
9
+ keyword = sys.argv[0] if len(sys.argv) > 0 else ''
10
+ hf_api = hf.HfApi()
11
+ model_filter = hf.ModelFilter(
12
+ model_name=keyword,
13
+ # task='text-to-image',
14
+ library=['diffusers'],
15
+ )
16
+ res = hf_api.list_models(filter=model_filter, full=True, limit=50, sort="downloads", direction=-1)
17
+ models = [{ 'name': m.modelId, 'downloads': m.downloads, 'mtime': m.lastModified, 'url': f'https://huggingface.co/{m.modelId}', 'pipeline': m.pipeline_tag, 'tags': m.tags } for m in res]
18
+ print(models)
cli/idle.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+ import time
5
+ import datetime
6
+ import logging
7
+ import urllib3
8
+ import requests
9
+
10
+ class Dot(dict):
11
+ __getattr__ = dict.get
12
+ __setattr__ = dict.__setitem__
13
+ __delattr__ = dict.__delitem__
14
+
15
+ opts = Dot({
16
+ "timeout": 3600,
17
+ "frequency": 60,
18
+ "action": "sudo shutdown now",
19
+ "url": "https://127.0.0.1:7860",
20
+ "user": "",
21
+ "password": "",
22
+ })
23
+
24
+ log_format = '%(asctime)s %(levelname)s: %(message)s'
25
+ logging.basicConfig(level = logging.INFO, format = log_format)
26
+ log = logging.getLogger("sd")
27
+ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
28
+ status = None
29
+
30
+ def progress():
31
+ auth = requests.auth.HTTPBasicAuth(opts.user, opts.password) if opts.user is not None and len(opts.user) > 0 and opts.password is not None and len(opts.password) > 0 else None
32
+ req = requests.get(f'{opts.url}/sdapi/v1/progress?skip_current_image=true', verify=False, auth=auth, timeout=60)
33
+ if req.status_code != 200:
34
+ log.error({ 'url': req.url, 'request': req.status_code, 'reason': req.reason })
35
+ return status
36
+ else:
37
+ res = Dot(req.json())
38
+ log.debug({ 'url': req.url, 'request': req.status_code, 'result': res })
39
+ return res
40
+
41
+ log.info(f'sdnext monitor started: {opts}')
42
+ while True:
43
+ try:
44
+ status = progress()
45
+ state = status.get('state', {})
46
+ last_job = state.get('job_timestamp', None)
47
+ if last_job is None:
48
+ log.warning(f'sdnext montoring cannot get last job info: {status}')
49
+ else:
50
+ last_job = datetime.datetime.strptime(last_job, "%Y%m%d%H%M%S")
51
+ elapsed = datetime.datetime.now() - last_job
52
+ timeout = round(opts.timeout - elapsed.total_seconds())
53
+ log.info(f'sdnext: last_job={last_job} elapsed={elapsed} timeout={timeout}')
54
+ if timeout < 0:
55
+ log.warning(f'sdnext reached: timeout={opts.timeout} action={opts.action}')
56
+ os.system(opts.action)
57
+ except Exception as e:
58
+ log.error(f'sdnext monitor error: {e}')
59
+ finally:
60
+ time.sleep(opts.frequency)
cli/image-exif.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/env python
2
+
3
+ import os
4
+ import io
5
+ import re
6
+ import sys
7
+ import json
8
+ from PIL import Image, ExifTags, TiffImagePlugin, PngImagePlugin
9
+ from rich import print # pylint: disable=redefined-builtin
10
+
11
+
12
+ def unquote(text):
13
+ if len(text) == 0 or text[0] != '"' or text[-1] != '"':
14
+ return text
15
+ try:
16
+ return json.loads(text)
17
+ except Exception:
18
+ return text
19
+
20
+
21
+ def parse_generation_parameters(infotext): # copied from modules.generation_parameters_copypaste
22
+ if not isinstance(infotext, str):
23
+ return {}
24
+
25
+ re_param = re.compile(r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)') # multi-word: value
26
+ re_size = re.compile(r"^(\d+)x(\d+)$") # int x int
27
+ sanitized = infotext.replace('prompt:', 'Prompt:').replace('negative prompt:', 'Negative prompt:').replace('Negative Prompt', 'Negative prompt') # cleanup everything in brackets so re_params can work
28
+ sanitized = re.sub(r'<[^>]*>', lambda match: ' ' * len(match.group()), sanitized)
29
+ sanitized = re.sub(r'\([^)]*\)', lambda match: ' ' * len(match.group()), sanitized)
30
+ sanitized = re.sub(r'\{[^}]*\}', lambda match: ' ' * len(match.group()), sanitized)
31
+
32
+ params = dict(re_param.findall(sanitized))
33
+ params = { k.strip():params[k].strip() for k in params if k.lower() not in ['hashes', 'lora', 'embeddings', 'prompt', 'negative prompt']} # remove some keys
34
+ first_param = next(iter(params)) if params else None
35
+ params_idx = sanitized.find(f'{first_param}:') if first_param else -1
36
+ negative_idx = infotext.find("Negative prompt:")
37
+
38
+ prompt = infotext[:params_idx] if negative_idx == -1 else infotext[:negative_idx] # prompt can be with or without negative prompt
39
+ negative = infotext[negative_idx:params_idx] if negative_idx >= 0 else ''
40
+
41
+ for k, v in params.copy().items(): # avoid dict-has-changed
42
+ if len(v) > 0 and v[0] == '"' and v[-1] == '"':
43
+ v = unquote(v)
44
+ m = re_size.match(v)
45
+ if v.replace('.', '', 1).isdigit():
46
+ params[k] = float(v) if '.' in v else int(v)
47
+ elif v == "True":
48
+ params[k] = True
49
+ elif v == "False":
50
+ params[k] = False
51
+ elif m is not None:
52
+ params[f"{k}-1"] = int(m.group(1))
53
+ params[f"{k}-2"] = int(m.group(2))
54
+ elif k == 'VAE' and v == 'TAESD':
55
+ params["Full quality"] = False
56
+ else:
57
+ params[k] = v
58
+ params["Prompt"] = prompt.replace('Prompt:', '').strip()
59
+ params["Negative prompt"] = negative.replace('Negative prompt:', '').strip()
60
+ return params
61
+
62
+
63
+ class Exif: # pylint: disable=single-string-used-for-slots
64
+ __slots__ = ('__dict__') # pylint: disable=superfluous-parens
65
+ def __init__(self, image = None):
66
+ super(Exif, self).__setattr__('exif', Image.Exif()) # pylint: disable=super-with-arguments
67
+ self.pnginfo = PngImagePlugin.PngInfo()
68
+ self.tags = {**dict(ExifTags.TAGS.items()), **dict(ExifTags.GPSTAGS.items())}
69
+ self.ids = {**{v: k for k, v in ExifTags.TAGS.items()}, **{v: k for k, v in ExifTags.GPSTAGS.items()}}
70
+ if image is not None:
71
+ self.load(image)
72
+
73
+ def __getattr__(self, attr):
74
+ if attr in self.__dict__:
75
+ return self.__dict__[attr]
76
+ return self.exif.get(attr, None)
77
+
78
+ def load(self, img: Image):
79
+ img.load() # exif may not be ready
80
+ exif_dict = {}
81
+ try:
82
+ exif_dict = dict(img._getexif().items()) # pylint: disable=protected-access
83
+ except Exception:
84
+ exif_dict = dict(img.info.items())
85
+ for key, val in exif_dict.items():
86
+ if isinstance(val, bytes): # decode bytestring
87
+ val = self.decode(val)
88
+ if val is not None:
89
+ if isinstance(key, str):
90
+ self.exif[key] = val
91
+ self.pnginfo.add_text(key, str(val), zip=False)
92
+ elif isinstance(key, int) and key in ExifTags.TAGS: # add known tags
93
+ if self.tags[key] in ['ExifOffset']:
94
+ continue
95
+ self.exif[self.tags[key]] = val
96
+ self.pnginfo.add_text(self.tags[key], str(val), zip=False)
97
+ # if self.tags[key] == 'UserComment': # add geninfo from UserComment
98
+ # self.geninfo = val
99
+ else:
100
+ print('metadata unknown tag:', key, val)
101
+ for key, val in self.exif.items():
102
+ if isinstance(val, bytes): # decode bytestring
103
+ self.exif[key] = self.decode(val)
104
+
105
+ def decode(self, s: bytes):
106
+ remove_prefix = lambda text, prefix: text[len(prefix):] if text.startswith(prefix) else text # pylint: disable=unnecessary-lambda-assignment
107
+ for encoding in ['utf-8', 'utf-16', 'ascii', 'latin_1', 'cp1252', 'cp437']: # try different encodings
108
+ try:
109
+ s = remove_prefix(s, b'UNICODE')
110
+ s = remove_prefix(s, b'ASCII')
111
+ s = remove_prefix(s, b'\x00')
112
+ val = s.decode(encoding, errors="strict")
113
+ val = re.sub(r'[\x00-\x09]', '', val).strip() # remove remaining special characters
114
+ if len(val) == 0: # remove empty strings
115
+ val = None
116
+ return val
117
+ except Exception:
118
+ pass
119
+ return None
120
+
121
+ def parse(self):
122
+ x = self.exif.pop('parameters', None) or self.exif.pop('UserComment', None)
123
+ res = parse_generation_parameters(x)
124
+ return res
125
+
126
+ def get_bytes(self):
127
+ ifd = TiffImagePlugin.ImageFileDirectory_v2()
128
+ exif_stream = io.BytesIO()
129
+ for key, val in self.exif.items():
130
+ if key in self.ids:
131
+ ifd[self.ids[key]] = val
132
+ else:
133
+ print('metadata unknown exif tag:', key, val)
134
+ ifd.save(exif_stream)
135
+ raw = b'Exif\x00\x00' + exif_stream.getvalue()
136
+ return raw
137
+
138
+
139
+ def read_exif(filename: str):
140
+ if filename.lower().endswith('.heic'):
141
+ from pi_heif import register_heif_opener
142
+ register_heif_opener()
143
+ try:
144
+ image = Image.open(filename)
145
+ exif = Exif(image)
146
+ print('image:', filename, 'format:', image)
147
+ print('exif:', vars(exif.exif)['_data'])
148
+ print('info:', exif.parse())
149
+ except Exception as e:
150
+ print('metadata error reading:', filename, e)
151
+
152
+
153
+ if __name__ == '__main__':
154
+ sys.argv.pop(0)
155
+ if len(sys.argv) == 0:
156
+ print('metadata:', 'no files specified')
157
+ for fn in sys.argv:
158
+ if os.path.isfile(fn):
159
+ read_exif(fn)
160
+ elif os.path.isdir(fn):
161
+ for root, _dirs, files in os.walk(fn):
162
+ for file in files:
163
+ read_exif(os.path.join(root, file))
cli/image-grid.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ Create image grid
4
+ """
5
+
6
+ import os
7
+ import argparse
8
+ import math
9
+ import logging
10
+ from pathlib import Path
11
+ import filetype
12
+ from PIL import Image, ImageDraw, ImageFont
13
+ from util import log
14
+
15
+
16
+ params = None
17
+
18
+
19
+ def wrap(text: str, font: ImageFont.ImageFont, length: int):
20
+ lines = ['']
21
+ for word in text.split():
22
+ line = f'{lines[-1]} {word}'.strip()
23
+ if font.getlength(line) <= length:
24
+ lines[-1] = line
25
+ else:
26
+ lines.append(word)
27
+ return '\n'.join(lines)
28
+
29
+
30
+ def grid(images, labels = None, width = 0, height = 0, border = 0, square = False, horizontal = False, vertical = False): # pylint: disable=redefined-outer-name
31
+ if horizontal:
32
+ rows = 1
33
+ elif vertical:
34
+ rows = len(images)
35
+ elif square:
36
+ rows = round(math.sqrt(len(images)))
37
+ else:
38
+ rows = math.floor(math.sqrt(len(images)))
39
+ cols = math.ceil(len(images) / rows)
40
+ size = [0, 0]
41
+ if width == 0:
42
+ w = max([i.size[0] for i in images])
43
+ size[0] = cols * w + cols * border
44
+ else:
45
+ size[0] = width
46
+ w = round(width / cols)
47
+ if height == 0:
48
+ h = max([i.size[1] for i in images])
49
+ size[1] = rows * h + rows * border
50
+ else:
51
+ size[1] = height
52
+ h = round(height / rows)
53
+ size = tuple(size)
54
+ image = Image.new('RGB', size = size, color = 'black') # pylint: disable=redefined-outer-name
55
+ font = ImageFont.truetype('DejaVuSansMono', round(w / 40))
56
+ for i, img in enumerate(images): # pylint: disable=redefined-outer-name
57
+ x = (i % cols * w) + (i % cols * border)
58
+ y = (i // cols * h) + (i // cols * border)
59
+ img.thumbnail((w, h), Image.Resampling.HAMMING)
60
+ image.paste(img, box=(x + int(border / 2), y + int(border / 2)))
61
+ if labels is not None and len(images) == len(labels):
62
+ ctx = ImageDraw.Draw(image)
63
+ label = wrap(labels[i], font, w)
64
+ ctx.text((x + 1 + round(w / 200), y + 1 + round(w / 200)), label, font = font, fill = (0, 0, 0))
65
+ ctx.text((x, y), label, font = font, fill = (255, 255, 255))
66
+ log.info({ 'grid': { 'images': len(images), 'rows': rows, 'cols': cols, 'cell': [w, h] } })
67
+ return image
68
+
69
+
70
+ if __name__ == '__main__':
71
+ log.info({ 'create grid' })
72
+ parser = argparse.ArgumentParser(description='image grid utility')
73
+ parser.add_argument("--square", default = False, action='store_true', help = "create square grid")
74
+ parser.add_argument("--horizontal", default = False, action='store_true', help = "create horizontal grid")
75
+ parser.add_argument("--vertical", default = False, action='store_true', help = "create vertical grid")
76
+ parser.add_argument("--width", type = int, default = 0, required = False, help = "fixed grid width")
77
+ parser.add_argument("--height", type = int, default = 0, required = False, help = "fixed grid height")
78
+ parser.add_argument("--border", type = int, default = 0, required = False, help = "image border")
79
+ parser.add_argument('--nolabels', default = False, action='store_true', help = "do not print image labels")
80
+ parser.add_argument('--debug', default = False, action='store_true', help = "print extra debug information")
81
+ parser.add_argument('output', type = str)
82
+ parser.add_argument('input', type = str, nargs = '*')
83
+ params = parser.parse_args()
84
+ output = params.output if params.output.lower().endswith('.jpg') else params.output + '.jpg'
85
+ if params.debug:
86
+ log.setLevel(logging.DEBUG)
87
+ log.debug({ 'debug': True })
88
+ log.debug({ 'args': params.__dict__ })
89
+ images = []
90
+ labels = []
91
+ for f in params.input:
92
+ path = Path(f)
93
+ if path.is_dir():
94
+ files = [os.path.join(f, file) for file in os.listdir(f) if os.path.isfile(os.path.join(f, file))]
95
+ elif path.is_file():
96
+ files = [f]
97
+ else:
98
+ log.warning({ 'grid not a valid file/folder', f})
99
+ continue
100
+ files.sort()
101
+ for file in files:
102
+ if not filetype.is_image(file):
103
+ continue
104
+ if file.lower().endswith('.heic'):
105
+ from pi_heif import register_heif_opener
106
+ register_heif_opener()
107
+ log.debug(file)
108
+ img = Image.open(file)
109
+ # img.verify()
110
+ images.append(img)
111
+ fp = Path(file)
112
+ if not params.nolabels:
113
+ labels.append(fp.stem)
114
+ # log.info({ 'folder': path.parent, 'labels': labels })
115
+ if len(images) > 0:
116
+ image = grid(
117
+ images = images,
118
+ labels = labels,
119
+ width = params.width,
120
+ height = params.height,
121
+ border = params.border,
122
+ square = params.square,
123
+ horizontal = params.horizontal,
124
+ vertical = params.vertical)
125
+ image.save(output, 'JPEG', optimize = True, quality = 60)
126
+ log.info({ 'grid': { 'file': output, 'size': list(image.size) } })
127
+ else:
128
+ log.info({ 'grid': 'nothing to do' })
cli/image-interrogate.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ use clip to interrogate image(s)
4
+ """
5
+
6
+ import io
7
+ import base64
8
+ import sys
9
+ import os
10
+ import asyncio
11
+ import filetype
12
+ from PIL import Image
13
+ from util import log, Map
14
+ import sdapi
15
+
16
+
17
+ stats = { 'captions': {}, 'keywords': {} }
18
+ exclude = ['a', 'in', 'on', 'out', 'at', 'the', 'and', 'with', 'next', 'to', 'it', 'for', 'of', 'into', 'that']
19
+
20
+
21
+ def decode(encoding):
22
+ if encoding.startswith("data:image/"):
23
+ encoding = encoding.split(";")[1].split(",")[1]
24
+ return Image.open(io.BytesIO(base64.b64decode(encoding)))
25
+
26
+
27
+ def encode(f):
28
+ image = Image.open(f)
29
+ exif = image.getexif()
30
+ if image.mode == 'RGBA':
31
+ image = image.convert('RGB')
32
+ with io.BytesIO() as stream:
33
+ image.save(stream, 'JPEG', exif = exif)
34
+ values = stream.getvalue()
35
+ encoded = base64.b64encode(values).decode()
36
+ return encoded
37
+
38
+
39
+ def print_summary():
40
+ captions = dict(sorted(stats['captions'].items(), key=lambda x:x[1], reverse=True))
41
+ log.info({ 'caption stats': captions })
42
+ keywords = dict(sorted(stats['keywords'].items(), key=lambda x:x[1], reverse=True))
43
+ log.info({ 'keyword stats': keywords })
44
+
45
+
46
+ async def interrogate(f):
47
+ if not filetype.is_image(f):
48
+ log.info({ 'interrogate skip': f })
49
+ return
50
+ json = Map({ 'image': encode(f) })
51
+ log.info({ 'interrogate': f })
52
+ # run clip
53
+ json.model = 'clip'
54
+ res = await sdapi.post('/sdapi/v1/interrogate', json)
55
+ caption = ""
56
+ style = ""
57
+ if 'caption' in res:
58
+ caption = res.caption
59
+ log.info({ 'interrogate caption': caption })
60
+ if ', by' in caption:
61
+ style = caption.split(', by')[1].strip()
62
+ log.info({ 'interrogate style': style })
63
+ for word in caption.split(' '):
64
+ if word not in exclude:
65
+ stats['captions'][word] = stats['captions'][word] + 1 if word in stats['captions'] else 1
66
+ else:
67
+ log.error({ 'interrogate clip error': res })
68
+ # run booru
69
+ json.model = 'deepdanbooru'
70
+ res = await sdapi.post('/sdapi/v1/interrogate', json)
71
+ keywords = {}
72
+ if 'caption' in res:
73
+ for term in res.caption.split(', '):
74
+ term = term.replace('(', '').replace(')', '').replace('\\', '').split(':')
75
+ if len(term) < 2:
76
+ continue
77
+ keywords[term[0]] = term[1]
78
+ keywords = dict(sorted(keywords.items(), key=lambda x:x[1], reverse=True))
79
+ for word in keywords.items():
80
+ stats['keywords'][word[0]] = stats['keywords'][word[0]] + 1 if word[0] in stats['keywords'] else 1
81
+ log.info({ 'interrogate keywords': keywords })
82
+ else:
83
+ log.error({ 'interrogate booru error': res })
84
+ return caption, keywords, style
85
+
86
+
87
+ async def main():
88
+ sys.argv.pop(0)
89
+ await sdapi.session()
90
+ if len(sys.argv) == 0:
91
+ log.error({ 'interrogate': 'no files specified' })
92
+ for arg in sys.argv:
93
+ if os.path.exists(arg):
94
+ if os.path.isfile(arg):
95
+ await interrogate(arg)
96
+ elif os.path.isdir(arg):
97
+ for root, _dirs, files in os.walk(arg):
98
+ for f in files:
99
+ _caption, _keywords, _style = await interrogate(os.path.join(root, f))
100
+ else:
101
+ log.error({ 'interrogate unknown file type': arg })
102
+ else:
103
+ log.error({ 'interrogate file missing': arg })
104
+ await sdapi.close()
105
+ print_summary()
106
+
107
+
108
+ if __name__ == "__main__":
109
+ asyncio.run(main())
cli/image-palette.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # based on <https://towardsdatascience.com/image-color-extraction-with-python-in-4-steps-8d9370d9216e>
3
+
4
+ import os
5
+ import io
6
+ import pathlib
7
+ import argparse
8
+ import importlib
9
+ import pandas as pd
10
+ import numpy as np
11
+ import extcolors
12
+ import filetype
13
+ import matplotlib.pyplot as plt
14
+ import matplotlib.patches as patches
15
+ from matplotlib.offsetbox import OffsetImage, AnnotationBbox
16
+ from colormap import rgb2hex
17
+ from PIL import Image
18
+ from util import log
19
+ grid = importlib.import_module('image-grid').grid
20
+
21
+
22
+ def color_to_df(param):
23
+ colors_pre_list = str(param).replace('([(','').split(', (')[0:-1]
24
+ df_rgb = [i.split('), ')[0] + ')' for i in colors_pre_list]
25
+ df_percent = [i.split('), ')[1].replace(')','') for i in colors_pre_list]
26
+ #convert RGB to HEX code
27
+ df_color_up = [rgb2hex(int(i.split(", ")[0].replace("(","")),
28
+ int(i.split(", ")[1]),
29
+ int(i.split(", ")[2].replace(")",""))) for i in df_rgb]
30
+ df = pd.DataFrame(zip(df_color_up, df_percent), columns = ['c_code','occurence'])
31
+ return df
32
+
33
+
34
+ def palette(img, params, output):
35
+ size = 1024
36
+ img.thumbnail((size, size), Image.HAMMING)
37
+
38
+ #crate dataframe
39
+ colors_x = extcolors.extract_from_image(img, tolerance = params.color, limit = 13)
40
+ df_color = color_to_df(colors_x)
41
+
42
+ #annotate text
43
+ list_color = list(df_color['c_code'])
44
+ list_precent = [int(i) for i in list(df_color['occurence'])]
45
+ text_c = [c + ' ' + str(round(p * 100 / sum(list_precent), 1)) +'%' for c, p in zip(list_color, list_precent)]
46
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(120,60), dpi=10)
47
+ fig.set_facecolor('black')
48
+
49
+ #donut plot
50
+ wedges, _text = ax1.pie(list_precent, labels= text_c, labeldistance= 1.05, colors = list_color, textprops={'fontsize': 100, 'color':'white'})
51
+ plt.setp(wedges, width=0.3)
52
+
53
+ #add image in the center of donut plot
54
+ data = np.asarray(img)
55
+ imagebox = OffsetImage(data, zoom=2.5)
56
+ ab = AnnotationBbox(imagebox, (0, 0))
57
+ ax1.add_artist(ab)
58
+
59
+ #color palette
60
+ x_posi, y_posi, y_posi2 = 160, -260, -260
61
+ for c in list_color:
62
+ if list_color.index(c) <= 5:
63
+ y_posi += 240
64
+ rect = patches.Rectangle((x_posi, y_posi), 540, 230, facecolor = c)
65
+ ax2.add_patch(rect)
66
+ ax2.text(x = x_posi + 100, y = y_posi + 140, s = c, fontdict={'fontsize': 140}, color = 'white')
67
+ else:
68
+ y_posi2 += 240
69
+ rect = patches.Rectangle((x_posi + 600, y_posi2), 540, 230, facecolor = c)
70
+ ax2.add_artist(rect)
71
+ ax2.text(x = x_posi + 700, y = y_posi2 + 140, s = c, fontdict={'fontsize': 140}, color = 'white')
72
+
73
+ # add background to force layout
74
+ fig.set_facecolor('black')
75
+ ax2.axis('off')
76
+ tmp = Image.new('RGB', (2000, 1400), (0, 0, 0))
77
+ plt.imshow(tmp)
78
+ plt.tight_layout(rect = (-0.08, -0.2, 1.18, 1.05))
79
+
80
+ # save image
81
+ if output is not None:
82
+ buf = io.BytesIO()
83
+ plt.savefig(buf, format='png')
84
+ pltimg = Image.open(buf)
85
+ pltimg = pltimg.convert('RGB')
86
+ pltimg.save(output)
87
+ buf.close()
88
+ log.info({ 'palette created': output })
89
+
90
+ plt.close()
91
+
92
+
93
+ if __name__ == '__main__':
94
+ parser = argparse.ArgumentParser(description = 'extract image color palette')
95
+ parser.add_argument('--color', type=int, default=20, help="color tolerance threshdold")
96
+ parser.add_argument('--output', type=str, required=False, default='', help='folder to store images')
97
+ parser.add_argument('--suffix', type=str, required=False, default='pallete', help='add suffix to image name')
98
+ parser.add_argument('--grid', default=False, action='store_true', help = "create grid of images before processing")
99
+ parser.add_argument('input', type=str, nargs='*')
100
+ args = parser.parse_args()
101
+ log.info({ 'palette args': vars(args) })
102
+ if args.output != '':
103
+ pathlib.Path(args.output).mkdir(parents = True, exist_ok = True)
104
+ if not args.grid:
105
+ for arg in args.input:
106
+ if os.path.isfile(arg) and filetype.is_image(arg):
107
+ image = Image.open(arg)
108
+ fn = os.path.join(args.output, pathlib.Path(arg).stem + '-' + args.suffix + '.jpg')
109
+ palette(image, args, fn)
110
+ elif os.path.isdir(arg):
111
+ for root, _dirs, files in os.walk(arg):
112
+ for f in files:
113
+ if filetype.is_image(os.path.join(root, f)):
114
+ image = Image.open(os.path.join(root, f))
115
+ fn = os.path.join(args.output, pathlib.Path(f).stem + '-' + args.suffix + '.jpg')
116
+ palette(image, args, fn)
117
+ else:
118
+ images = []
119
+ for arg in args.input:
120
+ if os.path.isfile(arg) and filetype.is_image(arg):
121
+ images.append(Image.open(arg))
122
+ elif os.path.isdir(arg):
123
+ for root, _dirs, files in os.walk(arg):
124
+ for f in files:
125
+ if filetype.is_image(os.path.join(root, f)):
126
+ images.append(Image.open(os.path.join(root, f)))
127
+ image = grid(images)
128
+ fn = os.path.join(args.output, args.suffix + '.jpg')
129
+ palette(image, args, fn)
cli/image-watermark.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ import os
3
+ import io
4
+ import pathlib
5
+ import argparse
6
+ import filetype
7
+ import numpy as np
8
+ from imwatermark import WatermarkEncoder, WatermarkDecoder
9
+ from PIL import Image
10
+ from PIL.ExifTags import TAGS
11
+ from PIL.TiffImagePlugin import ImageFileDirectory_v2
12
+ from util import log, Map
13
+ import piexif
14
+ import piexif.helper
15
+
16
+
17
+ options = Map({ 'method': 'dwtDctSvd', 'type': 'bytes' })
18
+
19
+
20
+ def get_exif(image):
21
+ # using piexif
22
+ res1 = {}
23
+ try:
24
+ exif = piexif.load(image.info["exif"])
25
+ exif = exif.get("Exif", {})
26
+ for k, v in exif.items():
27
+ key = list(vars(piexif.ExifIFD).keys())[list(vars(piexif.ExifIFD).values()).index(k)]
28
+ res1[key] = piexif.helper.UserComment.load(v)
29
+ except Exception:
30
+ pass
31
+ # using pillow
32
+ res2 = {}
33
+ try:
34
+ res2 = { TAGS[k]: v for k, v in image.getexif().items() if k in TAGS }
35
+ except Exception:
36
+ pass
37
+ return {**res1, **res2}
38
+
39
+
40
+ def set_exif(d: dict):
41
+ ifd = ImageFileDirectory_v2()
42
+ _TAGS = {v: k for k, v in TAGS.items()} # enumerate possible exif tags
43
+ for k, v in d.items():
44
+ ifd[_TAGS[k]] = v
45
+ exif_stream = io.BytesIO()
46
+ ifd.save(exif_stream)
47
+ encoded = b'Exif\x00\x00' + exif_stream.getvalue()
48
+ return encoded
49
+
50
+
51
+ def get_watermark(image, params):
52
+ data = np.asarray(image)
53
+ decoder = WatermarkDecoder(options.type, params.length)
54
+ decoded = decoder.decode(data, options.method)
55
+ wm = decoded.decode(encoding='ascii', errors='ignore')
56
+ return wm
57
+
58
+
59
+ def set_watermark(image, params):
60
+ data = np.asarray(image)
61
+ encoder = WatermarkEncoder()
62
+ length = params.length // 8
63
+ text = f"{params.wm:<{length}}"[:length]
64
+ bytearr = text.encode(encoding='ascii', errors='ignore')
65
+ encoder.set_watermark(options.type, bytearr)
66
+ encoded = encoder.encode(data, options.method)
67
+ image = Image.fromarray(encoded)
68
+ return image
69
+
70
+
71
+ def watermark(params, file):
72
+ if not os.path.exists(file):
73
+ log.error({ 'watermark': 'file not found' })
74
+ return
75
+ if not filetype.is_image(file):
76
+ log.error({ 'watermark': 'file is not an image' })
77
+ return
78
+ image = Image.open(file)
79
+ if image.width * image.height < 256 * 256:
80
+ log.error({ 'watermark': 'image too small' })
81
+ return
82
+
83
+ exif = get_exif(image)
84
+
85
+ if params.command == 'read':
86
+ fn = params.input
87
+ wm = get_watermark(image, params)
88
+
89
+ elif params.command == 'write':
90
+ metadata = b'' if params.strip else set_exif(exif)
91
+ if params.output != '':
92
+ pathlib.Path(params.output).mkdir(parents = True, exist_ok = True)
93
+ image=set_watermark(image, params)
94
+ fn = os.path.join(params.output, file)
95
+ image.save(fn, exif=metadata)
96
+
97
+ if params.verify:
98
+ image = Image.open(fn)
99
+ data = np.asarray(image)
100
+ decoder = WatermarkDecoder(options.type, params.length)
101
+ decoded = decoder.decode(data, options.method)
102
+ wm = decoded.decode(encoding='ascii', errors='ignore')
103
+ else:
104
+ wm = params.wm
105
+
106
+ log.info({ 'file': fn })
107
+ log.info({ 'resolution': f'{image.width}x{image.height}' })
108
+ log.info({ 'watermark': wm })
109
+ log.info({ 'exif': None if params.strip else exif })
110
+
111
+
112
+ if __name__ == '__main__':
113
+ parser = argparse.ArgumentParser(description = 'image watermarking')
114
+ parser.add_argument('command', choices = ['read', 'write'])
115
+ parser.add_argument('--wm', type=str, required=False, default='sdnext', help='watermark string')
116
+ parser.add_argument('--strip', default=False, action='store_true', help = "strip existing exif data")
117
+ parser.add_argument('--verify', default=False, action='store_true', help = "verify watermark during write")
118
+ parser.add_argument('--length', type=int, default=32, help="watermark length in bits")
119
+ parser.add_argument('--output', type=str, required=False, default='', help='folder to store images, default is overwrite in-place')
120
+ parser.add_argument('input', type=str, nargs='*')
121
+ args = parser.parse_args()
122
+ # log.info({ 'watermark args': vars(args), 'options': options })
123
+ for arg in args.input:
124
+ if os.path.isfile(arg):
125
+ watermark(args, arg)
126
+ elif os.path.isdir(arg):
127
+ for root, _dirs, files in os.walk(arg):
128
+ for f in files:
129
+ watermark(args, os.path.join(root, f))
cli/install-sf.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ import os
3
+ import re
4
+ import sys
5
+
6
+ torch_supported = ['211', '212','220','221']
7
+ cuda_supported = ['cu118', 'cu121']
8
+ python_supported = ['39', '310', '311']
9
+ repo_url = 'https://github.com/chengzeyi/stable-fast'
10
+ api_url = 'https://api.github.com/repos/chengzeyi/stable-fast/releases/tags/nightly'
11
+ path_url = '/releases/download/nightly'
12
+
13
+
14
+ def install_pip(arg: str):
15
+ import subprocess
16
+ cmd = f'"{sys.executable}" -m pip install -U {arg}'
17
+ print(f'Running: {cmd}')
18
+ result = subprocess.run(cmd, shell=True, check=False, env=os.environ)
19
+ return result.returncode == 0
20
+
21
+
22
+ def get_nightly():
23
+ import requests
24
+ r = requests.get(api_url, timeout=10)
25
+ if r.status_code != 200:
26
+ print('Failed to get nightly version')
27
+ return None
28
+ json = r.json()
29
+ assets = json.get('assets', [])
30
+ if len(assets) == 0:
31
+ print('Failed to get nightly version')
32
+ return None
33
+ asset = assets[0].get('name', '')
34
+ pattern = r"-(.+?)\+"
35
+ match = re.search(pattern, asset)
36
+ if match:
37
+ ver = match.group(1)
38
+ print(f'Nightly version: {ver}')
39
+ return ver
40
+ else:
41
+ print('Failed to get nightly version')
42
+ return None
43
+
44
+
45
+ def install_stable_fast():
46
+ import torch
47
+
48
+ python_ver = f'{sys.version_info.major}{sys.version_info.minor}'
49
+ if python_ver not in python_supported:
50
+ raise ValueError(f'StableFast unsupported python: {python_ver} required {python_supported}')
51
+ if sys.platform == 'linux':
52
+ bin_url = 'manylinux2014_x86_64.whl'
53
+ elif sys.platform == 'win32':
54
+ bin_url = 'win_amd64.whl'
55
+ else:
56
+ raise ValueError(f'StableFast unsupported platform: {sys.platform}')
57
+
58
+ torch_ver, cuda_ver = torch.__version__.split('+')
59
+ torch_ver = torch_ver.replace('.', '')
60
+ sf_ver = get_nightly()
61
+
62
+ if torch_ver not in torch_supported:
63
+ print(f'StableFast unsupported torch: {torch_ver} required {torch_supported}')
64
+ print('Installing from source...')
65
+ url = 'git+https://github.com/chengzeyi/stable-fast.git@main#egg=stable-fast'
66
+ elif cuda_ver not in cuda_supported:
67
+ print(f'StableFast unsupported CUDA: {cuda_ver} required {cuda_supported}')
68
+ print('Installing from source...')
69
+ url = 'git+https://github.com/chengzeyi/stable-fast.git@main#egg=stable-fast'
70
+ elif sf_ver is None:
71
+ print('StableFast cannot determine version')
72
+ print('Installing from source...')
73
+ url = 'git+https://github.com/chengzeyi/stable-fast.git@main#egg=stable-fast'
74
+ else:
75
+ print('Installing wheel...')
76
+ file_url = f'stable_fast-{sf_ver}+torch{torch_ver}{cuda_ver}-cp{python_ver}-cp{python_ver}-{bin_url}'
77
+ url = f'{repo_url}/{path_url}/{file_url}'
78
+
79
+ ok = install_pip(url)
80
+ if ok:
81
+ import sfast
82
+ print(f'StableFast installed: {sfast.__version__}')
83
+ else:
84
+ print('StableFast install failed')
85
+
86
+ if __name__ == '__main__':
87
+ install_stable_fast()
cli/latents.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+ import sys
5
+ import json
6
+ import pathlib
7
+ import argparse
8
+ import warnings
9
+
10
+ import cv2
11
+ import numpy as np
12
+ import torch
13
+ from PIL import Image
14
+ from torchvision import transforms
15
+ from tqdm import tqdm
16
+ from util import Map
17
+
18
+ from rich.pretty import install as pretty_install
19
+ from rich.traceback import install as traceback_install
20
+ from rich.console import Console
21
+
22
+ console = Console(log_time=True, log_time_format='%H:%M:%S-%f')
23
+ pretty_install(console=console)
24
+ traceback_install(console=console, extra_lines=1, width=console.width, word_wrap=False, indent_guides=False)
25
+
26
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'modules', 'lora'))
27
+ import library.model_util as model_util
28
+ import library.train_util as train_util
29
+
30
+ warnings.filterwarnings('ignore')
31
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
32
+ options = Map({
33
+ 'batch': 1,
34
+ 'input': '',
35
+ 'json': '',
36
+ 'max': 1024,
37
+ 'min': 256,
38
+ 'noupscale': False,
39
+ 'precision': 'fp32',
40
+ 'resolution': '512,512',
41
+ 'steps': 64,
42
+ 'vae': 'stabilityai/sd-vae-ft-mse'
43
+ })
44
+ vae = None
45
+
46
+
47
+ def get_latents(local_vae, images, weight_dtype):
48
+ image_transforms = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.5], [0.5]) ])
49
+ img_tensors = [image_transforms(image) for image in images]
50
+ img_tensors = torch.stack(img_tensors)
51
+ img_tensors = img_tensors.to(device, weight_dtype)
52
+ with torch.no_grad():
53
+ latents = local_vae.encode(img_tensors).latent_dist.sample().float().to('cpu').numpy()
54
+ return latents, [images[0].shape[0], images[0].shape[1]]
55
+
56
+
57
+ def get_npz_filename_wo_ext(data_dir, image_key):
58
+ return os.path.join(data_dir, os.path.splitext(os.path.basename(image_key))[0])
59
+
60
+
61
+ def create_vae_latents(local_params):
62
+ args = Map({**options, **local_params})
63
+ console.log(f'create vae latents args: {args}')
64
+ image_paths = train_util.glob_images(args.input)
65
+ if os.path.exists(args.json):
66
+ with open(args.json, 'rt', encoding='utf-8') as f:
67
+ metadata = json.load(f)
68
+ else:
69
+ return
70
+ if args.precision == 'fp16':
71
+ weight_dtype = torch.float16
72
+ elif args.precision == 'bf16':
73
+ weight_dtype = torch.bfloat16
74
+ else:
75
+ weight_dtype = torch.float32
76
+ global vae # pylint: disable=global-statement
77
+ if vae is None:
78
+ vae = model_util.load_vae(args.vae, weight_dtype)
79
+ vae.eval()
80
+ vae.to(device, dtype=weight_dtype)
81
+ max_reso = tuple([int(t) for t in args.resolution.split(',')])
82
+ assert len(max_reso) == 2, f'illegal resolution: {args.resolution}'
83
+ bucket_manager = train_util.BucketManager(args.noupscale, max_reso, args.min, args.max, args.steps)
84
+ if not args.noupscale:
85
+ bucket_manager.make_buckets()
86
+ img_ar_errors = []
87
+ def process_batch(is_last):
88
+ for bucket in bucket_manager.buckets:
89
+ if (is_last and len(bucket) > 0) or len(bucket) >= args.batch:
90
+ latents, original_size = get_latents(vae, [img for _, img in bucket], weight_dtype)
91
+ assert latents.shape[2] == bucket[0][1].shape[0] // 8 and latents.shape[3] == bucket[0][1].shape[1] // 8, f'latent shape {latents.shape}, {bucket[0][1].shape}'
92
+ for (image_key, _), latent in zip(bucket, latents):
93
+ npz_file_name = get_npz_filename_wo_ext(args.input, image_key)
94
+ # np.savez(npz_file_name, latent)
95
+ kwargs = {}
96
+ np.savez(
97
+ npz_file_name,
98
+ latents=latent,
99
+ original_size=np.array(original_size),
100
+ crop_ltrb=np.array([0, 0]),
101
+ **kwargs,
102
+ )
103
+ bucket.clear()
104
+ data = [[(None, ip)] for ip in image_paths]
105
+ bucket_counts = {}
106
+ for data_entry in tqdm(data, smoothing=0.0):
107
+ if data_entry[0] is None:
108
+ continue
109
+ img_tensor, image_path = data_entry[0]
110
+ if img_tensor is not None:
111
+ image = transforms.functional.to_pil_image(img_tensor)
112
+ else:
113
+ image = Image.open(image_path)
114
+ image_key = os.path.basename(image_path)
115
+ image_key = os.path.join(os.path.basename(pathlib.Path(image_path).parent), pathlib.Path(image_path).stem)
116
+ if image_key not in metadata:
117
+ metadata[image_key] = {}
118
+ reso, resized_size, ar_error = bucket_manager.select_bucket(image.width, image.height)
119
+ img_ar_errors.append(abs(ar_error))
120
+ bucket_counts[reso] = bucket_counts.get(reso, 0) + 1
121
+ metadata[image_key]['train_resolution'] = (reso[0] - reso[0] % 8, reso[1] - reso[1] % 8)
122
+ if not args.noupscale:
123
+ assert resized_size[0] == reso[0] or resized_size[1] == reso[1], f'internal error, resized size not match: {reso}, {resized_size}, {image.width}, {image.height}'
124
+ assert resized_size[0] >= reso[0] and resized_size[1] >= reso[1], f'internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}'
125
+ assert resized_size[0] >= reso[0] and resized_size[1] >= reso[1], f'internal error resized size is small: {resized_size}, {reso}'
126
+ image = np.array(image)
127
+ if resized_size[0] != image.shape[1] or resized_size[1] != image.shape[0]:
128
+ image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA)
129
+ if resized_size[0] > reso[0]:
130
+ trim_size = resized_size[0] - reso[0]
131
+ image = image[:, trim_size//2:trim_size//2 + reso[0]]
132
+ if resized_size[1] > reso[1]:
133
+ trim_size = resized_size[1] - reso[1]
134
+ image = image[trim_size//2:trim_size//2 + reso[1]]
135
+ assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f'internal error, illegal trimmed size: {image.shape}, {reso}'
136
+ bucket_manager.add_image(reso, (image_key, image))
137
+ process_batch(False)
138
+
139
+ process_batch(True)
140
+ vae.to('cpu')
141
+
142
+ bucket_manager.sort()
143
+ img_ar_errors = np.array(img_ar_errors)
144
+ for i, reso in enumerate(bucket_manager.resos):
145
+ count = bucket_counts.get(reso, 0)
146
+ if count > 0:
147
+ console.log(f'vae latents bucket: {i+1}/{len(bucket_manager.resos)} resolution: {reso} images: {count} mean-ar-error: {np.mean(img_ar_errors)}')
148
+ with open(args.json, 'wt', encoding='utf-8') as f:
149
+ json.dump(metadata, f, indent=2)
150
+
151
+
152
+ def unload_vae():
153
+ global vae # pylint: disable=global-statement
154
+ vae = None
155
+
156
+
157
+ if __name__ == '__main__':
158
+ parser = argparse.ArgumentParser()
159
+ parser.add_argument('input', type=str, help='directory for train images')
160
+ parser.add_argument('--json', type=str, required=True, help='metadata file to input')
161
+ parser.add_argument('--vae', type=str, required=True, help='model name or path to encode latents')
162
+ parser.add_argument('--batch', type=int, default=1, help='batch size in inference')
163
+ parser.add_argument('--resolution', type=str, default='512,512', help='max resolution in fine tuning (width,height)')
164
+ parser.add_argument('--min', type=int, default=256, help='minimum resolution for buckets')
165
+ parser.add_argument('--max', type=int, default=1024, help='maximum resolution for buckets')
166
+ parser.add_argument('--steps', type=int, default=64, help='steps of resolution for buckets, divisible by 8')
167
+ parser.add_argument('--noupscale', action='store_true', help='make bucket for each image without upscaling')
168
+ parser.add_argument('--precision', type=str, default='fp32', choices=['fp32', 'fp16', 'bf16'], help='use precision')
169
+ params = parser.parse_args()
170
+ create_vae_latents(vars(params))
cli/lcm-convert.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import torch
4
+ from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, AutoPipelineForText2Image, LCMScheduler
5
+
6
+ parser = argparse.ArgumentParser("lcm_convert")
7
+ parser.add_argument("--name", help="Name of the new LCM model", type=str)
8
+ parser.add_argument("--model", help="A model to convert", type=str)
9
+ parser.add_argument("--lora-scale", default=1.0, help="Strenght of the LCM", type=float)
10
+ parser.add_argument("--huggingface", action="store_true", help="Use Hugging Face models instead of safetensors models")
11
+ parser.add_argument("--upload", action="store_true", help="Upload the new LCM model to Hugging Face")
12
+ parser.add_argument("--no-half", action="store_true", help="Convert the new LCM model to FP32")
13
+ parser.add_argument("--no-save", action="store_true", help="Don't save the new LCM model to local disk")
14
+ parser.add_argument("--sdxl", action="store_true", help="Use SDXL models")
15
+ parser.add_argument("--ssd-1b", action="store_true", help="Use SSD-1B models")
16
+
17
+ args = parser.parse_args()
18
+
19
+ if args.huggingface:
20
+ pipeline = AutoPipelineForText2Image.from_pretrained(args.model, torch_dtype=torch.float16, variant="fp16")
21
+ else:
22
+ if args.sdxl or args.ssd_1b:
23
+ pipeline = StableDiffusionXLPipeline.from_single_file(args.model)
24
+ else:
25
+ pipeline = StableDiffusionPipeline.from_single_file(args.model)
26
+
27
+ pipeline.scheduler = LCMScheduler.from_config(pipeline.scheduler.config)
28
+ if args.sdxl:
29
+ pipeline.load_lora_weights("latent-consistency/lcm-lora-sdxl")
30
+ elif args.ssd_1b:
31
+ pipeline.load_lora_weights("latent-consistency/lcm-lora-ssd-1b")
32
+ else:
33
+ pipeline.load_lora_weights("latent-consistency/lcm-lora-sdv1-5")
34
+ pipeline.fuse_lora(lora_scale=args.lora_scale)
35
+
36
+ #components = pipeline.components
37
+ #pipeline = LatentConsistencyModelPipeline(**components)
38
+
39
+ if args.no_half:
40
+ pipeline = pipeline.to(dtype=torch.float32)
41
+ else:
42
+ pipeline = pipeline.to(dtype=torch.float16)
43
+ print(pipeline)
44
+
45
+ if not args.no_save:
46
+ os.makedirs(f"models--local--{args.name}/snapshots")
47
+ if args.no_half:
48
+ pipeline.save_pretrained(f"models--local--{args.name}/snapshots/{args.name}")
49
+ else:
50
+ pipeline.save_pretrained(f"models--local--{args.name}/snapshots/{args.name}", variant="fp16")
51
+ if args.upload:
52
+ if args.no_half:
53
+ pipeline.push_to_hub(args.name)
54
+ else:
55
+ pipeline.push_to_hub(args.name, variant="fp16")
cli/model-jit.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ import os
3
+ import time
4
+ import functools
5
+ import argparse
6
+ import logging
7
+ import warnings
8
+ from dataclasses import dataclass
9
+
10
+ logging.getLogger("DeepSpeed").disabled = True
11
+ warnings.filterwarnings(action="ignore", category=FutureWarning)
12
+ warnings.filterwarnings(action="ignore", category=DeprecationWarning)
13
+
14
+ import torch
15
+ import diffusers
16
+
17
+ n_warmup = 5
18
+ n_traces = 10
19
+ n_runs = 100
20
+ args = {}
21
+ pipe = None
22
+ log = logging.getLogger("sd")
23
+
24
+
25
+ def setup_logging():
26
+ from rich.theme import Theme
27
+ from rich.logging import RichHandler
28
+ from rich.console import Console
29
+ from rich.traceback import install
30
+ log.setLevel(logging.DEBUG)
31
+ console = Console(log_time=True, log_time_format='%H:%M:%S-%f', theme=Theme({ "traceback.border": "black", "traceback.border.syntax_error": "black", "inspect.value.border": "black" }))
32
+ logging.basicConfig(level=logging.ERROR, format='%(asctime)s | %(name)s | %(levelname)s | %(module)s | %(message)s', handlers=[logging.NullHandler()]) # redirect default logger to null
33
+ rh = RichHandler(show_time=True, omit_repeated_times=False, show_level=True, show_path=False, markup=False, rich_tracebacks=True, log_time_format='%H:%M:%S-%f', level=logging.DEBUG, console=console)
34
+ rh.setLevel(logging.DEBUG)
35
+ log.addHandler(rh)
36
+ logging.getLogger("diffusers").setLevel(logging.ERROR)
37
+ logging.getLogger("torch").setLevel(logging.ERROR)
38
+ warnings.filterwarnings(action="ignore", category=torch.jit.TracerWarning)
39
+ install(console=console, extra_lines=1, max_frames=10, width=console.width, word_wrap=False, indent_guides=False, suppress=[])
40
+
41
+
42
+ def generate_inputs():
43
+ if args.type == 'sd15':
44
+ sample = torch.randn(2, 4, 64, 64).half().cuda()
45
+ timestep = torch.rand(1).half().cuda() * 999
46
+ encoder_hidden_states = torch.randn(2, 77, 768).half().cuda()
47
+ return sample, timestep, encoder_hidden_states
48
+ if args.type == 'sdxl':
49
+ sample = torch.randn(2, 4, 64, 64).half().cuda()
50
+ timestep = torch.rand(1).half().cuda() * 999
51
+ encoder_hidden_states = torch.randn(2, 77, 768).half().cuda()
52
+ text_embeds = torch.randn(1, 77, 2048).half().cuda()
53
+ return sample, timestep, encoder_hidden_states, text_embeds
54
+
55
+
56
+ def load_model():
57
+ log.info(f'versions: torch={torch.__version__} diffusers={diffusers.__version__}')
58
+ diffusers_load_config = {
59
+ "low_cpu_mem_usage": True,
60
+ "torch_dtype": torch.float16,
61
+ "safety_checker": None,
62
+ "requires_safety_checker": False,
63
+ "load_safety_checker": False,
64
+ "load_connected_pipeline": True,
65
+ "use_safetensors": True,
66
+ }
67
+ pipeline = diffusers.StableDiffusionPipeline if args.type == 'sd15' else diffusers.StableDiffusionXLPipeline
68
+ global pipe # pylint: disable=global-statement
69
+ t0 = time.time()
70
+ pipe = pipeline.from_single_file(args.model, **diffusers_load_config).to('cuda')
71
+ size = os.path.getsize(args.model)
72
+ log.info(f'load: model={args.model} type={args.type} time={time.time() - t0:.3f}s size={size / 1024 / 1024:.3f}mb')
73
+
74
+
75
+ def load_trace(fn: str):
76
+
77
+ @dataclass
78
+ class UNet2DConditionOutput:
79
+ sample: torch.FloatTensor
80
+
81
+ class TracedUNet(torch.nn.Module):
82
+ def __init__(self):
83
+ super().__init__()
84
+ self.in_channels = pipe.unet.in_channels
85
+ self.device = pipe.unet.device
86
+
87
+ def forward(self, latent_model_input, t, encoder_hidden_states):
88
+ sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]
89
+ return UNet2DConditionOutput(sample=sample)
90
+
91
+ t0 = time.time()
92
+ unet_traced = torch.jit.load(fn)
93
+ pipe.unet = TracedUNet()
94
+ size = os.path.getsize(fn)
95
+ log.info(f'load: optimized={fn} time={time.time() - t0:.3f}s size={size / 1024 / 1024:.3f}mb')
96
+
97
+
98
+ def trace_model():
99
+ log.info(f'tracing model: {args.model}')
100
+ torch.set_grad_enabled(False)
101
+ unet = pipe.unet
102
+ unet.eval()
103
+ # unet.to(memory_format=torch.channels_last) # use channels_last memory format
104
+ unet.forward = functools.partial(unet.forward, return_dict=False) # set return_dict=False as default
105
+
106
+ # warmup
107
+ t0 = time.time()
108
+ for _ in range(n_warmup):
109
+ with torch.inference_mode():
110
+ inputs = generate_inputs()
111
+ _output = unet(*inputs)
112
+ log.info(f'warmup: time={time.time() - t0:.3f}s passes={n_warmup}')
113
+
114
+ # trace
115
+ t0 = time.time()
116
+ unet_traced = torch.jit.trace(unet, inputs, check_trace=True)
117
+ unet_traced.eval()
118
+ log.info(f'trace: time={time.time() - t0:.3f}s')
119
+
120
+ # optimize graph
121
+ t0 = time.time()
122
+ for _ in range(n_traces):
123
+ with torch.inference_mode():
124
+ inputs = generate_inputs()
125
+ _output = unet_traced(*inputs)
126
+ log.info(f'optimize: time={time.time() - t0:.3f}s passes={n_traces}')
127
+
128
+ # save the model
129
+ if args.save:
130
+ t0 = time.time()
131
+ basename, _ext = os.path.splitext(args.model)
132
+ fn = f"{basename}.pt"
133
+ unet_traced.save(fn)
134
+ size = os.path.getsize(fn)
135
+ log.info(f'save: optimized={fn} time={time.time() - t0:.3f}s size={size / 1024 / 1024:.3f}mb')
136
+ return fn
137
+
138
+ pipe.unet = unet_traced
139
+ return None
140
+
141
+
142
+ def benchmark_model(msg: str):
143
+ with torch.inference_mode():
144
+ inputs = generate_inputs()
145
+ torch.cuda.synchronize()
146
+ for n in range(n_runs):
147
+ if n > n_runs / 10:
148
+ t0 = time.time()
149
+ _output = pipe.unet(*inputs)
150
+ torch.cuda.synchronize()
151
+ t1 = time.time()
152
+ log.info(f"benchmark unet: {t1 - t0:.3f}s passes={n_runs} type={msg}")
153
+ return t1 - t0
154
+
155
+
156
+ if __name__ == '__main__':
157
+ parser = argparse.ArgumentParser(description = 'SD.Next')
158
+ parser.add_argument('--model', type=str, default='', required=True, help='model path')
159
+ parser.add_argument('--type', type=str, default='sd15', choices=['sd15', 'sdxl'], required=False, help='model type, default: %(default)s')
160
+ parser.add_argument('--benchmark', default = False, action='store_true', help = "run benchmarks, default: %(default)s")
161
+ parser.add_argument('--trace', default = True, action='store_true', help = "run jit tracing, default: %(default)s")
162
+ parser.add_argument('--save', default = False, action='store_true', help = "save optimized unet, default: %(default)s")
163
+ args = parser.parse_args()
164
+ setup_logging()
165
+ log.info('sdnext model jit tracing')
166
+ if not os.path.isfile(args.model):
167
+ log.error(f"invalid model path: {args.model}")
168
+ exit(1)
169
+ load_model()
170
+ if args.benchmark:
171
+ time0 = benchmark_model('original')
172
+ unet_saved = trace_model()
173
+ if unet_saved is not None:
174
+ load_trace(unet_saved)
175
+ if args.benchmark:
176
+ time1 = benchmark_model('traced')
177
+ log.info(f'benchmark speedup: {100 * (time0 - time1) / time0:.3f}%')
cli/model-metadata.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ import os
3
+ import sys
4
+ import json
5
+ from rich import print # pylint: disable=redefined-builtin
6
+
7
+
8
+ def read_metadata(fn):
9
+ res = {}
10
+ with open(fn, mode="rb") as f:
11
+ metadata_len = f.read(8)
12
+ metadata_len = int.from_bytes(metadata_len, "little")
13
+ json_start = f.read(2)
14
+ if metadata_len <= 2 or json_start not in (b'{"', b"{'"):
15
+ print(f"Not a valid safetensors file: {fn}")
16
+ json_data = json_start + f.read(metadata_len-2)
17
+ json_obj = json.loads(json_data)
18
+ for k, v in json_obj.get("__metadata__", {}).items():
19
+ res[k] = v
20
+ if isinstance(v, str) and v[0:1] == '{':
21
+ try:
22
+ res[k] = json.loads(v)
23
+ except Exception:
24
+ pass
25
+ print(f"{fn}: {json.dumps(res, indent=4)}")
26
+
27
+
28
+ def main():
29
+ if len(sys.argv) == 0:
30
+ print('metadata:', 'no files specified')
31
+ for fn in sys.argv:
32
+ if os.path.isfile(fn):
33
+ read_metadata(fn)
34
+ elif os.path.isdir(fn):
35
+ for root, _dirs, files in os.walk(fn):
36
+ for file in files:
37
+ read_metadata(os.path.join(root, file))
38
+
39
+ if __name__ == '__main__':
40
+ sys.argv.pop(0)
41
+ main()
cli/nvidia-smi.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ import os
3
+ import json
4
+ import shutil
5
+ import subprocess
6
+ import xmltodict
7
+ from rich import print # pylint: disable=redefined-builtin
8
+ from util import log, Map
9
+
10
+
11
+ def get_nvidia_smi(output='dict'):
12
+ smi = shutil.which('nvidia-smi')
13
+ if smi is None:
14
+ log.error("nvidia-smi not found")
15
+ return None
16
+ result = subprocess.run(f'"{smi}" -q -x', shell=True, check=False, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
17
+ xml = result.stdout.decode(encoding="utf8", errors="ignore")
18
+ d = xmltodict.parse(xml)
19
+ if 'nvidia_smi_log' in d:
20
+ d = d['nvidia_smi_log']
21
+ if 'gpu' in d and 'supported_clocks' in d['gpu']:
22
+ del d['gpu']['supported_clocks']
23
+ if output == 'dict':
24
+ return d
25
+ elif output == 'class' or output == 'map':
26
+ d = Map(d)
27
+ return d
28
+ elif output == 'json':
29
+ return json.dumps(d, indent=4)
30
+ return None
31
+
32
+
33
+ if __name__ == "__main__":
34
+ res = get_nvidia_smi(output='dict')
35
+ print(type(res), res)
cli/options.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from util import Map
2
+
3
+ embedding = Map({
4
+ "id_task": 0,
5
+ "embedding_name": "",
6
+ "learn_rate": -1,
7
+ "batch_size": 1,
8
+ "steps": 500,
9
+ "data_root": "",
10
+ "log_directory": "train/log",
11
+ "template_filename": "subject_filewords.txt",
12
+ "gradient_step": 20,
13
+ "training_width": 512,
14
+ "training_height": 512,
15
+ "shuffle_tags": False,
16
+ "tag_drop_out": 0,
17
+ "clip_grad_mode": "disabled",
18
+ "clip_grad_value": "0.1",
19
+ "latent_sampling_method": "deterministic",
20
+ "create_image_every": 0,
21
+ "save_embedding_every": 0,
22
+ "save_image_with_stored_embedding": False,
23
+ "preview_from_txt2img": False,
24
+ "preview_prompt": "",
25
+ "preview_negative_prompt": "blurry, duplicate, ugly, deformed, low res, watermark, text",
26
+ "preview_steps": 20,
27
+ "preview_sampler_index": 0,
28
+ "preview_cfg_scale": 6,
29
+ "preview_seed": -1,
30
+ "preview_width": 512,
31
+ "preview_height": 512,
32
+ "varsize": False,
33
+ "use_weight": False,
34
+ })
35
+
36
+ lora = Map({
37
+ "bucket_no_upscale": False,
38
+ "bucket_reso_steps": 64,
39
+ "cache_latents": True,
40
+ "caption_dropout_every_n_epochs": None,
41
+ "caption_dropout_rate": 0.0,
42
+ "caption_extension": ".txt",
43
+ "caption_extention": ".txt",
44
+ "caption_tag_dropout_rate": 0.0,
45
+ "clip_skip": None,
46
+ "color_aug": False,
47
+ "dataset_repeats": 1,
48
+ "debug_dataset": False,
49
+ "enable_bucket": False,
50
+ "face_crop_aug_range": None,
51
+ "flip_aug": False,
52
+ "full_fp16": False,
53
+ "gradient_accumulation_steps": 1,
54
+ "gradient_checkpointing": False,
55
+ "in_json": "",
56
+ "keep_tokens": None,
57
+ "learning_rate": 5e-05,
58
+ "log_prefix": None,
59
+ "logging_dir": None,
60
+ "lr_scheduler_num_cycles": 1,
61
+ "lr_scheduler_power": 1,
62
+ "lr_scheduler": "cosine",
63
+ "lr_warmup_steps": 0,
64
+ "max_bucket_reso": 1024,
65
+ "max_data_loader_n_workers": 8,
66
+ "max_grad_norm": 0.0,
67
+ "max_token_length": None,
68
+ "max_train_epochs": None,
69
+ "max_train_steps": 2500,
70
+ "mem_eff_attn": False,
71
+ "min_bucket_reso": 256,
72
+ "mixed_precision": "fp16",
73
+ "network_alpha": 1.0,
74
+ "network_args": None,
75
+ "network_dim": 16,
76
+ "network_module": "networks.lora",
77
+ "network_train_text_encoder_only": False,
78
+ "network_train_unet_only": False,
79
+ "network_weights": None,
80
+ "no_metadata": False,
81
+ "output_dir": "",
82
+ "output_name": "",
83
+ "persistent_data_loader_workers": False,
84
+ "pretrained_model_name_or_path": "",
85
+ "prior_loss_weight": 1.0,
86
+ "random_crop": False,
87
+ "reg_data_dir": None,
88
+ "resolution": "512,512",
89
+ "resume": None,
90
+ "save_every_n_epochs": None,
91
+ "save_last_n_epochs_state": None,
92
+ "save_last_n_epochs": None,
93
+ "save_model_as": "ckpt",
94
+ "save_n_epoch_ratio": None,
95
+ "save_precision": "fp16",
96
+ "save_state": False,
97
+ "seed": 42,
98
+ "shuffle_caption": False,
99
+ "text_encoder_lr": 5e-05,
100
+ "train_batch_size": 1,
101
+ "train_data_dir": "",
102
+ "training_comment": "",
103
+ "unet_lr": 1e-04,
104
+ "use_8bit_adam": False,
105
+ "v_parameterization": False,
106
+ "v2": False,
107
+ "vae": None,
108
+ "xformers": False,
109
+ })
110
+
111
+ process = Map({
112
+ # general settings, do not modify
113
+ 'format': '.jpg', # image format
114
+ 'target_size': 512, # target resolution
115
+ 'segmentation_model': 0, # segmentation model 0/general 1/landscape
116
+ 'segmentation_background': (192, 192, 192), # segmentation background color
117
+ 'blur_score': 1.8, # max score for face blur detection
118
+ 'blur_samplesize': 60, # sample size to use for blur detection
119
+ 'similarity_score': 0.8, # maximum similarity score before image is discarded
120
+ 'similarity_size': 64, # base similarity detection on reduced images
121
+ 'range_score': 0.15, # min score for face color dynamicrange detection
122
+ # face processing settings
123
+ 'face_score': 0.7, # min face detection score
124
+ 'face_pad': 0.1, # pad face image percentage
125
+ 'face_model': 1, # which face model to use 0/close-up 1/standard
126
+ # body processing settings
127
+ 'body_score': 0.9, # min body detection score
128
+ 'body_visibility': 0.5, # min visibility score for each detected body part
129
+ 'body_parts': 15, # min number of detected body parts with sufficient visibility
130
+ 'body_pad': 0.2, # pad body image percentage
131
+ 'body_model': 2, # body model to use 0/low 1/medium 2/high
132
+ # similarity detection settings
133
+ # interrogate settings
134
+ 'interrogate': False, # interrogate images
135
+ 'interrogate_model': ['clip', 'deepdanbooru'], # interrogate models
136
+ 'tag_limit': 5, # number of tags to extract
137
+ # validations
138
+ # tbd
139
+ 'face_segmentation': False, # segmentation enabled
140
+ 'body_segmentation': False, # segmentation enabled
141
+ })
cli/process.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: disable=global-statement
2
+ import os
3
+ import io
4
+ import math
5
+ import base64
6
+ import numpy as np
7
+ import mediapipe as mp
8
+ from PIL import Image, ImageOps
9
+ from pi_heif import register_heif_opener
10
+ from skimage.metrics import structural_similarity as ssim
11
+ from scipy.stats import beta
12
+
13
+ import util
14
+ import sdapi
15
+ import options
16
+
17
+ face_model = None
18
+ body_model = None
19
+ segmentation_model = None
20
+ all_images = []
21
+ all_images_by_type = {}
22
+
23
+
24
+ class Result():
25
+ def __init__(self, typ: str, fn: str, tag: str = None, requested: list = []):
26
+ self.type = typ
27
+ self.input = fn
28
+ self.output = ''
29
+ self.basename = ''
30
+ self.message = ''
31
+ self.image = None
32
+ self.caption = ''
33
+ self.tag = tag
34
+ self.tags = []
35
+ self.ops = []
36
+ self.steps = requested
37
+
38
+
39
+ def detect_blur(image: Image):
40
+ # based on <https://github.com/karthik9319/Blur-Detection/>
41
+ bw = ImageOps.grayscale(image)
42
+ cx, cy = image.size[0] // 2, image.size[1] // 2
43
+ fft = np.fft.fft2(bw)
44
+ fftShift = np.fft.fftshift(fft)
45
+ fftShift[cy - options.process.blur_samplesize: cy + options.process.blur_samplesize, cx - options.process.blur_samplesize: cx + options.process.blur_samplesize] = 0
46
+ fftShift = np.fft.ifftshift(fftShift)
47
+ recon = np.fft.ifft2(fftShift)
48
+ magnitude = np.log(np.abs(recon))
49
+ mean = round(np.mean(magnitude), 2)
50
+ return mean
51
+
52
+
53
+ def detect_dynamicrange(image: Image):
54
+ # based on <https://towardsdatascience.com/measuring-enhancing-image-quality-attributes-234b0f250e10>
55
+ data = np.asarray(image)
56
+ image = np.float32(data)
57
+ RGB = [0.299, 0.587, 0.114]
58
+ height, width = image.shape[:2] # pylint: disable=unsubscriptable-object
59
+ brightness_image = np.sqrt(image[..., 0] ** 2 * RGB[0] + image[..., 1] ** 2 * RGB[1] + image[..., 2] ** 2 * RGB[2]) # pylint: disable=unsubscriptable-object
60
+ hist, _ = np.histogram(brightness_image, bins=256, range=(0, 255))
61
+ img_brightness_pmf = hist / (height * width)
62
+ dist = beta(2, 2)
63
+ ys = dist.pdf(np.linspace(0, 1, 256))
64
+ ref_pmf = ys / np.sum(ys)
65
+ dot_product = np.dot(ref_pmf, img_brightness_pmf)
66
+ squared_dist_a = np.sum(ref_pmf ** 2)
67
+ squared_dist_b = np.sum(img_brightness_pmf ** 2)
68
+ res = dot_product / math.sqrt(squared_dist_a * squared_dist_b)
69
+ return round(res, 2)
70
+
71
+
72
+ def detect_simmilar(image: Image):
73
+ img = image.resize((options.process.similarity_size, options.process.similarity_size))
74
+ img = ImageOps.grayscale(img)
75
+ data = np.array(img)
76
+ similarity = 0
77
+ for i in all_images:
78
+ val = ssim(data, i, data_range=255, channel_axis=None, gradient=False, full=False)
79
+ if val > similarity:
80
+ similarity = val
81
+ all_images.append(data)
82
+ return similarity
83
+
84
+
85
+ def segmentation(res: Result):
86
+ global segmentation_model
87
+ if segmentation_model is None:
88
+ segmentation_model = mp.solutions.selfie_segmentation.SelfieSegmentation(model_selection=options.process.segmentation_model)
89
+ data = np.array(res.image)
90
+ results = segmentation_model.process(data)
91
+ condition = np.stack((results.segmentation_mask,) * 3, axis=-1) > 0.1
92
+ background = np.zeros(data.shape, dtype=np.uint8)
93
+ background[:] = options.process.segmentation_background
94
+ data = np.where(condition, data, background) # consider using a joint bilateral filter instead of pure combine
95
+ segmented = Image.fromarray(data)
96
+ res.image = segmented
97
+ res.ops.append('segmentation')
98
+ return res
99
+
100
+
101
+ def unload():
102
+ global face_model
103
+ if face_model is not None:
104
+ face_model = None
105
+ global body_model
106
+ if body_model is not None:
107
+ body_model = None
108
+ global segmentation_model
109
+ if segmentation_model is not None:
110
+ segmentation_model = None
111
+
112
+
113
+ def encode(img):
114
+ with io.BytesIO() as stream:
115
+ img.save(stream, 'JPEG')
116
+ values = stream.getvalue()
117
+ encoded = base64.b64encode(values).decode()
118
+ return encoded
119
+
120
+
121
+ def reset():
122
+ unload()
123
+ global all_images_by_type
124
+ all_images_by_type = {}
125
+ global all_images
126
+ all_images = []
127
+
128
+
129
+ def upscale_restore_image(res: Result, upscale: bool = False, restore: bool = False):
130
+ kwargs = util.Map({
131
+ 'image': encode(res.image),
132
+ 'codeformer_visibility': 0.0,
133
+ 'codeformer_weight': 0.0,
134
+ })
135
+ if res.image.width >= options.process.target_size and res.image.height >= options.process.target_size:
136
+ upscale = False
137
+ if upscale:
138
+ kwargs.upscaler_1 = 'SwinIR_4x'
139
+ kwargs.upscaling_resize = 2
140
+ res.ops.append('upscale')
141
+ if restore:
142
+ kwargs.codeformer_visibility = 1.0
143
+ kwargs.codeformer_weight = 0.2
144
+ res.ops.append('restore')
145
+ if upscale or restore:
146
+ result = sdapi.postsync('/sdapi/v1/extra-single-image', kwargs)
147
+ if 'image' not in result:
148
+ res.message = 'failed to upscale/restore image'
149
+ else:
150
+ res.image = Image.open(io.BytesIO(base64.b64decode(result['image'])))
151
+ return res
152
+
153
+
154
+ def interrogate_image(res: Result, tag: str = None):
155
+ caption = ''
156
+ tags = []
157
+ for model in options.process.interrogate_model:
158
+ json = util.Map({ 'image': encode(res.image), 'model': model })
159
+ result = sdapi.postsync('/sdapi/v1/interrogate', json)
160
+ if model == 'clip':
161
+ caption = result.caption if 'caption' in result else ''
162
+ caption = caption.split(',')[0].replace(' a ', ' ').strip()
163
+ if tag is not None:
164
+ caption = res.tag + ', ' + caption
165
+ if model == 'deepdanbooru':
166
+ tag = result.caption if 'caption' in result else ''
167
+ tags = tag.split(',')
168
+ tags = [t.replace('(', '').replace(')', '').replace('\\', '').split(':')[0].strip() for t in tags]
169
+ if tag is not None:
170
+ for t in res.tag.split(',')[::-1]:
171
+ tags.insert(0, t.strip())
172
+ pos = 0 if len(tags) == 0 else 1
173
+ tags.insert(pos, caption.split(' ')[1])
174
+ tags = [t for t in tags if len(t) > 2]
175
+ if len(tags) > options.process.tag_limit:
176
+ tags = tags[:options.process.tag_limit]
177
+ res.caption = caption
178
+ res.tags = tags
179
+ res.ops.append('interrogate')
180
+ return res
181
+
182
+
183
+ def resize_image(res: Result):
184
+ resized = res.image
185
+ resized.thumbnail((options.process.target_size, options.process.target_size), Image.Resampling.HAMMING)
186
+ res.image = resized
187
+ res.ops.append('resize')
188
+ return res
189
+
190
+
191
+ def square_image(res: Result):
192
+ size = max(res.image.width, res.image.height)
193
+ squared = Image.new('RGB', (size, size))
194
+ squared.paste(res.image, ((size - res.image.width) // 2, (size - res.image.height) // 2))
195
+ res.image = squared
196
+ res.ops.append('square')
197
+ return res
198
+
199
+
200
+ def process_face(res: Result):
201
+ res.ops.append('face')
202
+ global face_model
203
+ if face_model is None:
204
+ face_model = mp.solutions.face_detection.FaceDetection(min_detection_confidence=options.process.face_score, model_selection=options.process.face_model)
205
+ results = face_model.process(np.array(res.image))
206
+ if results.detections is None:
207
+ res.message = 'no face detected'
208
+ res.image = None
209
+ return res
210
+ box = results.detections[0].location_data.relative_bounding_box
211
+ if box.xmin < 0 or box.ymin < 0 or (box.width - box.xmin) > 1 or (box.height - box.ymin) > 1:
212
+ res.message = 'face out of frame'
213
+ res.image = None
214
+ return res
215
+ x = max(0, (box.xmin - options.process.face_pad / 2) * res.image.width)
216
+ y = max(0, (box.ymin - options.process.face_pad / 2)* res.image.height)
217
+ w = min(res.image.width, (box.width + options.process.face_pad) * res.image.width)
218
+ h = min(res.image.height, (box.height + options.process.face_pad) * res.image.height)
219
+ x = max(0, x)
220
+ res.image = res.image.crop((x, y, x + w, y + h))
221
+ return res
222
+
223
+
224
+ def process_body(res: Result):
225
+ res.ops.append('body')
226
+ global body_model
227
+ if body_model is None:
228
+ body_model = mp.solutions.pose.Pose(static_image_mode=True, min_detection_confidence=options.process.body_score, model_complexity=options.process.body_model)
229
+ results = body_model.process(np.array(res.image))
230
+ if results.pose_landmarks is None:
231
+ res.message = 'no body detected'
232
+ res.image = None
233
+ return res
234
+ x0 = [res.image.width * (i.x - options.process.body_pad / 2) for i in results.pose_landmarks.landmark if i.visibility > options.process.body_visibility]
235
+ y0 = [res.image.height * (i.y - options.process.body_pad / 2) for i in results.pose_landmarks.landmark if i.visibility > options.process.body_visibility]
236
+ x1 = [res.image.width * (i.x + options.process.body_pad / 2) for i in results.pose_landmarks.landmark if i.visibility > options.process.body_visibility]
237
+ y1 = [res.image.height * (i.y + options.process.body_pad / 2) for i in results.pose_landmarks.landmark if i.visibility > options.process.body_visibility]
238
+ if len(x0) < options.process.body_parts:
239
+ res.message = f'insufficient body parts detected: {len(x0)}'
240
+ res.image = None
241
+ return res
242
+ res.image = res.image.crop((max(0, min(x0)), max(0, min(y0)), min(res.image.width, max(x1)), min(res.image.height, max(y1))))
243
+ return res
244
+
245
+
246
+ def process_original(res: Result):
247
+ res.ops.append('original')
248
+ return res
249
+
250
+
251
+ def save_image(res: Result, folder: str):
252
+ if res.image is None or folder is None:
253
+ return res
254
+ all_images_by_type[res.type] = all_images_by_type.get(res.type, 0) + 1
255
+ res.basename = os.path.basename(res.input).split('.')[0]
256
+ res.basename = str(all_images_by_type[res.type]).rjust(3, '0') + '-' + res.type + '-' + res.basename
257
+ res.basename = os.path.join(folder, res.basename)
258
+ res.output = res.basename + options.process.format
259
+ res.image.save(res.output)
260
+ res.image.close()
261
+ res.ops.append('save')
262
+ return res
263
+
264
+
265
+ def file(filename: str, folder: str, tag = None, requested = []):
266
+ # initialize result dict
267
+ res = Result(fn = filename, typ='unknown', tag=tag, requested = requested)
268
+ # open image
269
+ try:
270
+ register_heif_opener()
271
+ res.image = Image.open(filename)
272
+ if res.image.mode == 'RGBA':
273
+ res.image = res.image.convert('RGB')
274
+ res.image = ImageOps.exif_transpose(res.image) # rotate image according to EXIF orientation
275
+ except Exception as e:
276
+ res.message = f'error opening: {e}'
277
+ return res
278
+ # primary steps
279
+ if 'face' in requested:
280
+ res.type = 'face'
281
+ res = process_face(res)
282
+ elif 'body' in requested:
283
+ res.type = 'body'
284
+ res = process_body(res)
285
+ elif 'original' in requested:
286
+ res.type = 'original'
287
+ res = process_original(res)
288
+ # validation steps
289
+ if res.image is None:
290
+ return res
291
+ if 'blur' in requested:
292
+ res.ops.append('blur')
293
+ val = detect_blur(res.image)
294
+ if val > options.process.blur_score:
295
+ res.message = f'blur check failed: {val}'
296
+ res.image = None
297
+ if 'range' in requested:
298
+ res.ops.append('range')
299
+ val = detect_dynamicrange(res.image)
300
+ if val < options.process.range_score:
301
+ res.message = f'dynamic range check failed: {val}'
302
+ res.image = None
303
+ if 'similarity' in requested:
304
+ res.ops.append('similarity')
305
+ val = detect_simmilar(res.image)
306
+ if val > options.process.similarity_score:
307
+ res.message = f'dynamic range check failed: {val}'
308
+ res.image = None
309
+ if res.image is None:
310
+ return res
311
+ # post processing steps
312
+ res = upscale_restore_image(res, 'upscale' in requested, 'restore' in requested)
313
+ if res.image.width < options.process.target_size or res.image.height < options.process.target_size:
314
+ res.message = f'low resolution: [{res.image.width}, {res.image.height}]'
315
+ res.image = None
316
+ return res
317
+ if 'interrogate' in requested:
318
+ res = interrogate_image(res, tag)
319
+ if 'resize' in requested:
320
+ res = resize_image(res)
321
+ if 'square' in requested:
322
+ res = square_image(res)
323
+ if 'segment' in requested:
324
+ res = segmentation(res)
325
+ # finally save image
326
+ res = save_image(res, folder)
327
+ return res
cli/random.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompts": [
3
+ "<style> of <embedding> <place>, high detailed, by <artist>, <suffix>"
4
+ ],
5
+ "negative": [
6
+ "watermark, fog, clouds, blurry, duplicate, deformed, mutation"
7
+ ],
8
+ "places": [
9
+ "standing in the city", "on a spaceship", "in fantasy landscape", "on a shore", "in a forest", "in winter wonderland"
10
+ ],
11
+ "embeddings": [
12
+ "man", "man next to a beautiful girl", "man next to a car", "beautiful girl", "sexy naked girl", "cute girl holding a flower", "beautiful robot",
13
+ "young korean girl with medium-length white hair", "monster", "pin up girl",
14
+ "man vlado", "beutiful girl ana", "man lee", "beautiful girl abby"
15
+ ],
16
+ "artists": [
17
+ "John Salminen", "Greg Rutkowski", "Akihiko Yoshida", "Alejandro Burdisio", "Artgerm", "Patrick Brown", "Walt Disney", "Neal Adams", "Jeremy Chong",
18
+ "Chris Rallis", "Roy Lichtenstein", "Claude Monet", "Jon Whitcomb", "Pablo Picasso", "Raymond Leech", "Tom Lovell", "Noriyoshi Ohrai", "Shingei",
19
+ "Helmut Newton", "Maciej Kuciara", "Daniel F. Gerhartz", "Stephan Martiniรจre", "Magali Villeneuve", "Carne Griffiths", "Alberto Seveso",
20
+ "Vincent Van Gogh", "WLOP", "Frank Xavier Leyendecker", "Peter Lindbergh", "Nick Gentry", "Howard Chandler Christy", "Raphael", "Henri Matisse"
21
+ ],
22
+ "styles": [
23
+ "illustration", "painting", "portrait", "photograph", "drawing", "sketch", "pencil sketch", "3d render", "cartoon", "anime", "scribbles", "pop art",
24
+ "ink painting", "steampunk illustration", "dc comics illustration", "marvel comics", "vray render", "photoillustration", "pixar", "marble sculpture",
25
+ "bronze sculpture", "christmas theme"
26
+ ],
27
+ "suffixes": [
28
+ "cinematic lighting", "artstation", "fineart", "cinematic", "photorealistic", "soft light", "sharp focus", "bokeh", "dreamlike", "semirealism",
29
+ "colorful", "black and white", "intricate", "elegant"
30
+ ]
31
+ }
cli/requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ aiohttp
2
+ mediapipe
3
+ extcolors
4
+ colormap
5
+ filetype
6
+ albumentations
7
+ matplotlib
cli/run-benchmark.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ sd api txt2img benchmark
4
+ """
5
+ import os
6
+ import asyncio
7
+ import base64
8
+ import io
9
+ import json
10
+ import time
11
+ import argparse
12
+ from PIL import Image
13
+ import sdapi
14
+ from util import Map, log
15
+
16
+
17
+ oom = 0
18
+ args = None
19
+ options = None
20
+
21
+
22
+ async def txt2img():
23
+ t0 = time.perf_counter()
24
+ data = {}
25
+ try:
26
+ data = await sdapi.post('/sdapi/v1/txt2img', options)
27
+ except Exception:
28
+ return -1
29
+ if 'error' in data:
30
+ return -1
31
+ if 'info' in data:
32
+ info = Map(json.loads(data['info']))
33
+ else:
34
+ return 0
35
+ log.debug({ 'info': info })
36
+ if options['batch_size'] != len(data['images']):
37
+ log.error({ 'requested': options['batch_size'], 'received': len(data['images']) })
38
+ return 0
39
+ for i in range(len(data['images'])):
40
+ data['images'][i] = Image.open(io.BytesIO(base64.b64decode(data['images'][i].split(',',1)[0])))
41
+ if args.save:
42
+ fn = os.path.join(args.save, f'benchmark-{i}-{len(data["images"])}.png')
43
+ data["images"][i].save(fn)
44
+ log.debug({ 'save': fn })
45
+ log.debug({ "images": data["images"] })
46
+ t1 = time.perf_counter()
47
+ return t1 - t0
48
+
49
+
50
+ def memstats():
51
+ mem = sdapi.getsync('/sdapi/v1/memory')
52
+ cpu = mem.get('ram', 'unavailable')
53
+ gpu = mem.get('cuda', 'unavailable')
54
+ if 'active' in gpu:
55
+ gpu['session'] = gpu.pop('active')
56
+ if 'reserved' in gpu:
57
+ gpu.pop('allocated')
58
+ gpu.pop('reserved')
59
+ gpu.pop('inactive')
60
+ if 'events' in gpu:
61
+ global oom # pylint: disable=global-statement
62
+ oom = gpu['events']['oom']
63
+ gpu.pop('events')
64
+ return cpu, gpu
65
+
66
+
67
+ def gb(val: float):
68
+ return round(val / 1024 / 1024 / 1024, 2)
69
+
70
+
71
+ async def main():
72
+ sdapi.quiet = True
73
+ await sdapi.session()
74
+ await sdapi.interrupt()
75
+ ver = await sdapi.get("/sdapi/v1/version")
76
+ log.info({ 'version': ver})
77
+ platform = await sdapi.get("/sdapi/v1/platform")
78
+ log.info({ 'platform': platform })
79
+ opts = await sdapi.get('/sdapi/v1/options')
80
+ opts = Map(opts)
81
+ log.info({ 'model': opts.sd_model_checkpoint })
82
+ cpu, gpu = memstats()
83
+ log.info({ 'system': { 'cpu': cpu, 'gpu': gpu }})
84
+ batch = [1, 1, 2, 4, 8, 12, 16, 24, 32, 48, 64, 96, 128, 192, 256]
85
+ batch = [b for b in batch if b <= args.maxbatch]
86
+ log.info({"batch-sizes": batch})
87
+ for i in range(len(batch)):
88
+ if oom > 0:
89
+ continue
90
+ options['batch_size'] = batch[i]
91
+ warmup = await txt2img()
92
+ ts = await txt2img()
93
+ if i == 0:
94
+ ts += warmup
95
+ if ts > 0.01: # cannot be faster than 10ms per run
96
+ await asyncio.sleep(0)
97
+ cpu, gpu = memstats()
98
+ if i == 0:
99
+ log.info({ 'warmup': round(ts, 2) })
100
+ else:
101
+ peak = gpu['system']['used'] # gpu['session']['peak'] if 'session' in gpu else 0
102
+ log.info({ 'batch': batch[i], 'its': round(options.steps / (ts / batch[i]), 2), 'img': round(ts / batch[i], 2), 'wall': round(ts, 2), 'peak': gb(peak), 'oom': oom > 0 })
103
+ else:
104
+ await asyncio.sleep(10)
105
+ cpu, gpu = memstats()
106
+ log.info({ 'batch': batch[i], 'result': 'error', 'gpu': gpu, 'oom': oom > 0 })
107
+ break
108
+ if oom > 0:
109
+ log.info({ 'benchmark': 'ended with oom so you should probably restart your automatic server now' })
110
+ await sdapi.close()
111
+
112
+
113
+ if __name__ == '__main__':
114
+ log.info({ 'run-benchmark' })
115
+ parser = argparse.ArgumentParser(description = 'run-benchmark')
116
+ parser.add_argument("--steps", type=int, default=50, required=False, help="steps")
117
+ parser.add_argument("--sampler", type=str, default='Euler a', required=False, help="Use specific sampler")
118
+ parser.add_argument("--prompt", type=str, default='photo of two dice on a table', required=False, help="prompt")
119
+ parser.add_argument("--negative", type=str, default='foggy, blurry', required=False, help="prompt")
120
+ parser.add_argument("--maxbatch", type=int, default=16, required=False, help="max batch size")
121
+ parser.add_argument("--width", type=int, default=512, required=False, help="width")
122
+ parser.add_argument("--height", type=int, default=512, required=False, help="height")
123
+ parser.add_argument('--debug', default = False, action='store_true', help = 'debug logging')
124
+ parser.add_argument('--taesd', default = False, action='store_true', help = 'use taesd as vae')
125
+ parser.add_argument("--save", type=str, default='', required=False, help="save images to folder")
126
+ args = parser.parse_args()
127
+ if args.debug:
128
+ log.setLevel('DEBUG')
129
+ options = Map(
130
+ {
131
+ "prompt": args.prompt,
132
+ "negative_prompt": args.negative,
133
+ "steps": args.steps,
134
+ "sampler_name": args.sampler,
135
+ "width": args.width,
136
+ "height": args.height,
137
+ "full_quality": not args.taesd,
138
+ "cfg_scale": 0,
139
+ "batch_size": 1,
140
+ "n_iter": 1,
141
+ "seed": -1,
142
+ }
143
+ )
144
+ log.info({"options": options})
145
+ try:
146
+ asyncio.run(main())
147
+ except KeyboardInterrupt:
148
+ log.warning({ 'interrupted': 'keyboard request' })
149
+ sdapi.interruptsync()
cli/sdapi.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ #pylint: disable=redefined-outer-name
3
+ """
4
+ helper methods that creates HTTP session with managed connection pool
5
+ provides async HTTP get/post methods and several helper methods
6
+ """
7
+
8
+ import io
9
+ import os
10
+ import sys
11
+ import ssl
12
+ import base64
13
+ import asyncio
14
+ import logging
15
+ import aiohttp
16
+ import requests
17
+ import urllib3
18
+ from PIL import Image
19
+ from util import Map, log
20
+ from rich import print # pylint: disable=redefined-builtin
21
+
22
+
23
+ sd_url = os.environ.get('SDAPI_URL', "http://127.0.0.1:7860") # api url root
24
+ sd_username = os.environ.get('SDAPI_USR', None)
25
+ sd_password = os.environ.get('SDAPI_PWD', None)
26
+
27
+ use_session = True
28
+ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
29
+ ssl.create_default_context = ssl._create_unverified_context # pylint: disable=protected-access
30
+ timeout = aiohttp.ClientTimeout(total = None, sock_connect = 10, sock_read = None) # default value is 5 minutes, we need longer for training
31
+ sess = None
32
+ quiet = False
33
+ BaseThreadPolicy = asyncio.WindowsSelectorEventLoopPolicy if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy") else asyncio.DefaultEventLoopPolicy
34
+
35
+
36
+ class AnyThreadEventLoopPolicy(BaseThreadPolicy):
37
+ def get_event_loop(self) -> asyncio.AbstractEventLoop:
38
+ try:
39
+ return super().get_event_loop()
40
+ except (RuntimeError, AssertionError):
41
+ loop = self.new_event_loop()
42
+ self.set_event_loop(loop)
43
+ return loop
44
+
45
+ asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
46
+
47
+
48
+ def authsync():
49
+ if sd_username is not None and sd_password is not None:
50
+ return requests.auth.HTTPBasicAuth(sd_username, sd_password)
51
+ return None
52
+
53
+
54
+ def auth():
55
+ if sd_username is not None and sd_password is not None:
56
+ return aiohttp.BasicAuth(sd_username, sd_password)
57
+ return None
58
+
59
+
60
+ async def result(req):
61
+ if req.status != 200:
62
+ if not quiet:
63
+ log.error({ 'request error': req.status, 'reason': req.reason, 'url': req.url })
64
+ if not use_session and sess is not None:
65
+ await sess.close()
66
+ return Map({ 'error': req.status, 'reason': req.reason, 'url': req.url })
67
+ else:
68
+ json = await req.json()
69
+ if isinstance(json, list):
70
+ res = json
71
+ elif json is None:
72
+ res = {}
73
+ else:
74
+ res = Map(json)
75
+ log.debug({ 'request': req.status, 'url': req.url, 'reason': req.reason })
76
+ return res
77
+
78
+
79
+ def resultsync(req: requests.Response):
80
+ if req.status_code != 200:
81
+ if not quiet:
82
+ log.error({ 'request error': req.status_code, 'reason': req.reason, 'url': req.url })
83
+ return Map({ 'error': req.status_code, 'reason': req.reason, 'url': req.url })
84
+ else:
85
+ json = req.json()
86
+ if isinstance(json, list):
87
+ res = json
88
+ elif json is None:
89
+ res = {}
90
+ else:
91
+ res = Map(json)
92
+ log.debug({ 'request': req.status_code, 'url': req.url, 'reason': req.reason })
93
+ return res
94
+
95
+
96
+ async def get(endpoint: str, json: dict = None):
97
+ global sess # pylint: disable=global-statement
98
+ sess = sess if sess is not None else await session()
99
+ try:
100
+ async with sess.get(url=endpoint, json=json, verify_ssl=False) as req:
101
+ res = await result(req)
102
+ return res
103
+ except Exception as err:
104
+ log.error({ 'session': err })
105
+ return {}
106
+
107
+
108
+ def getsync(endpoint: str, json: dict = None):
109
+ try:
110
+ req = requests.get(f'{sd_url}{endpoint}', json=json, verify=False, auth=authsync()) # pylint: disable=missing-timeout
111
+ res = resultsync(req)
112
+ return res
113
+ except Exception as err:
114
+ log.error({ 'session': err })
115
+ return {}
116
+
117
+
118
+ async def post(endpoint: str, json: dict = None):
119
+ global sess # pylint: disable=global-statement
120
+ # sess = sess if sess is not None else await session()
121
+ if sess and not sess.closed:
122
+ await sess.close()
123
+ sess = await session()
124
+ try:
125
+ async with sess.post(url=endpoint, json=json, verify_ssl=False) as req:
126
+ res = await result(req)
127
+ return res
128
+ except Exception as err:
129
+ log.error({ 'session': err })
130
+ return {}
131
+
132
+
133
+ def postsync(endpoint: str, json: dict = None):
134
+ req = requests.post(f'{sd_url}{endpoint}', json=json, verify=False, auth=authsync()) # pylint: disable=missing-timeout
135
+ res = resultsync(req)
136
+ return res
137
+
138
+
139
+ async def interrupt():
140
+ res = await get('/sdapi/v1/progress?skip_current_image=true')
141
+ if 'state' in res and res.state.job_count > 0:
142
+ log.debug({ 'interrupt': res.state })
143
+ res = await post('/sdapi/v1/interrupt')
144
+ await asyncio.sleep(1)
145
+ return res
146
+ else:
147
+ log.debug({ 'interrupt': 'idle' })
148
+ return { 'interrupt': 'idle' }
149
+
150
+
151
+ def interruptsync():
152
+ res = getsync('/sdapi/v1/progress?skip_current_image=true')
153
+ if 'state' in res and res.state.job_count > 0:
154
+ log.debug({ 'interrupt': res.state })
155
+ res = postsync('/sdapi/v1/interrupt')
156
+ return res
157
+ else:
158
+ log.debug({ 'interrupt': 'idle' })
159
+ return { 'interrupt': 'idle' }
160
+
161
+
162
+ async def progress():
163
+ res = await get('/sdapi/v1/progress?skip_current_image=false')
164
+ try:
165
+ if res is not None and res.get('current_image', None) is not None:
166
+ res.current_image = Image.open(io.BytesIO(base64.b64decode(res['current_image'])))
167
+ except Exception:
168
+ pass
169
+ log.debug({ 'progress': res })
170
+ return res
171
+
172
+
173
+ def progresssync():
174
+ res = getsync('/sdapi/v1/progress?skip_current_image=true')
175
+ log.debug({ 'progress': res })
176
+ return res
177
+
178
+
179
+ def get_log():
180
+ res = getsync('/sdapi/v1/log')
181
+ for line in res:
182
+ log.debug(line)
183
+ return res
184
+
185
+
186
+ def get_info():
187
+ import time
188
+ t0 = time.time()
189
+ res = getsync('/sdapi/v1/system-info/status?full=true&refresh=true')
190
+ t1 = time.time()
191
+ print({ 'duration': 1000 * round(t1-t0, 3), **res })
192
+ return res
193
+
194
+
195
+ def options():
196
+ opts = getsync('/sdapi/v1/options')
197
+ flags = getsync('/sdapi/v1/cmd-flags')
198
+ return { 'options': opts, 'flags': flags }
199
+
200
+
201
+ def shutdown():
202
+ try:
203
+ postsync('/sdapi/v1/shutdown')
204
+ except Exception as e:
205
+ log.info({ 'shutdown': e })
206
+
207
+
208
+ async def session():
209
+ global sess # pylint: disable=global-statement
210
+ time = aiohttp.ClientTimeout(total = None, sock_connect = 10, sock_read = None) # default value is 5 minutes, we need longer for training
211
+ sess = aiohttp.ClientSession(timeout = time, base_url = sd_url, auth=auth())
212
+ log.debug({ 'sdapi': 'session created', 'endpoint': sd_url })
213
+ """
214
+ sess = await aiohttp.ClientSession(timeout = timeout).__aenter__()
215
+ try:
216
+ async with sess.get(url = f'{sd_url}/') as req:
217
+ log.debug({ 'sdapi': 'session created', 'endpoint': sd_url })
218
+ except Exception as e:
219
+ log.error({ 'sdapi': e })
220
+ await asyncio.sleep(0)
221
+ await sess.__aexit__(None, None, None)
222
+ sess = None
223
+ return sess
224
+ """
225
+ return sess
226
+
227
+
228
+ async def close():
229
+ if sess is not None:
230
+ await asyncio.sleep(0)
231
+ await sess.close()
232
+ await sess.__aexit__(None, None, None)
233
+ log.debug({ 'sdapi': 'session closed', 'endpoint': sd_url })
234
+
235
+
236
+ if __name__ == "__main__":
237
+ sys.argv.pop(0)
238
+ log.setLevel(logging.DEBUG)
239
+ if 'interrupt' in sys.argv:
240
+ asyncio.run(interrupt())
241
+ elif 'progress' in sys.argv:
242
+ asyncio.run(progress())
243
+ elif 'progresssync' in sys.argv:
244
+ progresssync()
245
+ elif 'options' in sys.argv:
246
+ opt = options()
247
+ log.debug({ 'options' })
248
+ import json
249
+ print(json.dumps(opt['options'], indent = 2))
250
+ log.debug({ 'cmd-flags' })
251
+ print(json.dumps(opt['flags'], indent = 2))
252
+ elif 'log' in sys.argv:
253
+ get_log()
254
+ elif 'info' in sys.argv:
255
+ get_info()
256
+ elif 'shutdown' in sys.argv:
257
+ shutdown()
258
+ else:
259
+ res = getsync(sys.argv[0])
260
+ print(res)
261
+ asyncio.run(close(), debug=True)
262
+ asyncio.run(asyncio.sleep(0.5))
cli/simple-img2img.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ import os
3
+ import io
4
+ import time
5
+ import base64
6
+ import logging
7
+ import argparse
8
+ import requests
9
+ import urllib3
10
+ from PIL import Image
11
+
12
+ sd_url = os.environ.get('SDAPI_URL', "http://127.0.0.1:7860")
13
+ sd_username = os.environ.get('SDAPI_USR', None)
14
+ sd_password = os.environ.get('SDAPI_PWD', None)
15
+
16
+ logging.basicConfig(level = logging.INFO, format = '%(asctime)s %(levelname)s: %(message)s')
17
+ log = logging.getLogger(__name__)
18
+ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
19
+
20
+ options = {
21
+ "save_images": False,
22
+ "send_images": True,
23
+ }
24
+
25
+
26
+ def auth():
27
+ if sd_username is not None and sd_password is not None:
28
+ return requests.auth.HTTPBasicAuth(sd_username, sd_password)
29
+ return None
30
+
31
+
32
+ def post(endpoint: str, dct: dict = None):
33
+ req = requests.post(f'{sd_url}{endpoint}', json = dct, timeout=300, verify=False, auth=auth())
34
+ if req.status_code != 200:
35
+ return { 'error': req.status_code, 'reason': req.reason, 'url': req.url }
36
+ else:
37
+ return req.json()
38
+
39
+
40
+ def encode(f):
41
+ image = Image.open(f)
42
+ if image.mode == 'RGBA':
43
+ image = image.convert('RGB')
44
+ with io.BytesIO() as stream:
45
+ image.save(stream, 'JPEG')
46
+ image.close()
47
+ values = stream.getvalue()
48
+ encoded = base64.b64encode(values).decode()
49
+ return encoded
50
+
51
+
52
+ def generate(args): # pylint: disable=redefined-outer-name
53
+ t0 = time.time()
54
+ if args.model is not None:
55
+ post('/sdapi/v1/options', { 'sd_model_checkpoint': args.model })
56
+ post('/sdapi/v1/reload-checkpoint') # needed if running in api-only to trigger new model load
57
+ options['prompt'] = args.prompt
58
+ options['negative_prompt'] = args.negative
59
+ options['steps'] = int(args.steps)
60
+ options['seed'] = int(args.seed)
61
+ options['sampler_name'] = args.sampler
62
+ options['init_images'] = [encode(args.init)]
63
+ image = Image.open(args.init)
64
+ options['width'] = image.width
65
+ options['height'] = image.height
66
+ image.close()
67
+ if args.mask is not None:
68
+ options['mask'] = encode(args.mask)
69
+ data = post('/sdapi/v1/img2img', options)
70
+ t1 = time.time()
71
+ if 'images' in data:
72
+ for i in range(len(data['images'])):
73
+ b64 = data['images'][i].split(',',1)[0]
74
+ info = data['info']
75
+ image = Image.open(io.BytesIO(base64.b64decode(b64)))
76
+ log.info(f'received image: size={image.size} time={t1-t0:.2f} info="{info}"')
77
+ if args.output:
78
+ image.save(args.output)
79
+ log.info(f'image saved: size={image.size} filename={args.output}')
80
+
81
+ else:
82
+ log.warning(f'no images received: {data}')
83
+
84
+
85
+ if __name__ == "__main__":
86
+ parser = argparse.ArgumentParser(description = 'simple-img2img')
87
+ parser.add_argument('--init', required=True, help='init image')
88
+ parser.add_argument('--mask', required=False, help='mask image')
89
+ parser.add_argument('--prompt', required=False, default='', help='prompt text')
90
+ parser.add_argument('--negative', required=False, default='', help='negative prompt text')
91
+ parser.add_argument('--steps', required=False, default=20, help='number of steps')
92
+ parser.add_argument('--seed', required=False, default=-1, help='initial seed')
93
+ parser.add_argument('--sampler', required=False, default='Euler a', help='sampler name')
94
+ parser.add_argument('--output', required=False, default=None, help='output image file')
95
+ parser.add_argument('--model', required=False, help='model name')
96
+ args = parser.parse_args()
97
+ log.info(f'img2img: {args}')
98
+ generate(args)
cli/simple-info.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ import os
3
+ import time
4
+ import base64
5
+ import logging
6
+ import argparse
7
+ import requests
8
+ import urllib3
9
+
10
+
11
+ sd_url = os.environ.get('SDAPI_URL', "http://127.0.0.1:7860")
12
+ sd_username = os.environ.get('SDAPI_USR', None)
13
+ sd_password = os.environ.get('SDAPI_PWD', None)
14
+
15
+
16
+ logging.basicConfig(level = logging.INFO, format = '%(asctime)s %(levelname)s: %(message)s')
17
+ log = logging.getLogger(__name__)
18
+ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
19
+
20
+
21
+ def auth():
22
+ if sd_username is not None and sd_password is not None:
23
+ return requests.auth.HTTPBasicAuth(sd_username, sd_password)
24
+ return None
25
+
26
+
27
+ def get(endpoint: str, dct: dict = None):
28
+ req = requests.get(f'{sd_url}{endpoint}', json=dct, timeout=300, verify=False, auth=auth())
29
+ if req.status_code != 200:
30
+ return { 'error': req.status_code, 'reason': req.reason, 'url': req.url }
31
+ else:
32
+ return req.json()
33
+
34
+
35
+ def post(endpoint: str, dct: dict = None):
36
+ req = requests.post(f'{sd_url}{endpoint}', json = dct, timeout=300, verify=False, auth=auth())
37
+ if req.status_code != 200:
38
+ return { 'error': req.status_code, 'reason': req.reason, 'url': req.url }
39
+ else:
40
+ return req.json()
41
+
42
+
43
+ def info(args): # pylint: disable=redefined-outer-name
44
+ t0 = time.time()
45
+ with open(args.input, 'rb') as f:
46
+ content = f.read()
47
+ data = post('/sdapi/v1/png-info', { 'image': base64.b64encode(content).decode() })
48
+ t1 = time.time()
49
+ log.info(f'received: {data} time={t1-t0:.2f}')
50
+
51
+
52
+ if __name__ == "__main__":
53
+ parser = argparse.ArgumentParser(description = 'simple-info')
54
+ parser.add_argument('--input', required=True, help='input image')
55
+ args = parser.parse_args()
56
+ log.info(f'info: {args}')
57
+ info(args)
cli/simple-mask.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ import io
3
+ import os
4
+ import time
5
+ import base64
6
+ import logging
7
+ import argparse
8
+ import requests
9
+ import urllib3
10
+ from PIL import Image
11
+
12
+
13
+ sd_url = os.environ.get('SDAPI_URL', "http://127.0.0.1:7860")
14
+ sd_username = os.environ.get('SDAPI_USR', None)
15
+ sd_password = os.environ.get('SDAPI_PWD', None)
16
+
17
+
18
+ logging.basicConfig(level = logging.INFO, format = '%(asctime)s %(levelname)s: %(message)s')
19
+ log = logging.getLogger(__name__)
20
+ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
21
+
22
+
23
+ def auth():
24
+ if sd_username is not None and sd_password is not None:
25
+ return requests.auth.HTTPBasicAuth(sd_username, sd_password)
26
+ return None
27
+
28
+
29
+ def get(endpoint: str, dct: dict = None):
30
+ req = requests.get(f'{sd_url}{endpoint}', json=dct, timeout=300, verify=False, auth=auth())
31
+ if req.status_code != 200:
32
+ return { 'error': req.status_code, 'reason': req.reason, 'url': req.url }
33
+ else:
34
+ return req.json()
35
+
36
+
37
+ def post(endpoint: str, dct: dict = None):
38
+ req = requests.post(f'{sd_url}{endpoint}', json = dct, timeout=300, verify=False, auth=auth())
39
+ if req.status_code != 200:
40
+ return { 'error': req.status_code, 'reason': req.reason, 'url': req.url }
41
+ else:
42
+ return req.json()
43
+
44
+
45
+ def info(args): # pylint: disable=redefined-outer-name
46
+ t0 = time.time()
47
+ with open(args.input, 'rb') as f:
48
+ image = base64.b64encode(f.read()).decode()
49
+ if args.mask:
50
+ with open(args.mask, 'rb') as f:
51
+ mask = base64.b64encode(f.read()).decode()
52
+ else:
53
+ mask = None
54
+ options = get('/sdapi/v1/masking')
55
+ log.info(f'options: {options}')
56
+ req = {
57
+ 'image': image,
58
+ 'mask': mask,
59
+ 'type': args.type or 'Composite',
60
+ 'params': { 'auto_mask': 'Grayscale' if mask is None else None },
61
+ }
62
+ data = post('/sdapi/v1/mask', req)
63
+ t1 = time.time()
64
+ if 'mask' in data:
65
+ b64 = data['mask'].split(',',1)[0]
66
+ image = Image.open(io.BytesIO(base64.b64decode(b64)))
67
+ log.info(f'received image: size={image.size} time={t1-t0:.2f}')
68
+ if args.output:
69
+ image.save(args.output)
70
+ log.info(f'saved image: fn={args.output}')
71
+ else:
72
+ log.info(f'received: {data} time={t1-t0:.2f}')
73
+
74
+
75
+ if __name__ == "__main__":
76
+ parser = argparse.ArgumentParser(description = 'simple-info')
77
+ parser.add_argument('--input', required=True, help='input image')
78
+ parser.add_argument('--mask', required=False, help='input mask')
79
+ parser.add_argument('--type', required=False, help='output mask type')
80
+ parser.add_argument('--output', required=False, help='output image')
81
+ args = parser.parse_args()
82
+ log.info(f'info: {args}')
83
+ info(args)
cli/simple-preprocess.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ import io
3
+ import os
4
+ import time
5
+ import base64
6
+ import logging
7
+ import argparse
8
+ import requests
9
+ import urllib3
10
+ from PIL import Image
11
+
12
+
13
+ sd_url = os.environ.get('SDAPI_URL', "http://127.0.0.1:7860")
14
+ sd_username = os.environ.get('SDAPI_USR', None)
15
+ sd_password = os.environ.get('SDAPI_PWD', None)
16
+
17
+
18
+ logging.basicConfig(level = logging.INFO, format = '%(asctime)s %(levelname)s: %(message)s')
19
+ log = logging.getLogger(__name__)
20
+ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
21
+
22
+
23
+ def auth():
24
+ if sd_username is not None and sd_password is not None:
25
+ return requests.auth.HTTPBasicAuth(sd_username, sd_password)
26
+ return None
27
+
28
+
29
+ def get(endpoint: str, dct: dict = None):
30
+ req = requests.get(f'{sd_url}{endpoint}', json=dct, timeout=300, verify=False, auth=auth())
31
+ if req.status_code != 200:
32
+ return { 'error': req.status_code, 'reason': req.reason, 'url': req.url }
33
+ else:
34
+ return req.json()
35
+
36
+
37
+ def post(endpoint: str, dct: dict = None):
38
+ req = requests.post(f'{sd_url}{endpoint}', json = dct, timeout=300, verify=False, auth=auth())
39
+ if req.status_code != 200:
40
+ return { 'error': req.status_code, 'reason': req.reason, 'url': req.url }
41
+ else:
42
+ return req.json()
43
+
44
+
45
+ def info(args): # pylint: disable=redefined-outer-name
46
+ t0 = time.time()
47
+ with open(args.input, 'rb') as f:
48
+ content = f.read()
49
+ models = get('/sdapi/v1/preprocessors')
50
+ log.info(f'models: {models}')
51
+ req = {
52
+ 'model': args.model or 'Canny',
53
+ 'image': base64.b64encode(content).decode(),
54
+ 'config': { 'low_threshold': 50 },
55
+ }
56
+ data = post('/sdapi/v1/preprocess', req)
57
+ t1 = time.time()
58
+ if 'image' in data:
59
+ b64 = data['image'].split(',',1)[0]
60
+ image = Image.open(io.BytesIO(base64.b64decode(b64)))
61
+ log.info(f'received image: size={image.size} time={t1-t0:.2f}')
62
+ if args.output:
63
+ image.save(args.output)
64
+ log.info(f'saved image: fn={args.output}')
65
+ else:
66
+ log.info(f'received: {data} time={t1-t0:.2f}')
67
+
68
+
69
+ if __name__ == "__main__":
70
+ parser = argparse.ArgumentParser(description = 'simple-info')
71
+ parser.add_argument('--input', required=True, help='input image')
72
+ parser.add_argument('--model', required=True, help='preprocessing model')
73
+ parser.add_argument('--output', required=False, help='output image')
74
+ args = parser.parse_args()
75
+ log.info(f'info: {args}')
76
+ info(args)
cli/simple-txt2img.js ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env node
2
+
3
+ // simple nodejs script to test sdnext api
4
+
5
+ const fs = require('fs'); // eslint-disable-line no-undef
6
+ const process = require('process'); // eslint-disable-line no-undef
7
+
8
+ const sd_url = process.env.SDAPI_URL || 'http://127.0.0.1:7860';
9
+ const sd_username = process.env.SDAPI_USR;
10
+ const sd_password = process.env.SDAPI_PWD;
11
+ const sd_options = {
12
+ // first pass
13
+ prompt: 'city at night',
14
+ negative_prompt: 'foggy, blurry',
15
+ sampler_name: 'UniPC',
16
+ seed: -1,
17
+ steps: 20,
18
+ batch_size: 1,
19
+ n_iter: 1,
20
+ cfg_scale: 6,
21
+ width: 512,
22
+ height: 512,
23
+ // enable second pass
24
+ enable_hr: true,
25
+ // second pass: upscale
26
+ hr_upscaler: 'SCUNet GAN',
27
+ hr_scale: 2.0,
28
+ // second pass: hires
29
+ hr_force: true,
30
+ hr_second_pass_steps: 20,
31
+ hr_sampler_name: 'UniPC',
32
+ denoising_strength: 0.5,
33
+ // second pass: refiner
34
+ refiner_steps: 5,
35
+ refiner_start: 0.8,
36
+ refiner_prompt: '',
37
+ refiner_negative: '',
38
+ // api return options
39
+ save_images: false,
40
+ send_images: true,
41
+ };
42
+
43
+ async function main() {
44
+ const method = 'POST';
45
+ const headers = new Headers();
46
+ const body = JSON.stringify(sd_options);
47
+ headers.set('Content-Type', 'application/json');
48
+ if (sd_username && sd_password) headers.set({ Authorization: `Basic ${btoa('sd_username:sd_password')}` });
49
+ const res = await fetch(`${sd_url}/sdapi/v1/txt2img`, { method, headers, body });
50
+ if (res.status !== 200) {
51
+ console.log('Error', res.status);
52
+ } else {
53
+ const json = await res.json();
54
+ console.log('result:', json.info);
55
+ for (const i in json.images) { // eslint-disable-line guard-for-in
56
+ const f = `/tmp/test-{${i}.jpg`;
57
+ fs.writeFileSync(f, atob(json.images[i]), 'binary');
58
+ console.log('image saved:', f);
59
+ }
60
+ }
61
+ }
62
+
63
+ main();
cli/simple-txt2img.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ import io
3
+ import os
4
+ import time
5
+ import base64
6
+ import logging
7
+ import argparse
8
+ import requests
9
+ import urllib3
10
+ from PIL import Image
11
+
12
+ sd_url = os.environ.get('SDAPI_URL', "http://127.0.0.1:7860")
13
+ sd_username = os.environ.get('SDAPI_USR', None)
14
+ sd_password = os.environ.get('SDAPI_PWD', None)
15
+
16
+ logging.basicConfig(level = logging.INFO, format = '%(asctime)s %(levelname)s: %(message)s')
17
+ log = logging.getLogger(__name__)
18
+ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
19
+
20
+ options = {
21
+ "save_images": False,
22
+ "send_images": True,
23
+ }
24
+
25
+
26
+ def auth():
27
+ if sd_username is not None and sd_password is not None:
28
+ return requests.auth.HTTPBasicAuth(sd_username, sd_password)
29
+ return None
30
+
31
+
32
+ def post(endpoint: str, dct: dict = None):
33
+ req = requests.post(f'{sd_url}{endpoint}', json = dct, timeout=300, verify=False, auth=auth())
34
+ if req.status_code != 200:
35
+ return { 'error': req.status_code, 'reason': req.reason, 'url': req.url }
36
+ else:
37
+ return req.json()
38
+
39
+
40
+ def generate(args): # pylint: disable=redefined-outer-name
41
+ t0 = time.time()
42
+ if args.model is not None:
43
+ post('/sdapi/v1/options', { 'sd_model_checkpoint': args.model })
44
+ post('/sdapi/v1/reload-checkpoint') # needed if running in api-only to trigger new model load
45
+ options['prompt'] = args.prompt
46
+ options['negative_prompt'] = args.negative
47
+ options['steps'] = int(args.steps)
48
+ options['seed'] = int(args.seed)
49
+ options['sampler_name'] = args.sampler
50
+ options['width'] = int(args.width)
51
+ options['height'] = int(args.height)
52
+ data = post('/sdapi/v1/txt2img', options)
53
+ t1 = time.time()
54
+ if 'images' in data:
55
+ for i in range(len(data['images'])):
56
+ b64 = data['images'][i].split(',',1)[0]
57
+ image = Image.open(io.BytesIO(base64.b64decode(b64)))
58
+ info = data['info']
59
+ log.info(f'image received: size={image.size} time={t1-t0:.2f} info="{info}"')
60
+ if args.output:
61
+ image.save(args.output)
62
+ log.info(f'image saved: size={image.size} filename={args.output}')
63
+ else:
64
+ log.warning(f'no images received: {data}')
65
+
66
+
67
+ if __name__ == "__main__":
68
+ parser = argparse.ArgumentParser(description = 'simple-txt2img')
69
+ parser.add_argument('--prompt', required=False, default='', help='prompt text')
70
+ parser.add_argument('--negative', required=False, default='', help='negative prompt text')
71
+ parser.add_argument('--width', required=False, default=512, help='image width')
72
+ parser.add_argument('--height', required=False, default=512, help='image height')
73
+ parser.add_argument('--steps', required=False, default=20, help='number of steps')
74
+ parser.add_argument('--seed', required=False, default=-1, help='initial seed')
75
+ parser.add_argument('--sampler', required=False, default='Euler a', help='sampler name')
76
+ parser.add_argument('--output', required=False, default=None, help='output image file')
77
+ parser.add_argument('--model', required=False, help='model name')
78
+ args = parser.parse_args()
79
+ log.info(f'txt2img: {args}')
80
+ generate(args)
cli/simple-upscale.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ import os
3
+ import io
4
+ import time
5
+ import base64
6
+ import logging
7
+ import argparse
8
+ import requests
9
+ import urllib3
10
+ from PIL import Image
11
+
12
+ sd_url = os.environ.get('SDAPI_URL', "http://127.0.0.1:7860")
13
+ sd_username = os.environ.get('SDAPI_USR', None)
14
+ sd_password = os.environ.get('SDAPI_PWD', None)
15
+
16
+ logging.basicConfig(level = logging.INFO, format = '%(asctime)s %(levelname)s: %(message)s')
17
+ log = logging.getLogger(__name__)
18
+ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
19
+
20
+
21
+ def auth():
22
+ if sd_username is not None and sd_password is not None:
23
+ return requests.auth.HTTPBasicAuth(sd_username, sd_password)
24
+ return None
25
+
26
+
27
+ def get(endpoint: str, dct: dict = None):
28
+ req = requests.get(f'{sd_url}{endpoint}', json=dct, timeout=300, verify=False, auth=auth())
29
+ if req.status_code != 200:
30
+ return { 'error': req.status_code, 'reason': req.reason, 'url': req.url }
31
+ else:
32
+ return req.json()
33
+
34
+
35
+ def post(endpoint: str, dct: dict = None):
36
+ req = requests.post(f'{sd_url}{endpoint}', json = dct, timeout=300, verify=False, auth=auth())
37
+ if req.status_code != 200:
38
+ return { 'error': req.status_code, 'reason': req.reason, 'url': req.url }
39
+ else:
40
+ return req.json()
41
+
42
+
43
+ def encode(f):
44
+ image = Image.open(f)
45
+ if image.mode == 'RGBA':
46
+ image = image.convert('RGB')
47
+ log.info(f'encoding image: {image}')
48
+ with io.BytesIO() as stream:
49
+ image.save(stream, 'JPEG')
50
+ image.close()
51
+ values = stream.getvalue()
52
+ encoded = base64.b64encode(values).decode()
53
+ return encoded
54
+
55
+
56
+ def upscale(args): # pylint: disable=redefined-outer-name
57
+ t0 = time.time()
58
+ # options['mask'] = encode(args.mask)
59
+ upscalers = get('/sdapi/v1/upscalers')
60
+ upscalers = [u['name'] for u in upscalers]
61
+ log.info(f'upscalers: {upscalers}')
62
+ options = {
63
+ "save_images": False,
64
+ "send_images": True,
65
+ 'image': encode(args.input),
66
+ 'upscaler_1': args.upscaler,
67
+ 'resize_mode': 0, # rescale_by
68
+ 'upscaling_resize': args.scale,
69
+
70
+ }
71
+ data = post('/sdapi/v1/extra-single-image', options)
72
+ t1 = time.time()
73
+ if 'image' in data:
74
+ b64 = data['image'].split(',',1)[0]
75
+ image = Image.open(io.BytesIO(base64.b64decode(b64)))
76
+ image.save(args.output)
77
+ log.info(f'received: image={image} file={args.output} time={t1-t0:.2f}')
78
+ else:
79
+ log.warning(f'no images received: {data}')
80
+
81
+
82
+ if __name__ == "__main__":
83
+ parser = argparse.ArgumentParser(description = 'simple-upscale')
84
+ parser.add_argument('--input', required=True, help='input image')
85
+ parser.add_argument('--output', required=True, help='output image')
86
+ parser.add_argument('--upscaler', required=False, default='Nearest', help='upscaler name')
87
+ parser.add_argument('--scale', required=False, default=2, help='upscaler scale')
88
+ args = parser.parse_args()
89
+ log.info(f'upscale: {args}')
90
+ upscale(args)
cli/torch-compile.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # pylint: disable=cell-var-from-loop
3
+ """
4
+ Test Torch Dynamo functionality and backends
5
+ """
6
+ import json
7
+ import warnings
8
+
9
+ import numpy as np
10
+ import torch
11
+ from torchvision.models import resnet18
12
+
13
+
14
+ print('torch:', torch.__version__)
15
+ try:
16
+ # must be imported explicitly or namespace is not found
17
+ import torch._dynamo as dynamo # pylint: disable=ungrouped-imports
18
+ except Exception as err:
19
+ print('torch without dynamo support', err)
20
+
21
+
22
+ N_ITERS = 20
23
+ torch._dynamo.config.verbose=True # pylint: disable=protected-access
24
+ warnings.filterwarnings('ignore', category=UserWarning) # disable those for now as many backends reports tons
25
+ # torch.set_float32_matmul_precision('high') # enable to test in fp32
26
+
27
+
28
+ def timed(fn): # returns the result of running `fn()` and the time it took for `fn()` to run in ms using CUDA events
29
+ start = torch.cuda.Event(enable_timing=True)
30
+ end = torch.cuda.Event(enable_timing=True)
31
+ start.record()
32
+ result = fn()
33
+ end.record()
34
+ torch.cuda.synchronize()
35
+ return result, start.elapsed_time(end)
36
+
37
+
38
+ def generate_data(b):
39
+ return (
40
+ torch.randn(b, 3, 128, 128).to(torch.float32).cuda(),
41
+ torch.randint(1000, (b,)).cuda(),
42
+ )
43
+
44
+
45
+ def init_model():
46
+ return resnet18().to(torch.float32).cuda()
47
+
48
+
49
+ def evaluate(mod, val):
50
+ return mod(val)
51
+
52
+
53
+ if __name__ == '__main__':
54
+ # first pass, dynamo is going to be slower as it compiles
55
+ model = init_model()
56
+ inp = generate_data(16)[0]
57
+
58
+ # repeat test
59
+ results = {}
60
+ times = []
61
+ print('eager initial eval:', timed(lambda: evaluate(model, inp))[1])
62
+ for _i in range(N_ITERS):
63
+ inp = generate_data(16)[0]
64
+ _res, time = timed(lambda: evaluate(model, inp)) # noqa: B023
65
+ times.append(time)
66
+ results['default'] = np.median(times)
67
+
68
+ print('dynamo available backends:', dynamo.list_backends())
69
+ for backend in dynamo.list_backends():
70
+ try:
71
+ # required before changing backends
72
+ torch._dynamo.reset() # pylint: disable=protected-access
73
+ eval_dyn = dynamo.optimize(backend)(evaluate)
74
+ print('dynamo initial eval:', backend, timed(lambda: eval_dyn(model, inp))[1]) # noqa: B023
75
+ times = []
76
+ for _i in range(N_ITERS):
77
+ inp = generate_data(16)[0]
78
+ _res, time = timed(lambda: eval_dyn(model, inp)) # noqa: B023
79
+ times.append(time)
80
+ results[backend] = np.median(times)
81
+ except Exception as err:
82
+ lines = str(err).split('\n')
83
+ print('dyanmo backend failed:', backend, lines[0]) # print just first error line as backtraces can be quite long
84
+ results[backend] = 'error'
85
+
86
+ # print stats
87
+ print(json.dumps(results, indent = 4))
88
+
89
+ """
90
+ Reference: <https://github.com/pytorch/pytorch/blob/4f4b62e4a255708e928445b6502139d5962974fa/docs/source/dynamo/get-started.rst>
91
+ Training & Inference backends:
92
+ dynamo.optimize("inductor") - Uses TorchInductor backend with AotAutograd and cudagraphs by leveraging codegened Triton kernels
93
+ dynamo.optimize("aot_nvfuser") - nvFuser with AotAutograd
94
+ dynamo.optimize("aot_cudagraphs") - cudagraphs with AotAutograd
95
+ Inference-only backends:
96
+ dynamo.optimize("ofi") - Uses Torchscript optimize_for_inference
97
+ dynamo.optimize("fx2trt") - Uses Nvidia TensorRT for inference optimizations
98
+ dynamo.optimize("onnxrt") - Uses ONNXRT for inference on CPU/GPU
99
+ """
cli/train.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ """
4
+ Examples:
5
+ - sd15: train.py --type lora --tag girl --comments sdnext --input ~/generative/Input/mia --process original,interrogate,resize --name mia
6
+ - sdxl: train.py --type lora --tag girl --comments sdnext --input ~/generative/Input/mia --process original,interrogate,resize --precision fp32 --optimizer Adafactor --sdxl --name miaxl
7
+ - offline: train.py --type lora --tag girl --comments sdnext --input ~/generative/Input/mia --model /home/vlado/dev/sdnext/models/Stable-diffusion/sdxl/miaanimeSFWNSFWSDXL_v40.safetensors --dir /home/vlado/dev/sdnext/models/Lora/ --precision fp32 --optimizer Adafactor --sdxl --name miaxl
8
+ """
9
+
10
+ # system imports
11
+ import os
12
+ import re
13
+ import gc
14
+ import sys
15
+ import json
16
+ import shutil
17
+ import pathlib
18
+ import asyncio
19
+ import logging
20
+ import tempfile
21
+ import argparse
22
+
23
+ # local imports
24
+ import util
25
+ import sdapi
26
+ import options
27
+
28
+
29
+ # globals
30
+ args = None
31
+ log = logging.getLogger('train')
32
+ valid_steps = ['original', 'face', 'body', 'blur', 'range', 'upscale', 'restore', 'interrogate', 'resize', 'square', 'segment']
33
+ log_file = os.path.join(os.path.dirname(__file__), 'train.log')
34
+ server_ok = False
35
+
36
+ # methods
37
+
38
+ def setup_logging():
39
+ from rich.theme import Theme
40
+ from rich.logging import RichHandler
41
+ from rich.console import Console
42
+ from rich.pretty import install as pretty_install
43
+ from rich.traceback import install as traceback_install
44
+ console = Console(log_time=True, log_time_format='%H:%M:%S-%f', theme=Theme({
45
+ "traceback.border": "black",
46
+ "traceback.border.syntax_error": "black",
47
+ "inspect.value.border": "black",
48
+ }))
49
+ # logging.getLogger("urllib3").setLevel(logging.ERROR)
50
+ # logging.getLogger("httpx").setLevel(logging.ERROR)
51
+ level = logging.DEBUG if args.debug else logging.INFO
52
+ logging.basicConfig(level=logging.ERROR, format='%(asctime)s | %(name)s | %(levelname)s | %(module)s | %(message)s', filename=log_file, filemode='a', encoding='utf-8', force=True)
53
+ log.setLevel(logging.DEBUG) # log to file is always at level debug for facility `sd`
54
+ pretty_install(console=console)
55
+ traceback_install(console=console, extra_lines=1, width=console.width, word_wrap=False, indent_guides=False, suppress=[])
56
+ rh = RichHandler(show_time=True, omit_repeated_times=False, show_level=True, show_path=False, markup=False, rich_tracebacks=True, log_time_format='%H:%M:%S-%f', level=level, console=console)
57
+ rh.set_name(level)
58
+ while log.hasHandlers() and len(log.handlers) > 0:
59
+ log.removeHandler(log.handlers[0])
60
+ log.addHandler(rh)
61
+
62
+
63
+ def mem_stats():
64
+ gc.collect()
65
+ import torch
66
+ if torch.cuda.is_available():
67
+ with torch.no_grad():
68
+ torch.cuda.empty_cache()
69
+ with torch.cuda.device('cuda'):
70
+ torch.cuda.empty_cache()
71
+ torch.cuda.ipc_collect()
72
+ mem = util.get_memory()
73
+ peak = { 'active': mem['gpu-active']['peak'], 'allocated': mem['gpu-allocated']['peak'], 'reserved': mem['gpu-reserved']['peak'] }
74
+ log.debug(f"memory cpu: {mem.ram} gpu current: {mem.gpu} gpu peak: {peak}")
75
+
76
+
77
+ def parse_args():
78
+ global args # pylint: disable=global-statement
79
+ parser = argparse.ArgumentParser(description = 'SD.Next Train')
80
+
81
+ group_server = parser.add_argument_group('Server')
82
+ group_server.add_argument('--server', type=str, default='http://127.0.0.1:7860', required=False, help='server url, default: %(default)s')
83
+ group_server.add_argument('--user', type=str, default=None, required=False, help='server url, default: %(default)s')
84
+ group_server.add_argument('--password', type=str, default=None, required=False, help='server url, default: %(default)s')
85
+ group_server.add_argument('--dir', type=str, default=None, required=False, help='folder with trained networks, default: use server setting')
86
+
87
+ group_main = parser.add_argument_group('Main')
88
+ group_main.add_argument('--type', type=str, choices=['embedding', 'ti', 'lora', 'lyco', 'dreambooth', 'hypernetwork'], default=None, required=True, help='training type')
89
+ group_main.add_argument('--model', type=str, default='', required=False, help='base model to use for training, default: current loaded model')
90
+ group_main.add_argument('--name', type=str, default=None, required=True, help='output filename')
91
+ group_main.add_argument('--tag', type=str, default='person', required=False, help='primary tags, default: %(default)s')
92
+ group_main.add_argument('--comments', type=str, default='', required=False, help='comments to be added to trained model metadata, default: %(default)s')
93
+
94
+ group_data = parser.add_argument_group('Dataset')
95
+ group_data.add_argument('--input', type=str, default=None, required=True, help='input folder with training images')
96
+ group_data.add_argument('--interim', type=str, default='', required=False, help='where to store processed images, default is system temp/train')
97
+ group_data.add_argument('--process', type=str, default='original,interrogate,resize,square', required=False, help=f'list of possible processing steps: {valid_steps}, default: %(default)s')
98
+
99
+ group_train = parser.add_argument_group('Train')
100
+ group_train.add_argument('--gradient', type=int, default=1, required=False, help='gradient accumulation steps, default: %(default)s')
101
+ group_train.add_argument('--steps', type=int, default=2500, required=False, help='training steps, default: %(default)s')
102
+ group_train.add_argument('--batch', type=int, default=1, required=False, help='batch size, default: %(default)s')
103
+ group_train.add_argument('--lr', type=float, default=1e-04, required=False, help='model learning rate, default: %(default)s')
104
+ group_train.add_argument('--dim', type=int, default=32, required=False, help='network dimension or number of vectors, default: %(default)s')
105
+
106
+ # lora params
107
+ group_train.add_argument('--repeats', type=int, default=1, required=False, help='number of repeats per image, default: %(default)s')
108
+ group_train.add_argument('--alpha', type=float, default=0, required=False, help='lora/lyco alpha for weights scaling, default: dim/2')
109
+ group_train.add_argument('--algo', type=str, default=None, choices=['locon', 'loha', 'lokr', 'ia3'], required=False, help='alternative lyco algoritm, default: %(default)s')
110
+ group_train.add_argument('--args', type=str, default=None, required=False, help='lora/lyco additional network arguments, default: %(default)s')
111
+ group_train.add_argument('--optimizer', type=str, default='AdamW', required=False, help='optimizer type, default: %(default)s')
112
+ group_train.add_argument('--precision', type=str, choices=['fp16', 'fp32'], default='fp16', required=False, help='training precision, default: %(default)s')
113
+ group_train.add_argument('--sdxl', default = False, action='store_true', help = "run sdxl training, default: %(default)s")
114
+ # AdamW (default), AdamW8bit, PagedAdamW8bit, Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor
115
+
116
+ group_other = parser.add_argument_group('Other')
117
+ group_other.add_argument('--overwrite', default = False, action='store_true', help = "overwrite existing training, default: %(default)s")
118
+ group_other.add_argument('--experimental', default = False, action='store_true', help = "enable experimental options, default: %(default)s")
119
+ group_other.add_argument('--debug', default = False, action='store_true', help = "enable debug level logging, default: %(default)s")
120
+
121
+ args = parser.parse_args()
122
+
123
+
124
+ def prepare_server():
125
+ global server_ok # pylint: disable=global-statement
126
+ try:
127
+ server_status = util.Map(sdapi.progresssync())
128
+ server_state = server_status['state']
129
+ server_ok = True
130
+ except Exception:
131
+ log.warning(f'sdnext server error: {server_status}')
132
+ server_ok = False
133
+ if server_ok and server_state['job_count'] > 0:
134
+ log.error(f'sdnext server not idle: {server_state}')
135
+ exit(1)
136
+ if server_ok:
137
+ server_options = util.Map(sdapi.options())
138
+ server_options.options.save_training_settings_to_txt = False
139
+ server_options.options.training_enable_tensorboard = False
140
+ server_options.options.training_tensorboard_save_images = False
141
+ server_options.options.pin_memory = True
142
+ server_options.options.save_optimizer_state = False
143
+ server_options.options.training_image_repeats_per_epoch = args.repeats
144
+ server_options.options.training_write_csv_every = 0
145
+ sdapi.postsync('/sdapi/v1/options', server_options.options)
146
+ log.info('updated server options')
147
+
148
+
149
+ def verify_args():
150
+ server_options = util.Map(sdapi.options())
151
+ if args.model != '':
152
+ if not os.path.isfile(args.model):
153
+ log.error(f'cannot find loaded model: {args.model}')
154
+ exit(1)
155
+ if server_ok:
156
+ server_options.options.sd_model_checkpoint = args.model
157
+ sdapi.postsync('/sdapi/v1/options', server_options.options)
158
+ elif server_ok:
159
+ args.model = server_options.options.sd_model_checkpoint.split(' [')[0]
160
+ if args.sdxl and (server_options.sd_backend != 'diffusers' or server_options.diffusers_pipeline != 'Stable Diffusion XL'):
161
+ log.warning('server checkpoint is not sdxl')
162
+ else:
163
+ log.error('no model specified')
164
+ exit(1)
165
+ base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
166
+ if args.type == 'lora' and not server_ok and not args.dir:
167
+ log.error('offline lora training requires --dir <lora folder>')
168
+ exit(1)
169
+ if args.type == 'lora':
170
+ import transformers
171
+ if transformers.__version__ != '4.30.2':
172
+ log.error(f'lora training requires specific transformers version: current {transformers.__version__} required transformers==4.30.2')
173
+ exit(1)
174
+ args.lora_dir = server_options.options.lora_dir or args.dir
175
+ if not os.path.isabs(args.lora_dir):
176
+ args.lora_dir = os.path.join(base_dir, args.lora_dir)
177
+ args.lyco_dir = server_options.options.lyco_dir or args.dir
178
+ if not os.path.isabs(args.lyco_dir):
179
+ args.lyco_dir = os.path.join(base_dir, args.lyco_dir)
180
+ args.embeddings_dir = server_options.options.embeddings_dir or args.dir
181
+ if not os.path.isfile(args.model):
182
+ args.ckpt_dir = server_options.options.ckpt_dir
183
+ if not os.path.isabs(args.ckpt_dir):
184
+ args.ckpt_dir = os.path.join(base_dir, args.ckpt_dir)
185
+ attempt = os.path.abspath(os.path.join(args.ckpt_dir, args.model))
186
+ args.model = attempt if os.path.isfile(attempt) else args.model
187
+ if not os.path.isfile(args.model):
188
+ attempt = os.path.abspath(os.path.join(args.ckpt_dir, args.model + '.safetensors'))
189
+ args.model = attempt if os.path.isfile(attempt) else args.model
190
+ if not os.path.isfile(args.model):
191
+ log.error(f'cannot find loaded model: {args.model}')
192
+ exit(1)
193
+ if not os.path.exists(args.input) or not os.path.isdir(args.input):
194
+ log.error(f'cannot find training folder: {args.input}')
195
+ exit(1)
196
+ if not os.path.exists(args.lora_dir) or not os.path.isdir(args.lora_dir):
197
+ log.error(f'cannot find lora folder: {args.lora_dir}')
198
+ exit(1)
199
+ if not os.path.exists(args.lyco_dir) or not os.path.isdir(args.lyco_dir):
200
+ log.error(f'cannot find lyco folder: {args.lyco_dir}')
201
+ exit(1)
202
+ if args.interim != '':
203
+ args.process_dir = args.interim
204
+ else:
205
+ args.process_dir = os.path.join(tempfile.gettempdir(), 'train', args.name)
206
+ log.debug(f'args: {vars(args)}')
207
+ log.debug(f'server flags: {server_options.flags}')
208
+ log.debug(f'server options: {server_options.options}')
209
+
210
+
211
+ async def training_loop():
212
+ async def async_train():
213
+ res = await sdapi.post('/sdapi/v1/train/embedding', options.embedding)
214
+ log.info(f'train embedding result: {res}')
215
+
216
+ async def async_monitor():
217
+ from tqdm.rich import tqdm
218
+ await asyncio.sleep(3)
219
+ res = util.Map(sdapi.progress())
220
+ with tqdm(desc='train embedding', total=res.state.job_count) as pbar:
221
+ while res.state.job_no < res.state.job_count and not res.state.interrupted and not res.state.skipped:
222
+ await asyncio.sleep(2)
223
+ prev_job = res.state.job_no
224
+ res = util.Map(sdapi.progress())
225
+ loss = re.search(r"Loss: (.*?)(?=\<)", res.textinfo)
226
+ if loss:
227
+ pbar.set_postfix({ 'loss': loss.group(0) })
228
+ pbar.update(res.state.job_no - prev_job)
229
+
230
+ a = asyncio.create_task(async_train())
231
+ b = asyncio.create_task(async_monitor())
232
+ await asyncio.gather(a, b) # wait for both pipeline and monitor to finish
233
+
234
+
235
+ def train_embedding():
236
+ log.info(f'{args.type} options: {options.embedding}')
237
+ create_options = util.Map({
238
+ "name": args.name,
239
+ "num_vectors_per_token": args.dim,
240
+ "overwrite_old": False,
241
+ "init_text": args.tag,
242
+ })
243
+ fn = os.path.join(args.embeddings_dir, args.name) + '.pt'
244
+ if os.path.exists(fn) and args.overwrite:
245
+ log.warning(f'delete existing embedding {fn}')
246
+ os.remove(fn)
247
+ else:
248
+ log.error(f'embedding exists {fn}')
249
+ return
250
+ log.info(f'create embedding {create_options}')
251
+ res = sdapi.postsync('/sdapi/v1/create/embedding', create_options)
252
+ if 'info' in res and 'error' in res['info']: # formatted error
253
+ log.error(res.info)
254
+ elif 'info' in res: # no error
255
+ asyncio.run(training_loop())
256
+ else: # unknown error
257
+ log.error(f'create embedding error {res}')
258
+
259
+
260
+ def train_lora():
261
+ fn = os.path.join(options.lora.output_dir, args.name)
262
+ for ext in ['.ckpt', '.pt', '.safetensors']:
263
+ if os.path.exists(fn + ext):
264
+ if args.overwrite:
265
+ log.warning(f'delete existing lora: {fn + ext}')
266
+ os.remove(fn + ext)
267
+ else:
268
+ log.error(f'lora exists: {fn + ext}')
269
+ return
270
+ log.info(f'{args.type} options: {options.lora}')
271
+ # lora imports
272
+ lora_path = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, 'modules', 'lora'))
273
+ lycoris_path = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, 'modules', 'lycoris'))
274
+ sys.path.append(lora_path)
275
+ if args.type == 'lyco':
276
+ sys.path.append(lycoris_path)
277
+ log.debug('importing lora lib')
278
+ if not args.sdxl:
279
+ import train_network
280
+ trainer = train_network.NetworkTrainer()
281
+ trainer.train(options.lora)
282
+ else:
283
+ import sdxl_train_network
284
+ trainer = sdxl_train_network.SdxlNetworkTrainer()
285
+ trainer.train(options.lora)
286
+ if args.type == 'lyco':
287
+ log.debug('importing lycoris lib')
288
+ import importlib
289
+ _network_module = importlib.import_module(options.lora.network_module)
290
+
291
+
292
+ def prepare_options():
293
+ if args.type == 'embedding':
294
+ log.info('train embedding')
295
+ options.lora.in_json = None
296
+ if args.type == 'dreambooth':
297
+ log.info('train using dreambooth style training')
298
+ options.lora.vae_batch_size = args.batch
299
+ options.lora.in_json = None
300
+ if args.type == 'lora':
301
+ log.info('train using lora style training')
302
+ options.lora.output_dir = args.lora_dir
303
+ options.lora.in_json = os.path.join(args.process_dir, args.name + '.json')
304
+ if args.type == 'lyco':
305
+ log.info('train using lycoris network')
306
+ options.lora.output_dir = args.lora_dir
307
+ options.lora.network_module = 'lycoris.kohya'
308
+ options.lora.in_json = os.path.join(args.process_dir, args.name + '.json')
309
+ # lora specific
310
+ options.lora.save_model_as = 'safetensors'
311
+ options.lora.pretrained_model_name_or_path = args.model
312
+ options.lora.output_name = args.name
313
+ options.lora.max_train_steps = args.steps
314
+ options.lora.network_dim = args.dim
315
+ options.lora.network_alpha = args.dim // 2 if args.alpha == 0 else args.alpha
316
+ options.lora.network_args = []
317
+ options.lora.training_comment = args.comments
318
+ options.lora.sdpa = True
319
+ options.lora.optimizer_type = args.optimizer
320
+ if args.algo is not None:
321
+ options.lora.network_args.append(f'algo={args.algo}')
322
+ if args.args is not None:
323
+ for net_arg in args.args:
324
+ options.lora.network_args.append(net_arg)
325
+ options.lora.gradient_accumulation_steps = args.gradient
326
+ options.lora.learning_rate = args.lr
327
+ options.lora.train_batch_size = args.batch
328
+ options.lora.train_data_dir = args.process_dir
329
+ options.lora.no_half_vae = args.precision == 'fp16'
330
+ # embedding specific
331
+ options.embedding.embedding_name = args.name
332
+ options.embedding.learn_rate = str(args.lr)
333
+ options.embedding.batch_size = args.batch
334
+ options.embedding.steps = args.steps
335
+ options.embedding.data_root = args.process_dir
336
+ options.embedding.log_directory = os.path.join(args.process_dir, 'log')
337
+ options.embedding.gradient_step = args.gradient
338
+
339
+
340
+ def process_inputs():
341
+ import process
342
+ import filetype
343
+ pathlib.Path(args.process_dir).mkdir(parents=True, exist_ok=True)
344
+ processing_options = args.process.split(',') if isinstance(args.process, str) else args.process
345
+ processing_options = [opt.strip() for opt in re.split(',| ', args.process)]
346
+ log.info(f'processing steps: {processing_options}')
347
+ for step in processing_options:
348
+ if step not in valid_steps:
349
+ log.error(f'invalid processing step: {[step]}')
350
+ exit(1)
351
+ for root, _sub_dirs, folder in os.walk(args.input):
352
+ files = [os.path.join(root, f) for f in folder if filetype.is_image(os.path.join(root, f))]
353
+ log.info(f'processing input images: {len(files)}')
354
+ if os.path.exists(args.process_dir):
355
+ if args.overwrite:
356
+ log.warning(f'removing existing processed folder: {args.process_dir}')
357
+ shutil.rmtree(args.process_dir, ignore_errors=True)
358
+ else:
359
+ log.info(f'processed folder exists: {args.process_dir}')
360
+ steps = [step for step in processing_options if step in ['face', 'body', 'original']]
361
+ process.reset()
362
+ options.process.target_size = 1024 if args.sdxl else 512
363
+ metadata = {}
364
+ for step in steps:
365
+ if step == 'face':
366
+ opts = [step for step in processing_options if step not in ['body', 'original']]
367
+ if step == 'body':
368
+ opts = [step for step in processing_options if step not in ['face', 'original', 'upscale', 'restore']] # body does not perform upscale or restore
369
+ if step == 'original':
370
+ opts = [step for step in processing_options if step not in ['face', 'body', 'upscale', 'restore', 'blur', 'range', 'segment']] # original does not perform most steps
371
+ log.info(f'processing current step: {opts}')
372
+ tag = step
373
+ if tag == 'original' and args.tag is not None:
374
+ concept = args.tag.split(',')[0].strip()
375
+ else:
376
+ concept = step
377
+ if args.type in ['lora', 'lyco', 'dreambooth']:
378
+ folder = os.path.join(args.process_dir, str(args.repeats) + '_' + concept) # separate concepts per folder
379
+ if args.type in ['embedding']:
380
+ folder = os.path.join(args.process_dir) # everything into same folder
381
+ log.info(f'processing concept: {concept}')
382
+ log.info(f'processing output folder: {folder}')
383
+ pathlib.Path(folder).mkdir(parents=True, exist_ok=True)
384
+ results = {}
385
+ if server_ok:
386
+ for f in files:
387
+ res = process.file(filename = f, folder = folder, tag = args.tag, requested = opts)
388
+ if res.image: # valid result
389
+ results[res.type] = results.get(res.type, 0) + 1
390
+ results['total'] = results.get('total', 0) + 1
391
+ rel_path = res.basename.replace(os.path.commonpath([res.basename, args.process_dir]), '')
392
+ if rel_path.startswith(os.path.sep):
393
+ rel_path = rel_path[1:]
394
+ metadata[rel_path] = { 'caption': res.caption, 'tags': ','.join(res.tags) }
395
+ if options.lora.in_json is None:
396
+ with open(res.output.replace(options.process.format, '.txt'), "w", encoding='utf-8') as outfile:
397
+ outfile.write(res.caption)
398
+ log.info(f"processing {'saved' if res.image is not None else 'skipped'}: {f} => {res.output} {res.ops} {res.message}")
399
+ else:
400
+ log.info('processing skipped: offline')
401
+ folders = [os.path.join(args.process_dir, folder) for folder in os.listdir(args.process_dir) if os.path.isdir(os.path.join(args.process_dir, folder))]
402
+ log.info(f'input datasets {folders}')
403
+ if options.lora.in_json is not None:
404
+ with open(options.lora.in_json, "w", encoding='utf-8') as outfile: # write json at the end only
405
+ outfile.write(json.dumps(metadata, indent=2))
406
+ for folder in folders: # create latents
407
+ import latents
408
+ latents.create_vae_latents(util.Map({ 'input': folder, 'json': options.lora.in_json }))
409
+ latents.unload_vae()
410
+ r = { 'inputs': len(files), 'outputs': results, 'metadata': options.lora.in_json }
411
+ log.info(f'processing steps result: {r}')
412
+ if args.gradient < 0:
413
+ log.info(f"setting gradient accumulation to number of images: {results['total']}")
414
+ options.lora.gradient_accumulation_steps = results['total']
415
+ options.embedding.gradient_step = results['total']
416
+ process.unload()
417
+
418
+
419
+ if __name__ == '__main__':
420
+ parse_args()
421
+ setup_logging()
422
+ log.info('SD.Next Train')
423
+ sdapi.sd_url = args.server
424
+ if args.user is not None:
425
+ sdapi.sd_username = args.user
426
+ if args.password is not None:
427
+ sdapi.sd_password = args.password
428
+ prepare_server()
429
+ verify_args()
430
+ prepare_options()
431
+ mem_stats()
432
+ process_inputs()
433
+ mem_stats()
434
+ try:
435
+ if args.type == 'embedding':
436
+ train_embedding()
437
+ if args.type == 'lora' or args.type == 'lyco' or args.type == 'dreambooth':
438
+ train_lora()
439
+ except KeyboardInterrupt:
440
+ log.error('interrupt requested')
441
+ sdapi.interrupt()
442
+ mem_stats()
443
+ log.info('done')
cli/util.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ generic helper methods
4
+ """
5
+
6
+ import os
7
+ import string
8
+ import logging
9
+ import warnings
10
+
11
+ log_format = '%(asctime)s %(levelname)s: %(message)s'
12
+ logging.basicConfig(level = logging.INFO, format = log_format)
13
+ warnings.filterwarnings(action="ignore", category=DeprecationWarning)
14
+ warnings.filterwarnings(action="ignore", category=FutureWarning)
15
+ warnings.filterwarnings(action="ignore", category=UserWarning)
16
+ log = logging.getLogger("sd")
17
+
18
+
19
+ def set_logfile(logfile):
20
+ fh = logging.FileHandler(logfile)
21
+ formatter = logging.Formatter(log_format)
22
+ fh.setLevel(log.getEffectiveLevel())
23
+ fh.setFormatter(formatter)
24
+ log.addHandler(fh)
25
+ log.info({ 'log file': logfile })
26
+
27
+
28
+ def safestring(text: str):
29
+ lines = []
30
+ for line in text.splitlines():
31
+ lines.append(line.translate(str.maketrans('', '', string.punctuation)).strip())
32
+ res = ', '.join(lines)
33
+ return res[:1000]
34
+
35
+
36
+ def get_memory():
37
+ def gb(val: float):
38
+ return round(val / 1024 / 1024 / 1024, 2)
39
+ mem = {}
40
+ try:
41
+ import psutil
42
+ process = psutil.Process(os.getpid())
43
+ res = process.memory_info()
44
+ ram_total = 100 * res.rss / process.memory_percent()
45
+ ram = { 'free': gb(ram_total - res.rss), 'used': gb(res.rss), 'total': gb(ram_total) }
46
+ mem.update({ 'ram': ram })
47
+ except Exception as e:
48
+ mem.update({ 'ram': e })
49
+ try:
50
+ import torch
51
+ if torch.cuda.is_available():
52
+ s = torch.cuda.mem_get_info()
53
+ gpu = { 'free': gb(s[0]), 'used': gb(s[1] - s[0]), 'total': gb(s[1]) }
54
+ s = dict(torch.cuda.memory_stats('cuda'))
55
+ allocated = { 'current': gb(s['allocated_bytes.all.current']), 'peak': gb(s['allocated_bytes.all.peak']) }
56
+ reserved = { 'current': gb(s['reserved_bytes.all.current']), 'peak': gb(s['reserved_bytes.all.peak']) }
57
+ active = { 'current': gb(s['active_bytes.all.current']), 'peak': gb(s['active_bytes.all.peak']) }
58
+ inactive = { 'current': gb(s['inactive_split_bytes.all.current']), 'peak': gb(s['inactive_split_bytes.all.peak']) }
59
+ events = { 'retries': s['num_alloc_retries'], 'oom': s['num_ooms'] }
60
+ mem.update({
61
+ 'gpu': gpu,
62
+ 'gpu-active': active,
63
+ 'gpu-allocated': allocated,
64
+ 'gpu-reserved': reserved,
65
+ 'gpu-inactive': inactive,
66
+ 'events': events,
67
+ })
68
+ except Exception:
69
+ pass
70
+ return Map(mem)
71
+
72
+
73
+ class Map(dict): # pylint: disable=C0205
74
+ __slots__ = ('__dict__') # pylint: disable=superfluous-parens
75
+ def __init__(self, *args, **kwargs):
76
+ super(Map, self).__init__(*args, **kwargs) # pylint: disable=super-with-arguments
77
+ for arg in args:
78
+ if isinstance(arg, dict):
79
+ for k, v in arg.items():
80
+ if isinstance(v, dict):
81
+ v = Map(v)
82
+ if isinstance(v, list):
83
+ self.__convert(v)
84
+ self[k] = v
85
+ if kwargs:
86
+ for k, v in kwargs.items():
87
+ if isinstance(v, dict):
88
+ v = Map(v)
89
+ elif isinstance(v, list):
90
+ self.__convert(v)
91
+ self[k] = v
92
+ def __convert(self, v):
93
+ for elem in range(0, len(v)): # pylint: disable=consider-using-enumerate
94
+ if isinstance(v[elem], dict):
95
+ v[elem] = Map(v[elem])
96
+ elif isinstance(v[elem], list):
97
+ self.__convert(v[elem])
98
+ def __getattr__(self, attr):
99
+ return self.get(attr)
100
+ def __setattr__(self, key, value):
101
+ self.__setitem__(key, value)
102
+ def __setitem__(self, key, value):
103
+ super(Map, self).__setitem__(key, value) # pylint: disable=super-with-arguments
104
+ self.__dict__.update({key: value})
105
+ def __delattr__(self, item):
106
+ self.__delitem__(item)
107
+ def __delitem__(self, key):
108
+ super(Map, self).__delitem__(key) # pylint: disable=super-with-arguments
109
+ del self.__dict__[key]
110
+
111
+
112
+ if __name__ == "__main__":
113
+ pass
cli/validate-locale.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+ import sys
5
+ import json
6
+ from rich import print # pylint: disable=redefined-builtin
7
+
8
+ if __name__ == "__main__":
9
+ sys.argv.pop(0)
10
+ fn = sys.argv[0] if len(sys.argv) > 0 else 'locale_en.json'
11
+ if not os.path.isfile(fn):
12
+ print(f'File not found: {fn}')
13
+ sys.exit(1)
14
+ with open(fn, 'r', encoding="utf-8") as f:
15
+ data = json.load(f)
16
+ keys = []
17
+ t_names = 0
18
+ t_hints = 0
19
+ t_localized = 0
20
+ t_long = 0
21
+ for k in data.keys():
22
+ names = len(data[k])
23
+ t_names += names
24
+ hints = len([k for k in data[k] if k["hint"] != ""])
25
+ t_hints += hints
26
+ localized = len([k for k in data[k] if k["localized"] != ""])
27
+ t_localized += localized
28
+ missing = names - hints
29
+ long = 0
30
+ for v in data[k]:
31
+ if v['label'] in keys:
32
+ print(f' Duplicate: {k}.{v["label"]}')
33
+ else:
34
+ if len(v['label']) > 63:
35
+ long += 1
36
+ print(f' Long label: {k}.{v["label"]}')
37
+ keys.append(v['label'])
38
+ t_long += long
39
+ print(f'Section: [bold magenta]{k.ljust(20)}[/bold magenta] entries={names} localized={"[bold green]" + str(localized) + "[/bold green]" if localized > 0 else "0"} long={"[bold red]" + str(long) + "[/bold red]" if long > 0 else "0"} hints={hints} missing={"[bold red]" + str(missing) + "[/bold red]" if missing > 0 else "[bold green]0[/bold green]"}')
40
+ print(f'Totals: entries={t_names} localized={localized} long={t_long} hints={t_hints} missing={t_names - t_hints}')
cli/video-extract.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ use ffmpeg for animation processing
4
+ """
5
+ import os
6
+ import json
7
+ import subprocess
8
+ import pathlib
9
+ import argparse
10
+ import filetype
11
+ from util import log, Map
12
+
13
+
14
+ def probe(src: str):
15
+ cmd = f"ffprobe -hide_banner -loglevel 0 -print_format json -show_format -show_streams \"{src}\""
16
+ result = subprocess.run(cmd, shell = True, capture_output = True, text = True, check = True)
17
+ data = json.loads(result.stdout)
18
+ stream = [x for x in data['streams'] if x["codec_type"] == "video"][0]
19
+ fmt = data['format'] if 'format' in data else {}
20
+ res = {**stream, **fmt}
21
+ video = Map({
22
+ 'codec': res.get('codec_name', 'unknown') + '/' + res.get('codec_tag_string', ''),
23
+ 'resolution': [int(res.get('width', 0)), int(res.get('height', 0))],
24
+ 'duration': float(res.get('duration', 0)),
25
+ 'frames': int(res.get('nb_frames', 0)),
26
+ 'bitrate': round(float(res.get('bit_rate', 0)) / 1024),
27
+ })
28
+ return video
29
+
30
+
31
+ def extract(src: str, dst: str, rate: float = 0.015, fps: float = 0, start = 0, end = 0):
32
+ images = []
33
+ if not os.path.isfile(src) or not filetype.is_video(src):
34
+ log.error({ 'extract': 'input is not movie file' })
35
+ return 0
36
+ dst = dst if dst.endswith('/') else dst + '/'
37
+
38
+ video = probe(src)
39
+ log.info({ 'extract': { 'source': src, **video } })
40
+
41
+ ssstart = f' -ss {start}' if start > 0 else ''
42
+ ssend = f' -to {video.duration - end}' if start > 0 else ''
43
+ filename = pathlib.Path(src).stem
44
+ if rate > 0:
45
+ cmd = f"ffmpeg -hide_banner -y -loglevel info {ssstart} {ssend} -i \"{src}\" -filter:v \"select='gt(scene,{rate})',metadata=print\" -vsync vfr -frame_pts 1 \"{dst}{filename}-%05d.jpg\""
46
+ elif fps > 0:
47
+ cmd = f"ffmpeg -hide_banner -y -loglevel info {ssstart} {ssend} -i \"{src}\" -r {fps} -vsync vfr -frame_pts 1 \"{dst}{filename}-%05d.jpg\""
48
+ else:
49
+ log.error({ 'extract': 'requires either rate or fps' })
50
+ return 0
51
+ log.debug({ 'extract': cmd })
52
+ pathlib.Path(dst).mkdir(parents = True, exist_ok = True)
53
+ result = subprocess.run(cmd, shell = True, capture_output = True, text = True, check = True)
54
+ for line in result.stderr.split('\n'):
55
+ if 'pts_time' in line:
56
+ log.debug({ 'extract': { 'keyframe': line.strip().split(' ')[-1].split(':')[-1] } })
57
+ images = next(os.walk(dst))[2]
58
+ log.info({ 'extract': { 'destination': dst, 'keyframes': len(images), 'rate': rate, 'fps': fps } })
59
+ return len(images)
60
+
61
+
62
+ if __name__ == "__main__":
63
+ parser = argparse.ArgumentParser(description="ffmpeg pipeline")
64
+ parser.add_argument("--input", type = str, required = True, help="input")
65
+ parser.add_argument("--output", type = str, required = True, help="output")
66
+ parser.add_argument("--rate", type = float, default = 0, required = False, help="extraction change rate threshold")
67
+ parser.add_argument("--fps", type = float, default = 0, required = False, help="extraction frames per second")
68
+ parser.add_argument("--skipstart", type = float, default = 1, required = False, help="skip time from start of video")
69
+ parser.add_argument("--skipend", type = float, default = 1, required = False, help="skip time to end of video")
70
+ params = parser.parse_args()
71
+ extract(src = params.input, dst = params.output, rate = params.rate, fps = params.fps, start = params.skipstart, end = params.skipend)
configs/alt-diffusion-inference.yaml ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-04
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.00085
6
+ linear_end: 0.0120
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: "jpg"
11
+ cond_stage_key: "txt"
12
+ image_size: 64
13
+ channels: 4
14
+ cond_stage_trainable: false # Note: different from the one we trained before
15
+ conditioning_key: crossattn
16
+ monitor: val/loss_simple_ema
17
+ scale_factor: 0.18215
18
+ use_ema: False
19
+
20
+ scheduler_config: # 10000 warmup steps
21
+ target: ldm.lr_scheduler.LambdaLinearScheduler
22
+ params:
23
+ warm_up_steps: [ 10000 ]
24
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25
+ f_start: [ 1.e-6 ]
26
+ f_max: [ 1. ]
27
+ f_min: [ 1. ]
28
+
29
+ unet_config:
30
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31
+ params:
32
+ image_size: 32 # unused
33
+ in_channels: 4
34
+ out_channels: 4
35
+ model_channels: 320
36
+ attention_resolutions: [ 4, 2, 1 ]
37
+ num_res_blocks: 2
38
+ channel_mult: [ 1, 2, 4, 4 ]
39
+ num_heads: 8
40
+ use_spatial_transformer: True
41
+ transformer_depth: 1
42
+ context_dim: 768
43
+ use_checkpoint: True
44
+ legacy: False
45
+
46
+ first_stage_config:
47
+ target: ldm.models.autoencoder.AutoencoderKL
48
+ params:
49
+ embed_dim: 4
50
+ monitor: val/rec_loss
51
+ ddconfig:
52
+ double_z: true
53
+ z_channels: 4
54
+ resolution: 256
55
+ in_channels: 3
56
+ out_ch: 3
57
+ ch: 128
58
+ ch_mult:
59
+ - 1
60
+ - 2
61
+ - 4
62
+ - 4
63
+ num_res_blocks: 2
64
+ attn_resolutions: []
65
+ dropout: 0.0
66
+ lossconfig:
67
+ target: torch.nn.Identity
68
+
69
+ cond_stage_config:
70
+ target: modules.xlmr.BertSeriesModelWithTransformation
71
+ params:
72
+ name: "XLMR-Large"
configs/instruct-pix2pix.yaml ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
2
+ # See more details in LICENSE.
3
+
4
+ model:
5
+ base_learning_rate: 1.0e-04
6
+ target: modules.hijack.ddpm_edit.LatentDiffusion
7
+ params:
8
+ linear_start: 0.00085
9
+ linear_end: 0.0120
10
+ num_timesteps_cond: 1
11
+ log_every_t: 200
12
+ timesteps: 1000
13
+ first_stage_key: edited
14
+ cond_stage_key: edit
15
+ # image_size: 64
16
+ # image_size: 32
17
+ image_size: 16
18
+ channels: 4
19
+ cond_stage_trainable: false # Note: different from the one we trained before
20
+ conditioning_key: hybrid
21
+ monitor: val/loss_simple_ema
22
+ scale_factor: 0.18215
23
+ use_ema: false
24
+
25
+ scheduler_config: # 10000 warmup steps
26
+ target: ldm.lr_scheduler.LambdaLinearScheduler
27
+ params:
28
+ warm_up_steps: [ 0 ]
29
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
30
+ f_start: [ 1.e-6 ]
31
+ f_max: [ 1. ]
32
+ f_min: [ 1. ]
33
+
34
+ unet_config:
35
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
36
+ params:
37
+ image_size: 32 # unused
38
+ in_channels: 8
39
+ out_channels: 4
40
+ model_channels: 320
41
+ attention_resolutions: [ 4, 2, 1 ]
42
+ num_res_blocks: 2
43
+ channel_mult: [ 1, 2, 4, 4 ]
44
+ num_heads: 8
45
+ use_spatial_transformer: True
46
+ transformer_depth: 1
47
+ context_dim: 768
48
+ use_checkpoint: True
49
+ legacy: False
50
+
51
+ first_stage_config:
52
+ target: ldm.models.autoencoder.AutoencoderKL
53
+ params:
54
+ embed_dim: 4
55
+ monitor: val/rec_loss
56
+ ddconfig:
57
+ double_z: true
58
+ z_channels: 4
59
+ resolution: 256
60
+ in_channels: 3
61
+ out_ch: 3
62
+ ch: 128
63
+ ch_mult:
64
+ - 1
65
+ - 2
66
+ - 4
67
+ - 4
68
+ num_res_blocks: 2
69
+ attn_resolutions: []
70
+ dropout: 0.0
71
+ lossconfig:
72
+ target: torch.nn.Identity
73
+
74
+ cond_stage_config:
75
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
76
+
77
+ data:
78
+ target: main.DataModuleFromConfig
79
+ params:
80
+ batch_size: 128
81
+ num_workers: 1
82
+ wrap: false
83
+ validation:
84
+ target: edit_dataset.EditDataset
85
+ params:
86
+ path: data/clip-filtered-dataset
87
+ cache_dir: data/
88
+ cache_name: data_10k
89
+ split: val
90
+ min_text_sim: 0.2
91
+ min_image_sim: 0.75
92
+ min_direction_sim: 0.2
93
+ max_samples_per_prompt: 1
94
+ min_resize_res: 512
95
+ max_resize_res: 512
96
+ crop_res: 512
97
+ output_as_edit: False
98
+ real_input: True