Upload 2398 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- OPERA/LICENSE +21 -0
- OPERA/README.md +185 -0
- OPERA/chair.py +484 -0
- OPERA/chair_eval.py +192 -0
- OPERA/dataset/README_1_STAGE.md +96 -0
- OPERA/dataset/README_2_STAGE.md +19 -0
- OPERA/dataset/convert_cc_sbu.py +20 -0
- OPERA/dataset/convert_laion.py +20 -0
- OPERA/dataset/download_cc_sbu.sh +6 -0
- OPERA/dataset/download_laion.sh +6 -0
- OPERA/demo.ipynb +0 -0
- OPERA/environment.yml +317 -0
- OPERA/eval_configs/instructblip_eval.yaml +23 -0
- OPERA/eval_configs/llava-1.5_eval.yaml +29 -0
- OPERA/eval_configs/minigpt4_eval.yaml +23 -0
- OPERA/eval_configs/minigpt4_llama2_eval.yaml +22 -0
- OPERA/eval_configs/shikra_eval.yaml +29 -0
- OPERA/gpt4v_eval.py +306 -0
- OPERA/log/chair_eval_results/instructblip/beam5.jsonl +0 -0
- OPERA/log/chair_eval_results/instructblip/greedy.jsonl +0 -0
- OPERA/log/chair_eval_results/instructblip/ours.jsonl +0 -0
- OPERA/log/chair_eval_results/llava-1.5/beam5.jsonl +0 -0
- OPERA/log/chair_eval_results/llava-1.5/greedy.jsonl +0 -0
- OPERA/log/chair_eval_results/llava-1.5/ours.jsonl +0 -0
- OPERA/log/chair_eval_results/minigpt4/beam5.jsonl +0 -0
- OPERA/log/chair_eval_results/minigpt4/greedy.jsonl +0 -0
- OPERA/log/chair_eval_results/minigpt4/ours.jsonl +0 -0
- OPERA/log/chair_eval_results/shikra/beam5.jsonl +0 -0
- OPERA/log/chair_eval_results/shikra/greedy.jsonl +0 -0
- OPERA/log/chair_eval_results/shikra/ours.jsonl +0 -0
- OPERA/minigpt4/__init__.py +31 -0
- OPERA/minigpt4/common/__init__.py +0 -0
- OPERA/minigpt4/common/config.py +468 -0
- OPERA/minigpt4/common/dist_utils.py +151 -0
- OPERA/minigpt4/common/gradcam.py +24 -0
- OPERA/minigpt4/common/logger.py +195 -0
- OPERA/minigpt4/common/optims.py +119 -0
- OPERA/minigpt4/common/registry.py +329 -0
- OPERA/minigpt4/common/utils.py +424 -0
- OPERA/minigpt4/configs/datasets/cc_sbu/align.yaml +5 -0
- OPERA/minigpt4/configs/datasets/cc_sbu/defaults.yaml +5 -0
- OPERA/minigpt4/configs/datasets/laion/defaults.yaml +5 -0
- OPERA/minigpt4/configs/default.yaml +5 -0
- OPERA/minigpt4/configs/models/blip2_instruct_vicuna13b.yaml +43 -0
- OPERA/minigpt4/configs/models/blip2_instruct_vicuna7b.yaml +43 -0
- OPERA/minigpt4/configs/models/llava-1.5_vicuna7b.yaml +40 -0
- OPERA/minigpt4/configs/models/minigpt4_llama2.yaml +29 -0
- OPERA/minigpt4/configs/models/minigpt4_vicuna0.yaml +32 -0
- OPERA/minigpt4/configs/models/shikra_vicuna7b.yaml +40 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
OPERA/teaser.png filter=lfs diff=lfs merge=lfs -text
|
OPERA/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 QidongHuang
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
OPERA/README.md
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# OPERA: Alleviating Hallucination in Multi-Modal Large Language Models via Over-Trust Penalty and Retrospection-Allocation (CVPR 2024 Highlight)
|
2 |
+
|
3 |
+
[![License: MIT](https://img.shields.io/badge/License-MIT-g.svg)](https://opensource.org/licenses/MIT)
|
4 |
+
[![Arxiv](https://img.shields.io/badge/arXiv-2311.17911-B21A1B)](https://arxiv.org/pdf/2311.17911.pdf)
|
5 |
+
[![Hugging Face Transformers](https://img.shields.io/badge/%F0%9F%A4%97-Transformers-blue)](https://github.com/huggingface/transformers)
|
6 |
+
[![GitHub Stars](https://img.shields.io/github/stars/shikiw/OPERA?style=social)](https://github.com/shikiw/OPERA/stargazers)
|
7 |
+
|
8 |
+
|
9 |
+
This repository provides the official PyTorch implementation of the following paper:
|
10 |
+
> [**OPERA: Alleviating Hallucination in Multi-Modal Large Language Models via Over-Trust Penalty and Retrospection-Allocation**](https://arxiv.org/pdf/2311.17911.pdf) <br>
|
11 |
+
> [Qidong Huang](https://shikiw.github.io/)<sup>1,2</sup>,
|
12 |
+
> [Xiaoyi Dong](https://scholar.google.com/citations?user=FscToE0AAAAJ&hl=en)<sup>2</sup>,
|
13 |
+
> [Pan Zhang](https://panzhang0212.github.io/)<sup>2</sup>,
|
14 |
+
> [Bin Wang](https://wangbindl.github.io/) <sup>2</sup>,
|
15 |
+
> [Conghui He](https://conghui.github.io/) <sup>2</sup>,
|
16 |
+
> [Jiaqi Wang](https://myownskyw7.github.io/)<sup>2</sup>,
|
17 |
+
> [Dahua Lin](http://dahua.site/)<sup>2</sup>,
|
18 |
+
> [Weiming Zhang](http://staff.ustc.edu.cn/~zhangwm/index.html)<sup>1</sup>,
|
19 |
+
> [Nenghai Yu](https://scholar.google.com/citations?user=7620QAMAAAAJ&hl=en)<sup>1</sup> <br>
|
20 |
+
> <sup>1</sup>University of Science and Technology of China, <sup>2</sup>Shanghai AI Laboratory <br>
|
21 |
+
|
22 |
+
|
23 |
+
## Overview
|
24 |
+
|
25 |
+
<p align="center"><img src="./teaser.png" alt="teaser" width="500px" /></p>
|
26 |
+
|
27 |
+
Hallucination, posed as a pervasive challenge of multimodal large language models (MLLMs), has significantly impeded their real-world usage that demands precise judgment. Existing methods mitigate this issue with either training with specific designed data or inferencing with external knowledge from other sources, incurring inevitable additional costs. In this paper, we present OPERA, a novel MLLM decoding method grounded in an Over-trust Penalty and a Retrospection-Allocation strategy, serving as a nearly free lunch to alleviate the hallucination issue without additional data, knowledge, or training. Our approach begins with an interesting observation that, most hallucinations are closely tied to the knowledge aggregation patterns manifested in the self-attention matrix, i.e., MLLMs tend to generate new tokens by focusing on a few summary tokens, but not all the previous tokens. Such partial overtrust inclination results in the neglecting of image tokens and describes the image content with hallucination. Based on the observation, OPERA introduces a penalty term on
|
28 |
+
the model logits during the beam-search decoding to mitigate the over-trust issue, along with a rollback strategy that retrospects the presence of summary tokens in the previously generated tokens, and re-allocate the token selection if necessary. With extensive experiments, OPERA shows significant hallucination-mitigating performance on different MLLMs and metrics, proving its effectiveness and generality.
|
29 |
+
|
30 |
+
## Setup
|
31 |
+
|
32 |
+
The main implementation of OPERA is in `transformers-4.29.2/src/transformers/generation/utils.py`.
|
33 |
+
|
34 |
+
So it is convenient to use OPERA decoding by just installing our modified `transformers` package.
|
35 |
+
```
|
36 |
+
conda env create -f environment.yml
|
37 |
+
conda activate opera
|
38 |
+
python -m pip install -e transformers-4.29.2
|
39 |
+
```
|
40 |
+
#### Note: to implement OPERA on other version of transformers, you can follow the steps as the follows:
|
41 |
+
- Find the file at `transformers-4.29.2/src/transformers/generation/utils.py`.
|
42 |
+
- Add the arguments in `transformers.generate` function [here](https://github.com/shikiw/OPERA/blob/aa968c7501f4d3d8362f4b3bcab855024f4da5f6/transformers-4.29.2/src/transformers/generation/utils.py#L1156-L1162).
|
43 |
+
- Add the code in `transformers.generate` function [here](https://github.com/shikiw/OPERA/blob/aa968c7501f4d3d8362f4b3bcab855024f4da5f6/transformers-4.29.2/src/transformers/generation/utils.py#L1619-L1665).
|
44 |
+
- Copy and paste the `opera_decoding` function [here](https://github.com/shikiw/OPERA/blob/aa968c7501f4d3d8362f4b3bcab855024f4da5f6/transformers-4.29.2/src/transformers/generation/utils.py#L3116-L3674).
|
45 |
+
|
46 |
+
## TL;DR
|
47 |
+
After setup the environment, you can directly use OPERA on your own MLLM model by:
|
48 |
+
```
|
49 |
+
# specify the location indexes of some input tokens
|
50 |
+
START_INDEX_of_IMAGE_TOKENS = <the location index of the first image token>
|
51 |
+
END_INDEX_of_IMAGE_TOKENS = <the location index of the last image token>
|
52 |
+
NUM_of_TOKENS_IN_THE_PROMPT = <the total number of tokens in the user prompt (including image tokens)>
|
53 |
+
|
54 |
+
key_position = {
|
55 |
+
"image_start": START_INDEX_of_IMAGE_TOKENS,
|
56 |
+
"image_end": END_INDEX_of_IMAGE_TOKENS,
|
57 |
+
"response_start": NUM_of_TOKENS_IN_THE_PROMPT,
|
58 |
+
}
|
59 |
+
|
60 |
+
# add some arguments in the generate function
|
61 |
+
outputs = MLLM_model.generate(
|
62 |
+
input_ids=input_ids,
|
63 |
+
inputs_embeds=inputs_embeds,
|
64 |
+
attention_mask=attention_mask,
|
65 |
+
do_sample=False,
|
66 |
+
num_beams=5,
|
67 |
+
max_new_tokens=512,
|
68 |
+
# opera
|
69 |
+
opera_decoding=True,
|
70 |
+
key_position=key_position,
|
71 |
+
scale_factor=50,
|
72 |
+
threshold=15,
|
73 |
+
num_attn_candidates=5,
|
74 |
+
penalty_weights=1,
|
75 |
+
)
|
76 |
+
# for a more efficient version, please use the setting below:
|
77 |
+
outputs = MLLM_model.generate(
|
78 |
+
input_ids=input_ids,
|
79 |
+
inputs_embeds=inputs_embeds,
|
80 |
+
attention_mask=attention_mask,
|
81 |
+
do_sample=False,
|
82 |
+
num_beams=5,
|
83 |
+
max_new_tokens=512,
|
84 |
+
# opera
|
85 |
+
opera_decoding=True,
|
86 |
+
key_position=key_position,
|
87 |
+
scale_factor=50,
|
88 |
+
threshold=25,
|
89 |
+
num_attn_candidates=1,
|
90 |
+
penalty_weights=1,
|
91 |
+
)
|
92 |
+
```
|
93 |
+
|
94 |
+
Please refer to `demo.ipynb` [here](https://github.com/shikiw/OPERA/blob/1e74d8b5d082579c81e0e77ef1cf4a44d20ab91e/demo.ipynb) for more details.
|
95 |
+
|
96 |
+
|
97 |
+
## Evaluation
|
98 |
+
|
99 |
+
The following evaluation requires for MSCOCO 2014 dataset. Please download [here](https://cocodataset.org/#home) and extract it in your data path.
|
100 |
+
|
101 |
+
Besides, it needs you to prepare the following checkpoints of 7B base models:
|
102 |
+
|
103 |
+
- Download [LLaVA-1.5 merged 7B model](https://huggingface.co/liuhaotian/llava-v1.5-7b) and specify it at [Line 14](https://github.com/shikiw/OPERA/blob/bf18aa9c409f28b31168b0f71ebf8457ae8063d5/eval_configs/llava-1.5_eval.yaml#L14) of `eval_configs/llava-1.5_eval.yaml`.
|
104 |
+
- Download [Vicuna 7B v1.1 model](https://github.com/lm-sys/FastChat) and specify it at [Line 25](https://github.com/shikiw/OPERA/blob/bf18aa9c409f28b31168b0f71ebf8457ae8063d5/minigpt4/configs/models/blip2_instruct_vicuna7b.yaml#L25) of `minigpt4/configs/models/blip2_instruct_vicuna7b.yaml`.
|
105 |
+
- Download [Vicuna 7B v0 model](https://huggingface.co/Vision-CAIR/vicuna-7b/tree/main) and specify it at [Line 18](https://github.com/shikiw/OPERA/blob/bf18aa9c409f28b31168b0f71ebf8457ae8063d5/minigpt4/configs/models/minigpt4_vicuna0.yaml#L18) of `minigpt4/configs/models/minigpt4_vicuna0.yaml`.
|
106 |
+
- Download [MiniGPT-4 7B pretrained weights](https://drive.google.com/file/d/1RY9jV0dyqLX-o38LrumkKRh6Jtaop58R/view?usp=sharing) and specify it at [Line 8](https://github.com/shikiw/OPERA/blob/bf18aa9c409f28b31168b0f71ebf8457ae8063d5/eval_configs/minigpt4_eval.yaml#L8) of `eval_configs/minigpt4_eval.yaml`.
|
107 |
+
- Download [Shikra merged 7B model](https://github.com/shikras/shikra#checkpoint) and specify it at [Line 14](https://github.com/shikiw/OPERA/blob/bf18aa9c409f28b31168b0f71ebf8457ae8063d5/eval_configs/shikra_eval.yaml#L14) of `eval_configs/shikra_eval.yaml`.
|
108 |
+
|
109 |
+
### Arguments
|
110 |
+
|
111 |
+
| Argument | Example | Description |
|
112 |
+
| -------------------- | ------------------- | ------------- |
|
113 |
+
| `--model` | `llava-1.5` | Specify the MLLM model, this codebase supports `instructblip`, `minigpt4`, `llava-1.5`, `shikra`. |
|
114 |
+
| `--data-path` | `/path/to/dataset` | Path to the dataset file or folder, e.g., `COCO_2014/val2014/`. |
|
115 |
+
| `--pope-type` | `random` | Type for POPE evaluation, supports `random`, `popular`, `adversarial`. |
|
116 |
+
| `--scale_factor` | `50` | The scale factor to scale up the self-attention weights. Default: 50. |
|
117 |
+
| `--threshold` | `15` | The threshold for attending retrospection. Default: 15. |
|
118 |
+
| `--num_attn_candidates` | `5` | The number of candidates per beam. Default: 5. |
|
119 |
+
| `--penalty_weights`| `1` | The weight of penalty term in decoding. Default: 1. |
|
120 |
+
|
121 |
+
### POPE
|
122 |
+
```bash
|
123 |
+
python pope_eval.py --model MODEL_NAME --data_path /path/to/COCO --pope-type random --gpu-id GPU_IDs --beam 5 --scale_factor 50 --threshold 15 --num_attn_candidates 5 --penalty_weights 1
|
124 |
+
```
|
125 |
+
Result on `Random` split:
|
126 |
+
|
127 |
+
| Model | Accuracy | Precision | Recall | F1 score| Yes ratio |
|
128 |
+
| ----- | -------- | --------- | ------ | ------- | --------- |
|
129 |
+
| InstructBLIP 7B | 90.3 | 93.8 | 87.0 | 90.3 | 47.8 |
|
130 |
+
| MiniGPT-4 7B | 79.8 | 89.7 | 68.7 | 77.8 | 39.5 |
|
131 |
+
| LLaVA-1.5 7B | 89.4 | 90.4 | 88.8 | 89.6 | 50.6 |
|
132 |
+
|
133 |
+
Result on `Popular` split:
|
134 |
+
|
135 |
+
| Model | Accuracy | Precision | Recall | F1 score| Yes ratio |
|
136 |
+
| ----- | -------- | --------- | ------ | ------- | --------- |
|
137 |
+
| InstructBLIP 7B | 83.4 | 81.2 | 87.0 | 84.0 | 53.6 |
|
138 |
+
| MiniGPT-4 7B | 73.6 | 75.9 | 69.0 | 72.3 | 45.4 |
|
139 |
+
| LLaVA-1.5 7B | 86.0 | 84.1 | 88.8 | 86.4 | 52.8 |
|
140 |
+
|
141 |
+
Result on `Adversarial` split:
|
142 |
+
|
143 |
+
| Model | Accuracy | Precision | Recall | F1 score| Yes ratio |
|
144 |
+
| ----- | -------- | --------- | ------ | ------- | --------- |
|
145 |
+
| InstructBLIP 7B | 80.7 | 77.3 | 87.0 | 81.9 | 56.3 |
|
146 |
+
| MiniGPT-4 7B | 71.6 | 72.9 | 68.9 | 70.8 | 47.3 |
|
147 |
+
| LLaVA-1.5 7B | 79.1 | 74.4 | 88.8 | 81.0 | 59.7 |
|
148 |
+
|
149 |
+
### CHAIR
|
150 |
+
- Generate the MLLM's responses and save them in a jsonl file:
|
151 |
+
```bash
|
152 |
+
python chair_eval.py --model MODEL_NAME --data_path /path/to/COCO --gpu-id GPU_IDs --beam 5 --scale_factor 50 --threshold 15 --num_attn_candidates 5 --penalty_weights 1
|
153 |
+
```
|
154 |
+
Note: Please check out our released results in `log/chair_eval_results` for reproduction.
|
155 |
+
|
156 |
+
- Calculate CHAIR using the generated jsonl file:
|
157 |
+
```bash
|
158 |
+
python chair.py --cap_file /path/to/jsonl --image_id_key image_id --caption_key caption --coco_path /path/to/COCO/annotations_trainval2014/annotations/ --save_path /path/to/save/jsonl
|
159 |
+
```
|
160 |
+
|
161 |
+
### GPT-4V
|
162 |
+
The GPT-4V evaluation requires you to specify your API key in [Line 88](https://github.com/shikiw/OPERA/blob/559556048224d5c3eae995a21d529156fb150d5f/gpt4v_eval.py#L88) of `gpt4v_eval.py`.
|
163 |
+
```bash
|
164 |
+
python gpt4v_eval.py --model MODEL_NAME --data_path /path/to/COCO --gpu-id GPU_IDs --scale_factor 50 --threshold 15 --num_attn_candidates 5 --penalty_weights 1
|
165 |
+
```
|
166 |
+
|
167 |
+
|
168 |
+
|
169 |
+
|
170 |
+
## Acknowledgement
|
171 |
+
This repo is based on the MLLM codebase of [LAVIS](https://github.com/salesforce/LAVIS) and [MiniGPT-4](https://github.com/Vision-CAIR/MiniGPT-4) and the CHAIR code of [Maxlinn](https://github.com/Maxlinn/CHAIR-metric-standalone). Thanks for their impressive works!
|
172 |
+
|
173 |
+
## Citation
|
174 |
+
If you find this work useful for your research, please cite [our paper](https://arxiv.org/pdf/2311.17911.pdf):
|
175 |
+
```
|
176 |
+
@inproceedings{huang2024opera,
|
177 |
+
title={Opera: Alleviating hallucination in multi-modal large language models via over-trust penalty and retrospection-allocation},
|
178 |
+
author={Huang, Qidong and Dong, Xiaoyi and Zhang, Pan and Wang, Bin and He, Conghui and Wang, Jiaqi and Lin, Dahua and Zhang, Weiming and Yu, Nenghai},
|
179 |
+
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
|
180 |
+
pages={13418--13427},
|
181 |
+
year={2024}
|
182 |
+
}
|
183 |
+
```
|
184 |
+
|
185 |
+
|
OPERA/chair.py
ADDED
@@ -0,0 +1,484 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Copied from: https://github.com/LisaAnne/Hallucination/blob/master/utils/chair.py
|
3 |
+
|
4 |
+
Modified by: Maxlinn
|
5 |
+
|
6 |
+
1. adapt calculation of CHAIR-i and CHAIR-s for Python3, supports for both json and jsonl file input.
|
7 |
+
2. integrate synonyms.txt to make the script standalone.
|
8 |
+
3. remove machine-translation based metrics BLEU-n, CIDEr, ROGUE
|
9 |
+
4. add new metric Recall, which represents the node words(i.e. lemmas of objects) coverage overall.
|
10 |
+
5. add pickle cache mechanism to make it fast for repetitive evaluations.
|
11 |
+
'''
|
12 |
+
|
13 |
+
|
14 |
+
import os
|
15 |
+
import sys
|
16 |
+
import nltk
|
17 |
+
import json
|
18 |
+
# from pattern.en import singularize
|
19 |
+
from nltk.corpus import wordnet
|
20 |
+
from nltk.stem import WordNetLemmatizer
|
21 |
+
import argparse
|
22 |
+
import tqdm
|
23 |
+
import pickle
|
24 |
+
from collections import defaultdict
|
25 |
+
|
26 |
+
|
27 |
+
# copied from: https://github.com/LisaAnne/Hallucination/blob/master/data/synonyms.txt
|
28 |
+
synonyms_txt = '''
|
29 |
+
person, girl, boy, man, woman, kid, child, chef, baker, people, adult, rider, children, baby, worker, passenger, sister, biker, policeman, cop, officer, lady, cowboy, bride, groom, male, female, guy, traveler, mother, father, gentleman, pitcher, player, skier, snowboarder, skater, skateboarder, person, woman, guy, foreigner, child, gentleman, caller, offender, coworker, trespasser, patient, politician, soldier, grandchild, serviceman, walker, drinker, doctor, bicyclist, thief, buyer, teenager, student, camper, driver, solider, hunter, shopper, villager
|
30 |
+
bicycle, bike, bicycle, bike, unicycle, minibike, trike
|
31 |
+
car, automobile, van, minivan, sedan, suv, hatchback, cab, jeep, coupe, taxicab, limo, taxi
|
32 |
+
motorcycle, scooter, motor bike, motor cycle, motorbike, scooter, moped
|
33 |
+
airplane, jetliner, plane, air plane, monoplane, aircraft, jet, jetliner, airbus, biplane, seaplane
|
34 |
+
bus, minibus, trolley
|
35 |
+
train, locomotive, tramway, caboose
|
36 |
+
truck, pickup, lorry, hauler, firetruck
|
37 |
+
boat, ship, liner, sailboat, motorboat, dinghy, powerboat, speedboat, canoe, skiff, yacht, kayak, catamaran, pontoon, houseboat, vessel, rowboat, trawler, ferryboat, watercraft, tugboat, schooner, barge, ferry, sailboard, paddleboat, lifeboat, freighter, steamboat, riverboat, battleship, steamship
|
38 |
+
traffic light, street light, traffic signal, stop light, streetlight, stoplight
|
39 |
+
fire hydrant, hydrant
|
40 |
+
stop sign
|
41 |
+
parking meter
|
42 |
+
bench, pew
|
43 |
+
bird, ostrich, owl, seagull, goose, duck, parakeet, falcon, robin, pelican, waterfowl, heron, hummingbird, mallard, finch, pigeon, sparrow, seabird, osprey, blackbird, fowl, shorebird, woodpecker, egret, chickadee, quail, bluebird, kingfisher, buzzard, willet, gull, swan, bluejay, flamingo, cormorant, parrot, loon, gosling, waterbird, pheasant, rooster, sandpiper, crow, raven, turkey, oriole, cowbird, warbler, magpie, peacock, cockatiel, lorikeet, puffin, vulture, condor, macaw, peafowl, cockatoo, songbird
|
44 |
+
cat, kitten, feline, tabby
|
45 |
+
dog, puppy, beagle, pup, chihuahua, schnauzer, dachshund, rottweiler, canine, pitbull, collie, pug, terrier, poodle, labrador, doggie, doberman, mutt, doggy, spaniel, bulldog, sheepdog, weimaraner, corgi, cocker, greyhound, retriever, brindle, hound, whippet, husky
|
46 |
+
horse, colt, pony, racehorse, stallion, equine, mare, foal, palomino, mustang, clydesdale, bronc, bronco
|
47 |
+
sheep, lamb, ram, lamb, goat, ewe
|
48 |
+
cow, cattle, oxen, ox, calf, cattle, holstein, heifer, buffalo, bull, zebu, bison
|
49 |
+
elephant
|
50 |
+
bear, panda
|
51 |
+
zebra
|
52 |
+
giraffe
|
53 |
+
backpack, knapsack
|
54 |
+
umbrella
|
55 |
+
handbag, wallet, purse, briefcase
|
56 |
+
tie, bow, bow tie
|
57 |
+
suitcase, suit case, luggage
|
58 |
+
frisbee
|
59 |
+
skis, ski
|
60 |
+
snowboard
|
61 |
+
sports ball, ball
|
62 |
+
kite
|
63 |
+
baseball bat
|
64 |
+
baseball glove
|
65 |
+
skateboard
|
66 |
+
surfboard, longboard, skimboard, shortboard, wakeboard
|
67 |
+
tennis racket, racket
|
68 |
+
bottle
|
69 |
+
wine glass
|
70 |
+
cup
|
71 |
+
fork
|
72 |
+
knife, pocketknife, knive
|
73 |
+
spoon
|
74 |
+
bowl, container
|
75 |
+
banana
|
76 |
+
apple
|
77 |
+
sandwich, burger, sub, cheeseburger, hamburger
|
78 |
+
orange
|
79 |
+
broccoli
|
80 |
+
carrot
|
81 |
+
hot dog
|
82 |
+
pizza
|
83 |
+
donut, doughnut, bagel
|
84 |
+
cake, cheesecake, cupcake, shortcake, coffeecake, pancake
|
85 |
+
chair, seat, stool
|
86 |
+
couch, sofa, recliner, futon, loveseat, settee, chesterfield
|
87 |
+
potted plant, houseplant
|
88 |
+
bed
|
89 |
+
dining table, table, desk
|
90 |
+
toilet, urinal, commode, toilet, lavatory, potty
|
91 |
+
tv, monitor, televison, television
|
92 |
+
laptop, computer, notebook, netbook, lenovo, macbook, laptop computer
|
93 |
+
mouse
|
94 |
+
remote
|
95 |
+
keyboard
|
96 |
+
cell phone, mobile phone, phone, cellphone, telephone, phon, smartphone, iPhone
|
97 |
+
microwave
|
98 |
+
oven, stovetop, stove, stove top oven
|
99 |
+
toaster
|
100 |
+
sink
|
101 |
+
refrigerator, fridge, fridge, freezer
|
102 |
+
book
|
103 |
+
clock
|
104 |
+
vase
|
105 |
+
scissors
|
106 |
+
teddy bear, teddybear
|
107 |
+
hair drier, hairdryer
|
108 |
+
toothbrush
|
109 |
+
'''
|
110 |
+
|
111 |
+
|
112 |
+
def combine_coco_captions(annotation_path):
|
113 |
+
|
114 |
+
if not os.path.exists('%s/captions_%s2014.json' %(annotation_path, 'val')):
|
115 |
+
raise Exception("Please download MSCOCO caption annotations for val set")
|
116 |
+
if not os.path.exists('%s/captions_%s2014.json' %(annotation_path, 'train')):
|
117 |
+
raise Exception("Please download MSCOCO caption annotations for train set")
|
118 |
+
|
119 |
+
val_caps = json.load(open('%s/captions_%s2014.json' %(annotation_path, 'val')))
|
120 |
+
train_caps = json.load(open('%s/captions_%s2014.json' %(annotation_path, 'train')))
|
121 |
+
all_caps = {'info': train_caps['info'],
|
122 |
+
'licenses': train_caps['licenses'],
|
123 |
+
'images': val_caps['images'] + train_caps['images'],
|
124 |
+
'annotations': val_caps['annotations'] + train_caps['annotations']}
|
125 |
+
|
126 |
+
return all_caps
|
127 |
+
|
128 |
+
def combine_coco_instances(annotation_path):
|
129 |
+
|
130 |
+
if not os.path.exists('%s/instances_%s2014.json' %(annotation_path, 'val')):
|
131 |
+
raise Exception("Please download MSCOCO instance annotations for val set")
|
132 |
+
if not os.path.exists('%s/instances_%s2014.json' %(annotation_path, 'train')):
|
133 |
+
raise Exception("Please download MSCOCO instance annotations for train set")
|
134 |
+
|
135 |
+
val_instances = json.load(open('%s/instances_%s2014.json' %(annotation_path, 'val')))
|
136 |
+
train_instances = json.load(open('%s/instances_%s2014.json' %(annotation_path, 'train')))
|
137 |
+
all_instances = {'info': train_instances['info'],
|
138 |
+
'licenses': train_instances['licenses'],
|
139 |
+
'type': train_instances['licenses'],
|
140 |
+
'categories': train_instances['categories'],
|
141 |
+
'images': train_instances['images'] + val_instances['images'],
|
142 |
+
'annotations': val_instances['annotations'] + train_instances['annotations']}
|
143 |
+
|
144 |
+
return all_instances
|
145 |
+
|
146 |
+
class CHAIR(object):
|
147 |
+
|
148 |
+
def __init__(self, coco_path):
|
149 |
+
|
150 |
+
self.imid_to_objects = defaultdict(list) # later become a dict of sets
|
151 |
+
|
152 |
+
self.coco_path = coco_path
|
153 |
+
|
154 |
+
#read in synonyms
|
155 |
+
synonyms = synonyms_txt.splitlines()
|
156 |
+
synonyms = [s.strip().split(', ') for s in synonyms]
|
157 |
+
self.mscoco_objects = [] #mscoco objects and *all* synonyms
|
158 |
+
self.inverse_synonym_dict = {}
|
159 |
+
for synonym in synonyms:
|
160 |
+
self.mscoco_objects.extend(synonym)
|
161 |
+
for s in synonym:
|
162 |
+
self.inverse_synonym_dict[s] = synonym[0]
|
163 |
+
|
164 |
+
#Some hard coded rules for implementing CHAIR metrics on MSCOCO
|
165 |
+
|
166 |
+
#common 'double words' in MSCOCO that should be treated as a single word
|
167 |
+
coco_double_words = ['motor bike', 'motor cycle', 'air plane', 'traffic light', 'street light', 'traffic signal', 'stop light', 'fire hydrant', 'stop sign', 'parking meter', 'suit case', 'sports ball', 'baseball bat', 'baseball glove', 'tennis racket', 'wine glass', 'hot dog', 'cell phone', 'mobile phone', 'teddy bear', 'hair drier', 'potted plant', 'bow tie', 'laptop computer', 'stove top oven', 'hot dog', 'teddy bear', 'home plate', 'train track']
|
168 |
+
|
169 |
+
#Hard code some rules for special cases in MSCOCO
|
170 |
+
#qualifiers like 'baby' or 'adult' animal will lead to a false fire for the MSCOCO object 'person'. 'baby bird' --> 'bird'.
|
171 |
+
animal_words = ['bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'animal', 'cub']
|
172 |
+
#qualifiers like 'passenger' vehicle will lead to a false fire for the MSCOCO object 'person'. 'passenger jet' --> 'jet'.
|
173 |
+
vehicle_words = ['jet', 'train']
|
174 |
+
|
175 |
+
#double_word_dict will map double words to the word they should be treated as in our analysis
|
176 |
+
|
177 |
+
self.double_word_dict = {}
|
178 |
+
for double_word in coco_double_words:
|
179 |
+
self.double_word_dict[double_word] = double_word
|
180 |
+
for animal_word in animal_words:
|
181 |
+
self.double_word_dict['baby %s' %animal_word] = animal_word
|
182 |
+
self.double_word_dict['adult %s' %animal_word] = animal_word
|
183 |
+
for vehicle_word in vehicle_words:
|
184 |
+
self.double_word_dict['passenger %s' %vehicle_word] = vehicle_word
|
185 |
+
self.double_word_dict['bow tie'] = 'tie'
|
186 |
+
self.double_word_dict['toilet seat'] = 'toilet'
|
187 |
+
self.double_word_dict['wine glas'] = 'wine glass'
|
188 |
+
|
189 |
+
self.get_annotations()
|
190 |
+
|
191 |
+
def _load_generated_captions_into_evaluator(self, cap_file, image_id_key, caption_key):
|
192 |
+
|
193 |
+
'''
|
194 |
+
Meant to save time so imid_to_objects does not always need to be recomputed.
|
195 |
+
'''
|
196 |
+
#Read in captions
|
197 |
+
self.caps, self.eval_imids = load_generated_captions(cap_file, image_id_key, caption_key)
|
198 |
+
assert len(self.caps) == len(self.eval_imids)
|
199 |
+
|
200 |
+
def get_wordnet_pos(self, tag):
|
201 |
+
if tag.startswith('J'):
|
202 |
+
return wordnet.ADJ
|
203 |
+
elif tag.startswith('V'):
|
204 |
+
return wordnet.VERB
|
205 |
+
elif tag.startswith('N'):
|
206 |
+
return wordnet.NOUN
|
207 |
+
elif tag.startswith('R'):
|
208 |
+
return wordnet.ADV
|
209 |
+
else:
|
210 |
+
return None
|
211 |
+
|
212 |
+
def caption_to_words(self, caption):
|
213 |
+
|
214 |
+
'''
|
215 |
+
Input: caption
|
216 |
+
Output: MSCOCO words in the caption
|
217 |
+
'''
|
218 |
+
|
219 |
+
#standard preprocessing
|
220 |
+
words = nltk.word_tokenize(caption.lower())
|
221 |
+
tagged_sent = nltk.pos_tag(words)
|
222 |
+
lemmas_sent = []
|
223 |
+
wnl = WordNetLemmatizer()
|
224 |
+
for tag in tagged_sent:
|
225 |
+
wordnet_pos = self.get_wordnet_pos(tag[1]) or wordnet.NOUN
|
226 |
+
lemmas_sent.append(wnl.lemmatize(tag[0], pos=wordnet_pos))
|
227 |
+
# words = [singularize(w) for w in words]
|
228 |
+
words = lemmas_sent
|
229 |
+
|
230 |
+
#replace double words
|
231 |
+
i = 0
|
232 |
+
double_words = []
|
233 |
+
idxs = []
|
234 |
+
while i < len(words):
|
235 |
+
idxs.append(i)
|
236 |
+
double_word = ' '.join(words[i:i+2])
|
237 |
+
if double_word in self.double_word_dict:
|
238 |
+
double_words.append(self.double_word_dict[double_word])
|
239 |
+
i += 2
|
240 |
+
else:
|
241 |
+
double_words.append(words[i])
|
242 |
+
i += 1
|
243 |
+
words = double_words
|
244 |
+
|
245 |
+
#toilet seat is not chair (sentences like "the seat of the toilet" will fire for "chair" if we do not include this line)
|
246 |
+
if ('toilet' in words) & ('seat' in words): words = [word for word in words if word != 'seat']
|
247 |
+
|
248 |
+
#get synonyms for all words in the caption
|
249 |
+
idxs = [idxs[idx] for idx, word in enumerate(words) \
|
250 |
+
if word in set(self.mscoco_objects)]
|
251 |
+
words = [word for word in words if word in set(self.mscoco_objects)]
|
252 |
+
node_words = []
|
253 |
+
for word in words:
|
254 |
+
node_words.append(self.inverse_synonym_dict[word])
|
255 |
+
#return all the MSCOCO objects in the caption
|
256 |
+
return words, node_words, idxs, double_words
|
257 |
+
|
258 |
+
def get_annotations_from_segments(self):
|
259 |
+
'''
|
260 |
+
Add objects taken from MSCOCO segmentation masks
|
261 |
+
'''
|
262 |
+
|
263 |
+
coco_segments = combine_coco_instances(self.coco_path )
|
264 |
+
segment_annotations = coco_segments['annotations']
|
265 |
+
|
266 |
+
#make dict linking object name to ids
|
267 |
+
id_to_name = {} #dict with id to synsets
|
268 |
+
for cat in coco_segments['categories']:
|
269 |
+
id_to_name[cat['id']] = cat['name']
|
270 |
+
|
271 |
+
for i, annotation in enumerate(segment_annotations):
|
272 |
+
sys.stdout.write("\rGetting annotations for %d/%d segmentation masks"
|
273 |
+
%(i, len(segment_annotations)))
|
274 |
+
imid = annotation['image_id']
|
275 |
+
|
276 |
+
node_word = self.inverse_synonym_dict[id_to_name[annotation['category_id']]]
|
277 |
+
self.imid_to_objects[imid].append(node_word)
|
278 |
+
print("\n")
|
279 |
+
|
280 |
+
def get_annotations_from_captions(self):
|
281 |
+
'''
|
282 |
+
Add objects taken from MSCOCO ground truth captions
|
283 |
+
'''
|
284 |
+
|
285 |
+
coco_caps = combine_coco_captions(self.coco_path)
|
286 |
+
caption_annotations = coco_caps['annotations']
|
287 |
+
|
288 |
+
for i, annotation in enumerate(caption_annotations):
|
289 |
+
sys.stdout.write('\rGetting annotations for %d/%d ground truth captions'
|
290 |
+
%(i, len(coco_caps['annotations'])))
|
291 |
+
imid = annotation['image_id']
|
292 |
+
|
293 |
+
_, node_words, _, _ = self.caption_to_words(annotation['caption'])
|
294 |
+
# note here is update, so call get_annotations_from_segments first
|
295 |
+
self.imid_to_objects[imid].extend(node_words)
|
296 |
+
print("\n")
|
297 |
+
|
298 |
+
|
299 |
+
def get_annotations(self):
|
300 |
+
|
301 |
+
'''
|
302 |
+
Get annotations from both segmentation and captions. Need both annotation types for CHAIR metric.
|
303 |
+
'''
|
304 |
+
|
305 |
+
self.get_annotations_from_segments()
|
306 |
+
self.get_annotations_from_captions()
|
307 |
+
# deduplicate
|
308 |
+
for imid in self.imid_to_objects:
|
309 |
+
self.imid_to_objects[imid] = set(self.imid_to_objects[imid])
|
310 |
+
|
311 |
+
def compute_chair(self, cap_file, image_id_key, caption_key):
|
312 |
+
'''
|
313 |
+
Given ground truth objects and generated captions, determine which sentences have hallucinated words.
|
314 |
+
'''
|
315 |
+
self._load_generated_captions_into_evaluator(cap_file, image_id_key, caption_key)
|
316 |
+
|
317 |
+
imid_to_objects = self.imid_to_objects
|
318 |
+
caps = self.caps
|
319 |
+
eval_imids = self.eval_imids
|
320 |
+
|
321 |
+
num_caps = 0.
|
322 |
+
num_hallucinated_caps = 0.
|
323 |
+
hallucinated_word_count = 0.
|
324 |
+
coco_word_count = 0.
|
325 |
+
len_caps = 0.
|
326 |
+
|
327 |
+
# :add:
|
328 |
+
num_recall_gt_objects = 0.
|
329 |
+
num_gt_objects = 0.
|
330 |
+
|
331 |
+
output = {'sentences': []}
|
332 |
+
|
333 |
+
for i in tqdm.trange(len(caps)):
|
334 |
+
cap :str = caps[i]
|
335 |
+
imid :int = eval_imids[i]
|
336 |
+
|
337 |
+
#get all words in the caption, as well as corresponding node word
|
338 |
+
# pos = cap.rfind('.')
|
339 |
+
# cap = cap[:pos+1]
|
340 |
+
words, node_words, idxs, raw_words = self.caption_to_words(cap)
|
341 |
+
|
342 |
+
gt_objects = imid_to_objects[imid]
|
343 |
+
cap_dict = {'image_id': imid,
|
344 |
+
'caption': cap,
|
345 |
+
'mscoco_hallucinated_words': [],
|
346 |
+
'mscoco_gt_words': list(gt_objects),
|
347 |
+
'mscoco_generated_words': list(node_words),
|
348 |
+
'hallucination_idxs': [],
|
349 |
+
'words': raw_words
|
350 |
+
}
|
351 |
+
|
352 |
+
# :add:
|
353 |
+
cap_dict['metrics'] = {'CHAIRs': 0,
|
354 |
+
'CHAIRi': 0,
|
355 |
+
'Recall': 0,
|
356 |
+
'Len': 0,
|
357 |
+
}
|
358 |
+
|
359 |
+
#count hallucinated words
|
360 |
+
coco_word_count += len(node_words)
|
361 |
+
hallucinated = False
|
362 |
+
|
363 |
+
# add
|
364 |
+
recall_gt_objects = set()
|
365 |
+
for word, node_word, idx in zip(words, node_words, idxs):
|
366 |
+
if node_word not in gt_objects:
|
367 |
+
hallucinated_word_count += 1
|
368 |
+
cap_dict['mscoco_hallucinated_words'].append((word, node_word))
|
369 |
+
cap_dict['hallucination_idxs'].append(idx)
|
370 |
+
hallucinated = True
|
371 |
+
else:
|
372 |
+
recall_gt_objects.add(node_word)
|
373 |
+
|
374 |
+
#count hallucinated caps
|
375 |
+
num_caps += 1
|
376 |
+
len_caps += len(raw_words)
|
377 |
+
if hallucinated:
|
378 |
+
num_hallucinated_caps += 1
|
379 |
+
|
380 |
+
# add
|
381 |
+
num_gt_objects += len(gt_objects)
|
382 |
+
num_recall_gt_objects += len(recall_gt_objects)
|
383 |
+
|
384 |
+
cap_dict['metrics']['CHAIRs'] = int(hallucinated)
|
385 |
+
cap_dict['metrics']['CHAIRi'] = 0.
|
386 |
+
cap_dict['metrics']['Recall'] = 0.
|
387 |
+
cap_dict['metrics']['Len'] = 0.
|
388 |
+
|
389 |
+
|
390 |
+
if len(words) > 0:
|
391 |
+
cap_dict['metrics']['CHAIRi'] = len(cap_dict['mscoco_hallucinated_words'])/float(len(words))
|
392 |
+
|
393 |
+
# add
|
394 |
+
if len(gt_objects) > 0:
|
395 |
+
cap_dict['metrics']['Recall'] = len(recall_gt_objects) / len(gt_objects)
|
396 |
+
|
397 |
+
output['sentences'].append(cap_dict)
|
398 |
+
|
399 |
+
chair_s = (num_hallucinated_caps/num_caps)
|
400 |
+
chair_i = (hallucinated_word_count/coco_word_count)
|
401 |
+
# add
|
402 |
+
recall = num_recall_gt_objects / num_gt_objects
|
403 |
+
avg_len = (0.01*len_caps/num_caps)
|
404 |
+
|
405 |
+
output['overall_metrics'] = {'CHAIRs': chair_s,
|
406 |
+
'CHAIRi': chair_i,
|
407 |
+
'Recall': recall,
|
408 |
+
'Len': avg_len,}
|
409 |
+
|
410 |
+
return output
|
411 |
+
|
412 |
+
def load_generated_captions(cap_file, image_id_key:str, caption_key:str):
|
413 |
+
#Read in captions
|
414 |
+
# it should be list of dict
|
415 |
+
ext = os.path.splitext(cap_file)[-1]
|
416 |
+
if ext == '.json':
|
417 |
+
caps = json.load(open(cap_file))
|
418 |
+
elif ext == '.jsonl':
|
419 |
+
caps = [json.loads(s) for s in open(cap_file)]
|
420 |
+
else:
|
421 |
+
raise ValueError(f'Unspported extension {ext} for cap_file: {cap_file}')
|
422 |
+
|
423 |
+
# list of int
|
424 |
+
imids = [obj[image_id_key] for obj in caps]
|
425 |
+
|
426 |
+
# list of str
|
427 |
+
caps = [obj[caption_key] for obj in caps]
|
428 |
+
|
429 |
+
return caps, imids
|
430 |
+
|
431 |
+
def save_hallucinated_words(cap_file, cap_dict):
|
432 |
+
with open(cap_file, 'w') as f:
|
433 |
+
json.dump(cap_dict, f, indent=2, ensure_ascii=False)
|
434 |
+
|
435 |
+
def print_metrics(hallucination_cap_dict, quiet=False):
|
436 |
+
sentence_metrics = hallucination_cap_dict['overall_metrics']
|
437 |
+
|
438 |
+
for k, v in sentence_metrics.items():
|
439 |
+
k_str = str(k).ljust(10)
|
440 |
+
v_str = f'{v * 100:.01f}'
|
441 |
+
print(k_str, v_str, sep=': ')
|
442 |
+
|
443 |
+
if __name__ == '__main__':
|
444 |
+
parser = argparse.ArgumentParser()
|
445 |
+
|
446 |
+
parser.add_argument("--cap_file", type=str, default='',
|
447 |
+
help="path towards json or jsonl saving image ids and their captions in list of dict.")
|
448 |
+
parser.add_argument("--image_id_key", type=str, default="image_id",
|
449 |
+
help="in each dict of cap_file, which key stores image id of coco.")
|
450 |
+
parser.add_argument("--caption_key", type=str, default="caption",
|
451 |
+
help="in each dict of cap_file, which key stores caption of the image.")
|
452 |
+
|
453 |
+
parser.add_argument("--cache", type=str, default="chair.pkl",
|
454 |
+
help="pre inited CHAIR evaluator object, for fast loading.")
|
455 |
+
parser.add_argument("--coco_path", type=str, default='coco_annotations',
|
456 |
+
help="only use for regenerating CHAIR evaluator object, will be ignored if uses cached evaluator.")
|
457 |
+
|
458 |
+
parser.add_argument("--save_path", type=str, default="",
|
459 |
+
help="saving CHAIR evaluate and results to json, useful for debugging the caption model.")
|
460 |
+
|
461 |
+
args = parser.parse_args()
|
462 |
+
|
463 |
+
if args.cache and os.path.exists(args.cache):
|
464 |
+
evaluator = pickle.load(open(args.cache, 'rb'))
|
465 |
+
print(f"loaded evaluator from cache: {args.cache}")
|
466 |
+
else:
|
467 |
+
print(f"cache not setted or not exist yet, building from scratch...")
|
468 |
+
evaluator = CHAIR(args.coco_path)
|
469 |
+
pickle.dump(evaluator, open(args.cache, 'wb'))
|
470 |
+
print(f"cached evaluator to: {args.cache}")
|
471 |
+
|
472 |
+
cap_dict = evaluator.compute_chair(args.cap_file, args.image_id_key, args.caption_key)
|
473 |
+
|
474 |
+
print_metrics(cap_dict)
|
475 |
+
|
476 |
+
if args.save_path:
|
477 |
+
save_hallucinated_words(args.save_path, cap_dict)
|
478 |
+
|
479 |
+
|
480 |
+
# CUDA_VISIBLE_DEVICES=5 python chair.py \
|
481 |
+
# --cap_file ../POPE-Adv/text_feat/chair-eval/instructblip/ours.jsonl \
|
482 |
+
# --image_id_key image_id --caption_key caption \
|
483 |
+
# --coco_path /mnt/petrelfs/share_data/wangjiaqi/mllm-data-alg/COCO_2014/ori/annotations_trainval2014/annotations/ \
|
484 |
+
# --save_path ../POPE-Adv/text_feat/chair-eval/instructblip/ours_outputs.json
|
OPERA/chair_eval.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.backends.cudnn as cudnn
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
from torchvision import transforms
|
11 |
+
from torchvision.transforms.functional import InterpolationMode
|
12 |
+
from torchvision.utils import save_image
|
13 |
+
|
14 |
+
from pope_loader import POPEDataSet
|
15 |
+
from minigpt4.common.dist_utils import get_rank
|
16 |
+
from minigpt4.models import load_preprocess
|
17 |
+
|
18 |
+
from minigpt4.common.config import Config
|
19 |
+
from minigpt4.common.dist_utils import get_rank
|
20 |
+
from minigpt4.common.registry import registry
|
21 |
+
|
22 |
+
# imports modules for registration
|
23 |
+
from minigpt4.datasets.builders import *
|
24 |
+
from minigpt4.models import *
|
25 |
+
from minigpt4.processors import *
|
26 |
+
from minigpt4.runners import *
|
27 |
+
from minigpt4.tasks import *
|
28 |
+
|
29 |
+
from PIL import Image
|
30 |
+
from torchvision.utils import save_image
|
31 |
+
|
32 |
+
import matplotlib.pyplot as plt
|
33 |
+
import matplotlib as mpl
|
34 |
+
import seaborn
|
35 |
+
import json
|
36 |
+
|
37 |
+
|
38 |
+
MODEL_EVAL_CONFIG_PATH = {
|
39 |
+
"minigpt4": "eval_configs/minigpt4_eval.yaml",
|
40 |
+
"instructblip": "eval_configs/instructblip_eval.yaml",
|
41 |
+
"lrv_instruct": "eval_configs/lrv_instruct_eval.yaml",
|
42 |
+
"shikra": "eval_configs/shikra_eval.yaml",
|
43 |
+
"llava-1.5": "eval_configs/llava-1.5_eval.yaml",
|
44 |
+
}
|
45 |
+
|
46 |
+
INSTRUCTION_TEMPLATE = {
|
47 |
+
"minigpt4": "###Human: <Img><ImageHere></Img> <question> ###Assistant:",
|
48 |
+
"instructblip": "<ImageHere><question>",
|
49 |
+
"lrv_instruct": "###Human: <Img><ImageHere></Img> <question> ###Assistant:",
|
50 |
+
"shikra": "USER: <im_start><ImageHere><im_end> <question> ASSISTANT:",
|
51 |
+
"llava-1.5": "USER: <ImageHere> <question> ASSISTANT:"
|
52 |
+
}
|
53 |
+
|
54 |
+
|
55 |
+
def setup_seeds(config):
|
56 |
+
seed = config.run_cfg.seed + get_rank()
|
57 |
+
|
58 |
+
random.seed(seed)
|
59 |
+
np.random.seed(seed)
|
60 |
+
torch.manual_seed(seed)
|
61 |
+
|
62 |
+
cudnn.benchmark = False
|
63 |
+
cudnn.deterministic = True
|
64 |
+
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
parser = argparse.ArgumentParser(description="POPE-Adv evaluation on LVLMs.")
|
70 |
+
parser.add_argument("--model", type=str, help="model")
|
71 |
+
parser.add_argument("--gpu-id", type=int, help="specify the gpu to load the model.")
|
72 |
+
parser.add_argument(
|
73 |
+
"--options",
|
74 |
+
nargs="+",
|
75 |
+
help="override some settings in the used config, the key-value pair "
|
76 |
+
"in xxx=yyy format will be merged into config file (deprecate), "
|
77 |
+
"change to --cfg-options instead.",
|
78 |
+
)
|
79 |
+
parser.add_argument("--data_path", type=str, default="COCO_2014/val2014/", help="data path")
|
80 |
+
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
81 |
+
parser.add_argument("--num_workers", type=int, default=2, help="num workers")
|
82 |
+
|
83 |
+
parser.add_argument("--beam", type=int)
|
84 |
+
parser.add_argument("--sample", action='store_true')
|
85 |
+
parser.add_argument("--scale_factor", type=float, default=50)
|
86 |
+
parser.add_argument("--threshold", type=int, default=15)
|
87 |
+
parser.add_argument("--num_attn_candidates", type=int, default=5)
|
88 |
+
parser.add_argument("--penalty_weights", type=float, default=1.0)
|
89 |
+
args = parser.parse_known_args()[0]
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
|
95 |
+
|
96 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
|
97 |
+
args.cfg_path = MODEL_EVAL_CONFIG_PATH[args.model]
|
98 |
+
cfg = Config(args)
|
99 |
+
setup_seeds(cfg)
|
100 |
+
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
|
101 |
+
|
102 |
+
# ========================================
|
103 |
+
# Model Initialization
|
104 |
+
# ========================================
|
105 |
+
print('Initializing Model')
|
106 |
+
|
107 |
+
model_config = cfg.model_cfg
|
108 |
+
model_config.device_8bit = args.gpu_id
|
109 |
+
model_cls = registry.get_model_class(model_config.arch)
|
110 |
+
model = model_cls.from_config(model_config).to(device)
|
111 |
+
model.eval()
|
112 |
+
processor_cfg = cfg.get_config().preprocess
|
113 |
+
processor_cfg.vis_processor.eval.do_normalize = False
|
114 |
+
vis_processors, txt_processors = load_preprocess(processor_cfg)
|
115 |
+
print(vis_processors["eval"].transform)
|
116 |
+
print("Done!")
|
117 |
+
|
118 |
+
mean = (0.48145466, 0.4578275, 0.40821073)
|
119 |
+
std = (0.26862954, 0.26130258, 0.27577711)
|
120 |
+
norm = transforms.Normalize(mean, std)
|
121 |
+
|
122 |
+
|
123 |
+
img_files = os.listdir(args.data_path)
|
124 |
+
random.shuffle(img_files)
|
125 |
+
|
126 |
+
with open(args.data_path + '../annotations_trainval2014/annotations/instances_val2014.json', 'r') as f:
|
127 |
+
lines = f.readlines()
|
128 |
+
coco_anns = json.loads(lines[0])
|
129 |
+
|
130 |
+
img_dict = {}
|
131 |
+
|
132 |
+
categories = coco_anns["categories"]
|
133 |
+
category_names = [c["name"] for c in categories]
|
134 |
+
category_dict = {int(c["id"]): c["name"] for c in categories}
|
135 |
+
|
136 |
+
for img_info in coco_anns["images"]:
|
137 |
+
img_dict[img_info["id"]] = {"name": img_info["file_name"], "anns": []}
|
138 |
+
|
139 |
+
for ann_info in coco_anns["annotations"]:
|
140 |
+
img_dict[ann_info["image_id"]]["anns"].append(
|
141 |
+
category_dict[ann_info["category_id"]]
|
142 |
+
)
|
143 |
+
|
144 |
+
|
145 |
+
base_dir = "./log/" + args.model
|
146 |
+
if not os.path.exists(base_dir):
|
147 |
+
os.mkdir(base_dir)
|
148 |
+
|
149 |
+
|
150 |
+
for img_id in tqdm(range(len(img_files))):
|
151 |
+
if img_id == 500:
|
152 |
+
break
|
153 |
+
img_file = img_files[img_id]
|
154 |
+
img_id = int(img_file.split(".jpg")[0][-6:])
|
155 |
+
img_info = img_dict[img_id]
|
156 |
+
assert img_info["name"] == img_file
|
157 |
+
img_anns = set(img_info["anns"])
|
158 |
+
img_save = {}
|
159 |
+
img_save["image_id"] = img_id
|
160 |
+
|
161 |
+
image_path = args.data_path + img_file
|
162 |
+
raw_image = Image.open(image_path).convert("RGB")
|
163 |
+
image = vis_processors["eval"](raw_image).unsqueeze(0)
|
164 |
+
image = image.to(device)
|
165 |
+
|
166 |
+
qu = "Please describe this image in detail."
|
167 |
+
template = INSTRUCTION_TEMPLATE[args.model]
|
168 |
+
qu = template.replace("<question>", qu)
|
169 |
+
|
170 |
+
with torch.inference_mode():
|
171 |
+
with torch.no_grad():
|
172 |
+
out = model.generate(
|
173 |
+
{"image": norm(image), "prompt":qu},
|
174 |
+
use_nucleus_sampling=args.sample,
|
175 |
+
num_beams=args.beam,
|
176 |
+
max_new_tokens=512,
|
177 |
+
output_attentions=True,
|
178 |
+
opera_decoding=True,
|
179 |
+
scale_factor=args.scale_factor,
|
180 |
+
threshold=args.threshold,
|
181 |
+
num_attn_candidates=args.num_attn_candidates,
|
182 |
+
penalty_weights=args.penalty_weights,
|
183 |
+
)
|
184 |
+
img_save["caption"] = out[0]
|
185 |
+
|
186 |
+
# dump metric file
|
187 |
+
with open(os.path.join(base_dir, 'ours-s_{}-t_{}-num_can_{}-p_{}.jsonl'.format(args.scale_factor, args.threshold, args.num_attn_candidates, args.penalty_weights)), "a") as f:
|
188 |
+
json.dump(img_save, f)
|
189 |
+
f.write('\n')
|
190 |
+
|
191 |
+
|
192 |
+
|
OPERA/dataset/README_1_STAGE.md
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Download the filtered Conceptual Captions, SBU, LAION datasets
|
2 |
+
|
3 |
+
### Pre-training datasets download:
|
4 |
+
We use the filtered synthetic captions prepared by BLIP. For more details about the dataset, please refer to [BLIP](https://github.com/salesforce/BLIP).
|
5 |
+
|
6 |
+
It requires ~2.3T to store LAION and CC3M+CC12M+SBU datasets
|
7 |
+
|
8 |
+
Image source | Filtered synthetic caption by ViT-L
|
9 |
+
--- | :---:
|
10 |
+
CC3M+CC12M+SBU | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/ccs_synthetic_filtered_large.json">Download</a>
|
11 |
+
LAION115M | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/laion_synthetic_filtered_large.json">Download</a>
|
12 |
+
|
13 |
+
This will download two json files
|
14 |
+
```
|
15 |
+
ccs_synthetic_filtered_large.json
|
16 |
+
laion_synthetic_filtered_large.json
|
17 |
+
```
|
18 |
+
|
19 |
+
## prepare the data step-by-step
|
20 |
+
|
21 |
+
|
22 |
+
### setup the dataset folder and move the annotation file to the data storage folder
|
23 |
+
```
|
24 |
+
export MINIGPT4_DATASET=/YOUR/PATH/FOR/LARGE/DATASET/
|
25 |
+
mkdir ${MINIGPT4_DATASET}/cc_sbu
|
26 |
+
mkdir ${MINIGPT4_DATASET}/laion
|
27 |
+
mv ccs_synthetic_filtered_large.json ${MINIGPT4_DATASET}/cc_sbu
|
28 |
+
mv laion_synthetic_filtered_large.json ${MINIGPT4_DATASET}/laion
|
29 |
+
```
|
30 |
+
|
31 |
+
### Convert the scripts to data storate folder
|
32 |
+
```
|
33 |
+
cp convert_cc_sbu.py ${MINIGPT4_DATASET}/cc_sbu
|
34 |
+
cp download_cc_sbu.sh ${MINIGPT4_DATASET}/cc_sbu
|
35 |
+
cp convert_laion.py ${MINIGPT4_DATASET}/laion
|
36 |
+
cp download_laion.sh ${MINIGPT4_DATASET}/laion
|
37 |
+
```
|
38 |
+
|
39 |
+
|
40 |
+
### Convert the laion and cc_sbu annotation file format to be img2dataset format
|
41 |
+
```
|
42 |
+
cd ${MINIGPT4_DATASET}/cc_sbu
|
43 |
+
python convert_cc_sbu.py
|
44 |
+
|
45 |
+
cd ${MINIGPT4_DATASET}/laion
|
46 |
+
python convert_laion.py
|
47 |
+
```
|
48 |
+
|
49 |
+
### Download the datasets with img2dataset
|
50 |
+
```
|
51 |
+
cd ${MINIGPT4_DATASET}/cc_sbu
|
52 |
+
sh download_cc_sbu.sh
|
53 |
+
cd ${MINIGPT4_DATASET}/laion
|
54 |
+
sh download_laion.sh
|
55 |
+
```
|
56 |
+
|
57 |
+
|
58 |
+
The final dataset structure
|
59 |
+
|
60 |
+
```
|
61 |
+
.
|
62 |
+
├── ${MINIGPT4_DATASET}
|
63 |
+
│ ├── cc_sbu
|
64 |
+
│ ├── convert_cc_sbu.py
|
65 |
+
│ ├── download_cc_sbu.sh
|
66 |
+
│ ├── ccs_synthetic_filtered_large.json
|
67 |
+
│ ├── ccs_synthetic_filtered_large.tsv
|
68 |
+
│ └── cc_sbu_dataset
|
69 |
+
│ ├── 00000.tar
|
70 |
+
│ ├── 00000.parquet
|
71 |
+
│ ...
|
72 |
+
│ ├── laion
|
73 |
+
│ ├── convert_laion.py
|
74 |
+
│ ├── download_laion.sh
|
75 |
+
│ ├── laion_synthetic_filtered_large.json
|
76 |
+
│ ├── laion_synthetic_filtered_large.tsv
|
77 |
+
│ └── laion_dataset
|
78 |
+
│ ├── 00000.tar
|
79 |
+
│ ├── 00000.parquet
|
80 |
+
│ ...
|
81 |
+
...
|
82 |
+
```
|
83 |
+
|
84 |
+
|
85 |
+
## Set up the dataset configuration files
|
86 |
+
|
87 |
+
Then, set up the LAION dataset loading path in
|
88 |
+
[here](../minigpt4/configs/datasets/laion/defaults.yaml#L5) at Line 5 as
|
89 |
+
${MINIGPT4_DATASET}/laion/laion_dataset/{00000..10488}.tar
|
90 |
+
|
91 |
+
and the Conceptual Captoin and SBU datasets loading path in
|
92 |
+
[here](../minigpt4/configs/datasets/cc_sbu/defaults.yaml#L5) at Line 5 as
|
93 |
+
${MINIGPT4_DATASET}/cc_sbu/cc_sbu_dataset/{00000..01255}.tar
|
94 |
+
|
95 |
+
|
96 |
+
|
OPERA/dataset/README_2_STAGE.md
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Second Stage Data Preparation
|
2 |
+
|
3 |
+
Our second stage dataset can be downloaded from
|
4 |
+
[here](https://drive.google.com/file/d/1nJXhoEcy3KTExr17I7BXqY5Y9Lx_-n-9/view?usp=share_link)
|
5 |
+
After extraction, you will get a data follder with the following structure:
|
6 |
+
|
7 |
+
```
|
8 |
+
cc_sbu_align
|
9 |
+
├── filter_cap.json
|
10 |
+
└── image
|
11 |
+
├── 2.jpg
|
12 |
+
├── 3.jpg
|
13 |
+
...
|
14 |
+
```
|
15 |
+
|
16 |
+
Put the folder to any path you want.
|
17 |
+
Then, set up the dataset path in the dataset config file
|
18 |
+
[here](../minigpt4/configs/datasets/cc_sbu/align.yaml#L5) at Line 5.
|
19 |
+
|
OPERA/dataset/convert_cc_sbu.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import csv
|
3 |
+
|
4 |
+
# specify input and output file paths
|
5 |
+
input_file = 'ccs_synthetic_filtered_large.json'
|
6 |
+
output_file = 'ccs_synthetic_filtered_large.tsv'
|
7 |
+
|
8 |
+
# load JSON data from input file
|
9 |
+
with open(input_file, 'r') as f:
|
10 |
+
data = json.load(f)
|
11 |
+
|
12 |
+
# extract header and data from JSON
|
13 |
+
header = data[0].keys()
|
14 |
+
rows = [x.values() for x in data]
|
15 |
+
|
16 |
+
# write data to TSV file
|
17 |
+
with open(output_file, 'w') as f:
|
18 |
+
writer = csv.writer(f, delimiter='\t')
|
19 |
+
writer.writerow(header)
|
20 |
+
writer.writerows(rows)
|
OPERA/dataset/convert_laion.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import csv
|
3 |
+
|
4 |
+
# specify input and output file paths
|
5 |
+
input_file = 'laion_synthetic_filtered_large.json'
|
6 |
+
output_file = 'laion_synthetic_filtered_large.tsv'
|
7 |
+
|
8 |
+
# load JSON data from input file
|
9 |
+
with open(input_file, 'r') as f:
|
10 |
+
data = json.load(f)
|
11 |
+
|
12 |
+
# extract header and data from JSON
|
13 |
+
header = data[0].keys()
|
14 |
+
rows = [x.values() for x in data]
|
15 |
+
|
16 |
+
# write data to TSV file
|
17 |
+
with open(output_file, 'w') as f:
|
18 |
+
writer = csv.writer(f, delimiter='\t')
|
19 |
+
writer.writerow(header)
|
20 |
+
writer.writerows(rows)
|
OPERA/dataset/download_cc_sbu.sh
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
img2dataset --url_list ccs_synthetic_filtered_large.tsv --input_format "tsv"\
|
4 |
+
--url_col "url" --caption_col "caption" --output_format webdataset\
|
5 |
+
--output_folder cc_sbu_dataset --processes_count 16 --thread_count 128 --image_size 256 \
|
6 |
+
--enable_wandb True
|
OPERA/dataset/download_laion.sh
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
img2dataset --url_list laion_synthetic_filtered_large.tsv --input_format "tsv"\
|
4 |
+
--url_col "url" --caption_col "caption" --output_format webdataset\
|
5 |
+
--output_folder laion_dataset --processes_count 16 --thread_count 128 --image_size 256 \
|
6 |
+
--enable_wandb True
|
OPERA/demo.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
OPERA/environment.yml
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: opera
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- defaults
|
5 |
+
dependencies:
|
6 |
+
- _libgcc_mutex=0.1
|
7 |
+
- _openmp_mutex=5.1
|
8 |
+
- blas=1.0
|
9 |
+
- brotlipy=0.7.0
|
10 |
+
- bzip2=1.0.8
|
11 |
+
- ca-certificates=2023.01.10
|
12 |
+
- certifi=2022.12.7
|
13 |
+
- cffi=1.15.1
|
14 |
+
- charset-normalizer=2.0.4
|
15 |
+
- cryptography=39.0.1
|
16 |
+
- cudatoolkit=11.3.1
|
17 |
+
- ffmpeg=4.3
|
18 |
+
- freetype=2.12.1
|
19 |
+
- giflib=5.2.1
|
20 |
+
- gmp=6.2.1
|
21 |
+
- gnutls=3.6.15
|
22 |
+
- idna=3.4
|
23 |
+
- intel-openmp=2023.1.0
|
24 |
+
- jpeg=9e
|
25 |
+
- lame=3.100
|
26 |
+
- lcms2=2.12
|
27 |
+
- ld_impl_linux-64=2.38
|
28 |
+
- lerc=3.0
|
29 |
+
- libdeflate=1.17
|
30 |
+
- libffi=3.4.2
|
31 |
+
- libgcc-ng=11.2.0
|
32 |
+
- libgomp=11.2.0
|
33 |
+
- libiconv=1.16
|
34 |
+
- libidn2=2.3.2
|
35 |
+
- libpng=1.6.39
|
36 |
+
- libstdcxx-ng=11.2.0
|
37 |
+
- libtasn1=4.19.0
|
38 |
+
- libtiff=4.5.0
|
39 |
+
- libunistring=0.9.10
|
40 |
+
- libwebp=1.2.4
|
41 |
+
- libwebp-base=1.2.4
|
42 |
+
- lz4-c=1.9.4
|
43 |
+
- mkl=2023.1.0
|
44 |
+
- mkl-service=2.4.0
|
45 |
+
- mkl_fft=1.3.6
|
46 |
+
- mkl_random=1.2.2
|
47 |
+
- ncurses=6.4
|
48 |
+
- nettle=3.7.3
|
49 |
+
- numpy=1.24.3
|
50 |
+
- numpy-base=1.24.3
|
51 |
+
- openh264=2.1.1
|
52 |
+
- openssl=1.1.1t
|
53 |
+
- pillow=9.4.0
|
54 |
+
- pip=23.0.1
|
55 |
+
- pycparser=2.21
|
56 |
+
- pyopenssl=23.0.0
|
57 |
+
- pysocks=1.7.1
|
58 |
+
- python=3.9.16
|
59 |
+
- pytorch-mutex=1.0
|
60 |
+
- readline=8.2
|
61 |
+
- requests=2.29.0
|
62 |
+
- setuptools=66.0.0
|
63 |
+
- sqlite=3.41.2
|
64 |
+
- tbb=2021.8.0
|
65 |
+
- tk=8.6.12
|
66 |
+
- torchaudio=0.12.1
|
67 |
+
- typing_extensions=4.5.0
|
68 |
+
- urllib3=1.26.15
|
69 |
+
- wheel=0.38.4
|
70 |
+
- xz=5.4.2
|
71 |
+
- zlib=1.2.13
|
72 |
+
- zstd=1.5.5
|
73 |
+
- pip:
|
74 |
+
- accelerate==0.19.0
|
75 |
+
- aiofiles==23.1.0
|
76 |
+
- aiohttp==3.8.4
|
77 |
+
- aiosignal==1.3.1
|
78 |
+
- altair==4.2.2
|
79 |
+
- antlr4-python3-runtime==4.9.3
|
80 |
+
- anyio==3.6.2
|
81 |
+
- argon2-cffi==21.3.0
|
82 |
+
- argon2-cffi-bindings==21.2.0
|
83 |
+
- arrow==1.2.3
|
84 |
+
- asttokens==2.2.1
|
85 |
+
- async-lru==2.0.2
|
86 |
+
- async-timeout==4.0.2
|
87 |
+
- attrs==23.1.0
|
88 |
+
- babel==2.12.1
|
89 |
+
- backcall==0.2.0
|
90 |
+
- beautifulsoup4==4.12.2
|
91 |
+
- bertviz==1.4.0
|
92 |
+
- bitsandbytes==0.39.0
|
93 |
+
- bleach==6.0.0
|
94 |
+
- blinker==1.6.2
|
95 |
+
- blis==0.7.9
|
96 |
+
- boto3==1.28.63
|
97 |
+
- botocore==1.31.63
|
98 |
+
- braceexpand==0.1.7
|
99 |
+
- cachetools==5.3.0
|
100 |
+
- catalogue==2.0.8
|
101 |
+
- cfgv==3.3.1
|
102 |
+
- click==8.1.3
|
103 |
+
- cmake==3.26.3
|
104 |
+
- comm==0.1.3
|
105 |
+
- confection==0.0.4
|
106 |
+
- contexttimer==0.3.3
|
107 |
+
- contourpy==1.0.7
|
108 |
+
- curated-tokenizers==0.0.8
|
109 |
+
- curated-transformers==0.1.1
|
110 |
+
- cycler==0.11.0
|
111 |
+
- cymem==2.0.7
|
112 |
+
- debugpy==1.6.7
|
113 |
+
- decorator==5.1.1
|
114 |
+
- decord==0.6.0
|
115 |
+
- defusedxml==0.7.1
|
116 |
+
- distlib==0.3.6
|
117 |
+
- distro==1.8.0
|
118 |
+
- einops==0.6.1
|
119 |
+
- entrypoints==0.4
|
120 |
+
- et-xmlfile==1.1.0
|
121 |
+
- executing==1.2.0
|
122 |
+
- fairscale==0.4.4
|
123 |
+
- fastapi==0.95.2
|
124 |
+
- fastjsonschema==2.16.3
|
125 |
+
- ffmpy==0.3.0
|
126 |
+
- filelock==3.12.0
|
127 |
+
- fonttools==4.39.4
|
128 |
+
- fqdn==1.5.1
|
129 |
+
- frozenlist==1.3.3
|
130 |
+
- fsspec==2023.5.0
|
131 |
+
- ftfy==6.1.1
|
132 |
+
- gdown==4.7.1
|
133 |
+
- gitdb==4.0.10
|
134 |
+
- gitpython==3.1.31
|
135 |
+
- gradio==3.31.0
|
136 |
+
- gradio-client==0.2.5
|
137 |
+
- h11==0.14.0
|
138 |
+
- httpcore==0.17.1
|
139 |
+
- httpx==0.24.1
|
140 |
+
- huggingface-hub==0.14.1
|
141 |
+
- identify==2.5.24
|
142 |
+
- imageio==2.28.1
|
143 |
+
- importlib-metadata==6.6.0
|
144 |
+
- importlib-resources==5.12.0
|
145 |
+
- iopath==0.1.10
|
146 |
+
- ipykernel==6.23.1
|
147 |
+
- ipython==8.13.2
|
148 |
+
- ipython-genutils==0.2.0
|
149 |
+
- ipywidgets==8.0.6
|
150 |
+
- isoduration==20.11.0
|
151 |
+
- jedi==0.18.2
|
152 |
+
- jinja2==3.1.2
|
153 |
+
- jmespath==1.0.1
|
154 |
+
- joblib==1.2.0
|
155 |
+
- json5==0.9.14
|
156 |
+
- jsonpointer==2.3
|
157 |
+
- jsonschema==4.17.3
|
158 |
+
- jupyter==1.0.0
|
159 |
+
- jupyter-client==8.2.0
|
160 |
+
- jupyter-console==6.6.3
|
161 |
+
- jupyter-core==5.3.0
|
162 |
+
- jupyter-events==0.6.3
|
163 |
+
- jupyter-lsp==2.1.0
|
164 |
+
- jupyter-server==2.5.0
|
165 |
+
- jupyter-server-terminals==0.4.4
|
166 |
+
- jupyterlab==4.0.7
|
167 |
+
- jupyterlab-pygments==0.2.2
|
168 |
+
- jupyterlab-server==2.22.1
|
169 |
+
- jupyterlab-widgets==3.0.7
|
170 |
+
- kaggle==1.5.13
|
171 |
+
- kiwisolver==1.4.4
|
172 |
+
- langcodes==3.3.0
|
173 |
+
- lazy-loader==0.2
|
174 |
+
- linkify-it-py==2.0.2
|
175 |
+
- lit==16.0.5
|
176 |
+
- markdown-it-py==2.2.0
|
177 |
+
- markupsafe==2.1.2
|
178 |
+
- matplotlib==3.7.1
|
179 |
+
- matplotlib-inline==0.1.6
|
180 |
+
- mdit-py-plugins==0.3.3
|
181 |
+
- mdurl==0.1.2
|
182 |
+
- mistune==2.0.5
|
183 |
+
- mpmath==1.3.0
|
184 |
+
- multidict==6.0.4
|
185 |
+
- murmurhash==1.0.9
|
186 |
+
- nbclassic==1.0.0
|
187 |
+
- nbclient==0.7.4
|
188 |
+
- nbconvert==7.4.0
|
189 |
+
- nbformat==5.8.0
|
190 |
+
- nest-asyncio==1.5.6
|
191 |
+
- networkx==3.1
|
192 |
+
- nltk==3.8.1
|
193 |
+
- nodeenv==1.8.0
|
194 |
+
- notebook==7.0.6
|
195 |
+
- notebook-shim==0.2.3
|
196 |
+
- nvidia-cublas-cu11==11.10.3.66
|
197 |
+
- nvidia-cuda-cupti-cu11==11.7.101
|
198 |
+
- nvidia-cuda-nvrtc-cu11==11.7.99
|
199 |
+
- nvidia-cuda-runtime-cu11==11.7.99
|
200 |
+
- nvidia-cudnn-cu11==8.5.0.96
|
201 |
+
- nvidia-cufft-cu11==10.9.0.58
|
202 |
+
- nvidia-curand-cu11==10.2.10.91
|
203 |
+
- nvidia-cusolver-cu11==11.4.0.1
|
204 |
+
- nvidia-cusparse-cu11==11.7.4.91
|
205 |
+
- nvidia-nccl-cu11==2.14.3
|
206 |
+
- nvidia-nvtx-cu11==11.7.91
|
207 |
+
- omegaconf==2.3.0
|
208 |
+
- openai==0.28.1
|
209 |
+
- opencv-python==4.7.0.72
|
210 |
+
- opencv-python-headless==4.5.5.64
|
211 |
+
- opendatasets==0.1.22
|
212 |
+
- openpyxl==3.1.2
|
213 |
+
- orjson==3.8.12
|
214 |
+
- packaging==23.1
|
215 |
+
- pandas==2.0.1
|
216 |
+
- pandocfilters==1.5.0
|
217 |
+
- parso==0.8.3
|
218 |
+
- pathy==0.10.1
|
219 |
+
- pexpect==4.8.0
|
220 |
+
- pickleshare==0.7.5
|
221 |
+
- platformdirs==3.5.1
|
222 |
+
- plotly==5.14.1
|
223 |
+
- portalocker==2.7.0
|
224 |
+
- pre-commit==3.3.2
|
225 |
+
- preshed==3.0.8
|
226 |
+
- prometheus-client==0.16.0
|
227 |
+
- prompt-toolkit==3.0.38
|
228 |
+
- protobuf==3.20.3
|
229 |
+
- psutil==5.9.5
|
230 |
+
- ptyprocess==0.7.0
|
231 |
+
- pure-eval==0.2.2
|
232 |
+
- pyarrow==12.0.0
|
233 |
+
- pycocoevalcap==1.2
|
234 |
+
- pycocotools==2.0.6
|
235 |
+
- pydantic==1.10.7
|
236 |
+
- pydeck==0.8.1b0
|
237 |
+
- pydub==0.25.1
|
238 |
+
- pygments==2.15.1
|
239 |
+
- pympler==1.0.1
|
240 |
+
- pyparsing==3.0.9
|
241 |
+
- pyrsistent==0.19.3
|
242 |
+
- python-dateutil==2.8.2
|
243 |
+
- python-json-logger==2.0.7
|
244 |
+
- python-magic==0.4.27
|
245 |
+
- python-multipart==0.0.6
|
246 |
+
- python-slugify==8.0.1
|
247 |
+
- pytz==2023.3
|
248 |
+
- pywavelets==1.4.1
|
249 |
+
- pyyaml==6.0
|
250 |
+
- pyzmq==25.0.2
|
251 |
+
- qtconsole==5.4.3
|
252 |
+
- qtpy==2.3.1
|
253 |
+
- regex==2023.5.5
|
254 |
+
- rfc3339-validator==0.1.4
|
255 |
+
- rfc3986-validator==0.1.1
|
256 |
+
- rich==13.3.5
|
257 |
+
- s3transfer==0.7.0
|
258 |
+
- safetensors==0.3.1
|
259 |
+
- scikit-image==0.20.0
|
260 |
+
- scikit-learn==1.2.2
|
261 |
+
- scipy==1.9.1
|
262 |
+
- seaborn==0.13.0
|
263 |
+
- semantic-version==2.10.0
|
264 |
+
- send2trash==1.8.2
|
265 |
+
- sentencepiece==0.1.99
|
266 |
+
- six==1.16.0
|
267 |
+
- sklearn==0.0.post5
|
268 |
+
- smart-open==6.3.0
|
269 |
+
- smmap==5.0.0
|
270 |
+
- sniffio==1.3.0
|
271 |
+
- soupsieve==2.4.1
|
272 |
+
- spacy==3.7.0.dev0
|
273 |
+
- spacy-curated-transformers==0.2.0
|
274 |
+
- spacy-legacy==3.0.12
|
275 |
+
- spacy-loggers==1.0.4
|
276 |
+
- srsly==2.4.6
|
277 |
+
- stack-data==0.6.2
|
278 |
+
- starlette==0.27.0
|
279 |
+
- streamlit==1.22.0
|
280 |
+
- sympy==1.12
|
281 |
+
- tenacity==8.2.2
|
282 |
+
- terminado==0.17.1
|
283 |
+
- text-unidecode==1.3
|
284 |
+
- thinc==8.1.10
|
285 |
+
- threadpoolctl==3.1.0
|
286 |
+
- tifffile==2023.4.12
|
287 |
+
- timm==0.4.12
|
288 |
+
- tinycss2==1.2.1
|
289 |
+
- tokenizers==0.13.3
|
290 |
+
- toml==0.10.2
|
291 |
+
- tomli==2.0.1
|
292 |
+
- toolz==0.12.0
|
293 |
+
- torch==2.0.1
|
294 |
+
- torchvision==0.15.2
|
295 |
+
- tornado==6.3.2
|
296 |
+
- tqdm==4.65.0
|
297 |
+
- traitlets==5.9.0
|
298 |
+
- triton==2.0.0
|
299 |
+
- typer==0.7.0
|
300 |
+
- tzdata==2023.3
|
301 |
+
- tzlocal==5.0.1
|
302 |
+
- uc-micro-py==1.0.2
|
303 |
+
- uri-template==1.2.0
|
304 |
+
- uvicorn==0.22.0
|
305 |
+
- validators==0.20.0
|
306 |
+
- virtualenv==20.23.0
|
307 |
+
- wasabi==1.1.1
|
308 |
+
- watchdog==3.0.0
|
309 |
+
- wcwidth==0.2.6
|
310 |
+
- webcolors==1.13
|
311 |
+
- webdataset==0.2.48
|
312 |
+
- webencodings==0.5.1
|
313 |
+
- websocket-client==1.5.1
|
314 |
+
- websockets==11.0.3
|
315 |
+
- widgetsnbextension==4.0.7
|
316 |
+
- yarl==1.9.2
|
317 |
+
- zipp==3.15.0
|
OPERA/eval_configs/instructblip_eval.yaml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
arch: blip2_vicuna_instruct
|
3 |
+
model_type: vicuna7b
|
4 |
+
max_txt_len: 128
|
5 |
+
# end_sym: "###"
|
6 |
+
# low_resource: True
|
7 |
+
# prompt_template: '###Human: {} ###Assistant: '
|
8 |
+
# ckpt: '/mnt/petrelfs/share_data/huangqidong/lvlm/minigpt4/prerained_minigpt4_7b.pth'
|
9 |
+
|
10 |
+
|
11 |
+
datasets:
|
12 |
+
cc_sbu_align:
|
13 |
+
vis_processor:
|
14 |
+
train:
|
15 |
+
name: "blip2_image_eval"
|
16 |
+
image_size: 224
|
17 |
+
text_processor:
|
18 |
+
train:
|
19 |
+
name: "blip_caption"
|
20 |
+
|
21 |
+
run:
|
22 |
+
task: image_text_pretrain
|
23 |
+
seed: 42
|
OPERA/eval_configs/llava-1.5_eval.yaml
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
arch: llava-1.5
|
3 |
+
model_type: vicuna7b
|
4 |
+
freeze_vit: True
|
5 |
+
freeze_backbone: True
|
6 |
+
tune_mm_mlp_adapter: False
|
7 |
+
freeze_mm_mlp_adapter: True
|
8 |
+
max_txt_len: 160
|
9 |
+
end_sym: "###"
|
10 |
+
low_resource: False
|
11 |
+
# prompt_path: "prompts/alignment.txt"
|
12 |
+
prompt_template: 'USER: {} ASSISTANT: '
|
13 |
+
system_message: "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."
|
14 |
+
merged_ckpt: "/mnt/petrelfs/share_data/huangqidong/lvlm/llava-v1.5-7b/"
|
15 |
+
|
16 |
+
|
17 |
+
datasets:
|
18 |
+
cc_sbu_align:
|
19 |
+
vis_processor:
|
20 |
+
train:
|
21 |
+
name: "clip_image_eval"
|
22 |
+
proc_type: "openai/clip-vit-large-patch14-336"
|
23 |
+
text_processor:
|
24 |
+
train:
|
25 |
+
name: "blip_caption"
|
26 |
+
|
27 |
+
run:
|
28 |
+
task: image_text_pretrain
|
29 |
+
seed: 42
|
OPERA/eval_configs/minigpt4_eval.yaml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
arch: mini_gpt4
|
3 |
+
model_type: pretrain_vicuna0
|
4 |
+
max_txt_len: 160
|
5 |
+
end_sym: "###"
|
6 |
+
low_resource: False
|
7 |
+
prompt_template: '###Human: {} ###Assistant: '
|
8 |
+
ckpt: '/mnt/petrelfs/share_data/huangqidong/lvlm/minigpt4/prerained_minigpt4_7b.pth'
|
9 |
+
|
10 |
+
|
11 |
+
datasets:
|
12 |
+
cc_sbu_align:
|
13 |
+
vis_processor:
|
14 |
+
train:
|
15 |
+
name: "blip2_image_eval"
|
16 |
+
image_size: 224
|
17 |
+
text_processor:
|
18 |
+
train:
|
19 |
+
name: "blip_caption"
|
20 |
+
|
21 |
+
run:
|
22 |
+
task: image_text_pretrain
|
23 |
+
seed: 42
|
OPERA/eval_configs/minigpt4_llama2_eval.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
arch: mini_gpt4
|
3 |
+
model_type: pretrain_llama2
|
4 |
+
max_txt_len: 160
|
5 |
+
end_sym: "</s>"
|
6 |
+
low_resource: False
|
7 |
+
prompt_template: '[INST] {} [/INST] '
|
8 |
+
ckpt: '/path/to/checkpoint/'
|
9 |
+
|
10 |
+
|
11 |
+
datasets:
|
12 |
+
cc_sbu_align:
|
13 |
+
vis_processor:
|
14 |
+
train:
|
15 |
+
name: "blip2_image_eval"
|
16 |
+
image_size: 224
|
17 |
+
text_processor:
|
18 |
+
train:
|
19 |
+
name: "blip_caption"
|
20 |
+
|
21 |
+
run:
|
22 |
+
task: image_text_pretrain
|
OPERA/eval_configs/shikra_eval.yaml
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
arch: shikra
|
3 |
+
model_type: vicuna7b
|
4 |
+
freeze_vit: True
|
5 |
+
freeze_backbone: True
|
6 |
+
tune_mm_mlp_adapter: False
|
7 |
+
freeze_mm_mlp_adapter: True
|
8 |
+
max_txt_len: 160
|
9 |
+
end_sym: "###"
|
10 |
+
low_resource: False
|
11 |
+
# prompt_path: "prompts/alignment.txt"
|
12 |
+
prompt_template: 'USER: {} ASSISTANT: '
|
13 |
+
system_message: "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."
|
14 |
+
merged_ckpt: '/mnt/petrelfs/share_data/zhaohanqing/merged_shikra_7B/'
|
15 |
+
|
16 |
+
|
17 |
+
datasets:
|
18 |
+
cc_sbu_align:
|
19 |
+
vis_processor:
|
20 |
+
train:
|
21 |
+
name: "clip_image_eval"
|
22 |
+
proc_type: "openai/clip-vit-large-patch14"
|
23 |
+
text_processor:
|
24 |
+
train:
|
25 |
+
name: "blip_caption"
|
26 |
+
|
27 |
+
run:
|
28 |
+
task: image_text_pretrain
|
29 |
+
seed: 42
|
OPERA/gpt4v_eval.py
ADDED
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import requests
|
3 |
+
from PIL import Image
|
4 |
+
from io import BytesIO
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
import argparse
|
9 |
+
import os
|
10 |
+
import random
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import torch.backends.cudnn as cudnn
|
15 |
+
from tqdm import tqdm
|
16 |
+
|
17 |
+
from torchvision import transforms
|
18 |
+
from torchvision.transforms.functional import InterpolationMode
|
19 |
+
from torchvision.utils import save_image
|
20 |
+
|
21 |
+
from pope_loader import POPEDataSet
|
22 |
+
from minigpt4.common.dist_utils import get_rank
|
23 |
+
from minigpt4.models import load_preprocess
|
24 |
+
|
25 |
+
from minigpt4.common.config import Config
|
26 |
+
from minigpt4.common.dist_utils import get_rank
|
27 |
+
from minigpt4.common.registry import registry
|
28 |
+
|
29 |
+
# imports modules for registration
|
30 |
+
from minigpt4.datasets.builders import *
|
31 |
+
from minigpt4.models import *
|
32 |
+
from minigpt4.processors import *
|
33 |
+
from minigpt4.runners import *
|
34 |
+
from minigpt4.tasks import *
|
35 |
+
|
36 |
+
# from PIL import Image
|
37 |
+
from torchvision.utils import save_image
|
38 |
+
|
39 |
+
import matplotlib.pyplot as plt
|
40 |
+
import matplotlib as mpl
|
41 |
+
import seaborn
|
42 |
+
import json
|
43 |
+
|
44 |
+
|
45 |
+
MODEL_EVAL_CONFIG_PATH = {
|
46 |
+
"minigpt4": "eval_configs/minigpt4_eval.yaml",
|
47 |
+
"instructblip": "eval_configs/instructblip_eval.yaml",
|
48 |
+
"lrv_instruct": "eval_configs/lrv_instruct_eval.yaml",
|
49 |
+
"shikra": "eval_configs/shikra_eval.yaml",
|
50 |
+
"llava-1.5": "eval_configs/llava-1.5_eval.yaml",
|
51 |
+
}
|
52 |
+
|
53 |
+
INSTRUCTION_TEMPLATE = {
|
54 |
+
"minigpt4": "###Human: <Img><ImageHere></Img> <question> ###Assistant:",
|
55 |
+
"instructblip": "<ImageHere><question>",
|
56 |
+
"lrv_instruct": "###Human: <Img><ImageHere></Img> <question> ###Assistant:",
|
57 |
+
"shikra": "USER: <im_start><ImageHere><im_end> <question> ASSISTANT:",
|
58 |
+
"llava-1.5": "USER: <ImageHere> <question> ASSISTANT:"
|
59 |
+
}
|
60 |
+
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
GPT_JUDGE_PROMPT = '''
|
65 |
+
You are required to score the performance of two AI assistants in describing a given image. You should pay extra attention to the hallucination, which refers to the part of descriptions that are inconsistent with the image content, such as claiming the existence of something not present in the image or describing incorrectly in terms of the counts, positions, or colors of objects in the image. Please rate the responses of the assistants on a scale of 1 to 10, where a higher score indicates better performance, according to the following criteria:
|
66 |
+
1: Accuracy: whether the response is accurate with respect to the image content. Responses with fewer hallucinationsshould be given higher scores.
|
67 |
+
2: Detailedness: whether the response is rich in necessary details. Note that hallucinated descriptions should not countas necessary details.
|
68 |
+
Please output the scores for each criterion, containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space. Following the scores, please provide an explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment.
|
69 |
+
|
70 |
+
[Assistant 1]
|
71 |
+
{}
|
72 |
+
[End of Assistant 1]
|
73 |
+
|
74 |
+
[Assistant 2]
|
75 |
+
{}
|
76 |
+
[End of Assistant 2]
|
77 |
+
|
78 |
+
Output format:
|
79 |
+
Accuracy: <Scores of the two answers>
|
80 |
+
Reason:
|
81 |
+
|
82 |
+
Detailedness: <Scores of the two answers>
|
83 |
+
Reason:
|
84 |
+
'''
|
85 |
+
|
86 |
+
|
87 |
+
# OpenAI API Key
|
88 |
+
API_KEY = "YOUR_API_KEY"
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
def setup_seeds(config):
|
93 |
+
seed = config.run_cfg.seed + get_rank()
|
94 |
+
|
95 |
+
random.seed(seed)
|
96 |
+
np.random.seed(seed)
|
97 |
+
torch.manual_seed(seed)
|
98 |
+
|
99 |
+
cudnn.benchmark = False
|
100 |
+
cudnn.deterministic = True
|
101 |
+
|
102 |
+
|
103 |
+
|
104 |
+
|
105 |
+
def call_api(prompt, image_path):
|
106 |
+
# Function to encode the image
|
107 |
+
def encode_image(image_path):
|
108 |
+
with open(image_path, "rb") as image_file:
|
109 |
+
return base64.b64encode(image_file.read()).decode('utf-8')
|
110 |
+
|
111 |
+
# Getting the base64 string
|
112 |
+
base64_image = encode_image(image_path)
|
113 |
+
|
114 |
+
headers = {
|
115 |
+
"Content-Type": "application/json",
|
116 |
+
"Authorization": f"Bearer {API_KEY}"
|
117 |
+
}
|
118 |
+
|
119 |
+
payload = {
|
120 |
+
"model": "gpt-4-vision-preview",
|
121 |
+
"messages": [
|
122 |
+
{
|
123 |
+
"role": "user",
|
124 |
+
"content": [
|
125 |
+
{
|
126 |
+
"type": "text",
|
127 |
+
"text": prompt
|
128 |
+
},
|
129 |
+
{
|
130 |
+
"type": "image_url",
|
131 |
+
"image_url": {
|
132 |
+
"url": f"data:image/jpeg;base64,{base64_image}"
|
133 |
+
}
|
134 |
+
}
|
135 |
+
]
|
136 |
+
}
|
137 |
+
],
|
138 |
+
"max_tokens": 300
|
139 |
+
}
|
140 |
+
|
141 |
+
response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
|
142 |
+
|
143 |
+
print(response.json().keys())
|
144 |
+
return response.json()
|
145 |
+
|
146 |
+
|
147 |
+
def get_gpt4v_answer(prompt, image_path):
|
148 |
+
while 1:
|
149 |
+
try:
|
150 |
+
res = call_api(prompt, image_path)
|
151 |
+
if "choices" in res.keys():
|
152 |
+
return res["choices"][0]["message"]["content"]
|
153 |
+
else:
|
154 |
+
assert False
|
155 |
+
except Exception as e:
|
156 |
+
print("retry")
|
157 |
+
# pass
|
158 |
+
# return call_api(prompt, image_path)
|
159 |
+
|
160 |
+
|
161 |
+
parser = argparse.ArgumentParser(description="POPE-Adv evaluation on LVLMs.")
|
162 |
+
parser.add_argument("--model", type=str, help="model")
|
163 |
+
parser.add_argument("--gpu-id", type=int, help="specify the gpu to load the model.")
|
164 |
+
parser.add_argument(
|
165 |
+
"--options",
|
166 |
+
nargs="+",
|
167 |
+
help="override some settings in the used config, the key-value pair "
|
168 |
+
"in xxx=yyy format will be merged into config file (deprecate), "
|
169 |
+
"change to --cfg-options instead.",
|
170 |
+
)
|
171 |
+
parser.add_argument("--data_path", type=str, default="COCO_2014/val2014/", help="data path")
|
172 |
+
parser.add_argument("--batch_size", type=int, help="batch size")
|
173 |
+
parser.add_argument("--num_workers", type=int, default=2, help="num workers")
|
174 |
+
|
175 |
+
parser.add_argument("--scale_factor", type=float, default=50)
|
176 |
+
parser.add_argument("--threshold", type=int, default=15)
|
177 |
+
parser.add_argument("--num_attn_candidates", type=int, default=5)
|
178 |
+
parser.add_argument("--penalty_weights", type=float, default=1.0)
|
179 |
+
args = parser.parse_known_args()[0]
|
180 |
+
|
181 |
+
|
182 |
+
|
183 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
|
184 |
+
args.cfg_path = MODEL_EVAL_CONFIG_PATH[args.model]
|
185 |
+
cfg = Config(args)
|
186 |
+
setup_seeds(cfg)
|
187 |
+
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
|
188 |
+
|
189 |
+
# ========================================
|
190 |
+
# Model Initialization
|
191 |
+
# ========================================
|
192 |
+
print('Initializing Model')
|
193 |
+
|
194 |
+
model_config = cfg.model_cfg
|
195 |
+
model_config.device_8bit = args.gpu_id
|
196 |
+
model_cls = registry.get_model_class(model_config.arch)
|
197 |
+
model = model_cls.from_config(model_config).to(device)
|
198 |
+
model.eval()
|
199 |
+
processor_cfg = cfg.get_config().preprocess
|
200 |
+
processor_cfg.vis_processor.eval.do_normalize = False
|
201 |
+
vis_processors, txt_processors = load_preprocess(processor_cfg)
|
202 |
+
print(vis_processors["eval"].transform)
|
203 |
+
print("Done!")
|
204 |
+
|
205 |
+
mean = (0.48145466, 0.4578275, 0.40821073)
|
206 |
+
std = (0.26862954, 0.26130258, 0.27577711)
|
207 |
+
norm = transforms.Normalize(mean, std)
|
208 |
+
|
209 |
+
|
210 |
+
|
211 |
+
|
212 |
+
img_files = os.listdir(args.data_path)
|
213 |
+
random.shuffle(img_files)
|
214 |
+
|
215 |
+
base_path = "log/gpt4v-eval"
|
216 |
+
if not os.path.exists(base_path + f"/{args.model}"):
|
217 |
+
os.mkdir(base_path + f"/{args.model}")
|
218 |
+
|
219 |
+
gpt_answer_records = {}
|
220 |
+
assistant_answer_records = {}
|
221 |
+
avg_hal_score_1 = 0
|
222 |
+
avg_hal_score_2 = 0
|
223 |
+
avg_det_score_1 = 0
|
224 |
+
avg_det_score_2 = 0
|
225 |
+
num_count = 0
|
226 |
+
|
227 |
+
for idx in range(50):
|
228 |
+
img = img_files[idx]
|
229 |
+
image_path = args.data_path + img
|
230 |
+
raw_image = Image.open(image_path)
|
231 |
+
raw_image = raw_image.convert("RGB")
|
232 |
+
image = vis_processors["eval"](raw_image).unsqueeze(0)
|
233 |
+
image = image.to(device)
|
234 |
+
qu = "Please describe this image in detail."
|
235 |
+
|
236 |
+
template = INSTRUCTION_TEMPLATE[args.model]
|
237 |
+
qu = template.replace("<question>", qu)
|
238 |
+
assistant_answer_records[str(img)] = {}
|
239 |
+
|
240 |
+
with torch.inference_mode():
|
241 |
+
with torch.no_grad():
|
242 |
+
out = model.generate(
|
243 |
+
{"image": norm(image), "prompt":qu},
|
244 |
+
use_nucleus_sampling=False,
|
245 |
+
num_beams=5,
|
246 |
+
max_new_tokens=512,
|
247 |
+
)
|
248 |
+
model_response_1 = out[0]
|
249 |
+
assistant_answer_records[str(img)]["assistant_1"] = model_response_1
|
250 |
+
print("Beam-5 output:")
|
251 |
+
print(model_response_1)
|
252 |
+
|
253 |
+
|
254 |
+
with torch.inference_mode():
|
255 |
+
with torch.no_grad():
|
256 |
+
out = model.generate(
|
257 |
+
{"image": norm(image), "prompt":qu},
|
258 |
+
use_nucleus_sampling=False,
|
259 |
+
num_beams=5,
|
260 |
+
max_new_tokens=512,
|
261 |
+
output_attentions=True,
|
262 |
+
opera_decoding=True,
|
263 |
+
scale_factor=args.scale_factor,
|
264 |
+
threshold=args.threshold,
|
265 |
+
num_attn_candidates=args.num_attn_candidates,
|
266 |
+
penalty_weights=args.penalty_weights,
|
267 |
+
)
|
268 |
+
model_response_2 = out[0]
|
269 |
+
assistant_answer_records[str(img)]["assistant_2"] = model_response_2
|
270 |
+
print("OPERA output:")
|
271 |
+
print(model_response_2)
|
272 |
+
|
273 |
+
# gpt-4v eval
|
274 |
+
prompt = GPT_JUDGE_PROMPT.format(model_response_1, model_response_2)
|
275 |
+
|
276 |
+
gpt_answer = get_gpt4v_answer(prompt, image_path)
|
277 |
+
print(gpt_answer)
|
278 |
+
gpt_answer_records[str(img)] = gpt_answer
|
279 |
+
print(gpt_answer.split("Accuracy: ")[-1].split("\n")[0].split(" "))
|
280 |
+
print(len(gpt_answer.split("Accuracy: ")[-1].split("\n")[0].split(" ")))
|
281 |
+
try:
|
282 |
+
hal_score_1, hal_score_2 = gpt_answer.split("Accuracy: ")[-1].split("\n")[0].split(" ")
|
283 |
+
det_score_1, det_score_2 = gpt_answer.split("Detailedness: ")[-1].split("\n")[0].split(" ")
|
284 |
+
except:
|
285 |
+
continue
|
286 |
+
avg_hal_score_1 += int(hal_score_1)
|
287 |
+
avg_hal_score_2 += int(hal_score_2)
|
288 |
+
avg_det_score_1 += int(det_score_1)
|
289 |
+
avg_det_score_2 += int(det_score_2)
|
290 |
+
num_count += 1
|
291 |
+
print("=========================================")
|
292 |
+
|
293 |
+
# dump metric file
|
294 |
+
with open(os.path.join(base_path + f"/{args.model}", 'answers.json'), "w") as f:
|
295 |
+
json.dump(assistant_answer_records, f)
|
296 |
+
|
297 |
+
# dump metric file
|
298 |
+
with open(os.path.join(base_path + f"/{args.model}", 'records.json'), "w") as f:
|
299 |
+
json.dump(gpt_answer_records, f)
|
300 |
+
|
301 |
+
avg_score = float(avg_hal_score_1) / num_count
|
302 |
+
avg_score = float(avg_hal_score_2) / num_count
|
303 |
+
avg_score = float(avg_det_score_1) / num_count
|
304 |
+
avg_score = float(avg_det_score_2) / num_count
|
305 |
+
print(f"The avg hal score for Assistant 1 and Assistent 2: {avg_hal_score_1}; {avg_hal_score_2}")
|
306 |
+
print(f"The avg det score for Assistant 1 and Assistent 2: {avg_det_score_1}; {avg_det_score_2}")
|
OPERA/log/chair_eval_results/instructblip/beam5.jsonl
ADDED
The diff for this file is too large to render.
See raw diff
|
|
OPERA/log/chair_eval_results/instructblip/greedy.jsonl
ADDED
The diff for this file is too large to render.
See raw diff
|
|
OPERA/log/chair_eval_results/instructblip/ours.jsonl
ADDED
The diff for this file is too large to render.
See raw diff
|
|
OPERA/log/chair_eval_results/llava-1.5/beam5.jsonl
ADDED
The diff for this file is too large to render.
See raw diff
|
|
OPERA/log/chair_eval_results/llava-1.5/greedy.jsonl
ADDED
The diff for this file is too large to render.
See raw diff
|
|
OPERA/log/chair_eval_results/llava-1.5/ours.jsonl
ADDED
The diff for this file is too large to render.
See raw diff
|
|
OPERA/log/chair_eval_results/minigpt4/beam5.jsonl
ADDED
The diff for this file is too large to render.
See raw diff
|
|
OPERA/log/chair_eval_results/minigpt4/greedy.jsonl
ADDED
The diff for this file is too large to render.
See raw diff
|
|
OPERA/log/chair_eval_results/minigpt4/ours.jsonl
ADDED
The diff for this file is too large to render.
See raw diff
|
|
OPERA/log/chair_eval_results/shikra/beam5.jsonl
ADDED
The diff for this file is too large to render.
See raw diff
|
|
OPERA/log/chair_eval_results/shikra/greedy.jsonl
ADDED
The diff for this file is too large to render.
See raw diff
|
|
OPERA/log/chair_eval_results/shikra/ours.jsonl
ADDED
The diff for this file is too large to render.
See raw diff
|
|
OPERA/minigpt4/__init__.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
|
11 |
+
from omegaconf import OmegaConf
|
12 |
+
|
13 |
+
from minigpt4.common.registry import registry
|
14 |
+
|
15 |
+
from minigpt4.datasets.builders import *
|
16 |
+
from minigpt4.models import *
|
17 |
+
from minigpt4.processors import *
|
18 |
+
from minigpt4.tasks import *
|
19 |
+
|
20 |
+
|
21 |
+
root_dir = os.path.dirname(os.path.abspath(__file__))
|
22 |
+
default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml"))
|
23 |
+
|
24 |
+
registry.register_path("library_root", root_dir)
|
25 |
+
repo_root = os.path.join(root_dir, "..")
|
26 |
+
registry.register_path("repo_root", repo_root)
|
27 |
+
cache_root = os.path.join(repo_root, default_cfg.env.cache_root)
|
28 |
+
registry.register_path("cache_root", cache_root)
|
29 |
+
|
30 |
+
registry.register("MAX_INT", sys.maxsize)
|
31 |
+
registry.register("SPLIT_NAMES", ["train", "val", "test"])
|
OPERA/minigpt4/common/__init__.py
ADDED
File without changes
|
OPERA/minigpt4/common/config.py
ADDED
@@ -0,0 +1,468 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
import logging
|
9 |
+
import json
|
10 |
+
from typing import Dict
|
11 |
+
|
12 |
+
from omegaconf import OmegaConf
|
13 |
+
from minigpt4.common.registry import registry
|
14 |
+
|
15 |
+
|
16 |
+
class Config:
|
17 |
+
def __init__(self, args):
|
18 |
+
self.config = {}
|
19 |
+
|
20 |
+
self.args = args
|
21 |
+
|
22 |
+
# Register the config and configuration for setup
|
23 |
+
registry.register("configuration", self)
|
24 |
+
|
25 |
+
user_config = self._build_opt_list(self.args.options)
|
26 |
+
|
27 |
+
config = OmegaConf.load(self.args.cfg_path)
|
28 |
+
|
29 |
+
runner_config = self.build_runner_config(config)
|
30 |
+
model_config = self.build_model_config(config, **user_config)
|
31 |
+
dataset_config = self.build_dataset_config(config)
|
32 |
+
|
33 |
+
# Validate the user-provided runner configuration
|
34 |
+
# model and dataset configuration are supposed to be validated by the respective classes
|
35 |
+
# [TODO] validate the model/dataset configuration
|
36 |
+
# self._validate_runner_config(runner_config)
|
37 |
+
|
38 |
+
# Override the default configuration with user options.
|
39 |
+
self.config = OmegaConf.merge(
|
40 |
+
runner_config, model_config, dataset_config, user_config
|
41 |
+
)
|
42 |
+
|
43 |
+
def _validate_runner_config(self, runner_config):
|
44 |
+
"""
|
45 |
+
This method validates the configuration, such that
|
46 |
+
1) all the user specified options are valid;
|
47 |
+
2) no type mismatches between the user specified options and the config.
|
48 |
+
"""
|
49 |
+
runner_config_validator = create_runner_config_validator()
|
50 |
+
runner_config_validator.validate(runner_config)
|
51 |
+
|
52 |
+
def _build_opt_list(self, opts):
|
53 |
+
opts_dot_list = self._convert_to_dot_list(opts)
|
54 |
+
return OmegaConf.from_dotlist(opts_dot_list)
|
55 |
+
|
56 |
+
@staticmethod
|
57 |
+
def build_model_config(config, **kwargs):
|
58 |
+
model = config.get("model", None)
|
59 |
+
assert model is not None, "Missing model configuration file."
|
60 |
+
|
61 |
+
model_cls = registry.get_model_class(model.arch)
|
62 |
+
assert model_cls is not None, f"Model '{model.arch}' has not been registered."
|
63 |
+
|
64 |
+
model_type = kwargs.get("model.model_type", None)
|
65 |
+
if not model_type:
|
66 |
+
model_type = model.get("model_type", None)
|
67 |
+
# else use the model type selected by user.
|
68 |
+
|
69 |
+
assert model_type is not None, "Missing model_type."
|
70 |
+
|
71 |
+
model_config_path = model_cls.default_config_path(model_type=model_type)
|
72 |
+
|
73 |
+
model_config = OmegaConf.create()
|
74 |
+
# hierarchy override, customized config > default config
|
75 |
+
model_config = OmegaConf.merge(
|
76 |
+
model_config,
|
77 |
+
OmegaConf.load(model_config_path),
|
78 |
+
{"model": config["model"]},
|
79 |
+
)
|
80 |
+
|
81 |
+
return model_config
|
82 |
+
|
83 |
+
@staticmethod
|
84 |
+
def build_runner_config(config):
|
85 |
+
return {"run": config.run}
|
86 |
+
|
87 |
+
@staticmethod
|
88 |
+
def build_dataset_config(config):
|
89 |
+
datasets = config.get("datasets", None)
|
90 |
+
if datasets is None:
|
91 |
+
raise KeyError(
|
92 |
+
"Expecting 'datasets' as the root key for dataset configuration."
|
93 |
+
)
|
94 |
+
|
95 |
+
dataset_config = OmegaConf.create()
|
96 |
+
|
97 |
+
for dataset_name in datasets:
|
98 |
+
builder_cls = registry.get_builder_class(dataset_name)
|
99 |
+
|
100 |
+
dataset_config_type = datasets[dataset_name].get("type", "default")
|
101 |
+
dataset_config_path = builder_cls.default_config_path(
|
102 |
+
type=dataset_config_type
|
103 |
+
)
|
104 |
+
|
105 |
+
# hierarchy override, customized config > default config
|
106 |
+
dataset_config = OmegaConf.merge(
|
107 |
+
dataset_config,
|
108 |
+
OmegaConf.load(dataset_config_path),
|
109 |
+
{"datasets": {dataset_name: config["datasets"][dataset_name]}},
|
110 |
+
)
|
111 |
+
|
112 |
+
return dataset_config
|
113 |
+
|
114 |
+
def _convert_to_dot_list(self, opts):
|
115 |
+
if opts is None:
|
116 |
+
opts = []
|
117 |
+
|
118 |
+
if len(opts) == 0:
|
119 |
+
return opts
|
120 |
+
|
121 |
+
has_equal = opts[0].find("=") != -1
|
122 |
+
|
123 |
+
if has_equal:
|
124 |
+
return opts
|
125 |
+
|
126 |
+
return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])]
|
127 |
+
|
128 |
+
def get_config(self):
|
129 |
+
return self.config
|
130 |
+
|
131 |
+
@property
|
132 |
+
def run_cfg(self):
|
133 |
+
return self.config.run
|
134 |
+
|
135 |
+
@property
|
136 |
+
def datasets_cfg(self):
|
137 |
+
return self.config.datasets
|
138 |
+
|
139 |
+
@property
|
140 |
+
def model_cfg(self):
|
141 |
+
return self.config.model
|
142 |
+
|
143 |
+
def pretty_print(self):
|
144 |
+
logging.info("\n===== Running Parameters =====")
|
145 |
+
logging.info(self._convert_node_to_json(self.config.run))
|
146 |
+
|
147 |
+
logging.info("\n====== Dataset Attributes ======")
|
148 |
+
datasets = self.config.datasets
|
149 |
+
|
150 |
+
for dataset in datasets:
|
151 |
+
if dataset in self.config.datasets:
|
152 |
+
logging.info(f"\n======== {dataset} =======")
|
153 |
+
dataset_config = self.config.datasets[dataset]
|
154 |
+
logging.info(self._convert_node_to_json(dataset_config))
|
155 |
+
else:
|
156 |
+
logging.warning(f"No dataset named '{dataset}' in config. Skipping")
|
157 |
+
|
158 |
+
logging.info(f"\n====== Model Attributes ======")
|
159 |
+
logging.info(self._convert_node_to_json(self.config.model))
|
160 |
+
|
161 |
+
def _convert_node_to_json(self, node):
|
162 |
+
container = OmegaConf.to_container(node, resolve=True)
|
163 |
+
return json.dumps(container, indent=4, sort_keys=True)
|
164 |
+
|
165 |
+
def to_dict(self):
|
166 |
+
return OmegaConf.to_container(self.config)
|
167 |
+
|
168 |
+
|
169 |
+
def node_to_dict(node):
|
170 |
+
return OmegaConf.to_container(node)
|
171 |
+
|
172 |
+
|
173 |
+
class ConfigValidator:
|
174 |
+
"""
|
175 |
+
This is a preliminary implementation to centralize and validate the configuration.
|
176 |
+
May be altered in the future.
|
177 |
+
|
178 |
+
A helper class to validate configurations from yaml file.
|
179 |
+
|
180 |
+
This serves the following purposes:
|
181 |
+
1. Ensure all the options in the yaml are defined, raise error if not.
|
182 |
+
2. when type mismatches are found, the validator will raise an error.
|
183 |
+
3. a central place to store and display helpful messages for supported configurations.
|
184 |
+
|
185 |
+
"""
|
186 |
+
|
187 |
+
class _Argument:
|
188 |
+
def __init__(self, name, choices=None, type=None, help=None):
|
189 |
+
self.name = name
|
190 |
+
self.val = None
|
191 |
+
self.choices = choices
|
192 |
+
self.type = type
|
193 |
+
self.help = help
|
194 |
+
|
195 |
+
def __str__(self):
|
196 |
+
s = f"{self.name}={self.val}"
|
197 |
+
if self.type is not None:
|
198 |
+
s += f", ({self.type})"
|
199 |
+
if self.choices is not None:
|
200 |
+
s += f", choices: {self.choices}"
|
201 |
+
if self.help is not None:
|
202 |
+
s += f", ({self.help})"
|
203 |
+
return s
|
204 |
+
|
205 |
+
def __init__(self, description):
|
206 |
+
self.description = description
|
207 |
+
|
208 |
+
self.arguments = dict()
|
209 |
+
|
210 |
+
self.parsed_args = None
|
211 |
+
|
212 |
+
def __getitem__(self, key):
|
213 |
+
assert self.parsed_args is not None, "No arguments parsed yet."
|
214 |
+
|
215 |
+
return self.parsed_args[key]
|
216 |
+
|
217 |
+
def __str__(self) -> str:
|
218 |
+
return self.format_help()
|
219 |
+
|
220 |
+
def add_argument(self, *args, **kwargs):
|
221 |
+
"""
|
222 |
+
Assume the first argument is the name of the argument.
|
223 |
+
"""
|
224 |
+
self.arguments[args[0]] = self._Argument(*args, **kwargs)
|
225 |
+
|
226 |
+
def validate(self, config=None):
|
227 |
+
"""
|
228 |
+
Convert yaml config (dict-like) to list, required by argparse.
|
229 |
+
"""
|
230 |
+
for k, v in config.items():
|
231 |
+
assert (
|
232 |
+
k in self.arguments
|
233 |
+
), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}."""
|
234 |
+
|
235 |
+
if self.arguments[k].type is not None:
|
236 |
+
try:
|
237 |
+
self.arguments[k].val = self.arguments[k].type(v)
|
238 |
+
except ValueError:
|
239 |
+
raise ValueError(f"{k} is not a valid {self.arguments[k].type}.")
|
240 |
+
|
241 |
+
if self.arguments[k].choices is not None:
|
242 |
+
assert (
|
243 |
+
v in self.arguments[k].choices
|
244 |
+
), f"""{k} must be one of {self.arguments[k].choices}."""
|
245 |
+
|
246 |
+
return config
|
247 |
+
|
248 |
+
def format_arguments(self):
|
249 |
+
return str([f"{k}" for k in sorted(self.arguments.keys())])
|
250 |
+
|
251 |
+
def format_help(self):
|
252 |
+
# description + key-value pair string for each argument
|
253 |
+
help_msg = str(self.description)
|
254 |
+
return help_msg + ", available arguments: " + self.format_arguments()
|
255 |
+
|
256 |
+
def print_help(self):
|
257 |
+
# display help message
|
258 |
+
print(self.format_help())
|
259 |
+
|
260 |
+
|
261 |
+
def create_runner_config_validator():
|
262 |
+
validator = ConfigValidator(description="Runner configurations")
|
263 |
+
|
264 |
+
validator.add_argument(
|
265 |
+
"runner",
|
266 |
+
type=str,
|
267 |
+
choices=["runner_base", "runner_iter"],
|
268 |
+
help="""Runner to use. The "runner_base" uses epoch-based training while iter-based
|
269 |
+
runner runs based on iters. Default: runner_base""",
|
270 |
+
)
|
271 |
+
# add argumetns for training dataset ratios
|
272 |
+
validator.add_argument(
|
273 |
+
"train_dataset_ratios",
|
274 |
+
type=Dict[str, float],
|
275 |
+
help="""Ratios of training dataset. This is used in iteration-based runner.
|
276 |
+
Do not support for epoch-based runner because how to define an epoch becomes tricky.
|
277 |
+
Default: None""",
|
278 |
+
)
|
279 |
+
validator.add_argument(
|
280 |
+
"max_iters",
|
281 |
+
type=float,
|
282 |
+
help="Maximum number of iterations to run.",
|
283 |
+
)
|
284 |
+
validator.add_argument(
|
285 |
+
"max_epoch",
|
286 |
+
type=int,
|
287 |
+
help="Maximum number of epochs to run.",
|
288 |
+
)
|
289 |
+
# add arguments for iters_per_inner_epoch
|
290 |
+
validator.add_argument(
|
291 |
+
"iters_per_inner_epoch",
|
292 |
+
type=float,
|
293 |
+
help="Number of iterations per inner epoch. This is required when runner is runner_iter.",
|
294 |
+
)
|
295 |
+
lr_scheds_choices = registry.list_lr_schedulers()
|
296 |
+
validator.add_argument(
|
297 |
+
"lr_sched",
|
298 |
+
type=str,
|
299 |
+
choices=lr_scheds_choices,
|
300 |
+
help="Learning rate scheduler to use, from {}".format(lr_scheds_choices),
|
301 |
+
)
|
302 |
+
task_choices = registry.list_tasks()
|
303 |
+
validator.add_argument(
|
304 |
+
"task",
|
305 |
+
type=str,
|
306 |
+
choices=task_choices,
|
307 |
+
help="Task to use, from {}".format(task_choices),
|
308 |
+
)
|
309 |
+
# add arguments for init_lr
|
310 |
+
validator.add_argument(
|
311 |
+
"init_lr",
|
312 |
+
type=float,
|
313 |
+
help="Initial learning rate. This will be the learning rate after warmup and before decay.",
|
314 |
+
)
|
315 |
+
# add arguments for min_lr
|
316 |
+
validator.add_argument(
|
317 |
+
"min_lr",
|
318 |
+
type=float,
|
319 |
+
help="Minimum learning rate (after decay).",
|
320 |
+
)
|
321 |
+
# add arguments for warmup_lr
|
322 |
+
validator.add_argument(
|
323 |
+
"warmup_lr",
|
324 |
+
type=float,
|
325 |
+
help="Starting learning rate for warmup.",
|
326 |
+
)
|
327 |
+
# add arguments for learning rate decay rate
|
328 |
+
validator.add_argument(
|
329 |
+
"lr_decay_rate",
|
330 |
+
type=float,
|
331 |
+
help="Learning rate decay rate. Required if using a decaying learning rate scheduler.",
|
332 |
+
)
|
333 |
+
# add arguments for weight decay
|
334 |
+
validator.add_argument(
|
335 |
+
"weight_decay",
|
336 |
+
type=float,
|
337 |
+
help="Weight decay rate.",
|
338 |
+
)
|
339 |
+
# add arguments for training batch size
|
340 |
+
validator.add_argument(
|
341 |
+
"batch_size_train",
|
342 |
+
type=int,
|
343 |
+
help="Training batch size.",
|
344 |
+
)
|
345 |
+
# add arguments for evaluation batch size
|
346 |
+
validator.add_argument(
|
347 |
+
"batch_size_eval",
|
348 |
+
type=int,
|
349 |
+
help="Evaluation batch size, including validation and testing.",
|
350 |
+
)
|
351 |
+
# add arguments for number of workers for data loading
|
352 |
+
validator.add_argument(
|
353 |
+
"num_workers",
|
354 |
+
help="Number of workers for data loading.",
|
355 |
+
)
|
356 |
+
# add arguments for warm up steps
|
357 |
+
validator.add_argument(
|
358 |
+
"warmup_steps",
|
359 |
+
type=int,
|
360 |
+
help="Number of warmup steps. Required if a warmup schedule is used.",
|
361 |
+
)
|
362 |
+
# add arguments for random seed
|
363 |
+
validator.add_argument(
|
364 |
+
"seed",
|
365 |
+
type=int,
|
366 |
+
help="Random seed.",
|
367 |
+
)
|
368 |
+
# add arguments for output directory
|
369 |
+
validator.add_argument(
|
370 |
+
"output_dir",
|
371 |
+
type=str,
|
372 |
+
help="Output directory to save checkpoints and logs.",
|
373 |
+
)
|
374 |
+
# add arguments for whether only use evaluation
|
375 |
+
validator.add_argument(
|
376 |
+
"evaluate",
|
377 |
+
help="Whether to only evaluate the model. If true, training will not be performed.",
|
378 |
+
)
|
379 |
+
# add arguments for splits used for training, e.g. ["train", "val"]
|
380 |
+
validator.add_argument(
|
381 |
+
"train_splits",
|
382 |
+
type=list,
|
383 |
+
help="Splits to use for training.",
|
384 |
+
)
|
385 |
+
# add arguments for splits used for validation, e.g. ["val"]
|
386 |
+
validator.add_argument(
|
387 |
+
"valid_splits",
|
388 |
+
type=list,
|
389 |
+
help="Splits to use for validation. If not provided, will skip the validation.",
|
390 |
+
)
|
391 |
+
# add arguments for splits used for testing, e.g. ["test"]
|
392 |
+
validator.add_argument(
|
393 |
+
"test_splits",
|
394 |
+
type=list,
|
395 |
+
help="Splits to use for testing. If not provided, will skip the testing.",
|
396 |
+
)
|
397 |
+
# add arguments for accumulating gradient for iterations
|
398 |
+
validator.add_argument(
|
399 |
+
"accum_grad_iters",
|
400 |
+
type=int,
|
401 |
+
help="Number of iterations to accumulate gradient for.",
|
402 |
+
)
|
403 |
+
|
404 |
+
# ====== distributed training ======
|
405 |
+
validator.add_argument(
|
406 |
+
"device",
|
407 |
+
type=str,
|
408 |
+
choices=["cpu", "cuda"],
|
409 |
+
help="Device to use. Support 'cuda' or 'cpu' as for now.",
|
410 |
+
)
|
411 |
+
validator.add_argument(
|
412 |
+
"world_size",
|
413 |
+
type=int,
|
414 |
+
help="Number of processes participating in the job.",
|
415 |
+
)
|
416 |
+
validator.add_argument("dist_url", type=str)
|
417 |
+
validator.add_argument("distributed", type=bool)
|
418 |
+
# add arguments to opt using distributed sampler during evaluation or not
|
419 |
+
validator.add_argument(
|
420 |
+
"use_dist_eval_sampler",
|
421 |
+
type=bool,
|
422 |
+
help="Whether to use distributed sampler during evaluation or not.",
|
423 |
+
)
|
424 |
+
|
425 |
+
# ====== task specific ======
|
426 |
+
# generation task specific arguments
|
427 |
+
# add arguments for maximal length of text output
|
428 |
+
validator.add_argument(
|
429 |
+
"max_len",
|
430 |
+
type=int,
|
431 |
+
help="Maximal length of text output.",
|
432 |
+
)
|
433 |
+
# add arguments for minimal length of text output
|
434 |
+
validator.add_argument(
|
435 |
+
"min_len",
|
436 |
+
type=int,
|
437 |
+
help="Minimal length of text output.",
|
438 |
+
)
|
439 |
+
# add arguments number of beams
|
440 |
+
validator.add_argument(
|
441 |
+
"num_beams",
|
442 |
+
type=int,
|
443 |
+
help="Number of beams used for beam search.",
|
444 |
+
)
|
445 |
+
|
446 |
+
# vqa task specific arguments
|
447 |
+
# add arguments for number of answer candidates
|
448 |
+
validator.add_argument(
|
449 |
+
"num_ans_candidates",
|
450 |
+
type=int,
|
451 |
+
help="""For ALBEF and BLIP, these models first rank answers according to likelihood to select answer candidates.""",
|
452 |
+
)
|
453 |
+
# add arguments for inference method
|
454 |
+
validator.add_argument(
|
455 |
+
"inference_method",
|
456 |
+
type=str,
|
457 |
+
choices=["genearte", "rank"],
|
458 |
+
help="""Inference method to use for question answering. If rank, requires a answer list.""",
|
459 |
+
)
|
460 |
+
|
461 |
+
# ====== model specific ======
|
462 |
+
validator.add_argument(
|
463 |
+
"k_test",
|
464 |
+
type=int,
|
465 |
+
help="Number of top k most similar samples from ITC/VTC selection to be tested.",
|
466 |
+
)
|
467 |
+
|
468 |
+
return validator
|
OPERA/minigpt4/common/dist_utils.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
import datetime
|
9 |
+
import functools
|
10 |
+
import os
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.distributed as dist
|
14 |
+
import timm.models.hub as timm_hub
|
15 |
+
|
16 |
+
|
17 |
+
def setup_for_distributed(is_master):
|
18 |
+
"""
|
19 |
+
This function disables printing when not in master process
|
20 |
+
"""
|
21 |
+
import builtins as __builtin__
|
22 |
+
|
23 |
+
builtin_print = __builtin__.print
|
24 |
+
|
25 |
+
def print(*args, **kwargs):
|
26 |
+
force = kwargs.pop("force", False)
|
27 |
+
if is_master or force:
|
28 |
+
builtin_print(*args, **kwargs)
|
29 |
+
|
30 |
+
__builtin__.print = print
|
31 |
+
|
32 |
+
|
33 |
+
def is_dist_avail_and_initialized():
|
34 |
+
if not dist.is_available():
|
35 |
+
return False
|
36 |
+
if not dist.is_initialized():
|
37 |
+
return False
|
38 |
+
return True
|
39 |
+
|
40 |
+
|
41 |
+
def get_world_size():
|
42 |
+
if not is_dist_avail_and_initialized():
|
43 |
+
return 1
|
44 |
+
return dist.get_world_size()
|
45 |
+
|
46 |
+
|
47 |
+
def get_rank():
|
48 |
+
if not is_dist_avail_and_initialized():
|
49 |
+
return 0
|
50 |
+
return dist.get_rank()
|
51 |
+
|
52 |
+
|
53 |
+
def is_main_process():
|
54 |
+
return get_rank() == 0
|
55 |
+
|
56 |
+
|
57 |
+
def init_distributed_mode(args):
|
58 |
+
if args.distributed is False:
|
59 |
+
print("Not using distributed mode")
|
60 |
+
return
|
61 |
+
elif "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
62 |
+
args.rank = int(os.environ["RANK"])
|
63 |
+
args.world_size = int(os.environ["WORLD_SIZE"])
|
64 |
+
args.gpu = int(os.environ["LOCAL_RANK"])
|
65 |
+
elif "SLURM_PROCID" in os.environ:
|
66 |
+
args.rank = int(os.environ["SLURM_PROCID"])
|
67 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
68 |
+
else:
|
69 |
+
print("Not using distributed mode")
|
70 |
+
args.distributed = False
|
71 |
+
return
|
72 |
+
|
73 |
+
args.distributed = True
|
74 |
+
|
75 |
+
torch.cuda.set_device(args.gpu)
|
76 |
+
args.dist_backend = "nccl"
|
77 |
+
print(
|
78 |
+
"| distributed init (rank {}, world {}): {}".format(
|
79 |
+
args.rank, args.world_size, args.dist_url
|
80 |
+
),
|
81 |
+
flush=True,
|
82 |
+
)
|
83 |
+
torch.distributed.init_process_group(
|
84 |
+
backend=args.dist_backend,
|
85 |
+
init_method=args.dist_url,
|
86 |
+
world_size=args.world_size,
|
87 |
+
rank=args.rank,
|
88 |
+
timeout=datetime.timedelta(
|
89 |
+
days=365
|
90 |
+
), # allow auto-downloading and de-compressing
|
91 |
+
)
|
92 |
+
torch.distributed.barrier()
|
93 |
+
setup_for_distributed(args.rank == 0)
|
94 |
+
|
95 |
+
|
96 |
+
def get_dist_info():
|
97 |
+
if torch.__version__ < "1.0":
|
98 |
+
initialized = dist._initialized
|
99 |
+
else:
|
100 |
+
initialized = dist.is_initialized()
|
101 |
+
if initialized:
|
102 |
+
rank = dist.get_rank()
|
103 |
+
world_size = dist.get_world_size()
|
104 |
+
else: # non-distributed training
|
105 |
+
rank = 0
|
106 |
+
world_size = 1
|
107 |
+
return rank, world_size
|
108 |
+
|
109 |
+
|
110 |
+
def main_process(func):
|
111 |
+
@functools.wraps(func)
|
112 |
+
def wrapper(*args, **kwargs):
|
113 |
+
rank, _ = get_dist_info()
|
114 |
+
if rank == 0:
|
115 |
+
return func(*args, **kwargs)
|
116 |
+
|
117 |
+
return wrapper
|
118 |
+
|
119 |
+
|
120 |
+
def download_cached_file(url, check_hash=True, progress=False):
|
121 |
+
"""
|
122 |
+
Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
|
123 |
+
If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
|
124 |
+
"""
|
125 |
+
|
126 |
+
def get_cached_file_path():
|
127 |
+
# a hack to sync the file path across processes
|
128 |
+
parts = torch.hub.urlparse(url)
|
129 |
+
filename = os.path.basename(parts.path)
|
130 |
+
cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
|
131 |
+
|
132 |
+
return cached_file
|
133 |
+
|
134 |
+
if is_main_process():
|
135 |
+
timm_hub.download_cached_file(url, check_hash, progress)
|
136 |
+
|
137 |
+
if is_dist_avail_and_initialized():
|
138 |
+
dist.barrier()
|
139 |
+
|
140 |
+
return get_cached_file_path()
|
141 |
+
|
142 |
+
|
143 |
+
def all_reduce_mean(x):
|
144 |
+
world_size = get_world_size()
|
145 |
+
if world_size > 1:
|
146 |
+
x_reduce = torch.tensor(x).cuda()
|
147 |
+
dist.all_reduce(x_reduce)
|
148 |
+
x_reduce /= world_size
|
149 |
+
return x_reduce.item()
|
150 |
+
else:
|
151 |
+
return x
|
OPERA/minigpt4/common/gradcam.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from matplotlib import pyplot as plt
|
3 |
+
from scipy.ndimage import filters
|
4 |
+
from skimage import transform as skimage_transform
|
5 |
+
|
6 |
+
|
7 |
+
def getAttMap(img, attMap, blur=True, overlap=True):
|
8 |
+
attMap -= attMap.min()
|
9 |
+
if attMap.max() > 0:
|
10 |
+
attMap /= attMap.max()
|
11 |
+
attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant")
|
12 |
+
if blur:
|
13 |
+
attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
|
14 |
+
attMap -= attMap.min()
|
15 |
+
attMap /= attMap.max()
|
16 |
+
cmap = plt.get_cmap("jet")
|
17 |
+
attMapV = cmap(attMap)
|
18 |
+
attMapV = np.delete(attMapV, 3, 2)
|
19 |
+
if overlap:
|
20 |
+
attMap = (
|
21 |
+
1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
|
22 |
+
+ (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
|
23 |
+
)
|
24 |
+
return attMap
|
OPERA/minigpt4/common/logger.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
import datetime
|
9 |
+
import logging
|
10 |
+
import time
|
11 |
+
from collections import defaultdict, deque
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.distributed as dist
|
15 |
+
|
16 |
+
from minigpt4.common import dist_utils
|
17 |
+
|
18 |
+
|
19 |
+
class SmoothedValue(object):
|
20 |
+
"""Track a series of values and provide access to smoothed values over a
|
21 |
+
window or the global series average.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, window_size=20, fmt=None):
|
25 |
+
if fmt is None:
|
26 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
27 |
+
self.deque = deque(maxlen=window_size)
|
28 |
+
self.total = 0.0
|
29 |
+
self.count = 0
|
30 |
+
self.fmt = fmt
|
31 |
+
|
32 |
+
def update(self, value, n=1):
|
33 |
+
self.deque.append(value)
|
34 |
+
self.count += n
|
35 |
+
self.total += value * n
|
36 |
+
|
37 |
+
def synchronize_between_processes(self):
|
38 |
+
"""
|
39 |
+
Warning: does not synchronize the deque!
|
40 |
+
"""
|
41 |
+
if not dist_utils.is_dist_avail_and_initialized():
|
42 |
+
return
|
43 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
|
44 |
+
dist.barrier()
|
45 |
+
dist.all_reduce(t)
|
46 |
+
t = t.tolist()
|
47 |
+
self.count = int(t[0])
|
48 |
+
self.total = t[1]
|
49 |
+
|
50 |
+
@property
|
51 |
+
def median(self):
|
52 |
+
d = torch.tensor(list(self.deque))
|
53 |
+
return d.median().item()
|
54 |
+
|
55 |
+
@property
|
56 |
+
def avg(self):
|
57 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
58 |
+
return d.mean().item()
|
59 |
+
|
60 |
+
@property
|
61 |
+
def global_avg(self):
|
62 |
+
return self.total / self.count
|
63 |
+
|
64 |
+
@property
|
65 |
+
def max(self):
|
66 |
+
return max(self.deque)
|
67 |
+
|
68 |
+
@property
|
69 |
+
def value(self):
|
70 |
+
return self.deque[-1]
|
71 |
+
|
72 |
+
def __str__(self):
|
73 |
+
return self.fmt.format(
|
74 |
+
median=self.median,
|
75 |
+
avg=self.avg,
|
76 |
+
global_avg=self.global_avg,
|
77 |
+
max=self.max,
|
78 |
+
value=self.value,
|
79 |
+
)
|
80 |
+
|
81 |
+
|
82 |
+
class MetricLogger(object):
|
83 |
+
def __init__(self, delimiter="\t"):
|
84 |
+
self.meters = defaultdict(SmoothedValue)
|
85 |
+
self.delimiter = delimiter
|
86 |
+
|
87 |
+
def update(self, **kwargs):
|
88 |
+
for k, v in kwargs.items():
|
89 |
+
if isinstance(v, torch.Tensor):
|
90 |
+
v = v.item()
|
91 |
+
assert isinstance(v, (float, int))
|
92 |
+
self.meters[k].update(v)
|
93 |
+
|
94 |
+
def __getattr__(self, attr):
|
95 |
+
if attr in self.meters:
|
96 |
+
return self.meters[attr]
|
97 |
+
if attr in self.__dict__:
|
98 |
+
return self.__dict__[attr]
|
99 |
+
raise AttributeError(
|
100 |
+
"'{}' object has no attribute '{}'".format(type(self).__name__, attr)
|
101 |
+
)
|
102 |
+
|
103 |
+
def __str__(self):
|
104 |
+
loss_str = []
|
105 |
+
for name, meter in self.meters.items():
|
106 |
+
loss_str.append("{}: {}".format(name, str(meter)))
|
107 |
+
return self.delimiter.join(loss_str)
|
108 |
+
|
109 |
+
def global_avg(self):
|
110 |
+
loss_str = []
|
111 |
+
for name, meter in self.meters.items():
|
112 |
+
loss_str.append("{}: {:.4f}".format(name, meter.global_avg))
|
113 |
+
return self.delimiter.join(loss_str)
|
114 |
+
|
115 |
+
def synchronize_between_processes(self):
|
116 |
+
for meter in self.meters.values():
|
117 |
+
meter.synchronize_between_processes()
|
118 |
+
|
119 |
+
def add_meter(self, name, meter):
|
120 |
+
self.meters[name] = meter
|
121 |
+
|
122 |
+
def log_every(self, iterable, print_freq, header=None):
|
123 |
+
i = 0
|
124 |
+
if not header:
|
125 |
+
header = ""
|
126 |
+
start_time = time.time()
|
127 |
+
end = time.time()
|
128 |
+
iter_time = SmoothedValue(fmt="{avg:.4f}")
|
129 |
+
data_time = SmoothedValue(fmt="{avg:.4f}")
|
130 |
+
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
|
131 |
+
log_msg = [
|
132 |
+
header,
|
133 |
+
"[{0" + space_fmt + "}/{1}]",
|
134 |
+
"eta: {eta}",
|
135 |
+
"{meters}",
|
136 |
+
"time: {time}",
|
137 |
+
"data: {data}",
|
138 |
+
]
|
139 |
+
if torch.cuda.is_available():
|
140 |
+
log_msg.append("max mem: {memory:.0f}")
|
141 |
+
log_msg = self.delimiter.join(log_msg)
|
142 |
+
MB = 1024.0 * 1024.0
|
143 |
+
for obj in iterable:
|
144 |
+
data_time.update(time.time() - end)
|
145 |
+
yield obj
|
146 |
+
iter_time.update(time.time() - end)
|
147 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
148 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
149 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
150 |
+
if torch.cuda.is_available():
|
151 |
+
print(
|
152 |
+
log_msg.format(
|
153 |
+
i,
|
154 |
+
len(iterable),
|
155 |
+
eta=eta_string,
|
156 |
+
meters=str(self),
|
157 |
+
time=str(iter_time),
|
158 |
+
data=str(data_time),
|
159 |
+
memory=torch.cuda.max_memory_allocated() / MB,
|
160 |
+
)
|
161 |
+
)
|
162 |
+
else:
|
163 |
+
print(
|
164 |
+
log_msg.format(
|
165 |
+
i,
|
166 |
+
len(iterable),
|
167 |
+
eta=eta_string,
|
168 |
+
meters=str(self),
|
169 |
+
time=str(iter_time),
|
170 |
+
data=str(data_time),
|
171 |
+
)
|
172 |
+
)
|
173 |
+
i += 1
|
174 |
+
end = time.time()
|
175 |
+
total_time = time.time() - start_time
|
176 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
177 |
+
print(
|
178 |
+
"{} Total time: {} ({:.4f} s / it)".format(
|
179 |
+
header, total_time_str, total_time / len(iterable)
|
180 |
+
)
|
181 |
+
)
|
182 |
+
|
183 |
+
|
184 |
+
class AttrDict(dict):
|
185 |
+
def __init__(self, *args, **kwargs):
|
186 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
187 |
+
self.__dict__ = self
|
188 |
+
|
189 |
+
|
190 |
+
def setup_logger():
|
191 |
+
logging.basicConfig(
|
192 |
+
level=logging.INFO if dist_utils.is_main_process() else logging.WARN,
|
193 |
+
format="%(asctime)s [%(levelname)s] %(message)s",
|
194 |
+
handlers=[logging.StreamHandler()],
|
195 |
+
)
|
OPERA/minigpt4/common/optims.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
import math
|
9 |
+
|
10 |
+
from minigpt4.common.registry import registry
|
11 |
+
|
12 |
+
|
13 |
+
@registry.register_lr_scheduler("linear_warmup_step_lr")
|
14 |
+
class LinearWarmupStepLRScheduler:
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
optimizer,
|
18 |
+
max_epoch,
|
19 |
+
min_lr,
|
20 |
+
init_lr,
|
21 |
+
decay_rate=1,
|
22 |
+
warmup_start_lr=-1,
|
23 |
+
warmup_steps=0,
|
24 |
+
**kwargs
|
25 |
+
):
|
26 |
+
self.optimizer = optimizer
|
27 |
+
|
28 |
+
self.max_epoch = max_epoch
|
29 |
+
self.min_lr = min_lr
|
30 |
+
|
31 |
+
self.decay_rate = decay_rate
|
32 |
+
|
33 |
+
self.init_lr = init_lr
|
34 |
+
self.warmup_steps = warmup_steps
|
35 |
+
self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
|
36 |
+
|
37 |
+
def step(self, cur_epoch, cur_step):
|
38 |
+
if cur_epoch == 0:
|
39 |
+
warmup_lr_schedule(
|
40 |
+
step=cur_step,
|
41 |
+
optimizer=self.optimizer,
|
42 |
+
max_step=self.warmup_steps,
|
43 |
+
init_lr=self.warmup_start_lr,
|
44 |
+
max_lr=self.init_lr,
|
45 |
+
)
|
46 |
+
else:
|
47 |
+
step_lr_schedule(
|
48 |
+
epoch=cur_epoch,
|
49 |
+
optimizer=self.optimizer,
|
50 |
+
init_lr=self.init_lr,
|
51 |
+
min_lr=self.min_lr,
|
52 |
+
decay_rate=self.decay_rate,
|
53 |
+
)
|
54 |
+
|
55 |
+
|
56 |
+
@registry.register_lr_scheduler("linear_warmup_cosine_lr")
|
57 |
+
class LinearWarmupCosineLRScheduler:
|
58 |
+
def __init__(
|
59 |
+
self,
|
60 |
+
optimizer,
|
61 |
+
max_epoch,
|
62 |
+
iters_per_epoch,
|
63 |
+
min_lr,
|
64 |
+
init_lr,
|
65 |
+
warmup_steps=0,
|
66 |
+
warmup_start_lr=-1,
|
67 |
+
**kwargs
|
68 |
+
):
|
69 |
+
self.optimizer = optimizer
|
70 |
+
|
71 |
+
self.max_epoch = max_epoch
|
72 |
+
self.iters_per_epoch = iters_per_epoch
|
73 |
+
self.min_lr = min_lr
|
74 |
+
|
75 |
+
self.init_lr = init_lr
|
76 |
+
self.warmup_steps = warmup_steps
|
77 |
+
self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
|
78 |
+
|
79 |
+
def step(self, cur_epoch, cur_step):
|
80 |
+
total_cur_step = cur_epoch * self.iters_per_epoch + cur_step
|
81 |
+
if total_cur_step < self.warmup_steps:
|
82 |
+
warmup_lr_schedule(
|
83 |
+
step=cur_step,
|
84 |
+
optimizer=self.optimizer,
|
85 |
+
max_step=self.warmup_steps,
|
86 |
+
init_lr=self.warmup_start_lr,
|
87 |
+
max_lr=self.init_lr,
|
88 |
+
)
|
89 |
+
else:
|
90 |
+
cosine_lr_schedule(
|
91 |
+
epoch=total_cur_step,
|
92 |
+
optimizer=self.optimizer,
|
93 |
+
max_epoch=self.max_epoch * self.iters_per_epoch,
|
94 |
+
init_lr=self.init_lr,
|
95 |
+
min_lr=self.min_lr,
|
96 |
+
)
|
97 |
+
|
98 |
+
|
99 |
+
def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
|
100 |
+
"""Decay the learning rate"""
|
101 |
+
lr = (init_lr - min_lr) * 0.5 * (
|
102 |
+
1.0 + math.cos(math.pi * epoch / max_epoch)
|
103 |
+
) + min_lr
|
104 |
+
for param_group in optimizer.param_groups:
|
105 |
+
param_group["lr"] = lr
|
106 |
+
|
107 |
+
|
108 |
+
def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
|
109 |
+
"""Warmup the learning rate"""
|
110 |
+
lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
|
111 |
+
for param_group in optimizer.param_groups:
|
112 |
+
param_group["lr"] = lr
|
113 |
+
|
114 |
+
|
115 |
+
def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
|
116 |
+
"""Decay the learning rate"""
|
117 |
+
lr = max(min_lr, init_lr * (decay_rate**epoch))
|
118 |
+
for param_group in optimizer.param_groups:
|
119 |
+
param_group["lr"] = lr
|
OPERA/minigpt4/common/registry.py
ADDED
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
|
9 |
+
class Registry:
|
10 |
+
mapping = {
|
11 |
+
"builder_name_mapping": {},
|
12 |
+
"task_name_mapping": {},
|
13 |
+
"processor_name_mapping": {},
|
14 |
+
"model_name_mapping": {},
|
15 |
+
"lr_scheduler_name_mapping": {},
|
16 |
+
"runner_name_mapping": {},
|
17 |
+
"state": {},
|
18 |
+
"paths": {},
|
19 |
+
}
|
20 |
+
|
21 |
+
@classmethod
|
22 |
+
def register_builder(cls, name):
|
23 |
+
r"""Register a dataset builder to registry with key 'name'
|
24 |
+
|
25 |
+
Args:
|
26 |
+
name: Key with which the builder will be registered.
|
27 |
+
|
28 |
+
Usage:
|
29 |
+
|
30 |
+
from minigpt4.common.registry import registry
|
31 |
+
from minigpt4.datasets.base_dataset_builder import BaseDatasetBuilder
|
32 |
+
"""
|
33 |
+
|
34 |
+
def wrap(builder_cls):
|
35 |
+
from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder
|
36 |
+
|
37 |
+
assert issubclass(
|
38 |
+
builder_cls, BaseDatasetBuilder
|
39 |
+
), "All builders must inherit BaseDatasetBuilder class, found {}".format(
|
40 |
+
builder_cls
|
41 |
+
)
|
42 |
+
if name in cls.mapping["builder_name_mapping"]:
|
43 |
+
raise KeyError(
|
44 |
+
"Name '{}' already registered for {}.".format(
|
45 |
+
name, cls.mapping["builder_name_mapping"][name]
|
46 |
+
)
|
47 |
+
)
|
48 |
+
cls.mapping["builder_name_mapping"][name] = builder_cls
|
49 |
+
return builder_cls
|
50 |
+
|
51 |
+
return wrap
|
52 |
+
|
53 |
+
@classmethod
|
54 |
+
def register_task(cls, name):
|
55 |
+
r"""Register a task to registry with key 'name'
|
56 |
+
|
57 |
+
Args:
|
58 |
+
name: Key with which the task will be registered.
|
59 |
+
|
60 |
+
Usage:
|
61 |
+
|
62 |
+
from minigpt4.common.registry import registry
|
63 |
+
"""
|
64 |
+
|
65 |
+
def wrap(task_cls):
|
66 |
+
from minigpt4.tasks.base_task import BaseTask
|
67 |
+
|
68 |
+
assert issubclass(
|
69 |
+
task_cls, BaseTask
|
70 |
+
), "All tasks must inherit BaseTask class"
|
71 |
+
if name in cls.mapping["task_name_mapping"]:
|
72 |
+
raise KeyError(
|
73 |
+
"Name '{}' already registered for {}.".format(
|
74 |
+
name, cls.mapping["task_name_mapping"][name]
|
75 |
+
)
|
76 |
+
)
|
77 |
+
cls.mapping["task_name_mapping"][name] = task_cls
|
78 |
+
return task_cls
|
79 |
+
|
80 |
+
return wrap
|
81 |
+
|
82 |
+
@classmethod
|
83 |
+
def register_model(cls, name):
|
84 |
+
r"""Register a task to registry with key 'name'
|
85 |
+
|
86 |
+
Args:
|
87 |
+
name: Key with which the task will be registered.
|
88 |
+
|
89 |
+
Usage:
|
90 |
+
|
91 |
+
from minigpt4.common.registry import registry
|
92 |
+
"""
|
93 |
+
|
94 |
+
def wrap(model_cls):
|
95 |
+
from minigpt4.models import BaseModel
|
96 |
+
|
97 |
+
assert issubclass(
|
98 |
+
model_cls, BaseModel
|
99 |
+
), "All models must inherit BaseModel class"
|
100 |
+
if name in cls.mapping["model_name_mapping"]:
|
101 |
+
raise KeyError(
|
102 |
+
"Name '{}' already registered for {}.".format(
|
103 |
+
name, cls.mapping["model_name_mapping"][name]
|
104 |
+
)
|
105 |
+
)
|
106 |
+
cls.mapping["model_name_mapping"][name] = model_cls
|
107 |
+
return model_cls
|
108 |
+
|
109 |
+
return wrap
|
110 |
+
|
111 |
+
@classmethod
|
112 |
+
def register_processor(cls, name):
|
113 |
+
r"""Register a processor to registry with key 'name'
|
114 |
+
|
115 |
+
Args:
|
116 |
+
name: Key with which the task will be registered.
|
117 |
+
|
118 |
+
Usage:
|
119 |
+
|
120 |
+
from minigpt4.common.registry import registry
|
121 |
+
"""
|
122 |
+
|
123 |
+
def wrap(processor_cls):
|
124 |
+
from minigpt4.processors import BaseProcessor
|
125 |
+
|
126 |
+
assert issubclass(
|
127 |
+
processor_cls, BaseProcessor
|
128 |
+
), "All processors must inherit BaseProcessor class"
|
129 |
+
if name in cls.mapping["processor_name_mapping"]:
|
130 |
+
raise KeyError(
|
131 |
+
"Name '{}' already registered for {}.".format(
|
132 |
+
name, cls.mapping["processor_name_mapping"][name]
|
133 |
+
)
|
134 |
+
)
|
135 |
+
cls.mapping["processor_name_mapping"][name] = processor_cls
|
136 |
+
return processor_cls
|
137 |
+
|
138 |
+
return wrap
|
139 |
+
|
140 |
+
@classmethod
|
141 |
+
def register_lr_scheduler(cls, name):
|
142 |
+
r"""Register a model to registry with key 'name'
|
143 |
+
|
144 |
+
Args:
|
145 |
+
name: Key with which the task will be registered.
|
146 |
+
|
147 |
+
Usage:
|
148 |
+
|
149 |
+
from minigpt4.common.registry import registry
|
150 |
+
"""
|
151 |
+
|
152 |
+
def wrap(lr_sched_cls):
|
153 |
+
if name in cls.mapping["lr_scheduler_name_mapping"]:
|
154 |
+
raise KeyError(
|
155 |
+
"Name '{}' already registered for {}.".format(
|
156 |
+
name, cls.mapping["lr_scheduler_name_mapping"][name]
|
157 |
+
)
|
158 |
+
)
|
159 |
+
cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls
|
160 |
+
return lr_sched_cls
|
161 |
+
|
162 |
+
return wrap
|
163 |
+
|
164 |
+
@classmethod
|
165 |
+
def register_runner(cls, name):
|
166 |
+
r"""Register a model to registry with key 'name'
|
167 |
+
|
168 |
+
Args:
|
169 |
+
name: Key with which the task will be registered.
|
170 |
+
|
171 |
+
Usage:
|
172 |
+
|
173 |
+
from minigpt4.common.registry import registry
|
174 |
+
"""
|
175 |
+
|
176 |
+
def wrap(runner_cls):
|
177 |
+
if name in cls.mapping["runner_name_mapping"]:
|
178 |
+
raise KeyError(
|
179 |
+
"Name '{}' already registered for {}.".format(
|
180 |
+
name, cls.mapping["runner_name_mapping"][name]
|
181 |
+
)
|
182 |
+
)
|
183 |
+
cls.mapping["runner_name_mapping"][name] = runner_cls
|
184 |
+
return runner_cls
|
185 |
+
|
186 |
+
return wrap
|
187 |
+
|
188 |
+
@classmethod
|
189 |
+
def register_path(cls, name, path):
|
190 |
+
r"""Register a path to registry with key 'name'
|
191 |
+
|
192 |
+
Args:
|
193 |
+
name: Key with which the path will be registered.
|
194 |
+
|
195 |
+
Usage:
|
196 |
+
|
197 |
+
from minigpt4.common.registry import registry
|
198 |
+
"""
|
199 |
+
assert isinstance(path, str), "All path must be str."
|
200 |
+
if name in cls.mapping["paths"]:
|
201 |
+
raise KeyError("Name '{}' already registered.".format(name))
|
202 |
+
cls.mapping["paths"][name] = path
|
203 |
+
|
204 |
+
@classmethod
|
205 |
+
def register(cls, name, obj):
|
206 |
+
r"""Register an item to registry with key 'name'
|
207 |
+
|
208 |
+
Args:
|
209 |
+
name: Key with which the item will be registered.
|
210 |
+
|
211 |
+
Usage::
|
212 |
+
|
213 |
+
from minigpt4.common.registry import registry
|
214 |
+
|
215 |
+
registry.register("config", {})
|
216 |
+
"""
|
217 |
+
path = name.split(".")
|
218 |
+
current = cls.mapping["state"]
|
219 |
+
|
220 |
+
for part in path[:-1]:
|
221 |
+
if part not in current:
|
222 |
+
current[part] = {}
|
223 |
+
current = current[part]
|
224 |
+
|
225 |
+
current[path[-1]] = obj
|
226 |
+
|
227 |
+
# @classmethod
|
228 |
+
# def get_trainer_class(cls, name):
|
229 |
+
# return cls.mapping["trainer_name_mapping"].get(name, None)
|
230 |
+
|
231 |
+
@classmethod
|
232 |
+
def get_builder_class(cls, name):
|
233 |
+
return cls.mapping["builder_name_mapping"].get(name, None)
|
234 |
+
|
235 |
+
@classmethod
|
236 |
+
def get_model_class(cls, name):
|
237 |
+
return cls.mapping["model_name_mapping"].get(name, None)
|
238 |
+
|
239 |
+
@classmethod
|
240 |
+
def get_task_class(cls, name):
|
241 |
+
return cls.mapping["task_name_mapping"].get(name, None)
|
242 |
+
|
243 |
+
@classmethod
|
244 |
+
def get_processor_class(cls, name):
|
245 |
+
return cls.mapping["processor_name_mapping"].get(name, None)
|
246 |
+
|
247 |
+
@classmethod
|
248 |
+
def get_lr_scheduler_class(cls, name):
|
249 |
+
return cls.mapping["lr_scheduler_name_mapping"].get(name, None)
|
250 |
+
|
251 |
+
@classmethod
|
252 |
+
def get_runner_class(cls, name):
|
253 |
+
return cls.mapping["runner_name_mapping"].get(name, None)
|
254 |
+
|
255 |
+
@classmethod
|
256 |
+
def list_runners(cls):
|
257 |
+
return sorted(cls.mapping["runner_name_mapping"].keys())
|
258 |
+
|
259 |
+
@classmethod
|
260 |
+
def list_models(cls):
|
261 |
+
return sorted(cls.mapping["model_name_mapping"].keys())
|
262 |
+
|
263 |
+
@classmethod
|
264 |
+
def list_tasks(cls):
|
265 |
+
return sorted(cls.mapping["task_name_mapping"].keys())
|
266 |
+
|
267 |
+
@classmethod
|
268 |
+
def list_processors(cls):
|
269 |
+
return sorted(cls.mapping["processor_name_mapping"].keys())
|
270 |
+
|
271 |
+
@classmethod
|
272 |
+
def list_lr_schedulers(cls):
|
273 |
+
return sorted(cls.mapping["lr_scheduler_name_mapping"].keys())
|
274 |
+
|
275 |
+
@classmethod
|
276 |
+
def list_datasets(cls):
|
277 |
+
return sorted(cls.mapping["builder_name_mapping"].keys())
|
278 |
+
|
279 |
+
@classmethod
|
280 |
+
def get_path(cls, name):
|
281 |
+
return cls.mapping["paths"].get(name, None)
|
282 |
+
|
283 |
+
@classmethod
|
284 |
+
def get(cls, name, default=None, no_warning=False):
|
285 |
+
r"""Get an item from registry with key 'name'
|
286 |
+
|
287 |
+
Args:
|
288 |
+
name (string): Key whose value needs to be retrieved.
|
289 |
+
default: If passed and key is not in registry, default value will
|
290 |
+
be returned with a warning. Default: None
|
291 |
+
no_warning (bool): If passed as True, warning when key doesn't exist
|
292 |
+
will not be generated. Useful for MMF's
|
293 |
+
internal operations. Default: False
|
294 |
+
"""
|
295 |
+
original_name = name
|
296 |
+
name = name.split(".")
|
297 |
+
value = cls.mapping["state"]
|
298 |
+
for subname in name:
|
299 |
+
value = value.get(subname, default)
|
300 |
+
if value is default:
|
301 |
+
break
|
302 |
+
|
303 |
+
if (
|
304 |
+
"writer" in cls.mapping["state"]
|
305 |
+
and value == default
|
306 |
+
and no_warning is False
|
307 |
+
):
|
308 |
+
cls.mapping["state"]["writer"].warning(
|
309 |
+
"Key {} is not present in registry, returning default value "
|
310 |
+
"of {}".format(original_name, default)
|
311 |
+
)
|
312 |
+
return value
|
313 |
+
|
314 |
+
@classmethod
|
315 |
+
def unregister(cls, name):
|
316 |
+
r"""Remove an item from registry with key 'name'
|
317 |
+
|
318 |
+
Args:
|
319 |
+
name: Key which needs to be removed.
|
320 |
+
Usage::
|
321 |
+
|
322 |
+
from mmf.common.registry import registry
|
323 |
+
|
324 |
+
config = registry.unregister("config")
|
325 |
+
"""
|
326 |
+
return cls.mapping["state"].pop(name, None)
|
327 |
+
|
328 |
+
|
329 |
+
registry = Registry()
|
OPERA/minigpt4/common/utils.py
ADDED
@@ -0,0 +1,424 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
import io
|
9 |
+
import json
|
10 |
+
import logging
|
11 |
+
import os
|
12 |
+
import pickle
|
13 |
+
import re
|
14 |
+
import shutil
|
15 |
+
import urllib
|
16 |
+
import urllib.error
|
17 |
+
import urllib.request
|
18 |
+
from typing import Optional
|
19 |
+
from urllib.parse import urlparse
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
import pandas as pd
|
23 |
+
import yaml
|
24 |
+
from iopath.common.download import download
|
25 |
+
from iopath.common.file_io import file_lock, g_pathmgr
|
26 |
+
from minigpt4.common.registry import registry
|
27 |
+
from torch.utils.model_zoo import tqdm
|
28 |
+
from torchvision.datasets.utils import (
|
29 |
+
check_integrity,
|
30 |
+
download_file_from_google_drive,
|
31 |
+
extract_archive,
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
def now():
|
36 |
+
from datetime import datetime
|
37 |
+
|
38 |
+
return datetime.now().strftime("%Y%m%d%H%M")[:-1]
|
39 |
+
|
40 |
+
|
41 |
+
def is_url(url_or_filename):
|
42 |
+
parsed = urlparse(url_or_filename)
|
43 |
+
return parsed.scheme in ("http", "https")
|
44 |
+
|
45 |
+
|
46 |
+
def get_cache_path(rel_path):
|
47 |
+
return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path))
|
48 |
+
|
49 |
+
|
50 |
+
def get_abs_path(rel_path):
|
51 |
+
return os.path.join(registry.get_path("library_root"), rel_path)
|
52 |
+
|
53 |
+
|
54 |
+
def load_json(filename):
|
55 |
+
with open(filename, "r") as f:
|
56 |
+
return json.load(f)
|
57 |
+
|
58 |
+
|
59 |
+
# The following are adapted from torchvision and vissl
|
60 |
+
# torchvision: https://github.com/pytorch/vision
|
61 |
+
# vissl: https://github.com/facebookresearch/vissl/blob/main/vissl/utils/download.py
|
62 |
+
|
63 |
+
|
64 |
+
def makedir(dir_path):
|
65 |
+
"""
|
66 |
+
Create the directory if it does not exist.
|
67 |
+
"""
|
68 |
+
is_success = False
|
69 |
+
try:
|
70 |
+
if not g_pathmgr.exists(dir_path):
|
71 |
+
g_pathmgr.mkdirs(dir_path)
|
72 |
+
is_success = True
|
73 |
+
except BaseException:
|
74 |
+
print(f"Error creating directory: {dir_path}")
|
75 |
+
return is_success
|
76 |
+
|
77 |
+
|
78 |
+
def get_redirected_url(url: str):
|
79 |
+
"""
|
80 |
+
Given a URL, returns the URL it redirects to or the
|
81 |
+
original URL in case of no indirection
|
82 |
+
"""
|
83 |
+
import requests
|
84 |
+
|
85 |
+
with requests.Session() as session:
|
86 |
+
with session.get(url, stream=True, allow_redirects=True) as response:
|
87 |
+
if response.history:
|
88 |
+
return response.url
|
89 |
+
else:
|
90 |
+
return url
|
91 |
+
|
92 |
+
|
93 |
+
def to_google_drive_download_url(view_url: str) -> str:
|
94 |
+
"""
|
95 |
+
Utility function to transform a view URL of google drive
|
96 |
+
to a download URL for google drive
|
97 |
+
Example input:
|
98 |
+
https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view
|
99 |
+
Example output:
|
100 |
+
https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp
|
101 |
+
"""
|
102 |
+
splits = view_url.split("/")
|
103 |
+
assert splits[-1] == "view"
|
104 |
+
file_id = splits[-2]
|
105 |
+
return f"https://drive.google.com/uc?export=download&id={file_id}"
|
106 |
+
|
107 |
+
|
108 |
+
def download_google_drive_url(url: str, output_path: str, output_file_name: str):
|
109 |
+
"""
|
110 |
+
Download a file from google drive
|
111 |
+
Downloading an URL from google drive requires confirmation when
|
112 |
+
the file of the size is too big (google drive notifies that
|
113 |
+
anti-viral checks cannot be performed on such files)
|
114 |
+
"""
|
115 |
+
import requests
|
116 |
+
|
117 |
+
with requests.Session() as session:
|
118 |
+
|
119 |
+
# First get the confirmation token and append it to the URL
|
120 |
+
with session.get(url, stream=True, allow_redirects=True) as response:
|
121 |
+
for k, v in response.cookies.items():
|
122 |
+
if k.startswith("download_warning"):
|
123 |
+
url = url + "&confirm=" + v
|
124 |
+
|
125 |
+
# Then download the content of the file
|
126 |
+
with session.get(url, stream=True, verify=True) as response:
|
127 |
+
makedir(output_path)
|
128 |
+
path = os.path.join(output_path, output_file_name)
|
129 |
+
total_size = int(response.headers.get("Content-length", 0))
|
130 |
+
with open(path, "wb") as file:
|
131 |
+
from tqdm import tqdm
|
132 |
+
|
133 |
+
with tqdm(total=total_size) as progress_bar:
|
134 |
+
for block in response.iter_content(
|
135 |
+
chunk_size=io.DEFAULT_BUFFER_SIZE
|
136 |
+
):
|
137 |
+
file.write(block)
|
138 |
+
progress_bar.update(len(block))
|
139 |
+
|
140 |
+
|
141 |
+
def _get_google_drive_file_id(url: str) -> Optional[str]:
|
142 |
+
parts = urlparse(url)
|
143 |
+
|
144 |
+
if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
|
145 |
+
return None
|
146 |
+
|
147 |
+
match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
|
148 |
+
if match is None:
|
149 |
+
return None
|
150 |
+
|
151 |
+
return match.group("id")
|
152 |
+
|
153 |
+
|
154 |
+
def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
|
155 |
+
with open(filename, "wb") as fh:
|
156 |
+
with urllib.request.urlopen(
|
157 |
+
urllib.request.Request(url, headers={"User-Agent": "vissl"})
|
158 |
+
) as response:
|
159 |
+
with tqdm(total=response.length) as pbar:
|
160 |
+
for chunk in iter(lambda: response.read(chunk_size), ""):
|
161 |
+
if not chunk:
|
162 |
+
break
|
163 |
+
pbar.update(chunk_size)
|
164 |
+
fh.write(chunk)
|
165 |
+
|
166 |
+
|
167 |
+
def download_url(
|
168 |
+
url: str,
|
169 |
+
root: str,
|
170 |
+
filename: Optional[str] = None,
|
171 |
+
md5: Optional[str] = None,
|
172 |
+
) -> None:
|
173 |
+
"""Download a file from a url and place it in root.
|
174 |
+
Args:
|
175 |
+
url (str): URL to download file from
|
176 |
+
root (str): Directory to place downloaded file in
|
177 |
+
filename (str, optional): Name to save the file under.
|
178 |
+
If None, use the basename of the URL.
|
179 |
+
md5 (str, optional): MD5 checksum of the download. If None, do not check
|
180 |
+
"""
|
181 |
+
root = os.path.expanduser(root)
|
182 |
+
if not filename:
|
183 |
+
filename = os.path.basename(url)
|
184 |
+
fpath = os.path.join(root, filename)
|
185 |
+
|
186 |
+
makedir(root)
|
187 |
+
|
188 |
+
# check if file is already present locally
|
189 |
+
if check_integrity(fpath, md5):
|
190 |
+
print("Using downloaded and verified file: " + fpath)
|
191 |
+
return
|
192 |
+
|
193 |
+
# expand redirect chain if needed
|
194 |
+
url = get_redirected_url(url)
|
195 |
+
|
196 |
+
# check if file is located on Google Drive
|
197 |
+
file_id = _get_google_drive_file_id(url)
|
198 |
+
if file_id is not None:
|
199 |
+
return download_file_from_google_drive(file_id, root, filename, md5)
|
200 |
+
|
201 |
+
# download the file
|
202 |
+
try:
|
203 |
+
print("Downloading " + url + " to " + fpath)
|
204 |
+
_urlretrieve(url, fpath)
|
205 |
+
except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
|
206 |
+
if url[:5] == "https":
|
207 |
+
url = url.replace("https:", "http:")
|
208 |
+
print(
|
209 |
+
"Failed download. Trying https -> http instead."
|
210 |
+
" Downloading " + url + " to " + fpath
|
211 |
+
)
|
212 |
+
_urlretrieve(url, fpath)
|
213 |
+
else:
|
214 |
+
raise e
|
215 |
+
|
216 |
+
# check integrity of downloaded file
|
217 |
+
if not check_integrity(fpath, md5):
|
218 |
+
raise RuntimeError("File not found or corrupted.")
|
219 |
+
|
220 |
+
|
221 |
+
def download_and_extract_archive(
|
222 |
+
url: str,
|
223 |
+
download_root: str,
|
224 |
+
extract_root: Optional[str] = None,
|
225 |
+
filename: Optional[str] = None,
|
226 |
+
md5: Optional[str] = None,
|
227 |
+
remove_finished: bool = False,
|
228 |
+
) -> None:
|
229 |
+
download_root = os.path.expanduser(download_root)
|
230 |
+
if extract_root is None:
|
231 |
+
extract_root = download_root
|
232 |
+
if not filename:
|
233 |
+
filename = os.path.basename(url)
|
234 |
+
|
235 |
+
download_url(url, download_root, filename, md5)
|
236 |
+
|
237 |
+
archive = os.path.join(download_root, filename)
|
238 |
+
print("Extracting {} to {}".format(archive, extract_root))
|
239 |
+
extract_archive(archive, extract_root, remove_finished)
|
240 |
+
|
241 |
+
|
242 |
+
def cache_url(url: str, cache_dir: str) -> str:
|
243 |
+
"""
|
244 |
+
This implementation downloads the remote resource and caches it locally.
|
245 |
+
The resource will only be downloaded if not previously requested.
|
246 |
+
"""
|
247 |
+
parsed_url = urlparse(url)
|
248 |
+
dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/")))
|
249 |
+
makedir(dirname)
|
250 |
+
filename = url.split("/")[-1]
|
251 |
+
cached = os.path.join(dirname, filename)
|
252 |
+
with file_lock(cached):
|
253 |
+
if not os.path.isfile(cached):
|
254 |
+
logging.info(f"Downloading {url} to {cached} ...")
|
255 |
+
cached = download(url, dirname, filename=filename)
|
256 |
+
logging.info(f"URL {url} cached in {cached}")
|
257 |
+
return cached
|
258 |
+
|
259 |
+
|
260 |
+
# TODO (prigoyal): convert this into RAII-style API
|
261 |
+
def create_file_symlink(file1, file2):
|
262 |
+
"""
|
263 |
+
Simply create the symlinks for a given file1 to file2.
|
264 |
+
Useful during model checkpointing to symlinks to the
|
265 |
+
latest successful checkpoint.
|
266 |
+
"""
|
267 |
+
try:
|
268 |
+
if g_pathmgr.exists(file2):
|
269 |
+
g_pathmgr.rm(file2)
|
270 |
+
g_pathmgr.symlink(file1, file2)
|
271 |
+
except Exception as e:
|
272 |
+
logging.info(f"Could NOT create symlink. Error: {e}")
|
273 |
+
|
274 |
+
|
275 |
+
def save_file(data, filename, append_to_json=True, verbose=True):
|
276 |
+
"""
|
277 |
+
Common i/o utility to handle saving data to various file formats.
|
278 |
+
Supported:
|
279 |
+
.pkl, .pickle, .npy, .json
|
280 |
+
Specifically for .json, users have the option to either append (default)
|
281 |
+
or rewrite by passing in Boolean value to append_to_json.
|
282 |
+
"""
|
283 |
+
if verbose:
|
284 |
+
logging.info(f"Saving data to file: {filename}")
|
285 |
+
file_ext = os.path.splitext(filename)[1]
|
286 |
+
if file_ext in [".pkl", ".pickle"]:
|
287 |
+
with g_pathmgr.open(filename, "wb") as fopen:
|
288 |
+
pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL)
|
289 |
+
elif file_ext == ".npy":
|
290 |
+
with g_pathmgr.open(filename, "wb") as fopen:
|
291 |
+
np.save(fopen, data)
|
292 |
+
elif file_ext == ".json":
|
293 |
+
if append_to_json:
|
294 |
+
with g_pathmgr.open(filename, "a") as fopen:
|
295 |
+
fopen.write(json.dumps(data, sort_keys=True) + "\n")
|
296 |
+
fopen.flush()
|
297 |
+
else:
|
298 |
+
with g_pathmgr.open(filename, "w") as fopen:
|
299 |
+
fopen.write(json.dumps(data, sort_keys=True) + "\n")
|
300 |
+
fopen.flush()
|
301 |
+
elif file_ext == ".yaml":
|
302 |
+
with g_pathmgr.open(filename, "w") as fopen:
|
303 |
+
dump = yaml.dump(data)
|
304 |
+
fopen.write(dump)
|
305 |
+
fopen.flush()
|
306 |
+
else:
|
307 |
+
raise Exception(f"Saving {file_ext} is not supported yet")
|
308 |
+
|
309 |
+
if verbose:
|
310 |
+
logging.info(f"Saved data to file: {filename}")
|
311 |
+
|
312 |
+
|
313 |
+
def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False):
|
314 |
+
"""
|
315 |
+
Common i/o utility to handle loading data from various file formats.
|
316 |
+
Supported:
|
317 |
+
.pkl, .pickle, .npy, .json
|
318 |
+
For the npy files, we support reading the files in mmap_mode.
|
319 |
+
If the mmap_mode of reading is not successful, we load data without the
|
320 |
+
mmap_mode.
|
321 |
+
"""
|
322 |
+
if verbose:
|
323 |
+
logging.info(f"Loading data from file: {filename}")
|
324 |
+
|
325 |
+
file_ext = os.path.splitext(filename)[1]
|
326 |
+
if file_ext == ".txt":
|
327 |
+
with g_pathmgr.open(filename, "r") as fopen:
|
328 |
+
data = fopen.readlines()
|
329 |
+
elif file_ext in [".pkl", ".pickle"]:
|
330 |
+
with g_pathmgr.open(filename, "rb") as fopen:
|
331 |
+
data = pickle.load(fopen, encoding="latin1")
|
332 |
+
elif file_ext == ".npy":
|
333 |
+
if mmap_mode:
|
334 |
+
try:
|
335 |
+
with g_pathmgr.open(filename, "rb") as fopen:
|
336 |
+
data = np.load(
|
337 |
+
fopen,
|
338 |
+
allow_pickle=allow_pickle,
|
339 |
+
encoding="latin1",
|
340 |
+
mmap_mode=mmap_mode,
|
341 |
+
)
|
342 |
+
except ValueError as e:
|
343 |
+
logging.info(
|
344 |
+
f"Could not mmap {filename}: {e}. Trying without g_pathmgr"
|
345 |
+
)
|
346 |
+
data = np.load(
|
347 |
+
filename,
|
348 |
+
allow_pickle=allow_pickle,
|
349 |
+
encoding="latin1",
|
350 |
+
mmap_mode=mmap_mode,
|
351 |
+
)
|
352 |
+
logging.info("Successfully loaded without g_pathmgr")
|
353 |
+
except Exception:
|
354 |
+
logging.info("Could not mmap without g_pathmgr. Trying without mmap")
|
355 |
+
with g_pathmgr.open(filename, "rb") as fopen:
|
356 |
+
data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
|
357 |
+
else:
|
358 |
+
with g_pathmgr.open(filename, "rb") as fopen:
|
359 |
+
data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
|
360 |
+
elif file_ext == ".json":
|
361 |
+
with g_pathmgr.open(filename, "r") as fopen:
|
362 |
+
data = json.load(fopen)
|
363 |
+
elif file_ext == ".yaml":
|
364 |
+
with g_pathmgr.open(filename, "r") as fopen:
|
365 |
+
data = yaml.load(fopen, Loader=yaml.FullLoader)
|
366 |
+
elif file_ext == ".csv":
|
367 |
+
with g_pathmgr.open(filename, "r") as fopen:
|
368 |
+
data = pd.read_csv(fopen)
|
369 |
+
else:
|
370 |
+
raise Exception(f"Reading from {file_ext} is not supported yet")
|
371 |
+
return data
|
372 |
+
|
373 |
+
|
374 |
+
def abspath(resource_path: str):
|
375 |
+
"""
|
376 |
+
Make a path absolute, but take into account prefixes like
|
377 |
+
"http://" or "manifold://"
|
378 |
+
"""
|
379 |
+
regex = re.compile(r"^\w+://")
|
380 |
+
if regex.match(resource_path) is None:
|
381 |
+
return os.path.abspath(resource_path)
|
382 |
+
else:
|
383 |
+
return resource_path
|
384 |
+
|
385 |
+
|
386 |
+
def makedir(dir_path):
|
387 |
+
"""
|
388 |
+
Create the directory if it does not exist.
|
389 |
+
"""
|
390 |
+
is_success = False
|
391 |
+
try:
|
392 |
+
if not g_pathmgr.exists(dir_path):
|
393 |
+
g_pathmgr.mkdirs(dir_path)
|
394 |
+
is_success = True
|
395 |
+
except BaseException:
|
396 |
+
logging.info(f"Error creating directory: {dir_path}")
|
397 |
+
return is_success
|
398 |
+
|
399 |
+
|
400 |
+
def is_url(input_url):
|
401 |
+
"""
|
402 |
+
Check if an input string is a url. look for http(s):// and ignoring the case
|
403 |
+
"""
|
404 |
+
is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None
|
405 |
+
return is_url
|
406 |
+
|
407 |
+
|
408 |
+
def cleanup_dir(dir):
|
409 |
+
"""
|
410 |
+
Utility for deleting a directory. Useful for cleaning the storage space
|
411 |
+
that contains various training artifacts like checkpoints, data etc.
|
412 |
+
"""
|
413 |
+
if os.path.exists(dir):
|
414 |
+
logging.info(f"Deleting directory: {dir}")
|
415 |
+
shutil.rmtree(dir)
|
416 |
+
logging.info(f"Deleted contents of directory: {dir}")
|
417 |
+
|
418 |
+
|
419 |
+
def get_file_size(filename):
|
420 |
+
"""
|
421 |
+
Given a file, get the size of file in MB
|
422 |
+
"""
|
423 |
+
size_in_mb = os.path.getsize(filename) / float(1024**2)
|
424 |
+
return size_in_mb
|
OPERA/minigpt4/configs/datasets/cc_sbu/align.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
datasets:
|
2 |
+
cc_sbu_align:
|
3 |
+
data_type: images
|
4 |
+
build_info:
|
5 |
+
storage: /path/to/cc_sbu_align/
|
OPERA/minigpt4/configs/datasets/cc_sbu/defaults.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
datasets:
|
2 |
+
cc_sbu:
|
3 |
+
data_type: images
|
4 |
+
build_info:
|
5 |
+
storage: /path/to/cc_sbu_dataset/{00000..01255}.tar
|
OPERA/minigpt4/configs/datasets/laion/defaults.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
datasets:
|
2 |
+
laion:
|
3 |
+
data_type: images
|
4 |
+
build_info:
|
5 |
+
storage: /path/to/laion_dataset/{00000..10488}.tar
|
OPERA/minigpt4/configs/default.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
env:
|
2 |
+
# For default users
|
3 |
+
# cache_root: "cache"
|
4 |
+
# For internal use with persistent storage
|
5 |
+
cache_root: "/export/home/.cache/minigpt4"
|
OPERA/minigpt4/configs/models/blip2_instruct_vicuna13b.yaml
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, salesforce.com, inc.
|
2 |
+
# All rights reserved.
|
3 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
4 |
+
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
5 |
+
|
6 |
+
model:
|
7 |
+
arch: instruct_vicuna13b
|
8 |
+
load_finetuned: False
|
9 |
+
load_pretrained: True
|
10 |
+
|
11 |
+
pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/InstructBLIP/instruct_blip_vicuna13b_trimmed.pth"
|
12 |
+
finetuned: ""
|
13 |
+
|
14 |
+
# vit encoder
|
15 |
+
image_size: 224
|
16 |
+
drop_path_rate: 0
|
17 |
+
use_grad_checkpoint: False
|
18 |
+
vit_precision: "fp16"
|
19 |
+
freeze_vit: True
|
20 |
+
|
21 |
+
# Q-Former
|
22 |
+
num_query_token: 32
|
23 |
+
|
24 |
+
# path to Vicuna checkpoint
|
25 |
+
llm_model: "/mnt/petrelfs/share_data/wangbin/mllm/minigpt4/vicuna-13b-v1-1"
|
26 |
+
|
27 |
+
# generation configs
|
28 |
+
prompt: ""
|
29 |
+
|
30 |
+
|
31 |
+
preprocess:
|
32 |
+
vis_processor:
|
33 |
+
train:
|
34 |
+
name: "blip2_image_train"
|
35 |
+
image_size: 224
|
36 |
+
eval:
|
37 |
+
name: "blip_image_eval"
|
38 |
+
image_size: 224
|
39 |
+
text_processor:
|
40 |
+
train:
|
41 |
+
name: "blip_caption"
|
42 |
+
eval:
|
43 |
+
name: "blip_caption"
|
OPERA/minigpt4/configs/models/blip2_instruct_vicuna7b.yaml
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, salesforce.com, inc.
|
2 |
+
# All rights reserved.
|
3 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
4 |
+
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
5 |
+
|
6 |
+
model:
|
7 |
+
arch: instruct_vicuna7b
|
8 |
+
load_finetuned: False
|
9 |
+
load_pretrained: True
|
10 |
+
|
11 |
+
pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/InstructBLIP/instruct_blip_vicuna7b_trimmed.pth"
|
12 |
+
finetuned: ""
|
13 |
+
|
14 |
+
# vit encoder
|
15 |
+
image_size: 224
|
16 |
+
drop_path_rate: 0
|
17 |
+
use_grad_checkpoint: False
|
18 |
+
vit_precision: "fp16"
|
19 |
+
freeze_vit: True
|
20 |
+
|
21 |
+
# Q-Former
|
22 |
+
num_query_token: 32
|
23 |
+
|
24 |
+
# path to Vicuna checkpoint
|
25 |
+
llm_model: "/mnt/petrelfs/share_data/wangbin/mllm/minigpt4/vicuna-7b-v1-1"
|
26 |
+
|
27 |
+
# generation configs
|
28 |
+
prompt: ""
|
29 |
+
|
30 |
+
|
31 |
+
preprocess:
|
32 |
+
vis_processor:
|
33 |
+
train:
|
34 |
+
name: "blip2_image_train"
|
35 |
+
image_size: 224
|
36 |
+
eval:
|
37 |
+
name: "blip_image_eval"
|
38 |
+
image_size: 224
|
39 |
+
text_processor:
|
40 |
+
train:
|
41 |
+
name: "blip_caption"
|
42 |
+
eval:
|
43 |
+
name: "blip_caption"
|
OPERA/minigpt4/configs/models/llava-1.5_vicuna7b.yaml
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
arch: llava-1.5
|
3 |
+
version: 'v1.5'
|
4 |
+
|
5 |
+
# vit encoder
|
6 |
+
cache_dir: None
|
7 |
+
vit_model: "openai/clip-vit-large-patch14"
|
8 |
+
freeze_vit: True
|
9 |
+
|
10 |
+
# finetune config
|
11 |
+
freeze_backbone: False
|
12 |
+
tune_mm_mlp_adapter: False
|
13 |
+
freeze_mm_mlp_adapter: False
|
14 |
+
|
15 |
+
# model config
|
16 |
+
mm_vision_select_layer: -2
|
17 |
+
model_max_length: 2048
|
18 |
+
|
19 |
+
# data process config
|
20 |
+
image_token_len: 576
|
21 |
+
mm_use_im_start_end: True
|
22 |
+
|
23 |
+
# training config
|
24 |
+
bf16: False
|
25 |
+
fp16: True
|
26 |
+
|
27 |
+
|
28 |
+
preprocess:
|
29 |
+
vis_processor:
|
30 |
+
train:
|
31 |
+
name: "clip_image_train_336"
|
32 |
+
proc_type: "openai/clip-vit-large-patch14-336"
|
33 |
+
eval:
|
34 |
+
name: "clip_image_eval_336"
|
35 |
+
proc_type: "openai/clip-vit-large-patch14-336"
|
36 |
+
text_processor:
|
37 |
+
train:
|
38 |
+
name: "blip_caption"
|
39 |
+
eval:
|
40 |
+
name: "blip_caption"
|
OPERA/minigpt4/configs/models/minigpt4_llama2.yaml
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
arch: mini_gpt4
|
3 |
+
|
4 |
+
# vit encoder
|
5 |
+
image_size: 224
|
6 |
+
drop_path_rate: 0
|
7 |
+
use_grad_checkpoint: False
|
8 |
+
vit_precision: "fp16"
|
9 |
+
freeze_vit: True
|
10 |
+
has_qformer: False
|
11 |
+
|
12 |
+
# generation configs
|
13 |
+
prompt: ""
|
14 |
+
|
15 |
+
llama_model: "/path/to/llama2/weight"
|
16 |
+
|
17 |
+
preprocess:
|
18 |
+
vis_processor:
|
19 |
+
train:
|
20 |
+
name: "blip2_image_train"
|
21 |
+
image_size: 224
|
22 |
+
eval:
|
23 |
+
name: "blip2_image_eval"
|
24 |
+
image_size: 224
|
25 |
+
text_processor:
|
26 |
+
train:
|
27 |
+
name: "blip_caption"
|
28 |
+
eval:
|
29 |
+
name: "blip_caption"
|
OPERA/minigpt4/configs/models/minigpt4_vicuna0.yaml
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
arch: mini_gpt4
|
3 |
+
|
4 |
+
# vit encoder
|
5 |
+
image_size: 224
|
6 |
+
drop_path_rate: 0
|
7 |
+
use_grad_checkpoint: False
|
8 |
+
vit_precision: "fp16"
|
9 |
+
freeze_vit: True
|
10 |
+
freeze_qformer: True
|
11 |
+
|
12 |
+
# Q-Former
|
13 |
+
num_query_token: 32
|
14 |
+
|
15 |
+
# generation configs
|
16 |
+
prompt: ""
|
17 |
+
|
18 |
+
llama_model: "/mnt/petrelfs/share_data/wangbin/mllm/minigpt4/vicuna-7b-v0"
|
19 |
+
|
20 |
+
preprocess:
|
21 |
+
vis_processor:
|
22 |
+
train:
|
23 |
+
name: "blip2_image_train"
|
24 |
+
image_size: 224
|
25 |
+
eval:
|
26 |
+
name: "blip2_image_eval"
|
27 |
+
image_size: 224
|
28 |
+
text_processor:
|
29 |
+
train:
|
30 |
+
name: "blip_caption"
|
31 |
+
eval:
|
32 |
+
name: "blip_caption"
|
OPERA/minigpt4/configs/models/shikra_vicuna7b.yaml
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
arch: shikra
|
3 |
+
version: 'v1'
|
4 |
+
|
5 |
+
# vit encoder
|
6 |
+
cache_dir: None
|
7 |
+
vit_model: "openai/clip-vit-large-patch14"
|
8 |
+
freeze_vit: True
|
9 |
+
|
10 |
+
# finetune config
|
11 |
+
freeze_backbone: False
|
12 |
+
tune_mm_mlp_adapter: False
|
13 |
+
freeze_mm_mlp_adapter: False
|
14 |
+
|
15 |
+
# model config
|
16 |
+
mm_vision_select_layer: -2
|
17 |
+
model_max_length: 2048
|
18 |
+
|
19 |
+
# data process config
|
20 |
+
image_token_len: 256
|
21 |
+
mm_use_im_start_end: True
|
22 |
+
|
23 |
+
# training config
|
24 |
+
bf16: False
|
25 |
+
fp16: True
|
26 |
+
|
27 |
+
|
28 |
+
preprocess:
|
29 |
+
vis_processor:
|
30 |
+
train:
|
31 |
+
name: "clip_image_train"
|
32 |
+
proc_type: "openai/clip-vit-large-patch14"
|
33 |
+
eval:
|
34 |
+
name: "clip_image_eval"
|
35 |
+
proc_type: "openai/clip-vit-large-patch14"
|
36 |
+
text_processor:
|
37 |
+
train:
|
38 |
+
name: "blip_caption"
|
39 |
+
eval:
|
40 |
+
name: "blip_caption"
|