ga89tiy commited on
Commit
db6ee6a
1 Parent(s): b56b523

Initial model commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LLAVA_Biovil/.dockerignore +21 -0
  2. LLAVA_Biovil/.editorconfig +18 -0
  3. LLAVA_Biovil/.gitattributes +29 -0
  4. LLAVA_Biovil/.gitignore +35 -0
  5. LLAVA_Biovil/LICENSE +201 -0
  6. LLAVA_Biovil/README.md +410 -0
  7. LLAVA_Biovil/__init__.py +0 -0
  8. LLAVA_Biovil/biovil_t/__init__.py +0 -0
  9. LLAVA_Biovil/biovil_t/encoder.py +180 -0
  10. LLAVA_Biovil/biovil_t/model.py +130 -0
  11. LLAVA_Biovil/biovil_t/modules.py +85 -0
  12. LLAVA_Biovil/biovil_t/pretrained.py +85 -0
  13. LLAVA_Biovil/biovil_t/resnet.py +80 -0
  14. LLAVA_Biovil/biovil_t/transformer.py +266 -0
  15. LLAVA_Biovil/biovil_t/types.py +37 -0
  16. LLAVA_Biovil/cog.yaml +37 -0
  17. LLAVA_Biovil/install.md +6 -0
  18. LLAVA_Biovil/llava/__init__.py +1 -0
  19. LLAVA_Biovil/llava/constants.py +13 -0
  20. LLAVA_Biovil/llava/conversation.py +414 -0
  21. LLAVA_Biovil/llava/eval/__init__.py +0 -0
  22. LLAVA_Biovil/llava/eval/eval_gpt_review.py +113 -0
  23. LLAVA_Biovil/llava/eval/eval_gpt_review_bench.py +121 -0
  24. LLAVA_Biovil/llava/eval/eval_gpt_review_visual.py +118 -0
  25. LLAVA_Biovil/llava/eval/eval_pope.py +81 -0
  26. LLAVA_Biovil/llava/eval/eval_science_qa.py +114 -0
  27. LLAVA_Biovil/llava/eval/eval_science_qa_gpt4.py +104 -0
  28. LLAVA_Biovil/llava/eval/eval_science_qa_gpt4_requery.py +149 -0
  29. LLAVA_Biovil/llava/eval/eval_textvqa.py +65 -0
  30. LLAVA_Biovil/llava/eval/generate_webpage_data_from_table.py +111 -0
  31. LLAVA_Biovil/llava/eval/m4c_evaluator.py +334 -0
  32. LLAVA_Biovil/llava/eval/model_qa.py +85 -0
  33. LLAVA_Biovil/llava/eval/model_vqa.py +112 -0
  34. LLAVA_Biovil/llava/eval/model_vqa_loader.py +141 -0
  35. LLAVA_Biovil/llava/eval/model_vqa_mmbench.py +169 -0
  36. LLAVA_Biovil/llava/eval/model_vqa_qbench.py +120 -0
  37. LLAVA_Biovil/llava/eval/model_vqa_science.py +147 -0
  38. LLAVA_Biovil/llava/eval/qa_baseline_gpt35.py +74 -0
  39. LLAVA_Biovil/llava/eval/run_llava.py +155 -0
  40. LLAVA_Biovil/llava/eval/summarize_gpt_review.py +60 -0
  41. LLAVA_Biovil/llava/eval/webpage/figures/alpaca.png +0 -0
  42. LLAVA_Biovil/llava/eval/webpage/figures/bard.jpg +0 -0
  43. LLAVA_Biovil/llava/eval/webpage/figures/chatgpt.svg +1 -0
  44. LLAVA_Biovil/llava/eval/webpage/figures/llama.jpg +0 -0
  45. LLAVA_Biovil/llava/eval/webpage/figures/swords_FILL0_wght300_GRAD0_opsz48.svg +1 -0
  46. LLAVA_Biovil/llava/eval/webpage/figures/vicuna.jpeg +0 -0
  47. LLAVA_Biovil/llava/eval/webpage/index.html +162 -0
  48. LLAVA_Biovil/llava/eval/webpage/script.js +245 -0
  49. LLAVA_Biovil/llava/eval/webpage/styles.css +105 -0
  50. LLAVA_Biovil/llava/mm_utils.py +148 -0
