fffiloni commited on
Commit
2ada650
1 Parent(s): 2dd4459

Upload 164 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. LICENSE.md +14 -0
  3. LICENSE_Lavis.md +14 -0
  4. datasets/training_datasets/video_text_data/video_instruct_100/download_script.py +94 -0
  5. demo_job.sh +21 -0
  6. environment.yml +331 -0
  7. eval_video.py +221 -0
  8. jobs_video/eval/choose_best_ckpt/choose_best_ckpt.py +14 -0
  9. jobs_video/eval/choose_best_ckpt/evalualtion_ckpt.sh +17 -0
  10. jobs_video/eval/llama2_evalualtion.sh +37 -0
  11. jobs_video/eval/mistral_evalualtion.sh +39 -0
  12. jobs_video/eval/submit_job.py +19 -0
  13. jobs_video/train/stage_2_llama2.sh +23 -0
  14. jobs_video/train/stage_2_mistral.sh +23 -0
  15. jobs_video/train/stage_3_llama2.sh +23 -0
  16. jobs_video/train/stage_3_mistral.sh +23 -0
  17. minigpt4/__init__.py +31 -0
  18. minigpt4/common/__init__.py +0 -0
  19. minigpt4/common/config.py +474 -0
  20. minigpt4/common/dist_utils.py +146 -0
  21. minigpt4/common/eval_utils.py +224 -0
  22. minigpt4/common/gradcam.py +24 -0
  23. minigpt4/common/logger.py +195 -0
  24. minigpt4/common/optims.py +119 -0
  25. minigpt4/common/registry.py +330 -0
  26. minigpt4/common/utils.py +424 -0
  27. minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvalDemo.py +89 -0
  28. minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/__init__.py +1 -0
  29. minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/vqaEval.py +192 -0
  30. minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaDemo.py +73 -0
  31. minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/__init__.py +1 -0
  32. minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/vqa.py +179 -0
  33. minigpt4/common/vqa_tools/VQA/README.md +80 -0
  34. minigpt4/common/vqa_tools/__init__.py +8 -0
  35. minigpt4/common/vqa_tools/aokvqa/LICENSE +201 -0
  36. minigpt4/common/vqa_tools/aokvqa/README.md +207 -0
  37. minigpt4/common/vqa_tools/aokvqa/data_scripts/build_vocab.py +45 -0
  38. minigpt4/common/vqa_tools/aokvqa/data_scripts/encode_vocab_clip.py +26 -0
  39. minigpt4/common/vqa_tools/aokvqa/data_scripts/extract_bert_features.py +50 -0
  40. minigpt4/common/vqa_tools/aokvqa/data_scripts/extract_clip_features.py +51 -0
  41. minigpt4/common/vqa_tools/aokvqa/data_scripts/extract_resnet_features.py +62 -0
  42. minigpt4/common/vqa_tools/aokvqa/environment.yml +36 -0
  43. minigpt4/common/vqa_tools/aokvqa/evaluation/eval_predictions.py +97 -0
  44. minigpt4/common/vqa_tools/aokvqa/evaluation/load_aokvqa.py +13 -0
  45. minigpt4/common/vqa_tools/aokvqa/evaluation/prepare_predictions.py +31 -0
  46. minigpt4/common/vqa_tools/aokvqa/evaluation/remap_predictions.py +44 -0
  47. minigpt4/common/vqa_tools/aokvqa/gpt3/README.md +14 -0
  48. minigpt4/common/vqa_tools/aokvqa/gpt3/caption_inputs.py +23 -0
  49. minigpt4/common/vqa_tools/aokvqa/gpt3/query_gpt3.py +79 -0
  50. minigpt4/common/vqa_tools/aokvqa/gpt3/rationale_inputs.py +16 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ repo_imgs/sample_1.gif filter=lfs diff=lfs merge=lfs -text
