BiXie commited on
Commit
252711e
1 Parent(s): ed6a121

Upload 204 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. llava-1.7.0.dev0.dist-info/INSTALLER +1 -0
  2. llava-1.7.0.dev0.dist-info/LICENSE +201 -0
  3. llava-1.7.0.dev0.dist-info/METADATA +266 -0
  4. llava-1.7.0.dev0.dist-info/RECORD +204 -0
  5. llava-1.7.0.dev0.dist-info/REQUESTED +0 -0
  6. llava-1.7.0.dev0.dist-info/WHEEL +5 -0
  7. llava-1.7.0.dev0.dist-info/direct_url.json +1 -0
  8. llava-1.7.0.dev0.dist-info/top_level.txt +2 -0
  9. llava/__init__.py +1 -0
  10. llava/__pycache__/__init__.cpython-39.pyc +0 -0
  11. llava/__pycache__/constants.cpython-39.pyc +0 -0
  12. llava/__pycache__/conversation.cpython-39.pyc +0 -0
  13. llava/__pycache__/mm_utils.cpython-39.pyc +0 -0
  14. llava/__pycache__/utils.cpython-39.pyc +0 -0
  15. llava/constants.py +12 -0
  16. llava/conversation.py +577 -0
  17. llava/eval/__pycache__/evaluate_interleave.cpython-39.pyc +0 -0
  18. llava/eval/__pycache__/model_vqa.cpython-39.pyc +0 -0
  19. llava/eval/evaluate_interleave.py +339 -0
  20. llava/eval/model_vqa.py +240 -0
  21. llava/mm_utils.py +395 -0
  22. llava/model/__init__.py +16 -0
  23. llava/model/__pycache__/__init__.cpython-39.pyc +0 -0
  24. llava/model/__pycache__/apply_delta.cpython-39.pyc +0 -0
  25. llava/model/__pycache__/builder.cpython-39.pyc +0 -0
  26. llava/model/__pycache__/consolidate.cpython-39.pyc +0 -0
  27. llava/model/__pycache__/llava_arch.cpython-39.pyc +0 -0
  28. llava/model/__pycache__/make_delta.cpython-39.pyc +0 -0
  29. llava/model/__pycache__/utils.cpython-39.pyc +0 -0
  30. llava/model/apply_delta.py +47 -0
  31. llava/model/builder.py +301 -0
  32. llava/model/consolidate.py +30 -0
  33. llava/model/language_model/__pycache__/llava_gemma.cpython-39.pyc +0 -0
  34. llava/model/language_model/__pycache__/llava_llama.cpython-39.pyc +0 -0
  35. llava/model/language_model/__pycache__/llava_mistral.cpython-39.pyc +0 -0
  36. llava/model/language_model/__pycache__/llava_mixtral.cpython-39.pyc +0 -0
  37. llava/model/language_model/__pycache__/llava_mpt.cpython-39.pyc +0 -0
  38. llava/model/language_model/__pycache__/llava_qwen.cpython-39.pyc +0 -0
  39. llava/model/language_model/__pycache__/llava_qwen_moe.cpython-39.pyc +0 -0
  40. llava/model/language_model/__pycache__/modeling_llama.cpython-39.pyc +0 -0
  41. llava/model/language_model/llava_gemma.py +122 -0
  42. llava/model/language_model/llava_llama.py +156 -0
  43. llava/model/language_model/llava_mistral.py +127 -0
  44. llava/model/language_model/llava_mixtral.py +143 -0
  45. llava/model/language_model/llava_mpt.py +105 -0
  46. llava/model/language_model/llava_qwen.py +149 -0
  47. llava/model/language_model/llava_qwen_moe.py +149 -0
  48. llava/model/language_model/modeling_llama.py +1649 -0
  49. llava/model/llava_arch.py +509 -0
  50. llava/model/make_delta.py +52 -0