LLAVA_Biovil/.dockerignore ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # The .dockerignore file excludes files from the container build process.
2
+ #
3
+ # https://docs.docker.com/engine/reference/builder/#dockerignore-file
4
+
5
+ # Exclude Git files
6
+ .git
7
+ .github
8
+ .gitignore
9
+
10
+ # Exclude Python cache files
11
+ __pycache__
12
+ .mypy_cache
13
+ .pytest_cache
14
+ .ruff_cache
15
+
16
+ # Exclude Python virtual environment
17
+ /venv
18
+
19
+ # Exclude some weights
20
+ /openai
21
+ /liuhaotian
LLAVA_Biovil/.editorconfig ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ root = true
2
+
3
+ # Unix-style newlines with a newline ending every file
4
+ [*]
5
+ end_of_line = lf
6
+ insert_final_newline = true
7
+ trim_trailing_whitespace = true
8
+ charset = utf-8
9
+
10
+ # 4 space indentation
11
+ [*.{py,json}]
12
+ indent_style = space
13
+ indent_size = 4
14
+
15
+ # 2 space indentation
16
+ [*.{md,sh,yaml,yml}]
17
+ indent_style = space
18
+ indent_size = 2
LLAVA_Biovil/.gitattributes ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://git-scm.com/docs/gitattributes
2
+
3
+ # Set the default behavior, in case people don't have core.autocrlf set.
4
+ # https://git-scm.com/docs/gitattributes#_end_of_line_conversion
5
+ * text=auto
6
+
7
+ # common python attributes, taken from https://github.com/alexkaratarakis/gitattributes/blob/710900479a2bedeec7003d381719521ffbb18bf8/Python.gitattributes
8
+ # Source files
9
+ # ============
10
+ *.pxd text diff=python
11
+ *.py text diff=python
12
+ *.py3 text diff=python
13
+ *.pyw text diff=python
14
+ *.pyx text diff=python
15
+ *.pyz text diff=python
16
+ *.pyi text diff=python
17
+
18
+ # Binary files
19
+ # ============
20
+ *.db binary
21
+ *.p binary
22
+ *.pkl binary
23
+ *.pickle binary
24
+ *.pyc binary export-ignore
25
+ *.pyo binary export-ignore
26
+ *.pyd binary
27
+
28
+ # Jupyter notebook
29
+ *.ipynb text eol=lf
LLAVA_Biovil/.gitignore ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__
3
+ *.pyc
4
+ *.egg-info
5
+ dist
6
+
7
+ # Log
8
+ *.log
9
+ *.log.*
10
+ *.json
11
+ *.jsonl
12
+
13
+ # Data
14
+ !**/alpaca-data-conversation.json
15
+
16
+ # Editor
17
+ ../.idea
18
+ *.swp
19
+
20
+ # Other
21
+ .DS_Store
22
+ wandb
23
+ output
24
+
25
+ checkpoints
26
+ ckpts*
27
+
28
+ .ipynb_checkpoints
29
+ *.ipynb
30
+
31
+ # DevContainer
32
+ !.devcontainer/*
33
+
34
+ # Demo
35
+ serve_images/
LLAVA_Biovil/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
LLAVA_Biovil/README.md ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🌋 LLaVA: Large Language and Vision Assistant
2
+
3
+ *Visual instruction tuning towards large language and vision models with GPT-4 level capabilities.*
4
+
5
+ [[Project Page](https://llava-vl.github.io/)] [[Demo](https://llava.hliu.cc/)] [[Data](https://github.com/haotian-liu/LLaVA/blob/main/docs/Data.md)] [[Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)]
6
+
7
+ 🤝Community Contributions: [[llama.cpp](https://github.com/ggerganov/llama.cpp/pull/3436)] [[Colab](https://github.com/camenduru/LLaVA-colab)] [[🤗Space](https://huggingface.co/spaces/badayvedat/LLaVA)] [[Replicate](https://replicate.com/yorickvp/llava-13b)] [[AutoGen](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_lmm_llava.ipynb)] [[BakLLaVA (LLaVA with Mistral-7B)](https://github.com/SkunkworksAI/BakLLaVA)]
8
+
9
+ **Improved Baselines with Visual Instruction Tuning** [[Paper](https://arxiv.org/abs/2310.03744)] <br>
10
+ [Haotian Liu](https://hliu.cc), [Chunyuan Li](https://chunyuan.li/), [Yuheng Li](https://yuheng-li.github.io/), [Yong Jae Lee](https://pages.cs.wisc.edu/~yongjaelee/)
11
+
12
+ **Visual Instruction Tuning** (NeurIPS 2023, **Oral**) [[Paper](https://arxiv.org/abs/2304.08485)]<br>
13
+ [Haotian Liu*](https://hliu.cc), [Chunyuan Li*](https://chunyuan.li/), [Qingyang Wu](https://scholar.google.ca/citations?user=HDiw-TsAAAAJ&hl=en/), [Yong Jae Lee](https://pages.cs.wisc.edu/~yongjaelee/) (*Equal Contribution)
14
+
15
+ <!--p align="center">
16
+ <a href="https://llava.hliu.cc/"><img src="images/llava_logo.png" width="50%"></a> <br>
17
+ Generated by <a href="https://gligen.github.io/">GLIGEN</a> via "a cute lava llama with glasses" and box prompt
18
+ </p-->
19
+
20
+
21
+ ## Release
22
+ - [11/10] [LLaVA-Plus](https://llava-vl.github.io/llava-plus/) is released: Learning to Use Tools for Creating Multimodal Agents, with LLaVA-Plus (LLaVA that Plug and Learn to Use Skills). [[Project Page](https://llava-vl.github.io/llava-plus/)] [[Demo](https://llavaplus.ngrok.io/)] [[Code](https://github.com/LLaVA-VL/LLaVA-Plus-Codebase)] [[Paper](https://arxiv.org/abs/2311.05437)]
23
+ - [11/6] Support **Intel** dGPU and CPU platforms. [More details here.](https://github.com/haotian-liu/LLaVA/tree/intel/docs/intel)
24
+ - [11/2] [LLaVA-Interactive](https://llava-vl.github.io/llava-interactive/) is released: Experience the future of human-AI multimodal interaction with an all-in-one demo for Image Chat, Segmentation, Generation and Editing. [[Project Page](https://llava-vl.github.io/llava-interactive/)] [[Demo](https://llavainteractive.ngrok.io/)] [[Code](https://github.com/LLaVA-VL/LLaVA-Interactive-Demo)] [[Paper](https://arxiv.org/abs/2311.00571)]
25
+ - [10/26] 🔥 LLaVA-1.5 with LoRA achieves comparable performance as full-model finetuning, with a reduced GPU RAM requirement ([ckpts](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md#llava-v15), [script](https://github.com/haotian-liu/LLaVA#train)). We also provide a [doc](https://github.com/haotian-liu/LLaVA/blob/main/docs/Finetune_Custom_Data.md) on how to finetune LLaVA-1.5 on your own dataset with LoRA.
26
+ - [10/12] Check out the Korean LLaVA (Ko-LLaVA), created by ETRI, who has generously supported our research! [[🤗 Demo](https://huggingface.co/spaces/etri-vilab/Ko-LLaVA)]
27
+ - [10/12] LLaVA is now supported in [llama.cpp](https://github.com/ggerganov/llama.cpp/pull/3436) with 4-bit / 5-bit quantization support!
28
+ - [10/11] The training data and scripts of LLaVA-1.5 are released [here](https://github.com/haotian-liu/LLaVA#train), and evaluation scripts are released [here](https://github.com/haotian-liu/LLaVA/blob/main/docs/Evaluation.md)!
29
+ - [10/10] [Roboflow Deep Dive](https://blog.roboflow.com/first-impressions-with-llava-1-5/): First Impressions with LLaVA-1.5.
30
+ - [10/5] 🔥 LLaVA-1.5 is out! Achieving SoTA on 11 benchmarks, with just simple modifications to the original LLaVA, utilizes all public data, completes training in ~1 day on a single 8-A100 node, and surpasses methods like Qwen-VL-Chat that use billion-scale data. Check out the [technical report](https://arxiv.org/abs/2310.03744), and explore the [demo](https://llava.hliu.cc/)! Models are available in [Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md).
31
+ - [9/26] LLaVA is improved with reinforcement learning from human feedback (RLHF) to improve fact grounding and reduce hallucination. Check out the new SFT and RLHF checkpoints at project [[LLavA-RLHF]](https://llava-rlhf.github.io/)
32
+ - [9/22] [LLaVA](https://arxiv.org/abs/2304.08485) is accepted by NeurIPS 2023 as **oral presentation**, and [LLaVA-Med](https://arxiv.org/abs/2306.00890) is accepted by NeurIPS 2023 Datasets and Benchmarks Track as **spotlight presentation**.
33
+ - [9/20] We summarize our empirical study of training 33B and 65B LLaVA models in a [note](https://arxiv.org/abs/2309.09958). Further, if you are interested in the comprehensive review, evolution and trend of multimodal foundation models, please check out our recent survey paper [``Multimodal Foundation Models: From Specialists to General-Purpose Assistants''.](https://arxiv.org/abs/2309.10020)
34
+ <p align="center">
35
+ <img src="https://github.com/Computer-Vision-in-the-Wild/CVinW_Readings/blob/main/images/mfm_evolution.jpeg?raw=true" width=50%/>
36
+ </p>
37
+
38
+ - [7/19] 🔥 We release a major upgrade, including support for LLaMA-2, LoRA training, 4-/8-bit inference, higher resolution (336x336), and a lot more. We release [LLaVA Bench](https://github.com/haotian-liu/LLaVA/blob/main/docs/LLaVA_Bench.md) for benchmarking open-ended visual chat with results from Bard and Bing-Chat. We also support and verify training with RTX 3090 and RTX A6000. Check out [LLaVA-from-LLaMA-2](https://github.com/haotian-liu/LLaVA/blob/main/docs/LLaVA_from_LLaMA2.md), and our [model zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)!
39
+ - [6/26] [CVPR 2023 Tutorial](https://vlp-tutorial.github.io/) on **Large Multimodal Models: Towards Building and Surpassing Multimodal GPT-4**! Please check out [[Slides](https://datarelease.blob.core.windows.net/tutorial/vision_foundation_models_2023/slides/Chunyuan_cvpr2023_tutorial_lmm.pdf)] [[Notes](https://arxiv.org/abs/2306.14895)] [[YouTube](https://youtu.be/mkI7EPD1vp8)] [[Bilibli](https://www.bilibili.com/video/BV1Ng4y1T7v3/)].
40
+ - [6/11] We released the preview for the most requested feature: DeepSpeed and LoRA support! Please see documentations [here](./docs/LoRA.md).
41
+ - [6/1] We released **LLaVA-Med: Large Language and Vision Assistant for Biomedicine**, a step towards building biomedical domain large language and vision models with GPT-4 level capabilities. Checkout the [paper](https://arxiv.org/abs/2306.00890) and [page](https://github.com/microsoft/LLaVA-Med).
42
+ - [5/6] We are releasing [LLaVA-Lighting-MPT-7B-preview](https://huggingface.co/liuhaotian/LLaVA-Lightning-MPT-7B-preview), based on MPT-7B-Chat! See [here](#LLaVA-MPT-7b) for more details.
43
+ - [5/2] 🔥 We are releasing LLaVA-Lighting! Train a lite, multimodal GPT-4 with just $40 in 3 hours! See [here](#train-llava-lightning) for more details.
44
+ - [4/27] Thanks to the community effort, LLaVA-13B with 4-bit quantization allows you to run on a GPU with as few as 12GB VRAM! Try it out [here](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/llava).
45
+ - [4/17] 🔥 We released **LLaVA: Large Language and Vision Assistant**. We propose visual instruction tuning, towards building large language and vision models with GPT-4 level capabilities. Checkout the [paper](https://arxiv.org/abs/2304.08485) and [demo](https://llava.hliu.cc/).
46
+
47
+ <!-- <a href="https://llava.hliu.cc/"><img src="assets/demo.gif" width="70%"></a> -->
48
+
49
+ [![Code License](https://img.shields.io/badge/Code%20License-Apache_2.0-green.svg)](https://github.com/tatsu-lab/stanford_alpaca/blob/main/LICENSE)
50
+ [![Data License](https://img.shields.io/badge/Data%20License-CC%20By%20NC%204.0-red.svg)](https://github.com/tatsu-lab/stanford_alpaca/blob/main/DATA_LICENSE)
51
+ **Usage and License Notices**: The data and checkpoint is intended and licensed for research use only. They are also restricted to uses that follow the license agreement of LLaMA, Vicuna and GPT-4. The dataset is CC BY NC 4.0 (allowing only non-commercial use) and models trained using the dataset should not be used outside of research purposes.
52
+
53
+
54
+ ## Contents
55
+ - [Install](#install)
56
+ - [LLaVA Weights](#llava-weights)
57
+ - [Demo](#Demo)
58
+ - [Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)
59
+ - [Dataset](https://github.com/haotian-liu/LLaVA/blob/main/docs/Data.md)
60
+ - [Train](#train)
61
+ - [Evaluation](#evaluation)
62
+
63
+ ## Install
64
+
65
+ If you are not using Linux, do *NOT* proceed, see instructions for [macOS](https://github.com/haotian-liu/LLaVA/blob/main/docs/macOS.md) and [Windows](https://github.com/haotian-liu/LLaVA/blob/main/docs/Windows.md).
66
+
67
+ 1. Clone this repository and navigate to LLaVA folder
68
+ ```bash
69
+ git clone https://github.com/haotian-liu/LLaVA.git
70
+ cd LLaVA
71
+ ```
72
+
73
+ 2. Install Package
74
+ ```Shell
75
+ conda create -n llava python=3.10 -y
76
+ conda activate llava
77
+ pip install --upgrade pip # enable PEP 660 support
78
+ pip install -e .
79
+ ```
80
+
81
+ 3. Install additional packages for training cases
82
+ ```
83
+ pip install -e ".[train]"
84
+ pip install flash-attn --no-build-isolation
85
+ ```
86
+
87
+ ### Upgrade to latest code base
88
+
89
+ ```Shell
90
+ git pull
91
+ pip install -e .
92
+ ```
93
+
94
+ ### Quick Start With HuggingFace
95
+
96
+ <details>
97
+ <summary>Example Code</summary>
98
+
99
+ ```Python
100
+ from LLAV.llava import load_pretrained_model
101
+ from LLAV.llava import get_model_name_from_path
102
+ from LLAV.llava import eval_model
103
+
104
+ model_path = "liuhaotian/llava-v1.5-7b"
105
+
106
+ tokenizer, model, image_processor, context_len = load_pretrained_model(
107
+ model_path=model_path,
108
+ model_base=None,
109
+ model_name=get_model_name_from_path(model_path)
110
+ )
111
+ ```
112
+
113
+ Check out the details wth the `load_pretrained_model` function in `llava/model/builder.py`.
114
+
115
+ You can also use the `eval_model` function in `llava/eval/run_llava.py` to get the output easily. By doing so, you can use this code on Colab directly after downloading this repository.
116
+
117
+ ``` python
118
+ model_path = "liuhaotian/llava-v1.5-7b"
119
+ prompt = "What are the things I should be cautious about when I visit here?"
120
+ image_file = "https://llava-vl.github.io/static/images/view.jpg"
121
+
122
+ args = type('Args', (), {
123
+ "model_path": model_path,
124
+ "model_base": None,
125
+ "model_name": get_model_name_from_path(model_path),
126
+ "query": prompt,
127
+ "conv_mode": None,
128
+ "image_file": image_file,
129
+ "sep": ",",
130
+ })()
131
+
132
+ eval_model(args)
133
+ ```
134
+ </details>
135
+
136
+ ## LLaVA Weights
137
+ Please check out our [Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md) for all public LLaVA checkpoints, and the instructions of how to use the weights.
138
+
139
+ ## Demo
140
+
141
+ To run our demo, you need to prepare LLaVA checkpoints locally. Please follow the instructions [here](#llava-weights) to download the checkpoints.
142
+
143
+ ### Gradio Web UI
144
+
145
+ To launch a Gradio demo locally, please run the following commands one by one. If you plan to launch multiple model workers to compare between different checkpoints, you only need to launch the controller and the web server *ONCE*.
146
+
147
+ ```mermaid
148
+ flowchart BT
149
+ %% Declare Nodes
150
+ gws("Gradio (UI Server)")
151
+ c("Controller (API Server):<br/>PORT: 10000")
152
+ mw7b("Model Worker:<br/>llava-v1.5-7b<br/>PORT: 40000")
153
+ mw13b("Model Worker:<br/>llava-v1.5-13b<br/>PORT: 40001")
154
+
155
+ %% Declare Styles
156
+ classDef data fill:#3af,stroke:#48a,stroke-width:2px,color:#444
157
+ classDef success fill:#8f8,stroke:#0a0,stroke-width:2px,color:#444
158
+ classDef failure fill:#f88,stroke:#f00,stroke-width:2px,color:#444
159
+
160
+ %% Assign Styles
161
+ class id,od data;
162
+ class cimg,cs_s,scsim_s success;
163
+ class ncimg,cs_f,scsim_f failure;
164
+
165
+ subgraph Demo Connections
166
+ direction BT
167
+ c<-->gws
168
+
169
+ mw7b<-->c
170
+ mw13b<-->c
171
+ end
172
+ ```
173
+
174
+ #### Launch a controller
175
+ ```Shell
176
+ python -m llava.serve.controller --host 0.0.0.0 --port 10000
177
+ ```
178
+
179
+ #### Launch a gradio web server.
180
+ ```Shell
181
+ python -m llava.serve.gradio_web_server --controller http://localhost:10000 --model-list-mode reload
182
+ ```
183
+ You just launched the Gradio web interface. Now, you can open the web interface with the URL printed on the screen. You may notice that there is no model in the model list. Do not worry, as we have not launched any model worker yet. It will be automatically updated when you launch a model worker.
184
+
185
+ #### Launch a model worker
186
+
187
+ This is the actual *worker* that performs the inference on the GPU. Each worker is responsible for a single model specified in `--model-path`.
188
+
189
+ ```Shell
190
+ python -m llava.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-path liuhaotian/llava-v1.5-13b
191
+ ```
192
+ Wait until the process finishes loading the model and you see "Uvicorn running on ...". Now, refresh your Gradio web UI, and you will see the model you just launched in the model list.
193
+
194
+ You can launch as many workers as you want, and compare between different model checkpoints in the same Gradio interface. Please keep the `--controller` the same, and modify the `--port` and `--worker` to a different port number for each worker.
195
+ ```Shell
196
+ python -m llava.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port <different from 40000, say 40001> --worker http://localhost:<change accordingly, i.e. 40001> --model-path <ckpt2>
197
+ ```
198
+
199
+ If you are using an Apple device with an M1 or M2 chip, you can specify the mps device by using the `--device` flag: `--device mps`.
200
+
201
+ #### Launch a model worker (Multiple GPUs, when GPU VRAM <= 24GB)
202
+
203
+ If the VRAM of your GPU is less than 24GB (e.g., RTX 3090, RTX 4090, etc.), you may try running it with multiple GPUs. Our latest code base will automatically try to use multiple GPUs if you have more than one GPU. You can specify which GPUs to use with `CUDA_VISIBLE_DEVICES`. Below is an example of running with the first two GPUs.
204
+
205
+ ```Shell
206
+ CUDA_VISIBLE_DEVICES=0,1 python -m llava.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-path liuhaotian/llava-v1.5-13b
207
+ ```
208
+
209
+ #### Launch a model worker (4-bit, 8-bit inference, quantized)
210
+
211
+ You can launch the model worker with quantized bits (4-bit, 8-bit), which allows you to run the inference with reduced GPU memory footprint, potentially allowing you to run on a GPU with as few as 12GB VRAM. Note that inference with quantized bits may not be as accurate as the full-precision model. Simply append `--load-4bit` or `--load-8bit` to the **model worker** command that you are executing. Below is an example of running with 4-bit quantization.
212
+
213
+ ```Shell
214
+ python -m llava.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-path liuhaotian/llava-v1.5-13b --load-4bit
215
+ ```
216
+
217
+ #### Launch a model worker (LoRA weights, unmerged)
218
+
219
+ You can launch the model worker with LoRA weights, without merging them with the base checkpoint, to save disk space. There will be additional loading time, while the inference speed is the same as the merged checkpoints. Unmerged LoRA checkpoints do not have `lora-merge` in the model name, and are usually much smaller (less than 1GB) than the merged checkpoints (13G for 7B, and 25G for 13B).
220
+
221
+ To load unmerged LoRA weights, you simply need to pass an additional argument `--model-base`, which is the base LLM that is used to train the LoRA weights. You can check the base LLM of each LoRA weights in the [model zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md).
222
+
223
+ ```Shell
224
+ python -m llava.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-path liuhaotian/llava-v1-0719-336px-lora-vicuna-13b-v1.3 --model-base lmsys/vicuna-13b-v1.3
225
+ ```
226
+
227
+ ### CLI Inference
228
+
229
+ Chat about images using LLaVA without the need of Gradio interface. It also supports multiple GPUs, 4-bit and 8-bit quantized inference. With 4-bit quantization, for our LLaVA-1.5-7B, it uses less than 8GB VRAM on a single GPU.
230
+
231
+ ```Shell
232
+ python -m llava.serve.cli \
233
+ --model-path liuhaotian/llava-v1.5-7b \
234
+ --image-file "https://llava-vl.github.io/static/images/view.jpg" \
235
+ --load-4bit
236
+ ```
237
+
238
+ <img src="images/demo_cli.gif" width="70%">
239
+
240
+ ## Train
241
+
242
+ *Below is the latest training configuration for LLaVA v1.5. For legacy models, please refer to README of [this](https://github.com/haotian-liu/LLaVA/tree/v1.0.1) version for now. We'll add them in a separate doc later.*
243
+
244
+ LLaVA training consists of two stages: (1) feature alignment stage: use our 558K subset of the LAION-CC-SBU dataset to connect a *frozen pretrained* vision encoder to a *frozen LLM*; (2) visual instruction tuning stage: use 150K GPT-generated multimodal instruction-following data, plus around 515K VQA data from academic-oriented tasks, to teach the model to follow multimodal instructions.
245
+
246
+ LLaVA is trained on 8 A100 GPUs with 80GB memory. To train on fewer GPUs, you can reduce the `per_device_train_batch_size` and increase the `gradient_accumulation_steps` accordingly. Always keep the global batch size the same: `per_device_train_batch_size` x `gradient_accumulation_steps` x `num_gpus`.
247
+
248
+ ### Hyperparameters
249
+ We use a similar set of hyperparameters as Vicuna in finetuning. Both hyperparameters used in pretraining and finetuning are provided below.
250
+
251
+ 1. Pretraining
252
+
253
+ | Hyperparameter | Global Batch Size | Learning rate | Epochs | Max length | Weight decay |
254
+ | --- | ---: | ---: | ---: | ---: | ---: |
255
+ | LLaVA-v1.5-13B | 256 | 1e-3 | 1 | 2048 | 0 |
256
+
257
+ 2. Finetuning
258
+
259
+ | Hyperparameter | Global Batch Size | Learning rate | Epochs | Max length | Weight decay |
260
+ | --- | ---: | ---: | ---: | ---: | ---: |
261
+ | LLaVA-v1.5-13B | 128 | 2e-5 | 1 | 2048 | 0 |
262
+
263
+ ### Download Vicuna checkpoints (automatically)
264
+
265
+ Our base model Vicuna v1.5, which is an instruction-tuned chatbot, will be downloaded automatically when you run our provided training scripts. No action is needed.
266
+
267
+ ### Pretrain (feature alignment)
268
+
269
+ Please download the 558K subset of the LAION-CC-SBU dataset with BLIP captions we use in the paper [here](https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain).
270
+
271
+ Pretrain takes around 5.5 hours for LLaVA-v1.5-13B on 8x A100 (80G), due to the increased resolution to 336px. It takes around 3.5 hours for LLaVA-v1.5-7B.
272
+
273
+ Training script with DeepSpeed ZeRO-2: [`pretrain.sh`](https://github.com/haotian-liu/LLaVA/blob/main/scripts/v1_5/pretrain.sh).
274
+
275
+ - `--mm_projector_type mlp2x_gelu`: the two-layer MLP vision-language connector.
276
+ - `--vision_tower openai/clip-vit-large-patch14-336`: CLIP ViT-L/14 336px.
277
+
278
+ <details>
279
+ <summary>Pretrain takes around 20 hours for LLaVA-7B on 8x V100 (32G)</summary>
280
+
281
+ We provide training script with DeepSpeed [here](https://github.com/haotian-liu/LLaVA/blob/main/scripts/pretrain_xformers.sh).
282
+ Tips:
283
+ - If you are using V100 which is not supported by FlashAttention, you can use the [memory-efficient attention](https://arxiv.org/abs/2112.05682) implemented in [xFormers](https://github.com/facebookresearch/xformers). Install xformers and replace `llava/train/train_mem.py` above with [llava/train/train_xformers.py](LLAV/llava/train/train_xformers.py).
284
+ </details>
285
+
286
+ ### Visual Instruction Tuning
287
+
288
+ 1. Prepare data
289
+
290
+ Please download the annotation of the final mixture our instruction tuning data [llava_v1_5_mix665k.json](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/blob/main/llava_v1_5_mix665k.json), and download the images from constituting datasets:
291
+
292
+ - COCO: [train2017](http://images.cocodataset.org/zips/train2017.zip)
293
+ - GQA: [images](https://downloads.cs.stanford.edu/nlp/data/gqa/images.zip)
294
+ - OCR-VQA: [download script](https://drive.google.com/drive/folders/1_GYPY5UkUy7HIcR0zq3ZCFgeZN7BAfm_?usp=sharing), **we save all files as `.jpg`**
295
+ - TextVQA: [train_val_images](https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip)
296
+ - VisualGenome: [part1](https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip), [part2](https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip)
297
+
298
+ After downloading all of them, organize the data as follows in `./playground/data`,
299
+
300
+ ```
301
+ ├── coco
302
+ │ └── train2017
303
+ ├── gqa
304
+ │ └── images
305
+ ├── ocr_vqa
306
+ │ └── images
307
+ ├── textvqa
308
+ │ └── train_images
309
+ └── vg
310
+ ├── VG_100K
311
+ └── VG_100K_2
312
+ ```
313
+
314
+ 2. Start training!
315
+
316
+ You may download our pretrained projectors in [Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md). It is not recommended to use legacy projectors, as they may be trained with a different version of the codebase, and if any option is off, the model will not function/train as we expected.
317
+
318
+ Visual instruction tuning takes around 20 hours for LLaVA-v1.5-13B on 8x A100 (80G), due to the increased resolution to 336px. It takes around 10 hours for LLaVA-v1.5-7B on 8x A100 (40G).
319
+
320
+ Training script with DeepSpeed ZeRO-3: [`finetune.sh`](https://github.com/haotian-liu/LLaVA/blob/main/scripts/v1_5/finetune.sh).
321
+
322
+ If you are do not have enough GPU memory:
323
+
324
+ - Use LoRA: [`finetune_lora.sh`](https://github.com/haotian-liu/LLaVA/blob/main/scripts/v1_5/finetune_lora.sh). We are able to fit 13B training in 8-A100-40G/8-A6000, and 7B training in 8-RTX3090. Make sure `per_device_train_batch_size*gradient_accumulation_steps` is the same as the provided script for best reproducibility.
325
+ - Replace `zero3.json` with `zero3_offload.json` which offloads some parameters to CPU RAM. This slows down the training speed.
326
+
327
+ If you are interested in finetuning LLaVA model to your own task/data, please check out [`Finetune_Custom_Data.md`](https://github.com/haotian-liu/LLaVA/blob/main/docs/Finetune_Custom_Data.md)。
328
+
329
+ New options to note:
330
+
331
+ - `--mm_projector_type mlp2x_gelu`: the two-layer MLP vision-language connector.
332
+ - `--vision_tower openai/clip-vit-large-patch14-336`: CLIP ViT-L/14 336px.
333
+ - `--image_aspect_ratio pad`: this pads the non-square images to square, instead of cropping them; it slightly reduces hallucination.
334
+ - `--group_by_modality_length True`: this should only be used when your instruction tuning dataset contains both language (e.g. ShareGPT) and multimodal (e.g. LLaVA-Instruct). It makes the training sampler only sample a single modality (either image or language) during training, which we observe to speed up training by ~25%, and does not affect the final outcome.
335
+
336
+ ## Evaluation
337
+
338
+ In LLaVA-1.5, we evaluate models on a diverse set of 12 benchmarks. To ensure the reproducibility, we evaluate the models with greedy decoding. We do not evaluate using beam search to make the inference process consistent with the chat demo of real-time outputs.
339
+
340
+ See [Evaluation.md](https://github.com/haotian-liu/LLaVA/blob/main/docs/Evaluation.md).
341
+
342
+ ### GPT-assisted Evaluation
343
+
344
+ Our GPT-assisted evaluation pipeline for multimodal modeling is provided for a comprehensive understanding of the capabilities of vision-language models. Please see our paper for more details.
345
+
346
+ 1. Generate LLaVA responses
347
+
348
+ ```Shell
349
+ python model_vqa.py \
350
+ --model-path ./checkpoints/LLaVA-13B-v0 \
351
+ --question-file \
352
+ playground/data/coco2014_val_qa_eval/qa90_questions.jsonl \
353
+ --image-folder \
354
+ /path/to/coco2014_val \
355
+ --answers-file \
356
+ /path/to/answer-file-our.jsonl
357
+ ```
358
+
359
+ 2. Evaluate the generated responses. In our case, [`answer-file-ref.jsonl`](./playground/data/coco2014_val_qa_eval/qa90_gpt4_answer.jsonl) is the response generated by text-only GPT-4 (0314), with the context captions/boxes provided.
360
+
361
+ ```Shell
362
+ OPENAI_API_KEY="sk-***********************************" python llava/eval/eval_gpt_review_visual.py \
363
+ --question playground/data/coco2014_val_qa_eval/qa90_questions.jsonl \
364
+ --context llava/eval/table/caps_boxes_coco2014_val_80.jsonl \
365
+ --answer-list \
366
+ /path/to/answer-file-ref.jsonl \
367
+ /path/to/answer-file-our.jsonl \
368
+ --rule llava/eval/table/rule.json \
369
+ --output /path/to/review.json
370
+ ```
371
+
372
+ 3. Summarize the evaluation results
373
+
374
+ ```Shell
375
+ python summarize_gpt_review.py
376
+ ```
377
+
378
+ ## Citation
379
+
380
+ If you find LLaVA useful for your research and applications, please cite using this BibTeX:
381
+ ```bibtex
382
+
383
+ @misc{liu2023improvedllava,
384
+ title={Improved Baselines with Visual Instruction Tuning},
385
+ author={Liu, Haotian and Li, Chunyuan and Li, Yuheng and Lee, Yong Jae},
386
+ publisher={arXiv:2310.03744},
387
+ year={2023},
388
+ }
389
+
390
+ @misc{liu2023llava,
391
+ title={Visual Instruction Tuning},
392
+ author={Liu, Haotian and Li, Chunyuan and Wu, Qingyang and Lee, Yong Jae},
393
+ publisher={arXiv:2304.08485},
394
+ year={2023},
395
+ }
396
+ ```
397
+
398
+ ## Acknowledgement
399
+
400
+ - [Vicuna](https://github.com/lm-sys/FastChat): the codebase we built upon, and our base model Vicuna-13B that has the amazing language capabilities!
401
+
402
+ ## Related Projects
403
+
404
+ - [Instruction Tuning with GPT-4](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
405
+ - [LLaVA-Med: Training a Large Language-and-Vision Assistant for Biomedicine in One Day](https://github.com/microsoft/LLaVA-Med)
406
+ - [Otter: In-Context Multi-Modal Instruction Tuning](https://github.com/Luodian/Otter)
407
+
408
+ For future project ideas, please check out:
409
+ - [SEEM: Segment Everything Everywhere All at Once](https://github.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once)
410
+ - [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything) to detect, segment, and generate anything by marrying [Grounding DINO](https://github.com/IDEA-Research/GroundingDINO) and [Segment-Anything](https://github.com/facebookresearch/segment-anything).
LLAVA_Biovil/__init__.py ADDED
File without changes
LLAVA_Biovil/biovil_t/__init__.py ADDED
File without changes
LLAVA_Biovil/biovil_t/encoder.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -------------------------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
4
+ # -------------------------------------------------------------------------------------------
5
+
6
+ from __future__ import annotations
7
+
8
+ from contextlib import contextmanager
9
+ from typing import Any, Generator, Optional, Sequence, Tuple, Union
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from health_multimodal.common.device import get_module_device
14
+ from timm.models.layers import trunc_normal_
15
+
16
+ from .resnet import resnet18, resnet50
17
+ from .transformer import VisionTransformerPooler
18
+ from .types import ImageEncoderType
19
+
20
+ DEFAULT_DILATION_VALUES_FOR_RESNET = (False, False, True)
21
+ ImageEncoderOutputType = Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
22
+
23
+
24
+ class ImageEncoder(nn.Module):
25
+ """Image encoder trunk module for the ``ImageModel`` class.
26
+
27
+ :param img_encoder_type : Type of image encoder model to use, either ``"resnet18_multi_image"`` or
28
+ ``"resnet50_multi_image"``.
29
+ """
30
+
31
+ def __init__(self, img_encoder_type: str):
32
+ super().__init__()
33
+ self.img_encoder_type = img_encoder_type
34
+ self.encoder = self._create_encoder()
35
+
36
+ def _create_encoder(self, **kwargs: Any) -> nn.Module:
37
+ if self.img_encoder_type in [ImageEncoderType.RESNET18, ImageEncoderType.RESNET18_MULTI_IMAGE]:
38
+ encoder_class = resnet18
39
+ elif self.img_encoder_type in [ImageEncoderType.RESNET50, ImageEncoderType.RESNET50_MULTI_IMAGE]:
40
+ encoder_class = resnet50
41
+ else:
42
+ supported = ImageEncoderType.get_members(multi_image_encoders_only=False)
43
+ raise NotImplementedError(f"Image encoder type \"{self.img_encoder_type}\" must be in {supported}")
44
+
45
+ encoder = encoder_class(pretrained=True, **kwargs)
46
+
47
+ return encoder
48
+
49
+ def forward(self,
50
+ current_image: torch.Tensor,
51
+ return_patch_embeddings: bool = False) -> ImageEncoderOutputType:
52
+ """Get image global and patch embeddings"""
53
+
54
+ patch_emb = self.encoder(current_image)
55
+ avg_pooled_emb = torch.flatten(torch.nn.functional.adaptive_avg_pool2d(patch_emb, (1, 1)), 1)
56
+ if return_patch_embeddings:
57
+ return patch_emb, avg_pooled_emb
58
+
59
+ return avg_pooled_emb
60
+
61
+ def reload_encoder_with_dilation(self, replace_stride_with_dilation: Optional[Sequence[bool]] = None) -> None:
62
+ """Workaround for enabling dilated convolutions after model initialization.
63
+
64
+ :param replace_stride_with_dilation: Replace the 2x2 standard convolution stride with a dilated convolution
65
+ in each layer in the last three blocks of ResNet architecture.
66
+ """
67
+ if self.img_encoder_type == ImageEncoderType.RESNET18:
68
+ # resnet18 uses BasicBlock implementation, which does not support dilated convolutions.
69
+ raise NotImplementedError("resnet18 does not support dilated convolutions")
70
+
71
+ if replace_stride_with_dilation is None:
72
+ replace_stride_with_dilation = DEFAULT_DILATION_VALUES_FOR_RESNET
73
+
74
+ device = next(self.encoder.parameters()).device
75
+ new_encoder = self._create_encoder(replace_stride_with_dilation=replace_stride_with_dilation).to(device)
76
+
77
+ if self.encoder.training:
78
+ new_encoder.train()
79
+ else:
80
+ new_encoder.eval()
81
+
82
+ new_encoder.load_state_dict(self.encoder.state_dict())
83
+ self.encoder = new_encoder
84
+
85
+
86
+ class MultiImageEncoder(ImageEncoder):
87
+ """Multi-image encoder trunk module for the ``ImageModel`` class.
88
+ It can be used to encode multiple images into combined latent representation.
89
+ Currently it only supports two input images but can be extended to support more in future.
90
+
91
+ :param img_encoder_type: Type of image encoder model to use: either ``"resnet18"`` or ``"resnet50"``.
92
+ """
93
+
94
+ def __init__(self, img_encoder_type: str):
95
+ super().__init__(img_encoder_type)
96
+
97
+ output_dim = 256 # The aggregate feature dim of the encoder is `2 * output_dim` i.e. [f_static, f_diff]
98
+ grid_shape = (14, 14) # Spatial dimensions of patch grid.
99
+
100
+ backbone_output_feature_dim = get_encoder_output_dim(self.encoder, device=get_module_device(self))
101
+
102
+ self.backbone_to_vit = nn.Conv2d(in_channels=backbone_output_feature_dim, out_channels=output_dim,
103
+ kernel_size=1, stride=1, padding=0, bias=False)
104
+ self.vit_pooler = VisionTransformerPooler(input_dim=output_dim, grid_shape=grid_shape)
105
+
106
+ # Missing image embedding
107
+ self.missing_previous_emb = nn.Parameter(torch.zeros(1, output_dim, 1, 1))
108
+ trunc_normal_(self.missing_previous_emb, std=.02)
109
+
110
+ def forward(self, # type: ignore[override]
111
+ current_image: torch.Tensor,
112
+ previous_image: Optional[torch.Tensor] = None,
113
+ return_patch_embeddings: bool = False) -> ImageEncoderOutputType:
114
+
115
+ batch_size = current_image.shape[0]
116
+
117
+ if previous_image is not None:
118
+ assert current_image.shape == previous_image.shape
119
+ x = torch.cat([current_image, previous_image], dim=0)
120
+ x = super().forward(x, return_patch_embeddings=True)[0]
121
+ x = self.backbone_to_vit(x)
122
+ patch_x, patch_x_previous = x[:batch_size], x[batch_size:]
123
+ diff_x = self.vit_pooler(current_image=patch_x, previous_image=patch_x_previous)
124
+ else:
125
+ x = super().forward(current_image, return_patch_embeddings=True)[0]
126
+ patch_x = self.backbone_to_vit(x)
127
+ B, _, W, H = patch_x.shape
128
+ diff_x = self.missing_previous_emb.repeat(B, 1, W, H)
129
+
130
+ patch_fused = torch.cat([patch_x, diff_x], dim=1)
131
+ avg_pooled_emb = torch.flatten(torch.nn.functional.adaptive_avg_pool2d(patch_fused, (1, 1)), 1)
132
+
133
+ if return_patch_embeddings:
134
+ return patch_fused, avg_pooled_emb
135
+
136
+ return avg_pooled_emb
137
+
138
+ def reload_encoder_with_dilation(self, replace_stride_with_dilation: Optional[Sequence[bool]] = None) -> None:
139
+ raise NotImplementedError
140
+
141
+
142
+ @torch.no_grad()
143
+ def get_encoder_output_dim(module: torch.nn.Module, device: torch.device) -> int:
144
+ """Calculate the output dimension of an encoder by making a single forward pass.
145
+
146
+ :param module: Encoder module.
147
+ :param device: Compute device to use.
148
+ """
149
+ # Target device
150
+ assert isinstance(device, torch.device)
151
+
152
+ x = torch.rand((1, 3, 448, 448)).to(device)
153
+
154
+ # Extract the number of output feature dimensions
155
+ with restore_training_mode(module):
156
+ module.eval()
157
+ representations = module(x)
158
+ return representations.shape[1]
159
+
160
+
161
+ @contextmanager
162
+ def restore_training_mode(module: nn.Module) -> Generator[None, None, None]:
163
+ """Restore the training mode of a module after some operation.
164
+
165
+ :param module: PyTorch module.
166
+ """
167
+ training_mode = module.training
168
+ yield
169
+ module.train(mode=training_mode)
170
+
171
+
172
+ def get_encoder_from_type(img_encoder_type: str) -> ImageEncoder:
173
+ """Returns the encoder class for the given encoder type.
174
+
175
+ :param img_encoder_type: Encoder type. {RESNET18, RESNET50, RESNET18_MULTI_IMAGE, RESNET50_MULTI_IMAGE}
176
+ """
177
+ if img_encoder_type in ImageEncoderType.get_members(multi_image_encoders_only=True):
178
+ return MultiImageEncoder(img_encoder_type=img_encoder_type)
179
+ else:
180
+ return ImageEncoder(img_encoder_type=img_encoder_type)
LLAVA_Biovil/biovil_t/model.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -------------------------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
4
+ # -------------------------------------------------------------------------------------------
5
+
6
+ from __future__ import annotations
7
+
8
+ from abc import ABC, abstractmethod
9
+ from pathlib import Path
10
+ from typing import Any, Optional, Union
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from health_multimodal.common.device import get_module_device
16
+
17
+ from .encoder import get_encoder_from_type, get_encoder_output_dim, MultiImageEncoder
18
+ from .modules import MLP, MultiTaskModel
19
+ from .types import ImageModelOutput
20
+
21
+
22
+ class BaseImageModel(nn.Module, ABC):
23
+ """Abstract class for image models."""
24
+ @abstractmethod
25
+ def forward(self, *args: Any, **kwargs: Any) -> ImageModelOutput:
26
+ raise NotImplementedError
27
+
28
+ @abstractmethod
29
+ def get_patchwise_projected_embeddings(self, input_img: torch.Tensor, normalize: bool) -> torch.Tensor:
30
+ raise NotImplementedError
31
+
32
+
33
+ class ImageModel(BaseImageModel):
34
+ """Image encoder module"""
35
+
36
+ def __init__(self,
37
+ img_encoder_type: str,
38
+ joint_feature_size: int,
39
+ freeze_encoder: bool = False,
40
+ pretrained_model_path: Optional[Union[str, Path]] = None,
41
+ **downstream_classifier_kwargs: Any):
42
+ super().__init__()
43
+
44
+ # Initiate encoder, projector, and classifier
45
+ self.encoder = get_encoder_from_type(img_encoder_type)
46
+ self.feature_size = get_encoder_output_dim(self.encoder, device=get_module_device(self.encoder))
47
+ self.projector = MLP(input_dim=self.feature_size, output_dim=joint_feature_size,
48
+ hidden_dim=joint_feature_size, use_1x1_convs=True)
49
+ self.downstream_classifier_kwargs = downstream_classifier_kwargs
50
+ self.classifier = self.create_downstream_classifier() if downstream_classifier_kwargs else None
51
+
52
+ # Initialise the mode of modules
53
+ self.freeze_encoder = freeze_encoder
54
+ self.train()
55
+
56
+ self.image_processor = None #TODO
57
+
58
+ if pretrained_model_path is not None:
59
+ if not isinstance(pretrained_model_path, (str, Path)):
60
+ raise TypeError(f"Expected a string or Path, got {type(pretrained_model_path)}")
61
+ state_dict = torch.load(pretrained_model_path, map_location="cpu")
62
+ # drop projector
63
+ # for k in list(state_dict.keys()):
64
+ # if k.startswith("projector"):
65
+ # state_dict.pop(k)
66
+
67
+ self.load_state_dict(state_dict, strict=False)
68
+
69
+
70
+ def train(self, mode: bool = True) -> Any:
71
+ """Switch the model between training and evaluation modes."""
72
+ super().train(mode=mode)
73
+ if self.freeze_encoder:
74
+ self.encoder.train(mode=False)
75
+ self.projector.train(mode=False)
76
+ return self
77
+
78
+ def forward(self, x: torch.Tensor) -> ImageModelOutput: # type: ignore[override]
79
+ with torch.set_grad_enabled(not self.freeze_encoder):
80
+ patch_x, pooled_x = self.encoder(x, return_patch_embeddings=True)
81
+ return self.forward_post_encoder(patch_x, pooled_x)
82
+
83
+ def forward_post_encoder(self, patch_x: torch.Tensor, pooled_x: torch.Tensor) -> ImageModelOutput:
84
+ with torch.set_grad_enabled(not self.freeze_encoder):
85
+ projected_patch_embeddings = self.projector(patch_x)
86
+ projected_global_embedding = torch.mean(projected_patch_embeddings, dim=(2, 3))
87
+
88
+ logits = self.classifier(pooled_x) if self.classifier else None
89
+ return ImageModelOutput(img_embedding=pooled_x,
90
+ patch_embeddings=patch_x,
91
+ class_logits=logits,
92
+ projected_patch_embeddings=projected_patch_embeddings,
93
+ projected_global_embedding=projected_global_embedding)
94
+
95
+ def create_downstream_classifier(self, **kwargs: Any) -> MultiTaskModel:
96
+ """Create the classification module for the downstream task."""
97
+ downstream_classifier_kwargs = kwargs if kwargs else self.downstream_classifier_kwargs
98
+ return MultiTaskModel(self.feature_size, **downstream_classifier_kwargs)
99
+
100
+ @torch.no_grad()
101
+ def get_patchwise_projected_embeddings(self, input_img: torch.Tensor, normalize: bool) -> torch.Tensor:
102
+ """Get patch-wise projected embeddings from the CNN model.
103
+
104
+ :param input_img: input tensor image [B, C, H, W].
105
+ :param normalize: If ``True``, the embeddings are L2-normalized.
106
+ :returns projected_embeddings: tensor of embeddings in shape [batch, n_patches_h, n_patches_w, feature_size].
107
+ """
108
+ assert not self.training, "This function is only implemented for evaluation mode"
109
+ outputs = self.forward(input_img)
110
+ projected_embeddings = outputs.projected_patch_embeddings.detach() # type: ignore
111
+ if normalize:
112
+ projected_embeddings = F.normalize(projected_embeddings, dim=1)
113
+ projected_embeddings = projected_embeddings.permute([0, 2, 3, 1]) # B D H W -> B H W D (D: Features)
114
+ return projected_embeddings
115
+
116
+
117
+ class MultiImageModel(ImageModel):
118
+ def __init__(self, **kwargs: Any) -> None:
119
+ super().__init__(**kwargs)
120
+ assert isinstance(self.encoder, MultiImageEncoder), "MultiImageModel only supports MultiImageEncoder"
121
+
122
+ def forward(self, # type: ignore[override]
123
+ current_image: torch.Tensor,
124
+ previous_image: Optional[torch.Tensor] = None) -> ImageModelOutput:
125
+
126
+ with torch.set_grad_enabled(not self.freeze_encoder):
127
+ patch_x, pooled_x = self.encoder(current_image=current_image,
128
+ previous_image=previous_image,
129
+ return_patch_embeddings=True)
130
+ return self.forward_post_encoder(patch_x, pooled_x)
LLAVA_Biovil/biovil_t/modules.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -------------------------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
4
+ # -------------------------------------------------------------------------------------------
5
+
6
+ from typing import Callable, Optional
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+
12
+ class MLP(nn.Module):
13
+ """
14
+ Fully connected layers to map between image embeddings and projection space where pairs of images are compared.
15
+
16
+ :param input_dim: Input embedding feature size
17
+ :param hidden_dim: Hidden layer size in MLP
18
+ :param output_dim: Output projection size
19
+ :param use_1x1_convs: Use 1x1 conv kernels instead of 2D linear transformations for speed and memory efficiency.
20
+ """
21
+
22
+ def __init__(self,
23
+ input_dim: int,
24
+ output_dim: int,
25
+ hidden_dim: Optional[int] = None,
26
+ use_1x1_convs: bool = False) -> None:
27
+ super().__init__()
28
+
29
+ if use_1x1_convs:
30
+ linear_proj_1_args = {'in_channels': input_dim, 'out_channels': hidden_dim, 'kernel_size': 1, 'bias': False}
31
+ linear_proj_2_args = {'in_channels': hidden_dim, 'out_channels': output_dim, 'kernel_size': 1, 'bias': True}
32
+ normalisation_layer: Callable = nn.BatchNorm2d
33
+ projection_layer: Callable = nn.Conv2d
34
+ else:
35
+ linear_proj_1_args = {'in_features': input_dim, 'out_features': hidden_dim, 'bias': False}
36
+ linear_proj_2_args = {'in_features': hidden_dim, 'out_features': output_dim, 'bias': True}
37
+ normalisation_layer = nn.BatchNorm1d
38
+ projection_layer = nn.Linear
39
+
40
+ self.output_dim = output_dim
41
+ self.input_dim = input_dim
42
+ if hidden_dim is not None:
43
+ self.model = nn.Sequential(
44
+ projection_layer(**linear_proj_1_args),
45
+ normalisation_layer(hidden_dim),
46
+ nn.ReLU(inplace=True),
47
+ projection_layer(**linear_proj_2_args))
48
+ else:
49
+ self.model = nn.Linear(input_dim, output_dim) # type: ignore
50
+
51
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
52
+ """forward pass of the multi-layer perceptron"""
53
+ x = self.model(x)
54
+ return x
55
+
56
+
57
+ class MultiTaskModel(nn.Module):
58
+ """Torch module for multi-task classification heads. We create a separate classification head
59
+ for each task and perform a forward pass on each head independently in forward(). Classification
60
+ heads are instances of `MLP`.
61
+
62
+ :param input_dim: Number of dimensions of the input feature map.
63
+ :param classifier_hidden_dim: Number of dimensions of hidden features in the MLP.
64
+ :param num_classes: Number of output classes per task.
65
+ :param num_tasks: Number of classification tasks or heads required.
66
+ """
67
+
68
+ def __init__(self, input_dim: int, classifier_hidden_dim: Optional[int], num_classes: int, num_tasks: int):
69
+
70
+ super().__init__()
71
+
72
+ self.num_classes = num_classes
73
+ self.num_tasks = num_tasks
74
+
75
+ for task in range(num_tasks):
76
+ setattr(self, "fc_" + str(task), MLP(input_dim, output_dim=num_classes, hidden_dim=classifier_hidden_dim))
77
+
78
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
79
+ """Returns [batch_size, num_tasks, num_classes] tensor of logits."""
80
+ batch_size = x.shape[0]
81
+ out = torch.zeros((batch_size, self.num_classes, self.num_tasks), dtype=x.dtype, device=x.device)
82
+ for task in range(self.num_tasks):
83
+ classifier = getattr(self, "fc_" + str(task))
84
+ out[:, :, task] = classifier(x)
85
+ return out
LLAVA_Biovil/biovil_t/pretrained.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -------------------------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
4
+ # -------------------------------------------------------------------------------------------
5
+
6
+ from __future__ import annotations
7
+
8
+ import tempfile
9
+ from pathlib import Path
10
+
11
+ from torchvision.datasets.utils import download_url
12
+
13
+ from .model import ImageModel
14
+ from .types import ImageEncoderType
15
+
16
+
17
+ JOINT_FEATURE_SIZE = 128
18
+
19
+ BIOMED_VLP_CXR_BERT_SPECIALIZED = "microsoft/BiomedVLP-CXR-BERT-specialized"
20
+ BIOMED_VLP_BIOVIL_T = "microsoft/BiomedVLP-BioViL-T"
21
+ HF_URL = "https://huggingface.co"
22
+
23
+ CXR_BERT_COMMIT_TAG = "v1.1"
24
+ BIOVIL_T_COMMIT_TAG = "v1.0"
25
+
26
+ BIOVIL_IMAGE_WEIGHTS_NAME = "biovil_image_resnet50_proj_size_128.pt"
27
+ BIOVIL_IMAGE_WEIGHTS_URL = f"{HF_URL}/{BIOMED_VLP_CXR_BERT_SPECIALIZED}/resolve/{CXR_BERT_COMMIT_TAG}/{BIOVIL_IMAGE_WEIGHTS_NAME}" # noqa: E501
28
+ BIOVIL_IMAGE_WEIGHTS_MD5 = "02ce6ee460f72efd599295f440dbb453"
29
+
30
+ BIOVIL_T_IMAGE_WEIGHTS_NAME = "biovil_t_image_model_proj_size_128.pt"
31
+ BIOVIL_T_IMAGE_WEIGHTS_URL = f"{HF_URL}/{BIOMED_VLP_BIOVIL_T}/resolve/{BIOVIL_T_COMMIT_TAG}/{BIOVIL_T_IMAGE_WEIGHTS_NAME}" # noqa: E501
32
+ BIOVIL_T_IMAGE_WEIGHTS_MD5 = "a83080e2f23aa584a4f2b24c39b1bb64"
33
+
34
+
35
+ def _download_biovil_image_model_weights() -> Path:
36
+ """Download image model weights from Hugging Face.
37
+
38
+ More information available at https://huggingface.co/microsoft/BiomedVLP-CXR-BERT-specialized.
39
+ """
40
+ root_dir = tempfile.gettempdir()
41
+ download_url(
42
+ BIOVIL_IMAGE_WEIGHTS_URL,
43
+ root=root_dir,
44
+ filename=BIOVIL_IMAGE_WEIGHTS_NAME,
45
+ md5=BIOVIL_IMAGE_WEIGHTS_MD5,
46
+ )
47
+ return Path(root_dir, BIOVIL_IMAGE_WEIGHTS_NAME)
48
+
49
+
50
+ def _download_biovil_t_image_model_weights() -> Path:
51
+ """Download image model weights from Hugging Face.
52
+
53
+ More information available at https://huggingface.co/microsoft/microsoft/BiomedVLP-BioViL-T.
54
+ """
55
+ root_dir = tempfile.gettempdir()
56
+ download_url(
57
+ BIOVIL_T_IMAGE_WEIGHTS_URL,
58
+ root=root_dir,
59
+ filename=BIOVIL_T_IMAGE_WEIGHTS_NAME,
60
+ md5=BIOVIL_T_IMAGE_WEIGHTS_MD5
61
+ )
62
+ return Path(root_dir, BIOVIL_T_IMAGE_WEIGHTS_NAME)
63
+
64
+
65
+ def get_biovil_image_encoder(pretrained: bool = True) -> ImageModel:
66
+ """Download weights from Hugging Face and instantiate the image model."""
67
+ resnet_checkpoint_path = _download_biovil_image_model_weights() if pretrained else None
68
+
69
+ image_model = ImageModel(
70
+ img_encoder_type=ImageEncoderType.RESNET50,
71
+ joint_feature_size=JOINT_FEATURE_SIZE,
72
+ pretrained_model_path=resnet_checkpoint_path,
73
+ )
74
+ return image_model
75
+
76
+
77
+ def get_biovil_t_image_encoder() -> ImageModel:
78
+ """Download weights from Hugging Face and instantiate the image model."""
79
+
80
+ biovilt_checkpoint_path = _download_biovil_t_image_model_weights()
81
+ model_type = ImageEncoderType.RESNET50_MULTI_IMAGE
82
+ image_model = ImageModel(img_encoder_type=model_type,
83
+ joint_feature_size=JOINT_FEATURE_SIZE,
84
+ pretrained_model_path=biovilt_checkpoint_path)
85
+ return image_model
LLAVA_Biovil/biovil_t/resnet.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -------------------------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
4
+ # -------------------------------------------------------------------------------------------
5
+
6
+ from typing import Any, List, Tuple, Type, Union
7
+
8
+ import torch
9
+ from torch.hub import load_state_dict_from_url
10
+ from torchvision.models.resnet import ResNet, BasicBlock, Bottleneck
11
+
12
+ TypeSkipConnections = Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
13
+
14
+
15
+ class ResNetHIML(ResNet):
16
+ """Wrapper class of the original torchvision ResNet model.
17
+
18
+ The forward function is updated to return the penultimate layer
19
+ activations, which are required to obtain image patch embeddings.
20
+ """
21
+
22
+ def __init__(self, **kwargs: Any) -> None:
23
+ super().__init__(**kwargs)
24
+
25
+ def forward(self, x: torch.Tensor,
26
+ return_intermediate_layers: bool = False) -> Union[torch.Tensor, TypeSkipConnections]:
27
+ """ResNetHIML forward pass. Optionally returns intermediate layers using the
28
+ ``return_intermediate_layers`` argument.
29
+
30
+ :param return_intermediate_layers: If ``True``, return layers x0-x4 as a tuple,
31
+ otherwise return x4 only.
32
+ """
33
+
34
+ x0 = self.conv1(x)
35
+ x0 = self.bn1(x0)
36
+ x0 = self.relu(x0)
37
+ x0 = self.maxpool(x0)
38
+
39
+ x1 = self.layer1(x0)
40
+ x2 = self.layer2(x1)
41
+ x3 = self.layer3(x2)
42
+ x4 = self.layer4(x3)
43
+
44
+ if return_intermediate_layers:
45
+ return x0, x1, x2, x3, x4
46
+ else:
47
+ return x4
48
+
49
+
50
+ def _resnet(arch: str, block: Type[Union[BasicBlock, Bottleneck]], layers: List[int],
51
+ pretrained: bool, progress: bool, **kwargs: Any) -> ResNetHIML:
52
+ """Instantiate a custom :class:`ResNet` model.
53
+
54
+ Adapted from :mod:`torchvision.models.resnet`.
55
+ """
56
+ model = ResNetHIML(block=block, layers=layers, **kwargs)
57
+ if pretrained:
58
+ state_dict = load_state_dict_from_url('https://download.pytorch.org/models/resnet50-19c8e357.pth', progress=progress)
59
+ model.load_state_dict(state_dict)
60
+ return model
61
+
62
+
63
+ def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNetHIML:
64
+ r"""ResNet-18 model from
65
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
66
+
67
+ :param pretrained: If ``True``, returns a model pre-trained on ImageNet.
68
+ :param progress: If ``True``, displays a progress bar of the download to ``stderr``.
69
+ """
70
+ return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs)
71
+
72
+
73
+ def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNetHIML:
74
+ r"""ResNet-50 model from
75
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
76
+
77
+ :param pretrained: If ``True``, returns a model pre-trained on ImageNet
78
+ :param progress: If ``True``, displays a progress bar of the download to ``stderr``.
79
+ """
80
+ return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
LLAVA_Biovil/biovil_t/transformer.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -------------------------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
4
+ # -------------------------------------------------------------------------------------------
5
+
6
+ import math
7
+ from dataclasses import dataclass
8
+ from functools import partial
9
+ from typing import Any, Callable, Optional, Set, Tuple
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from timm.models.layers import DropPath, Mlp, trunc_normal_
14
+
15
+
16
+ def torch_int_div(tensor1, tensor2):
17
+ """
18
+ A function that performs integer division across different versions of PyTorch.
19
+ """
20
+ return torch.div(tensor1, tensor2, rounding_mode="floor")
21
+
22
+ @dataclass
23
+ class MultiHeadAttentionOutput:
24
+ mha_output: torch.Tensor
25
+ attention: Optional[torch.Tensor] = None
26
+
27
+
28
+ class VisionTransformerPooler(nn.Module):
29
+ """
30
+ :param input_dim: Input feature dimension (i.e., channels in old CNN terminology)
31
+ :param grid_shape: Shape of the grid of patches per image
32
+ :param num_heads: Number of self-attention heads within the MHA block
33
+ :param num_blocks: Number of blocks per attention layer
34
+ :param norm_layer: Normalisation layer
35
+
36
+ `self.type_embed`: Is used to characterise prior and current scans, and
37
+ create permutation variance across modalities/series.
38
+ """
39
+
40
+ def __init__(self,
41
+ input_dim: int,
42
+ grid_shape: Tuple[int, int],
43
+ num_heads: int = 8,
44
+ num_blocks: int = 3,
45
+ norm_layer: Any = partial(nn.LayerNorm, eps=1e-6)):
46
+ super().__init__()
47
+
48
+ block_kwargs = dict(dim=input_dim, num_heads=num_heads, mlp_ratio=1., drop=0.10, attn_drop=0.10,
49
+ drop_path=0.25, act_layer=nn.GELU, norm_layer=norm_layer)
50
+ self.blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_blocks)])
51
+ self.norm_post = norm_layer(input_dim)
52
+ self.grid_shape = grid_shape
53
+ self.num_patches = grid_shape[0] * grid_shape[1]
54
+ self.num_blocks = num_blocks
55
+
56
+ # Temporal positional embeddings
57
+ num_series: int = 2
58
+ self.type_embed = nn.Parameter(torch.zeros(num_series, 1, input_dim))
59
+ trunc_normal_(self.type_embed, std=.02)
60
+
61
+ # Positional embeddings 1 x L x C (L: Sequence length, C: Feature dimension)
62
+ self.pos_drop = nn.Dropout(p=0.10)
63
+ pos_embed_class = SinePositionEmbedding(embedding_dim=input_dim // 2, normalize=True)
64
+ pos_embed = pos_embed_class(mask=torch.ones([1, grid_shape[0], grid_shape[1]])) # 1 x L x C
65
+ self.register_buffer("pos_embed", pos_embed, persistent=False)
66
+
67
+ # Initialisation
68
+ self.apply(self._init_weights)
69
+
70
+ def no_weight_decay(self) -> Set[str]:
71
+ return {'type_embed'}
72
+
73
+ def forward(self, current_image: torch.Tensor, previous_image: Optional[torch.Tensor] = None) -> torch.Tensor:
74
+ B, C, H, W = current_image.shape
75
+ assert H == self.grid_shape[0] and W == self.grid_shape[1], "Input and grid shapes do not match"
76
+
77
+ # Flatten patch embeddings to have shape (B x L x C), L = H * W
78
+ if previous_image is not None:
79
+ assert previous_image.shape == current_image.shape, "current_image and previous_image shapes do not match"
80
+ previous_image = previous_image.view(B, C, H * W).transpose(1, 2)
81
+ current_image = current_image.view(B, C, H * W).transpose(1, 2)
82
+ pos_embed = self.pos_embed.repeat(B, 1, 1) # type: ignore
83
+
84
+ # Final token activations (B x 2L x C)
85
+ token_features = self.forward_after_reshape(x=current_image, pos_embed=pos_embed, x_previous=previous_image)
86
+
87
+ # Extract the patch features of current image
88
+ cur_img_token_id = 0
89
+ current_token_features = token_features[:, cur_img_token_id:self.num_patches+cur_img_token_id]
90
+ current_patch_features = current_token_features.transpose(1, 2).view(B, C, H, W)
91
+
92
+ return current_patch_features
93
+
94
+ def forward_after_reshape(self,
95
+ x: torch.Tensor,
96
+ pos_embed: torch.Tensor,
97
+ x_previous: Optional[torch.Tensor] = None) -> torch.Tensor:
98
+ B, L, _ = x.shape # Batch, Sequence length, Feature dimension
99
+
100
+ # Positional and type embeddings
101
+ type_embed = self.type_embed[0].expand(B, L, -1)
102
+ if x_previous is not None:
103
+ x = torch.cat((x, x_previous), dim=1)
104
+ pos_embed = torch.cat((pos_embed, pos_embed), dim=1)
105
+ prev_type_embed = self.type_embed[1].expand(B, L, -1)
106
+ type_embed = torch.cat((type_embed, prev_type_embed), dim=1)
107
+
108
+ # Add positional and type embeddings (used in query and key matching)
109
+ pos_and_type_embed = pos_embed + type_embed
110
+
111
+ # Positional dropout
112
+ x = self.pos_drop(x)
113
+
114
+ # Multihead attention followed by MLP
115
+ for block in self.blocks:
116
+ x = block(x=x, pos_and_type_embed=pos_and_type_embed)
117
+ x = self.norm_post(x)
118
+
119
+ return x
120
+
121
+ def _init_weights(self, m: nn.Module) -> None:
122
+ if isinstance(m, nn.Linear):
123
+ trunc_normal_(m.weight, std=.02)
124
+ if isinstance(m, nn.Linear) and m.bias is not None:
125
+ nn.init.constant_(m.bias, 0)
126
+ elif isinstance(m, nn.LayerNorm):
127
+ nn.init.constant_(m.bias, 0)
128
+ nn.init.constant_(m.weight, 1.0)
129
+
130
+
131
+ class MultiHeadAttentionLayer(nn.Module):
132
+ """
133
+ Multi-head self attention module
134
+
135
+ The content builds on top of the TIMM library (vision_transformer.py) and differs by the following:
136
+ - Defines a custom `MultiHeadAttentionLayer` which does not only apply `self-attention` but it can be
137
+ generalised to arbitrary (query, key, value) input tuples. This feature can be valuable to process
138
+ more than 2 scans at a time.
139
+ - `Self-attention` specific use-case can still be invoked by calling the `forward_as_mhsa` method.
140
+ """
141
+
142
+ def __init__(self,
143
+ dim: int,
144
+ num_heads: int = 8,
145
+ qkv_bias: bool = False,
146
+ attn_drop: float = 0.,
147
+ proj_drop: float = 0.) -> None:
148
+ super().__init__()
149
+ self.num_heads = num_heads
150
+ assert dim % num_heads == 0, f"The embedding dim ({dim}) must be divisible by the number of heads ({num_heads})"
151
+ head_dim = dim // num_heads
152
+ self.scale = head_dim ** -0.5
153
+ self.return_attention = False
154
+
155
+ self.proj_q = nn.Linear(dim, dim, bias=qkv_bias)
156
+ self.proj_k = nn.Linear(dim, dim, bias=qkv_bias)
157
+ self.proj_v = nn.Linear(dim, dim, bias=qkv_bias)
158
+
159
+ self.attn_drop = nn.Dropout(attn_drop)
160
+ self.proj = nn.Linear(dim, dim)
161
+ self.proj_drop = nn.Dropout(proj_drop)
162
+
163
+ def forward(self, k: torch.Tensor, q: torch.Tensor, v: torch.Tensor) -> MultiHeadAttentionOutput:
164
+ B, N, C = v.shape
165
+ assert C % self.num_heads == 0, \
166
+ f"The embedding dim ({C}) must be divisible by the number of heads ({self.num_heads})"
167
+
168
+ w_q = self.proj_q(q).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
169
+ w_k = self.proj_k(k).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
170
+ w_v = self.proj_v(v).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
171
+
172
+ attn = (w_q @ w_k.transpose(-2, -1)) * self.scale
173
+ attn = attn.softmax(dim=-1)
174
+ attn = self.attn_drop(attn)
175
+
176
+ o = (attn @ w_v).transpose(1, 2).reshape(B, N, C)
177
+ o = self.proj(o)
178
+ o = self.proj_drop(o)
179
+
180
+ attention_output = attn if self.return_attention else None
181
+
182
+ return MultiHeadAttentionOutput(mha_output=o, attention=attention_output)
183
+
184
+ def forward_as_mhsa(self, input: torch.Tensor) -> MultiHeadAttentionOutput:
185
+ return self(k=input, q=input, v=input)
186
+
187
+
188
+ class Block(nn.Module):
189
+ """
190
+ Encapsulates multi-layer perceptron and multi-head self attention modules into a block.
191
+
192
+ The content builds on top of the TIMM library (vision_transformer.py) and differs by the following:
193
+ - This implementation uses spatio-temporal positional embeddings instead of 2D positional embeddings only,
194
+ and they are taken into account within the forward pass of each ViT block.
195
+ - Utilises the custom defined `MultiHeadAttentionLayer` which does not apply `self-attention` only but can be
196
+ generalised to arbitrary (query, key, value) tuples. This can be valuable to process more than 2 scans.
197
+
198
+ Positional and type embeddings are handled in a similar fashion as DETR object localisation paper
199
+ https://alcinos.github.io/detr_page/, where a fixed set of sine/cos positional embeddings are used
200
+ in an additive manner to Q and K tensors.
201
+ """
202
+
203
+ def __init__(self, dim: int, num_heads: int, mlp_ratio: float = 1., qkv_bias: bool = False, drop: float = 0.,
204
+ attn_drop: float = 0., drop_path: float = 0., act_layer: Callable = nn.GELU,
205
+ norm_layer: Callable = nn.LayerNorm) -> None:
206
+ super().__init__()
207
+ self.norm1 = norm_layer(dim)
208
+ self.attn = MultiHeadAttentionLayer(dim=dim, num_heads=num_heads, qkv_bias=qkv_bias,
209
+ attn_drop=attn_drop, proj_drop=drop)
210
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
211
+ self.norm2 = norm_layer(dim)
212
+ mlp_hidden_dim = int(dim * mlp_ratio)
213
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
214
+
215
+ def with_pos_and_type_embed(self, tensor: torch.Tensor, emb: Optional[torch.Tensor]) -> torch.Tensor:
216
+ # Add positional embeddings to key and query tensors
217
+ return tensor if emb is None else tensor + emb
218
+
219
+ def forward(self, x: torch.Tensor, pos_and_type_embed: Optional[torch.Tensor]) -> torch.Tensor:
220
+ x_with_emb = self.with_pos_and_type_embed(self.norm1(x), emb=pos_and_type_embed)
221
+ x = x + self.drop_path(self.attn.forward_as_mhsa(x_with_emb).mha_output)
222
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
223
+
224
+ return x
225
+
226
+
227
+ class SinePositionEmbedding():
228
+ """
229
+ This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
230
+ need paper, generalized to work on images.
231
+ """
232
+
233
+ def __init__(self,
234
+ embedding_dim: int = 64,
235
+ temperature: int = 10000,
236
+ normalize: bool = False,
237
+ scale: float = None) -> None:
238
+ super().__init__()
239
+ self.embedding_dim = embedding_dim
240
+ self.temperature = temperature
241
+ self.normalize = normalize
242
+ if scale is not None and normalize is False:
243
+ raise ValueError("normalize should be True if scale is passed")
244
+ if scale is None:
245
+ scale = 2 * math.pi
246
+ self.scale = scale
247
+
248
+ def __call__(self, mask: torch.Tensor) -> torch.Tensor:
249
+ assert mask is not None, "No pixel mask provided"
250
+ B, H, W = mask.shape
251
+ y_embed = mask.cumsum(1, dtype=torch.float32)
252
+ x_embed = mask.cumsum(2, dtype=torch.float32)
253
+ if self.normalize:
254
+ y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * self.scale
255
+ x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale
256
+
257
+ dim_t = torch.arange(self.embedding_dim, dtype=torch.float32)
258
+ dim_t = self.temperature ** (2 * torch_int_div(dim_t, 2) / self.embedding_dim)
259
+
260
+ pos_x = x_embed[:, :, :, None] / dim_t
261
+ pos_y = y_embed[:, :, :, None] / dim_t
262
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
263
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
264
+ pos = torch.cat((pos_y, pos_x), dim=3).view(B, H * W, self.embedding_dim * 2)
265
+
266
+ return pos
LLAVA_Biovil/biovil_t/types.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -------------------------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
4
+ # -------------------------------------------------------------------------------------------
5
+
6
+
7
+ from __future__ import annotations
8
+
9
+ from dataclasses import dataclass
10
+ from enum import Enum, unique
11
+ from typing import List
12
+
13
+ import torch
14
+
15
+
16
+ @dataclass
17
+ class ImageModelOutput():
18
+ img_embedding: torch.Tensor
19
+ patch_embeddings: torch.Tensor
20
+ projected_global_embedding: torch.Tensor
21
+ class_logits: torch.Tensor
22
+ projected_patch_embeddings: torch.Tensor
23
+
24
+
25
+ @unique
26
+ class ImageEncoderType(str, Enum):
27
+ RESNET18 = "resnet18"
28
+ RESNET50 = "resnet50"
29
+ RESNET18_MULTI_IMAGE = "resnet18_multi_image"
30
+ RESNET50_MULTI_IMAGE = "resnet50_multi_image"
31
+
32
+ @classmethod
33
+ def get_members(cls, multi_image_encoders_only: bool) -> List[ImageEncoderType]:
34
+ if multi_image_encoders_only:
35
+ return [cls.RESNET18_MULTI_IMAGE, cls.RESNET50_MULTI_IMAGE]
36
+ else:
37
+ return [member for member in cls]
LLAVA_Biovil/cog.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration for Cog ⚙️
2
+ # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
3
+
4
+ build:
5
+ gpu: true
6
+
7
+ python_version: "3.11"
8
+
9
+ python_packages:
10
+ - "torch==2.0.1"
11
+ - "accelerate==0.21.0"
12
+ - "bitsandbytes==0.41.0"
13
+ - "deepspeed==0.9.5"
14
+ - "einops-exts==0.0.4"
15
+ - "einops==0.6.1"
16
+ - "gradio==3.35.2"
17
+ - "gradio_client==0.2.9"
18
+ - "httpx==0.24.0"
19
+ - "markdown2==2.4.10"
20
+ - "numpy==1.26.0"
21
+ - "peft==0.4.0"
22
+ - "scikit-learn==1.2.2"
23
+ - "sentencepiece==0.1.99"
24
+ - "shortuuid==1.0.11"
25
+ - "timm==0.6.13"
26
+ - "tokenizers==0.13.3"
27
+ - "torch==2.0.1"
28
+ - "torchvision==0.15.2"
29
+ - "transformers==4.31.0"
30
+ - "wandb==0.15.12"
31
+ - "wavedrom==2.0.3.post3"
32
+ - "Pygments==2.16.1"
33
+ run:
34
+ - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.0.3/pget" && chmod +x /usr/local/bin/pget
35
+
36
+ # predict.py defines how predictions are run on your model
37
+ predict: "predict.py:Predictor"
LLAVA_Biovil/install.md ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ step 1: clone Llava
2
+ step 2: git clone https://github.com/Dao-AILab/flash-attention.git
3
+ step 3: conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia
4
+ step 4: pip install -e .
5
+ step 5: pip install -e ".[train]"
6
+ step 6: in flash attention folder, run: python setup.py install
LLAVA_Biovil/llava/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model import LlavaLlamaForCausalLM
LLAVA_Biovil/llava/constants.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
2
+ WORKER_HEART_BEAT_INTERVAL = 15
3
+
4
+ LOGDIR = "."
5
+
6
+ # Model Constants
7
+ IGNORE_INDEX = -100
8
+ IMAGE_TOKEN_INDEX = -200
9
+ DEFAULT_IMAGE_TOKEN = "<image>"
10
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
11
+ DEFAULT_IM_START_TOKEN = "<im_start>"
12
+ DEFAULT_IM_END_TOKEN = "<im_end>"
13
+ IMAGE_PLACEHOLDER = "<image-placeholder>"
LLAVA_Biovil/llava/conversation.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Tuple
4
+
5
+
6
+ class SeparatorStyle(Enum):
7
+ """Different separator style."""
8
+ SINGLE = auto()
9
+ TWO = auto()
10
+ MPT = auto()
11
+ PLAIN = auto()
12
+ LLAMA_2 = auto()
13
+
14
+
15
+ @dataclasses.dataclass
16
+ class Conversation:
17
+ """A class that keeps all conversation history."""
18
+ system: str
19
+ roles: List[str]
20
+ messages: List[List[str]]
21
+ offset: int
22
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
23
+ sep: str = "###"
24
+ sep2: str = None
25
+ version: str = "Unknown"
26
+
27
+ skip_next: bool = False
28
+
29
+ def get_prompt(self):
30
+ messages = self.messages
31
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
32
+ messages = self.messages.copy()
33
+ init_role, init_msg = messages[0].copy()
34
+ init_msg = init_msg[0].replace("<image>", "").strip()
35
+ if 'mmtag' in self.version:
36
+ messages[0] = (init_role, init_msg)
37
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
38
+ messages.insert(1, (self.roles[1], "Received."))
39
+ else:
40
+ messages[0] = (init_role, "<image>\n" + init_msg)
41
+
42
+ if self.sep_style == SeparatorStyle.SINGLE:
43
+ ret = self.system + self.sep
44
+ for role, message in messages:
45
+ if message:
46
+ if type(message) is tuple:
47
+ message, _, _ = message
48
+ ret += role + ": " + message + self.sep
49
+ else:
50
+ ret += role + ":"
51
+ elif self.sep_style == SeparatorStyle.TWO:
52
+ seps = [self.sep, self.sep2]
53
+ ret = self.system + seps[0]
54
+ for i, (role, message) in enumerate(messages):
55
+ if message:
56
+ if type(message) is tuple:
57
+ message, _, _ = message
58
+ ret += role + ": " + message + seps[i % 2]
59
+ else:
60
+ ret += role + ":"
61
+ elif self.sep_style == SeparatorStyle.MPT:
62
+ ret = self.system + self.sep
63
+ for role, message in messages:
64
+ if message:
65
+ if type(message) is tuple:
66
+ message, _, _ = message
67
+ ret += role + message + self.sep
68
+ else:
69
+ ret += role
70
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
71
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
72
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
73
+ ret = ""
74
+
75
+ for i, (role, message) in enumerate(messages):
76
+ if i == 0:
77
+ assert message, "first message should not be none"
78
+ assert role == self.roles[0], "first message should come from user"
79
+ if message:
80
+ if type(message) is tuple:
81
+ message, _, _ = message
82
+ if i == 0: message = wrap_sys(self.system) + message
83
+ if i % 2 == 0:
84
+ message = wrap_inst(message)
85
+ ret += self.sep + message
86
+ else:
87
+ ret += " " + message + " " + self.sep2
88
+ else:
89
+ ret += ""
90
+ ret = ret.lstrip(self.sep)
91
+ elif self.sep_style == SeparatorStyle.PLAIN:
92
+ seps = [self.sep, self.sep2]
93
+ ret = self.system
94
+ for i, (role, message) in enumerate(messages):
95
+ if message:
96
+ if type(message) is tuple:
97
+ message, _, _ = message
98
+ ret += message + seps[i % 2]
99
+ else:
100
+ ret += ""
101
+ else:
102
+ raise ValueError(f"Invalid style: {self.sep_style}")
103
+
104
+ return ret
105
+
106
+ def append_message(self, role, message):
107
+ self.messages.append([role, message])
108
+
109
+ def get_images(self, return_pil=False):
110
+ images = []
111
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
112
+ if i % 2 == 0:
113
+ if type(msg) is tuple:
114
+ import base64
115
+ from io import BytesIO
116
+ from PIL import Image
117
+ msg, image, image_process_mode = msg
118
+ if image_process_mode == "Pad":
119
+ def expand2square(pil_img, background_color=(122, 116, 104)):
120
+ width, height = pil_img.size
121
+ if width == height:
122
+ return pil_img
123
+ elif width > height:
124
+ result = Image.new(pil_img.mode, (width, width), background_color)
125
+ result.paste(pil_img, (0, (width - height) // 2))
126
+ return result
127
+ else:
128
+ result = Image.new(pil_img.mode, (height, height), background_color)
129
+ result.paste(pil_img, ((height - width) // 2, 0))
130
+ return result
131
+ image = expand2square(image)
132
+ elif image_process_mode in ["Default", "Crop"]:
133
+ pass
134
+ elif image_process_mode == "Resize":
135
+ image = image.resize((336, 336))
136
+ else:
137
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
138
+ max_hw, min_hw = max(image.size), min(image.size)
139
+ aspect_ratio = max_hw / min_hw
140
+ max_len, min_len = 800, 400
141
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
142
+ longest_edge = int(shortest_edge * aspect_ratio)
143
+ W, H = image.size
144
+ if longest_edge != max(image.size):
145
+ if H > W:
146
+ H, W = longest_edge, shortest_edge
147
+ else:
148
+ H, W = shortest_edge, longest_edge
149
+ image = image.resize((W, H))
150
+ if return_pil:
151
+ images.append(image)
152
+ else:
153
+ buffered = BytesIO()
154
+ image.save(buffered, format="PNG")
155
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
156
+ images.append(img_b64_str)
157
+ return images
158
+
159
+ def to_gradio_chatbot(self):
160
+ ret = []
161
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
162
+ if i % 2 == 0:
163
+ if type(msg) is tuple:
164
+ import base64
165
+ from io import BytesIO
166
+ msg, image, image_process_mode = msg
167
+ max_hw, min_hw = max(image.size), min(image.size)
168
+ aspect_ratio = max_hw / min_hw
169
+ max_len, min_len = 800, 400
170
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
171
+ longest_edge = int(shortest_edge * aspect_ratio)
172
+ W, H = image.size
173
+ if H > W:
174
+ H, W = longest_edge, shortest_edge
175
+ else:
176
+ H, W = shortest_edge, longest_edge
177
+ image = image.resize((W, H))
178
+ buffered = BytesIO()
179
+ image.save(buffered, format="JPEG")
180
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
181
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
182
+ msg = img_str + msg.replace('<image>', '').strip()
183
+ ret.append([msg, None])
184
+ else:
185
+ ret.append([msg, None])
186
+ else:
187
+ ret[-1][-1] = msg
188
+ return ret
189
+
190
+ def copy(self):
191
+ return Conversation(
192
+ system=self.system,
193
+ roles=self.roles,
194
+ messages=[[x, y] for x, y in self.messages],
195
+ offset=self.offset,
196
+ sep_style=self.sep_style,
197
+ sep=self.sep,
198
+ sep2=self.sep2,
199
+ version=self.version)
200
+
201
+ def dict(self):
202
+ if len(self.get_images()) > 0:
203
+ return {
204
+ "system": self.system,
205
+ "roles": self.roles,
206
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
207
+ "offset": self.offset,
208
+ "sep": self.sep,
209
+ "sep2": self.sep2,
210
+ }
211
+ return {
212
+ "system": self.system,
213
+ "roles": self.roles,
214
+ "messages": self.messages,
215
+ "offset": self.offset,
216
+ "sep": self.sep,
217
+ "sep2": self.sep2,
218
+ }
219
+
220
+
221
+ conv_vicuna_v0 = Conversation(
222
+ system="A chat between a curious human and an artificial intelligence assistant. "
223
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
224
+ roles=("Human", "Assistant"),
225
+ messages=(
226
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
227
+ ("Assistant",
228
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
229
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
230
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
231
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
232
+ "renewable and non-renewable energy sources:\n"
233
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
234
+ "energy sources are finite and will eventually run out.\n"
235
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
236
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
237
+ "and other negative effects.\n"
238
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
239
+ "have lower operational costs than non-renewable sources.\n"
240
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
241
+ "locations than non-renewable sources.\n"
242
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
243
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
244
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
245
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
246
+ ),
247
+ offset=2,
248
+ sep_style=SeparatorStyle.SINGLE,
249
+ sep="###",
250
+ )
251
+
252
+ conv_vicuna_v1 = Conversation(
253
+ # system="A chat between a curious user and an artificial intelligence assistant. "
254
+ # "The assistant gives helpful, detailed, and polite answers to the user's questions.",
255
+ system="A chat between a curious user and an artificial intelligence assistant acting as an experienced radiologist. "
256
+ "The assistant gives professional, detailed, and polite answers to the user's questions.",
257
+ roles=("USER", "ASSISTANT"),
258
+ version="v1",
259
+ messages=[],
260
+ offset=0,
261
+ sep_style=SeparatorStyle.TWO,
262
+ sep=" ",
263
+ sep2="</s>",
264
+ )
265
+
266
+ conv_llava_med = Conversation(
267
+ system="A chat between a curious user and an artificial intelligence assistant acting as an experienced radiologist. "
268
+ "The assistant gives professional, detailed, and polite answers to the user's questions.",
269
+ roles=("USER", "ASSISTANT"),
270
+ version="v1",
271
+ messages=[],
272
+ offset=2,
273
+ sep_style=SeparatorStyle.TWO,
274
+ sep="###",
275
+ sep2="</s>"
276
+ )
277
+
278
+ simple_conv_multimodal = Conversation(
279
+ system="You are LLaVA-Med, a large language and vision assistant trained by a group of researchers at Microsoft, based on the general domain LLaVA architecture."
280
+ "You are able to understand the visual content that the user provides, and assist the user with a variety of medical and clinical tasks using natural language."
281
+ "Follow the instructions carefully and explain your answers in detail.",
282
+ roles=("Human", "Assistant"),
283
+ messages=(
284
+ ("Human", "Hi!"),
285
+ ("Assistant", "Hi there! How can I help you today?\n")
286
+ ),
287
+ offset=2,
288
+ sep_style=SeparatorStyle.SINGLE,
289
+ sep="###",
290
+ )
291
+
292
+ conv_llama_2 = Conversation(
293
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
294
+
295
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
296
+ roles=("USER", "ASSISTANT"),
297
+ version="llama_v2",
298
+ messages=(),
299
+ offset=0,
300
+ sep_style=SeparatorStyle.LLAMA_2,
301
+ sep="<s>",
302
+ sep2="</s>",
303
+ )
304
+
305
+ conv_llava_llama_2 = Conversation(
306
+ system="You are a helpful language and vision assistant. "
307
+ "You are able to understand the visual content that the user provides, "
308
+ "and assist the user with a variety of tasks using natural language.",
309
+ roles=("USER", "ASSISTANT"),
310
+ version="llama_v2",
311
+ messages=(),
312
+ offset=0,
313
+ sep_style=SeparatorStyle.LLAMA_2,
314
+ sep="<s>",
315
+ sep2="</s>",
316
+ )
317
+
318
+ conv_mpt = Conversation(
319
+ system="""<|im_start|>system
320
+ A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
321
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
322
+ version="mpt",
323
+ messages=(),
324
+ offset=0,
325
+ sep_style=SeparatorStyle.MPT,
326
+ sep="<|im_end|>",
327
+ )
328
+
329
+ conv_llava_plain = Conversation(
330
+ system="",
331
+ roles=("", ""),
332
+ messages=(
333
+ ),
334
+ offset=0,
335
+ sep_style=SeparatorStyle.PLAIN,
336
+ sep="\n",
337
+ )
338
+
339
+ conv_llava_v0 = Conversation(
340
+ system="A chat between a curious human and an artificial intelligence assistant. "
341
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
342
+ roles=("Human", "Assistant"),
343
+ messages=(
344
+ ),
345
+ offset=0,
346
+ sep_style=SeparatorStyle.SINGLE,
347
+ sep="###",
348
+ )
349
+
350
+ conv_llava_v0_mmtag = Conversation(
351
+ system="A chat between a curious user and an artificial intelligence assistant. "
352
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
353
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
354
+ roles=("Human", "Assistant"),
355
+ messages=(
356
+ ),
357
+ offset=0,
358
+ sep_style=SeparatorStyle.SINGLE,
359
+ sep="###",
360
+ version="v0_mmtag",
361
+ )
362
+
363
+ conv_llava_v1 = Conversation(
364
+ # system="A chat between a curious human and an artificial intelligence assistant. "
365
+ # "The assistant gives helpful, detailed, and polite answers to the human's questions.",
366
+ system="A chat between a curious user and an artificial intelligence assistant acting as an experienced radiologist. "
367
+ "The assistant gives professional, detailed, and polite answers to the user's questions.",
368
+ roles=("USER", "ASSISTANT"),
369
+ version="v1",
370
+ messages=(),
371
+ offset=0,
372
+ sep_style=SeparatorStyle.TWO,
373
+ sep=" ",
374
+ sep2="</s>",
375
+ )
376
+
377
+
378
+ conv_llava_v1_mmtag = Conversation(
379
+ system="A chat between a curious user and an artificial intelligence assistant. "
380
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
381
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
382
+ roles=("USER", "ASSISTANT"),
383
+ messages=(),
384
+ offset=0,
385
+ sep_style=SeparatorStyle.TWO,
386
+ sep=" ",
387
+ sep2="</s>",
388
+ version="v1_mmtag",
389
+ )
390
+
391
+ default_conversation = conv_vicuna_v1
392
+ conv_templates = {
393
+ "default": conv_vicuna_v0,
394
+ "v0": conv_vicuna_v0,
395
+ "v1": conv_vicuna_v1,
396
+ "llava_med": conv_llava_med,
397
+ "vicuna_v1": conv_vicuna_v1,
398
+ "llama_2": conv_llama_2,
399
+
400
+ "plain": conv_llava_plain,
401
+ "v0_plain": conv_llava_plain,
402
+ "llava_v0": conv_llava_v0,
403
+ "v0_mmtag": conv_llava_v0_mmtag,
404
+ "llava_v1": conv_llava_v1,
405
+ "v1_mmtag": conv_llava_v1_mmtag,
406
+ "llava_llama_2": conv_llava_llama_2,
407
+ "multimodal": simple_conv_multimodal,
408
+
409
+ "mpt": conv_mpt,
410
+ }
411
+
412
+
413
+ if __name__ == "__main__":
414
+ print(default_conversation.get_prompt())
LLAVA_Biovil/llava/eval/__init__.py ADDED
File without changes
LLAVA_Biovil/llava/eval/eval_gpt_review.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+
5
+ import openai
6
+ import tqdm
7
+ import ray
8
+ import time
9
+
10
+ NUM_SECONDS_TO_SLEEP = 3
11
+
12
+ @ray.remote(num_cpus=4)
13
+ def get_eval(content: str, max_tokens: int):
14
+ while True:
15
+ try:
16
+ response = openai.ChatCompletion.create(
17
+ model='gpt-4',
18
+ messages=[{
19
+ 'role': 'system',
20
+ 'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
21
+ }, {
22
+ 'role': 'user',
23
+ 'content': content,
24
+ }],
25
+ temperature=0.2, # TODO: figure out which temperature is best for evaluation
26
+ max_tokens=max_tokens,
27
+ )
28
+ break
29
+ except openai.error.RateLimitError:
30
+ pass
31
+ except Exception as e:
32
+ print(e)
33
+ time.sleep(NUM_SECONDS_TO_SLEEP)
34
+
35
+ print('success!')
36
+ return response['choices'][0]['message']['content']
37
+
38
+
39
+ def parse_score(review):
40
+ try:
41
+ score_pair = review.split('\n')[0]
42
+ score_pair = score_pair.replace(',', ' ')
43
+ sp = score_pair.split(' ')
44
+ if len(sp) == 2:
45
+ return [float(sp[0]), float(sp[1])]
46
+ else:
47
+ print('error', review)
48
+ return [-1, -1]
49
+ except Exception as e:
50
+ print(e)
51
+ print('error', review)
52
+ return [-1, -1]
53
+
54
+
55
+ if __name__ == '__main__':
56
+ parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
57
+ parser.add_argument('-q', '--question')
58
+ # parser.add_argument('-a', '--answer')
59
+ parser.add_argument('-a', '--answer-list', nargs='+', default=[])
60
+ parser.add_argument('-r', '--rule')
61
+ parser.add_argument('-o', '--output')
62
+ parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
63
+ args = parser.parse_args()
64
+
65
+ ray.init()
66
+
67
+ f_q = open(os.path.expanduser(args.question))
68
+ f_ans1 = open(os.path.expanduser(args.answer_list[0]))
69
+ f_ans2 = open(os.path.expanduser(args.answer_list[1]))
70
+ rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))
71
+
72
+ review_file = open(f'{args.output}', 'w')
73
+
74
+ js_list = []
75
+ handles = []
76
+ idx = 0
77
+ for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2):
78
+ # if idx == 1:
79
+ # break
80
+
81
+ ques = json.loads(ques_js)
82
+ ans1 = json.loads(ans1_js)
83
+ ans2 = json.loads(ans2_js)
84
+
85
+ category = json.loads(ques_js)['category']
86
+ if category in rule_dict:
87
+ rule = rule_dict[category]
88
+ else:
89
+ rule = rule_dict['default']
90
+ prompt = rule['prompt']
91
+ role = rule['role']
92
+ content = (f'[Question]\n{ques["text"]}\n\n'
93
+ f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
94
+ f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
95
+ f'[System]\n{prompt}\n\n')
96
+ js_list.append({
97
+ 'id': idx+1,
98
+ 'question_id': ques['question_id'],
99
+ 'answer1_id': ans1['answer_id'],
100
+ 'answer2_id': ans2['answer_id'],
101
+ 'category': category})
102
+ idx += 1
103
+ handles.append(get_eval.remote(content, args.max_tokens))
104
+ # To avoid the rate limit set by OpenAI
105
+ time.sleep(NUM_SECONDS_TO_SLEEP)
106
+
107
+ reviews = ray.get(handles)
108
+ for idx, review in enumerate(reviews):
109
+ scores = parse_score(review)
110
+ js_list[idx]['content'] = review
111
+ js_list[idx]['tuple'] = scores
112
+ review_file.write(json.dumps(js_list[idx]) + '\n')
113
+ review_file.close()
LLAVA_Biovil/llava/eval/eval_gpt_review_bench.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+
5
+ import openai
6
+ import time
7
+
8
+ NUM_SECONDS_TO_SLEEP = 0.5
9
+
10
+
11
+ def get_eval(content: str, max_tokens: int):
12
+ while True:
13
+ try:
14
+ response = openai.ChatCompletion.create(
15
+ model='gpt-4-0314',
16
+ messages=[{
17
+ 'role': 'system',
18
+ 'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
19
+ }, {
20
+ 'role': 'user',
21
+ 'content': content,
22
+ }],
23
+ temperature=0.2, # TODO: figure out which temperature is best for evaluation
24
+ max_tokens=max_tokens,
25
+ )
26
+ break
27
+ except openai.error.RateLimitError:
28
+ pass
29
+ except Exception as e:
30
+ print(e)
31
+ time.sleep(NUM_SECONDS_TO_SLEEP)
32
+
33
+ return response['choices'][0]['message']['content']
34
+
35
+
36
+ def parse_score(review):
37
+ try:
38
+ score_pair = review.split('\n')[0]
39
+ score_pair = score_pair.replace(',', ' ')
40
+ sp = score_pair.split(' ')
41
+ if len(sp) == 2:
42
+ return [float(sp[0]), float(sp[1])]
43
+ else:
44
+ print('error', review)
45
+ return [-1, -1]
46
+ except Exception as e:
47
+ print(e)
48
+ print('error', review)
49
+ return [-1, -1]
50
+
51
+
52
+ if __name__ == '__main__':
53
+ parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
54
+ parser.add_argument('-q', '--question')
55
+ parser.add_argument('-c', '--context')
56
+ parser.add_argument('-a', '--answer-list', nargs='+', default=[])
57
+ parser.add_argument('-r', '--rule')
58
+ parser.add_argument('-o', '--output')
59
+ parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
60
+ args = parser.parse_args()
61
+
62
+ f_q = open(os.path.expanduser(args.question))
63
+ f_ans1 = open(os.path.expanduser(args.answer_list[0]))
64
+ f_ans2 = open(os.path.expanduser(args.answer_list[1]))
65
+ rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))
66
+
67
+ if os.path.isfile(os.path.expanduser(args.output)):
68
+ cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))]
69
+ else:
70
+ cur_reviews = []
71
+
72
+ review_file = open(f'{args.output}', 'a')
73
+
74
+ context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))]
75
+ image_to_context = {context['image']: context for context in context_list}
76
+
77
+ handles = []
78
+ idx = 0
79
+ for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2):
80
+ ques = json.loads(ques_js)
81
+ ans1 = json.loads(ans1_js)
82
+ ans2 = json.loads(ans2_js)
83
+
84
+ inst = image_to_context[ques['image']]
85
+
86
+ if isinstance(inst['caption'], list):
87
+ cap_str = '\n'.join(inst['caption'])
88
+ else:
89
+ cap_str = inst['caption']
90
+
91
+ category = 'llava_bench_' + json.loads(ques_js)['category']
92
+ if category in rule_dict:
93
+ rule = rule_dict[category]
94
+ else:
95
+ assert False, f"Visual QA category not found in rule file: {category}."
96
+ prompt = rule['prompt']
97
+ role = rule['role']
98
+ content = (f'[Context]\n{cap_str}\n\n'
99
+ f'[Question]\n{ques["text"]}\n\n'
100
+ f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
101
+ f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
102
+ f'[System]\n{prompt}\n\n')
103
+ cur_js = {
104
+ 'id': idx+1,
105
+ 'question_id': ques['question_id'],
106
+ 'answer1_id': ans1.get('answer_id', ans1['question_id']),
107
+ 'answer2_id': ans2.get('answer_id', ans2['answer_id']),
108
+ 'category': category
109
+ }
110
+ if idx >= len(cur_reviews):
111
+ review = get_eval(content, args.max_tokens)
112
+ scores = parse_score(review)
113
+ cur_js['content'] = review
114
+ cur_js['tuple'] = scores
115
+ review_file.write(json.dumps(cur_js) + '\n')
116
+ review_file.flush()
117
+ else:
118
+ print(f'Skipping {idx} as we already have it.')
119
+ idx += 1
120
+ print(idx)
121
+ review_file.close()
LLAVA_Biovil/llava/eval/eval_gpt_review_visual.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+
5
+ import openai
6
+ import time
7
+
8
+ NUM_SECONDS_TO_SLEEP = 0.5
9
+
10
+
11
+ def get_eval(content: str, max_tokens: int):
12
+ while True:
13
+ try:
14
+ response = openai.ChatCompletion.create(
15
+ model='gpt-4-0314',
16
+ messages=[{
17
+ 'role': 'system',
18
+ 'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
19
+ }, {
20
+ 'role': 'user',
21
+ 'content': content,
22
+ }],
23
+ temperature=0.2, # TODO: figure out which temperature is best for evaluation
24
+ max_tokens=max_tokens,
25
+ )
26
+ break
27
+ except openai.error.RateLimitError:
28
+ pass
29
+ except Exception as e:
30
+ print(e)
31
+ time.sleep(NUM_SECONDS_TO_SLEEP)
32
+
33
+ return response['choices'][0]['message']['content']
34
+
35
+
36
+ def parse_score(review):
37
+ try:
38
+ score_pair = review.split('\n')[0]
39
+ score_pair = score_pair.replace(',', ' ')
40
+ sp = score_pair.split(' ')
41
+ if len(sp) == 2:
42
+ return [float(sp[0]), float(sp[1])]
43
+ else:
44
+ print('error', review)
45
+ return [-1, -1]
46
+ except Exception as e:
47
+ print(e)
48
+ print('error', review)
49
+ return [-1, -1]
50
+
51
+
52
+ if __name__ == '__main__':
53
+ parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
54
+ parser.add_argument('-q', '--question')
55
+ parser.add_argument('-c', '--context')
56
+ parser.add_argument('-a', '--answer-list', nargs='+', default=[])
57
+ parser.add_argument('-r', '--rule')
58
+ parser.add_argument('-o', '--output')
59
+ parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
60
+ args = parser.parse_args()
61
+
62
+ f_q = open(os.path.expanduser(args.question))
63
+ f_ans1 = open(os.path.expanduser(args.answer_list[0]))
64
+ f_ans2 = open(os.path.expanduser(args.answer_list[1]))
65
+ rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))
66
+
67
+ if os.path.isfile(os.path.expanduser(args.output)):
68
+ cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))]
69
+ else:
70
+ cur_reviews = []
71
+
72
+ review_file = open(f'{args.output}', 'a')
73
+
74
+ context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))]
75
+ image_to_context = {context['image']: context for context in context_list}
76
+
77
+ handles = []
78
+ idx = 0
79
+ for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2):
80
+ ques = json.loads(ques_js)
81
+ ans1 = json.loads(ans1_js)
82
+ ans2 = json.loads(ans2_js)
83
+
84
+ inst = image_to_context[ques['image']]
85
+ cap_str = '\n'.join(inst['captions'])
86
+ box_str = '\n'.join([f'{instance["category"]}: {instance["bbox"]}' for instance in inst['instances']])
87
+
88
+ category = json.loads(ques_js)['category']
89
+ if category in rule_dict:
90
+ rule = rule_dict[category]
91
+ else:
92
+ assert False, f"Visual QA category not found in rule file: {category}."
93
+ prompt = rule['prompt']
94
+ role = rule['role']
95
+ content = (f'[Context]\n{cap_str}\n\n{box_str}\n\n'
96
+ f'[Question]\n{ques["text"]}\n\n'
97
+ f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
98
+ f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
99
+ f'[System]\n{prompt}\n\n')
100
+ cur_js = {
101
+ 'id': idx+1,
102
+ 'question_id': ques['question_id'],
103
+ 'answer1_id': ans1.get('answer_id', ans1['question_id']),
104
+ 'answer2_id': ans2.get('answer_id', ans2['answer_id']),
105
+ 'category': category
106
+ }
107
+ if idx >= len(cur_reviews):
108
+ review = get_eval(content, args.max_tokens)
109
+ scores = parse_score(review)
110
+ cur_js['content'] = review
111
+ cur_js['tuple'] = scores
112
+ review_file.write(json.dumps(cur_js) + '\n')
113
+ review_file.flush()
114
+ else:
115
+ print(f'Skipping {idx} as we already have it.')
116
+ idx += 1
117
+ print(idx)
118
+ review_file.close()
LLAVA_Biovil/llava/eval/eval_pope.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import argparse
4
+
5
+ def eval_pope(answers, label_file):
6
+ label_list = [json.loads(q)['label'] for q in open(label_file, 'r')]
7
+
8
+ for answer in answers:
9
+ text = answer['text']
10
+
11
+ # Only keep the first sentence
12
+ if text.find('.') != -1:
13
+ text = text.split('.')[0]
14
+
15
+ text = text.replace(',', '')
16
+ words = text.split(' ')
17
+ if 'No' in words or 'not' in words or 'no' in words:
18
+ answer['text'] = 'no'
19
+ else:
20
+ answer['text'] = 'yes'
21
+
22
+ for i in range(len(label_list)):
23
+ if label_list[i] == 'no':
24
+ label_list[i] = 0
25
+ else:
26
+ label_list[i] = 1
27
+
28
+ pred_list = []
29
+ for answer in answers:
30
+ if answer['text'] == 'no':
31
+ pred_list.append(0)
32
+ else:
33
+ pred_list.append(1)
34
+
35
+ pos = 1
36
+ neg = 0
37
+ yes_ratio = pred_list.count(1) / len(pred_list)
38
+
39
+ TP, TN, FP, FN = 0, 0, 0, 0
40
+ for pred, label in zip(pred_list, label_list):
41
+ if pred == pos and label == pos:
42
+ TP += 1
43
+ elif pred == pos and label == neg:
44
+ FP += 1
45
+ elif pred == neg and label == neg:
46
+ TN += 1
47
+ elif pred == neg and label == pos:
48
+ FN += 1
49
+
50
+ print('TP\tFP\tTN\tFN\t')
51
+ print('{}\t{}\t{}\t{}'.format(TP, FP, TN, FN))
52
+
53
+ precision = float(TP) / float(TP + FP)
54
+ recall = float(TP) / float(TP + FN)
55
+ f1 = 2*precision*recall / (precision + recall)
56
+ acc = (TP + TN) / (TP + TN + FP + FN)
57
+ print('Accuracy: {}'.format(acc))
58
+ print('Precision: {}'.format(precision))
59
+ print('Recall: {}'.format(recall))
60
+ print('F1 score: {}'.format(f1))
61
+ print('Yes ratio: {}'.format(yes_ratio))
62
+ print('%.3f, %.3f, %.3f, %.3f, %.3f' % (f1, acc, precision, recall, yes_ratio) )
63
+
64
+ if __name__ == "__main__":
65
+ parser = argparse.ArgumentParser()
66
+ parser.add_argument("--annotation-dir", type=str)
67
+ parser.add_argument("--question-file", type=str)
68
+ parser.add_argument("--result-file", type=str)
69
+ args = parser.parse_args()
70
+
71
+ questions = [json.loads(line) for line in open(args.question_file)]
72
+ questions = {question['question_id']: question for question in questions}
73
+ answers = [json.loads(q) for q in open(args.result_file)]
74
+ for file in os.listdir(args.annotation_dir):
75
+ assert file.startswith('coco_pope_')
76
+ assert file.endswith('.json')
77
+ category = file[10:-5]
78
+ cur_answers = [x for x in answers if questions[x['question_id']]['category'] == category]
79
+ print('Category: {}, # samples: {}'.format(category, len(cur_answers)))
80
+ eval_pope(cur_answers, os.path.join(args.annotation_dir, file))
81
+ print("====================================")
LLAVA_Biovil/llava/eval/eval_science_qa.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import re
5
+ import random
6
+
7
+
8
+ def get_args():
9
+ parser = argparse.ArgumentParser()
10
+ parser.add_argument('--base-dir', type=str)
11
+ parser.add_argument('--result-file', type=str)
12
+ parser.add_argument('--output-file', type=str)
13
+ parser.add_argument('--output-result', type=str)
14
+ parser.add_argument('--split', type=str, default='test')
15
+ parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
16
+ return parser.parse_args()
17
+
18
+
19
+ def convert_caps(results):
20
+ fakecaps = []
21
+ for result in results:
22
+ image_id = result['question_id']
23
+ caption = result['text']
24
+ fakecaps.append({"image_id": int(image_id), "caption": caption})
25
+ return fakecaps
26
+
27
+
28
+ def get_pred_idx(prediction, choices, options):
29
+ """
30
+ Get the index (e.g. 2) from the prediction (e.g. 'C')
31
+ """
32
+ if prediction in options[:len(choices)]:
33
+ return options.index(prediction)
34
+ else:
35
+ return -1
36
+ return random.choice(range(len(choices)))
37
+
38
+
39
+ if __name__ == "__main__":
40
+ args = get_args()
41
+
42
+ base_dir = args.base_dir
43
+ split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split]
44
+ problems = json.load(open(os.path.join(base_dir, "problems.json")))
45
+ predictions = [json.loads(line) for line in open(args.result_file)]
46
+ predictions = {pred['question_id']: pred for pred in predictions}
47
+ split_problems = {idx: problems[idx] for idx in split_indices}
48
+
49
+ results = {'correct': [], 'incorrect': []}
50
+ sqa_results = {}
51
+ sqa_results['acc'] = None
52
+ sqa_results['correct'] = None
53
+ sqa_results['count'] = None
54
+ sqa_results['results'] = {}
55
+ sqa_results['outputs'] = {}
56
+
57
+ for prob_id, prob in split_problems.items():
58
+ if prob_id not in predictions:
59
+ pred = {'text': 'FAILED', 'prompt': 'Unknown'}
60
+ pred_text = 'FAILED'
61
+ else:
62
+ pred = predictions[prob_id]
63
+ pred_text = pred['text']
64
+
65
+ if pred_text in args.options:
66
+ answer = pred_text
67
+ elif len(pred_text) >= 3 and pred_text[0] in args.options and pred_text[1:3] == ". ":
68
+ answer = pred_text[0]
69
+ else:
70
+ pattern = re.compile(r'The answer is ([A-Z]).')
71
+ res = pattern.findall(pred_text)
72
+ if len(res) == 1:
73
+ answer = res[0] # 'A', 'B', ...
74
+ else:
75
+ answer = "FAILED"
76
+
77
+ pred_idx = get_pred_idx(answer, prob['choices'], args.options)
78
+
79
+ analysis = {
80
+ 'question_id': prob_id,
81
+ 'parsed_ans': answer,
82
+ 'ground_truth': args.options[prob['answer']],
83
+ 'question': pred['prompt'],
84
+ 'pred': pred_text,
85
+ 'is_multimodal': '<image>' in pred['prompt'],
86
+ }
87
+
88
+ sqa_results['results'][prob_id] = get_pred_idx(answer, prob['choices'], args.options)
89
+ sqa_results['outputs'][prob_id] = pred_text
90
+
91
+ if pred_idx == prob['answer']:
92
+ results['correct'].append(analysis)
93
+ else:
94
+ results['incorrect'].append(analysis)
95
+
96
+ correct = len(results['correct'])
97
+ total = len(results['correct']) + len(results['incorrect'])
98
+
99
+ ###### IMG ######
100
+ multimodal_correct = len([x for x in results['correct'] if x['is_multimodal']])
101
+ multimodal_incorrect = len([x for x in results['incorrect'] if x['is_multimodal']])
102
+ multimodal_total = multimodal_correct + multimodal_incorrect
103
+ ###### IMG ######
104
+
105
+ print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%, IMG-Accuracy: {multimodal_correct / multimodal_total * 100:.2f}%')
106
+
107
+ sqa_results['acc'] = correct / total * 100
108
+ sqa_results['correct'] = correct
109
+ sqa_results['count'] = total
110
+
111
+ with open(args.output_file, 'w') as f:
112
+ json.dump(results, f, indent=2)
113
+ with open(args.output_result, 'w') as f:
114
+ json.dump(sqa_results, f, indent=2)
LLAVA_Biovil/llava/eval/eval_science_qa_gpt4.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import re
5
+ import random
6
+ from collections import defaultdict
7
+
8
+
9
+ def get_args():
10
+ parser = argparse.ArgumentParser()
11
+ parser.add_argument('--base-dir', type=str)
12
+ parser.add_argument('--gpt4-result', type=str)
13
+ parser.add_argument('--our-result', type=str)
14
+ parser.add_argument('--split', type=str, default='test')
15
+ parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
16
+ return parser.parse_args()
17
+
18
+
19
+ def convert_caps(results):
20
+ fakecaps = []
21
+ for result in results:
22
+ image_id = result['question_id']
23
+ caption = result['text']
24
+ fakecaps.append({"image_id": int(image_id), "caption": caption})
25
+ return fakecaps
26
+
27
+
28
+ def get_pred_idx(prediction, choices, options):
29
+ """
30
+ Get the index (e.g. 2) from the prediction (e.g. 'C')
31
+ """
32
+ if prediction in options[:len(choices)]:
33
+ return options.index(prediction)
34
+ else:
35
+ return random.choice(range(len(choices)))
36
+
37
+
38
+ if __name__ == "__main__":
39
+ args = get_args()
40
+
41
+ base_dir = args.base_dir
42
+ split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split]
43
+ problems = json.load(open(os.path.join(base_dir, "problems.json")))
44
+ our_predictions = [json.loads(line) for line in open(args.our_result)]
45
+ our_predictions = {pred['question_id']: pred for pred in our_predictions}
46
+ split_problems = {idx: problems[idx] for idx in split_indices}
47
+
48
+ gpt4_predictions = json.load(open(args.gpt4_result))['outputs']
49
+
50
+ results = defaultdict(lambda: 0)
51
+
52
+ for prob_id, prob in split_problems.items():
53
+ if prob_id not in our_predictions:
54
+ continue
55
+ if prob_id not in gpt4_predictions:
56
+ continue
57
+ our_pred = our_predictions[prob_id]['text']
58
+ gpt4_pred = gpt4_predictions[prob_id]
59
+
60
+ pattern = re.compile(r'The answer is ([A-Z]).')
61
+ our_res = pattern.findall(our_pred)
62
+ if len(our_res) == 1:
63
+ our_answer = our_res[0] # 'A', 'B', ...
64
+ else:
65
+ our_answer = "FAILED"
66
+ gpt4_res = pattern.findall(gpt4_pred)
67
+ if len(gpt4_res) == 1:
68
+ gpt4_answer = gpt4_res[0] # 'A', 'B', ...
69
+ else:
70
+ gpt4_answer = "FAILED"
71
+
72
+ our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options)
73
+ gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options)
74
+
75
+ if gpt4_answer == 'FAILED':
76
+ results['gpt4_failed'] += 1
77
+ # continue
78
+ gpt4_pred_idx = our_pred_idx
79
+ # if our_pred_idx != prob['answer']:
80
+ # print(our_predictions[prob_id]['prompt'])
81
+ # print('-----------------')
82
+ # print(f'LECTURE: {prob["lecture"]}')
83
+ # print(f'SOLUTION: {prob["solution"]}')
84
+ # print('=====================')
85
+ else:
86
+ # continue
87
+ pass
88
+ # gpt4_pred_idx = our_pred_idx
89
+
90
+ if gpt4_pred_idx == prob['answer']:
91
+ results['correct'] += 1
92
+ else:
93
+ results['incorrect'] += 1
94
+
95
+
96
+ if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']:
97
+ results['correct_upperbound'] += 1
98
+
99
+ correct = results['correct']
100
+ total = results['correct'] + results['incorrect']
101
+ print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%')
102
+ print(f'Total: {total}, Correct (upper): {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%')
103
+ print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%')
104
+
LLAVA_Biovil/llava/eval/eval_science_qa_gpt4_requery.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import re
5
+ import random
6
+ from collections import defaultdict
7
+
8
+
9
+ def get_args():
10
+ parser = argparse.ArgumentParser()
11
+ parser.add_argument('--base-dir', type=str)
12
+ parser.add_argument('--gpt4-result', type=str)
13
+ parser.add_argument('--requery-result', type=str)
14
+ parser.add_argument('--our-result', type=str)
15
+ parser.add_argument('--output-result', type=str)
16
+ parser.add_argument('--split', type=str, default='test')
17
+ parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
18
+ return parser.parse_args()
19
+
20
+
21
+ def convert_caps(results):
22
+ fakecaps = []
23
+ for result in results:
24
+ image_id = result['question_id']
25
+ caption = result['text']
26
+ fakecaps.append({"image_id": int(image_id), "caption": caption})
27
+ return fakecaps
28
+
29
+
30
+ def get_pred_idx(prediction, choices, options):
31
+ """
32
+ Get the index (e.g. 2) from the prediction (e.g. 'C')
33
+ """
34
+ if prediction in options[:len(choices)]:
35
+ return options.index(prediction)
36
+ else:
37
+ return random.choice(range(len(choices)))
38
+
39
+
40
+ if __name__ == "__main__":
41
+ args = get_args()
42
+
43
+ base_dir = args.base_dir
44
+ split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split]
45
+ problems = json.load(open(os.path.join(base_dir, "problems.json")))
46
+ our_predictions = [json.loads(line) for line in open(args.our_result)]
47
+ our_predictions = {pred['question_id']: pred for pred in our_predictions}
48
+ split_problems = {idx: problems[idx] for idx in split_indices}
49
+
50
+ requery_predictions = [json.loads(line) for line in open(args.requery_result)]
51
+ requery_predictions = {pred['question_id']: pred for pred in requery_predictions}
52
+
53
+ gpt4_predictions = json.load(open(args.gpt4_result))['outputs']
54
+
55
+ results = defaultdict(lambda: 0)
56
+
57
+ sqa_results = {}
58
+ sqa_results['acc'] = None
59
+ sqa_results['correct'] = None
60
+ sqa_results['count'] = None
61
+ sqa_results['results'] = {}
62
+ sqa_results['outputs'] = {}
63
+
64
+ for prob_id, prob in split_problems.items():
65
+ if prob_id not in our_predictions:
66
+ assert False
67
+ if prob_id not in gpt4_predictions:
68
+ assert False
69
+ our_pred = our_predictions[prob_id]['text']
70
+ gpt4_pred = gpt4_predictions[prob_id]
71
+ if prob_id not in requery_predictions:
72
+ results['missing_requery'] += 1
73
+ requery_pred = "MISSING"
74
+ else:
75
+ requery_pred = requery_predictions[prob_id]['text']
76
+
77
+ pattern = re.compile(r'The answer is ([A-Z]).')
78
+ our_res = pattern.findall(our_pred)
79
+ if len(our_res) == 1:
80
+ our_answer = our_res[0] # 'A', 'B', ...
81
+ else:
82
+ our_answer = "FAILED"
83
+
84
+ requery_res = pattern.findall(requery_pred)
85
+ if len(requery_res) == 1:
86
+ requery_answer = requery_res[0] # 'A', 'B', ...
87
+ else:
88
+ requery_answer = "FAILED"
89
+
90
+ gpt4_res = pattern.findall(gpt4_pred)
91
+ if len(gpt4_res) == 1:
92
+ gpt4_answer = gpt4_res[0] # 'A', 'B', ...
93
+ else:
94
+ gpt4_answer = "FAILED"
95
+
96
+ our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options)
97
+ gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options)
98
+ requery_pred_idx = get_pred_idx(requery_answer, prob['choices'], args.options)
99
+
100
+ results['total'] += 1
101
+
102
+ if gpt4_answer == 'FAILED':
103
+ results['gpt4_failed'] += 1
104
+ if gpt4_pred_idx == prob['answer']:
105
+ results['gpt4_correct'] += 1
106
+ if our_pred_idx == prob['answer']:
107
+ results['gpt4_ourvisual_correct'] += 1
108
+ elif gpt4_pred_idx == prob['answer']:
109
+ results['gpt4_correct'] += 1
110
+ results['gpt4_ourvisual_correct'] += 1
111
+
112
+ if our_pred_idx == prob['answer']:
113
+ results['our_correct'] += 1
114
+
115
+ if requery_answer == 'FAILED':
116
+ sqa_results['results'][prob_id] = our_pred_idx
117
+ if our_pred_idx == prob['answer']:
118
+ results['requery_correct'] += 1
119
+ else:
120
+ sqa_results['results'][prob_id] = requery_pred_idx
121
+ if requery_pred_idx == prob['answer']:
122
+ results['requery_correct'] += 1
123
+ else:
124
+ print(f"""
125
+ Question ({args.options[prob['answer']]}): {our_predictions[prob_id]['prompt']}
126
+ Our ({our_answer}): {our_pred}
127
+ GPT-4 ({gpt4_answer}): {gpt4_pred}
128
+ Requery ({requery_answer}): {requery_pred}
129
+ print("=====================================")
130
+ """)
131
+
132
+ if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']:
133
+ results['correct_upperbound'] += 1
134
+
135
+ total = results['total']
136
+ print(f'Total: {total}, Our-Correct: {results["our_correct"]}, Accuracy: {results["our_correct"] / total * 100:.2f}%')
137
+ print(f'Total: {total}, GPT-4-Correct: {results["gpt4_correct"]}, Accuracy: {results["gpt4_correct"] / total * 100:.2f}%')
138
+ print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%')
139
+ print(f'Total: {total}, GPT-4-OursVisual-Correct: {results["gpt4_ourvisual_correct"]}, Accuracy: {results["gpt4_ourvisual_correct"] / total * 100:.2f}%')
140
+ print(f'Total: {total}, Requery-Correct: {results["requery_correct"]}, Accuracy: {results["requery_correct"] / total * 100:.2f}%')
141
+ print(f'Total: {total}, Correct upper: {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%')
142
+
143
+ sqa_results['acc'] = results["requery_correct"] / total * 100
144
+ sqa_results['correct'] = results["requery_correct"]
145
+ sqa_results['count'] = total
146
+
147
+ with open(args.output_result, 'w') as f:
148
+ json.dump(sqa_results, f, indent=2)
149
+
LLAVA_Biovil/llava/eval/eval_textvqa.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import json
4
+ import re
5
+
6
+ from LLAV.llava.eval.m4c_evaluator import TextVQAAccuracyEvaluator
7
+
8
+
9
+ def get_args():
10
+ parser = argparse.ArgumentParser()
11
+ parser.add_argument('--annotation-file', type=str)
12
+ parser.add_argument('--result-file', type=str)
13
+ parser.add_argument('--result-dir', type=str)
14
+ return parser.parse_args()
15
+
16
+
17
+ def prompt_processor(prompt):
18
+ if prompt.startswith('OCR tokens: '):
19
+ pattern = r"Question: (.*?) Short answer:"
20
+ match = re.search(pattern, prompt, re.DOTALL)
21
+ question = match.group(1)
22
+ elif 'Reference OCR token: ' in prompt and len(prompt.split('\n')) == 3:
23
+ if prompt.startswith('Reference OCR token:'):
24
+ question = prompt.split('\n')[1]
25
+ else:
26
+ question = prompt.split('\n')[0]
27
+ elif len(prompt.split('\n')) == 2:
28
+ question = prompt.split('\n')[0]
29
+ else:
30
+ assert False
31
+
32
+ return question.lower()
33
+
34
+
35
+ def eval_single(annotation_file, result_file):
36
+ experiment_name = os.path.splitext(os.path.basename(result_file))[0]
37
+ print(experiment_name)
38
+ annotations = json.load(open(annotation_file))['data']
39
+ annotations = {(annotation['image_id'], annotation['question'].lower()): annotation for annotation in annotations}
40
+ results = [json.loads(line) for line in open(result_file)]
41
+
42
+ pred_list = []
43
+ for result in results:
44
+ annotation = annotations[(result['question_id'], prompt_processor(result['prompt']))]
45
+ pred_list.append({
46
+ "pred_answer": result['text'],
47
+ "gt_answers": annotation['answers'],
48
+ })
49
+
50
+ evaluator = TextVQAAccuracyEvaluator()
51
+ print('Samples: {}\nAccuracy: {:.2f}%\n'.format(len(pred_list), 100. * evaluator.eval_pred_list(pred_list)))
52
+
53
+
54
+ if __name__ == "__main__":
55
+ args = get_args()
56
+
57
+ if args.result_file is not None:
58
+ eval_single(args.annotation_file, args.result_file)
59
+
60
+ if args.result_dir is not None:
61
+ for result_file in sorted(os.listdir(args.result_dir)):
62
+ if not result_file.endswith('.jsonl'):
63
+ print(f'Skipping {result_file}')
64
+ continue
65
+ eval_single(args.annotation_file, os.path.join(args.result_dir, result_file))
LLAVA_Biovil/llava/eval/generate_webpage_data_from_table.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generate json file for webpage."""
2
+ import json
3
+ import os
4
+ import re
5
+
6
+ # models = ['llama', 'alpaca', 'gpt35', 'bard']
7
+ models = ['vicuna']
8
+
9
+
10
+ def read_jsonl(path: str, key: str=None):
11
+ data = []
12
+ with open(os.path.expanduser(path)) as f:
13
+ for line in f:
14
+ if not line:
15
+ continue
16
+ data.append(json.loads(line))
17
+ if key is not None:
18
+ data.sort(key=lambda x: x[key])
19
+ data = {item[key]: item for item in data}
20
+ return data
21
+
22
+
23
+ def trim_hanging_lines(s: str, n: int) -> str:
24
+ s = s.strip()
25
+ for _ in range(n):
26
+ s = s.split('\n', 1)[1].strip()
27
+ return s
28
+
29
+
30
+ if __name__ == '__main__':
31
+ questions = read_jsonl('table/question.jsonl', key='question_id')
32
+
33
+ # alpaca_answers = read_jsonl('table/answer/answer_alpaca-13b.jsonl', key='question_id')
34
+ # bard_answers = read_jsonl('table/answer/answer_bard.jsonl', key='question_id')
35
+ # gpt35_answers = read_jsonl('table/answer/answer_gpt35.jsonl', key='question_id')
36
+ # llama_answers = read_jsonl('table/answer/answer_llama-13b.jsonl', key='question_id')
37
+ vicuna_answers = read_jsonl('table/answer/answer_vicuna-13b.jsonl', key='question_id')
38
+ ours_answers = read_jsonl('table/results/llama-13b-hf-alpaca.jsonl', key='question_id')
39
+
40
+ review_vicuna = read_jsonl('table/review/review_vicuna-13b_llama-13b-hf-alpaca.jsonl', key='question_id')
41
+ # review_alpaca = read_jsonl('table/review/review_alpaca-13b_vicuna-13b.jsonl', key='question_id')
42
+ # review_bard = read_jsonl('table/review/review_bard_vicuna-13b.jsonl', key='question_id')
43
+ # review_gpt35 = read_jsonl('table/review/review_gpt35_vicuna-13b.jsonl', key='question_id')
44
+ # review_llama = read_jsonl('table/review/review_llama-13b_vicuna-13b.jsonl', key='question_id')
45
+
46
+ records = []
47
+ for qid in questions.keys():
48
+ r = {
49
+ 'id': qid,
50
+ 'category': questions[qid]['category'],
51
+ 'question': questions[qid]['text'],
52
+ 'answers': {
53
+ # 'alpaca': alpaca_answers[qid]['text'],
54
+ # 'llama': llama_answers[qid]['text'],
55
+ # 'bard': bard_answers[qid]['text'],
56
+ # 'gpt35': gpt35_answers[qid]['text'],
57
+ 'vicuna': vicuna_answers[qid]['text'],
58
+ 'ours': ours_answers[qid]['text'],
59
+ },
60
+ 'evaluations': {
61
+ # 'alpaca': review_alpaca[qid]['text'],
62
+ # 'llama': review_llama[qid]['text'],
63
+ # 'bard': review_bard[qid]['text'],
64
+ 'vicuna': review_vicuna[qid]['content'],
65
+ # 'gpt35': review_gpt35[qid]['text'],
66
+ },
67
+ 'scores': {
68
+ 'vicuna': review_vicuna[qid]['tuple'],
69
+ # 'alpaca': review_alpaca[qid]['score'],
70
+ # 'llama': review_llama[qid]['score'],
71
+ # 'bard': review_bard[qid]['score'],
72
+ # 'gpt35': review_gpt35[qid]['score'],
73
+ },
74
+ }
75
+
76
+ # cleanup data
77
+ cleaned_evals = {}
78
+ for k, v in r['evaluations'].items():
79
+ v = v.strip()
80
+ lines = v.split('\n')
81
+ # trim the first line if it's a pair of numbers
82
+ if re.match(r'\d+[, ]+\d+', lines[0]):
83
+ lines = lines[1:]
84
+ v = '\n'.join(lines)
85
+ cleaned_evals[k] = v.replace('Assistant 1', "**Assistant 1**").replace('Assistant 2', '**Assistant 2**')
86
+
87
+ r['evaluations'] = cleaned_evals
88
+ records.append(r)
89
+
90
+ # Reorder the records, this is optional
91
+ for r in records:
92
+ if r['id'] <= 20:
93
+ r['id'] += 60
94
+ else:
95
+ r['id'] -= 20
96
+ for r in records:
97
+ if r['id'] <= 50:
98
+ r['id'] += 10
99
+ elif 50 < r['id'] <= 60:
100
+ r['id'] -= 50
101
+ for r in records:
102
+ if r['id'] == 7:
103
+ r['id'] = 1
104
+ elif r['id'] < 7:
105
+ r['id'] += 1
106
+
107
+ records.sort(key=lambda x: x['id'])
108
+
109
+ # Write to file
110
+ with open('webpage/data.json', 'w') as f:
111
+ json.dump({'questions': records, 'models': models}, f, indent=2)
LLAVA_Biovil/llava/eval/m4c_evaluator.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import re
3
+
4
+ from tqdm import tqdm
5
+
6
+
7
+ class EvalAIAnswerProcessor:
8
+ """
9
+ Processes an answer similar to Eval AI
10
+ copied from
11
+ https://github.com/facebookresearch/mmf/blob/c46b3b3391275b4181567db80943473a89ab98ab/pythia/tasks/processors.py#L897
12
+ """
13
+
14
+ CONTRACTIONS = {
15
+ "aint": "ain't",
16
+ "arent": "aren't",
17
+ "cant": "can't",
18
+ "couldve": "could've",
19
+ "couldnt": "couldn't",
20
+ "couldn'tve": "couldn't've",
21
+ "couldnt've": "couldn't've",
22
+ "didnt": "didn't",
23
+ "doesnt": "doesn't",
24
+ "dont": "don't",
25
+ "hadnt": "hadn't",
26
+ "hadnt've": "hadn't've",
27
+ "hadn'tve": "hadn't've",
28
+ "hasnt": "hasn't",
29
+ "havent": "haven't",
30
+ "hed": "he'd",
31
+ "hed've": "he'd've",
32
+ "he'dve": "he'd've",
33
+ "hes": "he's",
34
+ "howd": "how'd",
35
+ "howll": "how'll",
36
+ "hows": "how's",
37
+ "Id've": "I'd've",
38
+ "I'dve": "I'd've",
39
+ "Im": "I'm",
40
+ "Ive": "I've",
41
+ "isnt": "isn't",
42
+ "itd": "it'd",
43
+ "itd've": "it'd've",
44
+ "it'dve": "it'd've",
45
+ "itll": "it'll",
46
+ "let's": "let's",
47
+ "maam": "ma'am",
48
+ "mightnt": "mightn't",
49
+ "mightnt've": "mightn't've",
50
+ "mightn'tve": "mightn't've",
51
+ "mightve": "might've",
52
+ "mustnt": "mustn't",
53
+ "mustve": "must've",
54
+ "neednt": "needn't",
55
+ "notve": "not've",
56
+ "oclock": "o'clock",
57
+ "oughtnt": "oughtn't",
58
+ "ow's'at": "'ow's'at",
59
+ "'ows'at": "'ow's'at",
60
+ "'ow'sat": "'ow's'at",
61
+ "shant": "shan't",
62
+ "shed've": "she'd've",
63
+ "she'dve": "she'd've",
64
+ "she's": "she's",
65
+ "shouldve": "should've",
66
+ "shouldnt": "shouldn't",
67
+ "shouldnt've": "shouldn't've",
68
+ "shouldn'tve": "shouldn't've",
69
+ "somebody'd": "somebodyd",
70
+ "somebodyd've": "somebody'd've",
71
+ "somebody'dve": "somebody'd've",
72
+ "somebodyll": "somebody'll",
73
+ "somebodys": "somebody's",
74
+ "someoned": "someone'd",
75
+ "someoned've": "someone'd've",
76
+ "someone'dve": "someone'd've",
77
+ "someonell": "someone'll",
78
+ "someones": "someone's",
79
+ "somethingd": "something'd",
80
+ "somethingd've": "something'd've",
81
+ "something'dve": "something'd've",
82
+ "somethingll": "something'll",
83
+ "thats": "that's",
84
+ "thered": "there'd",
85
+ "thered've": "there'd've",
86
+ "there'dve": "there'd've",
87
+ "therere": "there're",
88
+ "theres": "there's",
89
+ "theyd": "they'd",
90
+ "theyd've": "they'd've",
91
+ "they'dve": "they'd've",
92
+ "theyll": "they'll",
93
+ "theyre": "they're",
94
+ "theyve": "they've",
95
+ "twas": "'twas",
96
+ "wasnt": "wasn't",
97
+ "wed've": "we'd've",
98
+ "we'dve": "we'd've",
99
+ "weve": "we've",
100
+ "werent": "weren't",
101
+ "whatll": "what'll",
102
+ "whatre": "what're",
103
+ "whats": "what's",
104
+ "whatve": "what've",
105
+ "whens": "when's",
106
+ "whered": "where'd",
107
+ "wheres": "where's",
108
+ "whereve": "where've",
109
+ "whod": "who'd",
110
+ "whod've": "who'd've",
111
+ "who'dve": "who'd've",
112
+ "wholl": "who'll",
113
+ "whos": "who's",
114
+ "whove": "who've",
115
+ "whyll": "why'll",
116
+ "whyre": "why're",
117
+ "whys": "why's",
118
+ "wont": "won't",
119
+ "wouldve": "would've",
120
+ "wouldnt": "wouldn't",
121
+ "wouldnt've": "wouldn't've",
122
+ "wouldn'tve": "wouldn't've",
123
+ "yall": "y'all",
124
+ "yall'll": "y'all'll",
125
+ "y'allll": "y'all'll",
126
+ "yall'd've": "y'all'd've",
127
+ "y'alld've": "y'all'd've",
128
+ "y'all'dve": "y'all'd've",
129
+ "youd": "you'd",
130
+ "youd've": "you'd've",
131
+ "you'dve": "you'd've",
132
+ "youll": "you'll",
133
+ "youre": "you're",
134
+ "youve": "you've",
135
+ }
136
+
137
+ NUMBER_MAP = {
138
+ "none": "0",
139
+ "zero": "0",
140
+ "one": "1",
141
+ "two": "2",
142
+ "three": "3",
143
+ "four": "4",
144
+ "five": "5",
145
+ "six": "6",
146
+ "seven": "7",
147
+ "eight": "8",
148
+ "nine": "9",
149
+ "ten": "10",
150
+ }
151
+ ARTICLES = ["a", "an", "the"]
152
+ PERIOD_STRIP = re.compile(r"(?!<=\d)(\.)(?!\d)")
153
+ COMMA_STRIP = re.compile(r"(?<=\d)(\,)+(?=\d)")
154
+ PUNCTUATIONS = [
155
+ ";",
156
+ r"/",
157
+ "[",
158
+ "]",
159
+ '"',
160
+ "{",
161
+ "}",
162
+ "(",
163
+ ")",
164
+ "=",
165
+ "+",
166
+ "\\",
167
+ "_",
168
+ "-",
169
+ ">",
170
+ "<",
171
+ "@",
172
+ "`",
173
+ ",",
174
+ "?",
175
+ "!",
176
+ ]
177
+
178
+ def __init__(self, *args, **kwargs):
179
+ pass
180
+
181
+ def word_tokenize(self, word):
182
+ word = word.lower()
183
+ word = word.replace(",", "").replace("?", "").replace("'s", " 's")
184
+ return word.strip()
185
+
186
+ def process_punctuation(self, in_text):
187
+ out_text = in_text
188
+ for p in self.PUNCTUATIONS:
189
+ if (p + " " in in_text or " " + p in in_text) or (
190
+ re.search(self.COMMA_STRIP, in_text) is not None
191
+ ):
192
+ out_text = out_text.replace(p, "")
193
+ else:
194
+ out_text = out_text.replace(p, " ")
195
+ out_text = self.PERIOD_STRIP.sub("", out_text, re.UNICODE)
196
+ return out_text
197
+
198
+ def process_digit_article(self, in_text):
199
+ out_text = []
200
+ temp_text = in_text.lower().split()
201
+ for word in temp_text:
202
+ word = self.NUMBER_MAP.setdefault(word, word)
203
+ if word not in self.ARTICLES:
204
+ out_text.append(word)
205
+ else:
206
+ pass
207
+ for word_id, word in enumerate(out_text):
208
+ if word in self.CONTRACTIONS:
209
+ out_text[word_id] = self.CONTRACTIONS[word]
210
+ out_text = " ".join(out_text)
211
+ return out_text
212
+
213
+ def __call__(self, item):
214
+ item = self.word_tokenize(item)
215
+ item = item.replace("\n", " ").replace("\t", " ").strip()
216
+ item = self.process_punctuation(item)
217
+ item = self.process_digit_article(item)
218
+ return item
219
+
220
+
221
+ class TextVQAAccuracyEvaluator:
222
+ def __init__(self):
223
+ self.answer_processor = EvalAIAnswerProcessor()
224
+
225
+ def _compute_answer_scores(self, raw_answers):
226
+ """
227
+ compute the accuracy (soft score) of human answers
228
+ """
229
+ answers = [self.answer_processor(a) for a in raw_answers]
230
+ assert len(answers) == 10
231
+ gt_answers = list(enumerate(answers))
232
+ unique_answers = set(answers)
233
+ unique_answer_scores = {}
234
+
235
+ for unique_answer in unique_answers:
236
+ accs = []
237
+ for gt_answer in gt_answers:
238
+ other_answers = [item for item in gt_answers if item != gt_answer]
239
+ matching_answers = [
240
+ item for item in other_answers if item[1] == unique_answer
241
+ ]
242
+ acc = min(1, float(len(matching_answers)) / 3)
243
+ accs.append(acc)
244
+ unique_answer_scores[unique_answer] = sum(accs) / len(accs)
245
+
246
+ return unique_answer_scores
247
+
248
+ def eval_pred_list(self, pred_list):
249
+ pred_scores = []
250
+ for entry in tqdm(pred_list):
251
+ pred_answer = self.answer_processor(entry["pred_answer"])
252
+ unique_answer_scores = self._compute_answer_scores(entry["gt_answers"])
253
+ score = unique_answer_scores.get(pred_answer, 0.0)
254
+ pred_scores.append(score)
255
+
256
+ accuracy = sum(pred_scores) / len(pred_scores)
257
+ return accuracy
258
+
259
+
260
+ class STVQAAccuracyEvaluator:
261
+ def __init__(self):
262
+ self.answer_processor = EvalAIAnswerProcessor()
263
+
264
+ def eval_pred_list(self, pred_list):
265
+ pred_scores = []
266
+ for entry in pred_list:
267
+ pred_answer = self.answer_processor(entry["pred_answer"])
268
+ gts = [self.answer_processor(a) for a in entry["gt_answers"]]
269
+ score = 1.0 if pred_answer in gts else 0.0
270
+ pred_scores.append(score)
271
+
272
+ accuracy = sum(pred_scores) / len(pred_scores)
273
+ return accuracy
274
+
275
+
276
+ class STVQAANLSEvaluator:
277
+ def __init__(self):
278
+ import editdistance # install with `pip install editdistance`
279
+
280
+ self.get_edit_distance = editdistance.eval
281
+
282
+ def get_anls(self, s1, s2):
283
+ s1 = s1.lower().strip()
284
+ s2 = s2.lower().strip()
285
+ iou = 1 - self.get_edit_distance(s1, s2) / max(len(s1), len(s2))
286
+ anls = iou if iou >= 0.5 else 0.0
287
+ return anls
288
+
289
+ def eval_pred_list(self, pred_list):
290
+ pred_scores = []
291
+ for entry in pred_list:
292
+ anls = max(
293
+ self.get_anls(entry["pred_answer"], gt) for gt in entry["gt_answers"]
294
+ )
295
+ pred_scores.append(anls)
296
+
297
+ accuracy = sum(pred_scores) / len(pred_scores)
298
+ return accuracy
299
+
300
+
301
+ class TextCapsBleu4Evaluator:
302
+ def __init__(self):
303
+ # The following script requires Java 1.8.0 and pycocotools installed.
304
+ # The pycocoevalcap can be installed with pip as
305
+ # pip install git+https://github.com/ronghanghu/coco-caption.git@python23
306
+ # Original pycocoevalcap code is at https://github.com/tylin/coco-caption
307
+ # but has no python3 support yet.
308
+ try:
309
+ from pycocoevalcap.bleu.bleu import Bleu
310
+ from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
311
+ except ModuleNotFoundError:
312
+ print(
313
+ "Please install pycocoevalcap module using "
314
+ "pip install git+https://github.com/ronghanghu/coco-caption.git@python23" # noqa
315
+ )
316
+ raise
317
+
318
+ self.tokenizer = PTBTokenizer()
319
+ self.scorer = Bleu(4)
320
+
321
+ def eval_pred_list(self, pred_list):
322
+ # Create reference and hypotheses captions.
323
+ gts = {}
324
+ res = {}
325
+ for idx, entry in enumerate(pred_list):
326
+ gts[idx] = [{"caption": a} for a in entry["gt_answers"]]
327
+ res[idx] = [{"caption": entry["pred_answer"]}]
328
+
329
+ gts = self.tokenizer.tokenize(gts)
330
+ res = self.tokenizer.tokenize(res)
331
+ score, _ = self.scorer.compute_score(gts, res)
332
+
333
+ bleu4 = score[3] # score is (Bleu-1, Bleu-2, Bleu-3, Bleu-4)
334
+ return bleu4
LLAVA_Biovil/llava/eval/model_qa.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria
3
+ import torch
4
+ import os
5
+ import json
6
+ from tqdm import tqdm
7
+ import shortuuid
8
+
9
+ from LLAV.llava.conversation import default_conversation
10
+ from LLAV.llava.utils import disable_torch_init
11
+
12
+
13
+ # new stopping implementation
14
+ class KeywordsStoppingCriteria(StoppingCriteria):
15
+ def __init__(self, keywords, tokenizer, input_ids):
16
+ self.keywords = keywords
17
+ self.tokenizer = tokenizer
18
+ self.start_len = None
19
+ self.input_ids = input_ids
20
+
21
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
22
+ if self.start_len is None:
23
+ self.start_len = self.input_ids.shape[1]
24
+ else:
25
+ outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0]
26
+ for keyword in self.keywords:
27
+ if keyword in outputs:
28
+ return True
29
+ return False
30
+
31
+
32
+ @torch.inference_mode()
33
+ def eval_model(model_name, questions_file, answers_file):
34
+ # Model
35
+ disable_torch_init()
36
+ model_name = os.path.expanduser(model_name)
37
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
38
+ model = AutoModelForCausalLM.from_pretrained(model_name,
39
+ torch_dtype=torch.float16).cuda()
40
+
41
+
42
+ ques_file = open(os.path.expanduser(questions_file), "r")
43
+ ans_file = open(os.path.expanduser(answers_file), "w")
44
+ for i, line in enumerate(tqdm(ques_file)):
45
+ idx = json.loads(line)["question_id"]
46
+ qs = json.loads(line)["text"]
47
+ cat = json.loads(line)["category"]
48
+ conv = default_conversation.copy()
49
+ conv.append_message(conv.roles[0], qs)
50
+ prompt = conv.get_prompt()
51
+ inputs = tokenizer([prompt])
52
+ input_ids = torch.as_tensor(inputs.input_ids).cuda()
53
+ stopping_criteria = KeywordsStoppingCriteria([conv.sep], tokenizer, input_ids)
54
+ output_ids = model.generate(
55
+ input_ids,
56
+ do_sample=True,
57
+ use_cache=True,
58
+ temperature=0.7,
59
+ max_new_tokens=1024,
60
+ stopping_criteria=[stopping_criteria])
61
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
62
+ try:
63
+ index = outputs.index(conv.sep, len(prompt))
64
+ except ValueError:
65
+ outputs += conv.sep
66
+ index = outputs.index(conv.sep, len(prompt))
67
+
68
+ outputs = outputs[len(prompt) + len(conv.roles[1]) + 2:index].strip()
69
+ ans_id = shortuuid.uuid()
70
+ ans_file.write(json.dumps({"question_id": idx,
71
+ "text": outputs,
72
+ "answer_id": ans_id,
73
+ "model_id": model_name,
74
+ "metadata": {}}) + "\n")
75
+ ans_file.flush()
76
+ ans_file.close()
77
+
78
+ if __name__ == "__main__":
79
+ parser = argparse.ArgumentParser()
80
+ parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
81
+ parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
82
+ parser.add_argument("--answers-file", type=str, default="answer.jsonl")
83
+ args = parser.parse_args()
84
+
85
+ eval_model(args.model_name, args.question_file, args.answers_file)
LLAVA_Biovil/llava/eval/model_vqa.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import os
4
+ import json
5
+ from tqdm import tqdm
6
+ import shortuuid
7
+
8
+ from LLAV.llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
9
+ from LLAV.llava.conversation import conv_templates, SeparatorStyle
10
+ from LLAV.llava.model.builder import load_pretrained_model
11
+ from LLAV.llava.utils import disable_torch_init
12
+ from LLAV.llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
13
+
14
+ from PIL import Image
15
+ import math
16
+
17
+
18
+ def split_list(lst, n):
19
+ """Split a list into n (roughly) equal-sized chunks"""
20
+ chunk_size = math.ceil(len(lst) / n) # integer division
21
+ return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
22
+
23
+
24
+ def get_chunk(lst, n, k):
25
+ chunks = split_list(lst, n)
26
+ return chunks[k]
27
+
28
+
29
+ def eval_model(args):
30
+ # Model
31
+ disable_torch_init()
32
+ model_path = os.path.expanduser(args.model_path)
33
+ model_name = get_model_name_from_path(model_path)
34
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
35
+
36
+ questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
37
+ questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
38
+ answers_file = os.path.expanduser(args.answers_file)
39
+ os.makedirs(os.path.dirname(answers_file), exist_ok=True)
40
+ ans_file = open(answers_file, "w")
41
+ for line in tqdm(questions):
42
+ idx = line["question_id"]
43
+ image_file = line["image"]
44
+ qs = line["text"]
45
+ cur_prompt = qs
46
+ if model.config.mm_use_im_start_end:
47
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
48
+ else:
49
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
50
+
51
+ conv = conv_templates[args.conv_mode].copy()
52
+ conv.append_message(conv.roles[0], qs)
53
+ conv.append_message(conv.roles[1], None)
54
+ prompt = conv.get_prompt()
55
+
56
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
57
+
58
+ image = Image.open(os.path.join(args.image_folder, image_file))
59
+ image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
60
+
61
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
62
+ keywords = [stop_str]
63
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
64
+
65
+ with torch.inference_mode():
66
+ output_ids = model.generate(
67
+ input_ids,
68
+ images=image_tensor.unsqueeze(0).half().cuda(),
69
+ do_sample=True if args.temperature > 0 else False,
70
+ temperature=args.temperature,
71
+ top_p=args.top_p,
72
+ num_beams=args.num_beams,
73
+ # no_repeat_ngram_size=3,
74
+ max_new_tokens=1024,
75
+ use_cache=True)
76
+
77
+ input_token_len = input_ids.shape[1]
78
+ n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
79
+ if n_diff_input_output > 0:
80
+ print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
81
+ outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
82
+ outputs = outputs.strip()
83
+ if outputs.endswith(stop_str):
84
+ outputs = outputs[:-len(stop_str)]
85
+ outputs = outputs.strip()
86
+
87
+ ans_id = shortuuid.uuid()
88
+ ans_file.write(json.dumps({"question_id": idx,
89
+ "prompt": cur_prompt,
90
+ "text": outputs,
91
+ "answer_id": ans_id,
92
+ "model_id": model_name,
93
+ "metadata": {}}) + "\n")
94
+ ans_file.flush()
95
+ ans_file.close()
96
+
97
+ if __name__ == "__main__":
98
+ parser = argparse.ArgumentParser()
99
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
100
+ parser.add_argument("--model-base", type=str, default=None)
101
+ parser.add_argument("--image-folder", type=str, default="")
102
+ parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
103
+ parser.add_argument("--answers-file", type=str, default="answer.jsonl")
104
+ parser.add_argument("--conv-mode", type=str, default="llava_v1")
105
+ parser.add_argument("--num-chunks", type=int, default=1)
106
+ parser.add_argument("--chunk-idx", type=int, default=0)
107
+ parser.add_argument("--temperature", type=float, default=0.2)
108
+ parser.add_argument("--top_p", type=float, default=None)
109
+ parser.add_argument("--num_beams", type=int, default=1)
110
+ args = parser.parse_args()
111
+
112
+ eval_model(args)
LLAVA_Biovil/llava/eval/model_vqa_loader.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import os
4
+ import json
5
+ from tqdm import tqdm
6
+ import shortuuid
7
+
8
+ from LLAV.llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
9
+ from LLAV.llava.conversation import conv_templates
10
+ from LLAV.llava.model.builder import load_pretrained_model
11
+ from LLAV.llava.utils import disable_torch_init
12
+ from LLAV.llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
13
+ from torch.utils.data import Dataset, DataLoader
14
+
15
+ from PIL import Image
16
+ import math
17
+
18
+
19
+ def split_list(lst, n):
20
+ """Split a list into n (roughly) equal-sized chunks"""
21
+ chunk_size = math.ceil(len(lst) / n) # integer division
22
+ return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
23
+
24
+
25
+ def get_chunk(lst, n, k):
26
+ chunks = split_list(lst, n)
27
+ return chunks[k]
28
+
29
+
30
+ # Custom dataset class
31
+ class CustomDataset(Dataset):
32
+ def __init__(self, questions, image_folder, tokenizer, image_processor, model_config):
33
+ self.questions = questions
34
+ self.image_folder = image_folder
35
+ self.tokenizer = tokenizer
36
+ self.image_processor = image_processor
37
+ self.model_config = model_config
38
+
39
+ def __getitem__(self, index):
40
+ line = self.questions[index]
41
+ image_file = line["image"]
42
+ qs = line["text"]
43
+ if self.model_config.mm_use_im_start_end:
44
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
45
+ else:
46
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
47
+
48
+ conv = conv_templates[args.conv_mode].copy()
49
+ conv.append_message(conv.roles[0], qs)
50
+ conv.append_message(conv.roles[1], None)
51
+ prompt = conv.get_prompt()
52
+
53
+ image = Image.open(os.path.join(self.image_folder, image_file)).convert('RGB')
54
+ image_tensor = process_images([image], self.image_processor, self.model_config)[0]
55
+
56
+ input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
57
+
58
+ return input_ids, image_tensor
59
+
60
+ def __len__(self):
61
+ return len(self.questions)
62
+
63
+
64
+ # DataLoader
65
+ def create_data_loader(questions, image_folder, tokenizer, image_processor, model_config, batch_size=1, num_workers=4):
66
+ assert batch_size == 1, "batch_size must be 1"
67
+ dataset = CustomDataset(questions, image_folder, tokenizer, image_processor, model_config)
68
+ data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)
69
+ return data_loader
70
+
71
+
72
+ def eval_model(args):
73
+ # Model
74
+ disable_torch_init()
75
+ model_path = os.path.expanduser(args.model_path)
76
+ model_name = get_model_name_from_path(model_path)
77
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
78
+
79
+ questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
80
+ questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
81
+ answers_file = os.path.expanduser(args.answers_file)
82
+ os.makedirs(os.path.dirname(answers_file), exist_ok=True)
83
+ ans_file = open(answers_file, "w")
84
+
85
+ if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode:
86
+ args.conv_mode = args.conv_mode + '_mmtag'
87
+ print(f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.')
88
+
89
+ data_loader = create_data_loader(questions, args.image_folder, tokenizer, image_processor, model.config)
90
+
91
+ for (input_ids, image_tensor), line in tqdm(zip(data_loader, questions), total=len(questions)):
92
+ idx = line["question_id"]
93
+ cur_prompt = line["text"]
94
+
95
+ input_ids = input_ids.to(device='cuda', non_blocking=True)
96
+
97
+ with torch.inference_mode():
98
+ output_ids = model.generate(
99
+ input_ids,
100
+ images=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True),
101
+ do_sample=True if args.temperature > 0 else False,
102
+ temperature=args.temperature,
103
+ top_p=args.top_p,
104
+ num_beams=args.num_beams,
105
+ max_new_tokens=args.max_new_tokens,
106
+ use_cache=True)
107
+
108
+ input_token_len = input_ids.shape[1]
109
+ n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
110
+ if n_diff_input_output > 0:
111
+ print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
112
+ outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
113
+ outputs = outputs.strip()
114
+
115
+ ans_id = shortuuid.uuid()
116
+ ans_file.write(json.dumps({"question_id": idx,
117
+ "prompt": cur_prompt,
118
+ "text": outputs,
119
+ "answer_id": ans_id,
120
+ "model_id": model_name,
121
+ "metadata": {}}) + "\n")
122
+ # ans_file.flush()
123
+ ans_file.close()
124
+
125
+ if __name__ == "__main__":
126
+ parser = argparse.ArgumentParser()
127
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
128
+ parser.add_argument("--model-base", type=str, default=None)
129
+ parser.add_argument("--image-folder", type=str, default="")
130
+ parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
131
+ parser.add_argument("--answers-file", type=str, default="answer.jsonl")
132
+ parser.add_argument("--conv-mode", type=str, default="llava_v1")
133
+ parser.add_argument("--num-chunks", type=int, default=1)
134
+ parser.add_argument("--chunk-idx", type=int, default=0)
135
+ parser.add_argument("--temperature", type=float, default=0.2)
136
+ parser.add_argument("--top_p", type=float, default=None)
137
+ parser.add_argument("--num_beams", type=int, default=1)
138
+ parser.add_argument("--max_new_tokens", type=int, default=128)
139
+ args = parser.parse_args()
140
+
141
+ eval_model(args)
LLAVA_Biovil/llava/eval/model_vqa_mmbench.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import os
4
+ import json
5
+ import pandas as pd
6
+ from tqdm import tqdm
7
+ import shortuuid
8
+
9
+ from LLAV.llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
10
+ from LLAV.llava.conversation import conv_templates, SeparatorStyle
11
+ from LLAV.llava.model.builder import load_pretrained_model
12
+ from LLAV.llava.utils import disable_torch_init
13
+ from LLAV.llava.mm_utils import tokenizer_image_token, process_images, load_image_from_base64, get_model_name_from_path
14
+
15
+ import math
16
+
17
+
18
+ all_options = ['A', 'B', 'C', 'D']
19
+
20
+
21
+ def split_list(lst, n):
22
+ """Split a list into n (roughly) equal-sized chunks"""
23
+ chunk_size = math.ceil(len(lst) / n) # integer division
24
+ return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
25
+
26
+
27
+ def get_chunk(lst, n, k):
28
+ chunks = split_list(lst, n)
29
+ return chunks[k]
30
+
31
+
32
+ def is_none(value):
33
+ if value is None:
34
+ return True
35
+ if type(value) is float and math.isnan(value):
36
+ return True
37
+ if type(value) is str and value.lower() == 'nan':
38
+ return True
39
+ if type(value) is str and value.lower() == 'none':
40
+ return True
41
+ return False
42
+
43
+ def get_options(row, options):
44
+ parsed_options = []
45
+ for option in options:
46
+ option_value = row[option]
47
+ if is_none(option_value):
48
+ break
49
+ parsed_options.append(option_value)
50
+ return parsed_options
51
+
52
+
53
+ def eval_model(args):
54
+ # Model
55
+ disable_torch_init()
56
+ model_path = os.path.expanduser(args.model_path)
57
+ model_name = get_model_name_from_path(model_path)
58
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
59
+
60
+ questions = pd.read_table(os.path.expanduser(args.question_file))
61
+ questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
62
+ answers_file = os.path.expanduser(args.answers_file)
63
+ os.makedirs(os.path.dirname(answers_file), exist_ok=True)
64
+ ans_file = open(answers_file, "w")
65
+
66
+ if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode:
67
+ args.conv_mode = args.conv_mode + '_mmtag'
68
+ print(f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.')
69
+
70
+ for index, row in tqdm(questions.iterrows(), total=len(questions)):
71
+ options = get_options(row, all_options)
72
+ cur_option_char = all_options[:len(options)]
73
+
74
+ if args.all_rounds:
75
+ num_rounds = len(options)
76
+ else:
77
+ num_rounds = 1
78
+
79
+ for round_idx in range(num_rounds):
80
+ idx = row['index']
81
+ question = row['question']
82
+ hint = row['hint']
83
+ image = load_image_from_base64(row['image'])
84
+ if not is_none(hint):
85
+ question = hint + '\n' + question
86
+ for option_char, option in zip(all_options[:len(options)], options):
87
+ question = question + '\n' + option_char + '. ' + option
88
+ qs = cur_prompt = question
89
+ if model.config.mm_use_im_start_end:
90
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
91
+ else:
92
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
93
+
94
+ if args.single_pred_prompt:
95
+ if args.lang == 'cn':
96
+ qs = qs + '\n' + "请直接回答选项字母。"
97
+ else:
98
+ qs = qs + '\n' + "Answer with the option's letter from the given choices directly."
99
+
100
+ conv = conv_templates[args.conv_mode].copy()
101
+ conv.append_message(conv.roles[0], qs)
102
+ conv.append_message(conv.roles[1], None)
103
+ prompt = conv.get_prompt()
104
+
105
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
106
+
107
+ image_tensor = process_images([image], image_processor, model.config)[0]
108
+ # image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
109
+
110
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
111
+
112
+ with torch.inference_mode():
113
+ output_ids = model.generate(
114
+ input_ids,
115
+ images=image_tensor.unsqueeze(0).half().cuda(),
116
+ do_sample=True if args.temperature > 0 else False,
117
+ temperature=args.temperature,
118
+ top_p=args.top_p,
119
+ num_beams=args.num_beams,
120
+ # no_repeat_ngram_size=3,
121
+ max_new_tokens=1024,
122
+ use_cache=True)
123
+
124
+ input_token_len = input_ids.shape[1]
125
+ n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
126
+ if n_diff_input_output > 0:
127
+ print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
128
+ outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
129
+ outputs = outputs.strip()
130
+ if outputs.endswith(stop_str):
131
+ outputs = outputs[:-len(stop_str)]
132
+ outputs = outputs.strip()
133
+
134
+ ans_id = shortuuid.uuid()
135
+ ans_file.write(json.dumps({"question_id": idx,
136
+ "round_id": round_idx,
137
+ "prompt": cur_prompt,
138
+ "text": outputs,
139
+ "options": options,
140
+ "option_char": cur_option_char,
141
+ "answer_id": ans_id,
142
+ "model_id": model_name,
143
+ "metadata": {}}) + "\n")
144
+ ans_file.flush()
145
+
146
+ # rotate options
147
+ options = options[1:] + options[:1]
148
+ cur_option_char = cur_option_char[1:] + cur_option_char[:1]
149
+ ans_file.close()
150
+
151
+ if __name__ == "__main__":
152
+ parser = argparse.ArgumentParser()
153
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
154
+ parser.add_argument("--model-base", type=str, default=None)
155
+ parser.add_argument("--image-folder", type=str, default="")
156
+ parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
157
+ parser.add_argument("--answers-file", type=str, default="answer.jsonl")
158
+ parser.add_argument("--conv-mode", type=str, default="llava_v1")
159
+ parser.add_argument("--num-chunks", type=int, default=1)
160
+ parser.add_argument("--chunk-idx", type=int, default=0)
161
+ parser.add_argument("--temperature", type=float, default=0.2)
162
+ parser.add_argument("--top_p", type=float, default=None)
163
+ parser.add_argument("--num_beams", type=int, default=1)
164
+ parser.add_argument("--all-rounds", action="store_true")
165
+ parser.add_argument("--single-pred-prompt", action="store_true")
166
+ parser.add_argument("--lang", type=str, default="en")
167
+ args = parser.parse_args()
168
+
169
+ eval_model(args)
LLAVA_Biovil/llava/eval/model_vqa_qbench.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ from tqdm import tqdm
4
+ import json
5
+
6
+ from LLAV.llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
7
+ from LLAV.llava.conversation import conv_templates, SeparatorStyle
8
+ from LLAV.llava.model.builder import load_pretrained_model
9
+ from LLAV.llava.utils import disable_torch_init
10
+ from LLAV.llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
11
+
12
+ import requests
13
+ from PIL import Image
14
+ from io import BytesIO
15
+
16
+
17
+ def load_image(image_file):
18
+ if image_file.startswith('http') or image_file.startswith('https'):
19
+ response = requests.get(image_file)
20
+ image = Image.open(BytesIO(response.content)).convert('RGB')
21
+ else:
22
+ image = Image.open(image_file).convert('RGB')
23
+ return image
24
+
25
+
26
+ def eval_model(args):
27
+ # Model
28
+ disable_torch_init()
29
+
30
+ model_name = get_model_name_from_path(args.model_path)
31
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, True)
32
+
33
+
34
+
35
+
36
+ with open(args.questions_file) as f:
37
+ llvqa_data = json.load(f)
38
+
39
+ for i, llddata in enumerate(tqdm(llvqa_data)):
40
+ filename = llddata["img_path"]
41
+ if args.lang == "en":
42
+ message = llddata["question"] + "\nChoose between one of the options as follows:\n"
43
+ elif args.lang == "zh":
44
+ message = llddata["question"] + "\在下列选项中选择一个:\n"
45
+ else:
46
+ raise NotImplementedError("Q-Bench does not support languages other than English (en) and Chinese (zh) yet. Contact us (https://github.com/VQAssessment/Q-Bench/) to convert Q-Bench into more languages.")
47
+ for choice, ans in zip(["A.", "B.", "C.", "D."], llddata["candidates"]):
48
+ message += f"{choice} {ans}\n"
49
+ qs = message
50
+
51
+ if model.config.mm_use_im_start_end:
52
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
53
+ else:
54
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
55
+
56
+ if 'llama-2' in model_name.lower():
57
+ conv_mode = "llava_llama_2"
58
+ elif "v1" in model_name.lower():
59
+ conv_mode = "llava_v1"
60
+ elif "mpt" in model_name.lower():
61
+ conv_mode = "mpt"
62
+ else:
63
+ conv_mode = "llava_v0"
64
+
65
+ if args.conv_mode is not None and conv_mode != args.conv_mode:
66
+ print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
67
+ else:
68
+ args.conv_mode = conv_mode
69
+
70
+ conv = conv_templates[args.conv_mode].copy()
71
+ conv.append_message(conv.roles[0], qs)
72
+ conv.append_message(conv.roles[1], None)
73
+ prompt = conv.get_prompt()
74
+
75
+ image = load_image(args.image_folder + filename)
76
+ image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()
77
+
78
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
79
+
80
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
81
+ keywords = [stop_str]
82
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
83
+
84
+
85
+ with torch.inference_mode():
86
+ output_ids = model.generate(
87
+ input_ids,
88
+ images=image_tensor,
89
+ num_beams=1,
90
+ do_sample=False,
91
+ temperature=0,
92
+ max_new_tokens=1024,
93
+ use_cache=True,
94
+ stopping_criteria=[stopping_criteria])
95
+
96
+ input_token_len = input_ids.shape[1]
97
+ n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
98
+ if n_diff_input_output > 0:
99
+ print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
100
+ outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
101
+ outputs = outputs.strip()
102
+ if outputs.endswith(stop_str):
103
+ outputs = outputs[:-len(stop_str)]
104
+ outputs = outputs.strip()
105
+ llddata["response"] = outputs
106
+ with open(args.answers_file, "a") as wf:
107
+ json.dump(llddata, wf)
108
+
109
+ if __name__ == "__main__":
110
+ parser = argparse.ArgumentParser()
111
+ parser.add_argument("--model-path", type=str, default="llava-v1.5")
112
+ parser.add_argument("--model-base", type=str, default=None)
113
+ parser.add_argument("--image-folder", type=str, default="./playground/data/qbench/images_llvisionqa")
114
+ parser.add_argument("--questions-file", type=str, default="./playground/data/qbench/llvisionqa_dev.json")
115
+ parser.add_argument("--answers-file", type=str, default="answer.jsonl")
116
+ parser.add_argument("--conv-mode", type=str, default="llava_v1")
117
+ parser.add_argument("--lang", type=str, default="en")
118
+ args = parser.parse_args()
119
+
120
+ eval_model(args)
LLAVA_Biovil/llava/eval/model_vqa_science.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import os
4
+ import json
5
+ from tqdm import tqdm
6
+ import shortuuid
7
+
8
+ from LLAV.llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
9
+ from LLAV.llava.conversation import conv_templates, SeparatorStyle
10
+ from LLAV.llava.model.builder import load_pretrained_model
11
+ from LLAV.llava.utils import disable_torch_init
12
+ from LLAV.llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
13
+
14
+ from PIL import Image
15
+ import math
16
+
17
+
18
+ def split_list(lst, n):
19
+ """Split a list into n (roughly) equal-sized chunks"""
20
+ chunk_size = math.ceil(len(lst) / n) # integer division
21
+ return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
22
+
23
+
24
+ def get_chunk(lst, n, k):
25
+ chunks = split_list(lst, n)
26
+ return chunks[k]
27
+
28
+
29
+ def eval_model(args):
30
+ # Model
31
+ disable_torch_init()
32
+ model_path = os.path.expanduser(args.model_path)
33
+ model_name = get_model_name_from_path(model_path)
34
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
35
+
36
+ questions = json.load(open(os.path.expanduser(args.question_file), "r"))
37
+ questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
38
+ answers_file = os.path.expanduser(args.answers_file)
39
+ os.makedirs(os.path.dirname(answers_file), exist_ok=True)
40
+ ans_file = open(answers_file, "w")
41
+ for i, line in enumerate(tqdm(questions)):
42
+ idx = line["id"]
43
+ question = line['conversations'][0]
44
+ qs = question['value'].replace('<image>', '').strip()
45
+ cur_prompt = qs
46
+
47
+ if 'image' in line:
48
+ image_file = line["image"]
49
+ image = Image.open(os.path.join(args.image_folder, image_file))
50
+ image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
51
+ images = image_tensor.unsqueeze(0).half().cuda()
52
+ if getattr(model.config, 'mm_use_im_start_end', False):
53
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
54
+ else:
55
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
56
+ cur_prompt = '<image>' + '\n' + cur_prompt
57
+ else:
58
+ images = None
59
+
60
+ if args.single_pred_prompt:
61
+ qs = qs + '\n' + "Answer with the option's letter from the given choices directly."
62
+ cur_prompt = cur_prompt + '\n' + "Answer with the option's letter from the given choices directly."
63
+
64
+ conv = conv_templates[args.conv_mode].copy()
65
+ conv.append_message(conv.roles[0], qs)
66
+ conv.append_message(conv.roles[1], None)
67
+ prompt = conv.get_prompt()
68
+
69
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
70
+
71
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
72
+ keywords = [stop_str]
73
+ stopping_criteria = [KeywordsStoppingCriteria(keywords, tokenizer, input_ids)] if conv.version == "v0" else None
74
+
75
+ with torch.inference_mode():
76
+ output_ids = model.generate(
77
+ input_ids,
78
+ images=images,
79
+ do_sample=True if args.temperature > 0 else False,
80
+ temperature=args.temperature,
81
+ max_new_tokens=1024,
82
+ use_cache=True,
83
+ stopping_criteria=stopping_criteria,
84
+ )
85
+
86
+ input_token_len = input_ids.shape[1]
87
+ n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
88
+ if n_diff_input_output > 0:
89
+ print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
90
+ outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
91
+ outputs = outputs.strip()
92
+ if outputs.endswith(stop_str):
93
+ outputs = outputs[:-len(stop_str)]
94
+ outputs = outputs.strip()
95
+
96
+ # prompt for answer
97
+ if args.answer_prompter:
98
+ outputs_reasoning = outputs
99
+ input_ids = tokenizer_image_token(prompt + outputs_reasoning + ' ###\nANSWER:', tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
100
+
101
+ with torch.inference_mode():
102
+ output_ids = model.generate(
103
+ input_ids,
104
+ images=images,
105
+ do_sample=True if args.temperature > 0 else False,
106
+ temperature=args.temperature,
107
+ max_new_tokens=64,
108
+ use_cache=True,
109
+ stopping_criteria=[stopping_criteria])
110
+
111
+ input_token_len = input_ids.shape[1]
112
+ n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
113
+ if n_diff_input_output > 0:
114
+ print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
115
+ outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
116
+ outputs = outputs.strip()
117
+ if outputs.endswith(stop_str):
118
+ outputs = outputs[:-len(stop_str)]
119
+ outputs = outputs.strip()
120
+ outputs = outputs_reasoning + '\n The answer is ' + outputs
121
+
122
+ ans_id = shortuuid.uuid()
123
+ ans_file.write(json.dumps({"question_id": idx,
124
+ "prompt": cur_prompt,
125
+ "text": outputs,
126
+ "answer_id": ans_id,
127
+ "model_id": model_name,
128
+ "metadata": {}}) + "\n")
129
+ ans_file.flush()
130
+ ans_file.close()
131
+
132
+ if __name__ == "__main__":
133
+ parser = argparse.ArgumentParser()
134
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
135
+ parser.add_argument("--model-base", type=str, default=None)
136
+ parser.add_argument("--image-folder", type=str, default="")
137
+ parser.add_argument("--question-file", type=str, default="tables/question.json")
138
+ parser.add_argument("--answers-file", type=str, default="answer.jsonl")
139
+ parser.add_argument("--conv-mode", type=str, default="llava_v0")
140
+ parser.add_argument("--num-chunks", type=int, default=1)
141
+ parser.add_argument("--chunk-idx", type=int, default=0)
142
+ parser.add_argument("--temperature", type=float, default=0.2)
143
+ parser.add_argument("--answer-prompter", action="store_true")
144
+ parser.add_argument("--single-pred-prompt", action="store_true")
145
+ args = parser.parse_args()
146
+
147
+ eval_model(args)
LLAVA_Biovil/llava/eval/qa_baseline_gpt35.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generate answers with GPT-3.5"""
2
+ # Note: you need to be using OpenAI Python v0.27.0 for the code below to work
3
+ import argparse
4
+ import json
5
+ import os
6
+ import time
7
+ import concurrent.futures
8
+
9
+ import openai
10
+ import tqdm
11
+ import shortuuid
12
+
13
+ MODEL = 'gpt-3.5-turbo'
14
+ MODEL_ID = 'gpt-3.5-turbo:20230327'
15
+
16
+ def get_answer(question_id: int, question: str, max_tokens: int):
17
+ ans = {
18
+ 'answer_id': shortuuid.uuid(),
19
+ 'question_id': question_id,
20
+ 'model_id': MODEL_ID,
21
+ }
22
+ for _ in range(3):
23
+ try:
24
+ response = openai.ChatCompletion.create(
25
+ model=MODEL,
26
+ messages=[{
27
+ 'role': 'system',
28
+ 'content': 'You are a helpful assistant.'
29
+ }, {
30
+ 'role': 'user',
31
+ 'content': question,
32
+ }],
33
+ max_tokens=max_tokens,
34
+ )
35
+ ans['text'] = response['choices'][0]['message']['content']
36
+ return ans
37
+ except Exception as e:
38
+ print('[ERROR]', e)
39
+ ans['text'] = '#ERROR#'
40
+ time.sleep(1)
41
+ return ans
42
+
43
+
44
+ if __name__ == '__main__':
45
+ parser = argparse.ArgumentParser(description='ChatGPT answer generation.')
46
+ parser.add_argument('-q', '--question')
47
+ parser.add_argument('-o', '--output')
48
+ parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
49
+ args = parser.parse_args()
50
+
51
+ questions_dict = {}
52
+ with open(os.path.expanduser(args.question)) as f:
53
+ for line in f:
54
+ if not line:
55
+ continue
56
+ q = json.loads(line)
57
+ questions_dict[q['question_id']] = q['text']
58
+
59
+ answers = []
60
+
61
+ with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
62
+ futures = []
63
+ for qid, question in questions_dict.items():
64
+ future = executor.submit(get_answer, qid, question, args.max_tokens)
65
+ futures.append(future)
66
+
67
+ for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
68
+ answers.append(future.result())
69
+
70
+ answers.sort(key=lambda x: x['question_id'])
71
+
72
+ with open(os.path.expanduser(args.output), 'w') as f:
73
+ table = [json.dumps(ans) for ans in answers]
74
+ f.write('\n'.join(table))
LLAVA_Biovil/llava/eval/run_llava.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+
4
+ from LLAV.llava.constants import (
5
+ IMAGE_TOKEN_INDEX,
6
+ DEFAULT_IMAGE_TOKEN,
7
+ DEFAULT_IM_START_TOKEN,
8
+ DEFAULT_IM_END_TOKEN,
9
+ IMAGE_PLACEHOLDER,
10
+ )
11
+ from LLAV.llava.conversation import conv_templates, SeparatorStyle
12
+ from LLAV.llava.model.builder import load_pretrained_model
13
+ from LLAV.llava.utils import disable_torch_init
14
+ from LLAV.llava.mm_utils import (
15
+ process_images,
16
+ tokenizer_image_token,
17
+ get_model_name_from_path,
18
+ KeywordsStoppingCriteria,
19
+ )
20
+
21
+ import requests
22
+ from PIL import Image
23
+ from io import BytesIO
24
+ import re
25
+
26
+
27
+ def image_parser(args):
28
+ out = args.image_file.split(args.sep)
29
+ return out
30
+
31
+
32
+ def load_image(image_file):
33
+ if image_file.startswith("http") or image_file.startswith("https"):
34
+ response = requests.get(image_file)
35
+ image = Image.open(BytesIO(response.content)).convert("RGB")
36
+ else:
37
+ image = Image.open(image_file).convert("RGB")
38
+ return image
39
+
40
+
41
+ def load_images(image_files):
42
+ out = []
43
+ for image_file in image_files:
44
+ image = load_image(image_file)
45
+ out.append(image)
46
+ return out
47
+
48
+
49
+ def eval_model(args):
50
+ # Model
51
+ disable_torch_init()
52
+
53
+ model_name = get_model_name_from_path(args.model_path)
54
+ tokenizer, model, image_processor, context_len = load_pretrained_model(
55
+ args.model_path, args.model_base, model_name
56
+ )
57
+
58
+ qs = args.query
59
+ image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
60
+ if IMAGE_PLACEHOLDER in qs:
61
+ if model.config.mm_use_im_start_end:
62
+ qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
63
+ else:
64
+ qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
65
+ else:
66
+ if model.config.mm_use_im_start_end:
67
+ qs = image_token_se + "\n" + qs
68
+ else:
69
+ qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
70
+
71
+ if "llama-2" in model_name.lower():
72
+ conv_mode = "llava_llama_2"
73
+ elif "v1" in model_name.lower():
74
+ conv_mode = "llava_v1"
75
+ elif "mpt" in model_name.lower():
76
+ conv_mode = "mpt"
77
+ else:
78
+ conv_mode = "llava_v0"
79
+
80
+ if args.conv_mode is not None and conv_mode != args.conv_mode:
81
+ print(
82
+ "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
83
+ conv_mode, args.conv_mode, args.conv_mode
84
+ )
85
+ )
86
+ else:
87
+ args.conv_mode = conv_mode
88
+
89
+ conv = conv_templates[args.conv_mode].copy()
90
+ conv.append_message(conv.roles[0], qs)
91
+ conv.append_message(conv.roles[1], None)
92
+ prompt = conv.get_prompt()
93
+
94
+ image_files = image_parser(args)
95
+ images = load_images(image_files)
96
+ images_tensor = process_images(
97
+ images,
98
+ image_processor,
99
+ model.config
100
+ ).to(model.device, dtype=torch.float16)
101
+
102
+ input_ids = (
103
+ tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
104
+ .unsqueeze(0)
105
+ .cuda()
106
+ )
107
+
108
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
109
+ keywords = [stop_str]
110
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
111
+
112
+ with torch.inference_mode():
113
+ output_ids = model.generate(
114
+ input_ids,
115
+ images=images_tensor,
116
+ do_sample=True if args.temperature > 0 else False,
117
+ temperature=args.temperature,
118
+ top_p=args.top_p,
119
+ num_beams=args.num_beams,
120
+ max_new_tokens=args.max_new_tokens,
121
+ use_cache=True,
122
+ stopping_criteria=[stopping_criteria],
123
+ )
124
+
125
+ input_token_len = input_ids.shape[1]
126
+ n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
127
+ if n_diff_input_output > 0:
128
+ print(
129
+ f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids"
130
+ )
131
+ outputs = tokenizer.batch_decode(
132
+ output_ids[:, input_token_len:], skip_special_tokens=True
133
+ )[0]
134
+ outputs = outputs.strip()
135
+ if outputs.endswith(stop_str):
136
+ outputs = outputs[: -len(stop_str)]
137
+ outputs = outputs.strip()
138
+ print(outputs)
139
+
140
+
141
+ if __name__ == "__main__":
142
+ parser = argparse.ArgumentParser()
143
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
144
+ parser.add_argument("--model-base", type=str, default=None)
145
+ parser.add_argument("--image-file", type=str, required=True)
146
+ parser.add_argument("--query", type=str, required=True)
147
+ parser.add_argument("--conv-mode", type=str, default=None)
148
+ parser.add_argument("--sep", type=str, default=",")
149
+ parser.add_argument("--temperature", type=float, default=0.2)
150
+ parser.add_argument("--top_p", type=float, default=None)
151
+ parser.add_argument("--num_beams", type=int, default=1)
152
+ parser.add_argument("--max_new_tokens", type=int, default=512)
153
+ args = parser.parse_args()
154
+
155
+ eval_model(args)
LLAVA_Biovil/llava/eval/summarize_gpt_review.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from collections import defaultdict
4
+
5
+ import numpy as np
6
+
7
+ import argparse
8
+
9
+ def parse_args():
10
+ parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
11
+ parser.add_argument('-d', '--dir', default=None)
12
+ parser.add_argument('-v', '--version', default=None)
13
+ parser.add_argument('-s', '--select', nargs='*', default=None)
14
+ parser.add_argument('-f', '--files', nargs='*', default=[])
15
+ parser.add_argument('-i', '--ignore', nargs='*', default=[])
16
+ return parser.parse_args()
17
+
18
+
19
+ if __name__ == '__main__':
20
+ args = parse_args()
21
+
22
+ if args.ignore is not None:
23
+ args.ignore = [int(x) for x in args.ignore]
24
+
25
+ if len(args.files) > 0:
26
+ review_files = args.files
27
+ else:
28
+ review_files = [x for x in os.listdir(args.dir) if x.endswith('.jsonl') and (x.startswith('gpt4_text') or x.startswith('reviews_') or x.startswith('review_') or 'review' in args.dir)]
29
+
30
+ for review_file in sorted(review_files):
31
+ config = os.path.basename(review_file).replace('gpt4_text_', '').replace('.jsonl', '')
32
+ if args.select is not None and any(x not in config for x in args.select):
33
+ continue
34
+ if '0613' in config:
35
+ version = '0613'
36
+ else:
37
+ version = '0314'
38
+ if args.version is not None and args.version != version:
39
+ continue
40
+ scores = defaultdict(list)
41
+ print(config)
42
+ with open(os.path.join(args.dir, review_file) if args.dir is not None else review_file) as f:
43
+ for review_str in f:
44
+ review = json.loads(review_str)
45
+ if review['question_id'] in args.ignore:
46
+ continue
47
+ if 'category' in review:
48
+ scores[review['category']].append(review['tuple'])
49
+ scores['all'].append(review['tuple'])
50
+ else:
51
+ if 'tuple' in review:
52
+ scores['all'].append(review['tuple'])
53
+ else:
54
+ scores['all'].append(review['score'])
55
+ for k, v in sorted(scores.items()):
56
+ stats = np.asarray(v).mean(0).tolist()
57
+ stats = [round(x, 3) for x in stats]
58
+ # print(k, stats, round(stats[1]/stats[0]*100, 1))
59
+ print(k, round(stats[1]/stats[0]*100, 1), round(stats[0] * 10, 1), round(stats[1] * 10, 1))
60
+ print('=================================')
LLAVA_Biovil/llava/eval/webpage/figures/alpaca.png ADDED
LLAVA_Biovil/llava/eval/webpage/figures/bard.jpg ADDED
LLAVA_Biovil/llava/eval/webpage/figures/chatgpt.svg ADDED
LLAVA_Biovil/llava/eval/webpage/figures/llama.jpg ADDED
LLAVA_Biovil/llava/eval/webpage/figures/swords_FILL0_wght300_GRAD0_opsz48.svg ADDED
LLAVA_Biovil/llava/eval/webpage/figures/vicuna.jpeg ADDED
LLAVA_Biovil/llava/eval/webpage/index.html ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Who's GPT-4's favorite? Battles between State-of-the-Art Chatbots</title>
7
+ <link rel="stylesheet" href="https://maxcdn.bootstrapcdn.com/bootstrap/4.5.2/css/bootstrap.min.css">
8
+ <link rel="stylesheet" href="https://fonts.googleapis.com/icon?family=Material+Icons">
9
+ <link rel="stylesheet" href="styles.css">
10
+ </head>
11
+
12
+ <body>
13
+ <nav class="navbar navbar-expand-lg navbar-dark bg-dark">
14
+ <a class="navbar-brand" href="#">🏔️ Vicuna Evaluation Examples</a>
15
+ <button class="navbar-toggler" type="button" data-toggle="collapse" data-target="#navbarNav" aria-controls="navbarNav" aria-expanded="false" aria-label="Toggle navigation">
16
+ <span class="navbar-toggler-icon"></span>
17
+ </button>
18
+ <div class="collapse navbar-collapse" id="navbarNav">
19
+ <ul class="navbar-nav mr-auto">
20
+ <li class="nav-item">
21
+ <a class="nav-link" href="https://chat.lmsys.org/">Demo</a>
22
+ </li>
23
+ <li class="nav-item">
24
+ <a class="nav-link" href="https://vicuna.lmsys.org">Blog</a>
25
+ </li>
26
+ <li class="nav-item">
27
+ <a class="nav-link" href="https://github.com/lm-sys/FastChat">Github</a>
28
+ </li>
29
+ </ul>
30
+ </div>
31
+ </nav>
32
+
33
+ <div class="container mt-5">
34
+ <h2 class="text-center mb-5">Who's GPT-4's favorite? Battles between State-of-the-Art Chatbots</h2>
35
+
36
+ <!-- Selection -->
37
+ <div class="form-row">
38
+ <div class="form-group col-md-2">
39
+ <label for="category-select">Category</label>
40
+ <select class="form-control" id="category-select"></select>
41
+ </div>
42
+ <div class="form-group col-md-8">
43
+ <label for="question-select">Question</label>
44
+ <select class="form-control" id="question-select"></select>
45
+ </div>
46
+ <div class="form-group col-md-2">
47
+ <div class="col-md-2"><label>&nbsp;</label></div>
48
+ <div class="btn-group" role="group" aria-label="Left and Right Controller">
49
+ <button type="button" class="form-control btn btn-primary" id="prev-question"><i class="material-icons">keyboard_arrow_left</i></button>
50
+ <button type="button" class="form-control btn btn-primary" id="next-question"><i class="material-icons">keyboard_arrow_right</i></button>
51
+ </div>
52
+ </div>
53
+ </div>
54
+
55
+ <!-- "Battle" -->
56
+ <div class="row mb-4" style="justify-content: center;">
57
+ <div class="col" style="display: flex; justify-content: center; align-items: center;">
58
+ <label class="adjustable-font-size" id="other-score-label">*/10</label>
59
+ </div>
60
+ <div class="col">
61
+ <div class="vertical-flex-layout">
62
+ <img class="shadow figure-img img-fluid" src="" alt="other logo" width="150" id="other-model-figure">
63
+ </div>
64
+ </div>
65
+ <div class="col">
66
+ <div class="vertical-flex-layout">
67
+ <!-- from: https://fonts.google.com/icons?icon.query=battle&selected=Material+Symbols+Outlined:swords:FILL@0;wght@300;GRAD@0;opsz@48&icon.style=Outlined -->
68
+ <img class="figure-img img-fluid" src="figures/swords_FILL0_wght300_GRAD0_opsz48.svg" width="60" height="60">
69
+ </div>
70
+ </div>
71
+ <div class="col">
72
+ <div class="vertical-flex-layout">
73
+ <img class="shadow figure-img img-fluid" src="figures/vicuna.jpeg" alt="vicuna logo" width="150" id="our-model-figure">
74
+ </div>
75
+ </div>
76
+ <div class="col" style="display: flex; justify-content: center; align-items: center;">
77
+ <label class="adjustable-font-size" id="our-score-label">*/10</label>
78
+ </div>
79
+ </div>
80
+
81
+ <!-- Question Card -->
82
+ <div class="card mb-4">
83
+ <div class="card-body" id="selected-question"></div>
84
+ </div>
85
+
86
+ <!-- Answer Cards -->
87
+ <div class="row">
88
+ <div class="col-md-6">
89
+ <div class="card mb-4 expandable-card">
90
+ <div class="card-header" style="padding-bottom: 0.2rem" id="other-model-header-bg">
91
+ <div class="row">
92
+ <div class="col-md-5" style="align-items: center; display: flex;">
93
+ <label id="other-model-header">Assistant #1</label>
94
+ </div>
95
+ <div class="col-md-7">
96
+ <select class="form-control" id="model-select" style="height: fit-content; margin-top: -0.3rem;"></select>
97
+ </div>
98
+ </div>
99
+ </div>
100
+ <div class="card-body">
101
+ <div class="card-text-container">
102
+ <div class="card-text" id="other-model-answer"></div>
103
+ </div>
104
+ <div class="btn btn-primary expand-btn" style="display:flex;"></div>
105
+ </div>
106
+ </div>
107
+ </div>
108
+ <div class="col-md-6">
109
+ <div class="card mb-4 expandable-card">
110
+ <div class="card-header" id="our-model-header">
111
+ Assistant #2 (Vicuna, our model)
112
+ </div>
113
+ <div class="card-body">
114
+ <div class="card-text-container">
115
+ <div class="card-text" id="our-model-answer"></div>
116
+ </div>
117
+ <div class="btn btn-primary expand-btn" style="display:flex;"></div>
118
+ </div>
119
+ </div>
120
+ </div>
121
+ </div>
122
+
123
+ <!-- Evaluation -->
124
+ <div class="card expandable-card">
125
+ <div class="card-header" style="background-color: #c9c9f2;" id="evaluation-header">GPT-4 Evaluation</div>
126
+ <div class="card-body">
127
+ <div class="card-text-container">
128
+ <div class="card-text" id="evaluation-result"></div>
129
+ </div>
130
+ <div class="btn btn-primary expand-btn" style="display:flex;"></div>
131
+ </div>
132
+ </div>
133
+ </div>
134
+
135
+ <div class="container-fluid bg-light py-2">
136
+ <div class="text-center">
137
+ <small class="text-muted">This website is co-authored with <a href="https://openai.com" target="_blank">GPT-4</a>.</small>
138
+ </div>
139
+ </div>
140
+
141
+ <!-- Marked.js -->
142
+ <script src="https://cdn.jsdelivr.net/npm/marked@4.3.0/lib/marked.umd.min.js"></script>
143
+ <!-- Bootstrap and Popper.js JavaScript dependencies -->
144
+ <script src="https://code.jquery.com/jquery-3.5.1.slim.min.js"></script>
145
+ <script src="https://cdn.jsdelivr.net/npm/@popperjs/core@2.11.6/dist/umd/popper.min.js"></script>
146
+ <script src="https://maxcdn.bootstrapcdn.com/bootstrap/4.5.2/js/bootstrap.min.js"></script>
147
+
148
+ <script src="script.js"></script>
149
+ <script>
150
+ // Fetch the JSON file
151
+ fetch('data.json')
152
+ .then(response => response.json())
153
+ .then(json_data => {
154
+ // Populate the models and questions.
155
+ populateModels(json_data.models);
156
+ populateQuestions(json_data.questions);
157
+ displayQuestion(currentQuestionIndex);
158
+ }).catch(error => console.error(error));
159
+ </script>
160
+ </body>
161
+
162
+ </html>
LLAVA_Biovil/llava/eval/webpage/script.js ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Description: Script for the evaluation webpage.
2
+
3
+ let currentQuestionIndex = 1;
4
+
5
+ // Store the model name mapping for later use.
6
+ modelNameMapping = {
7
+ "gpt35": "ChatGPT-3.5",
8
+ "gpt4": "GPT-4",
9
+ "alpaca": "Alpaca-13b",
10
+ "vicuna": "Vicuna-13b",
11
+ "llama": "LLaMA-13b",
12
+ "bard": "Bard",
13
+ };
14
+
15
+ modelFigureMapping = {
16
+ "vicuna": "figures/vicuna.jpeg",
17
+ // Image from: https://commons.wikimedia.org/wiki/File:ChatGPT_logo.svg
18
+ "gpt35": "figures/chatgpt.svg",
19
+ // Image from: https://www.reddit.com/r/logodesign/comments/1128aat/google_ai_bard_logo_design/
20
+ "bard": "figures/bard.jpg",
21
+ // Image from: https://crfm.stanford.edu/2023/03/13/alpaca.html
22
+ "alpaca": "figures/alpaca.png",
23
+ // Image adapted from https://commons.wikimedia.org/wiki/File:Llama_on_Machu_Picchu.jpg
24
+ "llama": "figures/llama.jpg",
25
+ }
26
+
27
+ // Store the question data in a mapping for later use.
28
+ questionMapping = {};
29
+ // Store the question ids in a mapping for later use.
30
+ categoryMapping = {};
31
+ // Store the number of questions for later use.
32
+ questionsCount = 0;
33
+
34
+
35
+ function text2Markdown(text) {
36
+ // Normalize the text for markdown rendering.
37
+ text = text.trim().replaceAll('\n\n', '\n').replaceAll('\n', '\n\n');
38
+ return marked.parse(text);
39
+ }
40
+
41
+ function capitalizeFirstChar(str) {
42
+ if (!str || str.length === 0) {
43
+ return str;
44
+ }
45
+ return str.charAt(0).toUpperCase() + str.slice(1);
46
+ }
47
+
48
+ function updateQuestionSelect(question_id) {
49
+ const select = document.getElementById('question-select');
50
+ // Clear the question select.
51
+ select.innerHTML = '';
52
+ // Populate the question select.
53
+ category = questionMapping[question_id].category;
54
+ categoryMapping[category].forEach(question_id => {
55
+ const question = questionMapping[question_id];
56
+ const option = document.createElement('option');
57
+ option.value = question_id;
58
+ option.textContent = 'Q' + question_id.toString() + ': ' + question.question;
59
+ select.appendChild(option);
60
+ });
61
+ select.value = question_id;
62
+ }
63
+
64
+ function updateModelSelect() {
65
+ const select = document.getElementById('model-select');
66
+ img_path = modelFigureMapping[select.value];
67
+ document.getElementById('other-model-figure').src = img_path;
68
+ }
69
+
70
+ function populateModels(models) {
71
+ const select = document.getElementById('model-select');
72
+ models.forEach(model => {
73
+ const option = document.createElement('option');
74
+ option.value = model;
75
+ option.textContent = modelNameMapping[model];
76
+ select.appendChild(option);
77
+ });
78
+ updateModelSelect();
79
+ }
80
+
81
+ function populateQuestions(questions) {
82
+ const category_select = document.getElementById('category-select');
83
+
84
+ questionsCount = questions.length;
85
+ questions.forEach(question => {
86
+ const option = document.createElement('option');
87
+ // Store the question data in a mapping for later use.
88
+ questionMapping[question.id] = {
89
+ category: question.category,
90
+ question: question.question,
91
+ answers: question.answers,
92
+ evaluations: question.evaluations,
93
+ scores: question.scores,
94
+ };
95
+ // Store the question id in the category mapping.
96
+ if (question.category in categoryMapping) {
97
+ categoryMapping[question.category].push(question.id);
98
+ } else {
99
+ categoryMapping[question.category] = [question.id];
100
+ const category_option = document.createElement('option');
101
+ category_option.value = question.category;
102
+ category_option.textContent = capitalizeFirstChar(question.category);
103
+ category_select.appendChild(category_option);
104
+ }
105
+ });
106
+ // Set the default category.
107
+ updateQuestionSelect(currentQuestionIndex);
108
+ }
109
+
110
+ function displayQuestion(index) {
111
+ const question = questionMapping[index].question;
112
+ document.getElementById('selected-question').innerHTML = text2Markdown('**Question:** ' + question);
113
+ displayAnswers(index);
114
+ }
115
+
116
+ function displayAnswers(index) {
117
+ const question = questionMapping[index];
118
+ const otherModel = document.getElementById('model-select').value;
119
+ // render the answers with markdown
120
+ document.getElementById('other-model-answer').innerHTML = text2Markdown(question.answers[otherModel]);
121
+ document.getElementById('our-model-answer').innerHTML = text2Markdown(question.answers.vicuna);
122
+
123
+ // Display evaluation
124
+ score = question.scores[otherModel];
125
+ score_text = modelNameMapping[otherModel] + " " + score[0] + "/10, Vicuna-13b " + score[1] + "/10";
126
+ document.getElementById('evaluation-header').textContent = "GPT-4 Evaluation" + " (Score: " + score_text + ")";
127
+ document.getElementById('evaluation-result').innerHTML = text2Markdown(question.evaluations[otherModel]);
128
+
129
+ // Update model names
130
+ let assistant1_title = "Assistant #1"; // (" + modelNameMapping[otherModel] + ")";
131
+ let assistant2_title = "Assistant #2 (Vicuna-13b, our model)";
132
+ // Update scores/labels.
133
+ let assistant1_score_label = score[0].toString() + '/10';
134
+ let assistant2_score_label = score[1].toString() + '/10';
135
+
136
+ const colorRed ='#fa9'; // '#eb978d';
137
+ // const colorGreen = '#c9f2c9';
138
+ const colorBlue = '#8ef'; // '#71dbf9';
139
+ const colorYellow = '#fe7'; // '#fada57';
140
+ let otherModelHeaderColor = '';
141
+ let ourModelHeaderColor = '';
142
+ // Update the winner.
143
+ if (score[0] == score[1]) {
144
+ assistant1_title = '🏆 ' + assistant1_title;
145
+ assistant1_score_label = '🏆 ' + assistant1_score_label;
146
+ assistant2_title = '🏆 ' + assistant2_title;
147
+ assistant2_score_label = '🏆 ' + assistant2_score_label;
148
+ otherModelHeaderColor = colorYellow;
149
+ ourModelHeaderColor = colorYellow;
150
+ } else if (score[0] > score[1]) {
151
+ assistant1_title = '🏆 ' + assistant1_title;
152
+ assistant1_score_label = '🏆 ' + assistant1_score_label;
153
+ otherModelHeaderColor = colorBlue;
154
+ ourModelHeaderColor = colorRed;
155
+ } else if (score[0] < score[1]) {
156
+ assistant2_title = '🏆 ' + assistant2_title;
157
+ assistant2_score_label = '🏆 ' + assistant2_score_label;
158
+ otherModelHeaderColor = colorRed;
159
+ ourModelHeaderColor = colorBlue;
160
+ }
161
+
162
+ document.getElementById('other-model-header-bg').style.backgroundColor = otherModelHeaderColor;
163
+ document.getElementById('our-model-header').style.backgroundColor = ourModelHeaderColor;
164
+
165
+ document.getElementById('other-model-header').textContent = assistant1_title;
166
+ document.getElementById('our-model-header').textContent = assistant2_title;
167
+
168
+ document.getElementById('other-score-label').textContent = assistant1_score_label;
169
+ document.getElementById('our-score-label').textContent = assistant2_score_label;
170
+
171
+ // Update expand buttons visibility for both cards after displaying answers
172
+ // Reset the expanded state and update expand buttons visibility for both cards after displaying answers
173
+ document.querySelectorAll('.expandable-card').forEach(card => {
174
+ card.classList.remove('expanded');
175
+ updateExpandButtonVisibility(card);
176
+ const expandBtn = card.querySelector('.expand-btn');
177
+ expandBtn.innerHTML = '<i class="material-icons" style="pointer-events: none">keyboard_arrow_down</i> Show more'; // .textContent = 'Show more';
178
+ });
179
+ }
180
+
181
+ document.getElementById('question-select').addEventListener('change', e => {
182
+ currentQuestionIndex = parseInt(e.target.value);
183
+ displayQuestion(currentQuestionIndex);
184
+ });
185
+
186
+ document.getElementById('category-select').addEventListener('change', e => {
187
+ let currentCategory = e.target.value;
188
+ const questionIds = categoryMapping[currentCategory];
189
+ currentQuestionIndex = questionIds[0];
190
+ updateQuestionSelect(currentQuestionIndex);
191
+ displayQuestion(currentQuestionIndex);
192
+ });
193
+
194
+ // Update expand buttons whenever the model is changed
195
+ document.getElementById('model-select').addEventListener('change', () => {
196
+ displayAnswers(currentQuestionIndex);
197
+ document.querySelectorAll('.expandable-card').forEach(card => {
198
+ updateExpandButtonVisibility(card);
199
+ });
200
+ updateModelSelect();
201
+ });
202
+
203
+ function switchQuestionAndCategory() {
204
+ document.getElementById('question-select').value = currentQuestionIndex;
205
+ old_category = document.getElementById('category-select').value;
206
+ new_category = questionMapping[currentQuestionIndex].category;
207
+ if (old_category != new_category) {
208
+ document.getElementById('category-select').value = new_category;
209
+ updateQuestionSelect(currentQuestionIndex);
210
+ }
211
+ displayQuestion(currentQuestionIndex);
212
+ }
213
+
214
+ document.getElementById('prev-question').addEventListener('click', () => {
215
+ // Question index starts from 1.
216
+ currentQuestionIndex = Math.max(1, currentQuestionIndex - 1);
217
+ switchQuestionAndCategory();
218
+ });
219
+
220
+ document.getElementById('next-question').addEventListener('click', () => {
221
+ // Question index starts from 1.
222
+ currentQuestionIndex = Math.min(questionsCount, currentQuestionIndex + 1);
223
+ switchQuestionAndCategory();
224
+ });
225
+
226
+ function updateExpandButtonVisibility(card) {
227
+ const cardTextContainer = card.querySelector('.card-text-container');
228
+ const expandBtn = card.querySelector('.expand-btn');
229
+ if (cardTextContainer.scrollHeight > cardTextContainer.offsetHeight) {
230
+ expandBtn.style.display = 'flex';
231
+ } else {
232
+ expandBtn.style.display = 'none';
233
+ card.classList.add('expanded');
234
+ }
235
+ }
236
+
237
+ document.querySelectorAll('.expand-btn').forEach(btn => {
238
+ btn.addEventListener('click', e => {
239
+ const card = e.target.closest('.expandable-card');
240
+ card.classList.toggle('expanded');
241
+ const more = '<i class="material-icons" style="pointer-events: none">keyboard_arrow_down</i> Show more';
242
+ const less = '<i class="material-icons" style="pointer-events: none">keyboard_arrow_up</i> Show less';
243
+ e.target.innerHTML = card.classList.contains('expanded') ? less : more;
244
+ });
245
+ });
LLAVA_Biovil/llava/eval/webpage/styles.css ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ body {
2
+ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
3
+ background-color: #f8f9fa;
4
+ }
5
+
6
+ .navbar-dark .navbar-nav .nav-link {
7
+ color: #f1cf68;
8
+ font-size: 1.1rem;
9
+ padding: 0.5rem 0.6rem;
10
+ }
11
+
12
+ .card-header {
13
+ font-weight: bold;
14
+ }
15
+
16
+ .card {
17
+ box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
18
+ transition: 0.3s;
19
+ }
20
+
21
+ .card:hover {
22
+ box-shadow: 0 8px 16px rgba(0, 0, 0, 0.2);
23
+ }
24
+
25
+ button {
26
+ transition: background-color 0.3s;
27
+ }
28
+
29
+ button:hover {
30
+ background-color: #007bff;
31
+ }
32
+
33
+ @media (max-width: 767px) {
34
+ .form-row .form-group {
35
+ margin-bottom: 10px;
36
+ }
37
+ }
38
+
39
+ /* Extra styles */
40
+
41
+ .expandable-card .card-text-container {
42
+ max-height: 200px;
43
+ overflow-y: hidden;
44
+ position: relative;
45
+ }
46
+
47
+ .expandable-card.expanded .card-text-container {
48
+ max-height: none;
49
+ }
50
+
51
+ .expand-btn {
52
+ position: relative;
53
+ display: none;
54
+ background-color: rgba(255, 255, 255, 0.8);
55
+ color: #510c75;
56
+ border-color: transparent;
57
+ }
58
+
59
+ .expand-btn:hover {
60
+ background-color: rgba(200, 200, 200, 0.8);
61
+ text-decoration: none;
62
+ border-color: transparent;
63
+ color: #510c75;
64
+ }
65
+
66
+ .expand-btn:focus {
67
+ outline: none;
68
+ text-decoration: none;
69
+ }
70
+
71
+ .expandable-card:not(.expanded) .card-text-container:after {
72
+ content: "";
73
+ position: absolute;
74
+ bottom: 0;
75
+ left: 0;
76
+ width: 100%;
77
+ height: 90px;
78
+ background: linear-gradient(rgba(255, 255, 255, 0.2), rgba(255, 255, 255, 1));
79
+ }
80
+
81
+ .expandable-card:not(.expanded) .expand-btn {
82
+ margin-top: -40px;
83
+ }
84
+
85
+ .card-body {
86
+ padding-bottom: 5px;
87
+ }
88
+
89
+ .vertical-flex-layout {
90
+ justify-content: center;
91
+ align-items: center;
92
+ height: 100%;
93
+ display: flex;
94
+ flex-direction: column;
95
+ gap: 5px;
96
+ }
97
+
98
+ .figure-img {
99
+ max-width: 100%;
100
+ height: auto;
101
+ }
102
+
103
+ .adjustable-font-size {
104
+ font-size: calc(0.5rem + 2vw);
105
+ }
LLAVA_Biovil/llava/mm_utils.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from io import BytesIO
3
+ import base64
4
+ import numpy as np
5
+
6
+ import torch
7
+ from transformers import StoppingCriteria
8
+ from llava.constants import IMAGE_TOKEN_INDEX
9
+
10
+
11
+ def load_image_from_base64(image):
12
+ return Image.open(BytesIO(base64.b64decode(image)))
13
+
14
+ def remap_to_uint8(array: np.ndarray, percentiles=None) -> np.ndarray:
15
+ """Remap values in input so the output range is :math:`[0, 255]`.
16
+
17
+ Percentiles can be used to specify the range of values to remap.
18
+ This is useful to discard outliers in the input data.
19
+
20
+ :param array: Input array.
21
+ :param percentiles: Percentiles of the input values that will be mapped to ``0`` and ``255``.
22
+ Passing ``None`` is equivalent to using percentiles ``(0, 100)`` (but faster).
23
+ :returns: Array with ``0`` and ``255`` as minimum and maximum values.
24
+ """
25
+ array = array.astype(float)
26
+ if percentiles is not None:
27
+ len_percentiles = len(percentiles)
28
+ if len_percentiles != 2:
29
+ message = (
30
+ 'The value for percentiles should be a sequence of length 2,'
31
+ f' but has length {len_percentiles}'
32
+ )
33
+ raise ValueError(message)
34
+ a, b = percentiles
35
+ if a >= b:
36
+ raise ValueError(f'Percentiles must be in ascending order, but a sequence "{percentiles}" was passed')
37
+ if a < 0 or b > 100:
38
+ raise ValueError(f'Percentiles must be in the range [0, 100], but a sequence "{percentiles}" was passed')
39
+ cutoff: np.ndarray = np.percentile(array, percentiles)
40
+ array = np.clip(array, *cutoff)
41
+ array -= array.min()
42
+ array /= array.max()
43
+ array *= 255
44
+ return array.astype(np.uint8)
45
+ def load_image_from_base64_biovil(image):
46
+ image = Image.open(BytesIO(base64.b64decode(image)))
47
+ image = remap_to_uint8(np.array(image))
48
+ return Image.fromarray(image).convert("L")
49
+
50
+ def expand2square(pil_img, background_color):
51
+ width, height = pil_img.size
52
+ if width == height:
53
+ return pil_img
54
+ elif width > height:
55
+ result = Image.new(pil_img.mode, (width, width), background_color)
56
+ result.paste(pil_img, (0, (width - height) // 2))
57
+ return result
58
+ else:
59
+ result = Image.new(pil_img.mode, (height, height), background_color)
60
+ result.paste(pil_img, ((height - width) // 2, 0))
61
+ return result
62
+
63
+ def process_images(images, image_processor, model_cfg):
64
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
65
+ new_images = []
66
+ if image_aspect_ratio == 'pad':
67
+ for image in images:
68
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
69
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
70
+ new_images.append(image)
71
+ else:
72
+ return image_processor(images, return_tensors='pt')['pixel_values']
73
+ if all(x.shape == new_images[0].shape for x in new_images):
74
+ new_images = torch.stack(new_images, dim=0)
75
+ return new_images
76
+
77
+ def process_image_biovil(images, image_processor):
78
+ new_images = []
79
+ for image in images:
80
+ image = image_processor(image)
81
+ new_images.append(image)
82
+
83
+ if all(x.shape == new_images[0].shape for x in new_images):
84
+ new_images = torch.stack(new_images, dim=0)
85
+ return new_images
86
+
87
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
88
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
89
+
90
+ def insert_separator(X, sep):
91
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
92
+
93
+ input_ids = []
94
+ offset = 0
95
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
96
+ offset = 1
97
+ input_ids.append(prompt_chunks[0][0])
98
+
99
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
100
+ input_ids.extend(x[offset:])
101
+
102
+ if return_tensors is not None:
103
+ if return_tensors == 'pt':
104
+ return torch.tensor(input_ids, dtype=torch.long)
105
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
106
+ return input_ids
107
+
108
+
109
+ def get_model_name_from_path(model_path):
110
+ model_path = model_path.strip("/")
111
+ model_paths = model_path.split("/")
112
+ if model_paths[-1].startswith('checkpoint-'):
113
+ return model_paths[-2] + "_" + model_paths[-1]
114
+ else:
115
+ return model_paths[-1]
116
+
117
+ class KeywordsStoppingCriteria(StoppingCriteria):
118
+ def __init__(self, keywords, tokenizer, input_ids):
119
+ self.keywords = keywords
120
+ self.keyword_ids = []
121
+ self.max_keyword_len = 0
122
+ for keyword in keywords:
123
+ cur_keyword_ids = tokenizer(keyword).input_ids
124
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
125
+ cur_keyword_ids = cur_keyword_ids[1:]
126
+ if len(cur_keyword_ids) > self.max_keyword_len:
127
+ self.max_keyword_len = len(cur_keyword_ids)
128
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
129
+ self.tokenizer = tokenizer
130
+ self.start_len = input_ids.shape[1]
131
+
132
+ def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
133
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
134
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
135
+ for keyword_id in self.keyword_ids:
136
+ if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
137
+ return True
138
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
139
+ for keyword in self.keywords:
140
+ if keyword in outputs:
141
+ return True
142
+ return False
143
+
144
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
145
+ outputs = []
146
+ for i in range(output_ids.shape[0]):
147
+ outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
148
+ return all(outputs)