maccmaccmaccc commited on
Commit
2bf74f8
1 Parent(s): 18971d1

Upload 45 files

Browse files
Files changed (46) hide show
  1. .gitattributes +8 -0
  2. data/README.md +87 -0
  3. data/flux/.github/workflows/ci.yaml +20 -0
  4. data/flux/.gitignore +230 -0
  5. data/flux/LICENSE +201 -0
  6. data/flux/README.md +87 -0
  7. data/flux/assets/cup.png +3 -0
  8. data/flux/assets/cup_mask.png +0 -0
  9. data/flux/assets/dev_grid.jpg +3 -0
  10. data/flux/assets/docs/canny.png +3 -0
  11. data/flux/assets/docs/depth.png +3 -0
  12. data/flux/assets/docs/inpainting.png +3 -0
  13. data/flux/assets/docs/outpainting.png +3 -0
  14. data/flux/assets/docs/redux.png +0 -0
  15. data/flux/assets/grid.jpg +3 -0
  16. data/flux/assets/robot.webp +0 -0
  17. data/flux/assets/schnell_grid.jpg +3 -0
  18. data/flux/demo_gr.py +247 -0
  19. data/flux/demo_st.py +293 -0
  20. data/flux/demo_st_fill.py +487 -0
  21. data/flux/docs/fill.md +44 -0
  22. data/flux/docs/image-variation.md +33 -0
  23. data/flux/docs/structural-conditioning.md +40 -0
  24. data/flux/docs/text-to-image.md +93 -0
  25. data/flux/model_cards/FLUX.1-dev.md +46 -0
  26. data/flux/model_cards/FLUX.1-schnell.md +41 -0
  27. data/flux/model_licenses/LICENSE-FLUX1-dev +42 -0
  28. data/flux/model_licenses/LICENSE-FLUX1-schnell +54 -0
  29. data/flux/pyproject.toml +99 -0
  30. data/flux/setup.py +3 -0
  31. data/flux/src/flux/__init__.py +13 -0
  32. data/flux/src/flux/__main__.py +4 -0
  33. data/flux/src/flux/api.py +225 -0
  34. data/flux/src/flux/cli.py +238 -0
  35. data/flux/src/flux/cli_control.py +347 -0
  36. data/flux/src/flux/cli_fill.py +334 -0
  37. data/flux/src/flux/cli_redux.py +279 -0
  38. data/flux/src/flux/math.py +30 -0
  39. data/flux/src/flux/model.py +143 -0
  40. data/flux/src/flux/modules/autoencoder.py +312 -0
  41. data/flux/src/flux/modules/conditioner.py +37 -0
  42. data/flux/src/flux/modules/image_embedders.py +103 -0
  43. data/flux/src/flux/modules/layers.py +253 -0
  44. data/flux/src/flux/modules/lora.py +94 -0
  45. data/flux/src/flux/sampling.py +282 -0
  46. data/flux/src/flux/util.py +447 -0