llava-1.7.0.dev0.dist-info/INSTALLER ADDED
@@ -0,0 +1 @@
 
 
1
+ pip
llava-1.7.0.dev0.dist-info/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
llava-1.7.0.dev0.dist-info/METADATA ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: llava
3
+ Version: 1.7.0.dev0
4
+ Summary: LLaVA OneVision: The Next Generation of LLaVA with Better Image and Video Understanding Capabilities
5
+ Project-URL: Homepage, https://llava-vl.github.io
6
+ Project-URL: Bug Tracker, https://github.com/haotian-liu/LLaVA/issues
7
+ Classifier: Programming Language :: Python :: 3
8
+ Classifier: License :: OSI Approved :: Apache Software License
9
+ Requires-Python: >=3.8
10
+ Description-Content-Type: text/markdown
11
+ License-File: LICENSE
12
+ Provides-Extra: standalone
13
+ Requires-Dist: shortuuid ; extra == 'standalone'
14
+ Requires-Dist: httpx ==0.24.0 ; extra == 'standalone'
15
+ Requires-Dist: einops ; extra == 'standalone'
16
+ Requires-Dist: ftfy ; extra == 'standalone'
17
+ Provides-Extra: train
18
+ Requires-Dist: llava[standalone] ; extra == 'train'
19
+ Requires-Dist: numpy ==1.26.1 ; extra == 'train'
20
+ Requires-Dist: open-clip-torch ; extra == 'train'
21
+ Requires-Dist: fastapi ; extra == 'train'
22
+ Requires-Dist: gradio ==3.35.2 ; extra == 'train'
23
+ Requires-Dist: markdown2[all] ; extra == 'train'
24
+ Requires-Dist: numpy ; extra == 'train'
25
+ Requires-Dist: requests ; extra == 'train'
26
+ Requires-Dist: sentencepiece ; extra == 'train'
27
+ Requires-Dist: torch ==2.1.2 ; extra == 'train'
28
+ Requires-Dist: torchvision ==0.16.2 ; extra == 'train'
29
+ Requires-Dist: uvicorn ; extra == 'train'
30
+ Requires-Dist: wandb ; extra == 'train'
31
+ Requires-Dist: deepspeed ==0.14.2 ; extra == 'train'
32
+ Requires-Dist: peft ==0.4.0 ; extra == 'train'
33
+ Requires-Dist: accelerate >=0.29.1 ; extra == 'train'
34
+ Requires-Dist: tokenizers ~=0.15.2 ; extra == 'train'
35
+ Requires-Dist: transformers @ git+https://github.com/huggingface/transformers.git@1c39974a4c4036fd641bc1191cc32799f85715a4 ; extra == 'train'
36
+ Requires-Dist: bitsandbytes ==0.41.0 ; extra == 'train'
37
+ Requires-Dist: scikit-learn ==1.2.2 ; extra == 'train'
38
+ Requires-Dist: sentencepiece ~=0.1.99 ; extra == 'train'
39
+ Requires-Dist: einops ==0.6.1 ; extra == 'train'
40
+ Requires-Dist: einops-exts ==0.0.4 ; extra == 'train'
41
+ Requires-Dist: gradio-client ==0.2.9 ; extra == 'train'
42
+ Requires-Dist: urllib3 <=2.0.0 ; extra == 'train'
43
+ Requires-Dist: datasets ==2.16.1 ; extra == 'train'
44
+ Requires-Dist: pydantic ==1.10.8 ; extra == 'train'
45
+ Requires-Dist: timm ; extra == 'train'
46
+ Requires-Dist: hf-transfer ; extra == 'train'
47
+ Requires-Dist: opencv-python ; extra == 'train'
48
+ Requires-Dist: av ; extra == 'train'
49
+ Requires-Dist: decord ; extra == 'train'
50
+ Requires-Dist: tyro ; extra == 'train'
51
+ Requires-Dist: scipy ; extra == 'train'
52
+
53
+ <p align="center" width="100%">
54
+ <img src="https://i.postimg.cc/pL17YtG4/WX20240508-220230-2x.png" width="80%" height="80%">
55
+ </p>
56
+
57
+ # LLaVA-NeXT: Open Large Multimodal Models
58
+ [![Static Badge](https://img.shields.io/badge/llava_onevision-paper-green)](https://arxiv.org/abs/2408.03326)
59
+ [![llava_next-blog](https://img.shields.io/badge/llava_next-blog-green)](https://llava-vl.github.io/blog/)
60
+
61
+ [![llava_onevision-demo](https://img.shields.io/badge/llava_onevision-demo-red)](https://llava-onevision.lmms-lab.com/)
62
+ [![llava_next-interleave_demo](https://img.shields.io/badge/llava_next-interleave_demo-red)](https://huggingface.co/spaces/lmms-lab/LLaVA-NeXT-Interleave-Demo)
63
+ [![llava_next-video_demo](https://img.shields.io/badge/llava_next-video_demo-red)](https://huggingface.co/spaces/WildVision/vision-arena)
64
+
65
+ [![llava_onevision-checkpoints](https://img.shields.io/badge/llava_onevision-checkpoints-blue)](https://huggingface.co/collections/lmms-lab/llava-onevision-66a259c3526e15166d6bba37)
66
+ [![llava_next-interleave_checkpoints](https://img.shields.io/badge/llava_next-interleave_checkpoints-blue)](https://huggingface.co/collections/lmms-lab/llava-next-interleave-66763c55c411b340b35873d1)
67
+ [![llava_next-video_checkpoints](https://img.shields.io/badge/llava_next-video_checkpoints-blue)](https://huggingface.co/collections/lmms-lab/llava-next-video-661e86f5e8dabc3ff793c944)
68
+ [![llava_next-image_checkpoints](https://img.shields.io/badge/llava_next-image_checkpoints-blue)](https://huggingface.co/lmms-lab)
69
+
70
+ ## Release Notes
71
+
72
+ - [2024/08/06] 🔥 **🚀 [LLaVA-OneVision (OV)](https://llava-vl.github.io/blog/2024-08-05-llava-onevision/)!** The new LLaVA-OV models (0.5B/7B/72B) achieve new state-of-the-art performance across single-image, multi-image, and video benchmarks, sometimes rivaling top commercial models on 47 diverse benchmarks. 📄 Explore More:
73
+ * [[Paper]](https://arxiv.org/abs/2408.03326): In-depth insights, new emegerging scenarios, ie, strong video understadning through task transfer from images.
74
+ * [[LLaVA-OV Doc]](https://github.com/LLaVA-VL/LLaVA-NeXT/blob/main/docs/LLaVA_OneVision.md): Model inference and evaluation guidance.
75
+ * [[Scripts]](https://github.com/LLaVA-VL/LLaVA-NeXT/blob/main/scripts/train): Start training models on your single-image/multi-image/video data.
76
+
77
+ - [2024/07/16] 🔥 **LLaVA-NeXT-Video** has been upgraded. The new 32B model achieves the best open-source performance on several video benchmarks, including [Video-MME](https://video-mme.github.io/home_page.html#leaderboard). Please refer to [this page](docs/LLaVA-NeXT-Video_0716.md) for details, refer to [llava_next-video_demo](https://huggingface.co/spaces/WildVision/vision-arena) for demo.
78
+
79
+
80
+ - [2024/06/23] 🔥 **LLaVA-NeXT-Interleave** is released. We utilize image-text interleaved format to unify multi-image, video, and 3D tasks in one LLM and achieve **SoTA** performance on a wide range of benchmarks. Check out [paper](https://arxiv.org/pdf/2407.07895), [blog](https://llava-vl.github.io/blog/2024-06-16-llava-next-interleave/), and [checkpoints](https://huggingface.co/collections/lmms-lab/llava-next-interleave-66763c55c411b340b35873d1) to see new capabilities and improved performance! We have released 0.5b, 7b, and 7b-dpo models.
81
+ * An all-round LLM for multi-image, video, and 3D with strong performance \[[demo](https://huggingface.co/spaces/lmms-lab/LLaVA-NeXT-Interleave-Demo)\]
82
+ * Construct interleave training data [**M4-Instruct**](https://huggingface.co/datasets/lmms-lab/M4-Instruct-Data)
83
+ * Construct multi-image benchmark [**LLaVA-Interleave Bench**](https://huggingface.co/datasets/lmms-lab/LLaVA-NeXT-Interleave-Bench)
84
+
85
+
86
+ - [2024/05/25] 🔥 Wondering "[What Else Influences Visual Instruction Tuning Beyond Data?](https://llava-vl.github.io/blog/2024-05-25-llava-next-ablations/)" Our new [blog](https://llava-vl.github.io/blog/2024-05-25-llava-next-ablations/) summarizes empirical explorations to ablate the various design choices in improving LMMs except instruct data itself. Meanwhile, open-source the recapioned high-quality data using LLaVA-NeXT-34B on [[COCO]](https://huggingface.co/datasets/lmms-lab/LLaVA-ReCap-118K) [[LCS]](https://huggingface.co/datasets/lmms-lab/LLaVA-ReCap-558K) [[CC3M]](https://huggingface.co/datasets/lmms-lab/LLaVA-ReCap-CC3M).
87
+ * Architectures (LMM & Vision Encoder)
88
+ * Visual Representations (Resolution & # Tokens)
89
+ * Training Strategies (High-quality data & Trainable modules)
90
+
91
+ - [2024/05/10] 🔥 **LLaVA-NeXT** (Stronger) models are released, with support of stronger LMM inlcuding LLama-3 (8B) and Qwen-1.5 (72B/110B) Check out [[blog](https://llava-vl.github.io/blog/2024-05-10-llava-next-stronger-llms/)] and [[checkpoints](https://huggingface.co/lmms-lab)] to see improved performance!
92
+ - [2024/05/10] 🔥 **LLaVA-NeXT** (Video) is released. The image-only-trained LLaVA-NeXT model is surprisingly strong on video tasks with zero-shot modality transfer. DPO training with AI feedback on videos can yield significant improvement. [[Blog](https://llava-vl.github.io/blog/2024-04-30-llava-next-video/)], [[checkpoints](https://huggingface.co/collections/lmms-lab/llava-next-video-661e86f5e8dabc3ff793c944)] and [[sglang](https://github.com/sgl-project/sglang)]
93
+ - [2024/01/30] 🔥 **LLaVA-NeXT** is out! With additional scaling to LLaVA-1.5, LLaVA-NeXT-34B outperforms Gemini Pro on some benchmarks. It can now process 4x more pixels and perform more tasks/applications than before. Check out the [blog post](https://llava-vl.github.io/blog/2024-01-30-llava-next/), and explore the [demo](https://llava.hliu.cc/)! Models are available in [Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md). Training/eval data and scripts coming soon.
94
+ <details>
95
+ <summary>More</summary>
96
+
97
+ - [2024/03/10] 🔥 Releasing **LMMs-Eval**, a highly efficient evaluation pipeline we used when developing LLaVA-NeXT. It supports the evaluation of LMMs on dozens of public datasets and allows new dataset onboarding, making the dev of new LMMs much faster. [[Blog](https://lmms-lab.github.io/lmms-eval-blog/lmms-eval-0.1/)] [[Codebase](https://github.com/EvolvingLMMs-Lab/lmms-eval)]
98
+
99
+ - [2023/11/10] [LLaVA-Plus](https://llava-vl.github.io/llava-plus/) is released: Learning to Use Tools for Creating Multimodal Agents, with LLaVA-Plus (LLaVA that Plug and Learn to Use Skills). [[Project Page](https://llava-vl.github.io/llava-plus/)] [[Demo](https://llavaplus.ngrok.io/)] [[Code](https://github.com/LLaVA-VL/LLaVA-Plus-Codebase)] [[Paper](https://arxiv.org/abs/2311.05437)]
100
+ - [2023/11/02] [LLaVA-Interactive](https://llava-vl.github.io/llava-interactive/) is released: Experience the future of human-AI multimodal interaction with an all-in-one demo for Image Chat, Segmentation, Generation and Editing. [[Project Page](https://llava-vl.github.io/llava-interactive/)] [[Demo](https://llavainteractive.ngrok.io/)] [[Code](https://github.com/LLaVA-VL/LLaVA-Interactive-Demo)] [[Paper](https://arxiv.org/abs/2311.00571)]
101
+ - [2023/10/26] 🔥 LLaVA-1.5 with LoRA achieves comparable performance as full-model finetuning, with a reduced GPU RAM requirement ([ckpts](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md#llava-v15), [script](https://github.com/haotian-liu/LLaVA#train)). We also provide a [doc](https://github.com/haotian-liu/LLaVA/blob/main/docs/Finetune_Custom_Data.md) on how to finetune LLaVA-1.5 on your own dataset with LoRA.
102
+ - [2023/10/12] Check out the Korean LLaVA (Ko-LLaVA), created by ETRI, who has generously supported our research! [[🤗 Demo](https://huggingface.co/spaces/etri-vilab/Ko-LLaVA)]
103
+ - [2023/10/05] 🔥 LLaVA-1.5 is out! Achieving SoTA on 11 benchmarks, with just simple modifications to the original LLaVA, utilizes all public data, completes training in ~1 day on a single 8-A100 node, and surpasses methods like Qwen-VL-Chat that use billion-scale data. Check out the [technical report](https://arxiv.org/abs/2310.03744), and explore the [demo](https://llava.hliu.cc/)! Models are available in [Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md). The training data and scripts of LLaVA-1.5 are released [here](https://github.com/haotian-liu/LLaVA#train), and evaluation scripts are released [here](https://github.com/haotian-liu/LLaVA/blob/main/docs/Evaluation.md)!
104
+ - [2023/09/26] LLaVA is improved with reinforcement learning from human feedback (RLHF) to improve fact grounding and reduce hallucination. Check out the new SFT and RLHF checkpoints at project [[LLavA-RLHF]](https://llava-rlhf.github.io/)
105
+ - [2023/09/22] [LLaVA](https://arxiv.org/abs/2304.08485) is accepted by NeurIPS 2023 as **oral presentation**, and [LLaVA-Med](https://arxiv.org/abs/2306.00890) is accepted by NeurIPS 2023 Datasets and Benchmarks Track as **spotlight presentation**.
106
+ - [2023/11/06] Support **Intel** dGPU and CPU platforms. [More details here.](https://github.com/haotian-liu/LLaVA/tree/intel/docs/intel)
107
+ - [2023/10/12] LLaVA is now supported in [llama.cpp](https://github.com/ggerganov/llama.cpp/pull/3436) with 4-bit / 5-bit quantization support!
108
+ - [2023/10/11] The training data and scripts of LLaVA-1.5 are released [here](https://github.com/haotian-liu/LLaVA#train), and evaluation scripts are released [here](https://github.com/haotian-liu/LLaVA/blob/main/docs/Evaluation.md)!
109
+ - [2023/10/10] [Roboflow Deep Dive](https://blog.roboflow.com/first-impressions-with-llava-1-5/): First Impressions with LLaVA-1.5.
110
+ - [2023/09/20] We summarize our empirical study of training 33B and 65B LLaVA models in a [note](https://arxiv.org/abs/2309.09958). Further, if you are interested in the comprehensive review, evolution and trend of multimodal foundation models, please check out our recent survey paper [``Multimodal Foundation Models: From Specialists to General-Purpose Assistants''.](https://arxiv.org/abs/2309.10020)
111
+ <p align="center">
112
+ <img src="https://github.com/Computer-Vision-in-the-Wild/CVinW_Readings/blob/main/images/mfm_evolution.jpeg?raw=true" width=50%/>
113
+ </p>
114
+
115
+ - [2023/07/19] 🔥 We release a major upgrade, including support for LLaMA-2, LoRA training, 4-/8-bit inference, higher resolution (336x336), and a lot more. We release [LLaVA Bench](https://github.com/haotian-liu/LLaVA/blob/main/docs/LLaVA_Bench.md) for benchmarking open-ended visual chat with results from Bard and Bing-Chat. We also support and verify training with RTX 3090 and RTX A6000. Check out [LLaVA-from-LLaMA-2](https://github.com/haotian-liu/LLaVA/blob/main/docs/LLaVA_from_LLaMA2.md), and our [model zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)!
116
+ - [2023/06/26] [CVPR 2023 Tutorial](https://vlp-tutorial.github.io/) on **Large Multimodal Models: Towards Building and Surpassing Multimodal GPT-4**! Please check out [[Slides](https://datarelease.blob.core.windows.net/tutorial/vision_foundation_models_2023/slides/Chunyuan_cvpr2023_tutorial_lmm.pdf)] [[Notes](https://arxiv.org/abs/2306.14895)] [[YouTube](https://youtu.be/mkI7EPD1vp8)] [[Bilibli](https://www.bilibili.com/video/BV1Ng4y1T7v3/)].
117
+ - [2023/06/11] We released the preview for the most requested feature: DeepSpeed and LoRA support! Please see documentations [here](./docs/LoRA.md).
118
+ - [2023/06/01] We released **LLaVA-Med: Large Language and Vision Assistant for Biomedicine**, a step towards building biomedical domain large language and vision models with GPT-4 level capabilities. Checkout the [paper](https://arxiv.org/abs/2306.00890) and [page](https://github.com/microsoft/LLaVA-Med).
119
+ - [2023/05/06] We are releasing [LLaVA-Lighting-MPT-7B-preview](https://huggingface.co/liuhaotian/LLaVA-Lightning-MPT-7B-preview), based on MPT-7B-Chat! See [here](#LLaVA-MPT-7b) for more details.
120
+ - [2023/05/02] 🔥 We are releasing LLaVA-Lighting! Train a lite, multimodal GPT-4 with just $40 in 3 hours! See [here](#train-llava-lightning) for more details.
121
+ - [2023/04/27] Thanks to the community effort, LLaVA-13B with 4-bit quantization allows you to run on a GPU with as few as 12GB VRAM! Try it out [here](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/llava).
122
+ - [2023/04/17] 🔥 We released **LLaVA: Large Language and Vision Assistant**. We propose visual instruction tuning, towards building large language and vision models with GPT-4 level capabilities. Checkout the [paper](https://arxiv.org/abs/2304.08485) and [demo](https://llava.hliu.cc/).
123
+
124
+ </details>
125
+
126
+ <!-- <a href="https://llava.hliu.cc/"><img src="assets/demo.gif" width="70%"></a> -->
127
+
128
+ **Usage and License Notices**: This project utilizes certain datasets and checkpoints that are subject to their respective original licenses. Users must comply with all terms and conditions of these original licenses, including but not limited to the [OpenAI Terms of Use](https://openai.com/policies/terms-of-use) for the dataset and the specific licenses for base language models for checkpoints trained using the dataset (e.g. [Llama-1/2 community license](https://ai.meta.com/llama/license/) for LLaMA-2 and Vicuna-v1.5, [Tongyi Qianwen RESEARCH LICENSE AGREEMENT](https://huggingface.co/Qwen/Qwen1.5-0.5B-Chat/blob/main/LICENSE) and [Llama-3 Research License](https://llama.meta.com/llama3/license/)). This project does not impose any additional constraints beyond those stipulated in the original licenses. Furthermore, users are reminded to ensure that their use of the dataset and checkpoints is in compliance with all applicable laws and regulations.
129
+
130
+ ## Models & Scripts
131
+
132
+ ### Installation
133
+
134
+ #### 1. **Clone this repository and navigate to the LLaVA folder:**
135
+ ```bash
136
+ git clone https://github.com/LLaVA-VL/LLaVA-NeXT
137
+ cd LLaVA-NeXT
138
+ ```
139
+
140
+ #### 2. **Install the inference package:**
141
+ ```bash
142
+ conda create -n llava python=3.10 -y
143
+ conda activate llava
144
+ pip install --upgrade pip # Enable PEP 660 support.
145
+ pip install -e ".[train]"
146
+ ```
147
+
148
+ ### Project Navigation
149
+ Please checkout the following page for more inference & evaluation details.
150
+
151
+ #### - **LLaVA-NeXT: Stronger LLMs Supercharge Multimodal Capabilities in the Wild**
152
+ - [LLaVA-NeXT-Image](./docs/LLaVA-NeXT.md): for image demo inference and evaluation of stronger LMMs using [lmms-eval](https://github.com/EvolvingLMMs-Lab/lmms-eval).
153
+
154
+
155
+ #### - LLaVA-NeXT: A Strong Zero-shot Video Understanding Model
156
+ - [LLaVA-NeXT-Video](./docs/LLaVA-NeXT-Video.md): for video inference and evaluation scripts. We recommend to use [LMMs-video](https://lmms-lab.github.io/posts/lmms-eval-0.2/) for evaluation.
157
+
158
+ #### - LLaVA-NeXT: Tackling Multi-image, Video, and 3D in Large Multimodal Models
159
+ - [LLaVA-NeXT-Interleave](./docs/LLaVA-NeXT-Interleave.md): for multi-image demo and evaluation scripts.
160
+
161
+ ## SGLang for SpeedUp Inference and Deployment
162
+
163
+ We use [SGLang](https://github.com/sgl-project/sglang) to speed up inference and deployment of LLaVA-NeXT. You could make LLaVA-NeXT as a backend API service with SGLang.
164
+
165
+ **Prepare Environment**:
166
+ Following the instruction in the [sglang](https://github.com/sgl-project/sglang?tab=readme-ov-file#install)
167
+
168
+ ### LLaVA-NeXT (Image)
169
+
170
+ Checkout the HTTP Post/Get and SRT usage at [sglang/examples/usage/llava](https://github.com/sgl-project/sglang/blob/main/examples/usage/llava)
171
+
172
+ ### LLaVA-NeXT (Video)
173
+
174
+ **Launch and Run on (K) Nodes**:
175
+ - Go to sglang project
176
+ ```
177
+ cd PATH_TO/sglang
178
+ ```
179
+ - First node:
180
+ ```sh
181
+ bash examples/usage/llava_video/srt_example_llava_v.sh K 0 YOUR_VIDEO_PATH YOUR_MODEL_PATH FRAMES_PER_VIDEO
182
+ (e.g. bash examples/usage/llava_video/srt_example_llava_v.sh K 0 examples/usage/llava_video/videos/Q98Z4OTh8RwmDonc.mp4 lmms-lab/LLaVA-NeXT-Video-7B-DPO 16)
183
+ ```
184
+ - Second node:
185
+ ```sh
186
+ bash examples/usage/llava_video/srt_example_llava_v.sh K 1 YOUR_VIDEO_PATH YOUR_MODEL_PATH FRAMES_PER_VIDEO
187
+ ```
188
+ - The K node:
189
+ ```sh
190
+ bash examples/usage/llava_video/srt_example_llava_v.sh K K-1 YOUR_VIDEO_PATH YOUR_MODEL_PATH FRAMES_PER_VIDEO
191
+ ```
192
+
193
+
194
+ ## Citation
195
+
196
+ If you find it useful for your research and applications, please cite related papers/blogs using this BibTeX:
197
+ ```bibtex
198
+ @article{li2024llava,
199
+ title={LLaVA-NeXT-Interleave: Tackling Multi-image, Video, and 3D in Large Multimodal Models},
200
+ author={Li, Feng and Zhang, Renrui and Zhang, Hao and Zhang, Yuanhan and Li, Bo and Li, Wei and Ma, Zejun and Li, Chunyuan},
201
+ journal={arXiv preprint arXiv:2407.07895},
202
+ year={2024}
203
+ }
204
+
205
+ @misc{li2024llavanext-ablations,
206
+ title={LLaVA-NeXT: What Else Influences Visual Instruction Tuning Beyond Data?},
207
+ url={https://llava-vl.github.io/blog/2024-05-25-llava-next-ablations/},
208
+ author={Li, Bo and Zhang, Hao and Zhang, Kaichen and Guo, Dong and Zhang, Yuanhan and Zhang, Renrui and Li, Feng and Liu, Ziwei and Li, Chunyuan},
209
+ month={May},
210
+ year={2024}
211
+ }
212
+
213
+ @misc{li2024llavanext-strong,
214
+ title={LLaVA-NeXT: Stronger LLMs Supercharge Multimodal Capabilities in the Wild},
215
+ url={https://llava-vl.github.io/blog/2024-05-10-llava-next-stronger-llms/},
216
+ author={Li, Bo and Zhang, Kaichen and Zhang, Hao and Guo, Dong and Zhang, Renrui and Li, Feng and Zhang, Yuanhan and Liu, Ziwei and Li, Chunyuan},
217
+ month={May},
218
+ year={2024}
219
+ }
220
+
221
+ @misc{zhang2024llavanext-video,
222
+ title={LLaVA-NeXT: A Strong Zero-shot Video Understanding Model},
223
+ url={https://llava-vl.github.io/blog/2024-04-30-llava-next-video/},
224
+ author={Zhang, Yuanhan and Li, Bo and Liu, haotian and Lee, Yong jae and Gui, Liangke and Fu, Di and Feng, Jiashi and Liu, Ziwei and Li, Chunyuan},
225
+ month={April},
226
+ year={2024}
227
+ }
228
+
229
+ @misc{liu2024llavanext,
230
+ title={LLaVA-NeXT: Improved reasoning, OCR, and world knowledge},
231
+ url={https://llava-vl.github.io/blog/2024-01-30-llava-next/},
232
+ author={Liu, Haotian and Li, Chunyuan and Li, Yuheng and Li, Bo and Zhang, Yuanhan and Shen, Sheng and Lee, Yong Jae},
233
+ month={January},
234
+ year={2024}
235
+ }
236
+
237
+ @misc{liu2023improvedllava,
238
+ title={Improved Baselines with Visual Instruction Tuning},
239
+ author={Liu, Haotian and Li, Chunyuan and Li, Yuheng and Lee, Yong Jae},
240
+ publisher={arXiv:2310.03744},
241
+ year={2023},
242
+ }
243
+
244
+ @misc{liu2023llava,
245
+ title={Visual Instruction Tuning},
246
+ author={Liu, Haotian and Li, Chunyuan and Wu, Qingyang and Lee, Yong Jae},
247
+ publisher={NeurIPS},
248
+ year={2023},
249
+ }
250
+ ```
251
+
252
+ ## Acknowledgement
253
+
254
+ - [Vicuna](https://github.com/lm-sys/FastChat): the codebase we built upon, and our base model Vicuna-13B that has the amazing language capabilities!
255
+ - The LLaVA-NeXT project is currently maintained by the team along with our contributors (listed alphabetically by the first names): [Bo Li](https://brianboli.com/), [Dong Guo](https://www.linkedin.com/in/dongguoset/), [Feng Li](https://scholar.google.com/citations?hl=zh-CN&user=ybRe9GcAAAAJ&view_op=list_works&sortby=pubdate), [Hao Zhang](https://scholar.google.com/citations?user=B8hPxMQAAAAJ&hl=en), [Kaichen Zhang](https://www.linkedin.com/in/kaichen-zhang-014b17219/?originalSubdomain=sg), [Renrui Zhang](https://zrrskywalker.github.io/), [Yuanhan Zhang](https://zhangyuanhan-ai.github.io/), led by [Chunyuan Li](https://chunyuan.li/) and with the guidance and help from [Haotian Liu](https://hliu.cc/).
256
+ - The `lmms-eval` framework and its core contributors, including Peiyuan Zhang, Fanyi Pu, Joshua Adrian Cahyono, and Kairui Hu, for their support on the evaluation side.
257
+
258
+ ## Related Projects
259
+
260
+ - [Instruction Tuning with GPT-4](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
261
+ - [LLaVA-Med: Training a Large Language-and-Vision Assistant for Biomedicine in One Day](https://github.com/microsoft/LLaVA-Med)
262
+ - [Otter: In-Context Multi-Modal Instruction Tuning](https://github.com/Luodian/Otter)
263
+
264
+ For future project ideas, please check out:
265
+ - [SEEM: Segment Everything Everywhere All at Once](https://github.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once)
266
+ - [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything) to detect, segment, and generate anything by marrying [Grounding DINO](https://github.com/IDEA-Research/GroundingDINO) and [Segment-Anything](https://github.com/facebookresearch/segment-anything).
llava-1.7.0.dev0.dist-info/RECORD ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ llava-1.7.0.dev0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
2
+ llava-1.7.0.dev0.dist-info/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
3
+ llava-1.7.0.dev0.dist-info/METADATA,sha256=lLd1vxRYxiY82Kqxic3pOgPA1GfwRclzyqko-u4mbl8,22760
4
+ llava-1.7.0.dev0.dist-info/RECORD,,
5
+ llava-1.7.0.dev0.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
+ llava-1.7.0.dev0.dist-info/WHEEL,sha256=Mdi9PDNwEZptOjTlUcAth7XJDFtKrHYaQMPulZeBCiQ,91
7
+ llava-1.7.0.dev0.dist-info/direct_url.json,sha256=hxcapmB6J2WrkkCuvIaCPRbl85Ju7DJ67xkANY1sWuc,138
8
+ llava-1.7.0.dev0.dist-info/top_level.txt,sha256=AlU_N7AUyx6Fn0VOZu0pGBgjbU0fGvKkHnnkCFJbIF4,10
9
+ llava/__init__.py,sha256=8fWfEdbl8Xc5O1CThmLAMnB2h1Dt-gQiLeIW1Uo-JhE,42
10
+ llava/__pycache__/__init__.cpython-39.pyc,,
11
+ llava/__pycache__/constants.cpython-39.pyc,,
12
+ llava/__pycache__/conversation.cpython-39.pyc,,
13
+ llava/__pycache__/mm_utils.cpython-39.pyc,,
14
+ llava/__pycache__/utils.cpython-39.pyc,,
15
+ llava/constants.py,sha256=bcZAgJAHgpyMey-SSv3llZjeJfC8xJ7IvIRwPIGrj-4,305
16
+ llava/conversation.py,sha256=k-L_tP6EcNYxkVH0PacaeuNAw9R7NmllE8oTPmHs3oM,22785
17
+ llava/eval/__pycache__/evaluate_interleave.cpython-39.pyc,,
18
+ llava/eval/__pycache__/model_vqa.cpython-39.pyc,,
19
+ llava/eval/evaluate_interleave.py,sha256=i8jwOxkYCh2WwMmCQ6bMqeLYZ_YvZHDNv-g82PxaOoY,10989
20
+ llava/eval/model_vqa.py,sha256=sKUyodB4dGHy0j7oxF_Um72lgn48iPFkR9upYej8VWw,10704
21
+ llava/mm_utils.py,sha256=Gwvu67nQT2Urwj4Q7bvcK7Y_yOkzilX50GSj5UC2-DY,17417
22
+ llava/model/__init__.py,sha256=K1A5xgHwGb6vhX2FsA0kEcRK7RFlM419rJJ0--Ax_78,679
23
+ llava/model/__pycache__/__init__.cpython-39.pyc,,
24
+ llava/model/__pycache__/apply_delta.cpython-39.pyc,,
25
+ llava/model/__pycache__/builder.cpython-39.pyc,,
26
+ llava/model/__pycache__/consolidate.cpython-39.pyc,,
27
+ llava/model/__pycache__/llava_arch.cpython-39.pyc,,
28
+ llava/model/__pycache__/make_delta.cpython-39.pyc,,
29
+ llava/model/__pycache__/utils.cpython-39.pyc,,
30
+ llava/model/apply_delta.py,sha256=ZItbnApA9G_hAXShPAOe5STKUy4s5o9acJ_wseyTWrU,1979
31
+ llava/model/builder.py,sha256=ou9C95SNH6JWB8tBuYWSFcw71Twl7b9l3T3UBV3XN_8,17923
32
+ llava/model/consolidate.py,sha256=iYWg_Huv7GQcuUKXI6EV5uESjhij_qTH_XhUzchoXV0,945
33
+ llava/model/language_model/__pycache__/llava_gemma.cpython-39.pyc,,
34
+ llava/model/language_model/__pycache__/llava_llama.cpython-39.pyc,,
35
+ llava/model/language_model/__pycache__/llava_mistral.cpython-39.pyc,,
36
+ llava/model/language_model/__pycache__/llava_mixtral.cpython-39.pyc,,
37
+ llava/model/language_model/__pycache__/llava_mpt.cpython-39.pyc,,
38
+ llava/model/language_model/__pycache__/llava_qwen.cpython-39.pyc,,
39
+ llava/model/language_model/__pycache__/llava_qwen_moe.cpython-39.pyc,,
40
+ llava/model/language_model/__pycache__/modeling_llama.cpython-39.pyc,,
41
+ llava/model/language_model/llava_gemma.py,sha256=800LF_ldzdpq9_yeYejbhpzsOD1EwZuS_G2Nkf6ejuU,4980
42
+ llava/model/language_model/llava_llama.py,sha256=X1m-xVknZb6caYTeF1iJT29WV0l-2n6Ud7iL8zr49C0,6322
43
+ llava/model/language_model/llava_mistral.py,sha256=IpV8-8NE693wMSYF7zzqBgCNGdpTMRYTU9RClAWKQL4,5189
44
+ llava/model/language_model/llava_mixtral.py,sha256=mA4kq2VjbYin7r_nad9VzEkj8cguVR9F35mu3zVYulc,5882
45
+ llava/model/language_model/llava_mpt.py,sha256=7FRWHZf6JWkrqJIgaosAP19p2vl-kNGNrhu4JkaYdPk,3836
46
+ llava/model/language_model/llava_qwen.py,sha256=ESNFowdoSykW9BnhSZAWgJMb_xwigOeA4G4AkrF2rh8,6204
47
+ llava/model/language_model/llava_qwen_moe.py,sha256=LOMS6d-BP8o2-SJfETnlZttCdMVZ6hGDfuhispjtQlo,6230
48
+ llava/model/language_model/modeling_llama.py,sha256=Qwd2vsz-vAbBl07zwR1IvCTCTwyRSoxJMwVcgbNNKJc,82886
49
+ llava/model/llava_arch.py,sha256=vO_gPr8uQZ8EOPWphnMeFGL05u7xRjZKFWgnU6AsUNc,28497
50
+ llava/model/make_delta.py,sha256=oUJNaT6ikdifV5fF9SPCllyucO8h0tjSniT8UFKzcWk,2303
51
+ llava/model/multimodal_encoder/__pycache__/builder.cpython-39.pyc,,
52
+ llava/model/multimodal_encoder/__pycache__/clip_encoder.cpython-39.pyc,,
53
+ llava/model/multimodal_encoder/__pycache__/hf_vision.cpython-39.pyc,,
54
+ llava/model/multimodal_encoder/__pycache__/imagebind.cpython-39.pyc,,
55
+ llava/model/multimodal_encoder/__pycache__/open_clip_encoder.cpython-39.pyc,,
56
+ llava/model/multimodal_encoder/__pycache__/siglip_encoder.cpython-39.pyc,,
57
+ llava/model/multimodal_encoder/builder.py,sha256=bYmGLnJHgBJXnkRrs9phaqHH6AlVsUNey6w-iZFuXn0,1922
58
+ llava/model/multimodal_encoder/clip_encoder.py,sha256=ofOgPYkJjXGnpk8SAtOGSXdw1DKYZOAeUwzn-5DouBE,7448
59
+ llava/model/multimodal_encoder/dev_eva_clip/__pycache__/eva_vit.cpython-39.pyc,,
60
+ llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__init__.py,sha256=6mbC4b7gg9g4LcxJXEEZtAGMp_jwzl0enio8T6j6b3Y,792
61
+ llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/__init__.cpython-39.pyc,,
62
+ llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/constants.cpython-39.pyc,,
63
+ llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/eva_vit_model.cpython-39.pyc,,
64
+ llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/factory.cpython-39.pyc,,
65
+ llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/hf_configs.cpython-39.pyc,,
66
+ llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/hf_model.cpython-39.pyc,,
67
+ llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/loss.cpython-39.pyc,,
68
+ llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/model.cpython-39.pyc,,
69
+ llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/modified_resnet.cpython-39.pyc,,
70
+ llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/openai.cpython-39.pyc,,
71
+ llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/pretrained.cpython-39.pyc,,
72
+ llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/rope.cpython-39.pyc,,
73
+ llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/timm_model.cpython-39.pyc,,
74
+ llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/tokenizer.cpython-39.pyc,,
75
+ llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/transform.cpython-39.pyc,,
76
+ llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/transformer.cpython-39.pyc,,
77
+ llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/utils.cpython-39.pyc,,
78
+ llava/model/multimodal_encoder/dev_eva_clip/eva_clip/constants.py,sha256=PKjrqkcdpJK_MQmnTsZ2oxhxIHn9AlI-y2ap87BnR1Q,118
79
+ llava/model/multimodal_encoder/dev_eva_clip/eva_clip/eva_vit_model.py,sha256=XwR8sswHn0Eb8C9EOFz-Gn-lQm831cCCu2xbR__0XiI,23142
80
+ llava/model/multimodal_encoder/dev_eva_clip/eva_clip/factory.py,sha256=4fbCuG0eUpVpSA_yE8_GUxCyRmojbXF2C9X3oSE32ns,24280
81
+ llava/model/multimodal_encoder/dev_eva_clip/eva_clip/hf_configs.py,sha256=CwD_HmdfQ1Tb-fLOr9KgseVP80nMNv6V4uWI6DDOBqg,2132
82
+ llava/model/multimodal_encoder/dev_eva_clip/eva_clip/hf_model.py,sha256=uWSu0OTsXR8v5y2P6jwkdzzy2Ce1H__UEthHM0F7xR4,10350
83
+ llava/model/multimodal_encoder/dev_eva_clip/eva_clip/loss.py,sha256=p-B34PgBg0JuutSriqp0Qc2VLJrkLf91fGmBRHiZOSg,5746
84
+ llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model.py,sha256=nNKTAljan_PpGkTJ4niwV3xxI0g9C3704U6OUJh8P_k,17650
85
+ llava/model/multimodal_encoder/dev_eva_clip/eva_clip/modified_resnet.py,sha256=PB0q6KsaQKwVRlX8R4qW8Cf4rzY5v5QcFiydMXE8rS0,7163
86
+ llava/model/multimodal_encoder/dev_eva_clip/eva_clip/openai.py,sha256=g-kvzfUuPEMW0uU4e8NfwGCpJnL1kXdEZUWT0YqVoXQ,5570
87
+ llava/model/multimodal_encoder/dev_eva_clip/eva_clip/pretrained.py,sha256=lWoSv_3xdMPmVv8EnQWpC7Kq4H8ihSYTHYk_Nr_jGA8,12211
88
+ llava/model/multimodal_encoder/dev_eva_clip/eva_clip/rope.py,sha256=i8RDQ1Zv9cTXVBbW8RbbfaT0wGxjEFu-qq3DCXQBR-8,5399
89
+ llava/model/multimodal_encoder/dev_eva_clip/eva_clip/timm_model.py,sha256=Eta_-wNrwv953zWVxXshCywCVOwK2jPRiOId9XcFyhk,4895
90
+ llava/model/multimodal_encoder/dev_eva_clip/eva_clip/tokenizer.py,sha256=u4Gur6i8rcWvdPZRZSNgDshmbkchDx5DvZlSGxvoXH8,7368
91
+ llava/model/multimodal_encoder/dev_eva_clip/eva_clip/transform.py,sha256=fYdJYEVPviaTliRkSsrWdPmYbLGTM4a6QYlNN_3ZzHA,3514
92
+ llava/model/multimodal_encoder/dev_eva_clip/eva_clip/transformer.py,sha256=EYlChMZnX7vQvitTWM73iIhaZf4zv_OTSJ6L7ZTZ8go,26410
93
+ llava/model/multimodal_encoder/dev_eva_clip/eva_clip/utils.py,sha256=aYJKAK5qw8Kge-a6wTTBOwb-wqaV9Gri5vuLMYq4E84,14964
94
+ llava/model/multimodal_encoder/dev_eva_clip/eva_vit.py,sha256=7m3OHiUdnHkScpoRp_DjGLavrCcKf2te6oJshv27kzI,6219
95
+ llava/model/multimodal_encoder/eva_clip/__pycache__/eva_clip_encoder.cpython-39.pyc,,
96
+ llava/model/multimodal_encoder/eva_clip/__pycache__/eva_clip_processors.cpython-39.pyc,,
97
+ llava/model/multimodal_encoder/eva_clip/__pycache__/eva_vit.cpython-39.pyc,,
98
+ llava/model/multimodal_encoder/eva_clip/__pycache__/factory.cpython-39.pyc,,
99
+ llava/model/multimodal_encoder/eva_clip/eva_clip_encoder.py,sha256=FL-gQEpHBlYYYtzPcM1jg5HTPtyqSasNrou_aRtmghs,2890
100
+ llava/model/multimodal_encoder/eva_clip/eva_clip_processors.py,sha256=kwNlbCc4cWz7cmQBZwS97as2taQf0RRFoZAXuNZdvjg,2215
101
+ llava/model/multimodal_encoder/eva_clip/eva_vit.py,sha256=mrgroKZHGFK2URbatEQKpID8zhKmVBHgRyKEc4D_bUI,34605
102
+ llava/model/multimodal_encoder/eva_clip/factory.py,sha256=iLoVP1ldKm0YvXX3uz4Wsb2dw1DElMZgUdIrxMS1e70,1829
103
+ llava/model/multimodal_encoder/hf_vision.py,sha256=Pw0y7SVYKiUIuBCP8uMySWRyIcpNBN1oUsjRBMVqSfM,4549
104
+ llava/model/multimodal_encoder/imagebind.py,sha256=MkaKOrpYr1Fj08QSzy-Y3awDmmB9Y5Y6KoVGJR52Xpg,2498
105
+ llava/model/multimodal_encoder/open_clip_encoder.py,sha256=0iFyD49NZFwTutR6Hq5upIybHbrzgPlPwQ8kgrRwZXQ,6812
106
+ llava/model/multimodal_encoder/siglip_encoder.py,sha256=LPGxELdEKwu5FO1vtvVmBmcjhMq48d1AzpJJAq_0yIk,26103
107
+ llava/model/multimodal_projector/__pycache__/builder.cpython-39.pyc,,
108
+ llava/model/multimodal_projector/__pycache__/pooler_projector.cpython-39.pyc,,
109
+ llava/model/multimodal_projector/builder.py,sha256=acKSHT-As_qu2haXj-g6gRRzdq_BPNWwgP_ZDYEntUI,2192
110
+ llava/model/multimodal_projector/pooler_projector.py,sha256=zxAP1Ut-oJXG-L4xggh2FC4epc0nemgk1v8RnoKCxZ4,975
111
+ llava/model/multimodal_resampler/__pycache__/builder.cpython-39.pyc,,
112
+ llava/model/multimodal_resampler/__pycache__/masked_drop.cpython-39.pyc,,
113
+ llava/model/multimodal_resampler/__pycache__/perceiver.cpython-39.pyc,,
114
+ llava/model/multimodal_resampler/__pycache__/qformer.cpython-39.pyc,,
115
+ llava/model/multimodal_resampler/__pycache__/spatial_pool.cpython-39.pyc,,
116
+ llava/model/multimodal_resampler/builder.py,sha256=qaSzq2lcRDkIFv_QRTXrkfn-OHfII--4LIHkrkIfwPg,1039
117
+ llava/model/multimodal_resampler/masked_drop.py,sha256=FNgUNkIw8JQaAv3lppL7vYtUOfoP2DArw0AuslXQ0TE,3061
118
+ llava/model/multimodal_resampler/perceiver.py,sha256=uOAntKuihMkBkAp5bIozKUApvXhvlCeocRNtUva-VqA,4995
119
+ llava/model/multimodal_resampler/qformer.py,sha256=d-A2JpouT-VjWb43BF4HXP_jaIM0o_NhFhVy_3Uawsc,50384
120
+ llava/model/multimodal_resampler/spatial_pool.py,sha256=hEAlKpbgzGjXeY365TZaI3MI2YAvle1Yfb5dKlAiQls,1775
121
+ llava/model/utils.py,sha256=KzkLVJjTHJqI9vg1umDp4-SkT4IbMcI_Uhp-4V4xkWk,947
122
+ llava/serve/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
123
+ llava/serve/__pycache__/__init__.cpython-39.pyc,,
124
+ llava/serve/__pycache__/cli.cpython-39.pyc,,
125
+ llava/serve/__pycache__/controller.cpython-39.pyc,,
126
+ llava/serve/__pycache__/gradio_multi_image.cpython-39.pyc,,
127
+ llava/serve/__pycache__/gradio_web_server.cpython-39.pyc,,
128
+ llava/serve/__pycache__/model_worker.cpython-39.pyc,,
129
+ llava/serve/__pycache__/register_worker.cpython-39.pyc,,
130
+ llava/serve/__pycache__/sglang_worker.cpython-39.pyc,,
131
+ llava/serve/__pycache__/test_message.cpython-39.pyc,,
132
+ llava/serve/cli.py,sha256=e-ALjf2zdr08UiqeW-DmBoGEHRWiO-I5ELpqjls30iE,4403
133
+ llava/serve/controller.py,sha256=zKmdDMoOyHltZGKQCzIgrUkXsget_3UkjgGNyq0xy7Y,10070
134
+ llava/serve/gradio_multi_image.py,sha256=mwVVe4l-7ry3umZx9CFGrwYKgPup4RLupMXrsRj1IZc,20029
135
+ llava/serve/gradio_web_server.py,sha256=t8xWJPNQ0zDOGPi4ju9NkA89kbOcPVMO-v6pNM7BZIs,19519
136
+ llava/serve/model_worker.py,sha256=SBzKdeQE0hhVM9bwxplVb8KqmUm9qhp1H74THX82MD0,11121
137
+ llava/serve/register_worker.py,sha256=Q7BnBGr0lcDdKaI-DHv_5IKK0KpHvtUTCBwFz5PspLo,760
138
+ llava/serve/sglang_worker.py,sha256=lYeIDVZlKho4YcLi82bUxP4ccFJCTpVcfcM_uvdH6wI,9221
139
+ llava/serve/test_message.py,sha256=ofJWbzm3oQz5UKU2tBSfV2ZzDZkGpMPDE9yrlvJXNAM,2048
140
+ llava/train/__pycache__/llama_flash_attn_monkey_patch.cpython-39.pyc,,
141
+ llava/train/__pycache__/llava_trainer.cpython-39.pyc,,
142
+ llava/train/__pycache__/llava_trainer_eval.cpython-39.pyc,,
143
+ llava/train/__pycache__/train.cpython-39.pyc,,
144
+ llava/train/__pycache__/train_dpo.cpython-39.pyc,,
145
+ llava/train/__pycache__/train_mem.cpython-39.pyc,,
146
+ llava/train/llama_flash_attn_monkey_patch.py,sha256=CBkiqWIZXW68_2YJdtTPQqXBadPq15vHmliDGoqeW5c,4280
147
+ llava/train/llava_trainer.py,sha256=rGZCclj_T8ATDfR2JrNa1mLibH1z0kkVRQsSxvT_Rw8,27309
148
+ llava/train/llava_trainer_eval.py,sha256=bNGpwNtA1d20xQ5BxZa8O-ZnMHR7CqQ9VgAbYtt3mAQ,3515
149
+ llava/train/train.py,sha256=VLRZQV-LjafzV5wtQ1f_s_0Qm0mxE8ei1VSW-e_XMrU,78864
150
+ llava/train/train_dpo.py,sha256=oFGJghgJezMVms70YWd4KvFrsZDei4UbojjhMroNIOE,84440
151
+ llava/train/train_mem.py,sha256=C06MqpCqOVtTsewH8N67oUzmYIm0HY6-y3MuzhlE1wg,80
152
+ llava/utils.py,sha256=qixNPajlBGe9XqNWnYOQ6V6OTreQcBK6W8jkuxRjzBU,6533
153
+ trl/__init__.py,sha256=Od8x7-H_1H5LfnScvJTJxjWeDuHzKlnUuToL5RQswSA,1110
154
+ trl/__pycache__/__init__.cpython-39.pyc,,
155
+ trl/__pycache__/core.cpython-39.pyc,,
156
+ trl/__pycache__/import_utils.cpython-39.pyc,,
157
+ trl/core.py,sha256=TPuO3us2wqAXsQWm8v-lNtnVmYHiuJcvOJoZeSV29YI,12303
158
+ trl/environment/__init__.py,sha256=XM1ZiS_F7-r8P6Z20VNHh71Wnw-scMoujSU-lqEKGNc,78
159
+ trl/environment/__pycache__/__init__.cpython-39.pyc,,
160
+ trl/environment/__pycache__/base_environment.cpython-39.pyc,,
161
+ trl/environment/base_environment.py,sha256=pyrIOZJsl-Q6VAv2PRGaUbIDeCDp7jyc1mtibpPvHrA,17882
162
+ trl/extras/__init__.py,sha256=daKpM_o7XbZix98t_kxwLyMteb5EViUCH8MURZFEq_Q,684
163
+ trl/extras/__pycache__/__init__.cpython-39.pyc,,
164
+ trl/extras/__pycache__/best_of_n_sampler.cpython-39.pyc,,
165
+ trl/extras/__pycache__/dataset_formatting.cpython-39.pyc,,
166
+ trl/extras/best_of_n_sampler.py,sha256=RHA3RbnqifnpUh7HZrKdhcDNz9LVSpcYUj_A_jrC8Ro,5243
167
+ trl/extras/dataset_formatting.py,sha256=TVeUWfxA1q3oat3HJpMIA6olsUYwjpQzDReVLkeZ7NI,3726
168
+ trl/import_utils.py,sha256=kfnxR_z4CB1rM5JcBZtVxhsOcxHYmIXWzTgTMRGc-7U,3238
169
+ trl/models/__init__.py,sha256=xY9josSMMq7J0coDxBnhsMvIK3sJvfNIeOgseQZu6cE,1244
170
+ trl/models/__pycache__/__init__.cpython-39.pyc,,
171
+ trl/models/__pycache__/modeling_base.cpython-39.pyc,,
172
+ trl/models/__pycache__/modeling_sd_base.cpython-39.pyc,,
173
+ trl/models/__pycache__/modeling_value_head.cpython-39.pyc,,
174
+ trl/models/__pycache__/utils.cpython-39.pyc,,
175
+ trl/models/modeling_base.py,sha256=oMvYF2MnXqykCkDBBAdLDjowUB0PcL5LftpArsdquiM,28842
176
+ trl/models/modeling_sd_base.py,sha256=2OB-rShWUebUoCuVr27gla3DEpA_eX2W5UCVr6WJ2w0,28073
177
+ trl/models/modeling_value_head.py,sha256=wq9rqn8oPJMmgyNpgI5AWZSmT0JZb4RHO13r6jzExTo,18822
178
+ trl/models/utils.py,sha256=8kc1anjd4PPLWM5zce8eoXQox1uq6R-E_UwUF_b2YBk,3389
179
+ trl/trainer/__init__.py,sha256=9gamN5nkygFHBfF56JvC0sN67axqU6WuXXY9s1YteK8,1514
180
+ trl/trainer/__pycache__/__init__.cpython-39.pyc,,
181
+ trl/trainer/__pycache__/base.cpython-39.pyc,,
182
+ trl/trainer/__pycache__/ddpo_config.cpython-39.pyc,,
183
+ trl/trainer/__pycache__/ddpo_trainer.cpython-39.pyc,,
184
+ trl/trainer/__pycache__/dpo_trainer.cpython-39.pyc,,
185
+ trl/trainer/__pycache__/iterative_sft_trainer.cpython-39.pyc,,
186
+ trl/trainer/__pycache__/model_config.cpython-39.pyc,,
187
+ trl/trainer/__pycache__/ppo_config.cpython-39.pyc,,
188
+ trl/trainer/__pycache__/ppo_trainer.cpython-39.pyc,,
189
+ trl/trainer/__pycache__/reward_config.cpython-39.pyc,,
190
+ trl/trainer/__pycache__/reward_trainer.cpython-39.pyc,,
191
+ trl/trainer/__pycache__/sft_trainer.cpython-39.pyc,,
192
+ trl/trainer/__pycache__/utils.cpython-39.pyc,,
193
+ trl/trainer/base.py,sha256=PID37pjUqfbobelu9tFP9nwd_p9Rx_Cq7XRgEHhhWYE,1818
194
+ trl/trainer/ddpo_config.py,sha256=kwFUTMv85yjXICGVimcQfrPCu5Smz-Mz3c3erEA3SRU,4932
195
+ trl/trainer/ddpo_trainer.py,sha256=NTfQ5jiLuiKGp9ypH3mcxSZIj2cOZjzu3yg5THBXLAg,27023
196
+ trl/trainer/dpo_trainer.py,sha256=Zcc7ohWl83KVFcQLkC8qfBLT6zWpzK1jjLuqqGL4UBE,62580
197
+ trl/trainer/iterative_sft_trainer.py,sha256=_91ZH1o1IkWOanuqHSjhEGx_nElDJ_WiBmQwG0DWNsU,16489
198
+ trl/trainer/model_config.py,sha256=xlsz4478y8f11ZQZ-kwVsGc5bdzyIPTYi4pPdOSr2TU,2966
199
+ trl/trainer/ppo_config.py,sha256=IC9Y1K-6hQcipDr6jDywsBh4fToJ-3KsuSgOY4aJS-0,8317
200
+ trl/trainer/ppo_trainer.py,sha256=NmongqErhUrRckVrHpItl3J1ztV_exEAalqyTqxDA7g,63231
201
+ trl/trainer/reward_config.py,sha256=Q7IihMGMTMIBFGglv-IuJdSWpV6FSbhnlqrZcUaERVU,1661
202
+ trl/trainer/reward_trainer.py,sha256=93FBp9uus_FAQN560ehyP4yRLWyb9Y4OVysx8rpIACU,13603
203
+ trl/trainer/sft_trainer.py,sha256=AxrL8nkyO9Cfgd9C8MZLTR5ZUckEUfD5TWSHQWL2dTE,24691
204
+ trl/trainer/utils.py,sha256=d5W852wGU4mOErsLvqN4jGq4-Mzr_fFFAMY-stBFUYU,31955
llava-1.7.0.dev0.dist-info/REQUESTED ADDED
File without changes
llava-1.7.0.dev0.dist-info/WHEEL ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (73.0.1)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
llava-1.7.0.dev0.dist-info/direct_url.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"url": "https://github.com/LLaVA-VL/LLaVA-NeXT.git", "vcs_info": {"commit_id": "e98849102929e1c6304b60b28cca541567b7b643", "vcs": "git"}}
llava-1.7.0.dev0.dist-info/top_level.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ llava
2
+ trl
llava/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model import LlavaLlamaForCausalLM
llava/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (224 Bytes). View file
 
llava/__pycache__/constants.cpython-39.pyc ADDED
Binary file (486 Bytes). View file
 
llava/__pycache__/conversation.cpython-39.pyc ADDED
Binary file (14.4 kB). View file
 
llava/__pycache__/mm_utils.cpython-39.pyc ADDED
Binary file (14.1 kB). View file
 
llava/__pycache__/utils.cpython-39.pyc ADDED
Binary file (6.51 kB). View file
 
llava/constants.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
2
+ WORKER_HEART_BEAT_INTERVAL = 15
3
+
4
+ LOGDIR = "."
5
+
6
+ # Model Constants
7
+ IGNORE_INDEX = -100
8
+ IMAGE_TOKEN_INDEX = -200
9
+ DEFAULT_IMAGE_TOKEN = "<image>"
10
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
11
+ DEFAULT_IM_START_TOKEN = "<im_start>"
12
+ DEFAULT_IM_END_TOKEN = "<im_end>"
llava/conversation.py ADDED
@@ -0,0 +1,577 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Any, Dict, Union, Tuple
4
+ import re
5
+ import base64
6
+ from io import BytesIO
7
+ from PIL import Image
8
+ from transformers import AutoTokenizer
9
+
10
+
11
+ class SeparatorStyle(Enum):
12
+ """Different separator style."""
13
+
14
+ SINGLE = auto()
15
+ TWO = auto()
16
+ MPT = auto()
17
+ PLAIN = auto()
18
+ CHATML = auto()
19
+ LLAMA_2 = auto()
20
+ LLAMA_3 = auto()
21
+ QWEN = auto()
22
+ GEMMA = auto()
23
+
24
+
25
+ @dataclasses.dataclass
26
+ class Conversation:
27
+ """A class that keeps all conversation history."""
28
+
29
+ system: str
30
+ roles: List[str]
31
+ messages: List[List[str]]
32
+ offset: int
33
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
34
+ sep: str = "###"
35
+ sep2: str = None
36
+ version: str = "Unknown"
37
+
38
+ tokenizer_id: str = ""
39
+ tokenizer: Any = None
40
+ # Stop criteria (the default one is EOS token)
41
+ stop_str: Union[str, List[str]] = None
42
+ # Stops generation if meeting any token in this list
43
+ stop_token_ids: List[int] = None
44
+
45
+ skip_next: bool = False
46
+
47
+ def get_prompt(self):
48
+ messages = self.messages
49
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
50
+ messages = self.messages.copy()
51
+ init_role, init_msg = messages[0].copy()
52
+ init_msg = init_msg[0]
53
+ if "mmtag" in self.version:
54
+ init_msg = init_msg.replace("<image>", "").strip()
55
+ messages[0] = (init_role, init_msg)
56
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
57
+ messages.insert(1, (self.roles[1], "Received."))
58
+ elif not init_msg.startswith("<image>"):
59
+ init_msg = init_msg.replace("<image>", "").strip()
60
+ messages[0] = (init_role, "<image>\n" + init_msg)
61
+ else:
62
+ messages[0] = (init_role, init_msg)
63
+
64
+ if self.sep_style == SeparatorStyle.SINGLE:
65
+ ret = self.system + self.sep
66
+ for role, message in messages:
67
+ if message:
68
+ if type(message) is tuple:
69
+ message, _, _ = message
70
+ ret += role + ": " + message + self.sep
71
+ else:
72
+ ret += role + ":"
73
+
74
+ elif self.sep_style == SeparatorStyle.TWO:
75
+ seps = [self.sep, self.sep2]
76
+ ret = self.system + seps[0]
77
+ for i, (role, message) in enumerate(messages):
78
+ if message:
79
+ if type(message) is tuple:
80
+ message, _, _ = message
81
+ ret += role + ": " + message + seps[i % 2]
82
+ else:
83
+ ret += role + ":"
84
+
85
+ elif self.sep_style == SeparatorStyle.CHATML:
86
+ ret = "" if self.system == "" else self.system + self.sep + "\n"
87
+ for role, message in messages:
88
+ if message:
89
+ if type(message) is tuple:
90
+ message, images, _ = message
91
+ message = "<image>" * len(images) + message
92
+ ret += role + "\n" + message + self.sep + "\n"
93
+ else:
94
+ ret += role + "\n"
95
+ return ret
96
+
97
+ elif self.sep_style == SeparatorStyle.LLAMA_3:
98
+ chat_template_messages = [{"role": "system", "content": self.system}]
99
+ for role, message in messages:
100
+ if message:
101
+ if type(message) is tuple:
102
+ message, images = message
103
+ message = "<image>" * len(images) + message
104
+ chat_template_messages.append({"role": role, "content": message})
105
+
106
+ # print(chat_template_messages)
107
+ return self.tokenizer.apply_chat_template(chat_template_messages, tokenize=False, add_generation_prompt=True)
108
+ # ret = "" if self.system == "" else self.system + self.sep + "\n"
109
+ # for role, message in messages:
110
+ # if message:
111
+ # if type(message) is tuple:
112
+ # message, images = message
113
+ # message = "<image>" * len(images) + message
114
+ # ret += role + "\n" + message + self.sep + "\n"
115
+ # else:
116
+ # ret += role + "\n"
117
+ # return ret
118
+
119
+ elif self.sep_style == SeparatorStyle.MPT:
120
+ ret = self.system + self.sep
121
+ for role, message in messages:
122
+ if message:
123
+ if type(message) is tuple:
124
+ message, _, _ = message
125
+ ret += role + message + self.sep
126
+ else:
127
+ ret += role
128
+
129
+ elif self.sep_style == SeparatorStyle.GEMMA:
130
+ ret = ""
131
+ for i, (role, message) in enumerate(messages):
132
+ assert role == self.roles[i % 2], "Conversation should alternate user/assistant/user/assistant/..."
133
+ if message:
134
+ if type(message) is tuple:
135
+ message, _, _ = message
136
+ ret += role + message + self.sep
137
+ else:
138
+ ret += role
139
+
140
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
141
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
142
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
143
+ ret = ""
144
+
145
+ for i, (role, message) in enumerate(messages):
146
+ if i == 0:
147
+ assert message, "first message should not be none"
148
+ assert role == self.roles[0], "first message should come from user"
149
+ if message:
150
+ if type(message) is tuple:
151
+ message, _, _ = message
152
+ if i == 0:
153
+ message = wrap_sys(self.system) + message
154
+ if i % 2 == 0:
155
+ message = wrap_inst(message)
156
+ ret += self.sep + message
157
+ else:
158
+ ret += " " + message + " " + self.sep2
159
+ else:
160
+ ret += ""
161
+ ret = ret.lstrip(self.sep)
162
+
163
+ elif self.sep_style == SeparatorStyle.PLAIN:
164
+ seps = [self.sep, self.sep2]
165
+ ret = self.system
166
+ for i, (role, message) in enumerate(messages):
167
+ if message:
168
+ if type(message) is tuple:
169
+ message, _, _ = message
170
+ ret += message + seps[i % 2]
171
+ else:
172
+ ret += ""
173
+ else:
174
+ raise ValueError(f"Invalid style: {self.sep_style}")
175
+
176
+ return ret
177
+
178
+ def append_message(self, role, message):
179
+ self.messages.append([role, message])
180
+
181
+ def process_image(self, image, image_process_mode, return_pil=False, image_format="PNG"):
182
+ if image_process_mode == "Pad":
183
+
184
+ def expand2square(pil_img, background_color=(122, 116, 104)):
185
+ width, height = pil_img.size
186
+ if width == height:
187
+ return pil_img
188
+ elif width > height:
189
+ result = Image.new(pil_img.mode, (width, width), background_color)
190
+ result.paste(pil_img, (0, (width - height) // 2))
191
+ return result
192
+ else:
193
+ result = Image.new(pil_img.mode, (height, height), background_color)
194
+ result.paste(pil_img, ((height - width) // 2, 0))
195
+ return result
196
+
197
+ image = expand2square(image)
198
+ elif image_process_mode in ["Default", "Crop"]:
199
+ pass
200
+ elif image_process_mode == "Resize":
201
+ image = image.resize((336, 336))
202
+ else:
203
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
204
+
205
+ if type(image) is not Image.Image:
206
+ image = Image.open(image).convert("RGB")
207
+
208
+ max_hw, min_hw = max(image.size), min(image.size)
209
+ aspect_ratio = max_hw / min_hw
210
+ max_len, min_len = 672, 448
211
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
212
+ longest_edge = int(shortest_edge * aspect_ratio)
213
+ W, H = image.size
214
+ if H > W:
215
+ H, W = longest_edge, shortest_edge
216
+ else:
217
+ H, W = shortest_edge, longest_edge
218
+ image = image.resize((W, H))
219
+ if return_pil:
220
+ return image
221
+ else:
222
+ buffered = BytesIO()
223
+ image.save(buffered, format=image_format)
224
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
225
+ return img_b64_str
226
+
227
+ def get_images(self, return_pil=False, return_path=False):
228
+ images = []
229
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
230
+ if i % 2 == 0:
231
+ if type(msg) is tuple:
232
+ msg, image, image_process_mode = msg
233
+ if type(image) != list:
234
+ image = [image]
235
+ for img in image:
236
+ if not return_path and self.is_image_file(img):
237
+ img = self.process_image(img, image_process_mode, return_pil=return_pil)
238
+ else:
239
+ images.append(img)
240
+ return images
241
+
242
+ def is_image_file(self, filename):
243
+ image_extensions = [".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".webp"]
244
+ return any(filename.lower().endswith(ext) for ext in image_extensions)
245
+
246
+ def is_video_file(self, filename):
247
+ video_extensions = [".mp4", ".mov", ".avi", ".mkv", ".wmv", ".flv", ".mpeg", ".mpg"]
248
+ return any(filename.lower().endswith(ext) for ext in video_extensions)
249
+
250
+ def to_gradio_chatbot(self):
251
+ ret = []
252
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
253
+ if i % 2 == 0:
254
+ if type(msg) is tuple:
255
+ msg, image, image_process_mode = msg
256
+ if type(image) != list:
257
+ image = [image]
258
+ if len(image) == 1:
259
+ msg = "<image>\n" + msg.replace("<image>", "").strip()
260
+ else:
261
+ msg = re.sub(r"(<image>)\n(?=<image>)", r"\1 ", msg)
262
+
263
+ img_str_list = []
264
+ for img in image:
265
+ if self.is_image_file(img):
266
+ img_b64_str = self.process_image(img, "Default", return_pil=False, image_format="JPEG")
267
+ img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" style="max-width: 256px; max-height: 256px; width: auto; height: auto; object-fit: contain;"/>'
268
+ img_str_list.append(img_str)
269
+ elif self.is_video_file(img):
270
+ ret.append(((img,), None))
271
+
272
+ msg = msg.strip()
273
+ img_place_holder = ""
274
+ for img_str in img_str_list:
275
+ img_place_holder += f"{img_str}\n\n"
276
+
277
+ if len(img_str_list) > 0:
278
+ msg = f"{img_place_holder}\n\n{msg}"
279
+
280
+ if len(msg) > 0:
281
+ ret.append([msg, None])
282
+ else:
283
+ ret.append([msg, None])
284
+ else:
285
+ ret[-1][-1] = msg
286
+ return ret
287
+
288
+ def copy(self):
289
+ return Conversation(system=self.system, roles=self.roles, messages=[[x, y] for x, y in self.messages], offset=self.offset, sep_style=self.sep_style, sep=self.sep, sep2=self.sep2, version=self.version)
290
+
291
+ def dict(self):
292
+ if len(self.get_images()) > 0:
293
+ return {
294
+ "system": self.system,
295
+ "roles": self.roles,
296
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
297
+ "offset": self.offset,
298
+ "sep": self.sep,
299
+ "sep2": self.sep2,
300
+ }
301
+ return {
302
+ "system": self.system,
303
+ "roles": self.roles,
304
+ "messages": self.messages,
305
+ "offset": self.offset,
306
+ "sep": self.sep,
307
+ "sep2": self.sep2,
308
+ }
309
+
310
+
311
+ conv_vicuna_v0 = Conversation(
312
+ system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.",
313
+ roles=("Human", "Assistant"),
314
+ messages=[
315
+ ["Human", "What are the key differences between renewable and non-renewable energy sources?"],
316
+ [
317
+ "Assistant",
318
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
319
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
320
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
321
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
322
+ "renewable and non-renewable energy sources:\n"
323
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
324
+ "energy sources are finite and will eventually run out.\n"
325
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
326
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
327
+ "and other negative effects.\n"
328
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
329
+ "have lower operational costs than non-renewable sources.\n"
330
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
331
+ "locations than non-renewable sources.\n"
332
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
333
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
334
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
335
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n",
336
+ ],
337
+ ],
338
+ offset=2,
339
+ sep_style=SeparatorStyle.SINGLE,
340
+ sep="###",
341
+ )
342
+
343
+ conv_vicuna_v1 = Conversation(
344
+ system="A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions.",
345
+ roles=("USER", "ASSISTANT"),
346
+ version="v1",
347
+ messages=[],
348
+ offset=0,
349
+ sep_style=SeparatorStyle.TWO,
350
+ sep=" ",
351
+ sep2="</s>",
352
+ )
353
+
354
+ conv_llama_2 = Conversation(
355
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
356
+
357
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
358
+ roles=("USER", "ASSISTANT"),
359
+ version="llama_v2",
360
+ messages=[],
361
+ offset=0,
362
+ sep_style=SeparatorStyle.LLAMA_2,
363
+ sep="<s>",
364
+ sep2="</s>",
365
+ )
366
+
367
+ conv_llava_llama_2 = Conversation(
368
+ system="You are a helpful language and vision assistant. " "You are able to understand the visual content that the user provides, " "and assist the user with a variety of tasks using natural language.",
369
+ roles=("USER", "ASSISTANT"),
370
+ version="llama_v2",
371
+ messages=[],
372
+ offset=0,
373
+ sep_style=SeparatorStyle.LLAMA_2,
374
+ sep="<s>",
375
+ sep2="</s>",
376
+ )
377
+
378
+ conv_llava_llama_3 = Conversation(
379
+ system="You are a helpful language and vision assistant. " "You are able to understand the visual content that the user provides, " "and assist the user with a variety of tasks using natural language.",
380
+ roles=("user", "assistant"),
381
+ version="llama_v3",
382
+ messages=[],
383
+ offset=0,
384
+ sep="<|eot_id|>",
385
+ sep_style=SeparatorStyle.LLAMA_3,
386
+ tokenizer_id="meta-llama/Meta-Llama-3-8B-Instruct",
387
+ tokenizer=AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct"),
388
+ stop_token_ids=[128009],
389
+ )
390
+
391
+ conv_mistral_instruct = Conversation(
392
+ system="",
393
+ roles=("USER", "ASSISTANT"),
394
+ version="llama_v2",
395
+ messages=[],
396
+ offset=0,
397
+ sep_style=SeparatorStyle.LLAMA_2,
398
+ sep="",
399
+ sep2="</s>",
400
+ )
401
+
402
+ conv_llava_llama_2_simple = Conversation(
403
+ system="Answer the questions about the visual content that the user provides.",
404
+ roles=("USER", "ASSISTANT"),
405
+ version="llama_v2",
406
+ messages=[],
407
+ offset=0,
408
+ sep_style=SeparatorStyle.LLAMA_2,
409
+ sep="<s>",
410
+ sep2="</s>",
411
+ )
412
+
413
+ conv_llava_llama_2_mmtag = Conversation(
414
+ system="Answer the questions about the visual content that the user provides." "The visual content will be provided with the following format: <Image>visual content</Image>.",
415
+ roles=("USER", "ASSISTANT"),
416
+ version="llama_v2_mmtag",
417
+ messages=[],
418
+ offset=0,
419
+ sep_style=SeparatorStyle.LLAMA_2,
420
+ sep="<s>",
421
+ sep2="</s>",
422
+ )
423
+
424
+ conv_mpt = Conversation(
425
+ system="""<|im_start|>system
426
+ A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
427
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
428
+ version="mpt",
429
+ messages=[],
430
+ offset=0,
431
+ sep_style=SeparatorStyle.MPT,
432
+ sep="<|im_end|>",
433
+ )
434
+
435
+ conv_qwen = Conversation(
436
+ system="""<|im_start|>system
437
+ You are a helpful assistant.""",
438
+ roles=("<|im_start|>user", "<|im_start|>assistant"),
439
+ version="qwen",
440
+ messages=[],
441
+ offset=0,
442
+ sep_style=SeparatorStyle.CHATML,
443
+ sep="<|im_end|>",
444
+ )
445
+
446
+ conv_gemma_instruct = Conversation(system="", roles=("<start_of_turn>user\n", "<start_of_turn>model\n"), version="gemma", messages=[], offset=0, sep_style=SeparatorStyle.GEMMA, sep="<end_of_turn>\n")
447
+
448
+ conv_llava_plain = Conversation(
449
+ system="",
450
+ roles=("", ""),
451
+ messages=[],
452
+ offset=0,
453
+ sep_style=SeparatorStyle.PLAIN,
454
+ sep="\n",
455
+ )
456
+
457
+ conv_llava_v0 = Conversation(
458
+ system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.",
459
+ roles=("Human", "Assistant"),
460
+ messages=[],
461
+ offset=0,
462
+ sep_style=SeparatorStyle.SINGLE,
463
+ sep="###",
464
+ )
465
+
466
+ conv_llava_v0_mmtag = Conversation(
467
+ system="A chat between a curious user and an artificial intelligence assistant. "
468
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
469
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
470
+ roles=("Human", "Assistant"),
471
+ messages=[],
472
+ offset=0,
473
+ sep_style=SeparatorStyle.SINGLE,
474
+ sep="###",
475
+ version="v0_mmtag",
476
+ )
477
+
478
+ conv_llava_v1 = Conversation(
479
+ system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.",
480
+ roles=("USER", "ASSISTANT"),
481
+ version="v1",
482
+ messages=[],
483
+ offset=0,
484
+ sep_style=SeparatorStyle.TWO,
485
+ sep=" ",
486
+ sep2="</s>",
487
+ )
488
+
489
+ conv_llava_v1_mmtag = Conversation(
490
+ system="A chat between a curious user and an artificial intelligence assistant. "
491
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
492
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
493
+ roles=("USER", "ASSISTANT"),
494
+ messages=[],
495
+ offset=0,
496
+ sep_style=SeparatorStyle.TWO,
497
+ sep=" ",
498
+ sep2="</s>",
499
+ version="v1_mmtag",
500
+ )
501
+
502
+ conv_mistral_orca = Conversation(
503
+ system="""<|im_start|>system
504
+ You are MistralOrca, a large language model trained by Alignment Lab AI. Write out your reasoning step-by-step to be sure you get the right answers!""",
505
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
506
+ version="mpt",
507
+ messages=[],
508
+ offset=0,
509
+ sep_style=SeparatorStyle.MPT,
510
+ sep="<|im_end|>",
511
+ )
512
+
513
+ conv_mistral_zephyr = Conversation(
514
+ system="""<|system|>
515
+ You are a helpful AI assistant.""",
516
+ roles=("<|user|>\n", "<|assistant|>\n"),
517
+ version="mpt",
518
+ messages=[],
519
+ offset=0,
520
+ sep_style=SeparatorStyle.MPT,
521
+ sep="</s>",
522
+ )
523
+
524
+ conv_mistral_direct = Conversation(
525
+ system="""<|im_start|>system
526
+ Answer the questions.""",
527
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
528
+ version="mpt",
529
+ messages=[],
530
+ offset=0,
531
+ sep_style=SeparatorStyle.MPT,
532
+ sep="<|im_end|>",
533
+ )
534
+
535
+ conv_chatml_direct = Conversation(
536
+ system="""<|im_start|>system
537
+ Answer the questions.""",
538
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
539
+ version="mpt",
540
+ messages=[],
541
+ offset=0,
542
+ sep_style=SeparatorStyle.MPT,
543
+ sep="<|im_end|>",
544
+ )
545
+
546
+ default_conversation = conv_vicuna_v0
547
+ conv_templates = {
548
+ "default": conv_vicuna_v0,
549
+ "v0": conv_vicuna_v0,
550
+ "v1": conv_vicuna_v1,
551
+ "vicuna_v1": conv_vicuna_v1,
552
+ "llama_2": conv_llama_2,
553
+ "mistral_instruct": conv_mistral_instruct,
554
+ "mistral_orca": conv_mistral_orca,
555
+ "mistral_zephyr": conv_mistral_zephyr,
556
+ "mistral_direct": conv_mistral_direct,
557
+ "plain": conv_llava_plain,
558
+ "v0_plain": conv_llava_plain,
559
+ "chatml_direct": conv_chatml_direct,
560
+ "llava_v0": conv_llava_v0,
561
+ "llava_v0_mmtag": conv_llava_v0_mmtag,
562
+ "llava_v1": conv_llava_v1,
563
+ "llava_v1_mmtag": conv_llava_v1_mmtag,
564
+ "llava_llama_2": conv_llava_llama_2,
565
+ "llava_llama_3": conv_llava_llama_3,
566
+ "llava_llama_2_simple": conv_llava_llama_2_simple,
567
+ "llava_llama_2_mmtag": conv_llava_llama_2_mmtag,
568
+ "llava_mistral_instruct": conv_mistral_instruct,
569
+ "mpt": conv_mpt,
570
+ "qwen_1_5": conv_qwen,
571
+ "qwen_2": conv_qwen,
572
+ "gemma_instruct": conv_gemma_instruct,
573
+ }
574
+
575
+
576
+ if __name__ == "__main__":
577
+ print(default_conversation.get_prompt())
llava/eval/__pycache__/evaluate_interleave.cpython-39.pyc ADDED
Binary file (7.26 kB). View file
 
llava/eval/__pycache__/model_vqa.cpython-39.pyc ADDED
Binary file (6.58 kB). View file
 
llava/eval/evaluate_interleave.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from rouge import Rouge
3
+ import argparse
4
+ import os
5
+ import json
6
+ import numpy as np
7
+ from sklearn.feature_extraction.text import TfidfVectorizer
8
+ from sklearn.metrics.pairwise import cosine_similarity
9
+
10
+
11
+ spot_the_diff = ["Spot-the-Diff", "Birds-to-Words", "CLEVR-Change"]
12
+ image_edit_instruct = ["IEdit", "HQ-Edit", "MagicBrush"]
13
+ visual_story_telling = ["AESOP", "FlintstonesSV", "PororoSV", "VIST"]
14
+ visual_cloze = ["COMICS_Dialogue", "RecipeQA_VisualCloze"]
15
+ text_rich_vqa = ["WebQA", "TQA", "OCR-VQA", "DocVQA"]
16
+ multi_image_vqa = ["MIT-States_StateCoherence", "MIT-States_PropertyCoherence", "VISION", "RecipeQA_ImageCoherence"]
17
+
18
+ puzzle = ["RAVEN"]
19
+ nlrv2 = ["NLVR2_Mantis"]
20
+ qbench = ["QBench"]
21
+
22
+ class Eval:
23
+ def __init__(self):
24
+ self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
25
+ self.commaStrip = re.compile("(\d)(\,)(\d)")
26
+ self.punct = [
27
+ ";",
28
+ r"/",
29
+ "[",
30
+ "]",
31
+ '"',
32
+ "{",
33
+ "}",
34
+ "(",
35
+ ")",
36
+ "=",
37
+ "+",
38
+ "\\",
39
+ "_",
40
+ "-",
41
+ ">",
42
+ "<",
43
+ "@",
44
+ "`",
45
+ ",",
46
+ "?",
47
+ "!",
48
+ ]
49
+
50
+ def processPunctuation(self, inText):
51
+ outText = inText
52
+ for p in self.punct:
53
+ if (p + " " in inText or " " + p in inText) or (
54
+ re.search(self.commaStrip, inText) != None
55
+ ):
56
+ outText = outText.replace(p, "")
57
+ else:
58
+ outText = outText.replace(p, " ")
59
+ outText = self.periodStrip.sub("", outText, re.UNICODE)
60
+ return outText
61
+
62
+ def process(self, answer):
63
+ answer = answer.replace("\n", " ")
64
+ answer = answer.replace("\t", " ")
65
+ answer = answer.strip()
66
+ answer = self.processPunctuation(answer)
67
+ answer = answer.strip('\'')
68
+ answer = answer.strip('\"')
69
+ answer = answer.strip(')')
70
+ answer = answer.strip('(')
71
+ answer = answer.strip().lower()
72
+ return answer
73
+
74
+ def evaluate_rouge(self,preds):
75
+ rouge = Rouge()
76
+ acc = {'f': []}
77
+ eval_list = []
78
+ for i, res in enumerate(preds):
79
+ sample_id = res['sample_id']
80
+ # print(sample_id)
81
+ gt_ans = self.process(res["gt_response"])
82
+ pred_ans = self.process(res["pred_response"])
83
+ # assert gt_ans != ''
84
+
85
+ if gt_ans == '':
86
+ continue
87
+
88
+ if pred_ans == '':
89
+ s = 0
90
+ else:
91
+ if len(pred_ans) > 512:
92
+ pred_ans = pred_ans[0: 512]
93
+ s = rouge.get_scores(pred_ans, gt_ans)[0]['rouge-l']['f']
94
+ acc['f'].append(s)
95
+ eval_list.append({'id':str(sample_id),'score':str(round(s,3))})
96
+ results = {'Rouge-L f': np.mean(acc['f'])}
97
+ return results,eval_list
98
+
99
+
100
+ def judge_multi_choice(self,sample):
101
+ sample_id = sample['sample_id']
102
+ gt_ans = sample["gt_response"]
103
+ pred_ans = sample["pred_response"]
104
+
105
+ if ":" in pred_ans:
106
+ a_list = pred_ans.split(":")
107
+ a_list = [a.strip() for a in a_list ]
108
+ for a in a_list:
109
+ if len(a) == 1 and a[-1] in ["a", "b", "c", "d", "e", "f", "g", "h"]:
110
+ pred_ans = a
111
+
112
+ if pred_ans == gt_ans:
113
+ return 1
114
+ else:
115
+ return 0
116
+
117
+ def process_sample(self,sample):
118
+ sample["gt_response"] = self.process(sample["gt_response"])
119
+ sample["pred_response"] = self.process(sample["pred_response"])
120
+
121
+ def evaluate_multichoice(self, preditions):
122
+ correct = 0
123
+ eval_list = []
124
+ for i, sample in enumerate(preditions):
125
+ self.process_sample(sample)
126
+ score = self.judge_multi_choice(sample)
127
+ sample_id = sample['sample_id']
128
+ sample['result'] = score
129
+ eval_list.append({'id':str(sample_id),'score':str(score)})
130
+ correct+=score
131
+ return {'Accuracy':correct/len(preditions)},eval_list
132
+
133
+ def evaluate_multi_choice_image(self,preditions):
134
+ correct = 0
135
+ eval_list = []
136
+ for i,sample in enumerate(preditions):
137
+ gt_ans = self.process(sample["gt_response"])
138
+ pred_ans = self.process(sample["pred_response"])
139
+ sample_id = sample['sample_id']
140
+
141
+ if ":" in pred_ans:
142
+ a_list = pred_ans.split(":")
143
+ a_list = [a.strip() for a in a_list ]
144
+ for a in a_list:
145
+ if len(a) == 1 and a[-1] in ["a", "b", "c", "d", "e", "f", "g", "h"]:
146
+ pred_ans = a
147
+
148
+ if gt_ans == pred_ans:
149
+ score = 1
150
+ else:
151
+ score = 0
152
+ sample_id = sample['sample_id']
153
+ sample['result'] = score
154
+ eval_list.append({'id':str(sample_id),'score':str(score)})
155
+ correct+=score
156
+ return {'Accuracy':correct/len(preditions)},eval_list
157
+
158
+
159
+ if __name__ == "__main__":
160
+ parser = argparse.ArgumentParser()
161
+ parser.add_argument('--result-dir', type=str, required=True)
162
+
163
+ args = parser.parse_args()
164
+
165
+ result_file = os.path.join(args.result_dir, "result.jsonl")
166
+
167
+ if not os.path.exists(result_file):
168
+ print('No prediction file found')
169
+ exit(0)
170
+ with open(result_file, 'r') as f:
171
+ preds_all = [json.loads(line) for line in f]
172
+
173
+ preds_all_dict = dict()
174
+ for pred in preds_all:
175
+ if pred["dataset"] not in preds_all_dict:
176
+ preds_all_dict[pred["dataset"]] = list()
177
+ preds_all_dict[pred["dataset"]].append(pred)
178
+
179
+ image_choice_dataset_list = ["recipeqa-RecipeQA_VisualCloze", "RecipeQA_ImageCoherence", "COMICS_Panel"]
180
+ E = Eval()
181
+
182
+ eval_result_list = dict()
183
+ eval_result_list_detail = dict()
184
+
185
+ for dataset in preds_all_dict:
186
+
187
+ preds = preds_all_dict[dataset]
188
+ question_type = preds[0]["question_type"]
189
+
190
+ if question_type == 'open-ended':
191
+ eval_result, eval_list = E.evaluate_rouge(preds)
192
+
193
+ elif question_type == 'multi-choice' or dataset == 'nlrv2':
194
+ if dataset in image_choice_dataset_list:
195
+ eval_result, eval_list = E.evaluate_multi_choice_image(preds)
196
+ else:
197
+ eval_result, eval_list = E.evaluate_multichoice(preds)
198
+
199
+ else:
200
+ eval_result = 'Dataset not supported'
201
+ print('Dataset not supported')
202
+ exit(0)
203
+
204
+ print(dataset, end = ': ')
205
+ print(eval_result)
206
+
207
+ eval_result_list[dataset] = eval_result
208
+ eval_result_list_detail[dataset] = eval_list
209
+
210
+ os.makedirs(args.result_dir, exist_ok=True)
211
+ with open(os.path.join(args.result_dir, 'eval_dataset.json'), 'w') as f:
212
+ json.dump(eval_result_list, f, indent=4)
213
+
214
+ with open(os.path.join(args.result_dir,'eval_dataset_details.json'), 'w') as f:
215
+ json.dump(eval_result_list_detail, f, indent=4)
216
+
217
+
218
+ eval_cat_list = dict()
219
+ print()
220
+
221
+ # spot_the_diff
222
+ score = 0
223
+ count = 0
224
+ for dataset in eval_result_list:
225
+ if dataset in spot_the_diff:
226
+ count += 1
227
+ score += list(eval_result_list[dataset].values())[0]
228
+ if count > 0:
229
+ score /= count
230
+ eval_cat_list["spot_the_diff"] = score
231
+ print("spot_the_diff", end = ': ')
232
+ print('{:.2f}'.format(100 * score))
233
+
234
+ # image_edit_instruct
235
+ score = 0
236
+ count = 0
237
+ for dataset in eval_result_list:
238
+ if dataset in image_edit_instruct:
239
+ count += 1
240
+ score += list(eval_result_list[dataset].values())[0]
241
+ if count > 0:
242
+ score /= count
243
+ eval_cat_list["image_edit_instruct"] = score
244
+ print("image_edit_instruct", end = ': ')
245
+ print('{:.2f}'.format(100 * score))
246
+
247
+ # visual_story_telling
248
+ score = 0
249
+ count = 0
250
+ for dataset in eval_result_list:
251
+ if dataset in visual_story_telling:
252
+ count += 1
253
+ score += list(eval_result_list[dataset].values())[0]
254
+ if count > 0:
255
+ score /= count
256
+ eval_cat_list["visual_story_telling"] = score
257
+ print("visual_story_telling", end = ': ')
258
+ print('{:.2f}'.format(100 * score))
259
+
260
+ # visual_cloze
261
+ score = 0
262
+ count = 0
263
+ for dataset in eval_result_list:
264
+ if dataset in visual_cloze:
265
+ count += 1
266
+ score += list(eval_result_list[dataset].values())[0]
267
+ if count > 0:
268
+ score /= count
269
+ eval_cat_list["visual_cloze"] = score
270
+ print("visual_cloze", end = ': ')
271
+ print('{:.2f}'.format(100 * score))
272
+
273
+ # text_rich_vqa
274
+ score = 0
275
+ count = 0
276
+ for dataset in eval_result_list:
277
+ if dataset in text_rich_vqa:
278
+ count += 1
279
+ score += list(eval_result_list[dataset].values())[0]
280
+ if count > 0:
281
+ score /= count
282
+ eval_cat_list["text_rich_vqa"] = score
283
+ print("text_rich_vqa", end = ': ')
284
+ print('{:.2f}'.format(100 * score))
285
+
286
+ # multi_image_vqa
287
+ score = 0
288
+ count = 0
289
+ for dataset in eval_result_list:
290
+ if dataset in multi_image_vqa:
291
+ count += 1
292
+ score += list(eval_result_list[dataset].values())[0]
293
+ if count > 0:
294
+ score /= count
295
+ eval_cat_list["multi_image_vqa"] = score
296
+ print("multi_image_vqa", end = ': ')
297
+ print('{:.2f}'.format(100 * score))
298
+
299
+ # puzzle
300
+ score = 0
301
+ count = 0
302
+ for dataset in eval_result_list:
303
+ if dataset in puzzle:
304
+ count += 1
305
+ score += list(eval_result_list[dataset].values())[0]
306
+ if count > 0:
307
+ score /= count
308
+ eval_cat_list["puzzle"] = score
309
+ print("puzzle", end = ': ')
310
+ print('{:.2f}'.format(100 * score))
311
+
312
+ # nlrv2
313
+ score = 0
314
+ count = 0
315
+ for dataset in eval_result_list:
316
+ if dataset in nlrv2:
317
+ count += 1
318
+ score += list(eval_result_list[dataset].values())[0]
319
+ if count > 0:
320
+ score /= count
321
+ eval_cat_list["nlrv2"] = score
322
+ print("nlrv2", end = ': ')
323
+ print('{:.2f}'.format(100 * score))
324
+
325
+ # qbench
326
+ score = 0
327
+ count = 0
328
+ for dataset in eval_result_list:
329
+ if dataset in qbench:
330
+ count += 1
331
+ score += list(eval_result_list[dataset].values())[0]
332
+ if count > 0:
333
+ score /= count
334
+ eval_cat_list["qbench"] = score
335
+ print("qbench", end = ': ')
336
+ print('{:.2f}'.format(100 * score))
337
+
338
+ with open(os.path.join(args.result_dir,'eval_cat.json'), 'w') as f:
339
+ json.dump(eval_cat_list, f, indent=4)
llava/eval/model_vqa.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import os
4
+ import json
5
+ from tqdm import tqdm
6
+ import shortuuid
7
+
8
+ from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
9
+ from llava.conversation import conv_templates, SeparatorStyle
10
+ from llava.model.builder import load_pretrained_model
11
+ from llava.utils import disable_torch_init
12
+ from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
13
+
14
+ from llava.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IMAGE_TOKEN_INDEX
15
+ from typing import Dict, Optional, Sequence, List
16
+ import transformers
17
+ import re
18
+
19
+ from PIL import Image
20
+ import math
21
+
22
+
23
+ def split_list(lst, n):
24
+ """Split a list into n (roughly) equal-sized chunks"""
25
+ chunk_size = math.ceil(len(lst) / n) # integer division
26
+ return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
27
+
28
+
29
+ def get_chunk(lst, n, k):
30
+ chunks = split_list(lst, n)
31
+ return chunks[k]
32
+
33
+ def preprocess_qwen(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False, max_len=2048, system_message: str = "You are a helpful assistant.") -> Dict:
34
+ roles = {"human": "<|im_start|>user", "gpt": "<|im_start|>assistant"}
35
+
36
+ im_start, im_end = tokenizer.additional_special_tokens_ids
37
+ nl_tokens = tokenizer("\n").input_ids
38
+ _system = tokenizer("system").input_ids + nl_tokens
39
+ _user = tokenizer("user").input_ids + nl_tokens
40
+ _assistant = tokenizer("assistant").input_ids + nl_tokens
41
+
42
+ # Apply prompt templates
43
+ input_ids, targets = [], []
44
+
45
+ source = sources
46
+ if roles[source[0]["from"]] != roles["human"]:
47
+ source = source[1:]
48
+
49
+ input_id, target = [], []
50
+ system = [im_start] + _system + tokenizer(system_message).input_ids + [im_end] + nl_tokens
51
+ input_id += system
52
+ target += [im_start] + [IGNORE_INDEX] * (len(system) - 3) + [im_end] + nl_tokens
53
+ assert len(input_id) == len(target)
54
+ for j, sentence in enumerate(source):
55
+ role = roles[sentence["from"]]
56
+ if has_image and sentence["value"] is not None and "<image>" in sentence["value"]:
57
+ num_image = len(re.findall(DEFAULT_IMAGE_TOKEN, sentence["value"]))
58
+ texts = sentence["value"].split('<image>')
59
+ _input_id = tokenizer(role).input_ids + nl_tokens
60
+ for i,text in enumerate(texts):
61
+ _input_id += tokenizer(text).input_ids
62
+ if i<len(texts)-1:
63
+ _input_id += [IMAGE_TOKEN_INDEX] + nl_tokens
64
+ _input_id += [im_end] + nl_tokens
65
+ assert sum([i==IMAGE_TOKEN_INDEX for i in _input_id])==num_image
66
+ else:
67
+ if sentence["value"] is None:
68
+ _input_id = tokenizer(role).input_ids + nl_tokens
69
+ else:
70
+ _input_id = tokenizer(role).input_ids + nl_tokens + tokenizer(sentence["value"]).input_ids + [im_end] + nl_tokens
71
+ input_id += _input_id
72
+ if role == "<|im_start|>user":
73
+ _target = [im_start] + [IGNORE_INDEX] * (len(_input_id) - 3) + [im_end] + nl_tokens
74
+ elif role == "<|im_start|>assistant":
75
+ _target = [im_start] + [IGNORE_INDEX] * len(tokenizer(role).input_ids) + _input_id[len(tokenizer(role).input_ids) + 1 : -2] + [im_end] + nl_tokens
76
+ else:
77
+ raise NotImplementedError
78
+ target += _target
79
+
80
+ input_ids.append(input_id)
81
+ targets.append(target)
82
+ input_ids = torch.tensor(input_ids, dtype=torch.long)
83
+ targets = torch.tensor(targets, dtype=torch.long)
84
+ return input_ids
85
+
86
+ def eval_model(args):
87
+
88
+ # Model
89
+ disable_torch_init()
90
+ model_path = os.path.expanduser(args.model_path)
91
+ model_name = get_model_name_from_path(model_path)
92
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
93
+
94
+ # Data
95
+ with open(os.path.expanduser(args.question_file)) as f:
96
+ questions = json.load(f)
97
+ questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
98
+ answers_file = os.path.expanduser(args.answers_file)
99
+ os.makedirs(os.path.dirname(answers_file), exist_ok=True)
100
+ ans_file = open(answers_file, "w")
101
+
102
+ for line in tqdm(questions):
103
+ idx = line["sample_id"]
104
+ question_type = line["metadata"]["question_type"]
105
+ dataset_name = line["metadata"]["dataset"]
106
+ gt = line["conversations"][1]["value"]
107
+
108
+ image_files = line["image"]
109
+ qs = line["conversations"][0]["value"]
110
+ cur_prompt = args.extra_prompt + qs
111
+
112
+ args.conv_mode = "qwen_1_5"
113
+
114
+ conv = conv_templates[args.conv_mode].copy()
115
+ conv.append_message(conv.roles[0], qs)
116
+ conv.append_message(conv.roles[1], None)
117
+ prompt = conv.get_prompt()
118
+
119
+ input_ids = preprocess_qwen([line["conversations"][0],{'from': 'gpt','value': None}], tokenizer, has_image=True).cuda()
120
+ img_num = list(input_ids.squeeze()).count(IMAGE_TOKEN_INDEX)
121
+
122
+ image_tensors = []
123
+ for image_file in image_files:
124
+ image = Image.open(os.path.join(args.image_folder, image_file))
125
+ image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values']
126
+ image_tensors.append(image_tensor.half().cuda())
127
+ # image_tensors = torch.cat(image_tensors, dim=0)
128
+
129
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
130
+ keywords = [stop_str]
131
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
132
+
133
+ with torch.inference_mode():
134
+ output_ids = model.generate(
135
+ input_ids,
136
+ images=image_tensors,
137
+ do_sample=True if args.temperature > 0 else False,
138
+ temperature=args.temperature,
139
+ top_p=args.top_p,
140
+ num_beams=args.num_beams,
141
+ # no_repeat_ngram_size=3,
142
+ max_new_tokens=1024,
143
+ use_cache=True)
144
+
145
+
146
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
147
+ outputs = outputs.strip()
148
+ if outputs.endswith(stop_str):
149
+ outputs = outputs[:-len(stop_str)]
150
+ outputs = outputs.strip()
151
+
152
+ ans_id = shortuuid.uuid()
153
+ ans_file.write(json.dumps({
154
+ "dataset": dataset_name,
155
+ "sample_id": idx,
156
+ "prompt": cur_prompt,
157
+ "pred_response": outputs,
158
+ "gt_response": gt,
159
+ "shortuuid": ans_id,
160
+ "model_id": model_name,
161
+ "question_type": question_type,
162
+ }) + "\n")
163
+ ans_file.flush()
164
+
165
+ if len(line["conversations"]) > 2:
166
+
167
+ for i in range(2, len(line["conversations"]), 2):
168
+ input_ids = torch.cat((input_ids, output_ids), dim=1)
169
+
170
+ gt = line["conversations"][i + 1]["value"]
171
+ qs = line["conversations"][i]["value"]
172
+ cur_prompt = args.extra_prompt + qs
173
+
174
+ args.conv_mode = "qwen_1_5"
175
+
176
+ conv = conv_templates[args.conv_mode].copy()
177
+ conv.append_message(conv.roles[0], qs)
178
+ conv.append_message(conv.roles[1], None)
179
+ prompt = conv.get_prompt()
180
+
181
+ input_ids_new = preprocess_qwen([line["conversations"][i],{'from': 'gpt','value': None}], tokenizer, has_image=True).cuda()
182
+ input_ids = torch.cat((input_ids, input_ids_new), dim=1)
183
+ img_num = list(input_ids_new.squeeze()).count(IMAGE_TOKEN_INDEX)
184
+
185
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
186
+ keywords = [stop_str]
187
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
188
+
189
+ with torch.inference_mode():
190
+ output_ids = model.generate(
191
+ input_ids,
192
+ images=image_tensors,
193
+ do_sample=True if args.temperature > 0 else False,
194
+ temperature=args.temperature,
195
+ top_p=args.top_p,
196
+ num_beams=args.num_beams,
197
+ # no_repeat_ngram_size=3,
198
+ max_new_tokens=1024,
199
+ use_cache=True)
200
+
201
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
202
+ outputs = outputs.strip()
203
+ if outputs.endswith(stop_str):
204
+ outputs = outputs[:-len(stop_str)]
205
+ outputs = outputs.strip()
206
+
207
+ ans_id = shortuuid.uuid()
208
+ ans_file.write(json.dumps({
209
+ "dataset": dataset_name,
210
+ "sample_id": idx,
211
+ "prompt": cur_prompt,
212
+ "pred_response": outputs,
213
+ "gt_response": gt,
214
+ "shortuuid": ans_id,
215
+ "model_id": model_name,
216
+ "question_type": question_type,
217
+ }) + "\n")
218
+ ans_file.flush()
219
+
220
+
221
+ ans_file.close()
222
+
223
+ if __name__ == "__main__":
224
+ parser = argparse.ArgumentParser()
225
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
226
+ parser.add_argument("--model-base", type=str, default=None)
227
+ parser.add_argument("--image-folder", type=str, default="")
228
+ parser.add_argument("--extra-prompt", type=str, default="")
229
+ parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
230
+ parser.add_argument("--answers-file", type=str, default="answer.jsonl")
231
+ parser.add_argument("--conv-mode", type=str, default="llava_v1")
232
+ parser.add_argument("--num-chunks", type=int, default=1)
233
+ parser.add_argument("--chunk-idx", type=int, default=0)
234
+ parser.add_argument("--temperature", type=float, default=0.2)
235
+ parser.add_argument("--top_p", type=float, default=None)
236
+ parser.add_argument("--num_beams", type=int, default=1)
237
+ parser.add_argument("--test_size", type=int, default=10000000)
238
+ args = parser.parse_args()
239
+
240
+ eval_model(args)
llava/mm_utils.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from io import BytesIO
3
+ import base64
4
+ import math
5
+ import ast
6
+ import re
7
+ import torch
8
+ from transformers import StoppingCriteria
9
+ from llava.constants import IMAGE_TOKEN_INDEX
10
+
11
+
12
+ def resize_and_center_crop(image, shortest_edge_length):
13
+ # Calculate new dimensions and resize
14
+ aspect_ratio = float(image.width) / float(image.height)
15
+ if aspect_ratio > 1:
16
+ new_width = int(shortest_edge_length * aspect_ratio)
17
+ new_height = shortest_edge_length
18
+ else:
19
+ new_width = shortest_edge_length
20
+ new_height = int(shortest_edge_length / aspect_ratio)
21
+ resized_image = image.resize((new_width, new_height), Image.ANTIALIAS)
22
+
23
+ # Calculate the position and perform the center crop
24
+ left = (new_width - shortest_edge_length) / 2
25
+ top = (new_height - shortest_edge_length) / 2
26
+ right = (new_width + shortest_edge_length) / 2
27
+ bottom = (new_height + shortest_edge_length) / 2
28
+ cropped_image = resized_image.crop((left, top, right, bottom))
29
+
30
+ return cropped_image
31
+
32
+
33
+ def auto_pad_images(image, grid_params):
34
+ assert isinstance(image, Image.Image), "Input should be a Pillow Image"
35
+ assert len(grid_params) > 0, "Grid parameters should not be empty"
36
+
37
+ # Step 1: Calculate and find the closest aspect ratio
38
+ input_width, input_height = image.size
39
+ input_aspect_ratio = input_width / input_height
40
+ candidate_resolutions = [(w / h, w, h) for w in grid_params for h in grid_params]
41
+ closest_aspect_ratio = min(candidate_resolutions, key=lambda x: abs(input_aspect_ratio - x[0]))
42
+
43
+ candidate_resolutions = [(x[1], x[2]) for x in candidate_resolutions if abs(x[0] - closest_aspect_ratio[0]) < 1e-3]
44
+
45
+ target_resolution = min(candidate_resolutions, key=lambda res: abs(max(input_width, input_height) / max(res) - 1))
46
+
47
+ resize_width, resize_height = target_resolution
48
+ if input_width > input_height:
49
+ resize_height = int(resize_width / input_aspect_ratio)
50
+ else:
51
+ resize_width = int(resize_height * input_aspect_ratio)
52
+ resized_image = image.resize((resize_width, resize_height), Image.ANTIALIAS)
53
+
54
+ # Step 5: Pad the resized image if necessary to match the target resolution
55
+ pad_width = target_resolution[0] - resize_width
56
+ pad_height = target_resolution[1] - resize_height
57
+ padded_image = Image.new("RGB", target_resolution, color=(0, 0, 0))
58
+ padded_image.paste(resized_image, (pad_width // 2, pad_height // 2))
59
+
60
+ return padded_image
61
+
62
+
63
+ def extract_patches(image, patch_size, overlap_ratio):
64
+ assert isinstance(image, Image.Image), "Input should be a Pillow Image"
65
+ assert patch_size > 0, "Patch size should be greater than 0"
66
+ assert 0 <= overlap_ratio < 1, "Overlap ratio should be between 0 and 1"
67
+
68
+ W, H = image.size
69
+ patches = []
70
+
71
+ stride = int(patch_size * (1 - overlap_ratio))
72
+
73
+ num_patches_y = (H - patch_size) // stride + 1
74
+ num_patches_x = (W - patch_size) // stride + 1
75
+
76
+ y_start = (H - (num_patches_y - 1) * stride - patch_size) // 2
77
+ x_start = (W - (num_patches_x - 1) * stride - patch_size) // 2
78
+
79
+ for y in range(y_start, y_start + num_patches_y * stride, stride):
80
+ for x in range(x_start, x_start + num_patches_x * stride, stride):
81
+ patch = image.crop((x, y, x + patch_size, y + patch_size))
82
+ patches.append(patch)
83
+
84
+ return patches
85
+
86
+
87
+ def process_highres_image_crop_split(image, data_args, processor=None):
88
+ crop_resolution = data_args.image_crop_resolution
89
+ split_resolution = data_args.image_split_resolution
90
+ if processor is None:
91
+ processor = data_args.image_processor
92
+ image_crop = resize_and_center_crop(image, crop_resolution)
93
+ image_patches = extract_patches(image_crop, patch_size=split_resolution, overlap_ratio=0)
94
+ image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches]
95
+ return torch.stack(image_patches, dim=0)
96
+
97
+
98
+ def process_highres_image(image, processor, grid_pinpoints):
99
+ grid_params = [int(x) for x in grid_pinpoints.split(",")]
100
+ width_height = max(image.size)
101
+ fit_grid_params = [x for x in grid_params if x >= width_height]
102
+ if len(fit_grid_params) == 0:
103
+ select_size = max(grid_params)
104
+ else:
105
+ select_size = min(fit_grid_params)
106
+ # FIXME: always select the 448
107
+ select_size = max(grid_params)
108
+ image_padded = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
109
+
110
+ # FIXME: this seems to be a bug that it always resizes instead of padding
111
+ image_original_resize = image.resize((processor.size["shortest_edge"], processor.size["shortest_edge"]))
112
+ image_padded = image_padded.resize((select_size, select_size))
113
+ image_patches = extract_patches(image_padded, patch_size=processor.size["shortest_edge"], overlap_ratio=0)
114
+ image_patches = [image_original_resize] + image_patches
115
+ image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches]
116
+ return torch.stack(image_patches, dim=0)
117
+
118
+
119
+ def select_best_resolution(original_size, possible_resolutions):
120
+ """
121
+ Selects the best resolution from a list of possible resolutions based on the original size.
122
+
123
+ Args:
124
+ original_size (tuple): The original size of the image in the format (width, height).
125
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
126
+
127
+ Returns:
128
+ tuple: The best fit resolution in the format (width, height).
129
+ """
130
+ original_width, original_height = original_size
131
+ best_fit = None
132
+ max_effective_resolution = 0
133
+ min_wasted_resolution = float("inf")
134
+
135
+ for width, height in possible_resolutions:
136
+ # Calculate the downscaled size to keep the aspect ratio
137
+ scale = min(width / original_width, height / original_height)
138
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
139
+
140
+ # Calculate effective and wasted resolutions
141
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
142
+ wasted_resolution = (width * height) - effective_resolution
143
+
144
+ if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
145
+ max_effective_resolution = effective_resolution
146
+ min_wasted_resolution = wasted_resolution
147
+ best_fit = (width, height)
148
+
149
+ return best_fit
150
+
151
+
152
+ def resize_and_pad_image(image, target_resolution):
153
+ """
154
+ Resize and pad an image to a target resolution while maintaining aspect ratio.
155
+
156
+ Args:
157
+ image (PIL.Image.Image): The input image.
158
+ target_resolution (tuple): The target resolution (width, height) of the image.
159
+
160
+ Returns:
161
+ PIL.Image.Image: The resized and padded image.
162
+ """
163
+ original_width, original_height = image.size
164
+ target_width, target_height = target_resolution
165
+
166
+ # Determine which dimension (width or height) to fill
167
+ scale_w = target_width / original_width
168
+ scale_h = target_height / original_height
169
+
170
+ if scale_w < scale_h:
171
+ # Width will be filled completely
172
+ new_width = target_width
173
+ new_height = min(math.ceil(original_height * scale_w), target_height)
174
+ else:
175
+ # Height will be filled completely
176
+ new_height = target_height
177
+ new_width = min(math.ceil(original_width * scale_h), target_width)
178
+
179
+ # Resize the image
180
+ resized_image = image.resize((new_width, new_height))
181
+
182
+ # Create a new image with the target size and paste the resized image onto it
183
+ new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0))
184
+ paste_x = (target_width - new_width) // 2
185
+ paste_y = (target_height - new_height) // 2
186
+ new_image.paste(resized_image, (paste_x, paste_y))
187
+
188
+ return new_image
189
+
190
+
191
+ def divide_to_patches(image, patch_size):
192
+ """
193
+ Divides an image into patches of a specified size.
194
+
195
+ Args:
196
+ image (PIL.Image.Image): The input image.
197
+ patch_size (int): The size of each patch.
198
+
199
+ Returns:
200
+ list: A list of PIL.Image.Image objects representing the patches.
201
+ """
202
+ patches = []
203
+ width, height = image.size
204
+ for i in range(0, height, patch_size):
205
+ for j in range(0, width, patch_size):
206
+ box = (j, i, j + patch_size, i + patch_size)
207
+ patch = image.crop(box)
208
+ patches.append(patch)
209
+
210
+ return patches
211
+
212
+
213
+ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
214
+ """
215
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
216
+
217
+ Args:
218
+ image_size (tuple): The size of the input image in the format (width, height).
219
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
220
+ patch_size (int): The size of each image patch.
221
+
222
+ Returns:
223
+ tuple: The shape of the image patch grid in the format (width, height).
224
+ """
225
+ if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
226
+ assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
227
+ # Use regex to extract the range from the input string
228
+ matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
229
+ range_start = tuple(map(int, matches[0]))
230
+ range_end = tuple(map(int, matches[-1]))
231
+ # Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
232
+ grid_pinpoints = [(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)]
233
+ # Multiply all elements by patch_size
234
+ grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
235
+ if type(grid_pinpoints) is list:
236
+ possible_resolutions = grid_pinpoints
237
+ else:
238
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
239
+ width, height = select_best_resolution(image_size, possible_resolutions)
240
+ return width // patch_size, height // patch_size
241
+
242
+
243
+ def process_anyres_image(image, processor, grid_pinpoints):
244
+ """
245
+ Process an image with variable resolutions.
246
+
247
+ Args:
248
+ image (PIL.Image.Image): The input image to be processed.
249
+ processor: The image processor object.
250
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
251
+
252
+ Returns:
253
+ torch.Tensor: A tensor containing the processed image patches.
254
+ """
255
+ # Convert grid_pinpoints from string to list
256
+ if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
257
+ try:
258
+ patch_size = processor.size[0]
259
+ except Exception as e:
260
+ patch_size = processor.size["shortest_edge"]
261
+ assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
262
+ # Use regex to extract the range from the input string
263
+ matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
264
+ range_start = tuple(map(int, matches[0]))
265
+ range_end = tuple(map(int, matches[-1]))
266
+ # Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
267
+ grid_pinpoints = [(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)]
268
+ # Multiply all elements by patch_size
269
+ grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
270
+
271
+ if type(grid_pinpoints) is list:
272
+ possible_resolutions = grid_pinpoints
273
+ else:
274
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
275
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
276
+ image_padded = resize_and_pad_image(image, best_resolution)
277
+
278
+ patches = divide_to_patches(image_padded, processor.crop_size["height"])
279
+
280
+ # FIXME: this seems to be a bug that it resizes instead of pad.
281
+ # but to keep it consistent with previous, i will keep it as it is
282
+ # TODO: uncomment below to ablate with the padding
283
+ if isinstance(processor.size, dict):
284
+ shortest_edge = processor.size["shortest_edge"]
285
+ else:
286
+ shortest_edge = min(processor.size)
287
+ image_original_resize = image.resize((shortest_edge, shortest_edge))
288
+ # image_padded_square = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
289
+ # image_original_resize = image_padded_square.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
290
+
291
+ image_patches = [image_original_resize] + patches
292
+ image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches]
293
+ return torch.stack(image_patches, dim=0)
294
+
295
+
296
+ def load_image_from_base64(image):
297
+ return Image.open(BytesIO(base64.b64decode(image)))
298
+
299
+
300
+ def expand2square(pil_img, background_color):
301
+ width, height = pil_img.size
302
+ if width == height:
303
+ return pil_img
304
+ elif width > height:
305
+ result = Image.new(pil_img.mode, (width, width), background_color)
306
+ result.paste(pil_img, (0, (width - height) // 2))
307
+ return result
308
+ else:
309
+ result = Image.new(pil_img.mode, (height, height), background_color)
310
+ result.paste(pil_img, ((height - width) // 2, 0))
311
+ return result
312
+
313
+
314
+ def process_images(images, image_processor, model_cfg):
315
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
316
+ new_images = []
317
+ if image_aspect_ratio == "highres":
318
+ for image in images:
319
+ image = process_highres_image(image, image_processor, model_cfg.image_grid_pinpoints)
320
+ new_images.append(image)
321
+ elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
322
+ for image in images:
323
+ image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
324
+ new_images.append(image)
325
+ elif image_aspect_ratio == "crop_split":
326
+ for image in images:
327
+ image = process_highres_image_crop_split(image, model_cfg, image_processor)
328
+ new_images.append(image)
329
+ elif image_aspect_ratio == "pad":
330
+ for image in images:
331
+ image = expand2square(image, tuple(int(x * 255) for x in image_processor.image_mean))
332
+ image = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
333
+ new_images.append(image)
334
+ else:
335
+ return image_processor.preprocess(images, return_tensors="pt")["pixel_values"]
336
+ if all(x.shape == new_images[0].shape for x in new_images):
337
+ new_images = torch.stack(new_images, dim=0)
338
+ return new_images
339
+
340
+
341
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
342
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
343
+
344
+ def insert_separator(X, sep):
345
+ return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
346
+
347
+ input_ids = []
348
+ offset = 0
349
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
350
+ offset = 1
351
+ input_ids.append(prompt_chunks[0][0])
352
+
353
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
354
+ input_ids.extend(x[offset:])
355
+
356
+ if return_tensors is not None:
357
+ if return_tensors == "pt":
358
+ return torch.tensor(input_ids, dtype=torch.long)
359
+ raise ValueError(f"Unsupported tensor type: {return_tensors}")
360
+ return input_ids
361
+
362
+
363
+ def get_model_name_from_path(model_path):
364
+ model_path = model_path.strip("/")
365
+ model_paths = model_path.split("/")
366
+ if model_paths[-1].startswith("checkpoint-"):
367
+ return model_paths[-2] + "_" + model_paths[-1]
368
+ else:
369
+ return model_paths[-1]
370
+
371
+
372
+ class KeywordsStoppingCriteria(StoppingCriteria):
373
+ def __init__(self, keywords, tokenizer, input_ids):
374
+ self.keywords = keywords
375
+ self.keyword_ids = []
376
+ for keyword in keywords:
377
+ cur_keyword_ids = tokenizer(keyword).input_ids
378
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
379
+ cur_keyword_ids = cur_keyword_ids[1:]
380
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
381
+ self.tokenizer = tokenizer
382
+ self.start_len = input_ids.shape[1]
383
+
384
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
385
+ assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
386
+ offset = min(output_ids.shape[1] - self.start_len, 3)
387
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
388
+ for keyword_id in self.keyword_ids:
389
+ if output_ids[0, -keyword_id.shape[0] :] == keyword_id:
390
+ return True
391
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
392
+ for keyword in self.keywords:
393
+ if keyword in outputs:
394
+ return True
395
+ return False
llava/model/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ AVAILABLE_MODELS = {
4
+ "llava_llama": "LlavaLlamaForCausalLM, LlavaConfig",
5
+ "llava_qwen": "LlavaQwenForCausalLM, LlavaQwenConfig",
6
+ "llava_mistral": "LlavaMistralForCausalLM, LlavaMistralConfig",
7
+ "llava_mixtral": "LlavaMixtralForCausalLM, LlavaMixtralConfig",
8
+ # "llava_qwen_moe": "LlavaQwenMoeForCausalLM, LlavaQwenMoeConfig",
9
+ # Add other models as needed
10
+ }
11
+
12
+ for model_name, model_classes in AVAILABLE_MODELS.items():
13
+ try:
14
+ exec(f"from .language_model.{model_name} import {model_classes}")
15
+ except Exception as e:
16
+ print(f"Failed to import {model_name} from llava.language_model.{model_name}. Error: {e}")
llava/model/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (726 Bytes). View file
 
llava/model/__pycache__/apply_delta.cpython-39.pyc ADDED
Binary file (1.73 kB). View file
 
llava/model/__pycache__/builder.cpython-39.pyc ADDED
Binary file (7.77 kB). View file
 
llava/model/__pycache__/consolidate.cpython-39.pyc ADDED
Binary file (1.08 kB). View file
 
llava/model/__pycache__/llava_arch.cpython-39.pyc ADDED
Binary file (14.4 kB). View file
 
llava/model/__pycache__/make_delta.cpython-39.pyc ADDED
Binary file (1.94 kB). View file
 
llava/model/__pycache__/utils.cpython-39.pyc ADDED
Binary file (1.04 kB). View file
 
llava/model/apply_delta.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
4
+ """
5
+
6
+ import argparse
7
+
8
+ import torch
9
+ from tqdm import tqdm
10
+ from transformers import AutoTokenizer, AutoModelForCausalLM
11
+ from llava import LlavaLlamaForCausalLM
12
+
13
+
14
+ def apply_delta(base_model_path, target_model_path, delta_path):
15
+ print("Loading base model")
16
+ base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17
+
18
+ print("Loading delta")
19
+ delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
20
+ delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
21
+
22
+ print("Applying delta")
23
+ for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
24
+ if name not in base.state_dict():
25
+ assert name in ["model.mm_projector.weight", "model.mm_projector.bias"], f"{name} not in base model"
26
+ continue
27
+ if param.data.shape == base.state_dict()[name].shape:
28
+ param.data += base.state_dict()[name]
29
+ else:
30
+ assert name in ["model.embed_tokens.weight", "lm_head.weight"], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}"
31
+ bparam = base.state_dict()[name]
32
+ param.data[: bparam.shape[0], : bparam.shape[1]] += bparam
33
+
34
+ print("Saving target model")
35
+ delta.save_pretrained(target_model_path)
36
+ delta_tokenizer.save_pretrained(target_model_path)
37
+
38
+
39
+ if __name__ == "__main__":
40
+ parser = argparse.ArgumentParser()
41
+ parser.add_argument("--base-model-path", type=str, required=True)
42
+ parser.add_argument("--target-model-path", type=str, required=True)
43
+ parser.add_argument("--delta-path", type=str, required=True)
44
+
45
+ args = parser.parse_args()
46
+
47
+ apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
llava/model/builder.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import os
17
+ import warnings
18
+ import shutil
19
+
20
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
21
+ import torch
22
+ from llava.model import *
23
+ from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
24
+ from llava.utils import rank0_print
25
+
26
+
27
+ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", attn_implementation="flash_attention_2", customized_config=None, overwrite_config=None, **kwargs):
28
+ kwargs["device_map"] = device_map
29
+
30
+ if load_8bit:
31
+ kwargs["load_in_8bit"] = True
32
+ elif load_4bit:
33
+ kwargs["load_in_4bit"] = True
34
+ kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4")
35
+ else:
36
+ kwargs["torch_dtype"] = torch.float16
37
+
38
+ if customized_config is not None:
39
+ kwargs["config"] = customized_config
40
+
41
+ if "multimodal" in kwargs:
42
+ if kwargs["multimodal"] is True:
43
+ is_multimodal = True
44
+ kwargs.pop("multimodal")
45
+ else:
46
+ is_multimodal = False
47
+
48
+ if "llava" in model_name.lower() or is_multimodal:
49
+ # Load LLaVA model
50
+ if "lora" in model_name.lower() and model_base is None:
51
+ warnings.warn(
52
+ "There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged."
53
+ )
54
+ if "lora" in model_name.lower() and model_base is not None:
55
+ lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
56
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
57
+ rank0_print("Loading LLaVA from base model...")
58
+ if "mixtral" in model_name.lower():
59
+ from llava.model.language_model.llava_mixtral import LlavaMixtralConfig
60
+
61
+ lora_cfg_pretrained = LlavaMixtralConfig.from_pretrained(model_path)
62
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
63
+ model = LlavaMixtralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
64
+ elif "mistral" in model_name.lower():
65
+ from llava.model.language_model.llava_mistral import LlavaMistralConfig
66
+
67
+ lora_cfg_pretrained = LlavaMistralConfig.from_pretrained(model_path)
68
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
69
+ model = LlavaMistralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
70
+ elif "gemma" in model_name.lower():
71
+ from llava.model.language_model.llava_gemma import LlavaGemmaConfig
72
+
73
+ lora_cfg_pretrained = LlavaGemmaConfig.from_pretrained(model_path)
74
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
75
+ model = LlavaGemmaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
76
+ else:
77
+ from llava.model.language_model.llava_llama import LlavaConfig
78
+
79
+ lora_cfg_pretrained = LlavaConfig.from_pretrained(model_path)
80
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
81
+ model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
82
+
83
+ token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
84
+ if model.lm_head.weight.shape[0] != token_num:
85
+ model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
86
+ model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
87
+
88
+ rank0_print("Loading additional LLaVA weights...")
89
+ if os.path.exists(os.path.join(model_path, "non_lora_trainables.bin")):
90
+ non_lora_trainables = torch.load(os.path.join(model_path, "non_lora_trainables.bin"), map_location="cpu")
91
+ else:
92
+ # this is probably from HF Hub
93
+ from huggingface_hub import hf_hub_download
94
+
95
+ def load_from_hf(repo_id, filename, subfolder=None):
96
+ cache_file = hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder)
97
+ return torch.load(cache_file, map_location="cpu")
98
+
99
+ non_lora_trainables = load_from_hf(model_path, "non_lora_trainables.bin")
100
+ non_lora_trainables = {(k[11:] if k.startswith("base_model.") else k): v for k, v in non_lora_trainables.items()}
101
+ if any(k.startswith("model.model.") for k in non_lora_trainables):
102
+ non_lora_trainables = {(k[6:] if k.startswith("model.") else k): v for k, v in non_lora_trainables.items()}
103
+ model.load_state_dict(non_lora_trainables, strict=False)
104
+
105
+ from peft import PeftModel
106
+
107
+ rank0_print("Loading LoRA weights...")
108
+ model = PeftModel.from_pretrained(model, model_path)
109
+ rank0_print("Merging LoRA weights...")
110
+ model = model.merge_and_unload()
111
+ rank0_print("Model is loaded...")
112
+ elif model_base is not None: # this may be mm projector only, loading projector with preset language mdoel
113
+ rank0_print(f"Loading LLaVA from base model {model_base}...")
114
+ if "mixtral" in model_name.lower():
115
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
116
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
117
+ model = LlavaMixtralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
118
+ elif "mistral" in model_name.lower() or "zephyr" in model_name.lower():
119
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
120
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
121
+ model = LlavaMistralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
122
+ elif "gemma" in model_name.lower():
123
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
124
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
125
+ model = LlavaGemmaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
126
+ elif (
127
+ "wizardlm-2" in model_name.lower()
128
+ and "vicuna" in model_name.lower()
129
+ or "llama" in model_name.lower()
130
+ or "yi" in model_name.lower()
131
+ or "nous-hermes" in model_name.lower()
132
+ or "llava-v1.6-34b" in model_name.lower()
133
+ or "llava-v1.5" in model_name.lower()
134
+ ):
135
+ from llava.model.language_model.llava_llama import LlavaConfig
136
+
137
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
138
+ if customized_config is None:
139
+ llava_cfg = LlavaConfig.from_pretrained(model_path)
140
+ if "v1.5" in model_name.lower():
141
+ llava_cfg.delay_load = True # a workaround for correctly loading v1.5 models
142
+ else:
143
+ llava_cfg = customized_config
144
+
145
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
146
+ llava_cfg = LlavaConfig.from_pretrained(model_path)
147
+ model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=llava_cfg, **kwargs)
148
+ else:
149
+ raise ValueError(f"Model {model_name} not supported")
150
+
151
+ mm_projector_weights = torch.load(os.path.join(model_path, "mm_projector.bin"), map_location="cpu")
152
+ mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
153
+ model.load_state_dict(mm_projector_weights, strict=False)
154
+ else:
155
+ rank0_print(f"Loaded LLaVA model: {model_path}")
156
+ if "mixtral" in model_name.lower():
157
+ from llava.model.language_model.llava_mixtral import LlavaMixtralConfig
158
+
159
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
160
+ if customized_config is None:
161
+ llava_cfg = LlavaMixtralConfig.from_pretrained(model_path)
162
+ else:
163
+ llava_cfg = customized_config
164
+
165
+ if overwrite_config is not None:
166
+ rank0_print(f"Overwriting config with {overwrite_config}")
167
+ for k, v in overwrite_config.items():
168
+ setattr(llava_cfg, k, v)
169
+
170
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
171
+ model = LlavaMixtralForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs)
172
+
173
+ elif "mistral" in model_name.lower() or "zephyr" in model_name.lower():
174
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
175
+ model = LlavaMistralForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, **kwargs)
176
+ elif (
177
+ "wizardlm-2" in model_name.lower()
178
+ and "vicuna" in model_name.lower()
179
+ or "llama" in model_name.lower()
180
+ or "yi" in model_name.lower()
181
+ or "nous-hermes" in model_name.lower()
182
+ or "llava-v1.6-34b" in model_name.lower()
183
+ or "llava-v1.5" in model_name.lower()
184
+ ):
185
+ from llava.model.language_model.llava_llama import LlavaConfig
186
+
187
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
188
+ if customized_config is None:
189
+ llava_cfg = LlavaConfig.from_pretrained(model_path)
190
+ if "v1.5" in model_name.lower():
191
+ llava_cfg.delay_load = True # a workaround for correctly loading v1.5 models
192
+ else:
193
+ llava_cfg = customized_config
194
+
195
+ if overwrite_config is not None:
196
+ rank0_print(f"Overwriting config with {overwrite_config}")
197
+ for k, v in overwrite_config.items():
198
+ setattr(llava_cfg, k, v)
199
+
200
+ model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs)
201
+
202
+ elif "qwen" in model_name.lower() or "quyen" in model_name.lower():
203
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
204
+ if "moe" in model_name.lower() or "A14B" in model_name.lower():
205
+ from llava.model.language_model.llava_qwen_moe import LlavaQwenMoeConfig
206
+ if overwrite_config is not None:
207
+ llava_cfg = LlavaQwenMoeConfig.from_pretrained(model_path)
208
+ rank0_print(f"Overwriting config with {overwrite_config}")
209
+ for k, v in overwrite_config.items():
210
+ setattr(llava_cfg, k, v)
211
+ model = LlavaQwenMoeForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs)
212
+ else:
213
+ model = LlavaQwenMoeForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, **kwargs)
214
+
215
+ else:
216
+ from llava.model.language_model.llava_qwen import LlavaQwenConfig
217
+ if overwrite_config is not None:
218
+ llava_cfg = LlavaQwenConfig.from_pretrained(model_path)
219
+ rank0_print(f"Overwriting config with {overwrite_config}")
220
+ for k, v in overwrite_config.items():
221
+ setattr(llava_cfg, k, v)
222
+ model = LlavaQwenForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs)
223
+ else:
224
+ model = LlavaQwenForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, **kwargs)
225
+
226
+ elif "gemma" in model_name.lower():
227
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
228
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
229
+ model = LlavaGemmaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, config=cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
230
+ else:
231
+ try:
232
+ from llava.model.language_model.llava_llama import LlavaConfig
233
+
234
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
235
+ if customized_config is None:
236
+ llava_cfg = LlavaConfig.from_pretrained(model_path)
237
+ if "v1.5" in model_path.lower():
238
+ llava_cfg.delay_load = True # a workaround for correctly loading v1.5 models
239
+ else:
240
+ llava_cfg = customized_config
241
+
242
+ if overwrite_config is not None:
243
+ rank0_print(f"Overwriting config with {overwrite_config}")
244
+ for k, v in overwrite_config.items():
245
+ setattr(llava_cfg, k, v)
246
+ model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs)
247
+ except:
248
+ raise ValueError(f"Model {model_name} not supported")
249
+
250
+ else:
251
+ # Load language model
252
+ if model_base is not None:
253
+ # PEFT model
254
+ from peft import PeftModel
255
+
256
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
257
+ model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto")
258
+ print(f"Loading LoRA weights from {model_path}")
259
+ model = PeftModel.from_pretrained(model, model_path)
260
+ print(f"Merging weights")
261
+ model = model.merge_and_unload()
262
+ print("Convert to FP16...")
263
+ model.to(torch.float16)
264
+ else:
265
+ use_fast = False
266
+ if "mpt" in model_name.lower().replace("prompt", ""):
267
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
268
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
269
+ else:
270
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
271
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
272
+
273
+ rank0_print(f"Model Class: {model.__class__.__name__}")
274
+ image_processor = None
275
+
276
+ if "llava" in model_name.lower() or is_multimodal:
277
+ mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
278
+ mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
279
+ if mm_use_im_patch_token:
280
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
281
+ if mm_use_im_start_end:
282
+ tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
283
+ model.resize_token_embeddings(len(tokenizer))
284
+
285
+ vision_tower = model.get_vision_tower()
286
+ if not vision_tower.is_loaded:
287
+ vision_tower.load_model(device_map=device_map)
288
+ if device_map != "auto":
289
+ vision_tower.to(device="cuda", dtype=torch.float16)
290
+ image_processor = vision_tower.image_processor
291
+
292
+ if hasattr(model.config, "max_sequence_length"):
293
+ context_len = model.config.max_sequence_length
294
+ elif hasattr(model.config, "max_position_embeddings"):
295
+ context_len = model.config.max_position_embeddings
296
+ elif hasattr(model.config, "tokenizer_model_max_length"):
297
+ context_len = model.config.tokenizer_model_max_length
298
+ else:
299
+ context_len = 2048
300
+
301
+ return tokenizer, model, image_processor, context_len
llava/model/consolidate.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
4
+ """
5
+
6
+ import argparse
7
+
8
+ import torch
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM
10
+ from llava.model import *
11
+ from llava.model.utils import auto_upgrade
12
+
13
+
14
+ def consolidate_ckpt(src_path, dst_path):
15
+ print("Loading model")
16
+ auto_upgrade(src_path)
17
+ src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
18
+ src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
19
+ src_model.save_pretrained(dst_path)
20
+ src_tokenizer.save_pretrained(dst_path)
21
+
22
+
23
+ if __name__ == "__main__":
24
+ parser = argparse.ArgumentParser()
25
+ parser.add_argument("--src", type=str, required=True)
26
+ parser.add_argument("--dst", type=str, required=True)
27
+
28
+ args = parser.parse_args()
29
+
30
+ consolidate_ckpt(args.src, args.dst)
llava/model/language_model/__pycache__/llava_gemma.cpython-39.pyc ADDED
Binary file (3.72 kB). View file
 
llava/model/language_model/__pycache__/llava_llama.cpython-39.pyc ADDED
Binary file (4.24 kB). View file
 
llava/model/language_model/__pycache__/llava_mistral.cpython-39.pyc ADDED
Binary file (3.97 kB). View file
 
llava/model/language_model/__pycache__/llava_mixtral.cpython-39.pyc ADDED
Binary file (4.1 kB). View file
 
llava/model/language_model/__pycache__/llava_mpt.cpython-39.pyc ADDED
Binary file (3.34 kB). View file
 
llava/model/language_model/__pycache__/llava_qwen.cpython-39.pyc ADDED
Binary file (4.13 kB). View file
 
llava/model/language_model/__pycache__/llava_qwen_moe.cpython-39.pyc ADDED
Binary file (4.11 kB). View file
 
llava/model/language_model/__pycache__/modeling_llama.cpython-39.pyc ADDED
Binary file (48.9 kB). View file
 
llava/model/language_model/llava_gemma.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Duc Q. Nguyen, Haotian Liu and Bo Li
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from torch.nn import CrossEntropyLoss
21
+
22
+ from transformers import AutoConfig, AutoModelForCausalLM, GemmaConfig, GemmaModel, GemmaForCausalLM
23
+
24
+ from transformers.modeling_outputs import CausalLMOutputWithPast
25
+ from transformers.generation.utils import GenerateOutput
26
+
27
+ from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
28
+
29
+
30
+ class LlavaGemmaConfig(GemmaConfig):
31
+ model_type = "llava_gemma"
32
+
33
+
34
+ class LlavaGemmaModel(LlavaMetaModel, GemmaModel):
35
+ config_class = LlavaGemmaConfig
36
+
37
+ def __init__(self, config: GemmaConfig):
38
+ super(LlavaGemmaModel, self).__init__(config)
39
+
40
+
41
+ class LlavaGemmaForCausalLM(GemmaForCausalLM, LlavaMetaForCausalLM):
42
+ config_class = LlavaGemmaConfig
43
+
44
+ def __init__(self, config):
45
+ super(GemmaForCausalLM, self).__init__(config)
46
+ self.model = LlavaGemmaModel(config)
47
+
48
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
49
+
50
+ # Initialize weights and apply final processing
51
+ self.post_init()
52
+
53
+ def get_model(self):
54
+ return self.model
55
+
56
+ def forward(
57
+ self,
58
+ input_ids: torch.LongTensor = None,
59
+ attention_mask: Optional[torch.Tensor] = None,
60
+ position_ids: Optional[torch.LongTensor] = None,
61
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
62
+ inputs_embeds: Optional[torch.FloatTensor] = None,
63
+ labels: Optional[torch.LongTensor] = None,
64
+ use_cache: Optional[bool] = None,
65
+ output_attentions: Optional[bool] = None,
66
+ output_hidden_states: Optional[bool] = None,
67
+ images: Optional[torch.FloatTensor] = None,
68
+ image_sizes: Optional[List[List[int]]] = None,
69
+ return_dict: Optional[bool] = None,
70
+ cache_position: Optional[torch.LongTensor] = None,
71
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
72
+
73
+ if inputs_embeds is None:
74
+ (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes)
75
+
76
+ return super().forward(
77
+ input_ids=input_ids,
78
+ attention_mask=attention_mask,
79
+ position_ids=position_ids,
80
+ past_key_values=past_key_values,
81
+ inputs_embeds=inputs_embeds,
82
+ labels=labels,
83
+ use_cache=use_cache,
84
+ output_attentions=output_attentions,
85
+ output_hidden_states=output_hidden_states,
86
+ return_dict=return_dict,
87
+ cache_position=cache_position,
88
+ )
89
+
90
+ @torch.no_grad()
91
+ def generate(
92
+ self,
93
+ inputs: Optional[torch.Tensor] = None,
94
+ images: Optional[torch.Tensor] = None,
95
+ image_sizes: Optional[torch.Tensor] = None,
96
+ **kwargs,
97
+ ) -> Union[GenerateOutput, torch.LongTensor]:
98
+ position_ids = kwargs.pop("position_ids", None)
99
+ attention_mask = kwargs.pop("attention_mask", None)
100
+ if "inputs_embeds" in kwargs:
101
+ raise NotImplementedError("`inputs_embeds` is not supported")
102
+
103
+ if images is not None:
104
+ (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes)
105
+ else:
106
+ inputs_embeds = self.get_model().embed_tokens(inputs)
107
+
108
+ return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
109
+
110
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
111
+ images = kwargs.pop("images", None)
112
+ image_sizes = kwargs.pop("image_sizes", None)
113
+ inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
114
+ if images is not None:
115
+ inputs["images"] = images
116
+ if image_sizes is not None:
117
+ inputs["image_sizes"] = image_sizes
118
+ return inputs
119
+
120
+
121
+ AutoConfig.register("llava_gemma", LlavaGemmaConfig)
122
+ AutoModelForCausalLM.register(LlavaGemmaConfig, LlavaGemmaForCausalLM)
llava/model/language_model/llava_llama.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig
22
+
23
+ from torch.nn import CrossEntropyLoss
24
+
25
+
26
+ # , LlamaModel, LlamaForCausalLM, GenerationConfig
27
+ # from .modeling_llama import LlamaModel, LlamaForCausalLM
28
+ from transformers import LlamaModel, LlamaForCausalLM
29
+ from transformers.modeling_outputs import CausalLMOutputWithPast
30
+ from transformers.generation.utils import GenerateOutput
31
+
32
+ from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
33
+
34
+
35
+ class LlavaConfig(LlamaConfig):
36
+ model_type = "llava_llama"
37
+ temperature: float = 0.0 # reset to 0.0, previously 0.9 for Vicuna
38
+ max_new_tokens: int = 1024
39
+ do_sample: bool = False
40
+ top_p: Optional[float] = None
41
+ # rope_scaling: Optional[dict] = {}
42
+
43
+
44
+ class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
45
+ config_class = LlavaConfig
46
+
47
+ def __init__(self, config: LlamaConfig):
48
+ super(LlavaLlamaModel, self).__init__(config)
49
+
50
+
51
+ class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
52
+ config_class = LlavaConfig
53
+
54
+ def __init__(self, config):
55
+ LlamaForCausalLM.__init__(self, config)
56
+
57
+ # configure default generation settings
58
+ config.model_type = "llava_llama"
59
+ # config.rope_scaling = None
60
+
61
+ self.model = LlavaLlamaModel(config)
62
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
63
+ # Initialize weights and apply final processing
64
+ self.post_init()
65
+
66
+ def get_model(self):
67
+ return self.model
68
+
69
+ def forward(
70
+ self,
71
+ input_ids: torch.LongTensor = None,
72
+ attention_mask: Optional[torch.Tensor] = None,
73
+ position_ids: Optional[torch.LongTensor] = None,
74
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
75
+ inputs_embeds: Optional[torch.FloatTensor] = None,
76
+ labels: Optional[torch.LongTensor] = None,
77
+ use_cache: Optional[bool] = None,
78
+ output_attentions: Optional[bool] = None,
79
+ output_hidden_states: Optional[bool] = None,
80
+ images: Optional[torch.FloatTensor] = None,
81
+ image_sizes: Optional[List[List[int]]] = None,
82
+ return_dict: Optional[bool] = None,
83
+ modalities: Optional[List[str]] = ["image"],
84
+ dpo_forward: Optional[bool] = None,
85
+ cache_position=None,
86
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
87
+
88
+ if inputs_embeds is None:
89
+ (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities, image_sizes)
90
+
91
+ if dpo_forward:
92
+ outputs = self.model(
93
+ input_ids=input_ids,
94
+ attention_mask=attention_mask,
95
+ position_ids=position_ids,
96
+ past_key_values=past_key_values,
97
+ inputs_embeds=inputs_embeds,
98
+ use_cache=use_cache,
99
+ output_attentions=output_attentions,
100
+ output_hidden_states=output_hidden_states,
101
+ return_dict=return_dict,
102
+ )
103
+
104
+ hidden_states = outputs[0]
105
+ logits = self.lm_head(hidden_states)
106
+ return logits, labels
107
+
108
+ else:
109
+ return super().forward(
110
+ input_ids=input_ids,
111
+ attention_mask=attention_mask,
112
+ position_ids=position_ids,
113
+ past_key_values=past_key_values,
114
+ inputs_embeds=inputs_embeds,
115
+ labels=labels,
116
+ use_cache=use_cache,
117
+ output_attentions=output_attentions,
118
+ output_hidden_states=output_hidden_states,
119
+ return_dict=return_dict,
120
+ )
121
+
122
+ @torch.no_grad()
123
+ def generate(
124
+ self,
125
+ inputs: Optional[torch.Tensor] = None,
126
+ images: Optional[torch.Tensor] = None,
127
+ image_sizes: Optional[torch.Tensor] = None,
128
+ modalities: Optional[List[str]] = ["image"],
129
+ **kwargs,
130
+ ) -> Union[GenerateOutput, torch.LongTensor]:
131
+ modalities = kwargs.pop("modalities", None) if "modalities" in kwargs and modalities is None else modalities
132
+ position_ids = kwargs.pop("position_ids", None)
133
+ attention_mask = kwargs.pop("attention_mask", None)
134
+ if "inputs_embeds" in kwargs:
135
+ raise NotImplementedError("`inputs_embeds` is not supported")
136
+
137
+ if images is not None:
138
+ (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, modalities, image_sizes=image_sizes)
139
+ else:
140
+ inputs_embeds = self.get_model().embed_tokens(inputs)
141
+
142
+ return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
143
+
144
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
145
+ images = kwargs.pop("images", None)
146
+ image_sizes = kwargs.pop("image_sizes", None)
147
+ inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
148
+ if images is not None:
149
+ inputs["images"] = images
150
+ if image_sizes is not None:
151
+ inputs["image_sizes"] = image_sizes
152
+ return inputs
153
+
154
+
155
+ AutoConfig.register("llava_llama", LlavaConfig)
156
+ AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
llava/model/language_model/llava_mistral.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from torch.nn import CrossEntropyLoss
21
+
22
+ from transformers import AutoConfig, AutoModelForCausalLM, MistralConfig, MistralModel, MistralForCausalLM, GenerationConfig
23
+
24
+ from transformers.modeling_outputs import CausalLMOutputWithPast
25
+ from transformers.generation.utils import GenerateOutput
26
+
27
+ from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
28
+
29
+
30
+ class LlavaMistralConfig(MistralConfig):
31
+ model_type = "llava_mistral"
32
+ temperature: float = 0.0 # reset to 0.0, previously 0.9 for Vicuna
33
+ max_new_tokens: int = 1024
34
+ do_sample: bool = False
35
+ top_p: Optional[float] = None
36
+
37
+
38
+ class LlavaMistralModel(LlavaMetaModel, MistralModel):
39
+ config_class = LlavaMistralConfig
40
+
41
+ def __init__(self, config: MistralConfig):
42
+ super(LlavaMistralModel, self).__init__(config)
43
+
44
+
45
+ class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM):
46
+ config_class = LlavaMistralConfig
47
+
48
+ def __init__(self, config):
49
+ super(MistralForCausalLM, self).__init__(config)
50
+
51
+ config.model_type = "llava_mistral"
52
+ config.rope_scaling = None
53
+
54
+ self.model = LlavaMistralModel(config)
55
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
56
+ # Initialize weights and apply final processing
57
+ self.post_init()
58
+
59
+ def get_model(self):
60
+ return self.model
61
+
62
+ def forward(
63
+ self,
64
+ input_ids: torch.LongTensor = None,
65
+ attention_mask: Optional[torch.Tensor] = None,
66
+ position_ids: Optional[torch.LongTensor] = None,
67
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
68
+ inputs_embeds: Optional[torch.FloatTensor] = None,
69
+ labels: Optional[torch.LongTensor] = None,
70
+ use_cache: Optional[bool] = None,
71
+ output_attentions: Optional[bool] = None,
72
+ output_hidden_states: Optional[bool] = None,
73
+ images: Optional[torch.FloatTensor] = None,
74
+ image_sizes: Optional[List[List[int]]] = None,
75
+ return_dict: Optional[bool] = None,
76
+ cache_position=None,
77
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
78
+
79
+ if inputs_embeds is None:
80
+ (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes)
81
+
82
+ return super().forward(
83
+ input_ids=input_ids,
84
+ attention_mask=attention_mask,
85
+ position_ids=position_ids,
86
+ past_key_values=past_key_values,
87
+ inputs_embeds=inputs_embeds,
88
+ labels=labels,
89
+ use_cache=use_cache,
90
+ output_attentions=output_attentions,
91
+ output_hidden_states=output_hidden_states,
92
+ return_dict=return_dict,
93
+ )
94
+
95
+ @torch.no_grad()
96
+ def generate(
97
+ self,
98
+ inputs: Optional[torch.Tensor] = None,
99
+ images: Optional[torch.Tensor] = None,
100
+ image_sizes: Optional[torch.Tensor] = None,
101
+ **kwargs,
102
+ ) -> Union[GenerateOutput, torch.LongTensor]:
103
+ position_ids = kwargs.pop("position_ids", None)
104
+ attention_mask = kwargs.pop("attention_mask", None)
105
+ if "inputs_embeds" in kwargs:
106
+ raise NotImplementedError("`inputs_embeds` is not supported")
107
+
108
+ if images is not None:
109
+ (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes)
110
+ else:
111
+ inputs_embeds = self.get_model().embed_tokens(inputs)
112
+
113
+ return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
114
+
115
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
116
+ images = kwargs.pop("images", None)
117
+ image_sizes = kwargs.pop("image_sizes", None)
118
+ inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
119
+ if images is not None:
120
+ inputs["images"] = images
121
+ if image_sizes is not None:
122
+ inputs["image_sizes"] = image_sizes
123
+ return inputs
124
+
125
+
126
+ AutoConfig.register("llava_mistral", LlavaMistralConfig)
127
+ AutoModelForCausalLM.register(LlavaMistralConfig, LlavaMistralForCausalLM)
llava/model/language_model/llava_mixtral.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from torch.nn import CrossEntropyLoss
21
+
22
+ from transformers import AutoConfig, AutoModelForCausalLM, MixtralConfig, MixtralModel, MixtralForCausalLM, GenerationConfig
23
+
24
+ from transformers.modeling_outputs import CausalLMOutputWithPast
25
+ from transformers.generation.utils import GenerateOutput
26
+
27
+ from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
28
+
29
+
30
+ class LlavaMixtralConfig(MixtralConfig):
31
+ model_type = "llava_mixtral"
32
+
33
+
34
+ class LlavaMixtralModel(LlavaMetaModel, MixtralModel):
35
+ config_class = LlavaMixtralConfig
36
+
37
+ def __init__(self, config: MixtralConfig):
38
+ super(LlavaMixtralModel, self).__init__(config)
39
+
40
+
41
+ class LlavaMixtralForCausalLM(MixtralForCausalLM, LlavaMetaForCausalLM):
42
+ config_class = LlavaMixtralConfig
43
+
44
+ def __init__(self, config):
45
+ super(MixtralForCausalLM, self).__init__(config)
46
+
47
+ config.model_type = "llava_mixtral"
48
+ config.rope_scaling = None
49
+ self.model = LlavaMixtralModel(config)
50
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
51
+ # Initialize weights and apply final processing
52
+ self.post_init()
53
+
54
+ def get_model(self):
55
+ return self.model
56
+
57
+ def forward(
58
+ self,
59
+ input_ids: torch.LongTensor = None,
60
+ attention_mask: Optional[torch.Tensor] = None,
61
+ position_ids: Optional[torch.LongTensor] = None,
62
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
63
+ inputs_embeds: Optional[torch.FloatTensor] = None,
64
+ labels: Optional[torch.LongTensor] = None,
65
+ use_cache: Optional[bool] = None,
66
+ output_attentions: Optional[bool] = None,
67
+ output_hidden_states: Optional[bool] = None,
68
+ images: Optional[torch.FloatTensor] = None,
69
+ image_sizes: Optional[List[List[int]]] = None,
70
+ return_dict: Optional[bool] = None,
71
+ modalities: Optional[List[str]] = ["image"],
72
+ dpo_forward: Optional[bool] = None,
73
+ cache_position=None,
74
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
75
+
76
+ if inputs_embeds is None:
77
+ (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities, image_sizes)
78
+
79
+ if dpo_forward:
80
+ outputs = self.model(
81
+ input_ids=input_ids,
82
+ attention_mask=attention_mask,
83
+ position_ids=position_ids,
84
+ past_key_values=past_key_values,
85
+ inputs_embeds=inputs_embeds,
86
+ use_cache=use_cache,
87
+ output_attentions=output_attentions,
88
+ output_hidden_states=output_hidden_states,
89
+ return_dict=return_dict,
90
+ )
91
+
92
+ hidden_states = outputs[0]
93
+ logits = self.lm_head(hidden_states)
94
+ return logits, labels
95
+
96
+ else:
97
+ return super().forward(
98
+ input_ids=input_ids,
99
+ attention_mask=attention_mask,
100
+ position_ids=position_ids,
101
+ past_key_values=past_key_values,
102
+ inputs_embeds=inputs_embeds,
103
+ labels=labels,
104
+ use_cache=use_cache,
105
+ output_attentions=output_attentions,
106
+ output_hidden_states=output_hidden_states,
107
+ return_dict=return_dict,
108
+ )
109
+
110
+ @torch.no_grad()
111
+ def generate(
112
+ self,
113
+ inputs: Optional[torch.Tensor] = None,
114
+ images: Optional[torch.Tensor] = None,
115
+ image_sizes: Optional[torch.Tensor] = None,
116
+ modalities: Optional[List[str]] = ["image"],
117
+ **kwargs,
118
+ ) -> Union[GenerateOutput, torch.LongTensor]:
119
+ position_ids = kwargs.pop("position_ids", None)
120
+ attention_mask = kwargs.pop("attention_mask", None)
121
+ if "inputs_embeds" in kwargs:
122
+ raise NotImplementedError("`inputs_embeds` is not supported")
123
+
124
+ if images is not None:
125
+ (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, modalities, image_sizes=image_sizes)
126
+ else:
127
+ inputs_embeds = self.get_model().embed_tokens(inputs)
128
+
129
+ return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
130
+
131
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
132
+ images = kwargs.pop("images", None)
133
+ image_sizes = kwargs.pop("image_sizes", None)
134
+ inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
135
+ if images is not None:
136
+ inputs["images"] = images
137
+ if image_sizes is not None:
138
+ inputs["image_sizes"] = image_sizes
139
+ return inputs
140
+
141
+
142
+ AutoConfig.register("llava_mixtral", LlavaMixtralConfig)
143
+ AutoModelForCausalLM.register(LlavaMixtralConfig, LlavaMixtralForCausalLM)
llava/model/language_model/llava_mpt.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Optional, Tuple
17
+
18
+ import torch
19
+
20
+ from transformers import AutoConfig, AutoModelForCausalLM, MptConfig, MptForCausalLM, MptModel, GenerationConfig
21
+ from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
22
+
23
+
24
+ class LlavaMptConfig(MptConfig):
25
+ model_type = "llava_mpt"
26
+
27
+
28
+ class LlavaMptModel(LlavaMetaModel, MptModel):
29
+ config_class = LlavaMptConfig
30
+
31
+ def __init__(self, config: MptConfig):
32
+ config.hidden_size = config.d_model
33
+ super(LlavaMptModel, self).__init__(config)
34
+
35
+ def embed_tokens(self, x):
36
+ return self.wte(x)
37
+
38
+
39
+ class LlavaMptForCausalLM(MptForCausalLM, LlavaMetaForCausalLM):
40
+ config_class = LlavaMptConfig
41
+ supports_gradient_checkpointing = True
42
+
43
+ def __init__(self, config):
44
+ super(MptForCausalLM, self).__init__(config)
45
+
46
+ config.model_type = "llava_mpt"
47
+ config.rope_scaling = None
48
+ self.generation_config = GenerationConfig(
49
+ temperature=0.0,
50
+ max_new_tokens=1024,
51
+ do_sample=False,
52
+ top_p=None,
53
+ )
54
+
55
+ self.transformer = LlavaMptModel(config)
56
+ self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
57
+
58
+ # Initialize weights and apply final processing
59
+ self.post_init()
60
+
61
+ def get_model(self):
62
+ return self.transformer
63
+
64
+ def _set_gradient_checkpointing(self, module, value=False):
65
+ if isinstance(module, LlavaMptModel):
66
+ module.gradient_checkpointing = value
67
+
68
+ def forward(
69
+ self,
70
+ input_ids: Optional[torch.LongTensor] = None,
71
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
72
+ attention_mask: Optional[torch.Tensor] = None,
73
+ inputs_embeds: Optional[torch.Tensor] = None,
74
+ labels: Optional[torch.Tensor] = None,
75
+ use_cache: Optional[bool] = None,
76
+ output_attentions: Optional[bool] = None,
77
+ output_hidden_states: Optional[bool] = None,
78
+ return_dict: Optional[bool] = None,
79
+ cache_position=None,
80
+ images=None,
81
+ ):
82
+
83
+ input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
84
+
85
+ return super().forward(
86
+ input_ids,
87
+ past_key_values=past_key_values,
88
+ attention_mask=attention_mask,
89
+ inputs_embeds=inputs_embeds,
90
+ labels=labels,
91
+ use_cache=use_cache,
92
+ output_attentions=output_attentions,
93
+ output_hidden_states=output_hidden_states,
94
+ return_dict=return_dict,
95
+ )
96
+
97
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
98
+ images = kwargs.pop("images", None)
99
+ _inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
100
+ _inputs["images"] = images
101
+ return _inputs
102
+
103
+
104
+ AutoConfig.register("llava_mpt", LlavaMptConfig)
105
+ AutoModelForCausalLM.register(LlavaMptConfig, LlavaMptForCausalLM)
llava/model/language_model/llava_qwen.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Hao Zhang
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union, Dict
17
+ import torch
18
+ import torch.nn as nn
19
+ from torch.nn import CrossEntropyLoss
20
+
21
+ import transformers
22
+ from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig, LlamaModel, LlamaForCausalLM
23
+
24
+ from transformers.modeling_outputs import CausalLMOutputWithPast
25
+ from transformers.generation.utils import GenerateOutput
26
+
27
+ # from ...constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
28
+ from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
29
+ from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM
30
+
31
+ # from .qwen.modeling_qwen import QWenLMHeadModel, QWenModel
32
+ # from .qwen.configuration_qwen import QWenConfig
33
+
34
+
35
+ class LlavaQwenConfig(Qwen2Config):
36
+ model_type = "llava_qwen"
37
+
38
+
39
+ class LlavaQwenModel(LlavaMetaModel, Qwen2Model):
40
+ config_class = LlavaQwenConfig
41
+
42
+ def __init__(self, config: Qwen2Config):
43
+ super(LlavaQwenModel, self).__init__(config)
44
+
45
+
46
+ class LlavaQwenForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
47
+ config_class = LlavaQwenConfig
48
+
49
+ def __init__(self, config):
50
+ # super(Qwen2ForCausalLM, self).__init__(config)
51
+ Qwen2ForCausalLM.__init__(self, config)
52
+ config.model_type = "llava_qwen"
53
+ config.rope_scaling = None
54
+
55
+ self.model = LlavaQwenModel(config)
56
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
57
+ # Initialize weights and apply final processing
58
+ self.post_init()
59
+
60
+ def get_model(self):
61
+ return self.model
62
+
63
+ def forward(
64
+ self,
65
+ input_ids: torch.LongTensor = None,
66
+ attention_mask: Optional[torch.Tensor] = None,
67
+ position_ids: Optional[torch.LongTensor] = None,
68
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
69
+ inputs_embeds: Optional[torch.FloatTensor] = None,
70
+ labels: Optional[torch.LongTensor] = None,
71
+ use_cache: Optional[bool] = None,
72
+ output_attentions: Optional[bool] = None,
73
+ output_hidden_states: Optional[bool] = None,
74
+ images: Optional[torch.FloatTensor] = None,
75
+ image_sizes: Optional[List[List[int]]] = None,
76
+ return_dict: Optional[bool] = None,
77
+ modalities: Optional[List[str]] = ["image"],
78
+ dpo_forward: Optional[bool] = False,
79
+ cache_position=None,
80
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
81
+
82
+ if inputs_embeds is None:
83
+ (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities, image_sizes)
84
+
85
+ if dpo_forward:
86
+ outputs = self.model(
87
+ input_ids=input_ids,
88
+ attention_mask=attention_mask,
89
+ position_ids=position_ids,
90
+ past_key_values=past_key_values,
91
+ inputs_embeds=inputs_embeds,
92
+ use_cache=use_cache,
93
+ output_attentions=output_attentions,
94
+ output_hidden_states=output_hidden_states,
95
+ return_dict=return_dict,
96
+ )
97
+
98
+ hidden_states = outputs[0]
99
+ logits = self.lm_head(hidden_states)
100
+ return logits, labels
101
+
102
+ else:
103
+ return super().forward(
104
+ input_ids=input_ids,
105
+ attention_mask=attention_mask,
106
+ position_ids=position_ids,
107
+ past_key_values=past_key_values,
108
+ inputs_embeds=inputs_embeds,
109
+ labels=labels,
110
+ use_cache=use_cache,
111
+ output_attentions=output_attentions,
112
+ output_hidden_states=output_hidden_states,
113
+ return_dict=return_dict,
114
+ )
115
+
116
+ @torch.no_grad()
117
+ def generate(
118
+ self,
119
+ inputs: Optional[torch.Tensor] = None,
120
+ images: Optional[torch.Tensor] = None,
121
+ image_sizes: Optional[torch.Tensor] = None,
122
+ modalities: Optional[List[str]] = ["image"],
123
+ **kwargs,
124
+ ) -> Union[GenerateOutput, torch.LongTensor]:
125
+ position_ids = kwargs.pop("position_ids", None)
126
+ attention_mask = kwargs.pop("attention_mask", None)
127
+ if "inputs_embeds" in kwargs:
128
+ raise NotImplementedError("`inputs_embeds` is not supported")
129
+
130
+ if images is not None:
131
+ (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, modalities, image_sizes=image_sizes)
132
+ else:
133
+ inputs_embeds = self.get_model().embed_tokens(inputs)
134
+
135
+ return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
136
+
137
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
138
+ images = kwargs.pop("images", None)
139
+ image_sizes = kwargs.pop("image_sizes", None)
140
+ inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
141
+ if images is not None:
142
+ inputs["images"] = images
143
+ if image_sizes is not None:
144
+ inputs["image_sizes"] = image_sizes
145
+ return inputs
146
+
147
+
148
+ AutoConfig.register("llava_qwen", LlavaQwenConfig)
149
+ AutoModelForCausalLM.register(LlavaQwenConfig, LlavaQwenForCausalLM)
llava/model/language_model/llava_qwen_moe.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Hao Zhang
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union, Dict
17
+ import torch
18
+ import torch.nn as nn
19
+ from torch.nn import CrossEntropyLoss
20
+
21
+ import transformers
22
+ from transformers import AutoConfig, AutoModelForCausalLM
23
+
24
+ from transformers.modeling_outputs import CausalLMOutputWithPast
25
+ from transformers.generation.utils import GenerateOutput
26
+
27
+ # from ...constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
28
+ from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
29
+ from transformers import Qwen2MoeConfig, Qwen2MoeModel, Qwen2MoeForCausalLM
30
+
31
+ # from .qwen.modeling_qwen import QWenLMHeadModel, QWenModel
32
+ # from .qwen.configuration_qwen import QWenConfig
33
+
34
+
35
+ class LlavaQwenMoeConfig(Qwen2MoeConfig):
36
+ model_type = "llava_qwen_moe"
37
+
38
+
39
+ class LlavaQwenMoeModel(LlavaMetaModel, Qwen2MoeModel):
40
+ config_class = LlavaQwenMoeConfig
41
+
42
+ def __init__(self, config: Qwen2MoeConfig):
43
+ super(LlavaQwenMoeModel, self).__init__(config)
44
+
45
+
46
+ class LlavaQwenMoeForCausalLM(Qwen2MoeForCausalLM, LlavaMetaForCausalLM):
47
+ config_class = LlavaQwenMoeConfig
48
+
49
+ def __init__(self, config):
50
+ # super(Qwen2MoeForCausalLM, self).__init__(config)
51
+ Qwen2MoeForCausalLM.__init__(self, config)
52
+ config.model_type = "llava_qwen_moe"
53
+ config.rope_scaling = None
54
+
55
+ self.model = LlavaQwenMoeModel(config)
56
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
57
+ # Initialize weights and apply final processing
58
+ self.post_init()
59
+
60
+ def get_model(self):
61
+ return self.model
62
+
63
+ def forward(
64
+ self,
65
+ input_ids: torch.LongTensor = None,
66
+ attention_mask: Optional[torch.Tensor] = None,
67
+ position_ids: Optional[torch.LongTensor] = None,
68
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
69
+ inputs_embeds: Optional[torch.FloatTensor] = None,
70
+ labels: Optional[torch.LongTensor] = None,
71
+ use_cache: Optional[bool] = None,
72
+ output_attentions: Optional[bool] = None,
73
+ output_hidden_states: Optional[bool] = None,
74
+ images: Optional[torch.FloatTensor] = None,
75
+ image_sizes: Optional[List[List[int]]] = None,
76
+ return_dict: Optional[bool] = None,
77
+ modalities: Optional[List[str]] = ["image"],
78
+ dpo_forward: Optional[bool] = False,
79
+ cache_position=None,
80
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
81
+
82
+ if inputs_embeds is None:
83
+ (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities, image_sizes)
84
+
85
+ if dpo_forward:
86
+ outputs = self.model(
87
+ input_ids=input_ids,
88
+ attention_mask=attention_mask,
89
+ position_ids=position_ids,
90
+ past_key_values=past_key_values,
91
+ inputs_embeds=inputs_embeds,
92
+ use_cache=use_cache,
93
+ output_attentions=output_attentions,
94
+ output_hidden_states=output_hidden_states,
95
+ return_dict=return_dict,
96
+ )
97
+
98
+ hidden_states = outputs[0]
99
+ logits = self.lm_head(hidden_states)
100
+ return logits, labels
101
+
102
+ else:
103
+ return super().forward(
104
+ input_ids=input_ids,
105
+ attention_mask=attention_mask,
106
+ position_ids=position_ids,
107
+ past_key_values=past_key_values,
108
+ inputs_embeds=inputs_embeds,
109
+ labels=labels,
110
+ use_cache=use_cache,
111
+ output_attentions=output_attentions,
112
+ output_hidden_states=output_hidden_states,
113
+ return_dict=return_dict,
114
+ )
115
+
116
+ @torch.no_grad()
117
+ def generate(
118
+ self,
119
+ inputs: Optional[torch.Tensor] = None,
120
+ images: Optional[torch.Tensor] = None,
121
+ image_sizes: Optional[torch.Tensor] = None,
122
+ modalities: Optional[List[str]] = ["image"],
123
+ **kwargs,
124
+ ) -> Union[GenerateOutput, torch.LongTensor]:
125
+ position_ids = kwargs.pop("position_ids", None)
126
+ attention_mask = kwargs.pop("attention_mask", None)
127
+ if "inputs_embeds" in kwargs:
128
+ raise NotImplementedError("`inputs_embeds` is not supported")
129
+
130
+ if images is not None:
131
+ (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, modalities, image_sizes=image_sizes)
132
+ else:
133
+ inputs_embeds = self.get_model().embed_tokens(inputs)
134
+
135
+ return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
136
+
137
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
138
+ images = kwargs.pop("images", None)
139
+ image_sizes = kwargs.pop("image_sizes", None)
140
+ inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
141
+ if images is not None:
142
+ inputs["images"] = images
143
+ if image_sizes is not None:
144
+ inputs["image_sizes"] = image_sizes
145
+ return inputs
146
+
147
+
148
+ AutoConfig.register("llava_qwen_moe", LlavaQwenMoeConfig)
149
+ AutoModelForCausalLM.register(LlavaQwenMoeConfig, LlavaQwenMoeForCausalLM)
llava/model/language_model/modeling_llama.py ADDED
@@ -0,0 +1,1649 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ PyTorch LLaMA model."""
21
+ import math
22
+ import warnings
23
+ from typing import List, Optional, Tuple, Union
24
+
25
+ import torch
26
+ import torch.nn.functional as F
27
+ import torch.utils.checkpoint
28
+ from torch import nn
29
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
30
+
31
+ from transformers.activations import ACT2FN
32
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
33
+ from transformers.modeling_outputs import (
34
+ BaseModelOutputWithPast,
35
+ CausalLMOutputWithPast,
36
+ QuestionAnsweringModelOutput,
37
+ SequenceClassifierOutputWithPast,
38
+ )
39
+ from transformers.modeling_utils import PreTrainedModel
40
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
41
+ from transformers.utils import (
42
+ add_start_docstrings,
43
+ add_start_docstrings_to_model_forward,
44
+ is_flash_attn_2_available,
45
+ is_flash_attn_greater_or_equal_2_10,
46
+ logging,
47
+ replace_return_docstrings,
48
+ )
49
+ from transformers.models.llama.configuration_llama import LlamaConfig
50
+
51
+ if is_flash_attn_2_available():
52
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
53
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
54
+
55
+
56
+ logger = logging.get_logger(__name__)
57
+
58
+ _CONFIG_FOR_DOC = "LlamaConfig"
59
+
60
+
61
+ def _get_unpad_data(attention_mask):
62
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
63
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
64
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
65
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
66
+ return (
67
+ indices,
68
+ cu_seqlens,
69
+ max_seqlen_in_batch,
70
+ )
71
+
72
+
73
+ class LlamaRMSNorm(nn.Module):
74
+ def __init__(self, hidden_size, eps=1e-6):
75
+ """
76
+ LlamaRMSNorm is equivalent to T5LayerNorm
77
+ """
78
+ super().__init__()
79
+ self.weight = nn.Parameter(torch.ones(hidden_size))
80
+ self.variance_epsilon = eps
81
+
82
+ def forward(self, hidden_states):
83
+ input_dtype = hidden_states.dtype
84
+ hidden_states = hidden_states.to(torch.float32)
85
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
86
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
87
+ return self.weight * hidden_states.to(input_dtype)
88
+
89
+
90
+ ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
91
+
92
+
93
+ class LlamaRotaryEmbedding(nn.Module):
94
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
95
+ super().__init__()
96
+ self.scaling_factor = scaling_factor
97
+ self.dim = dim
98
+ self.max_position_embeddings = max_position_embeddings
99
+ self.base = base
100
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
101
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
102
+ # For BC we register cos and sin cached
103
+ self.max_seq_len_cached = max_position_embeddings
104
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
105
+ t = t / self.scaling_factor
106
+ freqs = torch.outer(t, self.inv_freq)
107
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
108
+ emb = torch.cat((freqs, freqs), dim=-1)
109
+ self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False)
110
+ self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False)
111
+
112
+ @property
113
+ def sin_cached(self):
114
+ logger.warning_once("The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class")
115
+ return self._sin_cached
116
+
117
+ @property
118
+ def cos_cached(self):
119
+ logger.warning_once("The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class")
120
+ return self._cos_cached
121
+
122
+ @torch.no_grad()
123
+ def forward(self, x, position_ids, seq_len=None):
124
+ if seq_len is not None:
125
+ logger.warning_once("The `seq_len` argument is deprecated and unused. It will be removed in v4.39.")
126
+
127
+ # x: [bs, num_attention_heads, seq_len, head_size]
128
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
129
+ position_ids_expanded = position_ids[:, None, :].float()
130
+ # Force float32 since bfloat16 loses precision on long contexts
131
+ # See https://github.com/huggingface/transformers/pull/29285
132
+ device_type = x.device.type
133
+ device_type = device_type if isinstance(device_type, str) else "cpu"
134
+ with torch.autocast(device_type=device_type, enabled=False):
135
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
136
+ emb = torch.cat((freqs, freqs), dim=-1)
137
+ cos = emb.cos()
138
+ sin = emb.sin()
139
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
140
+
141
+
142
+ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
143
+ """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
144
+
145
+ def forward(self, x, position_ids, seq_len=None):
146
+ # difference to the original RoPE: a scaling factor is aplied to the position ids
147
+ position_ids = position_ids.float() / self.scaling_factor
148
+ cos, sin = super().forward(x, position_ids, seq_len)
149
+ return cos, sin
150
+
151
+
152
+ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
153
+ """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
154
+
155
+ def forward(self, x, position_ids, seq_len=None):
156
+ # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
157
+ seq_len = torch.max(position_ids) + 1
158
+ if seq_len > self.max_position_embeddings:
159
+ base = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)) ** (self.dim / (self.dim - 2))
160
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim))
161
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation
162
+
163
+ cos, sin = super().forward(x, position_ids, seq_len)
164
+ return cos, sin
165
+
166
+
167
+ def rotate_half(x):
168
+ """Rotates half the hidden dims of the input."""
169
+ x1 = x[..., : x.shape[-1] // 2]
170
+ x2 = x[..., x.shape[-1] // 2 :]
171
+ return torch.cat((-x2, x1), dim=-1)
172
+
173
+
174
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
175
+ """Applies Rotary Position Embedding to the query and key tensors.
176
+
177
+ Args:
178
+ q (`torch.Tensor`): The query tensor.
179
+ k (`torch.Tensor`): The key tensor.
180
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
181
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
182
+ position_ids (`torch.Tensor`, *optional*):
183
+ Deprecated and unused.
184
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
185
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
186
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
187
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
188
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
189
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
190
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
191
+ Returns:
192
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
193
+ """
194
+ cos = cos.unsqueeze(unsqueeze_dim)
195
+ sin = sin.unsqueeze(unsqueeze_dim)
196
+ q_embed = (q * cos) + (rotate_half(q) * sin)
197
+ k_embed = (k * cos) + (rotate_half(k) * sin)
198
+ return q_embed, k_embed
199
+
200
+
201
+ class LlamaMLP(nn.Module):
202
+ def __init__(self, config):
203
+ super().__init__()
204
+ self.config = config
205
+ self.hidden_size = config.hidden_size
206
+ self.intermediate_size = config.intermediate_size
207
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
208
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
209
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
210
+ self.act_fn = ACT2FN[config.hidden_act]
211
+
212
+ def forward(self, x):
213
+ if self.config.pretraining_tp > 1:
214
+ slice = self.intermediate_size // self.config.pretraining_tp
215
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
216
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
217
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
218
+
219
+ gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
220
+ up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
221
+
222
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
223
+ down_proj = [F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)]
224
+ down_proj = sum(down_proj)
225
+ else:
226
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
227
+
228
+ return down_proj
229
+
230
+
231
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
232
+ """
233
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
234
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
235
+ """
236
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
237
+ if n_rep == 1:
238
+ return hidden_states
239
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
240
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
241
+
242
+
243
+ class LlamaAttention(nn.Module):
244
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
245
+
246
+ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
247
+ super().__init__()
248
+ self.config = config
249
+ self.layer_idx = layer_idx
250
+ if layer_idx is None:
251
+ logger.warning_once(
252
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
253
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
254
+ "when creating this class."
255
+ )
256
+
257
+ self.attention_dropout = config.attention_dropout
258
+ self.hidden_size = config.hidden_size
259
+ self.num_heads = config.num_attention_heads
260
+ self.head_dim = self.hidden_size // self.num_heads
261
+ self.num_key_value_heads = config.num_key_value_heads
262
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
263
+ self.max_position_embeddings = config.max_position_embeddings
264
+ self.rope_theta = config.rope_theta
265
+ self.is_causal = True
266
+
267
+ if (self.head_dim * self.num_heads) != self.hidden_size:
268
+ raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f" and `num_heads`: {self.num_heads}).")
269
+
270
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
271
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
272
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
273
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
274
+ self._init_rope()
275
+
276
+ def _init_rope(self):
277
+ if self.config.rope_scaling is None:
278
+ self.rotary_emb = LlamaRotaryEmbedding(
279
+ self.head_dim,
280
+ max_position_embeddings=self.max_position_embeddings,
281
+ base=self.rope_theta,
282
+ )
283
+ else:
284
+ scaling_type = self.config.rope_scaling["type"]
285
+ scaling_factor = self.config.rope_scaling["factor"]
286
+ if scaling_type == "linear":
287
+ self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
288
+ self.head_dim,
289
+ max_position_embeddings=self.max_position_embeddings,
290
+ scaling_factor=scaling_factor,
291
+ base=self.rope_theta,
292
+ )
293
+ elif scaling_type == "dynamic":
294
+ self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
295
+ self.head_dim,
296
+ max_position_embeddings=self.max_position_embeddings,
297
+ scaling_factor=scaling_factor,
298
+ base=self.rope_theta,
299
+ )
300
+ else:
301
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
302
+
303
+ def forward(
304
+ self,
305
+ hidden_states: torch.Tensor,
306
+ attention_mask: Optional[torch.Tensor] = None,
307
+ position_ids: Optional[torch.LongTensor] = None,
308
+ past_key_value: Optional[Cache] = None,
309
+ output_attentions: bool = False,
310
+ use_cache: bool = False,
311
+ cache_position: Optional[torch.LongTensor] = None,
312
+ **kwargs,
313
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
314
+ bsz, q_len, _ = hidden_states.size()
315
+
316
+ if self.config.pretraining_tp > 1:
317
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
318
+ query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0)
319
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
320
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
321
+
322
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
323
+ query_states = torch.cat(query_states, dim=-1)
324
+
325
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
326
+ key_states = torch.cat(key_states, dim=-1)
327
+
328
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
329
+ value_states = torch.cat(value_states, dim=-1)
330
+
331
+ else:
332
+ query_states = self.q_proj(hidden_states)
333
+ key_states = self.k_proj(hidden_states)
334
+ value_states = self.v_proj(hidden_states)
335
+
336
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
337
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
338
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
339
+
340
+ past_key_value = getattr(self, "past_key_value", past_key_value)
341
+ cos, sin = self.rotary_emb(value_states, position_ids)
342
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
343
+
344
+ if past_key_value is not None:
345
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
346
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
347
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
348
+
349
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
350
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
351
+
352
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
353
+
354
+ if attention_mask is not None: # no matter the length, we just slice it
355
+ causal_mask = attention_mask
356
+ if cache_position is not None:
357
+ causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
358
+ attn_weights = attn_weights + causal_mask
359
+
360
+ # upcast attention to fp32
361
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
362
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
363
+ attn_output = torch.matmul(attn_weights, value_states)
364
+
365
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
366
+ raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}")
367
+
368
+ attn_output = attn_output.transpose(1, 2).contiguous()
369
+
370
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
371
+
372
+ if self.config.pretraining_tp > 1:
373
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
374
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
375
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
376
+ else:
377
+ attn_output = self.o_proj(attn_output)
378
+
379
+ if not output_attentions:
380
+ attn_weights = None
381
+
382
+ return attn_output, attn_weights, past_key_value
383
+
384
+
385
+ class LlamaRingFlashAttention2(LlamaAttention):
386
+ """
387
+ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
388
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
389
+ flash attention and deal with padding tokens in case the input contains any of them.
390
+ """
391
+
392
+ def __init__(self, *args, **kwargs):
393
+ super().__init__(*args, **kwargs)
394
+
395
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
396
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
397
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
398
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
399
+
400
+ def forward(
401
+ self,
402
+ hidden_states: torch.Tensor,
403
+ attention_mask: Optional[torch.LongTensor] = None,
404
+ position_ids: Optional[torch.LongTensor] = None,
405
+ past_key_value: Optional[Cache] = None,
406
+ output_attentions: bool = False,
407
+ use_cache: bool = False,
408
+ cache_position: Optional[torch.LongTensor] = None,
409
+ **kwargs,
410
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
411
+ output_attentions = False
412
+
413
+ bsz, q_len, _ = hidden_states.size()
414
+
415
+ query_states = self.q_proj(hidden_states)
416
+ key_states = self.k_proj(hidden_states)
417
+ value_states = self.v_proj(hidden_states)
418
+
419
+ # Flash attention requires the input to have the shape
420
+ # batch_size x seq_length x head_dim x hidden_dim
421
+ # therefore we just need to keep the original shape
422
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
423
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
424
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
425
+
426
+ cos, sin = self.rotary_emb(value_states, position_ids)
427
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
428
+
429
+ past_key_value = getattr(self, "past_key_value", past_key_value)
430
+
431
+ if past_key_value is not None:
432
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
433
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
434
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
435
+
436
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
437
+ # to be able to avoid many of these transpose/reshape/view.
438
+ query_states = query_states.transpose(1, 2)
439
+ key_states = key_states.transpose(1, 2)
440
+ value_states = value_states.transpose(1, 2)
441
+
442
+ dropout_rate = self.attention_dropout if self.training else 0.0
443
+
444
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
445
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
446
+ # cast them back in the correct dtype just to be sure everything works as expected.
447
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
448
+ # in fp32. (LlamaRMSNorm handles it correctly)
449
+
450
+ input_dtype = query_states.dtype
451
+ if input_dtype == torch.float32:
452
+ if torch.is_autocast_enabled():
453
+ target_dtype = torch.get_autocast_gpu_dtype()
454
+ # Handle the case where the model is quantized
455
+ elif hasattr(self.config, "_pre_quantization_dtype"):
456
+ target_dtype = self.config._pre_quantization_dtype
457
+ else:
458
+ target_dtype = self.q_proj.weight.dtype
459
+
460
+ logger.warning_once(
461
+ f"The input hidden states seems to be silently casted in float32, this might be related to" f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" f" {target_dtype}."
462
+ )
463
+
464
+ query_states = query_states.to(target_dtype)
465
+ key_states = key_states.to(target_dtype)
466
+ value_states = value_states.to(target_dtype)
467
+
468
+ attn_output = self._flash_attention_forward(query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate)
469
+
470
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
471
+ attn_output = self.o_proj(attn_output)
472
+
473
+ if not output_attentions:
474
+ attn_weights = None
475
+
476
+ return attn_output, attn_weights, past_key_value
477
+
478
+ def _flash_attention_forward(self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None):
479
+ """
480
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
481
+ first unpad the input, then computes the attention scores and pad the final attention scores.
482
+
483
+ Args:
484
+ query_states (`torch.Tensor`):
485
+ Input query states to be passed to Flash Attention API
486
+ key_states (`torch.Tensor`):
487
+ Input key states to be passed to Flash Attention API
488
+ value_states (`torch.Tensor`):
489
+ Input value states to be passed to Flash Attention API
490
+ attention_mask (`torch.Tensor`):
491
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
492
+ position of padding tokens and 1 for the position of non-padding tokens.
493
+ dropout (`int`, *optional*):
494
+ Attention dropout
495
+ softmax_scale (`float`, *optional*):
496
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
497
+ """
498
+ if not self._flash_attn_uses_top_left_mask:
499
+ causal = self.is_causal
500
+ else:
501
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
502
+ causal = self.is_causal and query_length != 1
503
+
504
+ # Contains at least one padding token in the sequence
505
+ if attention_mask is not None:
506
+ batch_size = query_states.shape[0]
507
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(query_states, key_states, value_states, attention_mask, query_length)
508
+
509
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
510
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
511
+
512
+ attn_output_unpad = zigzag_ring_flash_attn_varlen_func(
513
+ query_states,
514
+ key_states,
515
+ value_states,
516
+ cu_seqlens_q=cu_seqlens_q,
517
+ cu_seqlens_k=cu_seqlens_k,
518
+ max_seqlen_q=max_seqlen_in_batch_q,
519
+ max_seqlen_k=max_seqlen_in_batch_k,
520
+ dropout_p=dropout,
521
+ softmax_scale=softmax_scale,
522
+ causal=causal,
523
+ )
524
+
525
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
526
+ else:
527
+ # pack qkv
528
+ # query_states: (batch_size, seqlen, nheads, headdim)
529
+ # qkv: (batch_size, seqlen, 3, nheads, headdim)
530
+ qkv = torch.stack([query_states, key_states, value_states], dim=2)
531
+ attn_output = zigzag_ring_flash_attn_qkvpacked_func(qkv, dropout, softmax_scale, causal=causal)
532
+
533
+ return attn_output
534
+
535
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
536
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
537
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
538
+
539
+ key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k)
540
+ value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k)
541
+ if query_length == kv_seq_len:
542
+ query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k)
543
+ cu_seqlens_q = cu_seqlens_k
544
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
545
+ indices_q = indices_k
546
+ elif query_length == 1:
547
+ max_seqlen_in_batch_q = 1
548
+ cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=query_layer.device) # There is a memcpy here, that is very bad.
549
+ indices_q = cu_seqlens_q[:-1]
550
+ query_layer = query_layer.squeeze(1)
551
+ else:
552
+ # The -q_len: slice assumes left padding.
553
+ attention_mask = attention_mask[:, -query_length:]
554
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
555
+
556
+ return (
557
+ query_layer,
558
+ key_layer,
559
+ value_layer,
560
+ indices_q,
561
+ (cu_seqlens_q, cu_seqlens_k),
562
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
563
+ )
564
+
565
+
566
+ class LlamaFlashAttention2(LlamaAttention):
567
+ """
568
+ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
569
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
570
+ flash attention and deal with padding tokens in case the input contains any of them.
571
+ """
572
+
573
+ def __init__(self, *args, **kwargs):
574
+ super().__init__(*args, **kwargs)
575
+
576
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
577
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
578
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
579
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
580
+
581
+ def forward(
582
+ self,
583
+ hidden_states: torch.Tensor,
584
+ attention_mask: Optional[torch.LongTensor] = None,
585
+ position_ids: Optional[torch.LongTensor] = None,
586
+ past_key_value: Optional[Cache] = None,
587
+ output_attentions: bool = False,
588
+ use_cache: bool = False,
589
+ cache_position: Optional[torch.LongTensor] = None,
590
+ **kwargs,
591
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
592
+ output_attentions = False
593
+
594
+ bsz, q_len, _ = hidden_states.size()
595
+
596
+ query_states = self.q_proj(hidden_states)
597
+ key_states = self.k_proj(hidden_states)
598
+ value_states = self.v_proj(hidden_states)
599
+
600
+ # Flash attention requires the input to have the shape
601
+ # batch_size x seq_length x head_dim x hidden_dim
602
+ # therefore we just need to keep the original shape
603
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
604
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
605
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
606
+
607
+ cos, sin = self.rotary_emb(value_states, position_ids)
608
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
609
+
610
+ past_key_value = getattr(self, "past_key_value", past_key_value)
611
+
612
+ if past_key_value is not None:
613
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
614
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
615
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
616
+
617
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
618
+ # to be able to avoid many of these transpose/reshape/view.
619
+ query_states = query_states.transpose(1, 2)
620
+ key_states = key_states.transpose(1, 2)
621
+ value_states = value_states.transpose(1, 2)
622
+
623
+ dropout_rate = self.attention_dropout if self.training else 0.0
624
+
625
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
626
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
627
+ # cast them back in the correct dtype just to be sure everything works as expected.
628
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
629
+ # in fp32. (LlamaRMSNorm handles it correctly)
630
+
631
+ input_dtype = query_states.dtype
632
+ if input_dtype == torch.float32:
633
+ if torch.is_autocast_enabled():
634
+ target_dtype = torch.get_autocast_gpu_dtype()
635
+ # Handle the case where the model is quantized
636
+ elif hasattr(self.config, "_pre_quantization_dtype"):
637
+ target_dtype = self.config._pre_quantization_dtype
638
+ else:
639
+ target_dtype = self.q_proj.weight.dtype
640
+
641
+ logger.warning_once(
642
+ f"The input hidden states seems to be silently casted in float32, this might be related to" f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" f" {target_dtype}."
643
+ )
644
+
645
+ query_states = query_states.to(target_dtype)
646
+ key_states = key_states.to(target_dtype)
647
+ value_states = value_states.to(target_dtype)
648
+
649
+ attn_output = self._flash_attention_forward(query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate)
650
+
651
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
652
+ attn_output = self.o_proj(attn_output)
653
+
654
+ if not output_attentions:
655
+ attn_weights = None
656
+
657
+ return attn_output, attn_weights, past_key_value
658
+
659
+ def _flash_attention_forward(self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None):
660
+ """
661
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
662
+ first unpad the input, then computes the attention scores and pad the final attention scores.
663
+
664
+ Args:
665
+ query_states (`torch.Tensor`):
666
+ Input query states to be passed to Flash Attention API
667
+ key_states (`torch.Tensor`):
668
+ Input key states to be passed to Flash Attention API
669
+ value_states (`torch.Tensor`):
670
+ Input value states to be passed to Flash Attention API
671
+ attention_mask (`torch.Tensor`):
672
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
673
+ position of padding tokens and 1 for the position of non-padding tokens.
674
+ dropout (`int`, *optional*):
675
+ Attention dropout
676
+ softmax_scale (`float`, *optional*):
677
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
678
+ """
679
+ if not self._flash_attn_uses_top_left_mask:
680
+ causal = self.is_causal
681
+ else:
682
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
683
+ causal = self.is_causal and query_length != 1
684
+
685
+ # Contains at least one padding token in the sequence
686
+ if attention_mask is not None:
687
+ batch_size = query_states.shape[0]
688
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(query_states, key_states, value_states, attention_mask, query_length)
689
+
690
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
691
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
692
+
693
+ attn_output_unpad = flash_attn_varlen_func(
694
+ query_states,
695
+ key_states,
696
+ value_states,
697
+ cu_seqlens_q=cu_seqlens_q,
698
+ cu_seqlens_k=cu_seqlens_k,
699
+ max_seqlen_q=max_seqlen_in_batch_q,
700
+ max_seqlen_k=max_seqlen_in_batch_k,
701
+ dropout_p=dropout,
702
+ softmax_scale=softmax_scale,
703
+ causal=causal,
704
+ )
705
+
706
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
707
+ else:
708
+ attn_output = flash_attn_func(query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal)
709
+
710
+ return attn_output
711
+
712
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
713
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
714
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
715
+
716
+ key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k)
717
+ value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k)
718
+ if query_length == kv_seq_len:
719
+ query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k)
720
+ cu_seqlens_q = cu_seqlens_k
721
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
722
+ indices_q = indices_k
723
+ elif query_length == 1:
724
+ max_seqlen_in_batch_q = 1
725
+ cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=query_layer.device) # There is a memcpy here, that is very bad.
726
+ indices_q = cu_seqlens_q[:-1]
727
+ query_layer = query_layer.squeeze(1)
728
+ else:
729
+ # The -q_len: slice assumes left padding.
730
+ attention_mask = attention_mask[:, -query_length:]
731
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
732
+
733
+ return (
734
+ query_layer,
735
+ key_layer,
736
+ value_layer,
737
+ indices_q,
738
+ (cu_seqlens_q, cu_seqlens_k),
739
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
740
+ )
741
+
742
+
743
+ class LlamaSdpaAttention(LlamaAttention):
744
+ """
745
+ Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
746
+ `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
747
+ SDPA API.
748
+ """
749
+
750
+ # Adapted from LlamaAttention.forward
751
+ def forward(
752
+ self,
753
+ hidden_states: torch.Tensor,
754
+ attention_mask: Optional[torch.Tensor] = None,
755
+ position_ids: Optional[torch.LongTensor] = None,
756
+ past_key_value: Optional[Cache] = None,
757
+ output_attentions: bool = False,
758
+ use_cache: bool = False,
759
+ cache_position: Optional[torch.LongTensor] = None,
760
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
761
+ if output_attentions:
762
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
763
+ logger.warning_once(
764
+ "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
765
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
766
+ )
767
+ return super().forward(
768
+ hidden_states=hidden_states,
769
+ attention_mask=attention_mask,
770
+ position_ids=position_ids,
771
+ past_key_value=past_key_value,
772
+ output_attentions=output_attentions,
773
+ use_cache=use_cache,
774
+ cache_position=cache_position,
775
+ )
776
+
777
+ bsz, q_len, _ = hidden_states.size()
778
+
779
+ query_states = self.q_proj(hidden_states)
780
+ key_states = self.k_proj(hidden_states)
781
+ value_states = self.v_proj(hidden_states)
782
+
783
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
784
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
785
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
786
+
787
+ cos, sin = self.rotary_emb(value_states, position_ids)
788
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
789
+
790
+ # In case static cache is used, it is an instance attribute.
791
+ past_key_value = getattr(self, "past_key_value", past_key_value)
792
+
793
+ if past_key_value is not None:
794
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
795
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
796
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
797
+
798
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
799
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
800
+
801
+ causal_mask = attention_mask
802
+ if attention_mask is not None and cache_position is not None:
803
+ causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]
804
+
805
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
806
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
807
+ if query_states.device.type == "cuda" and causal_mask is not None:
808
+ query_states = query_states.contiguous()
809
+ key_states = key_states.contiguous()
810
+ value_states = value_states.contiguous()
811
+
812
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
813
+ query_states,
814
+ key_states,
815
+ value_states,
816
+ attn_mask=causal_mask,
817
+ dropout_p=self.attention_dropout if self.training else 0.0,
818
+ )
819
+
820
+ attn_output = attn_output.transpose(1, 2).contiguous()
821
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
822
+
823
+ attn_output = self.o_proj(attn_output)
824
+
825
+ return attn_output, None, past_key_value
826
+
827
+
828
+ try:
829
+ from ring_flash_attn import zigzag_ring_flash_attn_qkvpacked_func, zigzag_ring_flash_attn_varlen_func
830
+ except ImportError:
831
+ print("Please install the ring-flash-attn package")
832
+
833
+ LLAMA_ATTENTION_CLASSES = {
834
+ "eager": LlamaAttention,
835
+ "flash_attention_2": LlamaFlashAttention2,
836
+ "ring_flash_attention_2": LlamaRingFlashAttention2,
837
+ "sdpa": LlamaSdpaAttention,
838
+ }
839
+
840
+
841
+ class LlamaDecoderLayer(nn.Module):
842
+ def __init__(self, config: LlamaConfig, layer_idx: int):
843
+ super().__init__()
844
+ self.hidden_size = config.hidden_size
845
+
846
+ self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
847
+
848
+ self.mlp = LlamaMLP(config)
849
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
850
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
851
+
852
+ def forward(
853
+ self,
854
+ hidden_states: torch.Tensor,
855
+ attention_mask: Optional[torch.Tensor] = None,
856
+ position_ids: Optional[torch.LongTensor] = None,
857
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
858
+ output_attentions: Optional[bool] = False,
859
+ use_cache: Optional[bool] = False,
860
+ cache_position: Optional[torch.LongTensor] = None,
861
+ **kwargs,
862
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
863
+ """
864
+ Args:
865
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
866
+ attention_mask (`torch.FloatTensor`, *optional*):
867
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
868
+ query_sequence_length, key_sequence_length)` if default attention is used.
869
+ output_attentions (`bool`, *optional*):
870
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
871
+ returned tensors for more detail.
872
+ use_cache (`bool`, *optional*):
873
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
874
+ (see `past_key_values`).
875
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
876
+ """
877
+ if "padding_mask" in kwargs:
878
+ warnings.warn("Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`")
879
+
880
+ residual = hidden_states
881
+
882
+ hidden_states = self.input_layernorm(hidden_states)
883
+
884
+ # Self Attention
885
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
886
+ hidden_states=hidden_states,
887
+ attention_mask=attention_mask,
888
+ position_ids=position_ids,
889
+ past_key_value=past_key_value,
890
+ output_attentions=output_attentions,
891
+ use_cache=use_cache,
892
+ cache_position=cache_position,
893
+ **kwargs,
894
+ )
895
+ hidden_states = residual + hidden_states
896
+
897
+ # Fully Connected
898
+ residual = hidden_states
899
+ hidden_states = self.post_attention_layernorm(hidden_states)
900
+ hidden_states = self.mlp(hidden_states)
901
+ hidden_states = residual + hidden_states
902
+
903
+ outputs = (hidden_states,)
904
+
905
+ if output_attentions:
906
+ outputs += (self_attn_weights,)
907
+
908
+ if use_cache:
909
+ outputs += (present_key_value,)
910
+
911
+ return outputs
912
+
913
+
914
+ LLAMA_START_DOCSTRING = r"""
915
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
916
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
917
+ etc.)
918
+
919
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
920
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
921
+ and behavior.
922
+
923
+ Parameters:
924
+ config ([`LlamaConfig`]):
925
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
926
+ load the weights associated with the model, only the configuration. Check out the
927
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
928
+ """
929
+
930
+
931
+ @add_start_docstrings(
932
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
933
+ LLAMA_START_DOCSTRING,
934
+ )
935
+ class LlamaPreTrainedModel(PreTrainedModel):
936
+ config_class = LlamaConfig
937
+ base_model_prefix = "model"
938
+ supports_gradient_checkpointing = True
939
+ _no_split_modules = ["LlamaDecoderLayer"]
940
+ _skip_keys_device_placement = ["past_key_values", "causal_mask"]
941
+ _supports_flash_attn_2 = True
942
+ _supports_sdpa = True
943
+ _supports_cache_class = True
944
+
945
+ def _init_weights(self, module):
946
+ std = self.config.initializer_range
947
+ if isinstance(module, nn.Linear):
948
+ module.weight.data.normal_(mean=0.0, std=std)
949
+ if module.bias is not None:
950
+ module.bias.data.zero_()
951
+ elif isinstance(module, nn.Embedding):
952
+ module.weight.data.normal_(mean=0.0, std=std)
953
+ if module.padding_idx is not None:
954
+ module.weight.data[module.padding_idx].zero_()
955
+
956
+ def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None):
957
+ if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache:
958
+ raise ValueError("`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers")
959
+
960
+ if max_cache_len > self.model.causal_mask.shape[-1] or self.device != self.model.causal_mask.device:
961
+ causal_mask = torch.full((max_cache_len, max_cache_len), fill_value=True, device=self.device, dtype=torch.bool)
962
+ self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
963
+
964
+ for layer in self.model.layers:
965
+ device = layer.input_layernorm.weight.device
966
+ if hasattr(self.config, "_pre_quantization_dtype"):
967
+ dtype = self.config._pre_quantization_dtype
968
+ else:
969
+ dtype = layer.self_attn.o_proj.weight.dtype
970
+ layer.self_attn.past_key_value = cache_cls(self.config, max_batch_size, max_cache_len, device=device, dtype=dtype)
971
+
972
+ def _reset_cache(self):
973
+ for layer in self.model.layers:
974
+ layer.self_attn.past_key_value = None
975
+
976
+
977
+ LLAMA_INPUTS_DOCSTRING = r"""
978
+ Args:
979
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
980
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
981
+ it.
982
+
983
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
984
+ [`PreTrainedTokenizer.__call__`] for details.
985
+
986
+ [What are input IDs?](../glossary#input-ids)
987
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
988
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
989
+
990
+ - 1 for tokens that are **not masked**,
991
+ - 0 for tokens that are **masked**.
992
+
993
+ [What are attention masks?](../glossary#attention-mask)
994
+
995
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
996
+ [`PreTrainedTokenizer.__call__`] for details.
997
+
998
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
999
+ `past_key_values`).
1000
+
1001
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1002
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1003
+ information on the default strategy.
1004
+
1005
+ - 1 indicates the head is **not masked**,
1006
+ - 0 indicates the head is **masked**.
1007
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1008
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1009
+ config.n_positions - 1]`.
1010
+
1011
+ [What are position IDs?](../glossary#position-ids)
1012
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
1013
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1014
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
1015
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
1016
+
1017
+ Two formats are allowed:
1018
+ - a [`~cache_utils.Cache`] instance;
1019
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1020
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
1021
+ cache format.
1022
+
1023
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
1024
+ legacy cache format will be returned.
1025
+
1026
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
1027
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
1028
+ of shape `(batch_size, sequence_length)`.
1029
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1030
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1031
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1032
+ model's internal embedding lookup matrix.
1033
+ use_cache (`bool`, *optional*):
1034
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1035
+ `past_key_values`).
1036
+ output_attentions (`bool`, *optional*):
1037
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1038
+ tensors for more detail.
1039
+ output_hidden_states (`bool`, *optional*):
1040
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1041
+ more detail.
1042
+ return_dict (`bool`, *optional*):
1043
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1044
+ """
1045
+
1046
+
1047
+ @add_start_docstrings(
1048
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
1049
+ LLAMA_START_DOCSTRING,
1050
+ )
1051
+ class LlamaModel(LlamaPreTrainedModel):
1052
+ """
1053
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
1054
+
1055
+ Args:
1056
+ config: LlamaConfig
1057
+ """
1058
+
1059
+ def __init__(self, config: LlamaConfig):
1060
+ super().__init__(config)
1061
+ self.padding_idx = config.pad_token_id
1062
+ self.vocab_size = config.vocab_size
1063
+
1064
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1065
+ self.layers = nn.ModuleList([LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
1066
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1067
+ self.gradient_checkpointing = False
1068
+
1069
+ # Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
1070
+ # NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_position_embeddings`.
1071
+ causal_mask = torch.full((config.max_position_embeddings, config.max_position_embeddings), fill_value=True, dtype=torch.bool)
1072
+ self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
1073
+ # Initialize weights and apply final processing
1074
+ self.post_init()
1075
+
1076
+ def get_input_embeddings(self):
1077
+ return self.embed_tokens
1078
+
1079
+ def set_input_embeddings(self, value):
1080
+ self.embed_tokens = value
1081
+
1082
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1083
+ def forward(
1084
+ self,
1085
+ input_ids: torch.LongTensor = None,
1086
+ attention_mask: Optional[torch.Tensor] = None,
1087
+ position_ids: Optional[torch.LongTensor] = None,
1088
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1089
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1090
+ use_cache: Optional[bool] = None,
1091
+ output_attentions: Optional[bool] = None,
1092
+ output_hidden_states: Optional[bool] = None,
1093
+ return_dict: Optional[bool] = None,
1094
+ cache_position: Optional[torch.LongTensor] = None,
1095
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1096
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1097
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1098
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1099
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1100
+
1101
+ if (input_ids is None) ^ (inputs_embeds is not None):
1102
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one")
1103
+
1104
+ if self.gradient_checkpointing and self.training and use_cache:
1105
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.")
1106
+ use_cache = False
1107
+
1108
+ if inputs_embeds is None:
1109
+ inputs_embeds = self.embed_tokens(input_ids)
1110
+
1111
+ past_seen_tokens = 0
1112
+ if use_cache: # kept for BC (cache positions)
1113
+ if not isinstance(past_key_values, StaticCache):
1114
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1115
+ past_seen_tokens = past_key_values.get_seq_length()
1116
+
1117
+ if cache_position is None:
1118
+ if isinstance(past_key_values, StaticCache):
1119
+ raise ValueError("cache_position is a required argument when using StaticCache.")
1120
+ cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device)
1121
+
1122
+ if position_ids is None:
1123
+ position_ids = cache_position.unsqueeze(0)
1124
+
1125
+ causal_mask = self._update_causal_mask(attention_mask, inputs_embeds)
1126
+
1127
+ # embed positions
1128
+ hidden_states = inputs_embeds
1129
+
1130
+ # decoder layers
1131
+ all_hidden_states = () if output_hidden_states else None
1132
+ all_self_attns = () if output_attentions else None
1133
+ next_decoder_cache = None
1134
+
1135
+ for decoder_layer in self.layers:
1136
+ if output_hidden_states:
1137
+ all_hidden_states += (hidden_states,)
1138
+
1139
+ if self.gradient_checkpointing and self.training:
1140
+ layer_outputs = self._gradient_checkpointing_func(
1141
+ decoder_layer.__call__,
1142
+ hidden_states,
1143
+ causal_mask,
1144
+ position_ids,
1145
+ past_key_values,
1146
+ output_attentions,
1147
+ use_cache,
1148
+ cache_position,
1149
+ )
1150
+ else:
1151
+ layer_outputs = decoder_layer(
1152
+ hidden_states,
1153
+ attention_mask=causal_mask,
1154
+ position_ids=position_ids,
1155
+ past_key_value=past_key_values,
1156
+ output_attentions=output_attentions,
1157
+ use_cache=use_cache,
1158
+ cache_position=cache_position,
1159
+ )
1160
+
1161
+ hidden_states = layer_outputs[0]
1162
+
1163
+ if use_cache:
1164
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1165
+
1166
+ if output_attentions:
1167
+ all_self_attns += (layer_outputs[1],)
1168
+
1169
+ hidden_states = self.norm(hidden_states)
1170
+
1171
+ # add hidden states from the last decoder layer
1172
+ if output_hidden_states:
1173
+ all_hidden_states += (hidden_states,)
1174
+
1175
+ next_cache = None
1176
+ if use_cache:
1177
+ next_cache = next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
1178
+ if not return_dict:
1179
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1180
+ return BaseModelOutputWithPast(
1181
+ last_hidden_state=hidden_states,
1182
+ past_key_values=next_cache,
1183
+ hidden_states=all_hidden_states,
1184
+ attentions=all_self_attns,
1185
+ )
1186
+
1187
+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
1188
+ # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
1189
+ # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
1190
+ # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
1191
+ def _update_causal_mask(self, attention_mask, input_tensor):
1192
+ if self.config._attn_implementation == "flash_attention_2":
1193
+ if attention_mask is not None and 0.0 in attention_mask:
1194
+ return attention_mask
1195
+ return None
1196
+
1197
+ batch_size, seq_length = input_tensor.shape[:2]
1198
+ dtype = input_tensor.dtype
1199
+ device = input_tensor.device
1200
+
1201
+ # support going beyond cached `max_position_embedding`
1202
+ if seq_length > self.causal_mask.shape[-1]:
1203
+ causal_mask = torch.full((2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]), fill_value=1)
1204
+ self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
1205
+
1206
+ # We use the current dtype to avoid any overflows
1207
+ min_dtype = torch.finfo(dtype).min
1208
+ causal_mask = self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * min_dtype
1209
+
1210
+ causal_mask = causal_mask.to(dtype=dtype, device=device)
1211
+ if attention_mask is not None and attention_mask.dim() == 2:
1212
+ mask_length = attention_mask.shape[-1]
1213
+ padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
1214
+ causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
1215
+
1216
+ if self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type == "cuda":
1217
+ # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
1218
+ is_tracing = torch.jit.is_tracing() or isinstance(input_tensor, torch.fx.Proxy) or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
1219
+ if not is_tracing and torch.any(attention_mask != 1):
1220
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1221
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1222
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1223
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
1224
+
1225
+ return causal_mask
1226
+
1227
+
1228
+ class LlamaForCausalLM(LlamaPreTrainedModel):
1229
+ _tied_weights_keys = ["lm_head.weight"]
1230
+
1231
+ def __init__(self, config):
1232
+ super().__init__(config)
1233
+ self.model = LlamaModel(config)
1234
+ self.vocab_size = config.vocab_size
1235
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1236
+
1237
+ # Initialize weights and apply final processing
1238
+ self.post_init()
1239
+
1240
+ def get_input_embeddings(self):
1241
+ return self.model.embed_tokens
1242
+
1243
+ def set_input_embeddings(self, value):
1244
+ self.model.embed_tokens = value
1245
+
1246
+ def get_output_embeddings(self):
1247
+ return self.lm_head
1248
+
1249
+ def set_output_embeddings(self, new_embeddings):
1250
+ self.lm_head = new_embeddings
1251
+
1252
+ def set_decoder(self, decoder):
1253
+ self.model = decoder
1254
+
1255
+ def get_decoder(self):
1256
+ return self.model
1257
+
1258
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1259
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1260
+ def forward(
1261
+ self,
1262
+ input_ids: torch.LongTensor = None,
1263
+ attention_mask: Optional[torch.Tensor] = None,
1264
+ position_ids: Optional[torch.LongTensor] = None,
1265
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1266
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1267
+ labels: Optional[torch.LongTensor] = None,
1268
+ use_cache: Optional[bool] = None,
1269
+ output_attentions: Optional[bool] = None,
1270
+ output_hidden_states: Optional[bool] = None,
1271
+ return_dict: Optional[bool] = None,
1272
+ cache_position: Optional[torch.LongTensor] = None,
1273
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1274
+ r"""
1275
+ Args:
1276
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1277
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1278
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1279
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1280
+
1281
+ Returns:
1282
+
1283
+ Example:
1284
+
1285
+ ```python
1286
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
1287
+
1288
+ >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
1289
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
1290
+
1291
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1292
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1293
+
1294
+ >>> # Generate
1295
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1296
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1297
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1298
+ ```"""
1299
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1300
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1301
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1302
+
1303
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1304
+ outputs = self.model(
1305
+ input_ids=input_ids,
1306
+ attention_mask=attention_mask,
1307
+ position_ids=position_ids,
1308
+ past_key_values=past_key_values,
1309
+ inputs_embeds=inputs_embeds,
1310
+ use_cache=use_cache,
1311
+ output_attentions=output_attentions,
1312
+ output_hidden_states=output_hidden_states,
1313
+ return_dict=return_dict,
1314
+ cache_position=cache_position,
1315
+ )
1316
+
1317
+ hidden_states = outputs[0]
1318
+ if self.config.pretraining_tp > 1:
1319
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
1320
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
1321
+ logits = torch.cat(logits, dim=-1)
1322
+ else:
1323
+ logits = self.lm_head(hidden_states)
1324
+ logits = logits.float()
1325
+
1326
+ loss = None
1327
+ if labels is not None:
1328
+ # Shift so that tokens < n predict n
1329
+ shift_logits = logits[..., :-1, :].contiguous()
1330
+ shift_labels = labels[..., 1:].contiguous()
1331
+ # Flatten the tokens
1332
+ loss_fct = CrossEntropyLoss()
1333
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1334
+ shift_labels = shift_labels.view(-1)
1335
+ # Enable model parallelism
1336
+ shift_labels = shift_labels.to(shift_logits.device)
1337
+ loss = loss_fct(shift_logits, shift_labels)
1338
+
1339
+ if not return_dict:
1340
+ output = (logits,) + outputs[1:]
1341
+ return (loss,) + output if loss is not None else output
1342
+
1343
+ return CausalLMOutputWithPast(
1344
+ loss=loss,
1345
+ logits=logits,
1346
+ past_key_values=outputs.past_key_values,
1347
+ hidden_states=outputs.hidden_states,
1348
+ attentions=outputs.attentions,
1349
+ )
1350
+
1351
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs):
1352
+ past_length = 0
1353
+ if past_key_values is not None:
1354
+ if isinstance(past_key_values, Cache):
1355
+ cache_length = past_key_values.get_seq_length()
1356
+ past_length = past_key_values.seen_tokens
1357
+ max_cache_length = past_key_values.get_max_length()
1358
+ else:
1359
+ cache_length = past_length = past_key_values[0][0].shape[2]
1360
+ max_cache_length = None
1361
+
1362
+ # Keep only the unprocessed tokens:
1363
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1364
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1365
+ # input)
1366
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1367
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1368
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1369
+ # input_ids based on the past_length.
1370
+ elif past_length < input_ids.shape[1]:
1371
+ input_ids = input_ids[:, past_length:]
1372
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1373
+
1374
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1375
+ if max_cache_length is not None and attention_mask is not None and cache_length + input_ids.shape[1] > max_cache_length:
1376
+ attention_mask = attention_mask[:, -max_cache_length:]
1377
+
1378
+ position_ids = kwargs.get("position_ids", None)
1379
+ if attention_mask is not None and position_ids is None:
1380
+ # create position_ids on the fly for batch generation
1381
+ position_ids = attention_mask.long().cumsum(-1) - 1
1382
+ position_ids.masked_fill_(attention_mask == 0, 1)
1383
+ if past_key_values:
1384
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1385
+
1386
+ if self.generation_config.cache_implementation == "static":
1387
+ # generation with static cache
1388
+ cache_position = kwargs.get("cache_position", None)
1389
+ if cache_position is None:
1390
+ past_length = 0
1391
+ else:
1392
+ past_length = cache_position[-1] + 1
1393
+ input_ids = input_ids[:, past_length:]
1394
+ position_ids = position_ids[:, past_length:]
1395
+
1396
+ # TODO @gante we should only keep a `cache_position` in generate, and do +=1.
1397
+ # same goes for position ids. Could also help with continued generation.
1398
+ input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
1399
+ cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
1400
+ position_ids = position_ids.contiguous() if position_ids is not None else None
1401
+
1402
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1403
+ if inputs_embeds is not None and past_key_values is None:
1404
+ model_inputs = {"inputs_embeds": inputs_embeds}
1405
+ else:
1406
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
1407
+ # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
1408
+ # TODO: use `next_tokens` directly instead.
1409
+ model_inputs = {"input_ids": input_ids.contiguous()}
1410
+
1411
+ model_inputs.update(
1412
+ {
1413
+ "position_ids": position_ids,
1414
+ "cache_position": cache_position,
1415
+ "past_key_values": past_key_values,
1416
+ "use_cache": kwargs.get("use_cache"),
1417
+ "attention_mask": attention_mask,
1418
+ }
1419
+ )
1420
+ return model_inputs
1421
+
1422
+ @staticmethod
1423
+ def _reorder_cache(past_key_values, beam_idx):
1424
+ reordered_past = ()
1425
+ for layer_past in past_key_values:
1426
+ reordered_past += (tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),)
1427
+ return reordered_past
1428
+
1429
+
1430
+ @add_start_docstrings(
1431
+ """
1432
+ The LLaMa Model transformer with a sequence classification head on top (linear layer).
1433
+
1434
+ [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1435
+ (e.g. GPT-2) do.
1436
+
1437
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1438
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1439
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1440
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1441
+ each row of the batch).
1442
+ """,
1443
+ LLAMA_START_DOCSTRING,
1444
+ )
1445
+ class LlamaForSequenceClassification(LlamaPreTrainedModel):
1446
+ def __init__(self, config):
1447
+ super().__init__(config)
1448
+ self.num_labels = config.num_labels
1449
+ self.model = LlamaModel(config)
1450
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1451
+
1452
+ # Initialize weights and apply final processing
1453
+ self.post_init()
1454
+
1455
+ def get_input_embeddings(self):
1456
+ return self.model.embed_tokens
1457
+
1458
+ def set_input_embeddings(self, value):
1459
+ self.model.embed_tokens = value
1460
+
1461
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1462
+ def forward(
1463
+ self,
1464
+ input_ids: torch.LongTensor = None,
1465
+ attention_mask: Optional[torch.Tensor] = None,
1466
+ position_ids: Optional[torch.LongTensor] = None,
1467
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1468
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1469
+ labels: Optional[torch.LongTensor] = None,
1470
+ use_cache: Optional[bool] = None,
1471
+ output_attentions: Optional[bool] = None,
1472
+ output_hidden_states: Optional[bool] = None,
1473
+ return_dict: Optional[bool] = None,
1474
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1475
+ r"""
1476
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1477
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1478
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1479
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1480
+ """
1481
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1482
+
1483
+ transformer_outputs = self.model(
1484
+ input_ids,
1485
+ attention_mask=attention_mask,
1486
+ position_ids=position_ids,
1487
+ past_key_values=past_key_values,
1488
+ inputs_embeds=inputs_embeds,
1489
+ use_cache=use_cache,
1490
+ output_attentions=output_attentions,
1491
+ output_hidden_states=output_hidden_states,
1492
+ return_dict=return_dict,
1493
+ )
1494
+ hidden_states = transformer_outputs[0]
1495
+ logits = self.score(hidden_states)
1496
+
1497
+ if input_ids is not None:
1498
+ batch_size = input_ids.shape[0]
1499
+ else:
1500
+ batch_size = inputs_embeds.shape[0]
1501
+
1502
+ if self.config.pad_token_id is None and batch_size != 1:
1503
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1504
+ if self.config.pad_token_id is None:
1505
+ sequence_lengths = -1
1506
+ else:
1507
+ if input_ids is not None:
1508
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1509
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1510
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1511
+ sequence_lengths = sequence_lengths.to(logits.device)
1512
+ else:
1513
+ sequence_lengths = -1
1514
+
1515
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1516
+
1517
+ loss = None
1518
+ if labels is not None:
1519
+ labels = labels.to(logits.device)
1520
+ if self.config.problem_type is None:
1521
+ if self.num_labels == 1:
1522
+ self.config.problem_type = "regression"
1523
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1524
+ self.config.problem_type = "single_label_classification"
1525
+ else:
1526
+ self.config.problem_type = "multi_label_classification"
1527
+
1528
+ if self.config.problem_type == "regression":
1529
+ loss_fct = MSELoss()
1530
+ if self.num_labels == 1:
1531
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1532
+ else:
1533
+ loss = loss_fct(pooled_logits, labels)
1534
+ elif self.config.problem_type == "single_label_classification":
1535
+ loss_fct = CrossEntropyLoss()
1536
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1537
+ elif self.config.problem_type == "multi_label_classification":
1538
+ loss_fct = BCEWithLogitsLoss()
1539
+ loss = loss_fct(pooled_logits, labels)
1540
+ if not return_dict:
1541
+ output = (pooled_logits,) + transformer_outputs[1:]
1542
+ return ((loss,) + output) if loss is not None else output
1543
+
1544
+ return SequenceClassifierOutputWithPast(
1545
+ loss=loss,
1546
+ logits=pooled_logits,
1547
+ past_key_values=transformer_outputs.past_key_values,
1548
+ hidden_states=transformer_outputs.hidden_states,
1549
+ attentions=transformer_outputs.attentions,
1550
+ )
1551
+
1552
+
1553
+ @add_start_docstrings(
1554
+ """
1555
+ The Llama Model transformer with a span classification head on top for extractive question-answering tasks like
1556
+ SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
1557
+ """,
1558
+ LLAMA_START_DOCSTRING,
1559
+ )
1560
+ class LlamaForQuestionAnswering(LlamaPreTrainedModel):
1561
+ base_model_prefix = "transformer"
1562
+
1563
+ # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama
1564
+ def __init__(self, config):
1565
+ super().__init__(config)
1566
+ self.transformer = LlamaModel(config)
1567
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
1568
+
1569
+ # Initialize weights and apply final processing
1570
+ self.post_init()
1571
+
1572
+ def get_input_embeddings(self):
1573
+ return self.transformer.embed_tokens
1574
+
1575
+ def set_input_embeddings(self, value):
1576
+ self.transformer.embed_tokens = value
1577
+
1578
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1579
+ def forward(
1580
+ self,
1581
+ input_ids: Optional[torch.LongTensor] = None,
1582
+ attention_mask: Optional[torch.FloatTensor] = None,
1583
+ position_ids: Optional[torch.LongTensor] = None,
1584
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1585
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1586
+ start_positions: Optional[torch.LongTensor] = None,
1587
+ end_positions: Optional[torch.LongTensor] = None,
1588
+ output_attentions: Optional[bool] = None,
1589
+ output_hidden_states: Optional[bool] = None,
1590
+ return_dict: Optional[bool] = None,
1591
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1592
+ r"""
1593
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1594
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1595
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1596
+ are not taken into account for computing the loss.
1597
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1598
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1599
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1600
+ are not taken into account for computing the loss.
1601
+ """
1602
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1603
+
1604
+ outputs = self.transformer(
1605
+ input_ids,
1606
+ attention_mask=attention_mask,
1607
+ position_ids=position_ids,
1608
+ past_key_values=past_key_values,
1609
+ inputs_embeds=inputs_embeds,
1610
+ output_attentions=output_attentions,
1611
+ output_hidden_states=output_hidden_states,
1612
+ return_dict=return_dict,
1613
+ )
1614
+
1615
+ sequence_output = outputs[0]
1616
+
1617
+ logits = self.qa_outputs(sequence_output)
1618
+ start_logits, end_logits = logits.split(1, dim=-1)
1619
+ start_logits = start_logits.squeeze(-1).contiguous()
1620
+ end_logits = end_logits.squeeze(-1).contiguous()
1621
+
1622
+ total_loss = None
1623
+ if start_positions is not None and end_positions is not None:
1624
+ # If we are on multi-GPU, split add a dimension
1625
+ if len(start_positions.size()) > 1:
1626
+ start_positions = start_positions.squeeze(-1).to(start_logits.device)
1627
+ if len(end_positions.size()) > 1:
1628
+ end_positions = end_positions.squeeze(-1).to(end_logits.device)
1629
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1630
+ ignored_index = start_logits.size(1)
1631
+ start_positions = start_positions.clamp(0, ignored_index)
1632
+ end_positions = end_positions.clamp(0, ignored_index)
1633
+
1634
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1635
+ start_loss = loss_fct(start_logits, start_positions)
1636
+ end_loss = loss_fct(end_logits, end_positions)
1637
+ total_loss = (start_loss + end_loss) / 2
1638
+
1639
+ if not return_dict:
1640
+ output = (start_logits, end_logits) + outputs[2:]
1641
+ return ((total_loss,) + output) if total_loss is not None else output
1642
+
1643
+ return QuestionAnsweringModelOutput(
1644
+ loss=total_loss,
1645
+ start_logits=start_logits,
1646
+ end_logits=end_logits,
1647
+ hidden_states=outputs.hidden_states,
1648
+ attentions=outputs.attentions,
1649
+ )
llava/model/llava_arch.py ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from abc import ABC, abstractmethod
17
+
18
+ import math
19
+ import re
20
+ import time
21
+ import torch
22
+ import torch.nn as nn
23
+ from .multimodal_encoder.builder import build_vision_tower
24
+ from .multimodal_resampler.builder import build_vision_resampler
25
+ from .multimodal_projector.builder import build_vision_projector
26
+
27
+ from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
28
+
29
+ from llava.mm_utils import get_anyres_image_grid_shape
30
+ from llava.utils import rank0_print, rank_print
31
+ import random
32
+
33
+
34
+ class LlavaMetaModel:
35
+
36
+ def __init__(self, config):
37
+ super(LlavaMetaModel, self).__init__(config)
38
+
39
+ if hasattr(config, "mm_vision_tower"):
40
+ delay_load = getattr(config, "delay_load", False)
41
+ self.vision_tower = build_vision_tower(config, delay_load=delay_load)
42
+ self.vision_resampler = build_vision_resampler(config, vision_tower=self.vision_tower)
43
+ self.mm_projector = build_vision_projector(config, vision_cfg=self.vision_tower.config)
44
+
45
+ if "unpad" in getattr(config, "mm_patch_merge_type", ""):
46
+ self.image_newline = nn.Parameter(torch.empty(config.hidden_size, dtype=self.dtype))
47
+
48
+ def get_vision_tower(self):
49
+ vision_tower = getattr(self, "vision_tower", None)
50
+ if type(vision_tower) is list:
51
+ vision_tower = vision_tower[0]
52
+ return vision_tower
53
+
54
+ def initialize_vision_modules(self, model_args, fsdp=None):
55
+ vision_tower = model_args.vision_tower
56
+ mm_vision_select_layer = model_args.mm_vision_select_layer
57
+ mm_vision_select_feature = model_args.mm_vision_select_feature
58
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
59
+ mm_patch_merge_type = model_args.mm_patch_merge_type
60
+
61
+ self.config.mm_vision_tower = vision_tower
62
+ self.config.vision_tower_pretrained = getattr(model_args, "vision_tower_pretrained", "")
63
+
64
+ if self.get_vision_tower() is None:
65
+ vision_tower = build_vision_tower(model_args)
66
+ vision_resampler = build_vision_resampler(model_args, vision_tower=vision_tower)
67
+ for k, v in vision_resampler.config.items():
68
+ setattr(self.config, k, v)
69
+
70
+ if fsdp is not None and len(fsdp) > 0:
71
+ self.vision_tower = [vision_tower]
72
+ self.vision_resampler = [vision_resampler]
73
+ else:
74
+ self.vision_tower = vision_tower
75
+ self.vision_resampler = vision_resampler
76
+ else:
77
+ if fsdp is not None and len(fsdp) > 0:
78
+ vision_resampler = self.vision_resampler[0]
79
+ vision_tower = self.vision_tower[0]
80
+ else:
81
+ vision_resampler = self.vision_resampler
82
+ vision_tower = self.vision_tower
83
+ vision_tower.load_model()
84
+
85
+ # In case it is frozen by LoRA
86
+ for p in self.vision_resampler.parameters():
87
+ p.requires_grad = True
88
+
89
+ self.config.use_mm_proj = True
90
+ self.config.mm_projector_type = getattr(model_args, "mm_projector_type", "linear")
91
+ self.config.mm_hidden_size = getattr(vision_resampler, "hidden_size", vision_tower.hidden_size)
92
+ self.config.mm_vision_select_layer = mm_vision_select_layer
93
+ self.config.mm_vision_select_feature = mm_vision_select_feature
94
+ self.config.mm_patch_merge_type = mm_patch_merge_type
95
+
96
+ if getattr(self, "mm_projector", None) is None:
97
+ self.mm_projector = build_vision_projector(self.config, vision_cfg=vision_tower.config)
98
+
99
+ if "unpad" in mm_patch_merge_type:
100
+ embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
101
+ self.image_newline = nn.Parameter(torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std)
102
+ else:
103
+ # In case it is frozen by LoRA
104
+ for p in self.mm_projector.parameters():
105
+ p.requires_grad = True
106
+
107
+ if pretrain_mm_mlp_adapter is not None:
108
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location="cpu")
109
+
110
+ def get_w(weights, keyword):
111
+ return {k.split(keyword + ".")[1]: v for k, v in weights.items() if keyword in k}
112
+
113
+ incompatible_keys = self.mm_projector.load_state_dict(get_w(mm_projector_weights, "mm_projector"))
114
+ rank0_print(f"Loaded mm projector weights from {pretrain_mm_mlp_adapter}. Incompatible keys: {incompatible_keys}")
115
+ incompatible_keys = self.vision_resampler.load_state_dict(get_w(mm_projector_weights, "vision_resampler"), strict=False)
116
+ rank0_print(f"Loaded vision resampler weights from {pretrain_mm_mlp_adapter}. Incompatible keys: {incompatible_keys}")
117
+
118
+
119
+ def unpad_image(tensor, original_size):
120
+ """
121
+ Unpads a PyTorch tensor of a padded and resized image.
122
+
123
+ Args:
124
+ tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
125
+ original_size (tuple): The original size of the image (height, width).
126
+
127
+ Returns:
128
+ torch.Tensor: The unpadded image tensor.
129
+ """
130
+ original_width, original_height = original_size
131
+ current_height, current_width = tensor.shape[1:]
132
+
133
+ # Compute aspect ratios
134
+ original_aspect_ratio = original_width / original_height
135
+ current_aspect_ratio = current_width / current_height
136
+
137
+ # Determine padding size and direction
138
+ if original_aspect_ratio > current_aspect_ratio:
139
+ # Padding was added to the height
140
+ scale_factor = current_width / original_width
141
+ new_height = int(original_height * scale_factor)
142
+ padding = (current_height - new_height) // 2
143
+ unpadded_tensor = tensor[:, padding : current_height - padding, :]
144
+ else:
145
+ # Padding was added to the width
146
+ scale_factor = current_height / original_height
147
+ new_width = int(original_width * scale_factor)
148
+ padding = (current_width - new_width) // 2
149
+ unpadded_tensor = tensor[:, :, padding : current_width - padding]
150
+
151
+ return unpadded_tensor
152
+
153
+
154
+ class LlavaMetaForCausalLM(ABC):
155
+
156
+ @abstractmethod
157
+ def get_model(self):
158
+ pass
159
+
160
+ def get_vision_tower(self):
161
+ return self.get_model().get_vision_tower()
162
+
163
+ def get_2dPool(self, image_feature):
164
+ height = width = self.get_vision_tower().num_patches_per_side
165
+ num_frames, num_tokens, num_dim = image_feature.shape
166
+ image_feature = image_feature.view(num_frames, height, width, -1)
167
+ image_feature = image_feature.permute(0, 3, 1, 2).contiguous()
168
+ # image_feature = nn.functional.max_pool2d(image_feature, self.config.mm_spatial_pool_stride)
169
+ if self.config.mm_spatial_pool_mode == "average":
170
+ image_feature = nn.functional.avg_pool2d(image_feature, self.config.mm_spatial_pool_stride)
171
+ elif self.config.mm_spatial_pool_mode == "max":
172
+ image_feature = nn.functional.max_pool2d(image_feature, self.config.mm_spatial_pool_stride)
173
+ elif self.config.mm_spatial_pool_mode == "bilinear":
174
+ height, weight = image_feature.shape[2:]
175
+ scaled_shape = [math.ceil(height / 2), math.ceil(weight / 2)]
176
+ image_feature = nn.functional.interpolate(image_feature, size=scaled_shape, mode='bilinear')
177
+
178
+ else:
179
+ raise ValueError(f"Unexpected mm_spatial_pool_mode: {self.config.mm_spatial_pool_mode}")
180
+ image_feature = image_feature.permute(0, 2, 3, 1)
181
+ image_feature = image_feature.view(num_frames, -1, num_dim)
182
+ return image_feature
183
+
184
+ def encode_images(self, images):
185
+ image_features = self.get_model().get_vision_tower()(images)
186
+ # image_features = self.get_model().vision_resampler(image_features, images=images)
187
+ image_features = self.get_model().mm_projector(image_features)
188
+ return image_features
189
+
190
+ def encode_multimodals(self, videos_or_images, video_idx_in_batch, split_sizes=None):
191
+ videos_or_images_features = self.get_model().get_vision_tower()(videos_or_images)
192
+ per_videos_or_images_features = torch.split(videos_or_images_features, split_sizes, dim=0) # tuple, (dim_1, 576, 4096)
193
+ all_videos_or_images_features = []
194
+
195
+ for idx, feat in enumerate(per_videos_or_images_features):
196
+ feat = self.get_model().mm_projector(feat)
197
+ if idx in video_idx_in_batch:
198
+ feat = self.get_2dPool(feat)
199
+ all_videos_or_images_features.append(feat)
200
+ return all_videos_or_images_features
201
+
202
+ def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities=["image"], image_sizes=None):
203
+ vision_tower = self.get_vision_tower()
204
+ # rank_print(modalities)
205
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
206
+ return input_ids, position_ids, attention_mask, past_key_values, None, labels
207
+
208
+ if type(images) is list or images.ndim == 5:
209
+ if type(images) is list:
210
+ images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
211
+
212
+ video_idx_in_batch = []
213
+ for _ in range(len(modalities)):
214
+ if modalities[_] == "video":
215
+ video_idx_in_batch.append(_)
216
+
217
+ images_list = []
218
+ for image in images:
219
+ if image.ndim == 4:
220
+ images_list.append(image)
221
+ else:
222
+ images_list.append(image.unsqueeze(0))
223
+
224
+ concat_images = torch.cat([image for image in images_list], dim=0)
225
+ split_sizes = [image.shape[0] for image in images_list]
226
+ encoded_image_features = self.encode_images(concat_images)
227
+
228
+ # This is a list, each element is [num_images, patch * patch, dim]
229
+ # rank_print(f"Concat images : {concat_images.shape}")
230
+ encoded_image_features = torch.split(encoded_image_features, split_sizes)
231
+ image_features = []
232
+ for idx, image_feat in enumerate(encoded_image_features):
233
+ if idx in video_idx_in_batch:
234
+ image_features.append(self.get_2dPool(image_feat))
235
+ else:
236
+ image_features.append(image_feat)
237
+ # image_features = self.encode_multimodals(concat_images, video_idx_in_batch, split_sizes)
238
+ # rank_print(f"Encoded image feats : {[x.shape for x in image_features]}")
239
+ # image_features = torch.split(image_features, split_sizes, dim=0)
240
+ mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
241
+ image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
242
+
243
+ if mm_patch_merge_type == "flat":
244
+ image_features = [x.flatten(0, 1) for x in image_features]
245
+
246
+ elif mm_patch_merge_type.startswith("spatial"):
247
+ new_image_features = []
248
+ for image_idx, image_feature in enumerate(image_features):
249
+ # FIXME: now assume the image is square, and split to 2x2 patches
250
+ # num_patches = h * w, where h = w = sqrt(num_patches)
251
+ # currently image_feature is a tensor of shape (4, num_patches, hidden_size)
252
+ # we want to first unflatten it to (2, 2, h, w, hidden_size)
253
+ # rank0_print("At least we are reaching here")
254
+ if image_idx in video_idx_in_batch: # video operations
255
+ # rank0_print("Video")
256
+ if "unpad" in mm_patch_merge_type:
257
+ # image_feature = image_feature.permute(2, 0, 1).contiguous()
258
+ # image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1)
259
+ # image_feature = image_feature.permute(1, 2, 0).contiguous()
260
+ image_feature = image_feature.flatten(0, 1)
261
+ image_feature = torch.cat((image_feature, self.model.image_newline[None].to(image_feature.device)), dim=0)
262
+
263
+ elif image_feature.shape[0] > 1: # multi patches and multi images operations
264
+ # rank0_print("Single-images")
265
+ base_image_feature = image_feature[0]
266
+ image_feature = image_feature[1:]
267
+ height = width = self.get_vision_tower().num_patches_per_side
268
+ assert height * width == base_image_feature.shape[0]
269
+
270
+ if "anyres_max" in image_aspect_ratio:
271
+ matched_anyres_max_num_patches = re.match(r"anyres_max_(\d+)", image_aspect_ratio)
272
+ if matched_anyres_max_num_patches:
273
+ max_num_patches = int(matched_anyres_max_num_patches.group(1))
274
+
275
+ if image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
276
+ if hasattr(self.get_vision_tower(), "image_size"):
277
+ vision_tower_image_size = self.get_vision_tower().image_size
278
+ else:
279
+ raise ValueError("vision_tower_image_size is not found in the vision tower.")
280
+ try:
281
+ num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, vision_tower_image_size)
282
+ except Exception as e:
283
+ rank0_print(f"Error: {e}")
284
+ num_patch_width, num_patch_height = 2, 2
285
+ image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
286
+ else:
287
+ image_feature = image_feature.view(2, 2, height, width, -1)
288
+
289
+ if "maxpool2x2" in mm_patch_merge_type:
290
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
291
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
292
+ image_feature = nn.functional.max_pool2d(image_feature, 2)
293
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
294
+ elif "unpad" in mm_patch_merge_type and "anyres_max" in image_aspect_ratio and matched_anyres_max_num_patches:
295
+ unit = image_feature.shape[2]
296
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
297
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
298
+ image_feature = unpad_image(image_feature, image_sizes[image_idx])
299
+ c, h, w = image_feature.shape
300
+ times = math.sqrt(h * w / (max_num_patches * unit**2))
301
+ if times > 1.1:
302
+ image_feature = image_feature[None]
303
+ image_feature = nn.functional.interpolate(image_feature, [int(h // times), int(w // times)], mode="bilinear")[0]
304
+ image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1)
305
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
306
+ elif "unpad" in mm_patch_merge_type:
307
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
308
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
309
+ image_feature = unpad_image(image_feature, image_sizes[image_idx])
310
+ image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1)
311
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
312
+ else:
313
+ image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
314
+ image_feature = image_feature.flatten(0, 3)
315
+ if "nobase" in mm_patch_merge_type:
316
+ pass
317
+ else:
318
+ image_feature = torch.cat((base_image_feature, image_feature), dim=0)
319
+ else: # single image operations
320
+ image_feature = image_feature[0]
321
+ if "unpad" in mm_patch_merge_type:
322
+ image_feature = torch.cat((image_feature, self.model.image_newline[None]), dim=0)
323
+
324
+ new_image_features.append(image_feature)
325
+ image_features = new_image_features
326
+ else:
327
+ raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
328
+ else:
329
+ image_features = self.encode_images(images)
330
+
331
+ # TODO: image start / end is not implemented here to support pretraining.
332
+ if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(self.config, "mm_use_im_start_end", False):
333
+ raise NotImplementedError
334
+ # rank_print(f"Total images : {len(image_features)}")
335
+
336
+ # Let's just add dummy tensors if they do not exist,
337
+ # it is a headache to deal with None all the time.
338
+ # But it is not ideal, and if you have a better idea,
339
+ # please open an issue / submit a PR, thanks.
340
+ _labels = labels
341
+ _position_ids = position_ids
342
+ _attention_mask = attention_mask
343
+ if attention_mask is None:
344
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
345
+ else:
346
+ attention_mask = attention_mask.bool()
347
+ if position_ids is None:
348
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
349
+ if labels is None:
350
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
351
+
352
+ # remove the padding using attention_mask -- FIXME
353
+ _input_ids = input_ids
354
+ input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
355
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
356
+
357
+ new_input_embeds = []
358
+ new_labels = []
359
+ cur_image_idx = 0
360
+ # rank_print("Inserting Images embedding")
361
+ for batch_idx, cur_input_ids in enumerate(input_ids):
362
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
363
+ # rank0_print(num_images)
364
+ if num_images == 0:
365
+ cur_image_features = image_features[cur_image_idx]
366
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
367
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
368
+ new_input_embeds.append(cur_input_embeds)
369
+ new_labels.append(labels[batch_idx])
370
+ cur_image_idx += 1
371
+ continue
372
+
373
+ image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
374
+ cur_input_ids_noim = []
375
+ cur_labels = labels[batch_idx]
376
+ cur_labels_noim = []
377
+ for i in range(len(image_token_indices) - 1):
378
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]])
379
+ cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]])
380
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
381
+ cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
382
+ cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
383
+ cur_new_input_embeds = []
384
+ cur_new_labels = []
385
+
386
+ for i in range(num_images + 1):
387
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
388
+ cur_new_labels.append(cur_labels_noim[i])
389
+ if i < num_images:
390
+ try:
391
+ cur_image_features = image_features[cur_image_idx]
392
+ except IndexError:
393
+ cur_image_features = image_features[cur_image_idx - 1]
394
+ cur_image_idx += 1
395
+ cur_new_input_embeds.append(cur_image_features)
396
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
397
+
398
+ cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
399
+
400
+ # import pdb; pdb.set_trace()
401
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
402
+ cur_new_labels = torch.cat(cur_new_labels)
403
+
404
+ new_input_embeds.append(cur_new_input_embeds)
405
+ new_labels.append(cur_new_labels)
406
+
407
+ # Truncate sequences to max length as image embeddings can make the sequence longer
408
+ tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None)
409
+ # rank_print("Finishing Inserting")
410
+
411
+ new_input_embeds = [x[:tokenizer_model_max_length] for x, modality in zip(new_input_embeds, modalities)]
412
+ new_labels = [x[:tokenizer_model_max_length] for x, modality in zip(new_labels, modalities)]
413
+ # TODO: Hard code for control loss spike
414
+ # if tokenizer_model_max_length is not None:
415
+ # new_input_embeds = [x[:4096] if modality != "video" else x[:tokenizer_model_max_length] for x, modality in zip(new_input_embeds, modalities)]
416
+ # new_labels = [x[:4096] if modality != "video" else x[:tokenizer_model_max_length] for x, modality in zip(new_labels, modalities)]
417
+
418
+ # Combine them
419
+ max_len = max(x.shape[0] for x in new_input_embeds)
420
+ batch_size = len(new_input_embeds)
421
+
422
+ new_input_embeds_padded = []
423
+ new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
424
+ attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
425
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
426
+ # rank0_print("Prepare pos id")
427
+
428
+ for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
429
+ cur_len = cur_new_embed.shape[0]
430
+ if getattr(self.config, "tokenizer_padding_side", "right") == "left":
431
+ new_input_embeds_padded.append(torch.cat((torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), cur_new_embed), dim=0))
432
+ if cur_len > 0:
433
+ new_labels_padded[i, -cur_len:] = cur_new_labels
434
+ attention_mask[i, -cur_len:] = True
435
+ position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
436
+ else:
437
+ new_input_embeds_padded.append(torch.cat((cur_new_embed, torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0))
438
+ if cur_len > 0:
439
+ new_labels_padded[i, :cur_len] = cur_new_labels
440
+ attention_mask[i, :cur_len] = True
441
+ position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
442
+
443
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
444
+ # rank0_print("tokenizer padding")
445
+
446
+ if _labels is None:
447
+ new_labels = None
448
+ else:
449
+ new_labels = new_labels_padded
450
+
451
+ if _attention_mask is None:
452
+ attention_mask = None
453
+ else:
454
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
455
+
456
+ if _position_ids is None:
457
+ position_ids = None
458
+ if getattr(self.config, "use_pos_skipping", False) and self.training:
459
+ position_ids = torch.arange(new_input_embeds.size(1), device=new_input_embeds.device).unsqueeze(0).to(new_input_embeds.device)
460
+ split_position = random.randint(0, new_input_embeds.size(1))
461
+ left_add = random.randint(0, self.config.pos_skipping_range)
462
+ right_add = random.randint(left_add, self.config.pos_skipping_range)
463
+ position_ids[:, :split_position] += left_add
464
+ position_ids[:, split_position:] += right_add
465
+ # import pdb; pdb.set_trace()
466
+ # rank0_print("Finish preparing")
467
+ return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
468
+
469
+ def initialize_vision_tokenizer(self, model_args, tokenizer):
470
+ if model_args.mm_use_im_patch_token:
471
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
472
+ self.resize_token_embeddings(len(tokenizer))
473
+
474
+ if model_args.mm_use_im_start_end:
475
+ num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
476
+ self.resize_token_embeddings(len(tokenizer))
477
+
478
+ if num_new_tokens > 0:
479
+ input_embeddings = self.get_input_embeddings().weight.data
480
+ output_embeddings = self.get_output_embeddings().weight.data
481
+
482
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
483
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
484
+
485
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
486
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
487
+
488
+ if model_args.tune_mm_mlp_adapter:
489
+ for p in self.get_input_embeddings().parameters():
490
+ p.requires_grad = True
491
+ for p in self.get_output_embeddings().parameters():
492
+ p.requires_grad = False
493
+
494
+ if model_args.pretrain_mm_mlp_adapter:
495
+ mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location="cpu")
496
+ embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"]
497
+ assert num_new_tokens == 2
498
+ if input_embeddings.shape == embed_tokens_weight.shape:
499
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
500
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
501
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
502
+ else:
503
+ raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
504
+ elif model_args.mm_use_im_patch_token:
505
+ if model_args.tune_mm_mlp_adapter:
506
+ for p in self.get_input_embeddings().parameters():
507
+ p.requires_grad = False
508
+ for p in self.get_output_embeddings().parameters():
509
+ p.requires_grad = False
llava/model/make_delta.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta
4
+ """
5
+
6
+ import argparse
7
+
8
+ import torch
9
+ from tqdm import tqdm
10
+ from transformers import AutoTokenizer, AutoModelForCausalLM
11
+ from llava.model.utils import auto_upgrade
12
+
13
+
14
+ def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
15
+ print("Loading base model")
16
+ base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17
+
18
+ print("Loading target model")
19
+ auto_upgrade(target_model_path)
20
+ target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
21
+
22
+ print("Calculating delta")
23
+ for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
24
+ if name not in base.state_dict():
25
+ assert name in ["model.mm_projector.weight", "model.mm_projector.bias"], f"{name} not in base model"
26
+ continue
27
+ if param.data.shape == base.state_dict()[name].shape:
28
+ param.data -= base.state_dict()[name]
29
+ else:
30
+ assert name in ["model.embed_tokens.weight", "lm_head.weight"], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}"
31
+ bparam = base.state_dict()[name]
32
+ param.data[: bparam.shape[0], : bparam.shape[1]] -= bparam
33
+
34
+ print("Saving delta")
35
+ if hub_repo_id:
36
+ kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
37
+ else:
38
+ kwargs = {}
39
+ target.save_pretrained(delta_path, **kwargs)
40
+ target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
41
+ target_tokenizer.save_pretrained(delta_path, **kwargs)
42
+
43
+
44
+ if __name__ == "__main__":
45
+ parser = argparse.ArgumentParser()
46
+ parser.add_argument("--base-model-path", type=str, required=True)
47
+ parser.add_argument("--target-model-path", type=str, required=True)
48
+ parser.add_argument("--delta-path", type=str, required=True)
49
+ parser.add_argument("--hub-repo-id", type=str, default=None)
50
+ args = parser.parse_args()
51
+
52
+ make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)