sjzhao commited on
Commit
bd63939
1 Parent(s): 27c6484

update demo

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. .gitignore +4 -0
  3. .idea/.gitignore +8 -0
  4. .idea/SEED.iml +12 -0
  5. .idea/inspectionProfiles/profiles_settings.xml +6 -0
  6. .idea/misc.xml +4 -0
  7. .idea/modules.xml +8 -0
  8. .idea/vcs.xml +6 -0
  9. .project-root +0 -0
  10. Dockerfile +3 -12
  11. License.txt +470 -0
  12. README-SEED-2.md +184 -0
  13. SEED-1.md +93 -0
  14. configs/llm/seed_llama_14b.yaml +5 -0
  15. configs/llm/seed_llama_14b_8bit.yaml +5 -0
  16. configs/llm/seed_llama_8b.yaml +5 -0
  17. configs/llm/seed_llama_8b_8bit.yaml +5 -0
  18. configs/tokenizer/seed_llama_tokenizer.yaml +4 -0
  19. configs/tokenizer/seed_llama_tokenizer_hf.yaml +6 -0
  20. configs/transform/clip_transform.yaml +4 -0
  21. gradio_demo/conversation.py +190 -0
  22. gradio_demo/seed_llama_flask.py +230 -0
  23. gradio_demo/seed_llama_gradio.py +497 -0
  24. gradio_demo/utils.py +82 -0
  25. images/cat.jpg +3 -0
  26. images/demo_example1.jpg +3 -0
  27. images/demo_example2.jpg +3 -0
  28. images/demo_example3.jpg +3 -0
  29. images/demo_example4.jpg +3 -0
  30. images/demo_example5.jpg +3 -0
  31. images/demo_example6.jpg +3 -0
  32. images/demo_example7.jpg +3 -0
  33. images/dogs_4.jpg +3 -0
  34. images/eagle.jpg +3 -0
  35. images/flower.png +3 -0
  36. images/spongebob.png +3 -0
  37. images/star.jpg +3 -0
  38. models/__init__.py +0 -0
  39. models/llama_xformer.py +906 -0
  40. models/model_tools.py +18 -0
  41. models/pipeline_stable_unclip_img2img.py +794 -0
  42. models/seed_llama_tokenizer.py +213 -0
  43. models/seed_qformer/blip2.py +186 -0
  44. models/seed_qformer/clip_vit.py +257 -0
  45. models/seed_qformer/eva_vit.py +486 -0
  46. models/seed_qformer/qformer_causual.py +1169 -0
  47. models/seed_qformer/qformer_quantizer.py +375 -0
  48. models/seed_qformer/utils.py +138 -0
  49. models/seed_qformer/vit.py +395 -0
  50. models/transforms.py +21 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.jpg filter=lfs diff=lfs merge=lfs -text
