Spaces:
Sleeping
Sleeping
cao-lele
commited on
Commit
·
0724c4e
1
Parent(s):
9619f3f
initial commit
Browse files- .gitignore +167 -0
- Leaderboard.md +163 -0
- README.md +55 -0
- datasets/examples/000000061658.jpg +0 -0
- datasets/examples/000000338560.jpg +0 -0
- genception/evaluation.py +101 -0
- genception/example_script.sh +8 -0
- genception/experiment.py +373 -0
- genception/utils.py +25 -0
- requirements.txt +7 -0
.gitignore
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#custom
|
2 |
+
mme_data/
|
3 |
+
|
4 |
+
|
5 |
+
# Byte-compiled / optimized / DLL files
|
6 |
+
__pycache__/
|
7 |
+
*.py[cod]
|
8 |
+
*$py.class
|
9 |
+
|
10 |
+
# C extensions
|
11 |
+
*.so
|
12 |
+
|
13 |
+
# Distribution / packaging
|
14 |
+
.Python
|
15 |
+
build/
|
16 |
+
develop-eggs/
|
17 |
+
dist/
|
18 |
+
downloads/
|
19 |
+
eggs/
|
20 |
+
.eggs/
|
21 |
+
lib/
|
22 |
+
lib64/
|
23 |
+
parts/
|
24 |
+
sdist/
|
25 |
+
var/
|
26 |
+
wheels/
|
27 |
+
share/python-wheels/
|
28 |
+
*.egg-info/
|
29 |
+
.installed.cfg
|
30 |
+
*.egg
|
31 |
+
MANIFEST
|
32 |
+
|
33 |
+
# PyInstaller
|
34 |
+
# Usually these files are written by a python script from a template
|
35 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
36 |
+
*.manifest
|
37 |
+
*.spec
|
38 |
+
|
39 |
+
# Installer logs
|
40 |
+
pip-log.txt
|
41 |
+
pip-delete-this-directory.txt
|
42 |
+
|
43 |
+
# Unit test / coverage reports
|
44 |
+
htmlcov/
|
45 |
+
.tox/
|
46 |
+
.nox/
|
47 |
+
.coverage
|
48 |
+
.coverage.*
|
49 |
+
.cache
|
50 |
+
nosetests.xml
|
51 |
+
coverage.xml
|
52 |
+
*.cover
|
53 |
+
*.py,cover
|
54 |
+
.hypothesis/
|
55 |
+
.pytest_cache/
|
56 |
+
cover/
|
57 |
+
|
58 |
+
# Translations
|
59 |
+
*.mo
|
60 |
+
*.pot
|
61 |
+
|
62 |
+
# Django stuff:
|
63 |
+
*.log
|
64 |
+
local_settings.py
|
65 |
+
db.sqlite3
|
66 |
+
db.sqlite3-journal
|
67 |
+
|
68 |
+
# Flask stuff:
|
69 |
+
instance/
|
70 |
+
.webassets-cache
|
71 |
+
|
72 |
+
# Scrapy stuff:
|
73 |
+
.scrapy
|
74 |
+
|
75 |
+
# Sphinx documentation
|
76 |
+
docs/_build/
|
77 |
+
|
78 |
+
# PyBuilder
|
79 |
+
.pybuilder/
|
80 |
+
target/
|
81 |
+
|
82 |
+
# Jupyter Notebook
|
83 |
+
.ipynb_checkpoints
|
84 |
+
|
85 |
+
# IPython
|
86 |
+
profile_default/
|
87 |
+
ipython_config.py
|
88 |
+
|
89 |
+
# pyenv
|
90 |
+
# For a library or package, you might want to ignore these files since the code is
|
91 |
+
# intended to run in multiple environments; otherwise, check them in:
|
92 |
+
# .python-version
|
93 |
+
|
94 |
+
# pipenv
|
95 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
96 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
97 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
98 |
+
# install all needed dependencies.
|
99 |
+
#Pipfile.lock
|
100 |
+
|
101 |
+
# poetry
|
102 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
103 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
104 |
+
# commonly ignored for libraries.
|
105 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
106 |
+
#poetry.lock
|
107 |
+
|
108 |
+
# pdm
|
109 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
110 |
+
#pdm.lock
|
111 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
112 |
+
# in version control.
|
113 |
+
# https://pdm.fming.dev/#use-with-ide
|
114 |
+
.pdm.toml
|
115 |
+
|
116 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
117 |
+
__pypackages__/
|
118 |
+
|
119 |
+
# Celery stuff
|
120 |
+
celerybeat-schedule
|
121 |
+
celerybeat.pid
|
122 |
+
|
123 |
+
# SageMath parsed files
|
124 |
+
*.sage.py
|
125 |
+
|
126 |
+
# Environments
|
127 |
+
.env
|
128 |
+
.venv
|
129 |
+
env/
|
130 |
+
venv/
|
131 |
+
ENV/
|
132 |
+
env.bak/
|
133 |
+
venv.bak/
|
134 |
+
|
135 |
+
# Spyder project settings
|
136 |
+
.spyderproject
|
137 |
+
.spyproject
|
138 |
+
|
139 |
+
# Rope project settings
|
140 |
+
.ropeproject
|
141 |
+
|
142 |
+
# mkdocs documentation
|
143 |
+
/site
|
144 |
+
|
145 |
+
# mypy
|
146 |
+
.mypy_cache/
|
147 |
+
.dmypy.json
|
148 |
+
dmypy.json
|
149 |
+
|
150 |
+
# Pyre type checker
|
151 |
+
.pyre/
|
152 |
+
|
153 |
+
# pytype static type analyzer
|
154 |
+
.pytype/
|
155 |
+
|
156 |
+
# Cython debug symbols
|
157 |
+
cython_debug/
|
158 |
+
|
159 |
+
# PyCharm
|
160 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
161 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
162 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
163 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
164 |
+
#.idea/
|
165 |
+
|
166 |
+
# other
|
167 |
+
.DS_Store
|
Leaderboard.md
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 🔥🏅️GenCeption Leaderboard 🏅️🔥
|
2 |
+
|
3 |
+
Evaluated MLLMs: [ChatGPT-4V](https://cdn.openai.com/papers/GPTV_System_Card.pdf), [mPLUG-Owl2](https://arxiv.org/pdf/2311.04257.pdf), [LLaVA-13B](https://arxiv.org/pdf/2304.08485.pdf), [LLaVA-7B](https://arxiv.org/pdf/2304.08485.pdf)
|
4 |
+
|
5 |
+
<table>
|
6 |
+
<tr><th>Existence </th><th>Count</th></tr>
|
7 |
+
<tr><td>
|
8 |
+
|
9 |
+
| Model | GC@3|
|
10 |
+
|--|--|
|
11 |
+
| ChatGPT-4V|0.422 |
|
12 |
+
| mPLUG-Owl2|0.323 |
|
13 |
+
| LLaVA-7B|0.308 |
|
14 |
+
| LLaVA-13B|0.305 |
|
15 |
+
|
16 |
+
</td><td>
|
17 |
+
|
18 |
+
| Model | GC@3|
|
19 |
+
|--|--|
|
20 |
+
| ChatGPT-4V|0.404 |
|
21 |
+
| mPLUG-Owl2|0.299 |
|
22 |
+
| LLaVA-13B|0.294 |
|
23 |
+
| LLaVA-7B|0.353 |
|
24 |
+
|
25 |
+
</td></tr> </table>
|
26 |
+
|
27 |
+
|
28 |
+
<table>
|
29 |
+
<tr><th>Position </th><th>Color</th></tr>
|
30 |
+
<tr><td>
|
31 |
+
|
32 |
+
| Model | GC@3|
|
33 |
+
|--|--|
|
34 |
+
| ChatGPT-4V|0.408|
|
35 |
+
| mPLUG-Owl2|0.306 |
|
36 |
+
| LLaVA-7B|0.285 |
|
37 |
+
| LLaVA-13B|0.255 |
|
38 |
+
|
39 |
+
</td><td>
|
40 |
+
|
41 |
+
| Model | GC@3|
|
42 |
+
|--|--|
|
43 |
+
| ChatGPT-4V|0.403 |
|
44 |
+
| LLaVA-13B|0.300 |
|
45 |
+
| mPLUG-Owl2|0.290 |
|
46 |
+
| LLaVA-7B|0.284 |
|
47 |
+
|
48 |
+
</td></tr> </table>
|
49 |
+
|
50 |
+
|
51 |
+
<table>
|
52 |
+
<tr><th>Poster </th><th>Celebrity</th></tr>
|
53 |
+
<tr><td>
|
54 |
+
|
55 |
+
| Model | GC@3|
|
56 |
+
|--|--|
|
57 |
+
| ChatGPT-4V|0.324|
|
58 |
+
| mPLUG-Owl2|0.243 |
|
59 |
+
| LLaVA-13B|0.215 |
|
60 |
+
| LLaVA-7B|0.214 |
|
61 |
+
|
62 |
+
</td><td>
|
63 |
+
|
64 |
+
| Model | GC@3|
|
65 |
+
|--|--|
|
66 |
+
| ChatGPT-4V|0.332 |
|
67 |
+
| mPLUG-Owl2|0.232 |
|
68 |
+
| LLaVA-13B|0.206 |
|
69 |
+
| LLaVA-7B|0.188 |
|
70 |
+
|
71 |
+
</td></tr> </table>
|
72 |
+
|
73 |
+
|
74 |
+
<table>
|
75 |
+
<tr><th>Scene </th><th>Landmark</th></tr>
|
76 |
+
<tr><td>
|
77 |
+
|
78 |
+
| Model | GC@3|
|
79 |
+
|--|--|
|
80 |
+
| ChatGPT-4V|0.393|
|
81 |
+
| mPLUG-Owl2|0.299 |
|
82 |
+
| LLaVA-13B|0.277 |
|
83 |
+
| LLaVA-7B|0.266 |
|
84 |
+
|
85 |
+
</td><td>
|
86 |
+
|
87 |
+
| Model | GC@3|
|
88 |
+
|--|--|
|
89 |
+
| ChatGPT-4V|0.353 |
|
90 |
+
| mPLUG-Owl2|0.275 |
|
91 |
+
| LLaVA-7B|0.252 |
|
92 |
+
| LLaVA-13B|0.242 |
|
93 |
+
|
94 |
+
</td></tr> </table>
|
95 |
+
|
96 |
+
|
97 |
+
<table>
|
98 |
+
<tr><th>Artwork </th><th>Commonsense Reasoning</th></tr>
|
99 |
+
<tr><td>
|
100 |
+
|
101 |
+
| Model | GC@3|
|
102 |
+
|--|--|
|
103 |
+
| ChatGPT-4V|0.421|
|
104 |
+
| mPLUG-Owl2|0.252 |
|
105 |
+
| LLaVA-13B|0.212 |
|
106 |
+
| LLaVA-7B|0.210 |
|
107 |
+
|
108 |
+
</td><td>
|
109 |
+
|
110 |
+
| Model | GC@3|
|
111 |
+
|--|--|
|
112 |
+
| ChatGPT-4V|0.471 |
|
113 |
+
| mPLUG-Owl2|0.353 |
|
114 |
+
| LLaVA-13B|0.334 |
|
115 |
+
| LLaVA-7B|0.294 |
|
116 |
+
|
117 |
+
</td></tr> </table>
|
118 |
+
|
119 |
+
|
120 |
+
<table>
|
121 |
+
<tr><th>Code Reasoning </th><th>Numerical Calculation</th></tr>
|
122 |
+
<tr><td>
|
123 |
+
|
124 |
+
| Model | GC@3|
|
125 |
+
|--|--|
|
126 |
+
| ChatGPT-4V|0.193|
|
127 |
+
| mPLUG-Owl2|0.176 |
|
128 |
+
| LLaVA-13B|0.144 |
|
129 |
+
| LLaVA-7B|0.107 |
|
130 |
+
|
131 |
+
</td><td>
|
132 |
+
|
133 |
+
| Model | GC@3|
|
134 |
+
|--|--|
|
135 |
+
| ChatGPT-4V|0.240 |
|
136 |
+
| LLaVA-13B|0.195 |
|
137 |
+
| mPLUG-Owl2|0.192 |
|
138 |
+
| LLaVA-7B|0.155 |
|
139 |
+
|
140 |
+
</td></tr> </table>
|
141 |
+
|
142 |
+
|
143 |
+
<table>
|
144 |
+
<tr><th>Text Translation </th><th>OCR</th></tr>
|
145 |
+
<tr><td>
|
146 |
+
|
147 |
+
| Model | GC@3|
|
148 |
+
|--|--|
|
149 |
+
| ChatGPT-4V|0.157|
|
150 |
+
| LLaVA-13B|0.116 |
|
151 |
+
| LLaVA-7B|0.111 |
|
152 |
+
| mPLUG-Owl2|0.081 |
|
153 |
+
|
154 |
+
</td><td>
|
155 |
+
|
156 |
+
| Model | GC@3|
|
157 |
+
|--|--|
|
158 |
+
| ChatGPT-4V|0.393 |
|
159 |
+
| mPLUG-Owl2|0.276 |
|
160 |
+
| LLaVA-13B|0.239 |
|
161 |
+
| LLaVA-7B|0.222 |
|
162 |
+
|
163 |
+
</td></tr> </table>
|
README.md
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GenCeption: Evaluate Multimodal LLMs with Unlabeled Unimodal Data
|
2 |
+
|
3 |
+
<div>
|
4 |
+
<p align="center">
|
5 |
+
<a href="https://github.com/EQTPartners/GenCeption/blob/main/Leaderboard.md">🔥🏅️Leaderboard🏅️🔥</a> • 
|
6 |
+
<a href="#contribute">Contribute</a> • 
|
7 |
+
<a href="https://arxiv.org/abs/2402.14973">Paper</a> • 
|
8 |
+
<a href="#cite-this-work">Citation</a>
|
9 |
+
</p>
|
10 |
+
|
11 |
+
> GenCeption is an annotation-free MLLM (Multimodal Large Language Model) evaluation framework that merely requires unimodal data to assess inter-modality semantic coherence and inversely reflects the models' inclination to hallucinate.
|
12 |
+
|
13 |
+

|
14 |
+
|
15 |
+
GenCeption is inspired by a popular multi-player game [DrawCeption](https://wikipedia.org/wiki/drawception). Using the image modality as an example, the process begins with a seed image $\mathbf{X}^{(0)}$ from a unimodal image dataset for the first iteration ($t$=1). The MLLM creates a detailed description of the image, which is then used by an image generator to produce $\mathbf{X}^{(t)}$. After $T$ iterations, we calculate the GC@T score to measure the MLLM's performance on $\mathbf{X}^{(0)}$.
|
16 |
+
|
17 |
+
The GenCeption ranking on [MME](https://github.com/BradyFU/Awesome-Multimodal-Large-Language-Models/tree/Evaluation) benchmarking dataset (without using any label) shows a strong correlation with other sophisticated benchmarks such as [OpenCompass](https://rank.opencompass.org.cn/leaderboard-multimodal) and [HallusionBench](https://github.com/tianyi-lab/HallusionBench). Moreover, the negative correlation with MME scores suggests that GenCeption measures distinct aspects not covered by MME, using the same set of samples. For detailed experimental analysis, please read [our paper](https://arxiv.org/abs/2402.14973).
|
18 |
+
|
19 |
+
We demostrate a 5-iteration GenCeption procedure below run on a seed images to evaluate 4 VLLMs. Each iteration $t$ shows the generated image $\mathbf{X}^{(t)}$, the description $\mathbf{Q}^{(t)}$ of the preceding image $\mathbf{X}^{(t-1)}$, and the similarity score $s^{(t)}$ relative to $\mathbf{X}^{(0)}$. The GC@5 metric for each VLLM is also presented. Hallucinated elements within descriptions $\mathbf{Q}^{(1)}$ and $\mathbf{Q}^{(2)}$ as compared to the seed image are indicated with <span style="color:red"><u>red underlined</u></span>.
|
20 |
+
|
21 |
+

|
22 |
+
|
23 |
+
|
24 |
+
## Contribute
|
25 |
+
Please **create PR (Pull-Request)** to contribute your results to the [🔥🏅️**Leaderboard**🏅️🔥](https://github.com/EQTPartners/GenCeption/blob/main/Leaderboard.md). Start by creating your virtual environment:
|
26 |
+
|
27 |
+
```{bash}
|
28 |
+
conda create --name genception python=3.10 -y
|
29 |
+
conda activate genception
|
30 |
+
pip install -r requirements.txt
|
31 |
+
```
|
32 |
+
|
33 |
+
For example, if you want to evaluate mPLUG-Owl2 model, please follow the instructions in the [official mPLUG-OWL2 repository](https://github.com/X-PLUG/mPLUG-Owl/tree/main/mPLUG-Owl2#usage). Then run GenCeption by
|
34 |
+
|
35 |
+
```{bash}
|
36 |
+
bash example_script.sh # uses exemplary data in datasets/example/
|
37 |
+
```
|
38 |
+
|
39 |
+
This assumes that an OPENAI_API_KEY is set as an environment variable. The `model` argument to `experiment.py` in `example_script.sh` can be adjusted to `llava7b`, `llava13b`, `mPLUG`, or `gpt4v`. Please adapt accordingly for to evaluate your MLLM.
|
40 |
+
|
41 |
+
The MME dataset, of which the image modality was used in our paper, can be obtained as [described here](https://github.com/BradyFU/Awesome-Multimodal-Large-Language-Models/blob/Evaluation/README.md#our-mllm-works).
|
42 |
+
|
43 |
+
## Cite This Work
|
44 |
+
```bibtex
|
45 |
+
@article{cao2023genception,
|
46 |
+
author = {Lele Cao and
|
47 |
+
Valentin Buchner and
|
48 |
+
Zineb Senane and
|
49 |
+
Fangkai Yang},
|
50 |
+
title = {{GenCeption}: Evaluate Multimodal LLMs with Unlabeled Unimodal Data},
|
51 |
+
year={2023},
|
52 |
+
journal={arXiv preprint arXiv:2402.14973},
|
53 |
+
primaryClass={cs.AI,cs.CL,cs.LG}
|
54 |
+
}
|
55 |
+
```
|
datasets/examples/000000061658.jpg
ADDED
![]() |
datasets/examples/000000338560.jpg
ADDED
![]() |
genception/evaluation.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import pickle
|
4 |
+
import numpy as np
|
5 |
+
import argparse
|
6 |
+
from genception.utils import find_files
|
7 |
+
|
8 |
+
|
9 |
+
def read_all_pkl(folder_path: str) -> dict:
|
10 |
+
"""
|
11 |
+
Read all the pickle files in the given folder path
|
12 |
+
|
13 |
+
Args:
|
14 |
+
folder_path: str: The path to the folder
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
dict: The dictionary containing the file path as key and the pickle file content as value
|
18 |
+
"""
|
19 |
+
result_dict = dict()
|
20 |
+
file_list = find_files(folder_path, {".pkl"})
|
21 |
+
for file_path in file_list:
|
22 |
+
with open(file_path, "rb") as file:
|
23 |
+
result_dict[file_path] = pickle.load(file)
|
24 |
+
return result_dict
|
25 |
+
|
26 |
+
|
27 |
+
def integrated_decay_area(scores: list[float]) -> float:
|
28 |
+
"""
|
29 |
+
Calculate the Integrated Decay Area (IDA) for the given scores
|
30 |
+
|
31 |
+
Args:
|
32 |
+
scores: list[float]: The list of scores
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
float: The IDA score
|
36 |
+
"""
|
37 |
+
total_area = 0
|
38 |
+
|
39 |
+
for i, score in enumerate(scores):
|
40 |
+
total_area += (i + 1) * score
|
41 |
+
|
42 |
+
max_possible_area = sum(range(1, len(scores) + 1))
|
43 |
+
ida = total_area / max_possible_area if max_possible_area else 0
|
44 |
+
return ida
|
45 |
+
|
46 |
+
|
47 |
+
def gc_score(folder_path: str, n_iter: int = None) -> tuple[float, list[float]]:
|
48 |
+
"""
|
49 |
+
Calculate the GC@T score for the given folder path
|
50 |
+
|
51 |
+
Args:
|
52 |
+
folder_path: str: The path to the folder
|
53 |
+
n_iter: int: The number of iterations to consider for GC@T score
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
tuple[float, list[float]]: The GC@T score and the list of GC scores for each file
|
57 |
+
"""
|
58 |
+
test_data = read_all_pkl(folder_path)
|
59 |
+
all_gc_scores = []
|
60 |
+
for _, value in test_data.items():
|
61 |
+
sim_score = value["cosine_similarities"][1:]
|
62 |
+
if n_iter is None:
|
63 |
+
_gc = integrated_decay_area(sim_score)
|
64 |
+
else:
|
65 |
+
if len(value["cosine_similarities"]) >= n_iter:
|
66 |
+
_gc = integrated_decay_area(sim_score[:n_iter])
|
67 |
+
else:
|
68 |
+
continue
|
69 |
+
all_gc_scores.append(_gc)
|
70 |
+
return np.mean(all_gc_scores), all_gc_scores
|
71 |
+
|
72 |
+
|
73 |
+
def main():
|
74 |
+
parser = argparse.ArgumentParser()
|
75 |
+
parser.add_argument(
|
76 |
+
"--results_path",
|
77 |
+
type=str,
|
78 |
+
help="Path to the folder containing the pickle files",
|
79 |
+
required=True,
|
80 |
+
)
|
81 |
+
parser.add_argument(
|
82 |
+
"--t",
|
83 |
+
type=int,
|
84 |
+
help="Number of iterations to consider for GC@T score",
|
85 |
+
required=True,
|
86 |
+
)
|
87 |
+
args = parser.parse_args()
|
88 |
+
|
89 |
+
# calculate GC@T score and save in results directory
|
90 |
+
gc, all_gc_scores = gc_score(args.results_path, args.t)
|
91 |
+
result = {
|
92 |
+
"GC Score": gc,
|
93 |
+
"All GC Scores": all_gc_scores,
|
94 |
+
}
|
95 |
+
results_path = os.path.join(args.results_path, f"GC@{str(args.t)}.json")
|
96 |
+
with open(results_path, "w") as file:
|
97 |
+
json.dump(result, file)
|
98 |
+
|
99 |
+
|
100 |
+
if __name__ == "__main__":
|
101 |
+
main()
|
genception/example_script.sh
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# run experiment with gpt4v on examples dataset
|
2 |
+
python genception/experiment.py --model gpt4v --dataset datasets/examples
|
3 |
+
|
4 |
+
|
5 |
+
# Calculate GC@T evaluation metric
|
6 |
+
python genception/evaluation.py --results_path datasets/examples/results_gpt4v --t 1
|
7 |
+
python genception/evaluation.py --results_path datasets/examples/results_gpt4v --t 3
|
8 |
+
python genception/evaluation.py --results_path datasets/examples/results_gpt4v --t 5
|
genception/experiment.py
ADDED
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import base64
|
4 |
+
import pickle
|
5 |
+
import requests
|
6 |
+
import argparse
|
7 |
+
import nltk
|
8 |
+
from nltk.tokenize import word_tokenize
|
9 |
+
from functools import partial
|
10 |
+
from transformers import ViTImageProcessor, ViTModel
|
11 |
+
from transformers import AutoProcessor, LlavaForConditionalGeneration
|
12 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
13 |
+
from PIL import Image
|
14 |
+
import logging
|
15 |
+
from tqdm import tqdm
|
16 |
+
from openai import OpenAI
|
17 |
+
from mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
|
18 |
+
from mplug_owl2.conversation import conv_templates
|
19 |
+
from mplug_owl2.model.builder import load_pretrained_model
|
20 |
+
from mplug_owl2.mm_utils import (
|
21 |
+
process_images,
|
22 |
+
tokenizer_image_token,
|
23 |
+
get_model_name_from_path,
|
24 |
+
KeywordsStoppingCriteria,
|
25 |
+
)
|
26 |
+
from genception.utils import find_files
|
27 |
+
|
28 |
+
logging.basicConfig(level=logging.INFO)
|
29 |
+
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
30 |
+
api_key = client.api_key
|
31 |
+
nltk.download("punkt")
|
32 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
33 |
+
torch.backends.cudnn.enabled = False
|
34 |
+
|
35 |
+
# VIT model
|
36 |
+
vit_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
|
37 |
+
vit_model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
|
38 |
+
|
39 |
+
|
40 |
+
def image_embedding(image_file: str) -> list[float]:
|
41 |
+
"""
|
42 |
+
Generates an image embedding using a vit model
|
43 |
+
|
44 |
+
Args:
|
45 |
+
image_file: str: The path to the image file
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
list[float]: The image embedding
|
49 |
+
"""
|
50 |
+
image = Image.open(image_file).convert("RGB")
|
51 |
+
inputs = vit_processor(images=image, return_tensors="pt")
|
52 |
+
outputs = vit_model(**inputs)
|
53 |
+
return outputs.last_hidden_state.tolist()[0][0]
|
54 |
+
|
55 |
+
|
56 |
+
def save_image_from_url(url: str, filename: str):
|
57 |
+
"""
|
58 |
+
Save an image from a given URL to a file
|
59 |
+
|
60 |
+
Args:
|
61 |
+
url: str: The URL of the image
|
62 |
+
filename: str: The name of the file to save the image to
|
63 |
+
"""
|
64 |
+
response = requests.get(url)
|
65 |
+
if response.status_code == 200:
|
66 |
+
with open(filename, "wb") as file:
|
67 |
+
file.write(response.content)
|
68 |
+
else:
|
69 |
+
logging.warning(
|
70 |
+
f"Failed to download image. Status code: {response.status_code}"
|
71 |
+
)
|
72 |
+
|
73 |
+
|
74 |
+
def find_image_files(folder_path: str) -> list[str]:
|
75 |
+
image_extensions = {".jpg", ".png"}
|
76 |
+
return find_files(folder_path, image_extensions)
|
77 |
+
|
78 |
+
|
79 |
+
def count_words(text):
|
80 |
+
words = word_tokenize(text)
|
81 |
+
return len(words)
|
82 |
+
|
83 |
+
|
84 |
+
def encode_image_os(image_path: str):
|
85 |
+
image = Image.open(image_path).convert("RGB")
|
86 |
+
return image
|
87 |
+
|
88 |
+
|
89 |
+
def encode_image_gpt4v(image_path: str):
|
90 |
+
with open(image_path, "rb") as image_file:
|
91 |
+
return base64.b64encode(image_file.read()).decode("utf-8")
|
92 |
+
|
93 |
+
|
94 |
+
def generate_xt(
|
95 |
+
image_desc: str, output_folder: str, i: int, file_name: str, file_extension: str
|
96 |
+
) -> str:
|
97 |
+
"""
|
98 |
+
Generate an image based on a description using dall-e and save it to a file
|
99 |
+
|
100 |
+
Args:
|
101 |
+
image_desc: str: The description of the image
|
102 |
+
output_folder: str: The path to the folder to save the image to
|
103 |
+
i: int: The iteration number
|
104 |
+
file_name: str: The name of the file
|
105 |
+
file_extension: str: The extension of the file
|
106 |
+
|
107 |
+
Returns:
|
108 |
+
str: The path to the saved image file
|
109 |
+
"""
|
110 |
+
response = client.images.generate(
|
111 |
+
model="dall-e-3",
|
112 |
+
prompt="Generate an image that fully and precisely reflects this description: {}".format(
|
113 |
+
image_desc
|
114 |
+
),
|
115 |
+
size="1024x1024",
|
116 |
+
quality="standard",
|
117 |
+
n=1,
|
118 |
+
)
|
119 |
+
new_image_filename = os.path.join(
|
120 |
+
output_folder, f"{file_name}_{i}.{file_extension}"
|
121 |
+
)
|
122 |
+
save_image_from_url(response.data[0].url, new_image_filename)
|
123 |
+
return new_image_filename
|
124 |
+
|
125 |
+
|
126 |
+
def get_desc_mPLUG(image, image_processor, lmm_model, tokenizer, prompt):
|
127 |
+
"""
|
128 |
+
Given an image, generate a description using the mPLUG model
|
129 |
+
|
130 |
+
Args:
|
131 |
+
image: Image: The image to describe
|
132 |
+
image_processor: callable: The image processor
|
133 |
+
lmm_model: The language model
|
134 |
+
tokenizer: The tokenizer
|
135 |
+
prompt: str: The prompt for the model
|
136 |
+
|
137 |
+
Returns:
|
138 |
+
str: The description of the image
|
139 |
+
"""
|
140 |
+
conv = conv_templates["mplug_owl2"].copy()
|
141 |
+
max_edge = max(image.size)
|
142 |
+
image = image.resize((max_edge, max_edge))
|
143 |
+
image_tensor = process_images([image], image_processor)
|
144 |
+
image_tensor = image_tensor.to(lmm_model.device, dtype=torch.float16)
|
145 |
+
|
146 |
+
inp = DEFAULT_IMAGE_TOKEN + prompt
|
147 |
+
conv.append_message(conv.roles[0], inp)
|
148 |
+
conv.append_message(conv.roles[1], None)
|
149 |
+
prompt = conv.get_prompt()
|
150 |
+
|
151 |
+
input_ids = (
|
152 |
+
tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
|
153 |
+
.unsqueeze(0)
|
154 |
+
.to(lmm_model.device)
|
155 |
+
)
|
156 |
+
stop_str = conv.sep2
|
157 |
+
keywords = [stop_str]
|
158 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
159 |
+
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
|
160 |
+
|
161 |
+
temperature = 0.001
|
162 |
+
max_new_tokens = 512
|
163 |
+
|
164 |
+
with torch.inference_mode():
|
165 |
+
output_ids = lmm_model.generate(
|
166 |
+
input_ids,
|
167 |
+
images=image_tensor,
|
168 |
+
do_sample=True,
|
169 |
+
temperature=temperature,
|
170 |
+
max_new_tokens=max_new_tokens,
|
171 |
+
stopping_criteria=[stopping_criteria],
|
172 |
+
attention_mask=attention_mask,
|
173 |
+
)
|
174 |
+
|
175 |
+
image_desc = tokenizer.decode(
|
176 |
+
output_ids[0, input_ids.shape[1] :], skip_special_tokens=True
|
177 |
+
).strip()
|
178 |
+
return image_desc
|
179 |
+
|
180 |
+
|
181 |
+
def get_desc_llava(image, lmm_processor, lmm_model, prompt):
|
182 |
+
"""
|
183 |
+
Given an image, generate a description using the llava model
|
184 |
+
|
185 |
+
Args:
|
186 |
+
image: Image: The image to describe
|
187 |
+
lmm_processor: callable: The language model processor
|
188 |
+
lmm_model: The language model
|
189 |
+
prompt: str: The prompt for the model
|
190 |
+
|
191 |
+
Returns:
|
192 |
+
str: The description of the image
|
193 |
+
"""
|
194 |
+
inputs = lmm_processor(text=prompt, images=image, return_tensors="pt").to(device)
|
195 |
+
outputs = lmm_model.generate(**inputs, max_new_tokens=512, do_sample=False)
|
196 |
+
answer = lmm_processor.batch_decode(outputs, skip_special_tokens=True)[0]
|
197 |
+
image_desc = answer.split("ASSISTANT:")[1].strip()
|
198 |
+
return image_desc
|
199 |
+
|
200 |
+
|
201 |
+
def get_desc_gpt4v(image, prompt):
|
202 |
+
"""
|
203 |
+
Given an image, generate a description using the gpt-4-vision model
|
204 |
+
|
205 |
+
Args:
|
206 |
+
image: Image: The image to describe
|
207 |
+
prompt: str: The prompt for the model
|
208 |
+
|
209 |
+
Returns:
|
210 |
+
str: The description of the image
|
211 |
+
"""
|
212 |
+
payload = {
|
213 |
+
"model": "gpt-4-vision-preview",
|
214 |
+
"messages": [
|
215 |
+
{
|
216 |
+
"role": "user",
|
217 |
+
"content": [
|
218 |
+
{
|
219 |
+
"type": "text",
|
220 |
+
"text": prompt,
|
221 |
+
},
|
222 |
+
{
|
223 |
+
"type": "image_url",
|
224 |
+
"image_url": {"url": f"data:image/jpeg;base64,{image}"},
|
225 |
+
},
|
226 |
+
],
|
227 |
+
}
|
228 |
+
],
|
229 |
+
"max_tokens": 512,
|
230 |
+
"temperature": 0,
|
231 |
+
}
|
232 |
+
|
233 |
+
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
|
234 |
+
|
235 |
+
response = requests.post(
|
236 |
+
"https://api.openai.com/v1/chat/completions", headers=headers, json=payload
|
237 |
+
)
|
238 |
+
image_desc = response.json()["choices"][0]["message"]["content"]
|
239 |
+
return image_desc
|
240 |
+
|
241 |
+
|
242 |
+
def test_sample(
|
243 |
+
seed_image: str,
|
244 |
+
n_iteration: int,
|
245 |
+
output_folder: str,
|
246 |
+
get_desc_function: callable,
|
247 |
+
encode_image_function: callable,
|
248 |
+
):
|
249 |
+
"""
|
250 |
+
Iteratively generates T (n_iterations) descriptions and images based on the seed image
|
251 |
+
|
252 |
+
Args:
|
253 |
+
seed_image: str: The path to the seed image
|
254 |
+
n_iteration: int: The number of iterations to perform
|
255 |
+
output_folder: str: The path to the folder to save the results
|
256 |
+
get_desc_function: callable: The function to generate the description
|
257 |
+
encode_image_function: callable: The function to encode the image
|
258 |
+
"""
|
259 |
+
list_of_desc = []
|
260 |
+
list_of_image = []
|
261 |
+
list_of_image_embedding = [image_embedding(seed_image)]
|
262 |
+
list_of_cos_sim = [1.0]
|
263 |
+
|
264 |
+
current_image_path = seed_image
|
265 |
+
current_image_name = os.path.basename(current_image_path)
|
266 |
+
file_name, file_extension = current_image_name.split(".")
|
267 |
+
logging.debug(f"Image: {current_image_path}")
|
268 |
+
pkl_file = os.path.join(output_folder, f"{file_name}_result.pkl")
|
269 |
+
if os.path.exists(pkl_file):
|
270 |
+
logging.info("Results already exist, skipping")
|
271 |
+
return None
|
272 |
+
|
273 |
+
for i in range(n_iteration):
|
274 |
+
# Encode the current image and get the description
|
275 |
+
image = encode_image_function(current_image_path)
|
276 |
+
image_desc = get_desc_function(image)
|
277 |
+
list_of_desc.append(image_desc)
|
278 |
+
logging.debug(image_desc)
|
279 |
+
|
280 |
+
# generate X^t, append image and embedding
|
281 |
+
new_image_filename = generate_xt(
|
282 |
+
image_desc, output_folder, i, file_name, file_extension
|
283 |
+
)
|
284 |
+
list_of_image.append(new_image_filename)
|
285 |
+
list_of_image_embedding.append(image_embedding(new_image_filename))
|
286 |
+
|
287 |
+
# Calculate Cosine Sim to original image
|
288 |
+
similarity = cosine_similarity(
|
289 |
+
[list_of_image_embedding[0]], [list_of_image_embedding[-1]]
|
290 |
+
)[0][0]
|
291 |
+
list_of_cos_sim.append(similarity)
|
292 |
+
logging.info(f"({count_words(image_desc)}, {round(similarity,2)})")
|
293 |
+
|
294 |
+
# Save checkpoint to avoid losing results
|
295 |
+
data_to_save = {
|
296 |
+
"descriptions": list_of_desc,
|
297 |
+
"images": list_of_image,
|
298 |
+
"image_embeddings": list_of_image_embedding,
|
299 |
+
"cosine_similarities": list_of_cos_sim,
|
300 |
+
}
|
301 |
+
with open(pkl_file, "wb") as file:
|
302 |
+
pickle.dump(data_to_save, file)
|
303 |
+
|
304 |
+
# Update current_image_path for the next iteration
|
305 |
+
current_image_path = new_image_filename
|
306 |
+
|
307 |
+
return None
|
308 |
+
|
309 |
+
|
310 |
+
def main():
|
311 |
+
parser = argparse.ArgumentParser()
|
312 |
+
parser.add_argument("--dataset", type=str, default="mme_data/color")
|
313 |
+
parser.add_argument("--model", type=str, default="llava7b")
|
314 |
+
parser.add_argument("--n_iter", type=int, default=5)
|
315 |
+
args = parser.parse_args()
|
316 |
+
|
317 |
+
logging.info(args)
|
318 |
+
|
319 |
+
prompt = "Please write a clear, precise, detailed, and concise description of all elements in the image. Focus on accurately depicting various aspects, including but not limited to the colors, shapes, positions, styles, texts and the relationships between different objects and subjects in the image. Your description should be thorough enough to guide a professional in recreating this image solely based on your textual representation. Remember, only include descriptive texts that directly pertain to the contents of the image. You must complete the description using less than 500 words."
|
320 |
+
|
321 |
+
if "llava" in args.model:
|
322 |
+
lmm_model = LlavaForConditionalGeneration.from_pretrained(
|
323 |
+
f"llava-hf/llava-1.5-{args.model[5:]}-hf", load_in_8bit=True
|
324 |
+
)
|
325 |
+
lmm_processor = AutoProcessor.from_pretrained(
|
326 |
+
f"llava-hf/llava-1.5-{args.model[5:]}-hf"
|
327 |
+
)
|
328 |
+
prompt = f"<image>\nUSER: {prompt}\nASSISTANT:"
|
329 |
+
get_desc_function = partial(get_desc_llava, lmm_processor, lmm_model, prompt)
|
330 |
+
encode_image_function = encode_image_os
|
331 |
+
elif args.model == "mPLUG":
|
332 |
+
model_path = "MAGAer13/mplug-owl2-llama2-7b"
|
333 |
+
model_name = get_model_name_from_path(model_path)
|
334 |
+
tokenizer, lmm_model, image_processor, _ = load_pretrained_model(
|
335 |
+
model_path,
|
336 |
+
None,
|
337 |
+
model_name,
|
338 |
+
load_8bit=False,
|
339 |
+
load_4bit=False,
|
340 |
+
device=device,
|
341 |
+
)
|
342 |
+
tokenizer.pad_token_id = tokenizer.eos_token_id
|
343 |
+
tokenizer.pad_token = tokenizer.eos_token
|
344 |
+
get_desc_function = partial(
|
345 |
+
get_desc_mPLUG, image_processor, lmm_model, tokenizer, prompt
|
346 |
+
)
|
347 |
+
encode_image_function = encode_image_os
|
348 |
+
elif args.model == "gpt4v":
|
349 |
+
get_desc_function = partial(get_desc_gpt4v, prompt=prompt)
|
350 |
+
encode_image_function = encode_image_gpt4v
|
351 |
+
|
352 |
+
output_folder = os.path.join(args.dataset, f"results_{args.model}")
|
353 |
+
os.makedirs(output_folder, exist_ok=True)
|
354 |
+
|
355 |
+
logging.debug("Loaded model. Entered main loop.")
|
356 |
+
for img_file in tqdm(find_image_files(args.dataset)):
|
357 |
+
try:
|
358 |
+
logging.info(img_file)
|
359 |
+
test_sample(
|
360 |
+
seed_image=img_file,
|
361 |
+
n_iteration=args.n_iter,
|
362 |
+
output_folder=output_folder,
|
363 |
+
get_desc_function=get_desc_function,
|
364 |
+
encode_image_function=encode_image_function,
|
365 |
+
)
|
366 |
+
except Exception as e:
|
367 |
+
logging.warning("caught error:")
|
368 |
+
logging.warning(e)
|
369 |
+
continue
|
370 |
+
|
371 |
+
|
372 |
+
if __name__ == "__main__":
|
373 |
+
main()
|
genception/utils.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
|
4 |
+
def find_files(folder_path: str, file_extensions: dict) -> list[str]:
|
5 |
+
"""
|
6 |
+
Find all files with the given extensions in the given folder path
|
7 |
+
|
8 |
+
Args:
|
9 |
+
folder_path: str: The path to the folder
|
10 |
+
file_extensions: dict: The file extensions to look for
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
list[str]: The list of file paths
|
14 |
+
"""
|
15 |
+
file_paths = []
|
16 |
+
|
17 |
+
for file in os.listdir(folder_path):
|
18 |
+
if (
|
19 |
+
os.path.isfile(os.path.join(folder_path, file))
|
20 |
+
and os.path.splitext(file)[1].lower() in file_extensions
|
21 |
+
):
|
22 |
+
absolute_path = os.path.abspath(os.path.join(folder_path, file))
|
23 |
+
file_paths.append(absolute_path)
|
24 |
+
|
25 |
+
return file_paths
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers>=4.37.1
|
2 |
+
pillow
|
3 |
+
requests
|
4 |
+
scikit-learn
|
5 |
+
nltk
|
6 |
+
openai
|
7 |
+
sentencepiece
|