.gitattributes CHANGED
@@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/flux/assets/cup.png filter=lfs diff=lfs merge=lfs -text
37
+ data/flux/assets/dev_grid.jpg filter=lfs diff=lfs merge=lfs -text
38
+ data/flux/assets/docs/canny.png filter=lfs diff=lfs merge=lfs -text
39
+ data/flux/assets/docs/depth.png filter=lfs diff=lfs merge=lfs -text
40
+ data/flux/assets/docs/inpainting.png filter=lfs diff=lfs merge=lfs -text
41
+ data/flux/assets/docs/outpainting.png filter=lfs diff=lfs merge=lfs -text
42
+ data/flux/assets/grid.jpg filter=lfs diff=lfs merge=lfs -text
43
+ data/flux/assets/schnell_grid.jpg filter=lfs diff=lfs merge=lfs -text
data/README.md ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FLUX
2
+ by Black Forest Labs: https://blackforestlabs.ai. Documentation for our API can be found here: [docs.bfl.ml](https://docs.bfl.ml/).
3
+
4
+ ![grid](assets/grid.jpg)
5
+
6
+ This repo contains minimal inference code to run image generation & editing with our Flux models.
7
+
8
+ ## Local installation
9
+
10
+ ```bash
11
+ cd $HOME && git clone https://github.com/black-forest-labs/flux
12
+ cd $HOME/flux
13
+ python3.10 -m venv .venv
14
+ source .venv/bin/activate
15
+ pip install -e ".[all]"
16
+ ```
17
+
18
+ ### Models
19
+
20
+ We are offering an extensive suite of models. For more information about the invidual models, please refer to the link under **Usage**.
21
+
22
+ | Name | Usage | HuggingFace repo | License |
23
+ | --------------------------- | ---------------------------------------------------------- | ------------------------------------------------------------- | --------------------------------------------------------------------- |
24
+ | `FLUX.1 [schnell]` | [Text to Image](docs/text-to-image.md) | https://huggingface.co/black-forest-labs/FLUX.1-schnell | [apache-2.0](model_licenses/LICENSE-FLUX1-schnell) |
25
+ | `FLUX.1 [dev]` | [Text to Image](docs/text-to-image.md) | https://huggingface.co/black-forest-labs/FLUX.1-dev | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) |
26
+ | `FLUX.1 Fill [dev]` | [In/Out-painting](docs/fill.md) | https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) |
27
+ | `FLUX.1 Canny [dev]` | [Structural Conditioning](docs/structural-conditioning.md) | https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) |
28
+ | `FLUX.1 Depth [dev]` | [Structural Conditioning](docs/structural-conditioning.md) | https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) |
29
+ | `FLUX.1 Canny [dev] LoRA` | [Structural Conditioning](docs/structural-conditioning.md) | https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev-lora | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) |
30
+ | `FLUX.1 Depth [dev] LoRA` | [Structural Conditioning](docs/structural-conditioning.md) | https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev-lora | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) |
31
+ | `FLUX.1 Redux [dev]` | [Image variation](docs/image-variation.md) | https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) |
32
+ | `FLUX.1 [pro]` | [Text to Image](docs/text-to-image.md) | [Available in our API.](https://docs.bfl.ml/) |
33
+ | `FLUX1.1 [pro]` | [Text to Image](docs/text-to-image.md) | [Available in our API.](https://docs.bfl.ml/) |
34
+ | `FLUX1.1 [pro] Ultra/raw` | [Text to Image](docs/text-to-image.md) | [Available in our API.](https://docs.bfl.ml/) |
35
+ | `FLUX.1 Fill [pro]` | [In/Out-painting](docs/fill.md) | [Available in our API.](https://docs.bfl.ml/) |
36
+ | `FLUX.1 Canny [pro]` | [Structural Conditioning](docs/controlnet.md) | [Available in our API.](https://docs.bfl.ml/) |
37
+ | `FLUX.1 Depth [pro]` | [Structural Conditioning](docs/controlnet.md) | [Available in our API.](https://docs.bfl.ml/) |
38
+ | `FLUX1.1 Redux [pro]` | [Image variation](docs/image-variation.md) | [Available in our API.](https://docs.bfl.ml/) |
39
+ | `FLUX1.1 Redux [pro] Ultra` | [Image variation](docs/image-variation.md) | [Available in our API.](https://docs.bfl.ml/) |
40
+
41
+ The weights of the autoencoder are also released under [apache-2.0](https://huggingface.co/datasets/choosealicense/licenses/blob/main/markdown/apache-2.0.md) and can be found in the HuggingFace repos above.
42
+
43
+ ## API usage
44
+
45
+ Our API offers access to our models. It is documented here:
46
+ [docs.bfl.ml](https://docs.bfl.ml/).
47
+
48
+ In this repository we also offer an easy python interface. To use this, you
49
+ first need to register with the API on [api.bfl.ml](https://api.bfl.ml/), and
50
+ create a new API key.
51
+
52
+ To use the API key either run `export BFL_API_KEY=<your_key_here>` or provide
53
+ it via the `api_key=<your_key_here>` parameter. It is also expected that you
54
+ have installed the package as above.
55
+
56
+ Usage from python:
57
+
58
+ ```python
59
+ from flux.api import ImageRequest
60
+
61
+ # this will create an api request directly but not block until the generation is finished
62
+ request = ImageRequest("A beautiful beach", name="flux.1.1-pro")
63
+ # or: request = ImageRequest("A beautiful beach", name="flux.1.1-pro", api_key="your_key_here")
64
+
65
+ # any of the following will block until the generation is finished
66
+ request.url
67
+ # -> https:<...>/sample.jpg
68
+ request.bytes
69
+ # -> b"..." bytes for the generated image
70
+ request.save("outputs/api.jpg")
71
+ # saves the sample to local storage
72
+ request.image
73
+ # -> a PIL image
74
+ ```
75
+
76
+ Usage from the command line:
77
+
78
+ ```bash
79
+ $ python -m flux.api --prompt="A beautiful beach" url
80
+ https:<...>/sample.jpg
81
+
82
+ # generate and save the result
83
+ $ python -m flux.api --prompt="A beautiful beach" save outputs/api
84
+
85
+ # open the image directly
86
+ $ python -m flux.api --prompt="A beautiful beach" image show
87
+ ```
data/flux/.github/workflows/ci.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: CI
2
+ on: push
3
+ jobs:
4
+ lint:
5
+ runs-on: ubuntu-latest
6
+ steps:
7
+ - uses: actions/checkout@v2
8
+ - uses: actions/setup-python@v2
9
+ with:
10
+ python-version: "3.10"
11
+ - name: Install dependencies
12
+ run: |
13
+ python -m pip install --upgrade pip
14
+ pip install ruff==0.6.8
15
+ - name: Run Ruff
16
+ run: ruff check --output-format=github .
17
+ - name: Check imports
18
+ run: ruff check --select I --output-format=github .
19
+ - name: Check formatting
20
+ run: ruff format --check .
data/flux/.gitignore ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Created by https://www.toptal.com/developers/gitignore/api/linux,windows,macos,visualstudiocode,python
2
+ # Edit at https://www.toptal.com/developers/gitignore?templates=linux,windows,macos,visualstudiocode,python
3
+
4
+ ### Linux ###
5
+ *~
6
+
7
+ # temporary files which can be created if a process still has a handle open of a deleted file
8
+ .fuse_hidden*
9
+
10
+ # KDE directory preferences
11
+ .directory
12
+
13
+ # Linux trash folder which might appear on any partition or disk
14
+ .Trash-*
15
+
16
+ # .nfs files are created when an open file is removed but is still being accessed
17
+ .nfs*
18
+
19
+ ### macOS ###
20
+ # General
21
+ .DS_Store
22
+ .AppleDouble
23
+ .LSOverride
24
+
25
+ # Icon must end with two \r
26
+ Icon
27
+
28
+
29
+ # Thumbnails
30
+ ._*
31
+
32
+ # Files that might appear in the root of a volume
33
+ .DocumentRevisions-V100
34
+ .fseventsd
35
+ .Spotlight-V100
36
+ .TemporaryItems
37
+ .Trashes
38
+ .VolumeIcon.icns
39
+ .com.apple.timemachine.donotpresent
40
+
41
+ # Directories potentially created on remote AFP share
42
+ .AppleDB
43
+ .AppleDesktop
44
+ Network Trash Folder
45
+ Temporary Items
46
+ .apdisk
47
+
48
+ ### Python ###
49
+ # Byte-compiled / optimized / DLL files
50
+ __pycache__/
51
+ *.py[cod]
52
+ *$py.class
53
+
54
+ # C extensions
55
+ *.so
56
+
57
+ # Distribution / packaging
58
+ .Python
59
+ build/
60
+ develop-eggs/
61
+ dist/
62
+ downloads/
63
+ eggs/
64
+ .eggs/
65
+ lib/
66
+ lib64/
67
+ parts/
68
+ sdist/
69
+ var/
70
+ wheels/
71
+ share/python-wheels/
72
+ *.egg-info/
73
+ .installed.cfg
74
+ *.egg
75
+ MANIFEST
76
+
77
+ # PyInstaller
78
+ # Usually these files are written by a python script from a template
79
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
80
+ *.manifest
81
+ *.spec
82
+
83
+ # Installer logs
84
+ pip-log.txt
85
+ pip-delete-this-directory.txt
86
+
87
+ # Unit test / coverage reports
88
+ htmlcov/
89
+ .tox/
90
+ .nox/
91
+ .coverage
92
+ .coverage.*
93
+ .cache
94
+ nosetests.xml
95
+ coverage.xml
96
+ *.cover
97
+ *.py,cover
98
+ .hypothesis/
99
+ .pytest_cache/
100
+ cover/
101
+
102
+ # Translations
103
+ *.mo
104
+ *.pot
105
+
106
+ # Django stuff:
107
+ *.log
108
+ local_settings.py
109
+ db.sqlite3
110
+ db.sqlite3-journal
111
+
112
+ # Flask stuff:
113
+ instance/
114
+ .webassets-cache
115
+
116
+ # Scrapy stuff:
117
+ .scrapy
118
+
119
+ # Sphinx documentation
120
+ docs/_build/
121
+
122
+ # PyBuilder
123
+ .pybuilder/
124
+ target/
125
+
126
+ # Jupyter Notebook
127
+ .ipynb_checkpoints
128
+
129
+ # IPython
130
+ profile_default/
131
+ ipython_config.py
132
+
133
+ # pyenv
134
+ # For a library or package, you might want to ignore these files since the code is
135
+ # intended to run in multiple environments; otherwise, check them in:
136
+ # .python-version
137
+
138
+ # pipenv
139
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
140
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
141
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
142
+ # install all needed dependencies.
143
+ #Pipfile.lock
144
+
145
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
146
+ __pypackages__/
147
+
148
+ # Celery stuff
149
+ celerybeat-schedule
150
+ celerybeat.pid
151
+
152
+ # SageMath parsed files
153
+ *.sage.py
154
+
155
+ # Environments
156
+ .env
157
+ .venv
158
+ env/
159
+ venv/
160
+ ENV/
161
+ env.bak/
162
+ venv.bak/
163
+
164
+ # Spyder project settings
165
+ .spyderproject
166
+ .spyproject
167
+
168
+ # Rope project settings
169
+ .ropeproject
170
+
171
+ # mkdocs documentation
172
+ /site
173
+
174
+ # mypy
175
+ .mypy_cache/
176
+ .dmypy.json
177
+ dmypy.json
178
+
179
+ # Pyre type checker
180
+ .pyre/
181
+
182
+ # pytype static type analyzer
183
+ .pytype/
184
+
185
+ # Cython debug symbols
186
+ cython_debug/
187
+
188
+ ### VisualStudioCode ###
189
+ .vscode/*
190
+ !.vscode/settings.json
191
+ !.vscode/tasks.json
192
+ !.vscode/launch.json
193
+ !.vscode/extensions.json
194
+ *.code-workspace
195
+
196
+ # Local History for Visual Studio Code
197
+ .history/
198
+
199
+ ### VisualStudioCode Patch ###
200
+ # Ignore all local history of files
201
+ .history
202
+ .ionide
203
+
204
+ ### Windows ###
205
+ # Windows thumbnail cache files
206
+ Thumbs.db
207
+ Thumbs.db:encryptable
208
+ ehthumbs.db
209
+ ehthumbs_vista.db
210
+
211
+ # Dump file
212
+ *.stackdump
213
+
214
+ # Folder config file
215
+ [Dd]esktop.ini
216
+
217
+ # Recycle Bin used on file shares
218
+ $RECYCLE.BIN/
219
+
220
+ # Windows Installer files
221
+ *.cab
222
+ *.msi
223
+ *.msix
224
+ *.msm
225
+ *.msp
226
+
227
+ # Windows shortcuts
228
+ *.lnk
229
+
230
+ # End of https://www.toptal.com/developers/gitignore/api/linux,windows,macos,visualstudiocode,python
data/flux/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
data/flux/README.md ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FLUX
2
+ by Black Forest Labs: https://blackforestlabs.ai. Documentation for our API can be found here: [docs.bfl.ml](https://docs.bfl.ml/).
3
+
4
+ ![grid](assets/grid.jpg)
5
+
6
+ This repo contains minimal inference code to run image generation & editing with our Flux models.
7
+
8
+ ## Local installation
9
+
10
+ ```bash
11
+ cd $HOME && git clone https://github.com/black-forest-labs/flux
12
+ cd $HOME/flux
13
+ python3.10 -m venv .venv
14
+ source .venv/bin/activate
15
+ pip install -e ".[all]"
16
+ ```
17
+
18
+ ### Models
19
+
20
+ We are offering an extensive suite of models. For more information about the invidual models, please refer to the link under **Usage**.
21
+
22
+ | Name | Usage | HuggingFace repo | License |
23
+ | --------------------------- | ---------------------------------------------------------- | ------------------------------------------------------------- | --------------------------------------------------------------------- |
24
+ | `FLUX.1 [schnell]` | [Text to Image](docs/text-to-image.md) | https://huggingface.co/black-forest-labs/FLUX.1-schnell | [apache-2.0](model_licenses/LICENSE-FLUX1-schnell) |
25
+ | `FLUX.1 [dev]` | [Text to Image](docs/text-to-image.md) | https://huggingface.co/black-forest-labs/FLUX.1-dev | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) |
26
+ | `FLUX.1 Fill [dev]` | [In/Out-painting](docs/fill.md) | https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) |
27
+ | `FLUX.1 Canny [dev]` | [Structural Conditioning](docs/structural-conditioning.md) | https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) |
28
+ | `FLUX.1 Depth [dev]` | [Structural Conditioning](docs/structural-conditioning.md) | https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) |
29
+ | `FLUX.1 Canny [dev] LoRA` | [Structural Conditioning](docs/structural-conditioning.md) | https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev-lora | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) |
30
+ | `FLUX.1 Depth [dev] LoRA` | [Structural Conditioning](docs/structural-conditioning.md) | https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev-lora | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) |
31
+ | `FLUX.1 Redux [dev]` | [Image variation](docs/image-variation.md) | https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) |
32
+ | `FLUX.1 [pro]` | [Text to Image](docs/text-to-image.md) | [Available in our API.](https://docs.bfl.ml/) |
33
+ | `FLUX1.1 [pro]` | [Text to Image](docs/text-to-image.md) | [Available in our API.](https://docs.bfl.ml/) |
34
+ | `FLUX1.1 [pro] Ultra/raw` | [Text to Image](docs/text-to-image.md) | [Available in our API.](https://docs.bfl.ml/) |
35
+ | `FLUX.1 Fill [pro]` | [In/Out-painting](docs/fill.md) | [Available in our API.](https://docs.bfl.ml/) |
36
+ | `FLUX.1 Canny [pro]` | [Structural Conditioning](docs/controlnet.md) | [Available in our API.](https://docs.bfl.ml/) |
37
+ | `FLUX.1 Depth [pro]` | [Structural Conditioning](docs/controlnet.md) | [Available in our API.](https://docs.bfl.ml/) |
38
+ | `FLUX1.1 Redux [pro]` | [Image variation](docs/image-variation.md) | [Available in our API.](https://docs.bfl.ml/) |
39
+ | `FLUX1.1 Redux [pro] Ultra` | [Image variation](docs/image-variation.md) | [Available in our API.](https://docs.bfl.ml/) |
40
+
41
+ The weights of the autoencoder are also released under [apache-2.0](https://huggingface.co/datasets/choosealicense/licenses/blob/main/markdown/apache-2.0.md) and can be found in the HuggingFace repos above.
42
+
43
+ ## API usage
44
+
45
+ Our API offers access to our models. It is documented here:
46
+ [docs.bfl.ml](https://docs.bfl.ml/).
47
+
48
+ In this repository we also offer an easy python interface. To use this, you
49
+ first need to register with the API on [api.bfl.ml](https://api.bfl.ml/), and
50
+ create a new API key.
51
+
52
+ To use the API key either run `export BFL_API_KEY=<your_key_here>` or provide
53
+ it via the `api_key=<your_key_here>` parameter. It is also expected that you
54
+ have installed the package as above.
55
+
56
+ Usage from python:
57
+
58
+ ```python
59
+ from flux.api import ImageRequest
60
+
61
+ # this will create an api request directly but not block until the generation is finished
62
+ request = ImageRequest("A beautiful beach", name="flux.1.1-pro")
63
+ # or: request = ImageRequest("A beautiful beach", name="flux.1.1-pro", api_key="your_key_here")
64
+
65
+ # any of the following will block until the generation is finished
66
+ request.url
67
+ # -> https:<...>/sample.jpg
68
+ request.bytes
69
+ # -> b"..." bytes for the generated image
70
+ request.save("outputs/api.jpg")
71
+ # saves the sample to local storage
72
+ request.image
73
+ # -> a PIL image
74
+ ```
75
+
76
+ Usage from the command line:
77
+
78
+ ```bash
79
+ $ python -m flux.api --prompt="A beautiful beach" url
80
+ https:<...>/sample.jpg
81
+
82
+ # generate and save the result
83
+ $ python -m flux.api --prompt="A beautiful beach" save outputs/api
84
+
85
+ # open the image directly
86
+ $ python -m flux.api --prompt="A beautiful beach" image show
87
+ ```
data/flux/assets/cup.png ADDED

Git LFS Details

  • SHA256: bdffdc181455480d3c9dcb3ca9d91e088c1ac162a12a9c4f45fc56db0ef86787
  • Pointer size: 132 Bytes
  • Size of remote file: 4.36 MB
data/flux/assets/cup_mask.png ADDED
data/flux/assets/dev_grid.jpg ADDED

Git LFS Details

  • SHA256: 37a587d0ff3d9dda0d8ab59d65342c0242ffb909573d8d998d599e3401d3d7e9
  • Pointer size: 132 Bytes
  • Size of remote file: 1.3 MB
data/flux/assets/docs/canny.png ADDED

Git LFS Details

  • SHA256: 3d0d6a4ca1f5bdb1149a8e9cd919df0895f014b3e8abbb5a182d55aa2a940af5
  • Pointer size: 132 Bytes
  • Size of remote file: 2.02 MB
data/flux/assets/docs/depth.png ADDED

Git LFS Details

  • SHA256: d9343b1f1e7328300af8af74b1cf47b1e05566148ba18692682b6389b22c6196
  • Pointer size: 132 Bytes
  • Size of remote file: 1.41 MB
data/flux/assets/docs/inpainting.png ADDED

Git LFS Details

  • SHA256: 7d405d7029c2571879106e1efbeda054b1635c4f70fe5945c357f608b8c78ca5
  • Pointer size: 132 Bytes
  • Size of remote file: 3.95 MB
data/flux/assets/docs/outpainting.png ADDED

Git LFS Details

  • SHA256: a15f0e9c0a3794d5a7a6ba95601adf5a0d06ec836dd3c88860fdbafc27ebe99a
  • Pointer size: 132 Bytes
  • Size of remote file: 2.15 MB
data/flux/assets/docs/redux.png ADDED
data/flux/assets/grid.jpg ADDED

Git LFS Details

  • SHA256: 0e0a4c0c510f3659452aed0a25fbea6a62bfdefc05c40b096b92d615a3608a3d
  • Pointer size: 132 Bytes
  • Size of remote file: 2.14 MB
data/flux/assets/robot.webp ADDED
data/flux/assets/schnell_grid.jpg ADDED

Git LFS Details

  • SHA256: 5bff488da88933a529825e4fd72e10c3824e22508dc1d6f36b49850fa517ac44
  • Pointer size: 132 Bytes
  • Size of remote file: 1.67 MB
data/flux/demo_gr.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import uuid
4
+
5
+ import gradio as gr
6
+ import numpy as np
7
+ import torch
8
+ from einops import rearrange
9
+ from PIL import ExifTags, Image
10
+ from transformers import pipeline
11
+
12
+ from flux.cli import SamplingOptions
13
+ from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
14
+ from flux.util import configs, embed_watermark, load_ae, load_clip, load_flow_model, load_t5
15
+
16
+ NSFW_THRESHOLD = 0.85
17
+
18
+
19
+ def get_models(name: str, device: torch.device, offload: bool, is_schnell: bool):
20
+ t5 = load_t5(device, max_length=256 if is_schnell else 512)
21
+ clip = load_clip(device)
22
+ model = load_flow_model(name, device="cpu" if offload else device)
23
+ ae = load_ae(name, device="cpu" if offload else device)
24
+ nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
25
+ return model, ae, t5, clip, nsfw_classifier
26
+
27
+
28
+ class FluxGenerator:
29
+ def __init__(self, model_name: str, device: str, offload: bool):
30
+ self.device = torch.device(device)
31
+ self.offload = offload
32
+ self.model_name = model_name
33
+ self.is_schnell = model_name == "flux-schnell"
34
+ self.model, self.ae, self.t5, self.clip, self.nsfw_classifier = get_models(
35
+ model_name,
36
+ device=self.device,
37
+ offload=self.offload,
38
+ is_schnell=self.is_schnell,
39
+ )
40
+
41
+ @torch.inference_mode()
42
+ def generate_image(
43
+ self,
44
+ width,
45
+ height,
46
+ num_steps,
47
+ guidance,
48
+ seed,
49
+ prompt,
50
+ init_image=None,
51
+ image2image_strength=0.0,
52
+ add_sampling_metadata=True,
53
+ ):
54
+ seed = int(seed)
55
+ if seed == -1:
56
+ seed = None
57
+
58
+ opts = SamplingOptions(
59
+ prompt=prompt,
60
+ width=width,
61
+ height=height,
62
+ num_steps=num_steps,
63
+ guidance=guidance,
64
+ seed=seed,
65
+ )
66
+
67
+ if opts.seed is None:
68
+ opts.seed = torch.Generator(device="cpu").seed()
69
+ print(f"Generating '{opts.prompt}' with seed {opts.seed}")
70
+ t0 = time.perf_counter()
71
+
72
+ if init_image is not None:
73
+ if isinstance(init_image, np.ndarray):
74
+ init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 255.0
75
+ init_image = init_image.unsqueeze(0)
76
+ init_image = init_image.to(self.device)
77
+ init_image = torch.nn.functional.interpolate(init_image, (opts.height, opts.width))
78
+ if self.offload:
79
+ self.ae.encoder.to(self.device)
80
+ init_image = self.ae.encode(init_image.to())
81
+ if self.offload:
82
+ self.ae = self.ae.cpu()
83
+ torch.cuda.empty_cache()
84
+
85
+ # prepare input
86
+ x = get_noise(
87
+ 1,
88
+ opts.height,
89
+ opts.width,
90
+ device=self.device,
91
+ dtype=torch.bfloat16,
92
+ seed=opts.seed,
93
+ )
94
+ timesteps = get_schedule(
95
+ opts.num_steps,
96
+ x.shape[-1] * x.shape[-2] // 4,
97
+ shift=(not self.is_schnell),
98
+ )
99
+ if init_image is not None:
100
+ t_idx = int((1 - image2image_strength) * num_steps)
101
+ t = timesteps[t_idx]
102
+ timesteps = timesteps[t_idx:]
103
+ x = t * x + (1.0 - t) * init_image.to(x.dtype)
104
+
105
+ if self.offload:
106
+ self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device)
107
+ inp = prepare(t5=self.t5, clip=self.clip, img=x, prompt=opts.prompt)
108
+
109
+ # offload TEs to CPU, load model to gpu
110
+ if self.offload:
111
+ self.t5, self.clip = self.t5.cpu(), self.clip.cpu()
112
+ torch.cuda.empty_cache()
113
+ self.model = self.model.to(self.device)
114
+
115
+ # denoise initial noise
116
+ x = denoise(self.model, **inp, timesteps=timesteps, guidance=opts.guidance)
117
+
118
+ # offload model, load autoencoder to gpu
119
+ if self.offload:
120
+ self.model.cpu()
121
+ torch.cuda.empty_cache()
122
+ self.ae.decoder.to(x.device)
123
+
124
+ # decode latents to pixel space
125
+ x = unpack(x.float(), opts.height, opts.width)
126
+ with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16):
127
+ x = self.ae.decode(x)
128
+
129
+ if self.offload:
130
+ self.ae.decoder.cpu()
131
+ torch.cuda.empty_cache()
132
+
133
+ t1 = time.perf_counter()
134
+
135
+ print(f"Done in {t1 - t0:.1f}s.")
136
+ # bring into PIL format
137
+ x = x.clamp(-1, 1)
138
+ x = embed_watermark(x.float())
139
+ x = rearrange(x[0], "c h w -> h w c")
140
+
141
+ img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
142
+ nsfw_score = [x["score"] for x in self.nsfw_classifier(img) if x["label"] == "nsfw"][0]
143
+
144
+ if nsfw_score < NSFW_THRESHOLD:
145
+ filename = f"output/gradio/{uuid.uuid4()}.jpg"
146
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
147
+ exif_data = Image.Exif()
148
+ if init_image is None:
149
+ exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
150
+ else:
151
+ exif_data[ExifTags.Base.Software] = "AI generated;img2img;flux"
152
+ exif_data[ExifTags.Base.Make] = "Black Forest Labs"
153
+ exif_data[ExifTags.Base.Model] = self.model_name
154
+ if add_sampling_metadata:
155
+ exif_data[ExifTags.Base.ImageDescription] = prompt
156
+
157
+ img.save(filename, format="jpeg", exif=exif_data, quality=95, subsampling=0)
158
+
159
+ return img, str(opts.seed), filename, None
160
+ else:
161
+ return None, str(opts.seed), None, "Your generated image may contain NSFW content."
162
+
163
+
164
+ def create_demo(
165
+ model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu", offload: bool = False
166
+ ):
167
+ generator = FluxGenerator(model_name, device, offload)
168
+ is_schnell = model_name == "flux-schnell"
169
+
170
+ with gr.Blocks() as demo:
171
+ gr.Markdown(f"# Flux Image Generation Demo - Model: {model_name}")
172
+
173
+ with gr.Row():
174
+ with gr.Column():
175
+ prompt = gr.Textbox(
176
+ label="Prompt",
177
+ value='a photo of a forest with mist swirling around the tree trunks. The word "FLUX" is painted over it in big, red brush strokes with visible texture',
178
+ )
179
+ do_img2img = gr.Checkbox(label="Image to Image", value=False, interactive=not is_schnell)
180
+ init_image = gr.Image(label="Input Image", visible=False)
181
+ image2image_strength = gr.Slider(
182
+ 0.0, 1.0, 0.8, step=0.1, label="Noising strength", visible=False
183
+ )
184
+
185
+ with gr.Accordion("Advanced Options", open=False):
186
+ width = gr.Slider(128, 8192, 1360, step=16, label="Width")
187
+ height = gr.Slider(128, 8192, 768, step=16, label="Height")
188
+ num_steps = gr.Slider(1, 50, 4 if is_schnell else 50, step=1, label="Number of steps")
189
+ guidance = gr.Slider(
190
+ 1.0, 10.0, 3.5, step=0.1, label="Guidance", interactive=not is_schnell
191
+ )
192
+ seed = gr.Textbox(-1, label="Seed (-1 for random)")
193
+ add_sampling_metadata = gr.Checkbox(
194
+ label="Add sampling parameters to metadata?", value=True
195
+ )
196
+
197
+ generate_btn = gr.Button("Generate")
198
+
199
+ with gr.Column():
200
+ output_image = gr.Image(label="Generated Image")
201
+ seed_output = gr.Number(label="Used Seed")
202
+ warning_text = gr.Textbox(label="Warning", visible=False)
203
+ download_btn = gr.File(label="Download full-resolution")
204
+
205
+ def update_img2img(do_img2img):
206
+ return {
207
+ init_image: gr.update(visible=do_img2img),
208
+ image2image_strength: gr.update(visible=do_img2img),
209
+ }
210
+
211
+ do_img2img.change(update_img2img, do_img2img, [init_image, image2image_strength])
212
+
213
+ generate_btn.click(
214
+ fn=generator.generate_image,
215
+ inputs=[
216
+ width,
217
+ height,
218
+ num_steps,
219
+ guidance,
220
+ seed,
221
+ prompt,
222
+ init_image,
223
+ image2image_strength,
224
+ add_sampling_metadata,
225
+ ],
226
+ outputs=[output_image, seed_output, download_btn, warning_text],
227
+ )
228
+
229
+ return demo
230
+
231
+
232
+ if __name__ == "__main__":
233
+ import argparse
234
+
235
+ parser = argparse.ArgumentParser(description="Flux")
236
+ parser.add_argument(
237
+ "--name", type=str, default="flux-schnell", choices=list(configs.keys()), help="Model name"
238
+ )
239
+ parser.add_argument(
240
+ "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use"
241
+ )
242
+ parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use")
243
+ parser.add_argument("--share", action="store_true", help="Create a public link to your demo")
244
+ args = parser.parse_args()
245
+
246
+ demo = create_demo(args.name, args.device, args.offload)
247
+ demo.launch(share=args.share)
data/flux/demo_st.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import time
4
+ from glob import iglob
5
+ from io import BytesIO
6
+
7
+ import streamlit as st
8
+ import torch
9
+ from einops import rearrange
10
+ from fire import Fire
11
+ from PIL import ExifTags, Image
12
+ from st_keyup import st_keyup
13
+ from torchvision import transforms
14
+ from transformers import pipeline
15
+
16
+ from flux.cli import SamplingOptions
17
+ from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
18
+ from flux.util import (
19
+ configs,
20
+ embed_watermark,
21
+ load_ae,
22
+ load_clip,
23
+ load_flow_model,
24
+ load_t5,
25
+ )
26
+
27
+ NSFW_THRESHOLD = 0.85
28
+
29
+
30
+ @st.cache_resource()
31
+ def get_models(name: str, device: torch.device, offload: bool, is_schnell: bool):
32
+ t5 = load_t5(device, max_length=256 if is_schnell else 512)
33
+ clip = load_clip(device)
34
+ model = load_flow_model(name, device="cpu" if offload else device)
35
+ ae = load_ae(name, device="cpu" if offload else device)
36
+ nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
37
+ return model, ae, t5, clip, nsfw_classifier
38
+
39
+
40
+ def get_image() -> torch.Tensor | None:
41
+ image = st.file_uploader("Input", type=["jpg", "JPEG", "png"])
42
+ if image is None:
43
+ return None
44
+ image = Image.open(image).convert("RGB")
45
+
46
+ transform = transforms.Compose(
47
+ [
48
+ transforms.ToTensor(),
49
+ transforms.Lambda(lambda x: 2.0 * x - 1.0),
50
+ ]
51
+ )
52
+ img: torch.Tensor = transform(image)
53
+ return img[None, ...]
54
+
55
+
56
+ @torch.inference_mode()
57
+ def main(
58
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
59
+ offload: bool = False,
60
+ output_dir: str = "output",
61
+ ):
62
+ torch_device = torch.device(device)
63
+ names = list(configs.keys())
64
+ name = st.selectbox("Which model to load?", names)
65
+ if name is None or not st.checkbox("Load model", False):
66
+ return
67
+
68
+ is_schnell = name == "flux-schnell"
69
+ model, ae, t5, clip, nsfw_classifier = get_models(
70
+ name,
71
+ device=torch_device,
72
+ offload=offload,
73
+ is_schnell=is_schnell,
74
+ )
75
+
76
+ do_img2img = (
77
+ st.checkbox(
78
+ "Image to Image",
79
+ False,
80
+ disabled=is_schnell,
81
+ help="Partially noise an image and denoise again to get variations.\n\nOnly works for flux-dev",
82
+ )
83
+ and not is_schnell
84
+ )
85
+ if do_img2img:
86
+ init_image = get_image()
87
+ if init_image is None:
88
+ st.warning("Please add an image to do image to image")
89
+ image2image_strength = st.number_input("Noising strength", min_value=0.0, max_value=1.0, value=0.8)
90
+ if init_image is not None:
91
+ h, w = init_image.shape[-2:]
92
+ st.write(f"Got image of size {w}x{h} ({h*w/1e6:.2f}MP)")
93
+ resize_img = st.checkbox("Resize image", False) or init_image is None
94
+ else:
95
+ init_image = None
96
+ resize_img = True
97
+ image2image_strength = 0.0
98
+
99
+ # allow for packing and conversion to latent space
100
+ width = int(
101
+ 16 * (st.number_input("Width", min_value=128, value=1360, step=16, disabled=not resize_img) // 16)
102
+ )
103
+ height = int(
104
+ 16 * (st.number_input("Height", min_value=128, value=768, step=16, disabled=not resize_img) // 16)
105
+ )
106
+ num_steps = int(st.number_input("Number of steps", min_value=1, value=(4 if is_schnell else 50)))
107
+ guidance = float(st.number_input("Guidance", min_value=1.0, value=3.5, disabled=is_schnell))
108
+ seed_str = st.text_input("Seed", disabled=is_schnell)
109
+ if seed_str.isdecimal():
110
+ seed = int(seed_str)
111
+ else:
112
+ st.info("No seed set, set to positive integer to enable")
113
+ seed = None
114
+ save_samples = st.checkbox("Save samples?", not is_schnell)
115
+ add_sampling_metadata = st.checkbox("Add sampling parameters to metadata?", True)
116
+
117
+ default_prompt = (
118
+ "a photo of a forest with mist swirling around the tree trunks. The word "
119
+ '"FLUX" is painted over it in big, red brush strokes with visible texture'
120
+ )
121
+ prompt = st_keyup("Enter a prompt", value=default_prompt, debounce=300, key="interactive_text")
122
+
123
+ output_name = os.path.join(output_dir, "img_{idx}.jpg")
124
+ if not os.path.exists(output_dir):
125
+ os.makedirs(output_dir)
126
+ idx = 0
127
+ else:
128
+ fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
129
+ if len(fns) > 0:
130
+ idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
131
+ else:
132
+ idx = 0
133
+
134
+ rng = torch.Generator(device="cpu")
135
+
136
+ if "seed" not in st.session_state:
137
+ st.session_state.seed = rng.seed()
138
+
139
+ def increment_counter():
140
+ st.session_state.seed += 1
141
+
142
+ def decrement_counter():
143
+ if st.session_state.seed > 0:
144
+ st.session_state.seed -= 1
145
+
146
+ opts = SamplingOptions(
147
+ prompt=prompt,
148
+ width=width,
149
+ height=height,
150
+ num_steps=num_steps,
151
+ guidance=guidance,
152
+ seed=seed,
153
+ )
154
+
155
+ if name == "flux-schnell":
156
+ cols = st.columns([5, 1, 1, 5])
157
+ with cols[1]:
158
+ st.button("↩", on_click=increment_counter)
159
+ with cols[2]:
160
+ st.button("↪", on_click=decrement_counter)
161
+ if is_schnell or st.button("Sample"):
162
+ if is_schnell:
163
+ opts.seed = st.session_state.seed
164
+ elif opts.seed is None:
165
+ opts.seed = rng.seed()
166
+ print(f"Generating '{opts.prompt}' with seed {opts.seed}")
167
+ t0 = time.perf_counter()
168
+
169
+ if init_image is not None:
170
+ if resize_img:
171
+ init_image = torch.nn.functional.interpolate(init_image, (opts.height, opts.width))
172
+ else:
173
+ h, w = init_image.shape[-2:]
174
+ init_image = init_image[..., : 16 * (h // 16), : 16 * (w // 16)]
175
+ opts.height = init_image.shape[-2]
176
+ opts.width = init_image.shape[-1]
177
+ if offload:
178
+ ae.encoder.to(torch_device)
179
+ init_image = ae.encode(init_image.to(torch_device))
180
+ if offload:
181
+ ae = ae.cpu()
182
+ torch.cuda.empty_cache()
183
+
184
+ # prepare input
185
+ x = get_noise(
186
+ 1,
187
+ opts.height,
188
+ opts.width,
189
+ device=torch_device,
190
+ dtype=torch.bfloat16,
191
+ seed=opts.seed,
192
+ )
193
+ # divide pixel space by 16**2 to account for latent space conversion
194
+ timesteps = get_schedule(
195
+ opts.num_steps,
196
+ (x.shape[-1] * x.shape[-2]) // 4,
197
+ shift=(not is_schnell),
198
+ )
199
+ if init_image is not None:
200
+ t_idx = int((1 - image2image_strength) * num_steps)
201
+ t = timesteps[t_idx]
202
+ timesteps = timesteps[t_idx:]
203
+ x = t * x + (1.0 - t) * init_image.to(x.dtype)
204
+
205
+ if offload:
206
+ t5, clip = t5.to(torch_device), clip.to(torch_device)
207
+ inp = prepare(t5=t5, clip=clip, img=x, prompt=opts.prompt)
208
+
209
+ # offload TEs to CPU, load model to gpu
210
+ if offload:
211
+ t5, clip = t5.cpu(), clip.cpu()
212
+ torch.cuda.empty_cache()
213
+ model = model.to(torch_device)
214
+
215
+ # denoise initial noise
216
+ x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)
217
+
218
+ # offload model, load autoencoder to gpu
219
+ if offload:
220
+ model.cpu()
221
+ torch.cuda.empty_cache()
222
+ ae.decoder.to(x.device)
223
+
224
+ # decode latents to pixel space
225
+ x = unpack(x.float(), opts.height, opts.width)
226
+ with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
227
+ x = ae.decode(x)
228
+
229
+ if offload:
230
+ ae.decoder.cpu()
231
+ torch.cuda.empty_cache()
232
+
233
+ t1 = time.perf_counter()
234
+
235
+ fn = output_name.format(idx=idx)
236
+ print(f"Done in {t1 - t0:.1f}s.")
237
+ # bring into PIL format and save
238
+ x = x.clamp(-1, 1)
239
+ x = embed_watermark(x.float())
240
+ x = rearrange(x[0], "c h w -> h w c")
241
+
242
+ img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
243
+ nsfw_score = [x["score"] for x in nsfw_classifier(img) if x["label"] == "nsfw"][0]
244
+
245
+ if nsfw_score < NSFW_THRESHOLD:
246
+ buffer = BytesIO()
247
+ exif_data = Image.Exif()
248
+ if init_image is None:
249
+ exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
250
+ else:
251
+ exif_data[ExifTags.Base.Software] = "AI generated;img2img;flux"
252
+ exif_data[ExifTags.Base.Make] = "Black Forest Labs"
253
+ exif_data[ExifTags.Base.Model] = name
254
+ if add_sampling_metadata:
255
+ exif_data[ExifTags.Base.ImageDescription] = prompt
256
+ img.save(buffer, format="jpeg", exif=exif_data, quality=95, subsampling=0)
257
+
258
+ img_bytes = buffer.getvalue()
259
+ if save_samples:
260
+ print(f"Saving {fn}")
261
+ with open(fn, "wb") as file:
262
+ file.write(img_bytes)
263
+ idx += 1
264
+
265
+ st.session_state["samples"] = {
266
+ "prompt": opts.prompt,
267
+ "img": img,
268
+ "seed": opts.seed,
269
+ "bytes": img_bytes,
270
+ }
271
+ opts.seed = None
272
+ else:
273
+ st.warning("Your generated image may contain NSFW content.")
274
+ st.session_state["samples"] = None
275
+
276
+ samples = st.session_state.get("samples", None)
277
+ if samples is not None:
278
+ st.image(samples["img"], caption=samples["prompt"])
279
+ st.download_button(
280
+ "Download full-resolution",
281
+ samples["bytes"],
282
+ file_name="generated.jpg",
283
+ mime="image/jpg",
284
+ )
285
+ st.write(f"Seed: {samples['seed']}")
286
+
287
+
288
+ def app():
289
+ Fire(main)
290
+
291
+
292
+ if __name__ == "__main__":
293
+ app()
data/flux/demo_st_fill.py ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import tempfile
4
+ import time
5
+ from glob import iglob
6
+ from io import BytesIO
7
+
8
+ import numpy as np
9
+ import streamlit as st
10
+ import torch
11
+ from einops import rearrange
12
+ from PIL import ExifTags, Image
13
+ from st_keyup import st_keyup
14
+ from streamlit_drawable_canvas import st_canvas
15
+ from transformers import pipeline
16
+
17
+ from flux.sampling import denoise, get_noise, get_schedule, prepare_fill, unpack
18
+ from flux.util import embed_watermark, load_ae, load_clip, load_flow_model, load_t5
19
+
20
+ NSFW_THRESHOLD = 0.85
21
+
22
+
23
+ def add_border_and_mask(image, zoom_all=1.0, zoom_left=0, zoom_right=0, zoom_up=0, zoom_down=0, overlap=0):
24
+ """Adds a black border around the image with individual side control and mask overlap"""
25
+ orig_width, orig_height = image.size
26
+
27
+ # Calculate padding for each side (in pixels)
28
+ left_pad = int(orig_width * zoom_left)
29
+ right_pad = int(orig_width * zoom_right)
30
+ top_pad = int(orig_height * zoom_up)
31
+ bottom_pad = int(orig_height * zoom_down)
32
+
33
+ # Calculate overlap in pixels
34
+ overlap_left = int(orig_width * overlap)
35
+ overlap_right = int(orig_width * overlap)
36
+ overlap_top = int(orig_height * overlap)
37
+ overlap_bottom = int(orig_height * overlap)
38
+
39
+ # If using the all-sides zoom, add it to each side
40
+ if zoom_all > 1.0:
41
+ extra_each_side = (zoom_all - 1.0) / 2
42
+ left_pad += int(orig_width * extra_each_side)
43
+ right_pad += int(orig_width * extra_each_side)
44
+ top_pad += int(orig_height * extra_each_side)
45
+ bottom_pad += int(orig_height * extra_each_side)
46
+
47
+ # Calculate new dimensions (ensure they're multiples of 32)
48
+ new_width = 32 * round((orig_width + left_pad + right_pad) / 32)
49
+ new_height = 32 * round((orig_height + top_pad + bottom_pad) / 32)
50
+
51
+ # Create new image with black border
52
+ bordered_image = Image.new("RGB", (new_width, new_height), (0, 0, 0))
53
+ # Paste original image in position
54
+ paste_x = left_pad
55
+ paste_y = top_pad
56
+ bordered_image.paste(image, (paste_x, paste_y))
57
+
58
+ # Create mask (white where the border is, black where the original image was)
59
+ mask = Image.new("L", (new_width, new_height), 255) # White background
60
+ # Paste black rectangle with overlap adjustment
61
+ mask.paste(
62
+ 0,
63
+ (
64
+ paste_x + overlap_left, # Left edge moves right
65
+ paste_y + overlap_top, # Top edge moves down
66
+ paste_x + orig_width - overlap_right, # Right edge moves left
67
+ paste_y + orig_height - overlap_bottom, # Bottom edge moves up
68
+ ),
69
+ )
70
+
71
+ return bordered_image, mask
72
+
73
+
74
+ @st.cache_resource()
75
+ def get_models(name: str, device: torch.device, offload: bool):
76
+ t5 = load_t5(device, max_length=128)
77
+ clip = load_clip(device)
78
+ model = load_flow_model(name, device="cpu" if offload else device)
79
+ ae = load_ae(name, device="cpu" if offload else device)
80
+ nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
81
+ return model, ae, t5, clip, nsfw_classifier
82
+
83
+
84
+ def resize(img: Image.Image, min_mp: float = 0.5, max_mp: float = 2.0) -> Image.Image:
85
+ width, height = img.size
86
+ mp = (width * height) / 1_000_000 # Current megapixels
87
+
88
+ if min_mp <= mp <= max_mp:
89
+ # Even if MP is in range, ensure dimensions are multiples of 32
90
+ new_width = int(32 * round(width / 32))
91
+ new_height = int(32 * round(height / 32))
92
+ if new_width != width or new_height != height:
93
+ return img.resize((new_width, new_height), Image.Resampling.LANCZOS)
94
+ return img
95
+
96
+ # Calculate scaling factor
97
+ if mp < min_mp:
98
+ scale = (min_mp / mp) ** 0.5
99
+ else: # mp > max_mp
100
+ scale = (max_mp / mp) ** 0.5
101
+
102
+ new_width = int(32 * round(width * scale / 32))
103
+ new_height = int(32 * round(height * scale / 32))
104
+
105
+ return img.resize((new_width, new_height), Image.Resampling.LANCZOS)
106
+
107
+
108
+ def clear_canvas_state():
109
+ """Clear all canvas-related state"""
110
+ keys_to_clear = ["canvas", "last_image_dims"]
111
+ for key in keys_to_clear:
112
+ if key in st.session_state:
113
+ del st.session_state[key]
114
+
115
+
116
+ def set_new_image(img: Image.Image):
117
+ """Safely set a new image and clear relevant state"""
118
+ st.session_state["current_image"] = img
119
+ clear_canvas_state()
120
+ st.rerun()
121
+
122
+
123
+ def downscale_image(img: Image.Image, scale_factor: float) -> Image.Image:
124
+ """Downscale image by a given factor while maintaining 32-pixel multiple dimensions"""
125
+ if scale_factor >= 1.0:
126
+ return img
127
+
128
+ width, height = img.size
129
+ new_width = int(32 * round(width * scale_factor / 32))
130
+ new_height = int(32 * round(height * scale_factor / 32))
131
+
132
+ # Ensure minimum dimensions
133
+ new_width = max(64, new_width) # minimum 64 pixels
134
+ new_height = max(64, new_height) # minimum 64 pixels
135
+
136
+ return img.resize((new_width, new_height), Image.Resampling.LANCZOS)
137
+
138
+
139
+ @torch.inference_mode()
140
+ def main(
141
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
142
+ offload: bool = False,
143
+ output_dir: str = "output",
144
+ ):
145
+ torch_device = torch.device(device)
146
+ st.title("Flux Fill: Inpainting & Outpainting")
147
+
148
+ # Model selection and loading
149
+ name = "flux-dev-fill"
150
+ if not st.checkbox("Load model", False):
151
+ return
152
+
153
+ try:
154
+ model, ae, t5, clip, nsfw_classifier = get_models(
155
+ name,
156
+ device=torch_device,
157
+ offload=offload,
158
+ )
159
+ except Exception as e:
160
+ st.error(f"Error loading models: {e}")
161
+ return
162
+
163
+ # Mode selection
164
+ mode = st.radio("Select Mode", ["Inpainting", "Outpainting"])
165
+
166
+ # Image handling - either from previous generation or new upload
167
+ if "input_image" in st.session_state:
168
+ image = st.session_state["input_image"]
169
+ del st.session_state["input_image"]
170
+ set_new_image(image)
171
+ st.write("Continuing from previous result")
172
+ else:
173
+ uploaded_image = st.file_uploader("Upload image", type=["jpg", "jpeg", "png"])
174
+ if uploaded_image is None:
175
+ st.warning("Please upload an image")
176
+ return
177
+
178
+ if (
179
+ "current_image_name" not in st.session_state
180
+ or st.session_state["current_image_name"] != uploaded_image.name
181
+ ):
182
+ try:
183
+ image = Image.open(uploaded_image).convert("RGB")
184
+ st.session_state["current_image_name"] = uploaded_image.name
185
+ set_new_image(image)
186
+ except Exception as e:
187
+ st.error(f"Error loading image: {e}")
188
+ return
189
+ else:
190
+ image = st.session_state.get("current_image")
191
+ if image is None:
192
+ st.error("Error: Image state is invalid. Please reupload the image.")
193
+ clear_canvas_state()
194
+ return
195
+
196
+ # Add downscale control
197
+ with st.expander("Image Size Control"):
198
+ current_mp = (image.size[0] * image.size[1]) / 1_000_000
199
+ st.write(f"Current image size: {image.size[0]}x{image.size[1]} ({current_mp:.1f}MP)")
200
+
201
+ scale_factor = st.slider(
202
+ "Downscale Factor",
203
+ min_value=0.1,
204
+ max_value=1.0,
205
+ value=1.0,
206
+ step=0.1,
207
+ help="1.0 = original size, 0.5 = half size, etc.",
208
+ )
209
+
210
+ if scale_factor < 1.0 and st.button("Apply Downscaling"):
211
+ image = downscale_image(image, scale_factor)
212
+ set_new_image(image)
213
+ st.rerun()
214
+
215
+ # Resize image with validation
216
+ try:
217
+ original_mp = (image.size[0] * image.size[1]) / 1_000_000
218
+ image = resize(image)
219
+ width, height = image.size
220
+ current_mp = (width * height) / 1_000_000
221
+
222
+ if width % 32 != 0 or height % 32 != 0:
223
+ st.error("Error: Image dimensions must be multiples of 32")
224
+ return
225
+
226
+ st.write(f"Image dimensions: {width}x{height} pixels")
227
+ if original_mp != current_mp:
228
+ st.write(
229
+ f"Image has been resized from {original_mp:.1f}MP to {current_mp:.1f}MP to stay within bounds (0.5MP - 2MP)"
230
+ )
231
+ except Exception as e:
232
+ st.error(f"Error processing image: {e}")
233
+ return
234
+
235
+ if mode == "Outpainting":
236
+ # Outpainting controls
237
+ zoom_all = st.slider("Zoom Out Amount (All Sides)", min_value=1.0, max_value=3.0, value=1.0, step=0.1)
238
+
239
+ with st.expander("Advanced Zoom Controls"):
240
+ st.info("These controls add additional zoom to specific sides")
241
+ col1, col2 = st.columns(2)
242
+ with col1:
243
+ zoom_left = st.slider("Left", min_value=0.0, max_value=1.0, value=0.0, step=0.1)
244
+ zoom_right = st.slider("Right", min_value=0.0, max_value=1.0, value=0.0, step=0.1)
245
+ with col2:
246
+ zoom_up = st.slider("Up", min_value=0.0, max_value=1.0, value=0.0, step=0.1)
247
+ zoom_down = st.slider("Down", min_value=0.0, max_value=1.0, value=0.0, step=0.1)
248
+
249
+ overlap = st.slider("Overlap", min_value=0.01, max_value=0.25, value=0.01, step=0.01)
250
+
251
+ # Generate bordered image and mask
252
+ image_for_generation, mask = add_border_and_mask(
253
+ image,
254
+ zoom_all=zoom_all,
255
+ zoom_left=zoom_left,
256
+ zoom_right=zoom_right,
257
+ zoom_up=zoom_up,
258
+ zoom_down=zoom_down,
259
+ overlap=overlap,
260
+ )
261
+ width, height = image_for_generation.size
262
+
263
+ # Show preview
264
+ col1, col2 = st.columns(2)
265
+ with col1:
266
+ st.image(image_for_generation, caption="Image with Border")
267
+ with col2:
268
+ st.image(mask, caption="Mask (white areas will be generated)")
269
+
270
+ else: # Inpainting mode
271
+ # Canvas setup with dimension tracking
272
+ canvas_key = f"canvas_{width}_{height}"
273
+ if "last_image_dims" not in st.session_state:
274
+ st.session_state.last_image_dims = (width, height)
275
+ elif st.session_state.last_image_dims != (width, height):
276
+ clear_canvas_state()
277
+ st.session_state.last_image_dims = (width, height)
278
+ st.rerun()
279
+
280
+ try:
281
+ canvas_result = st_canvas(
282
+ fill_color="rgba(255, 255, 255, 0.0)",
283
+ stroke_width=st.slider("Brush size", 1, 500, 50),
284
+ stroke_color="#fff",
285
+ background_image=image,
286
+ height=height,
287
+ width=width,
288
+ drawing_mode="freedraw",
289
+ key=canvas_key,
290
+ display_toolbar=True,
291
+ )
292
+ except Exception as e:
293
+ st.error(f"Error creating canvas: {e}")
294
+ clear_canvas_state()
295
+ st.rerun()
296
+ return
297
+
298
+ # Sampling parameters
299
+ num_steps = int(st.number_input("Number of steps", min_value=1, value=50))
300
+ guidance = float(st.number_input("Guidance", min_value=1.0, value=30.0))
301
+ seed_str = st.text_input("Seed")
302
+ if seed_str.isdecimal():
303
+ seed = int(seed_str)
304
+ else:
305
+ st.info("No seed set, using random seed")
306
+ seed = None
307
+
308
+ save_samples = st.checkbox("Save samples?", True)
309
+ add_sampling_metadata = st.checkbox("Add sampling parameters to metadata?", True)
310
+
311
+ # Prompt input
312
+ prompt = st_keyup("Enter a prompt", value="", debounce=300, key="interactive_text")
313
+
314
+ # Setup output path
315
+ output_name = os.path.join(output_dir, "img_{idx}.jpg")
316
+ if not os.path.exists(output_dir):
317
+ os.makedirs(output_dir)
318
+ idx = 0
319
+ else:
320
+ fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
321
+ idx = len(fns)
322
+
323
+ if st.button("Generate"):
324
+ valid_input = False
325
+
326
+ if mode == "Inpainting" and canvas_result.image_data is not None:
327
+ valid_input = True
328
+ # Create mask from canvas
329
+ try:
330
+ mask = Image.fromarray(canvas_result.image_data)
331
+ mask = mask.getchannel("A") # Get alpha channel
332
+ mask_array = np.array(mask)
333
+ mask_array = (mask_array > 0).astype(np.uint8) * 255
334
+ mask = Image.fromarray(mask_array)
335
+ image_for_generation = image
336
+ except Exception as e:
337
+ st.error(f"Error creating mask: {e}")
338
+ return
339
+
340
+ elif mode == "Outpainting":
341
+ valid_input = True
342
+ # image_for_generation and mask are already set above
343
+
344
+ if not valid_input:
345
+ st.error("Please draw a mask or configure outpainting settings")
346
+ return
347
+
348
+ # Create temporary files
349
+ with (
350
+ tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_img,
351
+ tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_mask,
352
+ ):
353
+ try:
354
+ image_for_generation.save(tmp_img.name)
355
+ mask.save(tmp_mask.name)
356
+ except Exception as e:
357
+ st.error(f"Error saving temporary files: {e}")
358
+ return
359
+
360
+ try:
361
+ # Generate inpainting/outpainting
362
+ rng = torch.Generator(device="cpu")
363
+ if seed is None:
364
+ seed = rng.seed()
365
+
366
+ print(f"Generating with seed {seed}:\n{prompt}")
367
+ t0 = time.perf_counter()
368
+
369
+ x = get_noise(
370
+ 1,
371
+ height,
372
+ width,
373
+ device=torch_device,
374
+ dtype=torch.bfloat16,
375
+ seed=seed,
376
+ )
377
+
378
+ if offload:
379
+ t5, clip, ae = t5.to(torch_device), clip.to(torch_device), ae.to(torch_device)
380
+
381
+ inp = prepare_fill(
382
+ t5,
383
+ clip,
384
+ x,
385
+ prompt=prompt,
386
+ ae=ae,
387
+ img_cond_path=tmp_img.name,
388
+ mask_path=tmp_mask.name,
389
+ )
390
+
391
+ timesteps = get_schedule(num_steps, inp["img"].shape[1], shift=True)
392
+
393
+ if offload:
394
+ t5, clip, ae = t5.cpu(), clip.cpu(), ae.cpu()
395
+ torch.cuda.empty_cache()
396
+ model = model.to(torch_device)
397
+
398
+ x = denoise(model, **inp, timesteps=timesteps, guidance=guidance)
399
+
400
+ if offload:
401
+ model.cpu()
402
+ torch.cuda.empty_cache()
403
+ ae.decoder.to(x.device)
404
+
405
+ x = unpack(x.float(), height, width)
406
+ with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
407
+ x = ae.decode(x)
408
+
409
+ t1 = time.perf_counter()
410
+ print(f"Done in {t1 - t0:.1f}s")
411
+
412
+ # Process and display result
413
+ x = x.clamp(-1, 1)
414
+ x = embed_watermark(x.float())
415
+ x = rearrange(x[0], "c h w -> h w c")
416
+ img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
417
+
418
+ nsfw_score = [x["score"] for x in nsfw_classifier(img) if x["label"] == "nsfw"][0]
419
+
420
+ if nsfw_score < NSFW_THRESHOLD:
421
+ buffer = BytesIO()
422
+ exif_data = Image.Exif()
423
+ exif_data[ExifTags.Base.Software] = "AI generated;inpainting;flux"
424
+ exif_data[ExifTags.Base.Make] = "Black Forest Labs"
425
+ exif_data[ExifTags.Base.Model] = name
426
+ if add_sampling_metadata:
427
+ exif_data[ExifTags.Base.ImageDescription] = prompt
428
+ img.save(buffer, format="jpeg", exif=exif_data, quality=95, subsampling=0)
429
+
430
+ img_bytes = buffer.getvalue()
431
+ if save_samples:
432
+ fn = output_name.format(idx=idx)
433
+ print(f"Saving {fn}")
434
+ with open(fn, "wb") as file:
435
+ file.write(img_bytes)
436
+
437
+ st.session_state["samples"] = {
438
+ "prompt": prompt,
439
+ "img": img,
440
+ "seed": seed,
441
+ "bytes": img_bytes,
442
+ }
443
+ else:
444
+ st.warning("Your generated image may contain NSFW content.")
445
+ st.session_state["samples"] = None
446
+
447
+ except Exception as e:
448
+ st.error(f"Error during generation: {e}")
449
+ return
450
+ finally:
451
+ # Clean up temporary files
452
+ try:
453
+ os.unlink(tmp_img.name)
454
+ os.unlink(tmp_mask.name)
455
+ except Exception as e:
456
+ print(f"Error cleaning up temporary files: {e}")
457
+
458
+ # Display results
459
+ samples = st.session_state.get("samples", None)
460
+ if samples is not None:
461
+ st.image(samples["img"], caption=samples["prompt"])
462
+ col1, col2 = st.columns(2)
463
+ with col1:
464
+ st.download_button(
465
+ "Download full-resolution",
466
+ samples["bytes"],
467
+ file_name="generated.jpg",
468
+ mime="image/jpg",
469
+ )
470
+ with col2:
471
+ if st.button("Continue from this image"):
472
+ # Store the generated image
473
+ new_image = samples["img"]
474
+ # Clear ALL canvas state
475
+ clear_canvas_state()
476
+ if "samples" in st.session_state:
477
+ del st.session_state["samples"]
478
+ # Set as current image
479
+ st.session_state["current_image"] = new_image
480
+ st.rerun()
481
+
482
+ st.write(f"Seed: {samples['seed']}")
483
+
484
+
485
+ if __name__ == "__main__":
486
+ st.set_page_config(layout="wide")
487
+ main()
data/flux/docs/fill.md ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Models
2
+
3
+ FLUX.1 Fill introduces advanced inpainting and outpainting capabilities. It allows for seamless edits that integrate naturally with existing images.
4
+
5
+ | Name | HuggingFace repo | License | sha256sum |
6
+ | ------------------- | -------------------------------------------------------- | --------------------------------------------------------------------- | ---------------------------------------------------------------- |
7
+ | `FLUX.1 Fill [dev]` | https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) | 03e289f530df51d014f48e675a9ffa2141bc003259bf5f25d75b957e920a41ca |
8
+ | `FLUX.1 Fill [pro]` | Only available in our API. |
9
+
10
+ ## Examples
11
+
12
+ ![inpainting](../assets/docs/inpainting.png)
13
+ ![outpainting](../assets/docs/outpainting.png)
14
+
15
+ ## Open-weights usage
16
+
17
+ The weights will be downloaded automatically from HuggingFace once you start one of the demos. To download `FLUX.1 Fill [dev]`, you will need to be logged in, see [here](https://huggingface.co/docs/huggingface_hub/guides/cli#huggingface-cli-login). Alternatively, if you have downloaded the model weights manually from [here](https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev), you can specify the downloaded paths via environment variables:
18
+
19
+ ```bash
20
+ export FLUX_DEV_FILL=<path_to_flux_dev_fill_sft_file>
21
+ export AE=<path_to_ae_sft_file>
22
+ ```
23
+
24
+ For interactive sampling run
25
+
26
+ ```bash
27
+ python -m src.flux.cli_fill --loop
28
+ ```
29
+
30
+ Or to generate a single sample run
31
+
32
+ ```bash
33
+ python -m src.flux.cli_fill \
34
+ --img_cond_path <path_to_input_image> \
35
+ --img_cond_mask <path_to_input_mask>
36
+ ```
37
+
38
+ The input_mask should be an image of the same size as the conditioning image that only contains black and white pixels; see [an example mask](../assets/cup_mask.png) for [this image](../assets/cup.png).
39
+
40
+ We also provide an interactive streamlit demo. The demo can be run via
41
+
42
+ ```bash
43
+ streamlit run demo_st_fill.py
44
+ ```
data/flux/docs/image-variation.md ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Models
2
+
3
+ FLUX.1 Redux is an adapter for the FLUX.1 text-to-image base models, FLUX.1 [dev] and FLUX.1 [schnell], which can be used to generate image variations.
4
+ In addition, FLUX.1 Redux [pro] is available in our API and, augmenting the [dev] adapter, the API endpoint allows users to modify an image given a textual description. The feature is supported in our latest model FLUX1.1 [pro] Ultra, allowing for combining input images and text prompts to create high-quality 4-megapixel outputs with flexible aspect ratios.
5
+
6
+ | Name | HuggingFace repo | License | sha256sum |
7
+ | --------------------------- | ----------------------------------------------------------------------------------------------- | --------------------------------------------------------------------- | ---------------------------------------------------------------- |
8
+ | `FLUX.1 Redux [dev]` | https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) | a1b3bdcb4bdc58ce04874b9ca776d61fc3e914bb6beab41efb63e4e2694dca45 |
9
+ | `FLUX.1 Redux [pro]` | [Available in our API.](https://docs.bfl.ml/) Supports image variations. |
10
+ | `FLUX1.1 Redux [pro] Ultra` | [Available in our API.](https://docs.bfl.ml/) Supports image variations based on a text prompt. |
11
+
12
+ ## Examples
13
+
14
+ ![redux](../assets/docs/redux.png)
15
+
16
+ ## Open-weights usage
17
+
18
+ The text-to-image base model weights and the autoencoder weights will be downloaded automatically from HuggingFace once you start the demo. To download `FLUX.1 [dev]`, you will need to be logged in, see [here](https://huggingface.co/docs/huggingface_hub/guides/cli#huggingface-cli-login). You need to manually download the adapter weights from [here](https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev) and specify them via an environment variable `export FLUX_REDUX=<path_to_flux_redux_sft_file>`. In general, you may specify any manually downloaded weights via environment variables:
19
+
20
+ ```bash
21
+ export FLUX_REDUX=<path_to_flux_redux_sft_file>
22
+ export FLUX_SCHNELL=<path_to_flux_schnell_sft_file>
23
+ export FLUX_DEV=<path_to_flux_dev_sft_file>
24
+ export AE=<path_to_ae_sft_file>
25
+ ```
26
+
27
+ For interactive sampling run
28
+
29
+ ```bash
30
+ python -m src.flux.cli_redux --loop --name <name>
31
+ ```
32
+
33
+ where `name` is one of `flux-dev` or `flux-schnell`.
data/flux/docs/structural-conditioning.md ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Models
2
+
3
+ Structural conditioning uses canny edge or depth detection to maintain precise control during image transformations. By preserving the original image's structure through edge or depth maps, users can make text-guided edits while keeping the core composition intact. This is particularly effective for retexturing images. We release four variations: two based on edge maps (full model and LoRA for FLUX.1 [dev]) and two based on depth maps (full model and LoRA for FLUX.1 [dev]).
4
+
5
+ | Name | HuggingFace repo | License | sha256sum |
6
+ | ------------------------- | -------------------------------------------------------------- | --------------------------------------------------------------------- | ---------------------------------------------------------------- |
7
+ | `FLUX.1 Canny [dev]` | https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) | 996876670169591cb412b937fbd46ea14cbed6933aef17c48a2dcd9685c98cdb |
8
+ | `FLUX.1 Depth [dev]` | https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) | 41360d1662f44ca45bc1b665fe6387e91802f53911001630d970a4f8be8dac21 |
9
+ | `FLUX.1 Canny [dev] LoRA` | https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev-lora | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) | 8eaa21b9c43d5e7242844deb64b8cf22ae9010f813f955ca8c05f240b8a98f7e |
10
+ | `FLUX.1 Depth [dev] LoRA` | https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev-lora | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) | 1938b38ea0fdd98080fa3e48beb2bedfbc7ad102d8b65e6614de704a46d8b907 |
11
+ | `FLUX.1 Canny [pro]` | [Available in our API](https://docs.bfl.ml/). |
12
+ | `FLUX.1 Depth [pro]` | [Available in our API](https://docs.bfl.ml/). |
13
+
14
+ ## Examples
15
+
16
+ ![canny](../assets/docs/canny.png)
17
+ ![depth](../assets/docs/depth.png)
18
+
19
+ ## Open-weights usage
20
+
21
+ The full model weights (`FLUX.1 Canny [dev], Flux.1 Depth [dev], FLUX.1 [dev], and the autoencoder) will be downloaded automatically from HuggingFace once you start one of the demos. To download them, you will need to be logged in, see [here](https://huggingface.co/docs/huggingface_hub/guides/cli#huggingface-cli-login). The LoRA weights are not downloaded automatically, but can be downloaded manually [here (Canny)](https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev-lora) and [here (Depth)](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev-lora). You may specify any manually downloaded weights via environment variables: (**necessary for LoRAs**):
22
+
23
+ ```bash
24
+ export FLUX_DEV_DEPTH=<path_to_flux_dev_depth_sft_file>
25
+ export FLUX_DEV_CANNY=<path_to_flux_dev_canny_sft_file>
26
+ export FLUX_DEV_DEPTH_LORA=<path_to_flux_dev_depth_lora_sft_file>
27
+ export FLUX_DEV_CANNY_LORA=<path_to_flux_dev_canny_lora_sft_file>
28
+ export FLUX_REDUX=<path_to_flux_redux_sft_file>
29
+ export FLUX_SCHNELL=<path_to_flux_schnell_sft_file>
30
+ export FLUX_DEV=<path_to_flux_dev_sft_file>
31
+ export AE=<path_to_ae_sft_file>
32
+ ```
33
+
34
+ For interactive sampling run
35
+
36
+ ```bash
37
+ python -m src.flux.cli_control --loop --name <name>
38
+ ```
39
+
40
+ where `name` is one of `flux-dev-canny`, `flux-dev-depth`, `flux-dev-canny-lora`, or `flux-dev-depth-lora`.
data/flux/docs/text-to-image.md ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Models
2
+
3
+ We currently offer four text-to-image models. `FLUX1.1 [pro]` is our most capable model which can generate images at up to 4MP while maintaining an impressive generation time of only 10 seconds per sample.
4
+
5
+ | Name | HuggingFace repo | License | sha256sum |
6
+ | ------------------------- | ------------------------------------------------------- | --------------------------------------------------------------------- | ---------------------------------------------------------------- |
7
+ | `FLUX.1 [schnell]` | https://huggingface.co/black-forest-labs/FLUX.1-schnell | [apache-2.0](model_licenses/LICENSE-FLUX1-schnell) | 9403429e0052277ac2a87ad800adece5481eecefd9ed334e1f348723621d2a0a |
8
+ | `FLUX.1 [dev]` | https://huggingface.co/black-forest-labs/FLUX.1-dev | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) | 4610115bb0c89560703c892c59ac2742fa821e60ef5871b33493ba544683abd7 |
9
+ | `FLUX.1 [pro]` | [Available in our API](https://docs.bfl.ml/). |
10
+ | `FLUX1.1 [pro]` | [Available in our API](https://docs.bfl.ml/). |
11
+ | `FLUX1.1 [pro] Ultra/raw` | [Available in our API](https://docs.bfl.ml/). |
12
+
13
+ ## Open-weights usage
14
+
15
+ The weights will be downloaded automatically from HuggingFace once you start one of the demos. To download `FLUX.1 [dev]`, you will need to be logged in, see [here](https://huggingface.co/docs/huggingface_hub/guides/cli#huggingface-cli-login).
16
+ If you have downloaded the model weights manually, you can specify the downloaded paths via environment-variables:
17
+
18
+ ```bash
19
+ export FLUX_SCHNELL=<path_to_flux_schnell_sft_file>
20
+ export FLUX_DEV=<path_to_flux_dev_sft_file>
21
+ export AE=<path_to_ae_sft_file>
22
+ ```
23
+
24
+ For interactive sampling run
25
+
26
+ ```bash
27
+ python -m flux --name <name> --loop
28
+ ```
29
+
30
+ Or to generate a single sample run
31
+
32
+ ```bash
33
+ python -m flux --name <name> \
34
+ --height <height> --width <width> \
35
+ --prompt "<prompt>"
36
+ ```
37
+
38
+ We also provide a streamlit demo that does both text-to-image and image-to-image. The demo can be run via
39
+
40
+ ```bash
41
+ streamlit run demo_st.py
42
+ ```
43
+
44
+ We also offer a Gradio-based demo for an interactive experience. To run the Gradio demo:
45
+
46
+ ```bash
47
+ python demo_gr.py --name flux-schnell --device cuda
48
+ ```
49
+
50
+ Options:
51
+
52
+ - `--name`: Choose the model to use (options: "flux-schnell", "flux-dev")
53
+ - `--device`: Specify the device to use (default: "cuda" if available, otherwise "cpu")
54
+ - `--offload`: Offload model to CPU when not in use
55
+ - `--share`: Create a public link to your demo
56
+
57
+ To run the demo with the dev model and create a public link:
58
+
59
+ ```bash
60
+ python demo_gr.py --name flux-dev --share
61
+ ```
62
+
63
+ ## Diffusers integration
64
+
65
+ `FLUX.1 [schnell]` and `FLUX.1 [dev]` are integrated with the [🧨 diffusers](https://github.com/huggingface/diffusers) library. To use it with diffusers, install it:
66
+
67
+ ```shell
68
+ pip install git+https://github.com/huggingface/diffusers.git
69
+ ```
70
+
71
+ Then you can use `FluxPipeline` to run the model
72
+
73
+ ```python
74
+ import torch
75
+ from diffusers import FluxPipeline
76
+
77
+ model_id = "black-forest-labs/FLUX.1-schnell" #you can also use `black-forest-labs/FLUX.1-dev`
78
+
79
+ pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
80
+ pipe.enable_model_cpu_offload() #save some VRAM by offloading the model to CPU. Remove this if you have enough GPU power
81
+
82
+ prompt = "A cat holding a sign that says hello world"
83
+ seed = 42
84
+ image = pipe(
85
+ prompt,
86
+ output_type="pil",
87
+ num_inference_steps=4, #use a larger number if you are using [dev]
88
+ generator=torch.Generator("cpu").manual_seed(seed)
89
+ ).images[0]
90
+ image.save("flux-schnell.png")
91
+ ```
92
+
93
+ To learn more check out the [diffusers](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux) documentation
data/flux/model_cards/FLUX.1-dev.md ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ![FLUX.1 [dev] Grid](../assets/dev_grid.jpg)
2
+
3
+ `FLUX.1 [dev]` is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions.
4
+ For more information, please read our [blog post](https://blackforestlabs.ai/announcing-black-forest-labs/).
5
+
6
+ # Key Features
7
+ 1. Cutting-edge output quality, second only to our state-of-the-art model `FLUX.1 [pro]`.
8
+ 2. Competitive prompt following, matching the performance of closed source alternatives.
9
+ 3. Trained using guidance distillation, making `FLUX.1 [dev]` more efficient.
10
+ 4. Open weights to drive new scientific research, and empower artists to develop innovative workflows.
11
+ 5. Generated outputs can be used for personal, scientific, and commercial purposes, as described in the [flux-1-dev-non-commercial-license](./licence.md).
12
+
13
+ # Usage
14
+ We provide a reference implementation of `FLUX.1 [dev]`, as well as sampling code, in a dedicated [github repository](https://github.com/black-forest-labs/flux).
15
+ Developers and creatives looking to build on top of `FLUX.1 [dev]` are encouraged to use this as a starting point.
16
+
17
+ ## API Endpoints
18
+ The FLUX.1 models are also available via API from the following sources
19
+ 1. [bfl.ml](https://docs.bfl.ml/) (currently `FLUX.1 [pro]`)
20
+ 2. [replicate.com](https://replicate.com/collections/flux)
21
+ 3. [fal.ai](https://fal.ai/models/fal-ai/flux/dev)
22
+
23
+ ## ComfyUI
24
+ `FLUX.1 [dev]` is also available in [Comfy UI](https://github.com/comfyanonymous/ComfyUI) for local inference with a node-based workflow.
25
+
26
+ ---
27
+ # Limitations
28
+ - This model is not intended or able to provide factual information.
29
+ - As a statistical model this checkpoint might amplify existing societal biases.
30
+ - The model may fail to generate output that matches the prompts.
31
+ - Prompt following is heavily influenced by the prompting-style.
32
+
33
+ # Out-of-Scope Use
34
+ The model and its derivatives may not be used
35
+
36
+ - In any way that violates any applicable national, federal, state, local or international law or regulation.
37
+ - For the purpose of exploiting, harming or attempting to exploit or harm minors in any way; including but not limited to the solicitation, creation, acquisition, or dissemination of child exploitative content.
38
+ - To generate or disseminate verifiably false information and/or content with the purpose of harming others.
39
+ - To generate or disseminate personal identifiable information that can be used to harm an individual.
40
+ - To harass, abuse, threaten, stalk, or bully individuals or groups of individuals.
41
+ - To create non-consensual nudity or illegal pornographic content.
42
+ - For fully automated decision making that adversely impacts an individual's legal rights or otherwise creates or modifies a binding, enforceable obligation.
43
+ - Generating or facilitating large-scale disinformation campaigns.
44
+
45
+ # License
46
+ This model falls under the [`FLUX.1 [dev]` Non-Commercial License](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md).
data/flux/model_cards/FLUX.1-schnell.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ![FLUX.1 [schnell] Grid](../assets/schnell_grid.jpg)
2
+
3
+ `FLUX.1 [schnell]` is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions.
4
+ For more information, please read our [blog post](https://blackforestlabs.ai/announcing-black-forest-labs/).
5
+
6
+ # Key Features
7
+ 1. Cutting-edge output quality and competitive prompt following, matching the performance of closed source alternatives.
8
+ 2. Trained using latent adversarial diffusion distillation, `FLUX.1 [schnell]` can generate high-quality images in only 1 to 4 steps.
9
+ 3. Released under the `apache-2.0` licence, the model can be used for personal, scientific, and commercial purposes.
10
+
11
+ # Usage
12
+ We provide a reference implementation of `FLUX.1 [schnell]`, as well as sampling code, in a dedicated [github repository](https://github.com/black-forest-labs/flux).
13
+ Developers and creatives looking to build on top of `FLUX.1 [schnell]` are encouraged to use this as a starting point.
14
+
15
+ ## API Endpoints
16
+ The FLUX.1 models are also available via API from the following sources
17
+ 1. [bfl.ml](https://docs.bfl.ml/) (currently `FLUX.1 [pro]`)
18
+ 2. [replicate.com](https://replicate.com/collections/flux)
19
+ 3. [fal.ai](https://fal.ai/models/fal-ai/flux/schnell)
20
+
21
+ ## ComfyUI
22
+ `FLUX.1 [schnell]` is also available in [Comfy UI](https://github.com/comfyanonymous/ComfyUI) for local inference with a node-based workflow.
23
+
24
+ ---
25
+ # Limitations
26
+ - This model is not intended or able to provide factual information.
27
+ - As a statistical model this checkpoint might amplify existing societal biases.
28
+ - The model may fail to generate output that matches the prompts.
29
+ - Prompt following is heavily influenced by the prompting-style.
30
+
31
+ # Out-of-Scope Use
32
+ The model and its derivatives may not be used
33
+
34
+ - In any way that violates any applicable national, federal, state, local or international law or regulation.
35
+ - For the purpose of exploiting, harming or attempting to exploit or harm minors in any way; including but not limited to the solicitation, creation, acquisition, or dissemination of child exploitative content.
36
+ - To generate or disseminate verifiably false information and/or content with the purpose of harming others.
37
+ - To generate or disseminate personal identifiable information that can be used to harm an individual.
38
+ - To harass, abuse, threaten, stalk, or bully individuals or groups of individuals.
39
+ - To create non-consensual nudity or illegal pornographic content.
40
+ - For fully automated decision making that adversely impacts an individual's legal rights or otherwise creates or modifies a binding, enforceable obligation.
41
+ - Generating or facilitating large-scale disinformation campaigns.
data/flux/model_licenses/LICENSE-FLUX1-dev ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FLUX.1 [dev] Non-Commercial License
2
+ Black Forest Labs, Inc. (“we” or “our” or “Company”) is pleased to make available the weights, parameters and inference code for the FLUX.1 [dev] Model (as defined below) freely available for your non-commercial and non-production use as set forth in this FLUX.1 [dev] Non-Commercial License (“License”). The “FLUX.1 [dev] Model” means the FLUX.1 [dev] AI models, including FLUX.1 [dev], FLUX.1 Fill [dev], FLUX.1 Depth [dev], FLUX.1 Canny [dev], FLUX.1 Redux [dev], FLUX.1 Canny [dev] LoRA and FLUX.1 Depth [dev] LoRA, and their elements which includes algorithms, software, checkpoints, parameters, source code (inference code, evaluation code, and if applicable, fine-tuning code) and any other materials associated with the FLUX.1 [dev] AI models made available by Company under this License, including if any, the technical documentation, manuals and instructions for the use and operation thereof (collectively, “FLUX.1 [dev] Model”).
3
+ By downloading, accessing, use, Distributing (as defined below), or creating a Derivative (as defined below) of the FLUX.1 [dev] Model, you agree to the terms of this License. If you do not agree to this License, then you do not have any rights to access, use, Distribute or create a Derivative of the FLUX.1 [dev] Model and you must immediately cease using the FLUX.1 [dev] Model. If you are agreeing to be bound by the terms of this License on behalf of your employer or other entity, you represent and warrant to us that you have full legal authority to bind your employer or such entity to this License. If you do not have the requisite authority, you may not accept the License or access the FLUX.1 [dev] Model on behalf of your employer or other entity.
4
+ 1. Definitions. Capitalized terms used in this License but not defined herein have the following meanings:
5
+ a. “Derivative” means any (i) modified version of the FLUX.1 [dev] Model (including but not limited to any customized or fine-tuned version thereof), (ii) work based on the FLUX.1 [dev] Model, or (iii) any other derivative work thereof. For the avoidance of doubt, Outputs are not considered Derivatives under this License.
6
+ b. “Distribution” or “Distribute” or “Distributing” means providing or making available, by any means, a copy of the FLUX.1 [dev] Models and/or the Derivatives as the case may be.
7
+ c. “Non-Commercial Purpose” means any of the following uses, but only so far as you do not receive any direct or indirect payment arising from the use of the model or its output: (i) personal use for research, experiment, and testing for the benefit of public knowledge, personal study, private entertainment, hobby projects, or otherwise not directly or indirectly connected to any commercial activities, business operations, or employment responsibilities; (ii) use by commercial or for-profit entities for testing, evaluation, or non-commercial research and development in a non-production environment, (iii) use by any charitable organization for charitable purposes, or for testing or evaluation. For clarity, use for revenue-generating activity or direct interactions with or impacts on end users, or use to train, fine tune or distill other models for commercial use is not a Non-Commercial purpose.
8
+ d. “Outputs” means any content generated by the operation of the FLUX.1 [dev] Models or the Derivatives from a prompt (i.e., text instructions) provided by users. For the avoidance of doubt, Outputs do not include any components of a FLUX.1 [dev] Models, such as any fine-tuned versions of the FLUX.1 [dev] Models, the weights, or parameters.
9
+ e. “you” or “your” means the individual or entity entering into this License with Company.
10
+ 2. License Grant.
11
+ a. License. Subject to your compliance with this License, Company grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license to access, use, create Derivatives of, and Distribute the FLUX.1 [dev] Models solely for your Non-Commercial Purposes. The foregoing license is personal to you, and you may not assign or sublicense this License or any other rights or obligations under this License without Company’s prior written consent; any such assignment or sublicense will be void and will automatically and immediately terminate this License. Any restrictions set forth herein in regarding the FLUX.1 [dev] Model also applies to any Derivative you create or that are created on your behalf.
12
+ b. Non-Commercial Use Only. You may only access, use, Distribute, or creative Derivatives of or the FLUX.1 [dev] Model or Derivatives for Non-Commercial Purposes. If You want to use a FLUX.1 [dev] Model a Derivative for any purpose that is not expressly authorized under this License, such as for a commercial activity, you must request a license from Company, which Company may grant to you in Company’s sole discretion and which additional use may be subject to a fee, royalty or other revenue share. Please contact Company at the following e-mail address if you want to discuss such a license: info@blackforestlabs.ai.
13
+ c. Reserved Rights. The grant of rights expressly set forth in this License are the complete grant of rights to you in the FLUX.1 [dev] Model, and no other licenses are granted, whether by waiver, estoppel, implication, equity or otherwise. Company and its licensors reserve all rights not expressly granted by this License.
14
+ d. Outputs. We claim no ownership rights in and to the Outputs. You are solely responsible for the Outputs you generate and their subsequent uses in accordance with this License. You may use Output for any purpose (including for commercial purposes), except as expressly prohibited herein. You may not use the Output to train, fine-tune or distill a model that is competitive with the FLUX.1 [dev] Model.
15
+ 3. Distribution. Subject to this License, you may Distribute copies of the FLUX.1 [dev] Model and/or Derivatives made by you, under the following conditions:
16
+ a. you must make available a copy of this License to third-party recipients of the FLUX.1 [dev] Models and/or Derivatives you Distribute, and specify that any rights to use the FLUX.1 [dev] Models and/or Derivatives shall be directly granted by Company to said third-party recipients pursuant to this License;
17
+ b. you must make prominently display the following notice alongside the Distribution of the FLUX.1 [dev] Model or Derivative (such as via a “Notice” text file distributed as part of such FLUX.1 [dev] Model or Derivative) (the “Attribution Notice”):
18
+ “The FLUX.1 [dev] Model is licensed by Black Forest Labs. Inc. under the FLUX.1 [dev] Non-Commercial License. Copyright Black Forest Labs. Inc.
19
+ IN NO EVENT SHALL BLACK FOREST LABS, INC. BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH USE OF THIS MODEL.”
20
+ c. in the case of Distribution of Derivatives made by you, you must also include in the Attribution Notice a statement that you have modified the applicable FLUX.1 [dev] Model; and
21
+ d. in the case of Distribution of Derivatives made by you, any terms and conditions you impose on any third-party recipients relating to Derivatives made by or for you shall neither limit such third-party recipients’ use of the FLUX.1 [dev] Model or any Derivatives made by or for Company in accordance with this License nor conflict with any of its terms and conditions.
22
+ e. In the case of Distribution of Derivatives made by you, you must not misrepresent or imply, through any means, that the Derivatives made by or for you and/or any modified version of the FLUX.1 [dev] Model you Distribute under your name and responsibility is an official product of the Company or has been endorsed, approved or validated by the Company, unless you are authorized by Company to do so in writing.
23
+ 4. Restrictions. You will not, and will not permit, assist or cause any third party to
24
+ a. use, modify, copy, reproduce, create Derivatives of, or Distribute the FLUX.1 [dev] Model (or any Derivative thereof, or any data produced by the FLUX.1 [dev] Model), in whole or in part, for (i) any commercial or production purposes, (ii) military purposes, (iii) purposes of surveillance, including any research or development relating to surveillance, (iv) biometric processing, (v) in any manner that infringes, misappropriates, or otherwise violates any third-party rights, or (vi) in any manner that violates any applicable law and violating any privacy or security laws, rules, regulations, directives, or governmental requirements (including the General Data Privacy Regulation (Regulation (EU) 2016/679), the California Consumer Privacy Act, and any and all laws governing the processing of biometric information), as well as all amendments and successor laws to any of the foregoing;
25
+ b. alter or remove copyright and other proprietary notices which appear on or in any portion of the FLUX.1 [dev] Model;
26
+ c. utilize any equipment, device, software, or other means to circumvent or remove any security or protection used by Company in connection with the FLUX.1 [dev] Model, or to circumvent or remove any usage restrictions, or to enable functionality disabled by FLUX.1 [dev] Model; or
27
+ d. offer or impose any terms on the FLUX.1 [dev] Model that alter, restrict, or are inconsistent with the terms of this License.
28
+ e. violate any applicable U.S. and non-U.S. export control and trade sanctions laws (“Export Laws”) in connection with your use or Distribution of any FLUX.1 [dev] Model;
29
+ f. directly or indirectly Distribute, export, or otherwise transfer FLUX.1 [dev] Model (a) to any individual, entity, or country prohibited by Export Laws; (b) to anyone on U.S. or non-U.S. government restricted parties lists; or (c) for any purpose prohibited by Export Laws, including nuclear, chemical or biological weapons, or missile technology applications; 3) use or download FLUX.1 [dev] Model if you or they are (a) located in a comprehensively sanctioned jurisdiction, (b) currently listed on any U.S. or non-U.S. restricted parties list, or (c) for any purpose prohibited by Export Laws; and (4) will not disguise your location through IP proxying or other methods.
30
+ 5. DISCLAIMERS. THE FLUX.1 [dev] MODEL IS PROVIDED “AS IS” AND “WITH ALL FAULTS” WITH NO WARRANTY OF ANY KIND, EXPRESS OR IMPLIED. COMPANY EXPRESSLY DISCLAIMS ALL REPRESENTATIONS AND WARRANTIES, EXPRESS OR IMPLIED, WHETHER BY STATUTE, CUSTOM, USAGE OR OTHERWISE AS TO ANY MATTERS RELATED TO THE FLUX.1 [dev] MODEL, INCLUDING BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE, SATISFACTORY QUALITY, OR NON-INFRINGEMENT. COMPANY MAKES NO WARRANTIES OR REPRESENTATIONS THAT THE FLUX.1 [dev] MODEL WILL BE ERROR FREE OR FREE OF VIRUSES OR OTHER HARMFUL COMPONENTS, OR PRODUCE ANY PARTICULAR RESULTS.
31
+ 6. LIMITATION OF LIABILITY. TO THE FULLEST EXTENT PERMITTED BY LAW, IN NO EVENT WILL COMPANY BE LIABLE TO YOU OR YOUR EMPLOYEES, AFFILIATES, USERS, OFFICERS OR DIRECTORS (A) UNDER ANY THEORY OF LIABILITY, WHETHER BASED IN CONTRACT, TORT, NEGLIGENCE, STRICT LIABILITY, WARRANTY, OR OTHERWISE UNDER THIS LICENSE, OR (B) FOR ANY INDIRECT, CONSEQUENTIAL, EXEMPLARY, INCIDENTAL, PUNITIVE OR SPECIAL DAMAGES OR LOST PROFITS, EVEN IF COMPANY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. THE FLUX.1 [dev] MODEL, ITS CONSTITUENT COMPONENTS, AND ANY OUTPUT (COLLECTIVELY, “MODEL MATERIALS”) ARE NOT DESIGNED OR INTENDED FOR USE IN ANY APPLICATION OR SITUATION WHERE FAILURE OR FAULT OF THE MODEL MATERIALS COULD REASONABLY BE ANTICIPATED TO LEAD TO SERIOUS INJURY OF ANY PERSON, INCLUDING POTENTIAL DISCRIMINATION OR VIOLATION OF AN INDIVIDUAL’S PRIVACY RIGHTS, OR TO SEVERE PHYSICAL, PROPERTY, OR ENVIRONMENTAL DAMAGE (EACH, A “HIGH-RISK USE”). IF YOU ELECT TO USE ANY OF THE MODEL MATERIALS FOR A HIGH-RISK USE, YOU DO SO AT YOUR OWN RISK. YOU AGREE TO DESIGN AND IMPLEMENT APPROPRIATE DECISION-MAKING AND RISK-MITIGATION PROCEDURES AND POLICIES IN CONNECTION WITH A HIGH-RISK USE SUCH THAT EVEN IF THERE IS A FAILURE OR FAULT IN ANY OF THE MODEL MATERIALS, THE SAFETY OF PERSONS OR PROPERTY AFFECTED BY THE ACTIVITY STAYS AT A LEVEL THAT IS REASONABLE, APPROPRIATE, AND LAWFUL FOR THE FIELD OF THE HIGH-RISK USE.
32
+ 7. INDEMNIFICATION
33
+
34
+ You will indemnify, defend and hold harmless Company and our subsidiaries and affiliates, and each of our respective shareholders, directors, officers, employees, agents, successors, and assigns (collectively, the “Company Parties”) from and against any losses, liabilities, damages, fines, penalties, and expenses (including reasonable attorneys’ fees) incurred by any Company Party in connection with any claim, demand, allegation, lawsuit, proceeding, or investigation (collectively, “Claims”) arising out of or related to (a) your access to or use of the FLUX.1 [dev] Model (as well as any Output, results or data generated from such access or use), including any High-Risk Use (defined below); (b) your violation of this License; or (c) your violation, misappropriation or infringement of any rights of another (including intellectual property or other proprietary rights and privacy rights). You will promptly notify the Company Parties of any such Claims, and cooperate with Company Parties in defending such Claims. You will also grant the Company Parties sole control of the defense or settlement, at Company’s sole option, of any Claims. This indemnity is in addition to, and not in lieu of, any other indemnities or remedies set forth in a written agreement between you and Company or the other Company Parties.
35
+ 8. Termination; Survival.
36
+ a. This License will automatically terminate upon any breach by you of the terms of this License.
37
+ b. We may terminate this License, in whole or in part, at any time upon notice (including electronic) to you.
38
+ c. If You initiate any legal action or proceedings against Company or any other entity (including a cross-claim or counterclaim in a lawsuit), alleging that the FLUX.1 [dev] Model or any Derivative, or any part thereof, infringe upon intellectual property or other rights owned or licensable by you, then any licenses granted to you under this License will immediately terminate as of the date such legal action or claim is filed or initiated.
39
+ d. Upon termination of this License, you must cease all use, access or Distribution of the FLUX.1 [dev] Model and any Derivatives. The following sections survive termination of this License 2(c), 2(d), 4-11.
40
+ 9. Third Party Materials. The FLUX.1 [dev] Model may contain third-party software or other components (including free and open source software) (all of the foregoing, “Third Party Materials”), which are subject to the license terms of the respective third-party licensors. Your dealings or correspondence with third parties and your use of or interaction with any Third Party Materials are solely between you and the third party. Company does not control or endorse, and makes no representations or warranties regarding, any Third Party Materials, and your access to and use of such Third Party Materials are at your own risk.
41
+ 10. Trademarks. You have not been granted any trademark license as part of this License and may not use any name or mark associated with Company without the prior written permission of Company, except to the extent necessary to make the reference required in the Attribution Notice as specified above or as is reasonably necessary in describing the FLUX.1 [dev] Model and its creators.
42
+ 11. General. This License will be governed and construed under the laws of the State of Delaware without regard to conflicts of law provisions. If any provision or part of a provision of this License is unlawful, void or unenforceable, that provision or part of the provision is deemed severed from this License, and will not affect the validity and enforceability of any remaining provisions. The failure of Company to exercise or enforce any right or provision of this License will not operate as a waiver of such right or provision. This License does not confer any third-party beneficiary rights upon any other person or entity. This License, together with the Documentation, contains the entire understanding between you and Company regarding the subject matter of this License, and supersedes all other written or oral agreements and understandings between you and Company regarding such subject matter. No change or addition to any provision of this License will be binding unless it is in writing and signed by an authorized representative of both you and Company.
data/flux/model_licenses/LICENSE-FLUX1-schnell ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ Apache License
4
+ Version 2.0, January 2004
5
+ http://www.apache.org/licenses/
6
+
7
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
8
+
9
+ 1. Definitions.
10
+
11
+ "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
16
+
17
+ "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
18
+
19
+ "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
20
+
21
+ "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
22
+
23
+ "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
24
+
25
+ "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
26
+
27
+ "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
28
+
29
+ "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
30
+
31
+ 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
32
+
33
+ 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
34
+
35
+ 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
36
+
37
+ You must give any other recipients of the Work or Derivative Works a copy of this License; and
38
+ You must cause any modified files to carry prominent notices stating that You changed the files; and
39
+ You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
40
+ If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
41
+
42
+ You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
43
+
44
+ 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
45
+
46
+ 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
47
+
48
+ 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
49
+
50
+ 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
51
+
52
+ 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
53
+
54
+ END OF TERMS AND CONDITIONS
data/flux/pyproject.toml ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "flux"
3
+ authors = [
4
+ { name = "Black Forest Labs", email = "support@blackforestlabs.ai" },
5
+ ]
6
+ description = "Inference codebase for FLUX"
7
+ readme = "README.md"
8
+ requires-python = ">=3.10"
9
+ license = { file = "LICENSE.md" }
10
+ dynamic = ["version"]
11
+ dependencies = [
12
+ "torch == 2.5.1",
13
+ "torchvision",
14
+ "einops",
15
+ "fire >= 0.6.0",
16
+ "huggingface-hub",
17
+ "safetensors",
18
+ "sentencepiece",
19
+ "transformers",
20
+ "tokenizers",
21
+ "protobuf",
22
+ "requests",
23
+ "invisible-watermark",
24
+ "ruff == 0.6.8",
25
+ ]
26
+
27
+ [project.optional-dependencies]
28
+ streamlit = [
29
+ "streamlit",
30
+ "streamlit-drawable-canvas",
31
+ "streamlit-keyup",
32
+ ]
33
+ gradio = [
34
+ "gradio",
35
+ ]
36
+ all = [
37
+ "flux[streamlit]",
38
+ "flux[gradio]",
39
+ ]
40
+
41
+ [project.scripts]
42
+ flux = "flux.cli:app"
43
+
44
+ [build-system]
45
+ build-backend = "setuptools.build_meta"
46
+ requires = ["setuptools>=64", "wheel", "setuptools_scm>=8"]
47
+
48
+ [tool.ruff]
49
+ line-length = 110
50
+ target-version = "py310"
51
+ extend-exclude = ["/usr/lib/*"]
52
+
53
+ [tool.ruff.lint]
54
+ ignore = [
55
+ "E501", # line too long - will be fixed in format
56
+ ]
57
+
58
+ [tool.ruff.format]
59
+ quote-style = "double"
60
+ indent-style = "space"
61
+ line-ending = "auto"
62
+ skip-magic-trailing-comma = false
63
+ docstring-code-format = true
64
+ exclude = [
65
+ "src/flux/_version.py", # generated by setuptools_scm
66
+ ]
67
+
68
+ [tool.ruff.lint.isort]
69
+ combine-as-imports = true
70
+ force-wrap-aliases = true
71
+ known-local-folder = ["src"]
72
+ known-first-party = ["flux"]
73
+
74
+ [tool.pyright]
75
+ include = ["src"]
76
+ exclude = [
77
+ "**/__pycache__", # cache directories
78
+ "./typings", # generated type stubs
79
+ ]
80
+ stubPath = "./typings"
81
+
82
+ [tool.tomlsort]
83
+ in_place = true
84
+ no_sort_tables = true
85
+ spaces_before_inline_comment = 1
86
+ spaces_indent_inline_array = 2
87
+ trailing_comma_inline_array = true
88
+ sort_first = [
89
+ "project",
90
+ "build-system",
91
+ "tool.setuptools",
92
+ ]
93
+
94
+ # needs to be last for CI reasons
95
+ [tool.setuptools_scm]
96
+ write_to = "src/flux/_version.py"
97
+ parentdir_prefix_version = "flux-"
98
+ fallback_version = "0.0.0"
99
+ version_scheme = "post-release"
data/flux/setup.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ import setuptools
2
+
3
+ setuptools.setup()
data/flux/src/flux/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ from ._version import (
3
+ version as __version__, # type: ignore
4
+ version_tuple,
5
+ )
6
+ except ImportError:
7
+ __version__ = "unknown (no version information available)"
8
+ version_tuple = (0, 0, "unknown", "noinfo")
9
+
10
+ from pathlib import Path
11
+
12
+ PACKAGE = __package__.replace("_", "-")
13
+ PACKAGE_ROOT = Path(__file__).parent
data/flux/src/flux/__main__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .cli import app
2
+
3
+ if __name__ == "__main__":
4
+ app()
data/flux/src/flux/api.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import time
4
+ from pathlib import Path
5
+
6
+ import requests
7
+ from PIL import Image
8
+
9
+ API_URL = "https://api.bfl.ml"
10
+ API_ENDPOINTS = {
11
+ "flux.1-pro": "flux-pro",
12
+ "flux.1-dev": "flux-dev",
13
+ "flux.1.1-pro": "flux-pro-1.1",
14
+ }
15
+
16
+
17
+ class ApiException(Exception):
18
+ def __init__(self, status_code: int, detail: str | list[dict] | None = None):
19
+ super().__init__()
20
+ self.detail = detail
21
+ self.status_code = status_code
22
+
23
+ def __str__(self) -> str:
24
+ return self.__repr__()
25
+
26
+ def __repr__(self) -> str:
27
+ if self.detail is None:
28
+ message = None
29
+ elif isinstance(self.detail, str):
30
+ message = self.detail
31
+ else:
32
+ message = "[" + ",".join(d["msg"] for d in self.detail) + "]"
33
+ return f"ApiException({self.status_code=}, {message=}, detail={self.detail})"
34
+
35
+
36
+ class ImageRequest:
37
+ def __init__(
38
+ self,
39
+ # api inputs
40
+ prompt: str,
41
+ name: str = "flux.1.1-pro",
42
+ width: int | None = None,
43
+ height: int | None = None,
44
+ num_steps: int | None = None,
45
+ prompt_upsampling: bool | None = None,
46
+ seed: int | None = None,
47
+ guidance: float | None = None,
48
+ interval: float | None = None,
49
+ safety_tolerance: int | None = None,
50
+ # behavior of this class
51
+ validate: bool = True,
52
+ launch: bool = True,
53
+ api_key: str | None = None,
54
+ ):
55
+ """
56
+ Manages an image generation request to the API.
57
+
58
+ All parameters not specified will use the API defaults.
59
+
60
+ Args:
61
+ prompt: Text prompt for image generation.
62
+ width: Width of the generated image in pixels. Must be a multiple of 32.
63
+ height: Height of the generated image in pixels. Must be a multiple of 32.
64
+ name: Which model version to use
65
+ num_steps: Number of steps for the image generation process.
66
+ prompt_upsampling: Whether to perform upsampling on the prompt.
67
+ seed: Optional seed for reproducibility.
68
+ guidance: Guidance scale for image generation.
69
+ safety_tolerance: Tolerance level for input and output moderation.
70
+ Between 0 and 6, 0 being most strict, 6 being least strict.
71
+ validate: Run input validation
72
+ launch: Directly launches request
73
+ api_key: Your API key if not provided by the environment
74
+
75
+ Raises:
76
+ ValueError: For invalid input, when `validate`
77
+ ApiException: For errors raised from the API
78
+ """
79
+ if validate:
80
+ if name not in API_ENDPOINTS.keys():
81
+ raise ValueError(f"Invalid model {name}")
82
+ elif width is not None and width % 32 != 0:
83
+ raise ValueError(f"width must be divisible by 32, got {width}")
84
+ elif width is not None and not (256 <= width <= 1440):
85
+ raise ValueError(f"width must be between 256 and 1440, got {width}")
86
+ elif height is not None and height % 32 != 0:
87
+ raise ValueError(f"height must be divisible by 32, got {height}")
88
+ elif height is not None and not (256 <= height <= 1440):
89
+ raise ValueError(f"height must be between 256 and 1440, got {height}")
90
+ elif num_steps is not None and not (1 <= num_steps <= 50):
91
+ raise ValueError(f"steps must be between 1 and 50, got {num_steps}")
92
+ elif guidance is not None and not (1.5 <= guidance <= 5.0):
93
+ raise ValueError(f"guidance must be between 1.5 and 4, got {guidance}")
94
+ elif interval is not None and not (1.0 <= interval <= 4.0):
95
+ raise ValueError(f"interval must be between 1 and 4, got {interval}")
96
+ elif safety_tolerance is not None and not (0 <= safety_tolerance <= 6.0):
97
+ raise ValueError(f"safety_tolerance must be between 0 and 6, got {interval}")
98
+
99
+ if name == "flux.1-dev":
100
+ if interval is not None:
101
+ raise ValueError("Interval is not supported for flux.1-dev")
102
+ if name == "flux.1.1-pro":
103
+ if interval is not None or num_steps is not None or guidance is not None:
104
+ raise ValueError("Interval, num_steps and guidance are not supported for " "flux.1.1-pro")
105
+
106
+ self.name = name
107
+ self.request_json = {
108
+ "prompt": prompt,
109
+ "width": width,
110
+ "height": height,
111
+ "steps": num_steps,
112
+ "prompt_upsampling": prompt_upsampling,
113
+ "seed": seed,
114
+ "guidance": guidance,
115
+ "interval": interval,
116
+ "safety_tolerance": safety_tolerance,
117
+ }
118
+ self.request_json = {key: value for key, value in self.request_json.items() if value is not None}
119
+
120
+ self.request_id: str | None = None
121
+ self.result: dict | None = None
122
+ self._image_bytes: bytes | None = None
123
+ self._url: str | None = None
124
+ if api_key is None:
125
+ self.api_key = os.environ.get("BFL_API_KEY")
126
+ else:
127
+ self.api_key = api_key
128
+
129
+ if launch:
130
+ self.request()
131
+
132
+ def request(self):
133
+ """
134
+ Request to generate the image.
135
+ """
136
+ if self.request_id is not None:
137
+ return
138
+ response = requests.post(
139
+ f"{API_URL}/v1/{API_ENDPOINTS[self.name]}",
140
+ headers={
141
+ "accept": "application/json",
142
+ "x-key": self.api_key,
143
+ "Content-Type": "application/json",
144
+ },
145
+ json=self.request_json,
146
+ )
147
+ result = response.json()
148
+ if response.status_code != 200:
149
+ raise ApiException(status_code=response.status_code, detail=result.get("detail"))
150
+ self.request_id = response.json()["id"]
151
+
152
+ def retrieve(self) -> dict:
153
+ """
154
+ Wait for the generation to finish and retrieve response.
155
+ """
156
+ if self.request_id is None:
157
+ self.request()
158
+ while self.result is None:
159
+ response = requests.get(
160
+ f"{API_URL}/v1/get_result",
161
+ headers={
162
+ "accept": "application/json",
163
+ "x-key": self.api_key,
164
+ },
165
+ params={
166
+ "id": self.request_id,
167
+ },
168
+ )
169
+ result = response.json()
170
+ if "status" not in result:
171
+ raise ApiException(status_code=response.status_code, detail=result.get("detail"))
172
+ elif result["status"] == "Ready":
173
+ self.result = result["result"]
174
+ elif result["status"] == "Pending":
175
+ time.sleep(0.5)
176
+ else:
177
+ raise ApiException(status_code=200, detail=f"API returned status '{result['status']}'")
178
+ return self.result
179
+
180
+ @property
181
+ def bytes(self) -> bytes:
182
+ """
183
+ Generated image as bytes.
184
+ """
185
+ if self._image_bytes is None:
186
+ response = requests.get(self.url)
187
+ if response.status_code == 200:
188
+ self._image_bytes = response.content
189
+ else:
190
+ raise ApiException(status_code=response.status_code)
191
+ return self._image_bytes
192
+
193
+ @property
194
+ def url(self) -> str:
195
+ """
196
+ Public url to retrieve the image from
197
+ """
198
+ if self._url is None:
199
+ result = self.retrieve()
200
+ self._url = result["sample"]
201
+ return self._url
202
+
203
+ @property
204
+ def image(self) -> Image.Image:
205
+ """
206
+ Load the image as a PIL Image
207
+ """
208
+ return Image.open(io.BytesIO(self.bytes))
209
+
210
+ def save(self, path: str):
211
+ """
212
+ Save the generated image to a local path
213
+ """
214
+ suffix = Path(self.url).suffix
215
+ if not path.endswith(suffix):
216
+ path = path + suffix
217
+ Path(path).resolve().parent.mkdir(parents=True, exist_ok=True)
218
+ with open(path, "wb") as file:
219
+ file.write(self.bytes)
220
+
221
+
222
+ if __name__ == "__main__":
223
+ from fire import Fire
224
+
225
+ Fire(ImageRequest)
data/flux/src/flux/cli.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import time
4
+ from dataclasses import dataclass
5
+ from glob import iglob
6
+
7
+ import torch
8
+ from fire import Fire
9
+ from transformers import pipeline
10
+
11
+ from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
12
+ from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image
13
+
14
+ NSFW_THRESHOLD = 0.85
15
+
16
+
17
+ @dataclass
18
+ class SamplingOptions:
19
+ prompt: str
20
+ width: int
21
+ height: int
22
+ num_steps: int
23
+ guidance: float
24
+ seed: int | None
25
+
26
+
27
+ def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
28
+ user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
29
+ usage = (
30
+ "Usage: Either write your prompt directly, leave this field empty "
31
+ "to repeat the prompt or write a command starting with a slash:\n"
32
+ "- '/w <width>' will set the width of the generated image\n"
33
+ "- '/h <height>' will set the height of the generated image\n"
34
+ "- '/s <seed>' sets the next seed\n"
35
+ "- '/g <guidance>' sets the guidance (flux-dev only)\n"
36
+ "- '/n <steps>' sets the number of steps\n"
37
+ "- '/q' to quit"
38
+ )
39
+
40
+ while (prompt := input(user_question)).startswith("/"):
41
+ if prompt.startswith("/w"):
42
+ if prompt.count(" ") != 1:
43
+ print(f"Got invalid command '{prompt}'\n{usage}")
44
+ continue
45
+ _, width = prompt.split()
46
+ options.width = 16 * (int(width) // 16)
47
+ print(
48
+ f"Setting resolution to {options.width} x {options.height} "
49
+ f"({options.height *options.width/1e6:.2f}MP)"
50
+ )
51
+ elif prompt.startswith("/h"):
52
+ if prompt.count(" ") != 1:
53
+ print(f"Got invalid command '{prompt}'\n{usage}")
54
+ continue
55
+ _, height = prompt.split()
56
+ options.height = 16 * (int(height) // 16)
57
+ print(
58
+ f"Setting resolution to {options.width} x {options.height} "
59
+ f"({options.height *options.width/1e6:.2f}MP)"
60
+ )
61
+ elif prompt.startswith("/g"):
62
+ if prompt.count(" ") != 1:
63
+ print(f"Got invalid command '{prompt}'\n{usage}")
64
+ continue
65
+ _, guidance = prompt.split()
66
+ options.guidance = float(guidance)
67
+ print(f"Setting guidance to {options.guidance}")
68
+ elif prompt.startswith("/s"):
69
+ if prompt.count(" ") != 1:
70
+ print(f"Got invalid command '{prompt}'\n{usage}")
71
+ continue
72
+ _, seed = prompt.split()
73
+ options.seed = int(seed)
74
+ print(f"Setting seed to {options.seed}")
75
+ elif prompt.startswith("/n"):
76
+ if prompt.count(" ") != 1:
77
+ print(f"Got invalid command '{prompt}'\n{usage}")
78
+ continue
79
+ _, steps = prompt.split()
80
+ options.num_steps = int(steps)
81
+ print(f"Setting number of steps to {options.num_steps}")
82
+ elif prompt.startswith("/q"):
83
+ print("Quitting")
84
+ return None
85
+ else:
86
+ if not prompt.startswith("/h"):
87
+ print(f"Got invalid command '{prompt}'\n{usage}")
88
+ print(usage)
89
+ if prompt != "":
90
+ options.prompt = prompt
91
+ return options
92
+
93
+
94
+ @torch.inference_mode()
95
+ def main(
96
+ name: str = "flux-schnell",
97
+ width: int = 1360,
98
+ height: int = 768,
99
+ seed: int | None = None,
100
+ prompt: str = (
101
+ "a photo of a forest with mist swirling around the tree trunks. The word "
102
+ '"FLUX" is painted over it in big, red brush strokes with visible texture'
103
+ ),
104
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
105
+ num_steps: int | None = None,
106
+ loop: bool = False,
107
+ guidance: float = 3.5,
108
+ offload: bool = False,
109
+ output_dir: str = "output",
110
+ add_sampling_metadata: bool = True,
111
+ ):
112
+ """
113
+ Sample the flux model. Either interactively (set `--loop`) or run for a
114
+ single image.
115
+
116
+ Args:
117
+ name: Name of the model to load
118
+ height: height of the sample in pixels (should be a multiple of 16)
119
+ width: width of the sample in pixels (should be a multiple of 16)
120
+ seed: Set a seed for sampling
121
+ output_name: where to save the output image, `{idx}` will be replaced
122
+ by the index of the sample
123
+ prompt: Prompt used for sampling
124
+ device: Pytorch device
125
+ num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
126
+ loop: start an interactive session and sample multiple times
127
+ guidance: guidance value used for guidance distillation
128
+ add_sampling_metadata: Add the prompt to the image Exif metadata
129
+ """
130
+ nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
131
+
132
+ if name not in configs:
133
+ available = ", ".join(configs.keys())
134
+ raise ValueError(f"Got unknown model name: {name}, chose from {available}")
135
+
136
+ torch_device = torch.device(device)
137
+ if num_steps is None:
138
+ num_steps = 4 if name == "flux-schnell" else 50
139
+
140
+ # allow for packing and conversion to latent space
141
+ height = 16 * (height // 16)
142
+ width = 16 * (width // 16)
143
+
144
+ output_name = os.path.join(output_dir, "img_{idx}.jpg")
145
+ if not os.path.exists(output_dir):
146
+ os.makedirs(output_dir)
147
+ idx = 0
148
+ else:
149
+ fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
150
+ if len(fns) > 0:
151
+ idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
152
+ else:
153
+ idx = 0
154
+
155
+ # init all components
156
+ t5 = load_t5(torch_device, max_length=256 if name == "flux-schnell" else 512)
157
+ clip = load_clip(torch_device)
158
+ model = load_flow_model(name, device="cpu" if offload else torch_device)
159
+ ae = load_ae(name, device="cpu" if offload else torch_device)
160
+
161
+ rng = torch.Generator(device="cpu")
162
+ opts = SamplingOptions(
163
+ prompt=prompt,
164
+ width=width,
165
+ height=height,
166
+ num_steps=num_steps,
167
+ guidance=guidance,
168
+ seed=seed,
169
+ )
170
+
171
+ if loop:
172
+ opts = parse_prompt(opts)
173
+
174
+ while opts is not None:
175
+ if opts.seed is None:
176
+ opts.seed = rng.seed()
177
+ print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
178
+ t0 = time.perf_counter()
179
+
180
+ # prepare input
181
+ x = get_noise(
182
+ 1,
183
+ opts.height,
184
+ opts.width,
185
+ device=torch_device,
186
+ dtype=torch.bfloat16,
187
+ seed=opts.seed,
188
+ )
189
+ opts.seed = None
190
+ if offload:
191
+ ae = ae.cpu()
192
+ torch.cuda.empty_cache()
193
+ t5, clip = t5.to(torch_device), clip.to(torch_device)
194
+ inp = prepare(t5, clip, x, prompt=opts.prompt)
195
+ timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
196
+
197
+ # offload TEs to CPU, load model to gpu
198
+ if offload:
199
+ t5, clip = t5.cpu(), clip.cpu()
200
+ torch.cuda.empty_cache()
201
+ model = model.to(torch_device)
202
+
203
+ # denoise initial noise
204
+ x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)
205
+
206
+ # offload model, load autoencoder to gpu
207
+ if offload:
208
+ model.cpu()
209
+ torch.cuda.empty_cache()
210
+ ae.decoder.to(x.device)
211
+
212
+ # decode latents to pixel space
213
+ x = unpack(x.float(), opts.height, opts.width)
214
+ with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
215
+ x = ae.decode(x)
216
+
217
+ if torch.cuda.is_available():
218
+ torch.cuda.synchronize()
219
+ t1 = time.perf_counter()
220
+
221
+ fn = output_name.format(idx=idx)
222
+ print(f"Done in {t1 - t0:.1f}s. Saving {fn}")
223
+
224
+ idx = save_image(nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt)
225
+
226
+ if loop:
227
+ print("-" * 80)
228
+ opts = parse_prompt(opts)
229
+ else:
230
+ opts = None
231
+
232
+
233
+ def app():
234
+ Fire(main)
235
+
236
+
237
+ if __name__ == "__main__":
238
+ app()
data/flux/src/flux/cli_control.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import time
4
+ from dataclasses import dataclass
5
+ from glob import iglob
6
+
7
+ import torch
8
+ from fire import Fire
9
+ from transformers import pipeline
10
+
11
+ from flux.modules.image_embedders import CannyImageEncoder, DepthImageEncoder
12
+ from flux.sampling import denoise, get_noise, get_schedule, prepare_control, unpack
13
+ from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image
14
+
15
+
16
+ @dataclass
17
+ class SamplingOptions:
18
+ prompt: str
19
+ width: int
20
+ height: int
21
+ num_steps: int
22
+ guidance: float
23
+ seed: int | None
24
+ img_cond_path: str
25
+ lora_scale: float | None
26
+
27
+
28
+ def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
29
+ user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
30
+ usage = (
31
+ "Usage: Either write your prompt directly, leave this field empty "
32
+ "to repeat the prompt or write a command starting with a slash:\n"
33
+ "- '/w <width>' will set the width of the generated image\n"
34
+ "- '/h <height>' will set the height of the generated image\n"
35
+ "- '/s <seed>' sets the next seed\n"
36
+ "- '/g <guidance>' sets the guidance (flux-dev only)\n"
37
+ "- '/n <steps>' sets the number of steps\n"
38
+ "- '/q' to quit"
39
+ )
40
+
41
+ while (prompt := input(user_question)).startswith("/"):
42
+ if prompt.startswith("/w"):
43
+ if prompt.count(" ") != 1:
44
+ print(f"Got invalid command '{prompt}'\n{usage}")
45
+ continue
46
+ _, width = prompt.split()
47
+ options.width = 16 * (int(width) // 16)
48
+ print(
49
+ f"Setting resolution to {options.width} x {options.height} "
50
+ f"({options.height *options.width/1e6:.2f}MP)"
51
+ )
52
+ elif prompt.startswith("/h"):
53
+ if prompt.count(" ") != 1:
54
+ print(f"Got invalid command '{prompt}'\n{usage}")
55
+ continue
56
+ _, height = prompt.split()
57
+ options.height = 16 * (int(height) // 16)
58
+ print(
59
+ f"Setting resolution to {options.width} x {options.height} "
60
+ f"({options.height *options.width/1e6:.2f}MP)"
61
+ )
62
+ elif prompt.startswith("/g"):
63
+ if prompt.count(" ") != 1:
64
+ print(f"Got invalid command '{prompt}'\n{usage}")
65
+ continue
66
+ _, guidance = prompt.split()
67
+ options.guidance = float(guidance)
68
+ print(f"Setting guidance to {options.guidance}")
69
+ elif prompt.startswith("/s"):
70
+ if prompt.count(" ") != 1:
71
+ print(f"Got invalid command '{prompt}'\n{usage}")
72
+ continue
73
+ _, seed = prompt.split()
74
+ options.seed = int(seed)
75
+ print(f"Setting seed to {options.seed}")
76
+ elif prompt.startswith("/n"):
77
+ if prompt.count(" ") != 1:
78
+ print(f"Got invalid command '{prompt}'\n{usage}")
79
+ continue
80
+ _, steps = prompt.split()
81
+ options.num_steps = int(steps)
82
+ print(f"Setting number of steps to {options.num_steps}")
83
+ elif prompt.startswith("/q"):
84
+ print("Quitting")
85
+ return None
86
+ else:
87
+ if not prompt.startswith("/h"):
88
+ print(f"Got invalid command '{prompt}'\n{usage}")
89
+ print(usage)
90
+ if prompt != "":
91
+ options.prompt = prompt
92
+ return options
93
+
94
+
95
+ def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOptions | None:
96
+ if options is None:
97
+ return None
98
+
99
+ user_question = "Next conditioning image (write /h for help, /q to quit and leave empty to repeat):\n"
100
+ usage = (
101
+ "Usage: Either write your prompt directly, leave this field empty "
102
+ "to repeat the conditioning image or write a command starting with a slash:\n"
103
+ "- '/q' to quit"
104
+ )
105
+
106
+ while True:
107
+ img_cond_path = input(user_question)
108
+
109
+ if img_cond_path.startswith("/"):
110
+ if img_cond_path.startswith("/q"):
111
+ print("Quitting")
112
+ return None
113
+ else:
114
+ if not img_cond_path.startswith("/h"):
115
+ print(f"Got invalid command '{img_cond_path}'\n{usage}")
116
+ print(usage)
117
+ continue
118
+
119
+ if img_cond_path == "":
120
+ break
121
+
122
+ if not os.path.isfile(img_cond_path) or not img_cond_path.lower().endswith(
123
+ (".jpg", ".jpeg", ".png", ".webp")
124
+ ):
125
+ print(f"File '{img_cond_path}' does not exist or is not a valid image file")
126
+ continue
127
+
128
+ options.img_cond_path = img_cond_path
129
+ break
130
+
131
+ return options
132
+
133
+
134
+ def parse_lora_scale(options: SamplingOptions | None) -> tuple[SamplingOptions | None, bool]:
135
+ changed = False
136
+
137
+ if options is None:
138
+ return None, changed
139
+
140
+ user_question = "Next lora scale (write /h for help, /q to quit and leave empty to repeat):\n"
141
+ usage = (
142
+ "Usage: Either write your prompt directly, leave this field empty "
143
+ "to repeat the lora scale or write a command starting with a slash:\n"
144
+ "- '/q' to quit"
145
+ )
146
+
147
+ while (prompt := input(user_question)).startswith("/"):
148
+ if prompt.startswith("/q"):
149
+ print("Quitting")
150
+ return None, changed
151
+ else:
152
+ if not prompt.startswith("/h"):
153
+ print(f"Got invalid command '{prompt}'\n{usage}")
154
+ print(usage)
155
+ if prompt != "":
156
+ options.lora_scale = float(prompt)
157
+ changed = True
158
+ return options, changed
159
+
160
+
161
+ @torch.inference_mode()
162
+ def main(
163
+ name: str,
164
+ width: int = 1024,
165
+ height: int = 1024,
166
+ seed: int | None = None,
167
+ prompt: str = "a robot made out of gold",
168
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
169
+ num_steps: int = 50,
170
+ loop: bool = False,
171
+ guidance: float | None = None,
172
+ offload: bool = False,
173
+ output_dir: str = "output",
174
+ add_sampling_metadata: bool = True,
175
+ img_cond_path: str = "assets/robot.webp",
176
+ lora_scale: float | None = 0.85,
177
+ ):
178
+ """
179
+ Sample the flux model. Either interactively (set `--loop`) or run for a
180
+ single image.
181
+
182
+ Args:
183
+ height: height of the sample in pixels (should be a multiple of 16)
184
+ width: width of the sample in pixels (should be a multiple of 16)
185
+ seed: Set a seed for sampling
186
+ output_name: where to save the output image, `{idx}` will be replaced
187
+ by the index of the sample
188
+ prompt: Prompt used for sampling
189
+ device: Pytorch device
190
+ num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
191
+ loop: start an interactive session and sample multiple times
192
+ guidance: guidance value used for guidance distillation
193
+ add_sampling_metadata: Add the prompt to the image Exif metadata
194
+ img_cond_path: path to conditioning image (jpeg/png/webp)
195
+ """
196
+ nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
197
+
198
+ assert name in [
199
+ "flux-dev-canny",
200
+ "flux-dev-depth",
201
+ "flux-dev-canny-lora",
202
+ "flux-dev-depth-lora",
203
+ ], f"Got unknown model name: {name}"
204
+ if guidance is None:
205
+ if name in ["flux-dev-canny", "flux-dev-canny-lora"]:
206
+ guidance = 30.0
207
+ elif name in ["flux-dev-depth", "flux-dev-depth-lora"]:
208
+ guidance = 10.0
209
+ else:
210
+ raise NotImplementedError()
211
+
212
+ if name not in configs:
213
+ available = ", ".join(configs.keys())
214
+ raise ValueError(f"Got unknown model name: {name}, chose from {available}")
215
+
216
+ torch_device = torch.device(device)
217
+
218
+ output_name = os.path.join(output_dir, "img_{idx}.jpg")
219
+ if not os.path.exists(output_dir):
220
+ os.makedirs(output_dir)
221
+ idx = 0
222
+ else:
223
+ fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
224
+ if len(fns) > 0:
225
+ idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
226
+ else:
227
+ idx = 0
228
+
229
+ # init all components
230
+ t5 = load_t5(torch_device, max_length=512)
231
+ clip = load_clip(torch_device)
232
+ model = load_flow_model(name, device="cpu" if offload else torch_device)
233
+ ae = load_ae(name, device="cpu" if offload else torch_device)
234
+
235
+ # set lora scale
236
+ if "lora" in name and lora_scale is not None:
237
+ for _, module in model.named_modules():
238
+ if hasattr(module, "set_scale"):
239
+ module.set_scale(lora_scale)
240
+
241
+ if name in ["flux-dev-depth", "flux-dev-depth-lora"]:
242
+ img_embedder = DepthImageEncoder(torch_device)
243
+ elif name in ["flux-dev-canny", "flux-dev-canny-lora"]:
244
+ img_embedder = CannyImageEncoder(torch_device)
245
+ else:
246
+ raise NotImplementedError()
247
+
248
+ rng = torch.Generator(device="cpu")
249
+ opts = SamplingOptions(
250
+ prompt=prompt,
251
+ width=width,
252
+ height=height,
253
+ num_steps=num_steps,
254
+ guidance=guidance,
255
+ seed=seed,
256
+ img_cond_path=img_cond_path,
257
+ lora_scale=lora_scale,
258
+ )
259
+
260
+ if loop:
261
+ opts = parse_prompt(opts)
262
+ opts = parse_img_cond_path(opts)
263
+ if "lora" in name:
264
+ opts, changed = parse_lora_scale(opts)
265
+ if changed:
266
+ # update the lora scale:
267
+ for _, module in model.named_modules():
268
+ if hasattr(module, "set_scale"):
269
+ module.set_scale(opts.lora_scale)
270
+
271
+ while opts is not None:
272
+ if opts.seed is None:
273
+ opts.seed = rng.seed()
274
+ print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
275
+ t0 = time.perf_counter()
276
+
277
+ # prepare input
278
+ x = get_noise(
279
+ 1,
280
+ opts.height,
281
+ opts.width,
282
+ device=torch_device,
283
+ dtype=torch.bfloat16,
284
+ seed=opts.seed,
285
+ )
286
+ opts.seed = None
287
+ if offload:
288
+ t5, clip, ae = t5.to(torch_device), clip.to(torch_device), ae.to(torch_device)
289
+ inp = prepare_control(
290
+ t5,
291
+ clip,
292
+ x,
293
+ prompt=opts.prompt,
294
+ ae=ae,
295
+ encoder=img_embedder,
296
+ img_cond_path=opts.img_cond_path,
297
+ )
298
+ timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
299
+
300
+ # offload TEs and AE to CPU, load model to gpu
301
+ if offload:
302
+ t5, clip, ae = t5.cpu(), clip.cpu(), ae.cpu()
303
+ torch.cuda.empty_cache()
304
+ model = model.to(torch_device)
305
+
306
+ # denoise initial noise
307
+ x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)
308
+
309
+ # offload model, load autoencoder to gpu
310
+ if offload:
311
+ model.cpu()
312
+ torch.cuda.empty_cache()
313
+ ae.decoder.to(x.device)
314
+
315
+ # decode latents to pixel space
316
+ x = unpack(x.float(), opts.height, opts.width)
317
+ with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
318
+ x = ae.decode(x)
319
+
320
+ if torch.cuda.is_available():
321
+ torch.cuda.synchronize()
322
+ t1 = time.perf_counter()
323
+ print(f"Done in {t1 - t0:.1f}s")
324
+
325
+ idx = save_image(nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt)
326
+
327
+ if loop:
328
+ print("-" * 80)
329
+ opts = parse_prompt(opts)
330
+ opts = parse_img_cond_path(opts)
331
+ if "lora" in name:
332
+ opts, changed = parse_lora_scale(opts)
333
+ if changed:
334
+ # update the lora scale:
335
+ for _, module in model.named_modules():
336
+ if hasattr(module, "set_scale"):
337
+ module.set_scale(opts.lora_scale)
338
+ else:
339
+ opts = None
340
+
341
+
342
+ def app():
343
+ Fire(main)
344
+
345
+
346
+ if __name__ == "__main__":
347
+ app()
data/flux/src/flux/cli_fill.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import time
4
+ from dataclasses import dataclass
5
+ from glob import iglob
6
+
7
+ import torch
8
+ from fire import Fire
9
+ from PIL import Image
10
+ from transformers import pipeline
11
+
12
+ from flux.sampling import denoise, get_noise, get_schedule, prepare_fill, unpack
13
+ from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image
14
+
15
+
16
+ @dataclass
17
+ class SamplingOptions:
18
+ prompt: str
19
+ width: int
20
+ height: int
21
+ num_steps: int
22
+ guidance: float
23
+ seed: int | None
24
+ img_cond_path: str
25
+ img_mask_path: str
26
+
27
+
28
+ def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
29
+ user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
30
+ usage = (
31
+ "Usage: Either write your prompt directly, leave this field empty "
32
+ "to repeat the prompt or write a command starting with a slash:\n"
33
+ "- '/s <seed>' sets the next seed\n"
34
+ "- '/g <guidance>' sets the guidance (flux-dev only)\n"
35
+ "- '/n <steps>' sets the number of steps\n"
36
+ "- '/q' to quit"
37
+ )
38
+
39
+ while (prompt := input(user_question)).startswith("/"):
40
+ if prompt.startswith("/g"):
41
+ if prompt.count(" ") != 1:
42
+ print(f"Got invalid command '{prompt}'\n{usage}")
43
+ continue
44
+ _, guidance = prompt.split()
45
+ options.guidance = float(guidance)
46
+ print(f"Setting guidance to {options.guidance}")
47
+ elif prompt.startswith("/s"):
48
+ if prompt.count(" ") != 1:
49
+ print(f"Got invalid command '{prompt}'\n{usage}")
50
+ continue
51
+ _, seed = prompt.split()
52
+ options.seed = int(seed)
53
+ print(f"Setting seed to {options.seed}")
54
+ elif prompt.startswith("/n"):
55
+ if prompt.count(" ") != 1:
56
+ print(f"Got invalid command '{prompt}'\n{usage}")
57
+ continue
58
+ _, steps = prompt.split()
59
+ options.num_steps = int(steps)
60
+ print(f"Setting number of steps to {options.num_steps}")
61
+ elif prompt.startswith("/q"):
62
+ print("Quitting")
63
+ return None
64
+ else:
65
+ if not prompt.startswith("/h"):
66
+ print(f"Got invalid command '{prompt}'\n{usage}")
67
+ print(usage)
68
+ if prompt != "":
69
+ options.prompt = prompt
70
+ return options
71
+
72
+
73
+ def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOptions | None:
74
+ if options is None:
75
+ return None
76
+
77
+ user_question = "Next conditioning image (write /h for help, /q to quit and leave empty to repeat):\n"
78
+ usage = (
79
+ "Usage: Either write your prompt directly, leave this field empty "
80
+ "to repeat the conditioning image or write a command starting with a slash:\n"
81
+ "- '/q' to quit"
82
+ )
83
+
84
+ while True:
85
+ img_cond_path = input(user_question)
86
+
87
+ if img_cond_path.startswith("/"):
88
+ if img_cond_path.startswith("/q"):
89
+ print("Quitting")
90
+ return None
91
+ else:
92
+ if not img_cond_path.startswith("/h"):
93
+ print(f"Got invalid command '{img_cond_path}'\n{usage}")
94
+ print(usage)
95
+ continue
96
+
97
+ if img_cond_path == "":
98
+ break
99
+
100
+ if not os.path.isfile(img_cond_path) or not img_cond_path.lower().endswith(
101
+ (".jpg", ".jpeg", ".png", ".webp")
102
+ ):
103
+ print(f"File '{img_cond_path}' does not exist or is not a valid image file")
104
+ continue
105
+ else:
106
+ with Image.open(img_cond_path) as img:
107
+ width, height = img.size
108
+
109
+ if width % 32 != 0 or height % 32 != 0:
110
+ print(f"Image dimensions must be divisible by 32, got {width}x{height}")
111
+ continue
112
+
113
+ options.img_cond_path = img_cond_path
114
+ break
115
+
116
+ return options
117
+
118
+
119
+ def parse_img_mask_path(options: SamplingOptions | None) -> SamplingOptions | None:
120
+ if options is None:
121
+ return None
122
+
123
+ user_question = "Next conditioning mask (write /h for help, /q to quit and leave empty to repeat):\n"
124
+ usage = (
125
+ "Usage: Either write your prompt directly, leave this field empty "
126
+ "to repeat the conditioning mask or write a command starting with a slash:\n"
127
+ "- '/q' to quit"
128
+ )
129
+
130
+ while True:
131
+ img_mask_path = input(user_question)
132
+
133
+ if img_mask_path.startswith("/"):
134
+ if img_mask_path.startswith("/q"):
135
+ print("Quitting")
136
+ return None
137
+ else:
138
+ if not img_mask_path.startswith("/h"):
139
+ print(f"Got invalid command '{img_mask_path}'\n{usage}")
140
+ print(usage)
141
+ continue
142
+
143
+ if img_mask_path == "":
144
+ break
145
+
146
+ if not os.path.isfile(img_mask_path) or not img_mask_path.lower().endswith(
147
+ (".jpg", ".jpeg", ".png", ".webp")
148
+ ):
149
+ print(f"File '{img_mask_path}' does not exist or is not a valid image file")
150
+ continue
151
+ else:
152
+ with Image.open(img_mask_path) as img:
153
+ width, height = img.size
154
+
155
+ if width % 32 != 0 or height % 32 != 0:
156
+ print(f"Image dimensions must be divisible by 32, got {width}x{height}")
157
+ continue
158
+ else:
159
+ with Image.open(options.img_cond_path) as img_cond:
160
+ img_cond_width, img_cond_height = img_cond.size
161
+
162
+ if width != img_cond_width or height != img_cond_height:
163
+ print(
164
+ f"Mask dimensions must match conditioning image, got {width}x{height} and {img_cond_width}x{img_cond_height}"
165
+ )
166
+ continue
167
+
168
+ options.img_mask_path = img_mask_path
169
+ break
170
+
171
+ return options
172
+
173
+
174
+ @torch.inference_mode()
175
+ def main(
176
+ seed: int | None = None,
177
+ prompt: str = "a white paper cup",
178
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
179
+ num_steps: int = 50,
180
+ loop: bool = False,
181
+ guidance: float = 30.0,
182
+ offload: bool = False,
183
+ output_dir: str = "output",
184
+ add_sampling_metadata: bool = True,
185
+ img_cond_path: str = "assets/cup.png",
186
+ img_mask_path: str = "assets/cup_mask.png",
187
+ ):
188
+ """
189
+ Sample the flux model. Either interactively (set `--loop`) or run for a
190
+ single image. This demo assumes that the conditioning image and mask have
191
+ the same shape and that height and width are divisible by 32.
192
+
193
+ Args:
194
+ seed: Set a seed for sampling
195
+ output_name: where to save the output image, `{idx}` will be replaced
196
+ by the index of the sample
197
+ prompt: Prompt used for sampling
198
+ device: Pytorch device
199
+ num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
200
+ loop: start an interactive session and sample multiple times
201
+ guidance: guidance value used for guidance distillation
202
+ add_sampling_metadata: Add the prompt to the image Exif metadata
203
+ img_cond_path: path to conditioning image (jpeg/png/webp)
204
+ img_mask_path: path to conditioning mask (jpeg/png/webp
205
+ """
206
+ nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
207
+
208
+ name = "flux-dev-fill"
209
+ if name not in configs:
210
+ available = ", ".join(configs.keys())
211
+ raise ValueError(f"Got unknown model name: {name}, chose from {available}")
212
+
213
+ torch_device = torch.device(device)
214
+
215
+ output_name = os.path.join(output_dir, "img_{idx}.jpg")
216
+ if not os.path.exists(output_dir):
217
+ os.makedirs(output_dir)
218
+ idx = 0
219
+ else:
220
+ fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
221
+ if len(fns) > 0:
222
+ idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
223
+ else:
224
+ idx = 0
225
+
226
+ # init all components
227
+ t5 = load_t5(torch_device, max_length=128)
228
+ clip = load_clip(torch_device)
229
+ model = load_flow_model(name, device="cpu" if offload else torch_device)
230
+ ae = load_ae(name, device="cpu" if offload else torch_device)
231
+
232
+ rng = torch.Generator(device="cpu")
233
+ with Image.open(img_cond_path) as img:
234
+ width, height = img.size
235
+ opts = SamplingOptions(
236
+ prompt=prompt,
237
+ width=width,
238
+ height=height,
239
+ num_steps=num_steps,
240
+ guidance=guidance,
241
+ seed=seed,
242
+ img_cond_path=img_cond_path,
243
+ img_mask_path=img_mask_path,
244
+ )
245
+
246
+ if loop:
247
+ opts = parse_prompt(opts)
248
+ opts = parse_img_cond_path(opts)
249
+
250
+ with Image.open(opts.img_cond_path) as img:
251
+ width, height = img.size
252
+ opts.height = height
253
+ opts.width = width
254
+
255
+ opts = parse_img_mask_path(opts)
256
+
257
+ while opts is not None:
258
+ if opts.seed is None:
259
+ opts.seed = rng.seed()
260
+ print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
261
+ t0 = time.perf_counter()
262
+
263
+ # prepare input
264
+ x = get_noise(
265
+ 1,
266
+ opts.height,
267
+ opts.width,
268
+ device=torch_device,
269
+ dtype=torch.bfloat16,
270
+ seed=opts.seed,
271
+ )
272
+ opts.seed = None
273
+ if offload:
274
+ t5, clip, ae = t5.to(torch_device), clip.to(torch_device), ae.to(torch.device)
275
+ inp = prepare_fill(
276
+ t5,
277
+ clip,
278
+ x,
279
+ prompt=opts.prompt,
280
+ ae=ae,
281
+ img_cond_path=opts.img_cond_path,
282
+ mask_path=opts.img_mask_path,
283
+ )
284
+
285
+ timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
286
+
287
+ # offload TEs and AE to CPU, load model to gpu
288
+ if offload:
289
+ t5, clip, ae = t5.cpu(), clip.cpu(), ae.cpu()
290
+ torch.cuda.empty_cache()
291
+ model = model.to(torch_device)
292
+
293
+ # denoise initial noise
294
+ x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)
295
+
296
+ # offload model, load autoencoder to gpu
297
+ if offload:
298
+ model.cpu()
299
+ torch.cuda.empty_cache()
300
+ ae.decoder.to(x.device)
301
+
302
+ # decode latents to pixel space
303
+ x = unpack(x.float(), opts.height, opts.width)
304
+ with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
305
+ x = ae.decode(x)
306
+
307
+ if torch.cuda.is_available():
308
+ torch.cuda.synchronize()
309
+ t1 = time.perf_counter()
310
+ print(f"Done in {t1 - t0:.1f}s")
311
+
312
+ idx = save_image(nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt)
313
+
314
+ if loop:
315
+ print("-" * 80)
316
+ opts = parse_prompt(opts)
317
+ opts = parse_img_cond_path(opts)
318
+
319
+ with Image.open(opts.img_cond_path) as img:
320
+ width, height = img.size
321
+ opts.height = height
322
+ opts.width = width
323
+
324
+ opts = parse_img_mask_path(opts)
325
+ else:
326
+ opts = None
327
+
328
+
329
+ def app():
330
+ Fire(main)
331
+
332
+
333
+ if __name__ == "__main__":
334
+ app()
data/flux/src/flux/cli_redux.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import time
4
+ from dataclasses import dataclass
5
+ from glob import iglob
6
+
7
+ import torch
8
+ from fire import Fire
9
+ from transformers import pipeline
10
+
11
+ from flux.modules.image_embedders import ReduxImageEncoder
12
+ from flux.sampling import denoise, get_noise, get_schedule, prepare_redux, unpack
13
+ from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image
14
+
15
+
16
+ @dataclass
17
+ class SamplingOptions:
18
+ prompt: str
19
+ width: int
20
+ height: int
21
+ num_steps: int
22
+ guidance: float
23
+ seed: int | None
24
+ img_cond_path: str
25
+
26
+
27
+ def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
28
+ user_question = "Write /h for help, /q to quit and leave empty to repeat):\n"
29
+ usage = (
30
+ "Usage: Leave this field empty to do nothing "
31
+ "or write a command starting with a slash:\n"
32
+ "- '/w <width>' will set the width of the generated image\n"
33
+ "- '/h <height>' will set the height of the generated image\n"
34
+ "- '/s <seed>' sets the next seed\n"
35
+ "- '/g <guidance>' sets the guidance (flux-dev only)\n"
36
+ "- '/n <steps>' sets the number of steps\n"
37
+ "- '/q' to quit"
38
+ )
39
+
40
+ while (prompt := input(user_question)).startswith("/"):
41
+ if prompt.startswith("/w"):
42
+ if prompt.count(" ") != 1:
43
+ print(f"Got invalid command '{prompt}'\n{usage}")
44
+ continue
45
+ _, width = prompt.split()
46
+ options.width = 16 * (int(width) // 16)
47
+ print(
48
+ f"Setting resolution to {options.width} x {options.height} "
49
+ f"({options.height *options.width/1e6:.2f}MP)"
50
+ )
51
+ elif prompt.startswith("/h"):
52
+ if prompt.count(" ") != 1:
53
+ print(f"Got invalid command '{prompt}'\n{usage}")
54
+ continue
55
+ _, height = prompt.split()
56
+ options.height = 16 * (int(height) // 16)
57
+ print(
58
+ f"Setting resolution to {options.width} x {options.height} "
59
+ f"({options.height *options.width/1e6:.2f}MP)"
60
+ )
61
+ elif prompt.startswith("/g"):
62
+ if prompt.count(" ") != 1:
63
+ print(f"Got invalid command '{prompt}'\n{usage}")
64
+ continue
65
+ _, guidance = prompt.split()
66
+ options.guidance = float(guidance)
67
+ print(f"Setting guidance to {options.guidance}")
68
+ elif prompt.startswith("/s"):
69
+ if prompt.count(" ") != 1:
70
+ print(f"Got invalid command '{prompt}'\n{usage}")
71
+ continue
72
+ _, seed = prompt.split()
73
+ options.seed = int(seed)
74
+ print(f"Setting seed to {options.seed}")
75
+ elif prompt.startswith("/n"):
76
+ if prompt.count(" ") != 1:
77
+ print(f"Got invalid command '{prompt}'\n{usage}")
78
+ continue
79
+ _, steps = prompt.split()
80
+ options.num_steps = int(steps)
81
+ print(f"Setting number of steps to {options.num_steps}")
82
+ elif prompt.startswith("/q"):
83
+ print("Quitting")
84
+ return None
85
+ else:
86
+ if not prompt.startswith("/h"):
87
+ print(f"Got invalid command '{prompt}'\n{usage}")
88
+ print(usage)
89
+ return options
90
+
91
+
92
+ def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOptions | None:
93
+ if options is None:
94
+ return None
95
+
96
+ user_question = "Next conditioning image (write /h for help, /q to quit and leave empty to repeat):\n"
97
+ usage = (
98
+ "Usage: Either write your prompt directly, leave this field empty "
99
+ "to repeat the conditioning image or write a command starting with a slash:\n"
100
+ "- '/q' to quit"
101
+ )
102
+
103
+ while True:
104
+ img_cond_path = input(user_question)
105
+
106
+ if img_cond_path.startswith("/"):
107
+ if img_cond_path.startswith("/q"):
108
+ print("Quitting")
109
+ return None
110
+ else:
111
+ if not img_cond_path.startswith("/h"):
112
+ print(f"Got invalid command '{img_cond_path}'\n{usage}")
113
+ print(usage)
114
+ continue
115
+
116
+ if img_cond_path == "":
117
+ break
118
+
119
+ if not os.path.isfile(img_cond_path) or not img_cond_path.lower().endswith(
120
+ (".jpg", ".jpeg", ".png", ".webp")
121
+ ):
122
+ print(f"File '{img_cond_path}' does not exist or is not a valid image file")
123
+ continue
124
+
125
+ options.img_cond_path = img_cond_path
126
+ break
127
+
128
+ return options
129
+
130
+
131
+ @torch.inference_mode()
132
+ def main(
133
+ name: str = "flux-dev",
134
+ width: int = 1360,
135
+ height: int = 768,
136
+ seed: int | None = None,
137
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
138
+ num_steps: int | None = None,
139
+ loop: bool = False,
140
+ guidance: float = 2.5,
141
+ offload: bool = False,
142
+ output_dir: str = "output",
143
+ add_sampling_metadata: bool = True,
144
+ img_cond_path: str = "assets/robot.webp",
145
+ ):
146
+ """
147
+ Sample the flux model. Either interactively (set `--loop`) or run for a
148
+ single image.
149
+
150
+ Args:
151
+ name: Name of the model to load
152
+ height: height of the sample in pixels (should be a multiple of 16)
153
+ width: width of the sample in pixels (should be a multiple of 16)
154
+ seed: Set a seed for sampling
155
+ output_name: where to save the output image, `{idx}` will be replaced
156
+ by the index of the sample
157
+ prompt: Prompt used for sampling
158
+ device: Pytorch device
159
+ num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
160
+ loop: start an interactive session and sample multiple times
161
+ guidance: guidance value used for guidance distillation
162
+ add_sampling_metadata: Add the prompt to the image Exif metadata
163
+ img_cond_path: path to conditioning image (jpeg/png/webp)
164
+ """
165
+ nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
166
+
167
+ if name not in configs:
168
+ available = ", ".join(configs.keys())
169
+ raise ValueError(f"Got unknown model name: {name}, chose from {available}")
170
+
171
+ torch_device = torch.device(device)
172
+ if num_steps is None:
173
+ num_steps = 4 if name == "flux-schnell" else 50
174
+
175
+ output_name = os.path.join(output_dir, "img_{idx}.jpg")
176
+ if not os.path.exists(output_dir):
177
+ os.makedirs(output_dir)
178
+ idx = 0
179
+ else:
180
+ fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
181
+ if len(fns) > 0:
182
+ idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
183
+ else:
184
+ idx = 0
185
+
186
+ # init all components
187
+ t5 = load_t5(torch_device, max_length=256 if name == "flux-schnell" else 512)
188
+ clip = load_clip(torch_device)
189
+ model = load_flow_model(name, device="cpu" if offload else torch_device)
190
+ ae = load_ae(name, device="cpu" if offload else torch_device)
191
+ img_embedder = ReduxImageEncoder(torch_device)
192
+
193
+ rng = torch.Generator(device="cpu")
194
+ prompt = ""
195
+ opts = SamplingOptions(
196
+ prompt=prompt,
197
+ width=width,
198
+ height=height,
199
+ num_steps=num_steps,
200
+ guidance=guidance,
201
+ seed=seed,
202
+ img_cond_path=img_cond_path,
203
+ )
204
+
205
+ if loop:
206
+ opts = parse_prompt(opts)
207
+ opts = parse_img_cond_path(opts)
208
+
209
+ while opts is not None:
210
+ if opts.seed is None:
211
+ opts.seed = rng.seed()
212
+ print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
213
+ t0 = time.perf_counter()
214
+
215
+ # prepare input
216
+ x = get_noise(
217
+ 1,
218
+ opts.height,
219
+ opts.width,
220
+ device=torch_device,
221
+ dtype=torch.bfloat16,
222
+ seed=opts.seed,
223
+ )
224
+ opts.seed = None
225
+ if offload:
226
+ ae = ae.cpu()
227
+ torch.cuda.empty_cache()
228
+ t5, clip = t5.to(torch_device), clip.to(torch_device)
229
+ inp = prepare_redux(
230
+ t5,
231
+ clip,
232
+ x,
233
+ prompt=opts.prompt,
234
+ encoder=img_embedder,
235
+ img_cond_path=opts.img_cond_path,
236
+ )
237
+ timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
238
+
239
+ # offload TEs to CPU, load model to gpu
240
+ if offload:
241
+ t5, clip = t5.cpu(), clip.cpu()
242
+ torch.cuda.empty_cache()
243
+ model = model.to(torch_device)
244
+
245
+ # denoise initial noise
246
+ x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)
247
+
248
+ # offload model, load autoencoder to gpu
249
+ if offload:
250
+ model.cpu()
251
+ torch.cuda.empty_cache()
252
+ ae.decoder.to(x.device)
253
+
254
+ # decode latents to pixel space
255
+ x = unpack(x.float(), opts.height, opts.width)
256
+ with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
257
+ x = ae.decode(x)
258
+
259
+ if torch.cuda.is_available():
260
+ torch.cuda.synchronize()
261
+ t1 = time.perf_counter()
262
+ print(f"Done in {t1 - t0:.1f}s")
263
+
264
+ idx = save_image(nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt)
265
+
266
+ if loop:
267
+ print("-" * 80)
268
+ opts = parse_prompt(opts)
269
+ opts = parse_img_cond_path(opts)
270
+ else:
271
+ opts = None
272
+
273
+
274
+ def app():
275
+ Fire(main)
276
+
277
+
278
+ if __name__ == "__main__":
279
+ app()
data/flux/src/flux/math.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+ from torch import Tensor
4
+
5
+
6
+ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
7
+ q, k = apply_rope(q, k, pe)
8
+
9
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
10
+ x = rearrange(x, "B H L D -> B L (H D)")
11
+
12
+ return x
13
+
14
+
15
+ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
16
+ assert dim % 2 == 0
17
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
18
+ omega = 1.0 / (theta**scale)
19
+ out = torch.einsum("...n,d->...nd", pos, omega)
20
+ out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
21
+ out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
22
+ return out.float()
23
+
24
+
25
+ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
26
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
27
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
28
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
29
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
30
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
data/flux/src/flux/model.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ from torch import Tensor, nn
5
+
6
+ from flux.modules.layers import (
7
+ DoubleStreamBlock,
8
+ EmbedND,
9
+ LastLayer,
10
+ MLPEmbedder,
11
+ SingleStreamBlock,
12
+ timestep_embedding,
13
+ )
14
+ from flux.modules.lora import LinearLora, replace_linear_with_lora
15
+
16
+
17
+ @dataclass
18
+ class FluxParams:
19
+ in_channels: int
20
+ out_channels: int
21
+ vec_in_dim: int
22
+ context_in_dim: int
23
+ hidden_size: int
24
+ mlp_ratio: float
25
+ num_heads: int
26
+ depth: int
27
+ depth_single_blocks: int
28
+ axes_dim: list[int]
29
+ theta: int
30
+ qkv_bias: bool
31
+ guidance_embed: bool
32
+
33
+
34
+ class Flux(nn.Module):
35
+ """
36
+ Transformer model for flow matching on sequences.
37
+ """
38
+
39
+ def __init__(self, params: FluxParams):
40
+ super().__init__()
41
+
42
+ self.params = params
43
+ self.in_channels = params.in_channels
44
+ self.out_channels = params.out_channels
45
+ if params.hidden_size % params.num_heads != 0:
46
+ raise ValueError(
47
+ f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
48
+ )
49
+ pe_dim = params.hidden_size // params.num_heads
50
+ if sum(params.axes_dim) != pe_dim:
51
+ raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
52
+ self.hidden_size = params.hidden_size
53
+ self.num_heads = params.num_heads
54
+ self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
55
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
56
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
57
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
58
+ self.guidance_in = (
59
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
60
+ )
61
+ self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
62
+
63
+ self.double_blocks = nn.ModuleList(
64
+ [
65
+ DoubleStreamBlock(
66
+ self.hidden_size,
67
+ self.num_heads,
68
+ mlp_ratio=params.mlp_ratio,
69
+ qkv_bias=params.qkv_bias,
70
+ )
71
+ for _ in range(params.depth)
72
+ ]
73
+ )
74
+
75
+ self.single_blocks = nn.ModuleList(
76
+ [
77
+ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
78
+ for _ in range(params.depth_single_blocks)
79
+ ]
80
+ )
81
+
82
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
83
+
84
+ def forward(
85
+ self,
86
+ img: Tensor,
87
+ img_ids: Tensor,
88
+ txt: Tensor,
89
+ txt_ids: Tensor,
90
+ timesteps: Tensor,
91
+ y: Tensor,
92
+ guidance: Tensor | None = None,
93
+ ) -> Tensor:
94
+ if img.ndim != 3 or txt.ndim != 3:
95
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
96
+
97
+ # running on sequences img
98
+ img = self.img_in(img)
99
+ vec = self.time_in(timestep_embedding(timesteps, 256))
100
+ if self.params.guidance_embed:
101
+ if guidance is None:
102
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
103
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
104
+ vec = vec + self.vector_in(y)
105
+ txt = self.txt_in(txt)
106
+
107
+ ids = torch.cat((txt_ids, img_ids), dim=1)
108
+ pe = self.pe_embedder(ids)
109
+
110
+ for block in self.double_blocks:
111
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
112
+
113
+ img = torch.cat((txt, img), 1)
114
+ for block in self.single_blocks:
115
+ img = block(img, vec=vec, pe=pe)
116
+ img = img[:, txt.shape[1] :, ...]
117
+
118
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
119
+ return img
120
+
121
+
122
+ class FluxLoraWrapper(Flux):
123
+ def __init__(
124
+ self,
125
+ lora_rank: int = 128,
126
+ lora_scale: float = 1.0,
127
+ *args,
128
+ **kwargs,
129
+ ) -> None:
130
+ super().__init__(*args, **kwargs)
131
+
132
+ self.lora_rank = lora_rank
133
+
134
+ replace_linear_with_lora(
135
+ self,
136
+ max_rank=lora_rank,
137
+ scale=lora_scale,
138
+ )
139
+
140
+ def set_lora_scale(self, scale: float) -> None:
141
+ for module in self.modules():
142
+ if isinstance(module, LinearLora):
143
+ module.set_scale(scale=scale)
data/flux/src/flux/modules/autoencoder.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ from einops import rearrange
5
+ from torch import Tensor, nn
6
+
7
+
8
+ @dataclass
9
+ class AutoEncoderParams:
10
+ resolution: int
11
+ in_channels: int
12
+ ch: int
13
+ out_ch: int
14
+ ch_mult: list[int]
15
+ num_res_blocks: int
16
+ z_channels: int
17
+ scale_factor: float
18
+ shift_factor: float
19
+
20
+
21
+ def swish(x: Tensor) -> Tensor:
22
+ return x * torch.sigmoid(x)
23
+
24
+
25
+ class AttnBlock(nn.Module):
26
+ def __init__(self, in_channels: int):
27
+ super().__init__()
28
+ self.in_channels = in_channels
29
+
30
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
31
+
32
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
33
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
34
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
35
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
36
+
37
+ def attention(self, h_: Tensor) -> Tensor:
38
+ h_ = self.norm(h_)
39
+ q = self.q(h_)
40
+ k = self.k(h_)
41
+ v = self.v(h_)
42
+
43
+ b, c, h, w = q.shape
44
+ q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
45
+ k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
46
+ v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
47
+ h_ = nn.functional.scaled_dot_product_attention(q, k, v)
48
+
49
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
50
+
51
+ def forward(self, x: Tensor) -> Tensor:
52
+ return x + self.proj_out(self.attention(x))
53
+
54
+
55
+ class ResnetBlock(nn.Module):
56
+ def __init__(self, in_channels: int, out_channels: int):
57
+ super().__init__()
58
+ self.in_channels = in_channels
59
+ out_channels = in_channels if out_channels is None else out_channels
60
+ self.out_channels = out_channels
61
+
62
+ self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
63
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
64
+ self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
65
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
66
+ if self.in_channels != self.out_channels:
67
+ self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
68
+
69
+ def forward(self, x):
70
+ h = x
71
+ h = self.norm1(h)
72
+ h = swish(h)
73
+ h = self.conv1(h)
74
+
75
+ h = self.norm2(h)
76
+ h = swish(h)
77
+ h = self.conv2(h)
78
+
79
+ if self.in_channels != self.out_channels:
80
+ x = self.nin_shortcut(x)
81
+
82
+ return x + h
83
+
84
+
85
+ class Downsample(nn.Module):
86
+ def __init__(self, in_channels: int):
87
+ super().__init__()
88
+ # no asymmetric padding in torch conv, must do it ourselves
89
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
90
+
91
+ def forward(self, x: Tensor):
92
+ pad = (0, 1, 0, 1)
93
+ x = nn.functional.pad(x, pad, mode="constant", value=0)
94
+ x = self.conv(x)
95
+ return x
96
+
97
+
98
+ class Upsample(nn.Module):
99
+ def __init__(self, in_channels: int):
100
+ super().__init__()
101
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
102
+
103
+ def forward(self, x: Tensor):
104
+ x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
105
+ x = self.conv(x)
106
+ return x
107
+
108
+
109
+ class Encoder(nn.Module):
110
+ def __init__(
111
+ self,
112
+ resolution: int,
113
+ in_channels: int,
114
+ ch: int,
115
+ ch_mult: list[int],
116
+ num_res_blocks: int,
117
+ z_channels: int,
118
+ ):
119
+ super().__init__()
120
+ self.ch = ch
121
+ self.num_resolutions = len(ch_mult)
122
+ self.num_res_blocks = num_res_blocks
123
+ self.resolution = resolution
124
+ self.in_channels = in_channels
125
+ # downsampling
126
+ self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
127
+
128
+ curr_res = resolution
129
+ in_ch_mult = (1,) + tuple(ch_mult)
130
+ self.in_ch_mult = in_ch_mult
131
+ self.down = nn.ModuleList()
132
+ block_in = self.ch
133
+ for i_level in range(self.num_resolutions):
134
+ block = nn.ModuleList()
135
+ attn = nn.ModuleList()
136
+ block_in = ch * in_ch_mult[i_level]
137
+ block_out = ch * ch_mult[i_level]
138
+ for _ in range(self.num_res_blocks):
139
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
140
+ block_in = block_out
141
+ down = nn.Module()
142
+ down.block = block
143
+ down.attn = attn
144
+ if i_level != self.num_resolutions - 1:
145
+ down.downsample = Downsample(block_in)
146
+ curr_res = curr_res // 2
147
+ self.down.append(down)
148
+
149
+ # middle
150
+ self.mid = nn.Module()
151
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
152
+ self.mid.attn_1 = AttnBlock(block_in)
153
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
154
+
155
+ # end
156
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
157
+ self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
158
+
159
+ def forward(self, x: Tensor) -> Tensor:
160
+ # downsampling
161
+ hs = [self.conv_in(x)]
162
+ for i_level in range(self.num_resolutions):
163
+ for i_block in range(self.num_res_blocks):
164
+ h = self.down[i_level].block[i_block](hs[-1])
165
+ if len(self.down[i_level].attn) > 0:
166
+ h = self.down[i_level].attn[i_block](h)
167
+ hs.append(h)
168
+ if i_level != self.num_resolutions - 1:
169
+ hs.append(self.down[i_level].downsample(hs[-1]))
170
+
171
+ # middle
172
+ h = hs[-1]
173
+ h = self.mid.block_1(h)
174
+ h = self.mid.attn_1(h)
175
+ h = self.mid.block_2(h)
176
+ # end
177
+ h = self.norm_out(h)
178
+ h = swish(h)
179
+ h = self.conv_out(h)
180
+ return h
181
+
182
+
183
+ class Decoder(nn.Module):
184
+ def __init__(
185
+ self,
186
+ ch: int,
187
+ out_ch: int,
188
+ ch_mult: list[int],
189
+ num_res_blocks: int,
190
+ in_channels: int,
191
+ resolution: int,
192
+ z_channels: int,
193
+ ):
194
+ super().__init__()
195
+ self.ch = ch
196
+ self.num_resolutions = len(ch_mult)
197
+ self.num_res_blocks = num_res_blocks
198
+ self.resolution = resolution
199
+ self.in_channels = in_channels
200
+ self.ffactor = 2 ** (self.num_resolutions - 1)
201
+
202
+ # compute in_ch_mult, block_in and curr_res at lowest res
203
+ block_in = ch * ch_mult[self.num_resolutions - 1]
204
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
205
+ self.z_shape = (1, z_channels, curr_res, curr_res)
206
+
207
+ # z to block_in
208
+ self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
209
+
210
+ # middle
211
+ self.mid = nn.Module()
212
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
213
+ self.mid.attn_1 = AttnBlock(block_in)
214
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
215
+
216
+ # upsampling
217
+ self.up = nn.ModuleList()
218
+ for i_level in reversed(range(self.num_resolutions)):
219
+ block = nn.ModuleList()
220
+ attn = nn.ModuleList()
221
+ block_out = ch * ch_mult[i_level]
222
+ for _ in range(self.num_res_blocks + 1):
223
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
224
+ block_in = block_out
225
+ up = nn.Module()
226
+ up.block = block
227
+ up.attn = attn
228
+ if i_level != 0:
229
+ up.upsample = Upsample(block_in)
230
+ curr_res = curr_res * 2
231
+ self.up.insert(0, up) # prepend to get consistent order
232
+
233
+ # end
234
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
235
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
236
+
237
+ def forward(self, z: Tensor) -> Tensor:
238
+ # z to block_in
239
+ h = self.conv_in(z)
240
+
241
+ # middle
242
+ h = self.mid.block_1(h)
243
+ h = self.mid.attn_1(h)
244
+ h = self.mid.block_2(h)
245
+
246
+ # upsampling
247
+ for i_level in reversed(range(self.num_resolutions)):
248
+ for i_block in range(self.num_res_blocks + 1):
249
+ h = self.up[i_level].block[i_block](h)
250
+ if len(self.up[i_level].attn) > 0:
251
+ h = self.up[i_level].attn[i_block](h)
252
+ if i_level != 0:
253
+ h = self.up[i_level].upsample(h)
254
+
255
+ # end
256
+ h = self.norm_out(h)
257
+ h = swish(h)
258
+ h = self.conv_out(h)
259
+ return h
260
+
261
+
262
+ class DiagonalGaussian(nn.Module):
263
+ def __init__(self, sample: bool = True, chunk_dim: int = 1):
264
+ super().__init__()
265
+ self.sample = sample
266
+ self.chunk_dim = chunk_dim
267
+
268
+ def forward(self, z: Tensor) -> Tensor:
269
+ mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
270
+ if self.sample:
271
+ std = torch.exp(0.5 * logvar)
272
+ return mean + std * torch.randn_like(mean)
273
+ else:
274
+ return mean
275
+
276
+
277
+ class AutoEncoder(nn.Module):
278
+ def __init__(self, params: AutoEncoderParams):
279
+ super().__init__()
280
+ self.encoder = Encoder(
281
+ resolution=params.resolution,
282
+ in_channels=params.in_channels,
283
+ ch=params.ch,
284
+ ch_mult=params.ch_mult,
285
+ num_res_blocks=params.num_res_blocks,
286
+ z_channels=params.z_channels,
287
+ )
288
+ self.decoder = Decoder(
289
+ resolution=params.resolution,
290
+ in_channels=params.in_channels,
291
+ ch=params.ch,
292
+ out_ch=params.out_ch,
293
+ ch_mult=params.ch_mult,
294
+ num_res_blocks=params.num_res_blocks,
295
+ z_channels=params.z_channels,
296
+ )
297
+ self.reg = DiagonalGaussian()
298
+
299
+ self.scale_factor = params.scale_factor
300
+ self.shift_factor = params.shift_factor
301
+
302
+ def encode(self, x: Tensor) -> Tensor:
303
+ z = self.reg(self.encoder(x))
304
+ z = self.scale_factor * (z - self.shift_factor)
305
+ return z
306
+
307
+ def decode(self, z: Tensor) -> Tensor:
308
+ z = z / self.scale_factor + self.shift_factor
309
+ return self.decoder(z)
310
+
311
+ def forward(self, x: Tensor) -> Tensor:
312
+ return self.decode(self.encode(x))
data/flux/src/flux/modules/conditioner.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import Tensor, nn
2
+ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
3
+
4
+
5
+ class HFEmbedder(nn.Module):
6
+ def __init__(self, version: str, max_length: int, **hf_kwargs):
7
+ super().__init__()
8
+ self.is_clip = version.startswith("openai")
9
+ self.max_length = max_length
10
+ self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
11
+
12
+ if self.is_clip:
13
+ self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
14
+ self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
15
+ else:
16
+ self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
17
+ self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
18
+
19
+ self.hf_module = self.hf_module.eval().requires_grad_(False)
20
+
21
+ def forward(self, text: list[str]) -> Tensor:
22
+ batch_encoding = self.tokenizer(
23
+ text,
24
+ truncation=True,
25
+ max_length=self.max_length,
26
+ return_length=False,
27
+ return_overflowing_tokens=False,
28
+ padding="max_length",
29
+ return_tensors="pt",
30
+ )
31
+
32
+ outputs = self.hf_module(
33
+ input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
34
+ attention_mask=None,
35
+ output_hidden_states=False,
36
+ )
37
+ return outputs[self.output_key]
data/flux/src/flux/modules/image_embedders.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ from einops import rearrange, repeat
7
+ from PIL import Image
8
+ from safetensors.torch import load_file as load_sft
9
+ from torch import nn
10
+ from transformers import AutoModelForDepthEstimation, AutoProcessor, SiglipImageProcessor, SiglipVisionModel
11
+
12
+ from flux.util import print_load_warning
13
+
14
+
15
+ class DepthImageEncoder:
16
+ depth_model_name = "LiheYoung/depth-anything-large-hf"
17
+
18
+ def __init__(self, device):
19
+ self.device = device
20
+ self.depth_model = AutoModelForDepthEstimation.from_pretrained(self.depth_model_name).to(device)
21
+ self.processor = AutoProcessor.from_pretrained(self.depth_model_name)
22
+
23
+ def __call__(self, img: torch.Tensor) -> torch.Tensor:
24
+ hw = img.shape[-2:]
25
+
26
+ img = torch.clamp(img, -1.0, 1.0)
27
+ img_byte = ((img + 1.0) * 127.5).byte()
28
+
29
+ img = self.processor(img_byte, return_tensors="pt")["pixel_values"]
30
+ depth = self.depth_model(img.to(self.device)).predicted_depth
31
+ depth = repeat(depth, "b h w -> b 3 h w")
32
+ depth = torch.nn.functional.interpolate(depth, hw, mode="bicubic", antialias=True)
33
+
34
+ depth = depth / 127.5 - 1.0
35
+ return depth
36
+
37
+
38
+ class CannyImageEncoder:
39
+ def __init__(
40
+ self,
41
+ device,
42
+ min_t: int = 50,
43
+ max_t: int = 200,
44
+ ):
45
+ self.device = device
46
+ self.min_t = min_t
47
+ self.max_t = max_t
48
+
49
+ def __call__(self, img: torch.Tensor) -> torch.Tensor:
50
+ assert img.shape[0] == 1, "Only batch size 1 is supported"
51
+
52
+ img = rearrange(img[0], "c h w -> h w c")
53
+ img = torch.clamp(img, -1.0, 1.0)
54
+ img_np = ((img + 1.0) * 127.5).numpy().astype(np.uint8)
55
+
56
+ # Apply Canny edge detection
57
+ canny = cv2.Canny(img_np, self.min_t, self.max_t)
58
+
59
+ # Convert back to torch tensor and reshape
60
+ canny = torch.from_numpy(canny).float() / 127.5 - 1.0
61
+ canny = rearrange(canny, "h w -> 1 1 h w")
62
+ canny = repeat(canny, "b 1 ... -> b 3 ...")
63
+ return canny.to(self.device)
64
+
65
+
66
+ class ReduxImageEncoder(nn.Module):
67
+ siglip_model_name = "google/siglip-so400m-patch14-384"
68
+
69
+ def __init__(
70
+ self,
71
+ device,
72
+ redux_dim: int = 1152,
73
+ txt_in_features: int = 4096,
74
+ redux_path: str | None = os.getenv("FLUX_REDUX"),
75
+ dtype=torch.bfloat16,
76
+ ) -> None:
77
+ assert redux_path is not None, "Redux path must be provided"
78
+
79
+ super().__init__()
80
+
81
+ self.redux_dim = redux_dim
82
+ self.device = device if isinstance(device, torch.device) else torch.device(device)
83
+ self.dtype = dtype
84
+
85
+ with self.device:
86
+ self.redux_up = nn.Linear(redux_dim, txt_in_features * 3, dtype=dtype)
87
+ self.redux_down = nn.Linear(txt_in_features * 3, txt_in_features, dtype=dtype)
88
+
89
+ sd = load_sft(redux_path, device=str(device))
90
+ missing, unexpected = self.load_state_dict(sd, strict=False, assign=True)
91
+ print_load_warning(missing, unexpected)
92
+
93
+ self.siglip = SiglipVisionModel.from_pretrained(self.siglip_model_name).to(dtype=dtype)
94
+ self.normalize = SiglipImageProcessor.from_pretrained(self.siglip_model_name)
95
+
96
+ def __call__(self, x: Image.Image) -> torch.Tensor:
97
+ imgs = self.normalize.preprocess(images=[x], do_resize=True, return_tensors="pt", do_convert_rgb=True)
98
+
99
+ _encoded_x = self.siglip(**imgs.to(device=self.device, dtype=self.dtype)).last_hidden_state
100
+
101
+ projected_x = self.redux_down(nn.functional.silu(self.redux_up(_encoded_x)))
102
+
103
+ return projected_x
data/flux/src/flux/modules/layers.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ from einops import rearrange
6
+ from torch import Tensor, nn
7
+
8
+ from flux.math import attention, rope
9
+
10
+
11
+ class EmbedND(nn.Module):
12
+ def __init__(self, dim: int, theta: int, axes_dim: list[int]):
13
+ super().__init__()
14
+ self.dim = dim
15
+ self.theta = theta
16
+ self.axes_dim = axes_dim
17
+
18
+ def forward(self, ids: Tensor) -> Tensor:
19
+ n_axes = ids.shape[-1]
20
+ emb = torch.cat(
21
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
22
+ dim=-3,
23
+ )
24
+
25
+ return emb.unsqueeze(1)
26
+
27
+
28
+ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
29
+ """
30
+ Create sinusoidal timestep embeddings.
31
+ :param t: a 1-D Tensor of N indices, one per batch element.
32
+ These may be fractional.
33
+ :param dim: the dimension of the output.
34
+ :param max_period: controls the minimum frequency of the embeddings.
35
+ :return: an (N, D) Tensor of positional embeddings.
36
+ """
37
+ t = time_factor * t
38
+ half = dim // 2
39
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
40
+ t.device
41
+ )
42
+
43
+ args = t[:, None].float() * freqs[None]
44
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
45
+ if dim % 2:
46
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
47
+ if torch.is_floating_point(t):
48
+ embedding = embedding.to(t)
49
+ return embedding
50
+
51
+
52
+ class MLPEmbedder(nn.Module):
53
+ def __init__(self, in_dim: int, hidden_dim: int):
54
+ super().__init__()
55
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
56
+ self.silu = nn.SiLU()
57
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
58
+
59
+ def forward(self, x: Tensor) -> Tensor:
60
+ return self.out_layer(self.silu(self.in_layer(x)))
61
+
62
+
63
+ class RMSNorm(torch.nn.Module):
64
+ def __init__(self, dim: int):
65
+ super().__init__()
66
+ self.scale = nn.Parameter(torch.ones(dim))
67
+
68
+ def forward(self, x: Tensor):
69
+ x_dtype = x.dtype
70
+ x = x.float()
71
+ rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
72
+ return (x * rrms).to(dtype=x_dtype) * self.scale
73
+
74
+
75
+ class QKNorm(torch.nn.Module):
76
+ def __init__(self, dim: int):
77
+ super().__init__()
78
+ self.query_norm = RMSNorm(dim)
79
+ self.key_norm = RMSNorm(dim)
80
+
81
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
82
+ q = self.query_norm(q)
83
+ k = self.key_norm(k)
84
+ return q.to(v), k.to(v)
85
+
86
+
87
+ class SelfAttention(nn.Module):
88
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
89
+ super().__init__()
90
+ self.num_heads = num_heads
91
+ head_dim = dim // num_heads
92
+
93
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
94
+ self.norm = QKNorm(head_dim)
95
+ self.proj = nn.Linear(dim, dim)
96
+
97
+ def forward(self, x: Tensor, pe: Tensor) -> Tensor:
98
+ qkv = self.qkv(x)
99
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
100
+ q, k = self.norm(q, k, v)
101
+ x = attention(q, k, v, pe=pe)
102
+ x = self.proj(x)
103
+ return x
104
+
105
+
106
+ @dataclass
107
+ class ModulationOut:
108
+ shift: Tensor
109
+ scale: Tensor
110
+ gate: Tensor
111
+
112
+
113
+ class Modulation(nn.Module):
114
+ def __init__(self, dim: int, double: bool):
115
+ super().__init__()
116
+ self.is_double = double
117
+ self.multiplier = 6 if double else 3
118
+ self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
119
+
120
+ def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
121
+ out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
122
+
123
+ return (
124
+ ModulationOut(*out[:3]),
125
+ ModulationOut(*out[3:]) if self.is_double else None,
126
+ )
127
+
128
+
129
+ class DoubleStreamBlock(nn.Module):
130
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
131
+ super().__init__()
132
+
133
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
134
+ self.num_heads = num_heads
135
+ self.hidden_size = hidden_size
136
+ self.img_mod = Modulation(hidden_size, double=True)
137
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
138
+ self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
139
+
140
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
141
+ self.img_mlp = nn.Sequential(
142
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
143
+ nn.GELU(approximate="tanh"),
144
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
145
+ )
146
+
147
+ self.txt_mod = Modulation(hidden_size, double=True)
148
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
149
+ self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
150
+
151
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
152
+ self.txt_mlp = nn.Sequential(
153
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
154
+ nn.GELU(approximate="tanh"),
155
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
156
+ )
157
+
158
+ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
159
+ img_mod1, img_mod2 = self.img_mod(vec)
160
+ txt_mod1, txt_mod2 = self.txt_mod(vec)
161
+
162
+ # prepare image for attention
163
+ img_modulated = self.img_norm1(img)
164
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
165
+ img_qkv = self.img_attn.qkv(img_modulated)
166
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
167
+ img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
168
+
169
+ # prepare txt for attention
170
+ txt_modulated = self.txt_norm1(txt)
171
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
172
+ txt_qkv = self.txt_attn.qkv(txt_modulated)
173
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
174
+ txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
175
+
176
+ # run actual attention
177
+ q = torch.cat((txt_q, img_q), dim=2)
178
+ k = torch.cat((txt_k, img_k), dim=2)
179
+ v = torch.cat((txt_v, img_v), dim=2)
180
+
181
+ attn = attention(q, k, v, pe=pe)
182
+ txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
183
+
184
+ # calculate the img bloks
185
+ img = img + img_mod1.gate * self.img_attn.proj(img_attn)
186
+ img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
187
+
188
+ # calculate the txt bloks
189
+ txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
190
+ txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
191
+ return img, txt
192
+
193
+
194
+ class SingleStreamBlock(nn.Module):
195
+ """
196
+ A DiT block with parallel linear layers as described in
197
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
198
+ """
199
+
200
+ def __init__(
201
+ self,
202
+ hidden_size: int,
203
+ num_heads: int,
204
+ mlp_ratio: float = 4.0,
205
+ qk_scale: float | None = None,
206
+ ):
207
+ super().__init__()
208
+ self.hidden_dim = hidden_size
209
+ self.num_heads = num_heads
210
+ head_dim = hidden_size // num_heads
211
+ self.scale = qk_scale or head_dim**-0.5
212
+
213
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
214
+ # qkv and mlp_in
215
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
216
+ # proj and mlp_out
217
+ self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
218
+
219
+ self.norm = QKNorm(head_dim)
220
+
221
+ self.hidden_size = hidden_size
222
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
223
+
224
+ self.mlp_act = nn.GELU(approximate="tanh")
225
+ self.modulation = Modulation(hidden_size, double=False)
226
+
227
+ def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
228
+ mod, _ = self.modulation(vec)
229
+ x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
230
+ qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
231
+
232
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
233
+ q, k = self.norm(q, k, v)
234
+
235
+ # compute attention
236
+ attn = attention(q, k, v, pe=pe)
237
+ # compute activation in mlp stream, cat again and run second linear layer
238
+ output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
239
+ return x + mod.gate * output
240
+
241
+
242
+ class LastLayer(nn.Module):
243
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
244
+ super().__init__()
245
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
246
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
247
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
248
+
249
+ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
250
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
251
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
252
+ x = self.linear(x)
253
+ return x
data/flux/src/flux/modules/lora.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ def replace_linear_with_lora(
6
+ module: nn.Module,
7
+ max_rank: int,
8
+ scale: float = 1.0,
9
+ ) -> None:
10
+ for name, child in module.named_children():
11
+ if isinstance(child, nn.Linear):
12
+ new_lora = LinearLora(
13
+ in_features=child.in_features,
14
+ out_features=child.out_features,
15
+ bias=child.bias,
16
+ rank=max_rank,
17
+ scale=scale,
18
+ dtype=child.weight.dtype,
19
+ device=child.weight.device,
20
+ )
21
+
22
+ new_lora.weight = child.weight
23
+ new_lora.bias = child.bias if child.bias is not None else None
24
+
25
+ setattr(module, name, new_lora)
26
+ else:
27
+ replace_linear_with_lora(
28
+ module=child,
29
+ max_rank=max_rank,
30
+ scale=scale,
31
+ )
32
+
33
+
34
+ class LinearLora(nn.Linear):
35
+ def __init__(
36
+ self,
37
+ in_features: int,
38
+ out_features: int,
39
+ bias: bool,
40
+ rank: int,
41
+ dtype: torch.dtype,
42
+ device: torch.device,
43
+ lora_bias: bool = True,
44
+ scale: float = 1.0,
45
+ *args,
46
+ **kwargs,
47
+ ) -> None:
48
+ super().__init__(
49
+ in_features=in_features,
50
+ out_features=out_features,
51
+ bias=bias is not None,
52
+ device=device,
53
+ dtype=dtype,
54
+ *args,
55
+ **kwargs,
56
+ )
57
+
58
+ assert isinstance(scale, float), "scale must be a float"
59
+
60
+ self.scale = scale
61
+ self.rank = rank
62
+ self.lora_bias = lora_bias
63
+ self.dtype = dtype
64
+ self.device = device
65
+
66
+ if rank > (new_rank := min(self.out_features, self.in_features)):
67
+ self.rank = new_rank
68
+
69
+ self.lora_A = nn.Linear(
70
+ in_features=in_features,
71
+ out_features=self.rank,
72
+ bias=False,
73
+ dtype=dtype,
74
+ device=device,
75
+ )
76
+ self.lora_B = nn.Linear(
77
+ in_features=self.rank,
78
+ out_features=out_features,
79
+ bias=self.lora_bias,
80
+ dtype=dtype,
81
+ device=device,
82
+ )
83
+
84
+ def set_scale(self, scale: float) -> None:
85
+ assert isinstance(scale, float), "scalar value must be a float"
86
+ self.scale = scale
87
+
88
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
89
+ base_out = super().forward(input)
90
+
91
+ _lora_out_B = self.lora_B(self.lora_A(input))
92
+ lora_update = _lora_out_B * self.scale
93
+
94
+ return base_out + lora_update
data/flux/src/flux/sampling.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Callable
3
+
4
+ import numpy as np
5
+ import torch
6
+ from einops import rearrange, repeat
7
+ from PIL import Image
8
+ from torch import Tensor
9
+
10
+ from .model import Flux
11
+ from .modules.autoencoder import AutoEncoder
12
+ from .modules.conditioner import HFEmbedder
13
+ from .modules.image_embedders import CannyImageEncoder, DepthImageEncoder, ReduxImageEncoder
14
+
15
+
16
+ def get_noise(
17
+ num_samples: int,
18
+ height: int,
19
+ width: int,
20
+ device: torch.device,
21
+ dtype: torch.dtype,
22
+ seed: int,
23
+ ):
24
+ return torch.randn(
25
+ num_samples,
26
+ 16,
27
+ # allow for packing
28
+ 2 * math.ceil(height / 16),
29
+ 2 * math.ceil(width / 16),
30
+ device=device,
31
+ dtype=dtype,
32
+ generator=torch.Generator(device=device).manual_seed(seed),
33
+ )
34
+
35
+
36
+ def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
37
+ bs, c, h, w = img.shape
38
+ if bs == 1 and not isinstance(prompt, str):
39
+ bs = len(prompt)
40
+
41
+ img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
42
+ if img.shape[0] == 1 and bs > 1:
43
+ img = repeat(img, "1 ... -> bs ...", bs=bs)
44
+
45
+ img_ids = torch.zeros(h // 2, w // 2, 3)
46
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
47
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
48
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
49
+
50
+ if isinstance(prompt, str):
51
+ prompt = [prompt]
52
+ txt = t5(prompt)
53
+ if txt.shape[0] == 1 and bs > 1:
54
+ txt = repeat(txt, "1 ... -> bs ...", bs=bs)
55
+ txt_ids = torch.zeros(bs, txt.shape[1], 3)
56
+
57
+ vec = clip(prompt)
58
+ if vec.shape[0] == 1 and bs > 1:
59
+ vec = repeat(vec, "1 ... -> bs ...", bs=bs)
60
+
61
+ return {
62
+ "img": img,
63
+ "img_ids": img_ids.to(img.device),
64
+ "txt": txt.to(img.device),
65
+ "txt_ids": txt_ids.to(img.device),
66
+ "vec": vec.to(img.device),
67
+ }
68
+
69
+
70
+ def prepare_control(
71
+ t5: HFEmbedder,
72
+ clip: HFEmbedder,
73
+ img: Tensor,
74
+ prompt: str | list[str],
75
+ ae: AutoEncoder,
76
+ encoder: DepthImageEncoder | CannyImageEncoder,
77
+ img_cond_path: str,
78
+ ) -> dict[str, Tensor]:
79
+ # load and encode the conditioning image
80
+ bs, _, h, w = img.shape
81
+ if bs == 1 and not isinstance(prompt, str):
82
+ bs = len(prompt)
83
+
84
+ img_cond = Image.open(img_cond_path).convert("RGB")
85
+
86
+ width = w * 8
87
+ height = h * 8
88
+ img_cond = img_cond.resize((width, height), Image.LANCZOS)
89
+ img_cond = np.array(img_cond)
90
+ img_cond = torch.from_numpy(img_cond).float() / 127.5 - 1.0
91
+ img_cond = rearrange(img_cond, "h w c -> 1 c h w")
92
+
93
+ with torch.no_grad():
94
+ img_cond = encoder(img_cond)
95
+ img_cond = ae.encode(img_cond)
96
+
97
+ img_cond = img_cond.to(torch.bfloat16)
98
+ img_cond = rearrange(img_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
99
+ if img_cond.shape[0] == 1 and bs > 1:
100
+ img_cond = repeat(img_cond, "1 ... -> bs ...", bs=bs)
101
+
102
+ return_dict = prepare(t5, clip, img, prompt)
103
+ return_dict["img_cond"] = img_cond
104
+ return return_dict
105
+
106
+
107
+ def prepare_fill(
108
+ t5: HFEmbedder,
109
+ clip: HFEmbedder,
110
+ img: Tensor,
111
+ prompt: str | list[str],
112
+ ae: AutoEncoder,
113
+ img_cond_path: str,
114
+ mask_path: str,
115
+ ) -> dict[str, Tensor]:
116
+ # load and encode the conditioning image and the mask
117
+ bs, _, _, _ = img.shape
118
+ if bs == 1 and not isinstance(prompt, str):
119
+ bs = len(prompt)
120
+
121
+ img_cond = Image.open(img_cond_path).convert("RGB")
122
+ img_cond = np.array(img_cond)
123
+ img_cond = torch.from_numpy(img_cond).float() / 127.5 - 1.0
124
+ img_cond = rearrange(img_cond, "h w c -> 1 c h w")
125
+
126
+ mask = Image.open(mask_path).convert("L")
127
+ mask = np.array(mask)
128
+ mask = torch.from_numpy(mask).float() / 255.0
129
+ mask = rearrange(mask, "h w -> 1 1 h w")
130
+
131
+ with torch.no_grad():
132
+ img_cond = img_cond.to(img.device)
133
+ mask = mask.to(img.device)
134
+ img_cond = img_cond * (1 - mask)
135
+ img_cond = ae.encode(img_cond)
136
+ mask = mask[:, 0, :, :]
137
+ mask = mask.to(torch.bfloat16)
138
+ mask = rearrange(
139
+ mask,
140
+ "b (h ph) (w pw) -> b (ph pw) h w",
141
+ ph=8,
142
+ pw=8,
143
+ )
144
+ mask = rearrange(mask, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
145
+ if mask.shape[0] == 1 and bs > 1:
146
+ mask = repeat(mask, "1 ... -> bs ...", bs=bs)
147
+
148
+ img_cond = img_cond.to(torch.bfloat16)
149
+ img_cond = rearrange(img_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
150
+ if img_cond.shape[0] == 1 and bs > 1:
151
+ img_cond = repeat(img_cond, "1 ... -> bs ...", bs=bs)
152
+
153
+ img_cond = torch.cat((img_cond, mask), dim=-1)
154
+
155
+ return_dict = prepare(t5, clip, img, prompt)
156
+ return_dict["img_cond"] = img_cond.to(img.device)
157
+ return return_dict
158
+
159
+
160
+ def prepare_redux(
161
+ t5: HFEmbedder,
162
+ clip: HFEmbedder,
163
+ img: Tensor,
164
+ prompt: str | list[str],
165
+ encoder: ReduxImageEncoder,
166
+ img_cond_path: str,
167
+ ) -> dict[str, Tensor]:
168
+ bs, _, h, w = img.shape
169
+ if bs == 1 and not isinstance(prompt, str):
170
+ bs = len(prompt)
171
+
172
+ img_cond = Image.open(img_cond_path).convert("RGB")
173
+ with torch.no_grad():
174
+ img_cond = encoder(img_cond)
175
+
176
+ img_cond = img_cond.to(torch.bfloat16)
177
+ if img_cond.shape[0] == 1 and bs > 1:
178
+ img_cond = repeat(img_cond, "1 ... -> bs ...", bs=bs)
179
+
180
+ img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
181
+ if img.shape[0] == 1 and bs > 1:
182
+ img = repeat(img, "1 ... -> bs ...", bs=bs)
183
+
184
+ img_ids = torch.zeros(h // 2, w // 2, 3)
185
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
186
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
187
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
188
+
189
+ if isinstance(prompt, str):
190
+ prompt = [prompt]
191
+ txt = t5(prompt)
192
+ txt = torch.cat((txt, img_cond.to(txt)), dim=-2)
193
+ if txt.shape[0] == 1 and bs > 1:
194
+ txt = repeat(txt, "1 ... -> bs ...", bs=bs)
195
+ txt_ids = torch.zeros(bs, txt.shape[1], 3)
196
+
197
+ vec = clip(prompt)
198
+ if vec.shape[0] == 1 and bs > 1:
199
+ vec = repeat(vec, "1 ... -> bs ...", bs=bs)
200
+
201
+ return {
202
+ "img": img,
203
+ "img_ids": img_ids.to(img.device),
204
+ "txt": txt.to(img.device),
205
+ "txt_ids": txt_ids.to(img.device),
206
+ "vec": vec.to(img.device),
207
+ }
208
+
209
+
210
+ def time_shift(mu: float, sigma: float, t: Tensor):
211
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
212
+
213
+
214
+ def get_lin_function(
215
+ x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
216
+ ) -> Callable[[float], float]:
217
+ m = (y2 - y1) / (x2 - x1)
218
+ b = y1 - m * x1
219
+ return lambda x: m * x + b
220
+
221
+
222
+ def get_schedule(
223
+ num_steps: int,
224
+ image_seq_len: int,
225
+ base_shift: float = 0.5,
226
+ max_shift: float = 1.15,
227
+ shift: bool = True,
228
+ ) -> list[float]:
229
+ # extra step for zero
230
+ timesteps = torch.linspace(1, 0, num_steps + 1)
231
+
232
+ # shifting the schedule to favor high timesteps for higher signal images
233
+ if shift:
234
+ # estimate mu based on linear estimation between two points
235
+ mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
236
+ timesteps = time_shift(mu, 1.0, timesteps)
237
+
238
+ return timesteps.tolist()
239
+
240
+
241
+ def denoise(
242
+ model: Flux,
243
+ # model input
244
+ img: Tensor,
245
+ img_ids: Tensor,
246
+ txt: Tensor,
247
+ txt_ids: Tensor,
248
+ vec: Tensor,
249
+ # sampling parameters
250
+ timesteps: list[float],
251
+ guidance: float = 4.0,
252
+ # extra img tokens
253
+ img_cond: Tensor | None = None,
254
+ ):
255
+ # this is ignored for schnell
256
+ guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
257
+ for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
258
+ t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
259
+ pred = model(
260
+ img=torch.cat((img, img_cond), dim=-1) if img_cond is not None else img,
261
+ img_ids=img_ids,
262
+ txt=txt,
263
+ txt_ids=txt_ids,
264
+ y=vec,
265
+ timesteps=t_vec,
266
+ guidance=guidance_vec,
267
+ )
268
+
269
+ img = img + (t_prev - t_curr) * pred
270
+
271
+ return img
272
+
273
+
274
+ def unpack(x: Tensor, height: int, width: int) -> Tensor:
275
+ return rearrange(
276
+ x,
277
+ "b (h w) (c ph pw) -> b c (h ph) (w pw)",
278
+ h=math.ceil(height / 16),
279
+ w=math.ceil(width / 16),
280
+ ph=2,
281
+ pw=2,
282
+ )
data/flux/src/flux/util.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ from einops import rearrange
6
+ from huggingface_hub import hf_hub_download
7
+ from imwatermark import WatermarkEncoder
8
+ from PIL import ExifTags, Image
9
+ from safetensors.torch import load_file as load_sft
10
+
11
+ from flux.model import Flux, FluxLoraWrapper, FluxParams
12
+ from flux.modules.autoencoder import AutoEncoder, AutoEncoderParams
13
+ from flux.modules.conditioner import HFEmbedder
14
+
15
+
16
+ def save_image(
17
+ nsfw_classifier,
18
+ name: str,
19
+ output_name: str,
20
+ idx: int,
21
+ x: torch.Tensor,
22
+ add_sampling_metadata: bool,
23
+ prompt: str,
24
+ nsfw_threshold: float = 0.85,
25
+ ) -> int:
26
+ fn = output_name.format(idx=idx)
27
+ print(f"Saving {fn}")
28
+ # bring into PIL format and save
29
+ x = x.clamp(-1, 1)
30
+ x = embed_watermark(x.float())
31
+ x = rearrange(x[0], "c h w -> h w c")
32
+
33
+ img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
34
+ nsfw_score = [x["score"] for x in nsfw_classifier(img) if x["label"] == "nsfw"][0]
35
+
36
+ if nsfw_score < nsfw_threshold:
37
+ exif_data = Image.Exif()
38
+ exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
39
+ exif_data[ExifTags.Base.Make] = "Black Forest Labs"
40
+ exif_data[ExifTags.Base.Model] = name
41
+ if add_sampling_metadata:
42
+ exif_data[ExifTags.Base.ImageDescription] = prompt
43
+ img.save(fn, exif=exif_data, quality=95, subsampling=0)
44
+ idx += 1
45
+ else:
46
+ print("Your generated image may contain NSFW content.")
47
+
48
+ return idx
49
+
50
+
51
+ @dataclass
52
+ class ModelSpec:
53
+ params: FluxParams
54
+ ae_params: AutoEncoderParams
55
+ ckpt_path: str | None
56
+ lora_path: str | None
57
+ ae_path: str | None
58
+ repo_id: str | None
59
+ repo_flow: str | None
60
+ repo_ae: str | None
61
+
62
+
63
+ configs = {
64
+ "flux-dev": ModelSpec(
65
+ repo_id="black-forest-labs/FLUX.1-dev",
66
+ repo_flow="flux1-dev.safetensors",
67
+ repo_ae="ae.safetensors",
68
+ ckpt_path=os.getenv("FLUX_DEV"),
69
+ lora_path=None,
70
+ params=FluxParams(
71
+ in_channels=64,
72
+ out_channels=64,
73
+ vec_in_dim=768,
74
+ context_in_dim=4096,
75
+ hidden_size=3072,
76
+ mlp_ratio=4.0,
77
+ num_heads=24,
78
+ depth=19,
79
+ depth_single_blocks=38,
80
+ axes_dim=[16, 56, 56],
81
+ theta=10_000,
82
+ qkv_bias=True,
83
+ guidance_embed=True,
84
+ ),
85
+ ae_path=os.getenv("AE"),
86
+ ae_params=AutoEncoderParams(
87
+ resolution=256,
88
+ in_channels=3,
89
+ ch=128,
90
+ out_ch=3,
91
+ ch_mult=[1, 2, 4, 4],
92
+ num_res_blocks=2,
93
+ z_channels=16,
94
+ scale_factor=0.3611,
95
+ shift_factor=0.1159,
96
+ ),
97
+ ),
98
+ "flux-schnell": ModelSpec(
99
+ repo_id="black-forest-labs/FLUX.1-schnell",
100
+ repo_flow="flux1-schnell.safetensors",
101
+ repo_ae="ae.safetensors",
102
+ ckpt_path=os.getenv("FLUX_SCHNELL"),
103
+ lora_path=None,
104
+ params=FluxParams(
105
+ in_channels=64,
106
+ out_channels=64,
107
+ vec_in_dim=768,
108
+ context_in_dim=4096,
109
+ hidden_size=3072,
110
+ mlp_ratio=4.0,
111
+ num_heads=24,
112
+ depth=19,
113
+ depth_single_blocks=38,
114
+ axes_dim=[16, 56, 56],
115
+ theta=10_000,
116
+ qkv_bias=True,
117
+ guidance_embed=False,
118
+ ),
119
+ ae_path=os.getenv("AE"),
120
+ ae_params=AutoEncoderParams(
121
+ resolution=256,
122
+ in_channels=3,
123
+ ch=128,
124
+ out_ch=3,
125
+ ch_mult=[1, 2, 4, 4],
126
+ num_res_blocks=2,
127
+ z_channels=16,
128
+ scale_factor=0.3611,
129
+ shift_factor=0.1159,
130
+ ),
131
+ ),
132
+ "flux-dev-canny": ModelSpec(
133
+ repo_id="black-forest-labs/FLUX.1-Canny-dev",
134
+ repo_flow="flux1-canny-dev.safetensors",
135
+ repo_ae="ae.safetensors",
136
+ ckpt_path=os.getenv("FLUX_DEV_CANNY"),
137
+ lora_path=None,
138
+ params=FluxParams(
139
+ in_channels=128,
140
+ out_channels=64,
141
+ vec_in_dim=768,
142
+ context_in_dim=4096,
143
+ hidden_size=3072,
144
+ mlp_ratio=4.0,
145
+ num_heads=24,
146
+ depth=19,
147
+ depth_single_blocks=38,
148
+ axes_dim=[16, 56, 56],
149
+ theta=10_000,
150
+ qkv_bias=True,
151
+ guidance_embed=True,
152
+ ),
153
+ ae_path=os.getenv("AE"),
154
+ ae_params=AutoEncoderParams(
155
+ resolution=256,
156
+ in_channels=3,
157
+ ch=128,
158
+ out_ch=3,
159
+ ch_mult=[1, 2, 4, 4],
160
+ num_res_blocks=2,
161
+ z_channels=16,
162
+ scale_factor=0.3611,
163
+ shift_factor=0.1159,
164
+ ),
165
+ ),
166
+ "flux-dev-canny-lora": ModelSpec(
167
+ repo_id="black-forest-labs/FLUX.1-dev",
168
+ repo_flow="flux1-dev.safetensors",
169
+ repo_ae="ae.safetensors",
170
+ ckpt_path=os.getenv("FLUX_DEV"),
171
+ lora_path=os.getenv("FLUX_DEV_CANNY_LORA"),
172
+ params=FluxParams(
173
+ in_channels=128,
174
+ out_channels=64,
175
+ vec_in_dim=768,
176
+ context_in_dim=4096,
177
+ hidden_size=3072,
178
+ mlp_ratio=4.0,
179
+ num_heads=24,
180
+ depth=19,
181
+ depth_single_blocks=38,
182
+ axes_dim=[16, 56, 56],
183
+ theta=10_000,
184
+ qkv_bias=True,
185
+ guidance_embed=True,
186
+ ),
187
+ ae_path=os.getenv("AE"),
188
+ ae_params=AutoEncoderParams(
189
+ resolution=256,
190
+ in_channels=3,
191
+ ch=128,
192
+ out_ch=3,
193
+ ch_mult=[1, 2, 4, 4],
194
+ num_res_blocks=2,
195
+ z_channels=16,
196
+ scale_factor=0.3611,
197
+ shift_factor=0.1159,
198
+ ),
199
+ ),
200
+ "flux-dev-depth": ModelSpec(
201
+ repo_id="black-forest-labs/FLUX.1-Depth-dev",
202
+ repo_flow="flux1-depth-dev.safetensors",
203
+ repo_ae="ae.safetensors",
204
+ ckpt_path=os.getenv("FLUX_DEV_DEPTH"),
205
+ lora_path=None,
206
+ params=FluxParams(
207
+ in_channels=128,
208
+ out_channels=64,
209
+ vec_in_dim=768,
210
+ context_in_dim=4096,
211
+ hidden_size=3072,
212
+ mlp_ratio=4.0,
213
+ num_heads=24,
214
+ depth=19,
215
+ depth_single_blocks=38,
216
+ axes_dim=[16, 56, 56],
217
+ theta=10_000,
218
+ qkv_bias=True,
219
+ guidance_embed=True,
220
+ ),
221
+ ae_path=os.getenv("AE"),
222
+ ae_params=AutoEncoderParams(
223
+ resolution=256,
224
+ in_channels=3,
225
+ ch=128,
226
+ out_ch=3,
227
+ ch_mult=[1, 2, 4, 4],
228
+ num_res_blocks=2,
229
+ z_channels=16,
230
+ scale_factor=0.3611,
231
+ shift_factor=0.1159,
232
+ ),
233
+ ),
234
+ "flux-dev-depth-lora": ModelSpec(
235
+ repo_id="black-forest-labs/FLUX.1-dev",
236
+ repo_flow="flux1-dev.safetensors",
237
+ repo_ae="ae.safetensors",
238
+ ckpt_path=os.getenv("FLUX_DEV"),
239
+ lora_path=os.getenv("FLUX_DEV_DEPTH_LORA"),
240
+ params=FluxParams(
241
+ in_channels=128,
242
+ out_channels=64,
243
+ vec_in_dim=768,
244
+ context_in_dim=4096,
245
+ hidden_size=3072,
246
+ mlp_ratio=4.0,
247
+ num_heads=24,
248
+ depth=19,
249
+ depth_single_blocks=38,
250
+ axes_dim=[16, 56, 56],
251
+ theta=10_000,
252
+ qkv_bias=True,
253
+ guidance_embed=True,
254
+ ),
255
+ ae_path=os.getenv("AE"),
256
+ ae_params=AutoEncoderParams(
257
+ resolution=256,
258
+ in_channels=3,
259
+ ch=128,
260
+ out_ch=3,
261
+ ch_mult=[1, 2, 4, 4],
262
+ num_res_blocks=2,
263
+ z_channels=16,
264
+ scale_factor=0.3611,
265
+ shift_factor=0.1159,
266
+ ),
267
+ ),
268
+ "flux-dev-fill": ModelSpec(
269
+ repo_id="black-forest-labs/FLUX.1-Fill-dev",
270
+ repo_flow="flux1-fill-dev.safetensors",
271
+ repo_ae="ae.safetensors",
272
+ ckpt_path=os.getenv("FLUX_DEV_FILL"),
273
+ lora_path=None,
274
+ params=FluxParams(
275
+ in_channels=384,
276
+ out_channels=64,
277
+ vec_in_dim=768,
278
+ context_in_dim=4096,
279
+ hidden_size=3072,
280
+ mlp_ratio=4.0,
281
+ num_heads=24,
282
+ depth=19,
283
+ depth_single_blocks=38,
284
+ axes_dim=[16, 56, 56],
285
+ theta=10_000,
286
+ qkv_bias=True,
287
+ guidance_embed=True,
288
+ ),
289
+ ae_path=os.getenv("AE"),
290
+ ae_params=AutoEncoderParams(
291
+ resolution=256,
292
+ in_channels=3,
293
+ ch=128,
294
+ out_ch=3,
295
+ ch_mult=[1, 2, 4, 4],
296
+ num_res_blocks=2,
297
+ z_channels=16,
298
+ scale_factor=0.3611,
299
+ shift_factor=0.1159,
300
+ ),
301
+ ),
302
+ }
303
+
304
+
305
+ def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
306
+ if len(missing) > 0 and len(unexpected) > 0:
307
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
308
+ print("\n" + "-" * 79 + "\n")
309
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
310
+ elif len(missing) > 0:
311
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
312
+ elif len(unexpected) > 0:
313
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
314
+
315
+
316
+ def load_flow_model(
317
+ name: str, device: str | torch.device = "cuda", hf_download: bool = True, verbose: bool = False
318
+ ) -> Flux:
319
+ # Loading Flux
320
+ print("Init model")
321
+ ckpt_path = configs[name].ckpt_path
322
+ lora_path = configs[name].lora_path
323
+ if (
324
+ ckpt_path is None
325
+ and configs[name].repo_id is not None
326
+ and configs[name].repo_flow is not None
327
+ and hf_download
328
+ ):
329
+ ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
330
+
331
+ with torch.device("meta" if ckpt_path is not None else device):
332
+ if lora_path is not None:
333
+ model = FluxLoraWrapper(params=configs[name].params).to(torch.bfloat16)
334
+ else:
335
+ model = Flux(configs[name].params).to(torch.bfloat16)
336
+
337
+ if ckpt_path is not None:
338
+ print("Loading checkpoint")
339
+ # load_sft doesn't support torch.device
340
+ sd = load_sft(ckpt_path, device=str(device))
341
+ sd = optionally_expand_state_dict(model, sd)
342
+ missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
343
+ if verbose:
344
+ print_load_warning(missing, unexpected)
345
+
346
+ if configs[name].lora_path is not None:
347
+ print("Loading LoRA")
348
+ lora_sd = load_sft(configs[name].lora_path, device=str(device))
349
+ # loading the lora params + overwriting scale values in the norms
350
+ missing, unexpected = model.load_state_dict(lora_sd, strict=False, assign=True)
351
+ if verbose:
352
+ print_load_warning(missing, unexpected)
353
+ return model
354
+
355
+
356
+ def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
357
+ # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
358
+ return HFEmbedder("google/t5-v1_1-xxl", max_length=max_length, torch_dtype=torch.bfloat16).to(device)
359
+
360
+
361
+ def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
362
+ return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device)
363
+
364
+
365
+ def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder:
366
+ ckpt_path = configs[name].ae_path
367
+ if (
368
+ ckpt_path is None
369
+ and configs[name].repo_id is not None
370
+ and configs[name].repo_ae is not None
371
+ and hf_download
372
+ ):
373
+ ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae)
374
+
375
+ # Loading the autoencoder
376
+ print("Init AE")
377
+ with torch.device("meta" if ckpt_path is not None else device):
378
+ ae = AutoEncoder(configs[name].ae_params)
379
+
380
+ if ckpt_path is not None:
381
+ sd = load_sft(ckpt_path, device=str(device))
382
+ missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
383
+ print_load_warning(missing, unexpected)
384
+ return ae
385
+
386
+
387
+ def optionally_expand_state_dict(model: torch.nn.Module, state_dict: dict) -> dict:
388
+ """
389
+ Optionally expand the state dict to match the model's parameters shapes.
390
+ """
391
+ for name, param in model.named_parameters():
392
+ if name in state_dict:
393
+ if state_dict[name].shape != param.shape:
394
+ print(
395
+ f"Expanding '{name}' with shape {state_dict[name].shape} to model parameter with shape {param.shape}."
396
+ )
397
+ # expand with zeros:
398
+ expanded_state_dict_weight = torch.zeros_like(param, device=state_dict[name].device)
399
+ slices = tuple(slice(0, dim) for dim in state_dict[name].shape)
400
+ expanded_state_dict_weight[slices] = state_dict[name]
401
+ state_dict[name] = expanded_state_dict_weight
402
+
403
+ return state_dict
404
+
405
+
406
+ class WatermarkEmbedder:
407
+ def __init__(self, watermark):
408
+ self.watermark = watermark
409
+ self.num_bits = len(WATERMARK_BITS)
410
+ self.encoder = WatermarkEncoder()
411
+ self.encoder.set_watermark("bits", self.watermark)
412
+
413
+ def __call__(self, image: torch.Tensor) -> torch.Tensor:
414
+ """
415
+ Adds a predefined watermark to the input image
416
+
417
+ Args:
418
+ image: ([N,] B, RGB, H, W) in range [-1, 1]
419
+
420
+ Returns:
421
+ same as input but watermarked
422
+ """
423
+ image = 0.5 * image + 0.5
424
+ squeeze = len(image.shape) == 4
425
+ if squeeze:
426
+ image = image[None, ...]
427
+ n = image.shape[0]
428
+ image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[:, :, :, ::-1]
429
+ # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
430
+ # watermarking libary expects input as cv2 BGR format
431
+ for k in range(image_np.shape[0]):
432
+ image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
433
+ image = torch.from_numpy(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)).to(
434
+ image.device
435
+ )
436
+ image = torch.clamp(image / 255, min=0.0, max=1.0)
437
+ if squeeze:
438
+ image = image[0]
439
+ image = 2 * image - 1
440
+ return image
441
+
442
+
443
+ # A fixed 48-bit message that was chosen at random
444
+ WATERMARK_MESSAGE = 0b001010101111111010000111100111001111010100101110
445
+ # bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
446
+ WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
447
+ embed_watermark = WatermarkEmbedder(WATERMARK_BITS)