Update README.md
#1
by
SFXX
- opened
- .gitignore +3 -0
- LICENSE.txt +437 -0
- README.md +139 -0
- added_tokens.json +14 -0
- config.json +26 -0
- configuration_blip_3.py +161 -0
- demo.ipynb +0 -0
- generation_config.json +7 -0
- image_processing_blip_3.py +409 -0
- model-00001-of-00004.safetensors +3 -0
- model-00002-of-00004.safetensors +3 -0
- model-00003-of-00004.safetensors +3 -0
- model-00004-of-00004.safetensors +3 -0
- model.safetensors.index.json +673 -0
- modeling_blip_3.py +110 -0
- preprocessor_config.json +23 -0
- special_tokens_map.json +30 -0
- test_samples/images/1074.jpg +0 -0
- test_samples/images/1148.jpg +0 -0
- test_samples/images/152.jpg +0 -0
- test_samples/images/1614.jpg +0 -0
- test_samples/images/26302.jpg +0 -0
- test_samples/images/45711.jpg +0 -0
- test_samples/test.json +42 -0
- tokenizer.json +0 -0
- tokenizer.model +3 -0
- tokenizer_config.json +137 -0
- utils.py +383 -0
- vlm.py +1531 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/**
|
2 |
+
debug.py
|
3 |
+
sanity_check.ipynb
|
LICENSE.txt
ADDED
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Attribution-NonCommercial-ShareAlike 4.0 International
|
2 |
+
|
3 |
+
=======================================================================
|
4 |
+
|
5 |
+
Creative Commons Corporation ("Creative Commons") is not a law firm and
|
6 |
+
does not provide legal services or legal advice. Distribution of
|
7 |
+
Creative Commons public licenses does not create a lawyer-client or
|
8 |
+
other relationship. Creative Commons makes its licenses and related
|
9 |
+
information available on an "as-is" basis. Creative Commons gives no
|
10 |
+
warranties regarding its licenses, any material licensed under their
|
11 |
+
terms and conditions, or any related information. Creative Commons
|
12 |
+
disclaims all liability for damages resulting from their use to the
|
13 |
+
fullest extent possible.
|
14 |
+
|
15 |
+
Using Creative Commons Public Licenses
|
16 |
+
|
17 |
+
Creative Commons public licenses provide a standard set of terms and
|
18 |
+
conditions that creators and other rights holders may use to share
|
19 |
+
original works of authorship and other material subject to copyright
|
20 |
+
and certain other rights specified in the public license below. The
|
21 |
+
following considerations are for informational purposes only, are not
|
22 |
+
exhaustive, and do not form part of our licenses.
|
23 |
+
|
24 |
+
Considerations for licensors: Our public licenses are
|
25 |
+
intended for use by those authorized to give the public
|
26 |
+
permission to use material in ways otherwise restricted by
|
27 |
+
copyright and certain other rights. Our licenses are
|
28 |
+
irrevocable. Licensors should read and understand the terms
|
29 |
+
and conditions of the license they choose before applying it.
|
30 |
+
Licensors should also secure all rights necessary before
|
31 |
+
applying our licenses so that the public can reuse the
|
32 |
+
material as expected. Licensors should clearly mark any
|
33 |
+
material not subject to the license. This includes other CC-
|
34 |
+
licensed material, or material used under an exception or
|
35 |
+
limitation to copyright. More considerations for licensors:
|
36 |
+
wiki.creativecommons.org/Considerations_for_licensors
|
37 |
+
|
38 |
+
Considerations for the public: By using one of our public
|
39 |
+
licenses, a licensor grants the public permission to use the
|
40 |
+
licensed material under specified terms and conditions. If
|
41 |
+
the licensor's permission is not necessary for any reason--for
|
42 |
+
example, because of any applicable exception or limitation to
|
43 |
+
copyright--then that use is not regulated by the license. Our
|
44 |
+
licenses grant only permissions under copyright and certain
|
45 |
+
other rights that a licensor has authority to grant. Use of
|
46 |
+
the licensed material may still be restricted for other
|
47 |
+
reasons, including because others have copyright or other
|
48 |
+
rights in the material. A licensor may make special requests,
|
49 |
+
such as asking that all changes be marked or described.
|
50 |
+
Although not required by our licenses, you are encouraged to
|
51 |
+
respect those requests where reasonable. More considerations
|
52 |
+
for the public:
|
53 |
+
wiki.creativecommons.org/Considerations_for_licensees
|
54 |
+
|
55 |
+
=======================================================================
|
56 |
+
|
57 |
+
Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
|
58 |
+
Public License
|
59 |
+
|
60 |
+
By exercising the Licensed Rights (defined below), You accept and agree
|
61 |
+
to be bound by the terms and conditions of this Creative Commons
|
62 |
+
Attribution-NonCommercial-ShareAlike 4.0 International Public License
|
63 |
+
("Public License"). To the extent this Public License may be
|
64 |
+
interpreted as a contract, You are granted the Licensed Rights in
|
65 |
+
consideration of Your acceptance of these terms and conditions, and the
|
66 |
+
Licensor grants You such rights in consideration of benefits the
|
67 |
+
Licensor receives from making the Licensed Material available under
|
68 |
+
these terms and conditions.
|
69 |
+
|
70 |
+
|
71 |
+
Section 1 -- Definitions.
|
72 |
+
|
73 |
+
a. Adapted Material means material subject to Copyright and Similar
|
74 |
+
Rights that is derived from or based upon the Licensed Material
|
75 |
+
and in which the Licensed Material is translated, altered,
|
76 |
+
arranged, transformed, or otherwise modified in a manner requiring
|
77 |
+
permission under the Copyright and Similar Rights held by the
|
78 |
+
Licensor. For purposes of this Public License, where the Licensed
|
79 |
+
Material is a musical work, performance, or sound recording,
|
80 |
+
Adapted Material is always produced where the Licensed Material is
|
81 |
+
synched in timed relation with a moving image.
|
82 |
+
|
83 |
+
b. Adapter's License means the license You apply to Your Copyright
|
84 |
+
and Similar Rights in Your contributions to Adapted Material in
|
85 |
+
accordance with the terms and conditions of this Public License.
|
86 |
+
|
87 |
+
c. BY-NC-SA Compatible License means a license listed at
|
88 |
+
creativecommons.org/compatiblelicenses, approved by Creative
|
89 |
+
Commons as essentially the equivalent of this Public License.
|
90 |
+
|
91 |
+
d. Copyright and Similar Rights means copyright and/or similar rights
|
92 |
+
closely related to copyright including, without limitation,
|
93 |
+
performance, broadcast, sound recording, and Sui Generis Database
|
94 |
+
Rights, without regard to how the rights are labeled or
|
95 |
+
categorized. For purposes of this Public License, the rights
|
96 |
+
specified in Section 2(b)(1)-(2) are not Copyright and Similar
|
97 |
+
Rights.
|
98 |
+
|
99 |
+
e. Effective Technological Measures means those measures that, in the
|
100 |
+
absence of proper authority, may not be circumvented under laws
|
101 |
+
fulfilling obligations under Article 11 of the WIPO Copyright
|
102 |
+
Treaty adopted on December 20, 1996, and/or similar international
|
103 |
+
agreements.
|
104 |
+
|
105 |
+
f. Exceptions and Limitations means fair use, fair dealing, and/or
|
106 |
+
any other exception or limitation to Copyright and Similar Rights
|
107 |
+
that applies to Your use of the Licensed Material.
|
108 |
+
|
109 |
+
g. License Elements means the license attributes listed in the name
|
110 |
+
of a Creative Commons Public License. The License Elements of this
|
111 |
+
Public License are Attribution, NonCommercial, and ShareAlike.
|
112 |
+
|
113 |
+
h. Licensed Material means the artistic or literary work, database,
|
114 |
+
or other material to which the Licensor applied this Public
|
115 |
+
License.
|
116 |
+
|
117 |
+
i. Licensed Rights means the rights granted to You subject to the
|
118 |
+
terms and conditions of this Public License, which are limited to
|
119 |
+
all Copyright and Similar Rights that apply to Your use of the
|
120 |
+
Licensed Material and that the Licensor has authority to license.
|
121 |
+
|
122 |
+
j. Licensor means the individual(s) or entity(ies) granting rights
|
123 |
+
under this Public License.
|
124 |
+
|
125 |
+
k. NonCommercial means not primarily intended for or directed towards
|
126 |
+
commercial advantage or monetary compensation. For purposes of
|
127 |
+
this Public License, the exchange of the Licensed Material for
|
128 |
+
other material subject to Copyright and Similar Rights by digital
|
129 |
+
file-sharing or similar means is NonCommercial provided there is
|
130 |
+
no payment of monetary compensation in connection with the
|
131 |
+
exchange.
|
132 |
+
|
133 |
+
l. Share means to provide material to the public by any means or
|
134 |
+
process that requires permission under the Licensed Rights, such
|
135 |
+
as reproduction, public display, public performance, distribution,
|
136 |
+
dissemination, communication, or importation, and to make material
|
137 |
+
available to the public including in ways that members of the
|
138 |
+
public may access the material from a place and at a time
|
139 |
+
individually chosen by them.
|
140 |
+
|
141 |
+
m. Sui Generis Database Rights means rights other than copyright
|
142 |
+
resulting from Directive 96/9/EC of the European Parliament and of
|
143 |
+
the Council of 11 March 1996 on the legal protection of databases,
|
144 |
+
as amended and/or succeeded, as well as other essentially
|
145 |
+
equivalent rights anywhere in the world.
|
146 |
+
|
147 |
+
n. You means the individual or entity exercising the Licensed Rights
|
148 |
+
under this Public License. Your has a corresponding meaning.
|
149 |
+
|
150 |
+
|
151 |
+
Section 2 -- Scope.
|
152 |
+
|
153 |
+
a. License grant.
|
154 |
+
|
155 |
+
1. Subject to the terms and conditions of this Public License,
|
156 |
+
the Licensor hereby grants You a worldwide, royalty-free,
|
157 |
+
non-sublicensable, non-exclusive, irrevocable license to
|
158 |
+
exercise the Licensed Rights in the Licensed Material to:
|
159 |
+
|
160 |
+
a. reproduce and Share the Licensed Material, in whole or
|
161 |
+
in part, for NonCommercial purposes only; and
|
162 |
+
|
163 |
+
b. produce, reproduce, and Share Adapted Material for
|
164 |
+
NonCommercial purposes only.
|
165 |
+
|
166 |
+
2. Exceptions and Limitations. For the avoidance of doubt, where
|
167 |
+
Exceptions and Limitations apply to Your use, this Public
|
168 |
+
License does not apply, and You do not need to comply with
|
169 |
+
its terms and conditions.
|
170 |
+
|
171 |
+
3. Term. The term of this Public License is specified in Section
|
172 |
+
6(a).
|
173 |
+
|
174 |
+
4. Media and formats; technical modifications allowed. The
|
175 |
+
Licensor authorizes You to exercise the Licensed Rights in
|
176 |
+
all media and formats whether now known or hereafter created,
|
177 |
+
and to make technical modifications necessary to do so. The
|
178 |
+
Licensor waives and/or agrees not to assert any right or
|
179 |
+
authority to forbid You from making technical modifications
|
180 |
+
necessary to exercise the Licensed Rights, including
|
181 |
+
technical modifications necessary to circumvent Effective
|
182 |
+
Technological Measures. For purposes of this Public License,
|
183 |
+
simply making modifications authorized by this Section 2(a)
|
184 |
+
(4) never produces Adapted Material.
|
185 |
+
|
186 |
+
5. Downstream recipients.
|
187 |
+
|
188 |
+
a. Offer from the Licensor -- Licensed Material. Every
|
189 |
+
recipient of the Licensed Material automatically
|
190 |
+
receives an offer from the Licensor to exercise the
|
191 |
+
Licensed Rights under the terms and conditions of this
|
192 |
+
Public License.
|
193 |
+
|
194 |
+
b. Additional offer from the Licensor -- Adapted Material.
|
195 |
+
Every recipient of Adapted Material from You
|
196 |
+
automatically receives an offer from the Licensor to
|
197 |
+
exercise the Licensed Rights in the Adapted Material
|
198 |
+
under the conditions of the Adapter's License You apply.
|
199 |
+
|
200 |
+
c. No downstream restrictions. You may not offer or impose
|
201 |
+
any additional or different terms or conditions on, or
|
202 |
+
apply any Effective Technological Measures to, the
|
203 |
+
Licensed Material if doing so restricts exercise of the
|
204 |
+
Licensed Rights by any recipient of the Licensed
|
205 |
+
Material.
|
206 |
+
|
207 |
+
6. No endorsement. Nothing in this Public License constitutes or
|
208 |
+
may be construed as permission to assert or imply that You
|
209 |
+
are, or that Your use of the Licensed Material is, connected
|
210 |
+
with, or sponsored, endorsed, or granted official status by,
|
211 |
+
the Licensor or others designated to receive attribution as
|
212 |
+
provided in Section 3(a)(1)(A)(i).
|
213 |
+
|
214 |
+
b. Other rights.
|
215 |
+
|
216 |
+
1. Moral rights, such as the right of integrity, are not
|
217 |
+
licensed under this Public License, nor are publicity,
|
218 |
+
privacy, and/or other similar personality rights; however, to
|
219 |
+
the extent possible, the Licensor waives and/or agrees not to
|
220 |
+
assert any such rights held by the Licensor to the limited
|
221 |
+
extent necessary to allow You to exercise the Licensed
|
222 |
+
Rights, but not otherwise.
|
223 |
+
|
224 |
+
2. Patent and trademark rights are not licensed under this
|
225 |
+
Public License.
|
226 |
+
|
227 |
+
3. To the extent possible, the Licensor waives any right to
|
228 |
+
collect royalties from You for the exercise of the Licensed
|
229 |
+
Rights, whether directly or through a collecting society
|
230 |
+
under any voluntary or waivable statutory or compulsory
|
231 |
+
licensing scheme. In all other cases the Licensor expressly
|
232 |
+
reserves any right to collect such royalties, including when
|
233 |
+
the Licensed Material is used other than for NonCommercial
|
234 |
+
purposes.
|
235 |
+
|
236 |
+
|
237 |
+
Section 3 -- License Conditions.
|
238 |
+
|
239 |
+
Your exercise of the Licensed Rights is expressly made subject to the
|
240 |
+
following conditions.
|
241 |
+
|
242 |
+
a. Attribution.
|
243 |
+
|
244 |
+
1. If You Share the Licensed Material (including in modified
|
245 |
+
form), You must:
|
246 |
+
|
247 |
+
a. retain the following if it is supplied by the Licensor
|
248 |
+
with the Licensed Material:
|
249 |
+
|
250 |
+
i. identification of the creator(s) of the Licensed
|
251 |
+
Material and any others designated to receive
|
252 |
+
attribution, in any reasonable manner requested by
|
253 |
+
the Licensor (including by pseudonym if
|
254 |
+
designated);
|
255 |
+
|
256 |
+
ii. a copyright notice;
|
257 |
+
|
258 |
+
iii. a notice that refers to this Public License;
|
259 |
+
|
260 |
+
iv. a notice that refers to the disclaimer of
|
261 |
+
warranties;
|
262 |
+
|
263 |
+
v. a URI or hyperlink to the Licensed Material to the
|
264 |
+
extent reasonably practicable;
|
265 |
+
|
266 |
+
b. indicate if You modified the Licensed Material and
|
267 |
+
retain an indication of any previous modifications; and
|
268 |
+
|
269 |
+
c. indicate the Licensed Material is licensed under this
|
270 |
+
Public License, and include the text of, or the URI or
|
271 |
+
hyperlink to, this Public License.
|
272 |
+
|
273 |
+
2. You may satisfy the conditions in Section 3(a)(1) in any
|
274 |
+
reasonable manner based on the medium, means, and context in
|
275 |
+
which You Share the Licensed Material. For example, it may be
|
276 |
+
reasonable to satisfy the conditions by providing a URI or
|
277 |
+
hyperlink to a resource that includes the required
|
278 |
+
information.
|
279 |
+
3. If requested by the Licensor, You must remove any of the
|
280 |
+
information required by Section 3(a)(1)(A) to the extent
|
281 |
+
reasonably practicable.
|
282 |
+
|
283 |
+
b. ShareAlike.
|
284 |
+
|
285 |
+
In addition to the conditions in Section 3(a), if You Share
|
286 |
+
Adapted Material You produce, the following conditions also apply.
|
287 |
+
|
288 |
+
1. The Adapter's License You apply must be a Creative Commons
|
289 |
+
license with the same License Elements, this version or
|
290 |
+
later, or a BY-NC-SA Compatible License.
|
291 |
+
|
292 |
+
2. You must include the text of, or the URI or hyperlink to, the
|
293 |
+
Adapter's License You apply. You may satisfy this condition
|
294 |
+
in any reasonable manner based on the medium, means, and
|
295 |
+
context in which You Share Adapted Material.
|
296 |
+
|
297 |
+
3. You may not offer or impose any additional or different terms
|
298 |
+
or conditions on, or apply any Effective Technological
|
299 |
+
Measures to, Adapted Material that restrict exercise of the
|
300 |
+
rights granted under the Adapter's License You apply.
|
301 |
+
|
302 |
+
|
303 |
+
Section 4 -- Sui Generis Database Rights.
|
304 |
+
|
305 |
+
Where the Licensed Rights include Sui Generis Database Rights that
|
306 |
+
apply to Your use of the Licensed Material:
|
307 |
+
|
308 |
+
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
|
309 |
+
to extract, reuse, reproduce, and Share all or a substantial
|
310 |
+
portion of the contents of the database for NonCommercial purposes
|
311 |
+
only;
|
312 |
+
|
313 |
+
b. if You include all or a substantial portion of the database
|
314 |
+
contents in a database in which You have Sui Generis Database
|
315 |
+
Rights, then the database in which You have Sui Generis Database
|
316 |
+
Rights (but not its individual contents) is Adapted Material,
|
317 |
+
including for purposes of Section 3(b); and
|
318 |
+
|
319 |
+
c. You must comply with the conditions in Section 3(a) if You Share
|
320 |
+
all or a substantial portion of the contents of the database.
|
321 |
+
|
322 |
+
For the avoidance of doubt, this Section 4 supplements and does not
|
323 |
+
replace Your obligations under this Public License where the Licensed
|
324 |
+
Rights include other Copyright and Similar Rights.
|
325 |
+
|
326 |
+
|
327 |
+
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
|
328 |
+
|
329 |
+
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
|
330 |
+
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
|
331 |
+
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
|
332 |
+
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
|
333 |
+
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
|
334 |
+
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
335 |
+
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
|
336 |
+
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
|
337 |
+
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
|
338 |
+
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
|
339 |
+
|
340 |
+
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
|
341 |
+
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
|
342 |
+
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
|
343 |
+
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
|
344 |
+
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
|
345 |
+
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
|
346 |
+
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
|
347 |
+
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
|
348 |
+
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
|
349 |
+
|
350 |
+
c. The disclaimer of warranties and limitation of liability provided
|
351 |
+
above shall be interpreted in a manner that, to the extent
|
352 |
+
possible, most closely approximates an absolute disclaimer and
|
353 |
+
waiver of all liability.
|
354 |
+
|
355 |
+
|
356 |
+
Section 6 -- Term and Termination.
|
357 |
+
|
358 |
+
a. This Public License applies for the term of the Copyright and
|
359 |
+
Similar Rights licensed here. However, if You fail to comply with
|
360 |
+
this Public License, then Your rights under this Public License
|
361 |
+
terminate automatically.
|
362 |
+
|
363 |
+
b. Where Your right to use the Licensed Material has terminated under
|
364 |
+
Section 6(a), it reinstates:
|
365 |
+
|
366 |
+
1. automatically as of the date the violation is cured, provided
|
367 |
+
it is cured within 30 days of Your discovery of the
|
368 |
+
violation; or
|
369 |
+
|
370 |
+
2. upon express reinstatement by the Licensor.
|
371 |
+
|
372 |
+
For the avoidance of doubt, this Section 6(b) does not affect any
|
373 |
+
right the Licensor may have to seek remedies for Your violations
|
374 |
+
of this Public License.
|
375 |
+
|
376 |
+
c. For the avoidance of doubt, the Licensor may also offer the
|
377 |
+
Licensed Material under separate terms or conditions or stop
|
378 |
+
distributing the Licensed Material at any time; however, doing so
|
379 |
+
will not terminate this Public License.
|
380 |
+
|
381 |
+
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
|
382 |
+
License.
|
383 |
+
|
384 |
+
|
385 |
+
Section 7 -- Other Terms and Conditions.
|
386 |
+
|
387 |
+
a. The Licensor shall not be bound by any additional or different
|
388 |
+
terms or conditions communicated by You unless expressly agreed.
|
389 |
+
|
390 |
+
b. Any arrangements, understandings, or agreements regarding the
|
391 |
+
Licensed Material not stated herein are separate from and
|
392 |
+
independent of the terms and conditions of this Public License.
|
393 |
+
|
394 |
+
|
395 |
+
Section 8 -- Interpretation.
|
396 |
+
|
397 |
+
a. For the avoidance of doubt, this Public License does not, and
|
398 |
+
shall not be interpreted to, reduce, limit, restrict, or impose
|
399 |
+
conditions on any use of the Licensed Material that could lawfully
|
400 |
+
be made without permission under this Public License.
|
401 |
+
|
402 |
+
b. To the extent possible, if any provision of this Public License is
|
403 |
+
deemed unenforceable, it shall be automatically reformed to the
|
404 |
+
minimum extent necessary to make it enforceable. If the provision
|
405 |
+
cannot be reformed, it shall be severed from this Public License
|
406 |
+
without affecting the enforceability of the remaining terms and
|
407 |
+
conditions.
|
408 |
+
|
409 |
+
c. No term or condition of this Public License will be waived and no
|
410 |
+
failure to comply consented to unless expressly agreed to by the
|
411 |
+
Licensor.
|
412 |
+
|
413 |
+
d. Nothing in this Public License constitutes or may be interpreted
|
414 |
+
as a limitation upon, or waiver of, any privileges and immunities
|
415 |
+
that apply to the Licensor or You, including from the legal
|
416 |
+
processes of any jurisdiction or authority.
|
417 |
+
|
418 |
+
=======================================================================
|
419 |
+
|
420 |
+
Creative Commons is not a party to its public
|
421 |
+
licenses. Notwithstanding, Creative Commons may elect to apply one of
|
422 |
+
its public licenses to material it publishes and in those instances
|
423 |
+
will be considered the “Licensor.” The text of the Creative Commons
|
424 |
+
public licenses is dedicated to the public domain under the CC0 Public
|
425 |
+
Domain Dedication. Except for the limited purpose of indicating that
|
426 |
+
material is shared under a Creative Commons public license or as
|
427 |
+
otherwise permitted by the Creative Commons policies published at
|
428 |
+
creativecommons.org/policies, Creative Commons does not authorize the
|
429 |
+
use of the trademark "Creative Commons" or any other trademark or logo
|
430 |
+
of Creative Commons without its prior written consent including,
|
431 |
+
without limitation, in connection with any unauthorized modifications
|
432 |
+
to any of its public licenses or any other arrangements,
|
433 |
+
understandings, or agreements concerning use of licensed material. For
|
434 |
+
the avoidance of doubt, this paragraph does not form part of the
|
435 |
+
public licenses.
|
436 |
+
|
437 |
+
Creative Commons may be contacted at creativecommons.org.
|
README.md
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: cc-by-nc-4.0
|
3 |
+
language:
|
4 |
+
- en
|
5 |
+
pipeline_tag: image-text-to-text
|
6 |
+
---
|
7 |
+
|
8 |
+
|
9 |
+
# Model description
|
10 |
+
|
11 |
+
BLIP3 is a series of foundational vision-language models (VLMs) developed by Salesforce AI Research. \
|
12 |
+
These models have been trained at scale on high-quality image caption datasets and interleaved image-text data. BLIP3 highlights a few features below,
|
13 |
+
|
14 |
+
* The pretrained foundation model, blip3-phi3-mini-base-r-v1, achieves state-of-the-art performance under 5b parameters and demonstrates strong in-context learning capabilities.
|
15 |
+
* The instruct fine-tuned model, blip3-phi3-mini-instruct-r-v1, achieves state-of-the-art performance among open-source and closed-source VLMs under 5b parameters.
|
16 |
+
* blip3-phi3-mini-instruct-r-v1 supports flexible high-resolution image encoding with efficient visual token sampling.
|
17 |
+
|
18 |
+
More technical details will come with a technical report soon.
|
19 |
+
|
20 |
+
|
21 |
+
# Datasets
|
22 |
+
|
23 |
+
| Dataset Type| Dataset(s) Used |
|
24 |
+
|--------|------------------------------------------|
|
25 |
+
| Pretrain | datacomp, cc12m, cc3m, SBU, vg, obelics |
|
26 |
+
| Instruction Tuning | LLaVA-Instruct-150K, ShareGPT4V captions, a mixture of academic VQA data including OCR/Document/Chart-focused tasks, publicly available text-only instruction data |
|
27 |
+
|
28 |
+
# Results
|
29 |
+
|
30 |
+
### Pretrain
|
31 |
+
| Model | Shot | COCO (val) | NoCaps (val) | TextCaps (val) | OKVQA (val) | TextVQA (val) | VizWiz (testdev) | VQAv2 (testdev) |
|
32 |
+
|-------------|------|------------|--------------|----------------|--------------|---------------|------------------|-----------------|
|
33 |
+
| Flamingo-3B | 4 | 85.0 | - | - | 43.3 | 32.7 | 34 | 53.2 |
|
34 |
+
| | 8 | 90.6 | - | - | 44.6 | 32.4 | 38.4 | 55.4 |
|
35 |
+
| MM1-3B | 0 | 73.5 | 55.6 | 63.3 | 26.1 | 29.4 | 15.6 | 46.2 |
|
36 |
+
| | 4 | 112.3 | 99.7 | 84.1 | 48.6 | 45.3 | 38.0 | 57.9 |
|
37 |
+
| | 8 | 114.6 | 104.7 | 88.8 | 48.4 | 44.6 | 46.4 | 63.6 |
|
38 |
+
| BLIP3-phi3-mini-base-r-v1 (Ours)| 0 | **81.7** | **80.2** | 60.7 | **26.5** | **36.0** | **21.2** | **48.1** |
|
39 |
+
| | 4 | 110.5 | **101.7** | **84.6** | **49.2** | **46.1** | **38.4** | **63.9** |
|
40 |
+
| | 8 | 112.1 | 104.4 | 87.7 | **49.1** | **46.4** | 44.3 | **63.8** |
|
41 |
+
|
42 |
+
### Instruct
|
43 |
+
| Model | SEED-IMG | MMBench(dev) | MME-total | MME-P | MME-C | MMStar | MMMU (val) | MMVet | MathVista (mini) | ScienceQA (test) | POPE | TextVQA | AI2D | |
|
44 |
+
|----------------------------|----------|--------------|-----------|----------|---------|----------|------------|----------|------------------|------------------|----------|----------|----------|---|
|
45 |
+
| MM1-3B-Chat | 68.8 | 75.9 | 1761 | **1482** | 279 | - | 33.9 | 43.7 | - | - | **87.4** | - | - | |
|
46 |
+
| openbmb/MiniCPM-V-2 | 67.1 | 69.6 | 1808 | - | - | - | 38.2 | - | 38.7 | - | - | 74.1 | - | |
|
47 |
+
| VILA1.5-3B | 67.9 | 63.4 | - | 1442 | - | - | 33.3 | 35.4 | - | 69.0 | 85.9 | 70.4 | - | |
|
48 |
+
| xtuner/llava-phi-3-mini-hf | 70.0 | 69.2 | 1790 | 1477 | 313 | 43.7 | **41.4** | - | - | 73.7 | 87.3 | 57.8 | 69.3 | |
|
49 |
+
| BLIP3-phi3-mini-instruct-r-v1 (Ours) | **72.1** | **74.1** | **1827** | 1467 | **360** | **44.6** | 39.8 | **45.1** | **39.3** | **74.2** | 87.2 | 64.6 | **75.8** | |
|
50 |
+
|
51 |
+
|
52 |
+
# Bias, Risks, Limitations, and Ethical Considerations
|
53 |
+
We removed Laion from our training data due to known CSAM concerns.
|
54 |
+
The other main data sources are from the internet, including webpages,
|
55 |
+
image stock sites, and curated datasets released by the research community.
|
56 |
+
The model may be subject to bias from the original data source, as well as bias from LLMs and commercial APIs.
|
57 |
+
We strongly recommend users conduct an assessment of safety and fairness before applying to downstream applications.
|
58 |
+
# How to use
|
59 |
+
|
60 |
+
> We require the use of the development version (`"4.41.0.dev0"`) of the `transformers` library. To get it, as of 05/07/2024, one can use `pip uninstall -y transformers && pip install git+https://github.com/huggingface/transformers.`
|
61 |
+
|
62 |
+
```python
|
63 |
+
from transformers import AutoModelForVision2Seq, AutoTokenizer, AutoImageProcessor, StoppingCriteria
|
64 |
+
import torch
|
65 |
+
import requests
|
66 |
+
from PIL import Image
|
67 |
+
|
68 |
+
# define the prompt template
|
69 |
+
def apply_prompt_template(prompt):
|
70 |
+
s = (
|
71 |
+
'<|system|>\nA chat between a curious user and an artificial intelligence assistant. '
|
72 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.<|end|>\n"
|
73 |
+
f'<|user|>\n<image>\n{prompt}<|end|>\n<|assistant|>\n'
|
74 |
+
)
|
75 |
+
return s
|
76 |
+
class EosListStoppingCriteria(StoppingCriteria):
|
77 |
+
def __init__(self, eos_sequence = [32007]):
|
78 |
+
self.eos_sequence = eos_sequence
|
79 |
+
|
80 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
81 |
+
last_ids = input_ids[:,-len(self.eos_sequence):].tolist()
|
82 |
+
return self.eos_sequence in last_ids
|
83 |
+
|
84 |
+
# load models
|
85 |
+
model_name_or_path = "Salesforce/blip3-phi3-mini-instruct-r-v1"
|
86 |
+
model = AutoModelForVision2Seq.from_pretrained(model_name_or_path, trust_remote_code=True)
|
87 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True, use_fast=False, legacy=False)
|
88 |
+
image_processor = AutoImageProcessor.from_pretrained(model_name_or_path, trust_remote_code=True)
|
89 |
+
tokenizer = model.update_special_tokens(tokenizer)
|
90 |
+
|
91 |
+
# craft a test sample
|
92 |
+
img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg'
|
93 |
+
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
|
94 |
+
query = "how many dogs are in the picture?"
|
95 |
+
|
96 |
+
model = model.cuda()
|
97 |
+
inputs = image_processor([raw_image], return_tensors="pt", image_aspect_ratio='anyres')
|
98 |
+
prompt = apply_prompt_template(query)
|
99 |
+
language_inputs = tokenizer([prompt], return_tensors="pt")
|
100 |
+
inputs.update(language_inputs)
|
101 |
+
inputs = {name: tensor.cuda() for name, tensor in inputs.items()}
|
102 |
+
generated_text = model.generate(**inputs, image_size=[raw_image.size],
|
103 |
+
pad_token_id=tokenizer.pad_token_id,
|
104 |
+
do_sample=False, max_new_tokens=768, top_p=None, num_beams=1,
|
105 |
+
stopping_criteria = [EosListStoppingCriteria()],
|
106 |
+
)
|
107 |
+
prediction = tokenizer.decode(generated_text[0], skip_special_tokens=True).split("<|end|>")[0]
|
108 |
+
print("==> prediction: ", prediction)
|
109 |
+
# output: ==> prediction: There is one dog in the picture.
|
110 |
+
```
|
111 |
+
|
112 |
+
More comprehensive examples can be found in the [notebook](demo.ipynb).
|
113 |
+
|
114 |
+
# Reproducibility:
|
115 |
+
|
116 |
+
Our SFT evaluation is based on the VLMEvalKit, in which we fixed some inconsistencies with the official benchmarks (e.g., LLM judge API). During our development, we noticed that the raw resolution of the input image would noticeably affect the model output in some cases.
|
117 |
+
|
118 |
+
|
119 |
+
# License
|
120 |
+
|
121 |
+
Our code and weights are released under the Creative Commons Attribution Non Commercial 4.0 [LICENSE](LICENSE.txt). Please fill out a form at [here](https://forms.gle/ffPc9oZC2ZGeJ1N68) to consult the commercial use of model weights.
|
122 |
+
|
123 |
+
Code acknowledgement
|
124 |
+
|
125 |
+
[LAVIS](https://github.com/salesforce/LAVIS) \
|
126 |
+
[openflamingo](https://github.com/mlfoundations/open_flamingo) \
|
127 |
+
[VLMEvalKit](https://github.com/open-compass/VLMEvalKit/tree/main)
|
128 |
+
|
129 |
+
|
130 |
+
# Troubleshoot
|
131 |
+
|
132 |
+
1. If you missed any packages, please consider the following
|
133 |
+
|
134 |
+
```
|
135 |
+
pip install torch==2.2.1 torchvision==0.17.1 torchaudio==2.2.1 --index-url https://download.pytorch.org/whl/cu121
|
136 |
+
pip install open_clip_torch==2.24.0
|
137 |
+
pip install einops
|
138 |
+
pip install einops-exts
|
139 |
+
```
|
added_tokens.json
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"<pad>": 32011,
|
3 |
+
"<|assistant|>": 32001,
|
4 |
+
"<|endoftext|>": 32000,
|
5 |
+
"<|end|>": 32007,
|
6 |
+
"<|placeholder1|>": 32002,
|
7 |
+
"<|placeholder2|>": 32003,
|
8 |
+
"<|placeholder3|>": 32004,
|
9 |
+
"<|placeholder4|>": 32005,
|
10 |
+
"<|placeholder5|>": 32008,
|
11 |
+
"<|placeholder6|>": 32009,
|
12 |
+
"<|system|>": 32006,
|
13 |
+
"<|user|>": 32010
|
14 |
+
}
|
config.json
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"Blip3ModelForConditionalGeneration"
|
4 |
+
],
|
5 |
+
"auto_map": {
|
6 |
+
"AutoConfig": "configuration_blip_3.Blip3Config",
|
7 |
+
"AutoModelForVision2Seq": "modeling_blip_3.Blip3ModelForConditionalGeneration"
|
8 |
+
},
|
9 |
+
"model_type": "blip_3",
|
10 |
+
"text_config": {
|
11 |
+
"initial_tokenizer_len": 32012,
|
12 |
+
"model_type": "phi3",
|
13 |
+
"sliding_window": 2047,
|
14 |
+
"torch_dtype": "bfloat16"
|
15 |
+
},
|
16 |
+
"torch_dtype": "float32",
|
17 |
+
"transformers_version": "4.41.0.dev0",
|
18 |
+
"vision_encoder_config": {
|
19 |
+
"anyres_patch_sampling": true,
|
20 |
+
"image_aspect_ratio": "anyres",
|
21 |
+
"model_type": "blip_3_vision_encoder"
|
22 |
+
},
|
23 |
+
"vision_tokenizer_config": {
|
24 |
+
"model_type": "blip_3_vision_tokenizer"
|
25 |
+
}
|
26 |
+
}
|
configuration_blip_3.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
from transformers import logging
|
3 |
+
from transformers import CONFIG_MAPPING
|
4 |
+
|
5 |
+
logger = logging.get_logger(__name__)
|
6 |
+
|
7 |
+
class Blip3VisionEncoderConfig(PretrainedConfig):
|
8 |
+
model_type = "blip_3_vision_encoder"
|
9 |
+
|
10 |
+
def __init__(self,
|
11 |
+
model_name: str = 'ViT-H-14-378-quickgelu',
|
12 |
+
force_image_size: int = 378,
|
13 |
+
**kwargs):
|
14 |
+
self.model_name = model_name
|
15 |
+
self.force_image_size = force_image_size
|
16 |
+
super().__init__(**kwargs)
|
17 |
+
|
18 |
+
|
19 |
+
class Blip3VisionTokenizerConfig(PretrainedConfig):
|
20 |
+
model_type = "blip_3_vision_tokenizer"
|
21 |
+
|
22 |
+
def __init__(self,
|
23 |
+
vis_feature_dim: int = 1280,
|
24 |
+
lang_embedding_dim: int = 3072,
|
25 |
+
num_vis_tokens: int = 128,
|
26 |
+
image_aspect_ratio: str = 'anyres',
|
27 |
+
repeat_latents: bool = False,
|
28 |
+
**kwargs):
|
29 |
+
self.vis_feature_dim = vis_feature_dim
|
30 |
+
self.lang_embedding_dim = lang_embedding_dim
|
31 |
+
self.num_vis_tokens = num_vis_tokens
|
32 |
+
self.image_aspect_ratio = image_aspect_ratio
|
33 |
+
self.repeat_latents = repeat_latents
|
34 |
+
super().__init__(**kwargs)
|
35 |
+
|
36 |
+
|
37 |
+
class Blip3Config(PretrainedConfig):
|
38 |
+
model_type = "blip_3"
|
39 |
+
|
40 |
+
def __init__(self,
|
41 |
+
vision_encoder_config: dict = None,
|
42 |
+
vision_tokenizer_config: dict = None,
|
43 |
+
text_config: dict = None,
|
44 |
+
**kwargs):
|
45 |
+
|
46 |
+
if vision_encoder_config is None:
|
47 |
+
vision_encoder_config = {'image_aspect_ratio': 'anyres', 'anyres_patch_sampling': True}
|
48 |
+
logger.info("vision_encoder_config is None. initializing the Blip3VisionEncoderConfig with default values.")
|
49 |
+
|
50 |
+
if vision_tokenizer_config is None:
|
51 |
+
vision_tokenizer_config = {}
|
52 |
+
logger.info("vision_tokenizer_config is None. Initializing the Blip3VisionTokenizerConfig with default values.")
|
53 |
+
|
54 |
+
if text_config is None:
|
55 |
+
text_config = {
|
56 |
+
'initial_tokenizer_len':32012,
|
57 |
+
'pad_token_id':32011,
|
58 |
+
'bos_token_id':1,
|
59 |
+
'eos_token_id':32000,
|
60 |
+
'vocab_size': 32064,
|
61 |
+
'hidden_size': 3072,
|
62 |
+
'intermediate_size': 8192,
|
63 |
+
'num_hidden_layers': 32,
|
64 |
+
'num_attention_heads': 32,
|
65 |
+
'num_key_value_heads': 32,
|
66 |
+
'resid_pdrop': 0.0,
|
67 |
+
'embd_pdrop': 0.0,
|
68 |
+
'attention_dropout': 0.0,
|
69 |
+
'hidden_act': 'silu',
|
70 |
+
'max_position_embeddings': 4096,
|
71 |
+
'original_max_position_embeddings': 4096,
|
72 |
+
'initializer_range': 0.02,
|
73 |
+
'rms_norm_eps': 1e-05,
|
74 |
+
'use_cache': True,
|
75 |
+
'rope_theta': 10000.0,
|
76 |
+
'rope_scaling': None,
|
77 |
+
'sliding_window': 2047,
|
78 |
+
'return_dict': True,
|
79 |
+
'output_hidden_states': False,
|
80 |
+
'output_attentions': False,
|
81 |
+
'torchscript': False,
|
82 |
+
'torch_dtype': 'bfloat16',
|
83 |
+
'use_bfloat16': False,
|
84 |
+
'tf_legacy_loss': False,
|
85 |
+
'pruned_heads': {},
|
86 |
+
'tie_word_embeddings': False,
|
87 |
+
'chunk_size_feed_forward': 0,
|
88 |
+
'is_encoder_decoder': False,
|
89 |
+
'is_decoder': False,
|
90 |
+
'cross_attention_hidden_size': None,
|
91 |
+
'add_cross_attention': False,
|
92 |
+
'tie_encoder_decoder': False,
|
93 |
+
'max_length': 20,
|
94 |
+
'min_length': 0,
|
95 |
+
'do_sample': False,
|
96 |
+
'early_stopping': False,
|
97 |
+
'num_beams': 1,
|
98 |
+
'num_beam_groups': 1,
|
99 |
+
'diversity_penalty': 0.0,
|
100 |
+
'temperature': 1.0,
|
101 |
+
'top_k': 50,
|
102 |
+
'top_p': 1.0,
|
103 |
+
'typical_p': 1.0,
|
104 |
+
'repetition_penalty': 1.0,
|
105 |
+
'length_penalty': 1.0,
|
106 |
+
'no_repeat_ngram_size': 0,
|
107 |
+
'encoder_no_repeat_ngram_size': 0,
|
108 |
+
'bad_words_ids': None,
|
109 |
+
'num_return_sequences': 1,
|
110 |
+
'output_scores': False,
|
111 |
+
'return_dict_in_generate': False,
|
112 |
+
'forced_bos_token_id': None,
|
113 |
+
'forced_eos_token_id': None,
|
114 |
+
'remove_invalid_values': False,
|
115 |
+
'exponential_decay_length_penalty': None,
|
116 |
+
'suppress_tokens': None,
|
117 |
+
'begin_suppress_tokens': None,
|
118 |
+
'finetuning_task': None,
|
119 |
+
'id2label': {0: 'LABEL_0', 1: 'LABEL_1'},
|
120 |
+
'label2id': {'LABEL_0': 0, 'LABEL_1': 1},
|
121 |
+
'tokenizer_class': None,
|
122 |
+
'prefix': None,
|
123 |
+
'bos_token_id': 1,
|
124 |
+
'pad_token_id': 32000,
|
125 |
+
'eos_token_id': 32000,
|
126 |
+
'sep_token_id': None,
|
127 |
+
'decoder_start_token_id': None,
|
128 |
+
'task_specific_params': None,
|
129 |
+
'problem_type': None,
|
130 |
+
'model_type': 'phi3'
|
131 |
+
}
|
132 |
+
logger.info("text_config is None. Initializing the text config with default values (`Phi3Config`).")
|
133 |
+
|
134 |
+
self.vision_encoder_config = Blip3VisionEncoderConfig(**vision_encoder_config)
|
135 |
+
|
136 |
+
self.vision_tokenizer_config = Blip3VisionTokenizerConfig(**vision_tokenizer_config)
|
137 |
+
|
138 |
+
text_model_type = text_config["model_type"] if "model_type" in text_config else "phi3"
|
139 |
+
self.text_config = CONFIG_MAPPING[text_model_type](**text_config)
|
140 |
+
|
141 |
+
for key in ['initial_tokenizer_len', 'pad_token_id']:
|
142 |
+
if key not in self.text_config.to_dict():
|
143 |
+
raise ValueError(f"The key `{key}` is missing in the text_config.")
|
144 |
+
|
145 |
+
super().__init__(**kwargs)
|
146 |
+
|
147 |
+
@classmethod
|
148 |
+
def from_vision_encoder_vision_tokenizer_text_configs(
|
149 |
+
cls,
|
150 |
+
vision_encoder_config: Blip3VisionEncoderConfig,
|
151 |
+
vision_tokenizer_config: Blip3VisionTokenizerConfig,
|
152 |
+
text_config: PretrainedConfig,
|
153 |
+
**kwargs):
|
154 |
+
|
155 |
+
return cls(
|
156 |
+
vision_encoder_config=vision_encoder_config.to_dict(),
|
157 |
+
vision_tokenizer_config=vision_tokenizer_config.to_dict(),
|
158 |
+
text_config=text_config.to_dict(),
|
159 |
+
**kwargs,
|
160 |
+
)
|
161 |
+
|
demo.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
generation_config.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"bos_token_id": 1,
|
4 |
+
"eos_token_id": 32000,
|
5 |
+
"pad_token_id": 32000,
|
6 |
+
"transformers_version": "4.41.0.dev0"
|
7 |
+
}
|
image_processing_blip_3.py
ADDED
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
3 |
+
import torchvision.transforms.functional as F
|
4 |
+
from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
|
5 |
+
CenterCrop, ColorJitter, Grayscale
|
6 |
+
import numbers
|
7 |
+
import torch
|
8 |
+
import ast
|
9 |
+
import math
|
10 |
+
from PIL import Image
|
11 |
+
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
|
12 |
+
from transformers.image_utils import ImageInput
|
13 |
+
from transformers.utils import TensorType
|
14 |
+
|
15 |
+
|
16 |
+
class Blip3ImageProcessor(BaseImageProcessor):
|
17 |
+
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
do_resize: bool = True,
|
21 |
+
resize_mode: str = "squash",
|
22 |
+
interpolation_mode: str = "bicubic",
|
23 |
+
size: Union[Tuple[int, int], List[int]] = None,
|
24 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
25 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
26 |
+
**kwargs,
|
27 |
+
) -> None:
|
28 |
+
super().__init__(**kwargs)
|
29 |
+
self.do_resize = do_resize
|
30 |
+
self.resize_mode = resize_mode
|
31 |
+
self.interpolation_mode = interpolation_mode
|
32 |
+
self.size = size if size is not None else (378, 378)
|
33 |
+
self.image_mean = image_mean if image_mean is not None else [0.48145466, 0.4578275, 0.40821073]
|
34 |
+
self.image_std = image_std if image_std is not None else [0.26862954, 0.26130258, 0.27577711]
|
35 |
+
|
36 |
+
|
37 |
+
@classmethod
|
38 |
+
def resize(cls, image_size, resize_mode, interpolation='bicubic', fill_color=0):
|
39 |
+
interpolation_mode = InterpolationMode.BILINEAR if interpolation == 'bilinear' else InterpolationMode.BICUBIC
|
40 |
+
if resize_mode == 'longest':
|
41 |
+
transforms = [
|
42 |
+
ResizeKeepRatio(image_size, interpolation=interpolation_mode, longest=1),
|
43 |
+
CenterCropOrPad(image_size, fill=fill_color)
|
44 |
+
]
|
45 |
+
elif resize_mode == 'squash':
|
46 |
+
if isinstance(image_size, int):
|
47 |
+
image_size = (image_size, image_size)
|
48 |
+
transforms = [
|
49 |
+
Resize(image_size, interpolation=interpolation_mode),
|
50 |
+
]
|
51 |
+
else:
|
52 |
+
assert resize_mode == 'shortest'
|
53 |
+
if not isinstance(image_size, (tuple, list)):
|
54 |
+
image_size = (image_size, image_size)
|
55 |
+
if image_size[0] == image_size[1]:
|
56 |
+
# simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg)
|
57 |
+
transforms = [
|
58 |
+
Resize(image_size[0], interpolation=interpolation_mode)
|
59 |
+
]
|
60 |
+
else:
|
61 |
+
# resize shortest edge to matching target dim for non-square target
|
62 |
+
transforms = [ResizeKeepRatio(image_size)]
|
63 |
+
transforms += [CenterCrop(image_size)]
|
64 |
+
return transforms
|
65 |
+
|
66 |
+
@classmethod
|
67 |
+
def convert_rgb(cls, image):
|
68 |
+
return image.convert("RGB")
|
69 |
+
|
70 |
+
|
71 |
+
def _preprocess(self,
|
72 |
+
images: ImageInput
|
73 |
+
) -> torch.Tensor:
|
74 |
+
transforms = self.resize(self.size, self.resize_mode, self.interpolation_mode)
|
75 |
+
transforms.extend([
|
76 |
+
self.convert_rgb,
|
77 |
+
ToTensor(),
|
78 |
+
Normalize(mean=self.image_mean, std=self.image_std)
|
79 |
+
])
|
80 |
+
composed_transforms = Compose(transforms)
|
81 |
+
images_tensor = composed_transforms(images)
|
82 |
+
return images_tensor
|
83 |
+
|
84 |
+
def preprocess(self,
|
85 |
+
images: ImageInput,
|
86 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
87 |
+
**kwargs) -> BatchFeature:
|
88 |
+
if 'image_aspect_ratio' in kwargs:
|
89 |
+
image_aspect_ratio = kwargs['image_aspect_ratio']
|
90 |
+
else:
|
91 |
+
image_aspect_ratio = 'pad'
|
92 |
+
new_images = []
|
93 |
+
if image_aspect_ratio == 'pad':
|
94 |
+
for image in images:
|
95 |
+
image = self._preprocess(image)
|
96 |
+
new_images.append(image)
|
97 |
+
else:
|
98 |
+
if isinstance(self.size, (tuple, list)):
|
99 |
+
base_img_size = self.size[0]
|
100 |
+
else:
|
101 |
+
raise ValueError("size should be list or tuple")
|
102 |
+
for image in images:
|
103 |
+
image = process_anyres_image(image, self._preprocess, self.size,
|
104 |
+
[
|
105 |
+
[base_img_size,base_img_size*2],
|
106 |
+
[base_img_size*2,base_img_size],
|
107 |
+
[base_img_size*2,base_img_size*2],
|
108 |
+
[base_img_size*3,base_img_size],
|
109 |
+
[base_img_size,base_img_size*3]
|
110 |
+
])
|
111 |
+
new_images.append(image)
|
112 |
+
|
113 |
+
if all(x.shape == new_images[0].shape for x in new_images):
|
114 |
+
new_images = torch.stack(new_images, dim=0)
|
115 |
+
if image_aspect_ratio == 'pad':
|
116 |
+
new_images = BatchFeature(data={"pixel_values": new_images.unsqueeze(0).unsqueeze(0)}, tensor_type=return_tensors)
|
117 |
+
else:
|
118 |
+
new_images = BatchFeature(data={"pixel_values": new_images.unsqueeze(0)}, tensor_type=return_tensors)
|
119 |
+
return new_images
|
120 |
+
# def preprocess(self,
|
121 |
+
# images: ImageInput,
|
122 |
+
# return_tensors: Optional[Union[str, TensorType]] = None,
|
123 |
+
# **kwargs) -> BatchFeature:
|
124 |
+
# transforms = self.resize(self.size, self.resize_mode, self.interpolation_mode)
|
125 |
+
# transforms.extend([
|
126 |
+
# self.convert_rgb,
|
127 |
+
# ToTensor(),
|
128 |
+
# Normalize(mean=self.image_mean, std=self.image_std)
|
129 |
+
# ])
|
130 |
+
# composed_transforms = Compose(transforms)
|
131 |
+
# images_tensor = composed_transforms(images).unsqueeze(0).unsqueeze(1).unsqueeze(0)
|
132 |
+
# encoded_outputs = BatchFeature(data={"pixel_values": images_tensor}, tensor_type=return_tensors)
|
133 |
+
# return encoded_outputs
|
134 |
+
|
135 |
+
|
136 |
+
class ResizeKeepRatio:
|
137 |
+
""" Resize and Keep Ratio
|
138 |
+
|
139 |
+
Copy & paste from `timm`
|
140 |
+
"""
|
141 |
+
|
142 |
+
def __init__(
|
143 |
+
self,
|
144 |
+
size,
|
145 |
+
longest=0.,
|
146 |
+
interpolation=InterpolationMode.BICUBIC,
|
147 |
+
random_scale_prob=0.,
|
148 |
+
random_scale_range=(0.85, 1.05),
|
149 |
+
random_aspect_prob=0.,
|
150 |
+
random_aspect_range=(0.9, 1.11)
|
151 |
+
):
|
152 |
+
if isinstance(size, (list, tuple)):
|
153 |
+
self.size = tuple(size)
|
154 |
+
else:
|
155 |
+
self.size = (size, size)
|
156 |
+
self.interpolation = interpolation
|
157 |
+
self.longest = float(longest) # [0, 1] where 0 == shortest edge, 1 == longest
|
158 |
+
self.random_scale_prob = random_scale_prob
|
159 |
+
self.random_scale_range = random_scale_range
|
160 |
+
self.random_aspect_prob = random_aspect_prob
|
161 |
+
self.random_aspect_range = random_aspect_range
|
162 |
+
|
163 |
+
@staticmethod
|
164 |
+
def get_params(
|
165 |
+
img,
|
166 |
+
target_size,
|
167 |
+
longest,
|
168 |
+
random_scale_prob=0.,
|
169 |
+
random_scale_range=(0.85, 1.05),
|
170 |
+
random_aspect_prob=0.,
|
171 |
+
random_aspect_range=(0.9, 1.11)
|
172 |
+
):
|
173 |
+
"""Get parameters
|
174 |
+
"""
|
175 |
+
source_size = img.size[::-1] # h, w
|
176 |
+
h, w = source_size
|
177 |
+
target_h, target_w = target_size
|
178 |
+
ratio_h = h / target_h
|
179 |
+
ratio_w = w / target_w
|
180 |
+
ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (1. - longest)
|
181 |
+
if random_scale_prob > 0 and random.random() < random_scale_prob:
|
182 |
+
ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1])
|
183 |
+
ratio_factor = (ratio_factor, ratio_factor)
|
184 |
+
else:
|
185 |
+
ratio_factor = (1., 1.)
|
186 |
+
if random_aspect_prob > 0 and random.random() < random_aspect_prob:
|
187 |
+
aspect_factor = random.uniform(random_aspect_range[0], random_aspect_range[1])
|
188 |
+
ratio_factor = (ratio_factor[0] / aspect_factor, ratio_factor[1] * aspect_factor)
|
189 |
+
size = [round(x * f / ratio) for x, f in zip(source_size, ratio_factor)]
|
190 |
+
return size
|
191 |
+
|
192 |
+
def __call__(self, img):
|
193 |
+
"""
|
194 |
+
Args:
|
195 |
+
img (PIL Image): Image to be cropped and resized.
|
196 |
+
|
197 |
+
Returns:
|
198 |
+
PIL Image: Resized, padded to at least target size, possibly cropped to exactly target size
|
199 |
+
"""
|
200 |
+
size = self.get_params(
|
201 |
+
img, self.size, self.longest,
|
202 |
+
self.random_scale_prob, self.random_scale_range,
|
203 |
+
self.random_aspect_prob, self.random_aspect_range
|
204 |
+
)
|
205 |
+
img = F.resize(img, size, self.interpolation)
|
206 |
+
return img
|
207 |
+
|
208 |
+
def __repr__(self):
|
209 |
+
format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
|
210 |
+
format_string += f', interpolation={self.interpolation})'
|
211 |
+
format_string += f', longest={self.longest:.3f})'
|
212 |
+
return format_string
|
213 |
+
|
214 |
+
def _setup_size(size, error_msg):
|
215 |
+
if isinstance(size, numbers.Number):
|
216 |
+
return int(size), int(size)
|
217 |
+
|
218 |
+
if isinstance(size, Sequence) and len(size) == 1:
|
219 |
+
return size[0], size[0]
|
220 |
+
|
221 |
+
if len(size) != 2:
|
222 |
+
raise ValueError(error_msg)
|
223 |
+
|
224 |
+
return size
|
225 |
+
|
226 |
+
def center_crop_or_pad(img: torch.Tensor, output_size: List[int], fill=0) -> torch.Tensor:
|
227 |
+
"""Center crops and/or pads the given image.
|
228 |
+
If the image is torch Tensor, it is expected
|
229 |
+
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
|
230 |
+
If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
|
231 |
+
|
232 |
+
Args:
|
233 |
+
img (PIL Image or Tensor): Image to be cropped.
|
234 |
+
output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int,
|
235 |
+
it is used for both directions.
|
236 |
+
fill (int, Tuple[int]): Padding color
|
237 |
+
|
238 |
+
Returns:
|
239 |
+
PIL Image or Tensor: Cropped image.
|
240 |
+
"""
|
241 |
+
if isinstance(output_size, numbers.Number):
|
242 |
+
output_size = (int(output_size), int(output_size))
|
243 |
+
elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
|
244 |
+
output_size = (output_size[0], output_size[0])
|
245 |
+
|
246 |
+
_, image_height, image_width = F.get_dimensions(img)
|
247 |
+
crop_height, crop_width = output_size
|
248 |
+
|
249 |
+
if crop_width > image_width or crop_height > image_height:
|
250 |
+
padding_ltrb = [
|
251 |
+
(crop_width - image_width) // 2 if crop_width > image_width else 0,
|
252 |
+
(crop_height - image_height) // 2 if crop_height > image_height else 0,
|
253 |
+
(crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
|
254 |
+
(crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
|
255 |
+
]
|
256 |
+
img = F.pad(img, padding_ltrb, fill=fill)
|
257 |
+
_, image_height, image_width = F.get_dimensions(img)
|
258 |
+
if crop_width == image_width and crop_height == image_height:
|
259 |
+
return img
|
260 |
+
|
261 |
+
crop_top = int(round((image_height - crop_height) / 2.0))
|
262 |
+
crop_left = int(round((image_width - crop_width) / 2.0))
|
263 |
+
return F.crop(img, crop_top, crop_left, crop_height, crop_width)
|
264 |
+
|
265 |
+
class CenterCropOrPad(torch.nn.Module):
|
266 |
+
"""Crops the given image at the center.
|
267 |
+
If the image is torch Tensor, it is expected
|
268 |
+
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
|
269 |
+
If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
|
270 |
+
|
271 |
+
Args:
|
272 |
+
size (sequence or int): Desired output size of the crop. If size is an
|
273 |
+
int instead of sequence like (h, w), a square crop (size, size) is
|
274 |
+
made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
|
275 |
+
"""
|
276 |
+
|
277 |
+
def __init__(self, size, fill=0):
|
278 |
+
super().__init__()
|
279 |
+
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
|
280 |
+
self.fill = fill
|
281 |
+
|
282 |
+
def forward(self, img):
|
283 |
+
"""
|
284 |
+
Args:
|
285 |
+
img (PIL Image or Tensor): Image to be cropped.
|
286 |
+
|
287 |
+
Returns:
|
288 |
+
PIL Image or Tensor: Cropped image.
|
289 |
+
"""
|
290 |
+
return center_crop_or_pad(img, self.size, fill=self.fill)
|
291 |
+
|
292 |
+
def __repr__(self) -> str:
|
293 |
+
return f"{self.__class__.__name__}(size={self.size})"
|
294 |
+
|
295 |
+
def process_anyres_image(image, processor, processor_size, grid_pinpoints):
|
296 |
+
"""
|
297 |
+
Process an image with variable resolutions.
|
298 |
+
|
299 |
+
Args:
|
300 |
+
image (PIL.Image.Image): The input image to be processed.
|
301 |
+
processor: The image processor object.
|
302 |
+
processor_size (tuple, list): The size of the image processor.
|
303 |
+
grid_pinpoints (str): A string representation of a list of possible resolutions.
|
304 |
+
|
305 |
+
Returns:
|
306 |
+
torch.Tensor: A tensor containing the processed image patches.
|
307 |
+
"""
|
308 |
+
# FIXME: determine grid_pinpoints from image sizes.
|
309 |
+
if type(grid_pinpoints) is list:
|
310 |
+
possible_resolutions = grid_pinpoints
|
311 |
+
else:
|
312 |
+
possible_resolutions = ast.literal_eval(grid_pinpoints)
|
313 |
+
best_resolution = select_best_resolution(image.size, possible_resolutions)
|
314 |
+
image_padded = resize_and_pad_image(image, best_resolution)
|
315 |
+
|
316 |
+
# processor_size = processor.transforms[0].size
|
317 |
+
patches = divide_to_patches(image_padded, processor_size[0])
|
318 |
+
|
319 |
+
image_original_resize = image.resize((processor_size[0], processor_size[0]))
|
320 |
+
|
321 |
+
image_patches = [image_original_resize] + patches
|
322 |
+
image_patches = [processor(image_patch)
|
323 |
+
for image_patch in image_patches]
|
324 |
+
return torch.stack(image_patches, dim=0)
|
325 |
+
|
326 |
+
|
327 |
+
def select_best_resolution(original_size, possible_resolutions):
|
328 |
+
"""
|
329 |
+
Selects the best resolution from a list of possible resolutions based on the original size.
|
330 |
+
|
331 |
+
Args:
|
332 |
+
original_size (tuple): The original size of the image in the format (width, height).
|
333 |
+
possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
|
334 |
+
|
335 |
+
Returns:
|
336 |
+
tuple: The best fit resolution in the format (width, height).
|
337 |
+
"""
|
338 |
+
original_width, original_height = original_size
|
339 |
+
best_fit = None
|
340 |
+
max_effective_resolution = 0
|
341 |
+
min_wasted_resolution = float('inf')
|
342 |
+
|
343 |
+
for width, height in possible_resolutions:
|
344 |
+
scale = min(width / original_width, height / original_height)
|
345 |
+
downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
|
346 |
+
effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
|
347 |
+
wasted_resolution = (width * height) - effective_resolution
|
348 |
+
|
349 |
+
if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
|
350 |
+
max_effective_resolution = effective_resolution
|
351 |
+
min_wasted_resolution = wasted_resolution
|
352 |
+
best_fit = (width, height)
|
353 |
+
|
354 |
+
return best_fit
|
355 |
+
|
356 |
+
def resize_and_pad_image(image, target_resolution):
|
357 |
+
"""
|
358 |
+
Resize and pad an image to a target resolution while maintaining aspect ratio.
|
359 |
+
|
360 |
+
Args:
|
361 |
+
image (PIL.Image.Image): The input image.
|
362 |
+
target_resolution (tuple): The target resolution (width, height) of the image.
|
363 |
+
|
364 |
+
Returns:
|
365 |
+
PIL.Image.Image: The resized and padded image.
|
366 |
+
"""
|
367 |
+
original_width, original_height = image.size
|
368 |
+
target_width, target_height = target_resolution
|
369 |
+
|
370 |
+
scale_w = target_width / original_width
|
371 |
+
scale_h = target_height / original_height
|
372 |
+
|
373 |
+
if scale_w < scale_h:
|
374 |
+
new_width = target_width
|
375 |
+
new_height = min(math.ceil(original_height * scale_w), target_height)
|
376 |
+
else:
|
377 |
+
new_height = target_height
|
378 |
+
new_width = min(math.ceil(original_width * scale_h), target_width)
|
379 |
+
|
380 |
+
# Resize the image
|
381 |
+
resized_image = image.resize((new_width, new_height))
|
382 |
+
|
383 |
+
new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
|
384 |
+
paste_x = (target_width - new_width) // 2
|
385 |
+
paste_y = (target_height - new_height) // 2
|
386 |
+
new_image.paste(resized_image, (paste_x, paste_y))
|
387 |
+
|
388 |
+
return new_image
|
389 |
+
|
390 |
+
def divide_to_patches(image, patch_size):
|
391 |
+
"""
|
392 |
+
Divides an image into patches of a specified size.
|
393 |
+
|
394 |
+
Args:
|
395 |
+
image (PIL.Image.Image): The input image.
|
396 |
+
patch_size (int): The size of each patch.
|
397 |
+
|
398 |
+
Returns:
|
399 |
+
list: A list of PIL.Image.Image objects representing the patches.
|
400 |
+
"""
|
401 |
+
patches = []
|
402 |
+
width, height = image.size
|
403 |
+
for i in range(0, height, patch_size):
|
404 |
+
for j in range(0, width, patch_size):
|
405 |
+
box = (j, i, j + patch_size, i + patch_size)
|
406 |
+
patch = image.crop(box)
|
407 |
+
patches.append(patch)
|
408 |
+
|
409 |
+
return patches
|
model-00001-of-00004.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2e6acb1fba540ae862c8ec5a3f0898a25f97df07029235aa164be537e13e664b
|
3 |
+
size 4977054880
|
model-00002-of-00004.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:edf7152efe2adbf112f09395c1a443d53b439cba0342285dae7ab6225281dcaa
|
3 |
+
size 4983112128
|
model-00003-of-00004.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e1afa3c02b194f2a1cd36f296ef7690c42f356b9b48e98644011b59336b0699a
|
3 |
+
size 4983112168
|
model-00004-of-00004.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b69d6172f961180def49586fe73b71c2bd2e4ba968564f276486e86030a1da36
|
3 |
+
size 3414256548
|
model.safetensors.index.json
ADDED
@@ -0,0 +1,673 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"metadata": {
|
3 |
+
"total_size": 18357448972
|
4 |
+
},
|
5 |
+
"weight_map": {
|
6 |
+
"vlm.lang_model.lm_head.additional_fc.bias": "model-00004-of-00004.safetensors",
|
7 |
+
"vlm.lang_model.lm_head.additional_fc.weight": "model-00004-of-00004.safetensors",
|
8 |
+
"vlm.lang_model.lm_head.bias": "model-00004-of-00004.safetensors",
|
9 |
+
"vlm.lang_model.lm_head.weight": "model-00004-of-00004.safetensors",
|
10 |
+
"vlm.lang_model.model.embed_tokens.additional_embedding.weight": "model-00001-of-00004.safetensors",
|
11 |
+
"vlm.lang_model.model.embed_tokens.weight": "model-00001-of-00004.safetensors",
|
12 |
+
"vlm.lang_model.model.layers.0.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
13 |
+
"vlm.lang_model.model.layers.0.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
14 |
+
"vlm.lang_model.model.layers.0.mlp.gate_up_proj.weight": "model-00001-of-00004.safetensors",
|
15 |
+
"vlm.lang_model.model.layers.0.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
16 |
+
"vlm.lang_model.model.layers.0.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
17 |
+
"vlm.lang_model.model.layers.0.self_attn.qkv_proj.weight": "model-00001-of-00004.safetensors",
|
18 |
+
"vlm.lang_model.model.layers.1.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
19 |
+
"vlm.lang_model.model.layers.1.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
20 |
+
"vlm.lang_model.model.layers.1.mlp.gate_up_proj.weight": "model-00001-of-00004.safetensors",
|
21 |
+
"vlm.lang_model.model.layers.1.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
22 |
+
"vlm.lang_model.model.layers.1.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
23 |
+
"vlm.lang_model.model.layers.1.self_attn.qkv_proj.weight": "model-00001-of-00004.safetensors",
|
24 |
+
"vlm.lang_model.model.layers.10.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
25 |
+
"vlm.lang_model.model.layers.10.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
26 |
+
"vlm.lang_model.model.layers.10.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
|
27 |
+
"vlm.lang_model.model.layers.10.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
28 |
+
"vlm.lang_model.model.layers.10.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
29 |
+
"vlm.lang_model.model.layers.10.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
|
30 |
+
"vlm.lang_model.model.layers.11.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
31 |
+
"vlm.lang_model.model.layers.11.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
32 |
+
"vlm.lang_model.model.layers.11.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
|
33 |
+
"vlm.lang_model.model.layers.11.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
34 |
+
"vlm.lang_model.model.layers.11.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
35 |
+
"vlm.lang_model.model.layers.11.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
|
36 |
+
"vlm.lang_model.model.layers.12.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
37 |
+
"vlm.lang_model.model.layers.12.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
38 |
+
"vlm.lang_model.model.layers.12.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
|
39 |
+
"vlm.lang_model.model.layers.12.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
40 |
+
"vlm.lang_model.model.layers.12.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
41 |
+
"vlm.lang_model.model.layers.12.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
|
42 |
+
"vlm.lang_model.model.layers.13.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
43 |
+
"vlm.lang_model.model.layers.13.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
44 |
+
"vlm.lang_model.model.layers.13.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
|
45 |
+
"vlm.lang_model.model.layers.13.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
46 |
+
"vlm.lang_model.model.layers.13.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
47 |
+
"vlm.lang_model.model.layers.13.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
|
48 |
+
"vlm.lang_model.model.layers.14.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
49 |
+
"vlm.lang_model.model.layers.14.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
50 |
+
"vlm.lang_model.model.layers.14.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
|
51 |
+
"vlm.lang_model.model.layers.14.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
52 |
+
"vlm.lang_model.model.layers.14.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
53 |
+
"vlm.lang_model.model.layers.14.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
|
54 |
+
"vlm.lang_model.model.layers.15.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
55 |
+
"vlm.lang_model.model.layers.15.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
56 |
+
"vlm.lang_model.model.layers.15.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
|
57 |
+
"vlm.lang_model.model.layers.15.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
58 |
+
"vlm.lang_model.model.layers.15.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
59 |
+
"vlm.lang_model.model.layers.15.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
|
60 |
+
"vlm.lang_model.model.layers.16.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
61 |
+
"vlm.lang_model.model.layers.16.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
62 |
+
"vlm.lang_model.model.layers.16.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
|
63 |
+
"vlm.lang_model.model.layers.16.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
64 |
+
"vlm.lang_model.model.layers.16.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
65 |
+
"vlm.lang_model.model.layers.16.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
|
66 |
+
"vlm.lang_model.model.layers.17.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
67 |
+
"vlm.lang_model.model.layers.17.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
68 |
+
"vlm.lang_model.model.layers.17.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
|
69 |
+
"vlm.lang_model.model.layers.17.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
70 |
+
"vlm.lang_model.model.layers.17.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
71 |
+
"vlm.lang_model.model.layers.17.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
|
72 |
+
"vlm.lang_model.model.layers.18.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
73 |
+
"vlm.lang_model.model.layers.18.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
74 |
+
"vlm.lang_model.model.layers.18.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
|
75 |
+
"vlm.lang_model.model.layers.18.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
76 |
+
"vlm.lang_model.model.layers.18.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
77 |
+
"vlm.lang_model.model.layers.18.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
|
78 |
+
"vlm.lang_model.model.layers.19.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
79 |
+
"vlm.lang_model.model.layers.19.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
80 |
+
"vlm.lang_model.model.layers.19.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
|
81 |
+
"vlm.lang_model.model.layers.19.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
82 |
+
"vlm.lang_model.model.layers.19.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
83 |
+
"vlm.lang_model.model.layers.19.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
|
84 |
+
"vlm.lang_model.model.layers.2.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
85 |
+
"vlm.lang_model.model.layers.2.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
86 |
+
"vlm.lang_model.model.layers.2.mlp.gate_up_proj.weight": "model-00001-of-00004.safetensors",
|
87 |
+
"vlm.lang_model.model.layers.2.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
88 |
+
"vlm.lang_model.model.layers.2.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
89 |
+
"vlm.lang_model.model.layers.2.self_attn.qkv_proj.weight": "model-00001-of-00004.safetensors",
|
90 |
+
"vlm.lang_model.model.layers.20.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
91 |
+
"vlm.lang_model.model.layers.20.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
92 |
+
"vlm.lang_model.model.layers.20.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
|
93 |
+
"vlm.lang_model.model.layers.20.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
94 |
+
"vlm.lang_model.model.layers.20.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
95 |
+
"vlm.lang_model.model.layers.20.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
|
96 |
+
"vlm.lang_model.model.layers.21.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
97 |
+
"vlm.lang_model.model.layers.21.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
98 |
+
"vlm.lang_model.model.layers.21.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
|
99 |
+
"vlm.lang_model.model.layers.21.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
100 |
+
"vlm.lang_model.model.layers.21.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
101 |
+
"vlm.lang_model.model.layers.21.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
|
102 |
+
"vlm.lang_model.model.layers.22.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
103 |
+
"vlm.lang_model.model.layers.22.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
104 |
+
"vlm.lang_model.model.layers.22.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
|
105 |
+
"vlm.lang_model.model.layers.22.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
106 |
+
"vlm.lang_model.model.layers.22.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
107 |
+
"vlm.lang_model.model.layers.22.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
|
108 |
+
"vlm.lang_model.model.layers.23.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
109 |
+
"vlm.lang_model.model.layers.23.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
110 |
+
"vlm.lang_model.model.layers.23.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
|
111 |
+
"vlm.lang_model.model.layers.23.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
112 |
+
"vlm.lang_model.model.layers.23.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
113 |
+
"vlm.lang_model.model.layers.23.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
|
114 |
+
"vlm.lang_model.model.layers.24.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
115 |
+
"vlm.lang_model.model.layers.24.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
116 |
+
"vlm.lang_model.model.layers.24.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
|
117 |
+
"vlm.lang_model.model.layers.24.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
118 |
+
"vlm.lang_model.model.layers.24.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
119 |
+
"vlm.lang_model.model.layers.24.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
|
120 |
+
"vlm.lang_model.model.layers.25.input_layernorm.weight": "model-00004-of-00004.safetensors",
|
121 |
+
"vlm.lang_model.model.layers.25.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
122 |
+
"vlm.lang_model.model.layers.25.mlp.gate_up_proj.weight": "model-00004-of-00004.safetensors",
|
123 |
+
"vlm.lang_model.model.layers.25.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
|
124 |
+
"vlm.lang_model.model.layers.25.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
125 |
+
"vlm.lang_model.model.layers.25.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
|
126 |
+
"vlm.lang_model.model.layers.26.input_layernorm.weight": "model-00004-of-00004.safetensors",
|
127 |
+
"vlm.lang_model.model.layers.26.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
128 |
+
"vlm.lang_model.model.layers.26.mlp.gate_up_proj.weight": "model-00004-of-00004.safetensors",
|
129 |
+
"vlm.lang_model.model.layers.26.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
|
130 |
+
"vlm.lang_model.model.layers.26.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
|
131 |
+
"vlm.lang_model.model.layers.26.self_attn.qkv_proj.weight": "model-00004-of-00004.safetensors",
|
132 |
+
"vlm.lang_model.model.layers.27.input_layernorm.weight": "model-00004-of-00004.safetensors",
|
133 |
+
"vlm.lang_model.model.layers.27.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
134 |
+
"vlm.lang_model.model.layers.27.mlp.gate_up_proj.weight": "model-00004-of-00004.safetensors",
|
135 |
+
"vlm.lang_model.model.layers.27.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
|
136 |
+
"vlm.lang_model.model.layers.27.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
|
137 |
+
"vlm.lang_model.model.layers.27.self_attn.qkv_proj.weight": "model-00004-of-00004.safetensors",
|
138 |
+
"vlm.lang_model.model.layers.28.input_layernorm.weight": "model-00004-of-00004.safetensors",
|
139 |
+
"vlm.lang_model.model.layers.28.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
140 |
+
"vlm.lang_model.model.layers.28.mlp.gate_up_proj.weight": "model-00004-of-00004.safetensors",
|
141 |
+
"vlm.lang_model.model.layers.28.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
|
142 |
+
"vlm.lang_model.model.layers.28.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
|
143 |
+
"vlm.lang_model.model.layers.28.self_attn.qkv_proj.weight": "model-00004-of-00004.safetensors",
|
144 |
+
"vlm.lang_model.model.layers.29.input_layernorm.weight": "model-00004-of-00004.safetensors",
|
145 |
+
"vlm.lang_model.model.layers.29.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
146 |
+
"vlm.lang_model.model.layers.29.mlp.gate_up_proj.weight": "model-00004-of-00004.safetensors",
|
147 |
+
"vlm.lang_model.model.layers.29.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
|
148 |
+
"vlm.lang_model.model.layers.29.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
|
149 |
+
"vlm.lang_model.model.layers.29.self_attn.qkv_proj.weight": "model-00004-of-00004.safetensors",
|
150 |
+
"vlm.lang_model.model.layers.3.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
151 |
+
"vlm.lang_model.model.layers.3.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
152 |
+
"vlm.lang_model.model.layers.3.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
|
153 |
+
"vlm.lang_model.model.layers.3.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
154 |
+
"vlm.lang_model.model.layers.3.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
155 |
+
"vlm.lang_model.model.layers.3.self_attn.qkv_proj.weight": "model-00001-of-00004.safetensors",
|
156 |
+
"vlm.lang_model.model.layers.30.input_layernorm.weight": "model-00004-of-00004.safetensors",
|
157 |
+
"vlm.lang_model.model.layers.30.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
158 |
+
"vlm.lang_model.model.layers.30.mlp.gate_up_proj.weight": "model-00004-of-00004.safetensors",
|
159 |
+
"vlm.lang_model.model.layers.30.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
|
160 |
+
"vlm.lang_model.model.layers.30.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
|
161 |
+
"vlm.lang_model.model.layers.30.self_attn.qkv_proj.weight": "model-00004-of-00004.safetensors",
|
162 |
+
"vlm.lang_model.model.layers.31.input_layernorm.weight": "model-00004-of-00004.safetensors",
|
163 |
+
"vlm.lang_model.model.layers.31.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
164 |
+
"vlm.lang_model.model.layers.31.mlp.gate_up_proj.weight": "model-00004-of-00004.safetensors",
|
165 |
+
"vlm.lang_model.model.layers.31.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
|
166 |
+
"vlm.lang_model.model.layers.31.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
|
167 |
+
"vlm.lang_model.model.layers.31.self_attn.qkv_proj.weight": "model-00004-of-00004.safetensors",
|
168 |
+
"vlm.lang_model.model.layers.4.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
169 |
+
"vlm.lang_model.model.layers.4.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
170 |
+
"vlm.lang_model.model.layers.4.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
|
171 |
+
"vlm.lang_model.model.layers.4.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
172 |
+
"vlm.lang_model.model.layers.4.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
173 |
+
"vlm.lang_model.model.layers.4.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
|
174 |
+
"vlm.lang_model.model.layers.5.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
175 |
+
"vlm.lang_model.model.layers.5.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
176 |
+
"vlm.lang_model.model.layers.5.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
|
177 |
+
"vlm.lang_model.model.layers.5.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
178 |
+
"vlm.lang_model.model.layers.5.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
179 |
+
"vlm.lang_model.model.layers.5.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
|
180 |
+
"vlm.lang_model.model.layers.6.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
181 |
+
"vlm.lang_model.model.layers.6.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
182 |
+
"vlm.lang_model.model.layers.6.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
|
183 |
+
"vlm.lang_model.model.layers.6.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
184 |
+
"vlm.lang_model.model.layers.6.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
185 |
+
"vlm.lang_model.model.layers.6.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
|
186 |
+
"vlm.lang_model.model.layers.7.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
187 |
+
"vlm.lang_model.model.layers.7.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
188 |
+
"vlm.lang_model.model.layers.7.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
|
189 |
+
"vlm.lang_model.model.layers.7.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
190 |
+
"vlm.lang_model.model.layers.7.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
191 |
+
"vlm.lang_model.model.layers.7.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
|
192 |
+
"vlm.lang_model.model.layers.8.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
193 |
+
"vlm.lang_model.model.layers.8.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
194 |
+
"vlm.lang_model.model.layers.8.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
|
195 |
+
"vlm.lang_model.model.layers.8.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
196 |
+
"vlm.lang_model.model.layers.8.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
197 |
+
"vlm.lang_model.model.layers.8.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
|
198 |
+
"vlm.lang_model.model.layers.9.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
199 |
+
"vlm.lang_model.model.layers.9.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
200 |
+
"vlm.lang_model.model.layers.9.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
|
201 |
+
"vlm.lang_model.model.layers.9.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
202 |
+
"vlm.lang_model.model.layers.9.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
203 |
+
"vlm.lang_model.model.layers.9.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
|
204 |
+
"vlm.lang_model.model.norm.weight": "model-00004-of-00004.safetensors",
|
205 |
+
"vlm.vision_encoder.class_embedding": "model-00001-of-00004.safetensors",
|
206 |
+
"vlm.vision_encoder.conv1.weight": "model-00001-of-00004.safetensors",
|
207 |
+
"vlm.vision_encoder.ln_post.bias": "model-00001-of-00004.safetensors",
|
208 |
+
"vlm.vision_encoder.ln_post.weight": "model-00001-of-00004.safetensors",
|
209 |
+
"vlm.vision_encoder.ln_pre.bias": "model-00001-of-00004.safetensors",
|
210 |
+
"vlm.vision_encoder.ln_pre.weight": "model-00001-of-00004.safetensors",
|
211 |
+
"vlm.vision_encoder.positional_embedding": "model-00001-of-00004.safetensors",
|
212 |
+
"vlm.vision_encoder.proj": "model-00001-of-00004.safetensors",
|
213 |
+
"vlm.vision_encoder.transformer.resblocks.0.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
214 |
+
"vlm.vision_encoder.transformer.resblocks.0.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
215 |
+
"vlm.vision_encoder.transformer.resblocks.0.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
216 |
+
"vlm.vision_encoder.transformer.resblocks.0.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
217 |
+
"vlm.vision_encoder.transformer.resblocks.0.ln_1.bias": "model-00001-of-00004.safetensors",
|
218 |
+
"vlm.vision_encoder.transformer.resblocks.0.ln_1.weight": "model-00001-of-00004.safetensors",
|
219 |
+
"vlm.vision_encoder.transformer.resblocks.0.ln_2.bias": "model-00001-of-00004.safetensors",
|
220 |
+
"vlm.vision_encoder.transformer.resblocks.0.ln_2.weight": "model-00001-of-00004.safetensors",
|
221 |
+
"vlm.vision_encoder.transformer.resblocks.0.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
222 |
+
"vlm.vision_encoder.transformer.resblocks.0.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
223 |
+
"vlm.vision_encoder.transformer.resblocks.0.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
224 |
+
"vlm.vision_encoder.transformer.resblocks.0.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
225 |
+
"vlm.vision_encoder.transformer.resblocks.1.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
226 |
+
"vlm.vision_encoder.transformer.resblocks.1.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
227 |
+
"vlm.vision_encoder.transformer.resblocks.1.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
228 |
+
"vlm.vision_encoder.transformer.resblocks.1.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
229 |
+
"vlm.vision_encoder.transformer.resblocks.1.ln_1.bias": "model-00001-of-00004.safetensors",
|
230 |
+
"vlm.vision_encoder.transformer.resblocks.1.ln_1.weight": "model-00001-of-00004.safetensors",
|
231 |
+
"vlm.vision_encoder.transformer.resblocks.1.ln_2.bias": "model-00001-of-00004.safetensors",
|
232 |
+
"vlm.vision_encoder.transformer.resblocks.1.ln_2.weight": "model-00001-of-00004.safetensors",
|
233 |
+
"vlm.vision_encoder.transformer.resblocks.1.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
234 |
+
"vlm.vision_encoder.transformer.resblocks.1.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
235 |
+
"vlm.vision_encoder.transformer.resblocks.1.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
236 |
+
"vlm.vision_encoder.transformer.resblocks.1.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
237 |
+
"vlm.vision_encoder.transformer.resblocks.10.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
238 |
+
"vlm.vision_encoder.transformer.resblocks.10.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
239 |
+
"vlm.vision_encoder.transformer.resblocks.10.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
240 |
+
"vlm.vision_encoder.transformer.resblocks.10.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
241 |
+
"vlm.vision_encoder.transformer.resblocks.10.ln_1.bias": "model-00001-of-00004.safetensors",
|
242 |
+
"vlm.vision_encoder.transformer.resblocks.10.ln_1.weight": "model-00001-of-00004.safetensors",
|
243 |
+
"vlm.vision_encoder.transformer.resblocks.10.ln_2.bias": "model-00001-of-00004.safetensors",
|
244 |
+
"vlm.vision_encoder.transformer.resblocks.10.ln_2.weight": "model-00001-of-00004.safetensors",
|
245 |
+
"vlm.vision_encoder.transformer.resblocks.10.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
246 |
+
"vlm.vision_encoder.transformer.resblocks.10.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
247 |
+
"vlm.vision_encoder.transformer.resblocks.10.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
248 |
+
"vlm.vision_encoder.transformer.resblocks.10.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
249 |
+
"vlm.vision_encoder.transformer.resblocks.11.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
250 |
+
"vlm.vision_encoder.transformer.resblocks.11.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
251 |
+
"vlm.vision_encoder.transformer.resblocks.11.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
252 |
+
"vlm.vision_encoder.transformer.resblocks.11.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
253 |
+
"vlm.vision_encoder.transformer.resblocks.11.ln_1.bias": "model-00001-of-00004.safetensors",
|
254 |
+
"vlm.vision_encoder.transformer.resblocks.11.ln_1.weight": "model-00001-of-00004.safetensors",
|
255 |
+
"vlm.vision_encoder.transformer.resblocks.11.ln_2.bias": "model-00001-of-00004.safetensors",
|
256 |
+
"vlm.vision_encoder.transformer.resblocks.11.ln_2.weight": "model-00001-of-00004.safetensors",
|
257 |
+
"vlm.vision_encoder.transformer.resblocks.11.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
258 |
+
"vlm.vision_encoder.transformer.resblocks.11.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
259 |
+
"vlm.vision_encoder.transformer.resblocks.11.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
260 |
+
"vlm.vision_encoder.transformer.resblocks.11.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
261 |
+
"vlm.vision_encoder.transformer.resblocks.12.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
262 |
+
"vlm.vision_encoder.transformer.resblocks.12.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
263 |
+
"vlm.vision_encoder.transformer.resblocks.12.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
264 |
+
"vlm.vision_encoder.transformer.resblocks.12.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
265 |
+
"vlm.vision_encoder.transformer.resblocks.12.ln_1.bias": "model-00001-of-00004.safetensors",
|
266 |
+
"vlm.vision_encoder.transformer.resblocks.12.ln_1.weight": "model-00001-of-00004.safetensors",
|
267 |
+
"vlm.vision_encoder.transformer.resblocks.12.ln_2.bias": "model-00001-of-00004.safetensors",
|
268 |
+
"vlm.vision_encoder.transformer.resblocks.12.ln_2.weight": "model-00001-of-00004.safetensors",
|
269 |
+
"vlm.vision_encoder.transformer.resblocks.12.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
270 |
+
"vlm.vision_encoder.transformer.resblocks.12.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
271 |
+
"vlm.vision_encoder.transformer.resblocks.12.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
272 |
+
"vlm.vision_encoder.transformer.resblocks.12.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
273 |
+
"vlm.vision_encoder.transformer.resblocks.13.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
274 |
+
"vlm.vision_encoder.transformer.resblocks.13.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
275 |
+
"vlm.vision_encoder.transformer.resblocks.13.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
276 |
+
"vlm.vision_encoder.transformer.resblocks.13.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
277 |
+
"vlm.vision_encoder.transformer.resblocks.13.ln_1.bias": "model-00001-of-00004.safetensors",
|
278 |
+
"vlm.vision_encoder.transformer.resblocks.13.ln_1.weight": "model-00001-of-00004.safetensors",
|
279 |
+
"vlm.vision_encoder.transformer.resblocks.13.ln_2.bias": "model-00001-of-00004.safetensors",
|
280 |
+
"vlm.vision_encoder.transformer.resblocks.13.ln_2.weight": "model-00001-of-00004.safetensors",
|
281 |
+
"vlm.vision_encoder.transformer.resblocks.13.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
282 |
+
"vlm.vision_encoder.transformer.resblocks.13.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
283 |
+
"vlm.vision_encoder.transformer.resblocks.13.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
284 |
+
"vlm.vision_encoder.transformer.resblocks.13.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
285 |
+
"vlm.vision_encoder.transformer.resblocks.14.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
286 |
+
"vlm.vision_encoder.transformer.resblocks.14.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
287 |
+
"vlm.vision_encoder.transformer.resblocks.14.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
288 |
+
"vlm.vision_encoder.transformer.resblocks.14.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
289 |
+
"vlm.vision_encoder.transformer.resblocks.14.ln_1.bias": "model-00001-of-00004.safetensors",
|
290 |
+
"vlm.vision_encoder.transformer.resblocks.14.ln_1.weight": "model-00001-of-00004.safetensors",
|
291 |
+
"vlm.vision_encoder.transformer.resblocks.14.ln_2.bias": "model-00001-of-00004.safetensors",
|
292 |
+
"vlm.vision_encoder.transformer.resblocks.14.ln_2.weight": "model-00001-of-00004.safetensors",
|
293 |
+
"vlm.vision_encoder.transformer.resblocks.14.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
294 |
+
"vlm.vision_encoder.transformer.resblocks.14.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
295 |
+
"vlm.vision_encoder.transformer.resblocks.14.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
296 |
+
"vlm.vision_encoder.transformer.resblocks.14.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
297 |
+
"vlm.vision_encoder.transformer.resblocks.15.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
298 |
+
"vlm.vision_encoder.transformer.resblocks.15.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
299 |
+
"vlm.vision_encoder.transformer.resblocks.15.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
300 |
+
"vlm.vision_encoder.transformer.resblocks.15.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
301 |
+
"vlm.vision_encoder.transformer.resblocks.15.ln_1.bias": "model-00001-of-00004.safetensors",
|
302 |
+
"vlm.vision_encoder.transformer.resblocks.15.ln_1.weight": "model-00001-of-00004.safetensors",
|
303 |
+
"vlm.vision_encoder.transformer.resblocks.15.ln_2.bias": "model-00001-of-00004.safetensors",
|
304 |
+
"vlm.vision_encoder.transformer.resblocks.15.ln_2.weight": "model-00001-of-00004.safetensors",
|
305 |
+
"vlm.vision_encoder.transformer.resblocks.15.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
306 |
+
"vlm.vision_encoder.transformer.resblocks.15.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
307 |
+
"vlm.vision_encoder.transformer.resblocks.15.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
308 |
+
"vlm.vision_encoder.transformer.resblocks.15.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
309 |
+
"vlm.vision_encoder.transformer.resblocks.16.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
310 |
+
"vlm.vision_encoder.transformer.resblocks.16.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
311 |
+
"vlm.vision_encoder.transformer.resblocks.16.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
312 |
+
"vlm.vision_encoder.transformer.resblocks.16.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
313 |
+
"vlm.vision_encoder.transformer.resblocks.16.ln_1.bias": "model-00001-of-00004.safetensors",
|
314 |
+
"vlm.vision_encoder.transformer.resblocks.16.ln_1.weight": "model-00001-of-00004.safetensors",
|
315 |
+
"vlm.vision_encoder.transformer.resblocks.16.ln_2.bias": "model-00001-of-00004.safetensors",
|
316 |
+
"vlm.vision_encoder.transformer.resblocks.16.ln_2.weight": "model-00001-of-00004.safetensors",
|
317 |
+
"vlm.vision_encoder.transformer.resblocks.16.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
318 |
+
"vlm.vision_encoder.transformer.resblocks.16.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
319 |
+
"vlm.vision_encoder.transformer.resblocks.16.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
320 |
+
"vlm.vision_encoder.transformer.resblocks.16.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
321 |
+
"vlm.vision_encoder.transformer.resblocks.17.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
322 |
+
"vlm.vision_encoder.transformer.resblocks.17.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
323 |
+
"vlm.vision_encoder.transformer.resblocks.17.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
324 |
+
"vlm.vision_encoder.transformer.resblocks.17.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
325 |
+
"vlm.vision_encoder.transformer.resblocks.17.ln_1.bias": "model-00001-of-00004.safetensors",
|
326 |
+
"vlm.vision_encoder.transformer.resblocks.17.ln_1.weight": "model-00001-of-00004.safetensors",
|
327 |
+
"vlm.vision_encoder.transformer.resblocks.17.ln_2.bias": "model-00001-of-00004.safetensors",
|
328 |
+
"vlm.vision_encoder.transformer.resblocks.17.ln_2.weight": "model-00001-of-00004.safetensors",
|
329 |
+
"vlm.vision_encoder.transformer.resblocks.17.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
330 |
+
"vlm.vision_encoder.transformer.resblocks.17.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
331 |
+
"vlm.vision_encoder.transformer.resblocks.17.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
332 |
+
"vlm.vision_encoder.transformer.resblocks.17.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
333 |
+
"vlm.vision_encoder.transformer.resblocks.18.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
334 |
+
"vlm.vision_encoder.transformer.resblocks.18.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
335 |
+
"vlm.vision_encoder.transformer.resblocks.18.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
336 |
+
"vlm.vision_encoder.transformer.resblocks.18.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
337 |
+
"vlm.vision_encoder.transformer.resblocks.18.ln_1.bias": "model-00001-of-00004.safetensors",
|
338 |
+
"vlm.vision_encoder.transformer.resblocks.18.ln_1.weight": "model-00001-of-00004.safetensors",
|
339 |
+
"vlm.vision_encoder.transformer.resblocks.18.ln_2.bias": "model-00001-of-00004.safetensors",
|
340 |
+
"vlm.vision_encoder.transformer.resblocks.18.ln_2.weight": "model-00001-of-00004.safetensors",
|
341 |
+
"vlm.vision_encoder.transformer.resblocks.18.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
342 |
+
"vlm.vision_encoder.transformer.resblocks.18.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
343 |
+
"vlm.vision_encoder.transformer.resblocks.18.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
344 |
+
"vlm.vision_encoder.transformer.resblocks.18.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
345 |
+
"vlm.vision_encoder.transformer.resblocks.19.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
346 |
+
"vlm.vision_encoder.transformer.resblocks.19.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
347 |
+
"vlm.vision_encoder.transformer.resblocks.19.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
348 |
+
"vlm.vision_encoder.transformer.resblocks.19.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
349 |
+
"vlm.vision_encoder.transformer.resblocks.19.ln_1.bias": "model-00001-of-00004.safetensors",
|
350 |
+
"vlm.vision_encoder.transformer.resblocks.19.ln_1.weight": "model-00001-of-00004.safetensors",
|
351 |
+
"vlm.vision_encoder.transformer.resblocks.19.ln_2.bias": "model-00001-of-00004.safetensors",
|
352 |
+
"vlm.vision_encoder.transformer.resblocks.19.ln_2.weight": "model-00001-of-00004.safetensors",
|
353 |
+
"vlm.vision_encoder.transformer.resblocks.19.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
354 |
+
"vlm.vision_encoder.transformer.resblocks.19.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
355 |
+
"vlm.vision_encoder.transformer.resblocks.19.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
356 |
+
"vlm.vision_encoder.transformer.resblocks.19.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
357 |
+
"vlm.vision_encoder.transformer.resblocks.2.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
358 |
+
"vlm.vision_encoder.transformer.resblocks.2.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
359 |
+
"vlm.vision_encoder.transformer.resblocks.2.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
360 |
+
"vlm.vision_encoder.transformer.resblocks.2.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
361 |
+
"vlm.vision_encoder.transformer.resblocks.2.ln_1.bias": "model-00001-of-00004.safetensors",
|
362 |
+
"vlm.vision_encoder.transformer.resblocks.2.ln_1.weight": "model-00001-of-00004.safetensors",
|
363 |
+
"vlm.vision_encoder.transformer.resblocks.2.ln_2.bias": "model-00001-of-00004.safetensors",
|
364 |
+
"vlm.vision_encoder.transformer.resblocks.2.ln_2.weight": "model-00001-of-00004.safetensors",
|
365 |
+
"vlm.vision_encoder.transformer.resblocks.2.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
366 |
+
"vlm.vision_encoder.transformer.resblocks.2.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
367 |
+
"vlm.vision_encoder.transformer.resblocks.2.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
368 |
+
"vlm.vision_encoder.transformer.resblocks.2.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
369 |
+
"vlm.vision_encoder.transformer.resblocks.20.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
370 |
+
"vlm.vision_encoder.transformer.resblocks.20.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
371 |
+
"vlm.vision_encoder.transformer.resblocks.20.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
372 |
+
"vlm.vision_encoder.transformer.resblocks.20.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
373 |
+
"vlm.vision_encoder.transformer.resblocks.20.ln_1.bias": "model-00001-of-00004.safetensors",
|
374 |
+
"vlm.vision_encoder.transformer.resblocks.20.ln_1.weight": "model-00001-of-00004.safetensors",
|
375 |
+
"vlm.vision_encoder.transformer.resblocks.20.ln_2.bias": "model-00001-of-00004.safetensors",
|
376 |
+
"vlm.vision_encoder.transformer.resblocks.20.ln_2.weight": "model-00001-of-00004.safetensors",
|
377 |
+
"vlm.vision_encoder.transformer.resblocks.20.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
378 |
+
"vlm.vision_encoder.transformer.resblocks.20.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
379 |
+
"vlm.vision_encoder.transformer.resblocks.20.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
380 |
+
"vlm.vision_encoder.transformer.resblocks.20.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
381 |
+
"vlm.vision_encoder.transformer.resblocks.21.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
382 |
+
"vlm.vision_encoder.transformer.resblocks.21.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
383 |
+
"vlm.vision_encoder.transformer.resblocks.21.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
384 |
+
"vlm.vision_encoder.transformer.resblocks.21.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
385 |
+
"vlm.vision_encoder.transformer.resblocks.21.ln_1.bias": "model-00001-of-00004.safetensors",
|
386 |
+
"vlm.vision_encoder.transformer.resblocks.21.ln_1.weight": "model-00001-of-00004.safetensors",
|
387 |
+
"vlm.vision_encoder.transformer.resblocks.21.ln_2.bias": "model-00001-of-00004.safetensors",
|
388 |
+
"vlm.vision_encoder.transformer.resblocks.21.ln_2.weight": "model-00001-of-00004.safetensors",
|
389 |
+
"vlm.vision_encoder.transformer.resblocks.21.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
390 |
+
"vlm.vision_encoder.transformer.resblocks.21.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
391 |
+
"vlm.vision_encoder.transformer.resblocks.21.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
392 |
+
"vlm.vision_encoder.transformer.resblocks.21.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
393 |
+
"vlm.vision_encoder.transformer.resblocks.22.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
394 |
+
"vlm.vision_encoder.transformer.resblocks.22.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
395 |
+
"vlm.vision_encoder.transformer.resblocks.22.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
396 |
+
"vlm.vision_encoder.transformer.resblocks.22.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
397 |
+
"vlm.vision_encoder.transformer.resblocks.22.ln_1.bias": "model-00001-of-00004.safetensors",
|
398 |
+
"vlm.vision_encoder.transformer.resblocks.22.ln_1.weight": "model-00001-of-00004.safetensors",
|
399 |
+
"vlm.vision_encoder.transformer.resblocks.22.ln_2.bias": "model-00001-of-00004.safetensors",
|
400 |
+
"vlm.vision_encoder.transformer.resblocks.22.ln_2.weight": "model-00001-of-00004.safetensors",
|
401 |
+
"vlm.vision_encoder.transformer.resblocks.22.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
402 |
+
"vlm.vision_encoder.transformer.resblocks.22.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
403 |
+
"vlm.vision_encoder.transformer.resblocks.22.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
404 |
+
"vlm.vision_encoder.transformer.resblocks.22.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
405 |
+
"vlm.vision_encoder.transformer.resblocks.23.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
406 |
+
"vlm.vision_encoder.transformer.resblocks.23.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
407 |
+
"vlm.vision_encoder.transformer.resblocks.23.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
408 |
+
"vlm.vision_encoder.transformer.resblocks.23.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
409 |
+
"vlm.vision_encoder.transformer.resblocks.23.ln_1.bias": "model-00001-of-00004.safetensors",
|
410 |
+
"vlm.vision_encoder.transformer.resblocks.23.ln_1.weight": "model-00001-of-00004.safetensors",
|
411 |
+
"vlm.vision_encoder.transformer.resblocks.23.ln_2.bias": "model-00001-of-00004.safetensors",
|
412 |
+
"vlm.vision_encoder.transformer.resblocks.23.ln_2.weight": "model-00001-of-00004.safetensors",
|
413 |
+
"vlm.vision_encoder.transformer.resblocks.23.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
414 |
+
"vlm.vision_encoder.transformer.resblocks.23.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
415 |
+
"vlm.vision_encoder.transformer.resblocks.23.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
416 |
+
"vlm.vision_encoder.transformer.resblocks.23.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
417 |
+
"vlm.vision_encoder.transformer.resblocks.24.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
418 |
+
"vlm.vision_encoder.transformer.resblocks.24.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
419 |
+
"vlm.vision_encoder.transformer.resblocks.24.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
420 |
+
"vlm.vision_encoder.transformer.resblocks.24.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
421 |
+
"vlm.vision_encoder.transformer.resblocks.24.ln_1.bias": "model-00001-of-00004.safetensors",
|
422 |
+
"vlm.vision_encoder.transformer.resblocks.24.ln_1.weight": "model-00001-of-00004.safetensors",
|
423 |
+
"vlm.vision_encoder.transformer.resblocks.24.ln_2.bias": "model-00001-of-00004.safetensors",
|
424 |
+
"vlm.vision_encoder.transformer.resblocks.24.ln_2.weight": "model-00001-of-00004.safetensors",
|
425 |
+
"vlm.vision_encoder.transformer.resblocks.24.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
426 |
+
"vlm.vision_encoder.transformer.resblocks.24.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
427 |
+
"vlm.vision_encoder.transformer.resblocks.24.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
428 |
+
"vlm.vision_encoder.transformer.resblocks.24.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
429 |
+
"vlm.vision_encoder.transformer.resblocks.25.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
430 |
+
"vlm.vision_encoder.transformer.resblocks.25.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
431 |
+
"vlm.vision_encoder.transformer.resblocks.25.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
432 |
+
"vlm.vision_encoder.transformer.resblocks.25.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
433 |
+
"vlm.vision_encoder.transformer.resblocks.25.ln_1.bias": "model-00001-of-00004.safetensors",
|
434 |
+
"vlm.vision_encoder.transformer.resblocks.25.ln_1.weight": "model-00001-of-00004.safetensors",
|
435 |
+
"vlm.vision_encoder.transformer.resblocks.25.ln_2.bias": "model-00001-of-00004.safetensors",
|
436 |
+
"vlm.vision_encoder.transformer.resblocks.25.ln_2.weight": "model-00001-of-00004.safetensors",
|
437 |
+
"vlm.vision_encoder.transformer.resblocks.25.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
438 |
+
"vlm.vision_encoder.transformer.resblocks.25.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
439 |
+
"vlm.vision_encoder.transformer.resblocks.25.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
440 |
+
"vlm.vision_encoder.transformer.resblocks.25.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
441 |
+
"vlm.vision_encoder.transformer.resblocks.26.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
442 |
+
"vlm.vision_encoder.transformer.resblocks.26.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
443 |
+
"vlm.vision_encoder.transformer.resblocks.26.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
444 |
+
"vlm.vision_encoder.transformer.resblocks.26.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
445 |
+
"vlm.vision_encoder.transformer.resblocks.26.ln_1.bias": "model-00001-of-00004.safetensors",
|
446 |
+
"vlm.vision_encoder.transformer.resblocks.26.ln_1.weight": "model-00001-of-00004.safetensors",
|
447 |
+
"vlm.vision_encoder.transformer.resblocks.26.ln_2.bias": "model-00001-of-00004.safetensors",
|
448 |
+
"vlm.vision_encoder.transformer.resblocks.26.ln_2.weight": "model-00001-of-00004.safetensors",
|
449 |
+
"vlm.vision_encoder.transformer.resblocks.26.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
450 |
+
"vlm.vision_encoder.transformer.resblocks.26.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
451 |
+
"vlm.vision_encoder.transformer.resblocks.26.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
452 |
+
"vlm.vision_encoder.transformer.resblocks.26.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
453 |
+
"vlm.vision_encoder.transformer.resblocks.27.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
454 |
+
"vlm.vision_encoder.transformer.resblocks.27.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
455 |
+
"vlm.vision_encoder.transformer.resblocks.27.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
456 |
+
"vlm.vision_encoder.transformer.resblocks.27.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
457 |
+
"vlm.vision_encoder.transformer.resblocks.27.ln_1.bias": "model-00001-of-00004.safetensors",
|
458 |
+
"vlm.vision_encoder.transformer.resblocks.27.ln_1.weight": "model-00001-of-00004.safetensors",
|
459 |
+
"vlm.vision_encoder.transformer.resblocks.27.ln_2.bias": "model-00001-of-00004.safetensors",
|
460 |
+
"vlm.vision_encoder.transformer.resblocks.27.ln_2.weight": "model-00001-of-00004.safetensors",
|
461 |
+
"vlm.vision_encoder.transformer.resblocks.27.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
462 |
+
"vlm.vision_encoder.transformer.resblocks.27.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
463 |
+
"vlm.vision_encoder.transformer.resblocks.27.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
464 |
+
"vlm.vision_encoder.transformer.resblocks.27.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
465 |
+
"vlm.vision_encoder.transformer.resblocks.28.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
466 |
+
"vlm.vision_encoder.transformer.resblocks.28.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
467 |
+
"vlm.vision_encoder.transformer.resblocks.28.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
468 |
+
"vlm.vision_encoder.transformer.resblocks.28.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
469 |
+
"vlm.vision_encoder.transformer.resblocks.28.ln_1.bias": "model-00001-of-00004.safetensors",
|
470 |
+
"vlm.vision_encoder.transformer.resblocks.28.ln_1.weight": "model-00001-of-00004.safetensors",
|
471 |
+
"vlm.vision_encoder.transformer.resblocks.28.ln_2.bias": "model-00001-of-00004.safetensors",
|
472 |
+
"vlm.vision_encoder.transformer.resblocks.28.ln_2.weight": "model-00001-of-00004.safetensors",
|
473 |
+
"vlm.vision_encoder.transformer.resblocks.28.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
474 |
+
"vlm.vision_encoder.transformer.resblocks.28.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
475 |
+
"vlm.vision_encoder.transformer.resblocks.28.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
476 |
+
"vlm.vision_encoder.transformer.resblocks.28.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
477 |
+
"vlm.vision_encoder.transformer.resblocks.29.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
478 |
+
"vlm.vision_encoder.transformer.resblocks.29.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
479 |
+
"vlm.vision_encoder.transformer.resblocks.29.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
480 |
+
"vlm.vision_encoder.transformer.resblocks.29.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
481 |
+
"vlm.vision_encoder.transformer.resblocks.29.ln_1.bias": "model-00001-of-00004.safetensors",
|
482 |
+
"vlm.vision_encoder.transformer.resblocks.29.ln_1.weight": "model-00001-of-00004.safetensors",
|
483 |
+
"vlm.vision_encoder.transformer.resblocks.29.ln_2.bias": "model-00001-of-00004.safetensors",
|
484 |
+
"vlm.vision_encoder.transformer.resblocks.29.ln_2.weight": "model-00001-of-00004.safetensors",
|
485 |
+
"vlm.vision_encoder.transformer.resblocks.29.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
486 |
+
"vlm.vision_encoder.transformer.resblocks.29.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
487 |
+
"vlm.vision_encoder.transformer.resblocks.29.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
488 |
+
"vlm.vision_encoder.transformer.resblocks.29.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
489 |
+
"vlm.vision_encoder.transformer.resblocks.3.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
490 |
+
"vlm.vision_encoder.transformer.resblocks.3.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
491 |
+
"vlm.vision_encoder.transformer.resblocks.3.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
492 |
+
"vlm.vision_encoder.transformer.resblocks.3.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
493 |
+
"vlm.vision_encoder.transformer.resblocks.3.ln_1.bias": "model-00001-of-00004.safetensors",
|
494 |
+
"vlm.vision_encoder.transformer.resblocks.3.ln_1.weight": "model-00001-of-00004.safetensors",
|
495 |
+
"vlm.vision_encoder.transformer.resblocks.3.ln_2.bias": "model-00001-of-00004.safetensors",
|
496 |
+
"vlm.vision_encoder.transformer.resblocks.3.ln_2.weight": "model-00001-of-00004.safetensors",
|
497 |
+
"vlm.vision_encoder.transformer.resblocks.3.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
498 |
+
"vlm.vision_encoder.transformer.resblocks.3.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
499 |
+
"vlm.vision_encoder.transformer.resblocks.3.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
500 |
+
"vlm.vision_encoder.transformer.resblocks.3.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
501 |
+
"vlm.vision_encoder.transformer.resblocks.30.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
502 |
+
"vlm.vision_encoder.transformer.resblocks.30.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
503 |
+
"vlm.vision_encoder.transformer.resblocks.30.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
504 |
+
"vlm.vision_encoder.transformer.resblocks.30.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
505 |
+
"vlm.vision_encoder.transformer.resblocks.30.ln_1.bias": "model-00001-of-00004.safetensors",
|
506 |
+
"vlm.vision_encoder.transformer.resblocks.30.ln_1.weight": "model-00001-of-00004.safetensors",
|
507 |
+
"vlm.vision_encoder.transformer.resblocks.30.ln_2.bias": "model-00001-of-00004.safetensors",
|
508 |
+
"vlm.vision_encoder.transformer.resblocks.30.ln_2.weight": "model-00001-of-00004.safetensors",
|
509 |
+
"vlm.vision_encoder.transformer.resblocks.30.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
510 |
+
"vlm.vision_encoder.transformer.resblocks.30.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
511 |
+
"vlm.vision_encoder.transformer.resblocks.30.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
512 |
+
"vlm.vision_encoder.transformer.resblocks.30.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
513 |
+
"vlm.vision_encoder.transformer.resblocks.31.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
514 |
+
"vlm.vision_encoder.transformer.resblocks.31.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
515 |
+
"vlm.vision_encoder.transformer.resblocks.31.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
516 |
+
"vlm.vision_encoder.transformer.resblocks.31.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
517 |
+
"vlm.vision_encoder.transformer.resblocks.31.ln_1.bias": "model-00001-of-00004.safetensors",
|
518 |
+
"vlm.vision_encoder.transformer.resblocks.31.ln_1.weight": "model-00001-of-00004.safetensors",
|
519 |
+
"vlm.vision_encoder.transformer.resblocks.31.ln_2.bias": "model-00001-of-00004.safetensors",
|
520 |
+
"vlm.vision_encoder.transformer.resblocks.31.ln_2.weight": "model-00001-of-00004.safetensors",
|
521 |
+
"vlm.vision_encoder.transformer.resblocks.31.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
522 |
+
"vlm.vision_encoder.transformer.resblocks.31.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
523 |
+
"vlm.vision_encoder.transformer.resblocks.31.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
524 |
+
"vlm.vision_encoder.transformer.resblocks.31.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
525 |
+
"vlm.vision_encoder.transformer.resblocks.4.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
526 |
+
"vlm.vision_encoder.transformer.resblocks.4.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
527 |
+
"vlm.vision_encoder.transformer.resblocks.4.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
528 |
+
"vlm.vision_encoder.transformer.resblocks.4.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
529 |
+
"vlm.vision_encoder.transformer.resblocks.4.ln_1.bias": "model-00001-of-00004.safetensors",
|
530 |
+
"vlm.vision_encoder.transformer.resblocks.4.ln_1.weight": "model-00001-of-00004.safetensors",
|
531 |
+
"vlm.vision_encoder.transformer.resblocks.4.ln_2.bias": "model-00001-of-00004.safetensors",
|
532 |
+
"vlm.vision_encoder.transformer.resblocks.4.ln_2.weight": "model-00001-of-00004.safetensors",
|
533 |
+
"vlm.vision_encoder.transformer.resblocks.4.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
534 |
+
"vlm.vision_encoder.transformer.resblocks.4.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
535 |
+
"vlm.vision_encoder.transformer.resblocks.4.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
536 |
+
"vlm.vision_encoder.transformer.resblocks.4.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
537 |
+
"vlm.vision_encoder.transformer.resblocks.5.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
538 |
+
"vlm.vision_encoder.transformer.resblocks.5.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
539 |
+
"vlm.vision_encoder.transformer.resblocks.5.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
540 |
+
"vlm.vision_encoder.transformer.resblocks.5.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
541 |
+
"vlm.vision_encoder.transformer.resblocks.5.ln_1.bias": "model-00001-of-00004.safetensors",
|
542 |
+
"vlm.vision_encoder.transformer.resblocks.5.ln_1.weight": "model-00001-of-00004.safetensors",
|
543 |
+
"vlm.vision_encoder.transformer.resblocks.5.ln_2.bias": "model-00001-of-00004.safetensors",
|
544 |
+
"vlm.vision_encoder.transformer.resblocks.5.ln_2.weight": "model-00001-of-00004.safetensors",
|
545 |
+
"vlm.vision_encoder.transformer.resblocks.5.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
546 |
+
"vlm.vision_encoder.transformer.resblocks.5.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
547 |
+
"vlm.vision_encoder.transformer.resblocks.5.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
548 |
+
"vlm.vision_encoder.transformer.resblocks.5.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
549 |
+
"vlm.vision_encoder.transformer.resblocks.6.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
550 |
+
"vlm.vision_encoder.transformer.resblocks.6.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
551 |
+
"vlm.vision_encoder.transformer.resblocks.6.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
552 |
+
"vlm.vision_encoder.transformer.resblocks.6.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
553 |
+
"vlm.vision_encoder.transformer.resblocks.6.ln_1.bias": "model-00001-of-00004.safetensors",
|
554 |
+
"vlm.vision_encoder.transformer.resblocks.6.ln_1.weight": "model-00001-of-00004.safetensors",
|
555 |
+
"vlm.vision_encoder.transformer.resblocks.6.ln_2.bias": "model-00001-of-00004.safetensors",
|
556 |
+
"vlm.vision_encoder.transformer.resblocks.6.ln_2.weight": "model-00001-of-00004.safetensors",
|
557 |
+
"vlm.vision_encoder.transformer.resblocks.6.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
558 |
+
"vlm.vision_encoder.transformer.resblocks.6.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
559 |
+
"vlm.vision_encoder.transformer.resblocks.6.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
560 |
+
"vlm.vision_encoder.transformer.resblocks.6.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
561 |
+
"vlm.vision_encoder.transformer.resblocks.7.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
562 |
+
"vlm.vision_encoder.transformer.resblocks.7.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
563 |
+
"vlm.vision_encoder.transformer.resblocks.7.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
564 |
+
"vlm.vision_encoder.transformer.resblocks.7.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
565 |
+
"vlm.vision_encoder.transformer.resblocks.7.ln_1.bias": "model-00001-of-00004.safetensors",
|
566 |
+
"vlm.vision_encoder.transformer.resblocks.7.ln_1.weight": "model-00001-of-00004.safetensors",
|
567 |
+
"vlm.vision_encoder.transformer.resblocks.7.ln_2.bias": "model-00001-of-00004.safetensors",
|
568 |
+
"vlm.vision_encoder.transformer.resblocks.7.ln_2.weight": "model-00001-of-00004.safetensors",
|
569 |
+
"vlm.vision_encoder.transformer.resblocks.7.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
570 |
+
"vlm.vision_encoder.transformer.resblocks.7.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
571 |
+
"vlm.vision_encoder.transformer.resblocks.7.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
572 |
+
"vlm.vision_encoder.transformer.resblocks.7.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
573 |
+
"vlm.vision_encoder.transformer.resblocks.8.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
574 |
+
"vlm.vision_encoder.transformer.resblocks.8.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
575 |
+
"vlm.vision_encoder.transformer.resblocks.8.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
576 |
+
"vlm.vision_encoder.transformer.resblocks.8.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
577 |
+
"vlm.vision_encoder.transformer.resblocks.8.ln_1.bias": "model-00001-of-00004.safetensors",
|
578 |
+
"vlm.vision_encoder.transformer.resblocks.8.ln_1.weight": "model-00001-of-00004.safetensors",
|
579 |
+
"vlm.vision_encoder.transformer.resblocks.8.ln_2.bias": "model-00001-of-00004.safetensors",
|
580 |
+
"vlm.vision_encoder.transformer.resblocks.8.ln_2.weight": "model-00001-of-00004.safetensors",
|
581 |
+
"vlm.vision_encoder.transformer.resblocks.8.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
582 |
+
"vlm.vision_encoder.transformer.resblocks.8.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
583 |
+
"vlm.vision_encoder.transformer.resblocks.8.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
584 |
+
"vlm.vision_encoder.transformer.resblocks.8.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
585 |
+
"vlm.vision_encoder.transformer.resblocks.9.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
586 |
+
"vlm.vision_encoder.transformer.resblocks.9.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
587 |
+
"vlm.vision_encoder.transformer.resblocks.9.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
588 |
+
"vlm.vision_encoder.transformer.resblocks.9.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
589 |
+
"vlm.vision_encoder.transformer.resblocks.9.ln_1.bias": "model-00001-of-00004.safetensors",
|
590 |
+
"vlm.vision_encoder.transformer.resblocks.9.ln_1.weight": "model-00001-of-00004.safetensors",
|
591 |
+
"vlm.vision_encoder.transformer.resblocks.9.ln_2.bias": "model-00001-of-00004.safetensors",
|
592 |
+
"vlm.vision_encoder.transformer.resblocks.9.ln_2.weight": "model-00001-of-00004.safetensors",
|
593 |
+
"vlm.vision_encoder.transformer.resblocks.9.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
594 |
+
"vlm.vision_encoder.transformer.resblocks.9.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
595 |
+
"vlm.vision_encoder.transformer.resblocks.9.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
596 |
+
"vlm.vision_encoder.transformer.resblocks.9.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
597 |
+
"vlm.vision_tokenizer.latents": "model-00001-of-00004.safetensors",
|
598 |
+
"vlm.vision_tokenizer.layers.0.0.norm_latents.bias": "model-00001-of-00004.safetensors",
|
599 |
+
"vlm.vision_tokenizer.layers.0.0.norm_latents.weight": "model-00001-of-00004.safetensors",
|
600 |
+
"vlm.vision_tokenizer.layers.0.0.norm_media.bias": "model-00001-of-00004.safetensors",
|
601 |
+
"vlm.vision_tokenizer.layers.0.0.norm_media.weight": "model-00001-of-00004.safetensors",
|
602 |
+
"vlm.vision_tokenizer.layers.0.0.to_kv.weight": "model-00001-of-00004.safetensors",
|
603 |
+
"vlm.vision_tokenizer.layers.0.0.to_out.weight": "model-00001-of-00004.safetensors",
|
604 |
+
"vlm.vision_tokenizer.layers.0.0.to_q.weight": "model-00001-of-00004.safetensors",
|
605 |
+
"vlm.vision_tokenizer.layers.0.1.0.bias": "model-00001-of-00004.safetensors",
|
606 |
+
"vlm.vision_tokenizer.layers.0.1.0.weight": "model-00001-of-00004.safetensors",
|
607 |
+
"vlm.vision_tokenizer.layers.0.1.1.weight": "model-00001-of-00004.safetensors",
|
608 |
+
"vlm.vision_tokenizer.layers.0.1.3.weight": "model-00001-of-00004.safetensors",
|
609 |
+
"vlm.vision_tokenizer.layers.1.0.norm_latents.bias": "model-00001-of-00004.safetensors",
|
610 |
+
"vlm.vision_tokenizer.layers.1.0.norm_latents.weight": "model-00001-of-00004.safetensors",
|
611 |
+
"vlm.vision_tokenizer.layers.1.0.norm_media.bias": "model-00001-of-00004.safetensors",
|
612 |
+
"vlm.vision_tokenizer.layers.1.0.norm_media.weight": "model-00001-of-00004.safetensors",
|
613 |
+
"vlm.vision_tokenizer.layers.1.0.to_kv.weight": "model-00001-of-00004.safetensors",
|
614 |
+
"vlm.vision_tokenizer.layers.1.0.to_out.weight": "model-00001-of-00004.safetensors",
|
615 |
+
"vlm.vision_tokenizer.layers.1.0.to_q.weight": "model-00001-of-00004.safetensors",
|
616 |
+
"vlm.vision_tokenizer.layers.1.1.0.bias": "model-00001-of-00004.safetensors",
|
617 |
+
"vlm.vision_tokenizer.layers.1.1.0.weight": "model-00001-of-00004.safetensors",
|
618 |
+
"vlm.vision_tokenizer.layers.1.1.1.weight": "model-00001-of-00004.safetensors",
|
619 |
+
"vlm.vision_tokenizer.layers.1.1.3.weight": "model-00001-of-00004.safetensors",
|
620 |
+
"vlm.vision_tokenizer.layers.2.0.norm_latents.bias": "model-00001-of-00004.safetensors",
|
621 |
+
"vlm.vision_tokenizer.layers.2.0.norm_latents.weight": "model-00001-of-00004.safetensors",
|
622 |
+
"vlm.vision_tokenizer.layers.2.0.norm_media.bias": "model-00001-of-00004.safetensors",
|
623 |
+
"vlm.vision_tokenizer.layers.2.0.norm_media.weight": "model-00001-of-00004.safetensors",
|
624 |
+
"vlm.vision_tokenizer.layers.2.0.to_kv.weight": "model-00001-of-00004.safetensors",
|
625 |
+
"vlm.vision_tokenizer.layers.2.0.to_out.weight": "model-00001-of-00004.safetensors",
|
626 |
+
"vlm.vision_tokenizer.layers.2.0.to_q.weight": "model-00001-of-00004.safetensors",
|
627 |
+
"vlm.vision_tokenizer.layers.2.1.0.bias": "model-00001-of-00004.safetensors",
|
628 |
+
"vlm.vision_tokenizer.layers.2.1.0.weight": "model-00001-of-00004.safetensors",
|
629 |
+
"vlm.vision_tokenizer.layers.2.1.1.weight": "model-00001-of-00004.safetensors",
|
630 |
+
"vlm.vision_tokenizer.layers.2.1.3.weight": "model-00001-of-00004.safetensors",
|
631 |
+
"vlm.vision_tokenizer.layers.3.0.norm_latents.bias": "model-00001-of-00004.safetensors",
|
632 |
+
"vlm.vision_tokenizer.layers.3.0.norm_latents.weight": "model-00001-of-00004.safetensors",
|
633 |
+
"vlm.vision_tokenizer.layers.3.0.norm_media.bias": "model-00001-of-00004.safetensors",
|
634 |
+
"vlm.vision_tokenizer.layers.3.0.norm_media.weight": "model-00001-of-00004.safetensors",
|
635 |
+
"vlm.vision_tokenizer.layers.3.0.to_kv.weight": "model-00001-of-00004.safetensors",
|
636 |
+
"vlm.vision_tokenizer.layers.3.0.to_out.weight": "model-00001-of-00004.safetensors",
|
637 |
+
"vlm.vision_tokenizer.layers.3.0.to_q.weight": "model-00001-of-00004.safetensors",
|
638 |
+
"vlm.vision_tokenizer.layers.3.1.0.bias": "model-00001-of-00004.safetensors",
|
639 |
+
"vlm.vision_tokenizer.layers.3.1.0.weight": "model-00001-of-00004.safetensors",
|
640 |
+
"vlm.vision_tokenizer.layers.3.1.1.weight": "model-00001-of-00004.safetensors",
|
641 |
+
"vlm.vision_tokenizer.layers.3.1.3.weight": "model-00001-of-00004.safetensors",
|
642 |
+
"vlm.vision_tokenizer.layers.4.0.norm_latents.bias": "model-00001-of-00004.safetensors",
|
643 |
+
"vlm.vision_tokenizer.layers.4.0.norm_latents.weight": "model-00001-of-00004.safetensors",
|
644 |
+
"vlm.vision_tokenizer.layers.4.0.norm_media.bias": "model-00001-of-00004.safetensors",
|
645 |
+
"vlm.vision_tokenizer.layers.4.0.norm_media.weight": "model-00001-of-00004.safetensors",
|
646 |
+
"vlm.vision_tokenizer.layers.4.0.to_kv.weight": "model-00001-of-00004.safetensors",
|
647 |
+
"vlm.vision_tokenizer.layers.4.0.to_out.weight": "model-00001-of-00004.safetensors",
|
648 |
+
"vlm.vision_tokenizer.layers.4.0.to_q.weight": "model-00001-of-00004.safetensors",
|
649 |
+
"vlm.vision_tokenizer.layers.4.1.0.bias": "model-00001-of-00004.safetensors",
|
650 |
+
"vlm.vision_tokenizer.layers.4.1.0.weight": "model-00001-of-00004.safetensors",
|
651 |
+
"vlm.vision_tokenizer.layers.4.1.1.weight": "model-00001-of-00004.safetensors",
|
652 |
+
"vlm.vision_tokenizer.layers.4.1.3.weight": "model-00001-of-00004.safetensors",
|
653 |
+
"vlm.vision_tokenizer.layers.5.0.norm_latents.bias": "model-00001-of-00004.safetensors",
|
654 |
+
"vlm.vision_tokenizer.layers.5.0.norm_latents.weight": "model-00001-of-00004.safetensors",
|
655 |
+
"vlm.vision_tokenizer.layers.5.0.norm_media.bias": "model-00001-of-00004.safetensors",
|
656 |
+
"vlm.vision_tokenizer.layers.5.0.norm_media.weight": "model-00001-of-00004.safetensors",
|
657 |
+
"vlm.vision_tokenizer.layers.5.0.to_kv.weight": "model-00001-of-00004.safetensors",
|
658 |
+
"vlm.vision_tokenizer.layers.5.0.to_out.weight": "model-00001-of-00004.safetensors",
|
659 |
+
"vlm.vision_tokenizer.layers.5.0.to_q.weight": "model-00001-of-00004.safetensors",
|
660 |
+
"vlm.vision_tokenizer.layers.5.1.0.bias": "model-00001-of-00004.safetensors",
|
661 |
+
"vlm.vision_tokenizer.layers.5.1.0.weight": "model-00001-of-00004.safetensors",
|
662 |
+
"vlm.vision_tokenizer.layers.5.1.1.weight": "model-00001-of-00004.safetensors",
|
663 |
+
"vlm.vision_tokenizer.layers.5.1.3.weight": "model-00001-of-00004.safetensors",
|
664 |
+
"vlm.vision_tokenizer.norm.bias": "model-00001-of-00004.safetensors",
|
665 |
+
"vlm.vision_tokenizer.norm.weight": "model-00001-of-00004.safetensors",
|
666 |
+
"vlm.vision_tokenizer.projection.bias": "model-00001-of-00004.safetensors",
|
667 |
+
"vlm.vision_tokenizer.projection.weight": "model-00001-of-00004.safetensors",
|
668 |
+
"vlm.vision_tokenizer.text_projection.0.bias": "model-00001-of-00004.safetensors",
|
669 |
+
"vlm.vision_tokenizer.text_projection.0.weight": "model-00001-of-00004.safetensors",
|
670 |
+
"vlm.vision_tokenizer.text_projection.2.bias": "model-00001-of-00004.safetensors",
|
671 |
+
"vlm.vision_tokenizer.text_projection.2.weight": "model-00001-of-00004.safetensors"
|
672 |
+
}
|
673 |
+
}
|
modeling_blip_3.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PreTrainedModel, AutoModelForCausalLM
|
2 |
+
import torch
|
3 |
+
import open_clip
|
4 |
+
from typing import List, Optional, Tuple, Union
|
5 |
+
from .utils import check_embedding_fns
|
6 |
+
from .vlm import InstructPerceiverResampler, KosmosInstruct
|
7 |
+
from .configuration_blip_3 import Blip3VisionEncoderConfig, Blip3VisionTokenizerConfig, Blip3Config
|
8 |
+
|
9 |
+
class Blip3VisionEncoder(PreTrainedModel):
|
10 |
+
main_input_name = "pixel_values"
|
11 |
+
config_class = Blip3VisionEncoderConfig
|
12 |
+
|
13 |
+
def __init__(self, config: Blip3VisionEncoderConfig):
|
14 |
+
super().__init__(config)
|
15 |
+
if config.model_name != 'ViT-H-14-378-quickgelu':
|
16 |
+
raise ValueError(f"Unsupported model {config.model_name}. New vision models will be added soon.")
|
17 |
+
self.model, _, _ = open_clip.create_model_and_transforms(
|
18 |
+
model_name = config.model_name,
|
19 |
+
force_image_size=config.force_image_size
|
20 |
+
)
|
21 |
+
|
22 |
+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
23 |
+
# assert pixel_values.ndim == 4, f"Expected 4D tensor (bs, c, h, w), got {pixel_values.ndim}"
|
24 |
+
return self.model.encode_image(pixel_values)
|
25 |
+
|
26 |
+
|
27 |
+
# vision tokenizer
|
28 |
+
class Blip3VisionTokenizer(PreTrainedModel):
|
29 |
+
config_class = Blip3VisionTokenizerConfig
|
30 |
+
def __init__(self, config: Blip3VisionTokenizerConfig):
|
31 |
+
super().__init__(config)
|
32 |
+
self.model = InstructPerceiverResampler(
|
33 |
+
dim_llm=config.lang_embedding_dim,
|
34 |
+
dim=config.vis_feature_dim,
|
35 |
+
dim_inner=config.lang_embedding_dim,
|
36 |
+
num_latents=config.num_vis_tokens,
|
37 |
+
repeat_latents=config.repeat_latents
|
38 |
+
)
|
39 |
+
|
40 |
+
def forward(self,
|
41 |
+
vision_features: torch.Tensor,
|
42 |
+
vision_attn_masks: torch.Tensor):
|
43 |
+
return self.model(vision_features, vision_attn_masks)
|
44 |
+
|
45 |
+
# Blip3 model
|
46 |
+
class Blip3ModelForConditionalGeneration(PreTrainedModel):
|
47 |
+
config_class = Blip3Config
|
48 |
+
|
49 |
+
def __init__(self, config: Blip3Config):
|
50 |
+
super().__init__(config)
|
51 |
+
|
52 |
+
# vision encoder initialization
|
53 |
+
vision_encoder = Blip3VisionEncoder(config.vision_encoder_config).model
|
54 |
+
vision_encoder.visual.output_tokens = True
|
55 |
+
vision_encoder = vision_encoder.visual
|
56 |
+
|
57 |
+
# language model initialization
|
58 |
+
language_model = AutoModelForCausalLM.from_config(config.text_config)
|
59 |
+
check_embedding_fns(language_model)
|
60 |
+
# Update _tied_weights_keys using the base model used.
|
61 |
+
if language_model._tied_weights_keys is not None:
|
62 |
+
self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
|
63 |
+
|
64 |
+
# vision tokenizer initialization
|
65 |
+
if config.vision_tokenizer_config.lang_embedding_dim != language_model.get_input_embeddings().weight.shape[1]:
|
66 |
+
overwrite = language_model.get_input_embeddings().weight.shape[1]
|
67 |
+
config.vision_tokenizer_config.lang_embedding_dim = overwrite
|
68 |
+
print(f"Warning: The language embedding dimension in the vision tokenizer config is different from the language model's embedding dimension. Overwriting the language embedding dimension in the vision tokenizer config to {overwrite}.")
|
69 |
+
|
70 |
+
vision_tokenizer = Blip3VisionTokenizer(config.vision_tokenizer_config).model
|
71 |
+
|
72 |
+
self.vlm = KosmosInstruct(
|
73 |
+
vision_encoder=vision_encoder,
|
74 |
+
vision_tokenizer=vision_tokenizer,
|
75 |
+
lang_model=language_model,
|
76 |
+
initial_tokenizer_len = config.text_config.initial_tokenizer_len,
|
77 |
+
pad_token_id = config.text_config.pad_token_id,
|
78 |
+
image_aspect_ratio = config.vision_encoder_config.image_aspect_ratio,
|
79 |
+
anyres_patch_sampling = config.vision_encoder_config.anyres_patch_sampling
|
80 |
+
)
|
81 |
+
# Initialize weights and apply final processing
|
82 |
+
self.post_init()
|
83 |
+
|
84 |
+
@torch.no_grad()
|
85 |
+
def generate(
|
86 |
+
self,
|
87 |
+
pixel_values: torch.FloatTensor,
|
88 |
+
input_ids: Optional[torch.LongTensor] = None,
|
89 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
90 |
+
**generate_kwargs,
|
91 |
+
) -> torch.LongTensor:
|
92 |
+
self.vlm = self.vlm.eval()
|
93 |
+
return self.vlm.generate(
|
94 |
+
vision_x = pixel_values,
|
95 |
+
lang_x = input_ids,
|
96 |
+
attention_mask = attention_mask,
|
97 |
+
**generate_kwargs)
|
98 |
+
|
99 |
+
def update_special_tokens(self, tokenizer):
|
100 |
+
tokenizer.add_special_tokens(
|
101 |
+
{"additional_special_tokens": list(self.vlm.special_tokens.values())}
|
102 |
+
)
|
103 |
+
self.vlm.lang_model.config.vocab_size = len(tokenizer)
|
104 |
+
self.vlm.set_special_token_ids(
|
105 |
+
{
|
106 |
+
v: tokenizer.convert_tokens_to_ids(v) for v in self.vlm.special_tokens.values()
|
107 |
+
}
|
108 |
+
)
|
109 |
+
return tokenizer
|
110 |
+
|
preprocessor_config.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"auto_map": {
|
3 |
+
"AutoImageProcessor": "image_processing_blip_3.Blip3ImageProcessor"
|
4 |
+
},
|
5 |
+
"do_resize": true,
|
6 |
+
"image_mean": [
|
7 |
+
0.48145466,
|
8 |
+
0.4578275,
|
9 |
+
0.40821073
|
10 |
+
],
|
11 |
+
"image_processor_type": "Blip3ImageProcessor",
|
12 |
+
"image_std": [
|
13 |
+
0.26862954,
|
14 |
+
0.26130258,
|
15 |
+
0.27577711
|
16 |
+
],
|
17 |
+
"interpolation_mode": "bicubic",
|
18 |
+
"resize_mode": "squash",
|
19 |
+
"size": [
|
20 |
+
378,
|
21 |
+
378
|
22 |
+
]
|
23 |
+
}
|
special_tokens_map.json
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": {
|
3 |
+
"content": "<s>",
|
4 |
+
"lstrip": false,
|
5 |
+
"normalized": false,
|
6 |
+
"rstrip": false,
|
7 |
+
"single_word": false
|
8 |
+
},
|
9 |
+
"eos_token": {
|
10 |
+
"content": "<|endoftext|>",
|
11 |
+
"lstrip": false,
|
12 |
+
"normalized": false,
|
13 |
+
"rstrip": false,
|
14 |
+
"single_word": false
|
15 |
+
},
|
16 |
+
"pad_token": {
|
17 |
+
"content": "<pad>",
|
18 |
+
"lstrip": false,
|
19 |
+
"normalized": false,
|
20 |
+
"rstrip": false,
|
21 |
+
"single_word": false
|
22 |
+
},
|
23 |
+
"unk_token": {
|
24 |
+
"content": "<unk>",
|
25 |
+
"lstrip": false,
|
26 |
+
"normalized": false,
|
27 |
+
"rstrip": false,
|
28 |
+
"single_word": false
|
29 |
+
}
|
30 |
+
}
|
test_samples/images/1074.jpg
ADDED
test_samples/images/1148.jpg
ADDED
test_samples/images/152.jpg
ADDED
test_samples/images/1614.jpg
ADDED
test_samples/images/26302.jpg
ADDED
test_samples/images/45711.jpg
ADDED
test_samples/test.json
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"question": "Can you explain this meme?",
|
4 |
+
"image_path": "./test_samples/images/152.jpg"
|
5 |
+
},
|
6 |
+
{
|
7 |
+
"question": "In the food web, what are the predators?",
|
8 |
+
"image_path": "./test_samples/images/26302.jpg"
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"question": "What is the role of Mouse?",
|
12 |
+
"image_path": "./test_samples/images/26302.jpg"
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"question": "How is this image taken?",
|
16 |
+
"image_path": "./test_samples/images/1614.jpg"
|
17 |
+
},
|
18 |
+
{
|
19 |
+
"question": "Can you identify the season in which the picture was taken?",
|
20 |
+
"image_path": "./test_samples/images/1074.jpg"
|
21 |
+
},
|
22 |
+
{
|
23 |
+
"question": "What can be the relationship between the two persons in this image?",
|
24 |
+
"image_path": "./test_samples/images/1148.jpg"
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"question": "What is this meeting about?",
|
28 |
+
"image_path": "./test_samples/images/45711.jpg"
|
29 |
+
},
|
30 |
+
{
|
31 |
+
"question": "How many things are discussed in the meeting?",
|
32 |
+
"image_path": "./test_samples/images/45711.jpg"
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"question": "What is the 2nd agenda?",
|
36 |
+
"image_path": "./test_samples/images/45711.jpg"
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"question": "When is the next meeting held?",
|
40 |
+
"image_path": "./test_samples/images/45711.jpg"
|
41 |
+
}
|
42 |
+
]
|
tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tokenizer.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
|
3 |
+
size 499723
|
tokenizer_config.json
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_bos_token": true,
|
3 |
+
"add_eos_token": false,
|
4 |
+
"added_tokens_decoder": {
|
5 |
+
"0": {
|
6 |
+
"content": "<unk>",
|
7 |
+
"lstrip": false,
|
8 |
+
"normalized": false,
|
9 |
+
"rstrip": false,
|
10 |
+
"single_word": false,
|
11 |
+
"special": true
|
12 |
+
},
|
13 |
+
"1": {
|
14 |
+
"content": "<s>",
|
15 |
+
"lstrip": false,
|
16 |
+
"normalized": false,
|
17 |
+
"rstrip": false,
|
18 |
+
"single_word": false,
|
19 |
+
"special": true
|
20 |
+
},
|
21 |
+
"2": {
|
22 |
+
"content": "</s>",
|
23 |
+
"lstrip": false,
|
24 |
+
"normalized": false,
|
25 |
+
"rstrip": true,
|
26 |
+
"single_word": false,
|
27 |
+
"special": false
|
28 |
+
},
|
29 |
+
"32000": {
|
30 |
+
"content": "<|endoftext|>",
|
31 |
+
"lstrip": false,
|
32 |
+
"normalized": false,
|
33 |
+
"rstrip": false,
|
34 |
+
"single_word": false,
|
35 |
+
"special": true
|
36 |
+
},
|
37 |
+
"32001": {
|
38 |
+
"content": "<|assistant|>",
|
39 |
+
"lstrip": false,
|
40 |
+
"normalized": false,
|
41 |
+
"rstrip": true,
|
42 |
+
"single_word": false,
|
43 |
+
"special": true
|
44 |
+
},
|
45 |
+
"32002": {
|
46 |
+
"content": "<|placeholder1|>",
|
47 |
+
"lstrip": false,
|
48 |
+
"normalized": false,
|
49 |
+
"rstrip": true,
|
50 |
+
"single_word": false,
|
51 |
+
"special": true
|
52 |
+
},
|
53 |
+
"32003": {
|
54 |
+
"content": "<|placeholder2|>",
|
55 |
+
"lstrip": false,
|
56 |
+
"normalized": false,
|
57 |
+
"rstrip": true,
|
58 |
+
"single_word": false,
|
59 |
+
"special": true
|
60 |
+
},
|
61 |
+
"32004": {
|
62 |
+
"content": "<|placeholder3|>",
|
63 |
+
"lstrip": false,
|
64 |
+
"normalized": false,
|
65 |
+
"rstrip": true,
|
66 |
+
"single_word": false,
|
67 |
+
"special": true
|
68 |
+
},
|
69 |
+
"32005": {
|
70 |
+
"content": "<|placeholder4|>",
|
71 |
+
"lstrip": false,
|
72 |
+
"normalized": false,
|
73 |
+
"rstrip": true,
|
74 |
+
"single_word": false,
|
75 |
+
"special": true
|
76 |
+
},
|
77 |
+
"32006": {
|
78 |
+
"content": "<|system|>",
|
79 |
+
"lstrip": false,
|
80 |
+
"normalized": false,
|
81 |
+
"rstrip": true,
|
82 |
+
"single_word": false,
|
83 |
+
"special": true
|
84 |
+
},
|
85 |
+
"32007": {
|
86 |
+
"content": "<|end|>",
|
87 |
+
"lstrip": false,
|
88 |
+
"normalized": false,
|
89 |
+
"rstrip": true,
|
90 |
+
"single_word": false,
|
91 |
+
"special": true
|
92 |
+
},
|
93 |
+
"32008": {
|
94 |
+
"content": "<|placeholder5|>",
|
95 |
+
"lstrip": false,
|
96 |
+
"normalized": false,
|
97 |
+
"rstrip": true,
|
98 |
+
"single_word": false,
|
99 |
+
"special": true
|
100 |
+
},
|
101 |
+
"32009": {
|
102 |
+
"content": "<|placeholder6|>",
|
103 |
+
"lstrip": false,
|
104 |
+
"normalized": false,
|
105 |
+
"rstrip": true,
|
106 |
+
"single_word": false,
|
107 |
+
"special": true
|
108 |
+
},
|
109 |
+
"32010": {
|
110 |
+
"content": "<|user|>",
|
111 |
+
"lstrip": false,
|
112 |
+
"normalized": false,
|
113 |
+
"rstrip": true,
|
114 |
+
"single_word": false,
|
115 |
+
"special": true
|
116 |
+
},
|
117 |
+
"32011": {
|
118 |
+
"content": "<pad>",
|
119 |
+
"lstrip": false,
|
120 |
+
"normalized": false,
|
121 |
+
"rstrip": false,
|
122 |
+
"single_word": false,
|
123 |
+
"special": true
|
124 |
+
}
|
125 |
+
},
|
126 |
+
"bos_token": "<s>",
|
127 |
+
"chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
|
128 |
+
"clean_up_tokenization_spaces": false,
|
129 |
+
"eos_token": "<|endoftext|>",
|
130 |
+
"model_max_length": 4096,
|
131 |
+
"pad_token": "<pad>",
|
132 |
+
"padding_side": "left",
|
133 |
+
"sp_model_kwargs": {},
|
134 |
+
"tokenizer_class": "LlamaTokenizer",
|
135 |
+
"unk_token": "<unk>",
|
136 |
+
"use_default_system_prompt": false
|
137 |
+
}
|
utils.py
ADDED
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import ast
|
3 |
+
import math
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
|
7 |
+
def has_fn(model, fn_name):
|
8 |
+
"""Check if model has a function fn_name"""
|
9 |
+
return callable(getattr(model, fn_name, None))
|
10 |
+
|
11 |
+
def exists(val):
|
12 |
+
return val is not None
|
13 |
+
|
14 |
+
def num_params(module, filter_to_trainable=False):
|
15 |
+
"""Returns the number of parameters in the module, or optionally only the trainable parameters"""
|
16 |
+
if filter_to_trainable:
|
17 |
+
return sum(p.numel() for p in module.parameters() if p.requires_grad)
|
18 |
+
else:
|
19 |
+
return sum(p.numel() for p in module.parameters())
|
20 |
+
|
21 |
+
def hasattr_recursive(obj, att):
|
22 |
+
"""
|
23 |
+
Check if obj has nested attribute
|
24 |
+
Example: hasattr_recursive(obj, 'a.b.c') is equivalent to hasattr(obj, 'a') and hasattr(obj.a, 'b') and hasattr(obj.a.b, 'c')
|
25 |
+
"""
|
26 |
+
if att == "":
|
27 |
+
return True
|
28 |
+
i = att.find(".")
|
29 |
+
if i < 0:
|
30 |
+
return hasattr(obj, att)
|
31 |
+
else:
|
32 |
+
try:
|
33 |
+
return hasattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
|
34 |
+
except:
|
35 |
+
return False
|
36 |
+
|
37 |
+
def getattr_recursive(obj, att):
|
38 |
+
"""
|
39 |
+
Return nested attribute of obj
|
40 |
+
Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c
|
41 |
+
"""
|
42 |
+
if att == "":
|
43 |
+
return obj
|
44 |
+
i = att.find(".")
|
45 |
+
if i < 0:
|
46 |
+
return getattr(obj, att)
|
47 |
+
else:
|
48 |
+
return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
|
49 |
+
|
50 |
+
|
51 |
+
def setattr_recursive(obj, att, val):
|
52 |
+
"""
|
53 |
+
Set nested attribute of obj
|
54 |
+
Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val
|
55 |
+
"""
|
56 |
+
if "." in att:
|
57 |
+
obj = getattr_recursive(obj, ".".join(att.split(".")[:-1]))
|
58 |
+
setattr(obj, att.split(".")[-1], val)
|
59 |
+
|
60 |
+
|
61 |
+
def stack_with_padding(list_of_tensors, padding_value=0, padding_side="right"):
|
62 |
+
"""
|
63 |
+
Stack a list of tensors with padding on one side
|
64 |
+
Args:
|
65 |
+
list_of_tensors (list[torch.Tensor]): List of tensors to stack
|
66 |
+
padding_value (int, optional): Value to pad with. Defaults to 0.
|
67 |
+
padding_side (str, optional): Side to pad on. Defaults to "right".
|
68 |
+
Returns:
|
69 |
+
torch.Tensor: Stacked tensors
|
70 |
+
"""
|
71 |
+
max_tokens = max(tensor.size(0) for tensor in list_of_tensors)
|
72 |
+
padded_tensors = []
|
73 |
+
for tensor in list_of_tensors:
|
74 |
+
num_tokens = tensor.size(0)
|
75 |
+
if len(tensor.size()) == 1:
|
76 |
+
padding = torch.full(
|
77 |
+
(max_tokens - num_tokens,),
|
78 |
+
padding_value,
|
79 |
+
dtype=tensor.dtype,
|
80 |
+
device=tensor.device,
|
81 |
+
)
|
82 |
+
else:
|
83 |
+
padding = torch.full(
|
84 |
+
(max_tokens - num_tokens, tensor.size(1)),
|
85 |
+
padding_value,
|
86 |
+
dtype=tensor.dtype,
|
87 |
+
device=tensor.device,
|
88 |
+
)
|
89 |
+
padded_tensor = (
|
90 |
+
torch.cat((tensor, padding), dim=0)
|
91 |
+
if padding_side == "right"
|
92 |
+
else torch.cat((padding, tensor), dim=0)
|
93 |
+
)
|
94 |
+
padded_tensors.append(padded_tensor)
|
95 |
+
return torch.stack(padded_tensors)
|
96 |
+
|
97 |
+
|
98 |
+
def check_embedding_fns(lang_model):
|
99 |
+
"""Checks for and attempts to set {get/set}_{input/output}_embeddings functions to the model"""
|
100 |
+
if not has_fn(lang_model, "get_input_embeddings"):
|
101 |
+
if hasattr_recursive(lang_model, "transformer.wte"): # MPT
|
102 |
+
lang_model.get_input_embeddings = lambda: lang_model.transformer.wte
|
103 |
+
elif hasattr_recursive(lang_model, "model.decoder.embed_tokens"): # OPT
|
104 |
+
lang_model.get_input_embeddings = lambda: lang_model.decoder.embed_tokens
|
105 |
+
else:
|
106 |
+
raise ValueError(
|
107 |
+
"We require the language encoder to have a get_input_embeddings method but we couldn't determine the name of the input embeddings attribute. Please supply this manually in factory.py."
|
108 |
+
)
|
109 |
+
|
110 |
+
if not has_fn(lang_model, "set_input_embeddings"):
|
111 |
+
if hasattr_recursive(lang_model, "transformer.wte"): # MPT
|
112 |
+
lang_model.set_input_embeddings = lambda x: setattr_recursive(
|
113 |
+
lang_model, "transformer.wte", x
|
114 |
+
)
|
115 |
+
elif hasattr_recursive(lang_model, "model.decoder.embed_tokens"): # OPT
|
116 |
+
lang_model.set_input_embeddings = lambda x: setattr_recursive(
|
117 |
+
lang_model, "model.decoder.embed_tokens", x
|
118 |
+
)
|
119 |
+
else:
|
120 |
+
raise ValueError(
|
121 |
+
"We require the language encoder to have a set_input_embeddings method but we couldn't determine the name of the input embeddings attribute. Please supply this manually in factory.py."
|
122 |
+
)
|
123 |
+
|
124 |
+
if not has_fn(lang_model, "get_output_embeddings"):
|
125 |
+
if hasattr_recursive(lang_model, "lm_head"):
|
126 |
+
lang_model.get_output_embeddings = lambda: lang_model.lm_head
|
127 |
+
else:
|
128 |
+
raise ValueError(
|
129 |
+
"We require the language encoder to have a get_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this manually in factory.py."
|
130 |
+
)
|
131 |
+
|
132 |
+
if not has_fn(lang_model, "set_output_embeddings"):
|
133 |
+
if hasattr_recursive(lang_model, "lm_head"):
|
134 |
+
lang_model.set_output_embeddings = lambda x: setattr_recursive(
|
135 |
+
lang_model, "lm_head", x
|
136 |
+
)
|
137 |
+
else:
|
138 |
+
raise ValueError(
|
139 |
+
"We require the language encoder to have a set_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this manually in factory.py."
|
140 |
+
)
|
141 |
+
|
142 |
+
|
143 |
+
def has_fn(model, fn_name):
|
144 |
+
"""Check if model has a function fn_name"""
|
145 |
+
return callable(getattr(model, fn_name, None))
|
146 |
+
|
147 |
+
|
148 |
+
# Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright:
|
149 |
+
#
|
150 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
151 |
+
# you may not use this file except in compliance with the License.
|
152 |
+
# You may obtain a copy of the License at
|
153 |
+
#
|
154 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
155 |
+
#
|
156 |
+
# Unless required by applicable law or agreed to in writing, software
|
157 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
158 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
159 |
+
# See the License for the specific language governing permissions and
|
160 |
+
# limitations under the License.
|
161 |
+
|
162 |
+
def unpad_image(tensor, original_size, keep_original_shape=False):
|
163 |
+
"""
|
164 |
+
Unpads a PyTorch tensor of a padded and resized image.
|
165 |
+
|
166 |
+
Args:
|
167 |
+
tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
|
168 |
+
original_size (tuple): The original size of the image (height, width).
|
169 |
+
|
170 |
+
Returns:
|
171 |
+
torch.Tensor: The unpadded image tensor.
|
172 |
+
"""
|
173 |
+
original_width, original_height = original_size
|
174 |
+
current_height, current_width = tensor.shape[1:]
|
175 |
+
|
176 |
+
original_aspect_ratio = original_width / original_height
|
177 |
+
current_aspect_ratio = current_width / current_height
|
178 |
+
|
179 |
+
if original_aspect_ratio > current_aspect_ratio:
|
180 |
+
scale_factor = current_width / original_width
|
181 |
+
new_height = int(original_height * scale_factor)
|
182 |
+
padding = (current_height - new_height) // 2
|
183 |
+
if keep_original_shape:
|
184 |
+
attention_mask = torch.ones((current_height, current_width), device=tensor.device)
|
185 |
+
attention_mask[:padding, :] = 0
|
186 |
+
attention_mask[current_height - padding:, :] = 0
|
187 |
+
return tensor, attention_mask
|
188 |
+
else:
|
189 |
+
unpadded_tensor = tensor[:, padding:current_height - padding, :]
|
190 |
+
return unpadded_tensor, None
|
191 |
+
else:
|
192 |
+
scale_factor = current_height / original_height
|
193 |
+
new_width = int(original_width * scale_factor)
|
194 |
+
padding = (current_width - new_width) // 2
|
195 |
+
if keep_original_shape:
|
196 |
+
attention_mask = torch.ones((current_height, current_width), device=tensor.device)
|
197 |
+
attention_mask[:, :padding] = 0
|
198 |
+
attention_mask[:, current_width - padding:] = 0
|
199 |
+
return tensor, attention_mask
|
200 |
+
else:
|
201 |
+
unpadded_tensor = tensor[:, :, padding:current_width - padding]
|
202 |
+
return unpadded_tensor, None
|
203 |
+
|
204 |
+
|
205 |
+
def select_best_resolution(original_size, possible_resolutions):
|
206 |
+
"""
|
207 |
+
Selects the best resolution from a list of possible resolutions based on the original size.
|
208 |
+
|
209 |
+
Args:
|
210 |
+
original_size (tuple): The original size of the image in the format (width, height).
|
211 |
+
possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
|
212 |
+
|
213 |
+
Returns:
|
214 |
+
tuple: The best fit resolution in the format (width, height).
|
215 |
+
"""
|
216 |
+
original_width, original_height = original_size
|
217 |
+
best_fit = None
|
218 |
+
max_effective_resolution = 0
|
219 |
+
min_wasted_resolution = float('inf')
|
220 |
+
|
221 |
+
for width, height in possible_resolutions:
|
222 |
+
scale = min(width / original_width, height / original_height)
|
223 |
+
downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
|
224 |
+
effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
|
225 |
+
wasted_resolution = (width * height) - effective_resolution
|
226 |
+
|
227 |
+
if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
|
228 |
+
max_effective_resolution = effective_resolution
|
229 |
+
min_wasted_resolution = wasted_resolution
|
230 |
+
best_fit = (width, height)
|
231 |
+
|
232 |
+
return best_fit
|
233 |
+
|
234 |
+
|
235 |
+
def resize_and_pad_image(image, target_resolution):
|
236 |
+
"""
|
237 |
+
Resize and pad an image to a target resolution while maintaining aspect ratio.
|
238 |
+
|
239 |
+
Args:
|
240 |
+
image (PIL.Image.Image): The input image.
|
241 |
+
target_resolution (tuple): The target resolution (width, height) of the image.
|
242 |
+
|
243 |
+
Returns:
|
244 |
+
PIL.Image.Image: The resized and padded image.
|
245 |
+
"""
|
246 |
+
original_width, original_height = image.size
|
247 |
+
target_width, target_height = target_resolution
|
248 |
+
|
249 |
+
scale_w = target_width / original_width
|
250 |
+
scale_h = target_height / original_height
|
251 |
+
|
252 |
+
if scale_w < scale_h:
|
253 |
+
new_width = target_width
|
254 |
+
new_height = min(math.ceil(original_height * scale_w), target_height)
|
255 |
+
else:
|
256 |
+
new_height = target_height
|
257 |
+
new_width = min(math.ceil(original_width * scale_h), target_width)
|
258 |
+
|
259 |
+
# Resize the image
|
260 |
+
resized_image = image.resize((new_width, new_height))
|
261 |
+
|
262 |
+
new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
|
263 |
+
paste_x = (target_width - new_width) // 2
|
264 |
+
paste_y = (target_height - new_height) // 2
|
265 |
+
new_image.paste(resized_image, (paste_x, paste_y))
|
266 |
+
|
267 |
+
return new_image
|
268 |
+
|
269 |
+
|
270 |
+
def divide_to_patches(image, patch_size):
|
271 |
+
"""
|
272 |
+
Divides an image into patches of a specified size.
|
273 |
+
|
274 |
+
Args:
|
275 |
+
image (PIL.Image.Image): The input image.
|
276 |
+
patch_size (int): The size of each patch.
|
277 |
+
|
278 |
+
Returns:
|
279 |
+
list: A list of PIL.Image.Image objects representing the patches.
|
280 |
+
"""
|
281 |
+
patches = []
|
282 |
+
width, height = image.size
|
283 |
+
for i in range(0, height, patch_size):
|
284 |
+
for j in range(0, width, patch_size):
|
285 |
+
box = (j, i, j + patch_size, i + patch_size)
|
286 |
+
patch = image.crop(box)
|
287 |
+
patches.append(patch)
|
288 |
+
|
289 |
+
return patches
|
290 |
+
|
291 |
+
|
292 |
+
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
293 |
+
"""
|
294 |
+
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
|
295 |
+
|
296 |
+
Args:
|
297 |
+
image_size (tuple): The size of the input image in the format (width, height).
|
298 |
+
grid_pinpoints (str): A string representation of a list of possible resolutions.
|
299 |
+
patch_size (int): The size of each image patch.
|
300 |
+
|
301 |
+
Returns:
|
302 |
+
tuple: The shape of the image patch grid in the format (width, height).
|
303 |
+
"""
|
304 |
+
if type(grid_pinpoints) is list:
|
305 |
+
possible_resolutions = grid_pinpoints
|
306 |
+
else:
|
307 |
+
possible_resolutions = ast.literal_eval(grid_pinpoints)
|
308 |
+
width, height = select_best_resolution(image_size, possible_resolutions)
|
309 |
+
return width // patch_size, height // patch_size
|
310 |
+
|
311 |
+
|
312 |
+
def process_anyres_image(image, processor, grid_pinpoints):
|
313 |
+
"""
|
314 |
+
Process an image with variable resolutions.
|
315 |
+
|
316 |
+
Args:
|
317 |
+
image (PIL.Image.Image): The input image to be processed.
|
318 |
+
processor: The image processor object.
|
319 |
+
grid_pinpoints (str): A string representation of a list of possible resolutions.
|
320 |
+
|
321 |
+
Returns:
|
322 |
+
torch.Tensor: A tensor containing the processed image patches.
|
323 |
+
"""
|
324 |
+
# FIXME: determine grid_pinpoints from image sizes.
|
325 |
+
if type(grid_pinpoints) is list:
|
326 |
+
possible_resolutions = grid_pinpoints
|
327 |
+
else:
|
328 |
+
possible_resolutions = ast.literal_eval(grid_pinpoints)
|
329 |
+
best_resolution = select_best_resolution(image.size, possible_resolutions)
|
330 |
+
image_padded = resize_and_pad_image(image, best_resolution)
|
331 |
+
|
332 |
+
processor_size = processor.transforms[0].size
|
333 |
+
patches = divide_to_patches(image_padded, processor_size[0])
|
334 |
+
|
335 |
+
image_original_resize = image.resize((processor_size[0], processor_size[0]))
|
336 |
+
|
337 |
+
image_patches = [image_original_resize] + patches
|
338 |
+
image_patches = [processor(image_patch)
|
339 |
+
for image_patch in image_patches]
|
340 |
+
return torch.stack(image_patches, dim=0)
|
341 |
+
|
342 |
+
|
343 |
+
def expand2square(pil_img, background_color):
|
344 |
+
width, height = pil_img.size
|
345 |
+
if width == height:
|
346 |
+
return pil_img
|
347 |
+
elif width > height:
|
348 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
349 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
350 |
+
return result
|
351 |
+
else:
|
352 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
353 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
354 |
+
return result
|
355 |
+
|
356 |
+
|
357 |
+
def process_images(images, image_processor, model_cfg):
|
358 |
+
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
|
359 |
+
new_images = []
|
360 |
+
if image_aspect_ratio == 'pad':
|
361 |
+
for image in images:
|
362 |
+
image = expand2square(image, tuple(int(x*255) for x in image_processor.transforms[-1].mean))
|
363 |
+
image = image_processor(image)
|
364 |
+
new_images.append(image)
|
365 |
+
elif image_aspect_ratio in ["anyres", "anyres-legacy"]:
|
366 |
+
base_img_size = image_processor.transforms[0].size[0]
|
367 |
+
for image in images:
|
368 |
+
image = process_anyres_image(image, image_processor, [[base_img_size,base_img_size*2],
|
369 |
+
[base_img_size*2,base_img_size],
|
370 |
+
[base_img_size*2,base_img_size*2],
|
371 |
+
[base_img_size*3,base_img_size],
|
372 |
+
[base_img_size,base_img_size*3]])
|
373 |
+
|
374 |
+
# Debug any res inference by only using 672x672.
|
375 |
+
# image = process_anyres_image(image, image_processor, [[base_img_size*2,base_img_size*2]])
|
376 |
+
new_images.append(image)
|
377 |
+
else:
|
378 |
+
return image_processor(images)
|
379 |
+
if all(x.shape == new_images[0].shape for x in new_images):
|
380 |
+
new_images = torch.stack(new_images, dim=0)
|
381 |
+
return new_images
|
382 |
+
|
383 |
+
|
vlm.py
ADDED
@@ -0,0 +1,1531 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
from torch import einsum, nn
|
4 |
+
from einops import rearrange, repeat
|
5 |
+
from einops_exts import rearrange_many
|
6 |
+
from einops import rearrange
|
7 |
+
from typing import List, Optional, Tuple, Union
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
10 |
+
from dataclasses import dataclass
|
11 |
+
from transformers import CLIPVisionModel
|
12 |
+
import transformers
|
13 |
+
|
14 |
+
from .utils import num_params, getattr_recursive, stack_with_padding, get_anyres_image_grid_shape, unpad_image
|
15 |
+
|
16 |
+
|
17 |
+
class VisionTokenizer(nn.Module):
|
18 |
+
def __init__(self, dim_media, num_tokens_per_media):
|
19 |
+
super().__init__()
|
20 |
+
self.dim_media = dim_media
|
21 |
+
self.num_tokens_per_media = num_tokens_per_media
|
22 |
+
|
23 |
+
class PerceiverAttention(nn.Module):
|
24 |
+
def __init__(self, *, dim, dim_head=64, heads=8):
|
25 |
+
super().__init__()
|
26 |
+
self.scale = dim_head**-0.5
|
27 |
+
self.heads = heads
|
28 |
+
inner_dim = dim_head * heads
|
29 |
+
|
30 |
+
self.norm_media = nn.LayerNorm(dim)
|
31 |
+
self.norm_latents = nn.LayerNorm(dim)
|
32 |
+
|
33 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
34 |
+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
35 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
36 |
+
|
37 |
+
def forward(self, x, latents, vision_attn_masks=None):
|
38 |
+
"""
|
39 |
+
Args:
|
40 |
+
x (torch.Tensor): image features
|
41 |
+
shape (b, T, n1, D)
|
42 |
+
latent (torch.Tensor): latent features
|
43 |
+
shape (b, T, n2, D)
|
44 |
+
"""
|
45 |
+
x = self.norm_media(x)
|
46 |
+
latents = self.norm_latents(latents)
|
47 |
+
|
48 |
+
h = self.heads
|
49 |
+
|
50 |
+
q = self.to_q(latents)
|
51 |
+
kv_input = torch.cat((x, latents), dim=-2) # TODO: Change the shape of vision attention mask according to this.
|
52 |
+
if vision_attn_masks is not None:
|
53 |
+
vision_attn_masks = torch.cat((vision_attn_masks,
|
54 |
+
torch.ones((latents.shape[0], latents.shape[-2]), dtype=latents.dtype, device=latents.device)),
|
55 |
+
dim=-1)
|
56 |
+
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
57 |
+
q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h)
|
58 |
+
q = q * self.scale
|
59 |
+
|
60 |
+
# attention
|
61 |
+
sim = einsum("... i d, ... j d -> ... i j", q, k)
|
62 |
+
# Apply vision attention mask here.
|
63 |
+
# Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention
|
64 |
+
if vision_attn_masks is not None:
|
65 |
+
attn_bias = torch.zeros((q.size(0), 1, 1, q.size(-2), k.size(-2)), dtype=q.dtype, device=q.device)
|
66 |
+
vision_attn_masks = repeat(vision_attn_masks, 'b n -> b 1 1 l n', l=q.size(-2))
|
67 |
+
attn_bias.masked_fill_(vision_attn_masks.logical_not(), float("-inf"))
|
68 |
+
sim += attn_bias
|
69 |
+
|
70 |
+
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
71 |
+
attn = sim.softmax(dim=-1)
|
72 |
+
|
73 |
+
|
74 |
+
out = einsum("... i j, ... j d -> ... i d", attn, v)
|
75 |
+
out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
|
76 |
+
return self.to_out(out)
|
77 |
+
|
78 |
+
|
79 |
+
def FeedForward(dim, mult=4):
|
80 |
+
inner_dim = int(dim * mult)
|
81 |
+
return nn.Sequential(
|
82 |
+
nn.LayerNorm(dim),
|
83 |
+
nn.Linear(dim, inner_dim, bias=False),
|
84 |
+
nn.GELU(),
|
85 |
+
nn.Linear(inner_dim, dim, bias=False),
|
86 |
+
)
|
87 |
+
|
88 |
+
|
89 |
+
class InstructPerceiverResampler(VisionTokenizer):
|
90 |
+
def __init__(
|
91 |
+
self,
|
92 |
+
*,
|
93 |
+
dim_llm,
|
94 |
+
dim,
|
95 |
+
dim_inner=None,
|
96 |
+
depth=6,
|
97 |
+
dim_head=96,
|
98 |
+
heads=16,
|
99 |
+
num_latents=64,
|
100 |
+
repeat_latents=False,
|
101 |
+
max_num_media=None,
|
102 |
+
max_num_frames=None,
|
103 |
+
ff_mult=4,
|
104 |
+
):
|
105 |
+
"""
|
106 |
+
Perceiver module which takes in image features and outputs image tokens.
|
107 |
+
Args:
|
108 |
+
dim (int): dimension of the incoming image features
|
109 |
+
dim_inner (int, optional): final dimension to project the incoming image features to;
|
110 |
+
also the final dimension of the outputted features. If None, no projection is used, and dim_inner = dim.
|
111 |
+
depth (int, optional): number of layers. Defaults to 6.
|
112 |
+
dim_head (int, optional): dimension of each head. Defaults to 64.
|
113 |
+
heads (int, optional): number of heads. Defaults to 8.
|
114 |
+
num_latents (int, optional): number of latent tokens to use in the Perceiver;
|
115 |
+
also corresponds to number of tokens per sequence to output. Defaults to 64.
|
116 |
+
max_num_media (int, optional): maximum number of media per sequence to input into the Perceiver
|
117 |
+
and keep positional embeddings for. If None, no positional embeddings are used.
|
118 |
+
max_num_frames (int, optional): maximum number of frames to input into the Perceiver
|
119 |
+
and keep positional embeddings for. If None, no positional embeddings are used.
|
120 |
+
ff_mult (int, optional): dimension multiplier for the feedforward network. Defaults to 4.
|
121 |
+
"""
|
122 |
+
if dim_inner is not None:
|
123 |
+
projection = nn.Linear(dim, dim_inner)
|
124 |
+
else:
|
125 |
+
projection = None
|
126 |
+
dim_inner = dim
|
127 |
+
super().__init__(dim_media=dim, num_tokens_per_media=num_latents)
|
128 |
+
self.projection = projection
|
129 |
+
|
130 |
+
# Text embedding projection.
|
131 |
+
# self.text_projection = nn.Linear(dim_llm, dim)
|
132 |
+
modules = [nn.Linear(dim_llm, dim)]
|
133 |
+
for _ in range(1, 2):
|
134 |
+
modules.append(nn.GELU())
|
135 |
+
modules.append(nn.Linear(dim, dim))
|
136 |
+
self.text_projection = nn.Sequential(*modules)
|
137 |
+
|
138 |
+
self.latents = nn.Parameter(torch.randn(num_latents, dim))
|
139 |
+
self.repeat_latents = repeat_latents
|
140 |
+
# positional embeddings
|
141 |
+
self.frame_embs = (
|
142 |
+
nn.Parameter(torch.randn(max_num_frames, dim))
|
143 |
+
if exists(max_num_frames)
|
144 |
+
else None
|
145 |
+
)
|
146 |
+
self.media_time_embs = (
|
147 |
+
nn.Parameter(torch.randn(max_num_media, 1, dim))
|
148 |
+
if exists(max_num_media)
|
149 |
+
else None
|
150 |
+
)
|
151 |
+
|
152 |
+
self.layers = nn.ModuleList([])
|
153 |
+
for _ in range(depth):
|
154 |
+
self.layers.append(
|
155 |
+
nn.ModuleList(
|
156 |
+
[
|
157 |
+
PerceiverAttention(
|
158 |
+
dim=dim, dim_head=dim_head, heads=heads
|
159 |
+
),
|
160 |
+
FeedForward(dim=dim, mult=ff_mult),
|
161 |
+
]
|
162 |
+
)
|
163 |
+
)
|
164 |
+
|
165 |
+
self.norm = nn.LayerNorm(dim)
|
166 |
+
# TODO: write a new forward function that takes in text input and append them to the query tokens.
|
167 |
+
def forward(self, x, text_embeds=None):
|
168 |
+
"""
|
169 |
+
Args:
|
170 |
+
x (torch.Tensor): image features
|
171 |
+
shape (b, T, F, v, D)
|
172 |
+
Returns:
|
173 |
+
shape (b, T, n, D) where n is self.num_latents
|
174 |
+
"""
|
175 |
+
b, T, F, v = x.shape[:4]
|
176 |
+
|
177 |
+
# frame and media time embeddings
|
178 |
+
if exists(self.frame_embs):
|
179 |
+
frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
|
180 |
+
x = x + frame_embs
|
181 |
+
x = rearrange(
|
182 |
+
x, "b T F v d -> b T (F v) d"
|
183 |
+
) # flatten the frame and spatial dimensions
|
184 |
+
if exists(self.media_time_embs):
|
185 |
+
x = x + self.media_time_embs[:T]
|
186 |
+
|
187 |
+
# blocks
|
188 |
+
# FIXME: extending query tokens proportional to the vision sequence length. Hard-coded as dfn5b token_len=729.
|
189 |
+
if self.repeat_latents:
|
190 |
+
r = v // 729 # Repeat the query tokens for r times.
|
191 |
+
latents = repeat(self.latents, "n d -> (n repeat) d", repeat=r)
|
192 |
+
else:
|
193 |
+
latents = self.latents
|
194 |
+
latents = repeat(latents, "n d -> b T n d", b=b, T=T)
|
195 |
+
# latents = repeat(self.latents, "n d -> b T n d", b=b, T=T)
|
196 |
+
# Append text embedding.
|
197 |
+
if exists(text_embeds):
|
198 |
+
text_embeds = self.text_projection(text_embeds)
|
199 |
+
text_embeds = text_embeds[:, None, :, :]
|
200 |
+
latents = torch.cat((latents, text_embeds), dim=2) # FIXME: check latents shape.
|
201 |
+
|
202 |
+
for attn, ff in self.layers:
|
203 |
+
latents = attn(x, latents) + latents
|
204 |
+
latents = ff(latents) + latents
|
205 |
+
|
206 |
+
# Truncate latents to only keep query tokens.
|
207 |
+
if exists(text_embeds):
|
208 |
+
latents = latents[:, :, :self.latents.shape[0], :]
|
209 |
+
|
210 |
+
if exists(self.projection):
|
211 |
+
return self.projection(self.norm(latents))
|
212 |
+
else:
|
213 |
+
return self.norm(latents)
|
214 |
+
|
215 |
+
class DecoupledEmbedding(nn.Embedding):
|
216 |
+
# Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding
|
217 |
+
"""
|
218 |
+
Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings. In practise, the
|
219 |
+
regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0,
|
220 |
+
then it will create `num_additional_embeddings` additional parameters that are always trained. If
|
221 |
+
`num_additional_embeddings=0`, then the module defaults back to the regular behavior of `nn.Embedding`.
|
222 |
+
"""
|
223 |
+
|
224 |
+
def __init__(
|
225 |
+
self,
|
226 |
+
max_original_id: int,
|
227 |
+
num_additional_embeddings: int = 0,
|
228 |
+
_weight: torch.Tensor = None,
|
229 |
+
num_original_embeddings: int = None,
|
230 |
+
embedding_dim: int = None,
|
231 |
+
partially_freeze=True,
|
232 |
+
device=None,
|
233 |
+
dtype=None,
|
234 |
+
pad_token_id=None,
|
235 |
+
) -> None:
|
236 |
+
"""
|
237 |
+
Args:
|
238 |
+
max_original_id (`int`):
|
239 |
+
The largest token id that should be embedded using the regular embedding (regular `weight`).
|
240 |
+
This is usually len(tokenizer) - 1 before additional tokens are added.
|
241 |
+
Note that this may not equal self.weight.shape[0]
|
242 |
+
num_additional_embeddings (`int`):
|
243 |
+
Number of additional tokens to initialize an Embedding matrix for (`additional_weight`).
|
244 |
+
_weight (`torch.Tensor`, *optional*, defaults to `None`): The regular weight tensor.
|
245 |
+
If provided, this sets the `num_original_embeddings` and `embedding_dim` parameters.
|
246 |
+
num_original_embeddings (`int`):
|
247 |
+
self.weight.shape[0]
|
248 |
+
embedding_dim (`int`):
|
249 |
+
The size of each embedding vector
|
250 |
+
partially_freeze: (`bool`, *optional*, defaults to `True`):
|
251 |
+
If `True`, the regular `weight` will be frozen. `additional_weight` is never frozen.
|
252 |
+
padding_idx (`int`, *optional*):
|
253 |
+
The padding index (needs to be less than num_embeddings)
|
254 |
+
|
255 |
+
Note: there are a lot of other parameters to initialize a standard `nn.Embedding` such as `padding_idx`,
|
256 |
+
`max_norm` or `norm_type`. We are not supporting these.
|
257 |
+
"""
|
258 |
+
# validate args
|
259 |
+
if pad_token_id is not None and pad_token_id > max_original_id:
|
260 |
+
raise ValueError(
|
261 |
+
f"pad_token_id must be <= max_original_id. Got {pad_token_id} and {max_original_id}."
|
262 |
+
+ "If the original tokenizer does not have a pad_token_id, use pad_token_id=None."
|
263 |
+
)
|
264 |
+
if _weight is not None:
|
265 |
+
assert (num_original_embeddings is None) or (
|
266 |
+
_weight.shape[0] == num_original_embeddings
|
267 |
+
), f"num_original_embeddings={num_original_embeddings} but _weight.shape[0]={_weight.shape[0]}"
|
268 |
+
assert (embedding_dim is None) or (
|
269 |
+
_weight.shape[1] == embedding_dim
|
270 |
+
), f"embedding_dim={embedding_dim} but _weight.shape[1]={_weight.shape[1]}"
|
271 |
+
num_original_embeddings = _weight.shape[0]
|
272 |
+
embedding_dim = _weight.shape[1]
|
273 |
+
else:
|
274 |
+
assert (
|
275 |
+
num_original_embeddings is not None
|
276 |
+
), "num_original_embeddings must be provided if _weight is not provided"
|
277 |
+
assert (
|
278 |
+
embedding_dim is not None
|
279 |
+
), "embedding_dim must be provided if _weight is not provided"
|
280 |
+
|
281 |
+
super().__init__(
|
282 |
+
num_embeddings=num_original_embeddings,
|
283 |
+
embedding_dim=embedding_dim,
|
284 |
+
device=device,
|
285 |
+
dtype=dtype,
|
286 |
+
padding_idx=pad_token_id,
|
287 |
+
_weight=_weight,
|
288 |
+
)
|
289 |
+
self.max_original_id = max_original_id
|
290 |
+
self.padding_idx = pad_token_id
|
291 |
+
self.num_additional_embeddings = num_additional_embeddings
|
292 |
+
if self.num_additional_embeddings > 0:
|
293 |
+
self.additional_embedding = nn.Embedding(
|
294 |
+
num_embeddings=self.num_additional_embeddings,
|
295 |
+
embedding_dim=embedding_dim,
|
296 |
+
device=device,
|
297 |
+
dtype=dtype,
|
298 |
+
)
|
299 |
+
self.set_requires_grad(
|
300 |
+
require_regular_grad=not partially_freeze, require_additional_grad=True
|
301 |
+
)
|
302 |
+
|
303 |
+
def set_requires_grad(self, require_regular_grad, require_additional_grad):
|
304 |
+
"""
|
305 |
+
Helper function to separately set the requires_grad flag for the regular weight and the additional weight.
|
306 |
+
"""
|
307 |
+
self.weight.requires_grad_(require_regular_grad)
|
308 |
+
self.additional_embedding.requires_grad_(require_additional_grad)
|
309 |
+
|
310 |
+
def forward(self, input_ids):
|
311 |
+
"""
|
312 |
+
we have 2 embeddings, with different indices - one pretrained self.weight and another
|
313 |
+
self.additional_embedding.weight that is being trained.
|
314 |
+
|
315 |
+
in order to make a lookup of the input ids, we:
|
316 |
+
1. find out the indices of the entries belonging to the 2nd embedding
|
317 |
+
2. extract those values while subtracting the size of the first embedding (num_embeddings), since the 2nd
|
318 |
+
embedding starts from 0 and not num_embeddings
|
319 |
+
3. perform the 2nd embedding lookup
|
320 |
+
4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index
|
321 |
+
5. perform the 1st embedding lookup
|
322 |
+
6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup
|
323 |
+
|
324 |
+
note: for the 1st embedding lookup we could have looked up only the low indices and not do the padding, but
|
325 |
+
then we have to create a new tensor and populate it with 2 tensors that are spread out across various indices -
|
326 |
+
i.e. not a simple concat - I haven't benchmarked the complex case if it's any faster, given that seqlens are
|
327 |
+
usually relatively short it's probably not faster or if faster not by much - but might be a good idea to
|
328 |
+
measure.
|
329 |
+
|
330 |
+
"""
|
331 |
+
if self.num_additional_embeddings == 0:
|
332 |
+
return F.embedding(input_ids, self.weight)
|
333 |
+
|
334 |
+
# Clone so that we don't modify the original input_ids later on
|
335 |
+
input_ids = input_ids.clone()
|
336 |
+
additional_vocab_indices = torch.where(input_ids > self.max_original_id)
|
337 |
+
input_ids_additional_vocab = input_ids[additional_vocab_indices]
|
338 |
+
additional_embeddings = self.additional_embedding(
|
339 |
+
input_ids_additional_vocab - self.max_original_id - 1
|
340 |
+
)
|
341 |
+
|
342 |
+
# for successful lookup replace input_ids with 0, the results of these will be discarded anyway
|
343 |
+
input_ids[additional_vocab_indices] = 0
|
344 |
+
full_vector = F.embedding(input_ids, self.weight)
|
345 |
+
|
346 |
+
# overwrite the records with high indices
|
347 |
+
full_vector[additional_vocab_indices] = additional_embeddings
|
348 |
+
|
349 |
+
return full_vector
|
350 |
+
|
351 |
+
def extra_repr(self) -> str:
|
352 |
+
return "num_original_embeddings={}, num_additional_embeddings={}, embedding_dim={}, partially_freeze={}".format(
|
353 |
+
self.max_original_id + 1,
|
354 |
+
self.num_additional_embeddings,
|
355 |
+
self.embedding_dim,
|
356 |
+
(not self.weight.requires_grad),
|
357 |
+
)
|
358 |
+
|
359 |
+
|
360 |
+
class DecoupledLinear(nn.Linear):
|
361 |
+
# Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
|
362 |
+
"""
|
363 |
+
Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters. In practise, the
|
364 |
+
regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `additional_out_features` > 0,
|
365 |
+
then it will create `additional_out_features * in_features` additional parameters that are always trained. If
|
366 |
+
`additional_out_features=0`, then the module defaults back to the regular behavior of `nn.Linear`.
|
367 |
+
"""
|
368 |
+
|
369 |
+
def __init__(
|
370 |
+
self,
|
371 |
+
max_original_id: int,
|
372 |
+
additional_out_features: int = 0,
|
373 |
+
_weight: torch.Tensor = None,
|
374 |
+
_bias: torch.Tensor = None,
|
375 |
+
in_features: int = None,
|
376 |
+
original_out_features: int = None,
|
377 |
+
bias: bool = True,
|
378 |
+
partially_freeze: bool = True,
|
379 |
+
device=None,
|
380 |
+
dtype=None,
|
381 |
+
) -> None:
|
382 |
+
"""
|
383 |
+
Args:
|
384 |
+
max_original_id (`int`): The largest token id that should be extracted from the regular weight.
|
385 |
+
This is usually len(tokenizer) - 1 before additional tokens are added.
|
386 |
+
Note that this may not equal original_out_features - 1
|
387 |
+
_weight: torch.Tensor, *optional*, defaults to `None`. The regular weight tensor.
|
388 |
+
If provided, this sets the `in_features` and `original_out_features` parameters.
|
389 |
+
_bias: torch.Tensor, *optional*, defaults to `None`. The regular bias tensor.
|
390 |
+
in_features: int. Input hidden size.
|
391 |
+
original_out_features: int. Original out_features of the language model's get_output_embeddings() function.
|
392 |
+
additional_out_features: int. Number of additional trainable dimensions.
|
393 |
+
bias: bool. Whether to include a bias term.
|
394 |
+
partially_freeze: bool, *optional*, defaults to `True`): If `True`, the regular `weight` will be frozen.
|
395 |
+
"""
|
396 |
+
# argument validation
|
397 |
+
if _weight is not None:
|
398 |
+
assert (_weight.shape[0] == original_out_features) or (
|
399 |
+
original_out_features is None
|
400 |
+
), f"original_out_features={original_out_features} but _weight.shape[0]={_weight.shape[0]}"
|
401 |
+
assert (_weight.shape[1] == in_features) or (
|
402 |
+
in_features is None
|
403 |
+
), f"in_features={in_features} but _weight.shape[1]={_weight.shape[1]}"
|
404 |
+
in_features = _weight.shape[1]
|
405 |
+
original_out_features = _weight.shape[0]
|
406 |
+
else:
|
407 |
+
assert (
|
408 |
+
in_features is not None
|
409 |
+
), "in_features must be provided if _weight is not provided"
|
410 |
+
assert (
|
411 |
+
original_out_features is not None
|
412 |
+
), "original_out_features must be provided if _weight is not provided"
|
413 |
+
|
414 |
+
if _bias is not None:
|
415 |
+
assert bias is True, "bias must be True if _bias is provided"
|
416 |
+
|
417 |
+
# initialize original linear
|
418 |
+
super().__init__(
|
419 |
+
in_features,
|
420 |
+
original_out_features,
|
421 |
+
bias,
|
422 |
+
device,
|
423 |
+
dtype)
|
424 |
+
|
425 |
+
# set weight and bias manually
|
426 |
+
if _weight is not None:
|
427 |
+
self.weight = nn.Parameter(_weight)
|
428 |
+
if _bias is not None:
|
429 |
+
self.bias = nn.Parameter(_bias)
|
430 |
+
|
431 |
+
self.in_features = in_features
|
432 |
+
self.original_out_features = original_out_features
|
433 |
+
self.max_original_id = max_original_id
|
434 |
+
|
435 |
+
# initialize additional linear
|
436 |
+
self.additional_out_features = additional_out_features
|
437 |
+
self.has_bias = bias
|
438 |
+
if additional_out_features > 0:
|
439 |
+
self.additional_fc = nn.Linear(
|
440 |
+
in_features=in_features,
|
441 |
+
out_features=additional_out_features,
|
442 |
+
bias=self.has_bias,
|
443 |
+
device=device,
|
444 |
+
dtype=dtype,
|
445 |
+
)
|
446 |
+
self.set_requires_grad(
|
447 |
+
require_regular_grad=not partially_freeze, require_additional_grad=True
|
448 |
+
)
|
449 |
+
|
450 |
+
def set_requires_grad(self, require_regular_grad, require_additional_grad):
|
451 |
+
"""
|
452 |
+
Helper function to separately set the requires_grad flag for the regular weight and the additional weight.
|
453 |
+
"""
|
454 |
+
self.weight.requires_grad_(require_regular_grad)
|
455 |
+
if self.has_bias:
|
456 |
+
self.bias.requires_grad_(require_regular_grad)
|
457 |
+
self.additional_fc.requires_grad_(require_additional_grad)
|
458 |
+
|
459 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
460 |
+
output = F.linear(input, self.weight, self.bias)
|
461 |
+
output = output[..., : self.max_original_id + 1]
|
462 |
+
|
463 |
+
if self.additional_out_features > 0:
|
464 |
+
additional_features = F.linear(
|
465 |
+
input, self.additional_fc.weight, self.additional_fc.bias
|
466 |
+
)
|
467 |
+
output = torch.cat((output, additional_features), -1)
|
468 |
+
return output
|
469 |
+
|
470 |
+
def extra_repr(self) -> str:
|
471 |
+
"""Overwriting `nn.Linear.extra_repr` to include new parameters."""
|
472 |
+
return "in_features={}, out_features={}, additional_out_features={}, bias={}, partially_freeze={}".format(
|
473 |
+
self.in_features,
|
474 |
+
self.max_original_id + 1,
|
475 |
+
self.additional_out_features,
|
476 |
+
self.bias is not None,
|
477 |
+
(not self.weight.requires_grad or not self.bias.requires_grad),
|
478 |
+
)
|
479 |
+
|
480 |
+
class VLM(nn.Module):
|
481 |
+
"""
|
482 |
+
Generic vision-language model (VLM) class.
|
483 |
+
A VLM consists of four components:
|
484 |
+
1. A vision encoder that extracts features from pixels, e.g. CLIP
|
485 |
+
input: (B, T_img, F, C, H, W)
|
486 |
+
output: (B, T_img, F, v, d)
|
487 |
+
2. A vision tokenizer that converts these features to visual token-like embeddings, e.g. Perceiver, or a linear projection head
|
488 |
+
input: (B, T_img, F, v, d)
|
489 |
+
output: (B, T_img, n, d)
|
490 |
+
3. A fusion method that allows the language model to attend to these tokens, e.g. cross-attention, or placing the tokens directly in the language model's input sequence
|
491 |
+
4. A language model
|
492 |
+
"""
|
493 |
+
|
494 |
+
def __init__(
|
495 |
+
self,
|
496 |
+
vision_encoder: nn.Module,
|
497 |
+
vision_tokenizer: nn.Module,
|
498 |
+
lang_model: nn.Module,
|
499 |
+
initial_tokenizer_len: int,
|
500 |
+
pad_token_id: int,
|
501 |
+
gradient_checkpointing: bool = False,
|
502 |
+
):
|
503 |
+
"""
|
504 |
+
Args:
|
505 |
+
vision_encoder (nn.Module): e.g. CLIP
|
506 |
+
vision_tokenizer (nn.Module): e.g. PerceiverResampler
|
507 |
+
lang_model (nn.Module): e.g. MPT
|
508 |
+
initial_tokenizer_len (int): size of the original tokenizer vocab
|
509 |
+
pad_token_id (int): id of the pad token
|
510 |
+
gradient_checkpointing (bool, optional): Whether to use gradient checkpointing. Defaults to False.
|
511 |
+
"""
|
512 |
+
super().__init__()
|
513 |
+
|
514 |
+
# save dimension information
|
515 |
+
self.lang_embedding_dim = lang_model.get_input_embeddings().weight.shape[1]
|
516 |
+
if hasattr(lang_model.config, "d_model"):
|
517 |
+
self.lang_hidden_dim = lang_model.config.d_model # mpt uses d_model
|
518 |
+
else:
|
519 |
+
self.lang_hidden_dim = lang_model.config.hidden_size
|
520 |
+
self.vis_embedding_dim = vision_tokenizer.dim_media
|
521 |
+
self.num_tokens_per_vis = vision_tokenizer.num_tokens_per_media
|
522 |
+
|
523 |
+
# core components
|
524 |
+
self.vision_encoder = vision_encoder
|
525 |
+
self.vision_tokenizer = vision_tokenizer
|
526 |
+
self.lang_model = lang_model
|
527 |
+
|
528 |
+
# lm embeddings
|
529 |
+
self.pad_token_id = pad_token_id
|
530 |
+
self.initial_tokenizer_len = initial_tokenizer_len
|
531 |
+
input_embeds = DecoupledEmbedding(
|
532 |
+
max_original_id=initial_tokenizer_len - 1,
|
533 |
+
num_additional_embeddings=len(self.special_tokens),
|
534 |
+
_weight=self.lang_model.get_input_embeddings().weight,
|
535 |
+
pad_token_id=self.pad_token_id,
|
536 |
+
)
|
537 |
+
if hasattr(input_embeds, "additional_embedding"):
|
538 |
+
input_embeds.additional_embedding.weight.data.normal_(
|
539 |
+
mean=0.0,
|
540 |
+
std=self.lang_model.config.initializer_range
|
541 |
+
if hasattr(self.lang_model.config, "initializer_range")
|
542 |
+
else 0.02,
|
543 |
+
)
|
544 |
+
self.lang_model.set_input_embeddings(input_embeds)
|
545 |
+
|
546 |
+
out_embeds = DecoupledLinear(
|
547 |
+
max_original_id=initial_tokenizer_len - 1,
|
548 |
+
additional_out_features=len(self.special_tokens),
|
549 |
+
_weight=self.lang_model.get_output_embeddings().weight,
|
550 |
+
_bias=self.lang_model.get_output_embeddings().bias if hasattr(self.lang_model.get_output_embeddings(), "bias") else None,
|
551 |
+
)
|
552 |
+
if hasattr(out_embeds, "additional_fc"):
|
553 |
+
out_embeds.additional_fc.weight.data.normal_(
|
554 |
+
mean=0.0,
|
555 |
+
std=self.lang_model.config.initializer_range
|
556 |
+
if hasattr(self.lang_model.config, "initializer_range")
|
557 |
+
else 0.02,
|
558 |
+
)
|
559 |
+
self.lang_model.set_output_embeddings(out_embeds)
|
560 |
+
|
561 |
+
# gradient checkpointing
|
562 |
+
self.vision_tokenizer._use_gradient_checkpointing = gradient_checkpointing
|
563 |
+
|
564 |
+
def forward(
|
565 |
+
self,
|
566 |
+
vision_x: Optional[torch.Tensor],
|
567 |
+
lang_x: torch.Tensor,
|
568 |
+
attention_mask: Optional[torch.Tensor] = None,
|
569 |
+
labels: Optional[torch.Tensor] = None,
|
570 |
+
past_key_values: Optional[
|
571 |
+
List[Union[torch.Tensor, Tuple[torch.Tensor]]]
|
572 |
+
] = None,
|
573 |
+
past_media_locations: Optional[torch.Tensor] = None,
|
574 |
+
past_vision_tokens: Optional[torch.Tensor] = None,
|
575 |
+
use_cache: Optional[bool] = False,
|
576 |
+
**kwargs,
|
577 |
+
):
|
578 |
+
"""
|
579 |
+
Args:
|
580 |
+
vision_x: Vision input
|
581 |
+
shape (B, T_img, F, C, H, W) with F=1
|
582 |
+
only F = 1 is supported (single-frame videos)
|
583 |
+
if T_img > the number of media tokens in the corresponding input_ids (lang_x),
|
584 |
+
only the first number of media tokens in lang_x are used
|
585 |
+
lang_x: Language input ids, with media tokens denoting where
|
586 |
+
visual media should be inserted.
|
587 |
+
shape (B, T_txt)
|
588 |
+
attention_mask: Attention mask. Defaults to None.
|
589 |
+
labels: Labels. Defaults to None.
|
590 |
+
shape (B, T_txt)
|
591 |
+
past_key_values (Tuple[torch.Tensor]], optional): Past key value pairs for each of the T_txt previous tokens in the language model. Defaults to None.
|
592 |
+
list of length = number of decoder layers in the LM
|
593 |
+
exact implementation depends on LM, see Hugging Face docs
|
594 |
+
past_media_locations (torch.Tensor, optional): boolean mask denoting which of the previous T_txt tokens were media tokens. Defaults to None.
|
595 |
+
shape (B, T_txt)
|
596 |
+
past_vision_tokens (torch.Tensor, optional): Previous vision tokens. Defaults to None.
|
597 |
+
use_cache (Optional[bool], optional): Whether to use cache. Defaults to False.
|
598 |
+
If True, includes key_values, media_locations, and vision_tokens in the output.
|
599 |
+
"""
|
600 |
+
assert not (past_vision_tokens is None) ^ (
|
601 |
+
past_media_locations is None
|
602 |
+
), "past_vision_tokens and past_media_locations must both be None or both be not None"
|
603 |
+
|
604 |
+
# convert pixels to vision tokens
|
605 |
+
if vision_x is not None:
|
606 |
+
vision_features = self._encode_vision_x(vision_x=vision_x)
|
607 |
+
vision_tokens = self.vision_tokenizer(vision_features)
|
608 |
+
else:
|
609 |
+
vision_tokens = None
|
610 |
+
|
611 |
+
# fuse the vision and language tokens
|
612 |
+
new_inputs = self._prepare_inputs_for_forward(
|
613 |
+
vision_tokens=vision_tokens,
|
614 |
+
lang_x=lang_x,
|
615 |
+
attention_mask=attention_mask,
|
616 |
+
labels=labels,
|
617 |
+
past_key_values=past_key_values,
|
618 |
+
past_media_locations=past_media_locations,
|
619 |
+
padding_side="right",
|
620 |
+
past_vision_tokens=past_vision_tokens,
|
621 |
+
)
|
622 |
+
output = self.lang_model(
|
623 |
+
**new_inputs,
|
624 |
+
use_cache=use_cache,
|
625 |
+
past_key_values=past_key_values,
|
626 |
+
**kwargs,
|
627 |
+
)
|
628 |
+
|
629 |
+
# postprocessing may be needed, e.g. to remove extra tokens from logits that were inserted into the language stream
|
630 |
+
# or to add the past_vision_tokens and past_media_locations to the output
|
631 |
+
output = self._postprocess_outputs_from_forward(
|
632 |
+
output=output,
|
633 |
+
lang_x=lang_x,
|
634 |
+
vision_tokens=vision_tokens,
|
635 |
+
use_cache=use_cache,
|
636 |
+
past_vision_tokens=past_vision_tokens,
|
637 |
+
past_media_locations=past_media_locations,
|
638 |
+
)
|
639 |
+
|
640 |
+
# postforward hooks
|
641 |
+
self._post_forward_hook()
|
642 |
+
return output
|
643 |
+
|
644 |
+
def _encode_vision_x_anyres(self, samples, device):
|
645 |
+
image_raw = samples["image"] # list of patch list in of shape [1, N_patch, C, H, W]
|
646 |
+
image_sizes = samples["image_size"]
|
647 |
+
|
648 |
+
# concate list of patches into one big patch for any res encoding.
|
649 |
+
images = [x.squeeze(0) for x in image_raw] # [N_patch, C, H, W]
|
650 |
+
image = torch.cat(images, dim=0) # [\sum{B}{N_patch_i}, C, H, W]
|
651 |
+
image = image.to(device)
|
652 |
+
|
653 |
+
with torch.no_grad():
|
654 |
+
if self.vision_encoder.__class__.__name__ == "TimmModel":
|
655 |
+
image_embeds = self.vision_encoder.trunk.forward_features(image)
|
656 |
+
elif self.vision_encoder.__class__.__name__ == 'CLIPVisionModel':
|
657 |
+
image_embeds = self.vision_encoder(image).last_hidden_state
|
658 |
+
else:
|
659 |
+
image_embeds = self.vision_encoder(image)[1] # OpenCLIP returns tuples
|
660 |
+
|
661 |
+
if isinstance(self.vision_encoder, CLIPVisionModel):
|
662 |
+
base_img_size = self.vision_encoder.config.image_size
|
663 |
+
else:
|
664 |
+
base_img_size = self.vision_encoder.image_size[0]
|
665 |
+
|
666 |
+
if self.vision_encoder.__class__.__name__ == "TimmModel":
|
667 |
+
grid_size = self.vision_encoder.trunk.patch_embed.grid_size
|
668 |
+
elif self.vision_encoder.__class__.__name__ == 'CLIPVisionModel':
|
669 |
+
grid_size_base = self.vision_encoder.config.image_size // self.vision_encoder.config.patch_size
|
670 |
+
grid_size = (grid_size_base, grid_size_base)
|
671 |
+
else:
|
672 |
+
grid_size = self.vision_encoder.grid_size
|
673 |
+
height, width = grid_size
|
674 |
+
|
675 |
+
if not image_embeds.shape[1] == height * width:
|
676 |
+
assert image_embeds.shape[1] == height * width + 1 # For vision encoders that has [CLS] token.
|
677 |
+
image_embeds = image_embeds[:, 1:, :] # Drop the cls token for each patch.
|
678 |
+
n_vis_token_per_patch = image_embeds.shape[1]
|
679 |
+
|
680 |
+
# Split encoded patches and merge patch features
|
681 |
+
# 1. Get the raw sizes from samples, and split the image embeds [\sum_{B}(N_patch_i), N_tok(16*16), C]
|
682 |
+
split_sizes = [image.shape[0] for image in images]
|
683 |
+
image_embeds = torch.split(image_embeds, split_sizes, dim=0)
|
684 |
+
# 2. For each image (consist of a list of patches), merge the patches spatially (of shape [C, n_patch_height, n_patch_width])
|
685 |
+
new_image_embeds = []
|
686 |
+
patch_attn_masks = []
|
687 |
+
max_n_img_token = -1
|
688 |
+
for idx, patch_embeds in enumerate(image_embeds):
|
689 |
+
if patch_embeds.shape[0] > 1:
|
690 |
+
# 3. Flatten the patch features and get [C, n_patch_height * (n_patch_width+1)]
|
691 |
+
base_patch_embeds = patch_embeds[0] # TODO: prepend the CLS token for th base patch embeds (of the resized entire image).
|
692 |
+
patch_embeds = patch_embeds[1:]
|
693 |
+
|
694 |
+
assert height * width == base_patch_embeds.shape[0]
|
695 |
+
|
696 |
+
num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[idx],
|
697 |
+
[[base_img_size,base_img_size*2],
|
698 |
+
[base_img_size*2,base_img_size],
|
699 |
+
[base_img_size*2,base_img_size*2],
|
700 |
+
[base_img_size*3,base_img_size],
|
701 |
+
[base_img_size,base_img_size*3]],
|
702 |
+
base_img_size) # Hardcoded grid_pinpoints.
|
703 |
+
patch_embeds = patch_embeds.view(num_patch_height, num_patch_width, height, width, -1)
|
704 |
+
|
705 |
+
patch_embeds = patch_embeds.permute(4, 0, 2, 1, 3).contiguous()
|
706 |
+
patch_embeds = patch_embeds.flatten(1, 2).flatten(2, 3)
|
707 |
+
# TODO: add an option that return masked patch_embeds instead of trimmed.
|
708 |
+
patch_embeds, patch_attn_mask = unpad_image(patch_embeds, image_sizes[idx], self.anyres_patch_sampling)
|
709 |
+
if hasattr(self, 'image_newline'):
|
710 |
+
patch_embeds = torch.cat((
|
711 |
+
patch_embeds,
|
712 |
+
self.image_newline[:, None, None].expand(*patch_embeds.shape[:-1], 1)
|
713 |
+
), dim=-1)
|
714 |
+
if self.anyres_patch_sampling:
|
715 |
+
patch_embeds = patch_embeds.view(-1, num_patch_height, num_patch_width, height*width)
|
716 |
+
patch_embeds = patch_embeds.flatten(1, 2).permute(1, 2, 0)
|
717 |
+
assert patch_attn_mask is not None
|
718 |
+
patch_attn_mask = patch_attn_mask.view(num_patch_height, num_patch_width, height*width)
|
719 |
+
patch_attn_mask = patch_attn_mask.flatten(0, 1)
|
720 |
+
patch_embeds = torch.cat((base_patch_embeds.unsqueeze(0), patch_embeds), dim=0)
|
721 |
+
patch_attn_mask = torch.cat((torch.ones(n_vis_token_per_patch, device=patch_embeds.device).unsqueeze(0), patch_attn_mask), dim=0)
|
722 |
+
else:
|
723 |
+
patch_embeds = patch_embeds.flatten(1, 2).transpose(0, 1)
|
724 |
+
patch_embeds = torch.cat((base_patch_embeds, patch_embeds), dim=0)
|
725 |
+
else:
|
726 |
+
patch_embeds = patch_embeds[0].unsqueeze(0) if self.anyres_patch_sampling else patch_embeds[0]
|
727 |
+
patch_attn_mask = torch.ones(n_vis_token_per_patch, device=patch_embeds.device).unsqueeze(0) if self.anyres_patch_sampling else None
|
728 |
+
if hasattr(self, 'image_newline'):
|
729 |
+
patch_embeds = torch.cat((
|
730 |
+
patch_embeds,
|
731 |
+
self.image_newline[None]
|
732 |
+
), dim=0)
|
733 |
+
if not self.anyres_patch_sampling:
|
734 |
+
max_n_img_token = max(patch_embeds.shape[0], max_n_img_token)
|
735 |
+
|
736 |
+
new_image_embeds.append(patch_embeds)
|
737 |
+
patch_attn_masks.append(patch_attn_mask)
|
738 |
+
|
739 |
+
if self.anyres_patch_sampling:
|
740 |
+
# Return individual patches for independent token downsampling.
|
741 |
+
return new_image_embeds, patch_attn_masks
|
742 |
+
|
743 |
+
# 4. Pad and concat the list of image_embeds [N_tok_i, C] together into a batch. Also modify the query attention mask.
|
744 |
+
image_embeds = []
|
745 |
+
image_atts = []
|
746 |
+
for image_embed in new_image_embeds:
|
747 |
+
n_img_token = image_embed.shape[0]
|
748 |
+
img_attn = torch.ones((max_n_img_token), dtype=torch.long, device=image_embed.device)
|
749 |
+
if n_img_token < max_n_img_token:
|
750 |
+
padded_embed = torch.zeros((max_n_img_token, image_embed.shape[-1]), dtype=image_embed.dtype, device=image_embed.device)
|
751 |
+
padded_embed[:n_img_token, :] = image_embed
|
752 |
+
img_attn[n_img_token:] = 0 # Mask out the padded entries.
|
753 |
+
else:
|
754 |
+
padded_embed = image_embed
|
755 |
+
image_embeds.append(padded_embed)
|
756 |
+
image_atts.append(img_attn)
|
757 |
+
image_embeds = torch.stack(image_embeds, dim=0) # Shape [B, N_tok_longest, C_dim]
|
758 |
+
image_atts = torch.stack(image_atts, dim=0) # Shape [B, N_tok_longest, C_dim]
|
759 |
+
# TODO: reshape image_embeds and image_atts to "b T F v d"
|
760 |
+
image_embeds = image_embeds[:, None, None, :, :]
|
761 |
+
# image_atts = image_atts[:, None, None, :, :]
|
762 |
+
|
763 |
+
return image_embeds, image_atts
|
764 |
+
|
765 |
+
def _encode_vision_x(self, vision_x: torch.Tensor):
|
766 |
+
"""
|
767 |
+
Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
|
768 |
+
Args:
|
769 |
+
vision_x: Vision input
|
770 |
+
shape (B, T_img, F, C, H, W)
|
771 |
+
Images in the same chunk are collated along T_img, and frames are collated along F
|
772 |
+
Currently only F=1 is supported (single-frame videos)
|
773 |
+
|
774 |
+
rearrange code based on https://github.com/dhansmair/flamingo-mini
|
775 |
+
"""
|
776 |
+
assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
|
777 |
+
b, T, F = vision_x.shape[:3]
|
778 |
+
|
779 |
+
vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
|
780 |
+
with torch.no_grad():
|
781 |
+
if self.vision_encoder.__class__.__name__ == "TimmModel":
|
782 |
+
vision_x = self.vision_encoder.trunk.forward_features(vision_x)
|
783 |
+
elif self.vision_encoder.__class__.__name__ == 'CLIPVisionModel':
|
784 |
+
vision_x = self.vision_encoder(vision_x).last_hidden_state
|
785 |
+
else:
|
786 |
+
vision_x = self.vision_encoder(vision_x)[1] # OpenCLIP returns tuples
|
787 |
+
vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
|
788 |
+
return vision_x
|
789 |
+
|
790 |
+
def _concat_vision_cache(
|
791 |
+
self, lang_x, vision_tokens, past_vision_tokens, past_media_locations, use_cache
|
792 |
+
):
|
793 |
+
"""
|
794 |
+
Helper function to include the past vision tokens and past media locations in the output.
|
795 |
+
"""
|
796 |
+
if use_cache:
|
797 |
+
if past_media_locations is not None and past_vision_tokens is not None:
|
798 |
+
if vision_tokens is not None:
|
799 |
+
updated_vision_tokens = torch.cat(
|
800 |
+
[
|
801 |
+
past_vision_tokens,
|
802 |
+
vision_tokens,
|
803 |
+
],
|
804 |
+
dim=1,
|
805 |
+
)
|
806 |
+
else:
|
807 |
+
updated_vision_tokens = past_vision_tokens
|
808 |
+
updated_media_locations = torch.cat(
|
809 |
+
[
|
810 |
+
past_media_locations,
|
811 |
+
lang_x == self.media_token_id,
|
812 |
+
],
|
813 |
+
dim=1,
|
814 |
+
)
|
815 |
+
else:
|
816 |
+
updated_vision_tokens = vision_tokens
|
817 |
+
updated_media_locations = lang_x == self.media_token_id
|
818 |
+
|
819 |
+
else:
|
820 |
+
updated_vision_tokens = None
|
821 |
+
updated_media_locations = None
|
822 |
+
|
823 |
+
return updated_vision_tokens, updated_media_locations
|
824 |
+
|
825 |
+
def generate(
|
826 |
+
self,
|
827 |
+
vision_x: torch.Tensor,
|
828 |
+
lang_x: torch.Tensor,
|
829 |
+
attention_mask: torch.Tensor = None,
|
830 |
+
past_key_values: Optional[
|
831 |
+
List[Union[torch.Tensor, Tuple[torch.Tensor]]]
|
832 |
+
] = None,
|
833 |
+
past_media_locations: Optional[torch.Tensor] = None,
|
834 |
+
past_vision_tokens: Optional[torch.Tensor] = None,
|
835 |
+
**kwargs,
|
836 |
+
):
|
837 |
+
"""
|
838 |
+
Generate text conditioned on vision and language inputs.
|
839 |
+
Args:
|
840 |
+
vision_x (torch.Tensor): Vision input
|
841 |
+
shape (B, T_img, F, C, H, W)
|
842 |
+
see documentation for forward
|
843 |
+
lang_x (torch.Tensor): Language input
|
844 |
+
shape (B, T_txt)
|
845 |
+
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
|
846 |
+
**kwargs: see generate documentation in Hugging Face CausalLM models.
|
847 |
+
Returns:
|
848 |
+
torch.Tensor: lang_x with generated tokens appended to it
|
849 |
+
"""
|
850 |
+
num_beams = kwargs.pop("num_beams", 1)
|
851 |
+
|
852 |
+
# convert pixels to vision tokens
|
853 |
+
if vision_x is not None:
|
854 |
+
vision_features = self._encode_vision_x(vision_x=vision_x)
|
855 |
+
vision_tokens = self.vision_tokenizer(vision_features)
|
856 |
+
else:
|
857 |
+
vision_tokens = None
|
858 |
+
|
859 |
+
# fuse the vision and language tokens
|
860 |
+
# for xattn, vision_x and media_location are repeat_interleaved s.t.
|
861 |
+
# the total batch size is B * num_beams
|
862 |
+
new_inputs = self._prepare_inputs_for_forward(
|
863 |
+
vision_tokens=vision_tokens,
|
864 |
+
lang_x=lang_x,
|
865 |
+
attention_mask=attention_mask,
|
866 |
+
past_key_values=past_key_values,
|
867 |
+
past_media_locations=past_media_locations,
|
868 |
+
past_vision_tokens=past_vision_tokens,
|
869 |
+
padding_side="left",
|
870 |
+
num_beams=num_beams,
|
871 |
+
)
|
872 |
+
output = self.lang_model.generate(
|
873 |
+
**new_inputs,
|
874 |
+
past_key_values=past_key_values,
|
875 |
+
num_beams=num_beams,
|
876 |
+
use_cache=True,
|
877 |
+
**kwargs,
|
878 |
+
)
|
879 |
+
self._post_forward_hook()
|
880 |
+
return output
|
881 |
+
|
882 |
+
@property
|
883 |
+
def num_trainable_params(self):
|
884 |
+
"""Print the number of trainable parameters"""
|
885 |
+
return num_params(self, filter_to_trainable=True)
|
886 |
+
|
887 |
+
def set_trainable(self):
|
888 |
+
"""
|
889 |
+
Freeze appropriate parameters in the model.
|
890 |
+
"""
|
891 |
+
raise NotImplementedError
|
892 |
+
|
893 |
+
def group_params_by_weight_decay(self):
|
894 |
+
"""
|
895 |
+
Return a tuple of (params to optimize w/ weight decay, params to optimize w/o weight decay)
|
896 |
+
"""
|
897 |
+
params_with_wd, params_without_wd = [], []
|
898 |
+
for n, p in self.named_parameters():
|
899 |
+
if p.requires_grad:
|
900 |
+
if self._should_apply_weight_decay(n):
|
901 |
+
params_with_wd.append(p)
|
902 |
+
else:
|
903 |
+
params_without_wd.append(p)
|
904 |
+
return params_with_wd, params_without_wd
|
905 |
+
|
906 |
+
def _should_apply_weight_decay(self, parameter_name):
|
907 |
+
"""
|
908 |
+
Return whether weight decay should be applied to a parameter.
|
909 |
+
"""
|
910 |
+
raise NotImplementedError
|
911 |
+
|
912 |
+
@property
|
913 |
+
def special_tokens(self):
|
914 |
+
"""
|
915 |
+
Returns a dict mapping from the attribute name of a special token to its string format,
|
916 |
+
e.g. "media_token": "<image>"
|
917 |
+
"""
|
918 |
+
assert (
|
919 |
+
"media_token" in self._special_tokens
|
920 |
+
), "VLMs need to request that the tokenizer add a media_token and call set_special_token_ids to set self.media_token_id"
|
921 |
+
return self._special_tokens
|
922 |
+
|
923 |
+
@property
|
924 |
+
def special_token_ids(self):
|
925 |
+
"""
|
926 |
+
Returns a list of the special token ids
|
927 |
+
"""
|
928 |
+
return [getattr(self, f"{att_name}_id") for att_name in self.special_tokens]
|
929 |
+
|
930 |
+
def set_special_token_ids(self, string_to_ids):
|
931 |
+
"""
|
932 |
+
Args:
|
933 |
+
string_to_ids (dict): mapping from token string to id
|
934 |
+
"""
|
935 |
+
assert set(self.special_tokens.values()).issubset(set(string_to_ids.keys()))
|
936 |
+
for att_name, token_str in self.special_tokens.items():
|
937 |
+
token_id = string_to_ids[token_str]
|
938 |
+
setattr(self, f"{att_name}_id", token_id)
|
939 |
+
setattr(self.lang_model, f"{att_name}_id", token_id)
|
940 |
+
|
941 |
+
def init_gradient_checkpointing(self):
|
942 |
+
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
943 |
+
checkpoint_wrapper,
|
944 |
+
CheckpointWrapper,
|
945 |
+
CheckpointImpl,
|
946 |
+
apply_activation_checkpointing,
|
947 |
+
)
|
948 |
+
from functools import partial
|
949 |
+
|
950 |
+
non_reentrant_wrapper = partial(
|
951 |
+
checkpoint_wrapper,
|
952 |
+
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
|
953 |
+
)
|
954 |
+
apply_activation_checkpointing(
|
955 |
+
self,
|
956 |
+
checkpoint_wrapper_fn=non_reentrant_wrapper,
|
957 |
+
check_fn=lambda m: getattr(m, "_use_gradient_checkpointing", False)
|
958 |
+
and not isinstance(m, CheckpointWrapper),
|
959 |
+
)
|
960 |
+
|
961 |
+
@dataclass
|
962 |
+
class VLMOutputWithPast(CausalLMOutputWithPast):
|
963 |
+
"""
|
964 |
+
VLMOutputWithPast is a wrapper around CausalLMOutputWithPast that adds the following attributes:
|
965 |
+
past_media_locations: Optional[torch.Tensor] = None,
|
966 |
+
past_vision_tokens: Optional[torch.Tensor] = None,
|
967 |
+
"""
|
968 |
+
|
969 |
+
past_media_locations: Optional[torch.Tensor] = None
|
970 |
+
past_vision_tokens: Optional[torch.Tensor] = None
|
971 |
+
|
972 |
+
|
973 |
+
def exists(val):
|
974 |
+
return val is not None
|
975 |
+
|
976 |
+
|
977 |
+
def FeedForward(dim, mult=4):
|
978 |
+
inner_dim = int(dim * mult)
|
979 |
+
return nn.Sequential(
|
980 |
+
nn.LayerNorm(dim),
|
981 |
+
nn.Linear(dim, inner_dim, bias=False),
|
982 |
+
nn.GELU(),
|
983 |
+
nn.Linear(inner_dim, dim, bias=False),
|
984 |
+
)
|
985 |
+
|
986 |
+
class VLMWithLanguageStream(VLM):
|
987 |
+
"""
|
988 |
+
VLM that fuses modalities by inserting vision tokens directly into the language stream.
|
989 |
+
"""
|
990 |
+
|
991 |
+
def __init__(
|
992 |
+
self,
|
993 |
+
vision_encoder: nn.Module,
|
994 |
+
vision_tokenizer: nn.Module,
|
995 |
+
lang_model: nn.Module,
|
996 |
+
initial_tokenizer_len: int,
|
997 |
+
pad_token_id: int,
|
998 |
+
decoder_layers_attr_name: str = None,
|
999 |
+
gradient_checkpointing: bool = False,
|
1000 |
+
):
|
1001 |
+
super().__init__(
|
1002 |
+
vision_encoder=vision_encoder,
|
1003 |
+
vision_tokenizer=vision_tokenizer,
|
1004 |
+
lang_model=lang_model,
|
1005 |
+
initial_tokenizer_len=initial_tokenizer_len,
|
1006 |
+
pad_token_id=pad_token_id,
|
1007 |
+
gradient_checkpointing=gradient_checkpointing,
|
1008 |
+
)
|
1009 |
+
self.decoder_layers_attr_name = decoder_layers_attr_name
|
1010 |
+
if decoder_layers_attr_name is not None:
|
1011 |
+
for block in getattr_recursive(self.lang_model, self.decoder_layers_attr_name):
|
1012 |
+
block._use_gradient_checkpointing = gradient_checkpointing
|
1013 |
+
|
1014 |
+
def _prepare_inputs_for_forward(
|
1015 |
+
self,
|
1016 |
+
vision_tokens: torch.Tensor,
|
1017 |
+
lang_x: torch.Tensor,
|
1018 |
+
attention_mask: torch.Tensor,
|
1019 |
+
labels: torch.Tensor = None,
|
1020 |
+
past_key_values=None,
|
1021 |
+
vision_attention_mask: Optional[torch.Tensor] = None,
|
1022 |
+
past_media_locations: torch.Tensor = None,
|
1023 |
+
past_vision_tokens: torch.Tensor = None,
|
1024 |
+
padding_side: str = "left",
|
1025 |
+
num_beams: int = 1,
|
1026 |
+
):
|
1027 |
+
"""
|
1028 |
+
Insert the vision tokens directly into the language stream/
|
1029 |
+
This requires us to modify the input_ids, attention_mask, and labels.
|
1030 |
+
"""
|
1031 |
+
if past_key_values is not None:
|
1032 |
+
past_len = past_key_values[0][0].shape[2]
|
1033 |
+
assert attention_mask.shape[1] == past_len + lang_x.shape[1], (
|
1034 |
+
"Attention_mask must be as long as the entire past len (including image tokens) and current input IDs. "
|
1035 |
+
+ "Check that you've expanded the attention mask to account for past image tokens."
|
1036 |
+
)
|
1037 |
+
|
1038 |
+
if vision_tokens is None:
|
1039 |
+
return {
|
1040 |
+
"input_ids": lang_x,
|
1041 |
+
"attention_mask": attention_mask,
|
1042 |
+
"labels": labels,
|
1043 |
+
}
|
1044 |
+
|
1045 |
+
# get the language embeddings
|
1046 |
+
lang_embeds = self.lang_model.get_input_embeddings()(lang_x)
|
1047 |
+
|
1048 |
+
# build up the multimodal embeddings
|
1049 |
+
B = lang_x.shape[0]
|
1050 |
+
has_labels = labels is not None
|
1051 |
+
multimodal_embeds = []
|
1052 |
+
multimodal_attention_mask = []
|
1053 |
+
multimodal_labels = [] if has_labels else None
|
1054 |
+
for i in range(B):
|
1055 |
+
# get index of <image> tokens in lang_x[i]
|
1056 |
+
image_token_idxs = torch.where(lang_x[i] == self.media_token_id)[0]
|
1057 |
+
|
1058 |
+
if len(image_token_idxs) == 0:
|
1059 |
+
multimodal_embeds.append(lang_embeds[i].clone())
|
1060 |
+
multimodal_attention_mask.append(attention_mask[i].clone())
|
1061 |
+
if has_labels:
|
1062 |
+
multimodal_labels.append(labels[i].clone())
|
1063 |
+
continue
|
1064 |
+
|
1065 |
+
# since an image is represented by self.num_tokens_per_vis tokens, we need to offset the image_token_idxs
|
1066 |
+
for j, img_idx in enumerate(image_token_idxs):
|
1067 |
+
image_token_idxs[j] += (self.num_tokens_per_vis - 1) * j # FIXME: different offset for any resolution encoding when has multiple images.
|
1068 |
+
|
1069 |
+
# loop through the image_token_idxs and insert the vision tokens
|
1070 |
+
new_embed = lang_embeds[i].clone()
|
1071 |
+
new_attention_mask = (
|
1072 |
+
attention_mask[i].clone() if attention_mask is not None else None
|
1073 |
+
)
|
1074 |
+
if has_labels:
|
1075 |
+
new_label = labels[i].clone()
|
1076 |
+
|
1077 |
+
for img_num, img_idx in enumerate(image_token_idxs):
|
1078 |
+
if img_num > 0:
|
1079 |
+
# FIXME: hardcoded as such to avoid assertion error, but this only works for single image samples.
|
1080 |
+
break
|
1081 |
+
# Get vision token attention mask for padded llava-style any resolution image tokens.
|
1082 |
+
if self.image_aspect_ratio =='anyres':
|
1083 |
+
num_vis_tokens = vision_tokens[i][img_num].shape[0]
|
1084 |
+
if vision_attention_mask is not None:
|
1085 |
+
vis_attention_mask = vision_attention_mask[i]
|
1086 |
+
else:
|
1087 |
+
vis_attention_mask = torch.ones(
|
1088 |
+
num_vis_tokens, dtype=torch.long
|
1089 |
+
).to(attention_mask.device)
|
1090 |
+
else:
|
1091 |
+
assert (
|
1092 |
+
vision_tokens[i][img_num].shape[0] == self.num_tokens_per_vis
|
1093 |
+
), f"vision token number mismatch: image embedding ({vision_tokens[i][img_num].shape[0]}) \
|
1094 |
+
vs. model.num_tokens_per_vis ({self.num_tokens_per_vis})"
|
1095 |
+
# By default, vision tokens are not padded.
|
1096 |
+
num_vis_tokens = self.num_tokens_per_vis
|
1097 |
+
vis_attention_mask = torch.ones(
|
1098 |
+
num_vis_tokens, dtype=torch.long
|
1099 |
+
).to(attention_mask.device)
|
1100 |
+
|
1101 |
+
|
1102 |
+
new_embed = torch.cat(
|
1103 |
+
(
|
1104 |
+
new_embed[:img_idx],
|
1105 |
+
vision_tokens[i][img_num],
|
1106 |
+
new_embed[img_idx + 1 :],
|
1107 |
+
),
|
1108 |
+
dim=0,
|
1109 |
+
)
|
1110 |
+
new_attention_mask = torch.cat(
|
1111 |
+
(
|
1112 |
+
new_attention_mask[:img_idx],
|
1113 |
+
vis_attention_mask,
|
1114 |
+
new_attention_mask[img_idx + 1 :],
|
1115 |
+
),
|
1116 |
+
dim=0,
|
1117 |
+
)
|
1118 |
+
if has_labels:
|
1119 |
+
new_label = torch.cat(
|
1120 |
+
(
|
1121 |
+
new_label[:img_idx],
|
1122 |
+
torch.ones(num_vis_tokens, dtype=torch.long).to(
|
1123 |
+
labels.device
|
1124 |
+
)
|
1125 |
+
* -100,
|
1126 |
+
new_label[img_idx + 1 :],
|
1127 |
+
),
|
1128 |
+
dim=0,
|
1129 |
+
)
|
1130 |
+
multimodal_embeds.append(new_embed)
|
1131 |
+
multimodal_attention_mask.append(new_attention_mask)
|
1132 |
+
if has_labels:
|
1133 |
+
multimodal_labels.append(new_label)
|
1134 |
+
|
1135 |
+
# stack
|
1136 |
+
multimodal_embeds = stack_with_padding(
|
1137 |
+
multimodal_embeds,
|
1138 |
+
padding_value=self.pad_token_id,
|
1139 |
+
padding_side=padding_side,
|
1140 |
+
)
|
1141 |
+
multimodal_attention_mask = stack_with_padding(
|
1142 |
+
multimodal_attention_mask,
|
1143 |
+
padding_value=0,
|
1144 |
+
padding_side=padding_side,
|
1145 |
+
)
|
1146 |
+
if has_labels:
|
1147 |
+
multimodal_labels = stack_with_padding(
|
1148 |
+
multimodal_labels,
|
1149 |
+
padding_value=-100,
|
1150 |
+
padding_side=padding_side,
|
1151 |
+
)
|
1152 |
+
|
1153 |
+
return {
|
1154 |
+
"inputs_embeds": multimodal_embeds,
|
1155 |
+
"attention_mask": multimodal_attention_mask,
|
1156 |
+
"labels": multimodal_labels,
|
1157 |
+
}
|
1158 |
+
|
1159 |
+
def _postprocess_outputs_from_forward(
|
1160 |
+
self,
|
1161 |
+
output: CausalLMOutputWithPast,
|
1162 |
+
lang_x: torch.Tensor,
|
1163 |
+
vision_tokens: torch.Tensor,
|
1164 |
+
past_vision_tokens: torch.Tensor,
|
1165 |
+
past_media_locations: torch.Tensor,
|
1166 |
+
use_cache: bool = False,
|
1167 |
+
):
|
1168 |
+
# Include the past vision tokens and past media locations in the output
|
1169 |
+
updated_vision_tokens, updated_media_locations = self._concat_vision_cache(
|
1170 |
+
lang_x=lang_x,
|
1171 |
+
vision_tokens=vision_tokens,
|
1172 |
+
past_vision_tokens=past_vision_tokens,
|
1173 |
+
past_media_locations=past_media_locations,
|
1174 |
+
use_cache=use_cache,
|
1175 |
+
)
|
1176 |
+
|
1177 |
+
# return logits that are the same shape as the original input_ids
|
1178 |
+
logits = output.logits
|
1179 |
+
batch_logits = []
|
1180 |
+
B, T_txt = lang_x.shape
|
1181 |
+
for i in range(B):
|
1182 |
+
sequence_logits = []
|
1183 |
+
logits_j = 0
|
1184 |
+
for j in range(T_txt):
|
1185 |
+
if lang_x[i, j] != self.media_token_id:
|
1186 |
+
sequence_logits.append(logits[i, logits_j])
|
1187 |
+
logits_j += 1
|
1188 |
+
else:
|
1189 |
+
# append the logit for the first image token, then skip over the rest
|
1190 |
+
# note: the model actually learns to predict <im_patch>, not <image>
|
1191 |
+
sequence_logits.append(logits[i, logits_j])
|
1192 |
+
logits_j += self.num_tokens_per_vis
|
1193 |
+
sequence_logits = torch.stack(sequence_logits, dim=0) # (B, vocab_size)
|
1194 |
+
batch_logits.append(sequence_logits)
|
1195 |
+
|
1196 |
+
batch_logits = torch.stack(batch_logits, dim=0) # (B, T_txt, vocab_size)
|
1197 |
+
# The final logits shape should be the same as the original input_ids shape
|
1198 |
+
assert batch_logits.shape[:2] == (B, T_txt)
|
1199 |
+
|
1200 |
+
# assemble the output
|
1201 |
+
output = VLMOutputWithPast(
|
1202 |
+
loss=output.loss,
|
1203 |
+
logits=batch_logits,
|
1204 |
+
past_key_values=output.past_key_values,
|
1205 |
+
hidden_states=output.hidden_states,
|
1206 |
+
attentions=output.attentions,
|
1207 |
+
past_media_locations=updated_media_locations,
|
1208 |
+
past_vision_tokens=updated_vision_tokens,
|
1209 |
+
)
|
1210 |
+
|
1211 |
+
return output
|
1212 |
+
|
1213 |
+
def _post_forward_hook(self):
|
1214 |
+
pass
|
1215 |
+
|
1216 |
+
|
1217 |
+
@property
|
1218 |
+
def num_params_per_module(self):
|
1219 |
+
"""Print the number of parameters per module in the model"""
|
1220 |
+
return "\n".join(
|
1221 |
+
[
|
1222 |
+
f"Vision encoder: {num_params(self.vision_encoder):,} parameters",
|
1223 |
+
f"Vision tokenizer: {num_params(self.vision_tokenizer):,} parameters",
|
1224 |
+
f"Language model: {num_params(self.lang_model):,} parameters",
|
1225 |
+
]
|
1226 |
+
)
|
1227 |
+
|
1228 |
+
@property
|
1229 |
+
def num_trainable_params_per_module(self):
|
1230 |
+
"""Print the number of trainable parameters per module in the model"""
|
1231 |
+
return "\n".join(
|
1232 |
+
[
|
1233 |
+
f"Vision encoder: {num_params(self.vision_encoder, filter_to_trainable=True):,} trainable parameters",
|
1234 |
+
f"Vision tokenizer: {num_params(self.vision_tokenizer, filter_to_trainable=True):,} trainable parameters",
|
1235 |
+
f"Language model: {num_params(self.lang_model, filter_to_trainable=True):,} trainable parameters",
|
1236 |
+
]
|
1237 |
+
)
|
1238 |
+
|
1239 |
+
|
1240 |
+
class KosmosInstruct(VLMWithLanguageStream):
|
1241 |
+
def __init__(
|
1242 |
+
self,
|
1243 |
+
vision_encoder: nn.Module,
|
1244 |
+
vision_tokenizer: nn.Module,
|
1245 |
+
lang_model: nn.Module,
|
1246 |
+
initial_tokenizer_len: int,
|
1247 |
+
pad_token_id: int,
|
1248 |
+
decoder_layers_attr_name: str = None,
|
1249 |
+
gradient_checkpointing: bool = False,
|
1250 |
+
image_aspect_ratio: str = 'pad',
|
1251 |
+
anyres_patch_sampling: bool = False
|
1252 |
+
):
|
1253 |
+
"""
|
1254 |
+
Args:
|
1255 |
+
vision_encoder (nn.Module): HF CLIPModel
|
1256 |
+
lang_encoder (nn.Module): HF causal language model
|
1257 |
+
vis_feature_dim (int): final dimension of the visual features outputted by the vision_encoder
|
1258 |
+
initial_tokenizer_len (int): size of the tokenizer vocab
|
1259 |
+
padding_token_id (int): id of the padding token. None if no padding token; then a padding token
|
1260 |
+
will be inserted into self.special_tokens, which factory.py fills after creating new tokens
|
1261 |
+
decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None.
|
1262 |
+
gradient_checkpointing (bool, optional): whether to use gradient checkpointing. Defaults to False.
|
1263 |
+
"""
|
1264 |
+
self._special_tokens = {
|
1265 |
+
"media_token": "<image>",
|
1266 |
+
"image_placeholder_token": "<image placeholder>",
|
1267 |
+
"end_of_trunk_token": "<|endofchunk|>",
|
1268 |
+
}
|
1269 |
+
lang_embedding_dim = lang_model.get_input_embeddings().weight.shape[1]
|
1270 |
+
super().__init__(
|
1271 |
+
vision_encoder=vision_encoder,
|
1272 |
+
vision_tokenizer=vision_tokenizer,
|
1273 |
+
lang_model=lang_model,
|
1274 |
+
initial_tokenizer_len=initial_tokenizer_len,
|
1275 |
+
gradient_checkpointing=gradient_checkpointing,
|
1276 |
+
decoder_layers_attr_name=decoder_layers_attr_name,
|
1277 |
+
pad_token_id=pad_token_id,
|
1278 |
+
)
|
1279 |
+
self.image_aspect_ratio = image_aspect_ratio
|
1280 |
+
self.anyres_patch_sampling = anyres_patch_sampling
|
1281 |
+
|
1282 |
+
def set_trainable(self):
|
1283 |
+
"""
|
1284 |
+
Unfreeze everything except the vision_encoder
|
1285 |
+
"""
|
1286 |
+
self.requires_grad_(True)
|
1287 |
+
self.vision_encoder.requires_grad_(False)
|
1288 |
+
|
1289 |
+
def _should_apply_weight_decay(self, parameter_name):
|
1290 |
+
"""
|
1291 |
+
Kosmos applies 0.01 weight deacy to everything
|
1292 |
+
"""
|
1293 |
+
return True
|
1294 |
+
|
1295 |
+
def forward(
|
1296 |
+
self,
|
1297 |
+
vision_x: Optional[torch.Tensor],
|
1298 |
+
lang_x: torch.Tensor,
|
1299 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1300 |
+
labels: Optional[torch.Tensor] = None,
|
1301 |
+
image_size: Optional[Tuple] = None,
|
1302 |
+
past_key_values: Optional[
|
1303 |
+
List[Union[torch.Tensor, Tuple[torch.Tensor]]]
|
1304 |
+
] = None,
|
1305 |
+
past_media_locations: Optional[torch.Tensor] = None,
|
1306 |
+
past_vision_tokens: Optional[torch.Tensor] = None,
|
1307 |
+
use_cache: Optional[bool] = False,
|
1308 |
+
**kwargs,
|
1309 |
+
):
|
1310 |
+
"""
|
1311 |
+
Args:
|
1312 |
+
vision_x: Vision input
|
1313 |
+
shape (B, T_img, F, C, H, W) with F=1
|
1314 |
+
only F = 1 is supported (single-frame videos)
|
1315 |
+
if T_img > the number of media tokens in the corresponding input_ids (lang_x),
|
1316 |
+
only the first number of media tokens in lang_x are used
|
1317 |
+
lang_x: Language input ids, with media tokens denoting where
|
1318 |
+
visual media should be inserted.
|
1319 |
+
shape (B, T_txt)
|
1320 |
+
attention_mask: Attention mask. Defaults to None.
|
1321 |
+
labels: Labels. Defaults to None.
|
1322 |
+
shape (B, T_txt)
|
1323 |
+
past_key_values (Tuple[torch.Tensor]], optional): Past key value pairs for each of the T_txt previous tokens in the language model. Defaults to None.
|
1324 |
+
list of length = number of decoder layers in the LM
|
1325 |
+
exact implementation depends on LM, see Hugging Face docs
|
1326 |
+
past_media_locations (torch.Tensor, optional): boolean mask denoting which of the previous T_txt tokens were media tokens. Defaults to None.
|
1327 |
+
shape (B, T_txt)
|
1328 |
+
past_vision_tokens (torch.Tensor, optional): Previous vision tokens. Defaults to None.
|
1329 |
+
use_cache (Optional[bool], optional): Whether to use cache. Defaults to False.
|
1330 |
+
If True, includes key_values, media_locations, and vision_tokens in the output.
|
1331 |
+
"""
|
1332 |
+
assert not (past_vision_tokens is None) ^ (
|
1333 |
+
past_media_locations is None
|
1334 |
+
), "past_vision_tokens and past_media_locations must both be None or both be not None"
|
1335 |
+
|
1336 |
+
# convert pixels to vision tokens
|
1337 |
+
vision_attention_mask = None
|
1338 |
+
if vision_x is not None:
|
1339 |
+
if self.image_aspect_ratio == 'anyres':
|
1340 |
+
input_dict = dict(image=vision_x, image_size=image_size)
|
1341 |
+
vision_features, vision_attn_masks = self._encode_vision_x_anyres(input_dict, lang_x.device)
|
1342 |
+
else:
|
1343 |
+
vision_features = self._encode_vision_x(vision_x=vision_x)
|
1344 |
+
vision_attn_masks = None
|
1345 |
+
if self.anyres_patch_sampling:
|
1346 |
+
split_sizes = [feature.shape[0] for feature in vision_features]
|
1347 |
+
vision_features = torch.cat(vision_features, dim=0)
|
1348 |
+
vision_features = vision_features[:, None, None, :, :] # Expand dimensions.
|
1349 |
+
vision_attn_masks = torch.cat(vision_attn_masks, dim=0)
|
1350 |
+
# Prepare text embeds for instruction-aware image query sampling.
|
1351 |
+
# FIXME: for debugging purposed, truncating text input to vision tokenizer to be 256 at max.
|
1352 |
+
lang_x_truncated = lang_x[:, :256]
|
1353 |
+
text_embeds = self.lang_model.get_input_embeddings()(lang_x_truncated)
|
1354 |
+
# TODO: repeat text_embeds to match the number of patches for each image patch group.
|
1355 |
+
if self.anyres_patch_sampling:
|
1356 |
+
repeated_text_embeds = []
|
1357 |
+
for i, np in enumerate(split_sizes):
|
1358 |
+
repeated_text_embeds.append(text_embeds[i].repeat(np, 1, 1))
|
1359 |
+
text_embeds = torch.cat(repeated_text_embeds, dim=0)
|
1360 |
+
vision_tokens = self.vision_tokenizer(vision_features, text_embeds)
|
1361 |
+
|
1362 |
+
# Post-processing: Split the batches into groups of patches and concatenate them together.
|
1363 |
+
if self.anyres_patch_sampling:
|
1364 |
+
vision_token_groups = torch.split(vision_tokens, split_sizes, dim=0)
|
1365 |
+
max_n_vis_token = max([vis.shape[0]*vis.shape[-2] for vis in vision_token_groups])
|
1366 |
+
# Padding.
|
1367 |
+
padded_vision_tokens = []
|
1368 |
+
padded_attn_masks = []
|
1369 |
+
for patch_vis_tokens in vision_token_groups:
|
1370 |
+
patch_vis_tokens = patch_vis_tokens.flatten(0, 2) # [Np, 1, v, d] -> [Np*v, d]
|
1371 |
+
n_vis_token = patch_vis_tokens.shape[0]
|
1372 |
+
patch_attn = torch.ones((max_n_vis_token), dtype=torch.long, device=patch_vis_tokens.device)
|
1373 |
+
if n_vis_token < max_n_vis_token:
|
1374 |
+
padded_vis_token = torch.zeros((max_n_vis_token, patch_vis_tokens.shape[-1]),
|
1375 |
+
dtype=patch_vis_tokens.dtype, device=patch_vis_tokens.device)
|
1376 |
+
padded_vis_token[:n_vis_token, :] = patch_vis_tokens
|
1377 |
+
patch_attn[n_vis_token:] = 0
|
1378 |
+
else:
|
1379 |
+
padded_vis_token = patch_vis_tokens
|
1380 |
+
padded_vision_tokens.append(padded_vis_token)
|
1381 |
+
padded_attn_masks.append(patch_attn)
|
1382 |
+
vision_tokens = torch.stack(padded_vision_tokens, dim=0)
|
1383 |
+
vision_attention_mask = torch.stack(padded_attn_masks, dim=0)
|
1384 |
+
vision_tokens = vision_tokens[:, None, :, :]
|
1385 |
+
else:
|
1386 |
+
vision_tokens = None
|
1387 |
+
|
1388 |
+
# fuse the vision and language tokens
|
1389 |
+
new_inputs = self._prepare_inputs_for_forward(
|
1390 |
+
vision_tokens=vision_tokens,
|
1391 |
+
lang_x=lang_x,
|
1392 |
+
attention_mask=attention_mask,
|
1393 |
+
vision_attention_mask=vision_attention_mask,
|
1394 |
+
labels=labels,
|
1395 |
+
past_key_values=past_key_values,
|
1396 |
+
past_media_locations=past_media_locations,
|
1397 |
+
padding_side="right",
|
1398 |
+
past_vision_tokens=past_vision_tokens,
|
1399 |
+
)
|
1400 |
+
output = self.lang_model(
|
1401 |
+
**new_inputs,
|
1402 |
+
use_cache=use_cache,
|
1403 |
+
past_key_values=past_key_values,
|
1404 |
+
**kwargs,
|
1405 |
+
)
|
1406 |
+
|
1407 |
+
# postprocessing may be needed, e.g. to remove extra tokens from logits that were inserted into the language stream
|
1408 |
+
# or to add the past_vision_tokens and past_media_locations to the output
|
1409 |
+
output = self._postprocess_outputs_from_forward(
|
1410 |
+
output=output,
|
1411 |
+
lang_x=lang_x,
|
1412 |
+
vision_tokens=vision_tokens,
|
1413 |
+
use_cache=use_cache,
|
1414 |
+
past_vision_tokens=past_vision_tokens,
|
1415 |
+
past_media_locations=past_media_locations,
|
1416 |
+
)
|
1417 |
+
|
1418 |
+
# postforward hooks
|
1419 |
+
self._post_forward_hook()
|
1420 |
+
return output
|
1421 |
+
|
1422 |
+
def generate(
|
1423 |
+
self,
|
1424 |
+
vision_x: torch.Tensor,
|
1425 |
+
lang_x: torch.Tensor,
|
1426 |
+
image_size: Optional[Tuple] = None,
|
1427 |
+
attention_mask: torch.Tensor = None,
|
1428 |
+
past_key_values: Optional[
|
1429 |
+
List[Union[torch.Tensor, Tuple[torch.Tensor]]]
|
1430 |
+
] = None,
|
1431 |
+
past_media_locations: Optional[torch.Tensor] = None,
|
1432 |
+
past_vision_tokens: Optional[torch.Tensor] = None,
|
1433 |
+
**kwargs,
|
1434 |
+
):
|
1435 |
+
"""
|
1436 |
+
Generate text conditioned on vision and language inputs.
|
1437 |
+
Args:
|
1438 |
+
vision_x (torch.Tensor): Vision input
|
1439 |
+
shape (B, T_img, F, C, H, W)
|
1440 |
+
see documentation for forward
|
1441 |
+
lang_x (torch.Tensor): Language input
|
1442 |
+
shape (B, T_txt)
|
1443 |
+
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
|
1444 |
+
**kwargs: see generate documentation in Hugging Face CausalLM models.
|
1445 |
+
Returns:
|
1446 |
+
torch.Tensor: lang_x with generated tokens appended to it
|
1447 |
+
"""
|
1448 |
+
num_beams = kwargs.pop("num_beams", 1)
|
1449 |
+
|
1450 |
+
# convert pixels to vision tokens
|
1451 |
+
vision_attention_mask = None
|
1452 |
+
if vision_x is not None:
|
1453 |
+
if self.image_aspect_ratio == 'anyres':
|
1454 |
+
input_dict = dict(image=vision_x, image_size=image_size)
|
1455 |
+
vision_features, vision_attn_masks = self._encode_vision_x_anyres(input_dict, lang_x.device)
|
1456 |
+
else:
|
1457 |
+
vision_features = self._encode_vision_x(vision_x=vision_x)
|
1458 |
+
vision_attn_masks = None
|
1459 |
+
if self.anyres_patch_sampling:
|
1460 |
+
split_sizes = [feature.shape[0] for feature in vision_features]
|
1461 |
+
vision_features = torch.cat(vision_features, dim=0)
|
1462 |
+
vision_features = vision_features[:, None, None, :, :] # Expand dimensions.
|
1463 |
+
vision_attn_masks = torch.cat(vision_attn_masks, dim=0)
|
1464 |
+
# Prepare text embeds for instruction-aware image query sampling.
|
1465 |
+
lang_x_truncated = lang_x[:, :256]
|
1466 |
+
text_embeds = self.lang_model.get_input_embeddings()(lang_x_truncated) # FIXME: check function calling.
|
1467 |
+
# Repeat text_embeds to match the number of patches for each image patch group.
|
1468 |
+
if self.anyres_patch_sampling:
|
1469 |
+
repeated_text_embeds = []
|
1470 |
+
for i, np in enumerate(split_sizes):
|
1471 |
+
repeated_text_embeds.append(text_embeds[i].repeat(np, 1, 1))
|
1472 |
+
text_embeds = torch.cat(repeated_text_embeds, dim=0)
|
1473 |
+
vision_tokens = self.vision_tokenizer(vision_features, text_embeds)
|
1474 |
+
|
1475 |
+
# Post-processing: Split the batches into groups of patches and concatenate them together.
|
1476 |
+
if self.anyres_patch_sampling:
|
1477 |
+
vision_token_groups = torch.split(vision_tokens, split_sizes, dim=0)
|
1478 |
+
max_n_vis_token = max([vis.shape[0]*vis.shape[-2] for vis in vision_token_groups])
|
1479 |
+
# Padding.
|
1480 |
+
padded_vision_tokens = []
|
1481 |
+
padded_attn_masks = []
|
1482 |
+
for patch_vis_tokens in vision_token_groups:
|
1483 |
+
patch_vis_tokens = patch_vis_tokens.flatten(0, 2) # [Np, 1, v, d] -> [Np*v, d]
|
1484 |
+
n_vis_token = patch_vis_tokens.shape[0]
|
1485 |
+
patch_attn = torch.ones((max_n_vis_token), dtype=torch.long, device=patch_vis_tokens.device)
|
1486 |
+
if n_vis_token < max_n_vis_token:
|
1487 |
+
padded_vis_token = torch.zeros((max_n_vis_token, patch_vis_tokens.shape[-1]),
|
1488 |
+
dtype=patch_vis_tokens.dtype, device=patch_vis_tokens.device)
|
1489 |
+
padded_vis_token[:n_vis_token, :] = patch_vis_tokens
|
1490 |
+
patch_attn[n_vis_token:] = 0
|
1491 |
+
else:
|
1492 |
+
padded_vis_token = patch_vis_tokens
|
1493 |
+
padded_vision_tokens.append(padded_vis_token)
|
1494 |
+
padded_attn_masks.append(patch_attn)
|
1495 |
+
vision_tokens = torch.stack(padded_vision_tokens, dim=0)
|
1496 |
+
vision_attention_mask = torch.stack(padded_attn_masks, dim=0)
|
1497 |
+
vision_tokens = vision_tokens[:, None, :, :]
|
1498 |
+
else:
|
1499 |
+
vision_tokens = None
|
1500 |
+
|
1501 |
+
# fuse the vision and language tokens
|
1502 |
+
# for xattn, vision_x and media_location are repeat_interleaved s.t.
|
1503 |
+
# the total batch size is B * num_beams
|
1504 |
+
new_inputs = self._prepare_inputs_for_forward(
|
1505 |
+
vision_tokens=vision_tokens,
|
1506 |
+
lang_x=lang_x,
|
1507 |
+
attention_mask=attention_mask,
|
1508 |
+
vision_attention_mask=vision_attention_mask,
|
1509 |
+
past_key_values=past_key_values,
|
1510 |
+
past_media_locations=past_media_locations,
|
1511 |
+
past_vision_tokens=past_vision_tokens,
|
1512 |
+
padding_side="left",
|
1513 |
+
num_beams=num_beams,
|
1514 |
+
)
|
1515 |
+
if transformers.__version__ == '4.41.0.dev0':
|
1516 |
+
output = self.lang_model.generate(
|
1517 |
+
**new_inputs,
|
1518 |
+
num_beams=num_beams,
|
1519 |
+
use_cache=True,
|
1520 |
+
**kwargs,
|
1521 |
+
)
|
1522 |
+
else:
|
1523 |
+
output = self.lang_model.generate(
|
1524 |
+
**new_inputs,
|
1525 |
+
past_key_values=past_key_values,
|
1526 |
+
num_beams=num_beams,
|
1527 |
+
use_cache=True,
|
1528 |
+
**kwargs,
|
1529 |
+
)
|
1530 |
+
self._post_forward_hook()
|
1531 |
+
return output
|