cao-lele commited on
Commit
0724c4e
·
1 Parent(s): 9619f3f

initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #custom
2
+ mme_data/
3
+
4
+
5
+ # Byte-compiled / optimized / DLL files
6
+ __pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+
10
+ # C extensions
11
+ *.so
12
+
13
+ # Distribution / packaging
14
+ .Python
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib/
22
+ lib64/
23
+ parts/
24
+ sdist/
25
+ var/
26
+ wheels/
27
+ share/python-wheels/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+ MANIFEST
32
+
33
+ # PyInstaller
34
+ # Usually these files are written by a python script from a template
35
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
36
+ *.manifest
37
+ *.spec
38
+
39
+ # Installer logs
40
+ pip-log.txt
41
+ pip-delete-this-directory.txt
42
+
43
+ # Unit test / coverage reports
44
+ htmlcov/
45
+ .tox/
46
+ .nox/
47
+ .coverage
48
+ .coverage.*
49
+ .cache
50
+ nosetests.xml
51
+ coverage.xml
52
+ *.cover
53
+ *.py,cover
54
+ .hypothesis/
55
+ .pytest_cache/
56
+ cover/
57
+
58
+ # Translations
59
+ *.mo
60
+ *.pot
61
+
62
+ # Django stuff:
63
+ *.log
64
+ local_settings.py
65
+ db.sqlite3
66
+ db.sqlite3-journal
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ .pybuilder/
80
+ target/
81
+
82
+ # Jupyter Notebook
83
+ .ipynb_checkpoints
84
+
85
+ # IPython
86
+ profile_default/
87
+ ipython_config.py
88
+
89
+ # pyenv
90
+ # For a library or package, you might want to ignore these files since the code is
91
+ # intended to run in multiple environments; otherwise, check them in:
92
+ # .python-version
93
+
94
+ # pipenv
95
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
97
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
98
+ # install all needed dependencies.
99
+ #Pipfile.lock
100
+
101
+ # poetry
102
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
103
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
104
+ # commonly ignored for libraries.
105
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
106
+ #poetry.lock
107
+
108
+ # pdm
109
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
110
+ #pdm.lock
111
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
112
+ # in version control.
113
+ # https://pdm.fming.dev/#use-with-ide
114
+ .pdm.toml
115
+
116
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
117
+ __pypackages__/
118
+
119
+ # Celery stuff
120
+ celerybeat-schedule
121
+ celerybeat.pid
122
+
123
+ # SageMath parsed files
124
+ *.sage.py
125
+
126
+ # Environments
127
+ .env
128
+ .venv
129
+ env/
130
+ venv/
131
+ ENV/
132
+ env.bak/
133
+ venv.bak/
134
+
135
+ # Spyder project settings
136
+ .spyderproject
137
+ .spyproject
138
+
139
+ # Rope project settings
140
+ .ropeproject
141
+
142
+ # mkdocs documentation
143
+ /site
144
+
145
+ # mypy
146
+ .mypy_cache/
147
+ .dmypy.json
148
+ dmypy.json
149
+
150
+ # Pyre type checker
151
+ .pyre/
152
+
153
+ # pytype static type analyzer
154
+ .pytype/
155
+
156
+ # Cython debug symbols
157
+ cython_debug/
158
+
159
+ # PyCharm
160
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
161
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
162
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
163
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
164
+ #.idea/
165
+
166
+ # other
167
+ .DS_Store
Leaderboard.md ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🔥🏅️GenCeption Leaderboard 🏅️🔥
2
+
3
+ Evaluated MLLMs: [ChatGPT-4V](https://cdn.openai.com/papers/GPTV_System_Card.pdf), [mPLUG-Owl2](https://arxiv.org/pdf/2311.04257.pdf), [LLaVA-13B](https://arxiv.org/pdf/2304.08485.pdf), [LLaVA-7B](https://arxiv.org/pdf/2304.08485.pdf)
4
+
5
+ <table>
6
+ <tr><th>Existence </th><th>Count</th></tr>
7
+ <tr><td>
8
+
9
+ | Model | GC@3|
10
+ |--|--|
11
+ | ChatGPT-4V|0.422 |
12
+ | mPLUG-Owl2|0.323 |
13
+ | LLaVA-7B|0.308 |
14
+ | LLaVA-13B|0.305 |
15
+
16
+ </td><td>
17
+
18
+ | Model | GC@3|
19
+ |--|--|
20
+ | ChatGPT-4V|0.404 |
21
+ | mPLUG-Owl2|0.299 |
22
+ | LLaVA-13B|0.294 |
23
+ | LLaVA-7B|0.353 |
24
+
25
+ </td></tr> </table>
26
+
27
+
28
+ <table>
29
+ <tr><th>Position </th><th>Color</th></tr>
30
+ <tr><td>
31
+
32
+ | Model | GC@3|
33
+ |--|--|
34
+ | ChatGPT-4V|0.408|
35
+ | mPLUG-Owl2|0.306 |
36
+ | LLaVA-7B|0.285 |
37
+ | LLaVA-13B|0.255 |
38
+
39
+ </td><td>
40
+
41
+ | Model | GC@3|
42
+ |--|--|
43
+ | ChatGPT-4V|0.403 |
44
+ | LLaVA-13B|0.300 |
45
+ | mPLUG-Owl2|0.290 |
46
+ | LLaVA-7B|0.284 |
47
+
48
+ </td></tr> </table>
49
+
50
+
51
+ <table>
52
+ <tr><th>Poster </th><th>Celebrity</th></tr>
53
+ <tr><td>
54
+
55
+ | Model | GC@3|
56
+ |--|--|
57
+ | ChatGPT-4V|0.324|
58
+ | mPLUG-Owl2|0.243 |
59
+ | LLaVA-13B|0.215 |
60
+ | LLaVA-7B|0.214 |
61
+
62
+ </td><td>
63
+
64
+ | Model | GC@3|
65
+ |--|--|
66
+ | ChatGPT-4V|0.332 |
67
+ | mPLUG-Owl2|0.232 |
68
+ | LLaVA-13B|0.206 |
69
+ | LLaVA-7B|0.188 |
70
+
71
+ </td></tr> </table>
72
+
73
+
74
+ <table>
75
+ <tr><th>Scene </th><th>Landmark</th></tr>
76
+ <tr><td>
77
+
78
+ | Model | GC@3|
79
+ |--|--|
80
+ | ChatGPT-4V|0.393|
81
+ | mPLUG-Owl2|0.299 |
82
+ | LLaVA-13B|0.277 |
83
+ | LLaVA-7B|0.266 |
84
+
85
+ </td><td>
86
+
87
+ | Model | GC@3|
88
+ |--|--|
89
+ | ChatGPT-4V|0.353 |
90
+ | mPLUG-Owl2|0.275 |
91
+ | LLaVA-7B|0.252 |
92
+ | LLaVA-13B|0.242 |
93
+
94
+ </td></tr> </table>
95
+
96
+
97
+ <table>
98
+ <tr><th>Artwork </th><th>Commonsense Reasoning</th></tr>
99
+ <tr><td>
100
+
101
+ | Model | GC@3|
102
+ |--|--|
103
+ | ChatGPT-4V|0.421|
104
+ | mPLUG-Owl2|0.252 |
105
+ | LLaVA-13B|0.212 |
106
+ | LLaVA-7B|0.210 |
107
+
108
+ </td><td>
109
+
110
+ | Model | GC@3|
111
+ |--|--|
112
+ | ChatGPT-4V|0.471 |
113
+ | mPLUG-Owl2|0.353 |
114
+ | LLaVA-13B|0.334 |
115
+ | LLaVA-7B|0.294 |
116
+
117
+ </td></tr> </table>
118
+
119
+
120
+ <table>
121
+ <tr><th>Code Reasoning </th><th>Numerical Calculation</th></tr>
122
+ <tr><td>
123
+
124
+ | Model | GC@3|
125
+ |--|--|
126
+ | ChatGPT-4V|0.193|
127
+ | mPLUG-Owl2|0.176 |
128
+ | LLaVA-13B|0.144 |
129
+ | LLaVA-7B|0.107 |
130
+
131
+ </td><td>
132
+
133
+ | Model | GC@3|
134
+ |--|--|
135
+ | ChatGPT-4V|0.240 |
136
+ | LLaVA-13B|0.195 |
137
+ | mPLUG-Owl2|0.192 |
138
+ | LLaVA-7B|0.155 |
139
+
140
+ </td></tr> </table>
141
+
142
+
143
+ <table>
144
+ <tr><th>Text Translation </th><th>OCR</th></tr>
145
+ <tr><td>
146
+
147
+ | Model | GC@3|
148
+ |--|--|
149
+ | ChatGPT-4V|0.157|
150
+ | LLaVA-13B|0.116 |
151
+ | LLaVA-7B|0.111 |
152
+ | mPLUG-Owl2|0.081 |
153
+
154
+ </td><td>
155
+
156
+ | Model | GC@3|
157
+ |--|--|
158
+ | ChatGPT-4V|0.393 |
159
+ | mPLUG-Owl2|0.276 |
160
+ | LLaVA-13B|0.239 |
161
+ | LLaVA-7B|0.222 |
162
+
163
+ </td></tr> </table>
README.md ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GenCeption: Evaluate Multimodal LLMs with Unlabeled Unimodal Data
2
+
3
+ <div>
4
+ <p align="center">
5
+ <a href="https://github.com/EQTPartners/GenCeption/blob/main/Leaderboard.md">🔥🏅️Leaderboard🏅️🔥</a>&emsp;•&emsp;
6
+ <a href="#contribute">Contribute</a>&emsp;•&emsp;
7
+ <a href="https://arxiv.org/abs/2402.14973">Paper</a>&emsp;•&emsp;
8
+ <a href="#cite-this-work">Citation</a>
9
+ </p>
10
+
11
+ > GenCeption is an annotation-free MLLM (Multimodal Large Language Model) evaluation framework that merely requires unimodal data to assess inter-modality semantic coherence and inversely reflects the models' inclination to hallucinate.
12
+
13
+ ![GenCeption Procedure](figures/genception-correlation.jpeg)
14
+
15
+ GenCeption is inspired by a popular multi-player game [DrawCeption](https://wikipedia.org/wiki/drawception). Using the image modality as an example, the process begins with a seed image $\mathbf{X}^{(0)}$ from a unimodal image dataset for the first iteration ($t$=1). The MLLM creates a detailed description of the image, which is then used by an image generator to produce $\mathbf{X}^{(t)}$. After $T$ iterations, we calculate the GC@T score to measure the MLLM's performance on $\mathbf{X}^{(0)}$.
16
+
17
+ The GenCeption ranking on [MME](https://github.com/BradyFU/Awesome-Multimodal-Large-Language-Models/tree/Evaluation) benchmarking dataset (without using any label) shows a strong correlation with other sophisticated benchmarks such as [OpenCompass](https://rank.opencompass.org.cn/leaderboard-multimodal) and [HallusionBench](https://github.com/tianyi-lab/HallusionBench). Moreover, the negative correlation with MME scores suggests that GenCeption measures distinct aspects not covered by MME, using the same set of samples. For detailed experimental analysis, please read [our paper](https://arxiv.org/abs/2402.14973).
18
+
19
+ We demostrate a 5-iteration GenCeption procedure below run on a seed images to evaluate 4 VLLMs. Each iteration $t$ shows the generated image $\mathbf{X}^{(t)}$, the description $\mathbf{Q}^{(t)}$ of the preceding image $\mathbf{X}^{(t-1)}$, and the similarity score $s^{(t)}$ relative to $\mathbf{X}^{(0)}$. The GC@5 metric for each VLLM is also presented. Hallucinated elements within descriptions $\mathbf{Q}^{(1)}$ and $\mathbf{Q}^{(2)}$ as compared to the seed image are indicated with <span style="color:red"><u>red underlined</u></span>.
20
+
21
+ ![GenCeption Example](figures/existence-example.jpeg)
22
+
23
+
24
+ ## Contribute
25
+ Please **create PR (Pull-Request)** to contribute your results to the [🔥🏅️**Leaderboard**🏅️🔥](https://github.com/EQTPartners/GenCeption/blob/main/Leaderboard.md). Start by creating your virtual environment:
26
+
27
+ ```{bash}
28
+ conda create --name genception python=3.10 -y
29
+ conda activate genception
30
+ pip install -r requirements.txt
31
+ ```
32
+
33
+ For example, if you want to evaluate mPLUG-Owl2 model, please follow the instructions in the [official mPLUG-OWL2 repository](https://github.com/X-PLUG/mPLUG-Owl/tree/main/mPLUG-Owl2#usage). Then run GenCeption by
34
+
35
+ ```{bash}
36
+ bash example_script.sh # uses exemplary data in datasets/example/
37
+ ```
38
+
39
+ This assumes that an OPENAI_API_KEY is set as an environment variable. The `model` argument to `experiment.py` in `example_script.sh` can be adjusted to `llava7b`, `llava13b`, `mPLUG`, or `gpt4v`. Please adapt accordingly for to evaluate your MLLM.
40
+
41
+ The MME dataset, of which the image modality was used in our paper, can be obtained as [described here](https://github.com/BradyFU/Awesome-Multimodal-Large-Language-Models/blob/Evaluation/README.md#our-mllm-works).
42
+
43
+ ## Cite This Work
44
+ ```bibtex
45
+ @article{cao2023genception,
46
+ author = {Lele Cao and
47
+ Valentin Buchner and
48
+ Zineb Senane and
49
+ Fangkai Yang},
50
+ title = {{GenCeption}: Evaluate Multimodal LLMs with Unlabeled Unimodal Data},
51
+ year={2023},
52
+ journal={arXiv preprint arXiv:2402.14973},
53
+ primaryClass={cs.AI,cs.CL,cs.LG}
54
+ }
55
+ ```
datasets/examples/000000061658.jpg ADDED
datasets/examples/000000338560.jpg ADDED
genception/evaluation.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import pickle
4
+ import numpy as np
5
+ import argparse
6
+ from genception.utils import find_files
7
+
8
+
9
+ def read_all_pkl(folder_path: str) -> dict:
10
+ """
11
+ Read all the pickle files in the given folder path
12
+
13
+ Args:
14
+ folder_path: str: The path to the folder
15
+
16
+ Returns:
17
+ dict: The dictionary containing the file path as key and the pickle file content as value
18
+ """
19
+ result_dict = dict()
20
+ file_list = find_files(folder_path, {".pkl"})
21
+ for file_path in file_list:
22
+ with open(file_path, "rb") as file:
23
+ result_dict[file_path] = pickle.load(file)
24
+ return result_dict
25
+
26
+
27
+ def integrated_decay_area(scores: list[float]) -> float:
28
+ """
29
+ Calculate the Integrated Decay Area (IDA) for the given scores
30
+
31
+ Args:
32
+ scores: list[float]: The list of scores
33
+
34
+ Returns:
35
+ float: The IDA score
36
+ """
37
+ total_area = 0
38
+
39
+ for i, score in enumerate(scores):
40
+ total_area += (i + 1) * score
41
+
42
+ max_possible_area = sum(range(1, len(scores) + 1))
43
+ ida = total_area / max_possible_area if max_possible_area else 0
44
+ return ida
45
+
46
+
47
+ def gc_score(folder_path: str, n_iter: int = None) -> tuple[float, list[float]]:
48
+ """
49
+ Calculate the GC@T score for the given folder path
50
+
51
+ Args:
52
+ folder_path: str: The path to the folder
53
+ n_iter: int: The number of iterations to consider for GC@T score
54
+
55
+ Returns:
56
+ tuple[float, list[float]]: The GC@T score and the list of GC scores for each file
57
+ """
58
+ test_data = read_all_pkl(folder_path)
59
+ all_gc_scores = []
60
+ for _, value in test_data.items():
61
+ sim_score = value["cosine_similarities"][1:]
62
+ if n_iter is None:
63
+ _gc = integrated_decay_area(sim_score)
64
+ else:
65
+ if len(value["cosine_similarities"]) >= n_iter:
66
+ _gc = integrated_decay_area(sim_score[:n_iter])
67
+ else:
68
+ continue
69
+ all_gc_scores.append(_gc)
70
+ return np.mean(all_gc_scores), all_gc_scores
71
+
72
+
73
+ def main():
74
+ parser = argparse.ArgumentParser()
75
+ parser.add_argument(
76
+ "--results_path",
77
+ type=str,
78
+ help="Path to the folder containing the pickle files",
79
+ required=True,
80
+ )
81
+ parser.add_argument(
82
+ "--t",
83
+ type=int,
84
+ help="Number of iterations to consider for GC@T score",
85
+ required=True,
86
+ )
87
+ args = parser.parse_args()
88
+
89
+ # calculate GC@T score and save in results directory
90
+ gc, all_gc_scores = gc_score(args.results_path, args.t)
91
+ result = {
92
+ "GC Score": gc,
93
+ "All GC Scores": all_gc_scores,
94
+ }
95
+ results_path = os.path.join(args.results_path, f"GC@{str(args.t)}.json")
96
+ with open(results_path, "w") as file:
97
+ json.dump(result, file)
98
+
99
+
100
+ if __name__ == "__main__":
101
+ main()
genception/example_script.sh ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # run experiment with gpt4v on examples dataset
2
+ python genception/experiment.py --model gpt4v --dataset datasets/examples
3
+
4
+
5
+ # Calculate GC@T evaluation metric
6
+ python genception/evaluation.py --results_path datasets/examples/results_gpt4v --t 1
7
+ python genception/evaluation.py --results_path datasets/examples/results_gpt4v --t 3
8
+ python genception/evaluation.py --results_path datasets/examples/results_gpt4v --t 5
genception/experiment.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import base64
4
+ import pickle
5
+ import requests
6
+ import argparse
7
+ import nltk
8
+ from nltk.tokenize import word_tokenize
9
+ from functools import partial
10
+ from transformers import ViTImageProcessor, ViTModel
11
+ from transformers import AutoProcessor, LlavaForConditionalGeneration
12
+ from sklearn.metrics.pairwise import cosine_similarity
13
+ from PIL import Image
14
+ import logging
15
+ from tqdm import tqdm
16
+ from openai import OpenAI
17
+ from mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
18
+ from mplug_owl2.conversation import conv_templates
19
+ from mplug_owl2.model.builder import load_pretrained_model
20
+ from mplug_owl2.mm_utils import (
21
+ process_images,
22
+ tokenizer_image_token,
23
+ get_model_name_from_path,
24
+ KeywordsStoppingCriteria,
25
+ )
26
+ from genception.utils import find_files
27
+
28
+ logging.basicConfig(level=logging.INFO)
29
+ client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
30
+ api_key = client.api_key
31
+ nltk.download("punkt")
32
+ device = "cuda" if torch.cuda.is_available() else "cpu"
33
+ torch.backends.cudnn.enabled = False
34
+
35
+ # VIT model
36
+ vit_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
37
+ vit_model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
38
+
39
+
40
+ def image_embedding(image_file: str) -> list[float]:
41
+ """
42
+ Generates an image embedding using a vit model
43
+
44
+ Args:
45
+ image_file: str: The path to the image file
46
+
47
+ Returns:
48
+ list[float]: The image embedding
49
+ """
50
+ image = Image.open(image_file).convert("RGB")
51
+ inputs = vit_processor(images=image, return_tensors="pt")
52
+ outputs = vit_model(**inputs)
53
+ return outputs.last_hidden_state.tolist()[0][0]
54
+
55
+
56
+ def save_image_from_url(url: str, filename: str):
57
+ """
58
+ Save an image from a given URL to a file
59
+
60
+ Args:
61
+ url: str: The URL of the image
62
+ filename: str: The name of the file to save the image to
63
+ """
64
+ response = requests.get(url)
65
+ if response.status_code == 200:
66
+ with open(filename, "wb") as file:
67
+ file.write(response.content)
68
+ else:
69
+ logging.warning(
70
+ f"Failed to download image. Status code: {response.status_code}"
71
+ )
72
+
73
+
74
+ def find_image_files(folder_path: str) -> list[str]:
75
+ image_extensions = {".jpg", ".png"}
76
+ return find_files(folder_path, image_extensions)
77
+
78
+
79
+ def count_words(text):
80
+ words = word_tokenize(text)
81
+ return len(words)
82
+
83
+
84
+ def encode_image_os(image_path: str):
85
+ image = Image.open(image_path).convert("RGB")
86
+ return image
87
+
88
+
89
+ def encode_image_gpt4v(image_path: str):
90
+ with open(image_path, "rb") as image_file:
91
+ return base64.b64encode(image_file.read()).decode("utf-8")
92
+
93
+
94
+ def generate_xt(
95
+ image_desc: str, output_folder: str, i: int, file_name: str, file_extension: str
96
+ ) -> str:
97
+ """
98
+ Generate an image based on a description using dall-e and save it to a file
99
+
100
+ Args:
101
+ image_desc: str: The description of the image
102
+ output_folder: str: The path to the folder to save the image to
103
+ i: int: The iteration number
104
+ file_name: str: The name of the file
105
+ file_extension: str: The extension of the file
106
+
107
+ Returns:
108
+ str: The path to the saved image file
109
+ """
110
+ response = client.images.generate(
111
+ model="dall-e-3",
112
+ prompt="Generate an image that fully and precisely reflects this description: {}".format(
113
+ image_desc
114
+ ),
115
+ size="1024x1024",
116
+ quality="standard",
117
+ n=1,
118
+ )
119
+ new_image_filename = os.path.join(
120
+ output_folder, f"{file_name}_{i}.{file_extension}"
121
+ )
122
+ save_image_from_url(response.data[0].url, new_image_filename)
123
+ return new_image_filename
124
+
125
+
126
+ def get_desc_mPLUG(image, image_processor, lmm_model, tokenizer, prompt):
127
+ """
128
+ Given an image, generate a description using the mPLUG model
129
+
130
+ Args:
131
+ image: Image: The image to describe
132
+ image_processor: callable: The image processor
133
+ lmm_model: The language model
134
+ tokenizer: The tokenizer
135
+ prompt: str: The prompt for the model
136
+
137
+ Returns:
138
+ str: The description of the image
139
+ """
140
+ conv = conv_templates["mplug_owl2"].copy()
141
+ max_edge = max(image.size)
142
+ image = image.resize((max_edge, max_edge))
143
+ image_tensor = process_images([image], image_processor)
144
+ image_tensor = image_tensor.to(lmm_model.device, dtype=torch.float16)
145
+
146
+ inp = DEFAULT_IMAGE_TOKEN + prompt
147
+ conv.append_message(conv.roles[0], inp)
148
+ conv.append_message(conv.roles[1], None)
149
+ prompt = conv.get_prompt()
150
+
151
+ input_ids = (
152
+ tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
153
+ .unsqueeze(0)
154
+ .to(lmm_model.device)
155
+ )
156
+ stop_str = conv.sep2
157
+ keywords = [stop_str]
158
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
159
+ attention_mask = torch.ones_like(input_ids, dtype=torch.long)
160
+
161
+ temperature = 0.001
162
+ max_new_tokens = 512
163
+
164
+ with torch.inference_mode():
165
+ output_ids = lmm_model.generate(
166
+ input_ids,
167
+ images=image_tensor,
168
+ do_sample=True,
169
+ temperature=temperature,
170
+ max_new_tokens=max_new_tokens,
171
+ stopping_criteria=[stopping_criteria],
172
+ attention_mask=attention_mask,
173
+ )
174
+
175
+ image_desc = tokenizer.decode(
176
+ output_ids[0, input_ids.shape[1] :], skip_special_tokens=True
177
+ ).strip()
178
+ return image_desc
179
+
180
+
181
+ def get_desc_llava(image, lmm_processor, lmm_model, prompt):
182
+ """
183
+ Given an image, generate a description using the llava model
184
+
185
+ Args:
186
+ image: Image: The image to describe
187
+ lmm_processor: callable: The language model processor
188
+ lmm_model: The language model
189
+ prompt: str: The prompt for the model
190
+
191
+ Returns:
192
+ str: The description of the image
193
+ """
194
+ inputs = lmm_processor(text=prompt, images=image, return_tensors="pt").to(device)
195
+ outputs = lmm_model.generate(**inputs, max_new_tokens=512, do_sample=False)
196
+ answer = lmm_processor.batch_decode(outputs, skip_special_tokens=True)[0]
197
+ image_desc = answer.split("ASSISTANT:")[1].strip()
198
+ return image_desc
199
+
200
+
201
+ def get_desc_gpt4v(image, prompt):
202
+ """
203
+ Given an image, generate a description using the gpt-4-vision model
204
+
205
+ Args:
206
+ image: Image: The image to describe
207
+ prompt: str: The prompt for the model
208
+
209
+ Returns:
210
+ str: The description of the image
211
+ """
212
+ payload = {
213
+ "model": "gpt-4-vision-preview",
214
+ "messages": [
215
+ {
216
+ "role": "user",
217
+ "content": [
218
+ {
219
+ "type": "text",
220
+ "text": prompt,
221
+ },
222
+ {
223
+ "type": "image_url",
224
+ "image_url": {"url": f"data:image/jpeg;base64,{image}"},
225
+ },
226
+ ],
227
+ }
228
+ ],
229
+ "max_tokens": 512,
230
+ "temperature": 0,
231
+ }
232
+
233
+ headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
234
+
235
+ response = requests.post(
236
+ "https://api.openai.com/v1/chat/completions", headers=headers, json=payload
237
+ )
238
+ image_desc = response.json()["choices"][0]["message"]["content"]
239
+ return image_desc
240
+
241
+
242
+ def test_sample(
243
+ seed_image: str,
244
+ n_iteration: int,
245
+ output_folder: str,
246
+ get_desc_function: callable,
247
+ encode_image_function: callable,
248
+ ):
249
+ """
250
+ Iteratively generates T (n_iterations) descriptions and images based on the seed image
251
+
252
+ Args:
253
+ seed_image: str: The path to the seed image
254
+ n_iteration: int: The number of iterations to perform
255
+ output_folder: str: The path to the folder to save the results
256
+ get_desc_function: callable: The function to generate the description
257
+ encode_image_function: callable: The function to encode the image
258
+ """
259
+ list_of_desc = []
260
+ list_of_image = []
261
+ list_of_image_embedding = [image_embedding(seed_image)]
262
+ list_of_cos_sim = [1.0]
263
+
264
+ current_image_path = seed_image
265
+ current_image_name = os.path.basename(current_image_path)
266
+ file_name, file_extension = current_image_name.split(".")
267
+ logging.debug(f"Image: {current_image_path}")
268
+ pkl_file = os.path.join(output_folder, f"{file_name}_result.pkl")
269
+ if os.path.exists(pkl_file):
270
+ logging.info("Results already exist, skipping")
271
+ return None
272
+
273
+ for i in range(n_iteration):
274
+ # Encode the current image and get the description
275
+ image = encode_image_function(current_image_path)
276
+ image_desc = get_desc_function(image)
277
+ list_of_desc.append(image_desc)
278
+ logging.debug(image_desc)
279
+
280
+ # generate X^t, append image and embedding
281
+ new_image_filename = generate_xt(
282
+ image_desc, output_folder, i, file_name, file_extension
283
+ )
284
+ list_of_image.append(new_image_filename)
285
+ list_of_image_embedding.append(image_embedding(new_image_filename))
286
+
287
+ # Calculate Cosine Sim to original image
288
+ similarity = cosine_similarity(
289
+ [list_of_image_embedding[0]], [list_of_image_embedding[-1]]
290
+ )[0][0]
291
+ list_of_cos_sim.append(similarity)
292
+ logging.info(f"({count_words(image_desc)}, {round(similarity,2)})")
293
+
294
+ # Save checkpoint to avoid losing results
295
+ data_to_save = {
296
+ "descriptions": list_of_desc,
297
+ "images": list_of_image,
298
+ "image_embeddings": list_of_image_embedding,
299
+ "cosine_similarities": list_of_cos_sim,
300
+ }
301
+ with open(pkl_file, "wb") as file:
302
+ pickle.dump(data_to_save, file)
303
+
304
+ # Update current_image_path for the next iteration
305
+ current_image_path = new_image_filename
306
+
307
+ return None
308
+
309
+
310
+ def main():
311
+ parser = argparse.ArgumentParser()
312
+ parser.add_argument("--dataset", type=str, default="mme_data/color")
313
+ parser.add_argument("--model", type=str, default="llava7b")
314
+ parser.add_argument("--n_iter", type=int, default=5)
315
+ args = parser.parse_args()
316
+
317
+ logging.info(args)
318
+
319
+ prompt = "Please write a clear, precise, detailed, and concise description of all elements in the image. Focus on accurately depicting various aspects, including but not limited to the colors, shapes, positions, styles, texts and the relationships between different objects and subjects in the image. Your description should be thorough enough to guide a professional in recreating this image solely based on your textual representation. Remember, only include descriptive texts that directly pertain to the contents of the image. You must complete the description using less than 500 words."
320
+
321
+ if "llava" in args.model:
322
+ lmm_model = LlavaForConditionalGeneration.from_pretrained(
323
+ f"llava-hf/llava-1.5-{args.model[5:]}-hf", load_in_8bit=True
324
+ )
325
+ lmm_processor = AutoProcessor.from_pretrained(
326
+ f"llava-hf/llava-1.5-{args.model[5:]}-hf"
327
+ )
328
+ prompt = f"<image>\nUSER: {prompt}\nASSISTANT:"
329
+ get_desc_function = partial(get_desc_llava, lmm_processor, lmm_model, prompt)
330
+ encode_image_function = encode_image_os
331
+ elif args.model == "mPLUG":
332
+ model_path = "MAGAer13/mplug-owl2-llama2-7b"
333
+ model_name = get_model_name_from_path(model_path)
334
+ tokenizer, lmm_model, image_processor, _ = load_pretrained_model(
335
+ model_path,
336
+ None,
337
+ model_name,
338
+ load_8bit=False,
339
+ load_4bit=False,
340
+ device=device,
341
+ )
342
+ tokenizer.pad_token_id = tokenizer.eos_token_id
343
+ tokenizer.pad_token = tokenizer.eos_token
344
+ get_desc_function = partial(
345
+ get_desc_mPLUG, image_processor, lmm_model, tokenizer, prompt
346
+ )
347
+ encode_image_function = encode_image_os
348
+ elif args.model == "gpt4v":
349
+ get_desc_function = partial(get_desc_gpt4v, prompt=prompt)
350
+ encode_image_function = encode_image_gpt4v
351
+
352
+ output_folder = os.path.join(args.dataset, f"results_{args.model}")
353
+ os.makedirs(output_folder, exist_ok=True)
354
+
355
+ logging.debug("Loaded model. Entered main loop.")
356
+ for img_file in tqdm(find_image_files(args.dataset)):
357
+ try:
358
+ logging.info(img_file)
359
+ test_sample(
360
+ seed_image=img_file,
361
+ n_iteration=args.n_iter,
362
+ output_folder=output_folder,
363
+ get_desc_function=get_desc_function,
364
+ encode_image_function=encode_image_function,
365
+ )
366
+ except Exception as e:
367
+ logging.warning("caught error:")
368
+ logging.warning(e)
369
+ continue
370
+
371
+
372
+ if __name__ == "__main__":
373
+ main()
genception/utils.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ def find_files(folder_path: str, file_extensions: dict) -> list[str]:
5
+ """
6
+ Find all files with the given extensions in the given folder path
7
+
8
+ Args:
9
+ folder_path: str: The path to the folder
10
+ file_extensions: dict: The file extensions to look for
11
+
12
+ Returns:
13
+ list[str]: The list of file paths
14
+ """
15
+ file_paths = []
16
+
17
+ for file in os.listdir(folder_path):
18
+ if (
19
+ os.path.isfile(os.path.join(folder_path, file))
20
+ and os.path.splitext(file)[1].lower() in file_extensions
21
+ ):
22
+ absolute_path = os.path.abspath(os.path.join(folder_path, file))
23
+ file_paths.append(absolute_path)
24
+
25
+ return file_paths
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ transformers>=4.37.1
2
+ pillow
3
+ requests
4
+ scikit-learn
5
+ nltk
6
+ openai
7
+ sentencepiece