WithAnyone commited on
Commit
4910a8a
·
verified ·
1 Parent(s): db24270

Upload 29 files

Browse files
.gitignore ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ #.idea/
169
+
170
+ # Ruff stuff:
171
+ .ruff_cache/
172
+
173
+ # PyPI configuration file
174
+ .pypirc
175
+
176
+ # User config files
177
+ .vscode/
178
+ output/
179
+
180
+ # ckpt
181
+ *.bin
182
+ *.pt
183
+ *.pth
184
+ ckpts/
185
+ ckpt-*
186
+ ckpts/*
187
+
188
+ # legacy code
189
+ legacy/
190
+ legacy/*
191
+
192
+ # wandb
193
+ wandb/
194
+ wandb/*
195
+
196
+ # arcface models
197
+ models/
198
+
199
+ # debug
200
+ debug*
201
+
202
+
203
+ data_single/
204
+ data_single_10/
205
+ lora_attampt/
206
+ lora_attampt/*
207
+
208
+ *.safetensors
209
+ *.ckpt
210
+
211
+ .output/
212
+ for_bbox/
213
+
214
+ # data
215
+ data/
216
+ datasets/
217
+
218
+ nohup.out
219
+
220
+ 10**
221
+ temp_generated.png
222
+
223
+ facenet_pytorch/
224
+ facenet_pytorch/*
225
+
226
+ # AdaFace/
227
+ # AdaFace/*
228
+
229
+ pretrained/
230
+
231
+ git_backup/
232
+ git_backup/*
.gitmodules ADDED
File without changes
LICENSE ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FLUX.1 [dev] Non-Commercial License v1.1.1
2
+
3
+ Black Forest Labs Inc. (“we” or “our” or “Company”) is pleased to make available the weights, parameters and inference code for the FLUX.1 [dev] Model (as defined below) freely available for your non-commercial and non-production use as set forth in this FLUX.1 [dev] Non-Commercial License (“License”). The “FLUX.1 [dev] Model” means the FLUX.1 [dev] AI models and models denoted as FLUX.1 [dev], including but not limited to FLUX.1 [dev], FLUX.1 Fill [dev], FLUX.1 Depth [dev], FLUX.1 Canny [dev], FLUX.1 Redux [dev], FLUX.1 Canny [dev] LoRA, FLUX.1 Depth [dev] LoRA, and FLUX.1 Kontext [dev], and their elements which includes algorithms, software, checkpoints, parameters, source code (inference code, evaluation code, and if applicable, fine-tuning code) and any other materials associated with the FLUX.1 [dev] AI models made available by Company under this License, including if any, the technical documentation, manuals and instructions for the use and operation thereof (collectively, “FLUX.1 [dev] Model”). Note that we may also make available certain elements of what is included in the definition of “FLUX.1 [dev] Model” under a separate license, such as the inference code, and nothing in this License will be deemed to restrict or limit any other licenses granted by us in such elements.
4
+
5
+ By downloading, accessing, using, Distributing (as defined below), or creating a Derivative (as defined below) of the FLUX.1 [dev] Model, you agree to the terms of this License. If you do not agree to this License, then you do not have any rights to access, use, Distribute or create a Derivative of the FLUX.1 [dev] Model and you must immediately cease using the FLUX.1 [dev] Model. If you are agreeing to be bound by the terms of this License on behalf of your employer or other entity, you represent and warrant to us that you have full legal authority to bind your employer or such entity to this License. If you do not have the requisite authority, you may not accept the License or access the FLUX.1 [dev] Model on behalf of your employer or other entity.
6
+
7
+ 1. Definitions.
8
+ - a. “Derivative” means any (i) modified version of the FLUX.1 [dev] Model (including but not limited to any customized or fine-tuned version thereof), (ii) work based on the FLUX.1 [dev] Model, or (iii) any other derivative work thereof. For the avoidance of doubt, Outputs are not considered Derivatives under this License.
9
+ - b. “Distribution” or “Distribute” or “Distributing” means providing or making available, by any means, a copy of the FLUX.1 [dev] Models and/or the Derivatives as the case may be.
10
+ - c. “Non-Commercial Purpose” means any of the following uses, but only so far as you do not receive any direct or indirect payment arising from the use of the FLUX.1 [dev] Model, Derivatives, or FLUX Content Filters (as defined below): (i) personal use for research, experiment, and testing for the benefit of public knowledge, personal study, private entertainment, hobby projects, or otherwise not directly or indirectly connected to any commercial activities, business operations, or employment responsibilities; (ii) use by commercial or for-profit entities for testing, evaluation, or non-commercial research and development in a non-production environment; and (iii) use by any charitable organization for charitable purposes, or for testing or evaluation. For clarity, use (a) for revenue-generating activity, (b) in direct interactions with or that has impact on end users, or (c) to train, fine tune or distill other models for commercial use, in each case is not a Non-Commercial Purpose.
11
+ - d. “Outputs” means any content generated by the operation of the FLUX.1 [dev] Models or the Derivatives from an input (such as an image input) or prompt (i.e., text instructions) provided by users. For the avoidance of doubt, Outputs do not include any components of the FLUX.1 [dev] Models, such as any fine-tuned versions of the FLUX.1 [dev] Models, the weights, or parameters.
12
+ - e. “you” or “your” means the individual or entity entering into this License with Company.
13
+
14
+ 2. License Grant.
15
+ - a. License. Subject to your compliance with this License, Company grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license to access, use, create Derivatives of, and Distribute the FLUX.1 [dev] Models and Derivatives solely for your Non-Commercial Purposes. The foregoing license is personal to you, and you may not assign or sublicense this License or any other rights or obligations under this License without Company’s prior written consent; any such assignment or sublicense will be void and will automatically and immediately terminate this License. Any restrictions set forth herein regarding the FLUX.1 [dev] Model also apply to any Derivative you create or that are created on your behalf.
16
+ - b. Non-Commercial Use Only. You may only access, use, Distribute, or create Derivatives of the FLUX.1 [dev] Model or Derivatives for Non-Commercial Purposes. If you want to use a FLUX.1 [dev] Model or a Derivative for any purpose that is not expressly authorized under this License, such as for a commercial activity, you must request a license from Company, which Company may grant to you in Company’s sole discretion and which additional use may be subject to a fee, royalty or other revenue share. Please see www.bfl.ai if you would like a commercial license.
17
+ - c. Reserved Rights. The grant of rights expressly set forth in this License are the complete grant of rights to you in the FLUX.1 [dev] Model, and no other licenses are granted, whether by waiver, estoppel, implication, equity or otherwise. Company and its licensors reserve all rights not expressly granted by this License.
18
+ - d. Outputs. We claim no ownership rights in and to the Outputs. You are solely responsible for the Outputs you generate and their subsequent uses in accordance with this License. You may use Output for any purpose (including for commercial purposes), except as expressly prohibited herein. You may not use the Output to train, fine-tune or distill a model that is competitive with the FLUX.1 [dev] Model or the FLUX.1 Kontext [dev] Model.
19
+ - e. You may access, use, Distribute, or create Output of the FLUX.1 [dev] Model or Derivatives if you: (i) (A) implement and maintain content filtering measures (“Content Filters”) for your use of the FLUX.1 [dev] Model or Derivatives to prevent the creation, display, transmission, generation, or dissemination of unlawful or infringing content, which may include Content Filters that we may make available for use with the FLUX.1 [dev] Model (“FLUX Content Filters”), or (B) ensure Output undergoes review for unlawful or infringing content before public or non-public distribution, display, transmission or dissemination; and (ii) ensure Output includes disclosure (or other indication) that the Output was generated or modified using artificial intelligence technologies to the extent required under applicable law.
20
+
21
+ 3. Distribution. Subject to this License, you may Distribute copies of the FLUX.1 [dev] Model and/or Derivatives made by you, under the following conditions:
22
+ - a. you must make available a copy of this License to third-party recipients of the FLUX.1 [dev] Models and/or Derivatives you Distribute, and specify that any rights to use the FLUX.1 [dev] Models and/or Derivatives shall be directly granted by Company to said third-party recipients pursuant to this License;
23
+ - b. you must prominently display the following notice alongside the Distribution of the FLUX.1 [dev] Model or Derivative (such as via a “Notice” text file distributed as part of such FLUX.1 [dev] Model or Derivative) (the “Attribution Notice”):
24
+
25
+ “The FLUX.1 [dev] Model is licensed by Black Forest Labs Inc. under the FLUX.1 [dev] Non-Commercial License. Copyright Black Forest Labs Inc.
26
+ IN NO EVENT SHALL BLACK FOREST LABS INC. BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH USE OF THIS MODEL.”
27
+
28
+ - c. in the case of Distribution of Derivatives made by you: (i) you must also include in the Attribution Notice a statement that you have modified the applicable FLUX.1 [dev] Model; (ii) any terms and conditions you impose on any third-party recipients relating to Derivatives made by or for you shall neither limit such third-party recipients’ use of the FLUX.1 [dev] Model or any Derivatives made by or for Company in accordance with this License nor conflict with any of its terms and conditions and must include disclaimer of warranties and limitation of liability provisions that are at least as protective of Company as those set forth herein; and (iii) you must not misrepresent or imply, through any means, that the Derivatives made by or for you and/or any modified version of the FLUX.1 [dev] Model you Distribute under your name and responsibility is an official product of the Company or has been endorsed, approved or validated by the Company, unless you are authorized by Company to do so in writing.
29
+
30
+ 4. Restrictions. You will not, and will not permit, assist or cause any third party to
31
+ - a. use, modify, copy, reproduce, create Derivatives of, or Distribute the FLUX.1 [dev] Model (or any Derivative thereof, or any data produced by the FLUX.1 [dev] Model), in whole or in part, (i) for any commercial or production purposes, (ii) military purposes, (iii) purposes of surveillance, including any research or development relating to surveillance, (iv) biometric processing, (v) in any manner that infringes, misappropriates, or otherwise violates (or is likely to infringe, misappropriate, or otherwise violate) any third party’s legal rights, including rights of publicity or “digital replica” rights, (vi) in any unlawful, fraudulent, defamatory, or abusive activity, (vii) to generate unlawful content, including child sexual abuse material, or non-consensual intimate images; or (viii) in any manner that violates any applicable law and violating any privacy or security laws, rules, regulations, directives, or governmental requirements (including the General Data Privacy Regulation (Regulation (EU) 2016/679), the California Consumer Privacy Act, any and all laws governing the processing of biometric information, and the EU Artificial Intelligence Act (Regulation (EU) 2024/1689), as well as all amendments and successor laws to any of the foregoing;
32
+ - b. alter or remove copyright and other proprietary notices which appear on or in any portion of the FLUX.1 [dev] Model;
33
+ - c. utilize any equipment, device, software, or other means to circumvent or remove any security or protection used by Company in connection with the FLUX.1 [dev] Model, or to circumvent or remove any usage restrictions, or to enable functionality disabled by FLUX.1 [dev] Model;
34
+ - d. offer or impose any terms on the FLUX.1 [dev] Model that alter, restrict, or are inconsistent with the terms of this License;
35
+ - e. violate any applicable U.S. and non-U.S. export control and trade sanctions laws (“Export Laws”) in connection with your use or Distribution of any FLUX.1 [dev] Model;
36
+ - f. directly or indirectly Distribute, export, or otherwise transfer FLUX.1 [dev] Model (i) to any individual, entity, or country prohibited by Export Laws; (ii) to anyone on U.S. or non-U.S. government restricted parties lists; (iii) for any purpose prohibited by Export Laws, including nuclear, chemical or biological weapons, or missile technology applications; (iv) use or download FLUX.1 [dev] Model if you or they are (a) located in a comprehensively sanctioned jurisdiction, (b) currently listed on any U.S. or non-U.S. restricted parties list, or (c) for any purpose prohibited by Export Laws; and (v) will not disguise your location through IP proxying or other methods.
37
+
38
+ 5. DISCLAIMERS. THE FLUX.1 [dev] MODEL AND FLUX CONTENT FILTERS ARE PROVIDED “AS IS” AND “WITH ALL FAULTS” WITH NO WARRANTY OF ANY KIND, EXPRESS OR IMPLIED. COMPANY EXPRESSLY DISCLAIMS ALL REPRESENTATIONS AND WARRANTIES, EXPRESS OR IMPLIED, WHETHER BY STATUTE, CUSTOM, USAGE OR OTHERWISE AS TO ANY MATTERS RELATED TO THE FLUX.1 [dev] MODEL AND FLUX CONTENT FILTERS, INCLUDING BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE, SATISFACTORY QUALITY, OR NON-INFRINGEMENT. COMPANY MAKES NO WARRANTIES OR REPRESENTATIONS THAT THE FLUX.1 [dev] MODEL AND FLUX CONTENT FILTERS WILL BE ERROR FREE OR FREE OF VIRUSES OR OTHER HARMFUL COMPONENTS, OR PRODUCE ANY PARTICULAR RESULTS.
39
+
40
+ 6. LIMITATION OF LIABILITY. TO THE FULLEST EXTENT PERMITTED BY LAW, IN NO EVENT WILL COMPANY BE LIABLE TO YOU OR YOUR EMPLOYEES, AFFILIATES, USERS, OFFICERS OR DIRECTORS (A) UNDER ANY THEORY OF LIABILITY, WHETHER BASED IN CONTRACT, TORT, NEGLIGENCE, STRICT LIABILITY, WARRANTY, OR OTHERWISE UNDER THIS LICENSE, OR (B) FOR ANY INDIRECT, CONSEQUENTIAL, EXEMPLARY, INCIDENTAL, PUNITIVE OR SPECIAL DAMAGES OR LOST PROFITS, EVEN IF COMPANY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. THE FLUX.1 [dev] MODEL, ITS CONSTITUENT COMPONENTS, FLUX CONTENT FILTERS, AND ANY OUTPUT (COLLECTIVELY, “MODEL MATERIALS”) ARE NOT DESIGNED OR INTENDED FOR USE IN ANY APPLICATION OR SITUATION WHERE FAILURE OR FAULT OF THE MODEL MATERIALS COULD REASONABLY BE ANTICIPATED TO LEAD TO SERIOUS INJURY OF ANY PERSON, INCLUDING POTENTIAL DISCRIMINATION OR VIOLATION OF AN INDIVIDUAL’S PRIVACY RIGHTS, OR TO SEVERE PHYSICAL, PROPERTY, OR ENVIRONMENTAL DAMAGE (EACH, A “HIGH-RISK USE”). IF YOU ELECT TO USE ANY OF THE MODEL MATERIALS FOR A HIGH-RISK USE, YOU DO SO AT YOUR OWN RISK. YOU AGREE TO DESIGN AND IMPLEMENT APPROPRIATE DECISION-MAKING AND RISK-MITIGATION PROCEDURES AND POLICIES IN CONNECTION WITH A HIGH-RISK USE SUCH THAT EVEN IF THERE IS A FAILURE OR FAULT IN ANY OF THE MODEL MATERIALS, THE SAFETY OF PERSONS OR PROPERTY AFFECTED BY THE ACTIVITY STAYS AT A LEVEL THAT IS REASONABLE, APPROPRIATE, AND LAWFUL FOR THE FIELD OF THE HIGH-RISK USE.
41
+
42
+ 7. INDEMNIFICATION. You will indemnify, defend and hold harmless Company and our subsidiaries and affiliates, and each of our respective shareholders, directors, officers, employees, agents, successors, and assigns (collectively, the “Company Parties”) from and against any losses, liabilities, damages, fines, penalties, and expenses (including reasonable attorneys’ fees) incurred by any Company Party in connection with any claim, demand, allegation, lawsuit, proceeding, or investigation (collectively, “Claims”) arising out of or related to (a) your access to or use of the FLUX.1 [dev] Model (including in connection with any Output, results or data generated from such access or use, or from your access or use of any FLUX Content Filters), including any High-Risk Use; (b) your Content Filters, including your failure to implement any Content Filters where required by this License such as in Section 2(e); (c) your violation of this License; or (d) your violation, misappropriation or infringement of any rights of another (including intellectual property or other proprietary rights and privacy rights). You will promptly notify the Company Parties of any such Claims, and cooperate with Company Parties in defending such Claims. You will also grant the Company Parties sole control of the defense or settlement, at Company’s sole option, of any Claims. This indemnity is in addition to, and not in lieu of, any other indemnities or remedies set forth in a written agreement between you and Company or the other Company Parties.
43
+
44
+ 8. Termination; Survival.
45
+ - a. This License will automatically terminate upon any breach by you of the terms of this License.
46
+ - b. We may terminate this License, in whole or in part, at any time upon notice (including electronic) to you.
47
+ - c. If you initiate any legal action or proceedings against Company or any other entity (including a cross-claim or counterclaim in a lawsuit), alleging that the FLUX.1 [dev] Model, any Derivative, or FLUX Content Filters, or any part thereof, infringe upon intellectual property or other rights owned or licensable by you, then any licenses granted to you under this License will immediately terminate as of the date such legal action or claim is filed or initiated.
48
+ - d. Upon termination of this License, you must cease all use, access or Distribution of the FLUX.1 [dev] Model, any Derivatives, and any FLUX Content Filters. The following sections survive termination of this License 2(c), 2(d), 4-11.
49
+
50
+ 9. Third Party Materials. The FLUX.1 [dev] Model may contain third-party software or other components (including free and open source software) (all of the foregoing, “Third Party Materials”), which are subject to the license terms of the respective third-party licensors. Your dealings or correspondence with third parties and your use of or interaction with any Third Party Materials are solely between you and the third party. Company does not control or endorse, and makes no representations or warranties regarding, any Third Party Materials, and your access to and use of such Third Party Materials are at your own risk.
51
+
52
+ 10. Trademarks. You have not been granted any trademark license as part of this License and may not use any name, logo or trademark associated with Company without the prior written permission of Company, except to the extent necessary to make the reference required in the Attribution Notice as specified above or as is reasonably necessary in describing the FLUX.1 [dev] Model and its creators.
53
+
54
+ 11. General. This License will be governed and construed under the laws of the State of Delaware without regard to conflicts of law provisions. If any provision or part of a provision of this License is unlawful, void or unenforceable, that provision or part of the provision is deemed severed from this License, and will not affect the validity and enforceability of any remaining provisions. The failure of Company to exercise or enforce any right or provision of this License will not operate as a waiver of such right or provision. This License does not confer any third-party beneficiary rights upon any other person or entity. This License, together with the documentation, contains the entire understanding between you and Company regarding the subject matter of this License, and supersedes all other written or oral agreements and understandings between you and Company regarding such subject matter.
app.py ADDED
@@ -0,0 +1,586 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Fudan University. All rights reserved.
2
+
3
+
4
+
5
+ import dataclasses
6
+ import json
7
+ import os
8
+ from pathlib import Path
9
+ from typing import List, Literal, Optional
10
+
11
+ import cv2
12
+ import gradio as gr
13
+ import numpy as np
14
+ import torch
15
+ from PIL import Image, ImageDraw
16
+
17
+ from withanyone.flux.pipeline import WithAnyonePipeline
18
+ from util import extract_moref, face_preserving_resize
19
+ import insightface
20
+
21
+
22
+ def captioner(prompt: str, num_person = 1) -> List[List[float]]:
23
+ # use random choose for testing
24
+ # within 512
25
+ if num_person == 1:
26
+ bbox_choices = [
27
+ # expanded, centered and quadrant placements
28
+ [96, 96, 288, 288],
29
+ [128, 128, 320, 320],
30
+ [160, 96, 352, 288],
31
+ [96, 160, 288, 352],
32
+ [208, 96, 400, 288],
33
+ [96, 208, 288, 400],
34
+ [192, 160, 368, 336],
35
+ [64, 128, 224, 320],
36
+ [288, 128, 448, 320],
37
+ [128, 256, 320, 448],
38
+ [80, 80, 240, 272],
39
+ [196, 196, 380, 380],
40
+ # originals
41
+ [100, 100, 300, 300],
42
+ [150, 50, 450, 350],
43
+ [200, 100, 500, 400],
44
+ [250, 150, 512, 450],
45
+ ]
46
+ return [bbox_choices[np.random.randint(0, len(bbox_choices))]]
47
+ elif num_person == 2:
48
+ # realistic side-by-side rows (no vertical stacks or diagonals)
49
+ bbox_choices = [
50
+ [[64, 112, 224, 304], [288, 112, 448, 304]],
51
+ [[48, 128, 208, 320], [304, 128, 464, 320]],
52
+ [[32, 144, 192, 336], [320, 144, 480, 336]],
53
+ [[80, 96, 240, 288], [272, 96, 432, 288]],
54
+ [[80, 160, 240, 352], [272, 160, 432, 352]],
55
+ [[64, 128, 240, 336], [272, 144, 432, 320]], # slight stagger, same row
56
+ [[96, 160, 256, 352], [288, 160, 448, 352]],
57
+ [[64, 192, 224, 384], [288, 192, 448, 384]], # lower row
58
+ [[16, 128, 176, 320], [336, 128, 496, 320]], # near edges
59
+ [[48, 120, 232, 328], [280, 120, 464, 328]],
60
+ [[96, 160, 240, 336], [272, 160, 416, 336]], # tighter faces
61
+ [[72, 136, 232, 328], [280, 152, 440, 344]], # small vertical offset
62
+ [[48, 120, 224, 344], [288, 144, 448, 336]], # asymmetric sizes
63
+ [[80, 224, 240, 416], [272, 224, 432, 416]], # bottom row
64
+ [[80, 64, 240, 256], [272, 64, 432, 256]], # top row
65
+ [[96, 176, 256, 368], [288, 176, 448, 368]],
66
+ ]
67
+ return bbox_choices[np.random.randint(0, len(bbox_choices))]
68
+
69
+ elif num_person == 3:
70
+ # Non-overlapping 3-person layouts within 512x512
71
+ bbox_choices = [
72
+ [[20, 140, 150, 360], [180, 120, 330, 360], [360, 130, 500, 360]],
73
+ [[30, 100, 160, 300], [190, 90, 320, 290], [350, 110, 480, 310]],
74
+ [[40, 180, 150, 330], [200, 180, 310, 330], [360, 180, 470, 330]],
75
+ [[60, 120, 170, 300], [210, 110, 320, 290], [350, 140, 480, 320]],
76
+ [[50, 80, 170, 250], [200, 130, 320, 300], [350, 80, 480, 250]],
77
+ [[40, 260, 170, 480], [190, 60, 320, 240], [350, 260, 490, 480]],
78
+ [[30, 120, 150, 320], [200, 140, 320, 340], [360, 160, 500, 360]],
79
+ [[80, 140, 200, 300], [220, 80, 350, 260], [370, 160, 500, 320]],
80
+ ]
81
+ return bbox_choices[np.random.randint(0, len(bbox_choices))]
82
+ elif num_person == 4:
83
+ # Non-overlapping 4-person layouts within 512x512
84
+ bbox_choices = [
85
+ [[20, 100, 120, 240], [140, 100, 240, 240], [260, 100, 360, 240], [380, 100, 480, 240]],
86
+ [[40, 60, 200, 260], [220, 60, 380, 260], [40, 280, 200, 480], [220, 280, 380, 480]],
87
+ [[180, 30, 330, 170], [30, 220, 150, 380], [200, 220, 320, 380], [360, 220, 490, 380]],
88
+ [[30, 60, 140, 200], [370, 60, 480, 200], [30, 320, 140, 460], [370, 320, 480, 460]],
89
+ [[20, 120, 120, 380], [140, 100, 240, 360], [260, 120, 360, 380], [380, 100, 480, 360]],
90
+ [[30, 80, 150, 240], [180, 120, 300, 280], [330, 80, 450, 240], [200, 300, 320, 460]],
91
+ [[30, 140, 110, 330], [140, 140, 220, 330], [250, 140, 330, 330], [370, 140, 450, 330]],
92
+ [[40, 80, 150, 240], [40, 260, 150, 420], [200, 80, 310, 240], [370, 80, 480, 240]],
93
+ ]
94
+ return bbox_choices[np.random.randint(0, len(bbox_choices))]
95
+
96
+
97
+
98
+
99
+ class FaceExtractor:
100
+ def __init__(self, model_path="./"):
101
+ try:
102
+ self.model = insightface.app.FaceAnalysis(name = "antelopev2", root=model_path, providers=['CUDAExecutionProvider'])
103
+ except Exception as e:
104
+ print(f"Error loading insightface model: {e}. There might be an issue with the directory structure. Trying to fix it...")
105
+ antelopev2_nested_path = os.path.join(model_path, "models", "antelopev2", "antelopev2")
106
+ print(f"Checking for nested path: {antelopev2_nested_path}")
107
+ if os.path.exists(antelopev2_nested_path):
108
+ import subprocess
109
+ print("Detected nested antelopev2 directory, fixing directory structure...")
110
+ # Change to the model_path directory to execute commands
111
+ current_dir = os.getcwd()
112
+ os.chdir(model_path)
113
+ # Execute the commands as specified by the user
114
+ subprocess.run(["mv", "models/antelopev2/", "models/antelopev2_"])
115
+ subprocess.run(["mv", "models/antelopev2_/antelopev2/", "models/antelopev2/"])
116
+ # Return to the original directory
117
+ os.chdir(current_dir)
118
+ print("Directory structure fixed.")
119
+ self.model = insightface.app.FaceAnalysis(name="antelopev2", root="./")
120
+ self.model.prepare(ctx_id=0)
121
+
122
+ def extract(self, image: Image.Image):
123
+ """Extract single face and embedding from an image"""
124
+ image_np = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
125
+ res = self.model.get(image_np)
126
+ if len(res) == 0:
127
+ return None, None
128
+ res = res[0]
129
+ bbox = res["bbox"]
130
+ moref = extract_moref(image, {"bboxes": [bbox]}, 1)
131
+ return moref[0], res["embedding"]
132
+
133
+ def extract_refs(self, image: Image.Image):
134
+ """Extract multiple faces and embeddings from an image"""
135
+ image_np = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
136
+ res = self.model.get(image_np)
137
+ if len(res) == 0:
138
+ return None, None, None
139
+ ref_imgs = []
140
+ arcface_embeddings = []
141
+ bboxes = []
142
+ for r in res:
143
+ bbox = r["bbox"]
144
+ bboxes.append(bbox)
145
+ moref = extract_moref(image, {"bboxes": [bbox]}, 1)
146
+ ref_imgs.append(moref[0])
147
+ arcface_embeddings.append(r["embedding"])
148
+
149
+ # Convert bboxes to the correct format
150
+ new_img, new_bboxes = face_preserving_resize(image, bboxes, 512)
151
+ return ref_imgs, arcface_embeddings, new_bboxes, new_img
152
+
153
+
154
+ def resize_bbox(bbox, ori_width, ori_height, new_width, new_height):
155
+ """Resize bounding box coordinates while preserving aspect ratio"""
156
+ x1, y1, x2, y2 = bbox
157
+
158
+ # Calculate scaling factors
159
+ width_scale = new_width / ori_width
160
+ height_scale = new_height / ori_height
161
+
162
+ # Use minimum scaling factor to preserve aspect ratio
163
+ min_scale = min(width_scale, height_scale)
164
+
165
+ # Calculate offsets for centering the scaled box
166
+ width_offset = (new_width - ori_width * min_scale) / 2
167
+ height_offset = (new_height - ori_height * min_scale) / 2
168
+
169
+ # Scale and adjust coordinates
170
+ new_x1 = int(x1 * min_scale + width_offset)
171
+ new_y1 = int(y1 * min_scale + height_offset)
172
+ new_x2 = int(x2 * min_scale + width_offset)
173
+ new_y2 = int(y2 * min_scale + height_offset)
174
+
175
+ return [new_x1, new_y1, new_x2, new_y2]
176
+
177
+
178
+ def draw_bboxes_on_image(image, bboxes):
179
+ """Draw bounding boxes on image for visualization"""
180
+ if bboxes is None:
181
+ return image
182
+
183
+ # Create a copy to draw on
184
+ img_draw = image.copy()
185
+ draw = ImageDraw.Draw(img_draw)
186
+
187
+ # Draw each bbox with a different color
188
+ colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0)]
189
+
190
+ for i, bbox in enumerate(bboxes):
191
+ color = colors[i % len(colors)]
192
+ x1, y1, x2, y2 = [int(coord) for coord in bbox]
193
+ # Draw rectangle
194
+ draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
195
+ # Draw label
196
+ draw.text((x1, y1-15), f"Face {i+1}", fill=color)
197
+
198
+ return img_draw
199
+
200
+
201
+ def create_demo(
202
+ model_type: str = "flux-dev",
203
+ ipa_path: str = "./ckpt/ipa.safetensors",
204
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
205
+ offload: bool = False,
206
+ lora_rank: int = 64,
207
+ additional_lora_ckpt: Optional[str] = None,
208
+ lora_scale: float = 1.0,
209
+ clip_path: str = "openai/clip-vit-large-patch14",
210
+ t5_path: str = "xlabs-ai/xflux_text_encoders",
211
+ flux_path: str = "black-forest-labs/FLUX.1-dev",
212
+ ):
213
+
214
+ face_extractor = FaceExtractor()
215
+ # Initialize pipeline and face extractor
216
+ pipeline = WithAnyonePipeline(
217
+ model_type,
218
+ ipa_path,
219
+ device,
220
+ offload,
221
+ only_lora=True,
222
+ no_lora=True,
223
+ lora_rank=lora_rank,
224
+ additional_lora_ckpt=additional_lora_ckpt,
225
+ lora_weight=lora_scale,
226
+ face_extractor=face_extractor,
227
+ clip_path=clip_path,
228
+ t5_path=t5_path,
229
+ flux_path=flux_path,
230
+ )
231
+
232
+
233
+
234
+ # Add project badges
235
+ # badges_text = r"""
236
+ # <div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
237
+ # <a href="https://github.com/bytedance/UNO"><img alt="Build" src="https://img.shields.io/github/stars/bytedance/UNO"></a>
238
+ # <a href="https://bytedance.github.io/UNO/"><img alt="Build" src="https://img.shields.io/badge/Project%20Page-UNO-yellow"></a>
239
+ # <a href="https://arxiv.org/abs/2504.02160"><img alt="Build" src="https://img.shields.io/badge/arXiv%20paper-UNO-b31b1b.svg"></a>
240
+ # </div>
241
+ # """.strip()
242
+
243
+ def parse_bboxes(bbox_text):
244
+ """Parse bounding box text input"""
245
+ if not bbox_text or bbox_text.strip() == "":
246
+ return None
247
+
248
+ try:
249
+ bboxes = []
250
+ lines = bbox_text.strip().split("\n")
251
+ for line in lines:
252
+ if not line.strip():
253
+ continue
254
+ coords = [float(x) for x in line.strip().split(",")]
255
+ if len(coords) != 4:
256
+ raise ValueError(f"Each bbox must have 4 coordinates (x1,y1,x2,y2), got: {line}")
257
+ bboxes.append(coords)
258
+ # print(f"\nParsed bboxes: {bboxes}\n")
259
+ return bboxes
260
+ except Exception as e:
261
+ raise gr.Error(f"Invalid bbox format: {e}")
262
+
263
+ def extract_from_multi_person(multi_person_image):
264
+ """Extract references and bboxes from a multi-person image"""
265
+ if multi_person_image is None:
266
+ return None, None, None, None
267
+
268
+ # Convert from numpy to PIL if needed
269
+ if isinstance(multi_person_image, np.ndarray):
270
+ multi_person_image = Image.fromarray(multi_person_image)
271
+
272
+ ref_imgs, arcface_embeddings, bboxes, new_img = face_extractor.extract_refs(multi_person_image)
273
+
274
+ if ref_imgs is None or len(ref_imgs) == 0:
275
+ raise gr.Error("No faces detected in the multi-person image")
276
+
277
+ # Limit to max 4 faces
278
+ ref_imgs = ref_imgs[:4]
279
+ arcface_embeddings = arcface_embeddings[:4]
280
+ bboxes = bboxes[:4]
281
+
282
+ # Create visualization with bboxes
283
+ viz_image = draw_bboxes_on_image(new_img, bboxes)
284
+
285
+ # Format bboxes as string for display
286
+ bbox_text = "\n".join([f"{bbox[0]:.1f},{bbox[1]:.1f},{bbox[2]:.1f},{bbox[3]:.1f}" for bbox in bboxes])
287
+
288
+ return ref_imgs, arcface_embeddings, bboxes, viz_image
289
+
290
+ def process_and_generate(
291
+ prompt,
292
+ width, height,
293
+ guidance, num_steps, seed,
294
+ ref_img1, ref_img2, ref_img3, ref_img4,
295
+ manual_bboxes_text,
296
+ multi_person_image,
297
+ # use_text_prompt,
298
+ # id_weight,
299
+ siglip_weight
300
+ ):
301
+ # Collect and validate reference images
302
+ ref_images = [img for img in [ref_img1, ref_img2, ref_img3, ref_img4] if img is not None]
303
+
304
+ if not ref_images:
305
+ raise gr.Error("At least one reference image is required")
306
+
307
+ # Process reference images to extract face and embeddings
308
+ ref_imgs = []
309
+ arcface_embeddings = []
310
+
311
+ # Modified bbox handling logic
312
+ if multi_person_image is not None:
313
+ # Extract from multi-person image mode
314
+ extracted_refs, extracted_embeddings, bboxes_, _ = extract_from_multi_person(multi_person_image)
315
+ if extracted_refs is None:
316
+ raise gr.Error("Failed to extract faces from the multi-person image")
317
+
318
+ print("bboxes from multi-person image:", bboxes_)
319
+ # need to resize bboxes from 512 512 to width height
320
+ bboxes_ = [resize_bbox(bbox, 512, 512, width, height) for bbox in bboxes_]
321
+
322
+ else:
323
+ # Parse manual bboxes
324
+ bboxes_ = parse_bboxes(manual_bboxes_text)
325
+
326
+ # If no manual bboxes provided, use automatic captioner
327
+ if bboxes_ is None:
328
+ print("No multi-person image or manual bboxes provided. Using automatic captioner.")
329
+ # Generate automatic bboxes based on image dimensions
330
+ bboxes__ = captioner(prompt, num_person=len(ref_images))
331
+ # resize to width height
332
+ bboxes_ = [resize_bbox(bbox, 512, 512, width, height) for bbox in bboxes__]
333
+ print("Automatically generated bboxes:", bboxes_)
334
+
335
+ bboxes = [bboxes_] # 伪装batch输入
336
+ # else:
337
+ # Manual mode: process each reference image
338
+ for img in ref_images:
339
+ if isinstance(img, np.ndarray):
340
+ img = Image.fromarray(img)
341
+
342
+ ref_img, embedding = face_extractor.extract(img)
343
+ if ref_img is None or embedding is None:
344
+ raise gr.Error("Failed to extract face from one of the reference images")
345
+
346
+ ref_imgs.append(ref_img)
347
+ arcface_embeddings.append(embedding)
348
+
349
+ # pad arcface_embeddings to 4 if less than 4
350
+ # while len(arcface_embeddings) < 4:
351
+ # arcface_embeddings.append(np.zeros_like(arcface_embeddings[0]))
352
+
353
+
354
+ if bboxes is None:
355
+ raise gr.Error("Either provide manual bboxes or a multi-person image for bbox extraction")
356
+
357
+ if len(bboxes[0]) != len(ref_imgs):
358
+ raise gr.Error(f"Number of bboxes ({len(bboxes[0])}) must match number of reference images ({len(ref_imgs)})")
359
+
360
+ # Convert arcface embeddings to tensor
361
+ arcface_embeddings = [torch.tensor(embedding) for embedding in arcface_embeddings]
362
+ arcface_embeddings = torch.stack(arcface_embeddings).to(device)
363
+
364
+ # Generate image
365
+ final_prompt = prompt
366
+
367
+ print(f"Generating image of size {width}x{height} with bboxes: {bboxes} ")
368
+
369
+ if seed < 0:
370
+ seed = np.random.randint(0, 1000000)
371
+
372
+ image_gen = pipeline(
373
+ prompt=final_prompt,
374
+ width=width,
375
+ height=height,
376
+ guidance=guidance,
377
+ num_steps=num_steps,
378
+ seed=seed if seed > 0 else None,
379
+ ref_imgs=ref_imgs,
380
+ arcface_embeddings=arcface_embeddings,
381
+ bboxes=bboxes,
382
+ id_weight = 1 - siglip_weight,
383
+ siglip_weight=siglip_weight,
384
+ )
385
+
386
+ # Save temp file for download
387
+ temp_path = "temp_generated.png"
388
+ image_gen.save(temp_path)
389
+
390
+ # draw bboxes on the generated image for debug
391
+ debug_face = draw_bboxes_on_image(image_gen, bboxes[0])
392
+
393
+ return image_gen, debug_face, temp_path
394
+
395
+ def update_bbox_display(multi_person_image):
396
+ if multi_person_image is None:
397
+ return None, gr.update(visible=True), gr.update(visible=False)
398
+
399
+ try:
400
+ _, _, _, viz_image = extract_from_multi_person(multi_person_image)
401
+ return viz_image, gr.update(visible=False), gr.update(visible=True)
402
+ except Exception as e:
403
+ return None, gr.update(visible=True), gr.update(visible=False)
404
+
405
+ # Create Gradio interface
406
+ with gr.Blocks() as demo:
407
+ gr.Markdown("# WithAnyone Demo")
408
+ # gr.Markdown(badges_text)
409
+
410
+ with gr.Row():
411
+
412
+ with gr.Column():
413
+ # Input controls
414
+ generate_btn = gr.Button("Generate", variant="primary")
415
+ with gr.Row():
416
+ with gr.Column():
417
+ siglip_weight = gr.Slider(0.0, 1.0, 1.0, step=0.05, label="Spiritual Resemblance <--> Formal Resemblance")
418
+ with gr.Row():
419
+ prompt = gr.Textbox(label="Prompt", value="a person in a beautiful garden. High resolution, extremely detailed")
420
+ # use_text_prompt = gr.Checkbox(label="Use text prompt", value=True)
421
+
422
+
423
+ with gr.Row():
424
+ # Image generation settings
425
+ with gr.Column():
426
+ width = gr.Slider(512, 1024, 768, step=64, label="Generation Width")
427
+ height = gr.Slider(512, 1024, 768, step=64, label="Generation Height")
428
+
429
+ with gr.Accordion("Advanced Options", open=False):
430
+ with gr.Row():
431
+ num_steps = gr.Slider(1, 50, 25, step=1, label="Number of steps")
432
+ guidance = gr.Slider(1.0, 10.0, 4.0, step=0.1, label="Guidance")
433
+ seed = gr.Number(-1, label="Seed (-1 for random)")
434
+
435
+ # start_at = gr.Slider(0, 50, 0, step=1, label="Start Identity at Step")
436
+ # end_at = gr.Number(-1, label="End Identity at Step (-1 for last)")
437
+
438
+ # with gr.Row():
439
+ # # skip_every = gr.Number(-1, label="Skip Identity Every N Steps (-1 for no skip)")
440
+
441
+ # siglip_weight = gr.Slider(0.0, 1.0, 1.0, step=0.05, label="Siglip Weight")
442
+
443
+
444
+ with gr.Row():
445
+ with gr.Column():
446
+ # Reference image inputs
447
+ gr.Markdown("### Face References (1-4 required)")
448
+ ref_img1 = gr.Image(label="Reference 1", type="pil")
449
+ ref_img2 = gr.Image(label="Reference 2", type="pil", visible=True)
450
+ ref_img3 = gr.Image(label="Reference 3", type="pil", visible=True)
451
+ ref_img4 = gr.Image(label="Reference 4", type="pil", visible=True)
452
+
453
+ with gr.Column():
454
+ # Bounding box inputs
455
+ gr.Markdown("### Mask Configuration (Option 1: Automatic)")
456
+ multi_person_image = gr.Image(label="Multi-person image (for automatic bbox extraction)", type="pil")
457
+ bbox_preview = gr.Image(label="Detected Faces", type="pil")
458
+
459
+
460
+ gr.Markdown("### Mask Configuration (Option 2: Manual)")
461
+ manual_bbox_input = gr.Textbox(
462
+ label="Manual Bounding Boxes (one per line, format: x1,y1,x2,y2)",
463
+ lines=4,
464
+ placeholder="100,100,200,200\n300,100,400,200"
465
+ )
466
+
467
+
468
+
469
+
470
+
471
+ # generate_btn = gr.Button("Generate", variant="primary")
472
+
473
+ with gr.Column():
474
+ # Output display
475
+ output_image = gr.Image(label="Generated Image")
476
+ debug_face = gr.Image(label="Debug. Faces are expected to be generated in these boxes")
477
+ download_btn = gr.File(label="Download full-resolution", type="filepath", interactive=False)
478
+
479
+ # Examples section
480
+ with gr.Row():
481
+
482
+ gr.Markdown("""
483
+ # Example Configurations
484
+
485
+ ### Tips for Better Results
486
+ Be prepared for the first few runs as it may not be very satisfying.
487
+
488
+ - Provide detailed prompts describing the identity. WithAnyone is "controllable", so it needs more information to be controlled. Here are something that might go wrong if not specified:
489
+ - Skin color (generally the race is fine, but for asain descent, if not specified, it may generate darker skin tone);
490
+ - Age (e.g., intead of "a man", try "a young man". If not specified, it may generate an older figure);
491
+ - Body build;
492
+ - Hairstyle;
493
+ - Accessories (glasses, hats, earrings, etc.);
494
+ - Makeup
495
+ - Use the slider to balance between "Resemblance in Spirit" and "Resemblance in Form" according to your needs. If you want to preserve more details in the reference image, move the slider to the right; if you want more freedom and creativity, move it to the left.
496
+ - Try it with LoRAs from community. They are usually fantastic.
497
+ """)
498
+ with gr.Row():
499
+ examples = gr.Examples(
500
+ examples=[
501
+ [
502
+ "a highly detailed portrait of a woman shown in profile. Her long, dark hair flows elegantly, intricately decorated with an abundant array of colorful flowers—ranging from soft light pinks and vibrant light oranges to delicate greyish blues—and lush green leaves, giving a sense of natural beauty and charm. Her bright blue eyes are striking, and her lips are painted a vivid red, adding to her alluring appearance. She is clad in an ornate garment with intricate floral patterns in warm hues like pink and orange, featuring exquisite detailing that speaks of fine craftsmanship. Around her neck, she wears a decorative choker with intricate designs, and dangling from her ears are beautiful blue teardrop earrings that catch the light. The background is filled with a profusion of flowers in various shades, creating a rich, vibrant, and romantic atmosphere that complements the woman's elegant and enchanting look.", # prompt
503
+ 1024, 1024, # width, height
504
+ 4.0, 25, 42, # guidance, num_steps, seed
505
+ "assets/ref1.jpg", None, None, None, # ref images
506
+ "240,180,540,500", None, # manual_bbox_input, multi_person_image
507
+ # True, # use_text_prompt
508
+ 0.0, # siglip_weight
509
+ ],
510
+ [
511
+ "High resolution anfd extremely detailed image of two elegant ladies enjoying a serene afternoon in a quaint Parisian café. They both wear fashionable trench coats and stylish berets, exuding an air of sophistication. One lady gently sips on a cappuccino, while her companion reads an intriguing novel with a subtle smile. The café is framed by charming antique furniture and vintage posters adorning the walls. Soft, warm light filters through a window, casting delicate shadows and creating a cozy, inviting atmosphere. Captured from a slightly elevated angle, the composition highlights the warmth of the scene in a gentle watercolor illustrative style. ", # prompt
512
+ 1024, 1024, # width, height
513
+ 4.0, 25, 42, # guidance, num_steps, seed
514
+ "assets/ref1.jpg", "assets/ref2.jpg", None, None, # ref images
515
+ "248,172,428,498\n554,128,728,464", None, # manual_bbox_input, multi_person_image
516
+ # True, # use_text_prompt
517
+ 0.0, # siglip_weight
518
+ ]
519
+ ],
520
+ inputs=[
521
+ prompt, width, height, guidance, num_steps, seed,
522
+ ref_img1, ref_img2, ref_img3, ref_img4,
523
+ manual_bbox_input, multi_person_image,
524
+ siglip_weight
525
+ ],
526
+ label="Click to load example configurations"
527
+ )
528
+ # Set up event handlers
529
+ multi_person_image.change(
530
+ fn=update_bbox_display,
531
+ inputs=[multi_person_image],
532
+ outputs=[bbox_preview, manual_bbox_input, bbox_preview]
533
+ )
534
+
535
+ generate_btn.click(
536
+ fn=process_and_generate,
537
+ inputs=[
538
+ prompt, width, height, guidance, num_steps, seed,
539
+ ref_img1, ref_img2, ref_img3, ref_img4,
540
+ manual_bbox_input, multi_person_image,
541
+ siglip_weight
542
+ ],
543
+ outputs=[output_image,debug_face, download_btn]
544
+ )
545
+
546
+ return demo
547
+
548
+
549
+ if __name__ == "__main__":
550
+ from transformers import HfArgumentParser
551
+
552
+ @dataclasses.dataclass
553
+ class AppArgs:
554
+ model_type: Literal["flux-dev", "flux-dev-fp8", "flux-schnell"] = "flux-dev"
555
+ device: Literal["cuda", "cpu"] = (
556
+ "cuda" if torch.cuda.is_available()
557
+ else "mps" if hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
558
+ else "cpu"
559
+ )
560
+ offload: bool = False
561
+ lora_rank: int = 64
562
+ port: int = 7860
563
+ additional_lora: str = None
564
+ lora_scale: float = 1.0
565
+ ipa_path: str = "WithAnyone/WithAnyone"
566
+ clip_path: str = "openai/clip-vit-large-patch14"
567
+ t5_path: str = "xlabs-ai/xflux_text_encoders"
568
+ flux_path: str = "black-forest-labs/FLUX.1-dev"
569
+
570
+ parser = HfArgumentParser([AppArgs])
571
+ args = parser.parse_args_into_dataclasses()[0]
572
+
573
+ demo = create_demo(
574
+ args.model_type,
575
+ args.ipa_path,
576
+ args.device,
577
+ args.offload,
578
+ args.lora_rank,
579
+ args.additional_lora,
580
+ args.lora_scale,
581
+ args.clip_path,
582
+ args.t5_path,
583
+ args.flux_path,
584
+ )
585
+ demo.launch(server_port=args.port)
586
+
gradio_app.py ADDED
@@ -0,0 +1,568 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Fudan University. All rights reserved.
2
+
3
+
4
+
5
+ import dataclasses
6
+ import json
7
+ import os
8
+ from pathlib import Path
9
+ from typing import List, Literal, Optional
10
+
11
+ import cv2
12
+ import gradio as gr
13
+ import numpy as np
14
+ import torch
15
+ from PIL import Image, ImageDraw
16
+
17
+ from withanyone.flux.pipeline import WithAnyonePipeline
18
+ from util import extract_moref, face_preserving_resize
19
+ import insightface
20
+
21
+
22
+ def captioner(prompt: str, num_person = 1) -> List[List[float]]:
23
+ # use random choose for testing
24
+ # within 512
25
+ if num_person == 1:
26
+ bbox_choices = [
27
+ # expanded, centered and quadrant placements
28
+ [96, 96, 288, 288],
29
+ [128, 128, 320, 320],
30
+ [160, 96, 352, 288],
31
+ [96, 160, 288, 352],
32
+ [208, 96, 400, 288],
33
+ [96, 208, 288, 400],
34
+ [192, 160, 368, 336],
35
+ [64, 128, 224, 320],
36
+ [288, 128, 448, 320],
37
+ [128, 256, 320, 448],
38
+ [80, 80, 240, 272],
39
+ [196, 196, 380, 380],
40
+ # originals
41
+ [100, 100, 300, 300],
42
+ [150, 50, 450, 350],
43
+ [200, 100, 500, 400],
44
+ [250, 150, 512, 450],
45
+ ]
46
+ return [bbox_choices[np.random.randint(0, len(bbox_choices))]]
47
+ elif num_person == 2:
48
+ # realistic side-by-side rows (no vertical stacks or diagonals)
49
+ bbox_choices = [
50
+ [[64, 112, 224, 304], [288, 112, 448, 304]],
51
+ [[48, 128, 208, 320], [304, 128, 464, 320]],
52
+ [[32, 144, 192, 336], [320, 144, 480, 336]],
53
+ [[80, 96, 240, 288], [272, 96, 432, 288]],
54
+ [[80, 160, 240, 352], [272, 160, 432, 352]],
55
+ [[64, 128, 240, 336], [272, 144, 432, 320]], # slight stagger, same row
56
+ [[96, 160, 256, 352], [288, 160, 448, 352]],
57
+ [[64, 192, 224, 384], [288, 192, 448, 384]], # lower row
58
+ [[16, 128, 176, 320], [336, 128, 496, 320]], # near edges
59
+ [[48, 120, 232, 328], [280, 120, 464, 328]],
60
+ [[96, 160, 240, 336], [272, 160, 416, 336]], # tighter faces
61
+ [[72, 136, 232, 328], [280, 152, 440, 344]], # small vertical offset
62
+ [[48, 120, 224, 344], [288, 144, 448, 336]], # asymmetric sizes
63
+ [[80, 224, 240, 416], [272, 224, 432, 416]], # bottom row
64
+ [[80, 64, 240, 256], [272, 64, 432, 256]], # top row
65
+ [[96, 176, 256, 368], [288, 176, 448, 368]],
66
+ ]
67
+ return bbox_choices[np.random.randint(0, len(bbox_choices))]
68
+
69
+ elif num_person == 3:
70
+ # Non-overlapping 3-person layouts within 512x512
71
+ bbox_choices = [
72
+ [[20, 140, 150, 360], [180, 120, 330, 360], [360, 130, 500, 360]],
73
+ [[30, 100, 160, 300], [190, 90, 320, 290], [350, 110, 480, 310]],
74
+ [[40, 180, 150, 330], [200, 180, 310, 330], [360, 180, 470, 330]],
75
+ [[60, 120, 170, 300], [210, 110, 320, 290], [350, 140, 480, 320]],
76
+ [[50, 80, 170, 250], [200, 130, 320, 300], [350, 80, 480, 250]],
77
+ [[40, 260, 170, 480], [190, 60, 320, 240], [350, 260, 490, 480]],
78
+ [[30, 120, 150, 320], [200, 140, 320, 340], [360, 160, 500, 360]],
79
+ [[80, 140, 200, 300], [220, 80, 350, 260], [370, 160, 500, 320]],
80
+ ]
81
+ return bbox_choices[np.random.randint(0, len(bbox_choices))]
82
+ elif num_person == 4:
83
+ # Non-overlapping 4-person layouts within 512x512
84
+ bbox_choices = [
85
+ [[20, 100, 120, 240], [140, 100, 240, 240], [260, 100, 360, 240], [380, 100, 480, 240]],
86
+ [[40, 60, 200, 260], [220, 60, 380, 260], [40, 280, 200, 480], [220, 280, 380, 480]],
87
+ [[180, 30, 330, 170], [30, 220, 150, 380], [200, 220, 320, 380], [360, 220, 490, 380]],
88
+ [[30, 60, 140, 200], [370, 60, 480, 200], [30, 320, 140, 460], [370, 320, 480, 460]],
89
+ [[20, 120, 120, 380], [140, 100, 240, 360], [260, 120, 360, 380], [380, 100, 480, 360]],
90
+ [[30, 80, 150, 240], [180, 120, 300, 280], [330, 80, 450, 240], [200, 300, 320, 460]],
91
+ [[30, 140, 110, 330], [140, 140, 220, 330], [250, 140, 330, 330], [370, 140, 450, 330]],
92
+ [[40, 80, 150, 240], [40, 260, 150, 420], [200, 80, 310, 240], [370, 80, 480, 240]],
93
+ ]
94
+ return bbox_choices[np.random.randint(0, len(bbox_choices))]
95
+
96
+
97
+
98
+
99
+ class FaceExtractor:
100
+ def __init__(self, model_path="./"):
101
+ self.model = insightface.app.FaceAnalysis(name="antelopev2", root="./")
102
+ self.model.prepare(ctx_id=0)
103
+
104
+ def extract(self, image: Image.Image):
105
+ """Extract single face and embedding from an image"""
106
+ image_np = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
107
+ res = self.model.get(image_np)
108
+ if len(res) == 0:
109
+ return None, None
110
+ res = res[0]
111
+ bbox = res["bbox"]
112
+ moref = extract_moref(image, {"bboxes": [bbox]}, 1)
113
+ return moref[0], res["embedding"]
114
+
115
+ def extract_refs(self, image: Image.Image):
116
+ """Extract multiple faces and embeddings from an image"""
117
+ image_np = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
118
+ res = self.model.get(image_np)
119
+ if len(res) == 0:
120
+ return None, None, None
121
+ ref_imgs = []
122
+ arcface_embeddings = []
123
+ bboxes = []
124
+ for r in res:
125
+ bbox = r["bbox"]
126
+ bboxes.append(bbox)
127
+ moref = extract_moref(image, {"bboxes": [bbox]}, 1)
128
+ ref_imgs.append(moref[0])
129
+ arcface_embeddings.append(r["embedding"])
130
+
131
+ # Convert bboxes to the correct format
132
+ new_img, new_bboxes = face_preserving_resize(image, bboxes, 512)
133
+ return ref_imgs, arcface_embeddings, new_bboxes, new_img
134
+
135
+
136
+ def resize_bbox(bbox, ori_width, ori_height, new_width, new_height):
137
+ """Resize bounding box coordinates while preserving aspect ratio"""
138
+ x1, y1, x2, y2 = bbox
139
+
140
+ # Calculate scaling factors
141
+ width_scale = new_width / ori_width
142
+ height_scale = new_height / ori_height
143
+
144
+ # Use minimum scaling factor to preserve aspect ratio
145
+ min_scale = min(width_scale, height_scale)
146
+
147
+ # Calculate offsets for centering the scaled box
148
+ width_offset = (new_width - ori_width * min_scale) / 2
149
+ height_offset = (new_height - ori_height * min_scale) / 2
150
+
151
+ # Scale and adjust coordinates
152
+ new_x1 = int(x1 * min_scale + width_offset)
153
+ new_y1 = int(y1 * min_scale + height_offset)
154
+ new_x2 = int(x2 * min_scale + width_offset)
155
+ new_y2 = int(y2 * min_scale + height_offset)
156
+
157
+ return [new_x1, new_y1, new_x2, new_y2]
158
+
159
+
160
+ def draw_bboxes_on_image(image, bboxes):
161
+ """Draw bounding boxes on image for visualization"""
162
+ if bboxes is None:
163
+ return image
164
+
165
+ # Create a copy to draw on
166
+ img_draw = image.copy()
167
+ draw = ImageDraw.Draw(img_draw)
168
+
169
+ # Draw each bbox with a different color
170
+ colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0)]
171
+
172
+ for i, bbox in enumerate(bboxes):
173
+ color = colors[i % len(colors)]
174
+ x1, y1, x2, y2 = [int(coord) for coord in bbox]
175
+ # Draw rectangle
176
+ draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
177
+ # Draw label
178
+ draw.text((x1, y1-15), f"Face {i+1}", fill=color)
179
+
180
+ return img_draw
181
+
182
+
183
+ def create_demo(
184
+ model_type: str = "flux-dev",
185
+ ipa_path: str = "./ckpt/ipa.safetensors",
186
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
187
+ offload: bool = False,
188
+ lora_rank: int = 64,
189
+ additional_lora_ckpt: Optional[str] = None,
190
+ lora_scale: float = 1.0,
191
+ clip_path: str = "openai/clip-vit-large-patch14",
192
+ t5_path: str = "xlabs-ai/xflux_text_encoders",
193
+ flux_path: str = "black-forest-labs/FLUX.1-dev",
194
+ ):
195
+
196
+ face_extractor = FaceExtractor()
197
+ # Initialize pipeline and face extractor
198
+ pipeline = WithAnyonePipeline(
199
+ model_type,
200
+ ipa_path,
201
+ device,
202
+ offload,
203
+ only_lora=True,
204
+ no_lora=True,
205
+ lora_rank=lora_rank,
206
+ additional_lora_ckpt=additional_lora_ckpt,
207
+ lora_weight=lora_scale,
208
+ face_extractor=face_extractor,
209
+ clip_path=clip_path,
210
+ t5_path=t5_path,
211
+ flux_path=flux_path,
212
+ )
213
+
214
+
215
+
216
+ # Add project badges
217
+ # badges_text = r"""
218
+ # <div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
219
+ # <a href="https://github.com/bytedance/UNO"><img alt="Build" src="https://img.shields.io/github/stars/bytedance/UNO"></a>
220
+ # <a href="https://bytedance.github.io/UNO/"><img alt="Build" src="https://img.shields.io/badge/Project%20Page-UNO-yellow"></a>
221
+ # <a href="https://arxiv.org/abs/2504.02160"><img alt="Build" src="https://img.shields.io/badge/arXiv%20paper-UNO-b31b1b.svg"></a>
222
+ # </div>
223
+ # """.strip()
224
+
225
+ def parse_bboxes(bbox_text):
226
+ """Parse bounding box text input"""
227
+ if not bbox_text or bbox_text.strip() == "":
228
+ return None
229
+
230
+ try:
231
+ bboxes = []
232
+ lines = bbox_text.strip().split("\n")
233
+ for line in lines:
234
+ if not line.strip():
235
+ continue
236
+ coords = [float(x) for x in line.strip().split(",")]
237
+ if len(coords) != 4:
238
+ raise ValueError(f"Each bbox must have 4 coordinates (x1,y1,x2,y2), got: {line}")
239
+ bboxes.append(coords)
240
+ # print(f"\nParsed bboxes: {bboxes}\n")
241
+ return bboxes
242
+ except Exception as e:
243
+ raise gr.Error(f"Invalid bbox format: {e}")
244
+
245
+ def extract_from_multi_person(multi_person_image):
246
+ """Extract references and bboxes from a multi-person image"""
247
+ if multi_person_image is None:
248
+ return None, None, None, None
249
+
250
+ # Convert from numpy to PIL if needed
251
+ if isinstance(multi_person_image, np.ndarray):
252
+ multi_person_image = Image.fromarray(multi_person_image)
253
+
254
+ ref_imgs, arcface_embeddings, bboxes, new_img = face_extractor.extract_refs(multi_person_image)
255
+
256
+ if ref_imgs is None or len(ref_imgs) == 0:
257
+ raise gr.Error("No faces detected in the multi-person image")
258
+
259
+ # Limit to max 4 faces
260
+ ref_imgs = ref_imgs[:4]
261
+ arcface_embeddings = arcface_embeddings[:4]
262
+ bboxes = bboxes[:4]
263
+
264
+ # Create visualization with bboxes
265
+ viz_image = draw_bboxes_on_image(new_img, bboxes)
266
+
267
+ # Format bboxes as string for display
268
+ bbox_text = "\n".join([f"{bbox[0]:.1f},{bbox[1]:.1f},{bbox[2]:.1f},{bbox[3]:.1f}" for bbox in bboxes])
269
+
270
+ return ref_imgs, arcface_embeddings, bboxes, viz_image
271
+
272
+ def process_and_generate(
273
+ prompt,
274
+ width, height,
275
+ guidance, num_steps, seed,
276
+ ref_img1, ref_img2, ref_img3, ref_img4,
277
+ manual_bboxes_text,
278
+ multi_person_image,
279
+ # use_text_prompt,
280
+ # id_weight,
281
+ siglip_weight
282
+ ):
283
+ # Collect and validate reference images
284
+ ref_images = [img for img in [ref_img1, ref_img2, ref_img3, ref_img4] if img is not None]
285
+
286
+ if not ref_images:
287
+ raise gr.Error("At least one reference image is required")
288
+
289
+ # Process reference images to extract face and embeddings
290
+ ref_imgs = []
291
+ arcface_embeddings = []
292
+
293
+ # Modified bbox handling logic
294
+ if multi_person_image is not None:
295
+ # Extract from multi-person image mode
296
+ extracted_refs, extracted_embeddings, bboxes_, _ = extract_from_multi_person(multi_person_image)
297
+ if extracted_refs is None:
298
+ raise gr.Error("Failed to extract faces from the multi-person image")
299
+
300
+ print("bboxes from multi-person image:", bboxes_)
301
+ # need to resize bboxes from 512 512 to width height
302
+ bboxes_ = [resize_bbox(bbox, 512, 512, width, height) for bbox in bboxes_]
303
+
304
+ else:
305
+ # Parse manual bboxes
306
+ bboxes_ = parse_bboxes(manual_bboxes_text)
307
+
308
+ # If no manual bboxes provided, use automatic captioner
309
+ if bboxes_ is None:
310
+ print("No multi-person image or manual bboxes provided. Using automatic captioner.")
311
+ # Generate automatic bboxes based on image dimensions
312
+ bboxes__ = captioner(prompt, num_person=len(ref_images))
313
+ # resize to width height
314
+ bboxes_ = [resize_bbox(bbox, 512, 512, width, height) for bbox in bboxes__]
315
+ print("Automatically generated bboxes:", bboxes_)
316
+
317
+ bboxes = [bboxes_] # 伪装batch输入
318
+ # else:
319
+ # Manual mode: process each reference image
320
+ for img in ref_images:
321
+ if isinstance(img, np.ndarray):
322
+ img = Image.fromarray(img)
323
+
324
+ ref_img, embedding = face_extractor.extract(img)
325
+ if ref_img is None or embedding is None:
326
+ raise gr.Error("Failed to extract face from one of the reference images")
327
+
328
+ ref_imgs.append(ref_img)
329
+ arcface_embeddings.append(embedding)
330
+
331
+ # pad arcface_embeddings to 4 if less than 4
332
+ # while len(arcface_embeddings) < 4:
333
+ # arcface_embeddings.append(np.zeros_like(arcface_embeddings[0]))
334
+
335
+
336
+ if bboxes is None:
337
+ raise gr.Error("Either provide manual bboxes or a multi-person image for bbox extraction")
338
+
339
+ if len(bboxes[0]) != len(ref_imgs):
340
+ raise gr.Error(f"Number of bboxes ({len(bboxes[0])}) must match number of reference images ({len(ref_imgs)})")
341
+
342
+ # Convert arcface embeddings to tensor
343
+ arcface_embeddings = [torch.tensor(embedding) for embedding in arcface_embeddings]
344
+ arcface_embeddings = torch.stack(arcface_embeddings).to(device)
345
+
346
+ # Generate image
347
+ final_prompt = prompt
348
+
349
+ print(f"Generating image of size {width}x{height} with bboxes: {bboxes} ")
350
+
351
+ if seed < 0:
352
+ seed = np.random.randint(0, 1000000)
353
+
354
+ image_gen = pipeline(
355
+ prompt=final_prompt,
356
+ width=width,
357
+ height=height,
358
+ guidance=guidance,
359
+ num_steps=num_steps,
360
+ seed=seed if seed > 0 else None,
361
+ ref_imgs=ref_imgs,
362
+ arcface_embeddings=arcface_embeddings,
363
+ bboxes=bboxes,
364
+ id_weight = 1 - siglip_weight,
365
+ siglip_weight=siglip_weight,
366
+ )
367
+
368
+ # Save temp file for download
369
+ temp_path = "temp_generated.png"
370
+ image_gen.save(temp_path)
371
+
372
+ # draw bboxes on the generated image for debug
373
+ debug_face = draw_bboxes_on_image(image_gen, bboxes[0])
374
+
375
+ return image_gen, debug_face, temp_path
376
+
377
+ def update_bbox_display(multi_person_image):
378
+ if multi_person_image is None:
379
+ return None, gr.update(visible=True), gr.update(visible=False)
380
+
381
+ try:
382
+ _, _, _, viz_image = extract_from_multi_person(multi_person_image)
383
+ return viz_image, gr.update(visible=False), gr.update(visible=True)
384
+ except Exception as e:
385
+ return None, gr.update(visible=True), gr.update(visible=False)
386
+
387
+ # Create Gradio interface
388
+ with gr.Blocks() as demo:
389
+ gr.Markdown("# WithAnyone Demo")
390
+ # gr.Markdown(badges_text)
391
+
392
+ with gr.Row():
393
+
394
+ with gr.Column():
395
+ # Input controls
396
+ generate_btn = gr.Button("Generate", variant="primary")
397
+ with gr.Row():
398
+ with gr.Column():
399
+ siglip_weight = gr.Slider(0.0, 1.0, 1.0, step=0.05, label="Spiritual Resemblance <--> Formal Resemblance")
400
+ with gr.Row():
401
+ prompt = gr.Textbox(label="Prompt", value="a person in a beautiful garden. High resolution, extremely detailed")
402
+ # use_text_prompt = gr.Checkbox(label="Use text prompt", value=True)
403
+
404
+
405
+ with gr.Row():
406
+ # Image generation settings
407
+ with gr.Column():
408
+ width = gr.Slider(512, 1024, 768, step=64, label="Generation Width")
409
+ height = gr.Slider(512, 1024, 768, step=64, label="Generation Height")
410
+
411
+ with gr.Accordion("Advanced Options", open=False):
412
+ with gr.Row():
413
+ num_steps = gr.Slider(1, 50, 25, step=1, label="Number of steps")
414
+ guidance = gr.Slider(1.0, 10.0, 4.0, step=0.1, label="Guidance")
415
+ seed = gr.Number(-1, label="Seed (-1 for random)")
416
+
417
+ # start_at = gr.Slider(0, 50, 0, step=1, label="Start Identity at Step")
418
+ # end_at = gr.Number(-1, label="End Identity at Step (-1 for last)")
419
+
420
+ # with gr.Row():
421
+ # # skip_every = gr.Number(-1, label="Skip Identity Every N Steps (-1 for no skip)")
422
+
423
+ # siglip_weight = gr.Slider(0.0, 1.0, 1.0, step=0.05, label="Siglip Weight")
424
+
425
+
426
+ with gr.Row():
427
+ with gr.Column():
428
+ # Reference image inputs
429
+ gr.Markdown("### Face References (1-4 required)")
430
+ ref_img1 = gr.Image(label="Reference 1", type="pil")
431
+ ref_img2 = gr.Image(label="Reference 2", type="pil", visible=True)
432
+ ref_img3 = gr.Image(label="Reference 3", type="pil", visible=True)
433
+ ref_img4 = gr.Image(label="Reference 4", type="pil", visible=True)
434
+
435
+ with gr.Column():
436
+ # Bounding box inputs
437
+ gr.Markdown("### Mask Configuration (Option 1: Automatic)")
438
+ multi_person_image = gr.Image(label="Multi-person image (for automatic bbox extraction)", type="pil")
439
+ bbox_preview = gr.Image(label="Detected Faces", type="pil")
440
+
441
+
442
+ gr.Markdown("### Mask Configuration (Option 2: Manual)")
443
+ manual_bbox_input = gr.Textbox(
444
+ label="Manual Bounding Boxes (one per line, format: x1,y1,x2,y2)",
445
+ lines=4,
446
+ placeholder="100,100,200,200\n300,100,400,200"
447
+ )
448
+
449
+
450
+
451
+
452
+
453
+ # generate_btn = gr.Button("Generate", variant="primary")
454
+
455
+ with gr.Column():
456
+ # Output display
457
+ output_image = gr.Image(label="Generated Image")
458
+ debug_face = gr.Image(label="Debug. Faces are expected to be generated in these boxes")
459
+ download_btn = gr.File(label="Download full-resolution", type="filepath", interactive=False)
460
+
461
+ # Examples section
462
+ with gr.Row():
463
+
464
+ gr.Markdown("""
465
+ # Example Configurations
466
+
467
+ ### Tips for Better Results
468
+ Be prepared for the first few runs as it may not be very satisfying.
469
+
470
+ - Provide detailed prompts describing the identity. WithAnyone is "controllable", so it needs more information to be controlled. Here are something that might go wrong if not specified:
471
+ - Skin color (generally the race is fine, but for asain descent, if not specified, it may generate darker skin tone);
472
+ - Age (e.g., intead of "a man", try "a young man". If not specified, it may generate an older figure);
473
+ - Body build;
474
+ - Hairstyle;
475
+ - Accessories (glasses, hats, earrings, etc.);
476
+ - Makeup
477
+ - Use the slider to balance between "Resemblance in Spirit" and "Resemblance in Form" according to your needs. If you want to preserve more details in the reference image, move the slider to the right; if you want more freedom and creativity, move it to the left.
478
+ - Try it with LoRAs from community. They are usually fantastic.
479
+ """)
480
+ with gr.Row():
481
+ examples = gr.Examples(
482
+ examples=[
483
+ [
484
+ "a highly detailed portrait of a woman shown in profile. Her long, dark hair flows elegantly, intricately decorated with an abundant array of colorful flowers—ranging from soft light pinks and vibrant light oranges to delicate greyish blues—and lush green leaves, giving a sense of natural beauty and charm. Her bright blue eyes are striking, and her lips are painted a vivid red, adding to her alluring appearance. She is clad in an ornate garment with intricate floral patterns in warm hues like pink and orange, featuring exquisite detailing that speaks of fine craftsmanship. Around her neck, she wears a decorative choker with intricate designs, and dangling from her ears are beautiful blue teardrop earrings that catch the light. The background is filled with a profusion of flowers in various shades, creating a rich, vibrant, and romantic atmosphere that complements the woman's elegant and enchanting look.", # prompt
485
+ 1024, 1024, # width, height
486
+ 4.0, 25, 42, # guidance, num_steps, seed
487
+ "assets/ref1.jpg", None, None, None, # ref images
488
+ "240,180,540,500", None, # manual_bbox_input, multi_person_image
489
+ # True, # use_text_prompt
490
+ 0.0, # siglip_weight
491
+ ],
492
+ [
493
+ "High resolution anfd extremely detailed image of two elegant ladies enjoying a serene afternoon in a quaint Parisian café. They both wear fashionable trench coats and stylish berets, exuding an air of sophistication. One lady gently sips on a cappuccino, while her companion reads an intriguing novel with a subtle smile. The café is framed by charming antique furniture and vintage posters adorning the walls. Soft, warm light filters through a window, casting delicate shadows and creating a cozy, inviting atmosphere. Captured from a slightly elevated angle, the composition highlights the warmth of the scene in a gentle watercolor illustrative style. ", # prompt
494
+ 1024, 1024, # width, height
495
+ 4.0, 25, 42, # guidance, num_steps, seed
496
+ "assets/ref1.jpg", "assets/ref2.jpg", None, None, # ref images
497
+ "248,172,428,498\n554,128,728,464", None, # manual_bbox_input, multi_person_image
498
+ # True, # use_text_prompt
499
+ 0.0, # siglip_weight
500
+ ]
501
+ ],
502
+ inputs=[
503
+ prompt, width, height, guidance, num_steps, seed,
504
+ ref_img1, ref_img2, ref_img3, ref_img4,
505
+ manual_bbox_input, multi_person_image,
506
+ siglip_weight
507
+ ],
508
+ label="Click to load example configurations"
509
+ )
510
+ # Set up event handlers
511
+ multi_person_image.change(
512
+ fn=update_bbox_display,
513
+ inputs=[multi_person_image],
514
+ outputs=[bbox_preview, manual_bbox_input, bbox_preview]
515
+ )
516
+
517
+ generate_btn.click(
518
+ fn=process_and_generate,
519
+ inputs=[
520
+ prompt, width, height, guidance, num_steps, seed,
521
+ ref_img1, ref_img2, ref_img3, ref_img4,
522
+ manual_bbox_input, multi_person_image,
523
+ siglip_weight
524
+ ],
525
+ outputs=[output_image,debug_face, download_btn]
526
+ )
527
+
528
+ return demo
529
+
530
+
531
+ if __name__ == "__main__":
532
+ from transformers import HfArgumentParser
533
+
534
+ @dataclasses.dataclass
535
+ class AppArgs:
536
+ model_type: Literal["flux-dev", "flux-dev-fp8", "flux-schnell"] = "flux-dev"
537
+ device: Literal["cuda", "cpu"] = (
538
+ "cuda" if torch.cuda.is_available()
539
+ else "mps" if hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
540
+ else "cpu"
541
+ )
542
+ offload: bool = False
543
+ lora_rank: int = 64
544
+ port: int = 7860
545
+ additional_lora: str = None
546
+ lora_scale: float = 1.0
547
+ ipa_path: str = "./ckpt/ipa.safetensors"
548
+ clip_path: str = "openai/clip-vit-large-patch14"
549
+ t5_path: str = "xlabs-ai/xflux_text_encoders"
550
+ flux_path: str = "black-forest-labs/FLUX.1-dev"
551
+
552
+ parser = HfArgumentParser([AppArgs])
553
+ args = parser.parse_args_into_dataclasses()[0]
554
+
555
+ demo = create_demo(
556
+ args.model_type,
557
+ args.ipa_path,
558
+ args.device,
559
+ args.offload,
560
+ args.lora_rank,
561
+ args.additional_lora,
562
+ args.lora_scale,
563
+ args.clip_path,
564
+ args.t5_path,
565
+ args.flux_path,
566
+ )
567
+ demo.launch(server_port=args.port)
568
+
gradio_edit.py ADDED
@@ -0,0 +1,563 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Fudan University. All rights reserved.
2
+
3
+
4
+
5
+ import dataclasses
6
+ import json
7
+ import os
8
+ from pathlib import Path
9
+ from typing import List, Literal, Optional, Tuple, Union
10
+ from io import BytesIO
11
+
12
+ import cv2
13
+ import gradio as gr
14
+ import numpy as np
15
+ import torch
16
+ from PIL import Image, ImageDraw, ImageFilter
17
+ from PIL.JpegImagePlugin import JpegImageFile
18
+
19
+ from withanyone_kontext_s.flux.pipeline import WithAnyonePipeline
20
+ from util import extract_moref, face_preserving_resize
21
+ import insightface
22
+
23
+
24
+ def blur_faces_in_image(img, json_data, face_size_threshold=100, blur_radius=15):
25
+ """
26
+ Blurs facial areas directly in the original image for privacy protection.
27
+
28
+ Args:
29
+ img: PIL Image or image data
30
+ json_data: JSON object with 'bboxes' and 'crop' information
31
+ face_size_threshold: Minimum size for faces to be considered (default: 100 pixels)
32
+ blur_radius: Strength of the blur effect (higher = more blurred)
33
+
34
+ Returns:
35
+ PIL Image with faces blurred
36
+ """
37
+ # Ensure img is a PIL Image
38
+ if not isinstance(img, Image.Image) and not isinstance(img, torch.Tensor) and not isinstance(img, JpegImageFile):
39
+ img = Image.open(BytesIO(img))
40
+
41
+ new_bboxes = json_data['bboxes']
42
+ # crop = json_data['crop'] if 'crop' in json_data else [0, 0, img.width, img.height]
43
+
44
+ # # Recalculate bounding boxes based on crop info
45
+ # new_bboxes = [recalculate_bbox(bbox, crop) for bbox in bboxes]
46
+
47
+ # Check face sizes and filter out faces that are too small
48
+ valid_bboxes = []
49
+ for bbox in new_bboxes:
50
+ x1, y1, x2, y2 = bbox
51
+ if x2 - x1 >= face_size_threshold and y2 - y1 >= face_size_threshold:
52
+ valid_bboxes.append(bbox)
53
+
54
+ # If no valid faces found, return original image
55
+ if not valid_bboxes:
56
+ return img
57
+
58
+ # Create a copy of the original image to modify
59
+ blurred_img = img.copy()
60
+
61
+ # Process each face
62
+ for bbox in valid_bboxes:
63
+ # Convert coordinates to integers
64
+ x1, y1, x2, y2 = map(int, bbox)
65
+
66
+ # Ensure coordinates are within image boundaries
67
+ img_width, img_height = img.size
68
+ x1 = max(0, x1)
69
+ y1 = max(0, y1)
70
+ x2 = min(img_width, x2)
71
+ y2 = min(img_height, y2)
72
+
73
+ # Extract the face region
74
+ face_region = img.crop((x1, y1, x2, y2))
75
+
76
+ # Apply blur to the face region
77
+ blurred_face = face_region.filter(ImageFilter.GaussianBlur(radius=blur_radius))
78
+
79
+ # Paste the blurred face back into the image
80
+ blurred_img.paste(blurred_face, (x1, y1))
81
+
82
+ return blurred_img
83
+
84
+
85
+ def captioner(prompt: str, num_person = 1) -> List[List[float]]:
86
+ # use random choose for testing
87
+ # within 512
88
+ if num_person == 1:
89
+ bbox_choices = [
90
+ # expanded, centered and quadrant placements
91
+ [96, 96, 288, 288],
92
+ [128, 128, 320, 320],
93
+ [160, 96, 352, 288],
94
+ [96, 160, 288, 352],
95
+ [208, 96, 400, 288],
96
+ [96, 208, 288, 400],
97
+ [192, 160, 368, 336],
98
+ [64, 128, 224, 320],
99
+ [288, 128, 448, 320],
100
+ [128, 256, 320, 448],
101
+ [80, 80, 240, 272],
102
+ [196, 196, 380, 380],
103
+ # originals
104
+ [100, 100, 300, 300],
105
+ [150, 50, 450, 350],
106
+ [200, 100, 500, 400],
107
+ [250, 150, 512, 450],
108
+ ]
109
+ return [bbox_choices[np.random.randint(0, len(bbox_choices))]]
110
+ elif num_person == 2:
111
+ # realistic side-by-side rows (no vertical stacks or diagonals)
112
+ bbox_choices = [
113
+ [[64, 112, 224, 304], [288, 112, 448, 304]],
114
+ [[48, 128, 208, 320], [304, 128, 464, 320]],
115
+ [[32, 144, 192, 336], [320, 144, 480, 336]],
116
+ [[80, 96, 240, 288], [272, 96, 432, 288]],
117
+ [[80, 160, 240, 352], [272, 160, 432, 352]],
118
+ [[64, 128, 240, 336], [272, 144, 432, 320]], # slight stagger, same row
119
+ [[96, 160, 256, 352], [288, 160, 448, 352]],
120
+ [[64, 192, 224, 384], [288, 192, 448, 384]], # lower row
121
+ [[16, 128, 176, 320], [336, 128, 496, 320]], # near edges
122
+ [[48, 120, 232, 328], [280, 120, 464, 328]],
123
+ [[96, 160, 240, 336], [272, 160, 416, 336]], # tighter faces
124
+ [[72, 136, 232, 328], [280, 152, 440, 344]], # small vertical offset
125
+ [[48, 120, 224, 344], [288, 144, 448, 336]], # asymmetric sizes
126
+ [[80, 224, 240, 416], [272, 224, 432, 416]], # bottom row
127
+ [[80, 64, 240, 256], [272, 64, 432, 256]], # top row
128
+ [[96, 176, 256, 368], [288, 176, 448, 368]],
129
+ ]
130
+ return bbox_choices[np.random.randint(0, len(bbox_choices))]
131
+
132
+ elif num_person == 3:
133
+ # Non-overlapping 3-person layouts within 512x512
134
+ bbox_choices = [
135
+ [[20, 140, 150, 360], [180, 120, 330, 360], [360, 130, 500, 360]],
136
+ [[30, 100, 160, 300], [190, 90, 320, 290], [350, 110, 480, 310]],
137
+ [[40, 180, 150, 330], [200, 180, 310, 330], [360, 180, 470, 330]],
138
+ [[60, 120, 170, 300], [210, 110, 320, 290], [350, 140, 480, 320]],
139
+ [[50, 80, 170, 250], [200, 130, 320, 300], [350, 80, 480, 250]],
140
+ [[40, 260, 170, 480], [190, 60, 320, 240], [350, 260, 490, 480]],
141
+ [[30, 120, 150, 320], [200, 140, 320, 340], [360, 160, 500, 360]],
142
+ [[80, 140, 200, 300], [220, 80, 350, 260], [370, 160, 500, 320]],
143
+ ]
144
+ return bbox_choices[np.random.randint(0, len(bbox_choices))]
145
+ elif num_person == 4:
146
+ # Non-overlapping 4-person layouts within 512x512
147
+ bbox_choices = [
148
+ [[20, 100, 120, 240], [140, 100, 240, 240], [260, 100, 360, 240], [380, 100, 480, 240]],
149
+ [[40, 60, 200, 260], [220, 60, 380, 260], [40, 280, 200, 480], [220, 280, 380, 480]],
150
+ [[180, 30, 330, 170], [30, 220, 150, 380], [200, 220, 320, 380], [360, 220, 490, 380]],
151
+ [[30, 60, 140, 200], [370, 60, 480, 200], [30, 320, 140, 460], [370, 320, 480, 460]],
152
+ [[20, 120, 120, 380], [140, 100, 240, 360], [260, 120, 360, 380], [380, 100, 480, 360]],
153
+ [[30, 80, 150, 240], [180, 120, 300, 280], [330, 80, 450, 240], [200, 300, 320, 460]],
154
+ [[30, 140, 110, 330], [140, 140, 220, 330], [250, 140, 330, 330], [370, 140, 450, 330]],
155
+ [[40, 80, 150, 240], [40, 260, 150, 420], [200, 80, 310, 240], [370, 80, 480, 240]],
156
+ ]
157
+ return bbox_choices[np.random.randint(0, len(bbox_choices))]
158
+
159
+
160
+ class FaceExtractor:
161
+ def __init__(self, model_path="./"):
162
+ self.model = insightface.app.FaceAnalysis(name="antelopev2", root="./")
163
+ self.model.prepare(ctx_id=0)
164
+
165
+ def extract(self, image: Image.Image):
166
+ """Extract single face and embedding from an image"""
167
+ image_np = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
168
+ res = self.model.get(image_np)
169
+ if len(res) == 0:
170
+ return None, None
171
+ res = res[0]
172
+ bbox = res["bbox"]
173
+ moref = extract_moref(image, {"bboxes": [bbox]}, 1)
174
+ return moref[0], res["embedding"]
175
+
176
+ def extract_refs(self, image: Image.Image):
177
+ """Extract multiple faces and embeddings from an image"""
178
+ image_np = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
179
+ res = self.model.get(image_np)
180
+ if len(res) == 0:
181
+ return None, None, None, None
182
+ ref_imgs = []
183
+ arcface_embeddings = []
184
+ bboxes = []
185
+ for r in res:
186
+ bbox = r["bbox"]
187
+ bboxes.append(bbox)
188
+ moref = extract_moref(image, {"bboxes": [bbox]}, 1)
189
+ ref_imgs.append(moref[0])
190
+ arcface_embeddings.append(r["embedding"])
191
+
192
+ # Convert bboxes to the correct format
193
+ new_img, new_bboxes = face_preserving_resize(image, bboxes, 512)
194
+ return ref_imgs, arcface_embeddings, new_bboxes, new_img
195
+
196
+
197
+ def resize_bbox(bbox, ori_width, ori_height, new_width, new_height):
198
+ """Resize bounding box coordinates while preserving aspect ratio"""
199
+ x1, y1, x2, y2 = bbox
200
+
201
+ # Calculate scaling factors
202
+ width_scale = new_width / ori_width
203
+ height_scale = new_height / ori_height
204
+
205
+ # Use minimum scaling factor to preserve aspect ratio
206
+ min_scale = min(width_scale, height_scale)
207
+
208
+ # Calculate offsets for centering the scaled box
209
+ width_offset = (new_width - ori_width * min_scale) / 2
210
+ height_offset = (new_height - ori_height * min_scale) / 2
211
+
212
+ # Scale and adjust coordinates
213
+ new_x1 = int(x1 * min_scale + width_offset)
214
+ new_y1 = int(y1 * min_scale + height_offset)
215
+ new_x2 = int(x2 * min_scale + width_offset)
216
+ new_y2 = int(y2 * min_scale + height_offset)
217
+
218
+ return [new_x1, new_y1, new_x2, new_y2]
219
+
220
+
221
+ def draw_bboxes_on_image(image, bboxes):
222
+ """Draw bounding boxes on image for visualization"""
223
+ if bboxes is None:
224
+ return image
225
+
226
+ # Create a copy to draw on
227
+ img_draw = image.copy()
228
+ draw = ImageDraw.Draw(img_draw)
229
+
230
+ # Draw each bbox with a different color
231
+ colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0)]
232
+
233
+ for i, bbox in enumerate(bboxes):
234
+ color = colors[i % len(colors)]
235
+ x1, y1, x2, y2 = [int(coord) for coord in bbox]
236
+ # Draw rectangle
237
+ draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
238
+ # Draw label
239
+ draw.text((x1, y1-15), f"Face {i+1}", fill=color)
240
+
241
+ return img_draw
242
+
243
+
244
+ def create_demo(
245
+ model_type: str = "flux-dev",
246
+ ipa_path: str = "./ckpt/ipa.safetensors",
247
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
248
+ offload: bool = False,
249
+ lora_rank: int = 64,
250
+ additional_lora_ckpt: Optional[str] = None,
251
+ lora_scale: float = 1.0,
252
+ clip_path: str = "openai/clip-vit-large-patch14",
253
+ t5_path: str = "xlabs-ai/xflux_text_encoders",
254
+ flux_path: str = "black-forest-labs/FLUX.1-dev",
255
+ ):
256
+
257
+ face_extractor = FaceExtractor()
258
+ # Initialize pipeline and face extractor
259
+ pipeline = WithAnyonePipeline(
260
+ model_type,
261
+ ipa_path,
262
+ device,
263
+ offload,
264
+ only_lora=True,
265
+ no_lora=True,
266
+ lora_rank=lora_rank,
267
+ additional_lora_ckpt=additional_lora_ckpt,
268
+ lora_weight=lora_scale,
269
+ face_extractor=face_extractor,
270
+ clip_path=clip_path,
271
+ t5_path=t5_path,
272
+ flux_path=flux_path,
273
+ )
274
+
275
+
276
+ def parse_bboxes(bbox_text):
277
+ """Parse bounding box text input"""
278
+ if not bbox_text or bbox_text.strip() == "":
279
+ return None
280
+
281
+ try:
282
+ bboxes = []
283
+ lines = bbox_text.strip().split("\n")
284
+ for line in lines:
285
+ if not line.strip():
286
+ continue
287
+ coords = [float(x) for x in line.strip().split(",")]
288
+ if len(coords) != 4:
289
+ raise ValueError(f"Each bbox must have 4 coordinates (x1,y1,x2,y2), got: {line}")
290
+ bboxes.append(coords)
291
+ return bboxes
292
+ except Exception as e:
293
+ raise gr.Error(f"Invalid bbox format: {e}")
294
+
295
+ def extract_from_base_image(base_img):
296
+ """Extract references and bboxes from the base image"""
297
+ if base_img is None:
298
+ return None, None, None, None
299
+
300
+ # Convert from numpy to PIL if needed
301
+ if isinstance(base_img, np.ndarray):
302
+ base_img = Image.fromarray(base_img)
303
+
304
+ ref_imgs, arcface_embeddings, bboxes, new_img = face_extractor.extract_refs(base_img)
305
+
306
+ if ref_imgs is None or len(ref_imgs) == 0:
307
+ raise gr.Error("No faces detected in the base image")
308
+
309
+ # Limit to max 4 faces
310
+ ref_imgs = ref_imgs[:4]
311
+ arcface_embeddings = arcface_embeddings[:4]
312
+ bboxes = bboxes[:4]
313
+
314
+ # Create visualization with bboxes
315
+ viz_image = draw_bboxes_on_image(new_img, bboxes)
316
+
317
+ # Format bboxes as string for display
318
+ bbox_text = "\n".join([f"{bbox[0]:.1f},{bbox[1]:.1f},{bbox[2]:.1f},{bbox[3]:.1f}" for bbox in bboxes])
319
+
320
+ return ref_imgs, arcface_embeddings, bboxes, viz_image, bbox_text
321
+
322
+ def process_and_generate(
323
+ prompt,
324
+ guidance, num_steps, seed,
325
+ ref_img1, ref_img2, ref_img3, ref_img4,
326
+ base_img,
327
+ manual_bboxes_text,
328
+ use_text_prompt,
329
+ siglip_weight
330
+ ):
331
+ # Validate base_img is provided
332
+ if base_img is None:
333
+ raise gr.Error("Base image is required")
334
+
335
+ # Convert numpy to PIL if needed
336
+ if isinstance(base_img, np.ndarray):
337
+ base_img = Image.fromarray(base_img)
338
+
339
+ # Get dimensions from base_img
340
+ width, height = base_img.size
341
+
342
+
343
+ # Collect and validate reference images
344
+ ref_images = [img for img in [ref_img1, ref_img2, ref_img3, ref_img4] if img is not None]
345
+
346
+ if not ref_images:
347
+ raise gr.Error("At least one reference image is required")
348
+
349
+ # Process reference images to extract face and embeddings
350
+ ref_imgs = []
351
+ arcface_embeddings = []
352
+
353
+ # Extract bboxes from the base image
354
+ extracted_refs, extracted_embeddings, bboxes_, _, _ = extract_from_base_image(base_img)
355
+ bboxes__ = [resize_bbox(bbox, 512, 512, width, height) for bbox in bboxes_]
356
+ if extracted_refs is None:
357
+ raise gr.Error("No faces detected in the base image. Please provide a different base image with clear faces.")
358
+
359
+ # Create blurred canvas by blurring faces in the base image
360
+ blurred_canvas = blur_faces_in_image(base_img, {'bboxes': bboxes__})
361
+
362
+
363
+ bboxes = [bboxes__] # Wrap in list for batch input format
364
+
365
+ # Process each reference image
366
+ for img in ref_images:
367
+ if isinstance(img, np.ndarray):
368
+ img = Image.fromarray(img)
369
+
370
+ ref_img, embedding = face_extractor.extract(img)
371
+ if ref_img is None or embedding is None:
372
+ raise gr.Error("Failed to extract face from one of the reference images")
373
+
374
+ ref_imgs.append(ref_img)
375
+ arcface_embeddings.append(embedding)
376
+
377
+ if len(bboxes[0]) != len(ref_imgs):
378
+ raise gr.Error(f"Number of bboxes ({len(bboxes[0])}) must match number of reference images ({len(ref_imgs)})")
379
+
380
+ # Convert arcface embeddings to tensor
381
+ arcface_embeddings = [torch.tensor(embedding) for embedding in arcface_embeddings]
382
+ arcface_embeddings = torch.stack(arcface_embeddings).to(device)
383
+
384
+ # Generate image
385
+ final_prompt = prompt if use_text_prompt else ""
386
+
387
+
388
+ if seed < 0:
389
+ seed = np.random.randint(0, 1000000)
390
+
391
+ image_gen = pipeline(
392
+ prompt=final_prompt,
393
+ width=width,
394
+ height=height,
395
+ guidance=guidance,
396
+ num_steps=num_steps,
397
+ seed=seed if seed > 0 else None,
398
+ ref_imgs=ref_imgs,
399
+ img_cond=blurred_canvas, # Pass the blurred canvas image
400
+ arcface_embeddings=arcface_embeddings,
401
+ bboxes=bboxes,
402
+ max_num_ids=len(ref_imgs),
403
+ siglip_weight=0,
404
+ id_weight=1, # only arcface supported now
405
+ arc_only=True,
406
+ )
407
+
408
+ # Save temp file for download
409
+ temp_path = "temp_generated.png"
410
+ image_gen.save(temp_path)
411
+
412
+ # draw bboxes on the generated image for debug
413
+ debug_face = draw_bboxes_on_image(image_gen, bboxes[0])
414
+
415
+ return image_gen, debug_face, temp_path
416
+
417
+ def update_bbox_display(base_img):
418
+ if base_img is None:
419
+ return None, None
420
+
421
+ try:
422
+ _, _, _, viz_image, bbox_text = extract_from_base_image(base_img)
423
+ return viz_image, bbox_text
424
+ except Exception as e:
425
+ return None, None
426
+
427
+ # Create Gradio interface
428
+ with gr.Blocks() as demo:
429
+ gr.Markdown("# WithAnyone Kontext Demo")
430
+
431
+ with gr.Row():
432
+
433
+ with gr.Column():
434
+ # Input controls
435
+ generate_btn = gr.Button("Generate", variant="primary")
436
+ siglip_weight = 0.0
437
+ with gr.Row():
438
+ prompt = gr.Textbox(label="Prompt", value="a person in a beautiful garden. High resolution, extremely detailed")
439
+ use_text_prompt = gr.Checkbox(label="Use text prompt", value=True)
440
+
441
+ with gr.Accordion("Advanced Options", open=False):
442
+ with gr.Row():
443
+ num_steps = gr.Slider(1, 50, 25, step=1, label="Number of steps")
444
+ guidance = gr.Slider(1.0, 10.0, 4.0, step=0.1, label="Guidance")
445
+ seed = gr.Number(-1, label="Seed (-1 for random)")
446
+
447
+ with gr.Row():
448
+ with gr.Column():
449
+ # Reference image inputs
450
+ gr.Markdown("### Face References (1-4 required)")
451
+ ref_img1 = gr.Image(label="Reference 1", type="pil")
452
+ ref_img2 = gr.Image(label="Reference 2", type="pil", visible=True)
453
+ ref_img3 = gr.Image(label="Reference 3", type="pil", visible=True)
454
+ ref_img4 = gr.Image(label="Reference 4", type="pil", visible=True)
455
+
456
+ with gr.Column():
457
+ # Base image input - combines the previous canvas and multi-person image
458
+ gr.Markdown("### Base Image (Required)")
459
+ base_img = gr.Image(label="Base Image - faces will be detected and replaced", type="pil")
460
+
461
+ bbox_preview = gr.Image(label="Detected Faces", type="pil")
462
+
463
+ gr.Markdown("### Manual Bounding Box Override (Optional)")
464
+ manual_bbox_input = gr.Textbox(
465
+ label="Manual Bounding Boxes (one per line, format: x1,y1,x2,y2)",
466
+ lines=4,
467
+ placeholder="100,100,200,200\n300,100,400,200"
468
+ )
469
+
470
+
471
+ with gr.Column():
472
+ # Output display
473
+ output_image = gr.Image(label="Generated Image")
474
+ debug_face = gr.Image(label="Debug: Faces are expected to be generated in these boxes")
475
+ download_btn = gr.File(label="Download full-resolution", type="filepath", interactive=False)
476
+
477
+ # Examples section
478
+ with gr.Row():
479
+
480
+ gr.Markdown("""
481
+ # Example Configurations
482
+
483
+ ### Tips for Better Results
484
+ - Base image is required - faces in this image will be detected, blurred, and then replaced
485
+ - Provide clear reference images with visible faces
486
+ - Use detailed prompts describing the desired output
487
+ - Adjust the resemblance slider based on your needs - more to the right for closer facial resemblance
488
+ """)
489
+ with gr.Row():
490
+ examples = gr.Examples(
491
+ examples=[
492
+ [
493
+ "", # prompt
494
+ 4.0, 25, 42, # guidance, num_steps, seed
495
+ "assets/ref3.jpg", "assets/ref1.jpg", None, None, # ref images
496
+ "assets/canvas.jpg", # base image
497
+ False, # use_text_prompt
498
+ ]
499
+ ],
500
+ inputs=[
501
+ prompt, guidance, num_steps, seed,
502
+ ref_img1, ref_img2, ref_img3, ref_img4,
503
+ base_img, use_text_prompt
504
+ ],
505
+ label="Click to load example configurations"
506
+ )
507
+ # Set up event handlers
508
+ base_img.change(
509
+ fn=update_bbox_display,
510
+ inputs=[base_img],
511
+ outputs=[bbox_preview, manual_bbox_input]
512
+ )
513
+
514
+ generate_btn.click(
515
+ fn=process_and_generate,
516
+ inputs=[
517
+ prompt, guidance, num_steps, seed,
518
+ ref_img1, ref_img2, ref_img3, ref_img4,
519
+ base_img, use_text_prompt,
520
+ ],
521
+ outputs=[output_image, debug_face, download_btn]
522
+ )
523
+
524
+ return demo
525
+
526
+
527
+ if __name__ == "__main__":
528
+ from transformers import HfArgumentParser
529
+
530
+ @dataclasses.dataclass
531
+ class AppArgs:
532
+ model_type: Literal["flux-dev", "flux-kontext", "flux-schnell"] = "flux-kontext"
533
+ device: Literal["cuda", "cpu"] = (
534
+ "cuda" if torch.cuda.is_available()
535
+ else "mps" if hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
536
+ else "cpu"
537
+ )
538
+ offload: bool = False
539
+ lora_rank: int = 64
540
+ port: int = 7860
541
+ additional_lora: str = None
542
+ lora_scale: float = 1.0
543
+ ipa_path: str = "./ckpt/ipa.safetensors"
544
+ clip_path: str = "openai/clip-vit-large-patch14"
545
+ t5_path: str = "xlabs-ai/xflux_text_encoders"
546
+ flux_path: str = "black-forest-labs/FLUX.1-dev"
547
+
548
+ parser = HfArgumentParser([AppArgs])
549
+ args = parser.parse_args_into_dataclasses()[0]
550
+
551
+ demo = create_demo(
552
+ args.model_type,
553
+ args.ipa_path,
554
+ args.device,
555
+ args.offload,
556
+ args.lora_rank,
557
+ args.additional_lora,
558
+ args.lora_scale,
559
+ args.clip_path,
560
+ args.t5_path,
561
+ args.flux_path,
562
+ )
563
+ demo.launch(server_port=args.port)
infer_withanyone.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Fudan University. All rights reserved.
2
+
3
+
4
+
5
+ import os
6
+ import dataclasses
7
+ from typing import Literal
8
+
9
+ from accelerate import Accelerator
10
+ from transformers import HfArgumentParser
11
+ from PIL import Image
12
+ import json
13
+ import itertools
14
+
15
+ from withanyone.flux.pipeline import WithAnyonePipeline
16
+
17
+ from util import extract_moref, general_face_preserving_resize, horizontal_concat, extract_object, FaceExtractor
18
+
19
+ import numpy as np
20
+
21
+ import random
22
+ import torch
23
+
24
+
25
+
26
+ from transformers import AutoModelForImageSegmentation
27
+ from torch.cuda.amp import autocast
28
+
29
+
30
+ BACK_UP_BBOXES_DOUBLE = [
31
+
32
+ [[100,100,200,200], [300,100,400,200]], # 2 faces
33
+ [[150,100,250,200], [300,100,400,200]]
34
+ ]
35
+
36
+ BACK_UP_BBOXES = [ # for single face
37
+ [[150,100,250,200]],
38
+ [[100,100,200,200]],
39
+ [[200,100,300,200]],
40
+ [[250,100,350,200]],
41
+ [[300,100,400,200]],
42
+ ]
43
+
44
+
45
+
46
+
47
+
48
+
49
+ @dataclasses.dataclass
50
+ class InferenceArgs:
51
+ prompt: str | None = None
52
+ image_paths: list[str] | None = None
53
+ eval_json_path: str | None = None
54
+ offload: bool = False
55
+ num_images_per_prompt: int = 1
56
+ model_type: Literal["flux-dev", "flux-dev-fp8", "flux-schnell"] = "flux-dev"
57
+ width: int = 512
58
+ height: int = 512
59
+ ref_size: int = -1
60
+ num_steps: int = 25
61
+ guidance: float = 4
62
+ seed: int = 1234
63
+ save_path: str = "output/inference"
64
+ only_lora: bool = True
65
+ concat_refs: bool = False
66
+ lora_rank: int = 64
67
+ data_resolution: int = 512
68
+ save_iter: str = "500"
69
+ use_rec: bool = False
70
+ drop_text: bool = False
71
+ use_matting: bool = False
72
+ id_weight: float = 1.0
73
+ siglip_weight: float = 1.0
74
+ bbox_from_json: bool = True
75
+ data_root: str = "./"
76
+ # for lora
77
+ additional_lora: str | None = None
78
+ trigger: str = ""
79
+ lora_weight: float = 1.0
80
+
81
+ # path to the ipa model
82
+ ipa_path: str = "./ckpt/ipa.safetensors"
83
+ clip_path: str = "openai/clip-vit-large-patch14"
84
+ t5_path: str = "xlabs-ai/xflux_text_encoders"
85
+ flux_path: str = "black-forest-labs/FLUX.1-dev"
86
+ siglip_path: str = "google/siglip-base-patch16-256-i18n"
87
+
88
+
89
+
90
+ def main(args: InferenceArgs):
91
+ accelerator = Accelerator()
92
+
93
+ face_extractor = FaceExtractor()
94
+
95
+ pipeline = WithAnyonePipeline(
96
+ args.model_type,
97
+ args.ipa_path,
98
+ accelerator.device,
99
+ args.offload,
100
+ only_lora=args.only_lora,
101
+ face_extractor=face_extractor,
102
+ additional_lora_ckpt=args.additional_lora,
103
+ lora_weight=args.lora_weight,
104
+ clip_path=args.clip_path,
105
+ t5_path=args.t5_path,
106
+ flux_path=args.flux_path,
107
+ siglip_path=args.siglip_path,
108
+ )
109
+
110
+
111
+
112
+ if args.use_matting:
113
+ birefnet = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True).to('cuda', dtype=torch.bfloat16)
114
+
115
+
116
+ assert args.prompt is not None or args.eval_json_path is not None, \
117
+ "Please provide either prompt or eval_json_path"
118
+
119
+ # if args.eval_json_path is not None:
120
+ assert args.eval_json_path is not None, "Please provide eval_json_path. This script only supports batch inference."
121
+ with open(args.eval_json_path, "rt") as f:
122
+ data_dicts = json.load(f)
123
+ data_root = args.data_root
124
+
125
+
126
+
127
+ metadata = {}
128
+ for (i, data_dict), j in itertools.product(enumerate(data_dicts), range(args.num_images_per_prompt)):
129
+
130
+
131
+ if (i * args.num_images_per_prompt + j) % accelerator.num_processes != accelerator.process_index:
132
+ continue
133
+ # check if exist, if this image is already generated, skip it
134
+
135
+
136
+
137
+ # if any of the images are None, skip this image
138
+ if not os.path.exists(os.path.join(data_root, data_dict["image_paths"][0])):
139
+ print(f"Image {i} does not exist, skipping...")
140
+ print("path:", os.path.join(data_root, data_dict["image_paths"][0]))
141
+ continue
142
+
143
+
144
+ # extract bbox
145
+
146
+ ori_img_path = data_dict.get("ori_img_path", None)
147
+ # ori_img = Image.open(os.path.join(data_root, data_dict["ori_img_path"]))
148
+
149
+ # basename = data_dict["ori_img_path"].split(".")[0].replace("/", "_")
150
+ if ori_img_path is None:
151
+ basename = f"{i}_{j}"
152
+ else:
153
+ basename = data_dict["ori_img_path"].split(".")[0].replace("/", "_")
154
+ ori_img = Image.open(os.path.join(data_root, ori_img_path))
155
+ bboxes = None
156
+ print("Processing image", i, basename)
157
+ if not args.bbox_from_json:
158
+ if ori_img_path is None:
159
+ print(f"Image {i} has no ori_img_path, cannot extract bbox, skipping...")
160
+ continue
161
+ ori_img = Image.open(os.path.join(data_root, ori_img_path))
162
+ bboxes = face_extractor.locate_bboxes(ori_img)
163
+ # cut bbox length to num of imgae_paths
164
+ if bboxes is not None and len(bboxes) > len(data_dict["image_paths"]):
165
+ bboxes = bboxes[:len(data_dict["image_paths"])]
166
+ elif bboxes is not None and len(bboxes) < len(data_dict["image_paths"]):
167
+ print(f"Image {i} has less faces than image_paths, continuing...")
168
+ continue
169
+ else:
170
+ json_file_path = os.path.join(data_root, basename + ".json")
171
+ if os.path.exists(json_file_path):
172
+ with open(json_file_path, "r") as f:
173
+ json_data = json.load(f)
174
+ old_bboxes = json_data.get("bboxes", None)
175
+
176
+ if old_bboxes is None:
177
+ print(f"Image {i} has no bboxes in json file, using backup bboxes...")
178
+ # v202 -> 2 faces v200_single -> 1 face
179
+ if "v202" in args.eval_json_path:
180
+ old_bboxes = random.choice(BACK_UP_BBOXES_DOUBLE)
181
+ elif "v200_single" in args.eval_json_path:
182
+ old_bboxes = random.choice(BACK_UP_BBOXES)
183
+
184
+
185
+ def recalculate_bbox( bbox, crop):
186
+ """
187
+ The image is cropped, so we need to recalculate the bbox.
188
+ bbox: [x1, y1, x2, y2]
189
+ crop: [x1c, y1c, x2c, y2c]
190
+ we just need to minus x1c and y1c from x1, y1,
191
+ """
192
+ x1, y1, x2, y2 = bbox
193
+ x1c, y1c, x2c, y2c = crop
194
+ return [x1-x1c, y1-y1c, x2-x1c, y2-y1c]
195
+ crop = json_data.get("crop", None)
196
+ rec_bboxes = [
197
+ recalculate_bbox(bbox, crop) if crop is not None else bbox for bbox in old_bboxes]
198
+ # face_preserving_resize(image, bboxes, 512)
199
+ if ori_img_path is not None:
200
+ _, bboxes = general_face_preserving_resize(ori_img, rec_bboxes, 512)
201
+ # else we consider the provided bbox is already in target size
202
+ else:
203
+ bboxes = rec_bboxes
204
+
205
+ if bboxes is None:
206
+
207
+ print(f"Image {i} has no face, bboxes are None, using backup bboxes..., basename: {basename}")
208
+
209
+ bboxes = random.choice(BACK_UP_BBOXES_DOUBLE)
210
+ print(f"Use backup bboxes: {bboxes}")
211
+
212
+
213
+ ref_imgs = []
214
+ arcface_embeddings = []
215
+ if not args.use_rec:
216
+ break_flag = False
217
+ for img_path in data_dict["image_paths"]:
218
+ img = Image.open(os.path.join(data_root, img_path))
219
+
220
+
221
+ ref_img, arcface_embedding = face_extractor.extract(img)
222
+
223
+ if ref_img is not None and arcface_embedding is not None:
224
+ if args.use_matting:
225
+ ref_img, _ = extract_object(birefnet, ref_img)
226
+ ref_imgs.append(ref_img)
227
+ arcface_embeddings.append(arcface_embedding)
228
+ else:
229
+ print(f"Image {i} has no face, skipping...")
230
+ break_flag = True
231
+ break
232
+ if break_flag:
233
+ continue
234
+ else:
235
+ ref_imgs, arcface_embeddings = face_extractor.extract_refs(ori_img)
236
+
237
+ if ref_imgs is None or arcface_embeddings is None:
238
+ print(f"Image {i} has no face, skipping...")
239
+ continue
240
+
241
+ if args.use_matting:
242
+ ref_imgs = [extract_object(birefnet, ref_img)[0] for ref_img in ref_imgs]
243
+
244
+
245
+ # arcface to tensor
246
+ arcface_embeddings = [torch.tensor(arcface_embedding) for arcface_embedding in arcface_embeddings]
247
+ arcface_embeddings = torch.stack(arcface_embeddings).to(accelerator.device)
248
+
249
+
250
+ # check, if any of the images are None, if so, skip this image
251
+ if any(ref_img is None for ref_img in ref_imgs):
252
+ print(f"Image {i}: failed to extract face, skipping...")
253
+ continue
254
+
255
+
256
+ if args.ref_size==-1:
257
+ args.ref_size = 512 if len(ref_imgs)==1 else 320
258
+
259
+
260
+ if args.trigger != "" and args.trigger is not None:
261
+ data_dict["prompt"] = args.trigger + " " + data_dict["prompt"]
262
+
263
+
264
+ image_gen = pipeline(
265
+ prompt=data_dict["prompt"] if not args.drop_text else "",
266
+ width=args.width,
267
+ height=args.height,
268
+ guidance=args.guidance,
269
+ num_steps=args.num_steps,
270
+ seed=args.seed,
271
+ ref_imgs=ref_imgs,
272
+ arcface_embeddings=arcface_embeddings,
273
+ bboxes=[bboxes],
274
+ id_weight=args.id_weight,
275
+ siglip_weight=args.siglip_weight,
276
+
277
+ )
278
+
279
+
280
+ if args.concat_refs:
281
+ image_gen = horizontal_concat([image_gen, *ref_imgs])
282
+
283
+ os.makedirs(args.save_path, exist_ok=True)
284
+
285
+
286
+ save_path = os.path.join(args.save_path, basename)
287
+ os.makedirs(os.path.join(args.save_path, basename), exist_ok=True)
288
+
289
+ # save refs, image_gen and original image
290
+ for k, ref_img in enumerate(ref_imgs):
291
+ ref_img.save(os.path.join(save_path, f"ref_{k}.jpg"))
292
+ image_gen.save(os.path.join(save_path, f"out.jpg"))
293
+ # original image
294
+ ori_img = Image.open(os.path.join(data_root, data_dict["ori_img_path"])) if "ori_img_path" in data_dict else None
295
+ if ori_img is not None:
296
+ ori_img.save(os.path.join(save_path, f"ori.jpg"))
297
+ # save config
298
+ args_dict = vars(args)
299
+ args_dict['prompt'] = data_dict["prompt"]
300
+ args_dict["name"] = data_dict["name"] if "name" in data_dict else None
301
+ json.dump(args_dict, open(os.path.join(save_path, f"meta.json"), 'w'), indent=4, ensure_ascii=False)
302
+
303
+
304
+ if __name__ == "__main__":
305
+ parser = HfArgumentParser([InferenceArgs])
306
+ args = parser.parse_args_into_dataclasses()[0]
307
+ main(args)
308
+
309
+
nohup.out ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ Only 2 GPUs available, exiting.
2
+ Only 2 GPUs available, exiting.
requirements.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.6.0
2
+ einops
3
+ gradio
4
+ huggingface_hub
5
+ insightface
6
+ matplotlib
7
+ numpy
8
+ opencv-python
9
+ opencv-python-headless
10
+ optimum
11
+ optimum_quanto
12
+ Pillow
13
+ PyYAML
14
+ PyYAML
15
+ safetensors
16
+ seaborn
17
+ scikit-image
18
+ torch==2.5.1
19
+ torchvision==0.20.1
20
+ tqdm
21
+ transformers==4.45.2
22
+ onnxruntime
23
+ onnxruntime-gpu
24
+ sentencepiece
util.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Fudan University. All rights reserved.
2
+
3
+
4
+ from io import BytesIO
5
+ import random
6
+ from PIL import Image
7
+ import numpy as np
8
+ import cv2
9
+ import insightface
10
+ import torch
11
+ from torchvision import transforms
12
+ from torch.cuda.amp import autocast
13
+
14
+ def face_preserving_resize(img, face_bboxes, target_size=512):
15
+ """
16
+ Resize image while ensuring all faces are preserved in the output.
17
+
18
+ Args:
19
+ img: PIL Image
20
+ face_bboxes: List of [x1, y1, x2, y2] face coordinates
21
+ target_size: Maximum dimension for resizing
22
+
23
+ Returns:
24
+ Tuple of (resized image, new_bboxes) or (None, None) if faces can't fit
25
+ """
26
+
27
+ x1_1, y1_1, x2_1, y2_1 = map(int, face_bboxes[0])
28
+ x1_2, y1_2, x2_2, y2_2 = map(int, face_bboxes[1])
29
+ min_x1 = min(x1_1, x1_2)
30
+ min_y1 = min(y1_1, y1_2)
31
+ max_x2 = max(x2_1, x2_2)
32
+ max_y2 = max(y2_1, y2_2)
33
+ # print("min_x1:", min_x1, "min_y1:", min_y1, "max_x2:", max_x2, "max_y2:", max_y2)
34
+ # if any of them is negative, we cannot resize (Idk why this happens)
35
+ if min_x1 < 0 or min_y1 < 0 or max_x2 < 0 or max_y2 < 0:
36
+ return None, None
37
+
38
+ # if face width is longer than the image height, or the face height is longer than the image width, we cannot resize
39
+ face_width = max_x2 - min_x1
40
+ face_height = max_y2 - min_y1
41
+ if face_width > img.height or face_height > img.width:
42
+ return None, None
43
+
44
+ # Create a copy of face_bboxes for transformation
45
+ new_bboxes = []
46
+ for bbox in face_bboxes:
47
+ new_bboxes.append(list(map(int, bbox)))
48
+
49
+ # Choose cropping strategy based on image aspect ratio
50
+ if img.width > img.height:
51
+ # We need to crop width to make a square
52
+ square_size = img.height
53
+
54
+ # Calculate valid horizontal crop range that preserves all faces
55
+ left_max = min_x1 # Leftmost position that includes leftmost face
56
+ right_min = max_x2 - square_size # Rightmost position that includes rightmost face
57
+
58
+ if right_min <= left_max:
59
+ # We can find a valid crop window
60
+ start = random.randint(int(right_min), int(left_max)) if right_min < left_max else int(right_min)
61
+ start = max(0, min(start, img.width - square_size)) # Ensure within image bounds
62
+ else:
63
+ # Faces are too far apart for square crop - use center of faces
64
+ face_center = (min_x1 + max_x2) // 2
65
+ start = max(0, min(face_center - (square_size // 2), img.width - square_size))
66
+
67
+ cropped_img = img.crop((start, 0, start + square_size, square_size))
68
+
69
+ # Adjust bounding box coordinates based on crop
70
+ for bbox in new_bboxes:
71
+ bbox[0] -= start # x1 adjustment
72
+ bbox[2] -= start # x2 adjustment
73
+ # y coordinates remain unchanged
74
+ else:
75
+ # We need to crop height to make a square
76
+ square_size = img.width
77
+
78
+ # Calculate valid vertical crop range that preserves all faces
79
+ top_max = min_y1 # Topmost position that includes topmost face
80
+ bottom_min = max_y2 - square_size # Bottommost position that includes bottommost face
81
+
82
+ if bottom_min <= top_max:
83
+ # We can find a valid crop window
84
+ start = random.randint(int(bottom_min), int(top_max)) if bottom_min < top_max else int(bottom_min)
85
+ start = max(0, min(start, img.height - square_size)) # Ensure within image bounds
86
+ else:
87
+ # Faces are too far apart for square crop - use center of faces
88
+ face_center = (min_y1 + max_y2) // 2
89
+ start = max(0, min(face_center - (square_size // 2), img.height - square_size))
90
+
91
+ cropped_img = img.crop((0, start, square_size, start + square_size))
92
+
93
+ # Adjust bounding box coordinates based on crop
94
+ for bbox in new_bboxes:
95
+ bbox[1] -= start # y1 adjustment
96
+ bbox[3] -= start # y2 adjustment
97
+ # x coordinates remain unchanged
98
+
99
+ # Calculate scale factor for resizing from square_size to target_size
100
+ scale_factor = target_size / square_size
101
+
102
+ # Adjust bounding boxes for the resize operation
103
+ for bbox in new_bboxes:
104
+ bbox[0] = int(bbox[0] * scale_factor)
105
+ bbox[1] = int(bbox[1] * scale_factor)
106
+ bbox[2] = int(bbox[2] * scale_factor)
107
+ bbox[3] = int(bbox[3] * scale_factor)
108
+
109
+ # Final resize to target size
110
+ resized_img = cropped_img.resize((target_size, target_size), Image.Resampling.LANCZOS)
111
+
112
+ # Make sure all coordinates are within bounds (0 to target_size)
113
+ # for bbox in new_bboxes:
114
+ # bbox[0] = max(0, min(bbox[0], target_size - 1))
115
+ # bbox[1] = max(0, min(bbox[1], target_size - 1))
116
+ # bbox[2] = max(1, min(bbox[2], target_size))
117
+ # bbox[3] = max(1, min(bbox[3], target_size))
118
+
119
+ return resized_img, new_bboxes
120
+
121
+ def extract_moref(img, json_data, face_size_restriction=100):
122
+ """
123
+ Extract faces from an image based on bounding boxes in JSON data.
124
+ Makes each face square and resizes to 512x512.
125
+
126
+ Args:
127
+ img: PIL Image or image data
128
+ json_data: JSON object with 'bboxes' and 'crop' information
129
+
130
+ Returns:
131
+ List of PIL Images, each 512x512, containing extracted faces
132
+ """
133
+ # Ensure img is a PIL Image
134
+ try:
135
+ if not isinstance(img, Image.Image) and not isinstance(img, torch.Tensor) and not isinstance(img, JpegImageFile):
136
+ img = Image.open(BytesIO(img))
137
+
138
+ bboxes = json_data['bboxes']
139
+ # crop = json_data['crop']
140
+ # print("len of bboxes:", len(bboxes))
141
+ # Recalculate bounding boxes based on crop info
142
+ # new_bboxes = [recalculate_bbox(bbox, crop) for bbox in bboxes]
143
+ new_bboxes = bboxes
144
+ # any of the face is less than 100 * 100, we ignore this image
145
+ for bbox in new_bboxes:
146
+ x1, y1, x2, y2 = bbox
147
+ if x2 - x1 < face_size_restriction or y2 - y1 < face_size_restriction:
148
+ return []
149
+ # print("len of new_bboxes:", len(new_bboxes))
150
+ faces = []
151
+ for bbox in new_bboxes:
152
+ # print("processing bbox")
153
+ # Convert coordinates to integers
154
+ x1, y1, x2, y2 = map(int, bbox)
155
+
156
+ # Calculate width and height
157
+ width = x2 - x1
158
+ height = y2 - y1
159
+
160
+ # Make the bounding box square by expanding the shorter dimension
161
+ if width > height:
162
+ # Height is shorter, expand it
163
+ diff = width - height
164
+ y1 -= diff // 2
165
+ y2 += diff - (diff // 2) # Handle odd differences
166
+ elif height > width:
167
+ # Width is shorter, expand it
168
+ diff = height - width
169
+ x1 -= diff // 2
170
+ x2 += diff - (diff // 2) # Handle odd differences
171
+
172
+ # Ensure coordinates are within image boundaries
173
+ img_width, img_height = img.size
174
+ x1 = max(0, x1)
175
+ y1 = max(0, y1)
176
+ x2 = min(img_width, x2)
177
+ y2 = min(img_height, y2)
178
+
179
+ # Extract face region
180
+ face_region = img.crop((x1, y1, x2, y2))
181
+
182
+ # Resize to 512x512
183
+ face_region = face_region.resize((512, 512), Image.LANCZOS)
184
+
185
+ faces.append(face_region)
186
+ # print("len of faces:", len(faces))
187
+ return faces
188
+ except Exception as e:
189
+ print(f"Error processing image: {e}")
190
+ return []
191
+
192
+ def general_face_preserving_resize(img, face_bboxes, target_size=512):
193
+ """
194
+ Resize image while ensuring all faces are preserved in the output.
195
+ Handles any number of faces (1-5).
196
+
197
+ Args:
198
+ img: PIL Image
199
+ face_bboxes: List of [x1, y1, x2, y2] face coordinates
200
+ target_size: Maximum dimension for resizing
201
+
202
+ Returns:
203
+ Tuple of (resized image, new_bboxes) or (None, None) if faces can't fit
204
+ """
205
+ # Find bounding region containing all faces
206
+ if not face_bboxes:
207
+ print("Warning: No face bounding boxes provided.")
208
+ return None, None
209
+
210
+ min_x1 = min(bbox[0] for bbox in face_bboxes)
211
+ min_y1 = min(bbox[1] for bbox in face_bboxes)
212
+ max_x2 = max(bbox[2] for bbox in face_bboxes)
213
+ max_y2 = max(bbox[3] for bbox in face_bboxes)
214
+
215
+ # Check for negative coordinates
216
+ if min_x1 < 0 or min_y1 < 0 or max_x2 < 0 or max_y2 < 0:
217
+ # print("Warning: Negative coordinates found in face bounding boxes.")
218
+ # return None, None
219
+ min_x1 = max(min_x1, 0)
220
+ min_y1 = max(min_y1, 0)
221
+
222
+ # Check if faces fit within image
223
+ face_width = max_x2 - min_x1
224
+ face_height = max_y2 - min_y1
225
+ if face_width > img.height or face_height > img.width:
226
+ # print("Warning: Faces are too large for the image dimensions.")
227
+ # return None, None
228
+ # Instead of returning None, we will crop the image to fit the faces
229
+ max_x2 = min(max_x2, img.width)
230
+ max_y2 = min(max_y2, img.height)
231
+ min_x1 = max(min_x1, 0)
232
+ min_y1 = max(min_y1, 0)
233
+ # Create a copy of face_bboxes for transformation
234
+ new_bboxes = []
235
+ for bbox in face_bboxes:
236
+ new_bboxes.append(list(map(int, bbox)))
237
+
238
+ # Choose cropping strategy based on image aspect ratio
239
+ if img.width > img.height:
240
+ # Crop width to make a square
241
+ square_size = img.height
242
+
243
+ # Calculate valid horizontal crop range
244
+ left_max = min_x1
245
+ right_min = max_x2 - square_size
246
+
247
+ if right_min <= left_max:
248
+ # We can find a valid crop window
249
+ start = random.randint(int(right_min), int(left_max)) if right_min < left_max else int(right_min)
250
+ start = max(0, min(start, img.width - square_size))
251
+ else:
252
+ # Faces are too far apart - use center of faces
253
+ face_center = (min_x1 + max_x2) // 2
254
+ start = max(0, min(face_center - (square_size // 2), img.width - square_size))
255
+
256
+ cropped_img = img.crop((start, 0, start + square_size, square_size))
257
+
258
+ # Adjust bounding box coordinates
259
+ for bbox in new_bboxes:
260
+ bbox[0] -= start
261
+ bbox[2] -= start
262
+ else:
263
+ # Crop height to make a square
264
+ square_size = img.width
265
+
266
+ # Calculate valid vertical crop range
267
+ top_max = min_y1
268
+ bottom_min = max_y2 - square_size
269
+
270
+ if bottom_min <= top_max:
271
+ start = random.randint(int(bottom_min), int(top_max)) if bottom_min < top_max else int(bottom_min)
272
+ start = max(0, min(start, img.height - square_size))
273
+ else:
274
+ face_center = (min_y1 + max_y2) // 2
275
+ start = max(0, min(face_center - (square_size // 2), img.height - square_size))
276
+
277
+ cropped_img = img.crop((0, start, square_size, start + square_size))
278
+
279
+ # Adjust bounding box coordinates
280
+ for bbox in new_bboxes:
281
+ bbox[1] -= start
282
+ bbox[3] -= start
283
+
284
+ # Calculate scale factor and adjust bounding boxes
285
+ scale_factor = target_size / square_size
286
+
287
+ for bbox in new_bboxes:
288
+ bbox[0] = int(bbox[0] * scale_factor)
289
+ bbox[1] = int(bbox[1] * scale_factor)
290
+ bbox[2] = int(bbox[2] * scale_factor)
291
+ bbox[3] = int(bbox[3] * scale_factor)
292
+
293
+ # Final resize to target size
294
+ resized_img = cropped_img.resize((target_size, target_size), Image.Resampling.LANCZOS)
295
+
296
+ # Make sure all coordinates are within bounds
297
+ for bbox in new_bboxes:
298
+ bbox[0] = max(0, min(bbox[0], target_size - 1))
299
+ bbox[1] = max(0, min(bbox[1], target_size - 1))
300
+ bbox[2] = max(1, min(bbox[2], target_size))
301
+ bbox[3] = max(1, min(bbox[3], target_size))
302
+
303
+ return resized_img, new_bboxes
304
+
305
+
306
+
307
+ def horizontal_concat(images):
308
+ widths, heights = zip(*(img.size for img in images))
309
+
310
+ total_width = sum(widths)
311
+ max_height = max(heights)
312
+
313
+ new_im = Image.new('RGB', (total_width, max_height))
314
+
315
+ x_offset = 0
316
+ for img in images:
317
+ new_im.paste(img, (x_offset, 0))
318
+ x_offset += img.size[0]
319
+
320
+ return new_im
321
+
322
+ def extract_object(birefnet, image):
323
+
324
+
325
+ if image.mode != 'RGB':
326
+ image = image.convert('RGB')
327
+ input_images = transforms.ToTensor()(image).unsqueeze(0).to('cuda', dtype=torch.bfloat16)
328
+
329
+ # Prediction
330
+ with torch.no_grad(), autocast(dtype=torch.bfloat16):
331
+ preds = birefnet(input_images)[-1].sigmoid().cpu()
332
+ pred = preds[0].squeeze().float()
333
+ pred_pil = transforms.ToPILImage()(pred)
334
+ mask = pred_pil.resize(image.size)
335
+
336
+ # Create a binary mask (0 or 255)
337
+ binary_mask = mask.convert("L")
338
+
339
+ # Create a new image with black background
340
+ result = Image.new("RGB", image.size, (0, 0, 0))
341
+
342
+ # Paste the original image onto the black background using the mask
343
+ result.paste(image, (0, 0), binary_mask)
344
+
345
+ return result, mask
346
+
347
+ class FaceExtractor:
348
+ def __init__(self):
349
+ self.model = insightface.app.FaceAnalysis(name = "antelopev2", root="./")
350
+ self.model.prepare(ctx_id=0, det_thresh=0.4)
351
+
352
+ def extract(self, image: Image.Image):
353
+ image_np = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
354
+ res = self.model.get(image_np)
355
+ if len(res) == 0:
356
+ return None, None
357
+ res = res[0]
358
+ # print(res.keys())
359
+ bbox = res["bbox"]
360
+ # print("len(bbox)", len(bbox))
361
+
362
+ moref = extract_moref(image, {"bboxes": [bbox]}, 1)
363
+ # print("len(moref)", len(moref))
364
+ return moref[0], res["embedding"]
365
+
366
+ def locate_bboxes(self, image: Image.Image):
367
+ image_np = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
368
+ res = self.model.get(image_np)
369
+ if len(res) == 0:
370
+ return None
371
+ bboxes = []
372
+ for r in res:
373
+ bbox = r["bbox"]
374
+ bboxes.append(bbox)
375
+
376
+ _, new_bboxes_ = general_face_preserving_resize(image, bboxes, 512)
377
+
378
+ # ensure the bbox is square
379
+ new_bboxes = []
380
+ for bbox in new_bboxes_:
381
+ x1, y1, x2, y2 = bbox
382
+ w = x2 - x1
383
+ h = y2 - y1
384
+ if w > h:
385
+ diff = w - h
386
+ y1 = max(0, y1 - diff // 2)
387
+ y2 = min(512, y2 + diff // 2 + diff % 2)
388
+ else:
389
+ diff = h - w
390
+ x1 = max(0, x1 - diff // 2)
391
+ x2 = min(512, x2 + diff // 2 + diff % 2)
392
+ new_bboxes.append([x1, y1, x2, y2])
393
+
394
+ return new_bboxes
395
+ def extract_refs(self, image: Image.Image):
396
+ """
397
+ Extracts reference faces from the image.
398
+ Returns a list of reference images and their arcface embeddings.
399
+ """
400
+ image_np = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
401
+ res = self.model.get(image_np)
402
+ if len(res) == 0:
403
+ return None, None
404
+ ref_imgs = []
405
+ arcface_embeddings = []
406
+ for r in res:
407
+ bbox = r["bbox"]
408
+ moref = extract_moref(image, {"bboxes": [bbox]}, 1)
409
+ ref_imgs.append(moref[0])
410
+ arcface_embeddings.append(r["embedding"])
411
+ return ref_imgs, arcface_embeddings
withanyone/flux/__pycache__/math.cpython-310.pyc ADDED
Binary file (2.03 kB). View file
 
withanyone/flux/__pycache__/model.cpython-310.pyc ADDED
Binary file (14.3 kB). View file
 
withanyone/flux/__pycache__/pipeline.cpython-310.pyc ADDED
Binary file (8.58 kB). View file
 
withanyone/flux/__pycache__/sampling.cpython-310.pyc ADDED
Binary file (4.12 kB). View file
 
withanyone/flux/__pycache__/util.cpython-310.pyc ADDED
Binary file (11 kB). View file
 
withanyone/flux/math.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import torch
4
+ from einops import rearrange
5
+ from torch import Tensor
6
+
7
+ import torch
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ import os
11
+ import seaborn as sns
12
+ from torch import Tensor
13
+ from matplotlib.colors import LinearSegmentedColormap
14
+ from dataclasses import dataclass
15
+ # a return class
16
+ @dataclass
17
+ class AttentionReturnQAndMAP:
18
+ result: Tensor
19
+ attention_map: Tensor
20
+ Q: Tensor
21
+
22
+
23
+
24
+ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask = None, token_aug_idx = -1, text_length = None, image_length = None, return_map = False) -> Tensor:
25
+ q, k = apply_rope(q, k, pe)
26
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v, mask)
27
+ x = rearrange(x, "B H L D -> B L (H D)")
28
+
29
+ return x
30
+
31
+
32
+
33
+
34
+ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
35
+ assert dim % 2 == 0
36
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
37
+ omega = 1.0 / (theta**scale)
38
+ out = torch.einsum("...n,d->...nd", pos, omega)
39
+ out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
40
+ out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
41
+ return out.float()
42
+
43
+
44
+ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
45
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
46
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
47
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
48
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
49
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
withanyone/flux/model.py ADDED
@@ -0,0 +1,610 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import torch
6
+ from torch import Tensor, nn
7
+
8
+ from .modules.layers import DoubleStreamBlock, EmbedND, LastLayer, MLPEmbedder, SingleStreamBlock, timestep_embedding, PerceiverAttentionCA
9
+ from transformers import AutoTokenizer, AutoProcessor, SiglipModel
10
+ import math
11
+ from transformers import AutoModelForImageSegmentation
12
+ from einops import rearrange
13
+
14
+ from torchvision import transforms
15
+ from PIL import Image
16
+ from torch.cuda.amp import autocast
17
+
18
+
19
+
20
+ def create_person_cross_attention_mask_varlen(
21
+ batch_size, img_len, id_len,
22
+ bbox_lists, original_width, original_height,
23
+ max_num_ids=2, # Default to support 2 identities
24
+ vae_scale_factor=8, patch_size=2, num_heads = 24
25
+ ):
26
+ """
27
+ Create boolean attention masks limiting image tokens to interact only with corresponding person ID tokens
28
+
29
+ Parameters:
30
+ - batch_size: Number of samples in batch
31
+ - num_heads: Number of attention heads
32
+ - img_len: Length of image token sequence
33
+ - id_len: Length of EACH identity embedding (not total)
34
+ - bbox_lists: List where bbox_lists[i] contains all bboxes for batch i
35
+ Each batch may have a different number of bboxes/identities
36
+ - max_num_ids: Maximum number of identities to support (for padding)
37
+ - original_width/height: Original image dimensions
38
+ - vae_scale_factor: VAE downsampling factor (default 8)
39
+ - patch_size: Patch size for token creation (default 2)
40
+
41
+ Returns:
42
+ - Boolean attention mask of shape [batch_size, num_heads, img_len, total_id_len]
43
+ """
44
+ # Total length of ID tokens based on maximum number of identities
45
+ total_id_len = max_num_ids * id_len
46
+
47
+ # Initialize mask to block all attention
48
+ mask = torch.zeros((batch_size, num_heads, img_len, total_id_len), dtype=torch.bool)
49
+
50
+ # Calculate VAE dimensions
51
+ latent_width = original_width // vae_scale_factor
52
+ latent_height = original_height // vae_scale_factor
53
+ patches_width = latent_width // patch_size
54
+ patches_height = latent_height // patch_size
55
+
56
+
57
+
58
+ # Convert boundary box to token indices
59
+ def bbox_to_token_indices(bbox):
60
+ x1, y1, x2, y2 = bbox
61
+
62
+ # Convert to patch space coordinates
63
+ if isinstance(x1, torch.Tensor):
64
+ x1_patch = max(0, int(x1.item()) // vae_scale_factor // patch_size)
65
+ y1_patch = max(0, int(y1.item()) // vae_scale_factor // patch_size)
66
+ x2_patch = min(patches_width, math.ceil(int(x2.item()) / vae_scale_factor / patch_size))
67
+ y2_patch = min(patches_height, math.ceil(int(y2.item()) / vae_scale_factor / patch_size))
68
+ elif isinstance(x1, int):
69
+ x1_patch = max(0, x1 // vae_scale_factor // patch_size)
70
+ y1_patch = max(0, y1 // vae_scale_factor // patch_size)
71
+ x2_patch = min(patches_width, math.ceil(x2 / vae_scale_factor / patch_size))
72
+ y2_patch = min(patches_height, math.ceil(y2 / vae_scale_factor / patch_size))
73
+ elif isinstance(x1, float):
74
+ x1_patch = max(0, int(x1) // vae_scale_factor // patch_size)
75
+ y1_patch = max(0, int(y1) // vae_scale_factor // patch_size)
76
+ x2_patch = min(patches_width, math.ceil(x2 / vae_scale_factor / patch_size))
77
+ y2_patch = min(patches_height, math.ceil(y2 / vae_scale_factor / patch_size))
78
+ else:
79
+ raise TypeError(f"Unsupported type: {type(x1)}")
80
+
81
+ # Create list of all token indices in this region
82
+ indices = []
83
+ for y in range(y1_patch, y2_patch):
84
+ for x in range(x1_patch, x2_patch):
85
+ idx = y * patches_width + x
86
+ indices.append(idx)
87
+
88
+ return indices
89
+
90
+ for b in range(batch_size):
91
+ # Get all bboxes for this batch item
92
+ batch_bboxes = bbox_lists[b] if b < len(bbox_lists) else []
93
+
94
+ # Process each bbox in the batch up to max_num_ids
95
+ for identity_idx, bbox in enumerate(batch_bboxes[:max_num_ids]):
96
+ # Get image token indices for this bbox
97
+ image_indices = bbox_to_token_indices(bbox)
98
+
99
+ # Calculate ID token slice for this identity
100
+ id_start = identity_idx * id_len
101
+ id_end = id_start + id_len
102
+ id_slice = slice(id_start, id_end)
103
+
104
+ # Enable attention between this region's image tokens and the identity's tokens
105
+ for h in range(num_heads):
106
+ for idx in image_indices:
107
+ mask[b, h, idx, id_slice] = True
108
+
109
+ return mask
110
+
111
+
112
+
113
+
114
+ # FFN
115
+ def FeedForward(dim, mult=4):
116
+ inner_dim = int(dim * mult)
117
+ return nn.Sequential(
118
+ nn.LayerNorm(dim),
119
+ nn.Linear(dim, inner_dim, bias=False),
120
+ nn.GELU(),
121
+ nn.Linear(inner_dim, dim, bias=False),
122
+ )
123
+
124
+
125
+
126
+ @dataclass
127
+ class FluxParams:
128
+ in_channels: int
129
+ vec_in_dim: int
130
+ context_in_dim: int
131
+ hidden_size: int
132
+ mlp_ratio: float
133
+ num_heads: int
134
+ depth: int
135
+ depth_single_blocks: int
136
+ axes_dim: list[int]
137
+ theta: int
138
+ qkv_bias: bool
139
+ guidance_embed: bool
140
+
141
+
142
+ class SiglipEmbedding(nn.Module):
143
+ def __init__(self, siglip_path = "google/siglip-base-patch16-256-i18n", use_matting=False):
144
+ super().__init__()
145
+ self.model = SiglipModel.from_pretrained(siglip_path).vision_model.to(torch.bfloat16)
146
+ self.processor = AutoProcessor.from_pretrained(siglip_path)
147
+ self.model.to(torch.cuda.current_device())
148
+
149
+ # BiRefNet matting setup
150
+ self.use_matting = use_matting
151
+ if self.use_matting:
152
+ self.birefnet = AutoModelForImageSegmentation.from_pretrained(
153
+ 'briaai/RMBG-2.0', trust_remote_code=True).to(torch.cuda.current_device(), dtype=torch.bfloat16)
154
+ # Apply half precision to the entire model after loading
155
+ self.matting_transform = transforms.Compose([
156
+ # transforms.Resize((512, 512)),
157
+ transforms.ToTensor(),
158
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
159
+ ])
160
+
161
+ def apply_matting(self, image):
162
+ """Apply BiRefNet matting to remove background from image"""
163
+ if not self.use_matting:
164
+ return image
165
+
166
+ # Convert to input format and move to GPU
167
+ input_image = self.matting_transform(image).unsqueeze(0).to(torch.cuda.current_device(), dtype=torch.bfloat16)
168
+
169
+ # Generate prediction
170
+ with torch.no_grad(), autocast(dtype=torch.bfloat16):
171
+ preds = self.birefnet(input_image)[-1].sigmoid().cpu()
172
+
173
+ # Process the mask
174
+ pred = preds[0].squeeze().float()
175
+ pred_pil = transforms.ToPILImage()(pred)
176
+ mask = pred_pil.resize(image.size)
177
+ binary_mask = mask.convert("L")
178
+
179
+ # Create a new image with black background
180
+ result = Image.new("RGB", image.size, (0, 0, 0))
181
+ result.paste(image, (0, 0), binary_mask)
182
+
183
+
184
+ return result
185
+
186
+
187
+ def get_id_embedding(self, refimage):
188
+ '''
189
+ refimage is a list (batch) of list (num of person) of PIL images
190
+ considering the whole batch, the number of person is fixed
191
+ '''
192
+ siglip_embedding = []
193
+
194
+
195
+ if isinstance(refimage, list):
196
+ batch_size = len(refimage)
197
+ for batch_idx, refimage_batch in enumerate(refimage):
198
+ # Apply matting if enabled
199
+ if self.use_matting:
200
+
201
+ processed_images = [self.apply_matting(img) for img in refimage_batch]
202
+ else:
203
+ processed_images = refimage_batch
204
+
205
+ pixel_values = self.processor(images=processed_images, return_tensors="pt").pixel_values
206
+ # device
207
+ pixel_values = pixel_values.to(torch.cuda.current_device(), dtype=torch.bfloat16)
208
+ last_hidden_state = self.model(pixel_values).last_hidden_state # 2, 256 768
209
+ # pooled_output = self.model(pixel_values).pooler_output # 2, 768
210
+ siglip_embedding.append(last_hidden_state)
211
+ # siglip_embedding.append(pooled_output) # 2, 768
212
+ siglip_embedding = torch.stack(siglip_embedding, dim=0) # shape ([batch_size, num_of_person, 256, 768])
213
+
214
+ if batch_size < 4:
215
+ # run additional times to avoid the first time cuda memory allocation overhead
216
+ for _ in range(4 - batch_size):
217
+ pixel_values = self.processor(images=processed_images, return_tensors="pt").pixel_values
218
+ # device
219
+ pixel_values = pixel_values.to(torch.cuda.current_device(), dtype=torch.bfloat16)
220
+ last_hidden_state = self.model(pixel_values).last_hidden_state
221
+
222
+ elif isinstance(refimage, torch.Tensor):
223
+ # refimage is a tensor of shape (batch_size, num_of_person, 3, H, W)
224
+ batch_size, num_of_person, C, H, W = refimage.shape
225
+ refimage = refimage.view(batch_size * num_of_person, C, H, W)
226
+ refimage = refimage.to(torch.cuda.current_device(), dtype=torch.bfloat16)
227
+ last_hidden_state = self.model(refimage).last_hidden_state
228
+ siglip_embedding = last_hidden_state.view(batch_size, num_of_person, 256, 768)
229
+
230
+ return siglip_embedding
231
+
232
+ def forward(self, refimage):
233
+ return self.get_id_embedding(refimage)
234
+
235
+ class Flux(nn.Module):
236
+ """
237
+ Transformer model for flow matching on sequences.
238
+ """
239
+ _supports_gradient_checkpointing = True
240
+
241
+ def __init__(self, params: FluxParams):
242
+ super().__init__()
243
+
244
+ self.params = params
245
+ self.in_channels = params.in_channels
246
+ self.out_channels = self.in_channels
247
+ if params.hidden_size % params.num_heads != 0:
248
+ raise ValueError(
249
+ f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
250
+ )
251
+ pe_dim = params.hidden_size // params.num_heads
252
+ if sum(params.axes_dim) != pe_dim:
253
+ raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
254
+ self.hidden_size = params.hidden_size
255
+ self.num_heads = params.num_heads
256
+ self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
257
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
258
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
259
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
260
+ self.guidance_in = (
261
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
262
+ )
263
+ self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
264
+
265
+ self.double_blocks = nn.ModuleList(
266
+ [
267
+ DoubleStreamBlock(
268
+ self.hidden_size,
269
+ self.num_heads,
270
+ mlp_ratio=params.mlp_ratio,
271
+ qkv_bias=params.qkv_bias,
272
+ )
273
+ for _ in range(params.depth)
274
+ ]
275
+ )
276
+
277
+ self.single_blocks = nn.ModuleList(
278
+ [
279
+ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
280
+ for _ in range(params.depth_single_blocks)
281
+ ]
282
+ )
283
+
284
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
285
+ self.gradient_checkpointing = False
286
+
287
+
288
+
289
+
290
+ # use cross attention
291
+ self.ipa_arc = nn.ModuleList([
292
+ PerceiverAttentionCA(dim=self.hidden_size, kv_dim=self.hidden_size, heads=self.num_heads)
293
+ for _ in range(self.params.depth_single_blocks + self.params.depth)
294
+ ])
295
+ self.ipa_sig = nn.ModuleList([
296
+ PerceiverAttentionCA(dim=self.hidden_size, kv_dim=self.hidden_size, heads=self.num_heads)
297
+ for _ in range(self.params.depth_single_blocks + self.params.depth)
298
+ ])
299
+
300
+
301
+
302
+ self.arcface_in_arc = nn.Sequential(
303
+ nn.Linear(512, 4 * self.hidden_size, bias=True),
304
+ nn.GELU(),
305
+ nn.LayerNorm(4 * self.hidden_size),
306
+ nn.Linear(4 * self.hidden_size, 8 * self.hidden_size, bias=True),
307
+ )
308
+
309
+
310
+ self.arcface_in_sig = nn.Sequential(
311
+ nn.Linear(512, 4 * self.hidden_size, bias=True),
312
+ nn.GELU(),
313
+ nn.LayerNorm(4 * self.hidden_size),
314
+ nn.Linear(4 * self.hidden_size, 8 * self.hidden_size, bias=True),
315
+ )
316
+
317
+ self.siglip_in_sig = nn.Sequential(
318
+ nn.Linear(768, self.hidden_size, bias=True),
319
+ nn.GELU(),
320
+ nn.LayerNorm(self.hidden_size),
321
+ nn.Linear(self.hidden_size, self.hidden_size, bias=True),
322
+ )
323
+
324
+
325
+ def lq_in_arc(self, txt_lq, siglip_embeddings, arcface_embeddings):
326
+ """
327
+ Process the siglip and arcface embeddings.
328
+ """
329
+
330
+ # shape of arcface: (num_refs, bs, 512)
331
+ arcface_embeddings = self.arcface_in_arc(arcface_embeddings)
332
+ # shape of arcface: (num_refs, bs, 4*hidden_size)
333
+ # 4*hidden_size -> 4 tokens of hidden_size
334
+ arcface_embeddings = rearrange(arcface_embeddings, 'b n (t d) -> b n t d', t=8, d=self.hidden_size)
335
+ # (num_ref, tokens, hidden_size) -> (bs, num_refs*tokens, hidden_size)
336
+
337
+
338
+ arcface_embeddings = arcface_embeddings.permute(1, 0, 2, 3) # (n, b, t, d) -> (b, n, t, d)
339
+
340
+ arcface_embeddings = rearrange(arcface_embeddings, 'b n t d -> b (n t) d')
341
+
342
+
343
+
344
+ return arcface_embeddings
345
+
346
+ def lq_in_sig(self, txt_lq, siglip_embeddings, arcface_embeddings):
347
+ """
348
+ Process the siglip and arcface embeddings.
349
+ """
350
+
351
+
352
+ # shape of arcface: (num_refs, bs, 512)
353
+ arcface_embeddings = self.arcface_in_sig(arcface_embeddings)
354
+
355
+ arcface_embeddings = rearrange(arcface_embeddings, 'b n (t d) -> b n t d', t=8, d=self.hidden_size)
356
+ # (num_ref, tokens, hidden_size) -> (bs, num_refs*tokens, hidden_size)
357
+
358
+ arcface_embeddings = arcface_embeddings.permute(1, 0, 2, 3) # (n, b, t, d) -> (b, n, t, d)
359
+
360
+ siglip_embeddings = self.siglip_in_sig(siglip_embeddings) # (bs, num_refs, 256, 768) -> (bs, num_refs, 4*hidden_size)
361
+
362
+ # concat in token dimension
363
+ arcface_embeddings = torch.cat((siglip_embeddings, arcface_embeddings), dim=2) # (bs, num_refs, 4, hidden_size) cat (bs, num_refs, 4, hidden_size) -> (bs, num_refs, 8, hidden_size)
364
+
365
+
366
+ arcface_embeddings = rearrange(arcface_embeddings, 'b n t d -> b (n t) d')
367
+ return arcface_embeddings
368
+
369
+
370
+
371
+ def _set_gradient_checkpointing(self, module, value=False):
372
+ if hasattr(module, "gradient_checkpointing"):
373
+ module.gradient_checkpointing = value
374
+
375
+ @property
376
+ def attn_processors(self):
377
+ # set recursively
378
+ processors = {} # type: dict[str, nn.Module]
379
+
380
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors):
381
+ if hasattr(module, "set_processor"):
382
+ processors[f"{name}.processor"] = module.processor
383
+
384
+ for sub_name, child in module.named_children():
385
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
386
+
387
+ return processors
388
+
389
+ for name, module in self.named_children():
390
+ fn_recursive_add_processors(name, module, processors)
391
+
392
+ return processors
393
+
394
+ def set_attn_processor(self, processor):
395
+ r"""
396
+ Sets the attention processor to use to compute attention.
397
+
398
+ Parameters:
399
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
400
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
401
+ for **all** `Attention` layers.
402
+
403
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
404
+ processor. This is strongly recommended when setting trainable attention processors.
405
+
406
+ """
407
+ count = len(self.attn_processors.keys())
408
+
409
+ if isinstance(processor, dict) and len(processor) != count:
410
+ raise ValueError(
411
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
412
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
413
+ )
414
+
415
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
416
+ if hasattr(module, "set_processor"):
417
+ if not isinstance(processor, dict):
418
+ module.set_processor(processor)
419
+ else:
420
+ module.set_processor(processor.pop(f"{name}.processor"))
421
+
422
+ for sub_name, child in module.named_children():
423
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
424
+
425
+ for name, module in self.named_children():
426
+ fn_recursive_attn_processor(name, module, processor)
427
+
428
+
429
+
430
+ def forward(
431
+ self,
432
+ img: Tensor,
433
+ img_ids: Tensor,
434
+ txt: Tensor,
435
+ txt_ids: Tensor,
436
+ timesteps: Tensor,
437
+ y: Tensor,
438
+ guidance: Tensor | None = None,
439
+ siglip_embeddings: Tensor | None = None, # (bs, num_refs, 256, 768)
440
+ arcface_embeddings: Tensor | None = None, # (bs, num_refs, 512)
441
+ bbox_lists: list | None = None, # list of list of bboxes, bbox_lists[i] is for the i-th batch, each has different number of bboxes (ids), which should align with the dim1 of arcface_embeddings. This is used to replace bbox_A and bbox_B, which should be discarded, but remained for compatibility.
442
+ use_mask: bool = True,
443
+ id_weight: float = 1.0,
444
+ siglip_weight: float = 1.0,
445
+ siglip_mask = None,
446
+ arc_mask = None,
447
+
448
+ img_height: int = 512,
449
+ img_width: int = 512,
450
+ ) -> Tensor:
451
+ if img.ndim != 3 or txt.ndim != 3:
452
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
453
+
454
+ # running on sequences img
455
+ img = self.img_in(img)
456
+ vec = self.time_in(timestep_embedding(timesteps, 256))
457
+ if self.params.guidance_embed:
458
+ if guidance is None:
459
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
460
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
461
+ vec = vec + self.vector_in(y)
462
+ txt = self.txt_in(txt)
463
+
464
+
465
+
466
+
467
+ text_length = txt.shape[1]
468
+ img_length = img.shape[1]
469
+
470
+ img_end = img.shape[1]
471
+
472
+
473
+ use_ip = arcface_embeddings is not None
474
+
475
+ if use_ip:
476
+
477
+ id_embeddings = self.lq_in_arc(None, siglip_embeddings, arcface_embeddings)
478
+ siglip_embeddings = self.lq_in_sig(None, siglip_embeddings, arcface_embeddings)
479
+
480
+ text_length = txt.shape[1] # update text_length after adding learnable query
481
+
482
+
483
+ # 8 tokens for arcface, 256 tokens for siglip
484
+ id_len = 8
485
+ siglip_len = 256 + 8
486
+
487
+
488
+
489
+ if bbox_lists is not None and use_mask and (arc_mask is None or siglip_mask is None):
490
+ arc_mask = create_person_cross_attention_mask_varlen(
491
+ batch_size=img.shape[0],
492
+ num_heads=self.params.num_heads,
493
+ # txt_len=text_length,
494
+ img_len=img_length,
495
+ id_len=id_len,
496
+ bbox_lists=bbox_lists,
497
+ max_num_ids=len(bbox_lists[0]),
498
+ original_width=img_width,
499
+ original_height= img_height,
500
+ ).to(img.device)
501
+ siglip_mask = create_person_cross_attention_mask_varlen(
502
+ batch_size=img.shape[0],
503
+ num_heads=self.params.num_heads,
504
+ # txt_len=text_length,
505
+ img_len=img_length,
506
+ id_len=siglip_len,
507
+ bbox_lists=bbox_lists,
508
+ max_num_ids=len(bbox_lists[0]),
509
+ original_width=img_width,
510
+ original_height= img_height,
511
+ ).to(img.device)
512
+ else:
513
+ arc_mask = None
514
+ siglip_mask = None
515
+
516
+
517
+
518
+ # update text_ids and id_ids
519
+ txt_ids = torch.zeros((txt.shape[0], text_length, 3)).to(img_ids.device) # (bs, T, 3)
520
+
521
+ ids = torch.cat((txt_ids, img_ids), dim=1) # (bs, T + I + ID, 3)
522
+
523
+
524
+ pe = self.pe_embedder(ids)
525
+
526
+ # ipa
527
+ ipa_idx = 0
528
+
529
+ for index_block, block in enumerate(self.double_blocks):
530
+ if self.training and self.gradient_checkpointing:
531
+ img, txt = torch.utils.checkpoint.checkpoint(
532
+ block,
533
+ img=img,
534
+ txt=txt,
535
+ vec=vec,
536
+ pe=pe,
537
+ # mask=mask,
538
+ text_length=text_length,
539
+ image_length=img_length,
540
+ # return_map = False,
541
+ use_reentrant=False,
542
+ )
543
+
544
+
545
+
546
+ else:
547
+ img, txt= block(
548
+ img=img,
549
+ txt=txt,
550
+ vec=vec,
551
+ pe=pe,
552
+ text_length=text_length,
553
+ image_length=img_length,
554
+ # return_map=False,
555
+ )
556
+
557
+
558
+ if use_ip:
559
+
560
+ img = img + id_weight * self.ipa_arc[ipa_idx](id_embeddings, img, mask=arc_mask) + siglip_weight * self.ipa_sig[ipa_idx](siglip_embeddings, img, mask=siglip_mask)
561
+ ipa_idx += 1
562
+
563
+
564
+
565
+
566
+
567
+
568
+
569
+ # for block in self.single_blocks:
570
+ img = torch.cat((txt, img), 1)
571
+
572
+
573
+ for index_block, block in enumerate(self.single_blocks):
574
+ if self.training and self.gradient_checkpointing:
575
+ img = torch.utils.checkpoint.checkpoint(
576
+ block,
577
+ img, vec=vec, pe=pe, #mask=mask,
578
+ text_length=text_length,
579
+ image_length=img_length,
580
+ return_map=False,
581
+ use_reentrant=False
582
+ )
583
+
584
+ else:
585
+ img = block(img, vec=vec, pe=pe,text_length=text_length, image_length=img_length, return_map=False)
586
+
587
+
588
+
589
+
590
+ # IPA
591
+ if use_ip:
592
+ txt, real_img = img[:, :text_length, :], img[:, text_length:, :]
593
+
594
+ id_ca = id_weight * self.ipa_arc[ipa_idx](id_embeddings, real_img, mask=arc_mask) + siglip_weight * self.ipa_sig[ipa_idx](siglip_embeddings, real_img, mask=siglip_mask)
595
+
596
+ real_img = real_img + id_ca
597
+ img = torch.cat((txt, real_img), dim=1)
598
+ ipa_idx += 1
599
+
600
+
601
+
602
+
603
+
604
+ img = img[:, txt.shape[1] :, ...]
605
+ # index img
606
+ img = img[:, :img_end, ...]
607
+
608
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
609
+
610
+ return img
withanyone/flux/modules/__pycache__/autoencoder.cpython-310.pyc ADDED
Binary file (9.09 kB). View file
 
withanyone/flux/modules/__pycache__/conditioner.cpython-310.pyc ADDED
Binary file (1.52 kB). View file
 
withanyone/flux/modules/__pycache__/layers.cpython-310.pyc ADDED
Binary file (18 kB). View file
 
withanyone/flux/modules/autoencoder.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2
+ # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from dataclasses import dataclass
17
+
18
+ import torch
19
+ from einops import rearrange
20
+ from torch import Tensor, nn
21
+
22
+
23
+ @dataclass
24
+ class AutoEncoderParams:
25
+ resolution: int
26
+ in_channels: int
27
+ ch: int
28
+ out_ch: int
29
+ ch_mult: list[int]
30
+ num_res_blocks: int
31
+ z_channels: int
32
+ scale_factor: float
33
+ shift_factor: float
34
+
35
+
36
+ def swish(x: Tensor) -> Tensor:
37
+ return x * torch.sigmoid(x)
38
+
39
+
40
+ class AttnBlock(nn.Module):
41
+ def __init__(self, in_channels: int):
42
+ super().__init__()
43
+ self.in_channels = in_channels
44
+
45
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
46
+
47
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
48
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
49
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
50
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
51
+
52
+ def attention(self, h_: Tensor) -> Tensor:
53
+ h_ = self.norm(h_)
54
+ q = self.q(h_)
55
+ k = self.k(h_)
56
+ v = self.v(h_)
57
+
58
+ b, c, h, w = q.shape
59
+ q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
60
+ k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
61
+ v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
62
+ h_ = nn.functional.scaled_dot_product_attention(q, k, v)
63
+
64
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
65
+
66
+ def forward(self, x: Tensor) -> Tensor:
67
+ return x + self.proj_out(self.attention(x))
68
+
69
+
70
+ class ResnetBlock(nn.Module):
71
+ def __init__(self, in_channels: int, out_channels: int):
72
+ super().__init__()
73
+ self.in_channels = in_channels
74
+ out_channels = in_channels if out_channels is None else out_channels
75
+ self.out_channels = out_channels
76
+
77
+ self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
78
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
79
+ self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
80
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
81
+ if self.in_channels != self.out_channels:
82
+ self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
83
+
84
+ def forward(self, x):
85
+ h = x
86
+ h = self.norm1(h)
87
+ h = swish(h)
88
+ h = self.conv1(h)
89
+
90
+ h = self.norm2(h)
91
+ h = swish(h)
92
+ h = self.conv2(h)
93
+
94
+ if self.in_channels != self.out_channels:
95
+ x = self.nin_shortcut(x)
96
+
97
+ return x + h
98
+
99
+
100
+ class Downsample(nn.Module):
101
+ def __init__(self, in_channels: int):
102
+ super().__init__()
103
+ # no asymmetric padding in torch conv, must do it ourselves
104
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
105
+
106
+ def forward(self, x: Tensor):
107
+ pad = (0, 1, 0, 1)
108
+ x = nn.functional.pad(x, pad, mode="constant", value=0)
109
+ x = self.conv(x)
110
+ return x
111
+
112
+
113
+ class Upsample(nn.Module):
114
+ def __init__(self, in_channels: int):
115
+ super().__init__()
116
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
117
+
118
+ def forward(self, x: Tensor):
119
+ x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
120
+ x = self.conv(x)
121
+ return x
122
+
123
+
124
+ class Encoder(nn.Module):
125
+ def __init__(
126
+ self,
127
+ resolution: int,
128
+ in_channels: int,
129
+ ch: int,
130
+ ch_mult: list[int],
131
+ num_res_blocks: int,
132
+ z_channels: int,
133
+ ):
134
+ super().__init__()
135
+ self.ch = ch
136
+ self.num_resolutions = len(ch_mult)
137
+ self.num_res_blocks = num_res_blocks
138
+ self.resolution = resolution
139
+ self.in_channels = in_channels
140
+ # downsampling
141
+ self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
142
+
143
+ curr_res = resolution
144
+ in_ch_mult = (1,) + tuple(ch_mult)
145
+ self.in_ch_mult = in_ch_mult
146
+ self.down = nn.ModuleList()
147
+ block_in = self.ch
148
+ for i_level in range(self.num_resolutions):
149
+ block = nn.ModuleList()
150
+ attn = nn.ModuleList()
151
+ block_in = ch * in_ch_mult[i_level]
152
+ block_out = ch * ch_mult[i_level]
153
+ for _ in range(self.num_res_blocks):
154
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
155
+ block_in = block_out
156
+ down = nn.Module()
157
+ down.block = block
158
+ down.attn = attn
159
+ if i_level != self.num_resolutions - 1:
160
+ down.downsample = Downsample(block_in)
161
+ curr_res = curr_res // 2
162
+ self.down.append(down)
163
+
164
+ # middle
165
+ self.mid = nn.Module()
166
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
167
+ self.mid.attn_1 = AttnBlock(block_in)
168
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
169
+
170
+ # end
171
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
172
+ self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
173
+
174
+ def forward(self, x: Tensor) -> Tensor:
175
+ # downsampling
176
+ hs = [self.conv_in(x)]
177
+ for i_level in range(self.num_resolutions):
178
+ for i_block in range(self.num_res_blocks):
179
+ h = self.down[i_level].block[i_block](hs[-1])
180
+ if len(self.down[i_level].attn) > 0:
181
+ h = self.down[i_level].attn[i_block](h)
182
+ hs.append(h)
183
+ if i_level != self.num_resolutions - 1:
184
+ hs.append(self.down[i_level].downsample(hs[-1]))
185
+
186
+ # middle
187
+ h = hs[-1]
188
+ h = self.mid.block_1(h)
189
+ h = self.mid.attn_1(h)
190
+ h = self.mid.block_2(h)
191
+ # end
192
+ h = self.norm_out(h)
193
+ h = swish(h)
194
+ h = self.conv_out(h)
195
+ return h
196
+
197
+
198
+ class Decoder(nn.Module):
199
+ def __init__(
200
+ self,
201
+ ch: int,
202
+ out_ch: int,
203
+ ch_mult: list[int],
204
+ num_res_blocks: int,
205
+ in_channels: int,
206
+ resolution: int,
207
+ z_channels: int,
208
+ ):
209
+ super().__init__()
210
+ self.ch = ch
211
+ self.num_resolutions = len(ch_mult)
212
+ self.num_res_blocks = num_res_blocks
213
+ self.resolution = resolution
214
+ self.in_channels = in_channels
215
+ self.ffactor = 2 ** (self.num_resolutions - 1)
216
+
217
+ # compute in_ch_mult, block_in and curr_res at lowest res
218
+ block_in = ch * ch_mult[self.num_resolutions - 1]
219
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
220
+ self.z_shape = (1, z_channels, curr_res, curr_res)
221
+
222
+ # z to block_in
223
+ self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
224
+
225
+ # middle
226
+ self.mid = nn.Module()
227
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
228
+ self.mid.attn_1 = AttnBlock(block_in)
229
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
230
+
231
+ # upsampling
232
+ self.up = nn.ModuleList()
233
+ for i_level in reversed(range(self.num_resolutions)):
234
+ block = nn.ModuleList()
235
+ attn = nn.ModuleList()
236
+ block_out = ch * ch_mult[i_level]
237
+ for _ in range(self.num_res_blocks + 1):
238
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
239
+ block_in = block_out
240
+ up = nn.Module()
241
+ up.block = block
242
+ up.attn = attn
243
+ if i_level != 0:
244
+ up.upsample = Upsample(block_in)
245
+ curr_res = curr_res * 2
246
+ self.up.insert(0, up) # prepend to get consistent order
247
+
248
+ # end
249
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
250
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
251
+
252
+ def forward(self, z: Tensor) -> Tensor:
253
+ # z to block_in
254
+ h = self.conv_in(z)
255
+
256
+ # middle
257
+ h = self.mid.block_1(h)
258
+ h = self.mid.attn_1(h)
259
+ h = self.mid.block_2(h)
260
+
261
+ # upsampling
262
+ for i_level in reversed(range(self.num_resolutions)):
263
+ for i_block in range(self.num_res_blocks + 1):
264
+ h = self.up[i_level].block[i_block](h)
265
+ if len(self.up[i_level].attn) > 0:
266
+ h = self.up[i_level].attn[i_block](h)
267
+ if i_level != 0:
268
+ h = self.up[i_level].upsample(h)
269
+
270
+ # end
271
+ h = self.norm_out(h)
272
+ h = swish(h)
273
+ h = self.conv_out(h)
274
+ return h
275
+
276
+
277
+ class DiagonalGaussian(nn.Module):
278
+ def __init__(self, sample: bool = True, chunk_dim: int = 1):
279
+ super().__init__()
280
+ self.sample = sample
281
+ self.chunk_dim = chunk_dim
282
+
283
+ def forward(self, z: Tensor) -> Tensor:
284
+ mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
285
+ if self.sample:
286
+ std = torch.exp(0.5 * logvar)
287
+ return mean + std * torch.randn_like(mean)
288
+ else:
289
+ return mean
290
+
291
+
292
+ class AutoEncoder(nn.Module):
293
+ def __init__(self, params: AutoEncoderParams):
294
+ super().__init__()
295
+ self.encoder = Encoder(
296
+ resolution=params.resolution,
297
+ in_channels=params.in_channels,
298
+ ch=params.ch,
299
+ ch_mult=params.ch_mult,
300
+ num_res_blocks=params.num_res_blocks,
301
+ z_channels=params.z_channels,
302
+ )
303
+ self.decoder = Decoder(
304
+ resolution=params.resolution,
305
+ in_channels=params.in_channels,
306
+ ch=params.ch,
307
+ out_ch=params.out_ch,
308
+ ch_mult=params.ch_mult,
309
+ num_res_blocks=params.num_res_blocks,
310
+ z_channels=params.z_channels,
311
+ )
312
+ self.reg = DiagonalGaussian()
313
+
314
+ self.scale_factor = params.scale_factor
315
+ self.shift_factor = params.shift_factor
316
+
317
+ def encode(self, x: Tensor) -> Tensor:
318
+ z = self.reg(self.encoder(x))
319
+ z = self.scale_factor * (z - self.shift_factor)
320
+ return z
321
+
322
+ def decode(self, z: Tensor) -> Tensor:
323
+ z = z / self.scale_factor + self.shift_factor
324
+ return self.decoder(z)
325
+
326
+ def forward(self, x: Tensor) -> Tensor:
327
+ return self.decode(self.encode(x))
withanyone/flux/modules/conditioner.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2
+ # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from torch import Tensor, nn
17
+ from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel,
18
+ T5Tokenizer)
19
+
20
+
21
+ class HFEmbedder(nn.Module):
22
+ def __init__(self, version: str, max_length: int, **hf_kwargs):
23
+ super().__init__()
24
+ self.is_clip = "clip" in version.lower()
25
+ self.max_length = max_length
26
+ self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
27
+
28
+ if self.is_clip:
29
+ self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
30
+ self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
31
+ else:
32
+ self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
33
+ self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
34
+
35
+ self.hf_module = self.hf_module.eval().requires_grad_(False)
36
+
37
+ def forward(self, text: list[str]) -> Tensor:
38
+ batch_encoding = self.tokenizer(
39
+ text,
40
+ truncation=True,
41
+ max_length=self.max_length,
42
+ return_length=False,
43
+ return_overflowing_tokens=False,
44
+ padding="max_length",
45
+ return_tensors="pt",
46
+ )
47
+
48
+ outputs = self.hf_module(
49
+ input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
50
+ attention_mask=None,
51
+ output_hidden_states=False,
52
+ )
53
+ return outputs[self.output_key]
withanyone/flux/modules/layers.py ADDED
@@ -0,0 +1,530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import math
4
+ from dataclasses import dataclass
5
+
6
+ import torch
7
+ from einops import rearrange
8
+ from torch import Tensor, nn
9
+
10
+ # from ..math import attention, rope
11
+ from ..math import rope
12
+ from ..math import attention
13
+ # from ..math import attention
14
+ import torch.nn.functional as F
15
+
16
+ TOKEN_AUG_IDX = 2048
17
+
18
+ class EmbedND(nn.Module):
19
+ def __init__(self, dim: int, theta: int, axes_dim: list[int]):
20
+ super().__init__()
21
+ self.dim = dim
22
+ self.theta = theta
23
+ self.axes_dim = axes_dim
24
+
25
+ def forward(self, ids: Tensor) -> Tensor:
26
+ n_axes = ids.shape[-1]
27
+ emb = torch.cat(
28
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
29
+ dim=-3,
30
+ )
31
+
32
+ return emb.unsqueeze(1)
33
+
34
+
35
+ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
36
+ """
37
+ Create sinusoidal timestep embeddings.
38
+ :param t: a 1-D Tensor of N indices, one per batch element.
39
+ These may be fractional.
40
+ :param dim: the dimension of the output.
41
+ :param max_period: controls the minimum frequency of the embeddings.
42
+ :return: an (N, D) Tensor of positional embeddings.
43
+ """
44
+ t = time_factor * t
45
+ half = dim // 2
46
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
47
+ t.device
48
+ )
49
+
50
+ args = t[:, None].float() * freqs[None]
51
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
52
+ if dim % 2:
53
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
54
+ if torch.is_floating_point(t):
55
+ embedding = embedding.to(t)
56
+ return embedding
57
+
58
+
59
+ class MLPEmbedder(nn.Module):
60
+ def __init__(self, in_dim: int, hidden_dim: int):
61
+ super().__init__()
62
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
63
+ self.silu = nn.SiLU()
64
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
65
+
66
+ def forward(self, x: Tensor) -> Tensor:
67
+ return self.out_layer(self.silu(self.in_layer(x)))
68
+
69
+ def reshape_tensor(x, heads):
70
+ # print("x in reshape_tensor", x.shape)
71
+ bs, length, width = x.shape
72
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
73
+ x = x.view(bs, length, heads, -1)
74
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
75
+ x = x.transpose(1, 2)
76
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
77
+ x = x.reshape(bs, heads, length, -1)
78
+ return x
79
+ class PerceiverAttentionCA(nn.Module):
80
+ def __init__(self, *, dim=3072, dim_head=64, heads=16, kv_dim=2048):
81
+ super().__init__()
82
+ self.scale = dim_head ** -0.5
83
+ self.dim_head = dim_head
84
+ self.heads = heads
85
+ inner_dim = dim_head * heads
86
+
87
+ self.norm1 = nn.LayerNorm(dim if kv_dim is None else kv_dim)
88
+ self.norm2 = nn.LayerNorm(dim)
89
+
90
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
91
+ self.to_kv = nn.Linear(dim if kv_dim is None else kv_dim, inner_dim * 2, bias=False)
92
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
93
+
94
+ def forward(self, x, latents, mask=None):
95
+ """
96
+ Args:
97
+ x (torch.Tensor): image features
98
+ shape (b, n1, D)
99
+ latent (torch.Tensor): latent features
100
+ shape (b, n2, D)
101
+ """
102
+ x = self.norm1(x)
103
+ latents = self.norm2(latents)
104
+
105
+ # print("x, latents in PerceiverAttentionCA", x.shape, latents.shape)
106
+
107
+ b, seq_len, _ = latents.shape
108
+
109
+ q = self.to_q(latents)
110
+ k, v = self.to_kv(x).chunk(2, dim=-1)
111
+
112
+ # print("q, k, v in PerceiverAttentionCA", q.shape, k.shape, v.shape)
113
+
114
+ q = reshape_tensor(q, self.heads)
115
+ k = reshape_tensor(k, self.heads)
116
+ v = reshape_tensor(v, self.heads)
117
+
118
+ # # attention
119
+ # scale = 1 / math.sqrt(math.sqrt(self.dim_head))
120
+ # weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
121
+ # print("is there any nan in weight:", torch.isnan(weight).any())
122
+ # if mask is not None:
123
+ # # Mask shape should be [batch_size, num_heads, q_len, kv_len]
124
+ # # weight = weight.masked_fill(mask == 0, float("-inf"))
125
+ # if mask.dtype == torch.bool:
126
+ # # Boolean mask: False values are masked out
127
+ # # print("Got boolean mask")
128
+ # weight = weight.masked_fill(~mask, -float('inf'))
129
+ # else:
130
+ # # Float mask: values are added directly to scores
131
+ # weight = weight + mask
132
+ # print("is there any nan in weight after mask:", torch.isnan(weight).any())
133
+ # weight = torch.softmax(weight, dim=-1)
134
+ # print("is there any nan in weight after softmax:", torch.isnan(weight).any())
135
+ # out = weight @ v
136
+
137
+ # use sdpa
138
+ # if mask is not None:
139
+ # print("mask shape in PerceiverAttentionCA", mask.shape)
140
+ out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
141
+
142
+ out = out.permute(0, 2, 1, 3).reshape(b, seq_len, -1)
143
+
144
+ return self.to_out(out)
145
+
146
+
147
+
148
+
149
+ class RMSNorm(torch.nn.Module):
150
+ def __init__(self, dim: int):
151
+ super().__init__()
152
+ self.scale = nn.Parameter(torch.ones(dim))
153
+
154
+ def forward(self, x: Tensor):
155
+ x_dtype = x.dtype
156
+ x = x.float()
157
+ rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
158
+ return ((x * rrms) * self.scale.float()).to(dtype=x_dtype)
159
+
160
+
161
+ class QKNorm(torch.nn.Module):
162
+ def __init__(self, dim: int):
163
+ super().__init__()
164
+ self.query_norm = RMSNorm(dim)
165
+ self.key_norm = RMSNorm(dim)
166
+
167
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
168
+ q = self.query_norm(q)
169
+ k = self.key_norm(k)
170
+ return q.to(v), k.to(v)
171
+
172
+ class LoRALinearLayer(nn.Module):
173
+ def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
174
+ super().__init__()
175
+
176
+ self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
177
+ self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
178
+ # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
179
+ # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
180
+ self.network_alpha = network_alpha
181
+ self.rank = rank
182
+
183
+ nn.init.normal_(self.down.weight, std=1 / rank)
184
+ nn.init.zeros_(self.up.weight)
185
+
186
+ def forward(self, hidden_states):
187
+ orig_dtype = hidden_states.dtype
188
+ dtype = self.down.weight.dtype
189
+
190
+ down_hidden_states = self.down(hidden_states.to(dtype))
191
+ up_hidden_states = self.up(down_hidden_states)
192
+
193
+ if self.network_alpha is not None:
194
+ up_hidden_states *= self.network_alpha / self.rank
195
+
196
+ return up_hidden_states.to(orig_dtype)
197
+
198
+ class FLuxSelfAttnProcessor:
199
+ def __call__(self, attn, x, pe, **attention_kwargs):
200
+ qkv = attn.qkv(x)
201
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
202
+ q, k = attn.norm(q, k, v)
203
+ x = attention(q, k, v, pe=pe)
204
+ x = attn.proj(x)
205
+ return x
206
+
207
+ class LoraFluxAttnProcessor(nn.Module):
208
+
209
+ def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1):
210
+ super().__init__()
211
+ self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
212
+ self.proj_lora = LoRALinearLayer(dim, dim, rank, network_alpha)
213
+ self.lora_weight = lora_weight
214
+
215
+
216
+ def __call__(self, attn, x, pe, **attention_kwargs):
217
+ qkv = attn.qkv(x) + self.qkv_lora(x) * self.lora_weight
218
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
219
+ q, k = attn.norm(q, k, v)
220
+ x = attention(q, k, v, pe=pe)
221
+ x = attn.proj(x) + self.proj_lora(x) * self.lora_weight
222
+ return x
223
+
224
+
225
+ class SelfAttention(nn.Module):
226
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
227
+ super().__init__()
228
+ self.num_heads = num_heads
229
+ head_dim = dim // num_heads
230
+
231
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
232
+ self.norm = QKNorm(head_dim)
233
+ self.proj = nn.Linear(dim, dim)
234
+ def forward():
235
+ pass
236
+
237
+ @dataclass
238
+ class ModulationOut:
239
+ shift: Tensor
240
+ scale: Tensor
241
+ gate: Tensor
242
+
243
+
244
+ class Modulation(nn.Module):
245
+ def __init__(self, dim: int, double: bool):
246
+ super().__init__()
247
+ self.is_double = double
248
+ self.multiplier = 6 if double else 3
249
+ self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
250
+
251
+ def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
252
+ out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
253
+
254
+ return (
255
+ ModulationOut(*out[:3]),
256
+ ModulationOut(*out[3:]) if self.is_double else None,
257
+ )
258
+
259
+ class DoubleStreamBlockLoraProcessor(nn.Module):
260
+ def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1):
261
+ super().__init__()
262
+ self.qkv_lora1 = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
263
+ self.proj_lora1 = LoRALinearLayer(dim, dim, rank, network_alpha)
264
+ self.qkv_lora2 = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
265
+ self.proj_lora2 = LoRALinearLayer(dim, dim, rank, network_alpha)
266
+ self.lora_weight = lora_weight
267
+
268
+ def forward(self, attn, img, txt, vec, pe, mask, text_length, image_length, **attention_kwargs):
269
+ img_mod1, img_mod2 = attn.img_mod(vec)
270
+ txt_mod1, txt_mod2 = attn.txt_mod(vec)
271
+
272
+ # prepare image for attention
273
+ img_modulated = attn.img_norm1(img)
274
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
275
+ img_qkv = attn.img_attn.qkv(img_modulated) + self.qkv_lora1(img_modulated) * self.lora_weight
276
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
277
+ img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
278
+
279
+ # prepare txt for attention
280
+ txt_modulated = attn.txt_norm1(txt)
281
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
282
+ txt_qkv = attn.txt_attn.qkv(txt_modulated) + self.qkv_lora2(txt_modulated) * self.lora_weight
283
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
284
+ txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
285
+
286
+ # run actual attention
287
+ q = torch.cat((txt_q, img_q), dim=2)
288
+ k = torch.cat((txt_k, img_k), dim=2)
289
+ v = torch.cat((txt_v, img_v), dim=2)
290
+
291
+ attn1 = attention(q, k, v, pe=pe, mask=mask, token_aug_idx=TOKEN_AUG_IDX, text_length=text_length, image_length=image_length)
292
+ txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
293
+
294
+ # calculate the img bloks
295
+ img = img + img_mod1.gate * (attn.img_attn.proj(img_attn) + self.proj_lora1(img_attn) * self.lora_weight)
296
+ img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
297
+
298
+ # calculate the txt bloks
299
+ txt = txt + txt_mod1.gate * (attn.txt_attn.proj(txt_attn) + self.proj_lora2(txt_attn) * self.lora_weight)
300
+ txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
301
+
302
+
303
+ return img, txt
304
+
305
+ class DoubleStreamBlockProcessor:
306
+ def __call__(self, attn, img, txt, vec, pe, mask, text_length, image_length, **attention_kwargs):
307
+ img_mod1, img_mod2 = attn.img_mod(vec)
308
+ txt_mod1, txt_mod2 = attn.txt_mod(vec)
309
+
310
+ # prepare image for attention
311
+ img_modulated = attn.img_norm1(img)
312
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
313
+ img_qkv = attn.img_attn.qkv(img_modulated)
314
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
315
+ img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
316
+
317
+ # prepare txt for attention
318
+ txt_modulated = attn.txt_norm1(txt)
319
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
320
+ txt_qkv = attn.txt_attn.qkv(txt_modulated)
321
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
322
+ txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
323
+
324
+ # run actual attention
325
+ q = torch.cat((txt_q, img_q), dim=2)
326
+ k = torch.cat((txt_k, img_k), dim=2)
327
+ v = torch.cat((txt_v, img_v), dim=2)
328
+
329
+
330
+ attn1 = attention(q, k, v, pe=pe, mask=attention_kwargs.get("mask"), token_aug_idx=TOKEN_AUG_IDX,text_length=text_length, image_length=image_length)
331
+
332
+
333
+ txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
334
+
335
+ # calculate the img bloks
336
+ img = img + img_mod1.gate * attn.img_attn.proj(img_attn)
337
+ img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
338
+
339
+ # calculate the txt bloks
340
+ txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn)
341
+ txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
342
+
343
+
344
+ return img, txt
345
+
346
+ class DoubleStreamBlock(nn.Module):
347
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
348
+ super().__init__()
349
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
350
+ self.num_heads = num_heads
351
+ self.hidden_size = hidden_size
352
+ self.head_dim = hidden_size // num_heads
353
+
354
+ self.img_mod = Modulation(hidden_size, double=True)
355
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
356
+ self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
357
+
358
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
359
+ self.img_mlp = nn.Sequential(
360
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
361
+ nn.GELU(approximate="tanh"),
362
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
363
+ )
364
+
365
+ self.txt_mod = Modulation(hidden_size, double=True)
366
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
367
+ self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
368
+
369
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
370
+ self.txt_mlp = nn.Sequential(
371
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
372
+ nn.GELU(approximate="tanh"),
373
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
374
+ )
375
+ processor = DoubleStreamBlockProcessor()
376
+ self.set_processor(processor)
377
+
378
+ def set_processor(self, processor) -> None:
379
+ self.processor = processor
380
+
381
+ def get_processor(self):
382
+ return self.processor
383
+
384
+ def forward(
385
+ self,
386
+ img: Tensor,
387
+ txt: Tensor,
388
+ vec: Tensor,
389
+ pe: Tensor,
390
+ image_proj: Tensor = None,
391
+ ip_scale: float =1.0,
392
+ mask: Tensor | None = None,
393
+ text_length: int = None,
394
+ image_length: int = None,
395
+ return_map: bool = False,
396
+ **attention_kwargs
397
+ ) -> tuple[Tensor, Tensor]:
398
+ if image_proj is None:
399
+
400
+ return self.processor(self, img, txt, vec, pe, mask, text_length, image_length)
401
+ else:
402
+
403
+ return self.processor(self, img, txt, vec, pe, mask, text_length, image_length, image_proj, ip_scale)
404
+
405
+
406
+ class SingleStreamBlockLoraProcessor(nn.Module):
407
+ def __init__(self, dim: int, rank: int = 4, network_alpha = None, lora_weight: float = 1):
408
+ super().__init__()
409
+ self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
410
+ self.proj_lora = LoRALinearLayer(15360, dim, rank, network_alpha)
411
+ self.lora_weight = lora_weight
412
+
413
+ def forward(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor, mask = None, text_length = None, image_length = None, return_map=False) -> Tensor:
414
+
415
+ mod, _ = attn.modulation(vec)
416
+ x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift
417
+ qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1)
418
+ qkv = qkv + self.qkv_lora(x_mod) * self.lora_weight
419
+
420
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
421
+ q, k = attn.norm(q, k, v)
422
+
423
+ # compute attention
424
+
425
+ attn_1 = attention(q, k, v, pe=pe, mask=mask, token_aug_idx=TOKEN_AUG_IDX,text_length=text_length, image_length=image_length)
426
+
427
+ # compute activation in mlp stream, cat again and run second linear layer
428
+ output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2))
429
+ output = output + self.proj_lora(torch.cat((attn_1, attn.mlp_act(mlp)), 2)) * self.lora_weight
430
+ output = x + mod.gate * output
431
+
432
+
433
+ return output
434
+
435
+
436
+ class SingleStreamBlockProcessor:
437
+ def __call__(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor, mask: Tensor, text_length, image_length, return_map=False, **attention_kwargs) -> Tensor:
438
+
439
+ mod, _ = attn.modulation(vec)
440
+ x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift
441
+ qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1)
442
+
443
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
444
+ q, k = attn.norm(q, k, v)
445
+
446
+ # compute attention
447
+
448
+
449
+ attn_1 = attention(q, k, v, pe=pe, mask=mask, token_aug_idx=TOKEN_AUG_IDX,text_length=text_length, image_length=image_length)
450
+
451
+ # compute activation in mlp stream, cat again and run second linear layer
452
+ output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2))
453
+ output = x + mod.gate * output
454
+
455
+ return output
456
+
457
+ class SingleStreamBlock(nn.Module):
458
+ """
459
+ A DiT block with parallel linear layers as described in
460
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
461
+ """
462
+
463
+ def __init__(
464
+ self,
465
+ hidden_size: int,
466
+ num_heads: int,
467
+ mlp_ratio: float = 4.0,
468
+ qk_scale: float | None = None,
469
+ ):
470
+ super().__init__()
471
+ self.hidden_dim = hidden_size
472
+ self.num_heads = num_heads
473
+ self.head_dim = hidden_size // num_heads
474
+ self.scale = qk_scale or self.head_dim**-0.5
475
+
476
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
477
+ # qkv and mlp_in
478
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
479
+ # proj and mlp_out
480
+ self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
481
+
482
+ self.norm = QKNorm(self.head_dim)
483
+
484
+ self.hidden_size = hidden_size
485
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
486
+
487
+ self.mlp_act = nn.GELU(approximate="tanh")
488
+ self.modulation = Modulation(hidden_size, double=False)
489
+
490
+ processor = SingleStreamBlockProcessor()
491
+ self.set_processor(processor)
492
+
493
+
494
+ def set_processor(self, processor) -> None:
495
+ self.processor = processor
496
+
497
+ def get_processor(self):
498
+ return self.processor
499
+
500
+ def forward(
501
+ self,
502
+ x: Tensor,
503
+ vec: Tensor,
504
+ pe: Tensor,
505
+ image_proj: Tensor | None = None,
506
+ ip_scale: float = 1.0,
507
+ mask: Tensor | None = None,
508
+ text_length: int | None = None,
509
+ image_length: int | None = None,
510
+ return_map: bool = False,
511
+ ) -> Tensor:
512
+ if image_proj is None:
513
+ return self.processor(self, x, vec, pe, mask, text_length=text_length, image_length=image_length)
514
+ else:
515
+ return self.processor(self, x, vec, pe, mask, image_proj, ip_scale, text_length=text_length, image_length=image_length)
516
+
517
+
518
+
519
+ class LastLayer(nn.Module):
520
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
521
+ super().__init__()
522
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
523
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
524
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
525
+
526
+ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
527
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
528
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
529
+ x = self.linear(x)
530
+ return x
withanyone/flux/pipeline.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2
+ # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ from typing import Literal
18
+
19
+ import torch
20
+ from einops import rearrange
21
+ from PIL import ExifTags, Image
22
+ import torchvision.transforms.functional as TVF
23
+
24
+
25
+ from withanyone.flux.modules.layers import (
26
+ DoubleStreamBlockLoraProcessor,
27
+ DoubleStreamBlockProcessor,
28
+ SingleStreamBlockLoraProcessor,
29
+ SingleStreamBlockProcessor,
30
+ )
31
+ from withanyone.flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
32
+ from withanyone.flux.util import (
33
+ load_ae,
34
+ load_clip,
35
+ load_flow_model_no_lora,
36
+ load_flow_model_diffusers,
37
+ load_t5,
38
+ )
39
+
40
+ from withanyone.flux.model import SiglipEmbedding, create_person_cross_attention_mask_varlen
41
+
42
+
43
+ def preprocess_ref(raw_image: Image.Image, long_size: int = 512):
44
+
45
+ image_w, image_h = raw_image.size
46
+
47
+ if image_w >= image_h:
48
+ new_w = long_size
49
+ new_h = int((long_size / image_w) * image_h)
50
+ else:
51
+ new_h = long_size
52
+ new_w = int((long_size / image_h) * image_w)
53
+
54
+ raw_image = raw_image.resize((new_w, new_h), resample=Image.LANCZOS)
55
+ target_w = new_w // 16 * 16
56
+ target_h = new_h // 16 * 16
57
+
58
+ left = (new_w - target_w) // 2
59
+ top = (new_h - target_h) // 2
60
+ right = left + target_w
61
+ bottom = top + target_h
62
+
63
+
64
+ raw_image = raw_image.crop((left, top, right, bottom))
65
+
66
+
67
+ raw_image = raw_image.convert("RGB")
68
+ return raw_image
69
+
70
+
71
+ from io import BytesIO
72
+ import insightface
73
+ import numpy as np
74
+ class FaceExtractor:
75
+ def __init__(self, model_path = "./"):
76
+
77
+ self.model = insightface.app.FaceAnalysis(name = "antelopev2", root=model_path, providers=['CUDAExecutionProvider'])
78
+ self.model.prepare(ctx_id=0, det_thresh=0.45)
79
+
80
+ def extract_moref(self, img, bboxes, face_size_restriction=1):
81
+ """
82
+ Extract faces from an image based on bounding boxes in JSON data.
83
+ Makes each face square and resizes to 512x512.
84
+
85
+ Args:
86
+ img: PIL Image or image data
87
+ json_data: JSON object with 'bboxes' and 'crop' information
88
+
89
+ Returns:
90
+ List of PIL Images, each 512x512, containing extracted faces
91
+ """
92
+ # Ensure img is a PIL Image
93
+ try:
94
+ if not isinstance(img, Image.Image) and not isinstance(img, torch.Tensor):
95
+ img = Image.open(BytesIO(img))
96
+
97
+ # bboxes = json_data['bboxes']
98
+ # crop = json_data['crop']
99
+ # print("len of bboxes:", len(bboxes))
100
+ # Recalculate bounding boxes based on crop info
101
+ # new_bboxes = [recalculate_bbox(bbox, crop) for bbox in bboxes]
102
+ new_bboxes = bboxes
103
+ # any of the face is less than 100 * 100, we ignore this image
104
+ for bbox in new_bboxes:
105
+ x1, y1, x2, y2 = bbox
106
+ if x2 - x1 < face_size_restriction or y2 - y1 < face_size_restriction:
107
+ return []
108
+ # print("len of new_bboxes:", len(new_bboxes))
109
+ faces = []
110
+ for bbox in new_bboxes:
111
+ # print("processing bbox")
112
+ # Convert coordinates to integers
113
+ x1, y1, x2, y2 = map(int, bbox)
114
+
115
+ # Calculate width and height
116
+ width = x2 - x1
117
+ height = y2 - y1
118
+
119
+ # Make the bounding box square by expanding the shorter dimension
120
+ if width > height:
121
+ # Height is shorter, expand it
122
+ diff = width - height
123
+ y1 -= diff // 2
124
+ y2 += diff - (diff // 2) # Handle odd differences
125
+ elif height > width:
126
+ # Width is shorter, expand it
127
+ diff = height - width
128
+ x1 -= diff // 2
129
+ x2 += diff - (diff // 2) # Handle odd differences
130
+
131
+ # Ensure coordinates are within image boundaries
132
+ img_width, img_height = img.size
133
+ x1 = max(0, x1)
134
+ y1 = max(0, y1)
135
+ x2 = min(img_width, x2)
136
+ y2 = min(img_height, y2)
137
+
138
+ # Extract face region
139
+ face_region = img.crop((x1, y1, x2, y2))
140
+
141
+ # Resize to 512x512
142
+ face_region = face_region.resize((512, 512), Image.LANCZOS)
143
+
144
+ faces.append(face_region)
145
+ # print("len of faces:", len(faces))
146
+ return faces
147
+ except Exception as e:
148
+ print(f"Error processing image: {e}")
149
+ return []
150
+
151
+ def __call__(self, img):
152
+ # if np, get PIL, else, get np
153
+ if isinstance(img, torch.Tensor):
154
+ img_np = img.cpu().numpy()
155
+ img_pil = Image.fromarray(img_np)
156
+ elif isinstance(img, Image.Image):
157
+ img_pil = img
158
+ img_np = np.array(img)
159
+ elif isinstance(img, np.ndarray):
160
+ img_np = img
161
+ img_pil = Image.fromarray(img)
162
+
163
+ else:
164
+ raise ValueError("Unsupported image format. Please provide a PIL Image or numpy array.")
165
+ # Detect faces in the image
166
+ faces = self.model.get(img_np)
167
+ # use one
168
+ if len(faces) > 0:
169
+ bboxes = []
170
+ face = faces[0]
171
+ bbox = face.bbox.astype(int)
172
+ bboxes.append(bbox)
173
+ return self.extract_moref(img_pil, bboxes)[0]
174
+ else:
175
+ print("Warning: No faces detected in the image.")
176
+ return img_pil
177
+
178
+
179
+ class WithAnyonePipeline:
180
+ def __init__(
181
+ self,
182
+ model_type: str,
183
+ ipa_path: str,
184
+ device: torch.device,
185
+ offload: bool = False,
186
+ only_lora: bool = False,
187
+ no_lora: bool = False,
188
+ lora_rank: int = 16,
189
+ face_extractor = None,
190
+ additional_lora_ckpt: str = None,
191
+ lora_weight: float = 1.0,
192
+ clip_path: str = "openai/clip-vit-large-patch14",
193
+ t5_path: str = "xlabs-ai/xflux_text_encoders",
194
+ flux_path: str = "black-forest-labs/FLUX.1-dev",
195
+ siglip_path: str = "google/siglip-base-patch16-256-i18n",
196
+
197
+ ):
198
+ self.device = device
199
+ self.offload = offload
200
+ self.model_type = model_type
201
+
202
+ self.clip = load_clip(clip_path, self.device)
203
+ self.t5 = load_t5(t5_path, self.device, max_length=512)
204
+ self.ae = load_ae(flux_path, model_type, device="cpu" if offload else self.device)
205
+ self.use_fp8 = "fp8" in model_type
206
+
207
+
208
+ if additional_lora_ckpt is not None:
209
+ self.model = load_flow_model_diffusers(
210
+ model_type,
211
+ flux_path,
212
+ ipa_path,
213
+ device="cpu" if offload else self.device,
214
+ lora_rank=lora_rank,
215
+ use_fp8=self.use_fp8,
216
+ additional_lora_ckpt=additional_lora_ckpt,
217
+ lora_weight=lora_weight,
218
+
219
+ ).to("cuda" if torch.cuda.is_available() else "cpu")
220
+ else:
221
+ self.model = load_flow_model_no_lora(
222
+ model_type,
223
+ flux_path,
224
+ ipa_path,
225
+ device="cpu" if offload else self.device,
226
+ use_fp8=self.use_fp8
227
+ )
228
+
229
+ if face_extractor is not None:
230
+ self.face_extractor = face_extractor
231
+ else:
232
+ self.face_extractor = FaceExtractor()
233
+
234
+ self.siglip = SiglipEmbedding(siglip_path=siglip_path)
235
+
236
+
237
+ def load_ckpt(self, ckpt_path):
238
+ if ckpt_path is not None:
239
+ from safetensors.torch import load_file as load_sft
240
+ print("Loading checkpoint to replace old keys")
241
+ # load_sft doesn't support torch.device
242
+ if ckpt_path.endswith('safetensors'):
243
+ sd = load_sft(ckpt_path, device='cpu')
244
+ missing, unexpected = self.model.load_state_dict(sd, strict=False, assign=True)
245
+ else:
246
+ dit_state = torch.load(ckpt_path, map_location='cpu')
247
+ sd = {}
248
+ for k in dit_state.keys():
249
+ sd[k.replace('module.','')] = dit_state[k]
250
+ missing, unexpected = self.model.load_state_dict(sd, strict=False, assign=True)
251
+ self.model.to(str(self.device))
252
+ print(f"missing keys: {missing}\n\n\n\n\nunexpected keys: {unexpected}")
253
+
254
+
255
+
256
+ def __call__(
257
+ self,
258
+ prompt: str,
259
+ width: int = 512,
260
+ height: int = 512,
261
+ guidance: float = 4,
262
+ num_steps: int = 50,
263
+ seed: int = 123456789,
264
+ **kwargs
265
+ ):
266
+ width = 16 * (width // 16)
267
+ height = 16 * (height // 16)
268
+
269
+ device_type = self.device if isinstance(self.device, str) else self.device.type
270
+ if device_type == "mps":
271
+ device_type = "cpu" # for support macos mps
272
+ with torch.autocast(enabled=self.use_fp8, device_type=device_type, dtype=torch.bfloat16):
273
+ return self.forward(
274
+ prompt,
275
+ width,
276
+ height,
277
+ guidance,
278
+ num_steps,
279
+ seed,
280
+ **kwargs
281
+ )
282
+
283
+
284
+
285
+ @torch.inference_mode
286
+ def forward(
287
+ self,
288
+ prompt: str,
289
+ width: int,
290
+ height: int,
291
+ guidance: float,
292
+ num_steps: int,
293
+ seed: int,
294
+ ref_imgs: list[Image.Image] | None = None,
295
+ arcface_embeddings: list[torch.Tensor] = None,
296
+ bboxes = None,
297
+ id_weight: float = 1.0,
298
+ siglip_weight: float = 1.0,
299
+ ):
300
+ x = get_noise(
301
+ 1, height, width, device=self.device,
302
+ dtype=torch.bfloat16, seed=seed
303
+ )
304
+ timesteps = get_schedule(
305
+ num_steps,
306
+ (width // 8) * (height // 8) // (16 * 16),
307
+ shift=True,
308
+ )
309
+ if self.offload:
310
+ self.ae.encoder = self.ae.encoder.to(self.device)
311
+
312
+ if ref_imgs is None:
313
+ siglip_embeddings = None
314
+ else:
315
+ siglip_embeddings = self.siglip(ref_imgs).to(self.device, torch.bfloat16).permute(1,0,2,3)
316
+ # num_ref, (1), n, d
317
+
318
+
319
+ if arcface_embeddings is not None:
320
+ arcface_embeddings = arcface_embeddings.unsqueeze(1)
321
+ # num_ref, 1, 512
322
+ arcface_embeddings = arcface_embeddings.to(self.device, torch.bfloat16)
323
+
324
+
325
+ if self.offload:
326
+ self.offload_model_to_cpu(self.ae.encoder)
327
+ self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device)
328
+
329
+
330
+ inp_cond = prepare(t5=self.t5, clip=self.clip,img=x,prompt=prompt
331
+ )
332
+ if self.offload:
333
+ self.offload_model_to_cpu(self.t5, self.clip)
334
+ self.model = self.model.to(self.device)
335
+
336
+
337
+
338
+ img = inp_cond["img"]
339
+ img_length = img.shape[1]
340
+ ##### create mask for siglip and arcface #####
341
+ if bboxes is not None:
342
+ arc_mask = create_person_cross_attention_mask_varlen(
343
+ batch_size=img.shape[0],
344
+ # num_heads=self.params.num_heads,
345
+ # txt_len=text_length,
346
+ img_len=img_length,
347
+ id_len=8,
348
+ bbox_lists=bboxes,
349
+ max_num_ids=len(bboxes[0]),
350
+ original_width=width,
351
+ original_height= height,
352
+ ).to(img.device)
353
+ siglip_mask = create_person_cross_attention_mask_varlen(
354
+ batch_size=img.shape[0],
355
+ # num_heads=self.params.num_heads,
356
+ # txt_len=text_length,
357
+ img_len=img_length,
358
+ id_len=256+8,
359
+ bbox_lists=bboxes,
360
+ max_num_ids=len(bboxes[0]),
361
+ original_width=width,
362
+ original_height= height,
363
+ ).to(img.device)
364
+
365
+
366
+
367
+
368
+
369
+ results = denoise(
370
+ self.model,
371
+ **inp_cond,
372
+ timesteps=timesteps,
373
+ guidance=guidance,
374
+ arcface_embeddings=arcface_embeddings,
375
+ siglip_embeddings=siglip_embeddings,
376
+ bboxes=bboxes,
377
+ id_weight=id_weight,
378
+ siglip_weight=siglip_weight,
379
+ img_height=height,
380
+ img_width=width,
381
+ arc_mask=arc_mask if bboxes is not None else None,
382
+ siglip_mask=siglip_mask if bboxes is not None else None,
383
+ )
384
+
385
+ x = results
386
+
387
+
388
+ if self.offload:
389
+ self.offload_model_to_cpu(self.model)
390
+ self.ae.decoder.to(x.device)
391
+ x = unpack(x.float(), height, width)
392
+ x = self.ae.decode(x)
393
+ self.offload_model_to_cpu(self.ae.decoder)
394
+
395
+ x1 = x.clamp(-1, 1)
396
+ x1 = rearrange(x1[-1], "c h w -> h w c")
397
+ output_img = Image.fromarray((127.5 * (x1 + 1.0)).cpu().byte().numpy())
398
+
399
+
400
+ return output_img
401
+
402
+ def offload_model_to_cpu(self, *models):
403
+ if not self.offload: return
404
+ for model in models:
405
+ model.cpu()
406
+ torch.cuda.empty_cache()
withanyone/flux/sampling.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import math
4
+ from typing import Literal
5
+
6
+ import torch
7
+ from einops import rearrange, repeat
8
+ from torch import Tensor
9
+ from tqdm import tqdm
10
+
11
+ from .model import Flux
12
+ from .modules.conditioner import HFEmbedder
13
+
14
+
15
+ def get_noise(
16
+ num_samples: int,
17
+ height: int,
18
+ width: int,
19
+ device: torch.device,
20
+ dtype: torch.dtype,
21
+ seed: int,
22
+ ):
23
+ return torch.randn(
24
+ num_samples,
25
+ 16,
26
+ # allow for packing
27
+ 2 * math.ceil(height / 16),
28
+ 2 * math.ceil(width / 16),
29
+ device=device,
30
+ dtype=dtype,
31
+ generator=torch.Generator(device=device).manual_seed(seed),
32
+ )
33
+
34
+
35
+ def prepare(
36
+ t5: HFEmbedder,
37
+ clip: HFEmbedder,
38
+ img: Tensor,
39
+ prompt: str | list[str],
40
+ ) -> dict[str, Tensor]:
41
+ bs, c, h, w = img.shape
42
+ if bs == 1 and not isinstance(prompt, str):
43
+ bs = len(prompt)
44
+
45
+ img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
46
+ if img.shape[0] == 1 and bs > 1:
47
+ img = repeat(img, "1 ... -> bs ...", bs=bs)
48
+
49
+ img_ids = torch.zeros(h // 2, w // 2, 3)
50
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
51
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
52
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
53
+
54
+
55
+ if isinstance(prompt, str):
56
+ prompt = [prompt]
57
+ txt = t5(prompt)
58
+ if txt.shape[0] == 1 and bs > 1:
59
+ txt = repeat(txt, "1 ... -> bs ...", bs=bs)
60
+ txt_ids = torch.zeros(bs, txt.shape[1], 3)
61
+
62
+ vec = clip(prompt)
63
+ if vec.shape[0] == 1 and bs > 1:
64
+ vec = repeat(vec, "1 ... -> bs ...", bs=bs)
65
+
66
+
67
+
68
+ return {
69
+ "img": img,
70
+ "img_ids": img_ids.to(img.device),
71
+ "txt": txt.to(img.device),
72
+ "txt_ids": txt_ids.to(img.device),
73
+ "vec": vec.to(img.device),
74
+ }
75
+
76
+
77
+
78
+
79
+ def time_shift(mu: float, sigma: float, t: Tensor):
80
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
81
+
82
+
83
+ def get_lin_function(
84
+ x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
85
+ ):
86
+ m = (y2 - y1) / (x2 - x1)
87
+ b = y1 - m * x1
88
+ return lambda x: m * x + b
89
+
90
+
91
+ def get_schedule(
92
+ num_steps: int,
93
+ image_seq_len: int,
94
+ base_shift: float = 0.5,
95
+ max_shift: float = 1.15,
96
+ shift: bool = True,
97
+ ) -> list[float]:
98
+ # extra step for zero
99
+ timesteps = torch.linspace(1, 0, num_steps + 1)
100
+
101
+ # shifting the schedule to favor high timesteps for higher signal images
102
+ if shift:
103
+ # eastimate mu based on linear estimation between two points
104
+ mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
105
+ timesteps = time_shift(mu, 1.0, timesteps)
106
+
107
+ return timesteps.tolist()
108
+
109
+
110
+ def denoise(
111
+ model: Flux,
112
+ # model input
113
+ img: Tensor,
114
+ img_ids: Tensor,
115
+ txt: Tensor,
116
+ txt_ids: Tensor,
117
+ vec: Tensor,
118
+
119
+ timesteps: list[float],
120
+ guidance: float = 4.0,
121
+
122
+ arcface_embeddings = None,
123
+ siglip_embeddings = None,
124
+ bboxes: Tensor = None,
125
+
126
+ id_weight: float = 1.0, # weight for identity embeddings
127
+ siglip_weight: float = 1.0, # weight for siglip embeddings
128
+ img_height: int = 512,
129
+ img_width: int = 512,
130
+ arc_mask = None,
131
+ siglip_mask = None,
132
+ ):
133
+ i = 0
134
+ guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
135
+ for t_curr, t_prev in tqdm(zip(timesteps[:-1], timesteps[1:]), total=len(timesteps) - 1):
136
+
137
+ t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
138
+
139
+ pred = model(
140
+ img=img,
141
+ img_ids=img_ids,
142
+ siglip_embeddings=siglip_embeddings,
143
+ txt=txt,
144
+ txt_ids=txt_ids,
145
+ y=vec,
146
+ timesteps=t_vec,
147
+ guidance=guidance_vec,
148
+ arcface_embeddings=arcface_embeddings,
149
+ bbox_lists=bboxes,
150
+ id_weight=id_weight,
151
+ siglip_weight=siglip_weight,
152
+ img_height=img_height,
153
+ img_width=img_width,
154
+ arc_mask=arc_mask,
155
+ siglip_mask=siglip_mask,
156
+ )
157
+ img = img + (t_prev - t_curr) * pred
158
+ i += 1
159
+
160
+ return img
161
+
162
+
163
+ def unpack(x: Tensor, height: int, width: int) -> Tensor:
164
+ return rearrange(
165
+ x,
166
+ "b (h w) (c ph pw) -> b c (h ph) (w pw)",
167
+ h=math.ceil(height / 16),
168
+ w=math.ceil(width / 16),
169
+ ph=2,
170
+ pw=2,
171
+ )
withanyone/flux/util.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import os
4
+ from dataclasses import dataclass
5
+
6
+ import torch
7
+ import json
8
+ import numpy as np
9
+ from huggingface_hub import hf_hub_download
10
+ from safetensors import safe_open
11
+ from safetensors.torch import load_file as load_sft
12
+
13
+ from withanyone.flux.model import Flux, FluxParams
14
+ from .modules.autoencoder import AutoEncoder, AutoEncoderParams
15
+ from .modules.conditioner import HFEmbedder
16
+
17
+ import re
18
+ from withanyone.flux.modules.layers import DoubleStreamBlockLoraProcessor, SingleStreamBlockLoraProcessor
19
+
20
+
21
+
22
+
23
+ def c_crop(image):
24
+ width, height = image.size
25
+ new_size = min(width, height)
26
+ left = (width - new_size) / 2
27
+ top = (height - new_size) / 2
28
+ right = (width + new_size) / 2
29
+ bottom = (height + new_size) / 2
30
+ return image.crop((left, top, right, bottom))
31
+
32
+ def pad64(x):
33
+ return int(np.ceil(float(x) / 64.0) * 64 - x)
34
+
35
+ def HWC3(x):
36
+ assert x.dtype == np.uint8
37
+ if x.ndim == 2:
38
+ x = x[:, :, None]
39
+ assert x.ndim == 3
40
+ H, W, C = x.shape
41
+ assert C == 1 or C == 3 or C == 4
42
+ if C == 3:
43
+ return x
44
+ if C == 1:
45
+ return np.concatenate([x, x, x], axis=2)
46
+ if C == 4:
47
+ color = x[:, :, 0:3].astype(np.float32)
48
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
49
+ y = color * alpha + 255.0 * (1.0 - alpha)
50
+ y = y.clip(0, 255).astype(np.uint8)
51
+ return y
52
+
53
+ @dataclass
54
+ class ModelSpec:
55
+ params: FluxParams
56
+ ae_params: AutoEncoderParams
57
+ repo_id: str | None
58
+ repo_flow: str | None
59
+ repo_ae: str | None
60
+ repo_id_ae: str | None
61
+
62
+
63
+ configs = {
64
+ "flux-dev": ModelSpec(
65
+ repo_id="black-forest-labs/FLUX.1-dev",
66
+ repo_id_ae="black-forest-labs/FLUX.1-dev",
67
+ repo_flow="flux1-dev.safetensors",
68
+ repo_ae="ae.safetensors",
69
+ params=FluxParams(
70
+ in_channels=64,
71
+ vec_in_dim=768,
72
+ context_in_dim=4096,
73
+ hidden_size=3072,
74
+ mlp_ratio=4.0,
75
+ num_heads=24,
76
+ depth=19,
77
+ depth_single_blocks=38,
78
+ axes_dim=[16, 56, 56],
79
+ theta=10_000,
80
+ qkv_bias=True,
81
+ guidance_embed=True,
82
+ ),
83
+ ae_params=AutoEncoderParams(
84
+ resolution=256,
85
+ in_channels=3,
86
+ ch=128,
87
+ out_ch=3,
88
+ ch_mult=[1, 2, 4, 4],
89
+ num_res_blocks=2,
90
+ z_channels=16,
91
+ scale_factor=0.3611,
92
+ shift_factor=0.1159,
93
+ ),
94
+ ),
95
+ "flux-dev-fp8": ModelSpec(
96
+ repo_id="black-forest-labs/FLUX.1-dev",
97
+ repo_id_ae="black-forest-labs/FLUX.1-dev",
98
+ repo_flow="flux1-dev.safetensors",
99
+ repo_ae="ae.safetensors",
100
+ params=FluxParams(
101
+ in_channels=64,
102
+ vec_in_dim=768,
103
+ context_in_dim=4096,
104
+ hidden_size=3072,
105
+ mlp_ratio=4.0,
106
+ num_heads=24,
107
+ depth=19,
108
+ depth_single_blocks=38,
109
+ axes_dim=[16, 56, 56],
110
+ theta=10_000,
111
+ qkv_bias=True,
112
+ guidance_embed=True,
113
+ ),
114
+ ae_params=AutoEncoderParams(
115
+ resolution=256,
116
+ in_channels=3,
117
+ ch=128,
118
+ out_ch=3,
119
+ ch_mult=[1, 2, 4, 4],
120
+ num_res_blocks=2,
121
+ z_channels=16,
122
+ scale_factor=0.3611,
123
+ shift_factor=0.1159,
124
+ ),
125
+ ),
126
+ "flux-krea": ModelSpec(
127
+ repo_id="black-forest-labs/FLUX.1-Krea-dev",
128
+ repo_id_ae="black-forest-labs/FLUX.1-Krea-dev",
129
+ repo_flow="flux1-krea-dev.safetensors",
130
+ repo_ae="ae.safetensors",
131
+ params=FluxParams(
132
+ in_channels=64,
133
+ vec_in_dim=768,
134
+ context_in_dim=4096,
135
+ hidden_size=3072,
136
+ mlp_ratio=4.0,
137
+ num_heads=24,
138
+ depth=19,
139
+ depth_single_blocks=38,
140
+ axes_dim=[16, 56, 56],
141
+ theta=10_000,
142
+ qkv_bias=True,
143
+ guidance_embed=True,
144
+ ),
145
+ ae_params=AutoEncoderParams(
146
+ resolution=256,
147
+ in_channels=3,
148
+ ch=128,
149
+ out_ch=3,
150
+ ch_mult=[1, 2, 4, 4],
151
+ num_res_blocks=2,
152
+ z_channels=16,
153
+ scale_factor=0.3611,
154
+ shift_factor=0.1159,
155
+ ),
156
+ ),
157
+ "flux-schnell": ModelSpec(
158
+ repo_id="black-forest-labs/FLUX.1-schnell",
159
+ repo_id_ae="black-forest-labs/FLUX.1-dev",
160
+ repo_flow="flux1-schnell.safetensors",
161
+ repo_ae="ae.safetensors",
162
+ params=FluxParams(
163
+ in_channels=64,
164
+ vec_in_dim=768,
165
+ context_in_dim=4096,
166
+ hidden_size=3072,
167
+ mlp_ratio=4.0,
168
+ num_heads=24,
169
+ depth=19,
170
+ depth_single_blocks=38,
171
+ axes_dim=[16, 56, 56],
172
+ theta=10_000,
173
+ qkv_bias=True,
174
+ guidance_embed=False,
175
+ ),
176
+ ae_params=AutoEncoderParams(
177
+ resolution=256,
178
+ in_channels=3,
179
+ ch=128,
180
+ out_ch=3,
181
+ ch_mult=[1, 2, 4, 4],
182
+ num_res_blocks=2,
183
+ z_channels=16,
184
+ scale_factor=0.3611,
185
+ shift_factor=0.1159,
186
+ ),
187
+ ),
188
+ }
189
+
190
+
191
+ def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
192
+ if len(missing) > 0 and len(unexpected) > 0:
193
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
194
+ print("\n" + "-" * 79 + "\n")
195
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
196
+ elif len(missing) > 0:
197
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
198
+ elif len(unexpected) > 0:
199
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
200
+
201
+ def load_from_repo_id(repo_id, checkpoint_name):
202
+ ckpt_path = hf_hub_download(repo_id, checkpoint_name)
203
+ sd = load_sft(ckpt_path, device='cpu')
204
+ return sd
205
+
206
+
207
+
208
+
209
+
210
+ def load_flow_model_no_lora(
211
+ name: str,
212
+ path: str,
213
+ ipa_path: str ,
214
+ device: str | torch.device = "cuda",
215
+ hf_download: bool = True,
216
+ lora_rank: int = 16,
217
+ use_fp8: bool = False
218
+ ):
219
+ # Loading Flux
220
+ print("Init model")
221
+ ckpt_path = path
222
+ if ckpt_path == "black-forest-labs/FLUX.1-dev" or (
223
+ ckpt_path is None
224
+ and configs[name].repo_id is not None
225
+ and configs[name].repo_flow is not None
226
+ and hf_download
227
+ ):
228
+ ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow.replace("sft", "safetensors"))
229
+ print("Downloading checkpoint from HF:", ckpt_path)
230
+ else:
231
+ ckpt_path = os.path.join(path, "flux1-dev.safetensors") if path is not None else None
232
+
233
+
234
+
235
+
236
+ ipa_ckpt_path = ipa_path
237
+
238
+
239
+ with torch.device("meta" if ckpt_path is not None else device):
240
+ model = Flux(configs[name].params)
241
+
242
+
243
+ # model = set_lora(model, lora_rank, device="meta" if ipa_ckpt_path is not None else device)
244
+
245
+ if ckpt_path is not None:
246
+ if ipa_ckpt_path == 'WithAnyone/WithAnyone':
247
+ ipa_ckpt_path = hf_hub_download("WithAnyone/WithAnyone", "withanyone.safetensors")
248
+
249
+ lora_sd = load_sft(ipa_ckpt_path, device=str(device)) if ipa_ckpt_path.endswith("safetensors")\
250
+ else torch.load(ipa_ckpt_path, map_location='cpu')
251
+
252
+
253
+ print("Loading main checkpoint")
254
+ # load_sft doesn't support torch.device
255
+
256
+ if ckpt_path.endswith('safetensors'):
257
+ if use_fp8:
258
+ print(
259
+ "####\n"
260
+ "We are in fp8 mode right now, since the fp8 checkpoint of XLabs-AI/flux-dev-fp8 seems broken\n"
261
+ "we convert the fp8 checkpoint on flight from bf16 checkpoint\n"
262
+ "If your storage is constrained"
263
+ "you can save the fp8 checkpoint and replace the bf16 checkpoint by yourself\n"
264
+ )
265
+ sd = load_sft(ckpt_path, device="cpu")
266
+ sd = {k: v.to(dtype=torch.float8_e4m3fn, device=device) for k, v in sd.items()}
267
+ else:
268
+ sd = load_sft(ckpt_path, device=str(device))
269
+
270
+
271
+
272
+ # # Then proceed with the update
273
+ sd.update(lora_sd)
274
+ missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
275
+ else:
276
+ dit_state = torch.load(ckpt_path, map_location='cpu')
277
+ sd = {}
278
+ for k in dit_state.keys():
279
+ sd[k.replace('module.','')] = dit_state[k]
280
+ sd.update(lora_sd)
281
+ missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
282
+ model.to(str(device))
283
+ print_load_warning(missing, unexpected)
284
+ return model
285
+
286
+
287
+ def merge_to_flux_model(
288
+ loading_device, working_device, flux_state_dict, model, ratio, merge_dtype, save_dtype, mem_eff_load_save=False
289
+ ):
290
+
291
+ lora_name_to_module_key = {}
292
+ keys = list(flux_state_dict.keys())
293
+ for key in keys:
294
+ if key.endswith(".weight"):
295
+ module_name = ".".join(key.split(".")[:-1])
296
+ lora_name = "lora_unet" + "_" + module_name.replace(".", "_")
297
+ lora_name_to_module_key[lora_name] = key
298
+
299
+
300
+ print(f"loading: {model}")
301
+ lora_sd = load_sft(model, device=loading_device) if model.endswith("safetensors")\
302
+ else torch.load(model, map_location='cpu')
303
+
304
+ print(f"merging...")
305
+ for key in list(lora_sd.keys()):
306
+ if "lora_down" in key:
307
+ lora_name = key[: key.rfind(".lora_down")]
308
+ up_key = key.replace("lora_down", "lora_up")
309
+ alpha_key = key[: key.index("lora_down")] + "alpha"
310
+
311
+ if lora_name not in lora_name_to_module_key:
312
+ print(f"no module found for LoRA weight: {key}. LoRA for Text Encoder is not supported yet.")
313
+ continue
314
+
315
+ down_weight = lora_sd.pop(key)
316
+ up_weight = lora_sd.pop(up_key)
317
+
318
+ dim = down_weight.size()[0]
319
+ alpha = lora_sd.pop(alpha_key, dim)
320
+ scale = alpha / dim
321
+
322
+ # W <- W + U * D
323
+ module_weight_key = lora_name_to_module_key[lora_name]
324
+ if module_weight_key not in flux_state_dict:
325
+ # weight = flux_file.get_tensor(module_weight_key)
326
+ print(f"no module found for LoRA weight: {module_weight_key}")
327
+ else:
328
+ weight = flux_state_dict[module_weight_key]
329
+
330
+ weight = weight.to(working_device, merge_dtype)
331
+ up_weight = up_weight.to(working_device, merge_dtype)
332
+ down_weight = down_weight.to(working_device, merge_dtype)
333
+
334
+
335
+ if len(weight.size()) == 2:
336
+ # linear
337
+ weight = weight + ratio * (up_weight @ down_weight) * scale
338
+ elif down_weight.size()[2:4] == (1, 1):
339
+ # conv2d 1x1
340
+ weight = (
341
+ weight
342
+ + ratio
343
+ * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
344
+ * scale
345
+ )
346
+ else:
347
+ # conv2d 3x3
348
+ conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
349
+ weight = weight + ratio * conved * scale
350
+
351
+ flux_state_dict[module_weight_key] = weight.to(loading_device, save_dtype)
352
+ del up_weight
353
+ del down_weight
354
+ del weight
355
+
356
+ if len(lora_sd) > 0:
357
+ print(f"Unused keys in LoRA model: {list(lora_sd.keys())}")
358
+
359
+ return flux_state_dict
360
+
361
+
362
+
363
+ def load_flow_model_diffusers(
364
+ name: str,
365
+ path: str,
366
+ ipa_path: str ,
367
+ device: str | torch.device = "cuda",
368
+ hf_download: bool = True,
369
+ lora_rank: int = 16,
370
+ use_fp8: bool = False,
371
+ additional_lora_ckpt: str | None = None,
372
+ lora_weight: float = 1.0,
373
+ ):
374
+ # Loading Flux
375
+ print("Init model")
376
+
377
+ ckpt_path = os.path.join(path, "flux1-dev.safetensors") if path is not None else None
378
+ print("Loading checkpoint from", ckpt_path)
379
+ if (
380
+ ckpt_path is None
381
+ and configs[name].repo_id is not None
382
+ and configs[name].repo_flow is not None
383
+ and hf_download
384
+ ):
385
+ ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow.replace("sft", "safetensors"))
386
+
387
+
388
+ ipa_ckpt_path = ipa_path
389
+
390
+
391
+
392
+ with torch.device("meta" if ckpt_path is not None else device):
393
+ model = Flux(configs[name].params)
394
+
395
+ # if additional_lora_ckpt is not None:
396
+ # model = set_lora(model, lora_rank, device="meta" if ipa_ckpt_path is not None else device)
397
+ assert additional_lora_ckpt is not None, "additional_lora_ckpt should have been provided. this must be a bug"
398
+
399
+ if ckpt_path is not None:
400
+ if ipa_ckpt_path == 'WithAnyone/WithAnyone':
401
+ ipa_ckpt_path = hf_hub_download("WithAnyone/WithAnyone", "withanyone.safetensors")
402
+ else:
403
+ lora_sd = load_sft(ipa_ckpt_path, device=str(device)) if ipa_ckpt_path.endswith("safetensors")\
404
+ else torch.load(ipa_ckpt_path, map_location='cpu')
405
+
406
+ extra_lora_path = additional_lora_ckpt
407
+
408
+ print("Loading main checkpoint")
409
+ # load_sft doesn't support torch.device
410
+
411
+ if ckpt_path.endswith('safetensors'):
412
+ if use_fp8:
413
+ print(
414
+ "####\n"
415
+ "We are in fp8 mode right now, since the fp8 checkpoint of XLabs-AI/flux-dev-fp8 seems broken\n"
416
+ "we convert the fp8 checkpoint on flight from bf16 checkpoint\n"
417
+ "If your storage is constrained"
418
+ "you can save the fp8 checkpoint and replace the bf16 checkpoint by yourself\n"
419
+ )
420
+ sd = load_sft(ckpt_path, device="cpu")
421
+ sd = {k: v.to(dtype=torch.float8_e4m3fn, device=device) for k, v in sd.items()}
422
+ else:
423
+ sd = load_sft(ckpt_path, device=str(device))
424
+
425
+ if extra_lora_path is not None:
426
+ print("Merging extra lora to main checkpoint")
427
+ lora_ckpt_path = extra_lora_path
428
+ sd = merge_to_flux_model("cpu", device, sd, lora_ckpt_path, lora_weight, torch.float8_e4m3fn if use_fp8 else torch.bfloat16, torch.float8_e4m3fn if use_fp8 else torch.bfloat16)
429
+ # # Then proceed with the update
430
+ sd.update(ipa_lora_sd)
431
+
432
+ missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
433
+ model.to(str(device))
434
+ else:
435
+ dit_state = torch.load(ckpt_path, map_location='cpu')
436
+ sd = {}
437
+ for k in dit_state.keys():
438
+ sd[k.replace('module.','')] = dit_state[k]
439
+
440
+ if extra_lora_path is not None:
441
+ print("Merging extra lora to main checkpoint")
442
+ lora_ckpt_path = extra_lora_path
443
+ sd = merge_to_flux_model("cpu", device, sd, lora_ckpt_path, 1.0, torch.float8_e4m3fn if use_fp8 else torch.bfloat16, torch.float8_e4m3fn if use_fp8 else torch.bfloat16)
444
+ sd.update(ipa_lora_sd)
445
+
446
+ missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
447
+ model.to(str(device))
448
+ print_load_warning(missing, unexpected)
449
+
450
+ return model
451
+
452
+
453
+ def set_lora(
454
+ model: Flux,
455
+ lora_rank: int,
456
+ double_blocks_indices: list[int] | None = None,
457
+ single_blocks_indices: list[int] | None = None,
458
+ device: str | torch.device = "cpu",
459
+ ) -> Flux:
460
+ double_blocks_indices = list(range(model.params.depth)) if double_blocks_indices is None else double_blocks_indices
461
+ single_blocks_indices = list(range(model.params.depth_single_blocks)) if single_blocks_indices is None \
462
+ else single_blocks_indices
463
+
464
+ lora_attn_procs = {}
465
+ with torch.device(device):
466
+ for name, attn_processor in model.attn_processors.items():
467
+ match = re.search(r'\.(\d+)\.', name)
468
+ if match:
469
+ layer_index = int(match.group(1))
470
+
471
+ if name.startswith("double_blocks") and layer_index in double_blocks_indices:
472
+ lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(dim=model.params.hidden_size, rank=lora_rank)
473
+ elif name.startswith("single_blocks") and layer_index in single_blocks_indices:
474
+ lora_attn_procs[name] = SingleStreamBlockLoraProcessor(dim=model.params.hidden_size, rank=lora_rank)
475
+ else:
476
+ lora_attn_procs[name] = attn_processor
477
+ model.set_attn_processor(lora_attn_procs)
478
+ return model
479
+
480
+
481
+
482
+
483
+ def load_t5(t5_path, device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
484
+ # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
485
+ version = t5_path
486
+ return HFEmbedder(version, max_length=max_length, torch_dtype=torch.bfloat16).to(device)
487
+
488
+ def load_clip(clip_path, device: str | torch.device = "cuda") -> HFEmbedder:
489
+ version = clip_path
490
+
491
+ return HFEmbedder(version, max_length=77, torch_dtype=torch.bfloat16).to(device)
492
+
493
+
494
+ def load_ae(flux_path, name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder:
495
+
496
+
497
+ if flux_path == "black-forest-labs/FLUX.1-dev" or flux_path == "black-forest-labs/FLUX.1-schnell" or flux_path == "black-forest-labs/FLUX.1-Krea-dev" or flux_path == "black-forest-labs/FLUX.1-Kontext-dev":
498
+ ckpt_path = hf_hub_download("black-forest-labs/FLUX.1-dev", "ae.safetensors")
499
+ else:
500
+ ckpt_path = os.path.join(flux_path, "ae.safetensors")
501
+ if not os.path.exists(ckpt_path):
502
+ # try diffusion_pytorch_model.safetensors
503
+ ckpt_path = os.path.join(flux_path, "vae", "ae.safetensors")
504
+ if not os.path.exists(ckpt_path):
505
+ raise FileNotFoundError(f"Cannot find ae checkpoint in {flux_path}/ae.safetensors or {flux_path}/vae/ae.safetensors")
506
+
507
+
508
+ # Loading the autoencoder
509
+ print("Init AE")
510
+ with torch.device("meta" if ckpt_path is not None else device):
511
+ ae = AutoEncoder(configs[name].ae_params)
512
+
513
+ # if ckpt_path is not None:
514
+ assert ckpt_path is not None, "ckpt_path should have been provided. this must be a bug"
515
+ sd = load_sft(ckpt_path, device=str(device))
516
+ missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
517
+ print_load_warning(missing, unexpected)
518
+ return ae
withanyone/utils/convert_yaml_to_args_file.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import argparse
4
+ import yaml
5
+
6
+ parser = argparse.ArgumentParser()
7
+ parser.add_argument("--yaml", type=str, required=True)
8
+ parser.add_argument("--arg", type=str, required=True)
9
+ args = parser.parse_args()
10
+
11
+
12
+ with open(args.yaml, "r") as f:
13
+ data = yaml.safe_load(f)
14
+
15
+ with open(args.arg, "w") as f:
16
+ for k, v in data.items():
17
+ if isinstance(v, list):
18
+ v = list(map(str, v))
19
+ v = " ".join(v)
20
+ if v is None:
21
+ continue
22
+ print(f"--{k} {v}", end=" ", file=f)