37
+ *.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ pretrained/*
2
+ !pretrained/.gitkeep
3
+ **/__pycache__/**
4
+ log/
.idea/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Editor-based HTTP Client requests
5
+ /httpRequests/
6
+ # Datasource local storage ignored files
7
+ /dataSources/
8
+ /dataSources.local.xml
.idea/SEED.iml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="inheritedJdk" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ <component name="PyDocumentationSettings">
9
+ <option name="format" value="GOOGLE" />
10
+ <option name="myDocStringFormat" value="Google" />
11
+ </component>
12
+ </module>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/misc.xml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.9" project-jdk-type="Python SDK" />
4
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/SEED.iml" filepath="$PROJECT_DIR$/.idea/SEED.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="" vcs="Git" />
5
+ </component>
6
+ </project>
.project-root ADDED
File without changes
Dockerfile CHANGED
@@ -4,8 +4,7 @@ FROM python:3.11
4
  # Set the working directory to /code
5
  WORKDIR /code
6
 
7
- RUN apt-get update && apt-get install -y screen git git-lfs
8
- RUN git lfs install
9
 
10
  # Copy the current directory contents into the container at /code
11
  # COPY ./requirements.txt /code/requirements.txt
@@ -29,16 +28,8 @@ WORKDIR $HOME/app
29
  # Copy the current directory contents into the container at $HOME/app setting the owner to the user
30
  COPY --chown=user . $HOME/app
31
 
32
- RUN git clone https://github.com/AILab-CVC/SEED.git
33
-
34
- RUN mv SEED/* . && rm -rf SEED
35
 
36
  RUN pip install -r requirements.txt
37
 
38
- # RUN git clone https://huggingface.co/AILab-CVC/SEED
39
-
40
- # RUN mv SEED/* pretrained/ && rm -rf SEED
41
-
42
- RUN chmod +x start.sh
43
-
44
- CMD ["./start.sh"]
 
4
  # Set the working directory to /code
5
  WORKDIR /code
6
 
7
+ RUN apt-get update && apt-get install -y git git-lfs
 
8
 
9
  # Copy the current directory contents into the container at /code
10
  # COPY ./requirements.txt /code/requirements.txt
 
28
  # Copy the current directory contents into the container at $HOME/app setting the owner to the user
29
  COPY --chown=user . $HOME/app
30
 
31
+ RUN git lfs install
 
 
32
 
33
  RUN pip install -r requirements.txt
34
 
35
+ CMD ["python", 'start.py']
 
 
 
 
 
 
License.txt ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ This license applies to the source codes that are open sourced in connection with the research AI Lab论文开源-SEED.
2
+
3
+ Copyright (C) 2023 THL A29 Limited, a Tencent company.
4
+
5
+ Apache License
6
+ Version 2.0, January 2004
7
+ http://www.apache.org/licenses/
8
+
9
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
10
+
11
+ 1. Definitions.
12
+
13
+ "License" shall mean the terms and conditions for use, reproduction,
14
+ and distribution as defined by Sections 1 through 9 of this document.
15
+
16
+ "Licensor" shall mean the copyright owner or entity authorized by
17
+ the copyright owner that is granting the License.
18
+
19
+ "Legal Entity" shall mean the union of the acting entity and all
20
+ other entities that control, are controlled by, or are under common
21
+ control with that entity. For the purposes of this definition,
22
+ "control" means (i) the power, direct or indirect, to cause the
23
+ direction or management of such entity, whether by contract or
24
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
25
+ outstanding shares, or (iii) beneficial ownership of such entity.
26
+
27
+ "You" (or "Your") shall mean an individual or Legal Entity
28
+ exercising permissions granted by this License.
29
+
30
+ "Source" form shall mean the preferred form for making modifications,
31
+ including but not limited to software source code, documentation
32
+ source, and configuration files.
33
+
34
+ "Object" form shall mean any form resulting from mechanical
35
+ transformation or translation of a Source form, including but
36
+ not limited to compiled object code, generated documentation,
37
+ and conversions to other media types.
38
+
39
+ "Work" shall mean the work of authorship, whether in Source or
40
+ Object form, made available under the License, as indicated by a
41
+ copyright notice that is included in or attached to the work
42
+ (an example is provided in the Appendix below).
43
+
44
+ "Derivative Works" shall mean any work, whether in Source or Object
45
+ form, that is based on (or derived from) the Work and for which the
46
+ editorial revisions, annotations, elaborations, or other modifications
47
+ represent, as a whole, an original work of authorship. For the purposes
48
+ of this License, Derivative Works shall not include works that remain
49
+ separable from, or merely link (or bind by name) to the interfaces of,
50
+ the Work and Derivative Works thereof.
51
+
52
+ "Contribution" shall mean any work of authorship, including
53
+ the original version of the Work and any modifications or additions
54
+ to that Work or Derivative Works thereof, that is intentionally
55
+ submitted to Licensor for inclusion in the Work by the copyright owner
56
+ or by an individual or Legal Entity authorized to submit on behalf of
57
+ the copyright owner. For the purposes of this definition, "submitted"
58
+ means any form of electronic, verbal, or written communication sent
59
+ to the Licensor or its representatives, including but not limited to
60
+ communication on electronic mailing lists, source code control systems,
61
+ and issue tracking systems that are managed by, or on behalf of, the
62
+ Licensor for the purpose of discussing and improving the Work, but
63
+ excluding communication that is conspicuously marked or otherwise
64
+ designated in writing by the copyright owner as "Not a Contribution."
65
+
66
+ "Contributor" shall mean Licensor and any individual or Legal Entity
67
+ on behalf of whom a Contribution has been received by Licensor and
68
+ subsequently incorporated within the Work.
69
+
70
+ 2. Grant of Copyright License. Subject to the terms and conditions of
71
+ this License, each Contributor hereby grants to You a perpetual,
72
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
73
+ copyright license to reproduce, prepare Derivative Works of,
74
+ publicly display, publicly perform, sublicense, and distribute the
75
+ Work and such Derivative Works in Source or Object form.
76
+
77
+ 3. Grant of Patent License. Subject to the terms and conditions of
78
+ this License, each Contributor hereby grants to You a perpetual,
79
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
80
+ (except as stated in this section) patent license to make, have made,
81
+ use, offer to sell, sell, import, and otherwise transfer the Work,
82
+ where such license applies only to those patent claims licensable
83
+ by such Contributor that are necessarily infringed by their
84
+ Contribution(s) alone or by combination of their Contribution(s)
85
+ with the Work to which such Contribution(s) was submitted. If You
86
+ institute patent litigation against any entity (including a
87
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
88
+ or a Contribution incorporated within the Work constitutes direct
89
+ or contributory patent infringement, then any patent licenses
90
+ granted to You under this License for that Work shall terminate
91
+ as of the date such litigation is filed.
92
+
93
+ 4. Redistribution. You may reproduce and distribute copies of the
94
+ Work or Derivative Works thereof in any medium, with or without
95
+ modifications, and in Source or Object form, provided that You
96
+ meet the following conditions:
97
+
98
+ (a) You must give any other recipients of the Work or
99
+ Derivative Works a copy of this License; and
100
+
101
+ (b) You must cause any modified files to carry prominent notices
102
+ stating that You changed the files; and
103
+
104
+ (c) You must retain, in the Source form of any Derivative Works
105
+ that You distribute, all copyright, patent, trademark, and
106
+ attribution notices from the Source form of the Work,
107
+ excluding those notices that do not pertain to any part of
108
+ the Derivative Works; and
109
+
110
+ (d) If the Work includes a "NOTICE" text file as part of its
111
+ distribution, then any Derivative Works that You distribute must
112
+ include a readable copy of the attribution notices contained
113
+ within such NOTICE file, excluding those notices that do not
114
+ pertain to any part of the Derivative Works, in at least one
115
+ of the following places: within a NOTICE text file distributed
116
+ as part of the Derivative Works; within the Source form or
117
+ documentation, if provided along with the Derivative Works; or,
118
+ within a display generated by the Derivative Works, if and
119
+ wherever such third-party notices normally appear. The contents
120
+ of the NOTICE file are for informational purposes only and
121
+ do not modify the License. You may add Your own attribution
122
+ notices within Derivative Works that You distribute, alongside
123
+ or as an addendum to the NOTICE text from the Work, provided
124
+ that such additional attribution notices cannot be construed
125
+ as modifying the License.
126
+
127
+ You may add Your own copyright statement to Your modifications and
128
+ may provide additional or different license terms and conditions
129
+ for use, reproduction, or distribution of Your modifications, or
130
+ for any such Derivative Works as a whole, provided Your use,
131
+ reproduction, and distribution of the Work otherwise complies with
132
+ the conditions stated in this License.
133
+
134
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
135
+ any Contribution intentionally submitted for inclusion in the Work
136
+ by You to the Licensor shall be under the terms and conditions of
137
+ this License, without any additional terms or conditions.
138
+ Notwithstanding the above, nothing herein shall supersede or modify
139
+ the terms of any separate license agreement you may have executed
140
+ with Licensor regarding such Contributions.
141
+
142
+ 6. Trademarks. This License does not grant permission to use the trade
143
+ names, trademarks, service marks, or product names of the Licensor,
144
+ except as required for reasonable and customary use in describing the
145
+ origin of the Work and reproducing the content of the NOTICE file.
146
+
147
+ 7. Disclaimer of Warranty. Unless required by applicable law or
148
+ agreed to in writing, Licensor provides the Work (and each
149
+ Contributor provides its Contributions) on an "AS IS" BASIS,
150
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
151
+ implied, including, without limitation, any warranties or conditions
152
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
153
+ PARTICULAR PURPOSE. You are solely responsible for determining the
154
+ appropriateness of using or redistributing the Work and assume any
155
+ risks associated with Your exercise of permissions under this License.
156
+
157
+ 8. Limitation of Liability. In no event and under no legal theory,
158
+ whether in tort (including negligence), contract, or otherwise,
159
+ unless required by applicable law (such as deliberate and grossly
160
+ negligent acts) or agreed to in writing, shall any Contributor be
161
+ liable to You for damages, including any direct, indirect, special,
162
+ incidental, or consequential damages of any character arising as a
163
+ result of this License or out of the use or inability to use the
164
+ Work (including but not limited to damages for loss of goodwill,
165
+ work stoppage, computer failure or malfunction, or any and all
166
+ other commercial damages or losses), even if such Contributor
167
+ has been advised of the possibility of such damages.
168
+
169
+ 9. Accepting Warranty or Additional Liability. While redistributing
170
+ the Work or Derivative Works thereof, You may choose to offer,
171
+ and charge a fee for, acceptance of support, warranty, indemnity,
172
+ or other liability obligations and/or rights consistent with this
173
+ License. However, in accepting such obligations, You may act only
174
+ on Your own behalf and on Your sole responsibility, not on behalf
175
+ of any other Contributor, and only if You agree to indemnify,
176
+ defend, and hold each Contributor harmless for any liability
177
+ incurred by, or claims asserted against, such Contributor by reason
178
+ of your accepting any such warranty or additional liability.
179
+
180
+ 10. This code is provided for research purposes only and is
181
+ not to be used for any commercial purposes. By using this code,
182
+ you agree that it will be used solely for academic research, scholarly work,
183
+ and non-commercial activities. Any use of this code for commercial purposes,
184
+ including but not limited to, selling, distributing, or incorporating it into
185
+ commercial products or services, is strictly prohibited. Violation of this
186
+ clause may result in legal actions and penalties.
187
+
188
+ END OF TERMS AND CONDITIONS
189
+
190
+ APPENDIX: How to apply the Apache License to your work.
191
+
192
+ To apply the Apache License to your work, attach the following
193
+ boilerplate notice, with the fields enclosed by brackets "[]"
194
+ replaced with your own identifying information. (Don't include
195
+ the brackets!) The text should be enclosed in the appropriate
196
+ comment syntax for the file format. We also recommend that a
197
+ file or class name and description of purpose be included on the
198
+ same "printed page" as the copyright notice for easier
199
+ identification within third-party archives.
200
+
201
+ Copyright [yyyy] [name of copyright owner]
202
+
203
+ Licensed under the Apache License, Version 2.0 (the "License");
204
+ you may not use this file except in compliance with the License.
205
+ You may obtain a copy of the License at
206
+
207
+ http://www.apache.org/licenses/LICENSE-2.0
208
+
209
+ Unless required by applicable law or agreed to in writing, software
210
+ distributed under the License is distributed on an "AS IS" BASIS,
211
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
212
+ See the License for the specific language governing permissions and
213
+ limitations under the License.
214
+
215
+
216
+ Other dependencies and licenses (if such optional components are used):
217
+
218
+
219
+ Components under BSD 3-Clause License:
220
+ ------------------------------------------------
221
+ 1. numpy
222
+ Copyright (c) 2005-2022, NumPy Developers.
223
+ All rights reserved.
224
+
225
+ 2. pytorch
226
+ Copyright (c) 2016- Facebook, Inc (Adam Paszke)
227
+ Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
228
+ Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
229
+ Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
230
+ Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
231
+ Copyright (c) 2011-2013 NYU (Clement Farabet)
232
+ Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
233
+ Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
234
+ Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
235
+
236
+ 3. torchvision
237
+ Copyright (c) Soumith Chintala 2016,
238
+ All rights reserved.
239
+
240
+ Redistribution and use in source and binary forms, with or without
241
+ modification, are permitted provided that the following conditions are met:
242
+
243
+ * Redistributions of source code must retain the above copyright notice, this
244
+ list of conditions and the following disclaimer.
245
+
246
+ * Redistributions in binary form must reproduce the above copyright notice,
247
+ this list of conditions and the following disclaimer in the documentation
248
+ and/or other materials provided with the distribution.
249
+
250
+ * Neither the name of the copyright holder nor the names of its
251
+ contributors may be used to endorse or promote products derived from
252
+ this software without specific prior written permission.
253
+
254
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
255
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
256
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
257
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
258
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
259
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
260
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
261
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
262
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
263
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
264
+
265
+ Component under Apache v2 License:
266
+ -----------------------------------------------------
267
+ 1. timm
268
+ Copyright 2019 Ross Wightman
269
+
270
+ Apache License
271
+ Version 2.0, January 2004
272
+ http://www.apache.org/licenses/
273
+
274
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
275
+
276
+ 1. Definitions.
277
+
278
+ "License" shall mean the terms and conditions for use, reproduction,
279
+ and distribution as defined by Sections 1 through 9 of this document.
280
+
281
+ "Licensor" shall mean the copyright owner or entity authorized by
282
+ the copyright owner that is granting the License.
283
+
284
+ "Legal Entity" shall mean the union of the acting entity and all
285
+ other entities that control, are controlled by, or are under common
286
+ control with that entity. For the purposes of this definition,
287
+ "control" means (i) the power, direct or indirect, to cause the
288
+ direction or management of such entity, whether by contract or
289
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
290
+ outstanding shares, or (iii) beneficial ownership of such entity.
291
+
292
+ "You" (or "Your") shall mean an individual or Legal Entity
293
+ exercising permissions granted by this License.
294
+
295
+ "Source" form shall mean the preferred form for making modifications,
296
+ including but not limited to software source code, documentation
297
+ source, and configuration files.
298
+
299
+ "Object" form shall mean any form resulting from mechanical
300
+ transformation or translation of a Source form, including but
301
+ not limited to compiled object code, generated documentation,
302
+ and conversions to other media types.
303
+
304
+ "Work" shall mean the work of authorship, whether in Source or
305
+ Object form, made available under the License, as indicated by a
306
+ copyright notice that is included in or attached to the work
307
+ (an example is provided in the Appendix below).
308
+
309
+ "Derivative Works" shall mean any work, whether in Source or Object
310
+ form, that is based on (or derived from) the Work and for which the
311
+ editorial revisions, annotations, elaborations, or other modifications
312
+ represent, as a whole, an original work of authorship. For the purposes
313
+ of this License, Derivative Works shall not include works that remain
314
+ separable from, or merely link (or bind by name) to the interfaces of,
315
+ the Work and Derivative Works thereof.
316
+
317
+ "Contribution" shall mean any work of authorship, including
318
+ the original version of the Work and any modifications or additions
319
+ to that Work or Derivative Works thereof, that is intentionally
320
+ submitted to Licensor for inclusion in the Work by the copyright owner
321
+ or by an individual or Legal Entity authorized to submit on behalf of
322
+ the copyright owner. For the purposes of this definition, "submitted"
323
+ means any form of electronic, verbal, or written communication sent
324
+ to the Licensor or its representatives, including but not limited to
325
+ communication on electronic mailing lists, source code control systems,
326
+ and issue tracking systems that are managed by, or on behalf of, the
327
+ Licensor for the purpose of discussing and improving the Work, but
328
+ excluding communication that is conspicuously marked or otherwise
329
+ designated in writing by the copyright owner as "Not a Contribution."
330
+
331
+ "Contributor" shall mean Licensor and any individual or Legal Entity
332
+ on behalf of whom a Contribution has been received by Licensor and
333
+ subsequently incorporated within the Work.
334
+
335
+ 2. Grant of Copyright License. Subject to the terms and conditions of
336
+ this License, each Contributor hereby grants to You a perpetual,
337
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
338
+ copyright license to reproduce, prepare Derivative Works of,
339
+ publicly display, publicly perform, sublicense, and distribute the
340
+ Work and such Derivative Works in Source or Object form.
341
+
342
+ 3. Grant of Patent License. Subject to the terms and conditions of
343
+ this License, each Contributor hereby grants to You a perpetual,
344
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
345
+ (except as stated in this section) patent license to make, have made,
346
+ use, offer to sell, sell, import, and otherwise transfer the Work,
347
+ where such license applies only to those patent claims licensable
348
+ by such Contributor that are necessarily infringed by their
349
+ Contribution(s) alone or by combination of their Contribution(s)
350
+ with the Work to which such Contribution(s) was submitted. If You
351
+ institute patent litigation against any entity (including a
352
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
353
+ or a Contribution incorporated within the Work constitutes direct
354
+ or contributory patent infringement, then any patent licenses
355
+ granted to You under this License for that Work shall terminate
356
+ as of the date such litigation is filed.
357
+
358
+ 4. Redistribution. You may reproduce and distribute copies of the
359
+ Work or Derivative Works thereof in any medium, with or without
360
+ modifications, and in Source or Object form, provided that You
361
+ meet the following conditions:
362
+
363
+ (a) You must give any other recipients of the Work or
364
+ Derivative Works a copy of this License; and
365
+
366
+ (b) You must cause any modified files to carry prominent notices
367
+ stating that You changed the files; and
368
+
369
+ (c) You must retain, in the Source form of any Derivative Works
370
+ that You distribute, all copyright, patent, trademark, and
371
+ attribution notices from the Source form of the Work,
372
+ excluding those notices that do not pertain to any part of
373
+ the Derivative Works; and
374
+
375
+ (d) If the Work includes a "NOTICE" text file as part of its
376
+ distribution, then any Derivative Works that You distribute must
377
+ include a readable copy of the attribution notices contained
378
+ within such NOTICE file, excluding those notices that do not
379
+ pertain to any part of the Derivative Works, in at least one
380
+ of the following places: within a NOTICE text file distributed
381
+ as part of the Derivative Works; within the Source form or
382
+ documentation, if provided along with the Derivative Works; or,
383
+ within a display generated by the Derivative Works, if and
384
+ wherever such third-party notices normally appear. The contents
385
+ of the NOTICE file are for informational purposes only and
386
+ do not modify the License. You may add Your own attribution
387
+ notices within Derivative Works that You distribute, alongside
388
+ or as an addendum to the NOTICE text from the Work, provided
389
+ that such additional attribution notices cannot be construed
390
+ as modifying the License.
391
+
392
+ You may add Your own copyright statement to Your modifications and
393
+ may provide additional or different license terms and conditions
394
+ for use, reproduction, or distribution of Your modifications, or
395
+ for any such Derivative Works as a whole, provided Your use,
396
+ reproduction, and distribution of the Work otherwise complies with
397
+ the conditions stated in this License.
398
+
399
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
400
+ any Contribution intentionally submitted for inclusion in the Work
401
+ by You to the Licensor shall be under the terms and conditions of
402
+ this License, without any additional terms or conditions.
403
+ Notwithstanding the above, nothing herein shall supersede or modify
404
+ the terms of any separate license agreement you may have executed
405
+ with Licensor regarding such Contributions.
406
+
407
+ 6. Trademarks. This License does not grant permission to use the trade
408
+ names, trademarks, service marks, or product names of the Licensor,
409
+ except as required for reasonable and customary use in describing the
410
+ origin of the Work and reproducing the content of the NOTICE file.
411
+
412
+ 7. Disclaimer of Warranty. Unless required by applicable law or
413
+ agreed to in writing, Licensor provides the Work (and each
414
+ Contributor provides its Contributions) on an "AS IS" BASIS,
415
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
416
+ implied, including, without limitation, any warranties or conditions
417
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
418
+ PARTICULAR PURPOSE. You are solely responsible for determining the
419
+ appropriateness of using or redistributing the Work and assume any
420
+ risks associated with Your exercise of permissions under this License.
421
+
422
+ 8. Limitation of Liability. In no event and under no legal theory,
423
+ whether in tort (including negligence), contract, or otherwise,
424
+ unless required by applicable law (such as deliberate and grossly
425
+ negligent acts) or agreed to in writing, shall any Contributor be
426
+ liable to You for damages, including any direct, indirect, special,
427
+ incidental, or consequential damages of any character arising as a
428
+ result of this License or out of the use or inability to use the
429
+ Work (including but not limited to damages for loss of goodwill,
430
+ work stoppage, computer failure or malfunction, or any and all
431
+ other commercial damages or losses), even if such Contributor
432
+ has been advised of the possibility of such damages.
433
+
434
+ 9. Accepting Warranty or Additional Liability. While redistributing
435
+ the Work or Derivative Works thereof, You may choose to offer,
436
+ and charge a fee for, acceptance of support, warranty, indemnity,
437
+ or other liability obligations and/or rights consistent with this
438
+ License. However, in accepting such obligations, You may act only
439
+ on Your own behalf and on Your sole responsibility, not on behalf
440
+ of any other Contributor, and only if You agree to indemnify,
441
+ defend, and hold each Contributor harmless for any liability
442
+ incurred by, or claims asserted against, such Contributor by reason
443
+ of your accepting any such warranty or additional liability.
444
+
445
+ END OF TERMS AND CONDITIONS
446
+
447
+ APPENDIX: How to apply the Apache License to your work.
448
+
449
+ To apply the Apache License to your work, attach the following
450
+ boilerplate notice, with the fields enclosed by brackets "[]"
451
+ replaced with your own identifying information. (Don't include
452
+ the brackets!) The text should be enclosed in the appropriate
453
+ comment syntax for the file format. We also recommend that a
454
+ file or class name and description of purpose be included on the
455
+ same "printed page" as the copyright notice for easier
456
+ identification within third-party archives.
457
+
458
+ Copyright [yyyy] [name of copyright owner]
459
+
460
+ Licensed under the Apache License, Version 2.0 (the "License");
461
+ you may not use this file except in compliance with the License.
462
+ You may obtain a copy of the License at
463
+
464
+ http://www.apache.org/licenses/LICENSE-2.0
465
+
466
+ Unless required by applicable law or agreed to in writing, software
467
+ distributed under the License is distributed on an "AS IS" BASIS,
468
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
469
+ See the License for the specific language governing permissions and
470
+ limitations under the License.
README-SEED-2.md ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # :chestnut: SEED Multimodal
2
+
3
+ [![Project Homepage](https://img.shields.io/badge/Project-Homepage-green)](https://ailab-cvc.github.io/seed/)
4
+ [![arXiv](https://img.shields.io/badge/arXiv-2307.08041-b31b1b.svg)](https://arxiv.org/abs/2307.08041)
5
+ [![arXiv](https://img.shields.io/badge/arXiv-2310.01218-b31b1b.svg)](https://arxiv.org/abs/2310.01218)
6
+ [![Static Badge](https://img.shields.io/badge/Model-Huggingface-yellow)](https://huggingface.co/AILab-CVC/SEED/tree/main)
7
+ [![Demo](https://img.shields.io/badge/Gradio-Demo-orange)](https://10a4e7976e6fc2032c.gradio.live/)
8
+
9
+
10
+ **Powered by [CV Center, Tencent AI Lab](https://ailab-cvc.github.io), and [ARC Lab, Tencent PCG](https://github.com/TencentARC).**
11
+
12
+ ![image](https://github.com/AILab-CVC/SEED/blob/main/paper_images/milestone.jpg)
13
+
14
+ The repository provides the official implementation of [SEED](https://ailab-cvc.github.io/seed/seed.html), [SEED-LLaMA](https://ailab-cvc.github.io/seed/seed_llama.html). For any inquiries, please email [seed-x@googlegroups.com](mailto:seed-x@googlegroups.com).
15
+
16
+
17
+ ## News
18
+
19
+ **:beers: We are actively looking for self-motivated interns. Please feel free to reach out if you are interested. :beers:**
20
+
21
+ - [x] **2023-10-23** :hugs: We have optimized the memory overhead. Through 8bit quantization and dynamic loading, SEED-LLaMA 8b/14B can run on single **16GB/24GB** GPU.
22
+ - [x] **2023-10-23** :hugs: All model weights will be **downloaded automatically** when starting the demo.
23
+ - [x] **2023-10-20** :hugs: We release the [checkpoints](https://huggingface.co/AILab-CVC/SEED/tree/main) and code of the SEED-2 tokenizer, and SEED-LLaMA-8B/14B.
24
+ - [x] **2023-10-20** :space_invader: We release an online [gradio demo](https://10a4e7976e6fc2032c.gradio.live/), feel free to use it by yourself.
25
+ - [x] **2023-10-02** :paperclip: We release the technical report of SEED-LLaMA on [arXiv](https://arxiv.org/abs/2310.01218), which is empowered by the improved SEED-2 tokenizer.
26
+ - [x] **2023-07-29** :octocat: We release the checkpoint of the SEED tokenizer and its inference code. Check it out via [SEED-1](./SEED-1.md).
27
+ - [x] **2023-07-16** :paperclip: We release the technical report of SEED on [arXiv](https://arxiv.org/abs/2307.08041).
28
+
29
+ Stay tuned for the updates!
30
+
31
+ ## Brief Introduction
32
+
33
+ It is recommended to check out our [papers](#citation) for technical details.
34
+
35
+ ### :speech_balloon: What can SEED-LLaMA do?
36
+
37
+ ![image](https://github.com/AILab-CVC/SEED/blob/main/paper_images/v2/teaser.jpg)
38
+
39
+ **SEED-LLaMA** is capable of both multimodal comprehension and generation, exhibiting compositional emergent abilities such as multi-turn in-context multimodal generation, acting like your AI assistant. [[Compare to SOTA]](https://ailab-cvc.github.io/seed/seed_llama_compare.html) [[More examples on X]](https://twitter.com/ge_yixiao/status/1710509538238157069?s=20)
40
+
41
+ <!-- We present **SEED-LLaMA** by large-scale pretraining and instruction tuning on the interleaved textual and visual data, which demonstrates impressive performance on a broad range of multimodal comprehension and generation tasks. More importantly, SEED-LLaMA has exhibited **compositional emergent abilities** such as multi-turn in-context multimodal generation, acting like your **AI assistant**. -->
42
+
43
+ ### :bulb: How does SEED-LLaMA achieve it?
44
+
45
+ ![image](https://github.com/AILab-CVC/SEED/blob/main/paper_images/seed_overview.jpg)
46
+
47
+ The core of SEED-LLaMA is the tailored **SEED** tokenizer, which properly quantized visual signals into discrete visual tokens, capturing necessary semantics while being produced under 1D causal dependence. [[SEED-2 vs. SEED-1]](https://ailab-cvc.github.io/seed/seed_llama.html)
48
+
49
+ <!-- ### Compositional Emergent Ability
50
+ **Multi-turn in-context image and text generation.**
51
+ ![image](paper_images/v2/multi_turn1.jpg)
52
+ ![image](paper_images/v2/multi_turn2.jpg)
53
+
54
+ **Compositional image generation.**
55
+ ![image](paper_images/v2/results.jpg) -->
56
+
57
+ <!-- ### SEED Tokenizer v2
58
+ In SEED tokenizer v2, the generation embedding is aligned with the **image embedding** (1 token) of [unCLIP SD](https://huggingface.co/stabilityai/stable-diffusion-2-1-unclip), and can be decoded to realistic images with the unCLIP-SD-UNet. In SEED tokenizer v1, we train a visual tokenizer through aligning the **generation embeddings** with the text embeddings (77 tokens) of [SD](https://github.com/CompVis/stable-diffusion), and the generation embeddings can be decoded to images with the SD-UNet. The below figure shows the visual comparison of the reconstructed images between SEED tokenizer v2 (the third row) and SEED tokenizer v1 (the second row). We can observe that the images reconstructed by SEED tokenizer v2 can better preserve the visual information of the original images. The semantic representations of texts can not fully preserve the rich visual information of images.
59
+ ![image](paper_images/v2/seed_comparison.jpg) -->
60
+
61
+ <!-- ### Pretraining
62
+ We perform multimodal autoregressive pretraining on interleaved visual and textual data for SEED-LLaMA. Visual inputs are pre-processed into discrete tokens to conserve computational resources. Given the multimodal discrete sequence, a unified next-word-prediction objective is employed. During inference, visual codes are decoded into a realistic image by SEED De-Tokenization.
63
+ ![image](paper_images/v2/method_page.jpg) -->
64
+
65
+ ## Usage
66
+
67
+ ### Dependencies
68
+ - Python >= 3.8 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux))
69
+ - [PyTorch >= 1.11.0](https://pytorch.org/)
70
+ - NVIDIA GPU + [CUDA](https://developer.nvidia.com/cuda-downloads)
71
+
72
+ ### Installation
73
+ Clone the repo and install dependent packages
74
+
75
+ ```bash
76
+ git clone https://github.com/AILab-CVC/SEED.git
77
+ cd SEED
78
+ pip install -r requirements.txt
79
+ ```
80
+
81
+
82
+ ### Model Weights
83
+ We release the pretrained SEED Tokenizer and De-Tokenizer, pretrained and instruction tuned SEED-LLaMA-8B and SEED-LLaMA-14B in [SEED Hugging Face](https://huggingface.co/AILab-CVC/SEED).
84
+
85
+ - Check the SEED tokenizer weights in [AILab-CVC/seed-tokenizer-2](https://huggingface.co/AILab-CVC/seed-tokenizer-2)
86
+ - Check the SEED LLaMA(8B) weights in [AILab-CVC/seed-llama-8b-sft](https://huggingface.co/AILab-CVC/seed-llama-8b-sft)
87
+ - Check the SEED LLaMA(14B) weights in [AILab-CVC/seed-llama-14b-sft](https://huggingface.co/AILab-CVC/seed-llama-14b-sft)
88
+
89
+ <!-- Please download the checkpoints and save under the folder `./pretrained`.
90
+
91
+ ```bash
92
+ cd pretrained # SEED/pretrained
93
+ git lfs install
94
+ git clone https://huggingface.co/AILab-CVC/SEED
95
+ mv SEED/* ./
96
+ ``` -->
97
+
98
+ The model weights of unCLIP SD-UNet which are used to reconstruct the image will be downloaded automatically.
99
+
100
+ <!-- To reconstruct the image from the SEED visual codes using unCLIP SD-UNet, please download the pretrained [unCLIP SD](https://huggingface.co/stabilityai/stable-diffusion-2-1-unclip). -->
101
+
102
+ <!-- To reconstruct the image from the SEED visual codes using unCLIP SD-UNet, please download the pretrained [unCLIP SD](https://huggingface.co/stabilityai/stable-diffusion-2-1-unclip).
103
+ Rename the checkpoint directory to **"diffusion_model"** and create a soft link to the "pretrained/seed_tokenizer" directory.
104
+
105
+ ```bash
106
+ # SEED/pretrained
107
+ git lfs install
108
+ git clone https://huggingface.co/stabilityai/stable-diffusion-2-1-unclip
109
+ mv stable-diffusion-2-1-unclip seed_tokenizer/diffusion_model
110
+ ``` -->
111
+
112
+
113
+ ### Inference for visual tokenization and de-tokenization
114
+ To discretize an image to 1D visual codes with causal dependency, and reconstruct the image from the visual codes using the off-the-shelf unCLIP SD-UNet:
115
+
116
+ ```bash
117
+ cd .. # SEED/
118
+ python scripts/seed_tokenizer_inference.py
119
+ ```
120
+ ### Inference for SEED-LLaMA
121
+ Given that SEED-LLaMA-8B is based on Vicuna-7B and SEED-LLaMA-14B based on LLaMA2-Chat-13B, we use Vicuna-7B's ("USER:", "ASSISTANT:") and LLaMA2-Chat-13B's ([INST] [/INST]) prompts for respective instruction tuning.
122
+
123
+ ```bash
124
+ # Inference for SEED-LLaMA-8B
125
+ python scripts/seed_llama_inference_8B.py
126
+ ```
127
+
128
+ ```bash
129
+ # Inference for SEED-LLaMA-14B
130
+ python scripts/seed_llama_inference_14B.py
131
+ ```
132
+
133
+
134
+ ### Launching Gradio Demo of SEED-LLaMA-14B Locally
135
+ 1. Building the local demo of SEED-LLaMA-14B currently requires **single 24GB** GPU.
136
+
137
+ ```bash
138
+ # SEED/
139
+ # in first terminal
140
+ bash scripts/start_backend_14b.sh
141
+ # in second terminal
142
+ bash scripts/start_frontend_14b.sh
143
+ ```
144
+
145
+ 2. Building the local demo of SEED-LLaMA-8B currently requires **single 16GB** GPU.
146
+
147
+ ```bash
148
+ # SEED/
149
+ # in first terminal
150
+ bash scripts/start_backend_8b.sh
151
+ # in second terminal
152
+ bash scripts/start_frontend_8b.sh
153
+ ```
154
+
155
+ Then the demo can be accessed through http://127.0.0.1:80
156
+
157
+ ## Citation
158
+ If you find the work helpful, please consider citing:
159
+ ```bash
160
+ @article{ge2023making,
161
+ title={Making LLaMA SEE and Draw with SEED Tokenizer},
162
+ author={Ge, Yuying and Zhao, Sijie and Zeng, Ziyun and Ge, Yixiao and Li, Chen and Wang, Xintao and Shan, Ying},
163
+ journal={arXiv preprint arXiv:2310.01218},
164
+ year={2023}
165
+ }
166
+
167
+ @article{ge2023planting,
168
+ title={Planting a seed of vision in large language model},
169
+ author={Ge, Yuying and Ge, Yixiao and Zeng, Ziyun and Wang, Xintao and Shan, Ying},
170
+ journal={arXiv preprint arXiv:2307.08041},
171
+ year={2023}
172
+ }
173
+ ```
174
+
175
+ The project is still in progress.
176
+
177
+ ## License
178
+ `SEED` is released under [Apache License Version 2.0](License.txt).
179
+
180
+ `SEED-LLaMA` is released under the original [License](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) of [LLaMA2](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf).
181
+
182
+ ## Acknowledgement
183
+ We thank the great work from [unCLIP SD](https://huggingface.co/stabilityai/stable-diffusion-2-1-unclip) and [BLIP2](https://github.com/salesforce/LAVIS).
184
+
SEED-1.md ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SEED Tokenizer v1
2
+ [[arXiv]](https://arxiv.org/abs/2307.08041)
3
+
4
+ ![image](paper_images/teaser.jpg)
5
+ ## Abstract
6
+ We present SEED, an elaborate image tokenizer that empowers Large Language
7
+ Models (LLMs) with the emergent ability to **SEE** and **D**raw at the same time.
8
+ Research on image tokenizers has previously reached an impasse, as frameworks
9
+ employing quantized visual tokens have lost prominence due to subpar performance and convergence in multimodal comprehension (compared to BLIP-2, etc.)
10
+ or generation (compared to Stable Diffusion, etc.). Despite the limitations, we
11
+ remain confident in its natural capacity to unify visual and textual representations,
12
+ facilitating scalable multimodal training with LLM’s original recipe. In this study,
13
+ we identify two crucial principles for the architecture and training of SEED that
14
+ effectively ease subsequent alignment with LLMs. (1) Image tokens should be
15
+ independent of 2D physical patch positions and instead be produced with a 1D
16
+ causal dependency, exhibiting intrinsic interdependence that aligns with the left-to-right autoregressive prediction mechanism in LLMs. (2) Image tokens should
17
+ capture high-level semantics consistent with the degree of semantic abstraction in
18
+ words, and be optimized for both discriminativeness and reconstruction during the
19
+ tokenizer training phase. As a result, the off-the-shelf LLM is able to perform both
20
+ image-to-text and text-to-image generation by incorporating our SEED through
21
+ efficient LoRA tuning. Comprehensive multimodal pretraining and instruction
22
+ tuning, which may yield improved results, are reserved for future investigation.
23
+ This version of SEED was trained in 5.7 days using only 64 V100 GPUs and 5M
24
+ publicly available image-text pairs. Our preliminary study emphasizes the great
25
+ potential of discrete visual tokens in versatile multimodal LLMs and the importance
26
+ of proper image tokenizers in broader research.
27
+
28
+ ## SEED Tokenizer for Image Reconstruction
29
+ ![image](paper_images/reconstruction.jpg)
30
+
31
+ ## SEED-OPT<sub>2.7B </sub> for Multimodal Comprehension
32
+ ![image](paper_images/vqa.jpg)
33
+
34
+ ## SEED-OPT<sub>2.7B </sub> for Multimodal Generation
35
+ ![image](paper_images/generation.jpg)
36
+
37
+ ## Dependencies and Installation
38
+ - Python >= 3.8 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux))
39
+ - [PyTorch >= 1.11.0](https://pytorch.org/)
40
+ - NVIDIA GPU + [CUDA](https://developer.nvidia.com/cuda-downloads)
41
+ ### Installation
42
+ 1. Clone repo
43
+
44
+ ```bash
45
+ git clone https://github.com/AILab-CVC/SEED.git
46
+ cd SEED
47
+ ```
48
+
49
+ 2. Install dependent packages
50
+
51
+ ```bash
52
+ sh install.sh
53
+ ```
54
+
55
+ ## Model Weights
56
+ We release the pre-trained SEED Visual Tokenizer in [google drive](https://drive.google.com/drive/folders/1xmVXuttQfBPBOe4ZR96Wu1X34uzPkxsS?usp=drive_link).
57
+
58
+ ## Inference
59
+ To discretize an image to 1D vision codes with causal dependency, and reconstruct the image
60
+ from the vision codes using stable diffusion UNet,
61
+
62
+ 1. Download the pre-trained SEED Visual Tokenizer and stable diffusion model in [google drive](https://drive.google.com/drive/folders/1xmVXuttQfBPBOe4ZR96Wu1X34uzPkxsS?usp=drive_link) and put them under the folder "pretrained".
63
+ 2. run the inference code.
64
+ ```bash
65
+ python demo_recon.py
66
+ ```
67
+
68
+ ## To Do
69
+ - [x] Release SEED Tokenizer
70
+
71
+ ## License
72
+ SEED is released under Apache License Version 2.0.
73
+
74
+ ## Acknowledgement
75
+ We utilize Stable Diffusion to decode images from our visual codes, and use its implementation and pre-trained model in https://github.com/CompVis/stable-diffusion.git.
76
+
77
+ Our code is based on the implementation of BLIP-2 in https://github.com/salesforce/LAVIS.git.
78
+
79
+
80
+ ## Citation
81
+ If you find the work helpful, please consider citing:
82
+ ```
83
+ @misc{ge2023planting,
84
+ title={Planting a SEED of Vision in Large Language Model},
85
+ author={Yuying Ge and Yixiao Ge and Ziyun Zeng and Xintao Wang and Ying Shan},
86
+ year={2023},
87
+ eprint={2307.08041},
88
+ archivePrefix={arXiv},
89
+ primaryClass={cs.CV}
90
+ }
91
+ ```
92
+
93
+ The project is still in progress. Stay tuned for more updates!
configs/llm/seed_llama_14b.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ _target_: models.model_tools.get_pretrained_llama_causal_model
2
+ pretrained_model_name_or_path: ${oc.env:PROJECT_ROOT}/pretrained/seed_llama_14b_sft
3
+
4
+ torch_dtype: fp16
5
+ low_cpu_mem_usage: True
configs/llm/seed_llama_14b_8bit.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ _target_: transformers.LlamaForCausalLM.from_pretrained
2
+ pretrained_model_name_or_path: AILab-CVC/seed-llama-14b-sft
3
+ load_in_8bit: True
4
+ # device_map: auto
5
+ low_cpu_mem_usage: True
configs/llm/seed_llama_8b.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ _target_: models.model_tools.get_pretrained_llama_causal_model
2
+ pretrained_model_name_or_path: ${oc.env:PROJECT_ROOT}/pretrained/seed_llama_8b_sft
3
+
4
+ torch_dtype: fp16
5
+ low_cpu_mem_usage: True
configs/llm/seed_llama_8b_8bit.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ _target_: transformers.LlamaForCausalLM.from_pretrained
2
+ pretrained_model_name_or_path: AILab-CVC/seed-llama-8b-sft
3
+ load_in_8bit: True
4
+ # device_map: auto
5
+ low_cpu_mem_usage: True
configs/tokenizer/seed_llama_tokenizer.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ _target_: models.seed_llama_tokenizer.SeedLlamaTokenizer.from_pretrained
2
+ pretrained_model_name_or_path: ${oc.env:PROJECT_ROOT}/pretrained/seed_tokenizer
3
+ fp16: True
4
+ load_diffusion: True
configs/tokenizer/seed_llama_tokenizer_hf.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ _target_: models.seed_llama_tokenizer.SeedLlamaTokenizer.from_pretrained
2
+ pretrained_model_name_or_path: AILab-CVC/seed-tokenizer-2
3
+ fp16: True
4
+ load_diffusion: False
5
+ encoder_url: https://huggingface.co/AILab-CVC/seed-tokenizer-2/resolve/main/seed_quantizer.pt
6
+ diffusion_path: stabilityai/stable-diffusion-2-1-unclip
configs/transform/clip_transform.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ _target_: models.transforms.get_transform
2
+ type: clip
3
+ image_size: 224
4
+ keep_ratio: False
gradio_demo/conversation.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Tuple
4
+
5
+ import io
6
+ import base64
7
+ import os
8
+ from PIL import Image
9
+ import copy
10
+
11
+ IMG_FLAG = '<image>'
12
+
13
+
14
+ class SeparatorStyle(Enum):
15
+ """Different separator style."""
16
+ SINGLE = auto()
17
+ TWO = auto()
18
+ MPT = auto()
19
+ PLAIN = auto()
20
+ LLAMA_2 = auto()
21
+
22
+
23
+ def decode_image(encoded_image: str) -> Image:
24
+ decoded_bytes = base64.b64decode(encoded_image.encode('utf-8'))
25
+ buffer = io.BytesIO(decoded_bytes)
26
+ image = Image.open(buffer)
27
+ return image
28
+
29
+
30
+ def encode_image(image: Image.Image, format: str = 'PNG') -> str:
31
+ with io.BytesIO() as buffer:
32
+ image.save(buffer, format=format)
33
+ encoded_image = base64.b64encode(buffer.getvalue()).decode('utf-8')
34
+ return encoded_image
35
+
36
+
37
+ @dataclasses.dataclass
38
+ class Conversation:
39
+ """A class that keeps all conversation history."""
40
+ system: str
41
+ roles: List[str]
42
+ messages: List[dict] # multi-turn -> user & assistant -> {'images': [PIL.Image,], 'text': str}
43
+ offset: int
44
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
45
+ sep: str = "###"
46
+ sep2: str = None
47
+ version: str = "Unknown"
48
+
49
+ skip_next: bool = False
50
+
51
+ def get_prompt(self):
52
+ messages = copy.deepcopy(self.messages)
53
+ if self.sep_style == SeparatorStyle.SINGLE:
54
+ if self.system is None or self.system == '':
55
+ text = ''
56
+ else:
57
+ text = self.system + self.sep
58
+ images = []
59
+ for message in messages:
60
+ text += message['role'] + ": " + message['message']['text'] + self.sep
61
+ for image_path, image_ids in zip(message['message']['images'], message['message']['images_ids']):
62
+ if image_ids is not None:
63
+ images.append(image_ids)
64
+ else:
65
+ image = Image.open(image_path).resize((256, 256))
66
+ image_base64 = encode_image(image)
67
+ images.append(image_base64)
68
+
69
+ text += self.roles[1] + ":"
70
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
71
+ b_token = "[INST] "
72
+ e_token = " [/INST]"
73
+ if self.system is None or self.system == '':
74
+ text = ''
75
+ else:
76
+ text = f"<<SYS>>\n{self.system}\n<</SYS>>\n\n"
77
+ images = []
78
+ for idx, message in enumerate(messages):
79
+ # text += message['role'] + ": " + message['message']['text'] + self.sep
80
+ if idx % 2 == 0:
81
+ text += b_token + message['message']['text'] + e_token + self.sep
82
+ else:
83
+ text += message['message']['text'] + self.sep
84
+
85
+ for image_path, image_ids in zip(message['message']['images'], message['message']['images_ids']):
86
+ if image_ids is not None:
87
+ images.append(image_ids)
88
+ else:
89
+ image = Image.open(image_path).resize((256, 256))
90
+ image_base64 = encode_image(image)
91
+ images.append(image_base64)
92
+ else:
93
+ raise NotImplementedError
94
+
95
+ return {'text': text, 'images': images}
96
+
97
+ def update_image_ids(self, images_ids):
98
+ image_count = 0
99
+ for message in self.messages:
100
+ for idx in range(len(message['message']['images_ids'])):
101
+ if message['message']["images_ids"][idx] is None:
102
+ message['message']["images_ids"][idx] = images_ids[image_count]
103
+ image_count += 1
104
+
105
+ assert len(images_ids) == image_count, print(len(images_ids), image_count)
106
+
107
+ def append_message(self, role, message):
108
+ self.messages.append([role, message])
109
+
110
+ def to_gradio_chatbot(self):
111
+ dialog = []
112
+ for i, single_turn in enumerate(self.messages[self.offset:]):
113
+ single_turn = single_turn['message']
114
+ text_list = single_turn['text'].split(IMG_FLAG)
115
+ assert len(text_list) == len(single_turn['images']) + 1, print(text_list, len(single_turn['images']))
116
+ message = ''
117
+ for image_idx in range(len(single_turn['images'])):
118
+ # image = single_turn['images'][image_idx]
119
+ # image_base64 = encode_image(image)
120
+ # image_str = f'<img src="data:image/png;base64,{image_base64}" alt="user upload image" />'
121
+ image_path = single_turn['images'][image_idx]
122
+ if image_path == '':
123
+ message += text_list[image_idx] + '<corrupt_image>'
124
+ else:
125
+ message += text_list[image_idx] + f'![](file={image_path})'
126
+ message += text_list[-1]
127
+
128
+ if i % 2 == 0:
129
+ dialog.append([message, None])
130
+ else:
131
+ dialog[-1][-1] = message
132
+
133
+ return dialog
134
+
135
+ def copy(self):
136
+ return Conversation(system=self.system,
137
+ roles=self.roles,
138
+ messages=copy.deepcopy(self.messages),
139
+ offset=self.offset,
140
+ sep_style=self.sep_style,
141
+ sep=self.sep,
142
+ sep2=self.sep2,
143
+ version=self.version)
144
+
145
+ def dict(self):
146
+ messages = copy.deepcopy(self.messages)
147
+ for message in messages:
148
+ if 'images_ids' in message:
149
+ message.pop('images_ids')
150
+ for i in range(len(message['message']['images'])):
151
+ message['message']['images'][i] = os.path.basename(message['message']['images'][i])
152
+ return {
153
+ "system": self.system,
154
+ "roles": self.roles,
155
+ "messages": messages,
156
+ "offset": self.offset,
157
+ "sep": self.sep,
158
+ "sep2": self.sep2,
159
+ }
160
+
161
+
162
+ conv_seed_vicuna = Conversation(
163
+ system="",
164
+ roles=("USER", "ASSISTANT"),
165
+ version="v2",
166
+ messages=[],
167
+ offset=0,
168
+ sep_style=SeparatorStyle.SINGLE,
169
+ sep='\n',
170
+ )
171
+
172
+ conv_seed_vicuna_system = Conversation(
173
+ system="A chat between a curious user and an artificial intelligence assistant. ",
174
+ roles=("USER", "ASSISTANT"),
175
+ version="v2",
176
+ messages=[],
177
+ offset=0,
178
+ sep_style=SeparatorStyle.SINGLE,
179
+ sep='\n',
180
+ )
181
+
182
+ conv_seed_llama2 = Conversation(
183
+ system="",
184
+ roles=("[INST]", "[/INST]"),
185
+ version="v2",
186
+ messages=[],
187
+ offset=0,
188
+ sep_style=SeparatorStyle.LLAMA_2,
189
+ sep='\n',
190
+ )
gradio_demo/seed_llama_flask.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hydra
2
+
3
+ import pyrootutils
4
+ import os
5
+ import torch
6
+
7
+ from omegaconf import OmegaConf
8
+ from flask import Flask, request
9
+ import json
10
+ from typing import Optional
11
+ import transformers
12
+ from dataclasses import dataclass, field
13
+ import io
14
+ import base64
15
+ from PIL import Image
16
+ import gc
17
+
18
+ pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
19
+
20
+ BOI_TOKEN = '<img>'
21
+ EOI_TOKEN = '</img>'
22
+ IMG_TOKEN = '<img_{:05d}>'
23
+
24
+ IMG_FLAG = '<image>'
25
+ NUM_IMG_TOKNES = 32
26
+ NUM_IMG_CODES = 8192
27
+
28
+ app = Flask(__name__)
29
+
30
+
31
+ def decode_image(encoded_image: str) -> Image:
32
+ decoded_bytes = base64.b64decode(encoded_image.encode('utf-8'))
33
+ buffer = io.BytesIO(decoded_bytes)
34
+ image = Image.open(buffer)
35
+ return image
36
+
37
+
38
+ def encode_image(image: Image.Image, format: str = 'PNG') -> str:
39
+ with io.BytesIO() as buffer:
40
+ image.save(buffer, format=format)
41
+ encoded_image = base64.b64encode(buffer.getvalue()).decode('utf-8')
42
+ return encoded_image
43
+
44
+
45
+ @dataclass
46
+ class Arguments:
47
+ image_transform: Optional[str] = field(default=None, metadata={"help": "config path of image transform"})
48
+ tokenizer: Optional[str] = field(default=None, metadata={"help": "config path of tokenizer used to initialize tokenizer"})
49
+ model: Optional[str] = field(default=None, metadata={"help": "config path of llm"})
50
+ port: Optional[str] = field(default=80, metadata={"help": "network port"})
51
+ llm_device: Optional[str] = field(default='cuda:0', metadata={"help": "llm device"})
52
+ tokenizer_device: Optional[str] = field(default='cuda:0', metadata={"help": "tokenizer device"})
53
+ offload_encoder: Optional[bool] = field(default=False, metadata={"help": "offload image tokenizer"})
54
+ offload_decoder: Optional[bool] = field(default=True, metadata={"help": "offload image tokenizer"})
55
+
56
+
57
+ parser = transformers.HfArgumentParser(Arguments)
58
+ args, = parser.parse_args_into_dataclasses()
59
+
60
+
61
+ class LLMService:
62
+ def __init__(self, args) -> None:
63
+ image_transform_cfg = OmegaConf.load(args.image_transform)
64
+ tokenizer_cfg = OmegaConf.load(args.tokenizer)
65
+ model_cfg = OmegaConf.load(args.model)
66
+ self.image_id_shift = 32000
67
+
68
+ self.image_transform = hydra.utils.instantiate(image_transform_cfg)
69
+ self.tokenizer = hydra.utils.instantiate(tokenizer_cfg, device=args.tokenizer_device, load_diffusion=True)
70
+
71
+ if args.offload_encoder:
72
+ self.tokenizer.image_tokenizer.model.visual_encoder.to('cpu')
73
+ if args.offload_decoder:
74
+ self.tokenizer.image_tokenizer.diffusion_model.to('cpu')
75
+
76
+ # model = hydra.utils.instantiate(model_cfg, torch_dtype=torch.float16)
77
+ # self.model = model.eval().to(args.llm_device)
78
+ model = hydra.utils.instantiate(model_cfg, device_map=args.llm_device).eval()
79
+ self.model = model
80
+ print(model.get_memory_footprint())
81
+ self.llm_device = args.llm_device
82
+ self.tokenizer_device = args.tokenizer_device
83
+ self.offload_encoder = args.offload_encoder
84
+ self.offload_decoder = args.offload_decoder
85
+ self.boi_token_id = self.tokenizer(BOI_TOKEN, add_special_tokens=False).input_ids[0]
86
+ self.eoi_token_id = self.tokenizer(EOI_TOKEN, add_special_tokens=False).input_ids[0]
87
+ print('Init Done...')
88
+
89
+
90
+ service = LLMService(args)
91
+
92
+
93
+ @app.route('/generate', methods=['GET', 'POST'])
94
+ def generate():
95
+
96
+ request_info = request.get_json()
97
+
98
+ text_list = request_info['text'].split(IMG_FLAG)
99
+ image_list = request_info['images']
100
+ temperature = request_info.get('temperature', 0.7)
101
+ num_beams = request_info.get('num_beams', 1)
102
+ max_new_tokens = request_info.get('max_new_tokens', 256)
103
+ top_p = request_info.get('top_p', 0.5)
104
+ force_boi = request_info.get('force_boi', False)
105
+
106
+ assert len(text_list) == len(image_list) + 1
107
+
108
+ if len(image_list) > 0:
109
+ images_tensor_list = []
110
+ images_tensor_indices = []
111
+ images_ids_list = []
112
+ images_ids_indices = []
113
+ for idx, image_item in enumerate(image_list):
114
+ if isinstance(image_item, str):
115
+ image = decode_image(image_item)
116
+ image_tensor = service.image_transform(image)
117
+ images_tensor_list.append(image_tensor)
118
+ images_tensor_indices.append(idx)
119
+ else:
120
+ images_ids_list.append(image_item)
121
+ images_ids_indices.append(idx)
122
+
123
+ if len(images_tensor_list) > 0:
124
+ images_tensor = torch.stack(images_tensor_list, dim=0).to(service.tokenizer_device)
125
+ if service.offload_encoder:
126
+ service.tokenizer.image_tokenizer.model.visual_encoder.to(service.tokenizer_device)
127
+
128
+ images_ids_1 = service.tokenizer.encode_image(image_torch=images_tensor).cpu()
129
+ if args.offload_encoder:
130
+ service.tokenizer.image_tokenizer.model.visual_encoder.to('cpu')
131
+ torch.cuda.empty_cache()
132
+ gc.collect()
133
+ num_image_ids = images_ids_1.shape[-1]
134
+ else:
135
+ num_image_ids = len(images_ids_list[-1])
136
+ images_ids_2 = torch.tensor(images_ids_list, dtype=torch.long)
137
+
138
+ images_ids = torch.zeros((len(image_list), num_image_ids), dtype=torch.long)
139
+ if len(images_tensor_indices) > 0:
140
+ images_ids[images_tensor_indices, :] = images_ids_1
141
+ if len(images_ids_indices) > 0:
142
+ images_ids[images_ids_indices, :] = images_ids_2
143
+
144
+ input_text = ''
145
+ for i in range(images_ids.shape[0]):
146
+ single_image_ids = images_ids[i].view(-1).tolist()
147
+ image_tokens = BOI_TOKEN + ''.join([IMG_TOKEN.format(int(item)) for item in single_image_ids]) + EOI_TOKEN
148
+ input_text += text_list[i] + image_tokens
149
+
150
+ input_text = service.tokenizer.bos_token + input_text + text_list[-1]
151
+
152
+ images_ids_list = images_ids.tolist()
153
+ else:
154
+
155
+ input_text = service.tokenizer.bos_token + ''.join(text_list)
156
+ images_ids_list = []
157
+
158
+ if force_boi:
159
+ input_text += BOI_TOKEN
160
+
161
+ print(input_text)
162
+ input_ids = service.tokenizer(input_text, add_special_tokens=False, return_tensors='pt').input_ids
163
+ input_ids = input_ids.to(service.llm_device)
164
+ generation_config = {
165
+ 'temperature': temperature,
166
+ 'num_beams': num_beams,
167
+ 'max_new_tokens': max_new_tokens,
168
+ 'top_p': top_p,
169
+ 'do_sample': True
170
+ }
171
+
172
+ generate_ids = service.model.generate(input_ids=input_ids, **generation_config)
173
+
174
+ if force_boi:
175
+ generate_ids = generate_ids[0][input_ids.shape[1] - 1:]
176
+ else:
177
+ generate_ids = generate_ids[0][input_ids.shape[1]:]
178
+ print('generated_ids: ', generate_ids)
179
+ boi_indices = torch.where(generate_ids == service.boi_token_id)[0].tolist()
180
+ eoi_indices = torch.where(generate_ids == service.eoi_token_id)[0].tolist()
181
+ # assert len(boi_indices) == len(eoi_indices)
182
+
183
+ generated_image_base64_list = []
184
+ text_mask = torch.ones_like(generate_ids, dtype=torch.bool)
185
+
186
+ error_msg = []
187
+ if len(boi_indices) != len(eoi_indices):
188
+ error_msg.append(
189
+ f'Num of BOI (begain of image) tokens: {len(boi_indices)} is not equal to EOI(end of image tokens): {len(eoi_indices)}, some image Some images will fail to decode.'
190
+ )
191
+
192
+ num_images = min(len(boi_indices), len(eoi_indices))
193
+ for idx in range(num_images):
194
+ boi_index, eoi_index = boi_indices[idx], eoi_indices[idx]
195
+ # for boi_index, eoi_index in zip(boi_indices, eoi_indices):
196
+ image_ids = generate_ids[boi_index + 1:eoi_index].unsqueeze(0).to(service.tokenizer_device)
197
+ image_ids = image_ids - service.image_id_shift
198
+ if image_ids.shape[-1] != NUM_IMG_TOKNES:
199
+ error_msg.append(f'Len(image_ids) {image_ids.shape[-1]} is not equal to {NUM_IMG_TOKNES}')
200
+ image_base64 = ''
201
+ elif (image_ids < 0).any() or (image_ids >= NUM_IMG_CODES).any():
202
+ error_msg.append(f'Some image_id out of range: [0, {NUM_IMG_CODES})')
203
+ image_base64 = ''
204
+ else:
205
+ if service.offload_decoder:
206
+ service.tokenizer.image_tokenizer.diffusion_model.to(service.tokenizer_device)
207
+ image = service.tokenizer.decode_image(image_ids)[0]
208
+ if service.offload_decoder:
209
+ service.tokenizer.image_tokenizer.diffusion_model.to('cpu')
210
+ torch.cuda.empty_cache()
211
+ gc.collect()
212
+ image_base64 = encode_image(image)
213
+
214
+ generated_image_base64_list.append(image_base64)
215
+ text_mask[boi_index + 1:eoi_index] = False
216
+ images_ids_list.append(image_ids.view(-1).tolist())
217
+ generate_ids = generate_ids[text_mask]
218
+
219
+ # print('generate_ids: ', generate_ids)
220
+ # generate_text = service.tokenizer.decode(generate_ids, skip_special_tokens=True)
221
+ generate_text = service.tokenizer.decode(generate_ids, skip_special_tokens=False)
222
+ # print('generate_text before: ', generate_text)
223
+ generate_text = generate_text.replace(BOI_TOKEN + ' ' + EOI_TOKEN + ' ', IMG_FLAG)
224
+ generate_text = generate_text.replace(service.tokenizer.eos_token, '')
225
+ print('generate_text: ', generate_text)
226
+ return {'text': generate_text, 'images': generated_image_base64_list, 'images_ids': images_ids_list, 'error_msg': error_msg}
227
+
228
+
229
+ if __name__ == '__main__':
230
+ app.run(host='0.0.0.0', port=args.port)
gradio_demo/seed_llama_gradio.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hydra
2
+
3
+ import pyrootutils
4
+ import os
5
+ import torch
6
+
7
+ import datetime
8
+ from omegaconf import OmegaConf
9
+ # from flask import Flask, request
10
+ import json
11
+ from typing import Optional
12
+ import transformers
13
+ from dataclasses import dataclass, field
14
+ import io
15
+ import base64
16
+ from PIL import Image
17
+ import gradio as gr
18
+ import random
19
+ import time
20
+ import hashlib
21
+ import requests
22
+
23
+ from utils import build_logger
24
+ from conversation import conv_seed_vicuna, conv_seed_llama2
25
+ # from conversation import conv_seed_llama
26
+
27
+ IMG_FLAG = '<image>'
28
+
29
+ # request_address = 'http://11.29.21.161:80/generate'
30
+ # request_address = 'http://0.0.0.0:7890/generate'
31
+ LOGDIR = 'log'
32
+
33
+ logger = build_logger("gradio_seed_llama", LOGDIR)
34
+ headers = {"User-Agent": "SEED LLaMA Client"}
35
+
36
+ no_change_btn = gr.Button.update()
37
+ enable_btn = gr.Button.update(interactive=True)
38
+ disable_btn = gr.Button.update(interactive=False)
39
+
40
+ @dataclass
41
+ class Arguments:
42
+ server_port: Optional[int] = field(default=7860, metadata={"help": "network port"})
43
+ server_name: Optional[str] = field(default='0.0.0.0', metadata={"help": "network address"})
44
+ request_address: Optional[str] = field(default='http://127.0.0.1:7890/generate', metadata={"help": "request address"})
45
+ model_type: Optional[str] = field(default='seed-llama-14b', metadata={"help": "choice: [seed-llama-8b, seed-llama-14b]"})
46
+
47
+ parser = transformers.HfArgumentParser(Arguments)
48
+ args, = parser.parse_args_into_dataclasses()
49
+
50
+ if args.model_type == 'seed-llama-8b':
51
+ conv_seed_llama = conv_seed_vicuna
52
+ elif args.model_type == 'seed-llama-14b':
53
+ conv_seed_llama = conv_seed_llama2
54
+ else:
55
+ raise ValueError
56
+
57
+
58
+ def decode_image(encoded_image: str) -> Image:
59
+ decoded_bytes = base64.b64decode(encoded_image.encode('utf-8'))
60
+ # with io.BytesIO(decoded_bytes) as buffer:
61
+ # image = Image.open(buffer)
62
+ # return image
63
+ buffer = io.BytesIO(decoded_bytes)
64
+ image = Image.open(buffer)
65
+ return image
66
+
67
+
68
+ def encode_image(image: Image.Image, format: str = 'PNG') -> str:
69
+ with io.BytesIO() as buffer:
70
+ image.save(buffer, format=format)
71
+ encoded_image = base64.b64encode(buffer.getvalue()).decode('utf-8')
72
+ return encoded_image
73
+
74
+
75
+ def get_conv_log_filename():
76
+ t = datetime.datetime.now()
77
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
78
+ return name
79
+
80
+
81
+ def get_conv_image_dir():
82
+ name = os.path.join(LOGDIR, 'images')
83
+ os.makedirs(name, exist_ok=True)
84
+ return name
85
+
86
+
87
+ def get_image_name(image, image_dir=None):
88
+ buffer = io.BytesIO()
89
+ image.save(buffer, format='PNG')
90
+ image_bytes = buffer.getvalue()
91
+ md5 = hashlib.md5(image_bytes).hexdigest()
92
+
93
+ if image_dir is not None:
94
+ image_name = os.path.join(image_dir, md5 + '.png')
95
+ else:
96
+ image_name = md5 + '.png'
97
+
98
+ return image_name
99
+
100
+
101
+ def resize_image(image, max_size=512):
102
+ width, height = image.size
103
+ aspect_ratio = float(width) / float(height)
104
+
105
+ if width > height:
106
+ new_width = max_size
107
+ new_height = int(new_width / aspect_ratio)
108
+ else:
109
+ new_height = max_size
110
+ new_width = int(new_height * aspect_ratio)
111
+
112
+ resized_image = image.resize((new_width, new_height))
113
+ return resized_image
114
+
115
+
116
+ def center_crop_image(image, max_aspect_ratio=1.5):
117
+ width, height = image.size
118
+ aspect_ratio = max(width, height) / min(width, height)
119
+
120
+ if aspect_ratio >= max_aspect_ratio:
121
+ if width > height:
122
+ new_width = int(height * max_aspect_ratio)
123
+ left = (width - new_width) // 2
124
+ right = (width + new_width) // 2
125
+ top = 0
126
+ bottom = height
127
+ else:
128
+ new_height = int(width * max_aspect_ratio)
129
+ left = 0
130
+ right = width
131
+ top = (height - new_height) // 2
132
+ bottom = (height + new_height) // 2
133
+
134
+ cropped_image = image.crop((left, top, right, bottom))
135
+ return cropped_image
136
+ else:
137
+ return image
138
+
139
+ def vote_last_response(state, vote_type, request: gr.Request):
140
+ with open(get_conv_log_filename(), "a") as fout:
141
+ data = {
142
+ "tstamp": round(time.time(), 4),
143
+ "type": vote_type,
144
+ "state": state.dict(),
145
+ "ip": request.client.host,
146
+ }
147
+ fout.write(json.dumps(data) + "\n")
148
+
149
+
150
+ def upvote_last_response(state, request: gr.Request):
151
+ logger.info(f"upvote. ip: {request.client.host}")
152
+ vote_last_response(state, "upvote", request)
153
+ return (disable_btn, ) * 2
154
+
155
+
156
+ def downvote_last_response(state, request: gr.Request):
157
+ logger.info(f"downvote. ip: {request.client.host}")
158
+ vote_last_response(state, "downvote", request)
159
+ return (disable_btn, ) * 2
160
+
161
+
162
+ def regenerate(dialog_state, request: gr.Request):
163
+ logger.info(f"regenerate. ip: {request.client.host}")
164
+ if dialog_state.messages[-1]['role'] == dialog_state.roles[1]:
165
+ dialog_state.messages.pop()
166
+ return (
167
+ dialog_state,
168
+ dialog_state.to_gradio_chatbot(),
169
+ ) + (disable_btn, ) * 4
170
+
171
+
172
+ def clear_history(request: gr.Request):
173
+ logger.info(f"clear_history. ip: {request.client.host}")
174
+ # state = None
175
+ # return (state, [], "") + (disable_btn, ) * 5
176
+ dialog_state = conv_seed_llama.copy()
177
+ input_state = init_input_state()
178
+ return (dialog_state, input_state, dialog_state.to_gradio_chatbot()) + (disable_btn, ) * 4
179
+
180
+
181
+ def init_input_state():
182
+ return {'images': [], 'text': '', 'images_ids': []}
183
+
184
+
185
+ def add_text(dialog_state, input_state, text, request: gr.Request):
186
+ logger.info(f"add_text. ip: {request.client.host}.")
187
+ # if len(input_state['text']) == 0:
188
+ if text is None or len(text) == 0:
189
+ # dialog_state.skip_next = True
190
+ return (dialog_state, input_state, "", dialog_state.to_gradio_chatbot()) + (no_change_btn, ) * 4
191
+ input_state['text'] += text
192
+
193
+ # dialog_state.skip_next = False
194
+
195
+ if len(dialog_state.messages) > 0 and dialog_state.messages[-1]['role'] == dialog_state.roles[0]:
196
+ dialog_state.messages[-1]['message'] = input_state
197
+ else:
198
+ dialog_state.messages.append({'role': dialog_state.roles[0], 'message': input_state})
199
+ print('add_text: ', dialog_state.to_gradio_chatbot())
200
+
201
+ return (dialog_state, input_state, "", dialog_state.to_gradio_chatbot()) + (disable_btn, ) * 4
202
+
203
+
204
+ def add_image(dialog_state, input_state, image, request: gr.Request):
205
+ logger.info(f"add_image. ip: {request.client.host}.")
206
+ if image is None:
207
+ return (dialog_state, input_state, None, dialog_state.to_gradio_chatbot()) + (no_change_btn, ) * 4
208
+
209
+ image = image.convert('RGB')
210
+ image = resize_image(image, max_size=512)
211
+ image = center_crop_image(image, max_aspect_ratio=1.3)
212
+ image_dir = get_conv_image_dir()
213
+ image_path = get_image_name(image=image, image_dir=image_dir)
214
+ if not os.path.exists(image_path):
215
+ image.save(image_path)
216
+
217
+ input_state['images'].append(image_path)
218
+ input_state['text'] += IMG_FLAG
219
+ input_state['images_ids'].append(None)
220
+
221
+ if len(dialog_state.messages) > 0 and dialog_state.messages[-1]['role'] == dialog_state.roles[0]:
222
+ dialog_state.messages[-1]['message'] = input_state
223
+ else:
224
+ dialog_state.messages.append({'role': dialog_state.roles[0], 'message': input_state})
225
+
226
+ print('add_image:', dialog_state)
227
+
228
+ return (dialog_state, input_state, None, dialog_state.to_gradio_chatbot()) + (disable_btn, ) * 4
229
+
230
+
231
+ def http_bot_test(dialog_state, input_state, temperature, top_p, max_new_tokens, num_beams, max_turns, force_image_gen, request: gr.Request):
232
+ logger.info(f"http_bot. ip: {request.client.host}")
233
+ output_state = {}
234
+ output_state['text'] = 'This is test for frontend!'
235
+ output_state['images'] = []
236
+ if len(dialog_state.messages) > 0 and len(dialog_state.messages[-1]['message']['images']) != 0:
237
+ image = random.choice(dialog_state.messages[-1]['message']['images'])
238
+ output_state['images'].append(image)
239
+ output_state['text'] += IMG_FLAG
240
+
241
+ dialog_state.messages.append({'role': dialog_state.roles[1], 'message': output_state})
242
+ input_state = init_input_state()
243
+
244
+ print('http_bot: ', dialog_state.to_gradio_chatbot())
245
+
246
+ return (dialog_state, input_state, dialog_state.to_gradio_chatbot()) + (enable_btn, ) * 4
247
+
248
+
249
+ def update_error_msg(chatbot, error_msg):
250
+ if len(error_msg) > 0:
251
+ info = '\n-------------\nSome errors occurred during response, please clear history and restart.\n' + '\n'.join(
252
+ error_msg)
253
+ chatbot[-1][-1] = chatbot[-1][-1] + info
254
+
255
+ return chatbot
256
+
257
+
258
+ def http_bot(dialog_state, input_state, temperature, top_p, max_new_tokens, num_beams, max_turns, force_image_gen, request: gr.Request):
259
+ logger.info(f"http_bot. ip: {request.client.host}")
260
+ print('input_state:', input_state)
261
+
262
+ if len(dialog_state.messages) == 0 or dialog_state.messages[-1]['role'] != dialog_state.roles[0] or len(
263
+ dialog_state.messages[-1]['message']['text'].strip(' ?.;!/')) == 0:
264
+ # if len(input_state['text']) == 0:
265
+ # dialog_state.skip_next = True
266
+ return (dialog_state, input_state, dialog_state.to_gradio_chatbot()) + (no_change_btn, ) * 4
267
+
268
+ if len(dialog_state.messages) > max_turns * 2:
269
+ output_state = init_input_state()
270
+ output_state['text'] = 'Error: History exceeds maximum rounds, please clear history and restart.'
271
+ dialog_state.messages.append({'role': dialog_state.roles[1], 'message': output_state})
272
+ input_state = init_input_state()
273
+ return (dialog_state, input_state, dialog_state.to_gradio_chatbot()) + (disable_btn, ) * 3 + (enable_btn, )
274
+
275
+ prompt = dialog_state.get_prompt()
276
+ payload = {
277
+ 'text': prompt['text'],
278
+ 'temperature': float(temperature),
279
+ 'top_p': float(top_p),
280
+ 'max_new_tokens': int(max_new_tokens),
281
+ 'num_beams': int(num_beams),
282
+ 'images': prompt['images'],
283
+ 'force_boi': force_image_gen,
284
+ }
285
+
286
+ print(
287
+ 'request: ', {
288
+ 'text': prompt['text'],
289
+ 'temperature': float(temperature),
290
+ 'top_p': float(top_p),
291
+ 'max_new_tokens': int(max_new_tokens),
292
+ 'num_beams': int(num_beams)
293
+ })
294
+ print('request_address', args.request_address)
295
+ response = requests.request(method="POST", url=args.request_address, headers=headers, json=payload)
296
+ results = response.json()
297
+ print('response: ', {'text': results['text'], 'images_ids': results['images_ids'], 'error_msg': results['error_msg']})
298
+
299
+ output_state = init_input_state()
300
+ image_dir = get_conv_image_dir()
301
+ output_state['text'] = results['text']
302
+
303
+ for image_base64 in results['images']:
304
+ if image_base64 == '':
305
+ image_path = ''
306
+ else:
307
+ image = decode_image(image_base64)
308
+ image = image.convert('RGB')
309
+ image_path = get_image_name(image=image, image_dir=image_dir)
310
+ if not os.path.exists(image_path):
311
+ image.save(image_path)
312
+ output_state['images'].append(image_path)
313
+ output_state['images_ids'].append(None)
314
+
315
+ dialog_state.messages.append({'role': dialog_state.roles[1], 'message': output_state})
316
+ dialog_state.update_image_ids(results['images_ids'])
317
+
318
+ vote_last_response(dialog_state, 'common', request)
319
+ input_state = init_input_state()
320
+ chatbot = update_error_msg(dialog_state.to_gradio_chatbot(), results['error_msg'])
321
+ return (dialog_state, input_state, chatbot) + (enable_btn, ) * 4
322
+
323
+
324
+ def load_demo(request: gr.Request):
325
+ logger.info(f"load_demo. ip: {request.client.host}")
326
+ dialog_state = conv_seed_llama.copy()
327
+ input_state = init_input_state()
328
+ return dialog_state, input_state
329
+
330
+
331
+ title = ("""
332
+ # SEED-LLaMA
333
+ [[Project Page]](https://ailab-cvc.github.io/seed/seed_llama.html) [[Paper]](https://arxiv.org/pdf/2310.01218.pdf) [[Code]](https://github.com/AILab-CVC/SEED/tree/main)
334
+
335
+ ## Tips:
336
+ * Check out the conversation examples (at the bottom) for inspiration.
337
+
338
+ * You can adjust "Max History Rounds" to try a conversation with up to five rounds. For more turns, you can download our checkpoints from GitHub and deploy them locally for inference.
339
+
340
+ * Our demo supports a mix of images and texts as input. You can freely upload an image or enter text, and then click on "Add Image/Text". You can repeat the former step multiple times, and click on "Submit" for model inference at last.
341
+
342
+ * If you are not satisfied with the output, especially the generated image, you may click on "Regenerate" for another chance.
343
+
344
+ * You can click "Force Image Generation" to compel the model to produce images when necessary. For example, our model might struggle to generate images when there is an excessive amount of text-only context.
345
+ * SEED-LLaMA was trained with English-only data. It may process with other languages due to the inherent capabilities from LLaMA, but might not stable.
346
+ """)
347
+
348
+ css = """
349
+ img {
350
+ font-family: 'Helvetica';
351
+ font-weight: 300;
352
+ line-height: 2;
353
+ text-align: center;
354
+
355
+ width: auto;
356
+ height: auto;
357
+ display: block;
358
+ position: relative;
359
+ }
360
+
361
+ img:before {
362
+ content: " ";
363
+ display: block;
364
+
365
+ position: absolute;
366
+ top: -10px;
367
+ left: 0;
368
+ height: calc(100% + 10px);
369
+ width: 100%;
370
+ background-color: rgb(230, 230, 230);
371
+ border: 2px dotted rgb(200, 200, 200);
372
+ border-radius: 5px;
373
+ }
374
+
375
+ img:after {
376
+ content: " ";
377
+ display: block;
378
+ font-size: 16px;
379
+ font-style: normal;
380
+ font-family: FontAwesome;
381
+ color: rgb(100, 100, 100);
382
+
383
+ position: absolute;
384
+ top: 5px;
385
+ left: 0;
386
+ width: 100%;
387
+ text-align: center;
388
+ }
389
+
390
+ """
391
+
392
+ if __name__ == '__main__':
393
+
394
+ examples_mix = [
395
+ ['images/cat.jpg', 'Add sunglasses to the animal.'],
396
+ ['images/eagle.jpg', 'Transform this image into cartoon style'],
397
+ [None, 'Generate an image of dog on green grass.'],
398
+ [None, 'Draw a painting of sunflowers in Van Gogh style.'],
399
+ ['images/dogs_4.jpg', 'How many dogs in the image?'],
400
+ ['images/spongebob.png', 'Who are they?'],
401
+ ['images/star.jpg', 'Do you know this painting?'],
402
+ ]
403
+
404
+ examples_conv = [
405
+ ['images/demo_example1.jpg'],
406
+ ['images/demo_example2.jpg'],
407
+ ['images/demo_example3.jpg'],
408
+ ['images/demo_example7.jpg'],
409
+ ['images/demo_example5.jpg'],
410
+ ['images/demo_example6.jpg'],
411
+ ]
412
+
413
+ with gr.Blocks(css=css) as demo:
414
+ gr.Markdown(title)
415
+ dialog_state = gr.State()
416
+ input_state = gr.State()
417
+ with gr.Row():
418
+ with gr.Column(scale=3):
419
+ with gr.Row():
420
+ image = gr.Image(type='pil', label='input_image')
421
+ with gr.Row():
422
+ text = gr.Textbox(lines=5,
423
+ show_label=False,
424
+ label='input_text',
425
+ elem_id='textbox',
426
+ placeholder="Enter text or add image, and press submit,").style(container=False)
427
+ with gr.Row():
428
+ add_image_btn = gr.Button("Add Image")
429
+ add_text_btn = gr.Button("Add Text")
430
+
431
+ submit_btn = gr.Button("Submit")
432
+
433
+ with gr.Row():
434
+ num_beams = gr.Slider(minimum=1, maximum=4, value=1, step=1, interactive=True, label="Num of Beams")
435
+ max_new_tokens = gr.Slider(minimum=64,
436
+ maximum=1024,
437
+ value=256,
438
+ step=64,
439
+ interactive=True,
440
+ label="Max New Tokens")
441
+ temperature = gr.Slider(minimum=0.0,
442
+ maximum=1.0,
443
+ value=1.0,
444
+ step=0.1,
445
+ interactive=True,
446
+ label="Temperature")
447
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.1, interactive=True, label="Top P")
448
+ max_turns = gr.Slider(minimum=1, maximum=5, value=3, step=1, interactive=True, label="Max History Rounds")
449
+ force_img_gen = gr.Radio(choices=[True, False], value=False, label='Force Image Generation')
450
+
451
+ with gr.Column(scale=7):
452
+ chatbot = gr.Chatbot(elem_id='chatbot', label="SEED LLaMA").style(height=700)
453
+ with gr.Row():
454
+ upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
455
+ downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
456
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
457
+ clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
458
+
459
+ # with gr.Row():
460
+ # gr.Examples(examples=examples_image, label='Image examples', inputs=[image])
461
+ with gr.Row():
462
+ # with gr.Column(scale=6):
463
+ gr.Examples(examples=examples_mix, label='Input examples', inputs=[image, text])
464
+ # with gr.Column(scale=0.4):
465
+ # gr.Examples(examples=examples_text, inputs=[text])
466
+
467
+
468
+ # with gr.Row():
469
+ # gr.Examples(examples=examples_2, inputs=[image])
470
+
471
+ with gr.Row():
472
+ # gr.Gallery(value=[Image.open(e[0]) for e in examples_conv], show_label=True, label="Example Conversations", elem_id="gallery",height=1400, object_fit='contain').style(grid=[3], height='auto')
473
+ gr.Gallery(value=[Image.open(e[0]) for e in examples_conv], show_label=True, label="Example Conversations", elem_id="gallery",height=1500, columns=[3], rows=[2])
474
+
475
+ # Register listeners
476
+ btn_list = [upvote_btn, downvote_btn, regenerate_btn, clear_btn]
477
+ upvote_btn.click(upvote_last_response, [dialog_state], [upvote_btn, downvote_btn])
478
+ downvote_btn.click(downvote_last_response, [dialog_state], [upvote_btn, downvote_btn])
479
+ regenerate_btn.click(regenerate, [dialog_state], [dialog_state, chatbot] + btn_list).then(
480
+ http_bot, [dialog_state, input_state, temperature, top_p, max_new_tokens, num_beams, max_turns, force_img_gen],
481
+ [dialog_state, input_state, chatbot] + btn_list)
482
+ add_image_btn.click(add_image, [dialog_state, input_state, image],
483
+ [dialog_state, input_state, image, chatbot] + btn_list)
484
+
485
+ add_text_btn.click(add_text, [dialog_state, input_state, text], [dialog_state, input_state, text, chatbot] + btn_list)
486
+
487
+ submit_btn.click(
488
+ add_image, [dialog_state, input_state, image], [dialog_state, input_state, image, chatbot] + btn_list).then(
489
+ add_text, [dialog_state, input_state, text],
490
+ [dialog_state, input_state, text, chatbot, upvote_btn, downvote_btn, regenerate_btn, clear_btn]).then(
491
+ http_bot, [dialog_state, input_state, temperature, top_p, max_new_tokens, num_beams, max_turns, force_img_gen],
492
+ [dialog_state, input_state, chatbot] + btn_list)
493
+ clear_btn.click(clear_history, None, [dialog_state, input_state, chatbot] + btn_list)
494
+
495
+ demo.load(load_demo, None, [dialog_state, input_state])
496
+
497
+ demo.launch(server_name=args.server_name, server_port=args.server_port, enable_queue=True)
gradio_demo/utils.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import logging
3
+ import logging.handlers
4
+ import os
5
+ import sys
6
+
7
+ handler = None
8
+
9
+
10
+ def build_logger(logger_name, logger_dir):
11
+ global handler
12
+
13
+ formatter = logging.Formatter(
14
+ fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
15
+ datefmt="%Y-%m-%d %H:%M:%S",
16
+ )
17
+
18
+ # Set the format of root handlers
19
+ if not logging.getLogger().handlers:
20
+ logging.basicConfig(level=logging.INFO)
21
+ logging.getLogger().handlers[0].setFormatter(formatter)
22
+
23
+ # Redirect stdout and stderr to loggers
24
+ stdout_logger = logging.getLogger("stdout")
25
+ stdout_logger.setLevel(logging.INFO)
26
+ sl = StreamToLogger(stdout_logger, logging.INFO)
27
+ sys.stdout = sl
28
+
29
+ stderr_logger = logging.getLogger("stderr")
30
+ stderr_logger.setLevel(logging.ERROR)
31
+ sl = StreamToLogger(stderr_logger, logging.ERROR)
32
+ sys.stderr = sl
33
+
34
+ # Get logger
35
+ logger = logging.getLogger(logger_name)
36
+ logger.setLevel(logging.INFO)
37
+
38
+ # Add a file handler for all loggers
39
+ if handler is None:
40
+ os.makedirs(logger_dir, exist_ok=True)
41
+ filename = os.path.join(logger_dir, logger_name + '.log')
42
+ handler = logging.handlers.TimedRotatingFileHandler(filename, when='D', utc=True)
43
+ handler.setFormatter(formatter)
44
+
45
+ for name, item in logging.root.manager.loggerDict.items():
46
+ if isinstance(item, logging.Logger):
47
+ item.addHandler(handler)
48
+
49
+ return logger
50
+
51
+
52
+ class StreamToLogger(object):
53
+ """
54
+ Fake file-like stream object that redirects writes to a logger instance.
55
+ """
56
+ def __init__(self, logger, log_level=logging.INFO):
57
+ self.terminal = sys.stdout
58
+ self.logger = logger
59
+ self.log_level = log_level
60
+ self.linebuf = ''
61
+
62
+ def __getattr__(self, attr):
63
+ return getattr(self.terminal, attr)
64
+
65
+ def write(self, buf):
66
+ temp_linebuf = self.linebuf + buf
67
+ self.linebuf = ''
68
+ for line in temp_linebuf.splitlines(True):
69
+ # From the io.TextIOWrapper docs:
70
+ # On output, if newline is None, any '\n' characters written
71
+ # are translated to the system default line separator.
72
+ # By default sys.stdout.write() expects '\n' newlines and then
73
+ # translates them so this is still cross platform.
74
+ if line[-1] == '\n':
75
+ self.logger.log(self.log_level, line.rstrip())
76
+ else:
77
+ self.linebuf += line
78
+
79
+ def flush(self):
80
+ if self.linebuf != '':
81
+ self.logger.log(self.log_level, self.linebuf.rstrip())
82
+ self.linebuf = ''
images/cat.jpg ADDED

Git LFS Details

  • SHA256: 0e0c19a34a640f9b83514d243fc6135d9afcb8c7de08d58a90aa35e0684237a0
  • Pointer size: 130 Bytes
  • Size of remote file: 59.7 kB
images/demo_example1.jpg ADDED

Git LFS Details

  • SHA256: f3f74364c2ad5461a611b3c71cb1e61b5579e648870a24075e5cf59705ff66e8
  • Pointer size: 132 Bytes
  • Size of remote file: 2.89 MB
images/demo_example2.jpg ADDED

Git LFS Details

  • SHA256: 37a33958f00f8f70cabcdd29cc401b5403b43fc95263db997faf6dd96cfaad62
  • Pointer size: 132 Bytes
  • Size of remote file: 2.71 MB
images/demo_example3.jpg ADDED

Git LFS Details

  • SHA256: 10a845ab517d70d891124f247aa270ccb27428681cf53044721c4cceee097b94
  • Pointer size: 132 Bytes
  • Size of remote file: 4.02 MB
images/demo_example4.jpg ADDED

Git LFS Details

  • SHA256: a855bbc69f23d66b9e0a5e07b7477886a6400c8879a336834e58a5b8beef2109
  • Pointer size: 132 Bytes
  • Size of remote file: 4.37 MB
images/demo_example5.jpg ADDED

Git LFS Details

  • SHA256: efc93e5ce3380dec20d51cc40c85a5ace973357579c947c7182e3caec0ac521a
  • Pointer size: 132 Bytes
  • Size of remote file: 2.09 MB
images/demo_example6.jpg ADDED

Git LFS Details

  • SHA256: 5be59761291a66cedbea829181ec7b759785e33d2f590c1c91c3b1f599426a34
  • Pointer size: 132 Bytes
  • Size of remote file: 3.64 MB
images/demo_example7.jpg ADDED

Git LFS Details

  • SHA256: 969246370751053970d8d37bb2da74540eefd9f3fd3d57798ea0409edd4b563c
  • Pointer size: 132 Bytes
  • Size of remote file: 3.44 MB
images/dogs_4.jpg ADDED

Git LFS Details

  • SHA256: 1e6acbeb446426d5e2ef1a01f9178760d2d97e6e76cef44bc46c925eecef37a0
  • Pointer size: 131 Bytes
  • Size of remote file: 274 kB
images/eagle.jpg ADDED

Git LFS Details

  • SHA256: 37c7669dfbd066afcc2535d1672da31d7050179fb0200f1160b1b1a97416a06a
  • Pointer size: 130 Bytes
  • Size of remote file: 90 kB
images/flower.png ADDED

Git LFS Details

  • SHA256: 70097caecb5a023cf876060c5449a75dead9d8d92777e5a458b2a661a653bab4
  • Pointer size: 132 Bytes
  • Size of remote file: 1.72 MB
images/spongebob.png ADDED

Git LFS Details

  • SHA256: 45e14eaea4cd27702574cf6ad31f0aeaf513527a04092bebb2a454585be9fd86
  • Pointer size: 132 Bytes
  • Size of remote file: 1.98 MB
images/star.jpg ADDED

Git LFS Details

  • SHA256: e57b6fb40aecfd6407d4e919c60213e254c7b0a7641f64de6a9443e9c3179cdb
  • Pointer size: 131 Bytes
  • Size of remote file: 658 kB
models/__init__.py ADDED
File without changes
models/llama_xformer.py ADDED
@@ -0,0 +1,906 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ PyTorch LLaMA model."""
21
+ from typing import List, Optional, Tuple, Union
22
+
23
+ import torch
24
+ import torch.utils.checkpoint
25
+ from torch import nn
26
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
+
28
+ from transformers.activations import ACT2FN
29
+ from transformers.modeling_outputs import (
30
+ BaseModelOutputWithPast,
31
+ CausalLMOutputWithPast,
32
+ SequenceClassifierOutputWithPast,
33
+ )
34
+ from transformers.modeling_utils import PreTrainedModel
35
+ from transformers.utils import (
36
+ add_start_docstrings,
37
+ add_start_docstrings_to_model_forward,
38
+ logging,
39
+ replace_return_docstrings,
40
+ )
41
+ from transformers.models.llama.configuration_llama import LlamaConfig
42
+ import xformers.ops as xops
43
+
44
+ logger = logging.get_logger(__name__)
45
+
46
+ _CONFIG_FOR_DOC = "LlamaConfig"
47
+
48
+
49
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
50
+ def _make_causal_mask(
51
+ input_ids_shape: torch.Size,
52
+ dtype: torch.dtype,
53
+ device: torch.device,
54
+ past_key_values_length: int = 0,
55
+ ):
56
+ """
57
+ Make causal mask used for bi-directional self-attention.
58
+ """
59
+ bsz, tgt_len = input_ids_shape
60
+ mask = torch.full(
61
+ (tgt_len, tgt_len),
62
+ torch.tensor(torch.finfo(dtype).min, device=device),
63
+ device=device,
64
+ )
65
+ mask_cond = torch.arange(mask.size(-1), device=device)
66
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
67
+ mask = mask.to(dtype)
68
+
69
+ if past_key_values_length > 0:
70
+ mask = torch.cat(
71
+ [
72
+ torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device),
73
+ mask,
74
+ ],
75
+ dim=-1,
76
+ )
77
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
78
+
79
+
80
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
81
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
82
+ """
83
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
84
+ """
85
+ bsz, src_len = mask.size()
86
+ tgt_len = tgt_len if tgt_len is not None else src_len
87
+
88
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
89
+
90
+ inverted_mask = 1.0 - expanded_mask
91
+
92
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
93
+
94
+
95
+ class LlamaRMSNorm(nn.Module):
96
+
97
+ def __init__(self, hidden_size, eps=1e-6):
98
+ """
99
+ LlamaRMSNorm is equivalent to T5LayerNorm
100
+ """
101
+ super().__init__()
102
+ self.weight = nn.Parameter(torch.ones(hidden_size))
103
+ self.variance_epsilon = eps
104
+
105
+ def forward(self, hidden_states):
106
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
107
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
108
+
109
+ # convert into half-precision if necessary
110
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
111
+ hidden_states = hidden_states.to(self.weight.dtype)
112
+
113
+ return self.weight * hidden_states
114
+
115
+
116
+ class LlamaRotaryEmbedding(torch.nn.Module):
117
+
118
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
119
+ super().__init__()
120
+ inv_freq = 1.0 / (base**(torch.arange(0, dim, 2).float().to(device) / dim))
121
+ self.register_buffer("inv_freq", inv_freq)
122
+
123
+ # Build here to make `torch.jit.trace` work.
124
+ self.max_seq_len_cached = max_position_embeddings
125
+ t = torch.arange(
126
+ self.max_seq_len_cached,
127
+ device=self.inv_freq.device,
128
+ dtype=self.inv_freq.dtype,
129
+ )
130
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
131
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
132
+ emb = torch.cat((freqs, freqs), dim=-1)
133
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
134
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
135
+
136
+ def forward(self, x, seq_len=None):
137
+ # x: [bs, num_attention_heads, seq_len, head_size]
138
+ # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
139
+ if seq_len > self.max_seq_len_cached:
140
+ self.max_seq_len_cached = seq_len
141
+ t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
142
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
143
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
144
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
145
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
146
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
147
+ return (
148
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
149
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
150
+ )
151
+
152
+
153
+ def rotate_half(x):
154
+ """Rotates half the hidden dims of the input."""
155
+ x1 = x[..., :x.shape[-1] // 2]
156
+ x2 = x[..., x.shape[-1] // 2:]
157
+ return torch.cat((-x2, x1), dim=-1)
158
+
159
+
160
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
161
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
162
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
163
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
164
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
165
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
166
+ q_embed = (q * cos) + (rotate_half(q) * sin)
167
+ k_embed = (k * cos) + (rotate_half(k) * sin)
168
+ return q_embed, k_embed
169
+
170
+
171
+ class LlamaMLP(nn.Module):
172
+
173
+ def __init__(
174
+ self,
175
+ hidden_size: int,
176
+ intermediate_size: int,
177
+ hidden_act: str,
178
+ ):
179
+ super().__init__()
180
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
181
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
182
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
183
+ self.act_fn = ACT2FN[hidden_act]
184
+
185
+ def forward(self, x):
186
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
187
+
188
+
189
+ class LlamaAttention(nn.Module):
190
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
191
+
192
+ def __init__(self, config: LlamaConfig):
193
+ super().__init__()
194
+ self.config = config
195
+ self.hidden_size = config.hidden_size
196
+ self.num_heads = config.num_attention_heads
197
+ self.head_dim = self.hidden_size // self.num_heads
198
+ self.max_position_embeddings = config.max_position_embeddings
199
+
200
+ if (self.head_dim * self.num_heads) != self.hidden_size:
201
+ raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
202
+ f" and `num_heads`: {self.num_heads}).")
203
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
204
+ self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
205
+ self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
206
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
207
+ self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
208
+
209
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
210
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
211
+
212
+ def forward(
213
+ self,
214
+ hidden_states: torch.Tensor,
215
+ attention_mask: Optional[torch.Tensor] = None,
216
+ position_ids: Optional[torch.LongTensor] = None,
217
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
218
+ output_attentions: bool = False,
219
+ use_cache: bool = False,
220
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
221
+ bsz, q_len, _ = hidden_states.size()
222
+
223
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
224
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
225
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
226
+
227
+ kv_seq_len = key_states.shape[-2]
228
+ if past_key_value is not None:
229
+ kv_seq_len += past_key_value[0].shape[-2]
230
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
231
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
232
+ # [bsz, nh, t, hd]
233
+
234
+ if past_key_value is not None:
235
+ # reuse k, v, self_attention
236
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
237
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
238
+
239
+ past_key_value = (key_states, value_states) if use_cache else None
240
+ query_states = query_states.transpose(1, 2)
241
+ key_states = key_states.transpose(1, 2)
242
+ value_states = value_states.transpose(1, 2)
243
+ if self.training:
244
+ attn_output = xops.memory_efficient_attention(
245
+ query_states,
246
+ key_states,
247
+ value_states,
248
+ attn_bias=xops.LowerTriangularMask(),
249
+ )
250
+ else:
251
+ attn_output = xops.memory_efficient_attention(
252
+ query_states,
253
+ key_states,
254
+ value_states,
255
+ attn_bias=None if attention_mask.sum() == 0 else xops.LowerTriangularMask(),
256
+ )
257
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
258
+ attn_output = self.o_proj(attn_output)
259
+
260
+ if not output_attentions:
261
+ attn_weights = None
262
+
263
+ return attn_output, attn_weights, past_key_value
264
+
265
+
266
+ class LlamaDecoderLayer(nn.Module):
267
+
268
+ def __init__(self, config: LlamaConfig):
269
+ super().__init__()
270
+ self.hidden_size = config.hidden_size
271
+ self.self_attn = LlamaAttention(config=config)
272
+ self.mlp = LlamaMLP(
273
+ hidden_size=self.hidden_size,
274
+ intermediate_size=config.intermediate_size,
275
+ hidden_act=config.hidden_act,
276
+ )
277
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
278
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
279
+
280
+ def forward(
281
+ self,
282
+ hidden_states: torch.Tensor,
283
+ attention_mask: Optional[torch.Tensor] = None,
284
+ position_ids: Optional[torch.LongTensor] = None,
285
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
286
+ output_attentions: Optional[bool] = False,
287
+ use_cache: Optional[bool] = False,
288
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
289
+ """
290
+ Args:
291
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
292
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
293
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
294
+ output_attentions (`bool`, *optional*):
295
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
296
+ returned tensors for more detail.
297
+ use_cache (`bool`, *optional*):
298
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
299
+ (see `past_key_values`).
300
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
301
+ """
302
+
303
+ residual = hidden_states
304
+
305
+ hidden_states = self.input_layernorm(hidden_states)
306
+
307
+ # Self Attention
308
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
309
+ hidden_states=hidden_states,
310
+ attention_mask=attention_mask,
311
+ position_ids=position_ids,
312
+ past_key_value=past_key_value,
313
+ output_attentions=output_attentions,
314
+ use_cache=use_cache,
315
+ )
316
+ hidden_states = residual + hidden_states
317
+
318
+ # Fully Connected
319
+ residual = hidden_states
320
+ hidden_states = self.post_attention_layernorm(hidden_states)
321
+ hidden_states = self.mlp(hidden_states)
322
+ hidden_states = residual + hidden_states
323
+
324
+ outputs = (hidden_states, )
325
+
326
+ if output_attentions:
327
+ outputs += (self_attn_weights, )
328
+
329
+ if use_cache:
330
+ outputs += (present_key_value, )
331
+
332
+ return outputs
333
+
334
+
335
+ LLAMA_START_DOCSTRING = r"""
336
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
337
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
338
+ etc.)
339
+
340
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
341
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
342
+ and behavior.
343
+
344
+ Parameters:
345
+ config ([`LlamaConfig`]):
346
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
347
+ load the weights associated with the model, only the configuration. Check out the
348
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
349
+ """
350
+
351
+
352
+ @add_start_docstrings(
353
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
354
+ LLAMA_START_DOCSTRING,
355
+ )
356
+ class LlamaPreTrainedModel(PreTrainedModel):
357
+ config_class = LlamaConfig
358
+ base_model_prefix = "model"
359
+ supports_gradient_checkpointing = True
360
+ _no_split_modules = ["LlamaDecoderLayer"]
361
+ _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
362
+
363
+ def _init_weights(self, module):
364
+ std = self.config.initializer_range
365
+ if isinstance(module, nn.Linear):
366
+ module.weight.data.normal_(mean=0.0, std=std)
367
+ if module.bias is not None:
368
+ module.bias.data.zero_()
369
+ elif isinstance(module, nn.Embedding):
370
+ module.weight.data.normal_(mean=0.0, std=std)
371
+ if module.padding_idx is not None:
372
+ module.weight.data[module.padding_idx].zero_()
373
+
374
+ def _set_gradient_checkpointing(self, module, value=False):
375
+ if isinstance(module, LlamaModel):
376
+ module.gradient_checkpointing = value
377
+
378
+
379
+ LLAMA_INPUTS_DOCSTRING = r"""
380
+ Args:
381
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
382
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
383
+ it.
384
+
385
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
386
+ [`PreTrainedTokenizer.__call__`] for details.
387
+
388
+ [What are input IDs?](../glossary#input-ids)
389
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
390
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
391
+
392
+ - 1 for tokens that are **not masked**,
393
+ - 0 for tokens that are **masked**.
394
+
395
+ [What are attention masks?](../glossary#attention-mask)
396
+
397
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
398
+ [`PreTrainedTokenizer.__call__`] for details.
399
+
400
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
401
+ `past_key_values`).
402
+
403
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
404
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
405
+ information on the default strategy.
406
+
407
+ - 1 indicates the head is **not masked**,
408
+ - 0 indicates the head is **masked**.
409
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
410
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
411
+ config.n_positions - 1]`.
412
+
413
+ [What are position IDs?](../glossary#position-ids)
414
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
415
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
416
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
417
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
418
+
419
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
420
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
421
+
422
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
423
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
424
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
425
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
426
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
427
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
428
+ model's internal embedding lookup matrix.
429
+ use_cache (`bool`, *optional*):
430
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
431
+ `past_key_values`).
432
+ output_attentions (`bool`, *optional*):
433
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
434
+ tensors for more detail.
435
+ output_hidden_states (`bool`, *optional*):
436
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
437
+ more detail.
438
+ return_dict (`bool`, *optional*):
439
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
440
+ """
441
+
442
+
443
+ @add_start_docstrings(
444
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
445
+ LLAMA_START_DOCSTRING,
446
+ )
447
+ class LlamaModel(LlamaPreTrainedModel):
448
+ """
449
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
450
+
451
+ Args:
452
+ config: LlamaConfig
453
+ """
454
+
455
+ def __init__(self, config: LlamaConfig):
456
+ super().__init__(config)
457
+ self.padding_idx = config.pad_token_id
458
+ self.vocab_size = config.vocab_size
459
+
460
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
461
+ self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
462
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
463
+
464
+ self.gradient_checkpointing = False
465
+ # Initialize weights and apply final processing
466
+ self.post_init()
467
+
468
+ def get_input_embeddings(self):
469
+ return self.embed_tokens
470
+
471
+ def set_input_embeddings(self, value):
472
+ self.embed_tokens = value
473
+
474
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
475
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
476
+ # create causal mask
477
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
478
+ combined_attention_mask = None
479
+ if input_shape[-1] > 1:
480
+ combined_attention_mask = _make_causal_mask(
481
+ input_shape,
482
+ inputs_embeds.dtype,
483
+ device=inputs_embeds.device,
484
+ past_key_values_length=past_key_values_length,
485
+ )
486
+
487
+ if attention_mask is not None:
488
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
489
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype,
490
+ tgt_len=input_shape[-1]).to(inputs_embeds.device)
491
+ combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
492
+
493
+ return combined_attention_mask
494
+
495
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
496
+ def forward(
497
+ self,
498
+ input_ids: torch.LongTensor = None,
499
+ attention_mask: Optional[torch.Tensor] = None,
500
+ position_ids: Optional[torch.LongTensor] = None,
501
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
502
+ inputs_embeds: Optional[torch.FloatTensor] = None,
503
+ use_cache: Optional[bool] = None,
504
+ output_attentions: Optional[bool] = None,
505
+ output_hidden_states: Optional[bool] = None,
506
+ return_dict: Optional[bool] = None,
507
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
508
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
509
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
510
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
511
+
512
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
513
+
514
+ # retrieve input_ids and inputs_embeds
515
+ if input_ids is not None and inputs_embeds is not None:
516
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
517
+ elif input_ids is not None:
518
+ batch_size, seq_length = input_ids.shape
519
+ elif inputs_embeds is not None:
520
+ batch_size, seq_length, _ = inputs_embeds.shape
521
+ else:
522
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
523
+
524
+ seq_length_with_past = seq_length
525
+ past_key_values_length = 0
526
+
527
+ if past_key_values is not None:
528
+ past_key_values_length = past_key_values[0][0].shape[2]
529
+ seq_length_with_past = seq_length_with_past + past_key_values_length
530
+
531
+ if position_ids is None:
532
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
533
+ position_ids = torch.arange(
534
+ past_key_values_length,
535
+ seq_length + past_key_values_length,
536
+ dtype=torch.long,
537
+ device=device,
538
+ )
539
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
540
+ else:
541
+ position_ids = position_ids.view(-1, seq_length).long()
542
+
543
+ if inputs_embeds is None:
544
+ inputs_embeds = self.embed_tokens(input_ids)
545
+ # embed positions
546
+ if attention_mask is None:
547
+ attention_mask = torch.ones(
548
+ (batch_size, seq_length_with_past),
549
+ dtype=torch.bool,
550
+ device=inputs_embeds.device,
551
+ )
552
+ attention_mask = self._prepare_decoder_attention_mask(
553
+ attention_mask,
554
+ (batch_size, seq_length),
555
+ inputs_embeds,
556
+ past_key_values_length,
557
+ )
558
+
559
+ hidden_states = inputs_embeds
560
+
561
+ if self.gradient_checkpointing and self.training:
562
+ if use_cache:
563
+ logger.warning_once(
564
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
565
+ use_cache = False
566
+
567
+ # decoder layers
568
+ all_hidden_states = () if output_hidden_states else None
569
+ all_self_attns = () if output_attentions else None
570
+ next_decoder_cache = () if use_cache else None
571
+
572
+ for idx, decoder_layer in enumerate(self.layers):
573
+ if output_hidden_states:
574
+ all_hidden_states += (hidden_states, )
575
+
576
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
577
+
578
+ if self.gradient_checkpointing and self.training:
579
+
580
+ def create_custom_forward(module):
581
+
582
+ def custom_forward(*inputs):
583
+ # None for past_key_value
584
+ return module(*inputs, output_attentions, None)
585
+
586
+ return custom_forward
587
+
588
+ layer_outputs = torch.utils.checkpoint.checkpoint(
589
+ create_custom_forward(decoder_layer),
590
+ hidden_states,
591
+ attention_mask,
592
+ position_ids,
593
+ None,
594
+ )
595
+ else:
596
+ layer_outputs = decoder_layer(
597
+ hidden_states,
598
+ attention_mask=attention_mask,
599
+ position_ids=position_ids,
600
+ past_key_value=past_key_value,
601
+ output_attentions=output_attentions,
602
+ use_cache=use_cache,
603
+ )
604
+
605
+ hidden_states = layer_outputs[0]
606
+
607
+ if use_cache:
608
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1], )
609
+
610
+ if output_attentions:
611
+ all_self_attns += (layer_outputs[1], )
612
+
613
+ hidden_states = self.norm(hidden_states)
614
+
615
+ # add hidden states from the last decoder layer
616
+ if output_hidden_states:
617
+ all_hidden_states += (hidden_states, )
618
+
619
+ next_cache = next_decoder_cache if use_cache else None
620
+ if not return_dict:
621
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
622
+ return BaseModelOutputWithPast(
623
+ last_hidden_state=hidden_states,
624
+ past_key_values=next_cache,
625
+ hidden_states=all_hidden_states,
626
+ attentions=all_self_attns,
627
+ )
628
+
629
+
630
+ class LlamaForCausalLM(LlamaPreTrainedModel):
631
+
632
+ def __init__(self, config):
633
+ super().__init__(config)
634
+ self.model = LlamaModel(config)
635
+
636
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
637
+
638
+ # Initialize weights and apply final processing
639
+ self.post_init()
640
+
641
+ def get_input_embeddings(self):
642
+ return self.model.embed_tokens
643
+
644
+ def set_input_embeddings(self, value):
645
+ self.model.embed_tokens = value
646
+
647
+ def get_output_embeddings(self):
648
+ return self.lm_head
649
+
650
+ def set_output_embeddings(self, new_embeddings):
651
+ self.lm_head = new_embeddings
652
+
653
+ def set_decoder(self, decoder):
654
+ self.model = decoder
655
+
656
+ def get_decoder(self):
657
+ return self.model
658
+
659
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
660
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
661
+ def forward(
662
+ self,
663
+ input_ids: torch.LongTensor = None,
664
+ attention_mask: Optional[torch.Tensor] = None,
665
+ position_ids: Optional[torch.LongTensor] = None,
666
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
667
+ inputs_embeds: Optional[torch.FloatTensor] = None,
668
+ labels: Optional[torch.LongTensor] = None,
669
+ use_cache: Optional[bool] = None,
670
+ output_attentions: Optional[bool] = None,
671
+ output_hidden_states: Optional[bool] = None,
672
+ return_dict: Optional[bool] = None,
673
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
674
+ r"""
675
+ Args:
676
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
677
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
678
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
679
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
680
+
681
+ Returns:
682
+
683
+ Example:
684
+
685
+ ```python
686
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
687
+
688
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
689
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
690
+
691
+ >>> prompt = "Hey, are you consciours? Can you talk to me?"
692
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
693
+
694
+ >>> # Generate
695
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
696
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
697
+ "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
698
+ ```"""
699
+
700
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
701
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
702
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
703
+
704
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
705
+ outputs = self.model(
706
+ input_ids=input_ids,
707
+ attention_mask=attention_mask,
708
+ position_ids=position_ids,
709
+ past_key_values=past_key_values,
710
+ inputs_embeds=inputs_embeds,
711
+ use_cache=use_cache,
712
+ output_attentions=output_attentions,
713
+ output_hidden_states=output_hidden_states,
714
+ return_dict=return_dict,
715
+ )
716
+
717
+ hidden_states = outputs[0]
718
+ logits = self.lm_head(hidden_states)
719
+
720
+ loss = None
721
+ if labels is not None:
722
+ # Shift so that tokens < n predict n
723
+ shift_logits = logits[..., :-1, :].contiguous()
724
+ shift_labels = labels[..., 1:].contiguous()
725
+ # Flatten the tokens
726
+ loss_fct = CrossEntropyLoss()
727
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
728
+ shift_labels = shift_labels.view(-1)
729
+ # Enable model parallelism
730
+ shift_labels = shift_labels.to(shift_logits.device)
731
+ loss = loss_fct(shift_logits, shift_labels)
732
+
733
+ if not return_dict:
734
+ output = (logits, ) + outputs[1:]
735
+ return (loss, ) + output if loss is not None else output
736
+
737
+ return CausalLMOutputWithPast(
738
+ loss=loss,
739
+ logits=logits,
740
+ past_key_values=outputs.past_key_values,
741
+ hidden_states=outputs.hidden_states,
742
+ attentions=outputs.attentions,
743
+ )
744
+
745
+ def prepare_inputs_for_generation(
746
+ self,
747
+ input_ids,
748
+ past_key_values=None,
749
+ attention_mask=None,
750
+ inputs_embeds=None,
751
+ **kwargs,
752
+ ):
753
+ if past_key_values:
754
+ input_ids = input_ids[:, -1:]
755
+
756
+ position_ids = kwargs.get("position_ids", None)
757
+ if attention_mask is not None and position_ids is None:
758
+ # create position_ids on the fly for batch generation
759
+ position_ids = attention_mask.long().cumsum(-1) - 1
760
+ position_ids.masked_fill_(attention_mask == 0, 1)
761
+ if past_key_values:
762
+ position_ids = position_ids[:, -1].unsqueeze(-1)
763
+
764
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
765
+ if inputs_embeds is not None and past_key_values is None:
766
+ model_inputs = {"inputs_embeds": inputs_embeds}
767
+ else:
768
+ model_inputs = {"input_ids": input_ids}
769
+
770
+ model_inputs.update({
771
+ "position_ids": position_ids,
772
+ "past_key_values": past_key_values,
773
+ "use_cache": kwargs.get("use_cache"),
774
+ "attention_mask": attention_mask,
775
+ })
776
+ return model_inputs
777
+
778
+ @staticmethod
779
+ def _reorder_cache(past_key_values, beam_idx):
780
+ reordered_past = ()
781
+ for layer_past in past_key_values:
782
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past), )
783
+ return reordered_past
784
+
785
+
786
+ @add_start_docstrings(
787
+ """
788
+ The LLaMa Model transformer with a sequence classification head on top (linear layer).
789
+
790
+ [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
791
+ (e.g. GPT-2) do.
792
+
793
+ Since it does classification on the last token, it requires to know the position of the last token. If a
794
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
795
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
796
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
797
+ each row of the batch).
798
+ """,
799
+ LLAMA_START_DOCSTRING,
800
+ )
801
+ class LlamaForSequenceClassification(LlamaPreTrainedModel):
802
+ _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
803
+
804
+ def __init__(self, config):
805
+ super().__init__(config)
806
+ self.num_labels = config.num_labels
807
+ self.model = LlamaModel(config)
808
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
809
+
810
+ # Initialize weights and apply final processing
811
+ self.post_init()
812
+
813
+ def get_input_embeddings(self):
814
+ return self.model.embed_tokens
815
+
816
+ def set_input_embeddings(self, value):
817
+ self.model.embed_tokens = value
818
+
819
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
820
+ def forward(
821
+ self,
822
+ input_ids: torch.LongTensor = None,
823
+ attention_mask: Optional[torch.Tensor] = None,
824
+ position_ids: Optional[torch.LongTensor] = None,
825
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
826
+ inputs_embeds: Optional[torch.FloatTensor] = None,
827
+ labels: Optional[torch.LongTensor] = None,
828
+ use_cache: Optional[bool] = None,
829
+ output_attentions: Optional[bool] = None,
830
+ output_hidden_states: Optional[bool] = None,
831
+ return_dict: Optional[bool] = None,
832
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
833
+ r"""
834
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
835
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
836
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
837
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
838
+ """
839
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
840
+
841
+ transformer_outputs = self.model(
842
+ input_ids,
843
+ attention_mask=attention_mask,
844
+ position_ids=position_ids,
845
+ past_key_values=past_key_values,
846
+ inputs_embeds=inputs_embeds,
847
+ use_cache=use_cache,
848
+ output_attentions=output_attentions,
849
+ output_hidden_states=output_hidden_states,
850
+ return_dict=return_dict,
851
+ )
852
+ hidden_states = transformer_outputs[0]
853
+ logits = self.score(hidden_states)
854
+
855
+ if input_ids is not None:
856
+ batch_size = input_ids.shape[0]
857
+ else:
858
+ batch_size = inputs_embeds.shape[0]
859
+
860
+ if self.config.pad_token_id is None and batch_size != 1:
861
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
862
+ if self.config.pad_token_id is None:
863
+ sequence_lengths = -1
864
+ else:
865
+ if input_ids is not None:
866
+ sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
867
+ else:
868
+ sequence_lengths = -1
869
+
870
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
871
+
872
+ loss = None
873
+ if labels is not None:
874
+ labels = labels.to(logits.device)
875
+ if self.config.problem_type is None:
876
+ if self.num_labels == 1:
877
+ self.config.problem_type = "regression"
878
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
879
+ self.config.problem_type = "single_label_classification"
880
+ else:
881
+ self.config.problem_type = "multi_label_classification"
882
+
883
+ if self.config.problem_type == "regression":
884
+ loss_fct = MSELoss()
885
+ if self.num_labels == 1:
886
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
887
+ else:
888
+ loss = loss_fct(pooled_logits, labels)
889
+ elif self.config.problem_type == "single_label_classification":
890
+ loss_fct = CrossEntropyLoss()
891
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
892
+ elif self.config.problem_type == "multi_label_classification":
893
+ loss_fct = BCEWithLogitsLoss()
894
+ loss = loss_fct(pooled_logits, labels)
895
+ if not return_dict:
896
+ output = (pooled_logits, ) + transformer_outputs[1:]
897
+ return ((loss, ) + output) if loss is not None else output
898
+
899
+ return SequenceClassifierOutputWithPast(
900
+ loss=loss,
901
+ logits=pooled_logits,
902
+ past_key_values=transformer_outputs.past_key_values,
903
+ hidden_states=transformer_outputs.hidden_states,
904
+ attentions=transformer_outputs.attentions,
905
+ )
906
+
models/model_tools.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .llama_xformer import LlamaForCausalLM
3
+
4
+
5
+ def get_pretrained_llama_causal_model(pretrained_model_name_or_path=None, torch_dtype='fp16', **kwargs):
6
+ if torch_dtype == 'fp16' or torch_dtype == 'float16':
7
+ torch_dtype = torch.float16
8
+ elif torch_dtype == 'bf16' or torch_dtype == 'bfloat16':
9
+ torch_dtype = torch.bfloat16
10
+ else:
11
+ torch_dtype == torch.float32
12
+ model = LlamaForCausalLM.from_pretrained(
13
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
14
+ torch_dtype=torch_dtype,
15
+ **kwargs,
16
+ )
17
+
18
+ return model
models/pipeline_stable_unclip_img2img.py ADDED
@@ -0,0 +1,794 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ import warnings
17
+ from typing import Any, Callable, Dict, List, Optional, Union
18
+
19
+ import PIL
20
+ import torch
21
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
22
+
23
+ from diffusers.utils.import_utils import is_accelerate_available
24
+
25
+ from diffusers.image_processor import VaeImageProcessor
26
+
27
+ from diffusers.image_processor import VaeImageProcessor
28
+ from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
29
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
30
+ from diffusers.models.embeddings import get_timestep_embedding
31
+ from diffusers.schedulers import KarrasDiffusionSchedulers
32
+ from diffusers.utils import is_accelerate_version, logging, randn_tensor, replace_example_docstring
33
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
34
+ from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
35
+
36
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
37
+
38
+ EXAMPLE_DOC_STRING = """
39
+ Examples:
40
+ ```py
41
+ >>> import requests
42
+ >>> import torch
43
+ >>> from PIL import Image
44
+ >>> from io import BytesIO
45
+
46
+ >>> from diffusers import StableUnCLIPImg2ImgPipeline
47
+
48
+ >>> pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(
49
+ ... "fusing/stable-unclip-2-1-l-img2img", torch_dtype=torch.float16
50
+ ... ) # TODO update model path
51
+ >>> pipe = pipe.to("cuda")
52
+
53
+ >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
54
+
55
+ >>> response = requests.get(url)
56
+ >>> init_image = Image.open(BytesIO(response.content)).convert("RGB")
57
+ >>> init_image = init_image.resize((768, 512))
58
+
59
+ >>> prompt = "A fantasy landscape, trending on artstation"
60
+
61
+ >>> images = pipe(prompt, init_image).images
62
+ >>> images[0].save("fantasy_landscape.png")
63
+ ```
64
+ """
65
+
66
+
67
+ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
68
+ """
69
+ Pipeline for text-guided image-to-image generation using stable unCLIP.
70
+
71
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
72
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
73
+
74
+ Args:
75
+ feature_extractor ([`CLIPImageProcessor`]):
76
+ Feature extractor for image pre-processing before being encoded.
77
+ image_encoder ([`CLIPVisionModelWithProjection`]):
78
+ CLIP vision model for encoding images.
79
+ image_normalizer ([`StableUnCLIPImageNormalizer`]):
80
+ Used to normalize the predicted image embeddings before the noise is applied and un-normalize the image
81
+ embeddings after the noise has been applied.
82
+ image_noising_scheduler ([`KarrasDiffusionSchedulers`]):
83
+ Noise schedule for adding noise to the predicted image embeddings. The amount of noise to add is determined
84
+ by the `noise_level`.
85
+ tokenizer (`~transformers.CLIPTokenizer`):
86
+ A [`~transformers.CLIPTokenizer`)].
87
+ text_encoder ([`~transformers.CLIPTextModel`]):
88
+ Frozen [`~transformers.CLIPTextModel`] text-encoder.
89
+ unet ([`UNet2DConditionModel`]):
90
+ A [`UNet2DConditionModel`] to denoise the encoded image latents.
91
+ scheduler ([`KarrasDiffusionSchedulers`]):
92
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
93
+ vae ([`AutoencoderKL`]):
94
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
95
+ """
96
+
97
+ _exclude_from_cpu_offload = ["image_normalizer"]
98
+
99
+ # image encoding components
100
+ feature_extractor: CLIPImageProcessor
101
+ image_encoder: CLIPVisionModelWithProjection
102
+
103
+ # image noising components
104
+ image_normalizer: StableUnCLIPImageNormalizer
105
+ image_noising_scheduler: KarrasDiffusionSchedulers
106
+
107
+ # regular denoising components
108
+ tokenizer: CLIPTokenizer
109
+ text_encoder: CLIPTextModel
110
+ unet: UNet2DConditionModel
111
+ scheduler: KarrasDiffusionSchedulers
112
+
113
+ vae: AutoencoderKL
114
+
115
+ def __init__(
116
+ self,
117
+ # image encoding components
118
+ feature_extractor: CLIPImageProcessor,
119
+ image_encoder: CLIPVisionModelWithProjection,
120
+ # image noising components
121
+ image_normalizer: StableUnCLIPImageNormalizer,
122
+ image_noising_scheduler: KarrasDiffusionSchedulers,
123
+ # regular denoising components
124
+ tokenizer: CLIPTokenizer,
125
+ text_encoder: CLIPTextModel,
126
+ unet: UNet2DConditionModel,
127
+ scheduler: KarrasDiffusionSchedulers,
128
+ # vae
129
+ vae: AutoencoderKL,
130
+ ):
131
+ super().__init__()
132
+
133
+ self.register_modules(
134
+ feature_extractor=feature_extractor,
135
+ image_encoder=image_encoder,
136
+ image_normalizer=image_normalizer,
137
+ image_noising_scheduler=image_noising_scheduler,
138
+ tokenizer=tokenizer,
139
+ text_encoder=text_encoder,
140
+ unet=unet,
141
+ scheduler=scheduler,
142
+ vae=vae,
143
+ )
144
+
145
+ self.vae_scale_factor = 2**(len(self.vae.config.block_out_channels) - 1)
146
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
147
+
148
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
149
+ def enable_vae_slicing(self):
150
+ r"""
151
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
152
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
153
+ """
154
+ self.vae.enable_slicing()
155
+
156
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
157
+ def disable_vae_slicing(self):
158
+ r"""
159
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
160
+ computing decoding in one step.
161
+ """
162
+ self.vae.disable_slicing()
163
+
164
+ def enable_model_cpu_offload(self, gpu_id=0):
165
+ r"""
166
+ Offload all models to CPU to reduce memory usage with a low impact on performance. Moves one whole model at a
167
+ time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs.
168
+ Memory savings are lower than using `enable_sequential_cpu_offload`, but performance is much better due to the
169
+ iterative execution of the `unet`.
170
+ """
171
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
172
+ from accelerate import cpu_offload_with_hook
173
+ else:
174
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
175
+
176
+ device = torch.device(f"cuda:{gpu_id}")
177
+
178
+ if self.device.type != "cpu":
179
+ self.to("cpu", silence_dtype_warnings=True)
180
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
181
+
182
+ hook = None
183
+ for cpu_offloaded_model in [self.text_encoder, self.image_encoder, self.unet, self.vae]:
184
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
185
+
186
+ # We'll offload the last model manually.
187
+ self.final_offload_hook = hook
188
+
189
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
190
+ def _encode_prompt(
191
+ self,
192
+ prompt,
193
+ device,
194
+ num_images_per_prompt,
195
+ do_classifier_free_guidance,
196
+ negative_prompt=None,
197
+ prompt_embeds: Optional[torch.FloatTensor] = None,
198
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
199
+ lora_scale: Optional[float] = None,
200
+ ):
201
+ r"""
202
+ Encodes the prompt into text encoder hidden states.
203
+
204
+ Args:
205
+ prompt (`str` or `List[str]`, *optional*):
206
+ prompt to be encoded
207
+ device: (`torch.device`):
208
+ torch device
209
+ num_images_per_prompt (`int`):
210
+ number of images that should be generated per prompt
211
+ do_classifier_free_guidance (`bool`):
212
+ whether to use classifier free guidance or not
213
+ negative_prompt (`str` or `List[str]`, *optional*):
214
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
215
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
216
+ less than `1`).
217
+ prompt_embeds (`torch.FloatTensor`, *optional*):
218
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
219
+ provided, text embeddings will be generated from `prompt` input argument.
220
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
221
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
222
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
223
+ argument.
224
+ lora_scale (`float`, *optional*):
225
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
226
+ """
227
+ # set lora scale so that monkey patched LoRA
228
+ # function of text encoder can correctly access it
229
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
230
+ self._lora_scale = lora_scale
231
+
232
+ if prompt is not None and isinstance(prompt, str):
233
+ batch_size = 1
234
+ elif prompt is not None and isinstance(prompt, list):
235
+ batch_size = len(prompt)
236
+ else:
237
+ batch_size = prompt_embeds.shape[0]
238
+
239
+ if prompt_embeds is None:
240
+ # textual inversion: procecss multi-vector tokens if necessary
241
+ if isinstance(self, TextualInversionLoaderMixin):
242
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
243
+
244
+ text_inputs = self.tokenizer(
245
+ prompt,
246
+ padding="max_length",
247
+ max_length=self.tokenizer.model_max_length,
248
+ truncation=True,
249
+ return_tensors="pt",
250
+ )
251
+ text_input_ids = text_inputs.input_ids
252
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
253
+
254
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
255
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1:-1])
256
+ logger.warning("The following part of your input was truncated because CLIP can only handle sequences up to"
257
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}")
258
+
259
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
260
+ attention_mask = text_inputs.attention_mask.to(device)
261
+ else:
262
+ attention_mask = None
263
+
264
+ prompt_embeds = self.text_encoder(
265
+ text_input_ids.to(device),
266
+ attention_mask=attention_mask,
267
+ )
268
+ prompt_embeds = prompt_embeds[0]
269
+
270
+ if self.text_encoder is not None:
271
+ prompt_embeds_dtype = self.text_encoder.dtype
272
+ elif self.unet is not None:
273
+ prompt_embeds_dtype = self.unet.dtype
274
+ else:
275
+ prompt_embeds_dtype = prompt_embeds.dtype
276
+
277
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
278
+
279
+ bs_embed, seq_len, _ = prompt_embeds.shape
280
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
281
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
282
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
283
+
284
+ # get unconditional embeddings for classifier free guidance
285
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
286
+ uncond_tokens: List[str]
287
+ if negative_prompt is None:
288
+ uncond_tokens = [""] * batch_size
289
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
290
+ raise TypeError(f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
291
+ f" {type(prompt)}.")
292
+ elif isinstance(negative_prompt, str):
293
+ uncond_tokens = [negative_prompt]
294
+ elif batch_size != len(negative_prompt):
295
+ raise ValueError(
296
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
297
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
298
+ " the batch size of `prompt`.")
299
+ else:
300
+ uncond_tokens = negative_prompt
301
+
302
+ # textual inversion: procecss multi-vector tokens if necessary
303
+ if isinstance(self, TextualInversionLoaderMixin):
304
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
305
+
306
+ max_length = prompt_embeds.shape[1]
307
+ uncond_input = self.tokenizer(
308
+ uncond_tokens,
309
+ padding="max_length",
310
+ max_length=max_length,
311
+ truncation=True,
312
+ return_tensors="pt",
313
+ )
314
+
315
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
316
+ attention_mask = uncond_input.attention_mask.to(device)
317
+ else:
318
+ attention_mask = None
319
+
320
+ negative_prompt_embeds = self.text_encoder(
321
+ uncond_input.input_ids.to(device),
322
+ attention_mask=attention_mask,
323
+ )
324
+ negative_prompt_embeds = negative_prompt_embeds[0]
325
+
326
+ if do_classifier_free_guidance:
327
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
328
+ seq_len = negative_prompt_embeds.shape[1]
329
+
330
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
331
+
332
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
333
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
334
+
335
+ # For classifier free guidance, we need to do two forward passes.
336
+ # Here we concatenate the unconditional and text embeddings into a single batch
337
+ # to avoid doing two forward passes
338
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
339
+
340
+ return prompt_embeds
341
+
342
+ def _encode_image(
343
+ self,
344
+ image,
345
+ device,
346
+ batch_size,
347
+ num_images_per_prompt,
348
+ do_classifier_free_guidance,
349
+ noise_level,
350
+ generator,
351
+ image_embeds,
352
+ negative_image_embeds,
353
+ ):
354
+ dtype = next(self.image_encoder.parameters()).dtype
355
+
356
+ if isinstance(image, PIL.Image.Image):
357
+ # the image embedding should repeated so it matches the total batch size of the prompt
358
+ repeat_by = batch_size
359
+ else:
360
+ # assume the image input is already properly batched and just needs to be repeated so
361
+ # it matches the num_images_per_prompt.
362
+ #
363
+ # NOTE(will) this is probably missing a few number of side cases. I.e. batched/non-batched
364
+ # `image_embeds`. If those happen to be common use cases, let's think harder about
365
+ # what the expected dimensions of inputs should be and how we handle the encoding.
366
+ repeat_by = num_images_per_prompt
367
+
368
+ if image_embeds is None:
369
+ if not isinstance(image, torch.Tensor):
370
+ image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
371
+
372
+ image = image.to(device=device, dtype=dtype)
373
+ image_embeds = self.image_encoder(image).image_embeds
374
+
375
+ image_embeds = self.noise_image_embeddings(
376
+ image_embeds=image_embeds,
377
+ noise_level=noise_level,
378
+ generator=generator,
379
+ )
380
+
381
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
382
+ image_embeds = image_embeds.unsqueeze(1)
383
+ bs_embed, seq_len, _ = image_embeds.shape
384
+ image_embeds = image_embeds.repeat(1, repeat_by, 1)
385
+ image_embeds = image_embeds.view(bs_embed * repeat_by, seq_len, -1)
386
+ image_embeds = image_embeds.squeeze(1)
387
+
388
+ if negative_image_embeds is not None:
389
+ negative_image_embeds = self.noise_image_embeddings(
390
+ image_embeds=negative_image_embeds,
391
+ noise_level=0,
392
+ generator=generator,
393
+ )
394
+ # duplicate negative image embeddings for each generation per prompt, using mps friendly method
395
+ negative_image_embeds = negative_image_embeds.unsqueeze(1)
396
+ bs_embed, seq_len, _ = negative_image_embeds.shape
397
+ negative_image_embeds = negative_image_embeds.repeat(1, repeat_by, 1)
398
+ negative_image_embeds = negative_image_embeds.view(bs_embed * repeat_by, seq_len, -1)
399
+ negative_image_embeds = negative_image_embeds.squeeze(1)
400
+
401
+ if do_classifier_free_guidance:
402
+ if negative_image_embeds is None:
403
+ negative_image_embeds = torch.zeros_like(image_embeds)
404
+
405
+ # For classifier free guidance, we need to do two forward passes.
406
+ # Here we concatenate the unconditional and text embeddings into a single batch
407
+ # to avoid doing two forward passes
408
+ image_embeds = torch.cat([negative_image_embeds, image_embeds])
409
+
410
+ return image_embeds
411
+
412
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
413
+ def decode_latents(self, latents):
414
+ warnings.warn(
415
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
416
+ " use VaeImageProcessor instead",
417
+ FutureWarning,
418
+ )
419
+ latents = 1 / self.vae.config.scaling_factor * latents
420
+ image = self.vae.decode(latents, return_dict=False)[0]
421
+ image = (image / 2 + 0.5).clamp(0, 1)
422
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
423
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
424
+ return image
425
+
426
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
427
+ def prepare_extra_step_kwargs(self, generator, eta):
428
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
429
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
430
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
431
+ # and should be between [0, 1]
432
+
433
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
434
+ extra_step_kwargs = {}
435
+ if accepts_eta:
436
+ extra_step_kwargs["eta"] = eta
437
+
438
+ # check if the scheduler accepts generator
439
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
440
+ if accepts_generator:
441
+ extra_step_kwargs["generator"] = generator
442
+ return extra_step_kwargs
443
+
444
+ def check_inputs(
445
+ self,
446
+ prompt,
447
+ image,
448
+ height,
449
+ width,
450
+ callback_steps,
451
+ noise_level,
452
+ negative_prompt=None,
453
+ prompt_embeds=None,
454
+ negative_prompt_embeds=None,
455
+ image_embeds=None,
456
+ ):
457
+ if height % 8 != 0 or width % 8 != 0:
458
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
459
+
460
+ if (callback_steps is None) or (callback_steps is not None and
461
+ (not isinstance(callback_steps, int) or callback_steps <= 0)):
462
+ raise ValueError(f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
463
+ f" {type(callback_steps)}.")
464
+
465
+ if prompt is not None and prompt_embeds is not None:
466
+ raise ValueError("Provide either `prompt` or `prompt_embeds`. Please make sure to define only one of the two.")
467
+
468
+ if prompt is None and prompt_embeds is None:
469
+ raise ValueError(
470
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.")
471
+
472
+ if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
473
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
474
+
475
+ if negative_prompt is not None and negative_prompt_embeds is not None:
476
+ raise ValueError(
477
+ "Provide either `negative_prompt` or `negative_prompt_embeds`. Cannot leave both `negative_prompt` and `negative_prompt_embeds` undefined."
478
+ )
479
+
480
+ if prompt is not None and negative_prompt is not None:
481
+ if type(prompt) is not type(negative_prompt):
482
+ raise TypeError(f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
483
+ f" {type(prompt)}.")
484
+
485
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
486
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
487
+ raise ValueError(
488
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
489
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
490
+ f" {negative_prompt_embeds.shape}.")
491
+
492
+ if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps:
493
+ raise ValueError(
494
+ f"`noise_level` must be between 0 and {self.image_noising_scheduler.config.num_train_timesteps - 1}, inclusive."
495
+ )
496
+
497
+ if image is not None and image_embeds is not None:
498
+ raise ValueError("Provide either `image` or `image_embeds`. Please make sure to define only one of the two.")
499
+
500
+ if image is None and image_embeds is None:
501
+ raise ValueError(
502
+ "Provide either `image` or `image_embeds`. Cannot leave both `image` and `image_embeds` undefined.")
503
+
504
+ if image is not None:
505
+ if (not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image) and not isinstance(image, list)):
506
+ raise ValueError(
507
+ "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
508
+ f" {type(image)}")
509
+
510
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
511
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
512
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
513
+ if isinstance(generator, list) and len(generator) != batch_size:
514
+ raise ValueError(
515
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
516
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators.")
517
+
518
+ if latents is None:
519
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
520
+ else:
521
+ latents = latents.to(device)
522
+
523
+ # scale the initial noise by the standard deviation required by the scheduler
524
+ latents = latents * self.scheduler.init_noise_sigma
525
+ return latents
526
+
527
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_unclip.StableUnCLIPPipeline.noise_image_embeddings
528
+ def noise_image_embeddings(
529
+ self,
530
+ image_embeds: torch.Tensor,
531
+ noise_level: int,
532
+ noise: Optional[torch.FloatTensor] = None,
533
+ generator: Optional[torch.Generator] = None,
534
+ ):
535
+ """
536
+ Add noise to the image embeddings. The amount of noise is controlled by a `noise_level` input. A higher
537
+ `noise_level` increases the variance in the final un-noised images.
538
+
539
+ The noise is applied in two ways:
540
+ 1. A noise schedule is applied directly to the embeddings.
541
+ 2. A vector of sinusoidal time embeddings are appended to the output.
542
+
543
+ In both cases, the amount of noise is controlled by the same `noise_level`.
544
+
545
+ The embeddings are normalized before the noise is applied and un-normalized after the noise is applied.
546
+ """
547
+ if noise is None:
548
+ noise = randn_tensor(image_embeds.shape, generator=generator, device=image_embeds.device, dtype=image_embeds.dtype)
549
+
550
+ noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device)
551
+
552
+ self.image_normalizer.to(image_embeds.device)
553
+ image_embeds = self.image_normalizer.scale(image_embeds)
554
+
555
+ image_embeds = self.image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise)
556
+
557
+ image_embeds = self.image_normalizer.unscale(image_embeds)
558
+
559
+ noise_level = get_timestep_embedding(timesteps=noise_level,
560
+ embedding_dim=image_embeds.shape[-1],
561
+ flip_sin_to_cos=True,
562
+ downscale_freq_shift=0)
563
+
564
+ # `get_timestep_embeddings` does not contain any weights and will always return f32 tensors,
565
+ # but we might actually be running in fp16. so we need to cast here.
566
+ # there might be better ways to encapsulate this.
567
+ noise_level = noise_level.to(image_embeds.dtype)
568
+
569
+ image_embeds = torch.cat((image_embeds, noise_level), 1)
570
+
571
+ return image_embeds
572
+
573
+ @torch.no_grad()
574
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
575
+ def __call__(
576
+ self,
577
+ image: Union[torch.FloatTensor, PIL.Image.Image] = None,
578
+ prompt: Union[str, List[str]] = None,
579
+ height: Optional[int] = None,
580
+ width: Optional[int] = None,
581
+ num_inference_steps: int = 20,
582
+ guidance_scale: float = 10,
583
+ negative_prompt: Optional[Union[str, List[str]]] = None,
584
+ num_images_per_prompt: Optional[int] = 1,
585
+ eta: float = 0.0,
586
+ generator: Optional[torch.Generator] = None,
587
+ latents: Optional[torch.FloatTensor] = None,
588
+ prompt_embeds: Optional[torch.FloatTensor] = None,
589
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
590
+ output_type: Optional[str] = "pil",
591
+ return_dict: bool = True,
592
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
593
+ callback_steps: int = 1,
594
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
595
+ noise_level: int = 0,
596
+ image_embeds: Optional[torch.FloatTensor] = None,
597
+ negative_image_embeds: Optional[torch.FloatTensor] = None,
598
+ ):
599
+ r"""
600
+ The call function to the pipeline for generation.
601
+
602
+ Args:
603
+ prompt (`str` or `List[str]`, *optional*):
604
+ The prompt or prompts to guide the image generation. If not defined, either `prompt_embeds` will be
605
+ used or prompt is initialized to `""`.
606
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
607
+ `Image` or tensor representing an image batch. The image is encoded to its CLIP embedding which the
608
+ `unet` is conditioned on. The image is _not_ encoded by the `vae` and then used as the latents in the
609
+ denoising process like it is in the standard Stable Diffusion text-guided image variation process.
610
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
611
+ The height in pixels of the generated image.
612
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
613
+ The width in pixels of the generated image.
614
+ num_inference_steps (`int`, *optional*, defaults to 20):
615
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
616
+ expense of slower inference.
617
+ guidance_scale (`float`, *optional*, defaults to 10.0):
618
+ A higher guidance scale value encourages the model to generate images closely linked to the text
619
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
620
+ negative_prompt (`str` or `List[str]`, *optional*):
621
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
622
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
623
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
624
+ The number of images to generate per prompt.
625
+ eta (`float`, *optional*, defaults to 0.0):
626
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
627
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
628
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
629
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
630
+ generation deterministic.
631
+ latents (`torch.FloatTensor`, *optional*):
632
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
633
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
634
+ tensor is generated by sampling using the supplied random `generator`.
635
+ prompt_embeds (`torch.FloatTensor`, *optional*):
636
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
637
+ provided, text embeddings are generated from the `prompt` input argument.
638
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
639
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
640
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
641
+ output_type (`str`, *optional*, defaults to `"pil"`):
642
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
643
+ return_dict (`bool`, *optional*, defaults to `True`):
644
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
645
+ callback (`Callable`, *optional*):
646
+ A function that calls every `callback_steps` steps during inference. The function is called with the
647
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
648
+ callback_steps (`int`, *optional*, defaults to 1):
649
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
650
+ every step.
651
+ cross_attention_kwargs (`dict`, *optional*):
652
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
653
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
654
+ noise_level (`int`, *optional*, defaults to `0`):
655
+ The amount of noise to add to the image embeddings. A higher `noise_level` increases the variance in
656
+ the final un-noised images. See [`StableUnCLIPPipeline.noise_image_embeddings`] for more details.
657
+ image_embeds (`torch.FloatTensor`, *optional*):
658
+ Pre-generated CLIP embeddings to condition the `unet` on. These latents are not used in the denoising
659
+ process. If you want to provide pre-generated latents, pass them to `__call__` as `latents`.
660
+
661
+ Examples:
662
+
663
+ Returns:
664
+ [`~pipelines.ImagePipelineOutput`] or `tuple`:
665
+ [`~ pipeline_utils.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning
666
+ a tuple, the first element is a list with the generated images.
667
+ """
668
+ # 0. Default height and width to unet
669
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
670
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
671
+
672
+ if prompt is None and prompt_embeds is None:
673
+ prompt = len(image) * [""] if isinstance(image, list) else ""
674
+
675
+ # 1. Check inputs. Raise error if not correct
676
+ self.check_inputs(
677
+ prompt=prompt,
678
+ image=image,
679
+ height=height,
680
+ width=width,
681
+ callback_steps=callback_steps,
682
+ noise_level=noise_level,
683
+ negative_prompt=negative_prompt,
684
+ prompt_embeds=prompt_embeds,
685
+ negative_prompt_embeds=negative_prompt_embeds,
686
+ image_embeds=image_embeds,
687
+ )
688
+
689
+ # 2. Define call parameters
690
+ if prompt is not None and isinstance(prompt, str):
691
+ batch_size = 1
692
+ elif prompt is not None and isinstance(prompt, list):
693
+ batch_size = len(prompt)
694
+ else:
695
+ batch_size = prompt_embeds.shape[0]
696
+
697
+ batch_size = batch_size * num_images_per_prompt
698
+
699
+ device = self._execution_device
700
+
701
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
702
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
703
+ # corresponds to doing no classifier free guidance.
704
+ do_classifier_free_guidance = guidance_scale > 1.0
705
+
706
+ # 3. Encode input prompt
707
+ text_encoder_lora_scale = (cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None)
708
+ prompt_embeds = self._encode_prompt(
709
+ prompt=prompt,
710
+ device=device,
711
+ num_images_per_prompt=num_images_per_prompt,
712
+ do_classifier_free_guidance=do_classifier_free_guidance,
713
+ negative_prompt=negative_prompt,
714
+ prompt_embeds=prompt_embeds,
715
+ negative_prompt_embeds=negative_prompt_embeds,
716
+ lora_scale=text_encoder_lora_scale,
717
+ )
718
+
719
+ # 4. Encoder input image
720
+ noise_level = torch.tensor([noise_level], device=device)
721
+ image_embeds = self._encode_image(
722
+ image=image,
723
+ device=device,
724
+ batch_size=batch_size,
725
+ num_images_per_prompt=num_images_per_prompt,
726
+ do_classifier_free_guidance=do_classifier_free_guidance,
727
+ noise_level=noise_level,
728
+ generator=generator,
729
+ image_embeds=image_embeds,
730
+ negative_image_embeds=negative_image_embeds,
731
+ )
732
+
733
+ # 5. Prepare timesteps
734
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
735
+ timesteps = self.scheduler.timesteps
736
+
737
+ # 6. Prepare latent variables
738
+ num_channels_latents = self.unet.config.in_channels
739
+ latents = self.prepare_latents(
740
+ batch_size=batch_size,
741
+ num_channels_latents=num_channels_latents,
742
+ height=height,
743
+ width=width,
744
+ dtype=prompt_embeds.dtype,
745
+ device=device,
746
+ generator=generator,
747
+ latents=latents,
748
+ )
749
+
750
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
751
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
752
+
753
+ # 8. Denoising loop
754
+ for i, t in enumerate(self.progress_bar(timesteps)):
755
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
756
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
757
+
758
+ # predict the noise residual
759
+ noise_pred = self.unet(
760
+ latent_model_input,
761
+ t,
762
+ encoder_hidden_states=prompt_embeds,
763
+ class_labels=image_embeds,
764
+ cross_attention_kwargs=cross_attention_kwargs,
765
+ return_dict=False,
766
+ )[0]
767
+
768
+ # perform guidance
769
+ if do_classifier_free_guidance:
770
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
771
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
772
+
773
+ # compute the previous noisy sample x_t -> x_t-1
774
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
775
+
776
+ if callback is not None and i % callback_steps == 0:
777
+ callback(i, t, latents)
778
+
779
+ # 9. Post-processing
780
+ if not output_type == "latent":
781
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
782
+ else:
783
+ image = latents
784
+
785
+ image = self.image_processor.postprocess(image, output_type=output_type)
786
+
787
+ # Offload last model to CPU
788
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
789
+ self.final_offload_hook.offload()
790
+
791
+ if not return_dict:
792
+ return (image, )
793
+
794
+ return ImagePipelineOutput(images=image)
models/seed_llama_tokenizer.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ # import math
4
+ # from torchvision import transforms
5
+ import os
6
+ # from timm.models import create_model
7
+ from typing import Any, Dict, List, Optional, Union
8
+ from transformers import LlamaTokenizer
9
+ from diffusers import DiffusionPipeline
10
+ # from torchvision.transforms.functional import pil_to_tensor
11
+
12
+ # import torch
13
+ from PIL import Image
14
+ from torchvision import transforms
15
+
16
+ # from qformer.qformer_quantizer import Blip2QformerQuantizer
17
+ # from diffusers import StableUnCLIPImg2ImgPipeline
18
+ from .pipeline_stable_unclip_img2img import StableUnCLIPImg2ImgPipeline
19
+
20
+ WEIGHTS_NAME = 'seed_quantizer.pt'
21
+ DIFFUSION_NAME = 'diffusion_model'
22
+
23
+
24
+ class ImageTokenizer(nn.Module):
25
+ def __init__(self,
26
+ model_path,
27
+ diffusion_model_path=None,
28
+ load_diffusion=False,
29
+ image_size=224,
30
+ device='cuda',
31
+ fp16=True,
32
+ **kwargs):
33
+ super().__init__()
34
+ from .seed_qformer.qformer_quantizer import Blip2QformerQuantizer
35
+
36
+ model = Blip2QformerQuantizer.from_pretrained(pretrained_model_path=model_path,
37
+ vit_precision='fp16' if fp16 else 'fp32',
38
+ **kwargs).eval()
39
+ if diffusion_model_path is not None and load_diffusion:
40
+ # diffusion_model = DiffusionPipeline.from_pretrained(diffusion_model_path,
41
+ # torch_dtype=torch.float16 if fp16 else torch.float32)
42
+ diffusion_model = StableUnCLIPImg2ImgPipeline.from_pretrained(diffusion_model_path,
43
+ torch_dtype=torch.float16 if fp16 else torch.float32)
44
+ self.diffusion_model = diffusion_model.to(device)
45
+ else:
46
+ self.diffusion_model = None
47
+
48
+ model = model.to(device)
49
+
50
+ processor = transforms.Compose([
51
+ transforms.Resize((image_size, image_size), interpolation=3),
52
+ # transforms.Resize(image_size, interpolation=3),
53
+ # transforms.CenterCrop(image_size),
54
+ transforms.ToTensor(),
55
+ transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
56
+ ])
57
+
58
+ if fp16:
59
+ model = model.half()
60
+
61
+ shape_latents = torch.Size([1, 4, 96, 96])
62
+ self.latents = torch.randn(shape_latents, generator=None, device=device, dtype=torch.float16, layout=torch.strided)
63
+
64
+ shape_noise = torch.Size([1, 1024])
65
+ self.noise = torch.randn(shape_noise, generator=None, device=device, dtype=torch.float16, layout=torch.strided)
66
+
67
+ self.model = model
68
+ self.processor = processor
69
+ self.device = device
70
+ self.fp16 = fp16
71
+
72
+ def __len__(self):
73
+ return self.model.n_embed
74
+
75
+ def encode(self, image_torch):
76
+ '''Convert a batch of img to code
77
+ Args:
78
+ model: The tokenizer model.
79
+ img: [b, c, h, w]
80
+ '''
81
+ if len(image_torch.shape) == 3:
82
+ image_torch = image_torch.unsqueeze(0)
83
+
84
+ # img = image_torch.to(self.device)
85
+ img = image_torch
86
+ if self.fp16:
87
+ img = img.half()
88
+ with torch.no_grad():
89
+ id, _ = self.model.get_codebook_indices(img)
90
+ return id.view(img.shape[0], -1)
91
+
92
+ def decode(self, indices, negative_indices=None, guidance_scale=10, num_inference_steps=20):
93
+ image_embeds = self.model.get_codebook_entry(indices)
94
+ # image = self.diffusion_model(image_embeds=image_embed,
95
+ # noise_level=0,
96
+ # num_inference_steps=20,
97
+ # latents=self.latents,
98
+ # noise=self.noise).images
99
+ if negative_indices is not None:
100
+ assert indices.shape == negative_indices.shape, 'Negative indices must have the same shape with indices'
101
+ negative_image_embeds = self.model.get_codebook_entry(negative_indices)
102
+ else:
103
+ negative_image_embeds = None
104
+
105
+ image = self.diffusion_model(
106
+ image_embeds=image_embeds,
107
+ negative_image_embeds=negative_image_embeds,
108
+ guidance_scale=guidance_scale,
109
+ noise_level=0,
110
+ num_inference_steps=num_inference_steps,
111
+ latents=self.latents,
112
+ ).images
113
+ return image
114
+
115
+
116
+ class SeedLlamaTokenizer(LlamaTokenizer):
117
+ def __init__(self,
118
+ vocab_file,
119
+ unk_token="<unk>",
120
+ bos_token="<s>",
121
+ eos_token="</s>",
122
+ pad_token=None,
123
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
124
+ add_bos_token=True,
125
+ add_eos_token=False,
126
+ clean_up_tokenization_spaces=False,
127
+ device='cuda',
128
+ fp16=True,
129
+ load_diffusion=False,
130
+ encoder_url=None,
131
+ diffusion_path=None,
132
+ **kwargs):
133
+ super().__init__(vocab_file, unk_token, bos_token, eos_token, pad_token, sp_model_kwargs, add_bos_token, add_eos_token,
134
+ clean_up_tokenization_spaces, **kwargs)
135
+ self.device = device
136
+ self.fp16 = fp16
137
+ self.pad_token = self.unk_token
138
+ self.load_diffusion = load_diffusion
139
+ self.encoder_url = encoder_url
140
+ self.diffusion_path = diffusion_path
141
+
142
+ self.load_image_tokenizer()
143
+
144
+ def load_image_tokenizer(self):
145
+ if not hasattr(self, '_image_tokenizer'):
146
+ if self.encoder_url is not None:
147
+ model_path = self.encoder_url
148
+ else:
149
+ assert hasattr(self, 'name_or_path') and os.path.exists(self.name_or_path)
150
+ model_path = os.path.join(self.name_or_path, WEIGHTS_NAME)
151
+ # diffusion_model_path = os.path.join(self.name_or_path, DIFFUSION_NAME)
152
+ # diffusion_model_path = 'stabilityai/stable-diffusion-2-1-unclip'
153
+ self._image_tokenizer = ImageTokenizer(model_path=model_path,
154
+ diffusion_model_path=self.diffusion_path,
155
+ load_diffusion=self.load_diffusion,
156
+ device=self.device,
157
+ fp16=self.fp16)
158
+
159
+ @property
160
+ def image_tokenizer(self):
161
+ if not hasattr(self, '_image_tokenizer'):
162
+ if self.encoder_url is not None:
163
+ model_path = self.encoder_url
164
+ else:
165
+ assert hasattr(self, 'name_or_path') and os.path.exists(self.name_or_path)
166
+ model_path = os.path.join(self.name_or_path, WEIGHTS_NAME)
167
+ # diffusion_model_path = os.path.join(self.name_or_path, DIFFUSION_NAME)
168
+ # diffusion_model_path = 'stabilityai/stable-diffusion-2-1-unclip'
169
+ self._image_tokenizer = ImageTokenizer(model_path=model_path,
170
+ diffusion_model_path=self.diffusion_path,
171
+ load_diffusion=self.load_diffusion,
172
+ device=self.device,
173
+ fp16=self.fp16)
174
+ return self._image_tokenizer
175
+
176
+ @property
177
+ def num_image_tokens(self):
178
+ return 8192 # self.image_tokenizer.num_tokens # allow not load
179
+
180
+ def to(self, device):
181
+ self.device = device
182
+ if hasattr(self, '_image_tokenizer'):
183
+ self._image_tokenizer.to(device=device)
184
+
185
+ def encode_image(
186
+ self,
187
+ image_path=None,
188
+ image_pil=None,
189
+ image_torch=None,
190
+ image_size: int = 224,
191
+ ):
192
+ assert (image_path is None) + (image_pil is None) + (image_torch is None) == 2
193
+
194
+ # need_norm_to_1 = False
195
+ if image_path is not None:
196
+ image_pil = Image.open(image_path).convert('RGB')
197
+
198
+ if image_pil is not None:
199
+ image_torch = self.image_tokenizer.processor(image_pil)
200
+
201
+ image_torch = image_torch.to(self.device)
202
+ return self.image_tokenizer.encode(image_torch)
203
+
204
+ def decode_image(self, indices, negative_indices=None, guidance_scale=10):
205
+ indices = indices.to(self.device)
206
+ if negative_indices is not None:
207
+ negative_indices = negative_indices.to(self.device)
208
+ image = self.image_tokenizer.decode(
209
+ indices,
210
+ negative_indices=negative_indices,
211
+ guidance_scale=guidance_scale,
212
+ )
213
+ return image
models/seed_qformer/blip2.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2023, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+ import contextlib
8
+ import logging
9
+ import os
10
+ import time
11
+ import datetime
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.distributed as dist
16
+ import torch.nn.functional as F
17
+
18
+
19
+ from .qformer_causual import BertConfig, BertLMHeadModel
20
+
21
+ from .utils import download_cached_file, get_rank, get_dist_info, get_world_size, main_process, is_dist_avail_and_initialized, is_url
22
+ from .eva_vit import create_eva_vit_g
23
+ from .clip_vit import create_clip_vit_L
24
+ from transformers import BertTokenizer
25
+
26
+
27
+ # class Blip2Base(BaseModel):
28
+ class Blip2Base(nn.Module):
29
+ def __init__(self):
30
+ super().__init__()
31
+
32
+ @property
33
+ def device(self):
34
+ return list(self.parameters())[0].device
35
+
36
+ @classmethod
37
+ def init_tokenizer(cls, truncation_side="right"):
38
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side=truncation_side)
39
+ tokenizer.add_special_tokens({"bos_token": "[DEC]"})
40
+ return tokenizer
41
+
42
+ def maybe_autocast(self, dtype=torch.float16):
43
+ # if on cpu, don't use autocast
44
+ # if on gpu, use autocast with dtype if provided, otherwise use torch.float16
45
+ enable_autocast = self.device != torch.device("cpu")
46
+
47
+ if enable_autocast:
48
+ return torch.cuda.amp.autocast(dtype=dtype)
49
+ else:
50
+ return contextlib.nullcontext()
51
+
52
+ @classmethod
53
+ def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2):
54
+ encoder_config = BertConfig.from_pretrained("bert-base-uncased")
55
+ encoder_config.encoder_width = vision_width
56
+ # insert cross-attention layer every other block
57
+ encoder_config.add_cross_attention = True
58
+ encoder_config.cross_attention_freq = cross_attention_freq
59
+ encoder_config.query_length = num_query_token
60
+ Qformer = BertLMHeadModel.from_pretrained("bert-base-uncased", config=encoder_config)
61
+ query_tokens = nn.Parameter(torch.zeros(1, num_query_token, encoder_config.hidden_size))
62
+ query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
63
+ return Qformer, query_tokens
64
+
65
+ def init_vision_encoder(self, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision):
66
+ assert model_name in [
67
+ "eva_clip_g",
68
+ "eva2_clip_L",
69
+ "clip_L",
70
+ ], "vit model must be eva_clip_g, eva2_clip_L or clip_L"
71
+ if model_name == "eva_clip_g":
72
+ visual_encoder = create_eva_vit_g(img_size, drop_path_rate, use_grad_checkpoint, precision)
73
+
74
+ elif model_name == "clip_L":
75
+ visual_encoder = create_clip_vit_L(img_size, use_grad_checkpoint, precision)
76
+ ln_vision = LayerNorm(visual_encoder.num_features)
77
+ self.vit_name = model_name
78
+ return visual_encoder, ln_vision
79
+
80
+ def load_from_pretrained(self, url_or_filename):
81
+ if is_url(url_or_filename):
82
+ cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
83
+ checkpoint = torch.load(cached_file, map_location="cpu")
84
+ elif os.path.isfile(url_or_filename):
85
+ checkpoint = torch.load(url_or_filename, map_location="cpu")
86
+ else:
87
+ raise RuntimeError("checkpoint url or path is invalid")
88
+
89
+ state_dict = checkpoint["model"]
90
+
91
+ msg = self.load_state_dict(state_dict, strict=False)
92
+
93
+ # logging.info("Missing keys {}".format(msg.missing_keys))
94
+ logging.info("load checkpoint from %s" % url_or_filename)
95
+
96
+ return msg
97
+
98
+ def get_optimizer_params(self, weight_decay, lr_scale=1):
99
+ if self.vit_name == "eva_clip_g":
100
+ vit_num_layers = self.visual_encoder.get_num_layer()
101
+ lr_scales = list(lr_scale**(vit_num_layers + 1 - i) for i in range(vit_num_layers + 2))
102
+
103
+ parameter_group_names = {}
104
+ parameter_group_vars = {}
105
+
106
+ for name, param in self.named_parameters():
107
+ if not param.requires_grad:
108
+ continue # frozen weights
109
+ if len(param.shape) == 1 or name.endswith(".bias"):
110
+ group_name = "no_decay"
111
+ this_weight_decay = 0.
112
+ else:
113
+ group_name = "decay"
114
+ this_weight_decay = weight_decay
115
+ if 'visual_encoder' in name:
116
+ layer_id = self.visual_encoder.get_num_layer(name.replace('visual_encoder.', ''))
117
+ group_name = "vit_layer_%d_%s" % (layer_id, group_name)
118
+ else:
119
+ layer_id = None
120
+
121
+ if group_name not in parameter_group_names:
122
+ if layer_id is not None:
123
+ scale = lr_scales[layer_id]
124
+ else:
125
+ scale = 1
126
+ parameter_group_names[group_name] = {"weight_decay": this_weight_decay, "params": [], "lr_scale": scale}
127
+ parameter_group_vars[group_name] = {"weight_decay": this_weight_decay, "params": [], "lr_scale": scale}
128
+ parameter_group_vars[group_name]["params"].append(param)
129
+ parameter_group_names[group_name]["params"].append(name)
130
+ # import json
131
+ # print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
132
+ optim_params = list(parameter_group_vars.values())
133
+ return optim_params
134
+ else:
135
+ return super().get_optimizer_params(weight_decay, lr_scale)
136
+
137
+ def _lemmatize(self, answers):
138
+ def apply(answer):
139
+ doc = self.lemmatizer(answer)
140
+
141
+ words = []
142
+ for token in doc:
143
+ if token.pos_ in ["NOUN", "VERB"]:
144
+ words.append(token.lemma_)
145
+ else:
146
+ words.append(token.text)
147
+ answer = " ".join(words)
148
+
149
+ return answer
150
+
151
+ return [apply(answer) for answer in answers]
152
+
153
+ @property
154
+ def lemmatizer(self):
155
+ if self._lemmatizer is None:
156
+ try:
157
+ import spacy
158
+
159
+ self._lemmatizer = spacy.load("en_core_web_sm")
160
+ except ImportError:
161
+ logging.error("""
162
+ Please install spacy and en_core_web_sm model to apply lemmatization.
163
+ python -m spacy download en_core_web_sm
164
+ OR
165
+ import spacy.cli
166
+ spacy.cli.download("en_core_web_sm")
167
+ """)
168
+ exit(1)
169
+
170
+ return self._lemmatizer
171
+
172
+
173
+ def disabled_train(self, mode=True):
174
+ """Overwrite model.train with this function to make sure train/eval mode
175
+ does not change anymore."""
176
+ return self
177
+
178
+
179
+ class LayerNorm(nn.LayerNorm):
180
+ """Subclass torch's LayerNorm to handle fp16."""
181
+ def forward(self, x: torch.Tensor):
182
+ orig_type = x.dtype
183
+ ret = super().forward(x.type(torch.float32))
184
+ return ret.type(orig_type)
185
+
186
+
models/seed_qformer/clip_vit.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from itertools import repeat
3
+ import collections.abc
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+
10
+
11
+ from .eva_vit import convert_weights_to_fp16
12
+ from .utils import download_cached_file
13
+
14
+
15
+ class Bottleneck(nn.Module):
16
+ expansion = 4
17
+
18
+ def __init__(self, inplanes, planes, stride=1):
19
+ super().__init__()
20
+
21
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
22
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
23
+ self.bn1 = nn.BatchNorm2d(planes)
24
+ self.relu1 = nn.ReLU(inplace=True)
25
+
26
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
27
+ self.bn2 = nn.BatchNorm2d(planes)
28
+ self.relu2 = nn.ReLU(inplace=True)
29
+
30
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
31
+
32
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
33
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
34
+ self.relu3 = nn.ReLU(inplace=True)
35
+
36
+ self.downsample = None
37
+ self.stride = stride
38
+
39
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
40
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
41
+ self.downsample = nn.Sequential(
42
+ OrderedDict([("-1", nn.AvgPool2d(stride)),
43
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
44
+ ("1", nn.BatchNorm2d(planes * self.expansion))]))
45
+
46
+ def forward(self, x: torch.Tensor):
47
+ identity = x
48
+
49
+ out = self.relu1(self.bn1(self.conv1(x)))
50
+ out = self.relu2(self.bn2(self.conv2(out)))
51
+ out = self.avgpool(out)
52
+ out = self.bn3(self.conv3(out))
53
+
54
+ if self.downsample is not None:
55
+ identity = self.downsample(x)
56
+
57
+ out += identity
58
+ out = self.relu3(out)
59
+ return out
60
+
61
+
62
+ class AttentionPool2d(nn.Module):
63
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
64
+ super().__init__()
65
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5)
66
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
67
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
68
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
69
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
70
+ self.num_heads = num_heads
71
+
72
+ def forward(self, x):
73
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
74
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
75
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
76
+ x, _ = F.multi_head_attention_forward(query=x,
77
+ key=x,
78
+ value=x,
79
+ embed_dim_to_check=x.shape[-1],
80
+ num_heads=self.num_heads,
81
+ q_proj_weight=self.q_proj.weight,
82
+ k_proj_weight=self.k_proj.weight,
83
+ v_proj_weight=self.v_proj.weight,
84
+ in_proj_weight=None,
85
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
86
+ bias_k=None,
87
+ bias_v=None,
88
+ add_zero_attn=False,
89
+ dropout_p=0,
90
+ out_proj_weight=self.c_proj.weight,
91
+ out_proj_bias=self.c_proj.bias,
92
+ use_separate_proj_weight=True,
93
+ training=self.training,
94
+ need_weights=False)
95
+
96
+ return x[0]
97
+
98
+
99
+ class LayerNorm(nn.LayerNorm):
100
+ """Subclass torch's LayerNorm to handle fp16."""
101
+ def forward(self, x: torch.Tensor):
102
+ orig_type = x.dtype
103
+ ret = super().forward(x.type(torch.float32))
104
+ return ret.type(orig_type)
105
+
106
+
107
+ class QuickGELU(nn.Module):
108
+ def forward(self, x: torch.Tensor):
109
+ return x * torch.sigmoid(1.702 * x)
110
+
111
+
112
+ class ResidualAttentionBlock(nn.Module):
113
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, use_grad_checkpointing=False):
114
+ super().__init__()
115
+
116
+ self.attn = nn.MultiheadAttention(d_model, n_head)
117
+ self.ln_1 = LayerNorm(d_model)
118
+ self.mlp = nn.Sequential(
119
+ OrderedDict([("c_fc", nn.Linear(d_model, d_model * 4)), ("gelu", QuickGELU()),
120
+ ("c_proj", nn.Linear(d_model * 4, d_model))]))
121
+ self.ln_2 = LayerNorm(d_model)
122
+ self.attn_mask = attn_mask
123
+
124
+ # if use_grad_checkpointing:
125
+ # self.attn = checkpoint_wrapper(self.attn)
126
+ # self.mlp = checkpoint_wrapper(self.mlp)
127
+ # raise NotImplementedError
128
+
129
+ def attention(self, x: torch.Tensor):
130
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
131
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
132
+
133
+ def forward(self, x: torch.Tensor):
134
+ x = x + self.attention(self.ln_1(x))
135
+ x = x + self.mlp(self.ln_2(x))
136
+ return x
137
+
138
+
139
+ class Transformer(nn.Module):
140
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, use_grad_checkpointing=False):
141
+ super().__init__()
142
+ self.width = width
143
+ self.layers = layers
144
+ self.resblocks = nn.Sequential(
145
+ *[ResidualAttentionBlock(width, heads, attn_mask, use_grad_checkpointing and i > 12) for i in range(layers)])
146
+
147
+ def forward(self, x: torch.Tensor):
148
+ return self.resblocks(x)
149
+
150
+
151
+ class VisionTransformer(nn.Module):
152
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int,
153
+ use_grad_checkpointing: bool):
154
+ super().__init__()
155
+ self.input_resolution = input_resolution
156
+ self.num_features = width
157
+ self.num_heads = heads
158
+ self.num_patches = (input_resolution // patch_size)**2
159
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
160
+
161
+ scale = width**-0.5
162
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
163
+ self.positional_embedding = nn.Parameter(scale * torch.randn(self.num_patches + 1, width))
164
+ self.ln_pre = LayerNorm(width)
165
+
166
+ self.transformer = Transformer(width, layers, heads, use_grad_checkpointing=use_grad_checkpointing)
167
+
168
+ # self.ln_final = LayerNorm(width)
169
+
170
+ def forward(self, x: torch.Tensor):
171
+
172
+ x = self.conv1(x) # shape = [*, width, grid, grid]
173
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
174
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
175
+ x = torch.cat(
176
+ [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x],
177
+ dim=1) # shape = [*, grid ** 2 + 1, width]
178
+ x = x + self.positional_embedding.to(x.dtype)
179
+ x = self.ln_pre(x)
180
+
181
+ x = x.permute(1, 0, 2) # NLD -> LND
182
+ x = self.transformer(x)
183
+ x = x.permute(1, 0, 2) # LND -> NLD
184
+
185
+ # x = self.ln_final(x)
186
+ return x
187
+
188
+
189
+ # From PyTorch internals
190
+ def _ntuple(n):
191
+ def parse(x):
192
+ if isinstance(x, collections.abc.Iterable):
193
+ return x
194
+ return tuple(repeat(x, n))
195
+
196
+ return parse
197
+
198
+
199
+ to_2tuple = _ntuple(2)
200
+
201
+
202
+ def interpolate_pos_embed(model, state_dict, interpolation: str = 'bicubic', seq_dim=1):
203
+ # Rescale the grid of position embeddings when loading from state_dict
204
+ old_pos_embed = state_dict.get('positional_embedding', None)
205
+
206
+ grid_size = round((model.positional_embedding.shape[0] - 1)**0.5)
207
+ if old_pos_embed is None:
208
+ return
209
+ grid_size = to_2tuple(grid_size)
210
+ extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
211
+ new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
212
+ if new_seq_len == old_pos_embed.shape[0]:
213
+ return
214
+
215
+ if extra_tokens:
216
+ pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
217
+ else:
218
+ pos_emb_tok, pos_emb_img = None, old_pos_embed
219
+
220
+ old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
221
+
222
+ print('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
223
+ pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
224
+ pos_emb_img = F.interpolate(
225
+ pos_emb_img,
226
+ size=grid_size,
227
+ mode=interpolation,
228
+ align_corners=True,
229
+ )
230
+ pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
231
+ if pos_emb_tok is not None:
232
+ new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
233
+ else:
234
+ new_pos_embed = pos_emb_img
235
+ state_dict['positional_embedding'] = new_pos_embed
236
+
237
+
238
+ def create_clip_vit_L(img_size=224, use_checkpoint=False, precision="fp16"):
239
+ model = VisionTransformer(
240
+ input_resolution=img_size,
241
+ patch_size=14,
242
+ width=1024,
243
+ layers=23,
244
+ heads=16,
245
+ use_grad_checkpointing=use_checkpoint,
246
+ )
247
+ url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/clip_vit_L.pth"
248
+ cached_file = download_cached_file(url, check_hash=False, progress=True)
249
+ state_dict = torch.load(cached_file, map_location="cpu")
250
+ interpolate_pos_embed(model, state_dict)
251
+
252
+ incompatible_keys = model.load_state_dict(state_dict, strict=False)
253
+ # print(incompatible_keys)
254
+
255
+ if precision == "fp16":
256
+ convert_weights_to_fp16(model)
257
+ return model
models/seed_qformer/eva_vit.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on EVA, BEIT, timm and DeiT code bases
2
+ # https://github.com/baaivision/EVA
3
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm
4
+ # https://github.com/microsoft/unilm/tree/master/beit
5
+ # https://github.com/facebookresearch/deit/
6
+ # https://github.com/facebookresearch/dino
7
+ # --------------------------------------------------------'
8
+ import math
9
+ from functools import partial
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ import torch.utils.checkpoint as checkpoint
15
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
16
+
17
+
18
+ from .utils import download_cached_file
19
+
20
+
21
+ def _cfg(url='', **kwargs):
22
+ return {
23
+ 'url': url,
24
+ 'num_classes': 1000,
25
+ 'input_size': (3, 224, 224),
26
+ 'pool_size': None,
27
+ 'crop_pct': .9,
28
+ 'interpolation': 'bicubic',
29
+ 'mean': (0.5, 0.5, 0.5),
30
+ 'std': (0.5, 0.5, 0.5),
31
+ **kwargs
32
+ }
33
+
34
+
35
+ class DropPath(nn.Module):
36
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
37
+ """
38
+ def __init__(self, drop_prob=None):
39
+ super(DropPath, self).__init__()
40
+ self.drop_prob = drop_prob
41
+
42
+ def forward(self, x):
43
+ return drop_path(x, self.drop_prob, self.training)
44
+
45
+ def extra_repr(self) -> str:
46
+ return 'p={}'.format(self.drop_prob)
47
+
48
+
49
+ class Mlp(nn.Module):
50
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
51
+ super().__init__()
52
+ out_features = out_features or in_features
53
+ hidden_features = hidden_features or in_features
54
+ self.fc1 = nn.Linear(in_features, hidden_features)
55
+ self.act = act_layer()
56
+ self.fc2 = nn.Linear(hidden_features, out_features)
57
+ self.drop = nn.Dropout(drop)
58
+
59
+ def forward(self, x):
60
+ x = self.fc1(x)
61
+ x = self.act(x)
62
+ # x = self.drop(x)
63
+ # commit this for the orignal BERT implement
64
+ x = self.fc2(x)
65
+ x = self.drop(x)
66
+ return x
67
+
68
+
69
+ class Attention(nn.Module):
70
+ def __init__(self,
71
+ dim,
72
+ num_heads=8,
73
+ qkv_bias=False,
74
+ qk_scale=None,
75
+ attn_drop=0.,
76
+ proj_drop=0.,
77
+ window_size=None,
78
+ attn_head_dim=None):
79
+ super().__init__()
80
+ self.num_heads = num_heads
81
+ head_dim = dim // num_heads
82
+ if attn_head_dim is not None:
83
+ head_dim = attn_head_dim
84
+ all_head_dim = head_dim * self.num_heads
85
+ self.scale = qk_scale or head_dim**-0.5
86
+
87
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
88
+ if qkv_bias:
89
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
90
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
91
+ else:
92
+ self.q_bias = None
93
+ self.v_bias = None
94
+
95
+ if window_size:
96
+ self.window_size = window_size
97
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
98
+ self.relative_position_bias_table = nn.Parameter(torch.zeros(self.num_relative_distance,
99
+ num_heads)) # 2*Wh-1 * 2*Ww-1, nH
100
+ # cls to token & token 2 cls & cls to cls
101
+
102
+ # get pair-wise relative position index for each token inside the window
103
+ coords_h = torch.arange(window_size[0])
104
+ coords_w = torch.arange(window_size[1])
105
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
106
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
107
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
108
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
109
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
110
+ relative_coords[:, :, 1] += window_size[1] - 1
111
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
112
+ relative_position_index = \
113
+ torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
114
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
115
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
116
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
117
+ relative_position_index[0, 0] = self.num_relative_distance - 1
118
+
119
+ self.register_buffer("relative_position_index", relative_position_index)
120
+ else:
121
+ self.window_size = None
122
+ self.relative_position_bias_table = None
123
+ self.relative_position_index = None
124
+
125
+ self.attn_drop = nn.Dropout(attn_drop)
126
+ self.proj = nn.Linear(all_head_dim, dim)
127
+ self.proj_drop = nn.Dropout(proj_drop)
128
+
129
+ def forward(self, x, rel_pos_bias=None):
130
+ B, N, C = x.shape
131
+ qkv_bias = None
132
+ if self.q_bias is not None:
133
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
134
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
135
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
136
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
137
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
138
+
139
+ q = q * self.scale
140
+ attn = (q @ k.transpose(-2, -1))
141
+
142
+ if self.relative_position_bias_table is not None:
143
+ relative_position_bias = \
144
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
145
+ self.window_size[0] * self.window_size[1] + 1,
146
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
147
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
148
+ attn = attn + relative_position_bias.unsqueeze(0)
149
+
150
+ if rel_pos_bias is not None:
151
+ attn = attn + rel_pos_bias
152
+
153
+ attn = attn.softmax(dim=-1)
154
+ attn = self.attn_drop(attn)
155
+
156
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
157
+ x = self.proj(x)
158
+ x = self.proj_drop(x)
159
+ return x
160
+
161
+
162
+ class Block(nn.Module):
163
+ def __init__(self,
164
+ dim,
165
+ num_heads,
166
+ mlp_ratio=4.,
167
+ qkv_bias=False,
168
+ qk_scale=None,
169
+ drop=0.,
170
+ attn_drop=0.,
171
+ drop_path=0.,
172
+ init_values=None,
173
+ act_layer=nn.GELU,
174
+ norm_layer=nn.LayerNorm,
175
+ window_size=None,
176
+ attn_head_dim=None):
177
+ super().__init__()
178
+ self.norm1 = norm_layer(dim)
179
+ self.attn = Attention(dim,
180
+ num_heads=num_heads,
181
+ qkv_bias=qkv_bias,
182
+ qk_scale=qk_scale,
183
+ attn_drop=attn_drop,
184
+ proj_drop=drop,
185
+ window_size=window_size,
186
+ attn_head_dim=attn_head_dim)
187
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
188
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
189
+ self.norm2 = norm_layer(dim)
190
+ mlp_hidden_dim = int(dim * mlp_ratio)
191
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
192
+
193
+ if init_values is not None and init_values > 0:
194
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
195
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
196
+ else:
197
+ self.gamma_1, self.gamma_2 = None, None
198
+
199
+ def forward(self, x, rel_pos_bias=None):
200
+ if self.gamma_1 is None:
201
+ x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
202
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
203
+ else:
204
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
205
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
206
+ return x
207
+
208
+
209
+ class PatchEmbed(nn.Module):
210
+ """ Image to Patch Embedding
211
+ """
212
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
213
+ super().__init__()
214
+ img_size = to_2tuple(img_size)
215
+ patch_size = to_2tuple(patch_size)
216
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
217
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
218
+ self.img_size = img_size
219
+ self.patch_size = patch_size
220
+ self.num_patches = num_patches
221
+
222
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
223
+
224
+ def forward(self, x, **kwargs):
225
+ B, C, H, W = x.shape
226
+ # FIXME look at relaxing size constraints
227
+ assert H == self.img_size[0] and W == self.img_size[1], \
228
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
229
+ x = self.proj(x).flatten(2).transpose(1, 2)
230
+ return x
231
+
232
+
233
+ class RelativePositionBias(nn.Module):
234
+ def __init__(self, window_size, num_heads):
235
+ super().__init__()
236
+ self.window_size = window_size
237
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
238
+ self.relative_position_bias_table = nn.Parameter(torch.zeros(self.num_relative_distance,
239
+ num_heads)) # 2*Wh-1 * 2*Ww-1, nH
240
+ # cls to token & token 2 cls & cls to cls
241
+
242
+ # get pair-wise relative position index for each token inside the window
243
+ coords_h = torch.arange(window_size[0])
244
+ coords_w = torch.arange(window_size[1])
245
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
246
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
247
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
248
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
249
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
250
+ relative_coords[:, :, 1] += window_size[1] - 1
251
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
252
+ relative_position_index = \
253
+ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
254
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
255
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
256
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
257
+ relative_position_index[0, 0] = self.num_relative_distance - 1
258
+
259
+ self.register_buffer("relative_position_index", relative_position_index)
260
+
261
+ # trunc_normal_(self.relative_position_bias_table, std=.02)
262
+
263
+ def forward(self):
264
+ relative_position_bias = \
265
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
266
+ self.window_size[0] * self.window_size[1] + 1,
267
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
268
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
269
+
270
+
271
+ class VisionTransformer(nn.Module):
272
+ """ Vision Transformer with support for patch or hybrid CNN input stage
273
+ """
274
+ def __init__(self,
275
+ img_size=224,
276
+ patch_size=16,
277
+ in_chans=3,
278
+ num_classes=1000,
279
+ embed_dim=768,
280
+ depth=12,
281
+ num_heads=12,
282
+ mlp_ratio=4.,
283
+ qkv_bias=False,
284
+ qk_scale=None,
285
+ drop_rate=0.,
286
+ attn_drop_rate=0.,
287
+ drop_path_rate=0.,
288
+ norm_layer=nn.LayerNorm,
289
+ init_values=None,
290
+ use_abs_pos_emb=True,
291
+ use_rel_pos_bias=False,
292
+ use_shared_rel_pos_bias=False,
293
+ use_mean_pooling=True,
294
+ init_scale=0.001,
295
+ use_checkpoint=False):
296
+ super().__init__()
297
+ self.image_size = img_size
298
+ self.num_classes = num_classes
299
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
300
+
301
+ self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
302
+ num_patches = self.patch_embed.num_patches
303
+
304
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
305
+ if use_abs_pos_emb:
306
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
307
+ else:
308
+ self.pos_embed = None
309
+ self.pos_drop = nn.Dropout(p=drop_rate)
310
+
311
+ if use_shared_rel_pos_bias:
312
+ self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
313
+ else:
314
+ self.rel_pos_bias = None
315
+ self.use_checkpoint = use_checkpoint
316
+
317
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
318
+ self.use_rel_pos_bias = use_rel_pos_bias
319
+ self.blocks = nn.ModuleList([
320
+ Block(dim=embed_dim,
321
+ num_heads=num_heads,
322
+ mlp_ratio=mlp_ratio,
323
+ qkv_bias=qkv_bias,
324
+ qk_scale=qk_scale,
325
+ drop=drop_rate,
326
+ attn_drop=attn_drop_rate,
327
+ drop_path=dpr[i],
328
+ norm_layer=norm_layer,
329
+ init_values=init_values,
330
+ window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None) for i in range(depth)
331
+ ])
332
+ # self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
333
+ # self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
334
+ # self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
335
+
336
+ if self.pos_embed is not None:
337
+ trunc_normal_(self.pos_embed, std=.02)
338
+ trunc_normal_(self.cls_token, std=.02)
339
+ # trunc_normal_(self.mask_token, std=.02)
340
+ # if isinstance(self.head, nn.Linear):
341
+ # trunc_normal_(self.head.weight, std=.02)
342
+ self.apply(self._init_weights)
343
+ self.fix_init_weight()
344
+
345
+ def fix_init_weight(self):
346
+ def rescale(param, layer_id):
347
+ param.div_(math.sqrt(2.0 * layer_id))
348
+
349
+ for layer_id, layer in enumerate(self.blocks):
350
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
351
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
352
+
353
+ def _init_weights(self, m):
354
+ if isinstance(m, nn.Linear):
355
+ trunc_normal_(m.weight, std=.02)
356
+ if isinstance(m, nn.Linear) and m.bias is not None:
357
+ nn.init.constant_(m.bias, 0)
358
+ elif isinstance(m, nn.LayerNorm):
359
+ nn.init.constant_(m.bias, 0)
360
+ nn.init.constant_(m.weight, 1.0)
361
+
362
+ def get_classifier(self):
363
+ return self.head
364
+
365
+ def reset_classifier(self, num_classes, global_pool=''):
366
+ self.num_classes = num_classes
367
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
368
+
369
+ def forward_features(self, x):
370
+ x = self.patch_embed(x)
371
+ batch_size, seq_len, _ = x.size()
372
+
373
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
374
+ x = torch.cat((cls_tokens, x), dim=1)
375
+ if self.pos_embed is not None:
376
+ x = x + self.pos_embed
377
+ x = self.pos_drop(x)
378
+
379
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
380
+ for blk in self.blocks:
381
+ if self.use_checkpoint:
382
+ x = checkpoint.checkpoint(blk, x, rel_pos_bias)
383
+ else:
384
+ x = blk(x, rel_pos_bias)
385
+ return x
386
+
387
+ def forward(self, x):
388
+ x = self.forward_features(x)
389
+ # x = self.head(x)
390
+ return x
391
+
392
+ def get_intermediate_layers(self, x):
393
+ x = self.patch_embed(x)
394
+ batch_size, seq_len, _ = x.size()
395
+
396
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
397
+ x = torch.cat((cls_tokens, x), dim=1)
398
+ if self.pos_embed is not None:
399
+ x = x + self.pos_embed
400
+ x = self.pos_drop(x)
401
+
402
+ features = []
403
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
404
+ for blk in self.blocks:
405
+ x = blk(x, rel_pos_bias)
406
+ features.append(x)
407
+
408
+ return features
409
+
410
+ def get_num_layer(self, var_name=""):
411
+ if var_name in ("cls_token", "mask_token", "pos_embed"):
412
+ return 0
413
+ elif var_name.startswith("patch_embed"):
414
+ return 0
415
+ elif var_name.startswith("rel_pos_bias"):
416
+ return len(self.blocks) - 1
417
+ elif var_name.startswith("blocks"):
418
+ layer_id = int(var_name.split('.')[1])
419
+ return layer_id + 1
420
+ else:
421
+ return len(self.blocks)
422
+
423
+
424
+ def interpolate_pos_embed(model, checkpoint_model):
425
+ if 'pos_embed' in checkpoint_model:
426
+ pos_embed_checkpoint = checkpoint_model['pos_embed'].float()
427
+ embedding_size = pos_embed_checkpoint.shape[-1]
428
+ num_patches = model.patch_embed.num_patches
429
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
430
+ # height (== width) for the checkpoint position embedding
431
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens)**0.5)
432
+ # height (== width) for the new position embedding
433
+ new_size = int(num_patches**0.5)
434
+ # class_token and dist_token are kept unchanged
435
+ if orig_size != new_size:
436
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
437
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
438
+ # only the position tokens are interpolated
439
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
440
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
441
+ pos_tokens = torch.nn.functional.interpolate(pos_tokens,
442
+ size=(new_size, new_size),
443
+ mode='bicubic',
444
+ align_corners=False)
445
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
446
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
447
+ checkpoint_model['pos_embed'] = new_pos_embed
448
+
449
+
450
+ def convert_weights_to_fp16(model: nn.Module):
451
+ """Convert applicable model parameters to fp16"""
452
+ def _convert_weights_to_fp16(l):
453
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
454
+ l.weight.data = l.weight.data.half()
455
+ if l.bias is not None:
456
+ l.bias.data = l.bias.data.half()
457
+
458
+ model.apply(_convert_weights_to_fp16)
459
+
460
+
461
+ def create_eva_vit_g(img_size=224, drop_path_rate=0.4, use_checkpoint=False, precision="fp16"):
462
+ model = VisionTransformer(
463
+ img_size=img_size,
464
+ patch_size=14,
465
+ use_mean_pooling=False,
466
+ embed_dim=1408,
467
+ depth=39,
468
+ num_heads=1408 // 88,
469
+ mlp_ratio=4.3637,
470
+ qkv_bias=True,
471
+ drop_path_rate=drop_path_rate,
472
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
473
+ use_checkpoint=use_checkpoint,
474
+ )
475
+ url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth"
476
+ cached_file = download_cached_file(url, check_hash=False, progress=True)
477
+ state_dict = torch.load(cached_file, map_location="cpu")
478
+ interpolate_pos_embed(model, state_dict)
479
+
480
+ incompatible_keys = model.load_state_dict(state_dict, strict=False)
481
+ # print(incompatible_keys)
482
+
483
+ if precision == "fp16":
484
+ # model.to("cuda")
485
+ convert_weights_to_fp16(model)
486
+ return model
models/seed_qformer/qformer_causual.py ADDED
@@ -0,0 +1,1169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ * Copyright (c) 2023, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ * Based on huggingface code base
8
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
9
+ """
10
+
11
+ import math
12
+ import os
13
+ import warnings
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple, Dict, Any
16
+
17
+ import torch
18
+ from torch import Tensor, device, dtype, nn
19
+ import torch.utils.checkpoint
20
+ from torch.nn import CrossEntropyLoss
21
+ import torch.nn.functional as F
22
+ import numpy as np
23
+
24
+ from transformers.activations import ACT2FN
25
+ from transformers.file_utils import (
26
+ ModelOutput, )
27
+ from transformers.modeling_outputs import (
28
+ BaseModelOutputWithPastAndCrossAttentions,
29
+ BaseModelOutputWithPoolingAndCrossAttentions,
30
+ CausalLMOutputWithCrossAttentions,
31
+ MaskedLMOutput,
32
+ MultipleChoiceModelOutput,
33
+ NextSentencePredictorOutput,
34
+ QuestionAnsweringModelOutput,
35
+ SequenceClassifierOutput,
36
+ TokenClassifierOutput,
37
+ )
38
+ from transformers.modeling_utils import (
39
+ PreTrainedModel,
40
+ apply_chunking_to_forward,
41
+ find_pruneable_heads_and_indices,
42
+ prune_linear_layer,
43
+ )
44
+ from transformers.utils import logging
45
+ from transformers.models.bert.configuration_bert import BertConfig
46
+
47
+ #torch.set_printoptions(profile="full")
48
+ logger = logging.get_logger(__name__)
49
+
50
+
51
+ class BertEmbeddings(nn.Module):
52
+ """Construct the embeddings from word and position embeddings."""
53
+ def __init__(self, config):
54
+ super().__init__()
55
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
56
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
57
+
58
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
59
+ # any TensorFlow checkpoint file
60
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
61
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
62
+
63
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
64
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
65
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
66
+
67
+ self.config = config
68
+
69
+ def forward(
70
+ self,
71
+ input_ids=None,
72
+ position_ids=None,
73
+ query_embeds=None,
74
+ past_key_values_length=0,
75
+ ):
76
+ if input_ids is not None:
77
+ seq_length = input_ids.size()[1]
78
+ else:
79
+ seq_length = 0
80
+
81
+ if position_ids is None:
82
+ position_ids = self.position_ids[:, past_key_values_length:seq_length + past_key_values_length].clone()
83
+
84
+ if input_ids is not None:
85
+ embeddings = self.word_embeddings(input_ids)
86
+ if self.position_embedding_type == "absolute":
87
+ position_embeddings = self.position_embeddings(position_ids)
88
+ embeddings = embeddings + position_embeddings
89
+
90
+ if query_embeds is not None:
91
+ embeddings = torch.cat((query_embeds, embeddings), dim=1)
92
+ #print(query_embeds.shape, embeddings.shape)
93
+ else:
94
+ embeddings = query_embeds
95
+
96
+ embeddings = self.LayerNorm(embeddings)
97
+ embeddings = self.dropout(embeddings)
98
+ return embeddings
99
+
100
+
101
+ class BertSelfAttention(nn.Module):
102
+ def __init__(self, config, is_cross_attention):
103
+ super().__init__()
104
+ self.config = config
105
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
106
+ raise ValueError("The hidden size (%d) is not a multiple of the number of attention "
107
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads))
108
+
109
+ self.num_attention_heads = config.num_attention_heads
110
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
111
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
112
+
113
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
114
+ if is_cross_attention:
115
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
116
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
117
+ else:
118
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
119
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
120
+
121
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
122
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
123
+ if (self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query"):
124
+ self.max_position_embeddings = config.max_position_embeddings
125
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
126
+ self.save_attention = False
127
+
128
+ def save_attn_gradients(self, attn_gradients):
129
+ self.attn_gradients = attn_gradients
130
+
131
+ def get_attn_gradients(self):
132
+ return self.attn_gradients
133
+
134
+ def save_attention_map(self, attention_map):
135
+ self.attention_map = attention_map
136
+
137
+ def get_attention_map(self):
138
+ return self.attention_map
139
+
140
+ def transpose_for_scores(self, x):
141
+ new_x_shape = x.size()[:-1] + (
142
+ self.num_attention_heads,
143
+ self.attention_head_size,
144
+ )
145
+ x = x.view(*new_x_shape)
146
+ return x.permute(0, 2, 1, 3)
147
+
148
+ def forward(
149
+ self,
150
+ hidden_states,
151
+ attention_mask=None,
152
+ head_mask=None,
153
+ encoder_hidden_states=None,
154
+ encoder_attention_mask=None,
155
+ past_key_value=None,
156
+ output_attentions=False,
157
+ ):
158
+
159
+ # If this is instantiated as a cross-attention module, the keys
160
+ # and values come from an encoder; the attention mask needs to be
161
+ # such that the encoder's padding tokens are not attended to.
162
+ is_cross_attention = encoder_hidden_states is not None
163
+
164
+ if is_cross_attention:
165
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
166
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
167
+ #print(key_layer.shape, value_layer.shape)
168
+ attention_mask = encoder_attention_mask
169
+ elif past_key_value is not None:
170
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
171
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
172
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
173
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
174
+ #print(past_key_value[0].shape, key_layer.shape)
175
+ else:
176
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
177
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
178
+
179
+ mixed_query_layer = self.query(hidden_states)
180
+
181
+ query_layer = self.transpose_for_scores(mixed_query_layer)
182
+ # if past_key_value is not None:
183
+ # print(query_layer.shape)
184
+
185
+ past_key_value = (key_layer, value_layer)
186
+ #print(key_layer.shape, value_layer.shape)
187
+
188
+ # Take the dot product between "query" and "key" to get the raw attention scores.
189
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
190
+ #if is_cross_attention:
191
+ # if attention_scores.shape[2] == 32:
192
+ # attention_scores_save = attention_scores[0].detach().cpu().numpy()
193
+ # print(attention_scores_save.shape)
194
+ # np.save('attention_scores_causal_text_child.npy', attention_scores_save)
195
+
196
+ if (self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query"):
197
+ seq_length = hidden_states.size()[1]
198
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
199
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
200
+ distance = position_ids_l - position_ids_r
201
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
202
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
203
+
204
+ if self.position_embedding_type == "relative_key":
205
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
206
+ attention_scores = attention_scores + relative_position_scores
207
+ elif self.position_embedding_type == "relative_key_query":
208
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
209
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
210
+ attention_scores = (attention_scores + relative_position_scores_query + relative_position_scores_key)
211
+
212
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
213
+ if attention_mask is not None:
214
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
215
+ attention_scores = attention_scores + attention_mask
216
+
217
+ # Normalize the attention scores to probabilities.
218
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
219
+
220
+ if is_cross_attention and self.save_attention:
221
+ self.save_attention_map(attention_probs)
222
+ attention_probs.register_hook(self.save_attn_gradients)
223
+
224
+ # This is actually dropping out entire tokens to attend to, which might
225
+ # seem a bit unusual, but is taken from the original Transformer paper.
226
+ attention_probs_dropped = self.dropout(attention_probs)
227
+
228
+ # Mask heads if we want to
229
+ if head_mask is not None:
230
+ attention_probs_dropped = attention_probs_dropped * head_mask
231
+
232
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
233
+
234
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
235
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size, )
236
+ context_layer = context_layer.view(*new_context_layer_shape)
237
+
238
+ outputs = ((context_layer, attention_probs) if output_attentions else (context_layer, ))
239
+
240
+ outputs = outputs + (past_key_value, )
241
+ return outputs
242
+
243
+
244
+ class BertSelfOutput(nn.Module):
245
+ def __init__(self, config):
246
+ super().__init__()
247
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
248
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
249
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
250
+
251
+ def forward(self, hidden_states, input_tensor):
252
+ hidden_states = self.dense(hidden_states)
253
+ hidden_states = self.dropout(hidden_states)
254
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
255
+ return hidden_states
256
+
257
+
258
+ class BertAttention(nn.Module):
259
+ def __init__(self, config, is_cross_attention=False):
260
+ super().__init__()
261
+ self.self = BertSelfAttention(config, is_cross_attention)
262
+ self.output = BertSelfOutput(config)
263
+ self.pruned_heads = set()
264
+
265
+ def prune_heads(self, heads):
266
+ if len(heads) == 0:
267
+ return
268
+ heads, index = find_pruneable_heads_and_indices(
269
+ heads,
270
+ self.self.num_attention_heads,
271
+ self.self.attention_head_size,
272
+ self.pruned_heads,
273
+ )
274
+
275
+ # Prune linear layers
276
+ self.self.query = prune_linear_layer(self.self.query, index)
277
+ self.self.key = prune_linear_layer(self.self.key, index)
278
+ self.self.value = prune_linear_layer(self.self.value, index)
279
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
280
+
281
+ # Update hyper params and store pruned heads
282
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
283
+ self.self.all_head_size = (self.self.attention_head_size * self.self.num_attention_heads)
284
+ self.pruned_heads = self.pruned_heads.union(heads)
285
+
286
+ def forward(
287
+ self,
288
+ hidden_states,
289
+ attention_mask=None,
290
+ head_mask=None,
291
+ encoder_hidden_states=None,
292
+ encoder_attention_mask=None,
293
+ past_key_value=None,
294
+ output_attentions=False,
295
+ ):
296
+ self_outputs = self.self(
297
+ hidden_states,
298
+ attention_mask,
299
+ head_mask,
300
+ encoder_hidden_states,
301
+ encoder_attention_mask,
302
+ past_key_value,
303
+ output_attentions,
304
+ )
305
+ attention_output = self.output(self_outputs[0], hidden_states)
306
+
307
+ outputs = (attention_output, ) + self_outputs[1:] # add attentions if we output them
308
+ return outputs
309
+
310
+
311
+ class BertIntermediate(nn.Module):
312
+ def __init__(self, config):
313
+ super().__init__()
314
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
315
+ if isinstance(config.hidden_act, str):
316
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
317
+ else:
318
+ self.intermediate_act_fn = config.hidden_act
319
+
320
+ def forward(self, hidden_states):
321
+ hidden_states = self.dense(hidden_states)
322
+ hidden_states = self.intermediate_act_fn(hidden_states)
323
+ return hidden_states
324
+
325
+
326
+ class BertOutput(nn.Module):
327
+ def __init__(self, config):
328
+ super().__init__()
329
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
330
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
331
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
332
+
333
+ def forward(self, hidden_states, input_tensor):
334
+ hidden_states = self.dense(hidden_states)
335
+ hidden_states = self.dropout(hidden_states)
336
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
337
+ return hidden_states
338
+
339
+
340
+ class BertLayer(nn.Module):
341
+ def __init__(self, config, layer_num):
342
+ super().__init__()
343
+ self.config = config
344
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
345
+ self.seq_len_dim = 1
346
+ self.attention = BertAttention(config)
347
+ self.layer_num = layer_num
348
+ if (self.config.add_cross_attention and layer_num % self.config.cross_attention_freq == 0):
349
+ self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
350
+ self.has_cross_attention = True
351
+ else:
352
+ self.has_cross_attention = False
353
+ self.intermediate = BertIntermediate(config)
354
+ self.output = BertOutput(config)
355
+
356
+ self.intermediate_query = BertIntermediate(config)
357
+ self.output_query = BertOutput(config)
358
+
359
+ def forward(
360
+ self,
361
+ hidden_states,
362
+ attention_mask=None,
363
+ head_mask=None,
364
+ encoder_hidden_states=None,
365
+ encoder_attention_mask=None,
366
+ past_key_value=None,
367
+ output_attentions=False,
368
+ query_length=0,
369
+ ):
370
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
371
+ self_attn_past_key_value = (past_key_value[:2] if past_key_value is not None else None)
372
+ # if past_key_value is not None:
373
+ # print(hidden_states.shape, attention_mask.shape)
374
+ #print(hidden_states.shape, attention_mask.shape)
375
+ # casual attention for query embeds with self attention
376
+ self_attention_outputs = self.attention(
377
+ hidden_states,
378
+ attention_mask,
379
+ head_mask,
380
+ output_attentions=output_attentions,
381
+ past_key_value=self_attn_past_key_value,
382
+ )
383
+ #print('attention_mask', attention_mask.shape)
384
+ # if attention_mask.shape[-1] == 77:
385
+ # print('attention_mask', attention_mask[0])
386
+ attention_output = self_attention_outputs[0]
387
+ outputs = self_attention_outputs[1:-1]
388
+
389
+ present_key_value = self_attention_outputs[-1]
390
+ #print(present_key_value[0].shape)
391
+
392
+ if query_length > 0:
393
+ query_attention_output = attention_output[:, :query_length, :]
394
+
395
+ if self.has_cross_attention:
396
+ assert (encoder_hidden_states is not None), "encoder_hidden_states must be given for cross-attention layers"
397
+ #print(attention_mask.shape)
398
+ cross_attention_outputs = self.crossattention(
399
+ query_attention_output,
400
+ attention_mask,
401
+ head_mask,
402
+ encoder_hidden_states,
403
+ encoder_attention_mask,
404
+ output_attentions=output_attentions,
405
+ )
406
+ query_attention_output = cross_attention_outputs[0]
407
+ outputs = (outputs + cross_attention_outputs[1:-1]) # add cross attentions if we output attention weights
408
+
409
+ layer_output = apply_chunking_to_forward(
410
+ self.feed_forward_chunk_query,
411
+ self.chunk_size_feed_forward,
412
+ self.seq_len_dim,
413
+ query_attention_output,
414
+ )
415
+ if attention_output.shape[1] > query_length:
416
+ layer_output_text = apply_chunking_to_forward(
417
+ self.feed_forward_chunk,
418
+ self.chunk_size_feed_forward,
419
+ self.seq_len_dim,
420
+ attention_output[:, query_length:, :],
421
+ )
422
+ layer_output = torch.cat([layer_output, layer_output_text], dim=1)
423
+ else:
424
+ layer_output = apply_chunking_to_forward(
425
+ self.feed_forward_chunk,
426
+ self.chunk_size_feed_forward,
427
+ self.seq_len_dim,
428
+ attention_output,
429
+ )
430
+ outputs = (layer_output, ) + outputs
431
+
432
+ outputs = outputs + (present_key_value, )
433
+
434
+ return outputs
435
+
436
+ def feed_forward_chunk(self, attention_output):
437
+ intermediate_output = self.intermediate(attention_output)
438
+ layer_output = self.output(intermediate_output, attention_output)
439
+ return layer_output
440
+
441
+ def feed_forward_chunk_query(self, attention_output):
442
+ intermediate_output = self.intermediate_query(attention_output)
443
+ layer_output = self.output_query(intermediate_output, attention_output)
444
+ return layer_output
445
+
446
+
447
+ class BertEncoder(nn.Module):
448
+ def __init__(self, config):
449
+ super().__init__()
450
+ self.config = config
451
+ self.layer = nn.ModuleList([BertLayer(config, i) for i in range(config.num_hidden_layers)])
452
+
453
+ def forward(
454
+ self,
455
+ hidden_states,
456
+ attention_mask=None,
457
+ head_mask=None,
458
+ encoder_hidden_states=None,
459
+ encoder_attention_mask=None,
460
+ past_key_values=None,
461
+ use_cache=None,
462
+ output_attentions=False,
463
+ output_hidden_states=False,
464
+ return_dict=True,
465
+ query_length=0,
466
+ ):
467
+ all_hidden_states = () if output_hidden_states else None
468
+ all_self_attentions = () if output_attentions else None
469
+ all_cross_attentions = (() if output_attentions and self.config.add_cross_attention else None)
470
+
471
+ next_decoder_cache = () if use_cache else None
472
+
473
+ for i in range(self.config.num_hidden_layers):
474
+ layer_module = self.layer[i]
475
+ if output_hidden_states:
476
+ all_hidden_states = all_hidden_states + (hidden_states, )
477
+
478
+ layer_head_mask = head_mask[i] if head_mask is not None else None
479
+ past_key_value = past_key_values[i] if past_key_values is not None else None
480
+ # if past_key_value is not None:
481
+ # print(past_key_value[0].shape, past_key_value[1].shape)
482
+
483
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
484
+
485
+ if use_cache:
486
+ logger.warn("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
487
+ use_cache = False
488
+
489
+ def create_custom_forward(module):
490
+ def custom_forward(*inputs):
491
+ return module(*inputs, past_key_value, output_attentions, query_length)
492
+
493
+ return custom_forward
494
+
495
+ layer_outputs = torch.utils.checkpoint.checkpoint(
496
+ create_custom_forward(layer_module),
497
+ hidden_states,
498
+ attention_mask,
499
+ layer_head_mask,
500
+ encoder_hidden_states,
501
+ encoder_attention_mask,
502
+ )
503
+ else:
504
+ layer_outputs = layer_module(
505
+ hidden_states,
506
+ attention_mask,
507
+ layer_head_mask,
508
+ encoder_hidden_states,
509
+ encoder_attention_mask,
510
+ past_key_value,
511
+ output_attentions,
512
+ query_length,
513
+ )
514
+ # if past_key_value is not None:
515
+ # print(hidden_states.shape, attention_mask.shape)
516
+ # print(len(past_key_value))
517
+
518
+ hidden_states = layer_outputs[0]
519
+ if use_cache:
520
+ next_decoder_cache += (layer_outputs[-1], )
521
+ #print(layer_outputs[-1][0].shape)
522
+ if output_attentions:
523
+ all_self_attentions = all_self_attentions + (layer_outputs[1], )
524
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2], )
525
+
526
+ if output_hidden_states:
527
+ all_hidden_states = all_hidden_states + (hidden_states, )
528
+
529
+ if not return_dict:
530
+ return tuple(v for v in [
531
+ hidden_states,
532
+ next_decoder_cache,
533
+ all_hidden_states,
534
+ all_self_attentions,
535
+ all_cross_attentions,
536
+ ] if v is not None)
537
+ return BaseModelOutputWithPastAndCrossAttentions(
538
+ last_hidden_state=hidden_states,
539
+ past_key_values=next_decoder_cache,
540
+ hidden_states=all_hidden_states,
541
+ attentions=all_self_attentions,
542
+ cross_attentions=all_cross_attentions,
543
+ )
544
+
545
+
546
+ class BertPooler(nn.Module):
547
+ def __init__(self, config):
548
+ super().__init__()
549
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
550
+ self.activation = nn.Tanh()
551
+
552
+ def forward(self, hidden_states):
553
+ # We "pool" the model by simply taking the hidden state corresponding
554
+ # to the first token.
555
+ first_token_tensor = hidden_states[:, 0]
556
+ pooled_output = self.dense(first_token_tensor)
557
+ pooled_output = self.activation(pooled_output)
558
+ return pooled_output
559
+
560
+
561
+ class BertPredictionHeadTransform(nn.Module):
562
+ def __init__(self, config):
563
+ super().__init__()
564
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
565
+ if isinstance(config.hidden_act, str):
566
+ self.transform_act_fn = ACT2FN[config.hidden_act]
567
+ else:
568
+ self.transform_act_fn = config.hidden_act
569
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
570
+
571
+ def forward(self, hidden_states):
572
+ hidden_states = self.dense(hidden_states)
573
+ hidden_states = self.transform_act_fn(hidden_states)
574
+ hidden_states = self.LayerNorm(hidden_states)
575
+ return hidden_states
576
+
577
+
578
+ class BertLMPredictionHead(nn.Module):
579
+ def __init__(self, config):
580
+ super().__init__()
581
+ self.transform = BertPredictionHeadTransform(config)
582
+
583
+ # The output weights are the same as the input embeddings, but there is
584
+ # an output-only bias for each token.
585
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
586
+
587
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
588
+
589
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
590
+ self.decoder.bias = self.bias
591
+
592
+ def forward(self, hidden_states):
593
+ hidden_states = self.transform(hidden_states)
594
+ hidden_states = self.decoder(hidden_states)
595
+ return hidden_states
596
+
597
+
598
+ class BertOnlyMLMHead(nn.Module):
599
+ def __init__(self, config):
600
+ super().__init__()
601
+ self.predictions = BertLMPredictionHead(config)
602
+
603
+ def forward(self, sequence_output):
604
+ prediction_scores = self.predictions(sequence_output)
605
+ return prediction_scores
606
+
607
+
608
+ class BertPreTrainedModel(PreTrainedModel):
609
+ """
610
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
611
+ models.
612
+ """
613
+
614
+ config_class = BertConfig
615
+ base_model_prefix = "bert"
616
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
617
+
618
+ def _init_weights(self, module):
619
+ """Initialize the weights"""
620
+ if isinstance(module, (nn.Linear, nn.Embedding)):
621
+ # Slightly different from the TF version which uses truncated_normal for initialization
622
+ # cf https://github.com/pytorch/pytorch/pull/5617
623
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
624
+ elif isinstance(module, nn.LayerNorm):
625
+ module.bias.data.zero_()
626
+ module.weight.data.fill_(1.0)
627
+ if isinstance(module, nn.Linear) and module.bias is not None:
628
+ module.bias.data.zero_()
629
+
630
+
631
+ class BertModel(BertPreTrainedModel):
632
+ """
633
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
634
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
635
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
636
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
637
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
638
+ input to the forward pass.
639
+ """
640
+ def __init__(self, config, add_pooling_layer=False):
641
+ super().__init__(config)
642
+ self.config = config
643
+
644
+ self.embeddings = BertEmbeddings(config)
645
+
646
+ self.encoder = BertEncoder(config)
647
+
648
+ self.pooler = BertPooler(config) if add_pooling_layer else None
649
+
650
+ self.init_weights()
651
+
652
+ def get_input_embeddings(self):
653
+ return self.embeddings.word_embeddings
654
+
655
+ def set_input_embeddings(self, value):
656
+ self.embeddings.word_embeddings = value
657
+
658
+ def _prune_heads(self, heads_to_prune):
659
+ """
660
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
661
+ class PreTrainedModel
662
+ """
663
+ for layer, heads in heads_to_prune.items():
664
+ self.encoder.layer[layer].attention.prune_heads(heads)
665
+
666
+ def get_extended_attention_mask(
667
+ self,
668
+ attention_mask: Tensor,
669
+ input_shape: Tuple[int],
670
+ device: device,
671
+ is_decoder: bool,
672
+ is_casual: bool,
673
+ has_query: bool = False,
674
+ ) -> Tensor:
675
+ """
676
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
677
+
678
+ Arguments:
679
+ attention_mask (:obj:`torch.Tensor`):
680
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
681
+ input_shape (:obj:`Tuple[int]`):
682
+ The shape of the input to the model.
683
+ device: (:obj:`torch.device`):
684
+ The device of the input to the model.
685
+
686
+ Returns:
687
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
688
+ """
689
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
690
+ # ourselves in which case we just need to make it broadcastable to all heads.
691
+ #print(attention_mask.dim())
692
+ if attention_mask.dim() == 3:
693
+ extended_attention_mask = attention_mask[:, None, :, :]
694
+ elif attention_mask.dim() == 2:
695
+ # Provided a padding mask of dimensions [batch_size, seq_length]
696
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
697
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
698
+ if is_decoder or is_casual:
699
+ batch_size, seq_length = input_shape
700
+ #print(input_shape)
701
+ if not is_decoder and seq_length > 32:
702
+ query_length = 32
703
+ text_length = seq_length - query_length
704
+ query_ids = torch.arange(query_length, device=device)
705
+ query_causal_mask = (query_ids[None, None, :].repeat(batch_size, query_length, 1) <= query_ids[None, :,
706
+ None])
707
+ causal_mask = torch.ones((batch_size, seq_length, seq_length), device=device)
708
+ causal_mask[:, :query_length, :query_length] = query_causal_mask
709
+ # print(query_causal_mask.shape, causal_mask.shape)
710
+ #print(causal_mask[0])
711
+
712
+ else:
713
+ seq_ids = torch.arange(seq_length, device=device)
714
+ causal_mask = (seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None])
715
+
716
+ # add a prefix ones mask to the causal mask
717
+ # causal and attention masks must have same type with pytorch version < 1.3
718
+ causal_mask = causal_mask.to(attention_mask.dtype)
719
+ # if is_decoder:
720
+ # print(causal_mask.shape, attention_mask.shape)
721
+ #print(causal_mask.shape, attention_mask.shape)
722
+
723
+ if causal_mask.shape[1] < attention_mask.shape[1]:
724
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
725
+ if has_query: # UniLM style attention mask
726
+ causal_mask = torch.cat(
727
+ [
728
+ torch.zeros(
729
+ (batch_size, prefix_seq_len, seq_length),
730
+ device=device,
731
+ dtype=causal_mask.dtype,
732
+ ),
733
+ causal_mask,
734
+ ],
735
+ axis=1,
736
+ )
737
+ causal_mask = torch.cat(
738
+ [
739
+ torch.ones(
740
+ (batch_size, causal_mask.shape[1], prefix_seq_len),
741
+ device=device,
742
+ dtype=causal_mask.dtype,
743
+ ),
744
+ causal_mask,
745
+ ],
746
+ axis=-1,
747
+ )
748
+ #print(has_query, causal_mask.shape)
749
+ #print(causal_mask[0])
750
+ extended_attention_mask = (causal_mask[:, None, :, :] * attention_mask[:, None, None, :])
751
+ #print(extended_attention_mask[0])
752
+ #print('extended_attention_mask', extended_attention_mask.shape)
753
+ else:
754
+ extended_attention_mask = attention_mask[:, None, None, :]
755
+ #print(attention_mask.shape, extended_attention_mask.shape)
756
+ else:
757
+ raise ValueError("Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
758
+ input_shape, attention_mask.shape))
759
+
760
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
761
+ # masked positions, this operation will create a tensor which is 0.0 for
762
+ # positions we want to attend and -10000.0 for masked positions.
763
+ # Since we are adding it to the raw scores before the softmax, this is
764
+ # effectively the same as removing these entirely.
765
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
766
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
767
+ return extended_attention_mask
768
+
769
+ def forward(
770
+ self,
771
+ input_ids=None,
772
+ attention_mask=None,
773
+ position_ids=None,
774
+ head_mask=None,
775
+ query_embeds=None,
776
+ encoder_hidden_states=None,
777
+ encoder_attention_mask=None,
778
+ past_key_values=None,
779
+ use_cache=None,
780
+ output_attentions=None,
781
+ output_hidden_states=None,
782
+ return_dict=None,
783
+ is_decoder=False,
784
+ ):
785
+ r"""
786
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
787
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
788
+ the model is configured as a decoder.
789
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
790
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
791
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
792
+ - 1 for tokens that are **not masked**,
793
+ - 0 for tokens that are **masked**.
794
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
795
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
796
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
797
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
798
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
799
+ use_cache (:obj:`bool`, `optional`):
800
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
801
+ decoding (see :obj:`past_key_values`).
802
+ """
803
+ output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions)
804
+ output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
805
+ return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
806
+
807
+ # use_cache = use_cache if use_cache is not None else self.config.use_cache
808
+
809
+ if input_ids is None:
810
+ assert (query_embeds is not None), "You have to specify query_embeds when input_ids is None"
811
+
812
+ #if query_embeds is not None:
813
+ if query_embeds is not None and query_embeds.shape[1] == 32:
814
+ is_casual = True
815
+ else:
816
+ is_casual = False
817
+ past_key_values_length = (past_key_values[0][0].shape[2] -
818
+ self.config.query_length if past_key_values is not None else 0)
819
+
820
+ query_length = query_embeds.shape[1] if query_embeds is not None else 0
821
+
822
+ embedding_output = self.embeddings(
823
+ input_ids=input_ids,
824
+ position_ids=position_ids,
825
+ query_embeds=query_embeds,
826
+ past_key_values_length=past_key_values_length,
827
+ )
828
+
829
+ input_shape = embedding_output.size()[:-1]
830
+ batch_size, seq_length = input_shape
831
+ device = embedding_output.device
832
+
833
+ #print('attention_mask', attention_mask)
834
+ if attention_mask is None:
835
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
836
+ #print(seq_length, past_key_values_length)
837
+
838
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
839
+ # ourselves in which case we just need to make it broadcastable to all heads.
840
+ if is_decoder:
841
+ #print(attention_mask.shape, input_ids.shape)
842
+ extended_attention_mask = self.get_extended_attention_mask(
843
+ attention_mask,
844
+ input_ids.shape,
845
+ device,
846
+ is_decoder,
847
+ is_casual,
848
+ has_query=(query_embeds is not None),
849
+ )
850
+ else:
851
+ extended_attention_mask = self.get_extended_attention_mask(
852
+ attention_mask,
853
+ input_shape,
854
+ device,
855
+ is_decoder,
856
+ is_casual,
857
+ )
858
+ #print(is_decoder, extended_attention_mask.shape)
859
+ # if is_decoder:
860
+ # print(extended_attention_mask[0,0,:,32:])
861
+ # if attention_mask is not None:
862
+ # print(input_ids, embedding_output.shape, extended_attention_mask.shape)
863
+
864
+ # If a 2D or 3D attention mask is provided for the cross-attention
865
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
866
+ if encoder_hidden_states is not None:
867
+ if type(encoder_hidden_states) == list:
868
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
869
+ else:
870
+ (
871
+ encoder_batch_size,
872
+ encoder_sequence_length,
873
+ _,
874
+ ) = encoder_hidden_states.size()
875
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
876
+
877
+ if type(encoder_attention_mask) == list:
878
+ encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
879
+ elif encoder_attention_mask is None:
880
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
881
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
882
+ else:
883
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
884
+ #print(is_casual, extended_attention_mask.shape, encoder_attention_mask.shape, encoder_extended_attention_mask.shape)
885
+ else:
886
+ encoder_extended_attention_mask = None
887
+
888
+ # if input_ids is not None and query_embeds is not None:
889
+ # print(extended_attention_mask.shape, encoder_extended_attention_mask.shape)
890
+ # Prepare head mask if needed
891
+ # 1.0 in head_mask indicate we keep the head
892
+ # attention_probs has shape bsz x n_heads x N x N
893
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
894
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
895
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
896
+ #print(head_mask)
897
+
898
+ encoder_outputs = self.encoder(
899
+ embedding_output,
900
+ attention_mask=extended_attention_mask,
901
+ head_mask=head_mask,
902
+ encoder_hidden_states=encoder_hidden_states,
903
+ encoder_attention_mask=encoder_extended_attention_mask,
904
+ past_key_values=past_key_values,
905
+ use_cache=use_cache,
906
+ output_attentions=output_attentions,
907
+ output_hidden_states=output_hidden_states,
908
+ return_dict=return_dict,
909
+ query_length=query_length,
910
+ )
911
+ # if is_decoder:
912
+ # print(embedding_output.shape, attention_mask.shape, len(past_key_values))
913
+ #print(embedding_output.shape, extended_attention_mask.shape, encoder_hidden_states.shape, encoder_extended_attention_mask.shape)
914
+ #print(extended_attention_mask[0], encoder_extended_attention_mask[0])
915
+
916
+ #print(query_embeds.shape, encoder_hidden_states.shape)
917
+
918
+ sequence_output = encoder_outputs[0]
919
+ pooled_output = (self.pooler(sequence_output) if self.pooler is not None else None)
920
+
921
+ if not return_dict:
922
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
923
+
924
+ return BaseModelOutputWithPoolingAndCrossAttentions(
925
+ last_hidden_state=sequence_output,
926
+ pooler_output=pooled_output,
927
+ past_key_values=encoder_outputs.past_key_values,
928
+ hidden_states=encoder_outputs.hidden_states,
929
+ attentions=encoder_outputs.attentions,
930
+ cross_attentions=encoder_outputs.cross_attentions,
931
+ )
932
+
933
+
934
+ class BertLMHeadModel(BertPreTrainedModel):
935
+
936
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
937
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
938
+
939
+ def __init__(self, config):
940
+ super().__init__(config)
941
+
942
+ self.bert = BertModel(config, add_pooling_layer=False)
943
+ self.cls = BertOnlyMLMHead(config)
944
+
945
+ self.init_weights()
946
+
947
+ def get_output_embeddings(self):
948
+ return self.cls.predictions.decoder
949
+
950
+ def set_output_embeddings(self, new_embeddings):
951
+ self.cls.predictions.decoder = new_embeddings
952
+
953
+ def forward(
954
+ self,
955
+ input_ids=None,
956
+ attention_mask=None,
957
+ position_ids=None,
958
+ head_mask=None,
959
+ query_embeds=None,
960
+ encoder_hidden_states=None,
961
+ encoder_attention_mask=None,
962
+ labels=None,
963
+ past_key_values=None,
964
+ use_cache=True,
965
+ output_attentions=None,
966
+ output_hidden_states=None,
967
+ return_dict=None,
968
+ return_logits=False,
969
+ is_decoder=True,
970
+ reduction="mean",
971
+ ):
972
+ r"""
973
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
974
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
975
+ the model is configured as a decoder.
976
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
977
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
978
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
979
+ - 1 for tokens that are **not masked**,
980
+ - 0 for tokens that are **masked**.
981
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
982
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
983
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
984
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
985
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
986
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
987
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
988
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
989
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
990
+ use_cache (:obj:`bool`, `optional`):
991
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
992
+ decoding (see :obj:`past_key_values`).
993
+ Returns:
994
+ Example::
995
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
996
+ >>> import torch
997
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
998
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
999
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
1000
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1001
+ >>> outputs = model(**inputs)
1002
+ >>> prediction_logits = outputs.logits
1003
+ """
1004
+ return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
1005
+ if labels is not None:
1006
+ use_cache = False
1007
+ if past_key_values is not None:
1008
+ query_embeds = None
1009
+ #print(len(past_key_values))
1010
+ #print('attention_mask', attention_mask)
1011
+ outputs = self.bert(
1012
+ input_ids,
1013
+ attention_mask=attention_mask,
1014
+ position_ids=position_ids,
1015
+ head_mask=head_mask,
1016
+ query_embeds=query_embeds,
1017
+ encoder_hidden_states=encoder_hidden_states,
1018
+ encoder_attention_mask=encoder_attention_mask,
1019
+ past_key_values=past_key_values,
1020
+ use_cache=use_cache,
1021
+ output_attentions=output_attentions,
1022
+ output_hidden_states=output_hidden_states,
1023
+ return_dict=return_dict,
1024
+ is_decoder=is_decoder,
1025
+ )
1026
+
1027
+ sequence_output = outputs[0]
1028
+ if query_embeds is not None:
1029
+ sequence_output = outputs[0][:, query_embeds.shape[1]:, :]
1030
+
1031
+ prediction_scores = self.cls(sequence_output)
1032
+
1033
+ if return_logits:
1034
+ return prediction_scores[:, :-1, :].contiguous()
1035
+
1036
+ lm_loss = None
1037
+ if labels is not None:
1038
+ # we are doing next-token prediction; shift prediction scores and input ids by one
1039
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
1040
+ labels = labels[:, 1:].contiguous()
1041
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
1042
+ lm_loss = loss_fct(
1043
+ shifted_prediction_scores.view(-1, self.config.vocab_size),
1044
+ labels.view(-1),
1045
+ )
1046
+ if reduction == "none":
1047
+ lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
1048
+
1049
+ if not return_dict:
1050
+ output = (prediction_scores, ) + outputs[2:]
1051
+ return ((lm_loss, ) + output) if lm_loss is not None else output
1052
+
1053
+ return CausalLMOutputWithCrossAttentions(
1054
+ loss=lm_loss,
1055
+ logits=prediction_scores,
1056
+ past_key_values=outputs.past_key_values,
1057
+ hidden_states=outputs.hidden_states,
1058
+ attentions=outputs.attentions,
1059
+ cross_attentions=outputs.cross_attentions,
1060
+ )
1061
+
1062
+ def prepare_inputs_for_generation(self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs):
1063
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1064
+ if attention_mask is None:
1065
+ attention_mask = input_ids.new_ones(input_ids.shape)
1066
+ query_mask = input_ids.new_ones(query_embeds.shape[:-1])
1067
+ attention_mask = torch.cat([query_mask, attention_mask], dim=-1)
1068
+
1069
+ # cut decoder_input_ids if past is used
1070
+ if past is not None:
1071
+ input_ids = input_ids[:, -1:]
1072
+
1073
+ return {
1074
+ "input_ids": input_ids,
1075
+ "query_embeds": query_embeds,
1076
+ "attention_mask": attention_mask,
1077
+ "past_key_values": past,
1078
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
1079
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
1080
+ "is_decoder": True,
1081
+ }
1082
+
1083
+ def _reorder_cache(self, past, beam_idx):
1084
+ reordered_past = ()
1085
+ for layer_past in past:
1086
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past), )
1087
+ return reordered_past
1088
+
1089
+
1090
+ class BertForMaskedLM(BertPreTrainedModel):
1091
+
1092
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1093
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1094
+
1095
+ def __init__(self, config):
1096
+ super().__init__(config)
1097
+
1098
+ self.bert = BertModel(config, add_pooling_layer=False)
1099
+ self.cls = BertOnlyMLMHead(config)
1100
+
1101
+ self.init_weights()
1102
+
1103
+ def get_output_embeddings(self):
1104
+ return self.cls.predictions.decoder
1105
+
1106
+ def set_output_embeddings(self, new_embeddings):
1107
+ self.cls.predictions.decoder = new_embeddings
1108
+
1109
+ def forward(
1110
+ self,
1111
+ input_ids=None,
1112
+ attention_mask=None,
1113
+ position_ids=None,
1114
+ head_mask=None,
1115
+ query_embeds=None,
1116
+ encoder_hidden_states=None,
1117
+ encoder_attention_mask=None,
1118
+ labels=None,
1119
+ output_attentions=None,
1120
+ output_hidden_states=None,
1121
+ return_dict=None,
1122
+ return_logits=False,
1123
+ is_decoder=False,
1124
+ ):
1125
+ r"""
1126
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1127
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
1128
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
1129
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
1130
+ """
1131
+
1132
+ return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
1133
+
1134
+ outputs = self.bert(
1135
+ input_ids,
1136
+ attention_mask=attention_mask,
1137
+ position_ids=position_ids,
1138
+ head_mask=head_mask,
1139
+ query_embeds=query_embeds,
1140
+ encoder_hidden_states=encoder_hidden_states,
1141
+ encoder_attention_mask=encoder_attention_mask,
1142
+ output_attentions=output_attentions,
1143
+ output_hidden_states=output_hidden_states,
1144
+ return_dict=return_dict,
1145
+ is_decoder=is_decoder,
1146
+ )
1147
+
1148
+ if query_embeds is not None:
1149
+ sequence_output = outputs[0][:, query_embeds.shape[1]:, :]
1150
+ prediction_scores = self.cls(sequence_output)
1151
+
1152
+ if return_logits:
1153
+ return prediction_scores
1154
+
1155
+ masked_lm_loss = None
1156
+ if labels is not None:
1157
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1158
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1159
+
1160
+ if not return_dict:
1161
+ output = (prediction_scores, ) + outputs[2:]
1162
+ return (((masked_lm_loss, ) + output) if masked_lm_loss is not None else output)
1163
+
1164
+ return MaskedLMOutput(
1165
+ loss=masked_lm_loss,
1166
+ logits=prediction_scores,
1167
+ hidden_states=outputs.hidden_states,
1168
+ attentions=outputs.attentions,
1169
+ )
models/seed_qformer/qformer_quantizer.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2023, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+ import logging
8
+
9
+ import torch
10
+ import torch.distributed as dist
11
+ import torch.nn as nn
12
+ from torch.cuda.amp import autocast as autocast
13
+ from torch.nn import functional as F
14
+ import numpy as np
15
+ from functools import partial
16
+ from einops import rearrange
17
+
18
+ from .blip2 import Blip2Base, disabled_train
19
+ from .vit import Block
20
+ from .utils import download_cached_file, is_url
21
+
22
+ class VectorQuantizer2(nn.Module):
23
+ """
24
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
25
+ avoids costly matrix multiplications and allows for post-hoc remapping of indices.
26
+ """
27
+
28
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
29
+ # backwards compatibility we use the buggy version by default, but you can
30
+ # specify legacy=False to fix it.
31
+ def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True):
32
+ super().__init__()
33
+ self.n_e = n_e
34
+ self.e_dim = e_dim
35
+ self.beta = beta
36
+ self.legacy = legacy
37
+
38
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
39
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
40
+
41
+ self.remap = remap
42
+ if self.remap is not None:
43
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
44
+ self.re_embed = self.used.shape[0]
45
+ self.unknown_index = unknown_index # "random" or "extra" or integer
46
+ if self.unknown_index == "extra":
47
+ self.unknown_index = self.re_embed
48
+ self.re_embed = self.re_embed + 1
49
+ print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
50
+ f"Using {self.unknown_index} for unknown indices.")
51
+ else:
52
+ self.re_embed = n_e
53
+
54
+ self.sane_index_shape = sane_index_shape
55
+
56
+ def remap_to_used(self, inds):
57
+ ishape = inds.shape
58
+ assert len(ishape) > 1
59
+ inds = inds.reshape(ishape[0], -1)
60
+ used = self.used.to(inds)
61
+ match = (inds[:, :, None] == used[None, None, ...]).long()
62
+ new = match.argmax(-1)
63
+ unknown = match.sum(2) < 1
64
+ if self.unknown_index == "random":
65
+ new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
66
+ else:
67
+ new[unknown] = self.unknown_index
68
+ return new.reshape(ishape)
69
+
70
+ def unmap_to_all(self, inds):
71
+ ishape = inds.shape
72
+ assert len(ishape) > 1
73
+ inds = inds.reshape(ishape[0], -1)
74
+ used = self.used.to(inds)
75
+ if self.re_embed > self.used.shape[0]: # extra token
76
+ inds[inds >= self.used.shape[0]] = 0 # simply set to zero
77
+ back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
78
+ return back.reshape(ishape)
79
+
80
+ # def l2norm(self, t):
81
+ # return F.normalize(t, p = 2, dim = -1)
82
+
83
+ def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
84
+ assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
85
+ assert rescale_logits is False, "Only for interface compatible with Gumbel"
86
+ assert return_logits is False, "Only for interface compatible with Gumbel"
87
+ # reshape z -> (batch, height, width, channel) and flatten
88
+ #z = rearrange(z, 'b c h w -> b h w c').contiguous()
89
+ bz = z.shape[0]
90
+ z_flattened = z.view(-1, self.e_dim)
91
+ #print('z_flattened', z_flattened.shape)
92
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
93
+
94
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
95
+ torch.sum(self.embedding.weight**2, dim=1) - 2 * \
96
+ torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
97
+
98
+ min_encoding_indices = torch.argmin(d, dim=1)
99
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
100
+ perplexity = None
101
+ min_encodings = None
102
+
103
+ # compute loss for embedding
104
+ if not self.legacy:
105
+ loss = self.beta * torch.mean((z_q.detach() - z)**2) + torch.mean((z_q - z.detach())**2)
106
+ else:
107
+ loss = torch.mean((z_q.detach() - z)**2) + self.beta * torch.mean((z_q - z.detach())**2)
108
+
109
+ # preserve gradients
110
+ z_q = z + (z_q - z).detach()
111
+
112
+ # reshape back to match original input shape
113
+ #z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
114
+ z_q = z_q.reshape(bz, -1, z_q.shape[-1])
115
+ if self.remap is not None:
116
+ min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
117
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
118
+ min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
119
+
120
+ if self.sane_index_shape:
121
+ min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
122
+
123
+ return z_q, loss, min_encoding_indices
124
+
125
+ def get_codebook_entry(self, indices, shape=None):
126
+ # shape specifying (batch, height, width, channel)
127
+ if self.remap is not None:
128
+ indices = indices.reshape(shape[0], -1) # add batch axis
129
+ indices = self.unmap_to_all(indices)
130
+ indices = indices.reshape(-1) # flatten again
131
+
132
+ # get quantized latent vectors
133
+ z_q = self.embedding(indices)
134
+
135
+ if shape is not None:
136
+ z_q = z_q.view(shape)
137
+ # reshape back to match original input shape
138
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
139
+
140
+ return z_q
141
+
142
+
143
+ class Blip2QformerQuantizer(Blip2Base):
144
+ """
145
+ BLIP2 first-stage model with Q-former and ViT.
146
+ Supported model types:
147
+ - pretrained: pretrained model with vit-g
148
+ - pretrain_vitL: pretrained model with vit-large
149
+ - coco: fintuned model on coco
150
+ Usage:
151
+ >>> from lavis.models import load_model
152
+ >>> model = load_model("blip2", "pretrain")
153
+ """
154
+
155
+ PRETRAINED_MODEL_CONFIG_DICT = {
156
+ "pretrain": "configs/models/blip2/blip2_pretrain.yaml",
157
+ "pretrain_vitL": "configs/models/blip2/blip2_pretrain_vitL.yaml",
158
+ "coco": "configs/models/blip2/blip2_coco.yaml",
159
+ }
160
+
161
+ def __init__(self,
162
+ vit_model="eva_clip_g",
163
+ img_size=224,
164
+ drop_path_rate=0,
165
+ use_grad_checkpoint=False,
166
+ vit_precision="fp16",
167
+ freeze_vit=True,
168
+ num_query_token=32,
169
+ cross_attention_freq=2,
170
+ embed_dim=256,
171
+ max_txt_len=32,
172
+ codebook_embed_dim=32,
173
+ n_embed=8192,
174
+ recon_s=True,
175
+ blocks_for_image=True,
176
+ decode_depth=4,
177
+ use_recon_s_for_image=False,
178
+ use_qformer_image=False,
179
+ image_features_dim=1024):
180
+ super().__init__()
181
+
182
+ self.tokenizer = self.init_tokenizer()
183
+
184
+ self.visual_encoder, self.ln_vision = self.init_vision_encoder(vit_model, img_size, drop_path_rate, use_grad_checkpoint,
185
+ vit_precision)
186
+ if freeze_vit:
187
+ for name, param in self.visual_encoder.named_parameters():
188
+ param.requires_grad = False
189
+ self.visual_encoder = self.visual_encoder.eval()
190
+ self.visual_encoder.train = disabled_train
191
+ logging.info("freeze vision encoder")
192
+ self.ln_vision.weight.requires_grad = False
193
+ self.ln_vision.bias.requires_grad = False
194
+
195
+ self.codebook_embed_dim = codebook_embed_dim
196
+ self.n_embed = n_embed
197
+ self.recon_s = recon_s
198
+ self.blocks_for_image = blocks_for_image
199
+ self.use_recon_s_for_image = use_recon_s_for_image
200
+ self.depth = decode_depth
201
+ self.image_features_dim = image_features_dim
202
+ self.use_qformer_image = use_qformer_image
203
+
204
+ self.Qformer, self.query_tokens = self.init_Qformer(num_query_token, self.visual_encoder.num_features)
205
+
206
+ self.Qformer.cls = None
207
+ self.Qformer.bert.embeddings.word_embeddings = None
208
+ self.Qformer.bert.embeddings.position_embeddings = None
209
+ for layer in self.Qformer.bert.encoder.layer:
210
+ layer.output = None
211
+ layer.intermediate = None
212
+
213
+ for name, param in self.Qformer.named_parameters():
214
+ param.requires_grad = False
215
+ self.query_tokens.requires_grad = False
216
+
217
+ self.quantize = VectorQuantizer2(n_embed, codebook_embed_dim, beta=0.25, remap=None, sane_index_shape=False)
218
+
219
+ self.encode_task_layer = nn.Sequential(
220
+ nn.Linear(self.Qformer.config.hidden_size, self.Qformer.config.hidden_size),
221
+ nn.Tanh(),
222
+ nn.Linear(self.Qformer.config.hidden_size, codebook_embed_dim) # for quantize
223
+ )
224
+
225
+ self.decode_task_layer = nn.Sequential(
226
+ nn.Linear(codebook_embed_dim, codebook_embed_dim),
227
+ nn.Tanh(),
228
+ nn.Linear(codebook_embed_dim, self.Qformer.config.hidden_size) # for quantize
229
+ )
230
+
231
+ self.quantize = self.quantize.eval()
232
+ self.quantize.training = False
233
+ for name, param in self.named_parameters():
234
+ if 'quantize' in name or 'encode_task_layer' in name or 'decode_task_layer' in name:
235
+ #print('freeze params', name)
236
+ param.requires_grad = False
237
+
238
+ if self.recon_s:
239
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_query_token, self.Qformer.config.hidden_size))
240
+ self.blocks = nn.ModuleList([
241
+ Block(dim=self.Qformer.config.hidden_size,
242
+ num_heads=12,
243
+ mlp_ratio=4.0,
244
+ qkv_bias=True,
245
+ qk_scale=None,
246
+ drop=0.0,
247
+ attn_drop=0.0,
248
+ drop_path=0.0,
249
+ norm_layer=partial(nn.LayerNorm, eps=1e-6)) for i in range(self.depth)
250
+ ])
251
+
252
+ if self.blocks_for_image:
253
+ self.pos_embed_image = nn.Parameter(torch.zeros(1, num_query_token, self.Qformer.config.hidden_size))
254
+ self.blocks_image = nn.ModuleList([
255
+ Block(dim=self.Qformer.config.hidden_size,
256
+ num_heads=12,
257
+ mlp_ratio=4.0,
258
+ qkv_bias=True,
259
+ qk_scale=None,
260
+ drop=0.0,
261
+ attn_drop=0.0,
262
+ drop_path=0.0,
263
+ norm_layer=partial(nn.LayerNorm, eps=1e-6)) for i in range(self.depth)
264
+ ])
265
+
266
+ if self.use_qformer_image:
267
+ num_reverse_token = 1
268
+ self.Reverse_Qformer, self.reverse_tokens = self.init_Qformer(num_reverse_token, self.Qformer.config.hidden_size)
269
+
270
+ self.Reverse_Qformer.cls = None
271
+ self.Reverse_Qformer.bert.embeddings.word_embeddings = None
272
+ self.Reverse_Qformer.bert.embeddings.position_embeddings = None
273
+ for layer in self.Reverse_Qformer.bert.encoder.layer:
274
+ layer.output = None
275
+ layer.intermediate = None
276
+ self.distill_image_proj = nn.Linear(self.Qformer.config.hidden_size, image_features_dim)
277
+
278
+ else:
279
+ self.image_down = nn.Sequential(
280
+ nn.Linear(self.Qformer.config.hidden_size, 256, bias=False),
281
+ nn.ReLU(),
282
+ nn.Linear(256, 128, bias=False),
283
+ nn.ReLU(),
284
+ nn.Linear(128, 32, bias=False),
285
+ )
286
+ self.distill_image_proj = nn.Linear(num_query_token * 32, image_features_dim)
287
+
288
+ def get_codebook_indices(self, image):
289
+ with torch.no_grad():
290
+ with self.maybe_autocast():
291
+ image_embeds = self.ln_vision(self.visual_encoder(image))
292
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
293
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
294
+ query_output = self.Qformer.bert(
295
+ query_embeds=query_tokens,
296
+ encoder_hidden_states=image_embeds,
297
+ encoder_attention_mask=image_atts,
298
+ return_dict=True,
299
+ )
300
+
301
+ query_output_down = self.encode_task_layer(query_output.last_hidden_state)
302
+ quant, loss_embed, embed_ind = self.quantize(query_output_down)
303
+ embed_ind = embed_ind.reshape(quant.shape[0], -1)
304
+
305
+ query_output_up = self.decode_task_layer(quant)
306
+
307
+ return embed_ind, query_output_up
308
+
309
+ def get_codebook_entry(self, indices):
310
+ quant_embedding = self.quantize.get_codebook_entry(indices)
311
+ # print('quant_embedding_shape: ', quant_embedding.shape)
312
+ # print(self.decode_task_layer)
313
+ # exit()
314
+ query_output_up = self.decode_task_layer(quant_embedding)
315
+
316
+ pos_embed_image = self.pos_embed_image.repeat(query_output_up.shape[0], 1, 1)
317
+ query_output_up_pos_image = query_output_up + pos_embed_image
318
+ for blk in self.blocks_image:
319
+ query_output_up_pos_image = blk(query_output_up_pos_image)
320
+ query_output_up = query_output_up_pos_image
321
+
322
+ if self.use_qformer_image:
323
+ query_atts = torch.ones(query_output_up.size()[:-1], dtype=torch.long).to(query_output_up.device)
324
+ reverse_tokens = self.reverse_tokens.expand(query_output_up.shape[0], -1, -1)
325
+ reverse_output = self.Reverse_Qformer.bert(
326
+ query_embeds=reverse_tokens,
327
+ encoder_hidden_states=query_output_up,
328
+ encoder_attention_mask=query_atts,
329
+ return_dict=True,
330
+ )
331
+ reverse_output = reverse_output.last_hidden_state
332
+ reverse_output_proj = self.distill_image_proj(reverse_output).squeeze(1)
333
+ else:
334
+ reverse_output = self.image_down(query_output_up)
335
+ reverse_output = reverse_output.reshape(reverse_output.shape[0], -1)
336
+ reverse_output_proj = self.distill_image_proj(reverse_output)
337
+
338
+ return reverse_output_proj
339
+
340
+ @classmethod
341
+ def from_pretrained(cls, pretrained_model_path, **kwargs):
342
+ vit_model = kwargs.get("vit_model", "eva_clip_g")
343
+ img_size = kwargs.get("image_size", 224)
344
+ num_query_token = kwargs.get("num_query_token", 32)
345
+ cross_attention_freq = kwargs.get("cross_attention_freq", 2)
346
+
347
+ drop_path_rate = kwargs.get("drop_path_rate", 0)
348
+ use_grad_checkpoint = kwargs.get("use_grad_checkpoint", False)
349
+ vit_precision = kwargs.get("vit_precision", "fp16")
350
+ freeze_vit = kwargs.get("freeze_vit", True)
351
+
352
+ max_txt_len = kwargs.get("max_txt_len", 32)
353
+
354
+ model = cls(
355
+ vit_model=vit_model,
356
+ img_size=img_size,
357
+ drop_path_rate=drop_path_rate,
358
+ use_grad_checkpoint=use_grad_checkpoint,
359
+ vit_precision=vit_precision,
360
+ freeze_vit=freeze_vit,
361
+ num_query_token=num_query_token,
362
+ cross_attention_freq=cross_attention_freq,
363
+ max_txt_len=max_txt_len,
364
+ )
365
+
366
+ if pretrained_model_path.startswith('http'):
367
+ print('start download seed model...')
368
+ cached_file = download_cached_file(pretrained_model_path, check_hash=False, progress=True)
369
+ print(cached_file)
370
+ ckpt = torch.load(cached_file, map_location="cpu")
371
+ else:
372
+ ckpt = torch.load(pretrained_model_path, map_location="cpu")
373
+ missing, unexcepted = model.load_state_dict(ckpt, strict=False)
374
+ print('missing keys: ', len(missing), 'unexpected keys:', len(unexcepted))
375
+ return model
models/seed_qformer/utils.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import datetime
9
+ import functools
10
+ import os
11
+
12
+ import torch
13
+ import torch.distributed as dist
14
+ import timm.models.hub as timm_hub
15
+ from urllib.parse import urlparse
16
+
17
+
18
+ def setup_for_distributed(is_master):
19
+ """
20
+ This function disables printing when not in master process
21
+ """
22
+ import builtins as __builtin__
23
+
24
+ builtin_print = __builtin__.print
25
+
26
+ def print(*args, **kwargs):
27
+ force = kwargs.pop("force", False)
28
+ if is_master or force:
29
+ builtin_print(*args, **kwargs)
30
+
31
+ __builtin__.print = print
32
+
33
+
34
+ def is_dist_avail_and_initialized():
35
+ if not dist.is_available():
36
+ return False
37
+ if not dist.is_initialized():
38
+ return False
39
+ return True
40
+
41
+
42
+ def get_world_size():
43
+ if not is_dist_avail_and_initialized():
44
+ return 1
45
+ return dist.get_world_size()
46
+
47
+
48
+ def get_rank():
49
+ if not is_dist_avail_and_initialized():
50
+ return 0
51
+ return dist.get_rank()
52
+
53
+
54
+ def is_main_process():
55
+ return get_rank() == 0
56
+
57
+
58
+ def init_distributed_mode(args):
59
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
60
+ args.rank = int(os.environ["RANK"])
61
+ args.world_size = int(os.environ["WORLD_SIZE"])
62
+ args.gpu = int(os.environ["LOCAL_RANK"])
63
+ elif "SLURM_PROCID" in os.environ:
64
+ args.rank = int(os.environ["SLURM_PROCID"])
65
+ args.gpu = args.rank % torch.cuda.device_count()
66
+ else:
67
+ print("Not using distributed mode")
68
+ args.distributed = False
69
+ return
70
+
71
+ args.distributed = True
72
+
73
+ torch.cuda.set_device(args.gpu)
74
+ args.dist_backend = "nccl"
75
+ print(
76
+ "| distributed init (rank {}, world {}): {}".format(args.rank, args.world_size, args.dist_url),
77
+ flush=True,
78
+ )
79
+ torch.distributed.init_process_group(
80
+ backend=args.dist_backend,
81
+ init_method=args.dist_url,
82
+ world_size=args.world_size,
83
+ rank=args.rank,
84
+ timeout=datetime.timedelta(days=365), # allow auto-downloading and de-compressing
85
+ )
86
+ torch.distributed.barrier()
87
+ setup_for_distributed(args.rank == 0)
88
+
89
+
90
+ def get_dist_info():
91
+ if torch.__version__ < "1.0":
92
+ initialized = dist._initialized
93
+ else:
94
+ initialized = dist.is_initialized()
95
+ if initialized:
96
+ rank = dist.get_rank()
97
+ world_size = dist.get_world_size()
98
+ else: # non-distributed training
99
+ rank = 0
100
+ world_size = 1
101
+ return rank, world_size
102
+
103
+
104
+ def main_process(func):
105
+ @functools.wraps(func)
106
+ def wrapper(*args, **kwargs):
107
+ rank, _ = get_dist_info()
108
+ if rank == 0:
109
+ return func(*args, **kwargs)
110
+
111
+ return wrapper
112
+
113
+
114
+ def download_cached_file(url, check_hash=True, progress=False):
115
+ """
116
+ Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
117
+ If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
118
+ """
119
+ def get_cached_file_path():
120
+ # a hack to sync the file path across processes
121
+ parts = torch.hub.urlparse(url)
122
+ filename = os.path.basename(parts.path)
123
+ cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
124
+
125
+ return cached_file
126
+
127
+ if is_main_process():
128
+ timm_hub.download_cached_file(url, check_hash, progress)
129
+
130
+ if is_dist_avail_and_initialized():
131
+ dist.barrier()
132
+
133
+ return get_cached_file_path()
134
+
135
+
136
+ def is_url(url_or_filename):
137
+ parsed = urlparse(url_or_filename)
138
+ return parsed.scheme in ("http", "https")
models/seed_qformer/vit.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+
7
+ Based on timm code base
8
+ https://github.com/rwightman/pytorch-image-models/tree/master/timm
9
+ """
10
+
11
+ import math
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from functools import partial
16
+
17
+ from timm.models.vision_transformer import _cfg, PatchEmbed
18
+ from timm.models.registry import register_model
19
+ from timm.models.layers import trunc_normal_, DropPath
20
+ from timm.models.helpers import named_apply, adapt_input_conv
21
+
22
+
23
+ class Mlp(nn.Module):
24
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
25
+ def __init__(
26
+ self,
27
+ in_features,
28
+ hidden_features=None,
29
+ out_features=None,
30
+ act_layer=nn.GELU,
31
+ drop=0.0,
32
+ ):
33
+ super().__init__()
34
+ out_features = out_features or in_features
35
+ hidden_features = hidden_features or in_features
36
+ self.fc1 = nn.Linear(in_features, hidden_features)
37
+ self.act = act_layer()
38
+ self.fc2 = nn.Linear(hidden_features, out_features)
39
+ self.drop = nn.Dropout(drop)
40
+
41
+ def forward(self, x):
42
+ x = self.fc1(x)
43
+ x = self.act(x)
44
+ x = self.drop(x)
45
+ x = self.fc2(x)
46
+ x = self.drop(x)
47
+ return x
48
+
49
+
50
+ class Attention(nn.Module):
51
+ def __init__(
52
+ self,
53
+ dim,
54
+ num_heads=8,
55
+ qkv_bias=False,
56
+ qk_scale=None,
57
+ attn_drop=0.0,
58
+ proj_drop=0.0,
59
+ ):
60
+ super().__init__()
61
+ self.num_heads = num_heads
62
+ head_dim = dim // num_heads
63
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
64
+ self.scale = qk_scale or head_dim**-0.5
65
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
66
+ self.attn_drop = nn.Dropout(attn_drop)
67
+ self.proj = nn.Linear(dim, dim)
68
+ self.proj_drop = nn.Dropout(proj_drop)
69
+ self.attn_gradients = None
70
+ self.attention_map = None
71
+
72
+ def save_attn_gradients(self, attn_gradients):
73
+ self.attn_gradients = attn_gradients
74
+
75
+ def get_attn_gradients(self):
76
+ return self.attn_gradients
77
+
78
+ def save_attention_map(self, attention_map):
79
+ self.attention_map = attention_map
80
+
81
+ def get_attention_map(self):
82
+ return self.attention_map
83
+
84
+ def forward(self, x, register_hook=False):
85
+ B, N, C = x.shape
86
+ qkv = (self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4))
87
+ q, k, v = (
88
+ qkv[0],
89
+ qkv[1],
90
+ qkv[2],
91
+ ) # make torchscript happy (cannot use tensor as tuple)
92
+
93
+ attn = (q @ k.transpose(-2, -1)) * self.scale
94
+ attn = attn.softmax(dim=-1)
95
+ attn = self.attn_drop(attn)
96
+
97
+ if register_hook:
98
+ self.save_attention_map(attn)
99
+ attn.register_hook(self.save_attn_gradients)
100
+
101
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
102
+ x = self.proj(x)
103
+ x = self.proj_drop(x)
104
+ return x
105
+
106
+
107
+ class Block(nn.Module):
108
+ def __init__(
109
+ self,
110
+ dim,
111
+ num_heads,
112
+ mlp_ratio=4.0,
113
+ qkv_bias=False,
114
+ qk_scale=None,
115
+ drop=0.0,
116
+ attn_drop=0.0,
117
+ drop_path=0.0,
118
+ act_layer=nn.GELU,
119
+ norm_layer=nn.LayerNorm,
120
+ use_grad_checkpointing=False,
121
+ ):
122
+ super().__init__()
123
+ self.norm1 = norm_layer(dim)
124
+ self.attn = Attention(
125
+ dim,
126
+ num_heads=num_heads,
127
+ qkv_bias=qkv_bias,
128
+ qk_scale=qk_scale,
129
+ attn_drop=attn_drop,
130
+ proj_drop=drop,
131
+ )
132
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
133
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
134
+ self.norm2 = norm_layer(dim)
135
+ mlp_hidden_dim = int(dim * mlp_ratio)
136
+ self.mlp = Mlp(
137
+ in_features=dim,
138
+ hidden_features=mlp_hidden_dim,
139
+ act_layer=act_layer,
140
+ drop=drop,
141
+ )
142
+
143
+ # if use_grad_checkpointing:
144
+ # self.attn = checkpoint_wrapper(self.attn)
145
+ # self.mlp = checkpoint_wrapper(self.mlp)
146
+
147
+ def forward(self, x, register_hook=False):
148
+ x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
149
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
150
+ return x
151
+
152
+
153
+ class VisionTransformer(nn.Module):
154
+ """Vision Transformer
155
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
156
+ https://arxiv.org/abs/2010.11929
157
+ """
158
+ def __init__(
159
+ self,
160
+ img_size=224,
161
+ patch_size=16,
162
+ in_chans=3,
163
+ num_classes=1000,
164
+ embed_dim=768,
165
+ depth=12,
166
+ num_heads=12,
167
+ mlp_ratio=4.0,
168
+ qkv_bias=True,
169
+ qk_scale=None,
170
+ representation_size=None,
171
+ drop_rate=0.0,
172
+ attn_drop_rate=0.0,
173
+ drop_path_rate=0.0,
174
+ norm_layer=None,
175
+ use_grad_checkpointing=False,
176
+ ckpt_layer=0,
177
+ ):
178
+ """
179
+ Args:
180
+ img_size (int, tuple): input image size
181
+ patch_size (int, tuple): patch size
182
+ in_chans (int): number of input channels
183
+ num_classes (int): number of classes for classification head
184
+ embed_dim (int): embedding dimension
185
+ depth (int): depth of transformer
186
+ num_heads (int): number of attention heads
187
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
188
+ qkv_bias (bool): enable bias for qkv if True
189
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set
190
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
191
+ drop_rate (float): dropout rate
192
+ attn_drop_rate (float): attention dropout rate
193
+ drop_path_rate (float): stochastic depth rate
194
+ norm_layer: (nn.Module): normalization layer
195
+ """
196
+ super().__init__()
197
+ self.num_features = (self.embed_dim) = embed_dim # num_features for consistency with other models
198
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
199
+
200
+ self.patch_embed = PatchEmbed(
201
+ img_size=img_size,
202
+ patch_size=patch_size,
203
+ in_chans=in_chans,
204
+ embed_dim=embed_dim,
205
+ )
206
+
207
+ num_patches = self.patch_embed.num_patches
208
+
209
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
210
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
211
+ self.pos_drop = nn.Dropout(p=drop_rate)
212
+
213
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
214
+ self.blocks = nn.ModuleList([
215
+ Block(
216
+ dim=embed_dim,
217
+ num_heads=num_heads,
218
+ mlp_ratio=mlp_ratio,
219
+ qkv_bias=qkv_bias,
220
+ qk_scale=qk_scale,
221
+ drop=drop_rate,
222
+ attn_drop=attn_drop_rate,
223
+ drop_path=dpr[i],
224
+ norm_layer=norm_layer,
225
+ use_grad_checkpointing=(use_grad_checkpointing and i >= depth - ckpt_layer),
226
+ ) for i in range(depth)
227
+ ])
228
+ self.norm = norm_layer(embed_dim)
229
+
230
+ trunc_normal_(self.pos_embed, std=0.02)
231
+ trunc_normal_(self.cls_token, std=0.02)
232
+ self.apply(self._init_weights)
233
+
234
+ def _init_weights(self, m):
235
+ if isinstance(m, nn.Linear):
236
+ trunc_normal_(m.weight, std=0.02)
237
+ if isinstance(m, nn.Linear) and m.bias is not None:
238
+ nn.init.constant_(m.bias, 0)
239
+ elif isinstance(m, nn.LayerNorm):
240
+ nn.init.constant_(m.bias, 0)
241
+ nn.init.constant_(m.weight, 1.0)
242
+
243
+ @torch.jit.ignore
244
+ def no_weight_decay(self):
245
+ return {"pos_embed", "cls_token"}
246
+
247
+ def forward(self, x, register_blk=-1):
248
+ B = x.shape[0]
249
+ x = self.patch_embed(x)
250
+
251
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
252
+ x = torch.cat((cls_tokens, x), dim=1)
253
+
254
+ x = x + self.pos_embed[:, :x.size(1), :]
255
+ x = self.pos_drop(x)
256
+
257
+ for i, blk in enumerate(self.blocks):
258
+ x = blk(x, register_blk == i)
259
+ x = self.norm(x)
260
+
261
+ return x
262
+
263
+ @torch.jit.ignore()
264
+ def load_pretrained(self, checkpoint_path, prefix=""):
265
+ _load_weights(self, checkpoint_path, prefix)
266
+
267
+
268
+ @torch.no_grad()
269
+ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ""):
270
+ """Load weights from .npz checkpoints for official Google Brain Flax implementation"""
271
+ import numpy as np
272
+
273
+ def _n2p(w, t=True):
274
+ if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
275
+ w = w.flatten()
276
+ if t:
277
+ if w.ndim == 4:
278
+ w = w.transpose([3, 2, 0, 1])
279
+ elif w.ndim == 3:
280
+ w = w.transpose([2, 0, 1])
281
+ elif w.ndim == 2:
282
+ w = w.transpose([1, 0])
283
+ return torch.from_numpy(w)
284
+
285
+ w = np.load(checkpoint_path)
286
+ if not prefix and "opt/target/embedding/kernel" in w:
287
+ prefix = "opt/target/"
288
+
289
+ if hasattr(model.patch_embed, "backbone"):
290
+ # hybrid
291
+ backbone = model.patch_embed.backbone
292
+ stem_only = not hasattr(backbone, "stem")
293
+ stem = backbone if stem_only else backbone.stem
294
+ stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f"{prefix}conv_root/kernel"])))
295
+ stem.norm.weight.copy_(_n2p(w[f"{prefix}gn_root/scale"]))
296
+ stem.norm.bias.copy_(_n2p(w[f"{prefix}gn_root/bias"]))
297
+ if not stem_only:
298
+ for i, stage in enumerate(backbone.stages):
299
+ for j, block in enumerate(stage.blocks):
300
+ bp = f"{prefix}block{i + 1}/unit{j + 1}/"
301
+ for r in range(3):
302
+ getattr(block, f"conv{r + 1}").weight.copy_(_n2p(w[f"{bp}conv{r + 1}/kernel"]))
303
+ getattr(block, f"norm{r + 1}").weight.copy_(_n2p(w[f"{bp}gn{r + 1}/scale"]))
304
+ getattr(block, f"norm{r + 1}").bias.copy_(_n2p(w[f"{bp}gn{r + 1}/bias"]))
305
+ if block.downsample is not None:
306
+ block.downsample.conv.weight.copy_(_n2p(w[f"{bp}conv_proj/kernel"]))
307
+ block.downsample.norm.weight.copy_(_n2p(w[f"{bp}gn_proj/scale"]))
308
+ block.downsample.norm.bias.copy_(_n2p(w[f"{bp}gn_proj/bias"]))
309
+ embed_conv_w = _n2p(w[f"{prefix}embedding/kernel"])
310
+ else:
311
+ embed_conv_w = adapt_input_conv(model.patch_embed.proj.weight.shape[1], _n2p(w[f"{prefix}embedding/kernel"]))
312
+ model.patch_embed.proj.weight.copy_(embed_conv_w)
313
+ model.patch_embed.proj.bias.copy_(_n2p(w[f"{prefix}embedding/bias"]))
314
+ model.cls_token.copy_(_n2p(w[f"{prefix}cls"], t=False))
315
+ pos_embed_w = _n2p(w[f"{prefix}Transformer/posembed_input/pos_embedding"], t=False)
316
+ if pos_embed_w.shape != model.pos_embed.shape:
317
+ pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
318
+ pos_embed_w,
319
+ model.pos_embed,
320
+ getattr(model, "num_tokens", 1),
321
+ model.patch_embed.grid_size,
322
+ )
323
+ model.pos_embed.copy_(pos_embed_w)
324
+ model.norm.weight.copy_(_n2p(w[f"{prefix}Transformer/encoder_norm/scale"]))
325
+ model.norm.bias.copy_(_n2p(w[f"{prefix}Transformer/encoder_norm/bias"]))
326
+ # if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
327
+ # model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
328
+ # model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
329
+ # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
330
+ # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
331
+ # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
332
+ for i, block in enumerate(model.blocks.children()):
333
+ block_prefix = f"{prefix}Transformer/encoderblock_{i}/"
334
+ mha_prefix = block_prefix + "MultiHeadDotProductAttention_1/"
335
+ block.norm1.weight.copy_(_n2p(w[f"{block_prefix}LayerNorm_0/scale"]))
336
+ block.norm1.bias.copy_(_n2p(w[f"{block_prefix}LayerNorm_0/bias"]))
337
+ block.attn.qkv.weight.copy_(
338
+ torch.cat([_n2p(w[f"{mha_prefix}{n}/kernel"], t=False).flatten(1).T for n in ("query", "key", "value")]))
339
+ block.attn.qkv.bias.copy_(
340
+ torch.cat([_n2p(w[f"{mha_prefix}{n}/bias"], t=False).reshape(-1) for n in ("query", "key", "value")]))
341
+ block.attn.proj.weight.copy_(_n2p(w[f"{mha_prefix}out/kernel"]).flatten(1))
342
+ block.attn.proj.bias.copy_(_n2p(w[f"{mha_prefix}out/bias"]))
343
+ for r in range(2):
344
+ getattr(block.mlp, f"fc{r + 1}").weight.copy_(_n2p(w[f"{block_prefix}MlpBlock_3/Dense_{r}/kernel"]))
345
+ getattr(block.mlp, f"fc{r + 1}").bias.copy_(_n2p(w[f"{block_prefix}MlpBlock_3/Dense_{r}/bias"]))
346
+ block.norm2.weight.copy_(_n2p(w[f"{block_prefix}LayerNorm_2/scale"]))
347
+ block.norm2.bias.copy_(_n2p(w[f"{block_prefix}LayerNorm_2/bias"]))
348
+
349
+
350
+ def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
351
+ # Rescale the grid of position embeddings when loading from state_dict. Adapted from
352
+ # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
353
+ print("Resized position embedding: %s to %s", posemb.shape, posemb_new.shape)
354
+ ntok_new = posemb_new.shape[1]
355
+ if num_tokens:
356
+ posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
357
+ ntok_new -= num_tokens
358
+ else:
359
+ posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
360
+ gs_old = int(math.sqrt(len(posemb_grid)))
361
+ if not len(gs_new): # backwards compatibility
362
+ gs_new = [int(math.sqrt(ntok_new))] * 2
363
+ assert len(gs_new) >= 2
364
+ print("Position embedding grid-size from %s to %s", [gs_old, gs_old], gs_new)
365
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
366
+ posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode="bicubic", align_corners=False)
367
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
368
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
369
+ return
370
+
371
+
372
+ def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
373
+ # interpolate position embedding
374
+ embedding_size = pos_embed_checkpoint.shape[-1]
375
+ num_patches = visual_encoder.patch_embed.num_patches
376
+ num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
377
+ # height (== width) for the checkpoint position embedding
378
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens)**0.5)
379
+ # height (== width) for the new position embedding
380
+ new_size = int(num_patches**0.5)
381
+
382
+ if orig_size != new_size:
383
+ # class_token and dist_token are kept unchanged
384
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
385
+ # only the position tokens are interpolated
386
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
387
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
388
+ pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False)
389
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
390
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
391
+ print("reshape position embedding from %d to %d" % (orig_size**2, new_size**2))
392
+
393
+ return new_pos_embed
394
+ else:
395
+ return pos_embed_checkpoint
models/transforms.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision import transforms
2
+
3
+
4
+ def get_transform(type='clip', keep_ratio=True, image_size=224):
5
+ if type == 'clip':
6
+ transform = []
7
+ if keep_ratio:
8
+ transform.extend([
9
+ transforms.Resize(image_size),
10
+ transforms.CenterCrop(image_size),
11
+ ])
12
+ else:
13
+ transform.append(transforms.Resize((image_size, image_size)))
14
+ transform.extend([
15
+ transforms.ToTensor(),
16
+ transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
17
+ ])
18
+
19
+ return transforms.Compose(transform)
20
+ else:
21
+ raise NotImplementedError