Upload 204 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- llava-1.7.0.dev0.dist-info/INSTALLER +1 -0
- llava-1.7.0.dev0.dist-info/LICENSE +201 -0
- llava-1.7.0.dev0.dist-info/METADATA +266 -0
- llava-1.7.0.dev0.dist-info/RECORD +204 -0
- llava-1.7.0.dev0.dist-info/REQUESTED +0 -0
- llava-1.7.0.dev0.dist-info/WHEEL +5 -0
- llava-1.7.0.dev0.dist-info/direct_url.json +1 -0
- llava-1.7.0.dev0.dist-info/top_level.txt +2 -0
- llava/__init__.py +1 -0
- llava/__pycache__/__init__.cpython-39.pyc +0 -0
- llava/__pycache__/constants.cpython-39.pyc +0 -0
- llava/__pycache__/conversation.cpython-39.pyc +0 -0
- llava/__pycache__/mm_utils.cpython-39.pyc +0 -0
- llava/__pycache__/utils.cpython-39.pyc +0 -0
- llava/constants.py +12 -0
- llava/conversation.py +577 -0
- llava/eval/__pycache__/evaluate_interleave.cpython-39.pyc +0 -0
- llava/eval/__pycache__/model_vqa.cpython-39.pyc +0 -0
- llava/eval/evaluate_interleave.py +339 -0
- llava/eval/model_vqa.py +240 -0
- llava/mm_utils.py +395 -0
- llava/model/__init__.py +16 -0
- llava/model/__pycache__/__init__.cpython-39.pyc +0 -0
- llava/model/__pycache__/apply_delta.cpython-39.pyc +0 -0
- llava/model/__pycache__/builder.cpython-39.pyc +0 -0
- llava/model/__pycache__/consolidate.cpython-39.pyc +0 -0
- llava/model/__pycache__/llava_arch.cpython-39.pyc +0 -0
- llava/model/__pycache__/make_delta.cpython-39.pyc +0 -0
- llava/model/__pycache__/utils.cpython-39.pyc +0 -0
- llava/model/apply_delta.py +47 -0
- llava/model/builder.py +301 -0
- llava/model/consolidate.py +30 -0
- llava/model/language_model/__pycache__/llava_gemma.cpython-39.pyc +0 -0
- llava/model/language_model/__pycache__/llava_llama.cpython-39.pyc +0 -0
- llava/model/language_model/__pycache__/llava_mistral.cpython-39.pyc +0 -0
- llava/model/language_model/__pycache__/llava_mixtral.cpython-39.pyc +0 -0
- llava/model/language_model/__pycache__/llava_mpt.cpython-39.pyc +0 -0
- llava/model/language_model/__pycache__/llava_qwen.cpython-39.pyc +0 -0
- llava/model/language_model/__pycache__/llava_qwen_moe.cpython-39.pyc +0 -0
- llava/model/language_model/__pycache__/modeling_llama.cpython-39.pyc +0 -0
- llava/model/language_model/llava_gemma.py +122 -0
- llava/model/language_model/llava_llama.py +156 -0
- llava/model/language_model/llava_mistral.py +127 -0
- llava/model/language_model/llava_mixtral.py +143 -0
- llava/model/language_model/llava_mpt.py +105 -0
- llava/model/language_model/llava_qwen.py +149 -0
- llava/model/language_model/llava_qwen_moe.py +149 -0
- llava/model/language_model/modeling_llama.py +1649 -0
- llava/model/llava_arch.py +509 -0
- llava/model/make_delta.py +52 -0
llava-1.7.0.dev0.dist-info/INSTALLER
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
pip
|
llava-1.7.0.dev0.dist-info/LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
llava-1.7.0.dev0.dist-info/METADATA
ADDED
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Metadata-Version: 2.1
|
2 |
+
Name: llava
|
3 |
+
Version: 1.7.0.dev0
|
4 |
+
Summary: LLaVA OneVision: The Next Generation of LLaVA with Better Image and Video Understanding Capabilities
|
5 |
+
Project-URL: Homepage, https://llava-vl.github.io
|
6 |
+
Project-URL: Bug Tracker, https://github.com/haotian-liu/LLaVA/issues
|
7 |
+
Classifier: Programming Language :: Python :: 3
|
8 |
+
Classifier: License :: OSI Approved :: Apache Software License
|
9 |
+
Requires-Python: >=3.8
|
10 |
+
Description-Content-Type: text/markdown
|
11 |
+
License-File: LICENSE
|
12 |
+
Provides-Extra: standalone
|
13 |
+
Requires-Dist: shortuuid ; extra == 'standalone'
|
14 |
+
Requires-Dist: httpx ==0.24.0 ; extra == 'standalone'
|
15 |
+
Requires-Dist: einops ; extra == 'standalone'
|
16 |
+
Requires-Dist: ftfy ; extra == 'standalone'
|
17 |
+
Provides-Extra: train
|
18 |
+
Requires-Dist: llava[standalone] ; extra == 'train'
|
19 |
+
Requires-Dist: numpy ==1.26.1 ; extra == 'train'
|
20 |
+
Requires-Dist: open-clip-torch ; extra == 'train'
|
21 |
+
Requires-Dist: fastapi ; extra == 'train'
|
22 |
+
Requires-Dist: gradio ==3.35.2 ; extra == 'train'
|
23 |
+
Requires-Dist: markdown2[all] ; extra == 'train'
|
24 |
+
Requires-Dist: numpy ; extra == 'train'
|
25 |
+
Requires-Dist: requests ; extra == 'train'
|
26 |
+
Requires-Dist: sentencepiece ; extra == 'train'
|
27 |
+
Requires-Dist: torch ==2.1.2 ; extra == 'train'
|
28 |
+
Requires-Dist: torchvision ==0.16.2 ; extra == 'train'
|
29 |
+
Requires-Dist: uvicorn ; extra == 'train'
|
30 |
+
Requires-Dist: wandb ; extra == 'train'
|
31 |
+
Requires-Dist: deepspeed ==0.14.2 ; extra == 'train'
|
32 |
+
Requires-Dist: peft ==0.4.0 ; extra == 'train'
|
33 |
+
Requires-Dist: accelerate >=0.29.1 ; extra == 'train'
|
34 |
+
Requires-Dist: tokenizers ~=0.15.2 ; extra == 'train'
|
35 |
+
Requires-Dist: transformers @ git+https://github.com/huggingface/transformers.git@1c39974a4c4036fd641bc1191cc32799f85715a4 ; extra == 'train'
|
36 |
+
Requires-Dist: bitsandbytes ==0.41.0 ; extra == 'train'
|
37 |
+
Requires-Dist: scikit-learn ==1.2.2 ; extra == 'train'
|
38 |
+
Requires-Dist: sentencepiece ~=0.1.99 ; extra == 'train'
|
39 |
+
Requires-Dist: einops ==0.6.1 ; extra == 'train'
|
40 |
+
Requires-Dist: einops-exts ==0.0.4 ; extra == 'train'
|
41 |
+
Requires-Dist: gradio-client ==0.2.9 ; extra == 'train'
|
42 |
+
Requires-Dist: urllib3 <=2.0.0 ; extra == 'train'
|
43 |
+
Requires-Dist: datasets ==2.16.1 ; extra == 'train'
|
44 |
+
Requires-Dist: pydantic ==1.10.8 ; extra == 'train'
|
45 |
+
Requires-Dist: timm ; extra == 'train'
|
46 |
+
Requires-Dist: hf-transfer ; extra == 'train'
|
47 |
+
Requires-Dist: opencv-python ; extra == 'train'
|
48 |
+
Requires-Dist: av ; extra == 'train'
|
49 |
+
Requires-Dist: decord ; extra == 'train'
|
50 |
+
Requires-Dist: tyro ; extra == 'train'
|
51 |
+
Requires-Dist: scipy ; extra == 'train'
|
52 |
+
|
53 |
+
<p align="center" width="100%">
|
54 |
+
<img src="https://i.postimg.cc/pL17YtG4/WX20240508-220230-2x.png" width="80%" height="80%">
|
55 |
+
</p>
|
56 |
+
|
57 |
+
# LLaVA-NeXT: Open Large Multimodal Models
|
58 |
+
[![Static Badge](https://img.shields.io/badge/llava_onevision-paper-green)](https://arxiv.org/abs/2408.03326)
|
59 |
+
[![llava_next-blog](https://img.shields.io/badge/llava_next-blog-green)](https://llava-vl.github.io/blog/)
|
60 |
+
|
61 |
+
[![llava_onevision-demo](https://img.shields.io/badge/llava_onevision-demo-red)](https://llava-onevision.lmms-lab.com/)
|
62 |
+
[![llava_next-interleave_demo](https://img.shields.io/badge/llava_next-interleave_demo-red)](https://huggingface.co/spaces/lmms-lab/LLaVA-NeXT-Interleave-Demo)
|
63 |
+
[![llava_next-video_demo](https://img.shields.io/badge/llava_next-video_demo-red)](https://huggingface.co/spaces/WildVision/vision-arena)
|
64 |
+
|
65 |
+
[![llava_onevision-checkpoints](https://img.shields.io/badge/llava_onevision-checkpoints-blue)](https://huggingface.co/collections/lmms-lab/llava-onevision-66a259c3526e15166d6bba37)
|
66 |
+
[![llava_next-interleave_checkpoints](https://img.shields.io/badge/llava_next-interleave_checkpoints-blue)](https://huggingface.co/collections/lmms-lab/llava-next-interleave-66763c55c411b340b35873d1)
|
67 |
+
[![llava_next-video_checkpoints](https://img.shields.io/badge/llava_next-video_checkpoints-blue)](https://huggingface.co/collections/lmms-lab/llava-next-video-661e86f5e8dabc3ff793c944)
|
68 |
+
[![llava_next-image_checkpoints](https://img.shields.io/badge/llava_next-image_checkpoints-blue)](https://huggingface.co/lmms-lab)
|
69 |
+
|
70 |
+
## Release Notes
|
71 |
+
|
72 |
+
- [2024/08/06] 🔥 **🚀 [LLaVA-OneVision (OV)](https://llava-vl.github.io/blog/2024-08-05-llava-onevision/)!** The new LLaVA-OV models (0.5B/7B/72B) achieve new state-of-the-art performance across single-image, multi-image, and video benchmarks, sometimes rivaling top commercial models on 47 diverse benchmarks. 📄 Explore More:
|
73 |
+
* [[Paper]](https://arxiv.org/abs/2408.03326): In-depth insights, new emegerging scenarios, ie, strong video understadning through task transfer from images.
|
74 |
+
* [[LLaVA-OV Doc]](https://github.com/LLaVA-VL/LLaVA-NeXT/blob/main/docs/LLaVA_OneVision.md): Model inference and evaluation guidance.
|
75 |
+
* [[Scripts]](https://github.com/LLaVA-VL/LLaVA-NeXT/blob/main/scripts/train): Start training models on your single-image/multi-image/video data.
|
76 |
+
|
77 |
+
- [2024/07/16] 🔥 **LLaVA-NeXT-Video** has been upgraded. The new 32B model achieves the best open-source performance on several video benchmarks, including [Video-MME](https://video-mme.github.io/home_page.html#leaderboard). Please refer to [this page](docs/LLaVA-NeXT-Video_0716.md) for details, refer to [llava_next-video_demo](https://huggingface.co/spaces/WildVision/vision-arena) for demo.
|
78 |
+
|
79 |
+
|
80 |
+
- [2024/06/23] 🔥 **LLaVA-NeXT-Interleave** is released. We utilize image-text interleaved format to unify multi-image, video, and 3D tasks in one LLM and achieve **SoTA** performance on a wide range of benchmarks. Check out [paper](https://arxiv.org/pdf/2407.07895), [blog](https://llava-vl.github.io/blog/2024-06-16-llava-next-interleave/), and [checkpoints](https://huggingface.co/collections/lmms-lab/llava-next-interleave-66763c55c411b340b35873d1) to see new capabilities and improved performance! We have released 0.5b, 7b, and 7b-dpo models.
|
81 |
+
* An all-round LLM for multi-image, video, and 3D with strong performance \[[demo](https://huggingface.co/spaces/lmms-lab/LLaVA-NeXT-Interleave-Demo)\]
|
82 |
+
* Construct interleave training data [**M4-Instruct**](https://huggingface.co/datasets/lmms-lab/M4-Instruct-Data)
|
83 |
+
* Construct multi-image benchmark [**LLaVA-Interleave Bench**](https://huggingface.co/datasets/lmms-lab/LLaVA-NeXT-Interleave-Bench)
|
84 |
+
|
85 |
+
|
86 |
+
- [2024/05/25] 🔥 Wondering "[What Else Influences Visual Instruction Tuning Beyond Data?](https://llava-vl.github.io/blog/2024-05-25-llava-next-ablations/)" Our new [blog](https://llava-vl.github.io/blog/2024-05-25-llava-next-ablations/) summarizes empirical explorations to ablate the various design choices in improving LMMs except instruct data itself. Meanwhile, open-source the recapioned high-quality data using LLaVA-NeXT-34B on [[COCO]](https://huggingface.co/datasets/lmms-lab/LLaVA-ReCap-118K) [[LCS]](https://huggingface.co/datasets/lmms-lab/LLaVA-ReCap-558K) [[CC3M]](https://huggingface.co/datasets/lmms-lab/LLaVA-ReCap-CC3M).
|
87 |
+
* Architectures (LMM & Vision Encoder)
|
88 |
+
* Visual Representations (Resolution & # Tokens)
|
89 |
+
* Training Strategies (High-quality data & Trainable modules)
|
90 |
+
|
91 |
+
- [2024/05/10] 🔥 **LLaVA-NeXT** (Stronger) models are released, with support of stronger LMM inlcuding LLama-3 (8B) and Qwen-1.5 (72B/110B) Check out [[blog](https://llava-vl.github.io/blog/2024-05-10-llava-next-stronger-llms/)] and [[checkpoints](https://huggingface.co/lmms-lab)] to see improved performance!
|
92 |
+
- [2024/05/10] 🔥 **LLaVA-NeXT** (Video) is released. The image-only-trained LLaVA-NeXT model is surprisingly strong on video tasks with zero-shot modality transfer. DPO training with AI feedback on videos can yield significant improvement. [[Blog](https://llava-vl.github.io/blog/2024-04-30-llava-next-video/)], [[checkpoints](https://huggingface.co/collections/lmms-lab/llava-next-video-661e86f5e8dabc3ff793c944)] and [[sglang](https://github.com/sgl-project/sglang)]
|
93 |
+
- [2024/01/30] 🔥 **LLaVA-NeXT** is out! With additional scaling to LLaVA-1.5, LLaVA-NeXT-34B outperforms Gemini Pro on some benchmarks. It can now process 4x more pixels and perform more tasks/applications than before. Check out the [blog post](https://llava-vl.github.io/blog/2024-01-30-llava-next/), and explore the [demo](https://llava.hliu.cc/)! Models are available in [Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md). Training/eval data and scripts coming soon.
|
94 |
+
<details>
|
95 |
+
<summary>More</summary>
|
96 |
+
|
97 |
+
- [2024/03/10] 🔥 Releasing **LMMs-Eval**, a highly efficient evaluation pipeline we used when developing LLaVA-NeXT. It supports the evaluation of LMMs on dozens of public datasets and allows new dataset onboarding, making the dev of new LMMs much faster. [[Blog](https://lmms-lab.github.io/lmms-eval-blog/lmms-eval-0.1/)] [[Codebase](https://github.com/EvolvingLMMs-Lab/lmms-eval)]
|
98 |
+
|
99 |
+
- [2023/11/10] [LLaVA-Plus](https://llava-vl.github.io/llava-plus/) is released: Learning to Use Tools for Creating Multimodal Agents, with LLaVA-Plus (LLaVA that Plug and Learn to Use Skills). [[Project Page](https://llava-vl.github.io/llava-plus/)] [[Demo](https://llavaplus.ngrok.io/)] [[Code](https://github.com/LLaVA-VL/LLaVA-Plus-Codebase)] [[Paper](https://arxiv.org/abs/2311.05437)]
|
100 |
+
- [2023/11/02] [LLaVA-Interactive](https://llava-vl.github.io/llava-interactive/) is released: Experience the future of human-AI multimodal interaction with an all-in-one demo for Image Chat, Segmentation, Generation and Editing. [[Project Page](https://llava-vl.github.io/llava-interactive/)] [[Demo](https://llavainteractive.ngrok.io/)] [[Code](https://github.com/LLaVA-VL/LLaVA-Interactive-Demo)] [[Paper](https://arxiv.org/abs/2311.00571)]
|
101 |
+
- [2023/10/26] 🔥 LLaVA-1.5 with LoRA achieves comparable performance as full-model finetuning, with a reduced GPU RAM requirement ([ckpts](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md#llava-v15), [script](https://github.com/haotian-liu/LLaVA#train)). We also provide a [doc](https://github.com/haotian-liu/LLaVA/blob/main/docs/Finetune_Custom_Data.md) on how to finetune LLaVA-1.5 on your own dataset with LoRA.
|
102 |
+
- [2023/10/12] Check out the Korean LLaVA (Ko-LLaVA), created by ETRI, who has generously supported our research! [[🤗 Demo](https://huggingface.co/spaces/etri-vilab/Ko-LLaVA)]
|
103 |
+
- [2023/10/05] 🔥 LLaVA-1.5 is out! Achieving SoTA on 11 benchmarks, with just simple modifications to the original LLaVA, utilizes all public data, completes training in ~1 day on a single 8-A100 node, and surpasses methods like Qwen-VL-Chat that use billion-scale data. Check out the [technical report](https://arxiv.org/abs/2310.03744), and explore the [demo](https://llava.hliu.cc/)! Models are available in [Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md). The training data and scripts of LLaVA-1.5 are released [here](https://github.com/haotian-liu/LLaVA#train), and evaluation scripts are released [here](https://github.com/haotian-liu/LLaVA/blob/main/docs/Evaluation.md)!
|
104 |
+
- [2023/09/26] LLaVA is improved with reinforcement learning from human feedback (RLHF) to improve fact grounding and reduce hallucination. Check out the new SFT and RLHF checkpoints at project [[LLavA-RLHF]](https://llava-rlhf.github.io/)
|
105 |
+
- [2023/09/22] [LLaVA](https://arxiv.org/abs/2304.08485) is accepted by NeurIPS 2023 as **oral presentation**, and [LLaVA-Med](https://arxiv.org/abs/2306.00890) is accepted by NeurIPS 2023 Datasets and Benchmarks Track as **spotlight presentation**.
|
106 |
+
- [2023/11/06] Support **Intel** dGPU and CPU platforms. [More details here.](https://github.com/haotian-liu/LLaVA/tree/intel/docs/intel)
|
107 |
+
- [2023/10/12] LLaVA is now supported in [llama.cpp](https://github.com/ggerganov/llama.cpp/pull/3436) with 4-bit / 5-bit quantization support!
|
108 |
+
- [2023/10/11] The training data and scripts of LLaVA-1.5 are released [here](https://github.com/haotian-liu/LLaVA#train), and evaluation scripts are released [here](https://github.com/haotian-liu/LLaVA/blob/main/docs/Evaluation.md)!
|
109 |
+
- [2023/10/10] [Roboflow Deep Dive](https://blog.roboflow.com/first-impressions-with-llava-1-5/): First Impressions with LLaVA-1.5.
|
110 |
+
- [2023/09/20] We summarize our empirical study of training 33B and 65B LLaVA models in a [note](https://arxiv.org/abs/2309.09958). Further, if you are interested in the comprehensive review, evolution and trend of multimodal foundation models, please check out our recent survey paper [``Multimodal Foundation Models: From Specialists to General-Purpose Assistants''.](https://arxiv.org/abs/2309.10020)
|
111 |
+
<p align="center">
|
112 |
+
<img src="https://github.com/Computer-Vision-in-the-Wild/CVinW_Readings/blob/main/images/mfm_evolution.jpeg?raw=true" width=50%/>
|
113 |
+
</p>
|
114 |
+
|
115 |
+
- [2023/07/19] 🔥 We release a major upgrade, including support for LLaMA-2, LoRA training, 4-/8-bit inference, higher resolution (336x336), and a lot more. We release [LLaVA Bench](https://github.com/haotian-liu/LLaVA/blob/main/docs/LLaVA_Bench.md) for benchmarking open-ended visual chat with results from Bard and Bing-Chat. We also support and verify training with RTX 3090 and RTX A6000. Check out [LLaVA-from-LLaMA-2](https://github.com/haotian-liu/LLaVA/blob/main/docs/LLaVA_from_LLaMA2.md), and our [model zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)!
|
116 |
+
- [2023/06/26] [CVPR 2023 Tutorial](https://vlp-tutorial.github.io/) on **Large Multimodal Models: Towards Building and Surpassing Multimodal GPT-4**! Please check out [[Slides](https://datarelease.blob.core.windows.net/tutorial/vision_foundation_models_2023/slides/Chunyuan_cvpr2023_tutorial_lmm.pdf)] [[Notes](https://arxiv.org/abs/2306.14895)] [[YouTube](https://youtu.be/mkI7EPD1vp8)] [[Bilibli](https://www.bilibili.com/video/BV1Ng4y1T7v3/)].
|
117 |
+
- [2023/06/11] We released the preview for the most requested feature: DeepSpeed and LoRA support! Please see documentations [here](./docs/LoRA.md).
|
118 |
+
- [2023/06/01] We released **LLaVA-Med: Large Language and Vision Assistant for Biomedicine**, a step towards building biomedical domain large language and vision models with GPT-4 level capabilities. Checkout the [paper](https://arxiv.org/abs/2306.00890) and [page](https://github.com/microsoft/LLaVA-Med).
|
119 |
+
- [2023/05/06] We are releasing [LLaVA-Lighting-MPT-7B-preview](https://huggingface.co/liuhaotian/LLaVA-Lightning-MPT-7B-preview), based on MPT-7B-Chat! See [here](#LLaVA-MPT-7b) for more details.
|
120 |
+
- [2023/05/02] 🔥 We are releasing LLaVA-Lighting! Train a lite, multimodal GPT-4 with just $40 in 3 hours! See [here](#train-llava-lightning) for more details.
|
121 |
+
- [2023/04/27] Thanks to the community effort, LLaVA-13B with 4-bit quantization allows you to run on a GPU with as few as 12GB VRAM! Try it out [here](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/llava).
|
122 |
+
- [2023/04/17] 🔥 We released **LLaVA: Large Language and Vision Assistant**. We propose visual instruction tuning, towards building large language and vision models with GPT-4 level capabilities. Checkout the [paper](https://arxiv.org/abs/2304.08485) and [demo](https://llava.hliu.cc/).
|
123 |
+
|
124 |
+
</details>
|
125 |
+
|
126 |
+
<!-- <a href="https://llava.hliu.cc/"><img src="assets/demo.gif" width="70%"></a> -->
|
127 |
+
|
128 |
+
**Usage and License Notices**: This project utilizes certain datasets and checkpoints that are subject to their respective original licenses. Users must comply with all terms and conditions of these original licenses, including but not limited to the [OpenAI Terms of Use](https://openai.com/policies/terms-of-use) for the dataset and the specific licenses for base language models for checkpoints trained using the dataset (e.g. [Llama-1/2 community license](https://ai.meta.com/llama/license/) for LLaMA-2 and Vicuna-v1.5, [Tongyi Qianwen RESEARCH LICENSE AGREEMENT](https://huggingface.co/Qwen/Qwen1.5-0.5B-Chat/blob/main/LICENSE) and [Llama-3 Research License](https://llama.meta.com/llama3/license/)). This project does not impose any additional constraints beyond those stipulated in the original licenses. Furthermore, users are reminded to ensure that their use of the dataset and checkpoints is in compliance with all applicable laws and regulations.
|
129 |
+
|
130 |
+
## Models & Scripts
|
131 |
+
|
132 |
+
### Installation
|
133 |
+
|
134 |
+
#### 1. **Clone this repository and navigate to the LLaVA folder:**
|
135 |
+
```bash
|
136 |
+
git clone https://github.com/LLaVA-VL/LLaVA-NeXT
|
137 |
+
cd LLaVA-NeXT
|
138 |
+
```
|
139 |
+
|
140 |
+
#### 2. **Install the inference package:**
|
141 |
+
```bash
|
142 |
+
conda create -n llava python=3.10 -y
|
143 |
+
conda activate llava
|
144 |
+
pip install --upgrade pip # Enable PEP 660 support.
|
145 |
+
pip install -e ".[train]"
|
146 |
+
```
|
147 |
+
|
148 |
+
### Project Navigation
|
149 |
+
Please checkout the following page for more inference & evaluation details.
|
150 |
+
|
151 |
+
#### - **LLaVA-NeXT: Stronger LLMs Supercharge Multimodal Capabilities in the Wild**
|
152 |
+
- [LLaVA-NeXT-Image](./docs/LLaVA-NeXT.md): for image demo inference and evaluation of stronger LMMs using [lmms-eval](https://github.com/EvolvingLMMs-Lab/lmms-eval).
|
153 |
+
|
154 |
+
|
155 |
+
#### - LLaVA-NeXT: A Strong Zero-shot Video Understanding Model
|
156 |
+
- [LLaVA-NeXT-Video](./docs/LLaVA-NeXT-Video.md): for video inference and evaluation scripts. We recommend to use [LMMs-video](https://lmms-lab.github.io/posts/lmms-eval-0.2/) for evaluation.
|
157 |
+
|
158 |
+
#### - LLaVA-NeXT: Tackling Multi-image, Video, and 3D in Large Multimodal Models
|
159 |
+
- [LLaVA-NeXT-Interleave](./docs/LLaVA-NeXT-Interleave.md): for multi-image demo and evaluation scripts.
|
160 |
+
|
161 |
+
## SGLang for SpeedUp Inference and Deployment
|
162 |
+
|
163 |
+
We use [SGLang](https://github.com/sgl-project/sglang) to speed up inference and deployment of LLaVA-NeXT. You could make LLaVA-NeXT as a backend API service with SGLang.
|
164 |
+
|
165 |
+
**Prepare Environment**:
|
166 |
+
Following the instruction in the [sglang](https://github.com/sgl-project/sglang?tab=readme-ov-file#install)
|
167 |
+
|
168 |
+
### LLaVA-NeXT (Image)
|
169 |
+
|
170 |
+
Checkout the HTTP Post/Get and SRT usage at [sglang/examples/usage/llava](https://github.com/sgl-project/sglang/blob/main/examples/usage/llava)
|
171 |
+
|
172 |
+
### LLaVA-NeXT (Video)
|
173 |
+
|
174 |
+
**Launch and Run on (K) Nodes**:
|
175 |
+
- Go to sglang project
|
176 |
+
```
|
177 |
+
cd PATH_TO/sglang
|
178 |
+
```
|
179 |
+
- First node:
|
180 |
+
```sh
|
181 |
+
bash examples/usage/llava_video/srt_example_llava_v.sh K 0 YOUR_VIDEO_PATH YOUR_MODEL_PATH FRAMES_PER_VIDEO
|
182 |
+
(e.g. bash examples/usage/llava_video/srt_example_llava_v.sh K 0 examples/usage/llava_video/videos/Q98Z4OTh8RwmDonc.mp4 lmms-lab/LLaVA-NeXT-Video-7B-DPO 16)
|
183 |
+
```
|
184 |
+
- Second node:
|
185 |
+
```sh
|
186 |
+
bash examples/usage/llava_video/srt_example_llava_v.sh K 1 YOUR_VIDEO_PATH YOUR_MODEL_PATH FRAMES_PER_VIDEO
|
187 |
+
```
|
188 |
+
- The K node:
|
189 |
+
```sh
|
190 |
+
bash examples/usage/llava_video/srt_example_llava_v.sh K K-1 YOUR_VIDEO_PATH YOUR_MODEL_PATH FRAMES_PER_VIDEO
|
191 |
+
```
|
192 |
+
|
193 |
+
|
194 |
+
## Citation
|
195 |
+
|
196 |
+
If you find it useful for your research and applications, please cite related papers/blogs using this BibTeX:
|
197 |
+
```bibtex
|
198 |
+
@article{li2024llava,
|
199 |
+
title={LLaVA-NeXT-Interleave: Tackling Multi-image, Video, and 3D in Large Multimodal Models},
|
200 |
+
author={Li, Feng and Zhang, Renrui and Zhang, Hao and Zhang, Yuanhan and Li, Bo and Li, Wei and Ma, Zejun and Li, Chunyuan},
|
201 |
+
journal={arXiv preprint arXiv:2407.07895},
|
202 |
+
year={2024}
|
203 |
+
}
|
204 |
+
|
205 |
+
@misc{li2024llavanext-ablations,
|
206 |
+
title={LLaVA-NeXT: What Else Influences Visual Instruction Tuning Beyond Data?},
|
207 |
+
url={https://llava-vl.github.io/blog/2024-05-25-llava-next-ablations/},
|
208 |
+
author={Li, Bo and Zhang, Hao and Zhang, Kaichen and Guo, Dong and Zhang, Yuanhan and Zhang, Renrui and Li, Feng and Liu, Ziwei and Li, Chunyuan},
|
209 |
+
month={May},
|
210 |
+
year={2024}
|
211 |
+
}
|
212 |
+
|
213 |
+
@misc{li2024llavanext-strong,
|
214 |
+
title={LLaVA-NeXT: Stronger LLMs Supercharge Multimodal Capabilities in the Wild},
|
215 |
+
url={https://llava-vl.github.io/blog/2024-05-10-llava-next-stronger-llms/},
|
216 |
+
author={Li, Bo and Zhang, Kaichen and Zhang, Hao and Guo, Dong and Zhang, Renrui and Li, Feng and Zhang, Yuanhan and Liu, Ziwei and Li, Chunyuan},
|
217 |
+
month={May},
|
218 |
+
year={2024}
|
219 |
+
}
|
220 |
+
|
221 |
+
@misc{zhang2024llavanext-video,
|
222 |
+
title={LLaVA-NeXT: A Strong Zero-shot Video Understanding Model},
|
223 |
+
url={https://llava-vl.github.io/blog/2024-04-30-llava-next-video/},
|
224 |
+
author={Zhang, Yuanhan and Li, Bo and Liu, haotian and Lee, Yong jae and Gui, Liangke and Fu, Di and Feng, Jiashi and Liu, Ziwei and Li, Chunyuan},
|
225 |
+
month={April},
|
226 |
+
year={2024}
|
227 |
+
}
|
228 |
+
|
229 |
+
@misc{liu2024llavanext,
|
230 |
+
title={LLaVA-NeXT: Improved reasoning, OCR, and world knowledge},
|
231 |
+
url={https://llava-vl.github.io/blog/2024-01-30-llava-next/},
|
232 |
+
author={Liu, Haotian and Li, Chunyuan and Li, Yuheng and Li, Bo and Zhang, Yuanhan and Shen, Sheng and Lee, Yong Jae},
|
233 |
+
month={January},
|
234 |
+
year={2024}
|
235 |
+
}
|
236 |
+
|
237 |
+
@misc{liu2023improvedllava,
|
238 |
+
title={Improved Baselines with Visual Instruction Tuning},
|
239 |
+
author={Liu, Haotian and Li, Chunyuan and Li, Yuheng and Lee, Yong Jae},
|
240 |
+
publisher={arXiv:2310.03744},
|
241 |
+
year={2023},
|
242 |
+
}
|
243 |
+
|
244 |
+
@misc{liu2023llava,
|
245 |
+
title={Visual Instruction Tuning},
|
246 |
+
author={Liu, Haotian and Li, Chunyuan and Wu, Qingyang and Lee, Yong Jae},
|
247 |
+
publisher={NeurIPS},
|
248 |
+
year={2023},
|
249 |
+
}
|
250 |
+
```
|
251 |
+
|
252 |
+
## Acknowledgement
|
253 |
+
|
254 |
+
- [Vicuna](https://github.com/lm-sys/FastChat): the codebase we built upon, and our base model Vicuna-13B that has the amazing language capabilities!
|
255 |
+
- The LLaVA-NeXT project is currently maintained by the team along with our contributors (listed alphabetically by the first names): [Bo Li](https://brianboli.com/), [Dong Guo](https://www.linkedin.com/in/dongguoset/), [Feng Li](https://scholar.google.com/citations?hl=zh-CN&user=ybRe9GcAAAAJ&view_op=list_works&sortby=pubdate), [Hao Zhang](https://scholar.google.com/citations?user=B8hPxMQAAAAJ&hl=en), [Kaichen Zhang](https://www.linkedin.com/in/kaichen-zhang-014b17219/?originalSubdomain=sg), [Renrui Zhang](https://zrrskywalker.github.io/), [Yuanhan Zhang](https://zhangyuanhan-ai.github.io/), led by [Chunyuan Li](https://chunyuan.li/) and with the guidance and help from [Haotian Liu](https://hliu.cc/).
|
256 |
+
- The `lmms-eval` framework and its core contributors, including Peiyuan Zhang, Fanyi Pu, Joshua Adrian Cahyono, and Kairui Hu, for their support on the evaluation side.
|
257 |
+
|
258 |
+
## Related Projects
|
259 |
+
|
260 |
+
- [Instruction Tuning with GPT-4](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
261 |
+
- [LLaVA-Med: Training a Large Language-and-Vision Assistant for Biomedicine in One Day](https://github.com/microsoft/LLaVA-Med)
|
262 |
+
- [Otter: In-Context Multi-Modal Instruction Tuning](https://github.com/Luodian/Otter)
|
263 |
+
|
264 |
+
For future project ideas, please check out:
|
265 |
+
- [SEEM: Segment Everything Everywhere All at Once](https://github.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once)
|
266 |
+
- [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything) to detect, segment, and generate anything by marrying [Grounding DINO](https://github.com/IDEA-Research/GroundingDINO) and [Segment-Anything](https://github.com/facebookresearch/segment-anything).
|
llava-1.7.0.dev0.dist-info/RECORD
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
llava-1.7.0.dev0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
|
2 |
+
llava-1.7.0.dev0.dist-info/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
|
3 |
+
llava-1.7.0.dev0.dist-info/METADATA,sha256=lLd1vxRYxiY82Kqxic3pOgPA1GfwRclzyqko-u4mbl8,22760
|
4 |
+
llava-1.7.0.dev0.dist-info/RECORD,,
|
5 |
+
llava-1.7.0.dev0.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
6 |
+
llava-1.7.0.dev0.dist-info/WHEEL,sha256=Mdi9PDNwEZptOjTlUcAth7XJDFtKrHYaQMPulZeBCiQ,91
|
7 |
+
llava-1.7.0.dev0.dist-info/direct_url.json,sha256=hxcapmB6J2WrkkCuvIaCPRbl85Ju7DJ67xkANY1sWuc,138
|
8 |
+
llava-1.7.0.dev0.dist-info/top_level.txt,sha256=AlU_N7AUyx6Fn0VOZu0pGBgjbU0fGvKkHnnkCFJbIF4,10
|
9 |
+
llava/__init__.py,sha256=8fWfEdbl8Xc5O1CThmLAMnB2h1Dt-gQiLeIW1Uo-JhE,42
|
10 |
+
llava/__pycache__/__init__.cpython-39.pyc,,
|
11 |
+
llava/__pycache__/constants.cpython-39.pyc,,
|
12 |
+
llava/__pycache__/conversation.cpython-39.pyc,,
|
13 |
+
llava/__pycache__/mm_utils.cpython-39.pyc,,
|
14 |
+
llava/__pycache__/utils.cpython-39.pyc,,
|
15 |
+
llava/constants.py,sha256=bcZAgJAHgpyMey-SSv3llZjeJfC8xJ7IvIRwPIGrj-4,305
|
16 |
+
llava/conversation.py,sha256=k-L_tP6EcNYxkVH0PacaeuNAw9R7NmllE8oTPmHs3oM,22785
|
17 |
+
llava/eval/__pycache__/evaluate_interleave.cpython-39.pyc,,
|
18 |
+
llava/eval/__pycache__/model_vqa.cpython-39.pyc,,
|
19 |
+
llava/eval/evaluate_interleave.py,sha256=i8jwOxkYCh2WwMmCQ6bMqeLYZ_YvZHDNv-g82PxaOoY,10989
|
20 |
+
llava/eval/model_vqa.py,sha256=sKUyodB4dGHy0j7oxF_Um72lgn48iPFkR9upYej8VWw,10704
|
21 |
+
llava/mm_utils.py,sha256=Gwvu67nQT2Urwj4Q7bvcK7Y_yOkzilX50GSj5UC2-DY,17417
|
22 |
+
llava/model/__init__.py,sha256=K1A5xgHwGb6vhX2FsA0kEcRK7RFlM419rJJ0--Ax_78,679
|
23 |
+
llava/model/__pycache__/__init__.cpython-39.pyc,,
|
24 |
+
llava/model/__pycache__/apply_delta.cpython-39.pyc,,
|
25 |
+
llava/model/__pycache__/builder.cpython-39.pyc,,
|
26 |
+
llava/model/__pycache__/consolidate.cpython-39.pyc,,
|
27 |
+
llava/model/__pycache__/llava_arch.cpython-39.pyc,,
|
28 |
+
llava/model/__pycache__/make_delta.cpython-39.pyc,,
|
29 |
+
llava/model/__pycache__/utils.cpython-39.pyc,,
|
30 |
+
llava/model/apply_delta.py,sha256=ZItbnApA9G_hAXShPAOe5STKUy4s5o9acJ_wseyTWrU,1979
|
31 |
+
llava/model/builder.py,sha256=ou9C95SNH6JWB8tBuYWSFcw71Twl7b9l3T3UBV3XN_8,17923
|
32 |
+
llava/model/consolidate.py,sha256=iYWg_Huv7GQcuUKXI6EV5uESjhij_qTH_XhUzchoXV0,945
|
33 |
+
llava/model/language_model/__pycache__/llava_gemma.cpython-39.pyc,,
|
34 |
+
llava/model/language_model/__pycache__/llava_llama.cpython-39.pyc,,
|
35 |
+
llava/model/language_model/__pycache__/llava_mistral.cpython-39.pyc,,
|
36 |
+
llava/model/language_model/__pycache__/llava_mixtral.cpython-39.pyc,,
|
37 |
+
llava/model/language_model/__pycache__/llava_mpt.cpython-39.pyc,,
|
38 |
+
llava/model/language_model/__pycache__/llava_qwen.cpython-39.pyc,,
|
39 |
+
llava/model/language_model/__pycache__/llava_qwen_moe.cpython-39.pyc,,
|
40 |
+
llava/model/language_model/__pycache__/modeling_llama.cpython-39.pyc,,
|
41 |
+
llava/model/language_model/llava_gemma.py,sha256=800LF_ldzdpq9_yeYejbhpzsOD1EwZuS_G2Nkf6ejuU,4980
|
42 |
+
llava/model/language_model/llava_llama.py,sha256=X1m-xVknZb6caYTeF1iJT29WV0l-2n6Ud7iL8zr49C0,6322
|
43 |
+
llava/model/language_model/llava_mistral.py,sha256=IpV8-8NE693wMSYF7zzqBgCNGdpTMRYTU9RClAWKQL4,5189
|
44 |
+
llava/model/language_model/llava_mixtral.py,sha256=mA4kq2VjbYin7r_nad9VzEkj8cguVR9F35mu3zVYulc,5882
|
45 |
+
llava/model/language_model/llava_mpt.py,sha256=7FRWHZf6JWkrqJIgaosAP19p2vl-kNGNrhu4JkaYdPk,3836
|
46 |
+
llava/model/language_model/llava_qwen.py,sha256=ESNFowdoSykW9BnhSZAWgJMb_xwigOeA4G4AkrF2rh8,6204
|
47 |
+
llava/model/language_model/llava_qwen_moe.py,sha256=LOMS6d-BP8o2-SJfETnlZttCdMVZ6hGDfuhispjtQlo,6230
|
48 |
+
llava/model/language_model/modeling_llama.py,sha256=Qwd2vsz-vAbBl07zwR1IvCTCTwyRSoxJMwVcgbNNKJc,82886
|
49 |
+
llava/model/llava_arch.py,sha256=vO_gPr8uQZ8EOPWphnMeFGL05u7xRjZKFWgnU6AsUNc,28497
|
50 |
+
llava/model/make_delta.py,sha256=oUJNaT6ikdifV5fF9SPCllyucO8h0tjSniT8UFKzcWk,2303
|
51 |
+
llava/model/multimodal_encoder/__pycache__/builder.cpython-39.pyc,,
|
52 |
+
llava/model/multimodal_encoder/__pycache__/clip_encoder.cpython-39.pyc,,
|
53 |
+
llava/model/multimodal_encoder/__pycache__/hf_vision.cpython-39.pyc,,
|
54 |
+
llava/model/multimodal_encoder/__pycache__/imagebind.cpython-39.pyc,,
|
55 |
+
llava/model/multimodal_encoder/__pycache__/open_clip_encoder.cpython-39.pyc,,
|
56 |
+
llava/model/multimodal_encoder/__pycache__/siglip_encoder.cpython-39.pyc,,
|
57 |
+
llava/model/multimodal_encoder/builder.py,sha256=bYmGLnJHgBJXnkRrs9phaqHH6AlVsUNey6w-iZFuXn0,1922
|
58 |
+
llava/model/multimodal_encoder/clip_encoder.py,sha256=ofOgPYkJjXGnpk8SAtOGSXdw1DKYZOAeUwzn-5DouBE,7448
|
59 |
+
llava/model/multimodal_encoder/dev_eva_clip/__pycache__/eva_vit.cpython-39.pyc,,
|
60 |
+
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__init__.py,sha256=6mbC4b7gg9g4LcxJXEEZtAGMp_jwzl0enio8T6j6b3Y,792
|
61 |
+
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/__init__.cpython-39.pyc,,
|
62 |
+
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/constants.cpython-39.pyc,,
|
63 |
+
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/eva_vit_model.cpython-39.pyc,,
|
64 |
+
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/factory.cpython-39.pyc,,
|
65 |
+
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/hf_configs.cpython-39.pyc,,
|
66 |
+
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/hf_model.cpython-39.pyc,,
|
67 |
+
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/loss.cpython-39.pyc,,
|
68 |
+
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/model.cpython-39.pyc,,
|
69 |
+
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/modified_resnet.cpython-39.pyc,,
|
70 |
+
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/openai.cpython-39.pyc,,
|
71 |
+
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/pretrained.cpython-39.pyc,,
|
72 |
+
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/rope.cpython-39.pyc,,
|
73 |
+
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/timm_model.cpython-39.pyc,,
|
74 |
+
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/tokenizer.cpython-39.pyc,,
|
75 |
+
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/transform.cpython-39.pyc,,
|
76 |
+
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/transformer.cpython-39.pyc,,
|
77 |
+
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/utils.cpython-39.pyc,,
|
78 |
+
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/constants.py,sha256=PKjrqkcdpJK_MQmnTsZ2oxhxIHn9AlI-y2ap87BnR1Q,118
|
79 |
+
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/eva_vit_model.py,sha256=XwR8sswHn0Eb8C9EOFz-Gn-lQm831cCCu2xbR__0XiI,23142
|
80 |
+
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/factory.py,sha256=4fbCuG0eUpVpSA_yE8_GUxCyRmojbXF2C9X3oSE32ns,24280
|
81 |
+
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/hf_configs.py,sha256=CwD_HmdfQ1Tb-fLOr9KgseVP80nMNv6V4uWI6DDOBqg,2132
|
82 |
+
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/hf_model.py,sha256=uWSu0OTsXR8v5y2P6jwkdzzy2Ce1H__UEthHM0F7xR4,10350
|
83 |
+
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/loss.py,sha256=p-B34PgBg0JuutSriqp0Qc2VLJrkLf91fGmBRHiZOSg,5746
|
84 |
+
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model.py,sha256=nNKTAljan_PpGkTJ4niwV3xxI0g9C3704U6OUJh8P_k,17650
|
85 |
+
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/modified_resnet.py,sha256=PB0q6KsaQKwVRlX8R4qW8Cf4rzY5v5QcFiydMXE8rS0,7163
|
86 |
+
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/openai.py,sha256=g-kvzfUuPEMW0uU4e8NfwGCpJnL1kXdEZUWT0YqVoXQ,5570
|
87 |
+
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/pretrained.py,sha256=lWoSv_3xdMPmVv8EnQWpC7Kq4H8ihSYTHYk_Nr_jGA8,12211
|
88 |
+
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/rope.py,sha256=i8RDQ1Zv9cTXVBbW8RbbfaT0wGxjEFu-qq3DCXQBR-8,5399
|
89 |
+
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/timm_model.py,sha256=Eta_-wNrwv953zWVxXshCywCVOwK2jPRiOId9XcFyhk,4895
|
90 |
+
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/tokenizer.py,sha256=u4Gur6i8rcWvdPZRZSNgDshmbkchDx5DvZlSGxvoXH8,7368
|
91 |
+
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/transform.py,sha256=fYdJYEVPviaTliRkSsrWdPmYbLGTM4a6QYlNN_3ZzHA,3514
|
92 |
+
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/transformer.py,sha256=EYlChMZnX7vQvitTWM73iIhaZf4zv_OTSJ6L7ZTZ8go,26410
|
93 |
+
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/utils.py,sha256=aYJKAK5qw8Kge-a6wTTBOwb-wqaV9Gri5vuLMYq4E84,14964
|
94 |
+
llava/model/multimodal_encoder/dev_eva_clip/eva_vit.py,sha256=7m3OHiUdnHkScpoRp_DjGLavrCcKf2te6oJshv27kzI,6219
|
95 |
+
llava/model/multimodal_encoder/eva_clip/__pycache__/eva_clip_encoder.cpython-39.pyc,,
|
96 |
+
llava/model/multimodal_encoder/eva_clip/__pycache__/eva_clip_processors.cpython-39.pyc,,
|
97 |
+
llava/model/multimodal_encoder/eva_clip/__pycache__/eva_vit.cpython-39.pyc,,
|
98 |
+
llava/model/multimodal_encoder/eva_clip/__pycache__/factory.cpython-39.pyc,,
|
99 |
+
llava/model/multimodal_encoder/eva_clip/eva_clip_encoder.py,sha256=FL-gQEpHBlYYYtzPcM1jg5HTPtyqSasNrou_aRtmghs,2890
|
100 |
+
llava/model/multimodal_encoder/eva_clip/eva_clip_processors.py,sha256=kwNlbCc4cWz7cmQBZwS97as2taQf0RRFoZAXuNZdvjg,2215
|
101 |
+
llava/model/multimodal_encoder/eva_clip/eva_vit.py,sha256=mrgroKZHGFK2URbatEQKpID8zhKmVBHgRyKEc4D_bUI,34605
|
102 |
+
llava/model/multimodal_encoder/eva_clip/factory.py,sha256=iLoVP1ldKm0YvXX3uz4Wsb2dw1DElMZgUdIrxMS1e70,1829
|
103 |
+
llava/model/multimodal_encoder/hf_vision.py,sha256=Pw0y7SVYKiUIuBCP8uMySWRyIcpNBN1oUsjRBMVqSfM,4549
|
104 |
+
llava/model/multimodal_encoder/imagebind.py,sha256=MkaKOrpYr1Fj08QSzy-Y3awDmmB9Y5Y6KoVGJR52Xpg,2498
|
105 |
+
llava/model/multimodal_encoder/open_clip_encoder.py,sha256=0iFyD49NZFwTutR6Hq5upIybHbrzgPlPwQ8kgrRwZXQ,6812
|
106 |
+
llava/model/multimodal_encoder/siglip_encoder.py,sha256=LPGxELdEKwu5FO1vtvVmBmcjhMq48d1AzpJJAq_0yIk,26103
|
107 |
+
llava/model/multimodal_projector/__pycache__/builder.cpython-39.pyc,,
|
108 |
+
llava/model/multimodal_projector/__pycache__/pooler_projector.cpython-39.pyc,,
|
109 |
+
llava/model/multimodal_projector/builder.py,sha256=acKSHT-As_qu2haXj-g6gRRzdq_BPNWwgP_ZDYEntUI,2192
|
110 |
+
llava/model/multimodal_projector/pooler_projector.py,sha256=zxAP1Ut-oJXG-L4xggh2FC4epc0nemgk1v8RnoKCxZ4,975
|
111 |
+
llava/model/multimodal_resampler/__pycache__/builder.cpython-39.pyc,,
|
112 |
+
llava/model/multimodal_resampler/__pycache__/masked_drop.cpython-39.pyc,,
|
113 |
+
llava/model/multimodal_resampler/__pycache__/perceiver.cpython-39.pyc,,
|
114 |
+
llava/model/multimodal_resampler/__pycache__/qformer.cpython-39.pyc,,
|
115 |
+
llava/model/multimodal_resampler/__pycache__/spatial_pool.cpython-39.pyc,,
|
116 |
+
llava/model/multimodal_resampler/builder.py,sha256=qaSzq2lcRDkIFv_QRTXrkfn-OHfII--4LIHkrkIfwPg,1039
|
117 |
+
llava/model/multimodal_resampler/masked_drop.py,sha256=FNgUNkIw8JQaAv3lppL7vYtUOfoP2DArw0AuslXQ0TE,3061
|
118 |
+
llava/model/multimodal_resampler/perceiver.py,sha256=uOAntKuihMkBkAp5bIozKUApvXhvlCeocRNtUva-VqA,4995
|
119 |
+
llava/model/multimodal_resampler/qformer.py,sha256=d-A2JpouT-VjWb43BF4HXP_jaIM0o_NhFhVy_3Uawsc,50384
|
120 |
+
llava/model/multimodal_resampler/spatial_pool.py,sha256=hEAlKpbgzGjXeY365TZaI3MI2YAvle1Yfb5dKlAiQls,1775
|
121 |
+
llava/model/utils.py,sha256=KzkLVJjTHJqI9vg1umDp4-SkT4IbMcI_Uhp-4V4xkWk,947
|
122 |
+
llava/serve/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
123 |
+
llava/serve/__pycache__/__init__.cpython-39.pyc,,
|
124 |
+
llava/serve/__pycache__/cli.cpython-39.pyc,,
|
125 |
+
llava/serve/__pycache__/controller.cpython-39.pyc,,
|
126 |
+
llava/serve/__pycache__/gradio_multi_image.cpython-39.pyc,,
|
127 |
+
llava/serve/__pycache__/gradio_web_server.cpython-39.pyc,,
|
128 |
+
llava/serve/__pycache__/model_worker.cpython-39.pyc,,
|
129 |
+
llava/serve/__pycache__/register_worker.cpython-39.pyc,,
|
130 |
+
llava/serve/__pycache__/sglang_worker.cpython-39.pyc,,
|
131 |
+
llava/serve/__pycache__/test_message.cpython-39.pyc,,
|
132 |
+
llava/serve/cli.py,sha256=e-ALjf2zdr08UiqeW-DmBoGEHRWiO-I5ELpqjls30iE,4403
|
133 |
+
llava/serve/controller.py,sha256=zKmdDMoOyHltZGKQCzIgrUkXsget_3UkjgGNyq0xy7Y,10070
|
134 |
+
llava/serve/gradio_multi_image.py,sha256=mwVVe4l-7ry3umZx9CFGrwYKgPup4RLupMXrsRj1IZc,20029
|
135 |
+
llava/serve/gradio_web_server.py,sha256=t8xWJPNQ0zDOGPi4ju9NkA89kbOcPVMO-v6pNM7BZIs,19519
|
136 |
+
llava/serve/model_worker.py,sha256=SBzKdeQE0hhVM9bwxplVb8KqmUm9qhp1H74THX82MD0,11121
|
137 |
+
llava/serve/register_worker.py,sha256=Q7BnBGr0lcDdKaI-DHv_5IKK0KpHvtUTCBwFz5PspLo,760
|
138 |
+
llava/serve/sglang_worker.py,sha256=lYeIDVZlKho4YcLi82bUxP4ccFJCTpVcfcM_uvdH6wI,9221
|
139 |
+
llava/serve/test_message.py,sha256=ofJWbzm3oQz5UKU2tBSfV2ZzDZkGpMPDE9yrlvJXNAM,2048
|
140 |
+
llava/train/__pycache__/llama_flash_attn_monkey_patch.cpython-39.pyc,,
|
141 |
+
llava/train/__pycache__/llava_trainer.cpython-39.pyc,,
|
142 |
+
llava/train/__pycache__/llava_trainer_eval.cpython-39.pyc,,
|
143 |
+
llava/train/__pycache__/train.cpython-39.pyc,,
|
144 |
+
llava/train/__pycache__/train_dpo.cpython-39.pyc,,
|
145 |
+
llava/train/__pycache__/train_mem.cpython-39.pyc,,
|
146 |
+
llava/train/llama_flash_attn_monkey_patch.py,sha256=CBkiqWIZXW68_2YJdtTPQqXBadPq15vHmliDGoqeW5c,4280
|
147 |
+
llava/train/llava_trainer.py,sha256=rGZCclj_T8ATDfR2JrNa1mLibH1z0kkVRQsSxvT_Rw8,27309
|
148 |
+
llava/train/llava_trainer_eval.py,sha256=bNGpwNtA1d20xQ5BxZa8O-ZnMHR7CqQ9VgAbYtt3mAQ,3515
|
149 |
+
llava/train/train.py,sha256=VLRZQV-LjafzV5wtQ1f_s_0Qm0mxE8ei1VSW-e_XMrU,78864
|
150 |
+
llava/train/train_dpo.py,sha256=oFGJghgJezMVms70YWd4KvFrsZDei4UbojjhMroNIOE,84440
|
151 |
+
llava/train/train_mem.py,sha256=C06MqpCqOVtTsewH8N67oUzmYIm0HY6-y3MuzhlE1wg,80
|
152 |
+
llava/utils.py,sha256=qixNPajlBGe9XqNWnYOQ6V6OTreQcBK6W8jkuxRjzBU,6533
|
153 |
+
trl/__init__.py,sha256=Od8x7-H_1H5LfnScvJTJxjWeDuHzKlnUuToL5RQswSA,1110
|
154 |
+
trl/__pycache__/__init__.cpython-39.pyc,,
|
155 |
+
trl/__pycache__/core.cpython-39.pyc,,
|
156 |
+
trl/__pycache__/import_utils.cpython-39.pyc,,
|
157 |
+
trl/core.py,sha256=TPuO3us2wqAXsQWm8v-lNtnVmYHiuJcvOJoZeSV29YI,12303
|
158 |
+
trl/environment/__init__.py,sha256=XM1ZiS_F7-r8P6Z20VNHh71Wnw-scMoujSU-lqEKGNc,78
|
159 |
+
trl/environment/__pycache__/__init__.cpython-39.pyc,,
|
160 |
+
trl/environment/__pycache__/base_environment.cpython-39.pyc,,
|
161 |
+
trl/environment/base_environment.py,sha256=pyrIOZJsl-Q6VAv2PRGaUbIDeCDp7jyc1mtibpPvHrA,17882
|
162 |
+
trl/extras/__init__.py,sha256=daKpM_o7XbZix98t_kxwLyMteb5EViUCH8MURZFEq_Q,684
|
163 |
+
trl/extras/__pycache__/__init__.cpython-39.pyc,,
|
164 |
+
trl/extras/__pycache__/best_of_n_sampler.cpython-39.pyc,,
|
165 |
+
trl/extras/__pycache__/dataset_formatting.cpython-39.pyc,,
|
166 |
+
trl/extras/best_of_n_sampler.py,sha256=RHA3RbnqifnpUh7HZrKdhcDNz9LVSpcYUj_A_jrC8Ro,5243
|
167 |
+
trl/extras/dataset_formatting.py,sha256=TVeUWfxA1q3oat3HJpMIA6olsUYwjpQzDReVLkeZ7NI,3726
|
168 |
+
trl/import_utils.py,sha256=kfnxR_z4CB1rM5JcBZtVxhsOcxHYmIXWzTgTMRGc-7U,3238
|
169 |
+
trl/models/__init__.py,sha256=xY9josSMMq7J0coDxBnhsMvIK3sJvfNIeOgseQZu6cE,1244
|
170 |
+
trl/models/__pycache__/__init__.cpython-39.pyc,,
|
171 |
+
trl/models/__pycache__/modeling_base.cpython-39.pyc,,
|
172 |
+
trl/models/__pycache__/modeling_sd_base.cpython-39.pyc,,
|
173 |
+
trl/models/__pycache__/modeling_value_head.cpython-39.pyc,,
|
174 |
+
trl/models/__pycache__/utils.cpython-39.pyc,,
|
175 |
+
trl/models/modeling_base.py,sha256=oMvYF2MnXqykCkDBBAdLDjowUB0PcL5LftpArsdquiM,28842
|
176 |
+
trl/models/modeling_sd_base.py,sha256=2OB-rShWUebUoCuVr27gla3DEpA_eX2W5UCVr6WJ2w0,28073
|
177 |
+
trl/models/modeling_value_head.py,sha256=wq9rqn8oPJMmgyNpgI5AWZSmT0JZb4RHO13r6jzExTo,18822
|
178 |
+
trl/models/utils.py,sha256=8kc1anjd4PPLWM5zce8eoXQox1uq6R-E_UwUF_b2YBk,3389
|
179 |
+
trl/trainer/__init__.py,sha256=9gamN5nkygFHBfF56JvC0sN67axqU6WuXXY9s1YteK8,1514
|
180 |
+
trl/trainer/__pycache__/__init__.cpython-39.pyc,,
|
181 |
+
trl/trainer/__pycache__/base.cpython-39.pyc,,
|
182 |
+
trl/trainer/__pycache__/ddpo_config.cpython-39.pyc,,
|
183 |
+
trl/trainer/__pycache__/ddpo_trainer.cpython-39.pyc,,
|
184 |
+
trl/trainer/__pycache__/dpo_trainer.cpython-39.pyc,,
|
185 |
+
trl/trainer/__pycache__/iterative_sft_trainer.cpython-39.pyc,,
|
186 |
+
trl/trainer/__pycache__/model_config.cpython-39.pyc,,
|
187 |
+
trl/trainer/__pycache__/ppo_config.cpython-39.pyc,,
|
188 |
+
trl/trainer/__pycache__/ppo_trainer.cpython-39.pyc,,
|
189 |
+
trl/trainer/__pycache__/reward_config.cpython-39.pyc,,
|
190 |
+
trl/trainer/__pycache__/reward_trainer.cpython-39.pyc,,
|
191 |
+
trl/trainer/__pycache__/sft_trainer.cpython-39.pyc,,
|
192 |
+
trl/trainer/__pycache__/utils.cpython-39.pyc,,
|
193 |
+
trl/trainer/base.py,sha256=PID37pjUqfbobelu9tFP9nwd_p9Rx_Cq7XRgEHhhWYE,1818
|
194 |
+
trl/trainer/ddpo_config.py,sha256=kwFUTMv85yjXICGVimcQfrPCu5Smz-Mz3c3erEA3SRU,4932
|
195 |
+
trl/trainer/ddpo_trainer.py,sha256=NTfQ5jiLuiKGp9ypH3mcxSZIj2cOZjzu3yg5THBXLAg,27023
|
196 |
+
trl/trainer/dpo_trainer.py,sha256=Zcc7ohWl83KVFcQLkC8qfBLT6zWpzK1jjLuqqGL4UBE,62580
|
197 |
+
trl/trainer/iterative_sft_trainer.py,sha256=_91ZH1o1IkWOanuqHSjhEGx_nElDJ_WiBmQwG0DWNsU,16489
|
198 |
+
trl/trainer/model_config.py,sha256=xlsz4478y8f11ZQZ-kwVsGc5bdzyIPTYi4pPdOSr2TU,2966
|
199 |
+
trl/trainer/ppo_config.py,sha256=IC9Y1K-6hQcipDr6jDywsBh4fToJ-3KsuSgOY4aJS-0,8317
|
200 |
+
trl/trainer/ppo_trainer.py,sha256=NmongqErhUrRckVrHpItl3J1ztV_exEAalqyTqxDA7g,63231
|
201 |
+
trl/trainer/reward_config.py,sha256=Q7IihMGMTMIBFGglv-IuJdSWpV6FSbhnlqrZcUaERVU,1661
|
202 |
+
trl/trainer/reward_trainer.py,sha256=93FBp9uus_FAQN560ehyP4yRLWyb9Y4OVysx8rpIACU,13603
|
203 |
+
trl/trainer/sft_trainer.py,sha256=AxrL8nkyO9Cfgd9C8MZLTR5ZUckEUfD5TWSHQWL2dTE,24691
|
204 |
+
trl/trainer/utils.py,sha256=d5W852wGU4mOErsLvqN4jGq4-Mzr_fFFAMY-stBFUYU,31955
|
llava-1.7.0.dev0.dist-info/REQUESTED
ADDED
File without changes
|
llava-1.7.0.dev0.dist-info/WHEEL
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Wheel-Version: 1.0
|
2 |
+
Generator: setuptools (73.0.1)
|
3 |
+
Root-Is-Purelib: true
|
4 |
+
Tag: py3-none-any
|
5 |
+
|
llava-1.7.0.dev0.dist-info/direct_url.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"url": "https://github.com/LLaVA-VL/LLaVA-NeXT.git", "vcs_info": {"commit_id": "e98849102929e1c6304b60b28cca541567b7b643", "vcs": "git"}}
|
llava-1.7.0.dev0.dist-info/top_level.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
llava
|
2 |
+
trl
|
llava/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .model import LlavaLlamaForCausalLM
|
llava/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (224 Bytes). View file
|
|
llava/__pycache__/constants.cpython-39.pyc
ADDED
Binary file (486 Bytes). View file
|
|
llava/__pycache__/conversation.cpython-39.pyc
ADDED
Binary file (14.4 kB). View file
|
|
llava/__pycache__/mm_utils.cpython-39.pyc
ADDED
Binary file (14.1 kB). View file
|
|
llava/__pycache__/utils.cpython-39.pyc
ADDED
Binary file (6.51 kB). View file
|
|
llava/constants.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CONTROLLER_HEART_BEAT_EXPIRATION = 30
|
2 |
+
WORKER_HEART_BEAT_INTERVAL = 15
|
3 |
+
|
4 |
+
LOGDIR = "."
|
5 |
+
|
6 |
+
# Model Constants
|
7 |
+
IGNORE_INDEX = -100
|
8 |
+
IMAGE_TOKEN_INDEX = -200
|
9 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
10 |
+
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
11 |
+
DEFAULT_IM_START_TOKEN = "<im_start>"
|
12 |
+
DEFAULT_IM_END_TOKEN = "<im_end>"
|
llava/conversation.py
ADDED
@@ -0,0 +1,577 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
from enum import auto, Enum
|
3 |
+
from typing import List, Any, Dict, Union, Tuple
|
4 |
+
import re
|
5 |
+
import base64
|
6 |
+
from io import BytesIO
|
7 |
+
from PIL import Image
|
8 |
+
from transformers import AutoTokenizer
|
9 |
+
|
10 |
+
|
11 |
+
class SeparatorStyle(Enum):
|
12 |
+
"""Different separator style."""
|
13 |
+
|
14 |
+
SINGLE = auto()
|
15 |
+
TWO = auto()
|
16 |
+
MPT = auto()
|
17 |
+
PLAIN = auto()
|
18 |
+
CHATML = auto()
|
19 |
+
LLAMA_2 = auto()
|
20 |
+
LLAMA_3 = auto()
|
21 |
+
QWEN = auto()
|
22 |
+
GEMMA = auto()
|
23 |
+
|
24 |
+
|
25 |
+
@dataclasses.dataclass
|
26 |
+
class Conversation:
|
27 |
+
"""A class that keeps all conversation history."""
|
28 |
+
|
29 |
+
system: str
|
30 |
+
roles: List[str]
|
31 |
+
messages: List[List[str]]
|
32 |
+
offset: int
|
33 |
+
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
34 |
+
sep: str = "###"
|
35 |
+
sep2: str = None
|
36 |
+
version: str = "Unknown"
|
37 |
+
|
38 |
+
tokenizer_id: str = ""
|
39 |
+
tokenizer: Any = None
|
40 |
+
# Stop criteria (the default one is EOS token)
|
41 |
+
stop_str: Union[str, List[str]] = None
|
42 |
+
# Stops generation if meeting any token in this list
|
43 |
+
stop_token_ids: List[int] = None
|
44 |
+
|
45 |
+
skip_next: bool = False
|
46 |
+
|
47 |
+
def get_prompt(self):
|
48 |
+
messages = self.messages
|
49 |
+
if len(messages) > 0 and type(messages[0][1]) is tuple:
|
50 |
+
messages = self.messages.copy()
|
51 |
+
init_role, init_msg = messages[0].copy()
|
52 |
+
init_msg = init_msg[0]
|
53 |
+
if "mmtag" in self.version:
|
54 |
+
init_msg = init_msg.replace("<image>", "").strip()
|
55 |
+
messages[0] = (init_role, init_msg)
|
56 |
+
messages.insert(0, (self.roles[0], "<Image><image></Image>"))
|
57 |
+
messages.insert(1, (self.roles[1], "Received."))
|
58 |
+
elif not init_msg.startswith("<image>"):
|
59 |
+
init_msg = init_msg.replace("<image>", "").strip()
|
60 |
+
messages[0] = (init_role, "<image>\n" + init_msg)
|
61 |
+
else:
|
62 |
+
messages[0] = (init_role, init_msg)
|
63 |
+
|
64 |
+
if self.sep_style == SeparatorStyle.SINGLE:
|
65 |
+
ret = self.system + self.sep
|
66 |
+
for role, message in messages:
|
67 |
+
if message:
|
68 |
+
if type(message) is tuple:
|
69 |
+
message, _, _ = message
|
70 |
+
ret += role + ": " + message + self.sep
|
71 |
+
else:
|
72 |
+
ret += role + ":"
|
73 |
+
|
74 |
+
elif self.sep_style == SeparatorStyle.TWO:
|
75 |
+
seps = [self.sep, self.sep2]
|
76 |
+
ret = self.system + seps[0]
|
77 |
+
for i, (role, message) in enumerate(messages):
|
78 |
+
if message:
|
79 |
+
if type(message) is tuple:
|
80 |
+
message, _, _ = message
|
81 |
+
ret += role + ": " + message + seps[i % 2]
|
82 |
+
else:
|
83 |
+
ret += role + ":"
|
84 |
+
|
85 |
+
elif self.sep_style == SeparatorStyle.CHATML:
|
86 |
+
ret = "" if self.system == "" else self.system + self.sep + "\n"
|
87 |
+
for role, message in messages:
|
88 |
+
if message:
|
89 |
+
if type(message) is tuple:
|
90 |
+
message, images, _ = message
|
91 |
+
message = "<image>" * len(images) + message
|
92 |
+
ret += role + "\n" + message + self.sep + "\n"
|
93 |
+
else:
|
94 |
+
ret += role + "\n"
|
95 |
+
return ret
|
96 |
+
|
97 |
+
elif self.sep_style == SeparatorStyle.LLAMA_3:
|
98 |
+
chat_template_messages = [{"role": "system", "content": self.system}]
|
99 |
+
for role, message in messages:
|
100 |
+
if message:
|
101 |
+
if type(message) is tuple:
|
102 |
+
message, images = message
|
103 |
+
message = "<image>" * len(images) + message
|
104 |
+
chat_template_messages.append({"role": role, "content": message})
|
105 |
+
|
106 |
+
# print(chat_template_messages)
|
107 |
+
return self.tokenizer.apply_chat_template(chat_template_messages, tokenize=False, add_generation_prompt=True)
|
108 |
+
# ret = "" if self.system == "" else self.system + self.sep + "\n"
|
109 |
+
# for role, message in messages:
|
110 |
+
# if message:
|
111 |
+
# if type(message) is tuple:
|
112 |
+
# message, images = message
|
113 |
+
# message = "<image>" * len(images) + message
|
114 |
+
# ret += role + "\n" + message + self.sep + "\n"
|
115 |
+
# else:
|
116 |
+
# ret += role + "\n"
|
117 |
+
# return ret
|
118 |
+
|
119 |
+
elif self.sep_style == SeparatorStyle.MPT:
|
120 |
+
ret = self.system + self.sep
|
121 |
+
for role, message in messages:
|
122 |
+
if message:
|
123 |
+
if type(message) is tuple:
|
124 |
+
message, _, _ = message
|
125 |
+
ret += role + message + self.sep
|
126 |
+
else:
|
127 |
+
ret += role
|
128 |
+
|
129 |
+
elif self.sep_style == SeparatorStyle.GEMMA:
|
130 |
+
ret = ""
|
131 |
+
for i, (role, message) in enumerate(messages):
|
132 |
+
assert role == self.roles[i % 2], "Conversation should alternate user/assistant/user/assistant/..."
|
133 |
+
if message:
|
134 |
+
if type(message) is tuple:
|
135 |
+
message, _, _ = message
|
136 |
+
ret += role + message + self.sep
|
137 |
+
else:
|
138 |
+
ret += role
|
139 |
+
|
140 |
+
elif self.sep_style == SeparatorStyle.LLAMA_2:
|
141 |
+
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
|
142 |
+
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
|
143 |
+
ret = ""
|
144 |
+
|
145 |
+
for i, (role, message) in enumerate(messages):
|
146 |
+
if i == 0:
|
147 |
+
assert message, "first message should not be none"
|
148 |
+
assert role == self.roles[0], "first message should come from user"
|
149 |
+
if message:
|
150 |
+
if type(message) is tuple:
|
151 |
+
message, _, _ = message
|
152 |
+
if i == 0:
|
153 |
+
message = wrap_sys(self.system) + message
|
154 |
+
if i % 2 == 0:
|
155 |
+
message = wrap_inst(message)
|
156 |
+
ret += self.sep + message
|
157 |
+
else:
|
158 |
+
ret += " " + message + " " + self.sep2
|
159 |
+
else:
|
160 |
+
ret += ""
|
161 |
+
ret = ret.lstrip(self.sep)
|
162 |
+
|
163 |
+
elif self.sep_style == SeparatorStyle.PLAIN:
|
164 |
+
seps = [self.sep, self.sep2]
|
165 |
+
ret = self.system
|
166 |
+
for i, (role, message) in enumerate(messages):
|
167 |
+
if message:
|
168 |
+
if type(message) is tuple:
|
169 |
+
message, _, _ = message
|
170 |
+
ret += message + seps[i % 2]
|
171 |
+
else:
|
172 |
+
ret += ""
|
173 |
+
else:
|
174 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
175 |
+
|
176 |
+
return ret
|
177 |
+
|
178 |
+
def append_message(self, role, message):
|
179 |
+
self.messages.append([role, message])
|
180 |
+
|
181 |
+
def process_image(self, image, image_process_mode, return_pil=False, image_format="PNG"):
|
182 |
+
if image_process_mode == "Pad":
|
183 |
+
|
184 |
+
def expand2square(pil_img, background_color=(122, 116, 104)):
|
185 |
+
width, height = pil_img.size
|
186 |
+
if width == height:
|
187 |
+
return pil_img
|
188 |
+
elif width > height:
|
189 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
190 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
191 |
+
return result
|
192 |
+
else:
|
193 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
194 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
195 |
+
return result
|
196 |
+
|
197 |
+
image = expand2square(image)
|
198 |
+
elif image_process_mode in ["Default", "Crop"]:
|
199 |
+
pass
|
200 |
+
elif image_process_mode == "Resize":
|
201 |
+
image = image.resize((336, 336))
|
202 |
+
else:
|
203 |
+
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
|
204 |
+
|
205 |
+
if type(image) is not Image.Image:
|
206 |
+
image = Image.open(image).convert("RGB")
|
207 |
+
|
208 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
209 |
+
aspect_ratio = max_hw / min_hw
|
210 |
+
max_len, min_len = 672, 448
|
211 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
212 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
213 |
+
W, H = image.size
|
214 |
+
if H > W:
|
215 |
+
H, W = longest_edge, shortest_edge
|
216 |
+
else:
|
217 |
+
H, W = shortest_edge, longest_edge
|
218 |
+
image = image.resize((W, H))
|
219 |
+
if return_pil:
|
220 |
+
return image
|
221 |
+
else:
|
222 |
+
buffered = BytesIO()
|
223 |
+
image.save(buffered, format=image_format)
|
224 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
225 |
+
return img_b64_str
|
226 |
+
|
227 |
+
def get_images(self, return_pil=False, return_path=False):
|
228 |
+
images = []
|
229 |
+
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
230 |
+
if i % 2 == 0:
|
231 |
+
if type(msg) is tuple:
|
232 |
+
msg, image, image_process_mode = msg
|
233 |
+
if type(image) != list:
|
234 |
+
image = [image]
|
235 |
+
for img in image:
|
236 |
+
if not return_path and self.is_image_file(img):
|
237 |
+
img = self.process_image(img, image_process_mode, return_pil=return_pil)
|
238 |
+
else:
|
239 |
+
images.append(img)
|
240 |
+
return images
|
241 |
+
|
242 |
+
def is_image_file(self, filename):
|
243 |
+
image_extensions = [".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".webp"]
|
244 |
+
return any(filename.lower().endswith(ext) for ext in image_extensions)
|
245 |
+
|
246 |
+
def is_video_file(self, filename):
|
247 |
+
video_extensions = [".mp4", ".mov", ".avi", ".mkv", ".wmv", ".flv", ".mpeg", ".mpg"]
|
248 |
+
return any(filename.lower().endswith(ext) for ext in video_extensions)
|
249 |
+
|
250 |
+
def to_gradio_chatbot(self):
|
251 |
+
ret = []
|
252 |
+
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
253 |
+
if i % 2 == 0:
|
254 |
+
if type(msg) is tuple:
|
255 |
+
msg, image, image_process_mode = msg
|
256 |
+
if type(image) != list:
|
257 |
+
image = [image]
|
258 |
+
if len(image) == 1:
|
259 |
+
msg = "<image>\n" + msg.replace("<image>", "").strip()
|
260 |
+
else:
|
261 |
+
msg = re.sub(r"(<image>)\n(?=<image>)", r"\1 ", msg)
|
262 |
+
|
263 |
+
img_str_list = []
|
264 |
+
for img in image:
|
265 |
+
if self.is_image_file(img):
|
266 |
+
img_b64_str = self.process_image(img, "Default", return_pil=False, image_format="JPEG")
|
267 |
+
img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" style="max-width: 256px; max-height: 256px; width: auto; height: auto; object-fit: contain;"/>'
|
268 |
+
img_str_list.append(img_str)
|
269 |
+
elif self.is_video_file(img):
|
270 |
+
ret.append(((img,), None))
|
271 |
+
|
272 |
+
msg = msg.strip()
|
273 |
+
img_place_holder = ""
|
274 |
+
for img_str in img_str_list:
|
275 |
+
img_place_holder += f"{img_str}\n\n"
|
276 |
+
|
277 |
+
if len(img_str_list) > 0:
|
278 |
+
msg = f"{img_place_holder}\n\n{msg}"
|
279 |
+
|
280 |
+
if len(msg) > 0:
|
281 |
+
ret.append([msg, None])
|
282 |
+
else:
|
283 |
+
ret.append([msg, None])
|
284 |
+
else:
|
285 |
+
ret[-1][-1] = msg
|
286 |
+
return ret
|
287 |
+
|
288 |
+
def copy(self):
|
289 |
+
return Conversation(system=self.system, roles=self.roles, messages=[[x, y] for x, y in self.messages], offset=self.offset, sep_style=self.sep_style, sep=self.sep, sep2=self.sep2, version=self.version)
|
290 |
+
|
291 |
+
def dict(self):
|
292 |
+
if len(self.get_images()) > 0:
|
293 |
+
return {
|
294 |
+
"system": self.system,
|
295 |
+
"roles": self.roles,
|
296 |
+
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
|
297 |
+
"offset": self.offset,
|
298 |
+
"sep": self.sep,
|
299 |
+
"sep2": self.sep2,
|
300 |
+
}
|
301 |
+
return {
|
302 |
+
"system": self.system,
|
303 |
+
"roles": self.roles,
|
304 |
+
"messages": self.messages,
|
305 |
+
"offset": self.offset,
|
306 |
+
"sep": self.sep,
|
307 |
+
"sep2": self.sep2,
|
308 |
+
}
|
309 |
+
|
310 |
+
|
311 |
+
conv_vicuna_v0 = Conversation(
|
312 |
+
system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
313 |
+
roles=("Human", "Assistant"),
|
314 |
+
messages=[
|
315 |
+
["Human", "What are the key differences between renewable and non-renewable energy sources?"],
|
316 |
+
[
|
317 |
+
"Assistant",
|
318 |
+
"Renewable energy sources are those that can be replenished naturally in a relatively "
|
319 |
+
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
|
320 |
+
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
|
321 |
+
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
|
322 |
+
"renewable and non-renewable energy sources:\n"
|
323 |
+
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
|
324 |
+
"energy sources are finite and will eventually run out.\n"
|
325 |
+
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
|
326 |
+
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
|
327 |
+
"and other negative effects.\n"
|
328 |
+
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
|
329 |
+
"have lower operational costs than non-renewable sources.\n"
|
330 |
+
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
|
331 |
+
"locations than non-renewable sources.\n"
|
332 |
+
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
|
333 |
+
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
|
334 |
+
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
|
335 |
+
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n",
|
336 |
+
],
|
337 |
+
],
|
338 |
+
offset=2,
|
339 |
+
sep_style=SeparatorStyle.SINGLE,
|
340 |
+
sep="###",
|
341 |
+
)
|
342 |
+
|
343 |
+
conv_vicuna_v1 = Conversation(
|
344 |
+
system="A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
345 |
+
roles=("USER", "ASSISTANT"),
|
346 |
+
version="v1",
|
347 |
+
messages=[],
|
348 |
+
offset=0,
|
349 |
+
sep_style=SeparatorStyle.TWO,
|
350 |
+
sep=" ",
|
351 |
+
sep2="</s>",
|
352 |
+
)
|
353 |
+
|
354 |
+
conv_llama_2 = Conversation(
|
355 |
+
system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
356 |
+
|
357 |
+
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
|
358 |
+
roles=("USER", "ASSISTANT"),
|
359 |
+
version="llama_v2",
|
360 |
+
messages=[],
|
361 |
+
offset=0,
|
362 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
363 |
+
sep="<s>",
|
364 |
+
sep2="</s>",
|
365 |
+
)
|
366 |
+
|
367 |
+
conv_llava_llama_2 = Conversation(
|
368 |
+
system="You are a helpful language and vision assistant. " "You are able to understand the visual content that the user provides, " "and assist the user with a variety of tasks using natural language.",
|
369 |
+
roles=("USER", "ASSISTANT"),
|
370 |
+
version="llama_v2",
|
371 |
+
messages=[],
|
372 |
+
offset=0,
|
373 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
374 |
+
sep="<s>",
|
375 |
+
sep2="</s>",
|
376 |
+
)
|
377 |
+
|
378 |
+
conv_llava_llama_3 = Conversation(
|
379 |
+
system="You are a helpful language and vision assistant. " "You are able to understand the visual content that the user provides, " "and assist the user with a variety of tasks using natural language.",
|
380 |
+
roles=("user", "assistant"),
|
381 |
+
version="llama_v3",
|
382 |
+
messages=[],
|
383 |
+
offset=0,
|
384 |
+
sep="<|eot_id|>",
|
385 |
+
sep_style=SeparatorStyle.LLAMA_3,
|
386 |
+
tokenizer_id="meta-llama/Meta-Llama-3-8B-Instruct",
|
387 |
+
tokenizer=AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct"),
|
388 |
+
stop_token_ids=[128009],
|
389 |
+
)
|
390 |
+
|
391 |
+
conv_mistral_instruct = Conversation(
|
392 |
+
system="",
|
393 |
+
roles=("USER", "ASSISTANT"),
|
394 |
+
version="llama_v2",
|
395 |
+
messages=[],
|
396 |
+
offset=0,
|
397 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
398 |
+
sep="",
|
399 |
+
sep2="</s>",
|
400 |
+
)
|
401 |
+
|
402 |
+
conv_llava_llama_2_simple = Conversation(
|
403 |
+
system="Answer the questions about the visual content that the user provides.",
|
404 |
+
roles=("USER", "ASSISTANT"),
|
405 |
+
version="llama_v2",
|
406 |
+
messages=[],
|
407 |
+
offset=0,
|
408 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
409 |
+
sep="<s>",
|
410 |
+
sep2="</s>",
|
411 |
+
)
|
412 |
+
|
413 |
+
conv_llava_llama_2_mmtag = Conversation(
|
414 |
+
system="Answer the questions about the visual content that the user provides." "The visual content will be provided with the following format: <Image>visual content</Image>.",
|
415 |
+
roles=("USER", "ASSISTANT"),
|
416 |
+
version="llama_v2_mmtag",
|
417 |
+
messages=[],
|
418 |
+
offset=0,
|
419 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
420 |
+
sep="<s>",
|
421 |
+
sep2="</s>",
|
422 |
+
)
|
423 |
+
|
424 |
+
conv_mpt = Conversation(
|
425 |
+
system="""<|im_start|>system
|
426 |
+
A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
|
427 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
428 |
+
version="mpt",
|
429 |
+
messages=[],
|
430 |
+
offset=0,
|
431 |
+
sep_style=SeparatorStyle.MPT,
|
432 |
+
sep="<|im_end|>",
|
433 |
+
)
|
434 |
+
|
435 |
+
conv_qwen = Conversation(
|
436 |
+
system="""<|im_start|>system
|
437 |
+
You are a helpful assistant.""",
|
438 |
+
roles=("<|im_start|>user", "<|im_start|>assistant"),
|
439 |
+
version="qwen",
|
440 |
+
messages=[],
|
441 |
+
offset=0,
|
442 |
+
sep_style=SeparatorStyle.CHATML,
|
443 |
+
sep="<|im_end|>",
|
444 |
+
)
|
445 |
+
|
446 |
+
conv_gemma_instruct = Conversation(system="", roles=("<start_of_turn>user\n", "<start_of_turn>model\n"), version="gemma", messages=[], offset=0, sep_style=SeparatorStyle.GEMMA, sep="<end_of_turn>\n")
|
447 |
+
|
448 |
+
conv_llava_plain = Conversation(
|
449 |
+
system="",
|
450 |
+
roles=("", ""),
|
451 |
+
messages=[],
|
452 |
+
offset=0,
|
453 |
+
sep_style=SeparatorStyle.PLAIN,
|
454 |
+
sep="\n",
|
455 |
+
)
|
456 |
+
|
457 |
+
conv_llava_v0 = Conversation(
|
458 |
+
system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
459 |
+
roles=("Human", "Assistant"),
|
460 |
+
messages=[],
|
461 |
+
offset=0,
|
462 |
+
sep_style=SeparatorStyle.SINGLE,
|
463 |
+
sep="###",
|
464 |
+
)
|
465 |
+
|
466 |
+
conv_llava_v0_mmtag = Conversation(
|
467 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
468 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
469 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
470 |
+
roles=("Human", "Assistant"),
|
471 |
+
messages=[],
|
472 |
+
offset=0,
|
473 |
+
sep_style=SeparatorStyle.SINGLE,
|
474 |
+
sep="###",
|
475 |
+
version="v0_mmtag",
|
476 |
+
)
|
477 |
+
|
478 |
+
conv_llava_v1 = Conversation(
|
479 |
+
system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
480 |
+
roles=("USER", "ASSISTANT"),
|
481 |
+
version="v1",
|
482 |
+
messages=[],
|
483 |
+
offset=0,
|
484 |
+
sep_style=SeparatorStyle.TWO,
|
485 |
+
sep=" ",
|
486 |
+
sep2="</s>",
|
487 |
+
)
|
488 |
+
|
489 |
+
conv_llava_v1_mmtag = Conversation(
|
490 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
491 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
492 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
493 |
+
roles=("USER", "ASSISTANT"),
|
494 |
+
messages=[],
|
495 |
+
offset=0,
|
496 |
+
sep_style=SeparatorStyle.TWO,
|
497 |
+
sep=" ",
|
498 |
+
sep2="</s>",
|
499 |
+
version="v1_mmtag",
|
500 |
+
)
|
501 |
+
|
502 |
+
conv_mistral_orca = Conversation(
|
503 |
+
system="""<|im_start|>system
|
504 |
+
You are MistralOrca, a large language model trained by Alignment Lab AI. Write out your reasoning step-by-step to be sure you get the right answers!""",
|
505 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
506 |
+
version="mpt",
|
507 |
+
messages=[],
|
508 |
+
offset=0,
|
509 |
+
sep_style=SeparatorStyle.MPT,
|
510 |
+
sep="<|im_end|>",
|
511 |
+
)
|
512 |
+
|
513 |
+
conv_mistral_zephyr = Conversation(
|
514 |
+
system="""<|system|>
|
515 |
+
You are a helpful AI assistant.""",
|
516 |
+
roles=("<|user|>\n", "<|assistant|>\n"),
|
517 |
+
version="mpt",
|
518 |
+
messages=[],
|
519 |
+
offset=0,
|
520 |
+
sep_style=SeparatorStyle.MPT,
|
521 |
+
sep="</s>",
|
522 |
+
)
|
523 |
+
|
524 |
+
conv_mistral_direct = Conversation(
|
525 |
+
system="""<|im_start|>system
|
526 |
+
Answer the questions.""",
|
527 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
528 |
+
version="mpt",
|
529 |
+
messages=[],
|
530 |
+
offset=0,
|
531 |
+
sep_style=SeparatorStyle.MPT,
|
532 |
+
sep="<|im_end|>",
|
533 |
+
)
|
534 |
+
|
535 |
+
conv_chatml_direct = Conversation(
|
536 |
+
system="""<|im_start|>system
|
537 |
+
Answer the questions.""",
|
538 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
539 |
+
version="mpt",
|
540 |
+
messages=[],
|
541 |
+
offset=0,
|
542 |
+
sep_style=SeparatorStyle.MPT,
|
543 |
+
sep="<|im_end|>",
|
544 |
+
)
|
545 |
+
|
546 |
+
default_conversation = conv_vicuna_v0
|
547 |
+
conv_templates = {
|
548 |
+
"default": conv_vicuna_v0,
|
549 |
+
"v0": conv_vicuna_v0,
|
550 |
+
"v1": conv_vicuna_v1,
|
551 |
+
"vicuna_v1": conv_vicuna_v1,
|
552 |
+
"llama_2": conv_llama_2,
|
553 |
+
"mistral_instruct": conv_mistral_instruct,
|
554 |
+
"mistral_orca": conv_mistral_orca,
|
555 |
+
"mistral_zephyr": conv_mistral_zephyr,
|
556 |
+
"mistral_direct": conv_mistral_direct,
|
557 |
+
"plain": conv_llava_plain,
|
558 |
+
"v0_plain": conv_llava_plain,
|
559 |
+
"chatml_direct": conv_chatml_direct,
|
560 |
+
"llava_v0": conv_llava_v0,
|
561 |
+
"llava_v0_mmtag": conv_llava_v0_mmtag,
|
562 |
+
"llava_v1": conv_llava_v1,
|
563 |
+
"llava_v1_mmtag": conv_llava_v1_mmtag,
|
564 |
+
"llava_llama_2": conv_llava_llama_2,
|
565 |
+
"llava_llama_3": conv_llava_llama_3,
|
566 |
+
"llava_llama_2_simple": conv_llava_llama_2_simple,
|
567 |
+
"llava_llama_2_mmtag": conv_llava_llama_2_mmtag,
|
568 |
+
"llava_mistral_instruct": conv_mistral_instruct,
|
569 |
+
"mpt": conv_mpt,
|
570 |
+
"qwen_1_5": conv_qwen,
|
571 |
+
"qwen_2": conv_qwen,
|
572 |
+
"gemma_instruct": conv_gemma_instruct,
|
573 |
+
}
|
574 |
+
|
575 |
+
|
576 |
+
if __name__ == "__main__":
|
577 |
+
print(default_conversation.get_prompt())
|
llava/eval/__pycache__/evaluate_interleave.cpython-39.pyc
ADDED
Binary file (7.26 kB). View file
|
|
llava/eval/__pycache__/model_vqa.cpython-39.pyc
ADDED
Binary file (6.58 kB). View file
|
|
llava/eval/evaluate_interleave.py
ADDED
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from rouge import Rouge
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
import json
|
6 |
+
import numpy as np
|
7 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
8 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
9 |
+
|
10 |
+
|
11 |
+
spot_the_diff = ["Spot-the-Diff", "Birds-to-Words", "CLEVR-Change"]
|
12 |
+
image_edit_instruct = ["IEdit", "HQ-Edit", "MagicBrush"]
|
13 |
+
visual_story_telling = ["AESOP", "FlintstonesSV", "PororoSV", "VIST"]
|
14 |
+
visual_cloze = ["COMICS_Dialogue", "RecipeQA_VisualCloze"]
|
15 |
+
text_rich_vqa = ["WebQA", "TQA", "OCR-VQA", "DocVQA"]
|
16 |
+
multi_image_vqa = ["MIT-States_StateCoherence", "MIT-States_PropertyCoherence", "VISION", "RecipeQA_ImageCoherence"]
|
17 |
+
|
18 |
+
puzzle = ["RAVEN"]
|
19 |
+
nlrv2 = ["NLVR2_Mantis"]
|
20 |
+
qbench = ["QBench"]
|
21 |
+
|
22 |
+
class Eval:
|
23 |
+
def __init__(self):
|
24 |
+
self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
|
25 |
+
self.commaStrip = re.compile("(\d)(\,)(\d)")
|
26 |
+
self.punct = [
|
27 |
+
";",
|
28 |
+
r"/",
|
29 |
+
"[",
|
30 |
+
"]",
|
31 |
+
'"',
|
32 |
+
"{",
|
33 |
+
"}",
|
34 |
+
"(",
|
35 |
+
")",
|
36 |
+
"=",
|
37 |
+
"+",
|
38 |
+
"\\",
|
39 |
+
"_",
|
40 |
+
"-",
|
41 |
+
">",
|
42 |
+
"<",
|
43 |
+
"@",
|
44 |
+
"`",
|
45 |
+
",",
|
46 |
+
"?",
|
47 |
+
"!",
|
48 |
+
]
|
49 |
+
|
50 |
+
def processPunctuation(self, inText):
|
51 |
+
outText = inText
|
52 |
+
for p in self.punct:
|
53 |
+
if (p + " " in inText or " " + p in inText) or (
|
54 |
+
re.search(self.commaStrip, inText) != None
|
55 |
+
):
|
56 |
+
outText = outText.replace(p, "")
|
57 |
+
else:
|
58 |
+
outText = outText.replace(p, " ")
|
59 |
+
outText = self.periodStrip.sub("", outText, re.UNICODE)
|
60 |
+
return outText
|
61 |
+
|
62 |
+
def process(self, answer):
|
63 |
+
answer = answer.replace("\n", " ")
|
64 |
+
answer = answer.replace("\t", " ")
|
65 |
+
answer = answer.strip()
|
66 |
+
answer = self.processPunctuation(answer)
|
67 |
+
answer = answer.strip('\'')
|
68 |
+
answer = answer.strip('\"')
|
69 |
+
answer = answer.strip(')')
|
70 |
+
answer = answer.strip('(')
|
71 |
+
answer = answer.strip().lower()
|
72 |
+
return answer
|
73 |
+
|
74 |
+
def evaluate_rouge(self,preds):
|
75 |
+
rouge = Rouge()
|
76 |
+
acc = {'f': []}
|
77 |
+
eval_list = []
|
78 |
+
for i, res in enumerate(preds):
|
79 |
+
sample_id = res['sample_id']
|
80 |
+
# print(sample_id)
|
81 |
+
gt_ans = self.process(res["gt_response"])
|
82 |
+
pred_ans = self.process(res["pred_response"])
|
83 |
+
# assert gt_ans != ''
|
84 |
+
|
85 |
+
if gt_ans == '':
|
86 |
+
continue
|
87 |
+
|
88 |
+
if pred_ans == '':
|
89 |
+
s = 0
|
90 |
+
else:
|
91 |
+
if len(pred_ans) > 512:
|
92 |
+
pred_ans = pred_ans[0: 512]
|
93 |
+
s = rouge.get_scores(pred_ans, gt_ans)[0]['rouge-l']['f']
|
94 |
+
acc['f'].append(s)
|
95 |
+
eval_list.append({'id':str(sample_id),'score':str(round(s,3))})
|
96 |
+
results = {'Rouge-L f': np.mean(acc['f'])}
|
97 |
+
return results,eval_list
|
98 |
+
|
99 |
+
|
100 |
+
def judge_multi_choice(self,sample):
|
101 |
+
sample_id = sample['sample_id']
|
102 |
+
gt_ans = sample["gt_response"]
|
103 |
+
pred_ans = sample["pred_response"]
|
104 |
+
|
105 |
+
if ":" in pred_ans:
|
106 |
+
a_list = pred_ans.split(":")
|
107 |
+
a_list = [a.strip() for a in a_list ]
|
108 |
+
for a in a_list:
|
109 |
+
if len(a) == 1 and a[-1] in ["a", "b", "c", "d", "e", "f", "g", "h"]:
|
110 |
+
pred_ans = a
|
111 |
+
|
112 |
+
if pred_ans == gt_ans:
|
113 |
+
return 1
|
114 |
+
else:
|
115 |
+
return 0
|
116 |
+
|
117 |
+
def process_sample(self,sample):
|
118 |
+
sample["gt_response"] = self.process(sample["gt_response"])
|
119 |
+
sample["pred_response"] = self.process(sample["pred_response"])
|
120 |
+
|
121 |
+
def evaluate_multichoice(self, preditions):
|
122 |
+
correct = 0
|
123 |
+
eval_list = []
|
124 |
+
for i, sample in enumerate(preditions):
|
125 |
+
self.process_sample(sample)
|
126 |
+
score = self.judge_multi_choice(sample)
|
127 |
+
sample_id = sample['sample_id']
|
128 |
+
sample['result'] = score
|
129 |
+
eval_list.append({'id':str(sample_id),'score':str(score)})
|
130 |
+
correct+=score
|
131 |
+
return {'Accuracy':correct/len(preditions)},eval_list
|
132 |
+
|
133 |
+
def evaluate_multi_choice_image(self,preditions):
|
134 |
+
correct = 0
|
135 |
+
eval_list = []
|
136 |
+
for i,sample in enumerate(preditions):
|
137 |
+
gt_ans = self.process(sample["gt_response"])
|
138 |
+
pred_ans = self.process(sample["pred_response"])
|
139 |
+
sample_id = sample['sample_id']
|
140 |
+
|
141 |
+
if ":" in pred_ans:
|
142 |
+
a_list = pred_ans.split(":")
|
143 |
+
a_list = [a.strip() for a in a_list ]
|
144 |
+
for a in a_list:
|
145 |
+
if len(a) == 1 and a[-1] in ["a", "b", "c", "d", "e", "f", "g", "h"]:
|
146 |
+
pred_ans = a
|
147 |
+
|
148 |
+
if gt_ans == pred_ans:
|
149 |
+
score = 1
|
150 |
+
else:
|
151 |
+
score = 0
|
152 |
+
sample_id = sample['sample_id']
|
153 |
+
sample['result'] = score
|
154 |
+
eval_list.append({'id':str(sample_id),'score':str(score)})
|
155 |
+
correct+=score
|
156 |
+
return {'Accuracy':correct/len(preditions)},eval_list
|
157 |
+
|
158 |
+
|
159 |
+
if __name__ == "__main__":
|
160 |
+
parser = argparse.ArgumentParser()
|
161 |
+
parser.add_argument('--result-dir', type=str, required=True)
|
162 |
+
|
163 |
+
args = parser.parse_args()
|
164 |
+
|
165 |
+
result_file = os.path.join(args.result_dir, "result.jsonl")
|
166 |
+
|
167 |
+
if not os.path.exists(result_file):
|
168 |
+
print('No prediction file found')
|
169 |
+
exit(0)
|
170 |
+
with open(result_file, 'r') as f:
|
171 |
+
preds_all = [json.loads(line) for line in f]
|
172 |
+
|
173 |
+
preds_all_dict = dict()
|
174 |
+
for pred in preds_all:
|
175 |
+
if pred["dataset"] not in preds_all_dict:
|
176 |
+
preds_all_dict[pred["dataset"]] = list()
|
177 |
+
preds_all_dict[pred["dataset"]].append(pred)
|
178 |
+
|
179 |
+
image_choice_dataset_list = ["recipeqa-RecipeQA_VisualCloze", "RecipeQA_ImageCoherence", "COMICS_Panel"]
|
180 |
+
E = Eval()
|
181 |
+
|
182 |
+
eval_result_list = dict()
|
183 |
+
eval_result_list_detail = dict()
|
184 |
+
|
185 |
+
for dataset in preds_all_dict:
|
186 |
+
|
187 |
+
preds = preds_all_dict[dataset]
|
188 |
+
question_type = preds[0]["question_type"]
|
189 |
+
|
190 |
+
if question_type == 'open-ended':
|
191 |
+
eval_result, eval_list = E.evaluate_rouge(preds)
|
192 |
+
|
193 |
+
elif question_type == 'multi-choice' or dataset == 'nlrv2':
|
194 |
+
if dataset in image_choice_dataset_list:
|
195 |
+
eval_result, eval_list = E.evaluate_multi_choice_image(preds)
|
196 |
+
else:
|
197 |
+
eval_result, eval_list = E.evaluate_multichoice(preds)
|
198 |
+
|
199 |
+
else:
|
200 |
+
eval_result = 'Dataset not supported'
|
201 |
+
print('Dataset not supported')
|
202 |
+
exit(0)
|
203 |
+
|
204 |
+
print(dataset, end = ': ')
|
205 |
+
print(eval_result)
|
206 |
+
|
207 |
+
eval_result_list[dataset] = eval_result
|
208 |
+
eval_result_list_detail[dataset] = eval_list
|
209 |
+
|
210 |
+
os.makedirs(args.result_dir, exist_ok=True)
|
211 |
+
with open(os.path.join(args.result_dir, 'eval_dataset.json'), 'w') as f:
|
212 |
+
json.dump(eval_result_list, f, indent=4)
|
213 |
+
|
214 |
+
with open(os.path.join(args.result_dir,'eval_dataset_details.json'), 'w') as f:
|
215 |
+
json.dump(eval_result_list_detail, f, indent=4)
|
216 |
+
|
217 |
+
|
218 |
+
eval_cat_list = dict()
|
219 |
+
print()
|
220 |
+
|
221 |
+
# spot_the_diff
|
222 |
+
score = 0
|
223 |
+
count = 0
|
224 |
+
for dataset in eval_result_list:
|
225 |
+
if dataset in spot_the_diff:
|
226 |
+
count += 1
|
227 |
+
score += list(eval_result_list[dataset].values())[0]
|
228 |
+
if count > 0:
|
229 |
+
score /= count
|
230 |
+
eval_cat_list["spot_the_diff"] = score
|
231 |
+
print("spot_the_diff", end = ': ')
|
232 |
+
print('{:.2f}'.format(100 * score))
|
233 |
+
|
234 |
+
# image_edit_instruct
|
235 |
+
score = 0
|
236 |
+
count = 0
|
237 |
+
for dataset in eval_result_list:
|
238 |
+
if dataset in image_edit_instruct:
|
239 |
+
count += 1
|
240 |
+
score += list(eval_result_list[dataset].values())[0]
|
241 |
+
if count > 0:
|
242 |
+
score /= count
|
243 |
+
eval_cat_list["image_edit_instruct"] = score
|
244 |
+
print("image_edit_instruct", end = ': ')
|
245 |
+
print('{:.2f}'.format(100 * score))
|
246 |
+
|
247 |
+
# visual_story_telling
|
248 |
+
score = 0
|
249 |
+
count = 0
|
250 |
+
for dataset in eval_result_list:
|
251 |
+
if dataset in visual_story_telling:
|
252 |
+
count += 1
|
253 |
+
score += list(eval_result_list[dataset].values())[0]
|
254 |
+
if count > 0:
|
255 |
+
score /= count
|
256 |
+
eval_cat_list["visual_story_telling"] = score
|
257 |
+
print("visual_story_telling", end = ': ')
|
258 |
+
print('{:.2f}'.format(100 * score))
|
259 |
+
|
260 |
+
# visual_cloze
|
261 |
+
score = 0
|
262 |
+
count = 0
|
263 |
+
for dataset in eval_result_list:
|
264 |
+
if dataset in visual_cloze:
|
265 |
+
count += 1
|
266 |
+
score += list(eval_result_list[dataset].values())[0]
|
267 |
+
if count > 0:
|
268 |
+
score /= count
|
269 |
+
eval_cat_list["visual_cloze"] = score
|
270 |
+
print("visual_cloze", end = ': ')
|
271 |
+
print('{:.2f}'.format(100 * score))
|
272 |
+
|
273 |
+
# text_rich_vqa
|
274 |
+
score = 0
|
275 |
+
count = 0
|
276 |
+
for dataset in eval_result_list:
|
277 |
+
if dataset in text_rich_vqa:
|
278 |
+
count += 1
|
279 |
+
score += list(eval_result_list[dataset].values())[0]
|
280 |
+
if count > 0:
|
281 |
+
score /= count
|
282 |
+
eval_cat_list["text_rich_vqa"] = score
|
283 |
+
print("text_rich_vqa", end = ': ')
|
284 |
+
print('{:.2f}'.format(100 * score))
|
285 |
+
|
286 |
+
# multi_image_vqa
|
287 |
+
score = 0
|
288 |
+
count = 0
|
289 |
+
for dataset in eval_result_list:
|
290 |
+
if dataset in multi_image_vqa:
|
291 |
+
count += 1
|
292 |
+
score += list(eval_result_list[dataset].values())[0]
|
293 |
+
if count > 0:
|
294 |
+
score /= count
|
295 |
+
eval_cat_list["multi_image_vqa"] = score
|
296 |
+
print("multi_image_vqa", end = ': ')
|
297 |
+
print('{:.2f}'.format(100 * score))
|
298 |
+
|
299 |
+
# puzzle
|
300 |
+
score = 0
|
301 |
+
count = 0
|
302 |
+
for dataset in eval_result_list:
|
303 |
+
if dataset in puzzle:
|
304 |
+
count += 1
|
305 |
+
score += list(eval_result_list[dataset].values())[0]
|
306 |
+
if count > 0:
|
307 |
+
score /= count
|
308 |
+
eval_cat_list["puzzle"] = score
|
309 |
+
print("puzzle", end = ': ')
|
310 |
+
print('{:.2f}'.format(100 * score))
|
311 |
+
|
312 |
+
# nlrv2
|
313 |
+
score = 0
|
314 |
+
count = 0
|
315 |
+
for dataset in eval_result_list:
|
316 |
+
if dataset in nlrv2:
|
317 |
+
count += 1
|
318 |
+
score += list(eval_result_list[dataset].values())[0]
|
319 |
+
if count > 0:
|
320 |
+
score /= count
|
321 |
+
eval_cat_list["nlrv2"] = score
|
322 |
+
print("nlrv2", end = ': ')
|
323 |
+
print('{:.2f}'.format(100 * score))
|
324 |
+
|
325 |
+
# qbench
|
326 |
+
score = 0
|
327 |
+
count = 0
|
328 |
+
for dataset in eval_result_list:
|
329 |
+
if dataset in qbench:
|
330 |
+
count += 1
|
331 |
+
score += list(eval_result_list[dataset].values())[0]
|
332 |
+
if count > 0:
|
333 |
+
score /= count
|
334 |
+
eval_cat_list["qbench"] = score
|
335 |
+
print("qbench", end = ': ')
|
336 |
+
print('{:.2f}'.format(100 * score))
|
337 |
+
|
338 |
+
with open(os.path.join(args.result_dir,'eval_cat.json'), 'w') as f:
|
339 |
+
json.dump(eval_cat_list, f, indent=4)
|
llava/eval/model_vqa.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
from tqdm import tqdm
|
6 |
+
import shortuuid
|
7 |
+
|
8 |
+
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
9 |
+
from llava.conversation import conv_templates, SeparatorStyle
|
10 |
+
from llava.model.builder import load_pretrained_model
|
11 |
+
from llava.utils import disable_torch_init
|
12 |
+
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
|
13 |
+
|
14 |
+
from llava.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IMAGE_TOKEN_INDEX
|
15 |
+
from typing import Dict, Optional, Sequence, List
|
16 |
+
import transformers
|
17 |
+
import re
|
18 |
+
|
19 |
+
from PIL import Image
|
20 |
+
import math
|
21 |
+
|
22 |
+
|
23 |
+
def split_list(lst, n):
|
24 |
+
"""Split a list into n (roughly) equal-sized chunks"""
|
25 |
+
chunk_size = math.ceil(len(lst) / n) # integer division
|
26 |
+
return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
|
27 |
+
|
28 |
+
|
29 |
+
def get_chunk(lst, n, k):
|
30 |
+
chunks = split_list(lst, n)
|
31 |
+
return chunks[k]
|
32 |
+
|
33 |
+
def preprocess_qwen(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False, max_len=2048, system_message: str = "You are a helpful assistant.") -> Dict:
|
34 |
+
roles = {"human": "<|im_start|>user", "gpt": "<|im_start|>assistant"}
|
35 |
+
|
36 |
+
im_start, im_end = tokenizer.additional_special_tokens_ids
|
37 |
+
nl_tokens = tokenizer("\n").input_ids
|
38 |
+
_system = tokenizer("system").input_ids + nl_tokens
|
39 |
+
_user = tokenizer("user").input_ids + nl_tokens
|
40 |
+
_assistant = tokenizer("assistant").input_ids + nl_tokens
|
41 |
+
|
42 |
+
# Apply prompt templates
|
43 |
+
input_ids, targets = [], []
|
44 |
+
|
45 |
+
source = sources
|
46 |
+
if roles[source[0]["from"]] != roles["human"]:
|
47 |
+
source = source[1:]
|
48 |
+
|
49 |
+
input_id, target = [], []
|
50 |
+
system = [im_start] + _system + tokenizer(system_message).input_ids + [im_end] + nl_tokens
|
51 |
+
input_id += system
|
52 |
+
target += [im_start] + [IGNORE_INDEX] * (len(system) - 3) + [im_end] + nl_tokens
|
53 |
+
assert len(input_id) == len(target)
|
54 |
+
for j, sentence in enumerate(source):
|
55 |
+
role = roles[sentence["from"]]
|
56 |
+
if has_image and sentence["value"] is not None and "<image>" in sentence["value"]:
|
57 |
+
num_image = len(re.findall(DEFAULT_IMAGE_TOKEN, sentence["value"]))
|
58 |
+
texts = sentence["value"].split('<image>')
|
59 |
+
_input_id = tokenizer(role).input_ids + nl_tokens
|
60 |
+
for i,text in enumerate(texts):
|
61 |
+
_input_id += tokenizer(text).input_ids
|
62 |
+
if i<len(texts)-1:
|
63 |
+
_input_id += [IMAGE_TOKEN_INDEX] + nl_tokens
|
64 |
+
_input_id += [im_end] + nl_tokens
|
65 |
+
assert sum([i==IMAGE_TOKEN_INDEX for i in _input_id])==num_image
|
66 |
+
else:
|
67 |
+
if sentence["value"] is None:
|
68 |
+
_input_id = tokenizer(role).input_ids + nl_tokens
|
69 |
+
else:
|
70 |
+
_input_id = tokenizer(role).input_ids + nl_tokens + tokenizer(sentence["value"]).input_ids + [im_end] + nl_tokens
|
71 |
+
input_id += _input_id
|
72 |
+
if role == "<|im_start|>user":
|
73 |
+
_target = [im_start] + [IGNORE_INDEX] * (len(_input_id) - 3) + [im_end] + nl_tokens
|
74 |
+
elif role == "<|im_start|>assistant":
|
75 |
+
_target = [im_start] + [IGNORE_INDEX] * len(tokenizer(role).input_ids) + _input_id[len(tokenizer(role).input_ids) + 1 : -2] + [im_end] + nl_tokens
|
76 |
+
else:
|
77 |
+
raise NotImplementedError
|
78 |
+
target += _target
|
79 |
+
|
80 |
+
input_ids.append(input_id)
|
81 |
+
targets.append(target)
|
82 |
+
input_ids = torch.tensor(input_ids, dtype=torch.long)
|
83 |
+
targets = torch.tensor(targets, dtype=torch.long)
|
84 |
+
return input_ids
|
85 |
+
|
86 |
+
def eval_model(args):
|
87 |
+
|
88 |
+
# Model
|
89 |
+
disable_torch_init()
|
90 |
+
model_path = os.path.expanduser(args.model_path)
|
91 |
+
model_name = get_model_name_from_path(model_path)
|
92 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
|
93 |
+
|
94 |
+
# Data
|
95 |
+
with open(os.path.expanduser(args.question_file)) as f:
|
96 |
+
questions = json.load(f)
|
97 |
+
questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
|
98 |
+
answers_file = os.path.expanduser(args.answers_file)
|
99 |
+
os.makedirs(os.path.dirname(answers_file), exist_ok=True)
|
100 |
+
ans_file = open(answers_file, "w")
|
101 |
+
|
102 |
+
for line in tqdm(questions):
|
103 |
+
idx = line["sample_id"]
|
104 |
+
question_type = line["metadata"]["question_type"]
|
105 |
+
dataset_name = line["metadata"]["dataset"]
|
106 |
+
gt = line["conversations"][1]["value"]
|
107 |
+
|
108 |
+
image_files = line["image"]
|
109 |
+
qs = line["conversations"][0]["value"]
|
110 |
+
cur_prompt = args.extra_prompt + qs
|
111 |
+
|
112 |
+
args.conv_mode = "qwen_1_5"
|
113 |
+
|
114 |
+
conv = conv_templates[args.conv_mode].copy()
|
115 |
+
conv.append_message(conv.roles[0], qs)
|
116 |
+
conv.append_message(conv.roles[1], None)
|
117 |
+
prompt = conv.get_prompt()
|
118 |
+
|
119 |
+
input_ids = preprocess_qwen([line["conversations"][0],{'from': 'gpt','value': None}], tokenizer, has_image=True).cuda()
|
120 |
+
img_num = list(input_ids.squeeze()).count(IMAGE_TOKEN_INDEX)
|
121 |
+
|
122 |
+
image_tensors = []
|
123 |
+
for image_file in image_files:
|
124 |
+
image = Image.open(os.path.join(args.image_folder, image_file))
|
125 |
+
image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values']
|
126 |
+
image_tensors.append(image_tensor.half().cuda())
|
127 |
+
# image_tensors = torch.cat(image_tensors, dim=0)
|
128 |
+
|
129 |
+
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
130 |
+
keywords = [stop_str]
|
131 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
132 |
+
|
133 |
+
with torch.inference_mode():
|
134 |
+
output_ids = model.generate(
|
135 |
+
input_ids,
|
136 |
+
images=image_tensors,
|
137 |
+
do_sample=True if args.temperature > 0 else False,
|
138 |
+
temperature=args.temperature,
|
139 |
+
top_p=args.top_p,
|
140 |
+
num_beams=args.num_beams,
|
141 |
+
# no_repeat_ngram_size=3,
|
142 |
+
max_new_tokens=1024,
|
143 |
+
use_cache=True)
|
144 |
+
|
145 |
+
|
146 |
+
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
|
147 |
+
outputs = outputs.strip()
|
148 |
+
if outputs.endswith(stop_str):
|
149 |
+
outputs = outputs[:-len(stop_str)]
|
150 |
+
outputs = outputs.strip()
|
151 |
+
|
152 |
+
ans_id = shortuuid.uuid()
|
153 |
+
ans_file.write(json.dumps({
|
154 |
+
"dataset": dataset_name,
|
155 |
+
"sample_id": idx,
|
156 |
+
"prompt": cur_prompt,
|
157 |
+
"pred_response": outputs,
|
158 |
+
"gt_response": gt,
|
159 |
+
"shortuuid": ans_id,
|
160 |
+
"model_id": model_name,
|
161 |
+
"question_type": question_type,
|
162 |
+
}) + "\n")
|
163 |
+
ans_file.flush()
|
164 |
+
|
165 |
+
if len(line["conversations"]) > 2:
|
166 |
+
|
167 |
+
for i in range(2, len(line["conversations"]), 2):
|
168 |
+
input_ids = torch.cat((input_ids, output_ids), dim=1)
|
169 |
+
|
170 |
+
gt = line["conversations"][i + 1]["value"]
|
171 |
+
qs = line["conversations"][i]["value"]
|
172 |
+
cur_prompt = args.extra_prompt + qs
|
173 |
+
|
174 |
+
args.conv_mode = "qwen_1_5"
|
175 |
+
|
176 |
+
conv = conv_templates[args.conv_mode].copy()
|
177 |
+
conv.append_message(conv.roles[0], qs)
|
178 |
+
conv.append_message(conv.roles[1], None)
|
179 |
+
prompt = conv.get_prompt()
|
180 |
+
|
181 |
+
input_ids_new = preprocess_qwen([line["conversations"][i],{'from': 'gpt','value': None}], tokenizer, has_image=True).cuda()
|
182 |
+
input_ids = torch.cat((input_ids, input_ids_new), dim=1)
|
183 |
+
img_num = list(input_ids_new.squeeze()).count(IMAGE_TOKEN_INDEX)
|
184 |
+
|
185 |
+
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
186 |
+
keywords = [stop_str]
|
187 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
188 |
+
|
189 |
+
with torch.inference_mode():
|
190 |
+
output_ids = model.generate(
|
191 |
+
input_ids,
|
192 |
+
images=image_tensors,
|
193 |
+
do_sample=True if args.temperature > 0 else False,
|
194 |
+
temperature=args.temperature,
|
195 |
+
top_p=args.top_p,
|
196 |
+
num_beams=args.num_beams,
|
197 |
+
# no_repeat_ngram_size=3,
|
198 |
+
max_new_tokens=1024,
|
199 |
+
use_cache=True)
|
200 |
+
|
201 |
+
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
|
202 |
+
outputs = outputs.strip()
|
203 |
+
if outputs.endswith(stop_str):
|
204 |
+
outputs = outputs[:-len(stop_str)]
|
205 |
+
outputs = outputs.strip()
|
206 |
+
|
207 |
+
ans_id = shortuuid.uuid()
|
208 |
+
ans_file.write(json.dumps({
|
209 |
+
"dataset": dataset_name,
|
210 |
+
"sample_id": idx,
|
211 |
+
"prompt": cur_prompt,
|
212 |
+
"pred_response": outputs,
|
213 |
+
"gt_response": gt,
|
214 |
+
"shortuuid": ans_id,
|
215 |
+
"model_id": model_name,
|
216 |
+
"question_type": question_type,
|
217 |
+
}) + "\n")
|
218 |
+
ans_file.flush()
|
219 |
+
|
220 |
+
|
221 |
+
ans_file.close()
|
222 |
+
|
223 |
+
if __name__ == "__main__":
|
224 |
+
parser = argparse.ArgumentParser()
|
225 |
+
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
226 |
+
parser.add_argument("--model-base", type=str, default=None)
|
227 |
+
parser.add_argument("--image-folder", type=str, default="")
|
228 |
+
parser.add_argument("--extra-prompt", type=str, default="")
|
229 |
+
parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
|
230 |
+
parser.add_argument("--answers-file", type=str, default="answer.jsonl")
|
231 |
+
parser.add_argument("--conv-mode", type=str, default="llava_v1")
|
232 |
+
parser.add_argument("--num-chunks", type=int, default=1)
|
233 |
+
parser.add_argument("--chunk-idx", type=int, default=0)
|
234 |
+
parser.add_argument("--temperature", type=float, default=0.2)
|
235 |
+
parser.add_argument("--top_p", type=float, default=None)
|
236 |
+
parser.add_argument("--num_beams", type=int, default=1)
|
237 |
+
parser.add_argument("--test_size", type=int, default=10000000)
|
238 |
+
args = parser.parse_args()
|
239 |
+
|
240 |
+
eval_model(args)
|
llava/mm_utils.py
ADDED
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
from io import BytesIO
|
3 |
+
import base64
|
4 |
+
import math
|
5 |
+
import ast
|
6 |
+
import re
|
7 |
+
import torch
|
8 |
+
from transformers import StoppingCriteria
|
9 |
+
from llava.constants import IMAGE_TOKEN_INDEX
|
10 |
+
|
11 |
+
|
12 |
+
def resize_and_center_crop(image, shortest_edge_length):
|
13 |
+
# Calculate new dimensions and resize
|
14 |
+
aspect_ratio = float(image.width) / float(image.height)
|
15 |
+
if aspect_ratio > 1:
|
16 |
+
new_width = int(shortest_edge_length * aspect_ratio)
|
17 |
+
new_height = shortest_edge_length
|
18 |
+
else:
|
19 |
+
new_width = shortest_edge_length
|
20 |
+
new_height = int(shortest_edge_length / aspect_ratio)
|
21 |
+
resized_image = image.resize((new_width, new_height), Image.ANTIALIAS)
|
22 |
+
|
23 |
+
# Calculate the position and perform the center crop
|
24 |
+
left = (new_width - shortest_edge_length) / 2
|
25 |
+
top = (new_height - shortest_edge_length) / 2
|
26 |
+
right = (new_width + shortest_edge_length) / 2
|
27 |
+
bottom = (new_height + shortest_edge_length) / 2
|
28 |
+
cropped_image = resized_image.crop((left, top, right, bottom))
|
29 |
+
|
30 |
+
return cropped_image
|
31 |
+
|
32 |
+
|
33 |
+
def auto_pad_images(image, grid_params):
|
34 |
+
assert isinstance(image, Image.Image), "Input should be a Pillow Image"
|
35 |
+
assert len(grid_params) > 0, "Grid parameters should not be empty"
|
36 |
+
|
37 |
+
# Step 1: Calculate and find the closest aspect ratio
|
38 |
+
input_width, input_height = image.size
|
39 |
+
input_aspect_ratio = input_width / input_height
|
40 |
+
candidate_resolutions = [(w / h, w, h) for w in grid_params for h in grid_params]
|
41 |
+
closest_aspect_ratio = min(candidate_resolutions, key=lambda x: abs(input_aspect_ratio - x[0]))
|
42 |
+
|
43 |
+
candidate_resolutions = [(x[1], x[2]) for x in candidate_resolutions if abs(x[0] - closest_aspect_ratio[0]) < 1e-3]
|
44 |
+
|
45 |
+
target_resolution = min(candidate_resolutions, key=lambda res: abs(max(input_width, input_height) / max(res) - 1))
|
46 |
+
|
47 |
+
resize_width, resize_height = target_resolution
|
48 |
+
if input_width > input_height:
|
49 |
+
resize_height = int(resize_width / input_aspect_ratio)
|
50 |
+
else:
|
51 |
+
resize_width = int(resize_height * input_aspect_ratio)
|
52 |
+
resized_image = image.resize((resize_width, resize_height), Image.ANTIALIAS)
|
53 |
+
|
54 |
+
# Step 5: Pad the resized image if necessary to match the target resolution
|
55 |
+
pad_width = target_resolution[0] - resize_width
|
56 |
+
pad_height = target_resolution[1] - resize_height
|
57 |
+
padded_image = Image.new("RGB", target_resolution, color=(0, 0, 0))
|
58 |
+
padded_image.paste(resized_image, (pad_width // 2, pad_height // 2))
|
59 |
+
|
60 |
+
return padded_image
|
61 |
+
|
62 |
+
|
63 |
+
def extract_patches(image, patch_size, overlap_ratio):
|
64 |
+
assert isinstance(image, Image.Image), "Input should be a Pillow Image"
|
65 |
+
assert patch_size > 0, "Patch size should be greater than 0"
|
66 |
+
assert 0 <= overlap_ratio < 1, "Overlap ratio should be between 0 and 1"
|
67 |
+
|
68 |
+
W, H = image.size
|
69 |
+
patches = []
|
70 |
+
|
71 |
+
stride = int(patch_size * (1 - overlap_ratio))
|
72 |
+
|
73 |
+
num_patches_y = (H - patch_size) // stride + 1
|
74 |
+
num_patches_x = (W - patch_size) // stride + 1
|
75 |
+
|
76 |
+
y_start = (H - (num_patches_y - 1) * stride - patch_size) // 2
|
77 |
+
x_start = (W - (num_patches_x - 1) * stride - patch_size) // 2
|
78 |
+
|
79 |
+
for y in range(y_start, y_start + num_patches_y * stride, stride):
|
80 |
+
for x in range(x_start, x_start + num_patches_x * stride, stride):
|
81 |
+
patch = image.crop((x, y, x + patch_size, y + patch_size))
|
82 |
+
patches.append(patch)
|
83 |
+
|
84 |
+
return patches
|
85 |
+
|
86 |
+
|
87 |
+
def process_highres_image_crop_split(image, data_args, processor=None):
|
88 |
+
crop_resolution = data_args.image_crop_resolution
|
89 |
+
split_resolution = data_args.image_split_resolution
|
90 |
+
if processor is None:
|
91 |
+
processor = data_args.image_processor
|
92 |
+
image_crop = resize_and_center_crop(image, crop_resolution)
|
93 |
+
image_patches = extract_patches(image_crop, patch_size=split_resolution, overlap_ratio=0)
|
94 |
+
image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches]
|
95 |
+
return torch.stack(image_patches, dim=0)
|
96 |
+
|
97 |
+
|
98 |
+
def process_highres_image(image, processor, grid_pinpoints):
|
99 |
+
grid_params = [int(x) for x in grid_pinpoints.split(",")]
|
100 |
+
width_height = max(image.size)
|
101 |
+
fit_grid_params = [x for x in grid_params if x >= width_height]
|
102 |
+
if len(fit_grid_params) == 0:
|
103 |
+
select_size = max(grid_params)
|
104 |
+
else:
|
105 |
+
select_size = min(fit_grid_params)
|
106 |
+
# FIXME: always select the 448
|
107 |
+
select_size = max(grid_params)
|
108 |
+
image_padded = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
|
109 |
+
|
110 |
+
# FIXME: this seems to be a bug that it always resizes instead of padding
|
111 |
+
image_original_resize = image.resize((processor.size["shortest_edge"], processor.size["shortest_edge"]))
|
112 |
+
image_padded = image_padded.resize((select_size, select_size))
|
113 |
+
image_patches = extract_patches(image_padded, patch_size=processor.size["shortest_edge"], overlap_ratio=0)
|
114 |
+
image_patches = [image_original_resize] + image_patches
|
115 |
+
image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches]
|
116 |
+
return torch.stack(image_patches, dim=0)
|
117 |
+
|
118 |
+
|
119 |
+
def select_best_resolution(original_size, possible_resolutions):
|
120 |
+
"""
|
121 |
+
Selects the best resolution from a list of possible resolutions based on the original size.
|
122 |
+
|
123 |
+
Args:
|
124 |
+
original_size (tuple): The original size of the image in the format (width, height).
|
125 |
+
possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
tuple: The best fit resolution in the format (width, height).
|
129 |
+
"""
|
130 |
+
original_width, original_height = original_size
|
131 |
+
best_fit = None
|
132 |
+
max_effective_resolution = 0
|
133 |
+
min_wasted_resolution = float("inf")
|
134 |
+
|
135 |
+
for width, height in possible_resolutions:
|
136 |
+
# Calculate the downscaled size to keep the aspect ratio
|
137 |
+
scale = min(width / original_width, height / original_height)
|
138 |
+
downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
|
139 |
+
|
140 |
+
# Calculate effective and wasted resolutions
|
141 |
+
effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
|
142 |
+
wasted_resolution = (width * height) - effective_resolution
|
143 |
+
|
144 |
+
if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
|
145 |
+
max_effective_resolution = effective_resolution
|
146 |
+
min_wasted_resolution = wasted_resolution
|
147 |
+
best_fit = (width, height)
|
148 |
+
|
149 |
+
return best_fit
|
150 |
+
|
151 |
+
|
152 |
+
def resize_and_pad_image(image, target_resolution):
|
153 |
+
"""
|
154 |
+
Resize and pad an image to a target resolution while maintaining aspect ratio.
|
155 |
+
|
156 |
+
Args:
|
157 |
+
image (PIL.Image.Image): The input image.
|
158 |
+
target_resolution (tuple): The target resolution (width, height) of the image.
|
159 |
+
|
160 |
+
Returns:
|
161 |
+
PIL.Image.Image: The resized and padded image.
|
162 |
+
"""
|
163 |
+
original_width, original_height = image.size
|
164 |
+
target_width, target_height = target_resolution
|
165 |
+
|
166 |
+
# Determine which dimension (width or height) to fill
|
167 |
+
scale_w = target_width / original_width
|
168 |
+
scale_h = target_height / original_height
|
169 |
+
|
170 |
+
if scale_w < scale_h:
|
171 |
+
# Width will be filled completely
|
172 |
+
new_width = target_width
|
173 |
+
new_height = min(math.ceil(original_height * scale_w), target_height)
|
174 |
+
else:
|
175 |
+
# Height will be filled completely
|
176 |
+
new_height = target_height
|
177 |
+
new_width = min(math.ceil(original_width * scale_h), target_width)
|
178 |
+
|
179 |
+
# Resize the image
|
180 |
+
resized_image = image.resize((new_width, new_height))
|
181 |
+
|
182 |
+
# Create a new image with the target size and paste the resized image onto it
|
183 |
+
new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0))
|
184 |
+
paste_x = (target_width - new_width) // 2
|
185 |
+
paste_y = (target_height - new_height) // 2
|
186 |
+
new_image.paste(resized_image, (paste_x, paste_y))
|
187 |
+
|
188 |
+
return new_image
|
189 |
+
|
190 |
+
|
191 |
+
def divide_to_patches(image, patch_size):
|
192 |
+
"""
|
193 |
+
Divides an image into patches of a specified size.
|
194 |
+
|
195 |
+
Args:
|
196 |
+
image (PIL.Image.Image): The input image.
|
197 |
+
patch_size (int): The size of each patch.
|
198 |
+
|
199 |
+
Returns:
|
200 |
+
list: A list of PIL.Image.Image objects representing the patches.
|
201 |
+
"""
|
202 |
+
patches = []
|
203 |
+
width, height = image.size
|
204 |
+
for i in range(0, height, patch_size):
|
205 |
+
for j in range(0, width, patch_size):
|
206 |
+
box = (j, i, j + patch_size, i + patch_size)
|
207 |
+
patch = image.crop(box)
|
208 |
+
patches.append(patch)
|
209 |
+
|
210 |
+
return patches
|
211 |
+
|
212 |
+
|
213 |
+
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
214 |
+
"""
|
215 |
+
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
|
216 |
+
|
217 |
+
Args:
|
218 |
+
image_size (tuple): The size of the input image in the format (width, height).
|
219 |
+
grid_pinpoints (str): A string representation of a list of possible resolutions.
|
220 |
+
patch_size (int): The size of each image patch.
|
221 |
+
|
222 |
+
Returns:
|
223 |
+
tuple: The shape of the image patch grid in the format (width, height).
|
224 |
+
"""
|
225 |
+
if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
|
226 |
+
assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
|
227 |
+
# Use regex to extract the range from the input string
|
228 |
+
matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
|
229 |
+
range_start = tuple(map(int, matches[0]))
|
230 |
+
range_end = tuple(map(int, matches[-1]))
|
231 |
+
# Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
|
232 |
+
grid_pinpoints = [(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)]
|
233 |
+
# Multiply all elements by patch_size
|
234 |
+
grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
|
235 |
+
if type(grid_pinpoints) is list:
|
236 |
+
possible_resolutions = grid_pinpoints
|
237 |
+
else:
|
238 |
+
possible_resolutions = ast.literal_eval(grid_pinpoints)
|
239 |
+
width, height = select_best_resolution(image_size, possible_resolutions)
|
240 |
+
return width // patch_size, height // patch_size
|
241 |
+
|
242 |
+
|
243 |
+
def process_anyres_image(image, processor, grid_pinpoints):
|
244 |
+
"""
|
245 |
+
Process an image with variable resolutions.
|
246 |
+
|
247 |
+
Args:
|
248 |
+
image (PIL.Image.Image): The input image to be processed.
|
249 |
+
processor: The image processor object.
|
250 |
+
grid_pinpoints (str): A string representation of a list of possible resolutions.
|
251 |
+
|
252 |
+
Returns:
|
253 |
+
torch.Tensor: A tensor containing the processed image patches.
|
254 |
+
"""
|
255 |
+
# Convert grid_pinpoints from string to list
|
256 |
+
if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
|
257 |
+
try:
|
258 |
+
patch_size = processor.size[0]
|
259 |
+
except Exception as e:
|
260 |
+
patch_size = processor.size["shortest_edge"]
|
261 |
+
assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
|
262 |
+
# Use regex to extract the range from the input string
|
263 |
+
matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
|
264 |
+
range_start = tuple(map(int, matches[0]))
|
265 |
+
range_end = tuple(map(int, matches[-1]))
|
266 |
+
# Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
|
267 |
+
grid_pinpoints = [(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)]
|
268 |
+
# Multiply all elements by patch_size
|
269 |
+
grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
|
270 |
+
|
271 |
+
if type(grid_pinpoints) is list:
|
272 |
+
possible_resolutions = grid_pinpoints
|
273 |
+
else:
|
274 |
+
possible_resolutions = ast.literal_eval(grid_pinpoints)
|
275 |
+
best_resolution = select_best_resolution(image.size, possible_resolutions)
|
276 |
+
image_padded = resize_and_pad_image(image, best_resolution)
|
277 |
+
|
278 |
+
patches = divide_to_patches(image_padded, processor.crop_size["height"])
|
279 |
+
|
280 |
+
# FIXME: this seems to be a bug that it resizes instead of pad.
|
281 |
+
# but to keep it consistent with previous, i will keep it as it is
|
282 |
+
# TODO: uncomment below to ablate with the padding
|
283 |
+
if isinstance(processor.size, dict):
|
284 |
+
shortest_edge = processor.size["shortest_edge"]
|
285 |
+
else:
|
286 |
+
shortest_edge = min(processor.size)
|
287 |
+
image_original_resize = image.resize((shortest_edge, shortest_edge))
|
288 |
+
# image_padded_square = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
|
289 |
+
# image_original_resize = image_padded_square.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
|
290 |
+
|
291 |
+
image_patches = [image_original_resize] + patches
|
292 |
+
image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches]
|
293 |
+
return torch.stack(image_patches, dim=0)
|
294 |
+
|
295 |
+
|
296 |
+
def load_image_from_base64(image):
|
297 |
+
return Image.open(BytesIO(base64.b64decode(image)))
|
298 |
+
|
299 |
+
|
300 |
+
def expand2square(pil_img, background_color):
|
301 |
+
width, height = pil_img.size
|
302 |
+
if width == height:
|
303 |
+
return pil_img
|
304 |
+
elif width > height:
|
305 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
306 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
307 |
+
return result
|
308 |
+
else:
|
309 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
310 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
311 |
+
return result
|
312 |
+
|
313 |
+
|
314 |
+
def process_images(images, image_processor, model_cfg):
|
315 |
+
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
|
316 |
+
new_images = []
|
317 |
+
if image_aspect_ratio == "highres":
|
318 |
+
for image in images:
|
319 |
+
image = process_highres_image(image, image_processor, model_cfg.image_grid_pinpoints)
|
320 |
+
new_images.append(image)
|
321 |
+
elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
|
322 |
+
for image in images:
|
323 |
+
image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
|
324 |
+
new_images.append(image)
|
325 |
+
elif image_aspect_ratio == "crop_split":
|
326 |
+
for image in images:
|
327 |
+
image = process_highres_image_crop_split(image, model_cfg, image_processor)
|
328 |
+
new_images.append(image)
|
329 |
+
elif image_aspect_ratio == "pad":
|
330 |
+
for image in images:
|
331 |
+
image = expand2square(image, tuple(int(x * 255) for x in image_processor.image_mean))
|
332 |
+
image = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
|
333 |
+
new_images.append(image)
|
334 |
+
else:
|
335 |
+
return image_processor.preprocess(images, return_tensors="pt")["pixel_values"]
|
336 |
+
if all(x.shape == new_images[0].shape for x in new_images):
|
337 |
+
new_images = torch.stack(new_images, dim=0)
|
338 |
+
return new_images
|
339 |
+
|
340 |
+
|
341 |
+
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
|
342 |
+
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
|
343 |
+
|
344 |
+
def insert_separator(X, sep):
|
345 |
+
return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
|
346 |
+
|
347 |
+
input_ids = []
|
348 |
+
offset = 0
|
349 |
+
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
350 |
+
offset = 1
|
351 |
+
input_ids.append(prompt_chunks[0][0])
|
352 |
+
|
353 |
+
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
|
354 |
+
input_ids.extend(x[offset:])
|
355 |
+
|
356 |
+
if return_tensors is not None:
|
357 |
+
if return_tensors == "pt":
|
358 |
+
return torch.tensor(input_ids, dtype=torch.long)
|
359 |
+
raise ValueError(f"Unsupported tensor type: {return_tensors}")
|
360 |
+
return input_ids
|
361 |
+
|
362 |
+
|
363 |
+
def get_model_name_from_path(model_path):
|
364 |
+
model_path = model_path.strip("/")
|
365 |
+
model_paths = model_path.split("/")
|
366 |
+
if model_paths[-1].startswith("checkpoint-"):
|
367 |
+
return model_paths[-2] + "_" + model_paths[-1]
|
368 |
+
else:
|
369 |
+
return model_paths[-1]
|
370 |
+
|
371 |
+
|
372 |
+
class KeywordsStoppingCriteria(StoppingCriteria):
|
373 |
+
def __init__(self, keywords, tokenizer, input_ids):
|
374 |
+
self.keywords = keywords
|
375 |
+
self.keyword_ids = []
|
376 |
+
for keyword in keywords:
|
377 |
+
cur_keyword_ids = tokenizer(keyword).input_ids
|
378 |
+
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
|
379 |
+
cur_keyword_ids = cur_keyword_ids[1:]
|
380 |
+
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
|
381 |
+
self.tokenizer = tokenizer
|
382 |
+
self.start_len = input_ids.shape[1]
|
383 |
+
|
384 |
+
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
385 |
+
assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
|
386 |
+
offset = min(output_ids.shape[1] - self.start_len, 3)
|
387 |
+
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
|
388 |
+
for keyword_id in self.keyword_ids:
|
389 |
+
if output_ids[0, -keyword_id.shape[0] :] == keyword_id:
|
390 |
+
return True
|
391 |
+
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
|
392 |
+
for keyword in self.keywords:
|
393 |
+
if keyword in outputs:
|
394 |
+
return True
|
395 |
+
return False
|
llava/model/__init__.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
AVAILABLE_MODELS = {
|
4 |
+
"llava_llama": "LlavaLlamaForCausalLM, LlavaConfig",
|
5 |
+
"llava_qwen": "LlavaQwenForCausalLM, LlavaQwenConfig",
|
6 |
+
"llava_mistral": "LlavaMistralForCausalLM, LlavaMistralConfig",
|
7 |
+
"llava_mixtral": "LlavaMixtralForCausalLM, LlavaMixtralConfig",
|
8 |
+
# "llava_qwen_moe": "LlavaQwenMoeForCausalLM, LlavaQwenMoeConfig",
|
9 |
+
# Add other models as needed
|
10 |
+
}
|
11 |
+
|
12 |
+
for model_name, model_classes in AVAILABLE_MODELS.items():
|
13 |
+
try:
|
14 |
+
exec(f"from .language_model.{model_name} import {model_classes}")
|
15 |
+
except Exception as e:
|
16 |
+
print(f"Failed to import {model_name} from llava.language_model.{model_name}. Error: {e}")
|
llava/model/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (726 Bytes). View file
|
|
llava/model/__pycache__/apply_delta.cpython-39.pyc
ADDED
Binary file (1.73 kB). View file
|
|
llava/model/__pycache__/builder.cpython-39.pyc
ADDED
Binary file (7.77 kB). View file
|
|
llava/model/__pycache__/consolidate.cpython-39.pyc
ADDED
Binary file (1.08 kB). View file
|
|
llava/model/__pycache__/llava_arch.cpython-39.pyc
ADDED
Binary file (14.4 kB). View file
|
|
llava/model/__pycache__/make_delta.cpython-39.pyc
ADDED
Binary file (1.94 kB). View file
|
|
llava/model/__pycache__/utils.cpython-39.pyc
ADDED
Binary file (1.04 kB). View file
|
|
llava/model/apply_delta.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Usage:
|
3 |
+
python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
|
4 |
+
"""
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from tqdm import tqdm
|
10 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
11 |
+
from llava import LlavaLlamaForCausalLM
|
12 |
+
|
13 |
+
|
14 |
+
def apply_delta(base_model_path, target_model_path, delta_path):
|
15 |
+
print("Loading base model")
|
16 |
+
base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
17 |
+
|
18 |
+
print("Loading delta")
|
19 |
+
delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
20 |
+
delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
|
21 |
+
|
22 |
+
print("Applying delta")
|
23 |
+
for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
|
24 |
+
if name not in base.state_dict():
|
25 |
+
assert name in ["model.mm_projector.weight", "model.mm_projector.bias"], f"{name} not in base model"
|
26 |
+
continue
|
27 |
+
if param.data.shape == base.state_dict()[name].shape:
|
28 |
+
param.data += base.state_dict()[name]
|
29 |
+
else:
|
30 |
+
assert name in ["model.embed_tokens.weight", "lm_head.weight"], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}"
|
31 |
+
bparam = base.state_dict()[name]
|
32 |
+
param.data[: bparam.shape[0], : bparam.shape[1]] += bparam
|
33 |
+
|
34 |
+
print("Saving target model")
|
35 |
+
delta.save_pretrained(target_model_path)
|
36 |
+
delta_tokenizer.save_pretrained(target_model_path)
|
37 |
+
|
38 |
+
|
39 |
+
if __name__ == "__main__":
|
40 |
+
parser = argparse.ArgumentParser()
|
41 |
+
parser.add_argument("--base-model-path", type=str, required=True)
|
42 |
+
parser.add_argument("--target-model-path", type=str, required=True)
|
43 |
+
parser.add_argument("--delta-path", type=str, required=True)
|
44 |
+
|
45 |
+
args = parser.parse_args()
|
46 |
+
|
47 |
+
apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
|
llava/model/builder.py
ADDED
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Haotian Liu
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
import os
|
17 |
+
import warnings
|
18 |
+
import shutil
|
19 |
+
|
20 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
|
21 |
+
import torch
|
22 |
+
from llava.model import *
|
23 |
+
from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
24 |
+
from llava.utils import rank0_print
|
25 |
+
|
26 |
+
|
27 |
+
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", attn_implementation="flash_attention_2", customized_config=None, overwrite_config=None, **kwargs):
|
28 |
+
kwargs["device_map"] = device_map
|
29 |
+
|
30 |
+
if load_8bit:
|
31 |
+
kwargs["load_in_8bit"] = True
|
32 |
+
elif load_4bit:
|
33 |
+
kwargs["load_in_4bit"] = True
|
34 |
+
kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4")
|
35 |
+
else:
|
36 |
+
kwargs["torch_dtype"] = torch.float16
|
37 |
+
|
38 |
+
if customized_config is not None:
|
39 |
+
kwargs["config"] = customized_config
|
40 |
+
|
41 |
+
if "multimodal" in kwargs:
|
42 |
+
if kwargs["multimodal"] is True:
|
43 |
+
is_multimodal = True
|
44 |
+
kwargs.pop("multimodal")
|
45 |
+
else:
|
46 |
+
is_multimodal = False
|
47 |
+
|
48 |
+
if "llava" in model_name.lower() or is_multimodal:
|
49 |
+
# Load LLaVA model
|
50 |
+
if "lora" in model_name.lower() and model_base is None:
|
51 |
+
warnings.warn(
|
52 |
+
"There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged."
|
53 |
+
)
|
54 |
+
if "lora" in model_name.lower() and model_base is not None:
|
55 |
+
lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
56 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
57 |
+
rank0_print("Loading LLaVA from base model...")
|
58 |
+
if "mixtral" in model_name.lower():
|
59 |
+
from llava.model.language_model.llava_mixtral import LlavaMixtralConfig
|
60 |
+
|
61 |
+
lora_cfg_pretrained = LlavaMixtralConfig.from_pretrained(model_path)
|
62 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
63 |
+
model = LlavaMixtralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
|
64 |
+
elif "mistral" in model_name.lower():
|
65 |
+
from llava.model.language_model.llava_mistral import LlavaMistralConfig
|
66 |
+
|
67 |
+
lora_cfg_pretrained = LlavaMistralConfig.from_pretrained(model_path)
|
68 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
69 |
+
model = LlavaMistralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
|
70 |
+
elif "gemma" in model_name.lower():
|
71 |
+
from llava.model.language_model.llava_gemma import LlavaGemmaConfig
|
72 |
+
|
73 |
+
lora_cfg_pretrained = LlavaGemmaConfig.from_pretrained(model_path)
|
74 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
75 |
+
model = LlavaGemmaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
|
76 |
+
else:
|
77 |
+
from llava.model.language_model.llava_llama import LlavaConfig
|
78 |
+
|
79 |
+
lora_cfg_pretrained = LlavaConfig.from_pretrained(model_path)
|
80 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
81 |
+
model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
|
82 |
+
|
83 |
+
token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
|
84 |
+
if model.lm_head.weight.shape[0] != token_num:
|
85 |
+
model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
|
86 |
+
model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
|
87 |
+
|
88 |
+
rank0_print("Loading additional LLaVA weights...")
|
89 |
+
if os.path.exists(os.path.join(model_path, "non_lora_trainables.bin")):
|
90 |
+
non_lora_trainables = torch.load(os.path.join(model_path, "non_lora_trainables.bin"), map_location="cpu")
|
91 |
+
else:
|
92 |
+
# this is probably from HF Hub
|
93 |
+
from huggingface_hub import hf_hub_download
|
94 |
+
|
95 |
+
def load_from_hf(repo_id, filename, subfolder=None):
|
96 |
+
cache_file = hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder)
|
97 |
+
return torch.load(cache_file, map_location="cpu")
|
98 |
+
|
99 |
+
non_lora_trainables = load_from_hf(model_path, "non_lora_trainables.bin")
|
100 |
+
non_lora_trainables = {(k[11:] if k.startswith("base_model.") else k): v for k, v in non_lora_trainables.items()}
|
101 |
+
if any(k.startswith("model.model.") for k in non_lora_trainables):
|
102 |
+
non_lora_trainables = {(k[6:] if k.startswith("model.") else k): v for k, v in non_lora_trainables.items()}
|
103 |
+
model.load_state_dict(non_lora_trainables, strict=False)
|
104 |
+
|
105 |
+
from peft import PeftModel
|
106 |
+
|
107 |
+
rank0_print("Loading LoRA weights...")
|
108 |
+
model = PeftModel.from_pretrained(model, model_path)
|
109 |
+
rank0_print("Merging LoRA weights...")
|
110 |
+
model = model.merge_and_unload()
|
111 |
+
rank0_print("Model is loaded...")
|
112 |
+
elif model_base is not None: # this may be mm projector only, loading projector with preset language mdoel
|
113 |
+
rank0_print(f"Loading LLaVA from base model {model_base}...")
|
114 |
+
if "mixtral" in model_name.lower():
|
115 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
116 |
+
cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
117 |
+
model = LlavaMixtralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
|
118 |
+
elif "mistral" in model_name.lower() or "zephyr" in model_name.lower():
|
119 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
120 |
+
cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
121 |
+
model = LlavaMistralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
|
122 |
+
elif "gemma" in model_name.lower():
|
123 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
124 |
+
cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
125 |
+
model = LlavaGemmaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
|
126 |
+
elif (
|
127 |
+
"wizardlm-2" in model_name.lower()
|
128 |
+
and "vicuna" in model_name.lower()
|
129 |
+
or "llama" in model_name.lower()
|
130 |
+
or "yi" in model_name.lower()
|
131 |
+
or "nous-hermes" in model_name.lower()
|
132 |
+
or "llava-v1.6-34b" in model_name.lower()
|
133 |
+
or "llava-v1.5" in model_name.lower()
|
134 |
+
):
|
135 |
+
from llava.model.language_model.llava_llama import LlavaConfig
|
136 |
+
|
137 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
138 |
+
if customized_config is None:
|
139 |
+
llava_cfg = LlavaConfig.from_pretrained(model_path)
|
140 |
+
if "v1.5" in model_name.lower():
|
141 |
+
llava_cfg.delay_load = True # a workaround for correctly loading v1.5 models
|
142 |
+
else:
|
143 |
+
llava_cfg = customized_config
|
144 |
+
|
145 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
146 |
+
llava_cfg = LlavaConfig.from_pretrained(model_path)
|
147 |
+
model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=llava_cfg, **kwargs)
|
148 |
+
else:
|
149 |
+
raise ValueError(f"Model {model_name} not supported")
|
150 |
+
|
151 |
+
mm_projector_weights = torch.load(os.path.join(model_path, "mm_projector.bin"), map_location="cpu")
|
152 |
+
mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
|
153 |
+
model.load_state_dict(mm_projector_weights, strict=False)
|
154 |
+
else:
|
155 |
+
rank0_print(f"Loaded LLaVA model: {model_path}")
|
156 |
+
if "mixtral" in model_name.lower():
|
157 |
+
from llava.model.language_model.llava_mixtral import LlavaMixtralConfig
|
158 |
+
|
159 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
160 |
+
if customized_config is None:
|
161 |
+
llava_cfg = LlavaMixtralConfig.from_pretrained(model_path)
|
162 |
+
else:
|
163 |
+
llava_cfg = customized_config
|
164 |
+
|
165 |
+
if overwrite_config is not None:
|
166 |
+
rank0_print(f"Overwriting config with {overwrite_config}")
|
167 |
+
for k, v in overwrite_config.items():
|
168 |
+
setattr(llava_cfg, k, v)
|
169 |
+
|
170 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
171 |
+
model = LlavaMixtralForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs)
|
172 |
+
|
173 |
+
elif "mistral" in model_name.lower() or "zephyr" in model_name.lower():
|
174 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
175 |
+
model = LlavaMistralForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, **kwargs)
|
176 |
+
elif (
|
177 |
+
"wizardlm-2" in model_name.lower()
|
178 |
+
and "vicuna" in model_name.lower()
|
179 |
+
or "llama" in model_name.lower()
|
180 |
+
or "yi" in model_name.lower()
|
181 |
+
or "nous-hermes" in model_name.lower()
|
182 |
+
or "llava-v1.6-34b" in model_name.lower()
|
183 |
+
or "llava-v1.5" in model_name.lower()
|
184 |
+
):
|
185 |
+
from llava.model.language_model.llava_llama import LlavaConfig
|
186 |
+
|
187 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
188 |
+
if customized_config is None:
|
189 |
+
llava_cfg = LlavaConfig.from_pretrained(model_path)
|
190 |
+
if "v1.5" in model_name.lower():
|
191 |
+
llava_cfg.delay_load = True # a workaround for correctly loading v1.5 models
|
192 |
+
else:
|
193 |
+
llava_cfg = customized_config
|
194 |
+
|
195 |
+
if overwrite_config is not None:
|
196 |
+
rank0_print(f"Overwriting config with {overwrite_config}")
|
197 |
+
for k, v in overwrite_config.items():
|
198 |
+
setattr(llava_cfg, k, v)
|
199 |
+
|
200 |
+
model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs)
|
201 |
+
|
202 |
+
elif "qwen" in model_name.lower() or "quyen" in model_name.lower():
|
203 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
204 |
+
if "moe" in model_name.lower() or "A14B" in model_name.lower():
|
205 |
+
from llava.model.language_model.llava_qwen_moe import LlavaQwenMoeConfig
|
206 |
+
if overwrite_config is not None:
|
207 |
+
llava_cfg = LlavaQwenMoeConfig.from_pretrained(model_path)
|
208 |
+
rank0_print(f"Overwriting config with {overwrite_config}")
|
209 |
+
for k, v in overwrite_config.items():
|
210 |
+
setattr(llava_cfg, k, v)
|
211 |
+
model = LlavaQwenMoeForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs)
|
212 |
+
else:
|
213 |
+
model = LlavaQwenMoeForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, **kwargs)
|
214 |
+
|
215 |
+
else:
|
216 |
+
from llava.model.language_model.llava_qwen import LlavaQwenConfig
|
217 |
+
if overwrite_config is not None:
|
218 |
+
llava_cfg = LlavaQwenConfig.from_pretrained(model_path)
|
219 |
+
rank0_print(f"Overwriting config with {overwrite_config}")
|
220 |
+
for k, v in overwrite_config.items():
|
221 |
+
setattr(llava_cfg, k, v)
|
222 |
+
model = LlavaQwenForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs)
|
223 |
+
else:
|
224 |
+
model = LlavaQwenForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, **kwargs)
|
225 |
+
|
226 |
+
elif "gemma" in model_name.lower():
|
227 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
228 |
+
cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
229 |
+
model = LlavaGemmaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, config=cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
|
230 |
+
else:
|
231 |
+
try:
|
232 |
+
from llava.model.language_model.llava_llama import LlavaConfig
|
233 |
+
|
234 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
235 |
+
if customized_config is None:
|
236 |
+
llava_cfg = LlavaConfig.from_pretrained(model_path)
|
237 |
+
if "v1.5" in model_path.lower():
|
238 |
+
llava_cfg.delay_load = True # a workaround for correctly loading v1.5 models
|
239 |
+
else:
|
240 |
+
llava_cfg = customized_config
|
241 |
+
|
242 |
+
if overwrite_config is not None:
|
243 |
+
rank0_print(f"Overwriting config with {overwrite_config}")
|
244 |
+
for k, v in overwrite_config.items():
|
245 |
+
setattr(llava_cfg, k, v)
|
246 |
+
model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs)
|
247 |
+
except:
|
248 |
+
raise ValueError(f"Model {model_name} not supported")
|
249 |
+
|
250 |
+
else:
|
251 |
+
# Load language model
|
252 |
+
if model_base is not None:
|
253 |
+
# PEFT model
|
254 |
+
from peft import PeftModel
|
255 |
+
|
256 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
257 |
+
model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto")
|
258 |
+
print(f"Loading LoRA weights from {model_path}")
|
259 |
+
model = PeftModel.from_pretrained(model, model_path)
|
260 |
+
print(f"Merging weights")
|
261 |
+
model = model.merge_and_unload()
|
262 |
+
print("Convert to FP16...")
|
263 |
+
model.to(torch.float16)
|
264 |
+
else:
|
265 |
+
use_fast = False
|
266 |
+
if "mpt" in model_name.lower().replace("prompt", ""):
|
267 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
|
268 |
+
model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
|
269 |
+
else:
|
270 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
271 |
+
model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
272 |
+
|
273 |
+
rank0_print(f"Model Class: {model.__class__.__name__}")
|
274 |
+
image_processor = None
|
275 |
+
|
276 |
+
if "llava" in model_name.lower() or is_multimodal:
|
277 |
+
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
|
278 |
+
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
|
279 |
+
if mm_use_im_patch_token:
|
280 |
+
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
281 |
+
if mm_use_im_start_end:
|
282 |
+
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
283 |
+
model.resize_token_embeddings(len(tokenizer))
|
284 |
+
|
285 |
+
vision_tower = model.get_vision_tower()
|
286 |
+
if not vision_tower.is_loaded:
|
287 |
+
vision_tower.load_model(device_map=device_map)
|
288 |
+
if device_map != "auto":
|
289 |
+
vision_tower.to(device="cuda", dtype=torch.float16)
|
290 |
+
image_processor = vision_tower.image_processor
|
291 |
+
|
292 |
+
if hasattr(model.config, "max_sequence_length"):
|
293 |
+
context_len = model.config.max_sequence_length
|
294 |
+
elif hasattr(model.config, "max_position_embeddings"):
|
295 |
+
context_len = model.config.max_position_embeddings
|
296 |
+
elif hasattr(model.config, "tokenizer_model_max_length"):
|
297 |
+
context_len = model.config.tokenizer_model_max_length
|
298 |
+
else:
|
299 |
+
context_len = 2048
|
300 |
+
|
301 |
+
return tokenizer, model, image_processor, context_len
|
llava/model/consolidate.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Usage:
|
3 |
+
python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
|
4 |
+
"""
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
10 |
+
from llava.model import *
|
11 |
+
from llava.model.utils import auto_upgrade
|
12 |
+
|
13 |
+
|
14 |
+
def consolidate_ckpt(src_path, dst_path):
|
15 |
+
print("Loading model")
|
16 |
+
auto_upgrade(src_path)
|
17 |
+
src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
18 |
+
src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
|
19 |
+
src_model.save_pretrained(dst_path)
|
20 |
+
src_tokenizer.save_pretrained(dst_path)
|
21 |
+
|
22 |
+
|
23 |
+
if __name__ == "__main__":
|
24 |
+
parser = argparse.ArgumentParser()
|
25 |
+
parser.add_argument("--src", type=str, required=True)
|
26 |
+
parser.add_argument("--dst", type=str, required=True)
|
27 |
+
|
28 |
+
args = parser.parse_args()
|
29 |
+
|
30 |
+
consolidate_ckpt(args.src, args.dst)
|
llava/model/language_model/__pycache__/llava_gemma.cpython-39.pyc
ADDED
Binary file (3.72 kB). View file
|
|
llava/model/language_model/__pycache__/llava_llama.cpython-39.pyc
ADDED
Binary file (4.24 kB). View file
|
|
llava/model/language_model/__pycache__/llava_mistral.cpython-39.pyc
ADDED
Binary file (3.97 kB). View file
|
|
llava/model/language_model/__pycache__/llava_mixtral.cpython-39.pyc
ADDED
Binary file (4.1 kB). View file
|
|
llava/model/language_model/__pycache__/llava_mpt.cpython-39.pyc
ADDED
Binary file (3.34 kB). View file
|
|
llava/model/language_model/__pycache__/llava_qwen.cpython-39.pyc
ADDED
Binary file (4.13 kB). View file
|
|
llava/model/language_model/__pycache__/llava_qwen_moe.cpython-39.pyc
ADDED
Binary file (4.11 kB). View file
|
|
llava/model/language_model/__pycache__/modeling_llama.cpython-39.pyc
ADDED
Binary file (48.9 kB). View file
|
|
llava/model/language_model/llava_gemma.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Duc Q. Nguyen, Haotian Liu and Bo Li
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from typing import List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
from torch.nn import CrossEntropyLoss
|
21 |
+
|
22 |
+
from transformers import AutoConfig, AutoModelForCausalLM, GemmaConfig, GemmaModel, GemmaForCausalLM
|
23 |
+
|
24 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
25 |
+
from transformers.generation.utils import GenerateOutput
|
26 |
+
|
27 |
+
from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
|
28 |
+
|
29 |
+
|
30 |
+
class LlavaGemmaConfig(GemmaConfig):
|
31 |
+
model_type = "llava_gemma"
|
32 |
+
|
33 |
+
|
34 |
+
class LlavaGemmaModel(LlavaMetaModel, GemmaModel):
|
35 |
+
config_class = LlavaGemmaConfig
|
36 |
+
|
37 |
+
def __init__(self, config: GemmaConfig):
|
38 |
+
super(LlavaGemmaModel, self).__init__(config)
|
39 |
+
|
40 |
+
|
41 |
+
class LlavaGemmaForCausalLM(GemmaForCausalLM, LlavaMetaForCausalLM):
|
42 |
+
config_class = LlavaGemmaConfig
|
43 |
+
|
44 |
+
def __init__(self, config):
|
45 |
+
super(GemmaForCausalLM, self).__init__(config)
|
46 |
+
self.model = LlavaGemmaModel(config)
|
47 |
+
|
48 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
49 |
+
|
50 |
+
# Initialize weights and apply final processing
|
51 |
+
self.post_init()
|
52 |
+
|
53 |
+
def get_model(self):
|
54 |
+
return self.model
|
55 |
+
|
56 |
+
def forward(
|
57 |
+
self,
|
58 |
+
input_ids: torch.LongTensor = None,
|
59 |
+
attention_mask: Optional[torch.Tensor] = None,
|
60 |
+
position_ids: Optional[torch.LongTensor] = None,
|
61 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
62 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
63 |
+
labels: Optional[torch.LongTensor] = None,
|
64 |
+
use_cache: Optional[bool] = None,
|
65 |
+
output_attentions: Optional[bool] = None,
|
66 |
+
output_hidden_states: Optional[bool] = None,
|
67 |
+
images: Optional[torch.FloatTensor] = None,
|
68 |
+
image_sizes: Optional[List[List[int]]] = None,
|
69 |
+
return_dict: Optional[bool] = None,
|
70 |
+
cache_position: Optional[torch.LongTensor] = None,
|
71 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
72 |
+
|
73 |
+
if inputs_embeds is None:
|
74 |
+
(input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes)
|
75 |
+
|
76 |
+
return super().forward(
|
77 |
+
input_ids=input_ids,
|
78 |
+
attention_mask=attention_mask,
|
79 |
+
position_ids=position_ids,
|
80 |
+
past_key_values=past_key_values,
|
81 |
+
inputs_embeds=inputs_embeds,
|
82 |
+
labels=labels,
|
83 |
+
use_cache=use_cache,
|
84 |
+
output_attentions=output_attentions,
|
85 |
+
output_hidden_states=output_hidden_states,
|
86 |
+
return_dict=return_dict,
|
87 |
+
cache_position=cache_position,
|
88 |
+
)
|
89 |
+
|
90 |
+
@torch.no_grad()
|
91 |
+
def generate(
|
92 |
+
self,
|
93 |
+
inputs: Optional[torch.Tensor] = None,
|
94 |
+
images: Optional[torch.Tensor] = None,
|
95 |
+
image_sizes: Optional[torch.Tensor] = None,
|
96 |
+
**kwargs,
|
97 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
98 |
+
position_ids = kwargs.pop("position_ids", None)
|
99 |
+
attention_mask = kwargs.pop("attention_mask", None)
|
100 |
+
if "inputs_embeds" in kwargs:
|
101 |
+
raise NotImplementedError("`inputs_embeds` is not supported")
|
102 |
+
|
103 |
+
if images is not None:
|
104 |
+
(inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes)
|
105 |
+
else:
|
106 |
+
inputs_embeds = self.get_model().embed_tokens(inputs)
|
107 |
+
|
108 |
+
return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
|
109 |
+
|
110 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
111 |
+
images = kwargs.pop("images", None)
|
112 |
+
image_sizes = kwargs.pop("image_sizes", None)
|
113 |
+
inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
|
114 |
+
if images is not None:
|
115 |
+
inputs["images"] = images
|
116 |
+
if image_sizes is not None:
|
117 |
+
inputs["image_sizes"] = image_sizes
|
118 |
+
return inputs
|
119 |
+
|
120 |
+
|
121 |
+
AutoConfig.register("llava_gemma", LlavaGemmaConfig)
|
122 |
+
AutoModelForCausalLM.register(LlavaGemmaConfig, LlavaGemmaForCausalLM)
|
llava/model/language_model/llava_llama.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Haotian Liu
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from typing import List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
|
21 |
+
from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig
|
22 |
+
|
23 |
+
from torch.nn import CrossEntropyLoss
|
24 |
+
|
25 |
+
|
26 |
+
# , LlamaModel, LlamaForCausalLM, GenerationConfig
|
27 |
+
# from .modeling_llama import LlamaModel, LlamaForCausalLM
|
28 |
+
from transformers import LlamaModel, LlamaForCausalLM
|
29 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
30 |
+
from transformers.generation.utils import GenerateOutput
|
31 |
+
|
32 |
+
from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
|
33 |
+
|
34 |
+
|
35 |
+
class LlavaConfig(LlamaConfig):
|
36 |
+
model_type = "llava_llama"
|
37 |
+
temperature: float = 0.0 # reset to 0.0, previously 0.9 for Vicuna
|
38 |
+
max_new_tokens: int = 1024
|
39 |
+
do_sample: bool = False
|
40 |
+
top_p: Optional[float] = None
|
41 |
+
# rope_scaling: Optional[dict] = {}
|
42 |
+
|
43 |
+
|
44 |
+
class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
|
45 |
+
config_class = LlavaConfig
|
46 |
+
|
47 |
+
def __init__(self, config: LlamaConfig):
|
48 |
+
super(LlavaLlamaModel, self).__init__(config)
|
49 |
+
|
50 |
+
|
51 |
+
class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
|
52 |
+
config_class = LlavaConfig
|
53 |
+
|
54 |
+
def __init__(self, config):
|
55 |
+
LlamaForCausalLM.__init__(self, config)
|
56 |
+
|
57 |
+
# configure default generation settings
|
58 |
+
config.model_type = "llava_llama"
|
59 |
+
# config.rope_scaling = None
|
60 |
+
|
61 |
+
self.model = LlavaLlamaModel(config)
|
62 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
63 |
+
# Initialize weights and apply final processing
|
64 |
+
self.post_init()
|
65 |
+
|
66 |
+
def get_model(self):
|
67 |
+
return self.model
|
68 |
+
|
69 |
+
def forward(
|
70 |
+
self,
|
71 |
+
input_ids: torch.LongTensor = None,
|
72 |
+
attention_mask: Optional[torch.Tensor] = None,
|
73 |
+
position_ids: Optional[torch.LongTensor] = None,
|
74 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
75 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
76 |
+
labels: Optional[torch.LongTensor] = None,
|
77 |
+
use_cache: Optional[bool] = None,
|
78 |
+
output_attentions: Optional[bool] = None,
|
79 |
+
output_hidden_states: Optional[bool] = None,
|
80 |
+
images: Optional[torch.FloatTensor] = None,
|
81 |
+
image_sizes: Optional[List[List[int]]] = None,
|
82 |
+
return_dict: Optional[bool] = None,
|
83 |
+
modalities: Optional[List[str]] = ["image"],
|
84 |
+
dpo_forward: Optional[bool] = None,
|
85 |
+
cache_position=None,
|
86 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
87 |
+
|
88 |
+
if inputs_embeds is None:
|
89 |
+
(input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities, image_sizes)
|
90 |
+
|
91 |
+
if dpo_forward:
|
92 |
+
outputs = self.model(
|
93 |
+
input_ids=input_ids,
|
94 |
+
attention_mask=attention_mask,
|
95 |
+
position_ids=position_ids,
|
96 |
+
past_key_values=past_key_values,
|
97 |
+
inputs_embeds=inputs_embeds,
|
98 |
+
use_cache=use_cache,
|
99 |
+
output_attentions=output_attentions,
|
100 |
+
output_hidden_states=output_hidden_states,
|
101 |
+
return_dict=return_dict,
|
102 |
+
)
|
103 |
+
|
104 |
+
hidden_states = outputs[0]
|
105 |
+
logits = self.lm_head(hidden_states)
|
106 |
+
return logits, labels
|
107 |
+
|
108 |
+
else:
|
109 |
+
return super().forward(
|
110 |
+
input_ids=input_ids,
|
111 |
+
attention_mask=attention_mask,
|
112 |
+
position_ids=position_ids,
|
113 |
+
past_key_values=past_key_values,
|
114 |
+
inputs_embeds=inputs_embeds,
|
115 |
+
labels=labels,
|
116 |
+
use_cache=use_cache,
|
117 |
+
output_attentions=output_attentions,
|
118 |
+
output_hidden_states=output_hidden_states,
|
119 |
+
return_dict=return_dict,
|
120 |
+
)
|
121 |
+
|
122 |
+
@torch.no_grad()
|
123 |
+
def generate(
|
124 |
+
self,
|
125 |
+
inputs: Optional[torch.Tensor] = None,
|
126 |
+
images: Optional[torch.Tensor] = None,
|
127 |
+
image_sizes: Optional[torch.Tensor] = None,
|
128 |
+
modalities: Optional[List[str]] = ["image"],
|
129 |
+
**kwargs,
|
130 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
131 |
+
modalities = kwargs.pop("modalities", None) if "modalities" in kwargs and modalities is None else modalities
|
132 |
+
position_ids = kwargs.pop("position_ids", None)
|
133 |
+
attention_mask = kwargs.pop("attention_mask", None)
|
134 |
+
if "inputs_embeds" in kwargs:
|
135 |
+
raise NotImplementedError("`inputs_embeds` is not supported")
|
136 |
+
|
137 |
+
if images is not None:
|
138 |
+
(inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, modalities, image_sizes=image_sizes)
|
139 |
+
else:
|
140 |
+
inputs_embeds = self.get_model().embed_tokens(inputs)
|
141 |
+
|
142 |
+
return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
|
143 |
+
|
144 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
145 |
+
images = kwargs.pop("images", None)
|
146 |
+
image_sizes = kwargs.pop("image_sizes", None)
|
147 |
+
inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
|
148 |
+
if images is not None:
|
149 |
+
inputs["images"] = images
|
150 |
+
if image_sizes is not None:
|
151 |
+
inputs["image_sizes"] = image_sizes
|
152 |
+
return inputs
|
153 |
+
|
154 |
+
|
155 |
+
AutoConfig.register("llava_llama", LlavaConfig)
|
156 |
+
AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
|
llava/model/language_model/llava_mistral.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Haotian Liu
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from typing import List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
from torch.nn import CrossEntropyLoss
|
21 |
+
|
22 |
+
from transformers import AutoConfig, AutoModelForCausalLM, MistralConfig, MistralModel, MistralForCausalLM, GenerationConfig
|
23 |
+
|
24 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
25 |
+
from transformers.generation.utils import GenerateOutput
|
26 |
+
|
27 |
+
from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
|
28 |
+
|
29 |
+
|
30 |
+
class LlavaMistralConfig(MistralConfig):
|
31 |
+
model_type = "llava_mistral"
|
32 |
+
temperature: float = 0.0 # reset to 0.0, previously 0.9 for Vicuna
|
33 |
+
max_new_tokens: int = 1024
|
34 |
+
do_sample: bool = False
|
35 |
+
top_p: Optional[float] = None
|
36 |
+
|
37 |
+
|
38 |
+
class LlavaMistralModel(LlavaMetaModel, MistralModel):
|
39 |
+
config_class = LlavaMistralConfig
|
40 |
+
|
41 |
+
def __init__(self, config: MistralConfig):
|
42 |
+
super(LlavaMistralModel, self).__init__(config)
|
43 |
+
|
44 |
+
|
45 |
+
class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM):
|
46 |
+
config_class = LlavaMistralConfig
|
47 |
+
|
48 |
+
def __init__(self, config):
|
49 |
+
super(MistralForCausalLM, self).__init__(config)
|
50 |
+
|
51 |
+
config.model_type = "llava_mistral"
|
52 |
+
config.rope_scaling = None
|
53 |
+
|
54 |
+
self.model = LlavaMistralModel(config)
|
55 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
56 |
+
# Initialize weights and apply final processing
|
57 |
+
self.post_init()
|
58 |
+
|
59 |
+
def get_model(self):
|
60 |
+
return self.model
|
61 |
+
|
62 |
+
def forward(
|
63 |
+
self,
|
64 |
+
input_ids: torch.LongTensor = None,
|
65 |
+
attention_mask: Optional[torch.Tensor] = None,
|
66 |
+
position_ids: Optional[torch.LongTensor] = None,
|
67 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
68 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
69 |
+
labels: Optional[torch.LongTensor] = None,
|
70 |
+
use_cache: Optional[bool] = None,
|
71 |
+
output_attentions: Optional[bool] = None,
|
72 |
+
output_hidden_states: Optional[bool] = None,
|
73 |
+
images: Optional[torch.FloatTensor] = None,
|
74 |
+
image_sizes: Optional[List[List[int]]] = None,
|
75 |
+
return_dict: Optional[bool] = None,
|
76 |
+
cache_position=None,
|
77 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
78 |
+
|
79 |
+
if inputs_embeds is None:
|
80 |
+
(input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes)
|
81 |
+
|
82 |
+
return super().forward(
|
83 |
+
input_ids=input_ids,
|
84 |
+
attention_mask=attention_mask,
|
85 |
+
position_ids=position_ids,
|
86 |
+
past_key_values=past_key_values,
|
87 |
+
inputs_embeds=inputs_embeds,
|
88 |
+
labels=labels,
|
89 |
+
use_cache=use_cache,
|
90 |
+
output_attentions=output_attentions,
|
91 |
+
output_hidden_states=output_hidden_states,
|
92 |
+
return_dict=return_dict,
|
93 |
+
)
|
94 |
+
|
95 |
+
@torch.no_grad()
|
96 |
+
def generate(
|
97 |
+
self,
|
98 |
+
inputs: Optional[torch.Tensor] = None,
|
99 |
+
images: Optional[torch.Tensor] = None,
|
100 |
+
image_sizes: Optional[torch.Tensor] = None,
|
101 |
+
**kwargs,
|
102 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
103 |
+
position_ids = kwargs.pop("position_ids", None)
|
104 |
+
attention_mask = kwargs.pop("attention_mask", None)
|
105 |
+
if "inputs_embeds" in kwargs:
|
106 |
+
raise NotImplementedError("`inputs_embeds` is not supported")
|
107 |
+
|
108 |
+
if images is not None:
|
109 |
+
(inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes)
|
110 |
+
else:
|
111 |
+
inputs_embeds = self.get_model().embed_tokens(inputs)
|
112 |
+
|
113 |
+
return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
|
114 |
+
|
115 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
116 |
+
images = kwargs.pop("images", None)
|
117 |
+
image_sizes = kwargs.pop("image_sizes", None)
|
118 |
+
inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
|
119 |
+
if images is not None:
|
120 |
+
inputs["images"] = images
|
121 |
+
if image_sizes is not None:
|
122 |
+
inputs["image_sizes"] = image_sizes
|
123 |
+
return inputs
|
124 |
+
|
125 |
+
|
126 |
+
AutoConfig.register("llava_mistral", LlavaMistralConfig)
|
127 |
+
AutoModelForCausalLM.register(LlavaMistralConfig, LlavaMistralForCausalLM)
|
llava/model/language_model/llava_mixtral.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Haotian Liu
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from typing import List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
from torch.nn import CrossEntropyLoss
|
21 |
+
|
22 |
+
from transformers import AutoConfig, AutoModelForCausalLM, MixtralConfig, MixtralModel, MixtralForCausalLM, GenerationConfig
|
23 |
+
|
24 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
25 |
+
from transformers.generation.utils import GenerateOutput
|
26 |
+
|
27 |
+
from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
|
28 |
+
|
29 |
+
|
30 |
+
class LlavaMixtralConfig(MixtralConfig):
|
31 |
+
model_type = "llava_mixtral"
|
32 |
+
|
33 |
+
|
34 |
+
class LlavaMixtralModel(LlavaMetaModel, MixtralModel):
|
35 |
+
config_class = LlavaMixtralConfig
|
36 |
+
|
37 |
+
def __init__(self, config: MixtralConfig):
|
38 |
+
super(LlavaMixtralModel, self).__init__(config)
|
39 |
+
|
40 |
+
|
41 |
+
class LlavaMixtralForCausalLM(MixtralForCausalLM, LlavaMetaForCausalLM):
|
42 |
+
config_class = LlavaMixtralConfig
|
43 |
+
|
44 |
+
def __init__(self, config):
|
45 |
+
super(MixtralForCausalLM, self).__init__(config)
|
46 |
+
|
47 |
+
config.model_type = "llava_mixtral"
|
48 |
+
config.rope_scaling = None
|
49 |
+
self.model = LlavaMixtralModel(config)
|
50 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
51 |
+
# Initialize weights and apply final processing
|
52 |
+
self.post_init()
|
53 |
+
|
54 |
+
def get_model(self):
|
55 |
+
return self.model
|
56 |
+
|
57 |
+
def forward(
|
58 |
+
self,
|
59 |
+
input_ids: torch.LongTensor = None,
|
60 |
+
attention_mask: Optional[torch.Tensor] = None,
|
61 |
+
position_ids: Optional[torch.LongTensor] = None,
|
62 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
63 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
64 |
+
labels: Optional[torch.LongTensor] = None,
|
65 |
+
use_cache: Optional[bool] = None,
|
66 |
+
output_attentions: Optional[bool] = None,
|
67 |
+
output_hidden_states: Optional[bool] = None,
|
68 |
+
images: Optional[torch.FloatTensor] = None,
|
69 |
+
image_sizes: Optional[List[List[int]]] = None,
|
70 |
+
return_dict: Optional[bool] = None,
|
71 |
+
modalities: Optional[List[str]] = ["image"],
|
72 |
+
dpo_forward: Optional[bool] = None,
|
73 |
+
cache_position=None,
|
74 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
75 |
+
|
76 |
+
if inputs_embeds is None:
|
77 |
+
(input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities, image_sizes)
|
78 |
+
|
79 |
+
if dpo_forward:
|
80 |
+
outputs = self.model(
|
81 |
+
input_ids=input_ids,
|
82 |
+
attention_mask=attention_mask,
|
83 |
+
position_ids=position_ids,
|
84 |
+
past_key_values=past_key_values,
|
85 |
+
inputs_embeds=inputs_embeds,
|
86 |
+
use_cache=use_cache,
|
87 |
+
output_attentions=output_attentions,
|
88 |
+
output_hidden_states=output_hidden_states,
|
89 |
+
return_dict=return_dict,
|
90 |
+
)
|
91 |
+
|
92 |
+
hidden_states = outputs[0]
|
93 |
+
logits = self.lm_head(hidden_states)
|
94 |
+
return logits, labels
|
95 |
+
|
96 |
+
else:
|
97 |
+
return super().forward(
|
98 |
+
input_ids=input_ids,
|
99 |
+
attention_mask=attention_mask,
|
100 |
+
position_ids=position_ids,
|
101 |
+
past_key_values=past_key_values,
|
102 |
+
inputs_embeds=inputs_embeds,
|
103 |
+
labels=labels,
|
104 |
+
use_cache=use_cache,
|
105 |
+
output_attentions=output_attentions,
|
106 |
+
output_hidden_states=output_hidden_states,
|
107 |
+
return_dict=return_dict,
|
108 |
+
)
|
109 |
+
|
110 |
+
@torch.no_grad()
|
111 |
+
def generate(
|
112 |
+
self,
|
113 |
+
inputs: Optional[torch.Tensor] = None,
|
114 |
+
images: Optional[torch.Tensor] = None,
|
115 |
+
image_sizes: Optional[torch.Tensor] = None,
|
116 |
+
modalities: Optional[List[str]] = ["image"],
|
117 |
+
**kwargs,
|
118 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
119 |
+
position_ids = kwargs.pop("position_ids", None)
|
120 |
+
attention_mask = kwargs.pop("attention_mask", None)
|
121 |
+
if "inputs_embeds" in kwargs:
|
122 |
+
raise NotImplementedError("`inputs_embeds` is not supported")
|
123 |
+
|
124 |
+
if images is not None:
|
125 |
+
(inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, modalities, image_sizes=image_sizes)
|
126 |
+
else:
|
127 |
+
inputs_embeds = self.get_model().embed_tokens(inputs)
|
128 |
+
|
129 |
+
return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
|
130 |
+
|
131 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
132 |
+
images = kwargs.pop("images", None)
|
133 |
+
image_sizes = kwargs.pop("image_sizes", None)
|
134 |
+
inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
|
135 |
+
if images is not None:
|
136 |
+
inputs["images"] = images
|
137 |
+
if image_sizes is not None:
|
138 |
+
inputs["image_sizes"] = image_sizes
|
139 |
+
return inputs
|
140 |
+
|
141 |
+
|
142 |
+
AutoConfig.register("llava_mixtral", LlavaMixtralConfig)
|
143 |
+
AutoModelForCausalLM.register(LlavaMixtralConfig, LlavaMixtralForCausalLM)
|
llava/model/language_model/llava_mpt.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Haotian Liu
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from typing import Optional, Tuple
|
17 |
+
|
18 |
+
import torch
|
19 |
+
|
20 |
+
from transformers import AutoConfig, AutoModelForCausalLM, MptConfig, MptForCausalLM, MptModel, GenerationConfig
|
21 |
+
from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
|
22 |
+
|
23 |
+
|
24 |
+
class LlavaMptConfig(MptConfig):
|
25 |
+
model_type = "llava_mpt"
|
26 |
+
|
27 |
+
|
28 |
+
class LlavaMptModel(LlavaMetaModel, MptModel):
|
29 |
+
config_class = LlavaMptConfig
|
30 |
+
|
31 |
+
def __init__(self, config: MptConfig):
|
32 |
+
config.hidden_size = config.d_model
|
33 |
+
super(LlavaMptModel, self).__init__(config)
|
34 |
+
|
35 |
+
def embed_tokens(self, x):
|
36 |
+
return self.wte(x)
|
37 |
+
|
38 |
+
|
39 |
+
class LlavaMptForCausalLM(MptForCausalLM, LlavaMetaForCausalLM):
|
40 |
+
config_class = LlavaMptConfig
|
41 |
+
supports_gradient_checkpointing = True
|
42 |
+
|
43 |
+
def __init__(self, config):
|
44 |
+
super(MptForCausalLM, self).__init__(config)
|
45 |
+
|
46 |
+
config.model_type = "llava_mpt"
|
47 |
+
config.rope_scaling = None
|
48 |
+
self.generation_config = GenerationConfig(
|
49 |
+
temperature=0.0,
|
50 |
+
max_new_tokens=1024,
|
51 |
+
do_sample=False,
|
52 |
+
top_p=None,
|
53 |
+
)
|
54 |
+
|
55 |
+
self.transformer = LlavaMptModel(config)
|
56 |
+
self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
57 |
+
|
58 |
+
# Initialize weights and apply final processing
|
59 |
+
self.post_init()
|
60 |
+
|
61 |
+
def get_model(self):
|
62 |
+
return self.transformer
|
63 |
+
|
64 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
65 |
+
if isinstance(module, LlavaMptModel):
|
66 |
+
module.gradient_checkpointing = value
|
67 |
+
|
68 |
+
def forward(
|
69 |
+
self,
|
70 |
+
input_ids: Optional[torch.LongTensor] = None,
|
71 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
72 |
+
attention_mask: Optional[torch.Tensor] = None,
|
73 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
74 |
+
labels: Optional[torch.Tensor] = None,
|
75 |
+
use_cache: Optional[bool] = None,
|
76 |
+
output_attentions: Optional[bool] = None,
|
77 |
+
output_hidden_states: Optional[bool] = None,
|
78 |
+
return_dict: Optional[bool] = None,
|
79 |
+
cache_position=None,
|
80 |
+
images=None,
|
81 |
+
):
|
82 |
+
|
83 |
+
input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
|
84 |
+
|
85 |
+
return super().forward(
|
86 |
+
input_ids,
|
87 |
+
past_key_values=past_key_values,
|
88 |
+
attention_mask=attention_mask,
|
89 |
+
inputs_embeds=inputs_embeds,
|
90 |
+
labels=labels,
|
91 |
+
use_cache=use_cache,
|
92 |
+
output_attentions=output_attentions,
|
93 |
+
output_hidden_states=output_hidden_states,
|
94 |
+
return_dict=return_dict,
|
95 |
+
)
|
96 |
+
|
97 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
98 |
+
images = kwargs.pop("images", None)
|
99 |
+
_inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
|
100 |
+
_inputs["images"] = images
|
101 |
+
return _inputs
|
102 |
+
|
103 |
+
|
104 |
+
AutoConfig.register("llava_mpt", LlavaMptConfig)
|
105 |
+
AutoModelForCausalLM.register(LlavaMptConfig, LlavaMptForCausalLM)
|
llava/model/language_model/llava_qwen.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Hao Zhang
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from typing import List, Optional, Tuple, Union, Dict
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
from torch.nn import CrossEntropyLoss
|
20 |
+
|
21 |
+
import transformers
|
22 |
+
from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig, LlamaModel, LlamaForCausalLM
|
23 |
+
|
24 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
25 |
+
from transformers.generation.utils import GenerateOutput
|
26 |
+
|
27 |
+
# from ...constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
28 |
+
from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
|
29 |
+
from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM
|
30 |
+
|
31 |
+
# from .qwen.modeling_qwen import QWenLMHeadModel, QWenModel
|
32 |
+
# from .qwen.configuration_qwen import QWenConfig
|
33 |
+
|
34 |
+
|
35 |
+
class LlavaQwenConfig(Qwen2Config):
|
36 |
+
model_type = "llava_qwen"
|
37 |
+
|
38 |
+
|
39 |
+
class LlavaQwenModel(LlavaMetaModel, Qwen2Model):
|
40 |
+
config_class = LlavaQwenConfig
|
41 |
+
|
42 |
+
def __init__(self, config: Qwen2Config):
|
43 |
+
super(LlavaQwenModel, self).__init__(config)
|
44 |
+
|
45 |
+
|
46 |
+
class LlavaQwenForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
|
47 |
+
config_class = LlavaQwenConfig
|
48 |
+
|
49 |
+
def __init__(self, config):
|
50 |
+
# super(Qwen2ForCausalLM, self).__init__(config)
|
51 |
+
Qwen2ForCausalLM.__init__(self, config)
|
52 |
+
config.model_type = "llava_qwen"
|
53 |
+
config.rope_scaling = None
|
54 |
+
|
55 |
+
self.model = LlavaQwenModel(config)
|
56 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
57 |
+
# Initialize weights and apply final processing
|
58 |
+
self.post_init()
|
59 |
+
|
60 |
+
def get_model(self):
|
61 |
+
return self.model
|
62 |
+
|
63 |
+
def forward(
|
64 |
+
self,
|
65 |
+
input_ids: torch.LongTensor = None,
|
66 |
+
attention_mask: Optional[torch.Tensor] = None,
|
67 |
+
position_ids: Optional[torch.LongTensor] = None,
|
68 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
69 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
70 |
+
labels: Optional[torch.LongTensor] = None,
|
71 |
+
use_cache: Optional[bool] = None,
|
72 |
+
output_attentions: Optional[bool] = None,
|
73 |
+
output_hidden_states: Optional[bool] = None,
|
74 |
+
images: Optional[torch.FloatTensor] = None,
|
75 |
+
image_sizes: Optional[List[List[int]]] = None,
|
76 |
+
return_dict: Optional[bool] = None,
|
77 |
+
modalities: Optional[List[str]] = ["image"],
|
78 |
+
dpo_forward: Optional[bool] = False,
|
79 |
+
cache_position=None,
|
80 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
81 |
+
|
82 |
+
if inputs_embeds is None:
|
83 |
+
(input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities, image_sizes)
|
84 |
+
|
85 |
+
if dpo_forward:
|
86 |
+
outputs = self.model(
|
87 |
+
input_ids=input_ids,
|
88 |
+
attention_mask=attention_mask,
|
89 |
+
position_ids=position_ids,
|
90 |
+
past_key_values=past_key_values,
|
91 |
+
inputs_embeds=inputs_embeds,
|
92 |
+
use_cache=use_cache,
|
93 |
+
output_attentions=output_attentions,
|
94 |
+
output_hidden_states=output_hidden_states,
|
95 |
+
return_dict=return_dict,
|
96 |
+
)
|
97 |
+
|
98 |
+
hidden_states = outputs[0]
|
99 |
+
logits = self.lm_head(hidden_states)
|
100 |
+
return logits, labels
|
101 |
+
|
102 |
+
else:
|
103 |
+
return super().forward(
|
104 |
+
input_ids=input_ids,
|
105 |
+
attention_mask=attention_mask,
|
106 |
+
position_ids=position_ids,
|
107 |
+
past_key_values=past_key_values,
|
108 |
+
inputs_embeds=inputs_embeds,
|
109 |
+
labels=labels,
|
110 |
+
use_cache=use_cache,
|
111 |
+
output_attentions=output_attentions,
|
112 |
+
output_hidden_states=output_hidden_states,
|
113 |
+
return_dict=return_dict,
|
114 |
+
)
|
115 |
+
|
116 |
+
@torch.no_grad()
|
117 |
+
def generate(
|
118 |
+
self,
|
119 |
+
inputs: Optional[torch.Tensor] = None,
|
120 |
+
images: Optional[torch.Tensor] = None,
|
121 |
+
image_sizes: Optional[torch.Tensor] = None,
|
122 |
+
modalities: Optional[List[str]] = ["image"],
|
123 |
+
**kwargs,
|
124 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
125 |
+
position_ids = kwargs.pop("position_ids", None)
|
126 |
+
attention_mask = kwargs.pop("attention_mask", None)
|
127 |
+
if "inputs_embeds" in kwargs:
|
128 |
+
raise NotImplementedError("`inputs_embeds` is not supported")
|
129 |
+
|
130 |
+
if images is not None:
|
131 |
+
(inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, modalities, image_sizes=image_sizes)
|
132 |
+
else:
|
133 |
+
inputs_embeds = self.get_model().embed_tokens(inputs)
|
134 |
+
|
135 |
+
return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
|
136 |
+
|
137 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
138 |
+
images = kwargs.pop("images", None)
|
139 |
+
image_sizes = kwargs.pop("image_sizes", None)
|
140 |
+
inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
|
141 |
+
if images is not None:
|
142 |
+
inputs["images"] = images
|
143 |
+
if image_sizes is not None:
|
144 |
+
inputs["image_sizes"] = image_sizes
|
145 |
+
return inputs
|
146 |
+
|
147 |
+
|
148 |
+
AutoConfig.register("llava_qwen", LlavaQwenConfig)
|
149 |
+
AutoModelForCausalLM.register(LlavaQwenConfig, LlavaQwenForCausalLM)
|
llava/model/language_model/llava_qwen_moe.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Hao Zhang
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from typing import List, Optional, Tuple, Union, Dict
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
from torch.nn import CrossEntropyLoss
|
20 |
+
|
21 |
+
import transformers
|
22 |
+
from transformers import AutoConfig, AutoModelForCausalLM
|
23 |
+
|
24 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
25 |
+
from transformers.generation.utils import GenerateOutput
|
26 |
+
|
27 |
+
# from ...constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
28 |
+
from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
|
29 |
+
from transformers import Qwen2MoeConfig, Qwen2MoeModel, Qwen2MoeForCausalLM
|
30 |
+
|
31 |
+
# from .qwen.modeling_qwen import QWenLMHeadModel, QWenModel
|
32 |
+
# from .qwen.configuration_qwen import QWenConfig
|
33 |
+
|
34 |
+
|
35 |
+
class LlavaQwenMoeConfig(Qwen2MoeConfig):
|
36 |
+
model_type = "llava_qwen_moe"
|
37 |
+
|
38 |
+
|
39 |
+
class LlavaQwenMoeModel(LlavaMetaModel, Qwen2MoeModel):
|
40 |
+
config_class = LlavaQwenMoeConfig
|
41 |
+
|
42 |
+
def __init__(self, config: Qwen2MoeConfig):
|
43 |
+
super(LlavaQwenMoeModel, self).__init__(config)
|
44 |
+
|
45 |
+
|
46 |
+
class LlavaQwenMoeForCausalLM(Qwen2MoeForCausalLM, LlavaMetaForCausalLM):
|
47 |
+
config_class = LlavaQwenMoeConfig
|
48 |
+
|
49 |
+
def __init__(self, config):
|
50 |
+
# super(Qwen2MoeForCausalLM, self).__init__(config)
|
51 |
+
Qwen2MoeForCausalLM.__init__(self, config)
|
52 |
+
config.model_type = "llava_qwen_moe"
|
53 |
+
config.rope_scaling = None
|
54 |
+
|
55 |
+
self.model = LlavaQwenMoeModel(config)
|
56 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
57 |
+
# Initialize weights and apply final processing
|
58 |
+
self.post_init()
|
59 |
+
|
60 |
+
def get_model(self):
|
61 |
+
return self.model
|
62 |
+
|
63 |
+
def forward(
|
64 |
+
self,
|
65 |
+
input_ids: torch.LongTensor = None,
|
66 |
+
attention_mask: Optional[torch.Tensor] = None,
|
67 |
+
position_ids: Optional[torch.LongTensor] = None,
|
68 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
69 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
70 |
+
labels: Optional[torch.LongTensor] = None,
|
71 |
+
use_cache: Optional[bool] = None,
|
72 |
+
output_attentions: Optional[bool] = None,
|
73 |
+
output_hidden_states: Optional[bool] = None,
|
74 |
+
images: Optional[torch.FloatTensor] = None,
|
75 |
+
image_sizes: Optional[List[List[int]]] = None,
|
76 |
+
return_dict: Optional[bool] = None,
|
77 |
+
modalities: Optional[List[str]] = ["image"],
|
78 |
+
dpo_forward: Optional[bool] = False,
|
79 |
+
cache_position=None,
|
80 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
81 |
+
|
82 |
+
if inputs_embeds is None:
|
83 |
+
(input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities, image_sizes)
|
84 |
+
|
85 |
+
if dpo_forward:
|
86 |
+
outputs = self.model(
|
87 |
+
input_ids=input_ids,
|
88 |
+
attention_mask=attention_mask,
|
89 |
+
position_ids=position_ids,
|
90 |
+
past_key_values=past_key_values,
|
91 |
+
inputs_embeds=inputs_embeds,
|
92 |
+
use_cache=use_cache,
|
93 |
+
output_attentions=output_attentions,
|
94 |
+
output_hidden_states=output_hidden_states,
|
95 |
+
return_dict=return_dict,
|
96 |
+
)
|
97 |
+
|
98 |
+
hidden_states = outputs[0]
|
99 |
+
logits = self.lm_head(hidden_states)
|
100 |
+
return logits, labels
|
101 |
+
|
102 |
+
else:
|
103 |
+
return super().forward(
|
104 |
+
input_ids=input_ids,
|
105 |
+
attention_mask=attention_mask,
|
106 |
+
position_ids=position_ids,
|
107 |
+
past_key_values=past_key_values,
|
108 |
+
inputs_embeds=inputs_embeds,
|
109 |
+
labels=labels,
|
110 |
+
use_cache=use_cache,
|
111 |
+
output_attentions=output_attentions,
|
112 |
+
output_hidden_states=output_hidden_states,
|
113 |
+
return_dict=return_dict,
|
114 |
+
)
|
115 |
+
|
116 |
+
@torch.no_grad()
|
117 |
+
def generate(
|
118 |
+
self,
|
119 |
+
inputs: Optional[torch.Tensor] = None,
|
120 |
+
images: Optional[torch.Tensor] = None,
|
121 |
+
image_sizes: Optional[torch.Tensor] = None,
|
122 |
+
modalities: Optional[List[str]] = ["image"],
|
123 |
+
**kwargs,
|
124 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
125 |
+
position_ids = kwargs.pop("position_ids", None)
|
126 |
+
attention_mask = kwargs.pop("attention_mask", None)
|
127 |
+
if "inputs_embeds" in kwargs:
|
128 |
+
raise NotImplementedError("`inputs_embeds` is not supported")
|
129 |
+
|
130 |
+
if images is not None:
|
131 |
+
(inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, modalities, image_sizes=image_sizes)
|
132 |
+
else:
|
133 |
+
inputs_embeds = self.get_model().embed_tokens(inputs)
|
134 |
+
|
135 |
+
return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
|
136 |
+
|
137 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
138 |
+
images = kwargs.pop("images", None)
|
139 |
+
image_sizes = kwargs.pop("image_sizes", None)
|
140 |
+
inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
|
141 |
+
if images is not None:
|
142 |
+
inputs["images"] = images
|
143 |
+
if image_sizes is not None:
|
144 |
+
inputs["image_sizes"] = image_sizes
|
145 |
+
return inputs
|
146 |
+
|
147 |
+
|
148 |
+
AutoConfig.register("llava_qwen_moe", LlavaQwenMoeConfig)
|
149 |
+
AutoModelForCausalLM.register(LlavaQwenMoeConfig, LlavaQwenMoeForCausalLM)
|
llava/model/language_model/modeling_llama.py
ADDED
@@ -0,0 +1,1649 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
5 |
+
# and OPT implementations in this library. It has been modified from its
|
6 |
+
# original forms to accommodate minor architectural differences compared
|
7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
8 |
+
#
|
9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
10 |
+
# you may not use this file except in compliance with the License.
|
11 |
+
# You may obtain a copy of the License at
|
12 |
+
#
|
13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
14 |
+
#
|
15 |
+
# Unless required by applicable law or agreed to in writing, software
|
16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
18 |
+
# See the License for the specific language governing permissions and
|
19 |
+
# limitations under the License.
|
20 |
+
""" PyTorch LLaMA model."""
|
21 |
+
import math
|
22 |
+
import warnings
|
23 |
+
from typing import List, Optional, Tuple, Union
|
24 |
+
|
25 |
+
import torch
|
26 |
+
import torch.nn.functional as F
|
27 |
+
import torch.utils.checkpoint
|
28 |
+
from torch import nn
|
29 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
30 |
+
|
31 |
+
from transformers.activations import ACT2FN
|
32 |
+
from transformers.cache_utils import Cache, DynamicCache, StaticCache
|
33 |
+
from transformers.modeling_outputs import (
|
34 |
+
BaseModelOutputWithPast,
|
35 |
+
CausalLMOutputWithPast,
|
36 |
+
QuestionAnsweringModelOutput,
|
37 |
+
SequenceClassifierOutputWithPast,
|
38 |
+
)
|
39 |
+
from transformers.modeling_utils import PreTrainedModel
|
40 |
+
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
41 |
+
from transformers.utils import (
|
42 |
+
add_start_docstrings,
|
43 |
+
add_start_docstrings_to_model_forward,
|
44 |
+
is_flash_attn_2_available,
|
45 |
+
is_flash_attn_greater_or_equal_2_10,
|
46 |
+
logging,
|
47 |
+
replace_return_docstrings,
|
48 |
+
)
|
49 |
+
from transformers.models.llama.configuration_llama import LlamaConfig
|
50 |
+
|
51 |
+
if is_flash_attn_2_available():
|
52 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
53 |
+
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
54 |
+
|
55 |
+
|
56 |
+
logger = logging.get_logger(__name__)
|
57 |
+
|
58 |
+
_CONFIG_FOR_DOC = "LlamaConfig"
|
59 |
+
|
60 |
+
|
61 |
+
def _get_unpad_data(attention_mask):
|
62 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
63 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
64 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
65 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
66 |
+
return (
|
67 |
+
indices,
|
68 |
+
cu_seqlens,
|
69 |
+
max_seqlen_in_batch,
|
70 |
+
)
|
71 |
+
|
72 |
+
|
73 |
+
class LlamaRMSNorm(nn.Module):
|
74 |
+
def __init__(self, hidden_size, eps=1e-6):
|
75 |
+
"""
|
76 |
+
LlamaRMSNorm is equivalent to T5LayerNorm
|
77 |
+
"""
|
78 |
+
super().__init__()
|
79 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
80 |
+
self.variance_epsilon = eps
|
81 |
+
|
82 |
+
def forward(self, hidden_states):
|
83 |
+
input_dtype = hidden_states.dtype
|
84 |
+
hidden_states = hidden_states.to(torch.float32)
|
85 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
86 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
87 |
+
return self.weight * hidden_states.to(input_dtype)
|
88 |
+
|
89 |
+
|
90 |
+
ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
|
91 |
+
|
92 |
+
|
93 |
+
class LlamaRotaryEmbedding(nn.Module):
|
94 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
95 |
+
super().__init__()
|
96 |
+
self.scaling_factor = scaling_factor
|
97 |
+
self.dim = dim
|
98 |
+
self.max_position_embeddings = max_position_embeddings
|
99 |
+
self.base = base
|
100 |
+
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
101 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
102 |
+
# For BC we register cos and sin cached
|
103 |
+
self.max_seq_len_cached = max_position_embeddings
|
104 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
105 |
+
t = t / self.scaling_factor
|
106 |
+
freqs = torch.outer(t, self.inv_freq)
|
107 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
108 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
109 |
+
self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False)
|
110 |
+
self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False)
|
111 |
+
|
112 |
+
@property
|
113 |
+
def sin_cached(self):
|
114 |
+
logger.warning_once("The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class")
|
115 |
+
return self._sin_cached
|
116 |
+
|
117 |
+
@property
|
118 |
+
def cos_cached(self):
|
119 |
+
logger.warning_once("The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class")
|
120 |
+
return self._cos_cached
|
121 |
+
|
122 |
+
@torch.no_grad()
|
123 |
+
def forward(self, x, position_ids, seq_len=None):
|
124 |
+
if seq_len is not None:
|
125 |
+
logger.warning_once("The `seq_len` argument is deprecated and unused. It will be removed in v4.39.")
|
126 |
+
|
127 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
128 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
129 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
130 |
+
# Force float32 since bfloat16 loses precision on long contexts
|
131 |
+
# See https://github.com/huggingface/transformers/pull/29285
|
132 |
+
device_type = x.device.type
|
133 |
+
device_type = device_type if isinstance(device_type, str) else "cpu"
|
134 |
+
with torch.autocast(device_type=device_type, enabled=False):
|
135 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
136 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
137 |
+
cos = emb.cos()
|
138 |
+
sin = emb.sin()
|
139 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
140 |
+
|
141 |
+
|
142 |
+
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
143 |
+
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
144 |
+
|
145 |
+
def forward(self, x, position_ids, seq_len=None):
|
146 |
+
# difference to the original RoPE: a scaling factor is aplied to the position ids
|
147 |
+
position_ids = position_ids.float() / self.scaling_factor
|
148 |
+
cos, sin = super().forward(x, position_ids, seq_len)
|
149 |
+
return cos, sin
|
150 |
+
|
151 |
+
|
152 |
+
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
153 |
+
"""LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
154 |
+
|
155 |
+
def forward(self, x, position_ids, seq_len=None):
|
156 |
+
# difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
|
157 |
+
seq_len = torch.max(position_ids) + 1
|
158 |
+
if seq_len > self.max_position_embeddings:
|
159 |
+
base = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)) ** (self.dim / (self.dim - 2))
|
160 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim))
|
161 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation
|
162 |
+
|
163 |
+
cos, sin = super().forward(x, position_ids, seq_len)
|
164 |
+
return cos, sin
|
165 |
+
|
166 |
+
|
167 |
+
def rotate_half(x):
|
168 |
+
"""Rotates half the hidden dims of the input."""
|
169 |
+
x1 = x[..., : x.shape[-1] // 2]
|
170 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
171 |
+
return torch.cat((-x2, x1), dim=-1)
|
172 |
+
|
173 |
+
|
174 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
175 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
176 |
+
|
177 |
+
Args:
|
178 |
+
q (`torch.Tensor`): The query tensor.
|
179 |
+
k (`torch.Tensor`): The key tensor.
|
180 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
181 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
182 |
+
position_ids (`torch.Tensor`, *optional*):
|
183 |
+
Deprecated and unused.
|
184 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
185 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
186 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
187 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
188 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
189 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
190 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
191 |
+
Returns:
|
192 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
193 |
+
"""
|
194 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
195 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
196 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
197 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
198 |
+
return q_embed, k_embed
|
199 |
+
|
200 |
+
|
201 |
+
class LlamaMLP(nn.Module):
|
202 |
+
def __init__(self, config):
|
203 |
+
super().__init__()
|
204 |
+
self.config = config
|
205 |
+
self.hidden_size = config.hidden_size
|
206 |
+
self.intermediate_size = config.intermediate_size
|
207 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
208 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
209 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
210 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
211 |
+
|
212 |
+
def forward(self, x):
|
213 |
+
if self.config.pretraining_tp > 1:
|
214 |
+
slice = self.intermediate_size // self.config.pretraining_tp
|
215 |
+
gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
|
216 |
+
up_proj_slices = self.up_proj.weight.split(slice, dim=0)
|
217 |
+
down_proj_slices = self.down_proj.weight.split(slice, dim=1)
|
218 |
+
|
219 |
+
gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
|
220 |
+
up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
|
221 |
+
|
222 |
+
intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
|
223 |
+
down_proj = [F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)]
|
224 |
+
down_proj = sum(down_proj)
|
225 |
+
else:
|
226 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
227 |
+
|
228 |
+
return down_proj
|
229 |
+
|
230 |
+
|
231 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
232 |
+
"""
|
233 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
234 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
235 |
+
"""
|
236 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
237 |
+
if n_rep == 1:
|
238 |
+
return hidden_states
|
239 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
240 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
241 |
+
|
242 |
+
|
243 |
+
class LlamaAttention(nn.Module):
|
244 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
245 |
+
|
246 |
+
def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
|
247 |
+
super().__init__()
|
248 |
+
self.config = config
|
249 |
+
self.layer_idx = layer_idx
|
250 |
+
if layer_idx is None:
|
251 |
+
logger.warning_once(
|
252 |
+
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
|
253 |
+
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
|
254 |
+
"when creating this class."
|
255 |
+
)
|
256 |
+
|
257 |
+
self.attention_dropout = config.attention_dropout
|
258 |
+
self.hidden_size = config.hidden_size
|
259 |
+
self.num_heads = config.num_attention_heads
|
260 |
+
self.head_dim = self.hidden_size // self.num_heads
|
261 |
+
self.num_key_value_heads = config.num_key_value_heads
|
262 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
263 |
+
self.max_position_embeddings = config.max_position_embeddings
|
264 |
+
self.rope_theta = config.rope_theta
|
265 |
+
self.is_causal = True
|
266 |
+
|
267 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
268 |
+
raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f" and `num_heads`: {self.num_heads}).")
|
269 |
+
|
270 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
|
271 |
+
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
272 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
273 |
+
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
|
274 |
+
self._init_rope()
|
275 |
+
|
276 |
+
def _init_rope(self):
|
277 |
+
if self.config.rope_scaling is None:
|
278 |
+
self.rotary_emb = LlamaRotaryEmbedding(
|
279 |
+
self.head_dim,
|
280 |
+
max_position_embeddings=self.max_position_embeddings,
|
281 |
+
base=self.rope_theta,
|
282 |
+
)
|
283 |
+
else:
|
284 |
+
scaling_type = self.config.rope_scaling["type"]
|
285 |
+
scaling_factor = self.config.rope_scaling["factor"]
|
286 |
+
if scaling_type == "linear":
|
287 |
+
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
|
288 |
+
self.head_dim,
|
289 |
+
max_position_embeddings=self.max_position_embeddings,
|
290 |
+
scaling_factor=scaling_factor,
|
291 |
+
base=self.rope_theta,
|
292 |
+
)
|
293 |
+
elif scaling_type == "dynamic":
|
294 |
+
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
|
295 |
+
self.head_dim,
|
296 |
+
max_position_embeddings=self.max_position_embeddings,
|
297 |
+
scaling_factor=scaling_factor,
|
298 |
+
base=self.rope_theta,
|
299 |
+
)
|
300 |
+
else:
|
301 |
+
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
302 |
+
|
303 |
+
def forward(
|
304 |
+
self,
|
305 |
+
hidden_states: torch.Tensor,
|
306 |
+
attention_mask: Optional[torch.Tensor] = None,
|
307 |
+
position_ids: Optional[torch.LongTensor] = None,
|
308 |
+
past_key_value: Optional[Cache] = None,
|
309 |
+
output_attentions: bool = False,
|
310 |
+
use_cache: bool = False,
|
311 |
+
cache_position: Optional[torch.LongTensor] = None,
|
312 |
+
**kwargs,
|
313 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
314 |
+
bsz, q_len, _ = hidden_states.size()
|
315 |
+
|
316 |
+
if self.config.pretraining_tp > 1:
|
317 |
+
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
318 |
+
query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0)
|
319 |
+
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
320 |
+
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
321 |
+
|
322 |
+
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
|
323 |
+
query_states = torch.cat(query_states, dim=-1)
|
324 |
+
|
325 |
+
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
|
326 |
+
key_states = torch.cat(key_states, dim=-1)
|
327 |
+
|
328 |
+
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
|
329 |
+
value_states = torch.cat(value_states, dim=-1)
|
330 |
+
|
331 |
+
else:
|
332 |
+
query_states = self.q_proj(hidden_states)
|
333 |
+
key_states = self.k_proj(hidden_states)
|
334 |
+
value_states = self.v_proj(hidden_states)
|
335 |
+
|
336 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
337 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
338 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
339 |
+
|
340 |
+
past_key_value = getattr(self, "past_key_value", past_key_value)
|
341 |
+
cos, sin = self.rotary_emb(value_states, position_ids)
|
342 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
343 |
+
|
344 |
+
if past_key_value is not None:
|
345 |
+
# sin and cos are specific to RoPE models; position_ids needed for the static cache
|
346 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
347 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
348 |
+
|
349 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
350 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
351 |
+
|
352 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
353 |
+
|
354 |
+
if attention_mask is not None: # no matter the length, we just slice it
|
355 |
+
causal_mask = attention_mask
|
356 |
+
if cache_position is not None:
|
357 |
+
causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
|
358 |
+
attn_weights = attn_weights + causal_mask
|
359 |
+
|
360 |
+
# upcast attention to fp32
|
361 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
362 |
+
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
363 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
364 |
+
|
365 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
366 |
+
raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}")
|
367 |
+
|
368 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
369 |
+
|
370 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
371 |
+
|
372 |
+
if self.config.pretraining_tp > 1:
|
373 |
+
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
374 |
+
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|
375 |
+
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
|
376 |
+
else:
|
377 |
+
attn_output = self.o_proj(attn_output)
|
378 |
+
|
379 |
+
if not output_attentions:
|
380 |
+
attn_weights = None
|
381 |
+
|
382 |
+
return attn_output, attn_weights, past_key_value
|
383 |
+
|
384 |
+
|
385 |
+
class LlamaRingFlashAttention2(LlamaAttention):
|
386 |
+
"""
|
387 |
+
Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
|
388 |
+
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
389 |
+
flash attention and deal with padding tokens in case the input contains any of them.
|
390 |
+
"""
|
391 |
+
|
392 |
+
def __init__(self, *args, **kwargs):
|
393 |
+
super().__init__(*args, **kwargs)
|
394 |
+
|
395 |
+
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
396 |
+
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
397 |
+
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
398 |
+
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
399 |
+
|
400 |
+
def forward(
|
401 |
+
self,
|
402 |
+
hidden_states: torch.Tensor,
|
403 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
404 |
+
position_ids: Optional[torch.LongTensor] = None,
|
405 |
+
past_key_value: Optional[Cache] = None,
|
406 |
+
output_attentions: bool = False,
|
407 |
+
use_cache: bool = False,
|
408 |
+
cache_position: Optional[torch.LongTensor] = None,
|
409 |
+
**kwargs,
|
410 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
411 |
+
output_attentions = False
|
412 |
+
|
413 |
+
bsz, q_len, _ = hidden_states.size()
|
414 |
+
|
415 |
+
query_states = self.q_proj(hidden_states)
|
416 |
+
key_states = self.k_proj(hidden_states)
|
417 |
+
value_states = self.v_proj(hidden_states)
|
418 |
+
|
419 |
+
# Flash attention requires the input to have the shape
|
420 |
+
# batch_size x seq_length x head_dim x hidden_dim
|
421 |
+
# therefore we just need to keep the original shape
|
422 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
423 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
424 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
425 |
+
|
426 |
+
cos, sin = self.rotary_emb(value_states, position_ids)
|
427 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
428 |
+
|
429 |
+
past_key_value = getattr(self, "past_key_value", past_key_value)
|
430 |
+
|
431 |
+
if past_key_value is not None:
|
432 |
+
# sin and cos are specific to RoPE models; position_ids needed for the static cache
|
433 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
434 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
435 |
+
|
436 |
+
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
437 |
+
# to be able to avoid many of these transpose/reshape/view.
|
438 |
+
query_states = query_states.transpose(1, 2)
|
439 |
+
key_states = key_states.transpose(1, 2)
|
440 |
+
value_states = value_states.transpose(1, 2)
|
441 |
+
|
442 |
+
dropout_rate = self.attention_dropout if self.training else 0.0
|
443 |
+
|
444 |
+
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
445 |
+
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
446 |
+
# cast them back in the correct dtype just to be sure everything works as expected.
|
447 |
+
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
448 |
+
# in fp32. (LlamaRMSNorm handles it correctly)
|
449 |
+
|
450 |
+
input_dtype = query_states.dtype
|
451 |
+
if input_dtype == torch.float32:
|
452 |
+
if torch.is_autocast_enabled():
|
453 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
454 |
+
# Handle the case where the model is quantized
|
455 |
+
elif hasattr(self.config, "_pre_quantization_dtype"):
|
456 |
+
target_dtype = self.config._pre_quantization_dtype
|
457 |
+
else:
|
458 |
+
target_dtype = self.q_proj.weight.dtype
|
459 |
+
|
460 |
+
logger.warning_once(
|
461 |
+
f"The input hidden states seems to be silently casted in float32, this might be related to" f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" f" {target_dtype}."
|
462 |
+
)
|
463 |
+
|
464 |
+
query_states = query_states.to(target_dtype)
|
465 |
+
key_states = key_states.to(target_dtype)
|
466 |
+
value_states = value_states.to(target_dtype)
|
467 |
+
|
468 |
+
attn_output = self._flash_attention_forward(query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate)
|
469 |
+
|
470 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
471 |
+
attn_output = self.o_proj(attn_output)
|
472 |
+
|
473 |
+
if not output_attentions:
|
474 |
+
attn_weights = None
|
475 |
+
|
476 |
+
return attn_output, attn_weights, past_key_value
|
477 |
+
|
478 |
+
def _flash_attention_forward(self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None):
|
479 |
+
"""
|
480 |
+
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
481 |
+
first unpad the input, then computes the attention scores and pad the final attention scores.
|
482 |
+
|
483 |
+
Args:
|
484 |
+
query_states (`torch.Tensor`):
|
485 |
+
Input query states to be passed to Flash Attention API
|
486 |
+
key_states (`torch.Tensor`):
|
487 |
+
Input key states to be passed to Flash Attention API
|
488 |
+
value_states (`torch.Tensor`):
|
489 |
+
Input value states to be passed to Flash Attention API
|
490 |
+
attention_mask (`torch.Tensor`):
|
491 |
+
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
492 |
+
position of padding tokens and 1 for the position of non-padding tokens.
|
493 |
+
dropout (`int`, *optional*):
|
494 |
+
Attention dropout
|
495 |
+
softmax_scale (`float`, *optional*):
|
496 |
+
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
497 |
+
"""
|
498 |
+
if not self._flash_attn_uses_top_left_mask:
|
499 |
+
causal = self.is_causal
|
500 |
+
else:
|
501 |
+
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
|
502 |
+
causal = self.is_causal and query_length != 1
|
503 |
+
|
504 |
+
# Contains at least one padding token in the sequence
|
505 |
+
if attention_mask is not None:
|
506 |
+
batch_size = query_states.shape[0]
|
507 |
+
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(query_states, key_states, value_states, attention_mask, query_length)
|
508 |
+
|
509 |
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
510 |
+
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
511 |
+
|
512 |
+
attn_output_unpad = zigzag_ring_flash_attn_varlen_func(
|
513 |
+
query_states,
|
514 |
+
key_states,
|
515 |
+
value_states,
|
516 |
+
cu_seqlens_q=cu_seqlens_q,
|
517 |
+
cu_seqlens_k=cu_seqlens_k,
|
518 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
519 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
520 |
+
dropout_p=dropout,
|
521 |
+
softmax_scale=softmax_scale,
|
522 |
+
causal=causal,
|
523 |
+
)
|
524 |
+
|
525 |
+
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
526 |
+
else:
|
527 |
+
# pack qkv
|
528 |
+
# query_states: (batch_size, seqlen, nheads, headdim)
|
529 |
+
# qkv: (batch_size, seqlen, 3, nheads, headdim)
|
530 |
+
qkv = torch.stack([query_states, key_states, value_states], dim=2)
|
531 |
+
attn_output = zigzag_ring_flash_attn_qkvpacked_func(qkv, dropout, softmax_scale, causal=causal)
|
532 |
+
|
533 |
+
return attn_output
|
534 |
+
|
535 |
+
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
536 |
+
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
537 |
+
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
538 |
+
|
539 |
+
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k)
|
540 |
+
value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k)
|
541 |
+
if query_length == kv_seq_len:
|
542 |
+
query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k)
|
543 |
+
cu_seqlens_q = cu_seqlens_k
|
544 |
+
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
545 |
+
indices_q = indices_k
|
546 |
+
elif query_length == 1:
|
547 |
+
max_seqlen_in_batch_q = 1
|
548 |
+
cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=query_layer.device) # There is a memcpy here, that is very bad.
|
549 |
+
indices_q = cu_seqlens_q[:-1]
|
550 |
+
query_layer = query_layer.squeeze(1)
|
551 |
+
else:
|
552 |
+
# The -q_len: slice assumes left padding.
|
553 |
+
attention_mask = attention_mask[:, -query_length:]
|
554 |
+
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
555 |
+
|
556 |
+
return (
|
557 |
+
query_layer,
|
558 |
+
key_layer,
|
559 |
+
value_layer,
|
560 |
+
indices_q,
|
561 |
+
(cu_seqlens_q, cu_seqlens_k),
|
562 |
+
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
563 |
+
)
|
564 |
+
|
565 |
+
|
566 |
+
class LlamaFlashAttention2(LlamaAttention):
|
567 |
+
"""
|
568 |
+
Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
|
569 |
+
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
570 |
+
flash attention and deal with padding tokens in case the input contains any of them.
|
571 |
+
"""
|
572 |
+
|
573 |
+
def __init__(self, *args, **kwargs):
|
574 |
+
super().__init__(*args, **kwargs)
|
575 |
+
|
576 |
+
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
577 |
+
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
578 |
+
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
579 |
+
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
580 |
+
|
581 |
+
def forward(
|
582 |
+
self,
|
583 |
+
hidden_states: torch.Tensor,
|
584 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
585 |
+
position_ids: Optional[torch.LongTensor] = None,
|
586 |
+
past_key_value: Optional[Cache] = None,
|
587 |
+
output_attentions: bool = False,
|
588 |
+
use_cache: bool = False,
|
589 |
+
cache_position: Optional[torch.LongTensor] = None,
|
590 |
+
**kwargs,
|
591 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
592 |
+
output_attentions = False
|
593 |
+
|
594 |
+
bsz, q_len, _ = hidden_states.size()
|
595 |
+
|
596 |
+
query_states = self.q_proj(hidden_states)
|
597 |
+
key_states = self.k_proj(hidden_states)
|
598 |
+
value_states = self.v_proj(hidden_states)
|
599 |
+
|
600 |
+
# Flash attention requires the input to have the shape
|
601 |
+
# batch_size x seq_length x head_dim x hidden_dim
|
602 |
+
# therefore we just need to keep the original shape
|
603 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
604 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
605 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
606 |
+
|
607 |
+
cos, sin = self.rotary_emb(value_states, position_ids)
|
608 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
609 |
+
|
610 |
+
past_key_value = getattr(self, "past_key_value", past_key_value)
|
611 |
+
|
612 |
+
if past_key_value is not None:
|
613 |
+
# sin and cos are specific to RoPE models; position_ids needed for the static cache
|
614 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
615 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
616 |
+
|
617 |
+
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
618 |
+
# to be able to avoid many of these transpose/reshape/view.
|
619 |
+
query_states = query_states.transpose(1, 2)
|
620 |
+
key_states = key_states.transpose(1, 2)
|
621 |
+
value_states = value_states.transpose(1, 2)
|
622 |
+
|
623 |
+
dropout_rate = self.attention_dropout if self.training else 0.0
|
624 |
+
|
625 |
+
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
626 |
+
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
627 |
+
# cast them back in the correct dtype just to be sure everything works as expected.
|
628 |
+
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
629 |
+
# in fp32. (LlamaRMSNorm handles it correctly)
|
630 |
+
|
631 |
+
input_dtype = query_states.dtype
|
632 |
+
if input_dtype == torch.float32:
|
633 |
+
if torch.is_autocast_enabled():
|
634 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
635 |
+
# Handle the case where the model is quantized
|
636 |
+
elif hasattr(self.config, "_pre_quantization_dtype"):
|
637 |
+
target_dtype = self.config._pre_quantization_dtype
|
638 |
+
else:
|
639 |
+
target_dtype = self.q_proj.weight.dtype
|
640 |
+
|
641 |
+
logger.warning_once(
|
642 |
+
f"The input hidden states seems to be silently casted in float32, this might be related to" f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" f" {target_dtype}."
|
643 |
+
)
|
644 |
+
|
645 |
+
query_states = query_states.to(target_dtype)
|
646 |
+
key_states = key_states.to(target_dtype)
|
647 |
+
value_states = value_states.to(target_dtype)
|
648 |
+
|
649 |
+
attn_output = self._flash_attention_forward(query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate)
|
650 |
+
|
651 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
652 |
+
attn_output = self.o_proj(attn_output)
|
653 |
+
|
654 |
+
if not output_attentions:
|
655 |
+
attn_weights = None
|
656 |
+
|
657 |
+
return attn_output, attn_weights, past_key_value
|
658 |
+
|
659 |
+
def _flash_attention_forward(self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None):
|
660 |
+
"""
|
661 |
+
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
662 |
+
first unpad the input, then computes the attention scores and pad the final attention scores.
|
663 |
+
|
664 |
+
Args:
|
665 |
+
query_states (`torch.Tensor`):
|
666 |
+
Input query states to be passed to Flash Attention API
|
667 |
+
key_states (`torch.Tensor`):
|
668 |
+
Input key states to be passed to Flash Attention API
|
669 |
+
value_states (`torch.Tensor`):
|
670 |
+
Input value states to be passed to Flash Attention API
|
671 |
+
attention_mask (`torch.Tensor`):
|
672 |
+
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
673 |
+
position of padding tokens and 1 for the position of non-padding tokens.
|
674 |
+
dropout (`int`, *optional*):
|
675 |
+
Attention dropout
|
676 |
+
softmax_scale (`float`, *optional*):
|
677 |
+
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
678 |
+
"""
|
679 |
+
if not self._flash_attn_uses_top_left_mask:
|
680 |
+
causal = self.is_causal
|
681 |
+
else:
|
682 |
+
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
|
683 |
+
causal = self.is_causal and query_length != 1
|
684 |
+
|
685 |
+
# Contains at least one padding token in the sequence
|
686 |
+
if attention_mask is not None:
|
687 |
+
batch_size = query_states.shape[0]
|
688 |
+
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(query_states, key_states, value_states, attention_mask, query_length)
|
689 |
+
|
690 |
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
691 |
+
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
692 |
+
|
693 |
+
attn_output_unpad = flash_attn_varlen_func(
|
694 |
+
query_states,
|
695 |
+
key_states,
|
696 |
+
value_states,
|
697 |
+
cu_seqlens_q=cu_seqlens_q,
|
698 |
+
cu_seqlens_k=cu_seqlens_k,
|
699 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
700 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
701 |
+
dropout_p=dropout,
|
702 |
+
softmax_scale=softmax_scale,
|
703 |
+
causal=causal,
|
704 |
+
)
|
705 |
+
|
706 |
+
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
707 |
+
else:
|
708 |
+
attn_output = flash_attn_func(query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal)
|
709 |
+
|
710 |
+
return attn_output
|
711 |
+
|
712 |
+
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
713 |
+
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
714 |
+
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
715 |
+
|
716 |
+
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k)
|
717 |
+
value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k)
|
718 |
+
if query_length == kv_seq_len:
|
719 |
+
query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k)
|
720 |
+
cu_seqlens_q = cu_seqlens_k
|
721 |
+
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
722 |
+
indices_q = indices_k
|
723 |
+
elif query_length == 1:
|
724 |
+
max_seqlen_in_batch_q = 1
|
725 |
+
cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=query_layer.device) # There is a memcpy here, that is very bad.
|
726 |
+
indices_q = cu_seqlens_q[:-1]
|
727 |
+
query_layer = query_layer.squeeze(1)
|
728 |
+
else:
|
729 |
+
# The -q_len: slice assumes left padding.
|
730 |
+
attention_mask = attention_mask[:, -query_length:]
|
731 |
+
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
732 |
+
|
733 |
+
return (
|
734 |
+
query_layer,
|
735 |
+
key_layer,
|
736 |
+
value_layer,
|
737 |
+
indices_q,
|
738 |
+
(cu_seqlens_q, cu_seqlens_k),
|
739 |
+
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
740 |
+
)
|
741 |
+
|
742 |
+
|
743 |
+
class LlamaSdpaAttention(LlamaAttention):
|
744 |
+
"""
|
745 |
+
Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
746 |
+
`LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
747 |
+
SDPA API.
|
748 |
+
"""
|
749 |
+
|
750 |
+
# Adapted from LlamaAttention.forward
|
751 |
+
def forward(
|
752 |
+
self,
|
753 |
+
hidden_states: torch.Tensor,
|
754 |
+
attention_mask: Optional[torch.Tensor] = None,
|
755 |
+
position_ids: Optional[torch.LongTensor] = None,
|
756 |
+
past_key_value: Optional[Cache] = None,
|
757 |
+
output_attentions: bool = False,
|
758 |
+
use_cache: bool = False,
|
759 |
+
cache_position: Optional[torch.LongTensor] = None,
|
760 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
761 |
+
if output_attentions:
|
762 |
+
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
763 |
+
logger.warning_once(
|
764 |
+
"LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
765 |
+
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
766 |
+
)
|
767 |
+
return super().forward(
|
768 |
+
hidden_states=hidden_states,
|
769 |
+
attention_mask=attention_mask,
|
770 |
+
position_ids=position_ids,
|
771 |
+
past_key_value=past_key_value,
|
772 |
+
output_attentions=output_attentions,
|
773 |
+
use_cache=use_cache,
|
774 |
+
cache_position=cache_position,
|
775 |
+
)
|
776 |
+
|
777 |
+
bsz, q_len, _ = hidden_states.size()
|
778 |
+
|
779 |
+
query_states = self.q_proj(hidden_states)
|
780 |
+
key_states = self.k_proj(hidden_states)
|
781 |
+
value_states = self.v_proj(hidden_states)
|
782 |
+
|
783 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
784 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
785 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
786 |
+
|
787 |
+
cos, sin = self.rotary_emb(value_states, position_ids)
|
788 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
789 |
+
|
790 |
+
# In case static cache is used, it is an instance attribute.
|
791 |
+
past_key_value = getattr(self, "past_key_value", past_key_value)
|
792 |
+
|
793 |
+
if past_key_value is not None:
|
794 |
+
# sin and cos are specific to RoPE models; position_ids needed for the static cache
|
795 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
796 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
797 |
+
|
798 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
799 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
800 |
+
|
801 |
+
causal_mask = attention_mask
|
802 |
+
if attention_mask is not None and cache_position is not None:
|
803 |
+
causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]
|
804 |
+
|
805 |
+
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
806 |
+
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
807 |
+
if query_states.device.type == "cuda" and causal_mask is not None:
|
808 |
+
query_states = query_states.contiguous()
|
809 |
+
key_states = key_states.contiguous()
|
810 |
+
value_states = value_states.contiguous()
|
811 |
+
|
812 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
813 |
+
query_states,
|
814 |
+
key_states,
|
815 |
+
value_states,
|
816 |
+
attn_mask=causal_mask,
|
817 |
+
dropout_p=self.attention_dropout if self.training else 0.0,
|
818 |
+
)
|
819 |
+
|
820 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
821 |
+
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
822 |
+
|
823 |
+
attn_output = self.o_proj(attn_output)
|
824 |
+
|
825 |
+
return attn_output, None, past_key_value
|
826 |
+
|
827 |
+
|
828 |
+
try:
|
829 |
+
from ring_flash_attn import zigzag_ring_flash_attn_qkvpacked_func, zigzag_ring_flash_attn_varlen_func
|
830 |
+
except ImportError:
|
831 |
+
print("Please install the ring-flash-attn package")
|
832 |
+
|
833 |
+
LLAMA_ATTENTION_CLASSES = {
|
834 |
+
"eager": LlamaAttention,
|
835 |
+
"flash_attention_2": LlamaFlashAttention2,
|
836 |
+
"ring_flash_attention_2": LlamaRingFlashAttention2,
|
837 |
+
"sdpa": LlamaSdpaAttention,
|
838 |
+
}
|
839 |
+
|
840 |
+
|
841 |
+
class LlamaDecoderLayer(nn.Module):
|
842 |
+
def __init__(self, config: LlamaConfig, layer_idx: int):
|
843 |
+
super().__init__()
|
844 |
+
self.hidden_size = config.hidden_size
|
845 |
+
|
846 |
+
self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
847 |
+
|
848 |
+
self.mlp = LlamaMLP(config)
|
849 |
+
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
850 |
+
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
851 |
+
|
852 |
+
def forward(
|
853 |
+
self,
|
854 |
+
hidden_states: torch.Tensor,
|
855 |
+
attention_mask: Optional[torch.Tensor] = None,
|
856 |
+
position_ids: Optional[torch.LongTensor] = None,
|
857 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
858 |
+
output_attentions: Optional[bool] = False,
|
859 |
+
use_cache: Optional[bool] = False,
|
860 |
+
cache_position: Optional[torch.LongTensor] = None,
|
861 |
+
**kwargs,
|
862 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
863 |
+
"""
|
864 |
+
Args:
|
865 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
866 |
+
attention_mask (`torch.FloatTensor`, *optional*):
|
867 |
+
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
|
868 |
+
query_sequence_length, key_sequence_length)` if default attention is used.
|
869 |
+
output_attentions (`bool`, *optional*):
|
870 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
871 |
+
returned tensors for more detail.
|
872 |
+
use_cache (`bool`, *optional*):
|
873 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
874 |
+
(see `past_key_values`).
|
875 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
876 |
+
"""
|
877 |
+
if "padding_mask" in kwargs:
|
878 |
+
warnings.warn("Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`")
|
879 |
+
|
880 |
+
residual = hidden_states
|
881 |
+
|
882 |
+
hidden_states = self.input_layernorm(hidden_states)
|
883 |
+
|
884 |
+
# Self Attention
|
885 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
886 |
+
hidden_states=hidden_states,
|
887 |
+
attention_mask=attention_mask,
|
888 |
+
position_ids=position_ids,
|
889 |
+
past_key_value=past_key_value,
|
890 |
+
output_attentions=output_attentions,
|
891 |
+
use_cache=use_cache,
|
892 |
+
cache_position=cache_position,
|
893 |
+
**kwargs,
|
894 |
+
)
|
895 |
+
hidden_states = residual + hidden_states
|
896 |
+
|
897 |
+
# Fully Connected
|
898 |
+
residual = hidden_states
|
899 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
900 |
+
hidden_states = self.mlp(hidden_states)
|
901 |
+
hidden_states = residual + hidden_states
|
902 |
+
|
903 |
+
outputs = (hidden_states,)
|
904 |
+
|
905 |
+
if output_attentions:
|
906 |
+
outputs += (self_attn_weights,)
|
907 |
+
|
908 |
+
if use_cache:
|
909 |
+
outputs += (present_key_value,)
|
910 |
+
|
911 |
+
return outputs
|
912 |
+
|
913 |
+
|
914 |
+
LLAMA_START_DOCSTRING = r"""
|
915 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
916 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
917 |
+
etc.)
|
918 |
+
|
919 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
920 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
921 |
+
and behavior.
|
922 |
+
|
923 |
+
Parameters:
|
924 |
+
config ([`LlamaConfig`]):
|
925 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
926 |
+
load the weights associated with the model, only the configuration. Check out the
|
927 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
928 |
+
"""
|
929 |
+
|
930 |
+
|
931 |
+
@add_start_docstrings(
|
932 |
+
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
933 |
+
LLAMA_START_DOCSTRING,
|
934 |
+
)
|
935 |
+
class LlamaPreTrainedModel(PreTrainedModel):
|
936 |
+
config_class = LlamaConfig
|
937 |
+
base_model_prefix = "model"
|
938 |
+
supports_gradient_checkpointing = True
|
939 |
+
_no_split_modules = ["LlamaDecoderLayer"]
|
940 |
+
_skip_keys_device_placement = ["past_key_values", "causal_mask"]
|
941 |
+
_supports_flash_attn_2 = True
|
942 |
+
_supports_sdpa = True
|
943 |
+
_supports_cache_class = True
|
944 |
+
|
945 |
+
def _init_weights(self, module):
|
946 |
+
std = self.config.initializer_range
|
947 |
+
if isinstance(module, nn.Linear):
|
948 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
949 |
+
if module.bias is not None:
|
950 |
+
module.bias.data.zero_()
|
951 |
+
elif isinstance(module, nn.Embedding):
|
952 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
953 |
+
if module.padding_idx is not None:
|
954 |
+
module.weight.data[module.padding_idx].zero_()
|
955 |
+
|
956 |
+
def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None):
|
957 |
+
if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache:
|
958 |
+
raise ValueError("`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers")
|
959 |
+
|
960 |
+
if max_cache_len > self.model.causal_mask.shape[-1] or self.device != self.model.causal_mask.device:
|
961 |
+
causal_mask = torch.full((max_cache_len, max_cache_len), fill_value=True, device=self.device, dtype=torch.bool)
|
962 |
+
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
|
963 |
+
|
964 |
+
for layer in self.model.layers:
|
965 |
+
device = layer.input_layernorm.weight.device
|
966 |
+
if hasattr(self.config, "_pre_quantization_dtype"):
|
967 |
+
dtype = self.config._pre_quantization_dtype
|
968 |
+
else:
|
969 |
+
dtype = layer.self_attn.o_proj.weight.dtype
|
970 |
+
layer.self_attn.past_key_value = cache_cls(self.config, max_batch_size, max_cache_len, device=device, dtype=dtype)
|
971 |
+
|
972 |
+
def _reset_cache(self):
|
973 |
+
for layer in self.model.layers:
|
974 |
+
layer.self_attn.past_key_value = None
|
975 |
+
|
976 |
+
|
977 |
+
LLAMA_INPUTS_DOCSTRING = r"""
|
978 |
+
Args:
|
979 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
980 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
981 |
+
it.
|
982 |
+
|
983 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
984 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
985 |
+
|
986 |
+
[What are input IDs?](../glossary#input-ids)
|
987 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
988 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
989 |
+
|
990 |
+
- 1 for tokens that are **not masked**,
|
991 |
+
- 0 for tokens that are **masked**.
|
992 |
+
|
993 |
+
[What are attention masks?](../glossary#attention-mask)
|
994 |
+
|
995 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
996 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
997 |
+
|
998 |
+
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
|
999 |
+
`past_key_values`).
|
1000 |
+
|
1001 |
+
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
1002 |
+
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
1003 |
+
information on the default strategy.
|
1004 |
+
|
1005 |
+
- 1 indicates the head is **not masked**,
|
1006 |
+
- 0 indicates the head is **masked**.
|
1007 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1008 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
1009 |
+
config.n_positions - 1]`.
|
1010 |
+
|
1011 |
+
[What are position IDs?](../glossary#position-ids)
|
1012 |
+
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
1013 |
+
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
1014 |
+
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
1015 |
+
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
1016 |
+
|
1017 |
+
Two formats are allowed:
|
1018 |
+
- a [`~cache_utils.Cache`] instance;
|
1019 |
+
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
1020 |
+
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
1021 |
+
cache format.
|
1022 |
+
|
1023 |
+
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
1024 |
+
legacy cache format will be returned.
|
1025 |
+
|
1026 |
+
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
1027 |
+
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
1028 |
+
of shape `(batch_size, sequence_length)`.
|
1029 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
1030 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
1031 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
1032 |
+
model's internal embedding lookup matrix.
|
1033 |
+
use_cache (`bool`, *optional*):
|
1034 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
1035 |
+
`past_key_values`).
|
1036 |
+
output_attentions (`bool`, *optional*):
|
1037 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
1038 |
+
tensors for more detail.
|
1039 |
+
output_hidden_states (`bool`, *optional*):
|
1040 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
1041 |
+
more detail.
|
1042 |
+
return_dict (`bool`, *optional*):
|
1043 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
1044 |
+
"""
|
1045 |
+
|
1046 |
+
|
1047 |
+
@add_start_docstrings(
|
1048 |
+
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
1049 |
+
LLAMA_START_DOCSTRING,
|
1050 |
+
)
|
1051 |
+
class LlamaModel(LlamaPreTrainedModel):
|
1052 |
+
"""
|
1053 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
|
1054 |
+
|
1055 |
+
Args:
|
1056 |
+
config: LlamaConfig
|
1057 |
+
"""
|
1058 |
+
|
1059 |
+
def __init__(self, config: LlamaConfig):
|
1060 |
+
super().__init__(config)
|
1061 |
+
self.padding_idx = config.pad_token_id
|
1062 |
+
self.vocab_size = config.vocab_size
|
1063 |
+
|
1064 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
1065 |
+
self.layers = nn.ModuleList([LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
|
1066 |
+
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
1067 |
+
self.gradient_checkpointing = False
|
1068 |
+
|
1069 |
+
# Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
|
1070 |
+
# NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_position_embeddings`.
|
1071 |
+
causal_mask = torch.full((config.max_position_embeddings, config.max_position_embeddings), fill_value=True, dtype=torch.bool)
|
1072 |
+
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
|
1073 |
+
# Initialize weights and apply final processing
|
1074 |
+
self.post_init()
|
1075 |
+
|
1076 |
+
def get_input_embeddings(self):
|
1077 |
+
return self.embed_tokens
|
1078 |
+
|
1079 |
+
def set_input_embeddings(self, value):
|
1080 |
+
self.embed_tokens = value
|
1081 |
+
|
1082 |
+
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
1083 |
+
def forward(
|
1084 |
+
self,
|
1085 |
+
input_ids: torch.LongTensor = None,
|
1086 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1087 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1088 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
1089 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1090 |
+
use_cache: Optional[bool] = None,
|
1091 |
+
output_attentions: Optional[bool] = None,
|
1092 |
+
output_hidden_states: Optional[bool] = None,
|
1093 |
+
return_dict: Optional[bool] = None,
|
1094 |
+
cache_position: Optional[torch.LongTensor] = None,
|
1095 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
1096 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1097 |
+
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1098 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
1099 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1100 |
+
|
1101 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
1102 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one")
|
1103 |
+
|
1104 |
+
if self.gradient_checkpointing and self.training and use_cache:
|
1105 |
+
logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.")
|
1106 |
+
use_cache = False
|
1107 |
+
|
1108 |
+
if inputs_embeds is None:
|
1109 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
1110 |
+
|
1111 |
+
past_seen_tokens = 0
|
1112 |
+
if use_cache: # kept for BC (cache positions)
|
1113 |
+
if not isinstance(past_key_values, StaticCache):
|
1114 |
+
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
1115 |
+
past_seen_tokens = past_key_values.get_seq_length()
|
1116 |
+
|
1117 |
+
if cache_position is None:
|
1118 |
+
if isinstance(past_key_values, StaticCache):
|
1119 |
+
raise ValueError("cache_position is a required argument when using StaticCache.")
|
1120 |
+
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device)
|
1121 |
+
|
1122 |
+
if position_ids is None:
|
1123 |
+
position_ids = cache_position.unsqueeze(0)
|
1124 |
+
|
1125 |
+
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds)
|
1126 |
+
|
1127 |
+
# embed positions
|
1128 |
+
hidden_states = inputs_embeds
|
1129 |
+
|
1130 |
+
# decoder layers
|
1131 |
+
all_hidden_states = () if output_hidden_states else None
|
1132 |
+
all_self_attns = () if output_attentions else None
|
1133 |
+
next_decoder_cache = None
|
1134 |
+
|
1135 |
+
for decoder_layer in self.layers:
|
1136 |
+
if output_hidden_states:
|
1137 |
+
all_hidden_states += (hidden_states,)
|
1138 |
+
|
1139 |
+
if self.gradient_checkpointing and self.training:
|
1140 |
+
layer_outputs = self._gradient_checkpointing_func(
|
1141 |
+
decoder_layer.__call__,
|
1142 |
+
hidden_states,
|
1143 |
+
causal_mask,
|
1144 |
+
position_ids,
|
1145 |
+
past_key_values,
|
1146 |
+
output_attentions,
|
1147 |
+
use_cache,
|
1148 |
+
cache_position,
|
1149 |
+
)
|
1150 |
+
else:
|
1151 |
+
layer_outputs = decoder_layer(
|
1152 |
+
hidden_states,
|
1153 |
+
attention_mask=causal_mask,
|
1154 |
+
position_ids=position_ids,
|
1155 |
+
past_key_value=past_key_values,
|
1156 |
+
output_attentions=output_attentions,
|
1157 |
+
use_cache=use_cache,
|
1158 |
+
cache_position=cache_position,
|
1159 |
+
)
|
1160 |
+
|
1161 |
+
hidden_states = layer_outputs[0]
|
1162 |
+
|
1163 |
+
if use_cache:
|
1164 |
+
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
1165 |
+
|
1166 |
+
if output_attentions:
|
1167 |
+
all_self_attns += (layer_outputs[1],)
|
1168 |
+
|
1169 |
+
hidden_states = self.norm(hidden_states)
|
1170 |
+
|
1171 |
+
# add hidden states from the last decoder layer
|
1172 |
+
if output_hidden_states:
|
1173 |
+
all_hidden_states += (hidden_states,)
|
1174 |
+
|
1175 |
+
next_cache = None
|
1176 |
+
if use_cache:
|
1177 |
+
next_cache = next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
|
1178 |
+
if not return_dict:
|
1179 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
1180 |
+
return BaseModelOutputWithPast(
|
1181 |
+
last_hidden_state=hidden_states,
|
1182 |
+
past_key_values=next_cache,
|
1183 |
+
hidden_states=all_hidden_states,
|
1184 |
+
attentions=all_self_attns,
|
1185 |
+
)
|
1186 |
+
|
1187 |
+
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
1188 |
+
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
1189 |
+
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
1190 |
+
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
1191 |
+
def _update_causal_mask(self, attention_mask, input_tensor):
|
1192 |
+
if self.config._attn_implementation == "flash_attention_2":
|
1193 |
+
if attention_mask is not None and 0.0 in attention_mask:
|
1194 |
+
return attention_mask
|
1195 |
+
return None
|
1196 |
+
|
1197 |
+
batch_size, seq_length = input_tensor.shape[:2]
|
1198 |
+
dtype = input_tensor.dtype
|
1199 |
+
device = input_tensor.device
|
1200 |
+
|
1201 |
+
# support going beyond cached `max_position_embedding`
|
1202 |
+
if seq_length > self.causal_mask.shape[-1]:
|
1203 |
+
causal_mask = torch.full((2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]), fill_value=1)
|
1204 |
+
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
|
1205 |
+
|
1206 |
+
# We use the current dtype to avoid any overflows
|
1207 |
+
min_dtype = torch.finfo(dtype).min
|
1208 |
+
causal_mask = self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * min_dtype
|
1209 |
+
|
1210 |
+
causal_mask = causal_mask.to(dtype=dtype, device=device)
|
1211 |
+
if attention_mask is not None and attention_mask.dim() == 2:
|
1212 |
+
mask_length = attention_mask.shape[-1]
|
1213 |
+
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
|
1214 |
+
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
|
1215 |
+
|
1216 |
+
if self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type == "cuda":
|
1217 |
+
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
|
1218 |
+
is_tracing = torch.jit.is_tracing() or isinstance(input_tensor, torch.fx.Proxy) or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
|
1219 |
+
if not is_tracing and torch.any(attention_mask != 1):
|
1220 |
+
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
1221 |
+
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
1222 |
+
# Details: https://github.com/pytorch/pytorch/issues/110213
|
1223 |
+
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
1224 |
+
|
1225 |
+
return causal_mask
|
1226 |
+
|
1227 |
+
|
1228 |
+
class LlamaForCausalLM(LlamaPreTrainedModel):
|
1229 |
+
_tied_weights_keys = ["lm_head.weight"]
|
1230 |
+
|
1231 |
+
def __init__(self, config):
|
1232 |
+
super().__init__(config)
|
1233 |
+
self.model = LlamaModel(config)
|
1234 |
+
self.vocab_size = config.vocab_size
|
1235 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
1236 |
+
|
1237 |
+
# Initialize weights and apply final processing
|
1238 |
+
self.post_init()
|
1239 |
+
|
1240 |
+
def get_input_embeddings(self):
|
1241 |
+
return self.model.embed_tokens
|
1242 |
+
|
1243 |
+
def set_input_embeddings(self, value):
|
1244 |
+
self.model.embed_tokens = value
|
1245 |
+
|
1246 |
+
def get_output_embeddings(self):
|
1247 |
+
return self.lm_head
|
1248 |
+
|
1249 |
+
def set_output_embeddings(self, new_embeddings):
|
1250 |
+
self.lm_head = new_embeddings
|
1251 |
+
|
1252 |
+
def set_decoder(self, decoder):
|
1253 |
+
self.model = decoder
|
1254 |
+
|
1255 |
+
def get_decoder(self):
|
1256 |
+
return self.model
|
1257 |
+
|
1258 |
+
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
1259 |
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
1260 |
+
def forward(
|
1261 |
+
self,
|
1262 |
+
input_ids: torch.LongTensor = None,
|
1263 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1264 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1265 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
1266 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1267 |
+
labels: Optional[torch.LongTensor] = None,
|
1268 |
+
use_cache: Optional[bool] = None,
|
1269 |
+
output_attentions: Optional[bool] = None,
|
1270 |
+
output_hidden_states: Optional[bool] = None,
|
1271 |
+
return_dict: Optional[bool] = None,
|
1272 |
+
cache_position: Optional[torch.LongTensor] = None,
|
1273 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
1274 |
+
r"""
|
1275 |
+
Args:
|
1276 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1277 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
1278 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
1279 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
1280 |
+
|
1281 |
+
Returns:
|
1282 |
+
|
1283 |
+
Example:
|
1284 |
+
|
1285 |
+
```python
|
1286 |
+
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
1287 |
+
|
1288 |
+
>>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
|
1289 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
|
1290 |
+
|
1291 |
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
1292 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
1293 |
+
|
1294 |
+
>>> # Generate
|
1295 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
1296 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
1297 |
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
1298 |
+
```"""
|
1299 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1300 |
+
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1301 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1302 |
+
|
1303 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
1304 |
+
outputs = self.model(
|
1305 |
+
input_ids=input_ids,
|
1306 |
+
attention_mask=attention_mask,
|
1307 |
+
position_ids=position_ids,
|
1308 |
+
past_key_values=past_key_values,
|
1309 |
+
inputs_embeds=inputs_embeds,
|
1310 |
+
use_cache=use_cache,
|
1311 |
+
output_attentions=output_attentions,
|
1312 |
+
output_hidden_states=output_hidden_states,
|
1313 |
+
return_dict=return_dict,
|
1314 |
+
cache_position=cache_position,
|
1315 |
+
)
|
1316 |
+
|
1317 |
+
hidden_states = outputs[0]
|
1318 |
+
if self.config.pretraining_tp > 1:
|
1319 |
+
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
|
1320 |
+
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
|
1321 |
+
logits = torch.cat(logits, dim=-1)
|
1322 |
+
else:
|
1323 |
+
logits = self.lm_head(hidden_states)
|
1324 |
+
logits = logits.float()
|
1325 |
+
|
1326 |
+
loss = None
|
1327 |
+
if labels is not None:
|
1328 |
+
# Shift so that tokens < n predict n
|
1329 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
1330 |
+
shift_labels = labels[..., 1:].contiguous()
|
1331 |
+
# Flatten the tokens
|
1332 |
+
loss_fct = CrossEntropyLoss()
|
1333 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
1334 |
+
shift_labels = shift_labels.view(-1)
|
1335 |
+
# Enable model parallelism
|
1336 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
1337 |
+
loss = loss_fct(shift_logits, shift_labels)
|
1338 |
+
|
1339 |
+
if not return_dict:
|
1340 |
+
output = (logits,) + outputs[1:]
|
1341 |
+
return (loss,) + output if loss is not None else output
|
1342 |
+
|
1343 |
+
return CausalLMOutputWithPast(
|
1344 |
+
loss=loss,
|
1345 |
+
logits=logits,
|
1346 |
+
past_key_values=outputs.past_key_values,
|
1347 |
+
hidden_states=outputs.hidden_states,
|
1348 |
+
attentions=outputs.attentions,
|
1349 |
+
)
|
1350 |
+
|
1351 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs):
|
1352 |
+
past_length = 0
|
1353 |
+
if past_key_values is not None:
|
1354 |
+
if isinstance(past_key_values, Cache):
|
1355 |
+
cache_length = past_key_values.get_seq_length()
|
1356 |
+
past_length = past_key_values.seen_tokens
|
1357 |
+
max_cache_length = past_key_values.get_max_length()
|
1358 |
+
else:
|
1359 |
+
cache_length = past_length = past_key_values[0][0].shape[2]
|
1360 |
+
max_cache_length = None
|
1361 |
+
|
1362 |
+
# Keep only the unprocessed tokens:
|
1363 |
+
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
1364 |
+
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
1365 |
+
# input)
|
1366 |
+
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
1367 |
+
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
1368 |
+
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
1369 |
+
# input_ids based on the past_length.
|
1370 |
+
elif past_length < input_ids.shape[1]:
|
1371 |
+
input_ids = input_ids[:, past_length:]
|
1372 |
+
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
1373 |
+
|
1374 |
+
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
1375 |
+
if max_cache_length is not None and attention_mask is not None and cache_length + input_ids.shape[1] > max_cache_length:
|
1376 |
+
attention_mask = attention_mask[:, -max_cache_length:]
|
1377 |
+
|
1378 |
+
position_ids = kwargs.get("position_ids", None)
|
1379 |
+
if attention_mask is not None and position_ids is None:
|
1380 |
+
# create position_ids on the fly for batch generation
|
1381 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
1382 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
1383 |
+
if past_key_values:
|
1384 |
+
position_ids = position_ids[:, -input_ids.shape[1] :]
|
1385 |
+
|
1386 |
+
if self.generation_config.cache_implementation == "static":
|
1387 |
+
# generation with static cache
|
1388 |
+
cache_position = kwargs.get("cache_position", None)
|
1389 |
+
if cache_position is None:
|
1390 |
+
past_length = 0
|
1391 |
+
else:
|
1392 |
+
past_length = cache_position[-1] + 1
|
1393 |
+
input_ids = input_ids[:, past_length:]
|
1394 |
+
position_ids = position_ids[:, past_length:]
|
1395 |
+
|
1396 |
+
# TODO @gante we should only keep a `cache_position` in generate, and do +=1.
|
1397 |
+
# same goes for position ids. Could also help with continued generation.
|
1398 |
+
input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
|
1399 |
+
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
|
1400 |
+
position_ids = position_ids.contiguous() if position_ids is not None else None
|
1401 |
+
|
1402 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
1403 |
+
if inputs_embeds is not None and past_key_values is None:
|
1404 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
1405 |
+
else:
|
1406 |
+
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
|
1407 |
+
# recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
|
1408 |
+
# TODO: use `next_tokens` directly instead.
|
1409 |
+
model_inputs = {"input_ids": input_ids.contiguous()}
|
1410 |
+
|
1411 |
+
model_inputs.update(
|
1412 |
+
{
|
1413 |
+
"position_ids": position_ids,
|
1414 |
+
"cache_position": cache_position,
|
1415 |
+
"past_key_values": past_key_values,
|
1416 |
+
"use_cache": kwargs.get("use_cache"),
|
1417 |
+
"attention_mask": attention_mask,
|
1418 |
+
}
|
1419 |
+
)
|
1420 |
+
return model_inputs
|
1421 |
+
|
1422 |
+
@staticmethod
|
1423 |
+
def _reorder_cache(past_key_values, beam_idx):
|
1424 |
+
reordered_past = ()
|
1425 |
+
for layer_past in past_key_values:
|
1426 |
+
reordered_past += (tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),)
|
1427 |
+
return reordered_past
|
1428 |
+
|
1429 |
+
|
1430 |
+
@add_start_docstrings(
|
1431 |
+
"""
|
1432 |
+
The LLaMa Model transformer with a sequence classification head on top (linear layer).
|
1433 |
+
|
1434 |
+
[`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
1435 |
+
(e.g. GPT-2) do.
|
1436 |
+
|
1437 |
+
Since it does classification on the last token, it requires to know the position of the last token. If a
|
1438 |
+
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
1439 |
+
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
1440 |
+
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
1441 |
+
each row of the batch).
|
1442 |
+
""",
|
1443 |
+
LLAMA_START_DOCSTRING,
|
1444 |
+
)
|
1445 |
+
class LlamaForSequenceClassification(LlamaPreTrainedModel):
|
1446 |
+
def __init__(self, config):
|
1447 |
+
super().__init__(config)
|
1448 |
+
self.num_labels = config.num_labels
|
1449 |
+
self.model = LlamaModel(config)
|
1450 |
+
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
1451 |
+
|
1452 |
+
# Initialize weights and apply final processing
|
1453 |
+
self.post_init()
|
1454 |
+
|
1455 |
+
def get_input_embeddings(self):
|
1456 |
+
return self.model.embed_tokens
|
1457 |
+
|
1458 |
+
def set_input_embeddings(self, value):
|
1459 |
+
self.model.embed_tokens = value
|
1460 |
+
|
1461 |
+
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
1462 |
+
def forward(
|
1463 |
+
self,
|
1464 |
+
input_ids: torch.LongTensor = None,
|
1465 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1466 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1467 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
1468 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1469 |
+
labels: Optional[torch.LongTensor] = None,
|
1470 |
+
use_cache: Optional[bool] = None,
|
1471 |
+
output_attentions: Optional[bool] = None,
|
1472 |
+
output_hidden_states: Optional[bool] = None,
|
1473 |
+
return_dict: Optional[bool] = None,
|
1474 |
+
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
1475 |
+
r"""
|
1476 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1477 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
1478 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
1479 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
1480 |
+
"""
|
1481 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1482 |
+
|
1483 |
+
transformer_outputs = self.model(
|
1484 |
+
input_ids,
|
1485 |
+
attention_mask=attention_mask,
|
1486 |
+
position_ids=position_ids,
|
1487 |
+
past_key_values=past_key_values,
|
1488 |
+
inputs_embeds=inputs_embeds,
|
1489 |
+
use_cache=use_cache,
|
1490 |
+
output_attentions=output_attentions,
|
1491 |
+
output_hidden_states=output_hidden_states,
|
1492 |
+
return_dict=return_dict,
|
1493 |
+
)
|
1494 |
+
hidden_states = transformer_outputs[0]
|
1495 |
+
logits = self.score(hidden_states)
|
1496 |
+
|
1497 |
+
if input_ids is not None:
|
1498 |
+
batch_size = input_ids.shape[0]
|
1499 |
+
else:
|
1500 |
+
batch_size = inputs_embeds.shape[0]
|
1501 |
+
|
1502 |
+
if self.config.pad_token_id is None and batch_size != 1:
|
1503 |
+
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
1504 |
+
if self.config.pad_token_id is None:
|
1505 |
+
sequence_lengths = -1
|
1506 |
+
else:
|
1507 |
+
if input_ids is not None:
|
1508 |
+
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
1509 |
+
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
1510 |
+
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
1511 |
+
sequence_lengths = sequence_lengths.to(logits.device)
|
1512 |
+
else:
|
1513 |
+
sequence_lengths = -1
|
1514 |
+
|
1515 |
+
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
1516 |
+
|
1517 |
+
loss = None
|
1518 |
+
if labels is not None:
|
1519 |
+
labels = labels.to(logits.device)
|
1520 |
+
if self.config.problem_type is None:
|
1521 |
+
if self.num_labels == 1:
|
1522 |
+
self.config.problem_type = "regression"
|
1523 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
1524 |
+
self.config.problem_type = "single_label_classification"
|
1525 |
+
else:
|
1526 |
+
self.config.problem_type = "multi_label_classification"
|
1527 |
+
|
1528 |
+
if self.config.problem_type == "regression":
|
1529 |
+
loss_fct = MSELoss()
|
1530 |
+
if self.num_labels == 1:
|
1531 |
+
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
1532 |
+
else:
|
1533 |
+
loss = loss_fct(pooled_logits, labels)
|
1534 |
+
elif self.config.problem_type == "single_label_classification":
|
1535 |
+
loss_fct = CrossEntropyLoss()
|
1536 |
+
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
1537 |
+
elif self.config.problem_type == "multi_label_classification":
|
1538 |
+
loss_fct = BCEWithLogitsLoss()
|
1539 |
+
loss = loss_fct(pooled_logits, labels)
|
1540 |
+
if not return_dict:
|
1541 |
+
output = (pooled_logits,) + transformer_outputs[1:]
|
1542 |
+
return ((loss,) + output) if loss is not None else output
|
1543 |
+
|
1544 |
+
return SequenceClassifierOutputWithPast(
|
1545 |
+
loss=loss,
|
1546 |
+
logits=pooled_logits,
|
1547 |
+
past_key_values=transformer_outputs.past_key_values,
|
1548 |
+
hidden_states=transformer_outputs.hidden_states,
|
1549 |
+
attentions=transformer_outputs.attentions,
|
1550 |
+
)
|
1551 |
+
|
1552 |
+
|
1553 |
+
@add_start_docstrings(
|
1554 |
+
"""
|
1555 |
+
The Llama Model transformer with a span classification head on top for extractive question-answering tasks like
|
1556 |
+
SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
1557 |
+
""",
|
1558 |
+
LLAMA_START_DOCSTRING,
|
1559 |
+
)
|
1560 |
+
class LlamaForQuestionAnswering(LlamaPreTrainedModel):
|
1561 |
+
base_model_prefix = "transformer"
|
1562 |
+
|
1563 |
+
# Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama
|
1564 |
+
def __init__(self, config):
|
1565 |
+
super().__init__(config)
|
1566 |
+
self.transformer = LlamaModel(config)
|
1567 |
+
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
1568 |
+
|
1569 |
+
# Initialize weights and apply final processing
|
1570 |
+
self.post_init()
|
1571 |
+
|
1572 |
+
def get_input_embeddings(self):
|
1573 |
+
return self.transformer.embed_tokens
|
1574 |
+
|
1575 |
+
def set_input_embeddings(self, value):
|
1576 |
+
self.transformer.embed_tokens = value
|
1577 |
+
|
1578 |
+
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
1579 |
+
def forward(
|
1580 |
+
self,
|
1581 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1582 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1583 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1584 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
1585 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1586 |
+
start_positions: Optional[torch.LongTensor] = None,
|
1587 |
+
end_positions: Optional[torch.LongTensor] = None,
|
1588 |
+
output_attentions: Optional[bool] = None,
|
1589 |
+
output_hidden_states: Optional[bool] = None,
|
1590 |
+
return_dict: Optional[bool] = None,
|
1591 |
+
) -> Union[Tuple, QuestionAnsweringModelOutput]:
|
1592 |
+
r"""
|
1593 |
+
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1594 |
+
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
1595 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
1596 |
+
are not taken into account for computing the loss.
|
1597 |
+
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1598 |
+
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
1599 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
1600 |
+
are not taken into account for computing the loss.
|
1601 |
+
"""
|
1602 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1603 |
+
|
1604 |
+
outputs = self.transformer(
|
1605 |
+
input_ids,
|
1606 |
+
attention_mask=attention_mask,
|
1607 |
+
position_ids=position_ids,
|
1608 |
+
past_key_values=past_key_values,
|
1609 |
+
inputs_embeds=inputs_embeds,
|
1610 |
+
output_attentions=output_attentions,
|
1611 |
+
output_hidden_states=output_hidden_states,
|
1612 |
+
return_dict=return_dict,
|
1613 |
+
)
|
1614 |
+
|
1615 |
+
sequence_output = outputs[0]
|
1616 |
+
|
1617 |
+
logits = self.qa_outputs(sequence_output)
|
1618 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
1619 |
+
start_logits = start_logits.squeeze(-1).contiguous()
|
1620 |
+
end_logits = end_logits.squeeze(-1).contiguous()
|
1621 |
+
|
1622 |
+
total_loss = None
|
1623 |
+
if start_positions is not None and end_positions is not None:
|
1624 |
+
# If we are on multi-GPU, split add a dimension
|
1625 |
+
if len(start_positions.size()) > 1:
|
1626 |
+
start_positions = start_positions.squeeze(-1).to(start_logits.device)
|
1627 |
+
if len(end_positions.size()) > 1:
|
1628 |
+
end_positions = end_positions.squeeze(-1).to(end_logits.device)
|
1629 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
1630 |
+
ignored_index = start_logits.size(1)
|
1631 |
+
start_positions = start_positions.clamp(0, ignored_index)
|
1632 |
+
end_positions = end_positions.clamp(0, ignored_index)
|
1633 |
+
|
1634 |
+
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
1635 |
+
start_loss = loss_fct(start_logits, start_positions)
|
1636 |
+
end_loss = loss_fct(end_logits, end_positions)
|
1637 |
+
total_loss = (start_loss + end_loss) / 2
|
1638 |
+
|
1639 |
+
if not return_dict:
|
1640 |
+
output = (start_logits, end_logits) + outputs[2:]
|
1641 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
1642 |
+
|
1643 |
+
return QuestionAnsweringModelOutput(
|
1644 |
+
loss=total_loss,
|
1645 |
+
start_logits=start_logits,
|
1646 |
+
end_logits=end_logits,
|
1647 |
+
hidden_states=outputs.hidden_states,
|
1648 |
+
attentions=outputs.attentions,
|
1649 |
+
)
|
llava/model/llava_arch.py
ADDED
@@ -0,0 +1,509 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Haotian Liu
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from abc import ABC, abstractmethod
|
17 |
+
|
18 |
+
import math
|
19 |
+
import re
|
20 |
+
import time
|
21 |
+
import torch
|
22 |
+
import torch.nn as nn
|
23 |
+
from .multimodal_encoder.builder import build_vision_tower
|
24 |
+
from .multimodal_resampler.builder import build_vision_resampler
|
25 |
+
from .multimodal_projector.builder import build_vision_projector
|
26 |
+
|
27 |
+
from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
28 |
+
|
29 |
+
from llava.mm_utils import get_anyres_image_grid_shape
|
30 |
+
from llava.utils import rank0_print, rank_print
|
31 |
+
import random
|
32 |
+
|
33 |
+
|
34 |
+
class LlavaMetaModel:
|
35 |
+
|
36 |
+
def __init__(self, config):
|
37 |
+
super(LlavaMetaModel, self).__init__(config)
|
38 |
+
|
39 |
+
if hasattr(config, "mm_vision_tower"):
|
40 |
+
delay_load = getattr(config, "delay_load", False)
|
41 |
+
self.vision_tower = build_vision_tower(config, delay_load=delay_load)
|
42 |
+
self.vision_resampler = build_vision_resampler(config, vision_tower=self.vision_tower)
|
43 |
+
self.mm_projector = build_vision_projector(config, vision_cfg=self.vision_tower.config)
|
44 |
+
|
45 |
+
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
46 |
+
self.image_newline = nn.Parameter(torch.empty(config.hidden_size, dtype=self.dtype))
|
47 |
+
|
48 |
+
def get_vision_tower(self):
|
49 |
+
vision_tower = getattr(self, "vision_tower", None)
|
50 |
+
if type(vision_tower) is list:
|
51 |
+
vision_tower = vision_tower[0]
|
52 |
+
return vision_tower
|
53 |
+
|
54 |
+
def initialize_vision_modules(self, model_args, fsdp=None):
|
55 |
+
vision_tower = model_args.vision_tower
|
56 |
+
mm_vision_select_layer = model_args.mm_vision_select_layer
|
57 |
+
mm_vision_select_feature = model_args.mm_vision_select_feature
|
58 |
+
pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
|
59 |
+
mm_patch_merge_type = model_args.mm_patch_merge_type
|
60 |
+
|
61 |
+
self.config.mm_vision_tower = vision_tower
|
62 |
+
self.config.vision_tower_pretrained = getattr(model_args, "vision_tower_pretrained", "")
|
63 |
+
|
64 |
+
if self.get_vision_tower() is None:
|
65 |
+
vision_tower = build_vision_tower(model_args)
|
66 |
+
vision_resampler = build_vision_resampler(model_args, vision_tower=vision_tower)
|
67 |
+
for k, v in vision_resampler.config.items():
|
68 |
+
setattr(self.config, k, v)
|
69 |
+
|
70 |
+
if fsdp is not None and len(fsdp) > 0:
|
71 |
+
self.vision_tower = [vision_tower]
|
72 |
+
self.vision_resampler = [vision_resampler]
|
73 |
+
else:
|
74 |
+
self.vision_tower = vision_tower
|
75 |
+
self.vision_resampler = vision_resampler
|
76 |
+
else:
|
77 |
+
if fsdp is not None and len(fsdp) > 0:
|
78 |
+
vision_resampler = self.vision_resampler[0]
|
79 |
+
vision_tower = self.vision_tower[0]
|
80 |
+
else:
|
81 |
+
vision_resampler = self.vision_resampler
|
82 |
+
vision_tower = self.vision_tower
|
83 |
+
vision_tower.load_model()
|
84 |
+
|
85 |
+
# In case it is frozen by LoRA
|
86 |
+
for p in self.vision_resampler.parameters():
|
87 |
+
p.requires_grad = True
|
88 |
+
|
89 |
+
self.config.use_mm_proj = True
|
90 |
+
self.config.mm_projector_type = getattr(model_args, "mm_projector_type", "linear")
|
91 |
+
self.config.mm_hidden_size = getattr(vision_resampler, "hidden_size", vision_tower.hidden_size)
|
92 |
+
self.config.mm_vision_select_layer = mm_vision_select_layer
|
93 |
+
self.config.mm_vision_select_feature = mm_vision_select_feature
|
94 |
+
self.config.mm_patch_merge_type = mm_patch_merge_type
|
95 |
+
|
96 |
+
if getattr(self, "mm_projector", None) is None:
|
97 |
+
self.mm_projector = build_vision_projector(self.config, vision_cfg=vision_tower.config)
|
98 |
+
|
99 |
+
if "unpad" in mm_patch_merge_type:
|
100 |
+
embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
|
101 |
+
self.image_newline = nn.Parameter(torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std)
|
102 |
+
else:
|
103 |
+
# In case it is frozen by LoRA
|
104 |
+
for p in self.mm_projector.parameters():
|
105 |
+
p.requires_grad = True
|
106 |
+
|
107 |
+
if pretrain_mm_mlp_adapter is not None:
|
108 |
+
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location="cpu")
|
109 |
+
|
110 |
+
def get_w(weights, keyword):
|
111 |
+
return {k.split(keyword + ".")[1]: v for k, v in weights.items() if keyword in k}
|
112 |
+
|
113 |
+
incompatible_keys = self.mm_projector.load_state_dict(get_w(mm_projector_weights, "mm_projector"))
|
114 |
+
rank0_print(f"Loaded mm projector weights from {pretrain_mm_mlp_adapter}. Incompatible keys: {incompatible_keys}")
|
115 |
+
incompatible_keys = self.vision_resampler.load_state_dict(get_w(mm_projector_weights, "vision_resampler"), strict=False)
|
116 |
+
rank0_print(f"Loaded vision resampler weights from {pretrain_mm_mlp_adapter}. Incompatible keys: {incompatible_keys}")
|
117 |
+
|
118 |
+
|
119 |
+
def unpad_image(tensor, original_size):
|
120 |
+
"""
|
121 |
+
Unpads a PyTorch tensor of a padded and resized image.
|
122 |
+
|
123 |
+
Args:
|
124 |
+
tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
|
125 |
+
original_size (tuple): The original size of the image (height, width).
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
torch.Tensor: The unpadded image tensor.
|
129 |
+
"""
|
130 |
+
original_width, original_height = original_size
|
131 |
+
current_height, current_width = tensor.shape[1:]
|
132 |
+
|
133 |
+
# Compute aspect ratios
|
134 |
+
original_aspect_ratio = original_width / original_height
|
135 |
+
current_aspect_ratio = current_width / current_height
|
136 |
+
|
137 |
+
# Determine padding size and direction
|
138 |
+
if original_aspect_ratio > current_aspect_ratio:
|
139 |
+
# Padding was added to the height
|
140 |
+
scale_factor = current_width / original_width
|
141 |
+
new_height = int(original_height * scale_factor)
|
142 |
+
padding = (current_height - new_height) // 2
|
143 |
+
unpadded_tensor = tensor[:, padding : current_height - padding, :]
|
144 |
+
else:
|
145 |
+
# Padding was added to the width
|
146 |
+
scale_factor = current_height / original_height
|
147 |
+
new_width = int(original_width * scale_factor)
|
148 |
+
padding = (current_width - new_width) // 2
|
149 |
+
unpadded_tensor = tensor[:, :, padding : current_width - padding]
|
150 |
+
|
151 |
+
return unpadded_tensor
|
152 |
+
|
153 |
+
|
154 |
+
class LlavaMetaForCausalLM(ABC):
|
155 |
+
|
156 |
+
@abstractmethod
|
157 |
+
def get_model(self):
|
158 |
+
pass
|
159 |
+
|
160 |
+
def get_vision_tower(self):
|
161 |
+
return self.get_model().get_vision_tower()
|
162 |
+
|
163 |
+
def get_2dPool(self, image_feature):
|
164 |
+
height = width = self.get_vision_tower().num_patches_per_side
|
165 |
+
num_frames, num_tokens, num_dim = image_feature.shape
|
166 |
+
image_feature = image_feature.view(num_frames, height, width, -1)
|
167 |
+
image_feature = image_feature.permute(0, 3, 1, 2).contiguous()
|
168 |
+
# image_feature = nn.functional.max_pool2d(image_feature, self.config.mm_spatial_pool_stride)
|
169 |
+
if self.config.mm_spatial_pool_mode == "average":
|
170 |
+
image_feature = nn.functional.avg_pool2d(image_feature, self.config.mm_spatial_pool_stride)
|
171 |
+
elif self.config.mm_spatial_pool_mode == "max":
|
172 |
+
image_feature = nn.functional.max_pool2d(image_feature, self.config.mm_spatial_pool_stride)
|
173 |
+
elif self.config.mm_spatial_pool_mode == "bilinear":
|
174 |
+
height, weight = image_feature.shape[2:]
|
175 |
+
scaled_shape = [math.ceil(height / 2), math.ceil(weight / 2)]
|
176 |
+
image_feature = nn.functional.interpolate(image_feature, size=scaled_shape, mode='bilinear')
|
177 |
+
|
178 |
+
else:
|
179 |
+
raise ValueError(f"Unexpected mm_spatial_pool_mode: {self.config.mm_spatial_pool_mode}")
|
180 |
+
image_feature = image_feature.permute(0, 2, 3, 1)
|
181 |
+
image_feature = image_feature.view(num_frames, -1, num_dim)
|
182 |
+
return image_feature
|
183 |
+
|
184 |
+
def encode_images(self, images):
|
185 |
+
image_features = self.get_model().get_vision_tower()(images)
|
186 |
+
# image_features = self.get_model().vision_resampler(image_features, images=images)
|
187 |
+
image_features = self.get_model().mm_projector(image_features)
|
188 |
+
return image_features
|
189 |
+
|
190 |
+
def encode_multimodals(self, videos_or_images, video_idx_in_batch, split_sizes=None):
|
191 |
+
videos_or_images_features = self.get_model().get_vision_tower()(videos_or_images)
|
192 |
+
per_videos_or_images_features = torch.split(videos_or_images_features, split_sizes, dim=0) # tuple, (dim_1, 576, 4096)
|
193 |
+
all_videos_or_images_features = []
|
194 |
+
|
195 |
+
for idx, feat in enumerate(per_videos_or_images_features):
|
196 |
+
feat = self.get_model().mm_projector(feat)
|
197 |
+
if idx in video_idx_in_batch:
|
198 |
+
feat = self.get_2dPool(feat)
|
199 |
+
all_videos_or_images_features.append(feat)
|
200 |
+
return all_videos_or_images_features
|
201 |
+
|
202 |
+
def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities=["image"], image_sizes=None):
|
203 |
+
vision_tower = self.get_vision_tower()
|
204 |
+
# rank_print(modalities)
|
205 |
+
if vision_tower is None or images is None or input_ids.shape[1] == 1:
|
206 |
+
return input_ids, position_ids, attention_mask, past_key_values, None, labels
|
207 |
+
|
208 |
+
if type(images) is list or images.ndim == 5:
|
209 |
+
if type(images) is list:
|
210 |
+
images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
|
211 |
+
|
212 |
+
video_idx_in_batch = []
|
213 |
+
for _ in range(len(modalities)):
|
214 |
+
if modalities[_] == "video":
|
215 |
+
video_idx_in_batch.append(_)
|
216 |
+
|
217 |
+
images_list = []
|
218 |
+
for image in images:
|
219 |
+
if image.ndim == 4:
|
220 |
+
images_list.append(image)
|
221 |
+
else:
|
222 |
+
images_list.append(image.unsqueeze(0))
|
223 |
+
|
224 |
+
concat_images = torch.cat([image for image in images_list], dim=0)
|
225 |
+
split_sizes = [image.shape[0] for image in images_list]
|
226 |
+
encoded_image_features = self.encode_images(concat_images)
|
227 |
+
|
228 |
+
# This is a list, each element is [num_images, patch * patch, dim]
|
229 |
+
# rank_print(f"Concat images : {concat_images.shape}")
|
230 |
+
encoded_image_features = torch.split(encoded_image_features, split_sizes)
|
231 |
+
image_features = []
|
232 |
+
for idx, image_feat in enumerate(encoded_image_features):
|
233 |
+
if idx in video_idx_in_batch:
|
234 |
+
image_features.append(self.get_2dPool(image_feat))
|
235 |
+
else:
|
236 |
+
image_features.append(image_feat)
|
237 |
+
# image_features = self.encode_multimodals(concat_images, video_idx_in_batch, split_sizes)
|
238 |
+
# rank_print(f"Encoded image feats : {[x.shape for x in image_features]}")
|
239 |
+
# image_features = torch.split(image_features, split_sizes, dim=0)
|
240 |
+
mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
|
241 |
+
image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
|
242 |
+
|
243 |
+
if mm_patch_merge_type == "flat":
|
244 |
+
image_features = [x.flatten(0, 1) for x in image_features]
|
245 |
+
|
246 |
+
elif mm_patch_merge_type.startswith("spatial"):
|
247 |
+
new_image_features = []
|
248 |
+
for image_idx, image_feature in enumerate(image_features):
|
249 |
+
# FIXME: now assume the image is square, and split to 2x2 patches
|
250 |
+
# num_patches = h * w, where h = w = sqrt(num_patches)
|
251 |
+
# currently image_feature is a tensor of shape (4, num_patches, hidden_size)
|
252 |
+
# we want to first unflatten it to (2, 2, h, w, hidden_size)
|
253 |
+
# rank0_print("At least we are reaching here")
|
254 |
+
if image_idx in video_idx_in_batch: # video operations
|
255 |
+
# rank0_print("Video")
|
256 |
+
if "unpad" in mm_patch_merge_type:
|
257 |
+
# image_feature = image_feature.permute(2, 0, 1).contiguous()
|
258 |
+
# image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1)
|
259 |
+
# image_feature = image_feature.permute(1, 2, 0).contiguous()
|
260 |
+
image_feature = image_feature.flatten(0, 1)
|
261 |
+
image_feature = torch.cat((image_feature, self.model.image_newline[None].to(image_feature.device)), dim=0)
|
262 |
+
|
263 |
+
elif image_feature.shape[0] > 1: # multi patches and multi images operations
|
264 |
+
# rank0_print("Single-images")
|
265 |
+
base_image_feature = image_feature[0]
|
266 |
+
image_feature = image_feature[1:]
|
267 |
+
height = width = self.get_vision_tower().num_patches_per_side
|
268 |
+
assert height * width == base_image_feature.shape[0]
|
269 |
+
|
270 |
+
if "anyres_max" in image_aspect_ratio:
|
271 |
+
matched_anyres_max_num_patches = re.match(r"anyres_max_(\d+)", image_aspect_ratio)
|
272 |
+
if matched_anyres_max_num_patches:
|
273 |
+
max_num_patches = int(matched_anyres_max_num_patches.group(1))
|
274 |
+
|
275 |
+
if image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
|
276 |
+
if hasattr(self.get_vision_tower(), "image_size"):
|
277 |
+
vision_tower_image_size = self.get_vision_tower().image_size
|
278 |
+
else:
|
279 |
+
raise ValueError("vision_tower_image_size is not found in the vision tower.")
|
280 |
+
try:
|
281 |
+
num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, vision_tower_image_size)
|
282 |
+
except Exception as e:
|
283 |
+
rank0_print(f"Error: {e}")
|
284 |
+
num_patch_width, num_patch_height = 2, 2
|
285 |
+
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
|
286 |
+
else:
|
287 |
+
image_feature = image_feature.view(2, 2, height, width, -1)
|
288 |
+
|
289 |
+
if "maxpool2x2" in mm_patch_merge_type:
|
290 |
+
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
291 |
+
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
292 |
+
image_feature = nn.functional.max_pool2d(image_feature, 2)
|
293 |
+
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
294 |
+
elif "unpad" in mm_patch_merge_type and "anyres_max" in image_aspect_ratio and matched_anyres_max_num_patches:
|
295 |
+
unit = image_feature.shape[2]
|
296 |
+
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
297 |
+
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
298 |
+
image_feature = unpad_image(image_feature, image_sizes[image_idx])
|
299 |
+
c, h, w = image_feature.shape
|
300 |
+
times = math.sqrt(h * w / (max_num_patches * unit**2))
|
301 |
+
if times > 1.1:
|
302 |
+
image_feature = image_feature[None]
|
303 |
+
image_feature = nn.functional.interpolate(image_feature, [int(h // times), int(w // times)], mode="bilinear")[0]
|
304 |
+
image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1)
|
305 |
+
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
306 |
+
elif "unpad" in mm_patch_merge_type:
|
307 |
+
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
308 |
+
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
309 |
+
image_feature = unpad_image(image_feature, image_sizes[image_idx])
|
310 |
+
image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1)
|
311 |
+
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
312 |
+
else:
|
313 |
+
image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
|
314 |
+
image_feature = image_feature.flatten(0, 3)
|
315 |
+
if "nobase" in mm_patch_merge_type:
|
316 |
+
pass
|
317 |
+
else:
|
318 |
+
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
|
319 |
+
else: # single image operations
|
320 |
+
image_feature = image_feature[0]
|
321 |
+
if "unpad" in mm_patch_merge_type:
|
322 |
+
image_feature = torch.cat((image_feature, self.model.image_newline[None]), dim=0)
|
323 |
+
|
324 |
+
new_image_features.append(image_feature)
|
325 |
+
image_features = new_image_features
|
326 |
+
else:
|
327 |
+
raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
|
328 |
+
else:
|
329 |
+
image_features = self.encode_images(images)
|
330 |
+
|
331 |
+
# TODO: image start / end is not implemented here to support pretraining.
|
332 |
+
if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(self.config, "mm_use_im_start_end", False):
|
333 |
+
raise NotImplementedError
|
334 |
+
# rank_print(f"Total images : {len(image_features)}")
|
335 |
+
|
336 |
+
# Let's just add dummy tensors if they do not exist,
|
337 |
+
# it is a headache to deal with None all the time.
|
338 |
+
# But it is not ideal, and if you have a better idea,
|
339 |
+
# please open an issue / submit a PR, thanks.
|
340 |
+
_labels = labels
|
341 |
+
_position_ids = position_ids
|
342 |
+
_attention_mask = attention_mask
|
343 |
+
if attention_mask is None:
|
344 |
+
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
|
345 |
+
else:
|
346 |
+
attention_mask = attention_mask.bool()
|
347 |
+
if position_ids is None:
|
348 |
+
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
|
349 |
+
if labels is None:
|
350 |
+
labels = torch.full_like(input_ids, IGNORE_INDEX)
|
351 |
+
|
352 |
+
# remove the padding using attention_mask -- FIXME
|
353 |
+
_input_ids = input_ids
|
354 |
+
input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
|
355 |
+
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
|
356 |
+
|
357 |
+
new_input_embeds = []
|
358 |
+
new_labels = []
|
359 |
+
cur_image_idx = 0
|
360 |
+
# rank_print("Inserting Images embedding")
|
361 |
+
for batch_idx, cur_input_ids in enumerate(input_ids):
|
362 |
+
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
|
363 |
+
# rank0_print(num_images)
|
364 |
+
if num_images == 0:
|
365 |
+
cur_image_features = image_features[cur_image_idx]
|
366 |
+
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
|
367 |
+
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
|
368 |
+
new_input_embeds.append(cur_input_embeds)
|
369 |
+
new_labels.append(labels[batch_idx])
|
370 |
+
cur_image_idx += 1
|
371 |
+
continue
|
372 |
+
|
373 |
+
image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
|
374 |
+
cur_input_ids_noim = []
|
375 |
+
cur_labels = labels[batch_idx]
|
376 |
+
cur_labels_noim = []
|
377 |
+
for i in range(len(image_token_indices) - 1):
|
378 |
+
cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]])
|
379 |
+
cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]])
|
380 |
+
split_sizes = [x.shape[0] for x in cur_labels_noim]
|
381 |
+
cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
|
382 |
+
cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
|
383 |
+
cur_new_input_embeds = []
|
384 |
+
cur_new_labels = []
|
385 |
+
|
386 |
+
for i in range(num_images + 1):
|
387 |
+
cur_new_input_embeds.append(cur_input_embeds_no_im[i])
|
388 |
+
cur_new_labels.append(cur_labels_noim[i])
|
389 |
+
if i < num_images:
|
390 |
+
try:
|
391 |
+
cur_image_features = image_features[cur_image_idx]
|
392 |
+
except IndexError:
|
393 |
+
cur_image_features = image_features[cur_image_idx - 1]
|
394 |
+
cur_image_idx += 1
|
395 |
+
cur_new_input_embeds.append(cur_image_features)
|
396 |
+
cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
|
397 |
+
|
398 |
+
cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
|
399 |
+
|
400 |
+
# import pdb; pdb.set_trace()
|
401 |
+
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
|
402 |
+
cur_new_labels = torch.cat(cur_new_labels)
|
403 |
+
|
404 |
+
new_input_embeds.append(cur_new_input_embeds)
|
405 |
+
new_labels.append(cur_new_labels)
|
406 |
+
|
407 |
+
# Truncate sequences to max length as image embeddings can make the sequence longer
|
408 |
+
tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None)
|
409 |
+
# rank_print("Finishing Inserting")
|
410 |
+
|
411 |
+
new_input_embeds = [x[:tokenizer_model_max_length] for x, modality in zip(new_input_embeds, modalities)]
|
412 |
+
new_labels = [x[:tokenizer_model_max_length] for x, modality in zip(new_labels, modalities)]
|
413 |
+
# TODO: Hard code for control loss spike
|
414 |
+
# if tokenizer_model_max_length is not None:
|
415 |
+
# new_input_embeds = [x[:4096] if modality != "video" else x[:tokenizer_model_max_length] for x, modality in zip(new_input_embeds, modalities)]
|
416 |
+
# new_labels = [x[:4096] if modality != "video" else x[:tokenizer_model_max_length] for x, modality in zip(new_labels, modalities)]
|
417 |
+
|
418 |
+
# Combine them
|
419 |
+
max_len = max(x.shape[0] for x in new_input_embeds)
|
420 |
+
batch_size = len(new_input_embeds)
|
421 |
+
|
422 |
+
new_input_embeds_padded = []
|
423 |
+
new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
|
424 |
+
attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
|
425 |
+
position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
|
426 |
+
# rank0_print("Prepare pos id")
|
427 |
+
|
428 |
+
for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
|
429 |
+
cur_len = cur_new_embed.shape[0]
|
430 |
+
if getattr(self.config, "tokenizer_padding_side", "right") == "left":
|
431 |
+
new_input_embeds_padded.append(torch.cat((torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), cur_new_embed), dim=0))
|
432 |
+
if cur_len > 0:
|
433 |
+
new_labels_padded[i, -cur_len:] = cur_new_labels
|
434 |
+
attention_mask[i, -cur_len:] = True
|
435 |
+
position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
|
436 |
+
else:
|
437 |
+
new_input_embeds_padded.append(torch.cat((cur_new_embed, torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0))
|
438 |
+
if cur_len > 0:
|
439 |
+
new_labels_padded[i, :cur_len] = cur_new_labels
|
440 |
+
attention_mask[i, :cur_len] = True
|
441 |
+
position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
|
442 |
+
|
443 |
+
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
|
444 |
+
# rank0_print("tokenizer padding")
|
445 |
+
|
446 |
+
if _labels is None:
|
447 |
+
new_labels = None
|
448 |
+
else:
|
449 |
+
new_labels = new_labels_padded
|
450 |
+
|
451 |
+
if _attention_mask is None:
|
452 |
+
attention_mask = None
|
453 |
+
else:
|
454 |
+
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
|
455 |
+
|
456 |
+
if _position_ids is None:
|
457 |
+
position_ids = None
|
458 |
+
if getattr(self.config, "use_pos_skipping", False) and self.training:
|
459 |
+
position_ids = torch.arange(new_input_embeds.size(1), device=new_input_embeds.device).unsqueeze(0).to(new_input_embeds.device)
|
460 |
+
split_position = random.randint(0, new_input_embeds.size(1))
|
461 |
+
left_add = random.randint(0, self.config.pos_skipping_range)
|
462 |
+
right_add = random.randint(left_add, self.config.pos_skipping_range)
|
463 |
+
position_ids[:, :split_position] += left_add
|
464 |
+
position_ids[:, split_position:] += right_add
|
465 |
+
# import pdb; pdb.set_trace()
|
466 |
+
# rank0_print("Finish preparing")
|
467 |
+
return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
|
468 |
+
|
469 |
+
def initialize_vision_tokenizer(self, model_args, tokenizer):
|
470 |
+
if model_args.mm_use_im_patch_token:
|
471 |
+
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
472 |
+
self.resize_token_embeddings(len(tokenizer))
|
473 |
+
|
474 |
+
if model_args.mm_use_im_start_end:
|
475 |
+
num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
476 |
+
self.resize_token_embeddings(len(tokenizer))
|
477 |
+
|
478 |
+
if num_new_tokens > 0:
|
479 |
+
input_embeddings = self.get_input_embeddings().weight.data
|
480 |
+
output_embeddings = self.get_output_embeddings().weight.data
|
481 |
+
|
482 |
+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
483 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
484 |
+
|
485 |
+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
486 |
+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
487 |
+
|
488 |
+
if model_args.tune_mm_mlp_adapter:
|
489 |
+
for p in self.get_input_embeddings().parameters():
|
490 |
+
p.requires_grad = True
|
491 |
+
for p in self.get_output_embeddings().parameters():
|
492 |
+
p.requires_grad = False
|
493 |
+
|
494 |
+
if model_args.pretrain_mm_mlp_adapter:
|
495 |
+
mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location="cpu")
|
496 |
+
embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"]
|
497 |
+
assert num_new_tokens == 2
|
498 |
+
if input_embeddings.shape == embed_tokens_weight.shape:
|
499 |
+
input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
|
500 |
+
elif embed_tokens_weight.shape[0] == num_new_tokens:
|
501 |
+
input_embeddings[-num_new_tokens:] = embed_tokens_weight
|
502 |
+
else:
|
503 |
+
raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
|
504 |
+
elif model_args.mm_use_im_patch_token:
|
505 |
+
if model_args.tune_mm_mlp_adapter:
|
506 |
+
for p in self.get_input_embeddings().parameters():
|
507 |
+
p.requires_grad = False
|
508 |
+
for p in self.get_output_embeddings().parameters():
|
509 |
+
p.requires_grad = False
|
llava/model/make_delta.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Usage:
|
3 |
+
python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta
|
4 |
+
"""
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from tqdm import tqdm
|
10 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
11 |
+
from llava.model.utils import auto_upgrade
|
12 |
+
|
13 |
+
|
14 |
+
def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
|
15 |
+
print("Loading base model")
|
16 |
+
base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
17 |
+
|
18 |
+
print("Loading target model")
|
19 |
+
auto_upgrade(target_model_path)
|
20 |
+
target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
21 |
+
|
22 |
+
print("Calculating delta")
|
23 |
+
for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
|
24 |
+
if name not in base.state_dict():
|
25 |
+
assert name in ["model.mm_projector.weight", "model.mm_projector.bias"], f"{name} not in base model"
|
26 |
+
continue
|
27 |
+
if param.data.shape == base.state_dict()[name].shape:
|
28 |
+
param.data -= base.state_dict()[name]
|
29 |
+
else:
|
30 |
+
assert name in ["model.embed_tokens.weight", "lm_head.weight"], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}"
|
31 |
+
bparam = base.state_dict()[name]
|
32 |
+
param.data[: bparam.shape[0], : bparam.shape[1]] -= bparam
|
33 |
+
|
34 |
+
print("Saving delta")
|
35 |
+
if hub_repo_id:
|
36 |
+
kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
|
37 |
+
else:
|
38 |
+
kwargs = {}
|
39 |
+
target.save_pretrained(delta_path, **kwargs)
|
40 |
+
target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
|
41 |
+
target_tokenizer.save_pretrained(delta_path, **kwargs)
|
42 |
+
|
43 |
+
|
44 |
+
if __name__ == "__main__":
|
45 |
+
parser = argparse.ArgumentParser()
|
46 |
+
parser.add_argument("--base-model-path", type=str, required=True)
|
47 |
+
parser.add_argument("--target-model-path", type=str, required=True)
|
48 |
+
parser.add_argument("--delta-path", type=str, required=True)
|
49 |
+
parser.add_argument("--hub-repo-id", type=str, default=None)
|
50 |
+
args = parser.parse_args()
|
51 |
+
|
52 |
+
make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)
|