37
+ repo_imgs/sample_2.gif filter=lfs diff=lfs merge=lfs -text
38
+ repo_imgs/sample_3.gif filter=lfs diff=lfs merge=lfs -text
LICENSE.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 3-Clause License
2
+
3
+ Copyright 2023 Deyao Zhu
4
+ All rights reserved.
5
+
6
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
9
+
10
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
11
+
12
+ 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
13
+
14
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
LICENSE_Lavis.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2022 Salesforce, Inc.
4
+ All rights reserved.
5
+
6
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
9
+
10
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
11
+
12
+ 3. Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
13
+
14
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
datasets/training_datasets/video_text_data/video_instruct_100/download_script.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from tqdm import tqdm
3
+ from pytubefix import YouTube
4
+
5
+ import xml.etree.ElementTree as ET
6
+ import os
7
+
8
+ with open ('VideoInstruct100K.json','r') as f :
9
+ data=json.load(f)
10
+
11
+ # Usage
12
+ existed_video_id={}
13
+ for video_name in os.listdir('videos'):
14
+ video_id = video_name.split('.')[0]
15
+ existed_video_id[video_id]=True
16
+
17
+
18
+
19
+ def download_video_with_subtitles(video_id):
20
+ # Create a YouTube object.
21
+ yt = YouTube(f'https://www.youtube.com/watch?v={video_id}')
22
+
23
+ video_filename = f"{video_id}.mp4"
24
+ video_downloaded=False
25
+ try :
26
+ # Get the video stream with the highest resolution and download the video.
27
+ stream = yt.streams.get_highest_resolution()
28
+ stream.download(output_path='videos', filename=video_filename)
29
+ video_downloaded=True
30
+ except Exception as e:
31
+ print(f"Error downloading video {video_id}: {str(e)}")
32
+ video_downloaded=False
33
+ if not video_downloaded:
34
+ return False,False
35
+
36
+ # Get the video's available captions (subtitles).
37
+ captions = yt.captions.all()
38
+
39
+ # Download the captions if available in xml format.
40
+ caption_downloaded = False
41
+ for caption in captions:
42
+ caption_code = caption.code
43
+ # select only english captions
44
+ if 'en' in caption_code:
45
+ caption.download(title=f"{video_id}", output_path='subtitles_xml',srt=False)
46
+ caption_downloaded = True
47
+ return video_downloaded,caption_downloaded
48
+ def convert_xml_vtt(xml_path, vtt_path):
49
+ # Parse the XML subtitle file
50
+ tree = ET.parse(xml_path)
51
+ root = tree.getroot()
52
+
53
+ # Initialize a list to store VTT subtitle entries
54
+ vtt_subtitle = []
55
+
56
+ # Function to convert time in milliseconds to WebVTT format
57
+ def ms_to_vtt_time(milliseconds):
58
+ seconds, milliseconds = divmod(milliseconds, 1000)
59
+ minutes, seconds = divmod(seconds, 60)
60
+ return f"{minutes:02d}:{seconds:02d}.{milliseconds:03d}"
61
+
62
+ # Iterate through subtitle elements
63
+ toggle = True
64
+ for p in root.findall(".//p"):
65
+ if toggle:
66
+ start_time = int(p.get("t"))
67
+ subtitle_text = " ".join(s.text.strip() for s in p.findall(".//s"))
68
+ # duration = int(p.get("d")) if p.get("d") is not None else 0
69
+ if not toggle:
70
+ end_time = int(p.get("t"))
71
+ # Format and append the VTT entry to the list
72
+ vtt_subtitle.append(f"{ms_to_vtt_time(start_time)} --> {ms_to_vtt_time(end_time)}\n{subtitle_text}\n")
73
+ toggle = not toggle
74
+ # Join the VTT entries into a single string
75
+ vtt_content = "WEBVTT\n\n" + "\n".join(vtt_subtitle)
76
+
77
+ # Save the VTT content to a file
78
+ with open(vtt_path, "w", encoding="utf-8") as vtt_file:
79
+ vtt_file.write(vtt_content)
80
+ import os
81
+ os.makedirs('videos', exist_ok=True)
82
+ os.makedirs('subtitles_vtt', exist_ok=True)
83
+ os.makedirs('subtitles_xml', exist_ok=True)
84
+ for video_path in tqdm(data,desc='Downloading videos') :
85
+ video_id=video_path.split('/')[-1].split('.')[0]
86
+ if existed_video_id.get(video_id,False):
87
+ continue
88
+ video_downloaded,caption_downloaded=download_video_with_subtitles(video_id)
89
+ if caption_downloaded:
90
+ # convert xml to vtt
91
+ xml_file_path=f'subtitles_xml/{video_id} (a.en).xml'
92
+ convert_xml_vtt(xml_file_path,f'subtitles_vtt/{video_id}.vtt')
93
+
94
+
demo_job.sh ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --partition=batch
3
+ #SBATCH --job-name=video_demo_llama2
4
+ #SBATCH --output=video_demo_llama2.out
5
+ #SBATCH --error=video_demo_llama2.err
6
+ #SBATCH --time=0-10:30:00
7
+ #SBATCH --mem=100G
8
+ #SBATCH --gres=gpu:a100:1
9
+ #SBATCH --nodes=1
10
+
11
+ # Choose the model to test
12
+ # Mistral
13
+ # ckpt="checkpoints/video_mistral_checkpoint_last.pth"
14
+ # config="test_configs/mistral_test_config.yaml"
15
+
16
+ # Llama2
17
+ ckpt="checkpoints/video_llama_checkpoint_last.pth"
18
+ config="test_configs/llama2_test_config.yaml"
19
+
20
+
21
+ python minigpt4_video_demo.py --cfg-path $config --ckpt $ckpt
environment.yml ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: minigpt4_video_test_v100
2
+ channels:
3
+ - conda-forge
4
+ dependencies:
5
+ - _libgcc_mutex=0.1=conda_forge
6
+ - _openmp_mutex=4.5=2_gnu
7
+ - archspec=0.2.2=pyhd8ed1ab_0
8
+ - boltons=23.1.1=pyhd8ed1ab_0
9
+ - brotli-python=1.1.0=py39h3d6467e_1
10
+ - bzip2=1.0.8=hd590300_5
11
+ - c-ares=1.25.0=hd590300_0
12
+ - ca-certificates=2024.2.2=hbcca054_0
13
+ - certifi=2024.2.2=pyhd8ed1ab_0
14
+ - cffi=1.16.0=py39h7a31438_0
15
+ - charset-normalizer=3.3.2=pyhd8ed1ab_0
16
+ - colorama=0.4.6=pyhd8ed1ab_0
17
+ - conda=23.11.0=py39hf3d152e_1
18
+ - conda-libmamba-solver=23.12.0=pyhd8ed1ab_0
19
+ - conda-package-handling=2.2.0=pyh38be061_0
20
+ - conda-package-streaming=0.9.0=pyhd8ed1ab_0
21
+ - cudatoolkit=11.8.0=h4ba93d1_12
22
+ - cudatoolkit-dev=11.7.0=h1de0b5d_6
23
+ - distro=1.9.0=pyhd8ed1ab_0
24
+ - faiss=1.7.4=py39cuda112h460e57a_0_cuda
25
+ - fmt=10.1.1=h00ab1b0_1
26
+ - freetype=2.12.1=h267a509_2
27
+ - gmp=6.1.2=hf484d3e_1000
28
+ - gnutls=3.5.19=h2a4e5f8_1
29
+ - icu=73.2=h59595ed_0
30
+ - idna=3.6=pyhd8ed1ab_0
31
+ - jsonpatch=1.33=pyhd8ed1ab_0
32
+ - jsonpointer=2.4=py39hf3d152e_3
33
+ - keyutils=1.6.1=h166bdaf_0
34
+ - krb5=1.21.2=h659d440_0
35
+ - ld_impl_linux-64=2.40=h41732ed_0
36
+ - libarchive=3.7.2=h2aa1ff5_1
37
+ - libblas=3.9.0=20_linux64_openblas
38
+ - libcblas=3.9.0=20_linux64_openblas
39
+ - libcurl=8.5.0=hca28451_0
40
+ - libedit=3.1.20191231=he28a2e2_2
41
+ - libev=4.33=hd590300_2
42
+ - libfaiss=1.7.4=cuda112hb18a002_0_cuda
43
+ - libfaiss-avx2=1.7.4=cuda112h1234567_0_cuda
44
+ - libffi=3.4.2=h7f98852_5
45
+ - libgcc-ng=13.2.0=h807b86a_3
46
+ - libgfortran-ng=13.2.0=h69a702a_3
47
+ - libgfortran5=13.2.0=ha4646dd_3
48
+ - libgomp=13.2.0=h807b86a_3
49
+ - libiconv=1.17=hd590300_2
50
+ - liblapack=3.9.0=20_linux64_openblas
51
+ - libmamba=1.5.6=had39da4_0
52
+ - libmambapy=1.5.6=py39h10defb6_0
53
+ - libnghttp2=1.58.0=h47da74e_1
54
+ - libnsl=2.0.1=hd590300_0
55
+ - libopenblas=0.3.25=pthreads_h413a1c8_0
56
+ - libpng=1.6.39=h753d276_0
57
+ - libsolv=0.7.27=hfc55251_0
58
+ - libsqlite=3.44.2=h2797004_0
59
+ - libssh2=1.11.0=h0841786_0
60
+ - libstdcxx-ng=13.2.0=h7e041cc_3
61
+ - libuuid=2.38.1=h0b41bf4_0
62
+ - libxcrypt=4.4.36=hd590300_1
63
+ - libxml2=2.12.3=h232c23b_0
64
+ - libzlib=1.2.13=hd590300_5
65
+ - lz4-c=1.9.4=hcb278e6_0
66
+ - lzo=2.10=h516909a_1000
67
+ - menuinst=2.0.1=py39hf3d152e_0
68
+ - ncurses=6.4=h59595ed_2
69
+ - nettle=3.3=0
70
+ - numpy=1.26.3=py39h474f0d3_0
71
+ - openh264=1.8.0=hdbcaa40_1000
72
+ - openssl=3.2.1=hd590300_0
73
+ - packaging=23.2=pyhd8ed1ab_0
74
+ - pip=23.3.2=pyhd8ed1ab_0
75
+ - platformdirs=4.1.0=pyhd8ed1ab_0
76
+ - pluggy=1.3.0=pyhd8ed1ab_0
77
+ - pybind11-abi=4=hd8ed1ab_3
78
+ - pycosat=0.6.6=py39hd1e30aa_0
79
+ - pycparser=2.21=pyhd8ed1ab_0
80
+ - pysocks=1.7.1=pyha2e5f31_6
81
+ - python=3.9.18=h0755675_1_cpython
82
+ - python_abi=3.9=4_cp39
83
+ - readline=8.2=h8228510_1
84
+ - reproc=14.2.4.post0=hd590300_1
85
+ - reproc-cpp=14.2.4.post0=h59595ed_1
86
+ - requests=2.31.0=pyhd8ed1ab_0
87
+ - ruamel.yaml=0.18.5=py39hd1e30aa_0
88
+ - ruamel.yaml.clib=0.2.7=py39hd1e30aa_2
89
+ - tk=8.6.13=noxft_h4845f30_101
90
+ - tqdm=4.66.1=pyhd8ed1ab_0
91
+ - urllib3=2.1.0=pyhd8ed1ab_0
92
+ - wheel=0.42.0=pyhd8ed1ab_0
93
+ - x264=1!152.20180717=h14c3975_1001
94
+ - xz=5.2.6=h166bdaf_0
95
+ - yaml-cpp=0.8.0=h59595ed_0
96
+ - zlib=1.2.13=hd590300_5
97
+ - zstandard=0.22.0=py39h6e5214e_0
98
+ - zstd=1.5.5=hfc55251_0
99
+ - pip:
100
+ - accelerate==0.25.0
101
+ - aiofiles==23.2.1
102
+ - aiohttp==3.9.1
103
+ - aiosignal==1.3.1
104
+ - altair==5.2.0
105
+ - annotated-types==0.6.0
106
+ - antlr4-python3-runtime==4.9.3
107
+ - anyio==4.2.0
108
+ - appdirs==1.4.4
109
+ - asgiref==3.7.2
110
+ - async-timeout==4.0.3
111
+ - attrs==23.2.0
112
+ - backoff==2.2.1
113
+ - bcrypt==4.1.2
114
+ - beautifulsoup4==4.12.2
115
+ - bitarray==2.9.2
116
+ - bitsandbytes==0.42.0
117
+ - bleach==6.1.0
118
+ - blinker==1.7.0
119
+ - braceexpand==0.1.7
120
+ - build==1.0.3
121
+ - cachetools==5.3.2
122
+ - chardet==5.2.0
123
+ - chroma-hnswlib==0.7.3
124
+ - chromadb==0.4.22
125
+ - click==8.1.7
126
+ - cmake==3.25.0
127
+ - colbert-ai==0.2.18
128
+ - coloredlogs==15.0.1
129
+ - contourpy==1.2.0
130
+ - cycler==0.12.1
131
+ - datasets==2.17.0
132
+ - decorator==4.4.2
133
+ - decord==0.6.0
134
+ - deprecated==1.2.14
135
+ - dill==0.3.8
136
+ - docker-pycreds==0.4.0
137
+ - docopt==0.6.2
138
+ - einops==0.7.0
139
+ - exceptiongroup==1.2.0
140
+ - faiss-gpu==1.7.2
141
+ - fastapi==0.108.0
142
+ - ffmpeg==1.4
143
+ - ffmpeg-python==0.2.0
144
+ - ffmpy==0.3.1
145
+ - filelock==3.13.1
146
+ - flash-attn==2.5.4
147
+ - flask==3.0.2
148
+ - flatbuffers==23.5.26
149
+ - fonttools==4.47.0
150
+ - frozenlist==1.4.1
151
+ - fsspec==2023.10.0
152
+ - ftfy==6.1.3
153
+ - future==0.18.3
154
+ - gdown==4.7.1
155
+ - git-python==1.0.3
156
+ - gitdb==4.0.11
157
+ - gitpython==3.1.40
158
+ - google-auth==2.26.1
159
+ - googleapis-common-protos==1.62.0
160
+ - gradio
161
+ - gradio-client
162
+ - h11==0.14.0
163
+ - h5py==3.10.0
164
+ - httpcore==1.0.2
165
+ - httptools==0.6.1
166
+ - httpx==0.26.0
167
+ - huggingface-hub==0.21.1
168
+ - humanfriendly==10.0
169
+ - imageio==2.33.1
170
+ - imageio-ffmpeg==0.4.9
171
+ - importlib-metadata==6.11.0
172
+ - importlib-resources==6.1.1
173
+ - inquirerpy==0.3.4
174
+ - iopath==0.1.10
175
+ - itsdangerous==2.1.2
176
+ - jinja2==3.1.2
177
+ - joblib==1.3.2
178
+ - jsonschema==4.20.0
179
+ - jsonschema-specifications==2023.12.1
180
+ - kaggle==1.6.0
181
+ - kiwisolver==1.4.5
182
+ - kubernetes==29.0.0
183
+ - lazy-loader==0.3
184
+ - lit==15.0.7
185
+ - llvmlite==0.41.1
186
+ - markdown-it-py==3.0.0
187
+ - matplotlib==3.8.2
188
+ - mdurl==0.1.2
189
+ - mmh3==4.1.0
190
+ - monotonic==1.6
191
+ - more-itertools==10.1.0
192
+ - moviepy==1.0.3
193
+ - mpmath==1.3.0
194
+ - multidict==6.0.4
195
+ - multiprocess==0.70.16
196
+ - mutagen==1.47.0
197
+ - networkx==3.2.1
198
+ - ninja==1.11.1.1
199
+ - nltk==3.8.1
200
+ - numba==0.58.1
201
+ - nvidia-cublas-cu11==11.10.3.66
202
+ - nvidia-cublas-cu12==12.1.3.1
203
+ - nvidia-cuda-cupti-cu12==12.1.105
204
+ - nvidia-cuda-nvrtc-cu11==11.7.99
205
+ - nvidia-cuda-nvrtc-cu12==12.1.105
206
+ - nvidia-cuda-runtime-cu11==11.7.99
207
+ - nvidia-cuda-runtime-cu12==12.1.105
208
+ - nvidia-cudnn-cu11==8.5.0.96
209
+ - nvidia-cudnn-cu12==8.9.2.26
210
+ - nvidia-cufft-cu12==11.0.2.54
211
+ - nvidia-curand-cu12==10.3.2.106
212
+ - nvidia-cusolver-cu12==11.4.5.107
213
+ - nvidia-cusparse-cu12==12.1.0.106
214
+ - nvidia-nccl-cu12==2.18.1
215
+ - nvidia-nvjitlink-cu12==12.3.101
216
+ - nvidia-nvtx-cu12==12.1.105
217
+ - omegaconf==2.3.0
218
+ - onnxruntime==1.16.3
219
+ - openai==0.28.0
220
+ - openai-whisper==20231117
221
+ - opencv-python==4.7.0.72
222
+ - opentelemetry-api==1.22.0
223
+ - opentelemetry-exporter-otlp-proto-common==1.22.0
224
+ - opentelemetry-exporter-otlp-proto-grpc==1.22.0
225
+ - opentelemetry-instrumentation==0.43b0
226
+ - opentelemetry-instrumentation-asgi==0.43b0
227
+ - opentelemetry-instrumentation-fastapi==0.43b0
228
+ - opentelemetry-proto==1.22.0
229
+ - opentelemetry-sdk==1.22.0
230
+ - opentelemetry-semantic-conventions==0.43b0
231
+ - opentelemetry-util-http==0.43b0
232
+ - orjson==3.9.10
233
+ - overrides==7.4.0
234
+ - pandas==2.0.0
235
+ - pathtools==0.1.2
236
+ - peft==0.2.0
237
+ - pfzy==0.3.4
238
+ - pillow==10.2.0
239
+ - plotly==5.18.0
240
+ - portalocker==2.8.2
241
+ - posthog==3.3.0
242
+ - proglog==0.1.10
243
+ - progressbar2==4.3.2
244
+ - prompt-toolkit==3.0.43
245
+ - protobuf==4.25.1
246
+ - psutil==5.9.7
247
+ - pulsar-client==3.4.0
248
+ - pyarrow==15.0.0
249
+ - pyarrow-hotfix==0.6
250
+ - pyasn1==0.5.1
251
+ - pyasn1-modules==0.3.0
252
+ - pycocoevalcap==1.2
253
+ - pycocotools==2.0.6
254
+ - pycryptodomex==3.19.1
255
+ - pydantic==2.5.3
256
+ - pydantic-core==2.14.6
257
+ - pydub==0.25.1
258
+ - pygments==2.17.2
259
+ - pyparsing==3.1.1
260
+ - pypika==0.48.9
261
+ - pyproject-hooks==1.0.0
262
+ - pysrt==1.1.2
263
+ - python-dateutil==2.8.2
264
+ - python-dotenv==1.0.0
265
+ - python-multipart==0.0.6
266
+ - python-slugify==8.0.1
267
+ - python-utils==3.8.1
268
+ - pytubefix
269
+ - pytz==2023.3.post1
270
+ - pyyaml==6.0.1
271
+ - referencing==0.32.0
272
+ - regex==2023.12.25
273
+ - rich==13.7.0
274
+ - rouge==1.0.1
275
+ - rpds-py==0.16.2
276
+ - rsa==4.9
277
+ - safetensors==0.4.1
278
+ - scikit-image==0.22.0
279
+ - scikit-learn==1.3.2
280
+ - scipy==1.11.4
281
+ - seaborn==0.13.1
282
+ - semantic-version==2.10.0
283
+ - sentence-transformers==2.2.2
284
+ - sentencepiece==0.1.97
285
+ - sentry-sdk==1.39.1
286
+ - setproctitle==1.3.3
287
+ - setuptools==69.0.3
288
+ - shellingham==1.5.4
289
+ - six==1.16.0
290
+ - smmap==5.0.1
291
+ - sniffio==1.3.0
292
+ - soundfile==0.12.1
293
+ - soupsieve==2.5
294
+ - starlette==0.32.0.post1
295
+ - sympy==1.12
296
+ - tenacity==8.2.3
297
+ - text-unidecode==1.3
298
+ - threadpoolctl==3.2.0
299
+ - tifffile==2023.12.9
300
+ - tiktoken==0.5.2
301
+ - timm==0.6.13
302
+ - tokenizers==0.15.2
303
+ - tomli==2.0.1
304
+ - tomlkit==0.12.0
305
+ - toolz==0.12.0
306
+ - torch==2.0.1
307
+ - torchaudio==2.0.2
308
+ - torchvision==0.15.2
309
+ - transformers==4.37.2
310
+ - triton==2.0.0
311
+ - typer==0.9.0
312
+ - typing-extensions==4.9.0
313
+ - tzdata==2023.4
314
+ - ujson==5.9.0
315
+ - uvicorn==0.25.0
316
+ - uvloop==0.19.0
317
+ - visual-genome==1.1.1
318
+ - wandb==0.14.2
319
+ - watchfiles==0.21.0
320
+ - wcwidth==0.2.13
321
+ - webdataset==0.2.48
322
+ - webencodings==0.5.1
323
+ - websocket-client==1.7.0
324
+ - websockets
325
+ - webvtt-py==0.4.6
326
+ - wrapt==1.16.0
327
+ - xxhash==3.4.1
328
+ - yarl==1.9.4
329
+ - youtube-dl==2021.12.17
330
+ - yt-dlp
331
+ - zipp
eval_video.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from tqdm import tqdm
4
+ from torch.utils.data import DataLoader
5
+ from minigpt4.common.eval_utils import prepare_texts, init_model, eval_parser
6
+ from minigpt4.conversation.conversation import CONV_VISION
7
+ from minigpt4.processors.blip_processors import Blip2ImageTrainProcessor,BlipCaptionProcessor
8
+ from minigpt4.datasets.datasets.video_datasets import VideoChatGPTEvalDataset,VideoChatGPTEval_consistancy,Video_validation_Dataset,TVQAEVAL,TVQAEVAL_Long
9
+
10
+ parser = eval_parser()
11
+ parser.add_argument("--dataset", type=str, default='msvd', help="dataset to evaluate")
12
+ parser.add_argument("--add_subtitles",action='store_true',help="whether to add subtitles to the video")
13
+ parser.add_argument("--name", type=str, default='3_datasets', help="evaluation name")
14
+ parser.add_argument("--batch_size", type=int, default=1, help="batch size")
15
+ parser.add_argument("--start", type=int, default=0, help="start from video number")
16
+ parser.add_argument("--end", type=int, default=10000000, help="end at video number")
17
+ args = parser.parse_args()
18
+
19
+ print(args.ckpt)
20
+ print(args.name)
21
+ print(args.cfg_path)
22
+ if "test_configs/mistral_test_config.yaml" == args.cfg_path:
23
+ llm_name="mistral"
24
+ else:
25
+ llm_name="llama2"
26
+ print("using captions",args.add_subtitles)
27
+
28
+ model, vis_processor = init_model(args)
29
+ conv_temp = CONV_VISION.copy()
30
+ conv_temp.system = ""
31
+ if args.dataset == 'video_chatgpt_generic':
32
+ ann_path="datasets/evaluation_datasets/videochatgpt_benchmark/generic_qa.json"
33
+ videos_path="/ibex/project/c2090/datasets/VideoInstruct100K/test_videos/Test_Videos"
34
+ subtitles_path="/home/ataallka/minigpt_video/minigpt_multi_img/inference_subtitles"
35
+ videos_features_path="/ibex/project/c2106/kirolos/videos_features/evaluation/benchmark/generic"
36
+ annotations_keys=['Q','A','video_name']
37
+ data = VideoChatGPTEvalDataset(vis_processor, videos_path, ann_path,subtitles_path,annotations_keys,videos_features_path, add_subtitles=args.add_subtitles,llm_name=llm_name)
38
+ elif args.dataset == 'video_chatgpt_temporal':
39
+ ann_path="datasets/evaluation_datasets/videochatgpt_benchmark/temporal_qa.json"
40
+ videos_path="/ibex/project/c2090/datasets/VideoInstruct100K/test_videos/Test_Videos"
41
+ subtitles_path="/home/ataallka/minigpt_video/minigpt_multi_img/inference_subtitles"
42
+ videos_features_path="/ibex/project/c2106/kirolos/videos_features/evaluation/benchmark/temporal"
43
+ annotations_keys=['Q','A','video_name']
44
+ data = VideoChatGPTEvalDataset(vis_processor, videos_path, ann_path,subtitles_path,annotations_keys,videos_features_path, add_subtitles=args.add_subtitles,llm_name=llm_name)
45
+ elif args.dataset == 'video_chatgpt_consistency':
46
+ ann_path="datasets/evaluation_datasets/videochatgpt_benchmark/consistency_qa.json"
47
+ videos_path="/ibex/project/c2090/datasets/VideoInstruct100K/test_videos/Test_Videos"
48
+ subtitles_path="/home/ataallka/minigpt_video/minigpt_multi_img/inference_subtitles"
49
+ annotations_keys=[['Q1','Q2'],'A','video_name']
50
+ data = VideoChatGPTEval_consistancy(vis_processor, videos_path, ann_path,subtitles_path,annotations_keys, add_subtitles=args.add_subtitles,llm_name=llm_name)
51
+
52
+ elif args.dataset == 'msrvtt':
53
+ ann_path="datasets/evaluation_datasets/msrvtt/val_qa_edited.json"
54
+ videos_path="/ibex/project/c2090/datasets/VideoInstruct100K/test_videos/MSRVTT/videos/all"
55
+ subtitles_path="/home/ataallka/minigpt_video/minigpt_multi_img/inference_subtitles"
56
+ videos_features_path="/ibex/project/c2106/kirolos/videos_features/evaluation/msrvtt"
57
+ annotations_keys=['question','answer','video_id']
58
+ data = VideoChatGPTEvalDataset(vis_processor, videos_path, ann_path,subtitles_path,annotations_keys,videos_features_path, add_subtitles=args.add_subtitles,llm_name=llm_name)
59
+
60
+ elif args.dataset == 'msvd':
61
+ ann_path="datasets/evaluation_datasets/msvd/val_qa_edited.json"
62
+ videos_path="/ibex/project/c2090/datasets/VideoInstruct100K/test_videos/MSVD-QA/videos"
63
+ subtitles_path="/home/ataallka/minigpt_video/minigpt_multi_img/inference_subtitles"
64
+ videos_features_path="/ibex/project/c2106/kirolos/videos_features/evaluation/msvd"
65
+ annotations_keys=['question','answer','video_id']
66
+ data = VideoChatGPTEvalDataset(vis_processor, videos_path, ann_path,subtitles_path,annotations_keys,videos_features_path, add_subtitles=args.add_subtitles,llm_name=llm_name)
67
+ elif args.dataset == 'activitynet':
68
+ ann_path="datasets/evaluation_datasets/activityNet/test_qa.json"
69
+ videos_path="/ibex/project/c2090/datasets/VideoInstruct100K/test_videos/Activity_net/Activity_net_videos"
70
+ subtitles_path="/home/ataallka/minigpt_video/minigpt_multi_img/inference_subtitles/"
71
+ videos_features_path="/ibex/project/c2106/kirolos/videos_features/evaluation/activity_net"
72
+ annotations_keys=['question','answer','video_id']
73
+ data = VideoChatGPTEvalDataset(vis_processor, videos_path, ann_path,subtitles_path,annotations_keys,videos_features_path, add_subtitles=args.add_subtitles,llm_name=llm_name)
74
+ elif args.dataset == 'tgif':
75
+ ann_path="datasets/evaluation_datasets/tgif/Test_frameqa_question.json"
76
+ videos_path="/ibex/project/c2090/datasets/VideoInstruct100K/test_videos/TGIF/mp4s"
77
+ subtitles_path="/home/ataallka/minigpt_video/minigpt_multi_img/inference_subtitles"
78
+ videos_features_path="/ibex/project/c2106/kirolos/videos_features/evaluation/tgif"
79
+ annotations_keys=['question','answer','gif_name']
80
+ # annotations_keys=['question','description','gif_name']
81
+ data = VideoChatGPTEvalDataset(vis_processor, videos_path, ann_path,subtitles_path,annotations_keys,videos_features_path, add_subtitles=False,llm_name=llm_name)
82
+ elif args.dataset == 'tvqa':
83
+ # TVQA dataset
84
+ ann_path="datasets/evaluation_datasets/tvqa_short/tvqa_val.json"
85
+ videos_path= "/ibex/project/c2090/datasets/TVR_dataset/videos/video_files/frames_hq/"
86
+ subtitles_path="/ibex/project/c2090/datasets/TVR_dataset/TVRetrieval/data/tvqa_preprocessed_subtitles.json"
87
+ videos_features_path="/ibex/project/c2106/kirolos/videos_features/evaluation/tvqa"
88
+ data = TVQAEVAL(vis_processor, videos_path, ann_path,subtitles_path,videos_features_path,add_subtitles=args.add_subtitles,llm_name=llm_name)
89
+
90
+ eval_dataloader = DataLoader(data, batch_size=args.batch_size, shuffle=False)
91
+
92
+ minigpt4_predict = []
93
+ sub="subtitles" if args.add_subtitles else "no_subtitles"
94
+ if args.start == 0 and args.end == 10000000:
95
+ save_path = f'results/{args.name}_{args.dataset}_{sub}.json'
96
+ else:
97
+ print("start from video number",args.start)
98
+ print("end at video number",args.end)
99
+ save_path = f'results/{args.name}_{args.dataset}_{sub}_{args.start}_{args.end}.json'
100
+
101
+ os.makedirs("results", exist_ok=True)
102
+ c=0
103
+ pred_result = {}
104
+ gt_result = {}
105
+ if args.dataset == 'video_chatgpt_consistency':
106
+ for images, texts_1,texts_2, gt_answers, lengths,videos_ids in tqdm(eval_dataloader,desc=f"Eval {args.dataset}"):
107
+ if args.start<= c <args.end :
108
+ texts_q1 = prepare_texts(texts_1, conv_temp, template='', lengths=lengths) # warp the texts with conversation template
109
+ texts_q2 = prepare_texts(texts_2, conv_temp, template='', lengths=lengths) # warp the texts with conversation template
110
+ models_answers_q1 = model.generate(images, texts_q1, max_new_tokens=args.max_new_tokens, do_sample=False, lengths=lengths,num_beams=1)
111
+ models_answers_q2 = model.generate(images, texts_q2, max_new_tokens=args.max_new_tokens, do_sample=False, lengths=lengths,num_beams=1)
112
+ for video_id,model_answer_q1,model_answer_q2, gt_answer,text_q1,text_q2 in zip(videos_ids,models_answers_q1,models_answers_q2, gt_answers,texts_q1,texts_q2):
113
+ result = dict()
114
+ result['video_name'] = video_id
115
+ result['Q1'] = text_q1.split('\n')[-1].replace('[/INST]','')
116
+ result['Q2'] = text_q2.split('\n')[-1].replace('[/INST]','')
117
+ result['A'] = gt_answer
118
+ result['pred1'] = model_answer_q1
119
+ result['pred2'] = model_answer_q2
120
+ pred_result[video_id] = [model_answer_q1,model_answer_q2]
121
+ gt_result[video_id] = [gt_answer]
122
+ minigpt4_predict.append(result)
123
+ # save results every 100 videos to avoid losing results
124
+ if c%100==0:
125
+ with open(save_path, 'w') as f:
126
+ json.dump(minigpt4_predict, f)
127
+ if c >= args.end :
128
+ break
129
+ c+=1
130
+
131
+ elif args.dataset == 'tvr':
132
+ for images, texts, gt_answers, lengths,videos_ids in tqdm(eval_dataloader,desc=f"Eval {args.dataset}"):
133
+ if args.start<= c <args.end :
134
+ texts = prepare_texts(texts, conv_temp, template='', lengths=lengths) # warp the texts with conversation template
135
+ models_answers = model.generate(images, texts, max_new_tokens=args.max_new_tokens, do_sample=False, lengths=lengths,num_beams=1)
136
+ for video_id,model_answer, gt_answer,text in zip(videos_ids,models_answers, gt_answers,texts):
137
+ result = dict()
138
+ result['video_name'] = video_id
139
+ result['Q'] = text.split('\n')[-1].replace('[/INST]','')
140
+ result['A'] = gt_answer
141
+ result['pred'] = model_answer
142
+ pred_result[video_id] = [model_answer]
143
+ gt_result[video_id] = [gt_answer]
144
+ minigpt4_predict.append(result)
145
+ # save results every 100 videos to avoid losing results
146
+ if c%100==0:
147
+ with open(save_path, 'w') as f:
148
+ json.dump(minigpt4_predict, f)
149
+ if c >= args.end :
150
+ break
151
+ c+=1
152
+ elif args.dataset == 'ego_schema' or args.dataset == 'tvqa' or args.dataset == 'tvqa_long_videos':
153
+ for images, texts, gt_answers, lengths,videos_ids in tqdm(eval_dataloader,desc=f"Eval {args.dataset}"):
154
+ if args.start<= c <args.end :
155
+ texts = prepare_texts(texts, conv_temp, template='', lengths=lengths) # warp the texts with conversation template
156
+ models_answers = model.generate(images, texts, max_new_tokens=args.max_new_tokens, do_sample=False, lengths=lengths,num_beams=1)
157
+ for video_id,model_answer, gt_answer,text in zip(videos_ids,models_answers, gt_answers,texts):
158
+ result = dict()
159
+ result['video_name'] = video_id
160
+ if args.dataset == 'tvqa_long_videos':
161
+ result['Q'] = text.split('\n\n')[1:]
162
+ else:
163
+ result['Q'] = text.split('\n')[1:]
164
+ result['A'] = gt_answer
165
+ result['pred'] = model_answer
166
+ pred_result[video_id] = [model_answer]
167
+ gt_result[video_id] = [gt_answer]
168
+ minigpt4_predict.append(result)
169
+ # save results every 100 videos to avoid losing results
170
+ if c%100==0:
171
+ with open(save_path, 'w') as f:
172
+ json.dump(minigpt4_predict, f)
173
+ if c >= args.end :
174
+ break
175
+ c+=1
176
+ else:
177
+ for images, texts, gt_answers, lengths,videos_ids in tqdm(eval_dataloader,desc=f"Eval {args.dataset}"):
178
+ if args.start<= c <args.end :
179
+ texts = prepare_texts(texts, conv_temp, template='', lengths=lengths) # warp the texts with conversation template
180
+ models_answers = model.generate(images, texts, max_new_tokens=args.max_new_tokens, do_sample=False, lengths=lengths,num_beams=1)
181
+ for video_id,model_answer, gt_answer,text in zip(videos_ids,models_answers, gt_answers,texts):
182
+ result = dict()
183
+ result['video_name'] = video_id
184
+ result['Q'] = text.split('\n')[-1].replace('[/INST]','')
185
+ result['A'] = gt_answer
186
+ result['pred'] = model_answer
187
+ pred_result[video_id] = [model_answer]
188
+ gt_result[video_id] = [gt_answer]
189
+ minigpt4_predict.append(result)
190
+ # save results every 100 videos to avoid losing results
191
+ if c%100==0:
192
+ with open(save_path, 'w') as f:
193
+ json.dump(minigpt4_predict, f)
194
+ if c >= args.end :
195
+ break
196
+ c+=1
197
+
198
+ with open(save_path, 'w') as f:
199
+ json.dump(minigpt4_predict, f)
200
+ print("saved results to",save_path)
201
+ # save results
202
+ # bleu_save_path = f'results/{args.name}_{args.dataset}_bleu.json'
203
+ # cider_save_path = f'results/{args.name}_{args.dataset}_cider.json'
204
+ # chatgpt_eval_save_path = f'results/{args.name}_{args.dataset}_chatgpt_eval.json'
205
+ # bleu_results=eval_bleu(minigpt4_predict)
206
+ # with open(bleu_save_path, 'w') as f:
207
+ # json.dump(bleu_results, f)
208
+ # print("bleu_results",bleu_results)
209
+ # cider_results=eval_cider(pred_result,gt_result)
210
+ # with open(cider_save_path, 'w') as f:
211
+ # json.dump(cider_results, f)
212
+ # print("mean_cider_scores:",cider_results['mean_cider_scores'])
213
+
214
+ # chatgpt_results=chat_gpt_eval(pred_result,gt_result)
215
+
216
+ # with open(chatgpt_eval_save_path, 'w') as f:
217
+ # json.dump(chatgpt_results, f)
218
+ # print("avg_chatgpt_score",chatgpt_results['avg_chatgpt_score'])
219
+ # print(chatgpt_results)
220
+
221
+
jobs_video/eval/choose_best_ckpt/choose_best_ckpt.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ ckpt_dir = 'ckpt_dir'
4
+ print(f'number of ckpts: {len(os.listdir(ckpt_dir))}')
5
+ for ckpt in sorted(os.listdir(ckpt_dir)):
6
+ if not ckpt.endswith('.pth'):
7
+ continue
8
+ ckpt_path = os.path.join(ckpt_dir,ckpt)
9
+ job_name="cmd_webvid_video_instruct_"+ckpt.split(".")[0]
10
+ # submit a job with this ckpt file
11
+ os.system(f'sbatch ./evalualtion_ckpt.sh {ckpt_path} {job_name}')
12
+ # print(f'sbatch ./evalualtion_ckpt.sh {ckpt_path} {job_name}')
13
+ # print(f'job {job_name} submitted')
14
+ # break
jobs_video/eval/choose_best_ckpt/evalualtion_ckpt.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --partition=batch
3
+ #SBATCH --job-name=val%j
4
+ #SBATCH --output=val%j.out
5
+ #SBATCH --error=val%j.err
6
+ #SBATCH --time=0-10:00:00
7
+ #SBATCH --mem=100G
8
+ #SBATCH --gres=gpu:a100:1
9
+ #SBATCH --nodes=1
10
+ ## run the application:
11
+ NAME=$2 # Name of the experiment
12
+ DATASET="dataset_name" # available datasets: tvqa, msrvtt, msvd, activitynet,tgif,video_chatgpt_generic,video_chatgpt_temporal,video_chatgpt_consistency
13
+ BATCH_SIZE=2 # batch size
14
+ CKPT_PATH=$1 # path to the checkpoint
15
+ cfg_path="test_configs/mistral_test_config.yaml" # path to the config file
16
+ cd ../../../
17
+ python eval_video.py --dataset $DATASET --batch_size $BATCH_SIZE --name $NAME --ckpt $CKPT_PATH --cfg-path=$cfg_path --add_subtitles
jobs_video/eval/llama2_evalualtion.sh ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --partition=batch
3
+ #SBATCH --job-name=llama2_best%j
4
+ #SBATCH --output=llama2_best%j.out
5
+ #SBATCH --error=llama2_best%j.err
6
+ #SBATCH --time=0-23:00:00
7
+ #SBATCH --mem=100G
8
+ #SBATCH --gres=gpu:a100:1
9
+ #SBATCH --nodes=1
10
+ ## run the application:
11
+ NAME="llama2_best" # Name of the experiment
12
+ DATASET="tvqa" # available datasets: tvqa, msrvtt, msvd, activitynet,tgif ,video_chatgpt_generic,video_chatgpt_temporal,video_chatgpt_consistency
13
+ BATCH_SIZE=8
14
+ CKPT_PATH="checkpoints/video_llama_checkpoint_last.pth" # path to the checkpoint
15
+ cfg_path="test_configs/llama2_test_config.yaml" # path to the config file
16
+ # # if the number of samples are large you can specify the start and end index to evaluate on several machines
17
+ # pass the start and end index as arguments
18
+ start=$1 # start index
19
+ end=$2 # end index
20
+ # if start and end are not provided, then use the whole dataset
21
+ if [ -z "$START" ]
22
+ then
23
+ START=0
24
+ fi
25
+ if [ -z "$END" ]
26
+ then
27
+ END=10000000
28
+ fi
29
+ echo "Start: $START"
30
+ echo "End: $END"
31
+
32
+ cd ../../
33
+ # without subtitles
34
+ python eval_video.py --dataset $DATASET --batch_size $BATCH_SIZE --name $NAME --ckpt $CKPT_PATH --cfg-path=$cfg_path --start $start --end $end
35
+
36
+ # with subtitles
37
+ # python eval_video.py --dataset $DATASET --batch_size $BATCH_SIZE --name $NAME --ckpt $CKPT_PATH --cfg-path=$cfg_path --add_subtitles --start $start --end $end
jobs_video/eval/mistral_evalualtion.sh ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --partition=batch
3
+ #SBATCH --mail-user=kirolos.ataallah@kaust.edu.sa
4
+ #SBATCH --mail-type=ALL
5
+ #SBATCH --job-name=mistral_best%j
6
+ #SBATCH --output=mistral_best%j.out
7
+ #SBATCH --error=mistral_best%j.err
8
+ #SBATCH --time=0-23:00:00
9
+ #SBATCH --mem=100G
10
+ #SBATCH --gres=gpu:a100:1
11
+ #SBATCH --nodes=1
12
+ ## run the application:
13
+ NAME="mistral_best" # Name of the experiment
14
+ DATASET="tvqa" # available datasets: tvqa, msrvtt, msvd, activitynet,tgif,video_chatgpt_generic,video_chatgpt_temporal,video_chatgpt_consistency
15
+ BATCH_SIZE=4 # batch size for A100 by using subtiles is 2 and without subtitles is 4
16
+ CKPT_PATH="checkpoints/video_mistral_checkpoint_best.pth" # path to the checkpoint
17
+ cfg_path="test_configs/mistral_test_config.yaml" # path to the config file
18
+ # # if the number of samples are large you can specify the start and end index to evaluate on several machines
19
+ # pass the start and end index as arguments
20
+ start=$1 # start index
21
+ end=$2 # end index
22
+ # if start and end are not provided, then use the whole dataset
23
+ if [ -z "$START" ]
24
+ then
25
+ START=0
26
+ fi
27
+ if [ -z "$END" ]
28
+ then
29
+ END=10000000
30
+ fi
31
+ echo "Start: $START"
32
+ echo "End: $END"
33
+
34
+ cd ../../
35
+ # without subtitles
36
+ python eval_video.py --dataset $DATASET --batch_size $BATCH_SIZE --name $NAME --ckpt $CKPT_PATH --cfg-path=$cfg_path --start $start --end $end
37
+
38
+ # with subtitles
39
+ # python eval_video.py --dataset $DATASET --batch_size $BATCH_SIZE --name $NAME --ckpt $CKPT_PATH --cfg-path=$cfg_path --add_subtitles --start $start --end $end
jobs_video/eval/submit_job.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import sys
4
+
5
+ start=0
6
+ end=7800
7
+ step=800
8
+
9
+ # Mistral
10
+ for i in range(start,end,step):
11
+ cmd=f'sbatch ./mistral_evalualtion.sh {i} {i+step}'
12
+ # print(cmd)
13
+ os.system(cmd)
14
+
15
+ # Llama 2
16
+ # for i in range(start,end,step):
17
+ # cmd=f'sbatch ./llama2_evalualtion.sh {i} {i+step}'
18
+ # # print(cmd)
19
+ # os.system(cmd)
jobs_video/train/stage_2_llama2.sh ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --partition=batch
3
+ #SBATCH --job-name=test
4
+ #SBATCH --output=test.out
5
+ #SBATCH --error=test.err
6
+ #SBATCH --time=23:00:00
7
+ #SBATCH --mem=110G
8
+ #SBATCH --gres=gpu:a100:4
9
+ #SBATCH --cpus-per-task=16
10
+ ## run the application:
11
+ job_name=test # Name of the experiment
12
+ cfg_path="train_configs/224_v2_llama2_video_stage_2.yaml" # path to the config file
13
+ number_of_gpus=1 # number of gpus
14
+ # cd ../../
15
+
16
+ read LOWERPORT UPPERPORT < /proc/sys/net/ipv4/ip_local_port_range
17
+ while :
18
+ do
19
+ PORT="`shuf -i $LOWERPORT-$UPPERPORT -n 1`"
20
+ ss -lpn | grep -q ":$PORT " || break
21
+ done
22
+ echo "Port is $PORT"
23
+ torchrun --master-port ${PORT} --nproc-per-node $number_of_gpus train.py --job_name ${job_name} --cfg-path ${cfg_path}
jobs_video/train/stage_2_mistral.sh ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --partition=batch
3
+ #SBATCH --job-name=test
4
+ #SBATCH --output=test.out
5
+ #SBATCH --error=test.err
6
+ #SBATCH --time=23:00:00
7
+ #SBATCH --mem=110G
8
+ #SBATCH --gres=gpu:a100:4
9
+ #SBATCH --cpus-per-task=16
10
+ ## run the application:
11
+ job_name=test # Name of the experiment
12
+ cfg_path="train_configs/224_v2_mistral_video_stage_2.yaml" # path to the config file
13
+ number_of_gpus=1 # number of gpus
14
+ # cd ../../
15
+
16
+ read LOWERPORT UPPERPORT < /proc/sys/net/ipv4/ip_local_port_range
17
+ while :
18
+ do
19
+ PORT="`shuf -i $LOWERPORT-$UPPERPORT -n 1`"
20
+ ss -lpn | grep -q ":$PORT " || break
21
+ done
22
+ echo "Port is $PORT"
23
+ torchrun --master-port ${PORT} --nproc-per-node $number_of_gpus train.py --job_name ${job_name} --cfg-path ${cfg_path}
jobs_video/train/stage_3_llama2.sh ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --partition=batch
3
+ #SBATCH --job-name=test
4
+ #SBATCH --output=test.out
5
+ #SBATCH --error=test.err
6
+ #SBATCH --time=23:00:00
7
+ #SBATCH --mem=110G
8
+ #SBATCH --gres=gpu:a100:4
9
+ #SBATCH --cpus-per-task=16
10
+ ## run the application:
11
+ job_name="test" # Name of the experiment
12
+ cfg_path="train_configs/224_v2_llama2_video_stage_3.yaml" # path to the config file
13
+ number_of_gpus=1 # number of gpus
14
+ # cd ../../
15
+
16
+ read LOWERPORT UPPERPORT < /proc/sys/net/ipv4/ip_local_port_range
17
+ while :
18
+ do
19
+ PORT="`shuf -i $LOWERPORT-$UPPERPORT -n 1`"
20
+ ss -lpn | grep -q ":$PORT " || break
21
+ done
22
+ echo "Port is $PORT"
23
+ torchrun --master-port ${PORT} --nproc-per-node $number_of_gpus train.py --job_name ${job_name} --cfg-path ${cfg_path}
jobs_video/train/stage_3_mistral.sh ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --partition=batch
3
+ #SBATCH --job-name=test
4
+ #SBATCH --output=test.out
5
+ #SBATCH --error=test.err
6
+ #SBATCH --time=23:00:00
7
+ #SBATCH --mem=110G
8
+ #SBATCH --gres=gpu:a100:4
9
+ #SBATCH --cpus-per-task=16
10
+ ## run the application:
11
+ job_name="test" # Name of the experiment
12
+ cfg_path="train_configs/224_v2_mistral_video_stage_3.yaml" # path to the config file
13
+ number_of_gpus=1 # number of gpus
14
+ # cd ../../
15
+
16
+ read LOWERPORT UPPERPORT < /proc/sys/net/ipv4/ip_local_port_range
17
+ while :
18
+ do
19
+ PORT="`shuf -i $LOWERPORT-$UPPERPORT -n 1`"
20
+ ss -lpn | grep -q ":$PORT " || break
21
+ done
22
+ echo "Port is $PORT"
23
+ torchrun --master-port ${PORT} --nproc-per-node $number_of_gpus train.py --job_name ${job_name} --cfg-path ${cfg_path}
minigpt4/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import os
9
+ import sys
10
+
11
+ from omegaconf import OmegaConf
12
+
13
+ from minigpt4.common.registry import registry
14
+
15
+ from minigpt4.datasets.builders import *
16
+ from minigpt4.models import *
17
+ from minigpt4.processors import *
18
+ from minigpt4.tasks import *
19
+
20
+
21
+ root_dir = os.path.dirname(os.path.abspath(__file__))
22
+ default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml"))
23
+
24
+ registry.register_path("library_root", root_dir)
25
+ repo_root = os.path.join(root_dir, "..")
26
+ registry.register_path("repo_root", repo_root)
27
+ cache_root = os.path.join(repo_root, default_cfg.env.cache_root)
28
+ registry.register_path("cache_root", cache_root)
29
+
30
+ registry.register("MAX_INT", sys.maxsize)
31
+ registry.register("SPLIT_NAMES", ["train", "val", "test"])
minigpt4/common/__init__.py ADDED
File without changes
minigpt4/common/config.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import logging
9
+ import json
10
+ from typing import Dict
11
+
12
+ from omegaconf import OmegaConf
13
+ from minigpt4.common.registry import registry
14
+
15
+
16
+ class Config:
17
+ def __init__(self, args):
18
+ self.config = {}
19
+
20
+ self.args = args
21
+
22
+ # Register the config and configuration for setup
23
+ registry.register("configuration", self)
24
+
25
+ user_config = self._build_opt_list(self.args.options)
26
+
27
+ config = OmegaConf.load(self.args.cfg_path)
28
+
29
+ runner_config = self.build_runner_config(config)
30
+ model_config = self.build_model_config(config, **user_config)
31
+ dataset_config = self.build_dataset_config(config)
32
+
33
+ # Validate the user-provided runner configuration
34
+ # model and dataset configuration are supposed to be validated by the respective classes
35
+ # [TODO] validate the model/dataset configuration
36
+ # self._validate_runner_config(runner_config)
37
+
38
+ # Override the default configuration with user options.
39
+ self.config = OmegaConf.merge(
40
+ runner_config, model_config, dataset_config, user_config
41
+ )
42
+
43
+ def _validate_runner_config(self, runner_config):
44
+ """
45
+ This method validates the configuration, such that
46
+ 1) all the user specified options are valid;
47
+ 2) no type mismatches between the user specified options and the config.
48
+ """
49
+ runner_config_validator = create_runner_config_validator()
50
+ runner_config_validator.validate(runner_config)
51
+
52
+ def _build_opt_list(self, opts):
53
+ opts_dot_list = self._convert_to_dot_list(opts)
54
+ return OmegaConf.from_dotlist(opts_dot_list)
55
+
56
+ @staticmethod
57
+ def build_model_config(config, **kwargs):
58
+ model = config.get("model", None)
59
+ assert model is not None, "Missing model configuration file."
60
+
61
+ model_cls = registry.get_model_class(model.arch)
62
+ assert model_cls is not None, f"Model '{model.arch}' has not been registered."
63
+
64
+ model_type = kwargs.get("model.model_type", None)
65
+ if not model_type:
66
+ model_type = model.get("model_type", None)
67
+ # else use the model type selected by user.
68
+
69
+ assert model_type is not None, "Missing model_type."
70
+
71
+ print("--------------")
72
+ print("model arch",model.arch)
73
+ print("model cls",model_cls)
74
+
75
+ model_config_path = model_cls.default_config_path(model_type=model_type)
76
+
77
+ model_config = OmegaConf.create()
78
+ # hierarchy override, customized config > default config
79
+ model_config = OmegaConf.merge(
80
+ model_config,
81
+ OmegaConf.load(model_config_path),
82
+ {"model": config["model"]},
83
+ )
84
+
85
+ return model_config
86
+
87
+ @staticmethod
88
+ def build_runner_config(config):
89
+ return {"run": config.run}
90
+
91
+ @staticmethod
92
+ def build_dataset_config(config):
93
+ datasets = config.get("datasets", None)
94
+ if datasets is None:
95
+ raise KeyError(
96
+ "Expecting 'datasets' as the root key for dataset configuration."
97
+ )
98
+
99
+ dataset_config = OmegaConf.create()
100
+
101
+ for dataset_name in datasets:
102
+
103
+ print("dataset name", dataset_name)
104
+ builder_cls = registry.get_builder_class(dataset_name)
105
+
106
+ dataset_config_type = datasets[dataset_name].get("type", "default")
107
+ dataset_config_path = builder_cls.default_config_path(
108
+ type=dataset_config_type
109
+ )
110
+
111
+ # hierarchy override, customized config > default config
112
+ dataset_config = OmegaConf.merge(
113
+ dataset_config,
114
+ OmegaConf.load(dataset_config_path),
115
+ {"datasets": {dataset_name: config["datasets"][dataset_name]}},
116
+ )
117
+
118
+ return dataset_config
119
+
120
+ def _convert_to_dot_list(self, opts):
121
+ if opts is None:
122
+ opts = []
123
+
124
+ if len(opts) == 0:
125
+ return opts
126
+
127
+ has_equal = opts[0].find("=") != -1
128
+
129
+ if has_equal:
130
+ return opts
131
+
132
+ return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])]
133
+
134
+ def get_config(self):
135
+ return self.config
136
+
137
+ @property
138
+ def run_cfg(self):
139
+ return self.config.run
140
+
141
+ @property
142
+ def datasets_cfg(self):
143
+ return self.config.datasets
144
+
145
+ @property
146
+ def model_cfg(self):
147
+ return self.config.model
148
+
149
+ def pretty_print(self):
150
+ logging.info("\n===== Running Parameters =====")
151
+ logging.info(self._convert_node_to_json(self.config.run))
152
+
153
+ logging.info("\n====== Dataset Attributes ======")
154
+ datasets = self.config.datasets
155
+
156
+ for dataset in datasets:
157
+ if dataset in self.config.datasets:
158
+ logging.info(f"\n======== {dataset} =======")
159
+ dataset_config = self.config.datasets[dataset]
160
+ logging.info(self._convert_node_to_json(dataset_config))
161
+ else:
162
+ logging.warning(f"No dataset named '{dataset}' in config. Skipping")
163
+
164
+ logging.info(f"\n====== Model Attributes ======")
165
+ logging.info(self._convert_node_to_json(self.config.model))
166
+
167
+ def _convert_node_to_json(self, node):
168
+ container = OmegaConf.to_container(node, resolve=True)
169
+ return json.dumps(container, indent=4, sort_keys=True)
170
+
171
+ def to_dict(self):
172
+ return OmegaConf.to_container(self.config)
173
+
174
+
175
+ def node_to_dict(node):
176
+ return OmegaConf.to_container(node)
177
+
178
+
179
+ class ConfigValidator:
180
+ """
181
+ This is a preliminary implementation to centralize and validate the configuration.
182
+ May be altered in the future.
183
+
184
+ A helper class to validate configurations from yaml file.
185
+
186
+ This serves the following purposes:
187
+ 1. Ensure all the options in the yaml are defined, raise error if not.
188
+ 2. when type mismatches are found, the validator will raise an error.
189
+ 3. a central place to store and display helpful messages for supported configurations.
190
+
191
+ """
192
+
193
+ class _Argument:
194
+ def __init__(self, name, choices=None, type=None, help=None):
195
+ self.name = name
196
+ self.val = None
197
+ self.choices = choices
198
+ self.type = type
199
+ self.help = help
200
+
201
+ def __str__(self):
202
+ s = f"{self.name}={self.val}"
203
+ if self.type is not None:
204
+ s += f", ({self.type})"
205
+ if self.choices is not None:
206
+ s += f", choices: {self.choices}"
207
+ if self.help is not None:
208
+ s += f", ({self.help})"
209
+ return s
210
+
211
+ def __init__(self, description):
212
+ self.description = description
213
+
214
+ self.arguments = dict()
215
+
216
+ self.parsed_args = None
217
+
218
+ def __getitem__(self, key):
219
+ assert self.parsed_args is not None, "No arguments parsed yet."
220
+
221
+ return self.parsed_args[key]
222
+
223
+ def __str__(self) -> str:
224
+ return self.format_help()
225
+
226
+ def add_argument(self, *args, **kwargs):
227
+ """
228
+ Assume the first argument is the name of the argument.
229
+ """
230
+ self.arguments[args[0]] = self._Argument(*args, **kwargs)
231
+
232
+ def validate(self, config=None):
233
+ """
234
+ Convert yaml config (dict-like) to list, required by argparse.
235
+ """
236
+ for k, v in config.items():
237
+ assert (
238
+ k in self.arguments
239
+ ), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}."""
240
+
241
+ if self.arguments[k].type is not None:
242
+ try:
243
+ self.arguments[k].val = self.arguments[k].type(v)
244
+ except ValueError:
245
+ raise ValueError(f"{k} is not a valid {self.arguments[k].type}.")
246
+
247
+ if self.arguments[k].choices is not None:
248
+ assert (
249
+ v in self.arguments[k].choices
250
+ ), f"""{k} must be one of {self.arguments[k].choices}."""
251
+
252
+ return config
253
+
254
+ def format_arguments(self):
255
+ return str([f"{k}" for k in sorted(self.arguments.keys())])
256
+
257
+ def format_help(self):
258
+ # description + key-value pair string for each argument
259
+ help_msg = str(self.description)
260
+ return help_msg + ", available arguments: " + self.format_arguments()
261
+
262
+ def print_help(self):
263
+ # display help message
264
+ print(self.format_help())
265
+
266
+
267
+ def create_runner_config_validator():
268
+ validator = ConfigValidator(description="Runner configurations")
269
+
270
+ validator.add_argument(
271
+ "runner",
272
+ type=str,
273
+ choices=["runner_base", "runner_iter"],
274
+ help="""Runner to use. The "runner_base" uses epoch-based training while iter-based
275
+ runner runs based on iters. Default: runner_base""",
276
+ )
277
+ # add argumetns for training dataset ratios
278
+ validator.add_argument(
279
+ "train_dataset_ratios",
280
+ type=Dict[str, float],
281
+ help="""Ratios of training dataset. This is used in iteration-based runner.
282
+ Do not support for epoch-based runner because how to define an epoch becomes tricky.
283
+ Default: None""",
284
+ )
285
+ validator.add_argument(
286
+ "max_iters",
287
+ type=float,
288
+ help="Maximum number of iterations to run.",
289
+ )
290
+ validator.add_argument(
291
+ "max_epoch",
292
+ type=int,
293
+ help="Maximum number of epochs to run.",
294
+ )
295
+ # add arguments for iters_per_inner_epoch
296
+ validator.add_argument(
297
+ "iters_per_inner_epoch",
298
+ type=float,
299
+ help="Number of iterations per inner epoch. This is required when runner is runner_iter.",
300
+ )
301
+ lr_scheds_choices = registry.list_lr_schedulers()
302
+ validator.add_argument(
303
+ "lr_sched",
304
+ type=str,
305
+ choices=lr_scheds_choices,
306
+ help="Learning rate scheduler to use, from {}".format(lr_scheds_choices),
307
+ )
308
+ task_choices = registry.list_tasks()
309
+ validator.add_argument(
310
+ "task",
311
+ type=str,
312
+ choices=task_choices,
313
+ help="Task to use, from {}".format(task_choices),
314
+ )
315
+ # add arguments for init_lr
316
+ validator.add_argument(
317
+ "init_lr",
318
+ type=float,
319
+ help="Initial learning rate. This will be the learning rate after warmup and before decay.",
320
+ )
321
+ # add arguments for min_lr
322
+ validator.add_argument(
323
+ "min_lr",
324
+ type=float,
325
+ help="Minimum learning rate (after decay).",
326
+ )
327
+ # add arguments for warmup_lr
328
+ validator.add_argument(
329
+ "warmup_lr",
330
+ type=float,
331
+ help="Starting learning rate for warmup.",
332
+ )
333
+ # add arguments for learning rate decay rate
334
+ validator.add_argument(
335
+ "lr_decay_rate",
336
+ type=float,
337
+ help="Learning rate decay rate. Required if using a decaying learning rate scheduler.",
338
+ )
339
+ # add arguments for weight decay
340
+ validator.add_argument(
341
+ "weight_decay",
342
+ type=float,
343
+ help="Weight decay rate.",
344
+ )
345
+ # add arguments for training batch size
346
+ validator.add_argument(
347
+ "batch_size_train",
348
+ type=int,
349
+ help="Training batch size.",
350
+ )
351
+ # add arguments for evaluation batch size
352
+ validator.add_argument(
353
+ "batch_size_eval",
354
+ type=int,
355
+ help="Evaluation batch size, including validation and testing.",
356
+ )
357
+ # add arguments for number of workers for data loading
358
+ validator.add_argument(
359
+ "num_workers",
360
+ help="Number of workers for data loading.",
361
+ )
362
+ # add arguments for warm up steps
363
+ validator.add_argument(
364
+ "warmup_steps",
365
+ type=int,
366
+ help="Number of warmup steps. Required if a warmup schedule is used.",
367
+ )
368
+ # add arguments for random seed
369
+ validator.add_argument(
370
+ "seed",
371
+ type=int,
372
+ help="Random seed.",
373
+ )
374
+ # add arguments for output directory
375
+ validator.add_argument(
376
+ "output_dir",
377
+ type=str,
378
+ help="Output directory to save checkpoints and logs.",
379
+ )
380
+ # add arguments for whether only use evaluation
381
+ validator.add_argument(
382
+ "evaluate",
383
+ help="Whether to only evaluate the model. If true, training will not be performed.",
384
+ )
385
+ # add arguments for splits used for training, e.g. ["train", "val"]
386
+ validator.add_argument(
387
+ "train_splits",
388
+ type=list,
389
+ help="Splits to use for training.",
390
+ )
391
+ # add arguments for splits used for validation, e.g. ["val"]
392
+ validator.add_argument(
393
+ "valid_splits",
394
+ type=list,
395
+ help="Splits to use for validation. If not provided, will skip the validation.",
396
+ )
397
+ # add arguments for splits used for testing, e.g. ["test"]
398
+ validator.add_argument(
399
+ "test_splits",
400
+ type=list,
401
+ help="Splits to use for testing. If not provided, will skip the testing.",
402
+ )
403
+ # add arguments for accumulating gradient for iterations
404
+ validator.add_argument(
405
+ "accum_grad_iters",
406
+ type=int,
407
+ help="Number of iterations to accumulate gradient for.",
408
+ )
409
+
410
+ # ====== distributed training ======
411
+ validator.add_argument(
412
+ "device",
413
+ type=str,
414
+ choices=["cpu", "cuda"],
415
+ help="Device to use. Support 'cuda' or 'cpu' as for now.",
416
+ )
417
+ validator.add_argument(
418
+ "world_size",
419
+ type=int,
420
+ help="Number of processes participating in the job.",
421
+ )
422
+ validator.add_argument("dist_url", type=str)
423
+ validator.add_argument("distributed", type=bool)
424
+ # add arguments to opt using distributed sampler during evaluation or not
425
+ validator.add_argument(
426
+ "use_dist_eval_sampler",
427
+ type=bool,
428
+ help="Whether to use distributed sampler during evaluation or not.",
429
+ )
430
+
431
+ # ====== task specific ======
432
+ # generation task specific arguments
433
+ # add arguments for maximal length of text output
434
+ validator.add_argument(
435
+ "max_len",
436
+ type=int,
437
+ help="Maximal length of text output.",
438
+ )
439
+ # add arguments for minimal length of text output
440
+ validator.add_argument(
441
+ "min_len",
442
+ type=int,
443
+ help="Minimal length of text output.",
444
+ )
445
+ # add arguments number of beams
446
+ validator.add_argument(
447
+ "num_beams",
448
+ type=int,
449
+ help="Number of beams used for beam search.",
450
+ )
451
+
452
+ # vqa task specific arguments
453
+ # add arguments for number of answer candidates
454
+ validator.add_argument(
455
+ "num_ans_candidates",
456
+ type=int,
457
+ help="""For ALBEF and BLIP, these models first rank answers according to likelihood to select answer candidates.""",
458
+ )
459
+ # add arguments for inference method
460
+ validator.add_argument(
461
+ "inference_method",
462
+ type=str,
463
+ choices=["genearte", "rank"],
464
+ help="""Inference method to use for question answering. If rank, requires a answer list.""",
465
+ )
466
+
467
+ # ====== model specific ======
468
+ validator.add_argument(
469
+ "k_test",
470
+ type=int,
471
+ help="Number of top k most similar samples from ITC/VTC selection to be tested.",
472
+ )
473
+
474
+ return validator
minigpt4/common/dist_utils.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import datetime
9
+ import functools
10
+ import os
11
+
12
+ import torch
13
+ import torch.distributed as dist
14
+ import timm.models.hub as timm_hub
15
+
16
+
17
+ def setup_for_distributed(is_master):
18
+ """
19
+ This function disables printing when not in master process
20
+ """
21
+ import builtins as __builtin__
22
+
23
+ builtin_print = __builtin__.print
24
+
25
+ def print(*args, **kwargs):
26
+ force = kwargs.pop("force", False)
27
+ if is_master or force:
28
+ builtin_print(*args, **kwargs)
29
+
30
+ __builtin__.print = print
31
+
32
+
33
+ def is_dist_avail_and_initialized():
34
+ if not dist.is_available():
35
+ return False
36
+ if not dist.is_initialized():
37
+ return False
38
+ return True
39
+
40
+
41
+ def get_world_size():
42
+ if not is_dist_avail_and_initialized():
43
+ return 1
44
+ return dist.get_world_size()
45
+
46
+
47
+ def get_rank():
48
+ if not is_dist_avail_and_initialized():
49
+ return 0
50
+ return dist.get_rank()
51
+
52
+
53
+ def is_main_process():
54
+ return get_rank() == 0
55
+
56
+
57
+ def init_distributed_mode(args):
58
+ if args.distributed is False:
59
+ print("Not using distributed mode")
60
+ args.rank = 0
61
+ return
62
+
63
+ if 'LOCAL_RANK' not in os.environ:
64
+ os.environ['LOCAL_RANK'] = str(args.local_rank)
65
+
66
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
67
+ args.rank = int(os.environ["RANK"])
68
+ args.world_size = int(os.environ["WORLD_SIZE"])
69
+ args.gpu = int(os.environ["LOCAL_RANK"])
70
+ elif "SLURM_PROCID" in os.environ:
71
+ args.rank = int(os.environ["SLURM_PROCID"])
72
+ args.gpu = args.rank % torch.cuda.device_count()
73
+ else:
74
+ print("Not using distributed mode")
75
+ args.distributed = False
76
+ args.rank = 0
77
+ return
78
+
79
+ args.distributed = True
80
+
81
+ torch.cuda.set_device(args.gpu)
82
+ args.dist_backend = "nccl"
83
+ print(
84
+ "| distributed init (rank {}, world {}): {}".format(
85
+ args.rank, args.world_size, args.dist_url
86
+ ),
87
+ flush=True,
88
+ )
89
+ torch.distributed.init_process_group(
90
+ backend=args.dist_backend,
91
+ init_method=args.dist_url,
92
+ world_size=args.world_size,
93
+ rank=args.rank,
94
+ timeout=datetime.timedelta(
95
+ days=365
96
+ ), # allow auto-downloading and de-compressing
97
+ )
98
+ torch.distributed.barrier()
99
+ setup_for_distributed(args.rank == 0)
100
+
101
+
102
+ def get_dist_info():
103
+ if torch.__version__ < "1.0":
104
+ initialized = dist._initialized
105
+ else:
106
+ initialized = dist.is_initialized()
107
+ if initialized:
108
+ rank = dist.get_rank()
109
+ world_size = dist.get_world_size()
110
+ else: # non-distributed training
111
+ rank = 0
112
+ world_size = 1
113
+ return rank, world_size
114
+
115
+
116
+ def main_process(func):
117
+ @functools.wraps(func)
118
+ def wrapper(*args, **kwargs):
119
+ rank, _ = get_dist_info()
120
+ if rank == 0:
121
+ return func(*args, **kwargs)
122
+
123
+ return wrapper
124
+
125
+
126
+ def download_cached_file(url, check_hash=True, progress=False):
127
+ """
128
+ Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
129
+ If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
130
+ """
131
+
132
+ def get_cached_file_path():
133
+ # a hack to sync the file path across processes
134
+ parts = torch.hub.urlparse(url)
135
+ filename = os.path.basename(parts.path)
136
+ cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
137
+
138
+ return cached_file
139
+
140
+ if is_main_process():
141
+ timm_hub.download_cached_file(url, check_hash, progress)
142
+
143
+ if is_dist_avail_and_initialized():
144
+ dist.barrier()
145
+
146
+ return get_cached_file_path()
minigpt4/common/eval_utils.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import numpy as np
3
+ from nltk.translate.bleu_score import sentence_bleu
4
+ import sys
5
+ sys.path.append('/home/ataallka/minigpt_video/minigpt_multi_img')
6
+ from minigpt4.common.registry import registry
7
+ from minigpt4.common.config import Config
8
+
9
+ # imports modules for registration
10
+ from minigpt4.datasets.builders import *
11
+ from minigpt4.models import *
12
+ from minigpt4.processors import *
13
+ # from minigpt4.runners import *
14
+ from minigpt4.tasks import *
15
+ from pycocoevalcap.cider.cider import Cider
16
+ import os
17
+ import openai
18
+ from tqdm import tqdm
19
+ import json
20
+ import ast
21
+ import time
22
+
23
+ def eval_parser():
24
+ parser = argparse.ArgumentParser(description="Demo")
25
+ parser.add_argument("--cfg-path", help="path to configuration file.",default="test_configs/llama2_test_config.yaml")
26
+ parser.add_argument("--ckpt", type=str,default='checkpoints/video_llama_checkpoint_last.pth', help="path to checkpoint")
27
+ parser.add_argument("--eval_opt", type=str, default='all', help="path to configuration file.")
28
+ parser.add_argument("--max_new_tokens", type=int, default=512, help="max number of generated tokens")
29
+ parser.add_argument("--lora_r", type=int, default=64, help="lora rank of the model")
30
+ parser.add_argument("--lora_alpha", type=int, default=16, help="lora alpha")
31
+ parser.add_argument(
32
+ "--options",
33
+ nargs="+",
34
+ help="override some settings in the used config, the key-value pair "
35
+ "in xxx=yyy format will be merged into config file (deprecate), "
36
+ "change to --cfg-options instead.",
37
+ )
38
+ return parser
39
+
40
+
41
+ def prepare_texts(texts, conv_temp, template='<Img><ImageHere></Img>', lengths=None):
42
+ convs = [conv_temp.copy() for _ in range(len(texts))]
43
+ if lengths is None:
44
+ [conv.append_message(conv.roles[0], '{} {}'.format(template, text)) for conv, text in zip(convs, texts)]
45
+ else:
46
+ templates = [template * length for length in lengths]
47
+ [conv.append_message(conv.roles[0], '{} {}'.format(template, text)) for template, conv, text in zip(templates, convs, texts)]
48
+ [conv.append_message(conv.roles[1], None) for conv in convs]
49
+ texts = [conv.get_prompt() for conv in convs]
50
+ return texts
51
+
52
+
53
+ def init_model(args):
54
+ print('Initialization Model')
55
+ cfg = Config(args)
56
+ cfg.model_cfg.ckpt = args.ckpt
57
+ cfg.model_cfg.lora_r = args.lora_r
58
+ cfg.model_cfg.lora_alpha = args.lora_alpha
59
+
60
+ model_config = cfg.model_cfg
61
+ model_config.low_resource = True
62
+ model_cls = registry.get_model_class(model_config.arch)
63
+ model = model_cls.from_config(model_config).to('cuda:0')
64
+
65
+ # import pudb; pudb.set_trace()
66
+ key = list(cfg.datasets_cfg.keys())[0]
67
+ vis_processor_cfg = cfg.datasets_cfg.get(key).vis_processor.train
68
+ print(vis_processor_cfg)
69
+ vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
70
+ print('Initialization Finished')
71
+ return model, vis_processor
72
+
73
+ def computeIoU(bbox1, bbox2):
74
+ x1, y1, x2, y2 = bbox1
75
+ x3, y3, x4, y4 = bbox2
76
+ intersection_x1 = max(x1, x3)
77
+ intersection_y1 = max(y1, y3)
78
+ intersection_x2 = min(x2, x4)
79
+ intersection_y2 = min(y2, y4)
80
+ intersection_area = max(0, intersection_x2 - intersection_x1 + 1) * max(0, intersection_y2 - intersection_y1 + 1)
81
+ bbox1_area = (x2 - x1 + 1) * (y2 - y1 + 1)
82
+ bbox2_area = (x4 - x3 + 1) * (y4 - y3 + 1)
83
+ union_area = bbox1_area + bbox2_area - intersection_area
84
+ iou = intersection_area / union_area
85
+ return iou
86
+
87
+ def eval_bleu(results):
88
+ bleus1,bleus2,bleus3,bleus4 = [],[],[],[]
89
+ for result in tqdm (results,desc="bleu_eval"):
90
+ gt = result['gt']
91
+ pred = result['pred']
92
+ bleus1.append(sentence_bleu([gt.split()], pred.split(), weights=(1,0,0,0)))
93
+ bleus2.append(sentence_bleu([gt.split()], pred.split(), weights=(0.5,0.5,0,0)))
94
+ bleus3.append(sentence_bleu([gt.split()], pred.split(), weights=(0.33,0.33,0.33,0)))
95
+ bleus4.append(sentence_bleu([gt.split()], pred.split()))
96
+ # print(np.mean(bleus1),np.mean(bleus2),np.mean(bleus3),np.mean(bleus4),flush=True)
97
+ return {'bleu1':np.mean(bleus1),'bleu2':np.mean(bleus2),'bleu3':np.mean(bleus3),'bleu4':np.mean(bleus4)}
98
+
99
+ # Create a Cider object
100
+ cider_scorer = Cider()
101
+ def eval_cider(pred_result,gt_result):
102
+ # Compute CIDEr scores
103
+ mean_cider_scores, cider_scores = cider_scorer.compute_score(gt_result, pred_result)
104
+ cider_scores_dict={}
105
+ for score,pred_vid_id,gt_vid_id in tqdm(zip(cider_scores.tolist(),pred_result,gt_result),desc="cider_eval") :
106
+ assert pred_vid_id==gt_vid_id
107
+ cider_scores_dict[pred_vid_id] = score
108
+ return {'mean_cider_scores':mean_cider_scores,'cider_scores':cider_scores_dict}
109
+
110
+
111
+ openai.api_key_path = "/home/ataallka/chatgpt_api.txt"
112
+
113
+
114
+ def chat_gpt_eval(results,output_path):
115
+ trial=0
116
+ gpt_results=[]
117
+ avg_chatgpt_score=0
118
+ existed_files={}
119
+ # read previous results from output path
120
+ for file in os.listdir(output_path):
121
+ if file.endswith(".json"):
122
+ with open(f'{output_path}/{file}') as json_file:
123
+ data = json.load(json_file)
124
+ gpt_results.append(data[0])
125
+ avg_chatgpt_score+=float(data[0]['chatgpt_score'])
126
+ existed_files[data[0]['video_name']]=True
127
+ length_output_path=len(os.listdir(output_path))
128
+ while len (results)!= length_output_path:
129
+ for res in tqdm(results,desc="chatgpt_eval"):
130
+ if existed_files.get(res['video_name'],False):
131
+ continue
132
+ video_name=res['video_name']
133
+ sentence_1=res['A']
134
+ sentence_2=res['pred']
135
+ try:
136
+ # prompt=f"given these 2 sentences the first one is the ground truth text and the second sentence is the generated text ,give me a score from 0 to 1 to evaluate how much they are similar to each other, and have the same context and related to each other to evaluate the quality of this generated text.the output should be only the score float number without any additional information\nfirst sentence: {sentence_1}\nsecond sentence: {sentence_2}\nscore:"
137
+ prompt=f"given these 2 sentences the first one is the ground truth descrption of a video and the second sentence is the generated text from a video summarization model,give it a score from 0 to 5 to evaluate the model summarization performance.the output should be only the score number without any additional information\nfirst sentence: {sentence_1}\nsecond sentence: {sentence_2}\nscore:"
138
+ response = openai.ChatCompletion.create(
139
+ model="gpt-3.5-turbo",
140
+ messages=[
141
+ {
142
+ "role": "user",
143
+ "content": prompt
144
+ }],
145
+ )
146
+ res['chatgpt_score']=response.choices[0].message['content']
147
+ out={'video_name':video_name,'chatgpt_score':response.choices[0].message['content']}
148
+ gpt_results.append(out)
149
+ # save each video result in a json file
150
+ with open(f'{output_path}/{video_name}.json', 'w') as f:
151
+ json.dump([out], f)
152
+ avg_chatgpt_score+=float(response.choices[0].message['content'])
153
+ except Exception as e:
154
+ print("chat gpt error",e)
155
+ print ("Finished chat gpt evaluation in trial",trial)
156
+ trial+=1
157
+ length_output_path=len(os.listdir(output_path))
158
+ return results,avg_chatgpt_score/len(results)
159
+ def GPT4_answer(question, answer,pred):
160
+ try:
161
+ # Compute the correctness score
162
+ completion = openai.ChatCompletion.create(
163
+ # model="gpt-3.5-turbo",
164
+ model='gpt-4',
165
+ messages=[
166
+ {
167
+ "role": "system",
168
+ "content":
169
+ "You are an intelligent chatbot designed for evaluating the correctness of generative outputs for question-answer pairs. "
170
+ "Your task is to compare the predicted answer with the correct answer and determine if they match meaningfully. Here's how you can accomplish the task:"
171
+ "------"
172
+ "##INSTRUCTIONS: "
173
+ "- Focus on the meaningful match between the predicted answer and the correct answer.\n"
174
+ "- Consider synonyms or paraphrases as valid matches.\n"
175
+ "- Evaluate the correctness of the prediction compared to the answer."
176
+ },
177
+ {
178
+ "role": "user",
179
+ "content":
180
+ "Please evaluate the following video-based question-answer pair:\n\n"
181
+ f"Question: {question}\n"
182
+ f"Correct Answer: {answer}\n"
183
+ f"Predicted Answer: {pred}\n\n"
184
+ "Provide your evaluation only as a yes/no and score where the score is an integer value between 0 and 5, with 5 indicating the highest meaningful match. "
185
+ "Please generate the response in the form of a Python dictionary string with keys 'pred' and 'score', where value of 'pred' is a string of 'yes' or 'no' and value of 'score' is in INTEGER, not STRING."
186
+ "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. "
187
+ "For example, your response should look like this: {'pred': 'yes', 'score': 4.8}."
188
+ }
189
+ ]
190
+ )
191
+ # Convert response to a Python dictionary.
192
+ response_message = completion["choices"][0]["message"]["content"]
193
+ response_dict = ast.literal_eval(response_message)
194
+ return response_dict
195
+ except Exception as e:
196
+ print(f"Error : {e}")
197
+ return None
198
+ def GPT4_evaluation(val_result):
199
+ scores=[]
200
+ yes_count=0
201
+ no_count=0
202
+ for res in val_result:
203
+ gpt_response=GPT4_answer(res['Q'],res['A'],res['pred'])
204
+ if gpt_response is None:
205
+ continue
206
+ try:
207
+ scores.append(float(gpt_response['score']))
208
+ if 'yes' in gpt_response['pred'].lower():
209
+ yes_count+=1
210
+ elif 'no' in gpt_response['pred'].lower():
211
+ no_count+=1
212
+ except:
213
+ continue
214
+ avg_score=sum(scores)/len(scores)
215
+ accuracy=(yes_count/(yes_count+no_count))*100
216
+ print(f"chatgpt score: {avg_score} accuracy: {accuracy}")
217
+ return avg_score,accuracy
218
+
219
+ # with open('results/ckpt_15_res89_res32_Video_validation_Dataset_subtitles.json','r') as f:
220
+ # results = json.load(f)
221
+ # t1=time.time()
222
+ # avg_score,accuracy=GPT4_evaluation(results)
223
+ # print(f"chatgpt score: {avg_score} accuracy: {accuracy}")
224
+ # print(f"Time taken: {time.time()-t1}")
minigpt4/common/gradcam.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from matplotlib import pyplot as plt
3
+ from scipy.ndimage import filters
4
+ from skimage import transform as skimage_transform
5
+
6
+
7
+ def getAttMap(img, attMap, blur=True, overlap=True):
8
+ attMap -= attMap.min()
9
+ if attMap.max() > 0:
10
+ attMap /= attMap.max()
11
+ attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant")
12
+ if blur:
13
+ attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
14
+ attMap -= attMap.min()
15
+ attMap /= attMap.max()
16
+ cmap = plt.get_cmap("jet")
17
+ attMapV = cmap(attMap)
18
+ attMapV = np.delete(attMapV, 3, 2)
19
+ if overlap:
20
+ attMap = (
21
+ 1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
22
+ + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
23
+ )
24
+ return attMap
minigpt4/common/logger.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import datetime
9
+ import logging
10
+ import time
11
+ from collections import defaultdict, deque
12
+
13
+ import torch
14
+ import torch.distributed as dist
15
+
16
+ from minigpt4.common import dist_utils
17
+
18
+
19
+ class SmoothedValue(object):
20
+ """Track a series of values and provide access to smoothed values over a
21
+ window or the global series average.
22
+ """
23
+
24
+ def __init__(self, window_size=20, fmt=None):
25
+ if fmt is None:
26
+ fmt = "{median:.4f} ({global_avg:.4f})"
27
+ self.deque = deque(maxlen=window_size)
28
+ self.total = 0.0
29
+ self.count = 0
30
+ self.fmt = fmt
31
+
32
+ def update(self, value, n=1):
33
+ self.deque.append(value)
34
+ self.count += n
35
+ self.total += value * n
36
+
37
+ def synchronize_between_processes(self):
38
+ """
39
+ Warning: does not synchronize the deque!
40
+ """
41
+ if not dist_utils.is_dist_avail_and_initialized():
42
+ return
43
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
44
+ dist.barrier()
45
+ dist.all_reduce(t)
46
+ t = t.tolist()
47
+ self.count = int(t[0])
48
+ self.total = t[1]
49
+
50
+ @property
51
+ def median(self):
52
+ d = torch.tensor(list(self.deque))
53
+ return d.median().item()
54
+
55
+ @property
56
+ def avg(self):
57
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
58
+ return d.mean().item()
59
+
60
+ @property
61
+ def global_avg(self):
62
+ return self.total / self.count
63
+
64
+ @property
65
+ def max(self):
66
+ return max(self.deque)
67
+
68
+ @property
69
+ def value(self):
70
+ return self.deque[-1]
71
+
72
+ def __str__(self):
73
+ return self.fmt.format(
74
+ median=self.median,
75
+ avg=self.avg,
76
+ global_avg=self.global_avg,
77
+ max=self.max,
78
+ value=self.value,
79
+ )
80
+
81
+
82
+ class MetricLogger(object):
83
+ def __init__(self, delimiter="\t"):
84
+ self.meters = defaultdict(SmoothedValue)
85
+ self.delimiter = delimiter
86
+
87
+ def update(self, **kwargs):
88
+ for k, v in kwargs.items():
89
+ if isinstance(v, torch.Tensor):
90
+ v = v.item()
91
+ assert isinstance(v, (float, int))
92
+ self.meters[k].update(v)
93
+
94
+ def __getattr__(self, attr):
95
+ if attr in self.meters:
96
+ return self.meters[attr]
97
+ if attr in self.__dict__:
98
+ return self.__dict__[attr]
99
+ raise AttributeError(
100
+ "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
101
+ )
102
+
103
+ def __str__(self):
104
+ loss_str = []
105
+ for name, meter in self.meters.items():
106
+ loss_str.append("{}: {}".format(name, str(meter)))
107
+ return self.delimiter.join(loss_str)
108
+
109
+ def global_avg(self):
110
+ loss_str = []
111
+ for name, meter in self.meters.items():
112
+ loss_str.append("{}: {:.4f}".format(name, meter.global_avg))
113
+ return self.delimiter.join(loss_str)
114
+
115
+ def synchronize_between_processes(self):
116
+ for meter in self.meters.values():
117
+ meter.synchronize_between_processes()
118
+
119
+ def add_meter(self, name, meter):
120
+ self.meters[name] = meter
121
+
122
+ def log_every(self, iterable, print_freq, header=None):
123
+ i = 0
124
+ if not header:
125
+ header = ""
126
+ start_time = time.time()
127
+ end = time.time()
128
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
129
+ data_time = SmoothedValue(fmt="{avg:.4f}")
130
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
131
+ log_msg = [
132
+ header,
133
+ "[{0" + space_fmt + "}/{1}]",
134
+ "eta: {eta}",
135
+ "{meters}",
136
+ "time: {time}",
137
+ "data: {data}",
138
+ ]
139
+ if torch.cuda.is_available():
140
+ log_msg.append("max mem: {memory:.0f}")
141
+ log_msg = self.delimiter.join(log_msg)
142
+ MB = 1024.0 * 1024.0
143
+ for obj in iterable:
144
+ data_time.update(time.time() - end)
145
+ yield obj
146
+ iter_time.update(time.time() - end)
147
+ if i % print_freq == 0 or i == len(iterable) - 1:
148
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
149
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
150
+ if torch.cuda.is_available():
151
+ print(
152
+ log_msg.format(
153
+ i,
154
+ len(iterable),
155
+ eta=eta_string,
156
+ meters=str(self),
157
+ time=str(iter_time),
158
+ data=str(data_time),
159
+ memory=torch.cuda.max_memory_allocated() / MB,
160
+ )
161
+ )
162
+ else:
163
+ print(
164
+ log_msg.format(
165
+ i,
166
+ len(iterable),
167
+ eta=eta_string,
168
+ meters=str(self),
169
+ time=str(iter_time),
170
+ data=str(data_time),
171
+ )
172
+ )
173
+ i += 1
174
+ end = time.time()
175
+ total_time = time.time() - start_time
176
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
177
+ print(
178
+ "{} Total time: {} ({:.4f} s / it)".format(
179
+ header, total_time_str, total_time / len(iterable)
180
+ )
181
+ )
182
+
183
+
184
+ class AttrDict(dict):
185
+ def __init__(self, *args, **kwargs):
186
+ super(AttrDict, self).__init__(*args, **kwargs)
187
+ self.__dict__ = self
188
+
189
+
190
+ def setup_logger():
191
+ logging.basicConfig(
192
+ level=logging.INFO if dist_utils.is_main_process() else logging.WARN,
193
+ format="%(asctime)s [%(levelname)s] %(message)s",
194
+ handlers=[logging.StreamHandler()],
195
+ )
minigpt4/common/optims.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import math
9
+
10
+ from minigpt4.common.registry import registry
11
+
12
+
13
+ @registry.register_lr_scheduler("linear_warmup_step_lr")
14
+ class LinearWarmupStepLRScheduler:
15
+ def __init__(
16
+ self,
17
+ optimizer,
18
+ max_epoch,
19
+ min_lr,
20
+ init_lr,
21
+ decay_rate=1,
22
+ warmup_start_lr=-1,
23
+ warmup_steps=0,
24
+ **kwargs
25
+ ):
26
+ self.optimizer = optimizer
27
+
28
+ self.max_epoch = max_epoch
29
+ self.min_lr = min_lr
30
+
31
+ self.decay_rate = decay_rate
32
+
33
+ self.init_lr = init_lr
34
+ self.warmup_steps = warmup_steps
35
+ self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
36
+
37
+ def step(self, cur_epoch, cur_step):
38
+ if cur_epoch == 0:
39
+ warmup_lr_schedule(
40
+ step=cur_step,
41
+ optimizer=self.optimizer,
42
+ max_step=self.warmup_steps,
43
+ init_lr=self.warmup_start_lr,
44
+ max_lr=self.init_lr,
45
+ )
46
+ else:
47
+ step_lr_schedule(
48
+ epoch=cur_epoch,
49
+ optimizer=self.optimizer,
50
+ init_lr=self.init_lr,
51
+ min_lr=self.min_lr,
52
+ decay_rate=self.decay_rate,
53
+ )
54
+
55
+
56
+ @registry.register_lr_scheduler("linear_warmup_cosine_lr")
57
+ class LinearWarmupCosineLRScheduler:
58
+ def __init__(
59
+ self,
60
+ optimizer,
61
+ max_epoch,
62
+ iters_per_epoch,
63
+ min_lr,
64
+ init_lr,
65
+ warmup_steps=0,
66
+ warmup_start_lr=-1,
67
+ **kwargs
68
+ ):
69
+ self.optimizer = optimizer
70
+
71
+ self.max_epoch = max_epoch
72
+ self.iters_per_epoch = iters_per_epoch
73
+ self.min_lr = min_lr
74
+
75
+ self.init_lr = init_lr
76
+ self.warmup_steps = warmup_steps
77
+ self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
78
+
79
+ def step(self, cur_epoch, cur_step):
80
+ total_cur_step = cur_epoch * self.iters_per_epoch + cur_step
81
+ if total_cur_step < self.warmup_steps:
82
+ warmup_lr_schedule(
83
+ step=total_cur_step,
84
+ optimizer=self.optimizer,
85
+ max_step=self.warmup_steps,
86
+ init_lr=self.warmup_start_lr,
87
+ max_lr=self.init_lr,
88
+ )
89
+ else:
90
+ cosine_lr_schedule(
91
+ epoch=total_cur_step,
92
+ optimizer=self.optimizer,
93
+ max_epoch=self.max_epoch * self.iters_per_epoch,
94
+ init_lr=self.init_lr,
95
+ min_lr=self.min_lr,
96
+ )
97
+
98
+
99
+ def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
100
+ """Decay the learning rate"""
101
+ lr = (init_lr - min_lr) * 0.5 * (
102
+ 1.0 + math.cos(math.pi * epoch / max_epoch)
103
+ ) + min_lr
104
+ for param_group in optimizer.param_groups:
105
+ param_group["lr"] = lr
106
+
107
+
108
+ def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
109
+ """Warmup the learning rate"""
110
+ lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
111
+ for param_group in optimizer.param_groups:
112
+ param_group["lr"] = lr
113
+
114
+
115
+ def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
116
+ """Decay the learning rate"""
117
+ lr = max(min_lr, init_lr * (decay_rate**epoch))
118
+ for param_group in optimizer.param_groups:
119
+ param_group["lr"] = lr
minigpt4/common/registry.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+
9
+ class Registry:
10
+ mapping = {
11
+ "builder_name_mapping": {},
12
+ "task_name_mapping": {},
13
+ "processor_name_mapping": {},
14
+ "model_name_mapping": {},
15
+ "lr_scheduler_name_mapping": {},
16
+ "runner_name_mapping": {},
17
+ "state": {},
18
+ "paths": {},
19
+ }
20
+
21
+ @classmethod
22
+ def register_builder(cls, name):
23
+ r"""Register a dataset builder to registry with key 'name'
24
+
25
+ Args:
26
+ name: Key with which the builder will be registered.
27
+
28
+ Usage:
29
+
30
+ from minigpt4.common.registry import registry
31
+ from minigpt4.datasets.base_dataset_builder import BaseDatasetBuilder
32
+ """
33
+
34
+ def wrap(builder_cls):
35
+ from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder
36
+
37
+ assert issubclass(
38
+ builder_cls, BaseDatasetBuilder
39
+ ), "All builders must inherit BaseDatasetBuilder class, found {}".format(
40
+ builder_cls
41
+ )
42
+ if name in cls.mapping["builder_name_mapping"]:
43
+ raise KeyError(
44
+ "Name '{}' already registered for {}.".format(
45
+ name, cls.mapping["builder_name_mapping"][name]
46
+ )
47
+ )
48
+ cls.mapping["builder_name_mapping"][name] = builder_cls
49
+ return builder_cls
50
+
51
+ return wrap
52
+
53
+ @classmethod
54
+ def register_task(cls, name):
55
+ r"""Register a task to registry with key 'name'
56
+
57
+ Args:
58
+ name: Key with which the task will be registered.
59
+
60
+ Usage:
61
+
62
+ from minigpt4.common.registry import registry
63
+ """
64
+
65
+ def wrap(task_cls):
66
+ from minigpt4.tasks.base_task import BaseTask
67
+
68
+ assert issubclass(
69
+ task_cls, BaseTask
70
+ ), "All tasks must inherit BaseTask class"
71
+ if name in cls.mapping["task_name_mapping"]:
72
+ raise KeyError(
73
+ "Name '{}' already registered for {}.".format(
74
+ name, cls.mapping["task_name_mapping"][name]
75
+ )
76
+ )
77
+ cls.mapping["task_name_mapping"][name] = task_cls
78
+ return task_cls
79
+
80
+ return wrap
81
+
82
+ @classmethod
83
+ def register_model(cls, name):
84
+ r"""Register a task to registry with key 'name'
85
+
86
+ Args:
87
+ name: Key with which the task will be registered.
88
+
89
+ Usage:
90
+
91
+ from minigpt4.common.registry import registry
92
+ """
93
+
94
+ def wrap(model_cls):
95
+ # from minigpt4.models import BaseModel
96
+
97
+ # assert issubclass(
98
+ # model_cls, BaseModel
99
+ # ), "All models must inherit BaseModel class"
100
+
101
+ if name in cls.mapping["model_name_mapping"]:
102
+ raise KeyError(
103
+ "Name '{}' already registered for {}.".format(
104
+ name, cls.mapping["model_name_mapping"][name]
105
+ )
106
+ )
107
+ cls.mapping["model_name_mapping"][name] = model_cls
108
+ return model_cls
109
+
110
+ return wrap
111
+
112
+ @classmethod
113
+ def register_processor(cls, name):
114
+ r"""Register a processor to registry with key 'name'
115
+
116
+ Args:
117
+ name: Key with which the task will be registered.
118
+
119
+ Usage:
120
+
121
+ from minigpt4.common.registry import registry
122
+ """
123
+
124
+ def wrap(processor_cls):
125
+ from minigpt4.processors import BaseProcessor
126
+
127
+ assert issubclass(
128
+ processor_cls, BaseProcessor
129
+ ), "All processors must inherit BaseProcessor class"
130
+ if name in cls.mapping["processor_name_mapping"]:
131
+ raise KeyError(
132
+ "Name '{}' already registered for {}.".format(
133
+ name, cls.mapping["processor_name_mapping"][name]
134
+ )
135
+ )
136
+ cls.mapping["processor_name_mapping"][name] = processor_cls
137
+ return processor_cls
138
+
139
+ return wrap
140
+
141
+ @classmethod
142
+ def register_lr_scheduler(cls, name):
143
+ r"""Register a model to registry with key 'name'
144
+
145
+ Args:
146
+ name: Key with which the task will be registered.
147
+
148
+ Usage:
149
+
150
+ from minigpt4.common.registry import registry
151
+ """
152
+
153
+ def wrap(lr_sched_cls):
154
+ if name in cls.mapping["lr_scheduler_name_mapping"]:
155
+ raise KeyError(
156
+ "Name '{}' already registered for {}.".format(
157
+ name, cls.mapping["lr_scheduler_name_mapping"][name]
158
+ )
159
+ )
160
+ cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls
161
+ return lr_sched_cls
162
+
163
+ return wrap
164
+
165
+ @classmethod
166
+ def register_runner(cls, name):
167
+ r"""Register a model to registry with key 'name'
168
+
169
+ Args:
170
+ name: Key with which the task will be registered.
171
+
172
+ Usage:
173
+
174
+ from minigpt4.common.registry import registry
175
+ """
176
+
177
+ def wrap(runner_cls):
178
+ if name in cls.mapping["runner_name_mapping"]:
179
+ raise KeyError(
180
+ "Name '{}' already registered for {}.".format(
181
+ name, cls.mapping["runner_name_mapping"][name]
182
+ )
183
+ )
184
+ cls.mapping["runner_name_mapping"][name] = runner_cls
185
+ return runner_cls
186
+
187
+ return wrap
188
+
189
+ @classmethod
190
+ def register_path(cls, name, path):
191
+ r"""Register a path to registry with key 'name'
192
+
193
+ Args:
194
+ name: Key with which the path will be registered.
195
+
196
+ Usage:
197
+
198
+ from minigpt4.common.registry import registry
199
+ """
200
+ assert isinstance(path, str), "All path must be str."
201
+ if name in cls.mapping["paths"]:
202
+ raise KeyError("Name '{}' already registered.".format(name))
203
+ cls.mapping["paths"][name] = path
204
+
205
+ @classmethod
206
+ def register(cls, name, obj):
207
+ r"""Register an item to registry with key 'name'
208
+
209
+ Args:
210
+ name: Key with which the item will be registered.
211
+
212
+ Usage::
213
+
214
+ from minigpt4.common.registry import registry
215
+
216
+ registry.register("config", {})
217
+ """
218
+ path = name.split(".")
219
+ current = cls.mapping["state"]
220
+
221
+ for part in path[:-1]:
222
+ if part not in current:
223
+ current[part] = {}
224
+ current = current[part]
225
+
226
+ current[path[-1]] = obj
227
+
228
+ # @classmethod
229
+ # def get_trainer_class(cls, name):
230
+ # return cls.mapping["trainer_name_mapping"].get(name, None)
231
+
232
+ @classmethod
233
+ def get_builder_class(cls, name):
234
+ return cls.mapping["builder_name_mapping"].get(name, None)
235
+
236
+ @classmethod
237
+ def get_model_class(cls, name):
238
+ return cls.mapping["model_name_mapping"].get(name, None)
239
+
240
+ @classmethod
241
+ def get_task_class(cls, name):
242
+ return cls.mapping["task_name_mapping"].get(name, None)
243
+
244
+ @classmethod
245
+ def get_processor_class(cls, name):
246
+ return cls.mapping["processor_name_mapping"].get(name, None)
247
+
248
+ @classmethod
249
+ def get_lr_scheduler_class(cls, name):
250
+ return cls.mapping["lr_scheduler_name_mapping"].get(name, None)
251
+
252
+ @classmethod
253
+ def get_runner_class(cls, name):
254
+ return cls.mapping["runner_name_mapping"].get(name, None)
255
+
256
+ @classmethod
257
+ def list_runners(cls):
258
+ return sorted(cls.mapping["runner_name_mapping"].keys())
259
+
260
+ @classmethod
261
+ def list_models(cls):
262
+ return sorted(cls.mapping["model_name_mapping"].keys())
263
+
264
+ @classmethod
265
+ def list_tasks(cls):
266
+ return sorted(cls.mapping["task_name_mapping"].keys())
267
+
268
+ @classmethod
269
+ def list_processors(cls):
270
+ return sorted(cls.mapping["processor_name_mapping"].keys())
271
+
272
+ @classmethod
273
+ def list_lr_schedulers(cls):
274
+ return sorted(cls.mapping["lr_scheduler_name_mapping"].keys())
275
+
276
+ @classmethod
277
+ def list_datasets(cls):
278
+ return sorted(cls.mapping["builder_name_mapping"].keys())
279
+
280
+ @classmethod
281
+ def get_path(cls, name):
282
+ return cls.mapping["paths"].get(name, None)
283
+
284
+ @classmethod
285
+ def get(cls, name, default=None, no_warning=False):
286
+ r"""Get an item from registry with key 'name'
287
+
288
+ Args:
289
+ name (string): Key whose value needs to be retrieved.
290
+ default: If passed and key is not in registry, default value will
291
+ be returned with a warning. Default: None
292
+ no_warning (bool): If passed as True, warning when key doesn't exist
293
+ will not be generated. Useful for MMF's
294
+ internal operations. Default: False
295
+ """
296
+ original_name = name
297
+ name = name.split(".")
298
+ value = cls.mapping["state"]
299
+ for subname in name:
300
+ value = value.get(subname, default)
301
+ if value is default:
302
+ break
303
+
304
+ if (
305
+ "writer" in cls.mapping["state"]
306
+ and value == default
307
+ and no_warning is False
308
+ ):
309
+ cls.mapping["state"]["writer"].warning(
310
+ "Key {} is not present in registry, returning default value "
311
+ "of {}".format(original_name, default)
312
+ )
313
+ return value
314
+
315
+ @classmethod
316
+ def unregister(cls, name):
317
+ r"""Remove an item from registry with key 'name'
318
+
319
+ Args:
320
+ name: Key which needs to be removed.
321
+ Usage::
322
+
323
+ from mmf.common.registry import registry
324
+
325
+ config = registry.unregister("config")
326
+ """
327
+ return cls.mapping["state"].pop(name, None)
328
+
329
+
330
+ registry = Registry()
minigpt4/common/utils.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import io
9
+ import json
10
+ import logging
11
+ import os
12
+ import pickle
13
+ import re
14
+ import shutil
15
+ import urllib
16
+ import urllib.error
17
+ import urllib.request
18
+ from typing import Optional
19
+ from urllib.parse import urlparse
20
+
21
+ import numpy as np
22
+ import pandas as pd
23
+ import yaml
24
+ from iopath.common.download import download
25
+ from iopath.common.file_io import file_lock, g_pathmgr
26
+ from minigpt4.common.registry import registry
27
+ from torch.utils.model_zoo import tqdm
28
+ from torchvision.datasets.utils import (
29
+ check_integrity,
30
+ download_file_from_google_drive,
31
+ extract_archive,
32
+ )
33
+
34
+
35
+ def now():
36
+ from datetime import datetime
37
+
38
+ return datetime.now().strftime("%Y%m%d%H%M")
39
+
40
+
41
+ def is_url(url_or_filename):
42
+ parsed = urlparse(url_or_filename)
43
+ return parsed.scheme in ("http", "https")
44
+
45
+
46
+ def get_cache_path(rel_path):
47
+ return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path))
48
+
49
+
50
+ def get_abs_path(rel_path):
51
+ return os.path.join(registry.get_path("library_root"), rel_path)
52
+
53
+
54
+ def load_json(filename):
55
+ with open(filename, "r") as f:
56
+ return json.load(f)
57
+
58
+
59
+ # The following are adapted from torchvision and vissl
60
+ # torchvision: https://github.com/pytorch/vision
61
+ # vissl: https://github.com/facebookresearch/vissl/blob/main/vissl/utils/download.py
62
+
63
+
64
+ def makedir(dir_path):
65
+ """
66
+ Create the directory if it does not exist.
67
+ """
68
+ is_success = False
69
+ try:
70
+ if not g_pathmgr.exists(dir_path):
71
+ g_pathmgr.mkdirs(dir_path)
72
+ is_success = True
73
+ except BaseException:
74
+ print(f"Error creating directory: {dir_path}")
75
+ return is_success
76
+
77
+
78
+ def get_redirected_url(url: str):
79
+ """
80
+ Given a URL, returns the URL it redirects to or the
81
+ original URL in case of no indirection
82
+ """
83
+ import requests
84
+
85
+ with requests.Session() as session:
86
+ with session.get(url, stream=True, allow_redirects=True) as response:
87
+ if response.history:
88
+ return response.url
89
+ else:
90
+ return url
91
+
92
+
93
+ def to_google_drive_download_url(view_url: str) -> str:
94
+ """
95
+ Utility function to transform a view URL of google drive
96
+ to a download URL for google drive
97
+ Example input:
98
+ https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view
99
+ Example output:
100
+ https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp
101
+ """
102
+ splits = view_url.split("/")
103
+ assert splits[-1] == "view"
104
+ file_id = splits[-2]
105
+ return f"https://drive.google.com/uc?export=download&id={file_id}"
106
+
107
+
108
+ def download_google_drive_url(url: str, output_path: str, output_file_name: str):
109
+ """
110
+ Download a file from google drive
111
+ Downloading an URL from google drive requires confirmation when
112
+ the file of the size is too big (google drive notifies that
113
+ anti-viral checks cannot be performed on such files)
114
+ """
115
+ import requests
116
+
117
+ with requests.Session() as session:
118
+
119
+ # First get the confirmation token and append it to the URL
120
+ with session.get(url, stream=True, allow_redirects=True) as response:
121
+ for k, v in response.cookies.items():
122
+ if k.startswith("download_warning"):
123
+ url = url + "&confirm=" + v
124
+
125
+ # Then download the content of the file
126
+ with session.get(url, stream=True, verify=True) as response:
127
+ makedir(output_path)
128
+ path = os.path.join(output_path, output_file_name)
129
+ total_size = int(response.headers.get("Content-length", 0))
130
+ with open(path, "wb") as file:
131
+ from tqdm import tqdm
132
+
133
+ with tqdm(total=total_size) as progress_bar:
134
+ for block in response.iter_content(
135
+ chunk_size=io.DEFAULT_BUFFER_SIZE
136
+ ):
137
+ file.write(block)
138
+ progress_bar.update(len(block))
139
+
140
+
141
+ def _get_google_drive_file_id(url: str) -> Optional[str]:
142
+ parts = urlparse(url)
143
+
144
+ if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
145
+ return None
146
+
147
+ match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
148
+ if match is None:
149
+ return None
150
+
151
+ return match.group("id")
152
+
153
+
154
+ def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
155
+ with open(filename, "wb") as fh:
156
+ with urllib.request.urlopen(
157
+ urllib.request.Request(url, headers={"User-Agent": "vissl"})
158
+ ) as response:
159
+ with tqdm(total=response.length) as pbar:
160
+ for chunk in iter(lambda: response.read(chunk_size), ""):
161
+ if not chunk:
162
+ break
163
+ pbar.update(chunk_size)
164
+ fh.write(chunk)
165
+
166
+
167
+ def download_url(
168
+ url: str,
169
+ root: str,
170
+ filename: Optional[str] = None,
171
+ md5: Optional[str] = None,
172
+ ) -> None:
173
+ """Download a file from a url and place it in root.
174
+ Args:
175
+ url (str): URL to download file from
176
+ root (str): Directory to place downloaded file in
177
+ filename (str, optional): Name to save the file under.
178
+ If None, use the basename of the URL.
179
+ md5 (str, optional): MD5 checksum of the download. If None, do not check
180
+ """
181
+ root = os.path.expanduser(root)
182
+ if not filename:
183
+ filename = os.path.basename(url)
184
+ fpath = os.path.join(root, filename)
185
+
186
+ makedir(root)
187
+
188
+ # check if file is already present locally
189
+ if check_integrity(fpath, md5):
190
+ print("Using downloaded and verified file: " + fpath)
191
+ return
192
+
193
+ # expand redirect chain if needed
194
+ url = get_redirected_url(url)
195
+
196
+ # check if file is located on Google Drive
197
+ file_id = _get_google_drive_file_id(url)
198
+ if file_id is not None:
199
+ return download_file_from_google_drive(file_id, root, filename, md5)
200
+
201
+ # download the file
202
+ try:
203
+ print("Downloading " + url + " to " + fpath)
204
+ _urlretrieve(url, fpath)
205
+ except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
206
+ if url[:5] == "https":
207
+ url = url.replace("https:", "http:")
208
+ print(
209
+ "Failed download. Trying https -> http instead."
210
+ " Downloading " + url + " to " + fpath
211
+ )
212
+ _urlretrieve(url, fpath)
213
+ else:
214
+ raise e
215
+
216
+ # check integrity of downloaded file
217
+ if not check_integrity(fpath, md5):
218
+ raise RuntimeError("File not found or corrupted.")
219
+
220
+
221
+ def download_and_extract_archive(
222
+ url: str,
223
+ download_root: str,
224
+ extract_root: Optional[str] = None,
225
+ filename: Optional[str] = None,
226
+ md5: Optional[str] = None,
227
+ remove_finished: bool = False,
228
+ ) -> None:
229
+ download_root = os.path.expanduser(download_root)
230
+ if extract_root is None:
231
+ extract_root = download_root
232
+ if not filename:
233
+ filename = os.path.basename(url)
234
+
235
+ download_url(url, download_root, filename, md5)
236
+
237
+ archive = os.path.join(download_root, filename)
238
+ print("Extracting {} to {}".format(archive, extract_root))
239
+ extract_archive(archive, extract_root, remove_finished)
240
+
241
+
242
+ def cache_url(url: str, cache_dir: str) -> str:
243
+ """
244
+ This implementation downloads the remote resource and caches it locally.
245
+ The resource will only be downloaded if not previously requested.
246
+ """
247
+ parsed_url = urlparse(url)
248
+ dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/")))
249
+ makedir(dirname)
250
+ filename = url.split("/")[-1]
251
+ cached = os.path.join(dirname, filename)
252
+ with file_lock(cached):
253
+ if not os.path.isfile(cached):
254
+ logging.info(f"Downloading {url} to {cached} ...")
255
+ cached = download(url, dirname, filename=filename)
256
+ logging.info(f"URL {url} cached in {cached}")
257
+ return cached
258
+
259
+
260
+ # TODO (prigoyal): convert this into RAII-style API
261
+ def create_file_symlink(file1, file2):
262
+ """
263
+ Simply create the symlinks for a given file1 to file2.
264
+ Useful during model checkpointing to symlinks to the
265
+ latest successful checkpoint.
266
+ """
267
+ try:
268
+ if g_pathmgr.exists(file2):
269
+ g_pathmgr.rm(file2)
270
+ g_pathmgr.symlink(file1, file2)
271
+ except Exception as e:
272
+ logging.info(f"Could NOT create symlink. Error: {e}")
273
+
274
+
275
+ def save_file(data, filename, append_to_json=True, verbose=True):
276
+ """
277
+ Common i/o utility to handle saving data to various file formats.
278
+ Supported:
279
+ .pkl, .pickle, .npy, .json
280
+ Specifically for .json, users have the option to either append (default)
281
+ or rewrite by passing in Boolean value to append_to_json.
282
+ """
283
+ if verbose:
284
+ logging.info(f"Saving data to file: {filename}")
285
+ file_ext = os.path.splitext(filename)[1]
286
+ if file_ext in [".pkl", ".pickle"]:
287
+ with g_pathmgr.open(filename, "wb") as fopen:
288
+ pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL)
289
+ elif file_ext == ".npy":
290
+ with g_pathmgr.open(filename, "wb") as fopen:
291
+ np.save(fopen, data)
292
+ elif file_ext == ".json":
293
+ if append_to_json:
294
+ with g_pathmgr.open(filename, "a") as fopen:
295
+ fopen.write(json.dumps(data, sort_keys=True) + "\n")
296
+ fopen.flush()
297
+ else:
298
+ with g_pathmgr.open(filename, "w") as fopen:
299
+ fopen.write(json.dumps(data, sort_keys=True) + "\n")
300
+ fopen.flush()
301
+ elif file_ext == ".yaml":
302
+ with g_pathmgr.open(filename, "w") as fopen:
303
+ dump = yaml.dump(data)
304
+ fopen.write(dump)
305
+ fopen.flush()
306
+ else:
307
+ raise Exception(f"Saving {file_ext} is not supported yet")
308
+
309
+ if verbose:
310
+ logging.info(f"Saved data to file: {filename}")
311
+
312
+
313
+ def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False):
314
+ """
315
+ Common i/o utility to handle loading data from various file formats.
316
+ Supported:
317
+ .pkl, .pickle, .npy, .json
318
+ For the npy files, we support reading the files in mmap_mode.
319
+ If the mmap_mode of reading is not successful, we load data without the
320
+ mmap_mode.
321
+ """
322
+ if verbose:
323
+ logging.info(f"Loading data from file: {filename}")
324
+
325
+ file_ext = os.path.splitext(filename)[1]
326
+ if file_ext == ".txt":
327
+ with g_pathmgr.open(filename, "r") as fopen:
328
+ data = fopen.readlines()
329
+ elif file_ext in [".pkl", ".pickle"]:
330
+ with g_pathmgr.open(filename, "rb") as fopen:
331
+ data = pickle.load(fopen, encoding="latin1")
332
+ elif file_ext == ".npy":
333
+ if mmap_mode:
334
+ try:
335
+ with g_pathmgr.open(filename, "rb") as fopen:
336
+ data = np.load(
337
+ fopen,
338
+ allow_pickle=allow_pickle,
339
+ encoding="latin1",
340
+ mmap_mode=mmap_mode,
341
+ )
342
+ except ValueError as e:
343
+ logging.info(
344
+ f"Could not mmap {filename}: {e}. Trying without g_pathmgr"
345
+ )
346
+ data = np.load(
347
+ filename,
348
+ allow_pickle=allow_pickle,
349
+ encoding="latin1",
350
+ mmap_mode=mmap_mode,
351
+ )
352
+ logging.info("Successfully loaded without g_pathmgr")
353
+ except Exception:
354
+ logging.info("Could not mmap without g_pathmgr. Trying without mmap")
355
+ with g_pathmgr.open(filename, "rb") as fopen:
356
+ data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
357
+ else:
358
+ with g_pathmgr.open(filename, "rb") as fopen:
359
+ data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
360
+ elif file_ext == ".json":
361
+ with g_pathmgr.open(filename, "r") as fopen:
362
+ data = json.load(fopen)
363
+ elif file_ext == ".yaml":
364
+ with g_pathmgr.open(filename, "r") as fopen:
365
+ data = yaml.load(fopen, Loader=yaml.FullLoader)
366
+ elif file_ext == ".csv":
367
+ with g_pathmgr.open(filename, "r") as fopen:
368
+ data = pd.read_csv(fopen)
369
+ else:
370
+ raise Exception(f"Reading from {file_ext} is not supported yet")
371
+ return data
372
+
373
+
374
+ def abspath(resource_path: str):
375
+ """
376
+ Make a path absolute, but take into account prefixes like
377
+ "http://" or "manifold://"
378
+ """
379
+ regex = re.compile(r"^\w+://")
380
+ if regex.match(resource_path) is None:
381
+ return os.path.abspath(resource_path)
382
+ else:
383
+ return resource_path
384
+
385
+
386
+ def makedir(dir_path):
387
+ """
388
+ Create the directory if it does not exist.
389
+ """
390
+ is_success = False
391
+ try:
392
+ if not g_pathmgr.exists(dir_path):
393
+ g_pathmgr.mkdirs(dir_path)
394
+ is_success = True
395
+ except BaseException:
396
+ logging.info(f"Error creating directory: {dir_path}")
397
+ return is_success
398
+
399
+
400
+ def is_url(input_url):
401
+ """
402
+ Check if an input string is a url. look for http(s):// and ignoring the case
403
+ """
404
+ is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None
405
+ return is_url
406
+
407
+
408
+ def cleanup_dir(dir):
409
+ """
410
+ Utility for deleting a directory. Useful for cleaning the storage space
411
+ that contains various training artifacts like checkpoints, data etc.
412
+ """
413
+ if os.path.exists(dir):
414
+ logging.info(f"Deleting directory: {dir}")
415
+ shutil.rmtree(dir)
416
+ logging.info(f"Deleted contents of directory: {dir}")
417
+
418
+
419
+ def get_file_size(filename):
420
+ """
421
+ Given a file, get the size of file in MB
422
+ """
423
+ size_in_mb = os.path.getsize(filename) / float(1024**2)
424
+ return size_in_mb
minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvalDemo.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ import sys
4
+ dataDir = '../../VQA'
5
+ sys.path.insert(0, '%s/PythonHelperTools/vqaTools' %(dataDir))
6
+ from vqa import VQA
7
+ from vqaEvaluation.vqaEval import VQAEval
8
+ import matplotlib.pyplot as plt
9
+ import skimage.io as io
10
+ import json
11
+ import random
12
+ import os
13
+
14
+ # set up file names and paths
15
+ versionType ='v2_' # this should be '' when using VQA v2.0 dataset
16
+ taskType ='OpenEnded' # 'OpenEnded' only for v2.0. 'OpenEnded' or 'MultipleChoice' for v1.0
17
+ dataType ='mscoco' # 'mscoco' only for v1.0. 'mscoco' for real and 'abstract_v002' for abstract for v1.0.
18
+ dataSubType ='train2014'
19
+ annFile ='%s/Annotations/%s%s_%s_annotations.json'%(dataDir, versionType, dataType, dataSubType)
20
+ quesFile ='%s/Questions/%s%s_%s_%s_questions.json'%(dataDir, versionType, taskType, dataType, dataSubType)
21
+ imgDir ='%s/Images/%s/%s/' %(dataDir, dataType, dataSubType)
22
+ resultType ='fake'
23
+ fileTypes = ['results', 'accuracy', 'evalQA', 'evalQuesType', 'evalAnsType']
24
+
25
+ # An example result json file has been provided in './Results' folder.
26
+
27
+ [resFile, accuracyFile, evalQAFile, evalQuesTypeFile, evalAnsTypeFile] = ['%s/Results/%s%s_%s_%s_%s_%s.json'%(dataDir, versionType, taskType, dataType, dataSubType, \
28
+ resultType, fileType) for fileType in fileTypes]
29
+
30
+ # create vqa object and vqaRes object
31
+ vqa = VQA(annFile, quesFile)
32
+ vqaRes = vqa.loadRes(resFile, quesFile)
33
+
34
+ # create vqaEval object by taking vqa and vqaRes
35
+ vqaEval = VQAEval(vqa, vqaRes, n=2) #n is precision of accuracy (number of places after decimal), default is 2
36
+
37
+ # evaluate results
38
+ """
39
+ If you have a list of question ids on which you would like to evaluate your results, pass it as a list to below function
40
+ By default it uses all the question ids in annotation file
41
+ """
42
+ vqaEval.evaluate()
43
+
44
+ # print accuracies
45
+ print "\n"
46
+ print "Overall Accuracy is: %.02f\n" %(vqaEval.accuracy['overall'])
47
+ print "Per Question Type Accuracy is the following:"
48
+ for quesType in vqaEval.accuracy['perQuestionType']:
49
+ print "%s : %.02f" %(quesType, vqaEval.accuracy['perQuestionType'][quesType])
50
+ print "\n"
51
+ print "Per Answer Type Accuracy is the following:"
52
+ for ansType in vqaEval.accuracy['perAnswerType']:
53
+ print "%s : %.02f" %(ansType, vqaEval.accuracy['perAnswerType'][ansType])
54
+ print "\n"
55
+ # demo how to use evalQA to retrieve low score result
56
+ evals = [quesId for quesId in vqaEval.evalQA if vqaEval.evalQA[quesId]<35] #35 is per question percentage accuracy
57
+ if len(evals) > 0:
58
+ print 'ground truth answers'
59
+ randomEval = random.choice(evals)
60
+ randomAnn = vqa.loadQA(randomEval)
61
+ vqa.showQA(randomAnn)
62
+
63
+ print '\n'
64
+ print 'generated answer (accuracy %.02f)'%(vqaEval.evalQA[randomEval])
65
+ ann = vqaRes.loadQA(randomEval)[0]
66
+ print "Answer: %s\n" %(ann['answer'])
67
+
68
+ imgId = randomAnn[0]['image_id']
69
+ imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg'
70
+ if os.path.isfile(imgDir + imgFilename):
71
+ I = io.imread(imgDir + imgFilename)
72
+ plt.imshow(I)
73
+ plt.axis('off')
74
+ plt.show()
75
+
76
+ # plot accuracy for various question types
77
+ plt.bar(range(len(vqaEval.accuracy['perQuestionType'])), vqaEval.accuracy['perQuestionType'].values(), align='center')
78
+ plt.xticks(range(len(vqaEval.accuracy['perQuestionType'])), vqaEval.accuracy['perQuestionType'].keys(), rotation='0',fontsize=10)
79
+ plt.title('Per Question Type Accuracy', fontsize=10)
80
+ plt.xlabel('Question Types', fontsize=10)
81
+ plt.ylabel('Accuracy', fontsize=10)
82
+ plt.show()
83
+
84
+ # save evaluation results to ./Results folder
85
+ json.dump(vqaEval.accuracy, open(accuracyFile, 'w'))
86
+ json.dump(vqaEval.evalQA, open(evalQAFile, 'w'))
87
+ json.dump(vqaEval.evalQuesType, open(evalQuesTypeFile, 'w'))
88
+ json.dump(vqaEval.evalAnsType, open(evalAnsTypeFile, 'w'))
89
+
minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ author='aagrawal'
minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/vqaEval.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+
3
+ __author__='aagrawal'
4
+
5
+ import re
6
+ # This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
7
+ # (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py).
8
+ import sys
9
+
10
+
11
+ class VQAEval:
12
+ def __init__(self, vqa, vqaRes, n=2):
13
+ self.n = n
14
+ self.accuracy = {}
15
+ self.evalQA = {}
16
+ self.evalQuesType = {}
17
+ self.evalAnsType = {}
18
+ self.vqa = vqa
19
+ self.vqaRes = vqaRes
20
+ self.params = {'question_id': vqa.getQuesIds()}
21
+ self.contractions = {"aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": "could've", "couldnt": "couldn't", \
22
+ "couldn'tve": "couldn't've", "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": "doesn't", "dont": "don't", "hadnt": "hadn't", \
23
+ "hadnt've": "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": "haven't", "hed": "he'd", "hed've": "he'd've", \
24
+ "he'dve": "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", \
25
+ "Im": "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's", \
26
+ "maam": "ma'am", "mightnt": "mightn't", "mightnt've": "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've", \
27
+ "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't", \
28
+ "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": "she'd've", \
29
+ "she's": "she's", "shouldve": "should've", "shouldnt": "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": "shouldn't've", \
30
+ "somebody'd": "somebodyd", "somebodyd've": "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": "somebody'll", \
31
+ "somebodys": "somebody's", "someoned": "someone'd", "someoned've": "someone'd've", "someone'dve": "someone'd've", \
32
+ "someonell": "someone'll", "someones": "someone's", "somethingd": "something'd", "somethingd've": "something'd've", \
33
+ "something'dve": "something'd've", "somethingll": "something'll", "thats": "that's", "thered": "there'd", "thered've": "there'd've", \
34
+ "there'dve": "there'd've", "therere": "there're", "theres": "there's", "theyd": "they'd", "theyd've": "they'd've", \
35
+ "they'dve": "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": "they've", "twas": "'twas", "wasnt": "wasn't", \
36
+ "wed've": "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": "weren't", "whatll": "what'll", "whatre": "what're", \
37
+ "whats": "what's", "whatve": "what've", "whens": "when's", "whered": "where'd", "wheres": "where's", "whereve": "where've", \
38
+ "whod": "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll", \
39
+ "whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've", \
40
+ "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've", \
41
+ "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": "you'd", "youd've": "you'd've", "you'dve": "you'd've", \
42
+ "youll": "you'll", "youre": "you're", "youve": "you've"}
43
+ self.manualMap = { 'none': '0',
44
+ 'zero': '0',
45
+ 'one': '1',
46
+ 'two': '2',
47
+ 'three': '3',
48
+ 'four': '4',
49
+ 'five': '5',
50
+ 'six': '6',
51
+ 'seven': '7',
52
+ 'eight': '8',
53
+ 'nine': '9',
54
+ 'ten': '10'
55
+ }
56
+ self.articles = ['a',
57
+ 'an',
58
+ 'the'
59
+ ]
60
+
61
+
62
+ self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
63
+ self.commaStrip = re.compile("(\d)(\,)(\d)")
64
+ self.punct = [';', r"/", '[', ']', '"', '{', '}',
65
+ '(', ')', '=', '+', '\\', '_', '-',
66
+ '>', '<', '@', '`', ',', '?', '!']
67
+
68
+
69
+ def evaluate(self, quesIds=None):
70
+ if quesIds == None:
71
+ quesIds = [quesId for quesId in self.params['question_id']]
72
+ gts = {}
73
+ res = {}
74
+ for quesId in quesIds:
75
+ gts[quesId] = self.vqa.qa[quesId]
76
+ res[quesId] = self.vqaRes.qa[quesId]
77
+
78
+ # =================================================
79
+ # Compute accuracy
80
+ # =================================================
81
+ accQA = []
82
+ accQuesType = {}
83
+ accAnsType = {}
84
+ # print "computing accuracy"
85
+ step = 0
86
+ for quesId in quesIds:
87
+ for ansDic in gts[quesId]['answers']:
88
+ ansDic['answer'] = ansDic['answer'].replace('\n', ' ')
89
+ ansDic['answer'] = ansDic['answer'].replace('\t', ' ')
90
+ ansDic['answer'] = ansDic['answer'].strip()
91
+ resAns = res[quesId]['answer']
92
+ resAns = resAns.replace('\n', ' ')
93
+ resAns = resAns.replace('\t', ' ')
94
+ resAns = resAns.strip()
95
+ gtAcc = []
96
+ gtAnswers = [ans['answer'] for ans in gts[quesId]['answers']]
97
+
98
+ if len(set(gtAnswers)) > 1:
99
+ for ansDic in gts[quesId]['answers']:
100
+ ansDic['answer'] = self.processPunctuation(ansDic['answer'])
101
+ ansDic['answer'] = self.processDigitArticle(ansDic['answer'])
102
+ resAns = self.processPunctuation(resAns)
103
+ resAns = self.processDigitArticle(resAns)
104
+
105
+ for gtAnsDatum in gts[quesId]['answers']:
106
+ otherGTAns = [item for item in gts[quesId]['answers'] if item!=gtAnsDatum]
107
+ matchingAns = [item for item in otherGTAns if item['answer'].lower()==resAns.lower()]
108
+ acc = min(1, float(len(matchingAns))/3)
109
+ gtAcc.append(acc)
110
+ quesType = gts[quesId]['question_type']
111
+ ansType = gts[quesId]['answer_type']
112
+ avgGTAcc = float(sum(gtAcc))/len(gtAcc)
113
+ accQA.append(avgGTAcc)
114
+ if quesType not in accQuesType:
115
+ accQuesType[quesType] = []
116
+ accQuesType[quesType].append(avgGTAcc)
117
+ if ansType not in accAnsType:
118
+ accAnsType[ansType] = []
119
+ accAnsType[ansType].append(avgGTAcc)
120
+ self.setEvalQA(quesId, avgGTAcc)
121
+ self.setEvalQuesType(quesId, quesType, avgGTAcc)
122
+ self.setEvalAnsType(quesId, ansType, avgGTAcc)
123
+ if step%100 == 0:
124
+ self.updateProgress(step/float(len(quesIds)))
125
+ step = step + 1
126
+
127
+ self.setAccuracy(accQA, accQuesType, accAnsType)
128
+ # print "Done computing accuracy"
129
+
130
+ def processPunctuation(self, inText):
131
+ outText = inText
132
+ for p in self.punct:
133
+ if (p + ' ' in inText or ' ' + p in inText) or (re.search(self.commaStrip, inText) != None):
134
+ outText = outText.replace(p, '')
135
+ else:
136
+ outText = outText.replace(p, ' ')
137
+ outText = self.periodStrip.sub("",
138
+ outText,
139
+ re.UNICODE)
140
+ return outText
141
+
142
+ def processDigitArticle(self, inText):
143
+ outText = []
144
+ tempText = inText.lower().split()
145
+ for word in tempText:
146
+ word = self.manualMap.setdefault(word, word)
147
+ if word not in self.articles:
148
+ outText.append(word)
149
+ else:
150
+ pass
151
+ for wordId, word in enumerate(outText):
152
+ if word in self.contractions:
153
+ outText[wordId] = self.contractions[word]
154
+ outText = ' '.join(outText)
155
+ return outText
156
+
157
+ def setAccuracy(self, accQA, accQuesType, accAnsType):
158
+ self.accuracy['overall'] = round(100*float(sum(accQA))/len(accQA), self.n)
159
+ self.accuracy['perQuestionType'] = {quesType: round(100*float(sum(accQuesType[quesType]))/len(accQuesType[quesType]), self.n) for quesType in accQuesType}
160
+ self.accuracy['perAnswerType'] = {ansType: round(100*float(sum(accAnsType[ansType]))/len(accAnsType[ansType]), self.n) for ansType in accAnsType}
161
+
162
+ def setEvalQA(self, quesId, acc):
163
+ self.evalQA[quesId] = round(100*acc, self.n)
164
+
165
+ def setEvalQuesType(self, quesId, quesType, acc):
166
+ if quesType not in self.evalQuesType:
167
+ self.evalQuesType[quesType] = {}
168
+ self.evalQuesType[quesType][quesId] = round(100*acc, self.n)
169
+
170
+ def setEvalAnsType(self, quesId, ansType, acc):
171
+ if ansType not in self.evalAnsType:
172
+ self.evalAnsType[ansType] = {}
173
+ self.evalAnsType[ansType][quesId] = round(100*acc, self.n)
174
+
175
+ def updateProgress(self, progress):
176
+ barLength = 20
177
+ status = ""
178
+ if isinstance(progress, int):
179
+ progress = float(progress)
180
+ if not isinstance(progress, float):
181
+ progress = 0
182
+ status = "error: progress var must be float\r\n"
183
+ if progress < 0:
184
+ progress = 0
185
+ status = "Halt...\r\n"
186
+ if progress >= 1:
187
+ progress = 1
188
+ status = "Done...\r\n"
189
+ block = int(round(barLength*progress))
190
+ text = "\rFinshed Percent: [{0}] {1}% {2}".format( "#"*block + "-"*(barLength-block), int(progress*100), status)
191
+ sys.stdout.write(text)
192
+ sys.stdout.flush()
minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaDemo.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ from vqaTools.vqa import VQA
4
+ import random
5
+ import skimage.io as io
6
+ import matplotlib.pyplot as plt
7
+ import os
8
+
9
+ dataDir ='../../VQA'
10
+ versionType ='v2_' # this should be '' when using VQA v2.0 dataset
11
+ taskType ='OpenEnded' # 'OpenEnded' only for v2.0. 'OpenEnded' or 'MultipleChoice' for v1.0
12
+ dataType ='mscoco' # 'mscoco' only for v1.0. 'mscoco' for real and 'abstract_v002' for abstract for v1.0.
13
+ dataSubType ='train2014'
14
+ annFile ='%s/Annotations/%s%s_%s_annotations.json'%(dataDir, versionType, dataType, dataSubType)
15
+ quesFile ='%s/Questions/%s%s_%s_%s_questions.json'%(dataDir, versionType, taskType, dataType, dataSubType)
16
+ imgDir = '%s/Images/%s/%s/' %(dataDir, dataType, dataSubType)
17
+
18
+ # initialize VQA api for QA annotations
19
+ vqa=VQA(annFile, quesFile)
20
+
21
+ # load and display QA annotations for given question types
22
+ """
23
+ All possible quesTypes for abstract and mscoco has been provided in respective text files in ../QuestionTypes/ folder.
24
+ """
25
+ annIds = vqa.getQuesIds(quesTypes='how many');
26
+ anns = vqa.loadQA(annIds)
27
+ randomAnn = random.choice(anns)
28
+ vqa.showQA([randomAnn])
29
+ imgId = randomAnn['image_id']
30
+ imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg'
31
+ if os.path.isfile(imgDir + imgFilename):
32
+ I = io.imread(imgDir + imgFilename)
33
+ plt.imshow(I)
34
+ plt.axis('off')
35
+ plt.show()
36
+
37
+ # load and display QA annotations for given answer types
38
+ """
39
+ ansTypes can be one of the following
40
+ yes/no
41
+ number
42
+ other
43
+ """
44
+ annIds = vqa.getQuesIds(ansTypes='yes/no');
45
+ anns = vqa.loadQA(annIds)
46
+ randomAnn = random.choice(anns)
47
+ vqa.showQA([randomAnn])
48
+ imgId = randomAnn['image_id']
49
+ imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg'
50
+ if os.path.isfile(imgDir + imgFilename):
51
+ I = io.imread(imgDir + imgFilename)
52
+ plt.imshow(I)
53
+ plt.axis('off')
54
+ plt.show()
55
+
56
+ # load and display QA annotations for given images
57
+ """
58
+ Usage: vqa.getImgIds(quesIds=[], quesTypes=[], ansTypes=[])
59
+ Above method can be used to retrieve imageIds for given question Ids or given question types or given answer types.
60
+ """
61
+ ids = vqa.getImgIds()
62
+ annIds = vqa.getQuesIds(imgIds=random.sample(ids,5));
63
+ anns = vqa.loadQA(annIds)
64
+ randomAnn = random.choice(anns)
65
+ vqa.showQA([randomAnn])
66
+ imgId = randomAnn['image_id']
67
+ imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg'
68
+ if os.path.isfile(imgDir + imgFilename):
69
+ I = io.imread(imgDir + imgFilename)
70
+ plt.imshow(I)
71
+ plt.axis('off')
72
+ plt.show()
73
+
minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __author__ = 'aagrawal'
minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/vqa.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __author__ = 'aagrawal'
2
+ __version__ = '0.9'
3
+
4
+ # Interface for accessing the VQA dataset.
5
+
6
+ # This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
7
+ # (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py).
8
+
9
+ # The following functions are defined:
10
+ # VQA - VQA class that loads VQA annotation file and prepares data structures.
11
+ # getQuesIds - Get question ids that satisfy given filter conditions.
12
+ # getImgIds - Get image ids that satisfy given filter conditions.
13
+ # loadQA - Load questions and answers with the specified question ids.
14
+ # showQA - Display the specified questions and answers.
15
+ # loadRes - Load result file and create result object.
16
+
17
+ # Help on each function can be accessed by: "help(COCO.function)"
18
+
19
+ import json
20
+ import datetime
21
+ import copy
22
+
23
+
24
+ class VQA:
25
+ def __init__(self, annotation_file=None, question_file=None):
26
+ """
27
+ Constructor of VQA helper class for reading and visualizing questions and answers.
28
+ :param annotation_file (str): location of VQA annotation file
29
+ :return:
30
+ """
31
+ # load dataset
32
+ self.dataset = {}
33
+ self.questions = {}
34
+ self.qa = {}
35
+ self.qqa = {}
36
+ self.imgToQA = {}
37
+ if not annotation_file == None and not question_file == None:
38
+ # print 'loading VQA annotations and questions into memory...'
39
+ time_t = datetime.datetime.utcnow()
40
+ dataset = json.load(open(annotation_file, 'r'))
41
+ questions = json.load(open(question_file, 'r'))
42
+ # print datetime.datetime.utcnow() - time_t
43
+ self.dataset = dataset
44
+ self.questions = questions
45
+ self.createIndex()
46
+
47
+ def createIndex(self):
48
+ imgToQA = {ann['image_id']: [] for ann in self.dataset['annotations']}
49
+ qa = {ann['question_id']: [] for ann in self.dataset['annotations']}
50
+ qqa = {ann['question_id']: [] for ann in self.dataset['annotations']}
51
+ for ann in self.dataset['annotations']:
52
+ imgToQA[ann['image_id']] += [ann]
53
+ qa[ann['question_id']] = ann
54
+ for ques in self.questions['questions']:
55
+ qqa[ques['question_id']] = ques
56
+ # print 'index created!'
57
+
58
+ # create class members
59
+ self.qa = qa
60
+ self.qqa = qqa
61
+ self.imgToQA = imgToQA
62
+
63
+ def info(self):
64
+ """
65
+ Print information about the VQA annotation file.
66
+ :return:
67
+ """
68
+
69
+ # for key, value in self.datset['info'].items():
70
+ # print '%s: %s'%(key, value)
71
+
72
+ def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]):
73
+ """
74
+ Get question ids that satisfy given filter conditions. default skips that filter
75
+ :param imgIds (int array) : get question ids for given imgs
76
+ quesTypes (str array) : get question ids for given question types
77
+ ansTypes (str array) : get question ids for given answer types
78
+ :return: ids (int array) : integer array of question ids
79
+ """
80
+ imgIds = imgIds if type(imgIds) == list else [imgIds]
81
+ quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
82
+ ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
83
+
84
+ if len(imgIds) == len(quesTypes) == len(ansTypes) == 0:
85
+ anns = self.dataset['annotations']
86
+ else:
87
+ if not len(imgIds) == 0:
88
+ anns = sum([self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA], [])
89
+ else:
90
+ anns = self.dataset['annotations']
91
+ anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes]
92
+ anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes]
93
+ ids = [ann['question_id'] for ann in anns]
94
+ return ids
95
+
96
+ def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]):
97
+ """
98
+ Get image ids that satisfy given filter conditions. default skips that filter
99
+ :param quesIds (int array) : get image ids for given question ids
100
+ quesTypes (str array) : get image ids for given question types
101
+ ansTypes (str array) : get image ids for given answer types
102
+ :return: ids (int array) : integer array of image ids
103
+ """
104
+ quesIds = quesIds if type(quesIds) == list else [quesIds]
105
+ quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
106
+ ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
107
+
108
+ if len(quesIds) == len(quesTypes) == len(ansTypes) == 0:
109
+ anns = self.dataset['annotations']
110
+ else:
111
+ if not len(quesIds) == 0:
112
+ anns = sum([self.qa[quesId] for quesId in quesIds if quesId in self.qa], [])
113
+ else:
114
+ anns = self.dataset['annotations']
115
+ anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes]
116
+ anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes]
117
+ ids = [ann['image_id'] for ann in anns]
118
+ return ids
119
+
120
+ def loadQA(self, ids=[]):
121
+ """
122
+ Load questions and answers with the specified question ids.
123
+ :param ids (int array) : integer ids specifying question ids
124
+ :return: qa (object array) : loaded qa objects
125
+ """
126
+ if type(ids) == list:
127
+ return [self.qa[id] for id in ids]
128
+ elif type(ids) == int:
129
+ return [self.qa[ids]]
130
+
131
+ def showQA(self, anns):
132
+ """
133
+ Display the specified annotations.
134
+ :param anns (array of object): annotations to display
135
+ :return: None
136
+ """
137
+ if len(anns) == 0:
138
+ return 0
139
+ for ann in anns:
140
+ quesId = ann['question_id']
141
+ print("Question: %s" % (self.qqa[quesId]['question']))
142
+ for ans in ann['answers']:
143
+ print("Answer %d: %s" % (ans['answer_id'], ans['answer']))
144
+
145
+ def loadRes(self, resFile, quesFile):
146
+ """
147
+ Load result file and return a result object.
148
+ :param resFile (str) : file name of result file
149
+ :return: res (obj) : result api object
150
+ """
151
+ res = VQA()
152
+ res.questions = json.load(open(quesFile))
153
+ res.dataset['info'] = copy.deepcopy(self.questions['info'])
154
+ res.dataset['task_type'] = copy.deepcopy(self.questions['task_type'])
155
+ res.dataset['data_type'] = copy.deepcopy(self.questions['data_type'])
156
+ res.dataset['data_subtype'] = copy.deepcopy(self.questions['data_subtype'])
157
+ res.dataset['license'] = copy.deepcopy(self.questions['license'])
158
+
159
+ # print 'Loading and preparing results... '
160
+ time_t = datetime.datetime.utcnow()
161
+ anns = json.load(open(resFile))
162
+ assert type(anns) == list, 'results is not an array of objects'
163
+ annsQuesIds = [ann['question_id'] for ann in anns]
164
+ assert set(annsQuesIds) == set(self.getQuesIds()), \
165
+ 'Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file.'
166
+ for ann in anns:
167
+ quesId = ann['question_id']
168
+ if res.dataset['task_type'] == 'Multiple Choice':
169
+ assert ann['answer'] in self.qqa[quesId][
170
+ 'multiple_choices'], 'predicted answer is not one of the multiple choices'
171
+ qaAnn = self.qa[quesId]
172
+ ann['image_id'] = qaAnn['image_id']
173
+ ann['question_type'] = qaAnn['question_type']
174
+ ann['answer_type'] = qaAnn['answer_type']
175
+ # print 'DONE (t=%0.2fs)'%((datetime.datetime.utcnow() - time_t).total_seconds())
176
+
177
+ res.dataset['annotations'] = anns
178
+ res.createIndex()
179
+ return res
minigpt4/common/vqa_tools/VQA/README.md ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Python API and Evaluation Code for v2.0 and v1.0 releases of the VQA dataset.
2
+ ===================
3
+ ## VQA v2.0 release ##
4
+ This release consists of
5
+ - Real
6
+ - 82,783 MS COCO training images, 40,504 MS COCO validation images and 81,434 MS COCO testing images (images are obtained from [MS COCO website] (http://mscoco.org/dataset/#download))
7
+ - 443,757 questions for training, 214,354 questions for validation and 447,793 questions for testing
8
+ - 4,437,570 answers for training and 2,143,540 answers for validation (10 per question)
9
+
10
+ There is only one type of task
11
+ - Open-ended task
12
+
13
+ ## VQA v1.0 release ##
14
+ This release consists of
15
+ - Real
16
+ - 82,783 MS COCO training images, 40,504 MS COCO validation images and 81,434 MS COCO testing images (images are obtained from [MS COCO website] (http://mscoco.org/dataset/#download))
17
+ - 248,349 questions for training, 121,512 questions for validation and 244,302 questions for testing (3 per image)
18
+ - 2,483,490 answers for training and 1,215,120 answers for validation (10 per question)
19
+ - Abstract
20
+ - 20,000 training images, 10,000 validation images and 20,000 MS COCO testing images
21
+ - 60,000 questions for training, 30,000 questions for validation and 60,000 questions for testing (3 per image)
22
+ - 600,000 answers for training and 300,000 answers for validation (10 per question)
23
+
24
+ There are two types of tasks
25
+ - Open-ended task
26
+ - Multiple-choice task (18 choices per question)
27
+
28
+ ## Requirements ##
29
+ - python 2.7
30
+ - scikit-image (visit [this page](http://scikit-image.org/docs/dev/install.html) for installation)
31
+ - matplotlib (visit [this page](http://matplotlib.org/users/installing.html) for installation)
32
+
33
+ ## Files ##
34
+ ./Questions
35
+ - For v2.0, download the question files from the [VQA download page](http://www.visualqa.org/download.html), extract them and place in this folder.
36
+ - For v1.0, both real and abstract, question files can be found on the [VQA v1 download page](http://www.visualqa.org/vqa_v1_download.html).
37
+ - Question files from Beta v0.9 release (123,287 MSCOCO train and val images, 369,861 questions, 3,698,610 answers) can be found below
38
+ - [training question files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Questions_Train_mscoco.zip)
39
+ - [validation question files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Questions_Val_mscoco.zip)
40
+ - Question files from Beta v0.1 release (10k MSCOCO images, 30k questions, 300k answers) can be found [here](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.1/Questions_Train_mscoco.zip).
41
+
42
+ ./Annotations
43
+ - For v2.0, download the annotations files from the [VQA download page](http://www.visualqa.org/download.html), extract them and place in this folder.
44
+ - For v1.0, for both real and abstract, annotation files can be found on the [VQA v1 download page](http://www.visualqa.org/vqa_v1_download.html).
45
+ - Annotation files from Beta v0.9 release (123,287 MSCOCO train and val images, 369,861 questions, 3,698,610 answers) can be found below
46
+ - [training annotation files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Annotations_Train_mscoco.zip)
47
+ - [validation annotation files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Annotations_Val_mscoco.zip)
48
+ - Annotation files from Beta v0.1 release (10k MSCOCO images, 30k questions, 300k answers) can be found [here](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.1/Annotations_Train_mscoco.zip).
49
+
50
+ ./Images
51
+ - For real, create a directory with name mscoco inside this directory. For each of train, val and test, create directories with names train2014, val2014 and test2015 respectively inside mscoco directory, download respective images from [MS COCO website](http://mscoco.org/dataset/#download) and place them in respective folders.
52
+ - For abstract, create a directory with name abstract_v002 inside this directory. For each of train, val and test, create directories with names train2015, val2015 and test2015 respectively inside abstract_v002 directory, download respective images from [VQA download page](http://www.visualqa.org/download.html) and place them in respective folders.
53
+
54
+ ./PythonHelperTools
55
+ - This directory contains the Python API to read and visualize the VQA dataset
56
+ - vqaDemo.py (demo script)
57
+ - vqaTools (API to read and visualize data)
58
+
59
+ ./PythonEvaluationTools
60
+ - This directory contains the Python evaluation code
61
+ - vqaEvalDemo.py (evaluation demo script)
62
+ - vqaEvaluation (evaluation code)
63
+
64
+ ./Results
65
+ - OpenEnded_mscoco_train2014_fake_results.json (an example of a fake results file for v1.0 to run the demo)
66
+ - Visit [VQA evaluation page] (http://visualqa.org/evaluation) for more details.
67
+
68
+ ./QuestionTypes
69
+ - This directory contains the following lists of question types for both real and abstract questions (question types are unchanged from v1.0 to v2.0). In a list, if there are question types of length n+k and length n with the same first n words, then the question type of length n does not include questions that belong to the question type of length n+k.
70
+ - mscoco_question_types.txt
71
+ - abstract_v002_question_types.txt
72
+
73
+ ## References ##
74
+ - [VQA: Visual Question Answering](http://visualqa.org/)
75
+ - [Microsoft COCO](http://mscoco.org/)
76
+
77
+ ## Developers ##
78
+ - Aishwarya Agrawal (Virginia Tech)
79
+ - Code for API is based on [MSCOCO API code](https://github.com/pdollar/coco).
80
+ - The format of the code for evaluation is based on [MSCOCO evaluation code](https://github.com/tylin/coco-caption).
minigpt4/common/vqa_tools/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ __author__ = "aagrawal"
minigpt4/common/vqa_tools/aokvqa/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2022 Allen Institute for Artificial Intelligence
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
minigpt4/common/vqa_tools/aokvqa/README.md ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # A-OKVQA
2
+
3
+ Official repository for **A-OKVQA: A Benchmark for Visual Question Answering using World Knowledge**.
4
+
5
+ Links: [[Paper]](https://arxiv.org/abs/2206.01718) [[Website]](http://a-okvqa.allenai.org) [[Leaderboard]](https://leaderboard.allenai.org/a-okvqa/submissions/public)
6
+
7
+ ### Abstract
8
+
9
+ The Visual Question Answering (VQA) task aspires to provide a meaningful testbed for the development of AI models that can jointly reason over visual and natural language inputs. Despite a proliferation of VQA datasets, this goal is hindered by a set of common limitations. These include a reliance on relatively simplistic questions that are repetitive in both concepts and linguistic structure, little world knowledge needed outside of the paired image, and limited reasoning required to arrive at the correct answer. We introduce A-OKVQA, a crowdsourced dataset composed of a diverse set of about 25K questions requiring a broad base of commonsense and world knowledge to answer. In contrast to the existing knowledge-based VQA datasets, the questions generally cannot be answered by simply querying a knowledge base, and instead require some form of commonsense reasoning about the scene depicted in the image. We demonstrate the potential of this new dataset through a detailed analysis of its contents and baseline performance measurements over a variety of state-of-the-art vision–language models.
10
+
11
+ ![dataset_web](https://user-images.githubusercontent.com/28768645/170799740-f0d9ea60-6aff-4322-98d5-cae8e05983f4.svg)
12
+
13
+ <hr>
14
+
15
+ #### Table of Contents
16
+
17
+ - [Getting started](#getting-started)
18
+ * [Downloading the dataset](#downloading-the-dataset)
19
+ - [Evaluation & Leaderboard](#evaluation)
20
+ - [Codebase](#codebase)
21
+ * [Preparing data](#preparing-data)
22
+ * [Models and Predictions](#models-and-predictions)
23
+
24
+ <hr>
25
+
26
+ ## Getting started
27
+
28
+ ```bash
29
+ git clone --single-branch --recurse-submodules https://github.com/allenai/aokvqa.git
30
+
31
+ cd aokvqa
32
+ export PYTHONPATH=.
33
+
34
+ conda env create --name aokvqa
35
+ conda activate aokvqa
36
+ ```
37
+
38
+ ### Downloading the dataset
39
+
40
+ ```bash
41
+ export AOKVQA_DIR=./datasets/aokvqa/
42
+ mkdir -p ${AOKVQA_DIR}
43
+
44
+ curl -fsSL https://prior-datasets.s3.us-east-2.amazonaws.com/aokvqa/aokvqa_v1p0.tar.gz | tar xvz -C ${AOKVQA_DIR}
45
+ ```
46
+
47
+ <details> <summary><b>Downloading COCO 2017</b></summary>
48
+
49
+ ```bash
50
+ export COCO_DIR=./datasets/coco/
51
+ mkdir -p ${COCO_DIR}
52
+
53
+ for split in train val test; do
54
+ wget "http://images.cocodataset.org/zips/${split}2017.zip"
55
+ unzip "${split}2017.zip" -d ${COCO_DIR}; rm "${split}2017.zip"
56
+ done
57
+
58
+ wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip
59
+ unzip annotations_trainval2017.zip -d ${COCO_DIR}; rm annotations_trainval2017.zip
60
+ ```
61
+
62
+ </details>
63
+
64
+ Loading our dataset is easy! Just grab our [load_aokvqa.py](https://github.com/allenai/aokvqa/blob/main/load_aokvqa.py) file and refer to the following code.
65
+
66
+ ```python
67
+ import os
68
+ aokvqa_dir = os.getenv('AOKVQA_DIR')
69
+
70
+ from load_aokvqa import load_aokvqa, get_coco_path
71
+ train_dataset = load_aokvqa(aokvqa_dir, 'train') # also 'val' or 'test'
72
+ ```
73
+
74
+ <details> <summary><b>Example dataset entry</b></summary>
75
+
76
+ ```python
77
+ dataset_example = train_dataset[0]
78
+
79
+ print(dataset_example['question_id'])
80
+ # 22MexNkBPpdZGX6sxbxVBH
81
+
82
+ coco_dir = os.getenv('COCO_DIR')
83
+ image_path = get_coco_path('train', dataset_example['image_id'], coco_dir)
84
+ print(image_path)
85
+ # ./datasets/coco/train2017/000000299207.jpg
86
+
87
+ print(dataset_example['question'])
88
+ print(dataset_example['choices'])
89
+ # What is the man by the bags awaiting?
90
+ # ['skateboarder', 'train', 'delivery', 'cab']
91
+
92
+ correct_choice = dataset_example['choices'][ dataset_example['correct_choice_idx'] ]
93
+ # Corrrect: cab
94
+
95
+ print(dataset_example['rationales'][0])
96
+ # A train would not be on the street, he would not have luggage waiting for a delivery, and the skateboarder is there and not paying attention to him so a cab is the only possible answer.
97
+ ```
98
+
99
+ </details>
100
+
101
+ ## Evaluation
102
+
103
+ Please prepare `predictions_{split}.json` files (for `split: {val,test}`) in the format below. You may omit either `multiple_choice` or `direct_answer` field if you only want to evaluate one setting.
104
+
105
+ ```python
106
+ {
107
+ '<question_id>' : {
108
+ 'multiple_choice' : '<prediction>',
109
+ 'direct_answer' : '<prediction>'
110
+ }
111
+ }
112
+ ```
113
+
114
+ You can run evaluation on the validation set as follows.
115
+
116
+ ```bash
117
+ python evaluation/eval_predictions.py --aokvqa-dir ${AOKVQA_DIR} --split val --preds ./predictions_val.json
118
+ ```
119
+
120
+ ### Leaderboard
121
+
122
+ You may submit `predictions_test.json` to the [leaderboard](https://leaderboard.allenai.org/a-okvqa/submissions/get-started).
123
+
124
+ ## Codebase
125
+
126
+ We provide all code and pretrained models necessary to replicate our experiments for Large-Scale Pretrained Models (sec. 5.2) and Rationale Generation (sec. 5.3).
127
+
128
+ ### Preparing data
129
+
130
+ ```bash
131
+ export FEATURES_DIR=./features/
132
+ mkdir -p ${FEATURES_DIR}
133
+ ```
134
+
135
+ You can compute CLIP features for our vocabulary and dataset. These are most commonly used by our other experiments.
136
+
137
+ ```bash
138
+ python data_scripts/encode_vocab_clip.py --vocab ${AOKVQA_DIR}/large_vocab_train.csv --model-type ViT-B/32 --out ${FEATURES_DIR}/clip-ViT-B-32_large_vocab.pt
139
+
140
+ for split in train val test; do
141
+ python data_scripts/extract_clip_features.py --aokvqa-dir ${AOKVQA_DIR} --coco-dir ${COCO_DIR} --split ${split} --model-type ViT-B/32 --out ${FEATURES_DIR}/clip-ViT-B-32_${split}.pt
142
+ done
143
+ ```
144
+
145
+ <details> <summary><b>For training ClipCap with a transformer mapping network</b></summary>
146
+
147
+ If you want to train our ClipCap models with the transformer mapping network (instead of an MLP, like we do), you'll also need to run `extract_clip_features.py` with `--model-type RN50x4`.
148
+
149
+ </details>
150
+
151
+ <details> <summary><b>For ResNet and BERT input features</b></summary>
152
+
153
+ Our ResNet and BERT classification experiments require these respective features instead of CLIP. To generate these, please run the following commands:
154
+
155
+ ```bash
156
+ # ResNet
157
+ for split in train val test; do
158
+ python data_scripts/extract_resnet_features.py --aokvqa-dir ${AOKVQA_DIR} --coco-dir ${COCO_DIR} --split ${split} --out ${FEATURES_DIR}/resnet_${split}.pt
159
+ done
160
+
161
+ # BERT
162
+ for split in train val test; do
163
+ python data_scripts/extract_bert_features.py --aokvqa-dir ${AOKVQA_DIR} --split ${split} --out ${FEATURES_DIR}/bert_${split}.pt
164
+ done
165
+ ```
166
+
167
+ </details>
168
+
169
+ ### Models and Predictions
170
+
171
+ ```bash
172
+ export LOG_DIR=./logs/
173
+ export PREDS_DIR=./predictions/
174
+ export PT_MODEL_DIR=./pretrained_models/
175
+ mkdir -p ${LOG_DIR} ${PREDS_DIR} ${PT_MODEL_DIR}
176
+ ```
177
+
178
+ <details> <summary><b>Download our pretrained model weights</b></summary>
179
+
180
+ ```bash
181
+ # Checkpoints for transfer learning experiments
182
+ curl -fsSL https://prior-model-weights.s3.us-east-2.amazonaws.com/aokvqa/transfer_exp_checkpoints.tar.gz | tar xvz -C ${PT_MODEL_DIR}/aokvqa_models
183
+
184
+ # Checkpoints for ClipCap models (generating answers and rationales)
185
+ curl -fsSL https://prior-model-weights.s3.us-east-2.amazonaws.com/aokvqa/clipcap_checkpoints.tar.gz | tar xvz -C ${PT_MODEL_DIR}/aokvqa_models
186
+ ```
187
+
188
+ </details>
189
+
190
+ We have included instructions for replicating each of our experiments (see README.md files below).
191
+
192
+ All Python scripts should be run from the root of this repository. Please be sure to first run the installation and data preparation as directed above.
193
+
194
+ - [Heuristics](./heuristics/README.md)
195
+ - [Transfer Learning Experiments](./transfer_experiments/README.md)
196
+ - [Querying GPT-3](./gpt3/README.md)
197
+ - [ClipCap](https://github.com/allenai/aokvqa/blob/ClipCap/README.md)
198
+ - [Generating Captions & Rationales](https://github.com/allenai/aokvqa/blob/ClipCap/README.md)
199
+
200
+ For each experiment, we follow this prediction file naming scheme: `{model-name}_{split}-{setting}.json` (e.g. `random-weighted_val-mc.json` or `random-weighted_test-da.json`). As examples in these Readme files, we produce predictions on the validation set.
201
+
202
+ We unify predictions for each split before evaluation. (You can omit one of `--mc` or `--da` prediction file if you only want to evaluate one setting.)
203
+
204
+ ```bash
205
+ python evaluation/prepare_predictions.py --aokvqa-dir ${AOKVQA_DIR} --split val --mc ./predictions_val-mc.json --da ./predictions_val-da.json --out ./predictions_val.json
206
+ # repeat for test split ...
207
+ ```
minigpt4/common/vqa_tools/aokvqa/data_scripts/build_vocab.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ from collections import Counter
4
+ import pathlib
5
+
6
+ from load_aokvqa import load_aokvqa
7
+
8
+
9
+ parser = argparse.ArgumentParser()
10
+ parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
11
+ parser.add_argument('--out', type=pathlib.Path, required=True, dest='output_file')
12
+ args = parser.parse_args()
13
+
14
+
15
+ # Build vocab from train set: correct choices + (direct answers appearing in >= 3 )
16
+
17
+ train_set = load_aokvqa(args.aokvqa_dir, 'train')
18
+
19
+ vocab = []
20
+ all_choices = Counter()
21
+ direct_answers = Counter()
22
+
23
+ for i in train_set:
24
+ vocab.append( i['choices'][i['correct_choice_idx']] )
25
+ all_choices.update(i['choices'])
26
+ direct_answers.update(set(i['direct_answers']))
27
+ vocab += [k for k,v in all_choices.items() if v >= 3]
28
+ vocab += [k for k,v in direct_answers.items() if v >= 3]
29
+
30
+ vocab = sorted(set(vocab))
31
+ print(f"Vocab size: {len(vocab)}")
32
+
33
+ # Save vocabulary Output
34
+
35
+ with open(args.output_file, 'w') as f:
36
+ for v in vocab:
37
+ print(v, file=f)
38
+
39
+ ## Check validation set coverage
40
+
41
+ val_set = load_aokvqa(args.aokvqa_dir, 'val')
42
+
43
+ val_acc = [v['choices'][v['correct_choice_idx']] in vocab for v in val_set]
44
+ val_acc = sum(val_acc) / len(val_acc) * 100
45
+ print(f"Val set coverage: {val_acc:.2f}" )
minigpt4/common/vqa_tools/aokvqa/data_scripts/encode_vocab_clip.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from tqdm import tqdm
3
+ import argparse
4
+ import pathlib
5
+
6
+ import torch
7
+ import clip
8
+
9
+ parser = argparse.ArgumentParser()
10
+ parser.add_argument('--vocab', type=pathlib.Path, required=True, dest='vocab_file')
11
+ parser.add_argument('--model-type', type=str, choices=['RN50', 'RN50x4', 'RN50x16', 'RN50x64', 'RN101', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'], required=True, dest='model_type')
12
+ parser.add_argument('--out', type=pathlib.Path, required=True, dest='output_file')
13
+ args = parser.parse_args()
14
+
15
+ assert args.output_file.suffix == '.pt'
16
+
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+ model, preprocess = clip.load(args.model_type, device=device)
19
+
20
+ with torch.no_grad():
21
+ a = open(args.vocab_file).read().splitlines()
22
+ mc_text = clip.tokenize(a).to(device)
23
+ mc_text_features = torch.stack([model.encode_text(mct.unsqueeze(0)).cpu() for mct in tqdm(mc_text)], dim=1)[0]
24
+ mc_text_features = mc_text_features.float()
25
+ model_name = args.model_type.replace('/', '-').replace('@', '-')
26
+ torch.save(mc_text_features, args.output_file)
minigpt4/common/vqa_tools/aokvqa/data_scripts/extract_bert_features.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import pathlib
4
+ from tqdm import tqdm
5
+
6
+ import torch
7
+ from transformers import AutoTokenizer, AutoModel
8
+
9
+ from load_aokvqa import load_aokvqa
10
+
11
+
12
+ parser = argparse.ArgumentParser()
13
+ parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
14
+ parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
15
+ parser.add_argument('--out', type=pathlib.Path, required=True, dest='output_file')
16
+ args = parser.parse_args()
17
+
18
+ assert args.output_file.suffix == '.pt'
19
+
20
+ ## Load dataset
21
+
22
+ dataset = load_aokvqa(args.aokvqa_dir, args.split)
23
+
24
+ ## Load model
25
+
26
+ tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/bert-base-nli-mean-tokens')
27
+ model = AutoModel.from_pretrained('sentence-transformers/bert-base-nli-mean-tokens')
28
+ device = "cuda" if torch.cuda.is_available() else "cpu"
29
+ model = model.to(device)
30
+ model.eval()
31
+
32
+ def mean_pooling(model_output, attention_mask):
33
+ token_embeddings = model_output[0] # First element of model_output contains all token embeddings
34
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
35
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
36
+
37
+ ## Encoding loop
38
+
39
+ with torch.no_grad():
40
+ embeddings = {}
41
+
42
+ for d in tqdm(dataset):
43
+ encoded_input = tokenizer([d['question']], padding=True, return_tensors='pt')
44
+ encoded_input = {k:v.to(device) for k,v in encoded_input.items()}
45
+ e = mean_pooling(model(**encoded_input), encoded_input['attention_mask'])
46
+ embeddings[d['question_id']] = {
47
+ 'question' : e[0].cpu()
48
+ }
49
+
50
+ torch.save(embeddings, args.output_file)
minigpt4/common/vqa_tools/aokvqa/data_scripts/extract_clip_features.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ from tqdm import tqdm
4
+ import argparse
5
+ import pathlib
6
+
7
+ import torch
8
+ import clip
9
+
10
+ from load_aokvqa import load_aokvqa, get_coco_path
11
+
12
+
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
15
+ parser.add_argument('--coco-dir', type=pathlib.Path, required=True, dest='coco_dir')
16
+ parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
17
+ parser.add_argument('--model-type', type=str, choices=['RN50', 'RN50x4', 'RN50x16', 'RN50x64', 'RN101', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'], required=True, dest='model_type')
18
+ parser.add_argument('--out', type=pathlib.Path, required=True, dest='output_file')
19
+ args = parser.parse_args()
20
+
21
+ assert args.output_file.suffix == '.pt'
22
+
23
+ ## Load dataset
24
+
25
+ dataset = load_aokvqa(args.aokvqa_dir, args.split)
26
+
27
+ ## Load model
28
+
29
+ device = "cuda" if torch.cuda.is_available() else "cpu"
30
+ model, preprocess = clip.load(args.model_type, device=device)
31
+
32
+ ## Encoding loop
33
+
34
+ with torch.no_grad():
35
+ embeddings = {}
36
+
37
+ for d in tqdm(dataset):
38
+ q = d["question"]
39
+ q_text = clip.tokenize(q).to(device)
40
+ q_text_features = model.encode_text(q_text)
41
+
42
+ img = Image.open(get_coco_path(args.split, d['image_id'], args.coco_dir))
43
+ img = preprocess(img).unsqueeze(0).to(device)
44
+ image_features = model.encode_image(img)
45
+
46
+ embeddings[d['question_id']] = {
47
+ 'question' : q_text_features[0].float().cpu(),
48
+ 'image' : image_features[0].float().cpu(),
49
+ }
50
+
51
+ torch.save(embeddings, args.output_file)
minigpt4/common/vqa_tools/aokvqa/data_scripts/extract_resnet_features.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import pathlib
4
+ from tqdm import tqdm
5
+ from PIL import Image
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torchvision import models
10
+ from torchvision import transforms as T
11
+
12
+ from load_aokvqa import load_aokvqa, get_coco_path
13
+
14
+
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
17
+ parser.add_argument('--coco-dir', type=pathlib.Path, required=True, dest='coco_dir')
18
+ parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
19
+ parser.add_argument('--out', type=pathlib.Path, required=True, dest='output_file')
20
+ args = parser.parse_args()
21
+
22
+ assert args.output_file.suffix == '.pt'
23
+
24
+ ## Load dataset
25
+
26
+ dataset = load_aokvqa(args.aokvqa_dir, args.split)
27
+
28
+ ## Load model
29
+
30
+ resnet_preprocess = T.Compose([
31
+ T.Resize(size=224, interpolation=T.InterpolationMode.BICUBIC),
32
+ T.CenterCrop(size=(224, 224)),
33
+ T.ToTensor(),
34
+ T.Normalize(
35
+ mean=[0.485, 0.456, 0.406],
36
+ std=[0.229, 0.224, 0.225]
37
+ )
38
+ ])
39
+
40
+ device = "cuda" if torch.cuda.is_available() else "cpu"
41
+
42
+ resnet_model = models.resnet50(pretrained=True)
43
+ resnet_model = torch.nn.Sequential(
44
+ *list(resnet_model.children())[:-1],
45
+ nn.Flatten()
46
+ ) # strip classification layer
47
+ resnet_model = resnet_model.to(device)
48
+
49
+ ## Encoding loop
50
+
51
+ with torch.no_grad():
52
+ embeddings = {}
53
+
54
+ for d in tqdm(dataset):
55
+ img = Image.open(get_coco_path(args.split, d['image_id'], args.coco_dir)).convert('RGB')
56
+ resnet_input = resnet_preprocess(img).unsqueeze(0).to(device)
57
+ resnet_features = resnet_model(resnet_input)
58
+ embeddings[d['question_id']] = {
59
+ 'image' : resnet_features[0].cpu()
60
+ }
61
+
62
+ torch.save(embeddings, args.output_file)
minigpt4/common/vqa_tools/aokvqa/environment.yml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: aokvqa
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ - huggingface
6
+ - conda-forge
7
+ - defaults
8
+ dependencies:
9
+ - python=3.7
10
+ - cudatoolkit=11.3
11
+ - numpy=1.21.6
12
+ - pytorch=1.11.0
13
+ - torchvision=0.12.0
14
+ - pytorch-lightning=1.6.3
15
+ - torchmetrics=0.8.1
16
+ - gdown=4.4.0
17
+ - pip=22.0.4
18
+ - pip:
19
+ - argparse==1.4.0
20
+ - Pillow==9.0.1
21
+ - tensorboard==2.9.0
22
+ - ftfy==6.1.1
23
+ - regex==2022.3.15
24
+ - tqdm==4.64.0
25
+ - clip @ git+https://github.com/openai/CLIP.git@b46f5ac7587d2e1862f8b7b1573179d80dcdd620
26
+ - openai==0.18.1
27
+ - nltk==3.7
28
+ - sacrebleu==2.0.0
29
+ - sacremoses==0.0.53
30
+ - sentence-transformers==2.2.0
31
+ - datasets==2.1.0
32
+ - tokenizers==0.10.3
33
+ - transformers==4.10.3
34
+
35
+ # Next: resolve conflict between sentence-transfomers and pytorch-lightning
36
+ # pip uninstall sentencepiece
minigpt4/common/vqa_tools/aokvqa/evaluation/eval_predictions.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import pathlib
3
+ import json
4
+ import glob
5
+
6
+ from load_aokvqa import load_aokvqa
7
+
8
+
9
+ def eval_aokvqa(dataset, preds, multiple_choice=False, strict=True):
10
+
11
+ if isinstance(dataset, list):
12
+ dataset = { dataset[i]['question_id'] : dataset[i] for i in range(len(dataset)) }
13
+
14
+ if multiple_choice is False:
15
+ dataset = {k:v for k,v in dataset.items() if v['difficult_direct_answer'] is False}
16
+
17
+ if strict:
18
+ dataset_qids = set(dataset.keys())
19
+ preds_qids = set(preds.keys())
20
+ assert dataset_qids.issubset(preds_qids)
21
+
22
+ # dataset = q_id (str) : dataset element (dict)
23
+ # preds = q_id (str) : prediction (str)
24
+
25
+ acc = []
26
+
27
+ for q in dataset.keys():
28
+ if q not in preds.keys():
29
+ acc.append(0.0)
30
+ continue
31
+
32
+ pred = preds[q]
33
+ choices = dataset[q]['choices']
34
+ direct_answers = dataset[q]['direct_answers']
35
+
36
+ ## Multiple Choice setting
37
+ if multiple_choice:
38
+ if strict:
39
+ assert pred in choices, 'Prediction must be a valid choice'
40
+ correct_choice_idx = dataset[q]['correct_choice_idx']
41
+ acc.append( float(pred == choices[correct_choice_idx]) )
42
+ ## Direct Answer setting
43
+ else:
44
+ num_match = sum([pred.lower() == da.lower() for da in direct_answers])
45
+ vqa_acc = min(1.0, num_match / 3.0)
46
+ acc.append(vqa_acc)
47
+
48
+ acc = sum(acc) / len(acc) * 100
49
+
50
+ return acc
51
+
52
+
53
+ if __name__ == '__main__':
54
+ parser = argparse.ArgumentParser()
55
+ parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
56
+ parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
57
+ parser.add_argument('--preds', type=str, required=True, dest='prediction_files')
58
+ args = parser.parse_args()
59
+
60
+ dataset = load_aokvqa(args.aokvqa_dir, args.split)
61
+
62
+ for prediction_file in glob.glob(args.prediction_files):
63
+ predictions = json.load(open(prediction_file, 'r'))
64
+
65
+ # Multiple choice
66
+
67
+ mc_predictions = {}
68
+
69
+ for q in predictions.keys():
70
+ if 'multiple_choice' in predictions[q].keys():
71
+ mc_predictions[q] = predictions[q]['multiple_choice']
72
+
73
+ if mc_predictions != {}:
74
+ mc_acc = eval_aokvqa(
75
+ dataset,
76
+ mc_predictions,
77
+ multiple_choice=True,
78
+ strict=False
79
+ )
80
+ print(prediction_file, 'MC', mc_acc)
81
+
82
+ # Direct Answer
83
+
84
+ da_predictions = {}
85
+
86
+ for q in predictions.keys():
87
+ if 'direct_answer' in predictions[q].keys():
88
+ da_predictions[q] = predictions[q]['direct_answer']
89
+
90
+ if da_predictions != {}:
91
+ da_acc = eval_aokvqa(
92
+ dataset,
93
+ da_predictions,
94
+ multiple_choice=False,
95
+ strict=False
96
+ )
97
+ print(prediction_file, 'DA', da_acc)
minigpt4/common/vqa_tools/aokvqa/evaluation/load_aokvqa.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+
5
+ def load_aokvqa(aokvqa_dir, split, version='v1p0'):
6
+ assert split in ['train', 'val', 'test', 'test_w_ans']
7
+ dataset = json.load(open(
8
+ os.path.join(aokvqa_dir, f"aokvqa_{version}_{split}.json")
9
+ ))
10
+ return dataset
11
+
12
+ def get_coco_path(split, image_id, coco_dir):
13
+ return os.path.join(coco_dir, f"{split}2017", f"{image_id:012}.jpg")
minigpt4/common/vqa_tools/aokvqa/evaluation/prepare_predictions.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import pathlib
3
+ import json
4
+
5
+ from load_aokvqa import load_aokvqa
6
+
7
+
8
+ if __name__ == '__main__':
9
+ parser = argparse.ArgumentParser()
10
+ parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
11
+ parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
12
+ parser.add_argument('--mc', type=argparse.FileType('r'), dest='mc_pred_file')
13
+ parser.add_argument('--da', type=argparse.FileType('r'), dest='da_pred_file')
14
+ parser.add_argument('--out', type=argparse.FileType('w'), dest='output_file')
15
+ args = parser.parse_args()
16
+ assert args.mc_pred_file or args.da_pred_file
17
+
18
+ dataset = load_aokvqa(args.aokvqa_dir, args.split)
19
+ mc_preds = json.load(args.mc_pred_file) if args.mc_pred_file else None
20
+ da_preds = json.load(args.da_pred_file) if args.da_pred_file else None
21
+ predictions = {}
22
+
23
+ for d in dataset:
24
+ q = d['question_id']
25
+ predictions[q] = {}
26
+ if mc_preds and q in mc_preds.keys():
27
+ predictions[q]['multiple_choice'] = mc_preds[q]
28
+ if da_preds and q in da_preds.keys():
29
+ predictions[q]['direct_answer'] = da_preds[q]
30
+
31
+ json.dump(predictions, args.output_file)
minigpt4/common/vqa_tools/aokvqa/evaluation/remap_predictions.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import pathlib
3
+ import json
4
+ from tqdm import tqdm
5
+
6
+ from sentence_transformers import SentenceTransformer
7
+ from sentence_transformers.util import cos_sim
8
+
9
+ from load_aokvqa import load_aokvqa
10
+
11
+
12
+ def map_to_choices(dataset, predictions, device='cpu'):
13
+ if isinstance(dataset, list):
14
+ dataset = { dataset[i]['question_id'] : dataset[i] for i in range(len(dataset)) }
15
+
16
+ if all([p in dataset[q]['choices'] for q, p in predictions.items()]):
17
+ return predictions
18
+
19
+ model = SentenceTransformer('sentence-transformers/average_word_embeddings_glove.6B.300d')
20
+ model.to(device)
21
+ for q in tqdm(predictions.keys()):
22
+ choices = dataset[q]['choices']
23
+ if predictions[q] not in choices:
24
+ choice_embeddings = model.encode([predictions[q]] + choices, convert_to_tensor=True)
25
+ a_idx = cos_sim(choice_embeddings[0], choice_embeddings[1:]).argmax().item()
26
+ predictions[q] = choices[a_idx]
27
+
28
+ return predictions
29
+
30
+
31
+ if __name__ == '__main__':
32
+ parser = argparse.ArgumentParser()
33
+ parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
34
+ parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
35
+ parser.add_argument('--pred', type=argparse.FileType('r'), required=True, dest='prediction_file')
36
+ parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file')
37
+ args = parser.parse_args()
38
+
39
+
40
+ dataset = load_aokvqa(args.aokvqa_dir, args.split)
41
+ predictions = json.load(args.prediction_file)
42
+ predictions = map_to_choices(dataset, predictions)
43
+
44
+ json.dump(predictions, args.output_file)
minigpt4/common/vqa_tools/aokvqa/gpt3/README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Querying GPT-3
2
+
3
+ To follow our experiments which use GPT-3, you must have access to the [OpenAI API](https://openai.com/api/) (at cost). Please retrieve your [organization](https://beta.openai.com/account/org-settings) and [API](https://beta.openai.com/account/api-keys) keys and set them in your environment variables.
4
+
5
+ ```bash
6
+ export OPENAI_ORG=....
7
+ export OPENAI_API_KEY=...
8
+ ```
9
+
10
+ For producing predictions for both DA and MC settings, run:
11
+ ```bash
12
+ python gpt3/query_gpt3.py --aokvqa-dir ${AOKVQA_DIR} --split val --out ${PREDS_DIR}/gpt3_val-da.json
13
+ python remap_predictions.py --aokvqa-dir ${AOKVQA_DIR} --split val --pred ${PREDS_DIR}/gpt3_val-da.json --out ${PREDS_DIR}/gpt3_val-mc.json
14
+ ```
minigpt4/common/vqa_tools/aokvqa/gpt3/caption_inputs.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import argparse
4
+ import pathlib
5
+
6
+ from load_aokvqa import load_aokvqa
7
+
8
+
9
+ parser = argparse.ArgumentParser()
10
+ parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
11
+ parser.add_argument('--coco-dir', type=pathlib.Path, required=True, dest='coco_dir')
12
+ parser.add_argument('--split', type=str, choices=['train', 'val'], required=True)
13
+ parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file')
14
+ args = parser.parse_args()
15
+
16
+ aokvqa_set = load_aokvqa(args.aokvqa_dir, args.split)
17
+
18
+ coco_captions = json.load(open(os.path.join(args.coco_dir, 'annotations', f'captions_{args.split}2017.json')))['annotations']
19
+ coco_captions = {c['image_id'] : c['caption'] for c in coco_captions}
20
+
21
+ captions = { d['question_id'] : coco_captions[d['image_id']] for d in aokvqa_set }
22
+
23
+ json.dump(captions, args.output_file)
minigpt4/common/vqa_tools/aokvqa/gpt3/query_gpt3.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import json
4
+ from tqdm import tqdm
5
+ import argparse
6
+ import pathlib
7
+
8
+ import openai
9
+ openai.organization = os.getenv('OPENAI_ORG')
10
+ openai.api_key = os.getenv('OPENAI_API_KEY')
11
+
12
+ from load_aokvqa import load_aokvqa
13
+
14
+
15
+ random.seed(0)
16
+
17
+
18
+ def main():
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
21
+ parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
22
+ parser.add_argument('--n', type=int, default=10, dest='num_examples')
23
+ parser.add_argument('--train-context', type=argparse.FileType('r'), dest='train_context_file')
24
+ parser.add_argument('--prefix', type=str, default='', dest='prompt_prefix')
25
+ parser.add_argument('--include-choices', action='store_true', dest='include_choices')
26
+ parser.add_argument('--context', type=argparse.FileType('r'), dest='context_file')
27
+ parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file')
28
+ args = parser.parse_args()
29
+
30
+
31
+ train_set = load_aokvqa(args.aokvqa_dir, 'train')
32
+ eval_set = load_aokvqa(args.aokvqa_dir, args.split)
33
+
34
+ train_context = {}
35
+ context = {}
36
+ if args.context_file is not None:
37
+ train_context = json.load(args.train_context_file)
38
+ context = json.load(args.context_file)
39
+
40
+ predictions = {}
41
+
42
+ for d in tqdm(eval_set):
43
+ q = d['question_id']
44
+
45
+ prompt = args.prompt_prefix
46
+ for e in random.sample(train_set, args.num_examples):
47
+ prompt += prompt_element(e,
48
+ context=train_context.get(q, None),
49
+ include_choices=args.include_choices,
50
+ answer=True
51
+ )
52
+ prompt += '\n\n'
53
+
54
+ prompt += prompt_element(d,
55
+ context=context.get(q, None),
56
+ include_choices=args.include_choices,
57
+ answer=False
58
+ )
59
+
60
+ response = openai.Completion.create(
61
+ engine="text-curie-001",
62
+ prompt=prompt,
63
+ temperature=0.0,
64
+ max_tokens=10,
65
+ )
66
+
67
+ predictions[q] = response.choices[0].text.strip()
68
+
69
+ json.dump(predictions, args.output_file)
70
+
71
+
72
+ def prompt_element(d, context=None, include_choices=False, answer=False):
73
+ return (f"Context: {context}\n" if context is not None else '') + \
74
+ f"Q: {d['question']}\n" + \
75
+ (f"Choices: {', '.join(d['choices'])}.\n" if include_choices else '') + \
76
+ f"A:" + (f" {d['choices'][d['correct_choice_idx']]}" if answer else '')
77
+
78
+ if __name__ == '__main__':
79
+ main()
minigpt4/common/vqa_tools/aokvqa/gpt3/rationale_inputs.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import argparse
3
+ import pathlib
4
+
5
+ from load_aokvqa import load_aokvqa
6
+
7
+
8
+ parser = argparse.ArgumentParser()
9
+ parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
10
+ parser.add_argument('--split', type=str, choices=['train', 'val', 'test_w_ans'], required=True)
11
+ parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file')
12
+ args = parser.parse_args()
13
+
14
+ aokvqa_set = load_aokvqa(args.aokvqa_dir, args.split)
15
+ rationales = {d['question_id'] : d['rationales'][0] for d in aokvqa_set}
16
+ json.dump(rationales, args.output_file)