shoubin commited on
Commit
7e8784c
1 Parent(s): 0054d8c

upload_demo

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
.DS_Store ADDED
Binary file (10.2 kB). View file
 
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
  demo4.mp4 filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
  demo4.mp4 filter=lfs diff=lfs merge=lfs -text
36
+ videos/*.mp4 filter=lfs diff=lfs merge=lfs -text
LICENSE.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2022 Salesforce, Inc.
4
+ All rights reserved.
5
+
6
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
9
+
10
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
11
+
12
+ 3. Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
13
+
14
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
MANIFEST.in ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ recursive-include lavis/configs *.yaml *.json
2
+ recursive-include lavis/projects *.yaml *.json
3
+
4
+ recursive-exclude lavis/datasets/download_scripts *
5
+ recursive-exclude lavis/output *
6
+
7
+ include requirements.txt
README.md CHANGED
@@ -1,13 +1,112 @@
1
- ---
2
- title: SeViLA
3
- emoji: 📉
4
- colorFrom: pink
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 3.29.0
8
- app_file: app.py
9
- pinned: false
10
- license: openrail
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Self-Chained Image-Language Model for Video Localization and Question Answering
2
+
3
+ * Authors: [Shoubin Yu](https://yui010206.github.io/), [Jaemin Cho](https://j-min.io), [Prateek Yadav](https://prateek-yadav.github.io/), [Mohit Bansal](https://www.cs.unc.edu/~mbansal/)
4
+ * [arXiv](https://arxiv.org/abs/2305.06988)
5
+ <img src="./assets/teaser.png" alt="teaser image" width="800"/>
6
+
7
+ <img src="./assets/model.png" alt="teaser image" width="800"/>
8
+
9
+ <img src="./assets/chain.png" alt="teaser image" width="800"/>
10
+
11
+
12
+ # Code structure
13
+ ```bash
14
+
15
+ # Data & Data Preprocessing
16
+ ./sevila_data
17
+
18
+ # Pretrained Checkpoints
19
+ ./sevila_checkpoints
20
+
21
+ # SeViLA code
22
+ ./lavis/
23
+
24
+ # running scripts for SeViLa localizer/answerer training/inference
25
+ ./run_scripts
26
+
27
+ ```
28
+
29
+ # Setup
30
+
31
+ ## Install Dependencies
32
+
33
+ 1. (Optional) Creating conda environment
34
+
35
+ ```bash
36
+ conda create -n sevila python=3.8
37
+ conda activate sevila
38
+ ```
39
+
40
+ 2. build from source
41
+
42
+ ```bash
43
+ pip install -e .
44
+ ```
45
+
46
+ ## Download Pretrained Models
47
+ We pre-train SeViLA localizer on QVHighlights and hold checkpoints via [Huggingface](https://huggingface.co/Shoubin/SeViLA/resolve/main/sevila_pretrained.pth).
48
+ Download checkpoints and put it under /sevila_checkpoints.
49
+ The checkpoints (814.55M) contains pre-trained localizer and zero-shot answerer.
50
+
51
+
52
+
53
+ # Dataset Preparation
54
+ We test our model on:
55
+ + [NExT-QA](https://doc-doc.github.io/docs/nextqa.html)
56
+
57
+ + [STAR](https://star.csail.mit.edu/)
58
+
59
+ + [How2QA](https://value-benchmark.github.io/index.html)
60
+
61
+ + [TVQA](https://tvqa.cs.unc.edu/)
62
+
63
+ + [VLEP](https://value-benchmark.github.io/index.html)
64
+
65
+ + [QVHighlights](https://github.com/jayleicn/moment_detr)
66
+
67
+ please download original data and preprocess them via our [scripts](sevila_data/) under ./sevila_data/ .
68
+
69
+
70
+ # Training and Inference
71
+ We provideo SeViLA training and inference script examples as following:
72
+ ## 1) Localizer Pre-training
73
+ ```bash
74
+ sh run_scripts/sevila/pre-train/pretrain_qvh.sh
75
+ ```
76
+
77
+ ## 2) Localizer Self-refinement
78
+
79
+ ```bash
80
+ sh run_scripts/sevila/refinement/nextqa_sr.sh
81
+ ```
82
+
83
+ ## 3) Answerer Fine-tuning
84
+
85
+ ```bash
86
+ sh run_scripts/sevila/finetune/nextqa_ft.sh
87
+ ```
88
+
89
+ ## 4) Inference
90
+
91
+ ```bash
92
+ sh run_scripts/sevila/inference/nextqa_infer.sh
93
+ ```
94
+
95
+
96
+ # Acknowledgments
97
+ We thank the developers of [LAVIS](https://github.com/salesforce/LAVIS), [BLIP-2](https://github.com/salesforce/LAVIS/tree/main/projects/blip2), [CLIP](https://github.com/openai/CLIP), [All-in-one](https://github.com/showlab/all-in-one), for their public code release.
98
+
99
+
100
+ # Reference
101
+ Please cite our paper if you use our models in your works:
102
+
103
+
104
+ ```bibtex
105
+ @misc{yu2023selfchained,
106
+ title={Self-Chained Image-Language Model for Video Localization and Question Answering},
107
+ author={Shoubin Yu and Jaemin Cho and Prateek Yadav and Mohit Bansal},
108
+ year={2023},
109
+ eprint={2305.06988},
110
+ archivePrefix={arXiv},
111
+ primaryClass={cs.CV}
112
+ }
app.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import torch
4
+ from torchvision import transforms
5
+ from lavis.processors import transforms_video
6
+ from lavis.datasets.data_utils import load_video_demo
7
+ from lavis.processors.blip_processors import ToUint8, ToTHWC
8
+ from lavis.models.sevila_models.sevila import SeViLA
9
+ from typing import Optional
10
+ import warnings
11
+ # model config
12
+ img_size = 224
13
+ num_query_token = 32
14
+ t5_model = 'google/flan-t5-xl'
15
+ drop_path_rate = 0
16
+ use_grad_checkpoint = False
17
+ vit_precision = "fp16"
18
+ freeze_vit = True
19
+ prompt = ''
20
+ max_txt_len = 77
21
+ answer_num = 5
22
+ apply_lemmatizer = False
23
+ task = 'freeze_loc_freeze_qa_vid'
24
+
25
+ # prompt
26
+ LOC_propmpt = 'Does the information within the frame provide the necessary details to accurately answer the given question?'
27
+ QA_prompt = 'Considering the information presented in the frame, select the correct answer from the options.'
28
+
29
+ # processors config
30
+ mean = (0.48145466, 0.4578275, 0.40821073)
31
+ std = (0.26862954, 0.26130258, 0.27577711)
32
+ normalize = transforms.Normalize(mean, std)
33
+ image_size = img_size
34
+ transform = transforms.Compose([ToUint8(), ToTHWC(), transforms_video.ToTensorVideo(), normalize])
35
+
36
+ print('model loading')
37
+ sevila = SeViLA(
38
+ img_size=img_size,
39
+ drop_path_rate=drop_path_rate,
40
+ use_grad_checkpoint=use_grad_checkpoint,
41
+ vit_precision=vit_precision,
42
+ freeze_vit=freeze_vit,
43
+ num_query_token=num_query_token,
44
+ t5_model=t5_model,
45
+ prompt=prompt,
46
+ max_txt_len=max_txt_len,
47
+ apply_lemmatizer=apply_lemmatizer,
48
+ frame_num=4,
49
+ answer_num=answer_num,
50
+ task=task,
51
+ )
52
+
53
+ sevila.load_checkpoint(url_or_filename='https://huggingface.co/Shoubin/SeViLA/resolve/main/sevila_pretrained.pth')
54
+ print('model loaded')
55
+
56
+ ANS_MAPPING = {0 : 'A', 1 : 'B', 2 : 'C', 3 : 'D', 4 : 'E'}
57
+
58
+ # os.mkdir('video')
59
+
60
+ def sevila_demo(video,
61
+ question,
62
+ option1, option2, option3,
63
+ video_frame_num,
64
+ keyframe_num):
65
+
66
+ if torch.cuda.is_available():
67
+ device = 0
68
+ else:
69
+ device = 'cpu'
70
+
71
+ global sevila
72
+ if device == "cpu":
73
+ sevila = sevila.float()
74
+ else:
75
+ sevila = sevila.to(int(device))
76
+
77
+ vpath = video
78
+ raw_clip, indice, fps, vlen = load_video_demo(
79
+ video_path=vpath,
80
+ n_frms=int(video_frame_num),
81
+ height=image_size,
82
+ width=image_size,
83
+ sampling="uniform",
84
+ clip_proposal=None
85
+ )
86
+ clip = transform(raw_clip.permute(1,0,2,3))
87
+ clip = clip.float().to(int(device))
88
+ clip = clip.unsqueeze(0)
89
+ # check
90
+ if option1[-1] != '.':
91
+ option1 += '.'
92
+ if option2[-1] != '.':
93
+ option2 += '.'
94
+ if option3[-1] != '.':
95
+ option3 += '.'
96
+ option_dict = {0:option1, 1:option2, 2:option3}
97
+ options = 'Option A:{} Option B:{} Option C:{}'.format(option1, option2, option3)
98
+ text_input_qa = 'Question: ' + question + ' ' + options + ' ' + QA_prompt
99
+ text_input_loc = 'Question: ' + question + ' ' + options + ' ' + LOC_propmpt
100
+
101
+ out = sevila.generate_demo(clip, text_input_qa, text_input_loc, int(keyframe_num))
102
+ # print(out)
103
+ answer_id = out['output_text'][0]
104
+ answer = option_dict[answer_id]
105
+ select_index = out['frame_idx'][0]
106
+ # images = []
107
+ keyframes = []
108
+ timestamps =[]
109
+
110
+ # print('raw_clip', len(raw_clip))
111
+ # for j in range(int(video_frame_num)):
112
+ # image = raw_clip[:, j, :, :].int()
113
+ # image = image.permute(1, 2, 0).numpy()
114
+ # images.append(image)
115
+
116
+ video_len = vlen/fps # seconds
117
+
118
+ for i in select_index:
119
+ image = raw_clip[:, i, :, :].int()
120
+ image = image.permute(1, 2, 0).numpy()
121
+ keyframes.append(image)
122
+ select_i = indice[i]
123
+ time = round((select_i / vlen) * video_len, 2)
124
+ timestamps.append(str(time)+'s')
125
+
126
+ gr.components.Gallery(keyframes)
127
+ #gr.components.Gallery(images)
128
+ timestamps_des = ''
129
+ for i in range(len(select_index)):
130
+ timestamps_des += 'Keyframe {}: {} \n'.format(str(i+1), timestamps[i])
131
+
132
+ return keyframes, timestamps_des, answer
133
+
134
+ with gr.Blocks(title="SeViLA demo") as demo:
135
+ description = """<p style="text-align: center; font-weight: bold;">
136
+ <span style="font-size: 28px">Self-Chained Image-Language Model for Video Localization and Question Answering</span>
137
+ <br>
138
+ <span style="font-size: 18px" id="author-info">
139
+ <a href="https://yui010206.github.io/" target="_blank">Shoubin Yu</a>,
140
+ <a href="https://j-min.io/" target="_blank">Jaemin Cho</a>,
141
+ <a href="https://prateek-yadav.github.io/" target="_blank">Prateek Yadav</a>,
142
+ <a href="https://www.cs.unc.edu/~mbansal/" target="_blank">Mohit Bansal</a>
143
+ </span>
144
+ <br>
145
+ <span style="font-size: 18px" id="paper-info">
146
+ [<a href="https://github.com/Yui010206/SeViLA" target="_blank">GitHub</a>]
147
+ [<a href="https://arxiv.org/abs/2305.06988" target="_blank">Paper</a>]
148
+ </span>
149
+ </p>
150
+ <p>
151
+ To locate keyframes in a video and answer question, please:
152
+ <br>
153
+ (1) upolad your video; (2) write your question/options and set # video frame/# keyframe/running device; (3) click Locate and Answer!
154
+ <br>
155
+ Just a heads up - loading the SeViLA model can take a few minutes (typically 2-3), and running examples requires about 12GB of memory.
156
+ <br>
157
+ We've got you covered! We've provided some example videos and questions below to help you get started. Feel free to try out SeViLA with these!
158
+ </p>
159
+ """
160
+ gr.HTML(description)
161
+ with gr.Row():
162
+ with gr.Column(scale=1, min_width=600):
163
+ video = gr.Video(label='Video')
164
+ question = gr.Textbox(placeholder="Why did the two ladies put their hands above their eyes while staring out?", label='Question')
165
+ with gr.Row():
166
+ option1 = gr.Textbox(placeholder="practicing cheer", label='Option 1')
167
+ option2 = gr.Textbox(placeholder="posing for photo", label='Option 2')
168
+ option3 = gr.Textbox(placeholder="to see better", label='Option 3')
169
+ video_frame_num = gr.Textbox(placeholder=32, label='# Video Frame')
170
+ keyframe_num = gr.Textbox(placeholder=4, label='# Keyframe')
171
+ # device = gr.Textbox(placeholder=0, label='Device')
172
+ gen_btn = gr.Button(value='Locate and Answer!')
173
+ with gr.Column(scale=2, min_width=600):
174
+ keyframes = gr.Gallery(
175
+ label="Keyframes", show_label=False, elem_id="gallery"
176
+ ).style(columns=[4], rows=[1], object_fit="contain", height="auto")
177
+ #keyframes = gr.Gallery(label='Keyframes')
178
+ timestamps = gr.outputs.Textbox(label="Keyframe Timestamps")
179
+ answer = gr.outputs.Textbox(label="Output Answer")
180
+
181
+ gen_btn.click(
182
+ sevila_demo,
183
+ inputs=[video, question, option1, option2, option3, video_frame_num, keyframe_num],
184
+ outputs=[keyframes, timestamps, answer],
185
+ queue=True
186
+ )
187
+ #demo = gr.Interface(sevila_demo,
188
+ # inputs=[gr.Video(), question, option1, option2, option3, video_frame_num, keyframe_num, device],
189
+ # outputs=['gallery', timestamps, answer],
190
+ # examples=[['videos/demo1.mp4', 'Why did the two ladies put their hands above their eyes while staring out?', 'practicing cheer.', 'play ball.', 'to see better.', 32, 4, 0],
191
+ # ['videos/demo2.mp4', 'What did both of them do after completing skiing?', 'jump and pose.' , 'bend down.','raised their hands.', 32, 4, 0],
192
+ # ['videos/demo3.mp4', 'What room was Wilson breaking into when House found him?', 'the kitchen.' , 'the dining room.','the bathroom.', 32, 4, 0]]
193
+ # )
194
+ with gr.Column():
195
+ gr.Examples(
196
+ inputs=[video, question, option1, option2, option3, video_frame_num, keyframe_num],
197
+ outputs=[keyframes, timestamps, answer],
198
+ fn=sevila_demo,
199
+ examples=[['videos/demo1.mp4', 'Why did the two ladies put their hands above their eyes while staring out?', 'practicing cheer', 'play ball', 'to see better', 32, 4],
200
+ ['videos/demo2.mp4', 'What did both of them do after completing skiing?', 'jump and pose' , 'bend down','raised their hands', 32, 4],
201
+ ['videos/demo3.mp4', 'What room was Wilson breaking into when House found him?', 'the kitchen' , 'the dining room','the bathroom', 32, 4],
202
+ ['videos/demo4.mp4', 'what kind of bird is it?', 'chikadee' , 'eagle','seagull', 32, 1]],
203
+ cache_examples=False,
204
+ )
205
+ demo.queue(concurrency_count=1, api_open=False)
206
+ demo.launch(share=False)
app/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ # Copyright (c) 2022, salesforce.com, inc.
3
+ # All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ from PIL import Image
9
+ import requests
10
+
11
+ import streamlit as st
12
+ import torch
13
+
14
+
15
+ @st.cache()
16
+ def load_demo_image():
17
+ img_url = (
18
+ "https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg"
19
+ )
20
+ raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
21
+ return raw_image
22
+
23
+
24
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+
26
+ cache_root = "/export/home/.cache/lavis/"
app/calculate_coco_features.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ # Copyright (c) 2022, salesforce.com, inc.
3
+ # All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ from PIL import Image
9
+ import requests
10
+ import torch
11
+
12
+ import os
13
+
14
+ from lavis.common.registry import registry
15
+ from lavis.processors import *
16
+ from lavis.models import *
17
+ from lavis.common.utils import build_default_model
18
+
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+
21
+
22
+ def load_demo_image():
23
+ img_url = (
24
+ "https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg"
25
+ )
26
+ raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
27
+
28
+ return raw_image
29
+
30
+
31
+ def read_img(filepath):
32
+ raw_image = Image.open(filepath).convert("RGB")
33
+
34
+ return raw_image
35
+
36
+
37
+ # model
38
+ model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth"
39
+ feature_extractor = BlipFeatureExtractor(pretrained=model_url)
40
+
41
+ feature_extractor.eval()
42
+ feature_extractor = feature_extractor.to(device)
43
+
44
+ # preprocessors
45
+ vis_processor = BlipImageEvalProcessor(image_size=224)
46
+ text_processor = BlipCaptionProcessor()
47
+
48
+ # files to process
49
+ # file_root = "/export/home/.cache/lavis/coco/images/val2014"
50
+ file_root = "/export/home/.cache/lavis/coco/images/train2014"
51
+ filepaths = os.listdir(file_root)
52
+
53
+ print(len(filepaths))
54
+
55
+ caption = "dummy"
56
+
57
+ path2feat = dict()
58
+ bsz = 256
59
+
60
+ images_in_batch = []
61
+ filepaths_in_batch = []
62
+
63
+ for i, filename in enumerate(filepaths):
64
+ if i % bsz == 0 and i > 0:
65
+ images_in_batch = torch.cat(images_in_batch, dim=0).to(device)
66
+ with torch.no_grad():
67
+ image_features = feature_extractor(
68
+ images_in_batch, caption, mode="image", normalized=True
69
+ )[:, 0]
70
+
71
+ for filepath, image_feat in zip(filepaths_in_batch, image_features):
72
+ path2feat[os.path.basename(filepath)] = image_feat.detach().cpu()
73
+
74
+ images_in_batch = []
75
+ filepaths_in_batch = []
76
+
77
+ print(len(path2feat), image_features.shape)
78
+ else:
79
+ filepath = os.path.join(file_root, filename)
80
+
81
+ image = read_img(filepath)
82
+ image = vis_processor(image).unsqueeze(0)
83
+
84
+ images_in_batch.append(image)
85
+ filepaths_in_batch.append(filepath)
86
+
87
+ torch.save(path2feat, "path2feat_coco_train2014.pth")
app/caption.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ # Copyright (c) 2022, salesforce.com, inc.
3
+ # All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import streamlit as st
9
+ from app import device, load_demo_image
10
+ from app.utils import load_model_cache
11
+ from lavis.processors import load_processor
12
+ from PIL import Image
13
+
14
+
15
+ def app():
16
+ # ===== layout =====
17
+ model_type = st.sidebar.selectbox("Model:", ["BLIP_base", "BLIP_large"])
18
+
19
+ sampling_method = st.sidebar.selectbox(
20
+ "Sampling method:", ["Beam search", "Nucleus sampling"]
21
+ )
22
+
23
+ st.markdown(
24
+ "<h1 style='text-align: center;'>Image Description Generation</h1>",
25
+ unsafe_allow_html=True,
26
+ )
27
+
28
+ instructions = """Try the provided image or upload your own:"""
29
+ file = st.file_uploader(instructions)
30
+
31
+ use_beam = sampling_method == "Beam search"
32
+
33
+ col1, col2 = st.columns(2)
34
+
35
+ if file:
36
+ raw_img = Image.open(file).convert("RGB")
37
+ else:
38
+ raw_img = load_demo_image()
39
+
40
+ col1.header("Image")
41
+
42
+ w, h = raw_img.size
43
+ scaling_factor = 720 / w
44
+ resized_image = raw_img.resize((int(w * scaling_factor), int(h * scaling_factor)))
45
+
46
+ col1.image(resized_image, use_column_width=True)
47
+ col2.header("Description")
48
+
49
+ cap_button = st.button("Generate")
50
+
51
+ # ==== event ====
52
+ vis_processor = load_processor("blip_image_eval").build(image_size=384)
53
+
54
+ if cap_button:
55
+ if model_type.startswith("BLIP"):
56
+ blip_type = model_type.split("_")[1].lower()
57
+ model = load_model_cache(
58
+ "blip_caption",
59
+ model_type=f"{blip_type}_coco",
60
+ is_eval=True,
61
+ device=device,
62
+ )
63
+
64
+ img = vis_processor(raw_img).unsqueeze(0).to(device)
65
+ captions = generate_caption(
66
+ model=model, image=img, use_nucleus_sampling=not use_beam
67
+ )
68
+
69
+ col2.write("\n\n".join(captions), use_column_width=True)
70
+
71
+
72
+ def generate_caption(
73
+ model, image, use_nucleus_sampling=False, num_beams=3, max_length=40, min_length=5
74
+ ):
75
+ samples = {"image": image}
76
+
77
+ captions = []
78
+ if use_nucleus_sampling:
79
+ for _ in range(5):
80
+ caption = model.generate(
81
+ samples,
82
+ use_nucleus_sampling=True,
83
+ max_length=max_length,
84
+ min_length=min_length,
85
+ top_p=0.9,
86
+ )
87
+ captions.append(caption[0])
88
+ else:
89
+ caption = model.generate(
90
+ samples,
91
+ use_nucleus_sampling=False,
92
+ num_beams=num_beams,
93
+ max_length=max_length,
94
+ min_length=min_length,
95
+ )
96
+ captions.append(caption[0])
97
+
98
+ return captions
app/classification.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ # Copyright (c) 2022, salesforce.com, inc.
3
+ # All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import plotly.graph_objects as go
9
+ import requests
10
+ import streamlit as st
11
+ import torch
12
+ from lavis.models import load_model
13
+ from lavis.processors import load_processor
14
+ from lavis.processors.blip_processors import BlipCaptionProcessor
15
+ from PIL import Image
16
+
17
+ from app import device, load_demo_image
18
+ from app.utils import load_blip_itm_model
19
+ from lavis.processors.clip_processors import ClipImageEvalProcessor
20
+
21
+
22
+ @st.cache()
23
+ def load_demo_image(img_url=None):
24
+ if not img_url:
25
+ img_url = "https://img.atlasobscura.com/yDJ86L8Ou6aIjBsxnlAy5f164w1rjTgcHZcx2yUs4mo/rt:fit/w:1200/q:81/sm:1/scp:1/ar:1/aHR0cHM6Ly9hdGxh/cy1kZXYuczMuYW1h/em9uYXdzLmNvbS91/cGxvYWRzL3BsYWNl/X2ltYWdlcy85MDll/MDRjOS00NTJjLTQx/NzQtYTY4MS02NmQw/MzI2YWIzNjk1ZGVk/MGZhMTJiMTM5MmZi/NGFfUmVhcl92aWV3/X29mX3RoZV9NZXJs/aW9uX3N0YXR1ZV9h/dF9NZXJsaW9uX1Bh/cmssX1NpbmdhcG9y/ZSxfd2l0aF9NYXJp/bmFfQmF5X1NhbmRz/X2luX3RoZV9kaXN0/YW5jZV8tXzIwMTQw/MzA3LmpwZw.jpg"
26
+ raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
27
+ return raw_image
28
+
29
+
30
+ @st.cache(
31
+ hash_funcs={
32
+ torch.nn.parameter.Parameter: lambda parameter: parameter.data.detach()
33
+ .cpu()
34
+ .numpy()
35
+ },
36
+ allow_output_mutation=True,
37
+ )
38
+ def load_model_cache(model_type, device):
39
+ if model_type == "blip":
40
+ model = load_model(
41
+ "blip_feature_extractor", model_type="base", is_eval=True, device=device
42
+ )
43
+ elif model_type == "albef":
44
+ model = load_model(
45
+ "albef_feature_extractor", model_type="base", is_eval=True, device=device
46
+ )
47
+ elif model_type == "CLIP_ViT-B-32":
48
+ model = load_model(
49
+ "clip_feature_extractor", "ViT-B-32", is_eval=True, device=device
50
+ )
51
+ elif model_type == "CLIP_ViT-B-16":
52
+ model = load_model(
53
+ "clip_feature_extractor", "ViT-B-16", is_eval=True, device=device
54
+ )
55
+ elif model_type == "CLIP_ViT-L-14":
56
+ model = load_model(
57
+ "clip_feature_extractor", "ViT-L-14", is_eval=True, device=device
58
+ )
59
+
60
+ return model
61
+
62
+
63
+ def app():
64
+ model_type = st.sidebar.selectbox(
65
+ "Model:",
66
+ ["ALBEF", "BLIP_Base", "CLIP_ViT-B-32", "CLIP_ViT-B-16", "CLIP_ViT-L-14"],
67
+ )
68
+ score_type = st.sidebar.selectbox("Score type:", ["Cosine", "Multimodal"])
69
+
70
+ # ===== layout =====
71
+ st.markdown(
72
+ "<h1 style='text-align: center;'>Zero-shot Classification</h1>",
73
+ unsafe_allow_html=True,
74
+ )
75
+
76
+ instructions = """Try the provided image or upload your own:"""
77
+ file = st.file_uploader(instructions)
78
+
79
+ st.header("Image")
80
+ if file:
81
+ raw_img = Image.open(file).convert("RGB")
82
+ else:
83
+ raw_img = load_demo_image()
84
+
85
+ st.image(raw_img) # , use_column_width=True)
86
+
87
+ col1, col2 = st.columns(2)
88
+
89
+ col1.header("Categories")
90
+
91
+ cls_0 = col1.text_input("category 1", value="merlion")
92
+ cls_1 = col1.text_input("category 2", value="sky")
93
+ cls_2 = col1.text_input("category 3", value="giraffe")
94
+ cls_3 = col1.text_input("category 4", value="fountain")
95
+ cls_4 = col1.text_input("category 5", value="marina bay")
96
+
97
+ cls_names = [cls_0, cls_1, cls_2, cls_3, cls_4]
98
+ cls_names = [cls_nm for cls_nm in cls_names if len(cls_nm) > 0]
99
+
100
+ if len(cls_names) != len(set(cls_names)):
101
+ st.error("Please provide unique class names")
102
+ return
103
+
104
+ button = st.button("Submit")
105
+
106
+ col2.header("Prediction")
107
+
108
+ # ===== event =====
109
+
110
+ if button:
111
+ if model_type.startswith("BLIP"):
112
+ text_processor = BlipCaptionProcessor(prompt="A picture of ")
113
+ cls_prompt = [text_processor(cls_nm) for cls_nm in cls_names]
114
+
115
+ if score_type == "Cosine":
116
+ vis_processor = load_processor("blip_image_eval").build(image_size=224)
117
+ img = vis_processor(raw_img).unsqueeze(0).to(device)
118
+
119
+ feature_extractor = load_model_cache(model_type="blip", device=device)
120
+
121
+ sample = {"image": img, "text_input": cls_prompt}
122
+
123
+ with torch.no_grad():
124
+ image_features = feature_extractor.extract_features(
125
+ sample, mode="image"
126
+ ).image_embeds_proj[:, 0]
127
+ text_features = feature_extractor.extract_features(
128
+ sample, mode="text"
129
+ ).text_embeds_proj[:, 0]
130
+ sims = (image_features @ text_features.t())[
131
+ 0
132
+ ] / feature_extractor.temp
133
+
134
+ else:
135
+ vis_processor = load_processor("blip_image_eval").build(image_size=384)
136
+ img = vis_processor(raw_img).unsqueeze(0).to(device)
137
+
138
+ model = load_blip_itm_model(device)
139
+
140
+ output = model(img, cls_prompt, match_head="itm")
141
+ sims = output[:, 1]
142
+
143
+ sims = torch.nn.Softmax(dim=0)(sims)
144
+ inv_sims = [sim * 100 for sim in sims.tolist()[::-1]]
145
+
146
+ elif model_type.startswith("ALBEF"):
147
+ vis_processor = load_processor("blip_image_eval").build(image_size=224)
148
+ img = vis_processor(raw_img).unsqueeze(0).to(device)
149
+
150
+ text_processor = BlipCaptionProcessor(prompt="A picture of ")
151
+ cls_prompt = [text_processor(cls_nm) for cls_nm in cls_names]
152
+
153
+ feature_extractor = load_model_cache(model_type="albef", device=device)
154
+
155
+ sample = {"image": img, "text_input": cls_prompt}
156
+
157
+ with torch.no_grad():
158
+ image_features = feature_extractor.extract_features(
159
+ sample, mode="image"
160
+ ).image_embeds_proj[:, 0]
161
+ text_features = feature_extractor.extract_features(
162
+ sample, mode="text"
163
+ ).text_embeds_proj[:, 0]
164
+
165
+ st.write(image_features.shape)
166
+ st.write(text_features.shape)
167
+
168
+ sims = (image_features @ text_features.t())[0] / feature_extractor.temp
169
+
170
+ sims = torch.nn.Softmax(dim=0)(sims)
171
+ inv_sims = [sim * 100 for sim in sims.tolist()[::-1]]
172
+
173
+ elif model_type.startswith("CLIP"):
174
+ if model_type == "CLIP_ViT-B-32":
175
+ model = load_model_cache(model_type="CLIP_ViT-B-32", device=device)
176
+ elif model_type == "CLIP_ViT-B-16":
177
+ model = load_model_cache(model_type="CLIP_ViT-B-16", device=device)
178
+ elif model_type == "CLIP_ViT-L-14":
179
+ model = load_model_cache(model_type="CLIP_ViT-L-14", device=device)
180
+ else:
181
+ raise ValueError(f"Unknown model type {model_type}")
182
+
183
+ if score_type == "Cosine":
184
+ # image_preprocess = ClipImageEvalProcessor(image_size=336)
185
+ image_preprocess = ClipImageEvalProcessor(image_size=224)
186
+ img = image_preprocess(raw_img).unsqueeze(0).to(device)
187
+
188
+ sample = {"image": img, "text_input": cls_names}
189
+
190
+ with torch.no_grad():
191
+ clip_features = model.extract_features(sample)
192
+
193
+ image_features = clip_features.image_embeds_proj
194
+ text_features = clip_features.text_embeds_proj
195
+
196
+ sims = (100.0 * image_features @ text_features.T)[0].softmax(dim=-1)
197
+ inv_sims = sims.tolist()[::-1]
198
+ else:
199
+ st.warning("CLIP does not support multimodal scoring.")
200
+ return
201
+
202
+ fig = go.Figure(
203
+ go.Bar(
204
+ x=inv_sims,
205
+ y=cls_names[::-1],
206
+ text=["{:.2f}".format(s) for s in inv_sims],
207
+ orientation="h",
208
+ )
209
+ )
210
+ fig.update_traces(
211
+ textfont_size=12,
212
+ textangle=0,
213
+ textposition="outside",
214
+ cliponaxis=False,
215
+ )
216
+ col2.plotly_chart(fig, use_container_width=True)
app/dataset_browser.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ # Copyright (c) 2022, salesforce.com, inc.
3
+ # All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import random
9
+ from collections import OrderedDict
10
+ from functools import reduce
11
+ from tkinter import N
12
+
13
+ import streamlit as st
14
+ from lavis.common.registry import registry
15
+ from lavis.datasets.builders import dataset_zoo, load_dataset
16
+ from lavis.datasets.builders.base_dataset_builder import load_dataset_config
17
+ from PIL import Image
18
+
19
+ IMAGE_LAYOUT = 3, 4
20
+ VIDEO_LAYOUT = 1, 2
21
+
22
+ PREV_STR = "Prev"
23
+ NEXT_STR = "Next"
24
+
25
+
26
+ def sample_dataset(dataset, indices):
27
+ samples = [dataset.displ_item(idx) for idx in indices]
28
+
29
+ return samples
30
+
31
+
32
+ def get_concat_v(im1, im2):
33
+ margin = 5
34
+
35
+ canvas_size = (im1.width + im2.width + margin, max(im1.height, im2.height))
36
+ canvas = Image.new("RGB", canvas_size, "White")
37
+ canvas.paste(im1, (0, 0))
38
+ canvas.paste(im2, (im1.width + margin, 0))
39
+
40
+ return canvas
41
+
42
+
43
+ def resize_img_w(raw_img, new_w=224):
44
+ if isinstance(raw_img, list):
45
+ resized_imgs = [resize_img_w(img, 196) for img in raw_img]
46
+ # concatenate images
47
+ resized_image = reduce(get_concat_v, resized_imgs)
48
+ else:
49
+ w, h = raw_img.size
50
+ scaling_factor = new_w / w
51
+ resized_image = raw_img.resize(
52
+ (int(w * scaling_factor), int(h * scaling_factor))
53
+ )
54
+
55
+ return resized_image
56
+
57
+
58
+ def get_visual_key(dataset):
59
+ if "image" in dataset[0]:
60
+ return "image"
61
+ elif "image0" in dataset[0]: # NLVR2 dataset
62
+ return "image"
63
+ elif "video" in dataset[0]:
64
+ return "video"
65
+ else:
66
+ raise ValueError("Visual key not found.")
67
+
68
+
69
+ def gather_items(samples, exclude=[]):
70
+ gathered = []
71
+
72
+ for s in samples:
73
+ ns = OrderedDict()
74
+ for k in s.keys():
75
+ if k not in exclude:
76
+ ns[k] = s[k]
77
+
78
+ gathered.append(ns)
79
+
80
+ return gathered
81
+
82
+
83
+ @st.cache(allow_output_mutation=True)
84
+ def load_dataset_cache(name):
85
+ return load_dataset(name)
86
+
87
+
88
+ def format_text(text):
89
+ md = "\n\n".join([f"**{k}**: {v}" for k, v in text.items()])
90
+
91
+ return md
92
+
93
+
94
+ def show_samples(dataset, offset=0, is_next=False):
95
+ visual_key = get_visual_key(dataset)
96
+
97
+ num_rows, num_cols = IMAGE_LAYOUT if visual_key == "image" else VIDEO_LAYOUT
98
+ n_samples = num_rows * num_cols
99
+
100
+ if not shuffle:
101
+ if is_next:
102
+ start = min(int(start_idx) + offset + n_samples, len(dataset) - n_samples)
103
+ else:
104
+ start = max(0, int(start_idx) + offset - n_samples)
105
+
106
+ st.session_state.last_start = start
107
+ end = min(start + n_samples, len(dataset))
108
+
109
+ indices = list(range(start, end))
110
+ else:
111
+ indices = random.sample(range(len(dataset)), n_samples)
112
+ samples = sample_dataset(dataset, indices)
113
+
114
+ visual_info = (
115
+ iter([resize_img_w(s[visual_key]) for s in samples])
116
+ if visual_key == "image"
117
+ # else iter([s[visual_key] for s in samples])
118
+ else iter([s["file"] for s in samples])
119
+ )
120
+ text_info = gather_items(samples, exclude=["image", "video"])
121
+ text_info = iter([format_text(s) for s in text_info])
122
+
123
+ st.markdown(
124
+ """<hr style="height:1px;border:none;color:#c7ccd4;background-color:#c7ccd4;"/> """,
125
+ unsafe_allow_html=True,
126
+ )
127
+ for _ in range(num_rows):
128
+ with st.container():
129
+ for col in st.columns(num_cols):
130
+ # col.text(next(text_info))
131
+ # col.caption(next(text_info))
132
+ try:
133
+ col.markdown(next(text_info))
134
+ if visual_key == "image":
135
+ col.image(next(visual_info), use_column_width=True, clamp=True)
136
+ elif visual_key == "video":
137
+ col.markdown(
138
+ "![Alt Text](https://media.giphy.com/media/vFKqnCdLPNOKc/giphy.gif)"
139
+ )
140
+ except StopIteration:
141
+ break
142
+
143
+ st.markdown(
144
+ """<hr style="height:1px;border:none;color:#c7ccd4;background-color:#c7ccd4;"/> """,
145
+ unsafe_allow_html=True,
146
+ )
147
+
148
+ st.session_state.n_display = n_samples
149
+
150
+
151
+ if __name__ == "__main__":
152
+ st.set_page_config(
153
+ page_title="LAVIS Dataset Explorer",
154
+ # layout="wide",
155
+ initial_sidebar_state="expanded",
156
+ )
157
+
158
+ dataset_name = st.sidebar.selectbox("Dataset:", dataset_zoo.get_names())
159
+
160
+ function = st.sidebar.selectbox("Function:", ["Browser"], index=0)
161
+
162
+ if function == "Browser":
163
+ shuffle = st.sidebar.selectbox("Shuffled:", [True, False], index=0)
164
+
165
+ dataset = load_dataset_cache(dataset_name)
166
+ split = st.sidebar.selectbox("Split:", dataset.keys())
167
+
168
+ dataset_len = len(dataset[split])
169
+ st.success(
170
+ f"Loaded {dataset_name}/{split} with **{dataset_len}** records. **Image/video directory**: {dataset[split].vis_root}"
171
+ )
172
+
173
+ if "last_dataset" not in st.session_state:
174
+ st.session_state.last_dataset = dataset_name
175
+ st.session_state.last_split = split
176
+
177
+ if "last_start" not in st.session_state:
178
+ st.session_state.last_start = 0
179
+
180
+ if "start_idx" not in st.session_state:
181
+ st.session_state.start_idx = 0
182
+
183
+ if "shuffle" not in st.session_state:
184
+ st.session_state.shuffle = shuffle
185
+
186
+ if "first_run" not in st.session_state:
187
+ st.session_state.first_run = True
188
+ elif (
189
+ st.session_state.last_dataset != dataset_name
190
+ or st.session_state.last_split != split
191
+ ):
192
+ st.session_state.first_run = True
193
+
194
+ st.session_state.last_dataset = dataset_name
195
+ st.session_state.last_split = split
196
+ elif st.session_state.shuffle != shuffle:
197
+ st.session_state.shuffle = shuffle
198
+ st.session_state.first_run = True
199
+
200
+ if not shuffle:
201
+ n_col, p_col = st.columns([0.05, 1])
202
+
203
+ prev_button = n_col.button(PREV_STR)
204
+ next_button = p_col.button(NEXT_STR)
205
+
206
+ else:
207
+ next_button = st.button(NEXT_STR)
208
+
209
+ if not shuffle:
210
+ start_idx = st.sidebar.text_input(f"Begin from (total {dataset_len})", 0)
211
+
212
+ if not start_idx.isdigit():
213
+ st.error(f"Input to 'Begin from' must be digits, found {start_idx}.")
214
+ else:
215
+ if int(start_idx) != st.session_state.start_idx:
216
+ st.session_state.start_idx = int(start_idx)
217
+ st.session_state.last_start = int(start_idx)
218
+
219
+ if prev_button:
220
+ show_samples(
221
+ dataset[split],
222
+ offset=st.session_state.last_start - st.session_state.start_idx,
223
+ is_next=False,
224
+ )
225
+
226
+ if next_button:
227
+ show_samples(
228
+ dataset[split],
229
+ offset=st.session_state.last_start - st.session_state.start_idx,
230
+ is_next=True,
231
+ )
232
+
233
+ if st.session_state.first_run:
234
+ st.session_state.first_run = False
235
+
236
+ show_samples(
237
+ dataset[split],
238
+ offset=st.session_state.last_start - st.session_state.start_idx,
239
+ is_next=True,
240
+ )
app/image_text_match.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ # Copyright (c) 2022, salesforce.com, inc.
3
+ # All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import numpy as np
9
+ import streamlit as st
10
+ import torch
11
+ from lavis.models.blip_models.blip_image_text_matching import compute_gradcam
12
+ from lavis.processors import load_processor
13
+ from PIL import Image
14
+
15
+ from app import device, load_demo_image
16
+ from app.utils import getAttMap, init_bert_tokenizer, load_blip_itm_model
17
+
18
+
19
+ def app():
20
+ model_type = st.sidebar.selectbox("Model:", ["BLIP_base", "BLIP_large"])
21
+
22
+ if model_type.startswith("BLIP"):
23
+ blip_type = model_type.split("_")[1]
24
+ model = load_blip_itm_model(device, model_type=blip_type)
25
+
26
+ vis_processor = load_processor("blip_image_eval").build(image_size=384)
27
+
28
+ st.markdown(
29
+ "<h1 style='text-align: center;'>Image Text Matching</h1>",
30
+ unsafe_allow_html=True,
31
+ )
32
+
33
+ values = list(range(1, 12))
34
+ default_layer_num = values.index(7)
35
+ layer_num = (
36
+ st.sidebar.selectbox("Layer number", values, index=default_layer_num) - 1
37
+ )
38
+
39
+ instructions = """Try the provided image or upload your own:"""
40
+ file = st.file_uploader(instructions)
41
+
42
+ col1, col2 = st.columns(2)
43
+ col1.header("Image")
44
+ col2.header("GradCam")
45
+ if file:
46
+ raw_img = Image.open(file).convert("RGB")
47
+ else:
48
+ raw_img = load_demo_image()
49
+
50
+ w, h = raw_img.size
51
+ scaling_factor = 720 / w
52
+ resized_image = raw_img.resize((int(w * scaling_factor), int(h * scaling_factor)))
53
+ col1.image(resized_image, use_column_width=True)
54
+
55
+ col3, col4 = st.columns(2)
56
+ col3.header("Text")
57
+ user_question = col3.text_input(
58
+ "Input your sentence!", "a woman sitting on the beach with a dog"
59
+ )
60
+ submit_button = col3.button("Submit")
61
+
62
+ col4.header("Matching score")
63
+
64
+ if submit_button:
65
+ tokenizer = init_bert_tokenizer()
66
+
67
+ img = vis_processor(raw_img).unsqueeze(0).to(device)
68
+ text_processor = load_processor("blip_caption").build()
69
+
70
+ qry = text_processor(user_question)
71
+
72
+ norm_img = np.float32(resized_image) / 255
73
+
74
+ qry_tok = tokenizer(qry, return_tensors="pt").to(device)
75
+ gradcam, output = compute_gradcam(model, img, qry, qry_tok, block_num=layer_num)
76
+
77
+ avg_gradcam = getAttMap(norm_img, gradcam[0][1], blur=True)
78
+
79
+ col2.image(avg_gradcam, use_column_width=True, clamp=True)
80
+ # output = model(img, question)
81
+ itm_score = torch.nn.functional.softmax(output, dim=1)
82
+ new_title = (
83
+ '<p style="text-align: left; font-size: 25px;">\n{:.3f}%</p>'.format(
84
+ itm_score[0][1].item() * 100
85
+ )
86
+ )
87
+ col4.markdown(new_title, unsafe_allow_html=True)
app/main.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ # Copyright (c) 2022, salesforce.com, inc.
3
+ # All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ from app.multipage import MultiPage
9
+ from app import vqa, caption
10
+ from app import image_text_match as itm
11
+ from app import text_localization as tl
12
+ from app import multimodal_search as ms
13
+ from app import classification as cl
14
+
15
+
16
+ if __name__ == "__main__":
17
+ app = MultiPage()
18
+
19
+ app.add_page("Image Description Generation", caption.app)
20
+ app.add_page("Multimodal Search", ms.app)
21
+ app.add_page("Visual Question Answering", vqa.app)
22
+ app.add_page("Image Text Matching", itm.app)
23
+ app.add_page("Text Localization", tl.app)
24
+ app.add_page("Classification", cl.app)
25
+ app.run()
app/multimodal_search.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ # Copyright (c) 2022, salesforce.com, inc.
3
+ # All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import os
9
+
10
+ import numpy as np
11
+ import streamlit as st
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from app import cache_root, device
15
+ from app.utils import (
16
+ getAttMap,
17
+ init_bert_tokenizer,
18
+ load_blip_itm_model,
19
+ read_img,
20
+ resize_img,
21
+ )
22
+ from lavis.models import load_model
23
+ from lavis.processors import load_processor
24
+
25
+
26
+ @st.cache(
27
+ hash_funcs={
28
+ torch.nn.parameter.Parameter: lambda parameter: parameter.data.detach()
29
+ .cpu()
30
+ .numpy()
31
+ },
32
+ allow_output_mutation=True,
33
+ )
34
+ def load_feat():
35
+ from lavis.common.utils import download_url
36
+
37
+ dirname = os.path.join(os.path.dirname(__file__), "assets")
38
+ filename = "path2feat_coco_train2014.pth"
39
+ filepath = os.path.join(dirname, filename)
40
+ url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/path2feat_coco_train2014.pth"
41
+
42
+ if not os.path.exists(filepath):
43
+ download_url(url=url, root=dirname, filename="path2feat_coco_train2014.pth")
44
+
45
+ path2feat = torch.load(filepath)
46
+ paths = sorted(path2feat.keys())
47
+
48
+ all_img_feats = torch.stack([path2feat[k] for k in paths], dim=0).to(device)
49
+
50
+ return path2feat, paths, all_img_feats
51
+
52
+
53
+ @st.cache(
54
+ hash_funcs={
55
+ torch.nn.parameter.Parameter: lambda parameter: parameter.data.detach()
56
+ .cpu()
57
+ .numpy()
58
+ },
59
+ allow_output_mutation=True,
60
+ )
61
+ def load_feature_extractor_model(device):
62
+ model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth"
63
+
64
+ model = load_model(
65
+ "blip_feature_extractor", model_type="base", is_eval=True, device=device
66
+ )
67
+ model.load_from_pretrained(model_url)
68
+
69
+ return model
70
+
71
+
72
+ def app():
73
+ # === layout ===
74
+ model_type = st.sidebar.selectbox("Model:", ["BLIP_base", "BLIP_large"])
75
+ file_root = os.path.join(cache_root, "coco/images/train2014/")
76
+
77
+ values = [12, 24, 48]
78
+ default_layer_num = values.index(24)
79
+ num_display = st.sidebar.selectbox(
80
+ "Number of images:", values, index=default_layer_num
81
+ )
82
+ show_gradcam = st.sidebar.selectbox("Show GradCam:", [True, False], index=1)
83
+ itm_ranking = st.sidebar.selectbox("Multimodal re-ranking:", [True, False], index=0)
84
+
85
+ # st.title('Multimodal Search')
86
+ st.markdown(
87
+ "<h1 style='text-align: center;'>Multimodal Search</h1>", unsafe_allow_html=True
88
+ )
89
+
90
+ # === event ===
91
+ vis_processor = load_processor("blip_image_eval").build(image_size=384)
92
+ text_processor = load_processor("blip_caption")
93
+
94
+ user_question = st.text_input(
95
+ "Search query", "A dog running on the grass.", help="Type something to search."
96
+ )
97
+ user_question = text_processor(user_question)
98
+ feature_extractor = load_feature_extractor_model(device)
99
+
100
+ # ======= ITC =========
101
+ sample = {"text_input": user_question}
102
+
103
+ with torch.no_grad():
104
+ text_feature = feature_extractor.extract_features(
105
+ sample, mode="text"
106
+ ).text_embeds_proj[0, 0]
107
+
108
+ path2feat, paths, all_img_feats = load_feat()
109
+ all_img_feats.to(device)
110
+ all_img_feats = F.normalize(all_img_feats, dim=1)
111
+
112
+ num_cols = 4
113
+ num_rows = int(num_display / num_cols)
114
+
115
+ similarities = text_feature @ all_img_feats.T
116
+ indices = torch.argsort(similarities, descending=True)[:num_display]
117
+
118
+ top_paths = [paths[ind.detach().cpu().item()] for ind in indices]
119
+ sorted_similarities = [similarities[idx] for idx in indices]
120
+ filenames = [os.path.join(file_root, p) for p in top_paths]
121
+
122
+ # ========= ITM and GradCam ==========
123
+ bsz = 4 # max number of images to avoid cuda oom
124
+ if model_type.startswith("BLIP"):
125
+ blip_type = model_type.split("_")[1]
126
+
127
+ itm_model = load_blip_itm_model(device, model_type=blip_type)
128
+
129
+ tokenizer = init_bert_tokenizer()
130
+ queries_batch = [user_question] * bsz
131
+ queries_tok_batch = tokenizer(queries_batch, return_tensors="pt").to(device)
132
+
133
+ num_batches = int(num_display / bsz)
134
+
135
+ avg_gradcams = []
136
+ all_raw_images = []
137
+ itm_scores = []
138
+
139
+ for i in range(num_batches):
140
+ filenames_in_batch = filenames[i * bsz : (i + 1) * bsz]
141
+ raw_images, images = read_and_process_images(filenames_in_batch, vis_processor)
142
+ gradcam, itm_output = compute_gradcam_batch(
143
+ itm_model, images, queries_batch, queries_tok_batch
144
+ )
145
+
146
+ all_raw_images.extend([resize_img(r_img) for r_img in raw_images])
147
+ norm_imgs = [np.float32(r_img) / 255 for r_img in raw_images]
148
+
149
+ for norm_img, grad_cam in zip(norm_imgs, gradcam):
150
+ avg_gradcam = getAttMap(norm_img, grad_cam[0], blur=True)
151
+ avg_gradcams.append(avg_gradcam)
152
+
153
+ with torch.no_grad():
154
+ itm_score = torch.nn.functional.softmax(itm_output, dim=1)
155
+
156
+ itm_scores.append(itm_score)
157
+
158
+ # ========= ITM re-ranking =========
159
+ itm_scores = torch.cat(itm_scores)[:, 1]
160
+ if itm_ranking:
161
+ itm_scores_sorted, indices = torch.sort(itm_scores, descending=True)
162
+
163
+ avg_gradcams_sorted = []
164
+ all_raw_images_sorted = []
165
+ for idx in indices:
166
+ avg_gradcams_sorted.append(avg_gradcams[idx])
167
+ all_raw_images_sorted.append(all_raw_images[idx])
168
+
169
+ avg_gradcams = avg_gradcams_sorted
170
+ all_raw_images = all_raw_images_sorted
171
+
172
+ if show_gradcam:
173
+ images_to_show = iter(avg_gradcams)
174
+ else:
175
+ images_to_show = iter(all_raw_images)
176
+
177
+ for _ in range(num_rows):
178
+ with st.container():
179
+ for col in st.columns(num_cols):
180
+ col.image(next(images_to_show), use_column_width=True, clamp=True)
181
+
182
+
183
+ def read_and_process_images(image_paths, vis_processor):
184
+ raw_images = [read_img(path) for path in image_paths]
185
+ images = [vis_processor(r_img) for r_img in raw_images]
186
+ images_tensors = torch.stack(images).to(device)
187
+
188
+ return raw_images, images_tensors
189
+
190
+
191
+ def compute_gradcam_batch(model, visual_input, text_input, tokenized_text, block_num=6):
192
+ model.text_encoder.base_model.base_model.encoder.layer[
193
+ block_num
194
+ ].crossattention.self.save_attention = True
195
+
196
+ output = model({"image": visual_input, "text_input": text_input}, match_head="itm")
197
+ loss = output[:, 1].sum()
198
+
199
+ model.zero_grad()
200
+ loss.backward()
201
+ with torch.no_grad():
202
+ mask = tokenized_text.attention_mask.view(
203
+ tokenized_text.attention_mask.size(0), 1, -1, 1, 1
204
+ ) # (bsz,1,token_len, 1,1)
205
+ token_length = mask.sum() - 2
206
+ token_length = token_length.cpu()
207
+ # grads and cams [bsz, num_head, seq_len, image_patch]
208
+ grads = model.text_encoder.base_model.base_model.encoder.layer[
209
+ block_num
210
+ ].crossattention.self.get_attn_gradients()
211
+ cams = model.text_encoder.base_model.base_model.encoder.layer[
212
+ block_num
213
+ ].crossattention.self.get_attention_map()
214
+
215
+ # assume using vit large with 576 num image patch
216
+ cams = cams[:, :, :, 1:].reshape(visual_input.size(0), 12, -1, 24, 24) * mask
217
+ grads = (
218
+ grads[:, :, :, 1:].clamp(0).reshape(visual_input.size(0), 12, -1, 24, 24)
219
+ * mask
220
+ )
221
+
222
+ gradcam = cams * grads
223
+ # [enc token gradcam, average gradcam across token, gradcam for individual token]
224
+ # gradcam = torch.cat((gradcam[0:1,:], gradcam[1:token_length+1, :].sum(dim=0, keepdim=True)/token_length, gradcam[1:, :]))
225
+ gradcam = gradcam.mean(1).cpu().detach()
226
+ gradcam = (
227
+ gradcam[:, 1 : token_length + 1, :].sum(dim=1, keepdim=True) / token_length
228
+ )
229
+
230
+ return gradcam, output
app/multipage.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ # Copyright (c) 2022, salesforce.com, inc.
3
+ # All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ """
9
+ This file is the framework for generating multiple Streamlit applications
10
+ through an object oriented framework.
11
+ """
12
+
13
+ # Import necessary libraries
14
+ import streamlit as st
15
+
16
+ # Define the multipage class to manage the multiple apps in our program
17
+ class MultiPage:
18
+ """Framework for combining multiple streamlit applications."""
19
+
20
+ def __init__(self) -> None:
21
+ """Constructor class to generate a list which will store all our applications as an instance variable."""
22
+ self.pages = []
23
+
24
+ def add_page(self, title, func) -> None:
25
+ """Class Method to Add pages to the project
26
+ Args:
27
+ title ([str]): The title of page which we are adding to the list of apps
28
+
29
+ func: Python function to render this page in Streamlit
30
+ """
31
+
32
+ self.pages.append({"title": title, "function": func})
33
+
34
+ def run(self):
35
+ # Drodown to select the page to run
36
+ page = st.sidebar.selectbox(
37
+ "Navigation", self.pages, format_func=lambda page: page["title"]
38
+ )
39
+
40
+ # run the app function
41
+ page["function"]()
app/text_localization.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ # Copyright (c) 2022, salesforce.com, inc.
3
+ # All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import math
9
+
10
+ import numpy as np
11
+ import streamlit as st
12
+ from lavis.models.blip_models.blip_image_text_matching import compute_gradcam
13
+ from lavis.processors import load_processor
14
+ from PIL import Image
15
+
16
+ from app import device, load_demo_image
17
+ from app.utils import getAttMap, init_bert_tokenizer, load_blip_itm_model
18
+
19
+
20
+ def app():
21
+ model_type = st.sidebar.selectbox("Model:", ["BLIP_base", "BLIP_large"])
22
+
23
+ values = list(range(1, 12))
24
+ default_layer_num = values.index(7)
25
+ layer_num = (
26
+ st.sidebar.selectbox("Layer number", values, index=default_layer_num) - 1
27
+ )
28
+
29
+ st.markdown(
30
+ "<h1 style='text-align: center;'>Text Localization</h1>", unsafe_allow_html=True
31
+ )
32
+
33
+ vis_processor = load_processor("blip_image_eval").build(image_size=384)
34
+ text_processor = load_processor("blip_caption")
35
+
36
+ tokenizer = init_bert_tokenizer()
37
+
38
+ instructions = "Try the provided image and text or use your own ones."
39
+ file = st.file_uploader(instructions)
40
+
41
+ query = st.text_input(
42
+ "Try a different input.", "A girl playing with her dog on the beach."
43
+ )
44
+
45
+ submit_button = st.button("Submit")
46
+
47
+ col1, col2 = st.columns(2)
48
+
49
+ if file:
50
+ raw_img = Image.open(file).convert("RGB")
51
+ else:
52
+ raw_img = load_demo_image()
53
+
54
+ col1.header("Image")
55
+ w, h = raw_img.size
56
+ scaling_factor = 720 / w
57
+ resized_image = raw_img.resize((int(w * scaling_factor), int(h * scaling_factor)))
58
+ col1.image(resized_image, use_column_width=True)
59
+
60
+ col2.header("GradCam")
61
+
62
+ if submit_button:
63
+ if model_type.startswith("BLIP"):
64
+ blip_type = model_type.split("_")[1]
65
+ model = load_blip_itm_model(device, model_type=blip_type)
66
+
67
+ img = vis_processor(raw_img).unsqueeze(0).to(device)
68
+ qry = text_processor(query)
69
+
70
+ qry_tok = tokenizer(qry, return_tensors="pt").to(device)
71
+
72
+ norm_img = np.float32(resized_image) / 255
73
+
74
+ gradcam, _ = compute_gradcam(model, img, qry, qry_tok, block_num=layer_num)
75
+
76
+ avg_gradcam = getAttMap(norm_img, gradcam[0][1], blur=True)
77
+ col2.image(avg_gradcam, use_column_width=True, clamp=True)
78
+
79
+ num_cols = 4.0
80
+ num_tokens = len(qry_tok.input_ids[0]) - 2
81
+
82
+ num_rows = int(math.ceil(num_tokens / num_cols))
83
+
84
+ gradcam_iter = iter(gradcam[0][2:-1])
85
+ token_id_iter = iter(qry_tok.input_ids[0][1:-1])
86
+
87
+ for _ in range(num_rows):
88
+ with st.container():
89
+ for col in st.columns(int(num_cols)):
90
+ token_id = next(token_id_iter, None)
91
+ if not token_id:
92
+ break
93
+ gradcam_img = next(gradcam_iter)
94
+
95
+ word = tokenizer.decode([token_id])
96
+ gradcam_todraw = getAttMap(norm_img, gradcam_img, blur=True)
97
+
98
+ new_title = (
99
+ '<p style="text-align: center; font-size: 25px;">{}</p>'.format(
100
+ word
101
+ )
102
+ )
103
+ col.markdown(new_title, unsafe_allow_html=True)
104
+ # st.image(image, channels="BGR")
105
+ col.image(gradcam_todraw, use_column_width=True, clamp=True)
app/utils.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ # Copyright (c) 2022, salesforce.com, inc.
3
+ # All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import numpy as np
9
+ import streamlit as st
10
+ import torch
11
+ from lavis.models import BlipBase, load_model
12
+ from matplotlib import pyplot as plt
13
+ from PIL import Image
14
+ from scipy.ndimage import filters
15
+ from skimage import transform as skimage_transform
16
+
17
+
18
+ def resize_img(raw_img):
19
+ w, h = raw_img.size
20
+ scaling_factor = 240 / w
21
+ resized_image = raw_img.resize((int(w * scaling_factor), int(h * scaling_factor)))
22
+ return resized_image
23
+
24
+
25
+ def read_img(filepath):
26
+ raw_image = Image.open(filepath).convert("RGB")
27
+
28
+ return raw_image
29
+
30
+
31
+ @st.cache(
32
+ hash_funcs={
33
+ torch.nn.parameter.Parameter: lambda parameter: parameter.data.detach()
34
+ .cpu()
35
+ .numpy()
36
+ },
37
+ allow_output_mutation=True,
38
+ )
39
+ def load_model_cache(name, model_type, is_eval, device):
40
+ return load_model(name, model_type, is_eval, device)
41
+
42
+
43
+ @st.cache(allow_output_mutation=True)
44
+ def init_bert_tokenizer():
45
+ tokenizer = BlipBase.init_tokenizer()
46
+ return tokenizer
47
+
48
+
49
+ def getAttMap(img, attMap, blur=True, overlap=True):
50
+ attMap -= attMap.min()
51
+ if attMap.max() > 0:
52
+ attMap /= attMap.max()
53
+ attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant")
54
+ if blur:
55
+ attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
56
+ attMap -= attMap.min()
57
+ attMap /= attMap.max()
58
+ cmap = plt.get_cmap("jet")
59
+ attMapV = cmap(attMap)
60
+ attMapV = np.delete(attMapV, 3, 2)
61
+ if overlap:
62
+ attMap = (
63
+ 1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
64
+ + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
65
+ )
66
+ return attMap
67
+
68
+
69
+ @st.cache(
70
+ hash_funcs={
71
+ torch.nn.parameter.Parameter: lambda parameter: parameter.data.detach()
72
+ .cpu()
73
+ .numpy()
74
+ },
75
+ allow_output_mutation=True,
76
+ )
77
+ def load_blip_itm_model(device, model_type="base"):
78
+ model = load_model(
79
+ "blip_image_text_matching", model_type, is_eval=True, device=device
80
+ )
81
+ return model
app/vqa.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ # Copyright (c) 2022, salesforce.com, inc.
3
+ # All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import streamlit as st
9
+ from app import load_demo_image, device
10
+ from app.utils import load_model_cache
11
+ from lavis.processors import load_processor
12
+ from PIL import Image
13
+
14
+
15
+ def app():
16
+ model_type = st.sidebar.selectbox("Model:", ["BLIP"])
17
+
18
+ # ===== layout =====
19
+ st.markdown(
20
+ "<h1 style='text-align: center;'>Visual Question Answering</h1>",
21
+ unsafe_allow_html=True,
22
+ )
23
+
24
+ instructions = """Try the provided image or upload your own:"""
25
+ file = st.file_uploader(instructions)
26
+
27
+ col1, col2 = st.columns(2)
28
+
29
+ col1.header("Image")
30
+ if file:
31
+ raw_img = Image.open(file).convert("RGB")
32
+ else:
33
+ raw_img = load_demo_image()
34
+
35
+ w, h = raw_img.size
36
+ scaling_factor = 720 / w
37
+ resized_image = raw_img.resize((int(w * scaling_factor), int(h * scaling_factor)))
38
+
39
+ col1.image(resized_image, use_column_width=True)
40
+ col2.header("Question")
41
+
42
+ user_question = col2.text_input("Input your question!", "What are objects there?")
43
+ qa_button = st.button("Submit")
44
+
45
+ col2.header("Answer")
46
+
47
+ # ===== event =====
48
+ vis_processor = load_processor("blip_image_eval").build(image_size=480)
49
+ text_processor = load_processor("blip_question").build()
50
+
51
+ if qa_button:
52
+ if model_type.startswith("BLIP"):
53
+ model = load_model_cache(
54
+ "blip_vqa", model_type="vqav2", is_eval=True, device=device
55
+ )
56
+
57
+ img = vis_processor(raw_img).unsqueeze(0).to(device)
58
+ question = text_processor(user_question)
59
+
60
+ vqa_samples = {"image": img, "text_input": [question]}
61
+ answers = model.predict_answers(vqa_samples, inference_method="generate")
62
+
63
+ col2.write("\n".join(answers), use_column_width=True)
assets/.DS_Store ADDED
Binary file (6.15 kB). View file
 
assets/chain.png ADDED
assets/model.png ADDED
assets/teaser.png ADDED
docs/.DS_Store ADDED
Binary file (8.2 kB). View file
 
docs/Makefile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Minimal makefile for Sphinx documentation
2
+ #
3
+
4
+ # You can set these variables from the command line, and also
5
+ # from the environment for the first two.
6
+ SPHINXOPTS ?=
7
+ SPHINXBUILD ?= sphinx-build
8
+ SOURCEDIR = source
9
+ BUILDDIR = build
10
+
11
+ # Put it first so that "make" without argument is like "make help".
12
+ help:
13
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14
+
15
+ .PHONY: help Makefile
16
+
17
+ # Catch-all target: route all unknown targets to Sphinx using the new
18
+ # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19
+ %: Makefile
20
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
docs/_static/.DS_Store ADDED
Binary file (6.15 kB). View file
 
docs/_static/architecture.png ADDED
docs/_static/logo_final.png ADDED
docs/benchmark.rst ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Benchmark
2
+ ############
3
+
4
+ We provide scripts for evaluating and training models on task datasets. The following benchmark results are included for reference.
5
+
6
+
7
+ ALBEF
8
+ *******
9
+ .. list-table::
10
+ :widths: 30 80 20
11
+
12
+ * - **Pretraining**
13
+ - COCO (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
14
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/train/pretrain.sh>`__
15
+ * -
16
+ - Visual Genome (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_vg.py>`__)
17
+ -
18
+ * -
19
+ - SBU (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_sbu.py>`__)
20
+ -
21
+ * -
22
+ - CC3M (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/DownloadConceptualCaptions/download_data_cc3m.py>`__)
23
+ -
24
+ * -
25
+ - CC12M (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/DownloadConceptualCaptions/download_data_cc12m.py>`__)
26
+ -
27
+
28
+ .. list-table::
29
+ :widths: 30 40 20 20 20 30 30
30
+ :header-rows: 1
31
+
32
+ * -
33
+ - **Retrieval**
34
+ - **R1**
35
+ - **R5**
36
+ - **R10**
37
+ - **Training**
38
+ - **Evaluation**
39
+ * - TR
40
+ - COCO (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
41
+ - 77.6
42
+ - 94.1
43
+ - 97.2
44
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/train/train_coco_retrieval_albef.sh>`__
45
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/eval/eval_coco_retrieval.sh>`__
46
+ * - IR
47
+ - COCO (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
48
+ - 61.0
49
+ - 84.5
50
+ - 90.7
51
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/train/train_coco_retrieval_albef.sh>`__
52
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/eval/eval_coco_retrieval.sh>`__
53
+ * - TR
54
+ - Flickr30k (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_flickr.py>`__)
55
+ - 77.6
56
+ - 94.1
57
+ - 97.2
58
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/train/train_flickr30k_retrieval_albef.sh>`__
59
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/eval/eval_flickr30k_retrieval.sh>`__
60
+ * - IR
61
+ - Flickr30k (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_flickr.py>`__)
62
+ - 61.0
63
+ - 84.5
64
+ - 90.7
65
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/train/train_flickr30k_retrieval_albef.sh>`__
66
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/eval/eval_flickr30k_retrieval.sh>`__
67
+
68
+
69
+ .. list-table::
70
+ :widths: 20 20 20 20 20
71
+ :header-rows: 1
72
+
73
+ * - **VQA**
74
+ - **test-dev**
75
+ - **test-std/test**
76
+ - **Training**
77
+ - **Evaluation**
78
+ * - VQAv2 (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
79
+ - 76.35
80
+ - 76.54
81
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/train/train_vqa_albef.sh>`__
82
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/eval/test_albef_vqa.sh>`__
83
+ * - OKVQA (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
84
+ - NA
85
+ - 54.7
86
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/train/train_okvqa_albef.sh>`__
87
+ - NA
88
+ * - AOKVQA (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
89
+ - 54.5
90
+ - NA
91
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/train/train_aokvqa_albef.sh>`__
92
+ - NA
93
+
94
+
95
+ .. list-table::
96
+ :widths: 20 20 20 20 20
97
+ :header-rows: 1
98
+
99
+ * - **Multimodal Classification**
100
+ - **val**
101
+ - **test**
102
+ - **Training**
103
+ - **Evaluation**
104
+ * - SNLI-VE (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
105
+ - 80.60
106
+ - 81.04
107
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/train/train_ve_albef.sh>`__
108
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/eval/eval_albef_ve.sh>`__
109
+ * - NLVR2 (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
110
+ - 82.47
111
+ - 82.91
112
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/train/train_nlvr_albef.sh>`__
113
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/eval/eval_albef_nlvr.sh>`__
114
+
115
+ BLIP
116
+ *******
117
+ .. list-table::
118
+ :widths: 30 80 20
119
+
120
+ * - **Pretraining (14M)**
121
+ - COCO (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
122
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/train/pretrain.sh>`__
123
+ * -
124
+ - Visual Genome (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_vg.py>`__)
125
+ -
126
+ * -
127
+ - SBU (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_sbu.py>`__)
128
+ -
129
+ * -
130
+ - CC3M (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/DownloadConceptualCaptions/download_data_cc3m.py>`__)
131
+ -
132
+ * -
133
+ - CC12M (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/DownloadConceptualCaptions/download_data_cc12m.py>`__)
134
+ -
135
+
136
+ .. list-table::
137
+ :widths: 30 40 20 20 20 30 30
138
+ :header-rows: 1
139
+
140
+ * - **Tasks**
141
+ - **Retrieval**
142
+ - **R1**
143
+ - **R5**
144
+ - **R10**
145
+ - **Training**
146
+ - **Evaluation**
147
+ * - TR
148
+ - COCO (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
149
+ - 82.0
150
+ - 95.8
151
+ - 98.1
152
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/train/train_retrieval_coco.sh>`__
153
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/eval/eval_ret_coco.sh>`__
154
+ * - IR
155
+ - COCO (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
156
+ - 64.5
157
+ - 86.0
158
+ - 91.7
159
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/train/train_retrieval_coco.sh>`__
160
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/eval/eval_ret_coco.sh>`__
161
+ * - TR
162
+ - Flickr30k (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_flickr.py>`__)
163
+ - 96.9
164
+ - 99.9
165
+ - 100.0
166
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/train/train_retrieval_flickr.sh>`__
167
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/eval/eval_ret_flickr.sh>`__
168
+ * - IR
169
+ - Flickr30k (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_flickr.py>`__)
170
+ - 87.5
171
+ - 97.6
172
+ - 98.9
173
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/train/train_retrieval_flickr.sh>`__
174
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/eval/eval_ret_flickr.sh>`__
175
+
176
+
177
+ .. list-table::
178
+ :widths: 20 20 20 20 20
179
+ :header-rows: 1
180
+
181
+ * - **VQA**
182
+ - **test-dev**
183
+ - **test-std/test**
184
+ - **Training**
185
+ - **Evaluation**
186
+ * - VQAv2 (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
187
+ - 78.23
188
+ - 78.29
189
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/train/train_vqa_albef.sh>`__
190
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/eval/test_albef_vqa.sh>`__
191
+ * - OKVQA (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
192
+ - NA
193
+ - 55.4
194
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/train/train_okvqa.sh>`__
195
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/eval/eval_okvqa.sh>`__
196
+ * - AOKVQA (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
197
+ - 56.2
198
+ - 50.1
199
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/train/train_aokvqa.sh>`__
200
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/eval/eval_aokvqa.sh>`__
201
+
202
+
203
+ .. list-table::
204
+ :widths: 20 20 20 20 20 20
205
+ :header-rows: 1
206
+
207
+ * - **Image Captioning**
208
+ - **BLEU@4**
209
+ - **CIDEr**
210
+ - **SPICE**
211
+ - **Training**
212
+ - **Evaluation**
213
+ * - COCO (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
214
+ - 39.9
215
+ - 133.5
216
+ - 23.7
217
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/train/train_caption_coco.sh>`__
218
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/eval/eval_coco_cap.sh>`__
219
+ * - NoCaps (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_nocaps.py>`__)
220
+ - 31.9
221
+ - 109.1
222
+ - 14.7
223
+ - NA
224
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/eval/eval_nocaps.sh>`__
225
+
226
+
227
+ .. list-table::
228
+ :widths: 20 20 20 20 20
229
+ :header-rows: 1
230
+
231
+ * - **Multimodal Classification**
232
+ - **val**
233
+ - **test**
234
+ - **Training**
235
+ - **Evaluation**
236
+ * - NLVR2 (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
237
+ - 82.48
238
+ - 83.25
239
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/train/train_nlvr.sh>`__
240
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/eval/eval_nlvr.sh>`__
241
+
242
+ CLIP
243
+ *******
244
+ .. list-table::
245
+ :widths: 30 40 20 20 20 30
246
+ :header-rows: 1
247
+
248
+ * - **Tasks**
249
+ - **Retrieval (Zero-shot)**
250
+ - **R1**
251
+ - **R5**
252
+ - **R10**
253
+ - **Evaluation**
254
+ * - TR
255
+ - COCO (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
256
+ - 57.2
257
+ - 80.5
258
+ - 87.8
259
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/clip/eval/eval_clip_ret_coco.sh>`__
260
+ * - IR
261
+ - COCO (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
262
+ - 36.5
263
+ - 60.8
264
+ - 71.0
265
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/clip/eval/eval_clip_ret_coco.sh>`__
266
+ * - TR
267
+ - Flickr30k (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_flickr.py>`__)
268
+ - 86.5
269
+ - 98.0
270
+ - 99.1
271
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/clip/eval/eval_clip_ret_flickr.sh>`__
272
+ * - IR
273
+ - Flickr30k (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_flickr.py>`__)
274
+ - 67.0
275
+ - 88.9
276
+ - 93.3
277
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/clip/eval/eval_clip_ret_flickr.sh>`__
278
+
279
+ .. list-table::
280
+ :widths: 20 20 20
281
+ :header-rows: 1
282
+
283
+ * - **Multimodal Classification**
284
+ - **val**
285
+ - **Evaluation**
286
+ * - ImageNet
287
+ - 76.5
288
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/clip/eval/eval_clip_zs_imnet.sh>`__
289
+
290
+
291
+ ALPRO
292
+ *******
293
+ .. list-table::
294
+ :widths: 30 40 20 20 20 20 30
295
+ :header-rows: 1
296
+
297
+ * - **Tasks**
298
+ - **Retrieval**
299
+ - **R1**
300
+ - **R5**
301
+ - **R10**
302
+ - **Training**
303
+ - **Evaluation**
304
+ * - TR
305
+ - MSRVTT (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_msrvtt.py>`__)
306
+ - 33.2
307
+ - 60.5
308
+ - 71.7
309
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/train/train_msrvtt_ret.sh>`__
310
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/eval/eval_msrvtt_ret.sh>`__
311
+ * - VR
312
+ - MSRVTT (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_msrvtt.py>`__)
313
+ - 33.8
314
+ - 61.4
315
+ - 72.7
316
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/train/train_msrvtt_ret.sh>`__
317
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/eval/eval_msrvtt_ret.sh>`__
318
+ * - TR
319
+ - DiDeMo (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_didemo.py>`__)
320
+ - 38.8
321
+ - 66.4
322
+ - 76.8
323
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/train/train_didemo_ret.sh>`__
324
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/eval/eval_didemo_ret.sh>`__
325
+ * - VR
326
+ - DiDeMo (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_didemo.py>`__)
327
+ - 36.6
328
+ - 67.5
329
+ - 77.9
330
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/train/train_didemo_ret.sh>`__
331
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/eval/eval_didemo_ret.sh>`__
332
+
333
+ .. list-table::
334
+ :widths: 20 20 20 20
335
+ :header-rows: 1
336
+
337
+ * - **Video QA**
338
+ - **test**
339
+ - **Training**
340
+ - **Evaluation**
341
+ * - MSRVTT
342
+ - 42.1
343
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/train/train_msrvtt_qa.sh>`__
344
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/eval/eval_msrvtt_qa.sh>`__
345
+ * - MSVD
346
+ - 46.0
347
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/train/train_msvd_qa.sh>`__
348
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/eval/eval_msvd_qa.sh>`__
docs/build_docs.sh ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -euo pipefail
3
+
4
+ # Change to root directory of repo
5
+ DIRNAME=$(cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
6
+ cd "${DIRNAME}/.."
7
+
8
+ # # Set up virtual environment
9
+ pip3 install setuptools wheel virtualenv
10
+ if [ ! -d venv ]; then
11
+ rm -f venv
12
+ virtualenv venv
13
+ fi
14
+ source venv/bin/activate
15
+
16
+ # # Get current git branch & stash unsaved changes
17
+ GIT_BRANCH=$(git branch --show-current)
18
+ if [ -z "${GIT_BRANCH}" ]; then
19
+ GIT_BRANCH="main"
20
+ fi
21
+ git stash
22
+
23
+ # Set up exit handler to restore git state & delete temp branches
24
+ # function exit_handler {
25
+ # git reset --hard
26
+ # git checkout "${GIT_BRANCH}" --
27
+ # git stash pop || true
28
+ # for version in $(git tag --list 'v[0-9]*'); do
29
+ # branch="${version}_local_docs_only"
30
+ # if git show-ref --verify --quiet "refs/heads/$branch"; then
31
+ # git branch -D "$branch"
32
+ # fi
33
+ # done
34
+ # }
35
+ # trap exit_handler EXIT
36
+
37
+ # Clean up build directory and install Sphinx requirements
38
+ pip3 install -r "${DIRNAME}/requirements.txt"
39
+ sphinx-build -M clean "${DIRNAME}" "${DIRNAME}/_build"
40
+
41
+ # Build API docs for current head
42
+ export current_version="latest"
43
+ pip3 install "."
44
+ sphinx-build -b html "${DIRNAME}" "${DIRNAME}/_build/html/${current_version}" -W --keep-going
45
+ rm -rf "${DIRNAME}/_build/html/${current_version}/.doctrees"
46
+ #pip3 uninstall -y omnixai
47
+
48
+ # Install all previous released versions
49
+ # and use them to build the appropriate API docs.
50
+ # Uninstall after we're done with each one.
51
+ # versions=()
52
+ # checkout_files=("${DIRNAME}/*.rst" "lavis" "tutorials" "setup.py")
53
+ # for version in $(git tag --list 'v[0-9]*'); do
54
+ # versions+=("$version")
55
+ # git checkout -b "${version}_local_docs_only"
56
+ # for f in $(git diff --name-only --diff-filter=A "tags/${version}" "${DIRNAME}/*.rst"); do
57
+ # git rm "$f"
58
+ # done
59
+ # git checkout "tags/${version}" -- "${checkout_files[@]}"
60
+ # export current_version=${version}
61
+ # pip3 install ".[all]"
62
+ # sphinx-build -b html "${DIRNAME}" "${DIRNAME}/_build/html/${current_version}" -W --keep-going
63
+ # rm -rf "${DIRNAME}/_build/html/${current_version}/.doctrees"
64
+ # #pip3 uninstall -y omnixai
65
+ # git reset --hard
66
+ # git checkout "${GIT_BRANCH}" --
67
+ # done
68
+
69
+ # Determine the latest stable version if there is one
70
+ # if (( ${#versions[@]} > 0 )); then
71
+ # stable_hash=$(git rev-list --tags --max-count=1)
72
+ # stable_version=$(git describe --tags "$stable_hash")
73
+ # export stable_version
74
+ # else
75
+ export stable_version="latest"
76
+ # fi
77
+
78
+ # Create dummy HTML's for the stable version in the base directory
79
+ while read -r filename; do
80
+ filename=$(echo "$filename" | sed "s/\.\///")
81
+ n_sub=$(echo "$filename" | (grep -o "/" || true) | wc -l)
82
+ prefix=""
83
+ for (( i=0; i<n_sub; i++ )); do
84
+ prefix+="../"
85
+ done
86
+ url="${prefix}${stable_version}/$filename"
87
+ mkdir -p "${DIRNAME}/_build/html/$(dirname "$filename")"
88
+ cat > "${DIRNAME}/_build/html/$filename" <<EOF
89
+ <!DOCTYPE html>
90
+ <html>
91
+ <head>
92
+ <title>LAVIS Documentation</title>
93
+ <meta http-equiv = "refresh" content="0; url='$url'" />
94
+ </head>
95
+ <body>
96
+ <p>Please wait while you're redirected to our <a href="$url">documentation</a>.</p>
97
+ </body>
98
+ </html>
99
+ EOF
100
+ done < <(cd "${DIRNAME}/_build/html/$stable_version" && find . -name "*.html")
101
+ echo "Finished writing to _build/html."
docs/conf.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration file for the Sphinx documentation builder.
2
+ #
3
+ # This file only contains a selection of the most common options. For a full
4
+ # list see the documentation:
5
+ # https://www.sphinx-doc.org/en/master/usage/configuration.html
6
+
7
+ # -- Path setup --------------------------------------------------------------
8
+
9
+ # If extensions (or modules to document with autodoc) are in another directory,
10
+ # add these directories to sys.path here. If the directory is relative to the
11
+ # documentation root, use os.path.abspath to make it absolute, like shown here.
12
+ #
13
+ # import os
14
+ # import sys
15
+ # sys.path.insert(0, os.path.abspath('.'))
16
+
17
+
18
+ # -- Project information -----------------------------------------------------
19
+
20
+ project = "LAVIS"
21
+ copyright = "2022, salesforce.com inc."
22
+ author = (
23
+ "Dongxu Li, Junnan Li, Hung Le, Guangsen Wang, Silvio Savarese, Steven C.H. Hoi"
24
+ )
25
+
26
+
27
+ # -- General configuration ---------------------------------------------------
28
+
29
+ # Add any Sphinx extension module names here, as strings. They can be
30
+ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
31
+ # ones.
32
+ extensions = ["nbsphinx"]
33
+
34
+ # Add any paths that contain templates here, relative to this directory.
35
+ templates_path = ["_templates"]
36
+
37
+ # List of patterns, relative to source directory, that match files and
38
+ # directories to ignore when looking for source files.
39
+ # This pattern also affects html_static_path and html_extra_path.
40
+ exclude_patterns = []
41
+
42
+
43
+ # -- Options for HTML output -------------------------------------------------
44
+
45
+ # The theme to use for HTML and HTML Help pages. See the documentation for
46
+ # a list of builtin themes.
47
+ #
48
+ # html_theme = "alabaster"
49
+ html_theme = "sphinx_rtd_theme"
50
+
51
+ # Add any paths that contain custom static files (such as style sheets) here,
52
+ # relative to this directory. They are copied after the builtin static files,
53
+ # so a file named "default.css" will overwrite the builtin "default.css".
54
+ html_static_path = ["_static"]
55
+
56
+ # pygments_style = "sphinx"
docs/getting_started.rst ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Dataset Zoo
2
+ ##################
3
+ LAVIS inherently supports a wide variety of common language-vision datasets by providing automatic download scripts to help download and organize these datasets;
4
+ and implements PyTorch datasets for these datasets. To view supported datasets, use the following code:
5
+
6
+ .. code-block:: python
7
+
8
+ from lavis.datasets.builders import dataset_zoo
9
+ dataset_names = dataset_zoo.get_names()
10
+ print(dataset_names)
11
+ # ['aok_vqa', 'coco_caption', 'coco_retrieval', 'coco_vqa', 'conceptual_caption_12m',
12
+ # 'conceptual_caption_3m', 'didemo_retrieval', 'flickr30k', 'imagenet', 'laion2B_multi',
13
+ # 'msrvtt_caption', 'msrvtt_qa', 'msrvtt_retrieval', 'msvd_caption', 'msvd_qa', 'nlvr',
14
+ # 'nocaps', 'ok_vqa', 'sbu_caption', 'snli_ve', 'vatex_caption', 'vg_caption', 'vg_vqa']
15
+ print(len(dataset_names))
16
+ # 23
17
+
18
+
19
+ Auto-Downloading and Loading Datasets
20
+ ######################################
21
+ We now take COCO caption dataset as an example to demonstrate how to download and prepare the dataset.
22
+
23
+ In ``lavis/datasets/download_scripts/``, we provide tools to download most common public language-vision datasets supported by LAVIS.
24
+ The COCO caption dataset uses images from COCO dataset. Therefore, we first download COCO images via:
25
+
26
+ .. code-block:: bash
27
+
28
+ cd lavis/datasets/download_scripts/ && python download_coco.py
29
+
30
+ This will automatically download and extract COCO images to the default LAVIS cache location.
31
+ The default cache location is ``~/.cache/lavis``, defined in ``lavis/configs/default.yaml``.
32
+
33
+ After downloading the images, we can use ``load_dataset()`` to obtain the dataset. On the first run, this will automatically download and cache annotation files.
34
+
35
+ .. code-block:: python
36
+
37
+ from lavis.datasets.builders import load_dataset
38
+ coco_dataset = load_dataset("coco_caption")
39
+
40
+ print(coco_dataset.keys())
41
+ # dict_keys(['train', 'val', 'test'])
42
+
43
+ print(len(coco_dataset["train"]))
44
+ # 566747
45
+
46
+ print(coco_dataset["train"][0])
47
+ # {'image': <PIL.Image.Image image mode=RGB size=640x480>,
48
+ # 'text_input': 'A woman wearing a net on her head cutting a cake. ',
49
+ # 'image_id': 0}
50
+
51
+ If you already host a local copy of the dataset, you can pass in the ``vis_path`` argument to change the default location to load images.
52
+
53
+ .. code-block:: python
54
+
55
+ coco_dataset = load_dataset("coco_caption", vis_path=YOUR_LOCAL_PATH)
56
+
57
+
58
+ Model Zoo
59
+ ####################################
60
+ LAVIS supports a growing list of pre-trained models for different tasks,
61
+ datatsets and of varying sizes. Let's get started by viewing the supported models.
62
+
63
+ .. code-block:: python
64
+
65
+ from lavis.models import model_zoo
66
+ print(model_zoo)
67
+ # ==================================================
68
+ # Architectures Types
69
+ # ==================================================
70
+ # albef_classification base, ve
71
+ # albef_nlvr base
72
+ # albef_pretrain base
73
+ # albef_retrieval base, coco, flickr
74
+ # albef_vqa base, vqav2
75
+ # alpro_qa base, msrvtt, msvd
76
+ # alpro_retrieval base, msrvtt, didemo
77
+ # blip_caption base, base_coco, large, large_coco
78
+ # blip_classification base
79
+ # blip_feature_extractor base
80
+ # blip_nlvr base
81
+ # blip_pretrain base
82
+ # blip_retrieval base, coco, flickr
83
+ # blip_vqa base, vqav2
84
+ # clip ViT-B-32, ViT-B-16, ViT-L-14, ViT-L-14-336, RN50
85
+
86
+ # show total number of support model variants
87
+ len(model_zoo)
88
+ # 33
89
+
90
+
91
+ Inference with Pre-trained Models
92
+ ####################################
93
+
94
+ Now let's see how to use models in LAVIS to perform inference on example data. We first
95
+ load a sample image from local.
96
+
97
+ .. code-block:: python
98
+
99
+ from PIL import Image
100
+
101
+ # setup device to use
102
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
103
+
104
+ # load sample image
105
+ raw_image = Image.open("docs/_static/merlion.png").convert("RGB")
106
+
107
+ This example image shows `Merlion park <https://en.wikipedia.org/wiki/Merlion>`_ (`image credit <https://theculturetrip.com/asia/singapore/articles/what-exactly-is-singapores-merlion-anyway/>`_), a landmark in Singapore.
108
+
109
+ .. image:: _static/merlion.png
110
+
111
+ Image Captioning
112
+ *******************************
113
+ We now use the BLIP model to generate a caption for the image. To make inference even easier, we also associate each
114
+ pre-trained model with its preprocessors (transforms), we use ``load_model_and_preprocess()`` with the following arguments:
115
+
116
+ - ``name``: The name of the model to load. This could be a pre-trained model, task model, or feature extractor. See ``model_zoo`` for a full list of model names.
117
+ - ``model_type``: Each architecture has variants trained on different datasets and at different scale. See Types column in ``model_zoo`` for a full list of model types.
118
+ - ``is_eval``: if `True`, set the model to evaluation mode. This is desired for inference or feature extraction.
119
+ - ``devce``: device to load the model to.
120
+
121
+ .. code-block:: python
122
+
123
+ from lavis.models import load_model_and_preprocess
124
+ # loads BLIP caption base model, with finetuned checkpoints on MSCOCO captioning dataset.
125
+ # this also loads the associated image processors
126
+ model, vis_processors, _ = load_model_and_preprocess(name="blip_caption", model_type="base_coco", is_eval=True, device=device)
127
+
128
+ # preprocess the image
129
+ # vis_processors stores image transforms for "train" and "eval" (validation / testing / inference)
130
+ image = vis_processors["eval"](raw_image).unsqueeze(0).to(device)
131
+
132
+ # generate caption
133
+ model.generate({"image": image})
134
+ # ['a large fountain spewing water into the air']
135
+
136
+
137
+ You may also load models and their preprocessors separately via ``load_model()`` and ``load_processor()``.
138
+ In BLIP, you can also generate diverse captions by turning nucleus sampling on.
139
+
140
+ .. code-block:: python
141
+
142
+ from lavis.processors import load_processor
143
+ from lavis.models import load_model
144
+
145
+ # load image preprocesser used for BLIP
146
+ vis_processor = load_processor("blip_image_eval").build(image_size=384)
147
+ model = load_model(name="blip_caption", model_type="base_coco", is_eval=True, device=device)
148
+
149
+ image = vis_processor(image).unsqueeze(0).to(device)
150
+ model.generate({"image": raw_image}, use_nucleus_sampling=True)
151
+ # one generated random sample: ['some very pretty buildings and some water jets']
152
+
153
+
154
+ Visual question answering (VQA)
155
+ *******************************
156
+ BLIP model is able to answer free-form questions about images in natural language.
157
+ To access the VQA model, simply replace the ``name`` and ``model_type`` arguments
158
+ passed to ``load_model_and_preprocess()``.
159
+
160
+ .. code-block:: python
161
+
162
+ from lavis.models import load_model_and_preprocess
163
+ model, vis_processors, txt_processors = load_model_and_preprocess(name="blip_vqa", model_type="vqav2", is_eval=True, device=device)
164
+
165
+ # ask a random question.
166
+ question = "Which city is this photo taken?"
167
+
168
+ image = vis_processors["eval"](raw_image).unsqueeze(0).to(device)
169
+ question = txt_processors["eval"](question)
170
+
171
+ model.predict_answers(samples={"image": image, "text_input": question}, inference_method="generate")
172
+ # ['singapore']
173
+
174
+
175
+ Unified Feature Extraction Interface
176
+ ####################################
177
+
178
+ LAVIS provides a unified interface to extract multimodal features from each architecture.
179
+ To extract features, we load the feature extractor variants of each model.
180
+ The multimodal feature can be used for multimodal classification. The low-dimensional unimodal features can be used to compute cross-modal similarity.
181
+
182
+ .. code-block:: python
183
+
184
+ from lavis.models import load_model_and_preprocess
185
+
186
+ model, vis_processors, txt_processors = load_model_and_preprocess(name="blip_feature_extractor", model_type="base", is_eval=True, device=device)
187
+ caption = "a large fountain spewing water into the air"
188
+
189
+ image = vis_processors["eval"](raw_image).unsqueeze(0).to(device)
190
+ text_input = txt_processors["eval"](caption)
191
+
192
+ sample = {"image": image, "text_input": [text_input]}
193
+
194
+ features_multimodal = model.extract_features(sample)
195
+ print(features_multimodal.keys())
196
+ # odict_keys(['image_embeds', 'multimodal_embeds'])
197
+ print(features_multimodal.multimodal_embeds.shape)
198
+ # torch.Size([1, 12, 768]), use features_multimodal[:, 0, :] for multimodal classification tasks
199
+
200
+ features_image = model.extract_features(sample, mode="image")
201
+ print(features_image.keys())
202
+ # odict_keys(['image_embeds', 'image_embeds_proj'])
203
+ print(features_image.image_embeds.shape)
204
+ # torch.Size([1, 197, 768])
205
+ print(features_image.image_embeds_proj.shape)
206
+ # torch.Size([1, 197, 256])
207
+
208
+ features_text = model.extract_features(sample, mode="text")
209
+ print(features_text.keys())
210
+ # odict_keys(['text_embeds', 'text_embeds_proj'])
211
+ print(features_text.text_embeds.shape)
212
+ # torch.Size([1, 12, 768])
213
+ print(features_text.text_embeds_proj.shape)
214
+ # torch.Size([1, 12, 256])
215
+
216
+ similarity = features_image.image_embeds_proj[:, 0, :] @ features_text.text_embeds_proj[:, 0, :].t()
217
+ print(similarity)
218
+ # tensor([[0.2622]])
219
+
220
+ Since LAVIS supports a unified feature extraction interface, minimal changes are necessary to use a different model as feature extractor. For example,
221
+ to use ALBEF as the feature extractor, one only needs to change the following line:
222
+
223
+ .. code-block:: python
224
+
225
+ model, vis_processors, txt_processors = load_model_and_preprocess(name="albef_feature_extractor", model_type="base", is_eval=True, device=device)
226
+
227
+ Similarly, to use CLIP as feature extractor:
228
+
229
+ .. code-block:: python
230
+
231
+ model, vis_processors, txt_processors = load_model_and_preprocess(name="clip_feature_extractor", model_type="base", is_eval=True, device=device)
232
+ # model, vis_processors, txt_processors = load_model_and_preprocess(name="clip_feature_extractor", model_type="RN50", is_eval=True, device=device)
233
+ # model, vis_processors, txt_processors = load_model_and_preprocess(name="clip_feature_extractor", model_type="ViT-L-14", is_eval=True, device=device)
docs/index.rst ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .. LAVIS documentation master file, created by
2
+ sphinx-quickstart on Sun Jul 31 10:32:27 2022.
3
+ You can adapt this file completely to your liking, but it should at least
4
+ contain the root `toctree` directive.
5
+
6
+ Welcome to LAVIS's documentation!
7
+ =================================
8
+
9
+ .. toctree::
10
+ :maxdepth: 1
11
+ :caption: Introduction
12
+
13
+ intro
14
+
15
+
16
+ .. toctree::
17
+ :maxdepth: 1
18
+ :caption: Getting Started
19
+
20
+ getting_started
21
+
22
+
23
+ .. :maxdepth: 1
24
+ .. :caption: Advanced Training
25
+
26
+ .. advanced_training
27
+
28
+
29
+ .. toctree::
30
+ :maxdepth: 2
31
+ :caption: Advanced Usage
32
+
33
+ benchmark
34
+ tutorial
35
+
36
+
37
+ .. Documentations
38
+ .. ===================
39
+
40
+
41
+ Indices and tables
42
+ ==================
43
+
44
+ * :ref:`genindex`
45
+ * :ref:`modindex`
46
+ * :ref:`search`
docs/intro.rst ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ What is LAVIS?
2
+ ####################################
3
+
4
+ LAVIS is a Python deep learning library for LAnguage-and-VISion research and applications.
5
+ It features a unified design to access state-of-the-art foundation language-vision models (`ALBEF <https://arxiv.org/pdf/2107.07651.pdf>`_,
6
+ `BLIP <https://arxiv.org/pdf/2201.12086.pdf>`_, `ALPRO <https://arxiv.org/pdf/2112.09583.pdf>`_, `CLIP <https://arxiv.org/pdf/2103.00020.pdf>`_), common tasks
7
+ (retrieval, captioning, visual question answering, multimodal classification etc.) and datasets (COCO, Flickr, Nocaps, Conceptual
8
+ Commons, SBU, etc.).
9
+
10
+ This library aims to provide engineers and researchers with a one-stop solution to rapidly develop models for their specific multimodal
11
+ scenarios, and benchmark them across standard and customized datasets.
12
+
13
+ Key features of LAVIS include:
14
+
15
+ - **Modular and Extensible Library Design**: facilitating to easily utilize and repurpose existing modules (datasets, models, preprocessors), also to add new modules.
16
+
17
+ - **Easy Off-the-shelf Inference and Feature Extraction**: readily available pre-trained models let you take advantage of state-of-the-art multimodal understanding and generation capabilities on your own data.
18
+
19
+ - **Reproducible Model Zoo**: provided training/pre-training recipies to easily replicate and extend state-of-the-art models.
20
+
21
+ - **Dataset Zoo and Automatic Downloading Tools**: it can be a hassle to prepare the many language-vision datasets. LAVIS provides automatic downloaing scripts to help prepare a large variety of datasets and their annotations.
22
+
23
+ Other features include:
24
+
25
+ - **Distributed Training** using multiple GPUs on one machine or across multiple machines.
26
+
27
+ - **Web Demo**: try supported models on your own pictures, questions etc.
28
+
29
+ - **Leaderboard**: comparing state-of-the-art models across standard datasets.
30
+
31
+ - **Dataset Explorer**: help browse and understand language-vision datasets.
32
+
33
+ Supported Tasks, Models and Datasets
34
+ ####################################
35
+
36
+ The following table shows the supported models and language-vision tasks by LAVIS. Adapting existing models to more tasks is possible and next to come in future releases.
37
+
38
+ ======================================== =========================== ============================================= ============
39
+ Tasks Supported Models Supported Datasets Modalities
40
+ ======================================== =========================== ============================================= ============
41
+ Image-text Pre-training ALBEF, BLIP COCO, VisualGenome, SBU, ConceptualCaptions image, text
42
+ Image-text Retrieval ALBEF, BLIP, CLIP COCO, Flickr30k image, text
43
+ Text-image Retrieval ALBEF, BLIP, CLIP COCO, Flickr30k image, text
44
+ Visual Question Answering ALBEF, BLIP VQAv2, OKVQA, A-OKVQA image, text
45
+ Image Captioning BLIP COCO, NoCaps image, text
46
+ Image Classification CLIP ImageNet image
47
+ Natural Language Visual Reasoning (NLVR) ALBEF, BLIP NLVR2 image, text
48
+ Visual Entailment (VE) ALBEF SNLI-VE image, text
49
+ Visual Dialogue BLIP VisDial image, text
50
+ Video-text Retrieval BLIP, ALPRO MSRVTT, DiDeMo video, text
51
+ Text-video Retrieval BLIP, ALPRO MSRVTT, DiDeMo video, text
52
+ Video Question Answering (VideoQA) BLIP, ALPRO MSRVTT, MSVD video, text
53
+ Video Dialogue VGD-GPT AVSD video, text
54
+ Multimodal Feature Extraction ALBEF, CLIP, BLIP, ALPRO customized image, text
55
+ ======================================== =========================== ============================================= ============
56
+
57
+ Library Design
58
+ ####################################
59
+
60
+ .. image:: _static/architecture.png
61
+ :width: 550
62
+
63
+ LAVIS has six key modules.
64
+
65
+ - ``lavis.runners`` manages the overall training and evaluation lifecycle. It is also responsible for creating required components lazily as per demand, such as optimizers, learning rate schedulers and dataloaders. Currently ``RunnerBase`` implements epoch-based training and ``RunerIters`` implements iteration-based training.
66
+ - ``lavis.tasks`` implements concrete training and evaluation logic per task. A task could be, for example, retrieval, captioning, pre-training. The rationale to have an abstraction of task is to accomodate task-specific training and evaluation. For example, evaluating a retrieval model is different from a classification model.
67
+ - ``lavis.datasets`` is responsible for creating datasets, where ``lavis.datasets.builders`` loads dataset configurations, downloads annotations and returns a dataset object; ``lavis.datasets.datasets`` defines the supported datasets, each is a ``torch.utils.data.Dataset`` instance. We also provide `automatic dataset downloading tools` in ``datasets/download_scripts`` to help prepare common public datasets.
68
+ - ``lavis.models`` holds definition for the supported models and shared model layers.
69
+ - ``lavis.processors`` handles preprocessing of text and images/videos before feeding the model. For images and videos, a processor can be thought as transfroms in torchvision; for text input, this may include lowering case, truncation etc.
70
+ - ``lavis.common`` module contains shared classes and methods used by multiple other modules. For example,
71
+
72
+ - ``lavis.common.config`` contains classes to store and manipulate configuration files used by LAVIS. In particular, we use a hierarchical configuration design, to allow highly customizable training and evaluation.
73
+ - ``lavis.common.registry`` serves as a centralized place to manage modules that share the same functionalities. It allows building datasets, models, tasks, and learning rate schedulers during runtime, by specifying their names as string in the configuration file.
74
+ - ``lavis.common.optims`` contains definitions of learning rate schedulers.
75
+ - ``lavis.common.dist_utils`` contains utilities for distributed training and evaluation.
76
+ - ``lavis.common.utils`` contains miscellaneous utilities, mostly IO-related helper functions.
77
+
78
+
79
+ Installation
80
+ ############
81
+ 1. (Optional) Creating conda environment
82
+
83
+ .. code-block:: bash
84
+
85
+ conda create -n lavis python=3.8
86
+ conda activate lavis
87
+
88
+ 2. Cloning and building from source
89
+
90
+ .. code-block:: bash
91
+
92
+ git clone https://github.com/salesforce/LAVIS.git
93
+ cd LAVIS
94
+ pip install .
95
+
96
+ If you would like to develop on LAVIS, you may find it easier to build with editable mode::
97
+
98
+ pip install -e .
99
+
docs/make.bat ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @ECHO OFF
2
+
3
+ pushd %~dp0
4
+
5
+ REM Command file for Sphinx documentation
6
+
7
+ if "%SPHINXBUILD%" == "" (
8
+ set SPHINXBUILD=sphinx-build
9
+ )
10
+ set SOURCEDIR=source
11
+ set BUILDDIR=build
12
+
13
+ if "%1" == "" goto help
14
+
15
+ %SPHINXBUILD% >NUL 2>NUL
16
+ if errorlevel 9009 (
17
+ echo.
18
+ echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19
+ echo.installed, then set the SPHINXBUILD environment variable to point
20
+ echo.to the full path of the 'sphinx-build' executable. Alternatively you
21
+ echo.may add the Sphinx directory to PATH.
22
+ echo.
23
+ echo.If you don't have Sphinx installed, grab it from
24
+ echo.http://sphinx-doc.org/
25
+ exit /b 1
26
+ )
27
+
28
+ %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29
+ goto end
30
+
31
+ :help
32
+ %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33
+
34
+ :end
35
+ popd
docs/requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ GitPython
2
+ ipykernel
3
+ nbsphinx==0.8.7
4
+ pandoc
5
+ sphinx
6
+ sphinx_autodoc_typehints
7
+ sphinx_rtd_theme
docs/tutorial.configs.rst ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .. _config:
2
+
3
+ Training Models on Task Datasets (Commands and Configurations)
4
+ #################################################################
5
+
6
+ LAVIS provides scripts to pre-train and finetune supported models on standard language-vision tasks, stored at ``lavis/run_scripts/``.
7
+ To replicate the experiments, just run these bash scripts. For example, to train BLIP model on the image-text retrieval task with MSCOCO dataset, we can run
8
+
9
+ .. code-block::
10
+
11
+ bash run_scripts/lavis/blip/train/train_retrieval_coco.sh
12
+
13
+ Inside the scripts, we can see
14
+
15
+ .. code-block:: bash
16
+
17
+ python -m torch.distributed.run --nproc_per_node=8 train.py --cfg-path lavis/projects/blip/train/retrieval_coco_ft.yaml
18
+
19
+ where we start a pytorch distributed training on 8 GPUs (you may change according to your own hardware setup). The ``--cfg-path`` specifys a `runtime configuration file`, specifying
20
+ the task, model, dataset and training recipes.
21
+
22
+ Available options and their descriptions are as below.
23
+
24
+ .. LAVIS executes training and evaluation based on arguments specified in the configuration files. The default model and dataset configurations are defined in ``lavis/configs``. The task-specific configurations are defined in ``lavis/projects``. Task-specific configurations have higher priority over the default configurations.
25
+
26
+ .. The following tables provide explanations for the arguments in the configuration files.
27
+
28
+ .. list-table::
29
+ :widths: 30 40
30
+ :header-rows: 1
31
+
32
+ * - Model Configurations
33
+ - Functionalities
34
+ * - arch
35
+ - | name of the model from the model zoo
36
+ | default: task-dependent
37
+ * - model_type
38
+ - | the type of the model (e.g., base)
39
+ | default: task-dependent
40
+ * - load_pretrained
41
+ - | load pretrained weights
42
+ | default: True (for finetuning task) | False (for pretraining task)
43
+ * - load_finetuned
44
+ - | load task-specific finetuned weights
45
+ | default: False (for finetuning task) | True (for evaluation)
46
+ * - pretrained
47
+ - | URL or local path which stores the pretrained model, defined in the default model configuration file
48
+ | default: task-dependent
49
+ * - finetuned
50
+ - | URL or local path which stores the finetuned model, defined in the default model configuration file
51
+ | default: task-dependent
52
+
53
+ .. list-table::
54
+ :widths: 30 50
55
+ :header-rows: 1
56
+
57
+ * - Dataset Configurations
58
+ - Functionalities
59
+ * - vis_processor
60
+ - | pre-processing of visual input
61
+ | default: task-dependent
62
+ * - text_processor
63
+ - | pre-processing of text input
64
+ | default: task-dependent
65
+ * - build_info
66
+ - | dataset information including the storage location, defined in the default dataset configuration file
67
+ | default: task-dependent
68
+
69
+ .. list-table::
70
+ :widths: 30 50
71
+ :header-rows: 1
72
+
73
+ * - Runtime Configurations
74
+ - Functionalities
75
+ * - task
76
+ - | name of the task
77
+ | default: task-dependent
78
+ * - lr_sched
79
+ - | learning rate schedular
80
+ | default: linear_warmup_cosine_lr
81
+ * - init_lr
82
+ - | initial learning rate (after warmup)
83
+ | default: task-dependent
84
+ * - min_lr
85
+ - | final learning rate after decay
86
+ | default: task-dependent
87
+ * - warmup_lr
88
+ - | starting learning rate for warmup
89
+ | default: init_lr (no warmup)
90
+ * - lr_decay_rate
91
+ - | learning rate decay per epoch for step_lr_shedule
92
+ | default: 0.9
93
+ * - warmup_steps
94
+ - | number of steps for learning rate warmup
95
+ | default: 0
96
+ * - max_epoch
97
+ - | total number of training epochs
98
+ | default: task-dependent
99
+ * - weight_decay
100
+ - | weight decay coefficient for the optimizer
101
+ | default: 0.05
102
+ * - batch_size_train
103
+ - | batch size during training
104
+ | default: task-dependent
105
+ * - batch_size_eval
106
+ - | batch size during evaluation
107
+ | default: task-dependent
108
+ * - seed
109
+ - | pseudo random number generator seed
110
+ | default: 42
111
+ * - output_dir
112
+ - | directory to store logs, results and checkpoints
113
+ | default: task-dependent
114
+ * - resume_ckpt_path
115
+ - | path of the checkpoint to resume training from
116
+ | default: None
117
+ * - evaluate
118
+ - | only perform evaluation without training
119
+ | default: False
120
+ * - train_splits
121
+ - | dataset splits used for training
122
+ | default: ["train"]
123
+ * - valid_splits
124
+ - | dataset splits used for validation
125
+ | default: ["val"]
126
+ * - test
127
+ - | dataset splits used for test
128
+ | default: ["test"]
129
+ * - device
130
+ - | use cpu or gpu (cuda)
131
+ | default: cuda
132
+ * - world_size
133
+ - | number of processes participating in the job
134
+ | default: 1
135
+ * - dist_url
136
+ - | URL specifying how to initialize the process group
137
+ | default: "env://"
138
+ * - distributed
139
+ - | use distributed training
140
+ | default: True
141
+ * - amp
142
+ - | use automatic mixed precision training
143
+ | default: False
144
+
145
+ .. list-table::
146
+ :widths: 40 50
147
+ :header-rows: 1
148
+
149
+ * - Text Generation Configurations
150
+ - Functionalities
151
+ * - max_len
152
+ - | maximum number of text tokens to generate
153
+ | default: 20 (for image captioning)
154
+ * - min_len
155
+ - | minimum number of text tokens to generate
156
+ | default: 5 (for image captioning)
157
+ * - num_beams
158
+ - | number of beams to perform beam search
159
+ | default: 3
160
+
161
+ .. list-table::
162
+ :widths: 40 50
163
+ :header-rows: 1
164
+
165
+ * - Multimodal Retrieval Configurations
166
+ - Functionalities
167
+ * - negative_all_rank
168
+ - | collect negatives from all processes for the image-text matching loss
169
+ | default: True (for coco)
170
+ * - k_test
171
+ - | number of retrieval candidates ranked from contrastive similarity
172
+ | default: 256 (for coco)
docs/tutorial.datasets.rst ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Adding Datasets
2
+ ################################################
3
+
4
+ This is a tutorial on adding a new dataset using ``lavis.datasets`` module.
5
+
6
+ The LAVIS library includes a standard dataset module, which allows customization to add new datasets.
7
+ The ``lavis.datasets`` module is designed such that any new dataset class can be easily added and adapted from our code base, including creating dataset configuration, and defining and associating new dataset classes.
8
+
9
+ In this tutorial, we will replicate the steps to add a dataset class for the `Audio-Visual Scene-Aware Dialogue (AVSD) <https://arxiv.org/pdf/1901.09107.pdf>`_ benchmark for the video-grounded dialogue task.
10
+
11
+ Dataset Configuration ``lavis.configs.datasets``
12
+ **************************************************************
13
+
14
+ First, we define the basic configurations for this dataset, including a new dataset class ``avsd_dialogue``, dataset card, and data types.
15
+ We can define any new dataset configuration in ``lavis.configs.datasets``. For instance, under this module, we can set up a configuration file ``avsd/defaults_dial.yaml`` as follows:
16
+
17
+ .. code-block:: yaml
18
+
19
+ datasets:
20
+ avsd_dialogue: # name of the dataset builder
21
+ dataset_card: dataset_card/avsd_dialogue.md # path to the dataset card
22
+ data_type: features # [images|videos|features] we use features in this case for extracted video features
23
+
24
+ build_info:
25
+ # Be careful not to append minus sign (-) before split to avoid itemizing
26
+ annotations:
27
+ train:
28
+ url: /export/home/data/avsd/train_set4DSTC7-AVSD.json
29
+ storage: avsd/annotations/train.json
30
+ val:
31
+ url: /export/home/data/avsd/valid_set4DSTC7-AVSD.json
32
+ storage: avsd/annotations/val.json
33
+ test:
34
+ url: /export/home/data/avsd/test_set4DSTC7-AVSD.json
35
+ storage: avsd/annotations/test.json
36
+ features:
37
+ storage: /export/home/data/avsd/features/
38
+
39
+
40
+ Dataset Card
41
+ ===============
42
+ One optional step to set up dataset configuration is defining a dataset card, which contains more details about the dataset such as description, tasks, and metrics.
43
+ For instance, we can define a dataset card for the AVSD benchmark in ``dataset_card/avsd_dialogue.md``.
44
+ Depending on the dataset, we included in its corresponding dataset card the command for auto-downloading data (with python code defined in ``lavis.datasets.download_scripts``) that will automatically load the data and store it in a specific folder.
45
+ Else, you should describe in the dataset card the external download instructions from the original data source to load the dataset properly.
46
+
47
+ One example of a dataset card for the AVSD benchmark is:
48
+
49
+ .. code-block:: md
50
+
51
+ ![Samples from the AVSD dataset (Image credit: "https://arxiv.org/pdf/1901.09107.pdf").](imgs/avsd_dialogue.png)(Samples from the AVSD dataset. Image credit: "https://arxiv.org/pdf/1901.09107.pdf")
52
+
53
+ # Audio-Visual Scene-Aware Dialogues (AVSD)
54
+
55
+ ## Description
56
+ [Audio-Visual Scene-Aware Dialogues (AVSD)](https://github.com/hudaAlamri/DSTC7-Audio-Visual-Scene-Aware-Dialog-AVSD-Challenge) contains more than 10,000 dialogues, each of which is grounded on a unique video. In the test split, for each test sample, 6 reference dialogue responses are provided.
57
+
58
+
59
+ ## Task
60
+
61
+ (https://github.com/hudaAlamri/DSTC7-Audio-Visual-Scene-Aware-Dialog-AVSD-Challenge)
62
+
63
+ In a **video-grounded dialogue task**, the system must generate responses to user input in the context of a given dialog.
64
+ This context consists of a dialog history (previous utterances by both user and system) in addition to video and audio information that comprise the scene. The quality of a system’s automatically generated sentences is evaluated using objective measures to determine whether or not the generated responses are natural and informative
65
+
66
+ ## Metrics
67
+ Models are typically evaluated according to [BLEU](https://aclanthology.org/P02-1040/), [CIDER](https://www.cv-foundation.org/openaccess/content_cvpr_2015/papers/Vedantam_CIDEr_Consensus-Based_Image_2015_CVPR_paper.pdf), [METEOR](https://aclanthology.org/W05-0909/), and [ROUGE-L](https://aclanthology.org/W04-1013/) metrics.
68
+
69
+ ## Leaderboard
70
+
71
+ ....
72
+
73
+
74
+ ## Auto-Downloading
75
+
76
+ Please refer to [benchmark webite](https://github.com/hudaAlamri/DSTC7-Audio-Visual-Scene-Aware-Dialog-AVSD-Challenge) for instructions to download the dataset.
77
+
78
+
79
+ ## References
80
+ "Audio Visual Scene-Aware Dialog", Huda Alamri, Vincent Cartillier, Abhishek Das, Jue Wang, Anoop Cherian, Irfan Essa, Dhruv Batra, Tim K. Marks, Chiori Hori, Peter Anderson, Stefan Lee, Devi Parikh
81
+
82
+ Visual Data Type
83
+ ==============================
84
+ We currently limit the visual data types to one of three options: ``images``, ``videos``, and ``features``.
85
+ "Images" and "videos" refer to the raw visual data, which is appropriate for models processing visual data in their original forms (e.g. ViT models).
86
+ "Features" are visual representations extracted from pretrained models (e.g. CNN models).
87
+ In this tutorial, the AVSD benchmark consists of video features extracted from 3D-CNN models.
88
+
89
+ Build Info
90
+ ==============================
91
+ Build info refers to the specific locations where data is stored and cached.
92
+
93
+ For text annotations (e.g. captioning or dialogues), by default, we include three data splits, namely "train", "val", and "test", typically used in all machine learning projects.
94
+ For each split, we specify 2 parameters: ``url`` and ``storage``.
95
+ ``url`` can be either an online URL where the dataset can be loaded automatically (e.g. from *googleapis*), or a local directory where data is already downloaded beforehand.
96
+ ``storage`` is the directory where the data will be cached over time, avoiding downloading data repeatedly.
97
+
98
+ For visual data annotations, ensure the field name matches the data types defined earlier (e.g. one of "images", "videos" or features").
99
+ As visual features are usually large and should be downloaded beforehand, we maintain only a ``storage`` parameter where visual data is cached.
100
+
101
+ Dataset ``lavis.datasets.datasets``
102
+ **************************************************************
103
+
104
+ Base Dataset ``lavis.datasets.datasets.base_dataset``
105
+ =======================================================
106
+ In this step, we want to define new dataset classes that inherit our base dataset class ``lavis.datasets.datasets.base_dataset``. This base dataset class already defines standard methods such as ``collater`` which uses the default collator from Pytorch.
107
+
108
+ .. code-block:: python
109
+
110
+ import json
111
+ from typing import Iterable
112
+
113
+ from torch.utils.data import Dataset, ConcatDataset
114
+ from torch.utils.data.dataloader import default_collate
115
+
116
+ class BaseDataset(Dataset):
117
+ def __init__(
118
+ self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[]
119
+ ):
120
+ """
121
+ vis_root (string): Root directory of images (e.g. coco/images/)
122
+ ann_root (string): directory to store the annotation file
123
+ """
124
+ self.vis_root = vis_root
125
+
126
+ self.annotation = []
127
+ for ann_path in ann_paths:
128
+ self.annotation.extend(json.load(open(ann_path, "r")))
129
+
130
+ self.vis_processor = vis_processor
131
+ self.text_processor = text_processor
132
+
133
+ self._add_instance_ids()
134
+
135
+ def __len__(self):
136
+ return len(self.annotation)
137
+
138
+ def collater(self, samples):
139
+ return default_collate(samples)
140
+
141
+ def set_processors(self, vis_processor, text_processor):
142
+ self.vis_processor = vis_processor
143
+ self.text_processor = text_processor
144
+
145
+ def _add_instance_ids(self, key="instance_id"):
146
+ for idx, ann in enumerate(self.annotation):
147
+ ann[key] = str(idx)
148
+
149
+ Any dataset subclass will inherit these methods and it is optional to define and overwrite these methods accordingly to the specifications of the dataset.
150
+ We encourage users not to modify the base dataset class as any modification will have cascading impacts on any other dataset classes that inherit this base dataset.
151
+ Instead, the users should independently create new dataset classes to cater to their specific requirements.
152
+
153
+ Dialogue Datasets ``lavis.datasets.datasets.dialogue_datasets``
154
+ ======================================================================
155
+
156
+ For example, for the AVSD dataset, we want to define a new dataset subclass ``DialogueDataset`` for dialogue tasks. We can define this dataset class in ``lavis.datasets.datasets.dialogue_datasets`` as following:
157
+
158
+ .. code-block:: python
159
+
160
+ import os
161
+ from collections import OrderedDict
162
+
163
+ from lavis.datasets.datasets.base_dataset import BaseDataset
164
+
165
+ import json
166
+ import copy
167
+
168
+ class DialogueDataset(BaseDataset):
169
+ def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
170
+ """
171
+ vis_processor (string): visual processor
172
+ text_processor (string): textual processor
173
+ vis_root (string): Root directory of images (e.g. coco/images/)
174
+ ann_paths (string): Root directory of images (e.g. coco/images/)
175
+ """
176
+
177
+ self.vis_root = vis_root
178
+
179
+ self.annotation = []
180
+ for ann_path in ann_paths:
181
+ dialogs = json.load(open(ann_path, "r"))['dialogs']
182
+ for dialog in dialogs:
183
+ all_turns = dialog['dialog']
184
+ dialogue_context = []
185
+ for turn in all_turns:
186
+ dialog_instance = copy.deepcopy(dialog)
187
+ question = turn['question']
188
+ answer = turn['answer']
189
+
190
+ dialog_instance['dialog'] = copy.deepcopy(dialogue_context)
191
+ dialog_instance['question'] = question
192
+ dialog_instance['answer'] = answer
193
+ self.annotation.append(dialog_instance)
194
+ dialogue_context.append(turn)
195
+
196
+ self.vis_processor = vis_processor
197
+ self.text_processor = text_processor
198
+
199
+ self._add_instance_ids()
200
+
201
+ self.img_ids = {}
202
+ n = 0
203
+ for ann in self.annotation:
204
+ img_id = ann["image_id"]
205
+ if img_id not in self.img_ids.keys():
206
+ self.img_ids[img_id] = n
207
+ n += 1
208
+
209
+ Class inheritance allows us to define multiple subclasses. For instance, we want another dialogue dataset class that is defined only for the test split. We can define another dataset class ``DialogueEvalDataset`` as similarly defined above but the annotations are processed differently.
210
+ Typically, in dialogue tasks, during test time, only a single test sample is constructed per dialogue (rather than decomposing all dialogue turns as samples during training time).
211
+ The dataset class can then be defined as:
212
+
213
+ .. code-block:: python
214
+
215
+ class DialogueEvalDataset(BaseDataset):
216
+ def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
217
+ # ...
218
+ # defined similarly as DialogueDataset above
219
+ # except for the loading of dialogue annotation data
220
+
221
+ self.annotation = []
222
+ for ann_path in ann_paths:
223
+ dialogs = json.load(open(ann_path, "r"))['dialogs']
224
+ for dialog in dialogs:
225
+ all_turns = dialog['dialog']
226
+ dialogue_context = all_turns[:-1]
227
+ last_turn = all_turns[-1]
228
+
229
+ question = last_turn['question']
230
+ answer = last_turn['answer']
231
+
232
+ dialog['dialog'] = dialogue_context
233
+ dialog['question'] = question
234
+ dialog['answer'] = answer
235
+
236
+ self.annotation.append(dialog)
237
+
238
+
239
+ Using class inheritance to define datasets also allows us to develop more fine-grain class implementations, each of which is specifically designated for a benchmark.
240
+ For instance, under the dialogue-based tasks, we can further define another dataset subclass that is specified for the AVSD dataset.
241
+ We can define a new class ``AVSDDialDataset`` that further specifies how to load individual samples and collate them accordingly to specific requirements:
242
+
243
+ .. code-block:: python
244
+
245
+ import os
246
+ from lavis.datasets.datasets.base_dataset import BaseDataset
247
+ from lavis.datasets.datasets.dialogue_datasets import DialogueDataset, DialogueEvalDataset
248
+
249
+ import torch
250
+
251
+ class AVSDDialDataset(DialogueDataset):
252
+ def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
253
+
254
+ super().__init__(vis_processor, text_processor, vis_root, ann_paths)
255
+
256
+ def __getitem__(self, index):
257
+
258
+ ann = self.annotation[index]
259
+
260
+ vname = ann["image_id"]
261
+
262
+ video = self.vis_processor(self.vis_root, vname)
263
+
264
+ dialogue = self.text_processor(ann)
265
+
266
+ return {
267
+ "video_fts": video['video_fts'],
268
+ "video_token_type_ids": video['token_type_ids'],
269
+ "input_ids": dialogue['input_ids'],
270
+ "token_type_ids": dialogue['token_type_ids'],
271
+ "labels": dialogue['labels'],
272
+ "image_id": ann["image_id"],
273
+ "instance_id": ann["instance_id"]
274
+ }
275
+
276
+ def collater(self, samples):
277
+
278
+ input_ids, token_type_ids, labels, video_fts, video_token_type_ids = [], [], [], [], []
279
+
280
+ for i in samples:
281
+ input_ids.append(i['input_ids'])
282
+ token_type_ids.append(i['token_type_ids'])
283
+ labels.append(i['labels'])
284
+ video_fts.append(i['video_fts'])
285
+ video_token_type_ids.append(i['video_token_type_ids'])
286
+
287
+ input_ids = self.text_processor.padding(input_ids)
288
+
289
+ labels = self.text_processor.padding(labels, -1)
290
+ video_fts = self.vis_processor.padding(video_fts)
291
+
292
+ token_type_ids = self.text_processor.padding(token_type_ids)
293
+ video_token_type_ids = self.text_processor.padding(video_token_type_ids)
294
+ token_type_ids = torch.cat([video_token_type_ids, token_type_ids], dim=1)
295
+
296
+ attn_mask = self.text_processor.get_attention_mask(input_ids)
297
+ video_mask = self.vis_processor.get_attention_mask(video_fts)
298
+ attn_mask = torch.cat([video_mask, attn_mask], dim=1)
299
+
300
+ video_labels = torch.ones((video_fts.size(0), video_fts.size(1))).long() * -1 # ignore token indice -1 by default
301
+
302
+ labels = torch.cat([video_labels, labels], dim=1)
303
+
304
+ samples = {}
305
+ samples['input_ids'] = input_ids
306
+ samples['token_type_ids'] = token_type_ids
307
+ samples['labels'] = labels
308
+ samples['video_fts'] = video_fts
309
+ samples['attn_mask'] = attn_mask
310
+
311
+ return samples
312
+
313
+ Note that in a dataset subclass, if methods such as ``__getitem__`` and ``collater`` are not defined, the same functions from the corresponding superclass will be used.
314
+ For instance, by default, we always use the collater from the ``BaseDataset`` class to collate data samples.
315
+
316
+ Dataset Builder ``lavis.datasets.builders``
317
+ **************************************************************
318
+ Dataset Builder is the data processing module that controls the dataset classes (by training or evaluation split) and associates the specific dataset configurations to these dataset classes.
319
+
320
+ Base Dataset Builder ``lavis.datasets.builders.base_dataset_builder``
321
+ ======================================================================
322
+
323
+ Note that any new builder class definition should inherit the base dataset builder class ``lavis.datasets.builders.base_dataset_builder``:
324
+
325
+ .. code-block:: python
326
+
327
+ class BaseDatasetBuilder:
328
+ train_dataset_cls, eval_dataset_cls = None, None
329
+ ...
330
+
331
+ This allows us to standardize the operations of dataset builders across all builder classes. We advise the users to carefully review the standard methods defined in the base builder class, including methods such as ``_download_data`` and ``build_dataset`` that will load download the data and create instances of dataset classes:
332
+
333
+ .. code-block:: python
334
+
335
+ class BaseDatasetBuilder:
336
+ ...
337
+
338
+ def build_datasets(self):
339
+ # download, split, etc...
340
+ # only called on 1 GPU/TPU in distributed
341
+
342
+ if is_main_process():
343
+ self._download_data()
344
+
345
+ if is_dist_avail_and_initialized():
346
+ dist.barrier()
347
+
348
+ # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
349
+ logging.info("Building datasets...")
350
+ datasets = self.build() # dataset['train'/'val'/'test']
351
+
352
+ return datasets
353
+
354
+ def _download_data(self):
355
+ self._download_ann()
356
+ self._download_vis()
357
+
358
+ We encourage users not to modify the implementation of the base dataset builder class as this will affect all existing dataset builder subclasses.
359
+
360
+ Dialogue Dataset Builder ``lavis.datasets.builders.dialogue_builder``
361
+ ======================================================================
362
+ We can define any new builder subclass and associate this builder with the corresponding dataset classes and dataset configurations.
363
+ For instance, for the AVSD dataset, we can define a builder ``lavis.datasets.builders.dialogue_builder`` for dialogue-based datasets as follows:
364
+
365
+ .. code-block:: python
366
+
367
+ from lavis.datasets.builders.base_dataset_builder import BaseDatasetBuilder
368
+ from lavis.datasets.datasets.avsd_dialogue_datasets import (
369
+ AVSDDialDataset,
370
+ AVSDDialEvalDataset
371
+ )
372
+
373
+ from lavis.common.registry import registry
374
+
375
+
376
+ @registry.register_builder("avsd_dialogue")
377
+ class AVSDDialBuilder(BaseDatasetBuilder):
378
+ train_dataset_cls = AVSDDialDataset
379
+ eval_dataset_cls = AVSDDialEvalDataset
380
+
381
+ DATASET_CONFIG_DICT = {
382
+ "default": "configs/datasets/avsd/defaults_dial.yaml"
383
+ }
384
+
385
+ Note that we chose to separately define the parameters ``train_dataset_cls`` and ``eval_dataset_cls`` to consider cases where data is processed differently between training and test time.
386
+ For instance, in captioning tasks, during test time, each data sample often includes multiple ground-truth captions rather than just a single ground-truth during training time.
387
+ If the data processing is the same in both training and test time, the two parameters can be linked to the same dataset class.
388
+
389
+ Finally, define ``DATASET_CONFIG_DICT`` to associate the dataset configurations to the assigned dataset classes.
390
+
391
+ Registering Builder ``lavis.datasets.builders.__init__``
392
+ ======================================================================
393
+
394
+ To add a new builder class, ensure to first include the class within the ``__init__.py``. For instance, to define a new builder for the AVSD dataset:
395
+
396
+ .. code-block:: python
397
+
398
+ from lavis.datasets.builders.dialogue_builder import (
399
+ AVSDDialBuilder
400
+ )
401
+
402
+ __all__ = [
403
+ ...,
404
+ "AVSDDialBuilder"
405
+ ]
406
+
407
+ Assigning Builder
408
+ ======================================================================
409
+ Note that during data loading and processing, the builder being assigned must have the correct registry to be able to load it properly.
410
+ For instance, the following should be specified in a configuration file e.g. ``dialogue_avsd_ft.yaml``:
411
+
412
+ .. code-block:: yaml
413
+
414
+ datasets:
415
+ avsd_dialogue: # name of the dataset builder
416
+ ...
417
+ # processor configuration
418
+ ...
419
+
420
+ Subsequently, any processes (e.g. training) should load this configuration file to assign the correct builder which will then associate the correct dataset classes to construct data samples.
421
+
422
+ .. code-block:: sh
423
+
424
+ python train.py --cfg-path dialogue_avsd_ft.yaml
docs/tutorial.evaluation.rst ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Evaluating Pre-trained Models on Task Datasets
2
+ ###############################################
3
+ LAVIS provides pre-trained and finetuned model for off-the-shelf evaluation on task dataset.
4
+ Let's now see an example to evaluate BLIP model on the captioning task, using MSCOCO dataset.
5
+
6
+ .. _prep coco:
7
+
8
+ Preparing Datasets
9
+ ******************
10
+ First, let's download the dataset. LAVIS provides `automatic downloading scripts` to help prepare
11
+ most of the public dataset, to download MSCOCO dataset, simply run
12
+
13
+ .. code-block:: bash
14
+
15
+ cd lavis/datasets/download_scripts && bash download_coco.py
16
+
17
+ This will put the downloaded dataset at a default cache location ``cache`` used by LAVIS.
18
+
19
+ If you want to use a different cache location, you can specify it by updating ``cache_root`` in ``lavis/configs/default.yaml``.
20
+
21
+ If you have a local copy of the dataset, it is recommended to create a symlink from the cache location to the local copy, e.g.
22
+
23
+ .. code-block:: bash
24
+
25
+ ln -s /path/to/local/coco cache/coco
26
+
27
+ Evaluating pre-trained models
28
+ ******************************
29
+
30
+ To evaluate pre-trained model, simply run
31
+
32
+ .. code-block:: bash
33
+
34
+ bash run_scripts/lavis/blip/eval/eval_coco_cap.sh
35
+
36
+ Or to evaluate a large model:
37
+
38
+ .. code-block:: bash
39
+
40
+ bash run_scripts/lavis/blip/eval/eval_coco_cap_large.sh
docs/tutorial.models.rst ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Adding Models
2
+ ####################################
3
+
4
+ This is a tutorial on adding new models using ``lavis.models`` module.
5
+
6
+ The LAVIS library includes a standard model module that builds the foundation for many major language-vision models such as `ALBEF <https://arxiv.org/pdf/2107.07651.pdf>`_,
7
+ `BLIP <https://arxiv.org/pdf/2201.12086.pdf>`_, `ALPRO <https://arxiv.org/pdf/2112.09583.pdf>`_, and `CLIP <https://arxiv.org/pdf/2103.00020.pdf>`_.
8
+ The ``lavis.models`` module is designed such that any new models can be added and integrated into the LAVIS library, with minimal steps to develop training and testing procedures.
9
+ In this tutorial, we will replicate the steps to add a GPT-style model specifically for `video-grounded dialogue tasks <https://arxiv.org/pdf/1901.09107.pdf>`_.
10
+
11
+ Base Model ``lavis.models.base_model``
12
+ **************************************************************
13
+
14
+ Note that any new model definition should inherit the base model class ``BaseModel``:
15
+
16
+ .. code-block:: python
17
+
18
+ from omegaconf import OmegaConf
19
+
20
+ import numpy as np
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+
25
+ from lavis.common.utils import get_abs_path
26
+
27
+ class BaseModel(nn.Module):
28
+ """Base class for models."""
29
+
30
+ def __init__(self):
31
+ super().__init__()
32
+
33
+ def forward_features(self, *args, **kwargs):
34
+ """Similar to *forward* but only return features."""
35
+ raise NotImplementedError
36
+
37
+ def load_from_pretrained(self, url_or_filename):
38
+ raise NotImplementedError
39
+
40
+ @classmethod
41
+ def _from_config(cls, cfg=None, model_type="base"):
42
+ if not cfg:
43
+ # useful when building model without a provided configuration file
44
+ cfg = OmegaConf.load(cls.default_config_path(model_type)).model
45
+
46
+ return cls.from_config(cfg)
47
+
48
+ @classmethod
49
+ def from_pretrained(cls, model_type="base"):
50
+ """
51
+ Build a pretrained model from the default configuration file, specified by model_type.
52
+ """
53
+ return cls._from_config(cfg=None, model_type=model_type)
54
+
55
+ @property
56
+ def device(self):
57
+ return list(self.parameters())[0].device
58
+
59
+ @classmethod
60
+ def default_config_path(cls, model_type="base"):
61
+ assert (
62
+ model_type in cls.PRETRAINED_MODEL_CONFIG_DICT
63
+ ), "Unknown model type {}".format(model_type)
64
+ return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type])
65
+
66
+ def before_evaluation(self, **kwargs):
67
+ pass
68
+
69
+ def show_n_params(self, return_str=True):
70
+ tot = 0
71
+ for p in self.parameters():
72
+ w = 1
73
+ for x in p.shape:
74
+ w *= x
75
+ tot += w
76
+ if return_str:
77
+ if tot >= 1e6:
78
+ return "{:.1f}M".format(tot / 1e6)
79
+ else:
80
+ return "{:.1f}K".format(tot / 1e3)
81
+ else:
82
+ return tot
83
+
84
+
85
+ In this base model, we already declare and standardize many common methods such as ``_from_config`` and ``_from_pretrained``.
86
+ Inheriting this base model class allows us to standardize operations of models across all model classes while still allowing customizations.
87
+ We advise users not to change the implementation of the base model class as this will affect all existing model subclasses.
88
+
89
+ GPT-style Video-grounded Dialogue Model ``lavis.models.gpt_models.gpt_dialogue``
90
+ ********************************************************************************
91
+
92
+ In this step, we can define a new model class, e.g. under ``lavis.models.gpt_models.gpt_dialogue``, for GPT-based dialogue models designed specifically for video-grounded dialogues.
93
+ Note that we assume the model class inherits from the standard model super class ``GPT2LMHeadModel`` from the ``transformers`` `library <https://huggingface.co/docs/transformers/index>`_.
94
+ We also enforce model integration to the LAVIS framework through the inheritance of the ``BaseModel`` from the LAVIS library, as the secondary super class.
95
+
96
+ .. code-block:: python
97
+
98
+ import torch
99
+ from lavis.common.registry import registry
100
+ from lavis.models.base_model import BaseModel
101
+
102
+ from transformers import GPT2Model, GPT2LMHeadModel
103
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
104
+ import math
105
+ import torch
106
+ import torch.nn as nn
107
+ from torch.nn import CrossEntropyLoss, MSELoss
108
+
109
+ @registry.register_model("gpt_dialogue")
110
+ class GPTDialogue(GPT2LMHeadModel, BaseModel):
111
+ ...
112
+
113
+ Next, we can modify the architecture of the model during model initialization to fit the tasks of interest, i.e. video-grounded dialogues.
114
+ In this case, we want to add additional model parameters for a linear network to transform the video feature representations to the model dimension.
115
+
116
+ .. code-block:: python
117
+
118
+ class GPTDialogue(GPT2LMHeadModel, BaseModel):
119
+
120
+ def __init__(self, config, len_video_ft=4224):
121
+
122
+ super().__init__(config)
123
+
124
+ self.video_ff = nn.Linear(len_video_ft, config.n_embd)
125
+
126
+ # Model parallel
127
+ self.model_parallel = False
128
+ self.device_map = None
129
+
130
+ # Initialize weights and apply final processing
131
+ self.post_init()
132
+
133
+ Note that for each new model class, we advise redefining the ``from_config`` method which is inherited from the ``BaseModel`` class.
134
+ As each model usually has its own unique configurations, redefining the method will ensure the model instances are created properly.
135
+ For instance, ``GPTDialogue`` requires an additional parameter of video feature length (``len_video_ft``) which should be part of the model initialization procedure.
136
+ Another additional parameter is the number of tokens/words (as we include additional special tokens in the vocabulary for dialogue tasks).
137
+
138
+ .. code-block:: python
139
+
140
+ class GPTDialogue(GPT2LMHeadModel, BaseModel):
141
+ ...
142
+ @classmethod
143
+ def from_config(cls, cfg):
144
+ model = cls.from_pretrained('gpt2', len_video_ft=cfg['len_video_ft'])
145
+ model.resize_token_embeddings(cfg['len_tokenizer'])
146
+ return model
147
+
148
+ Other basic methods should also be defined explicitly in the new model class, including the ``forward`` function.
149
+ For instance, in GPT models for video-grounded dialogue tasks, we want the forward operation also includes the transformation and integration of video features before passing the representations to the Transformer layers.
150
+
151
+ .. code-block:: python
152
+
153
+ class GPTDialogue(GPT2LMHeadModel, BaseModel):
154
+ ...
155
+
156
+ def forward(self, samples,
157
+ past_key_values=None,
158
+ position_ids=None,
159
+ head_mask=None,
160
+ encoder_hidden_states=None,
161
+ encoder_attention_mask=None,
162
+ use_cache=None,
163
+ output_attentions=None,
164
+ output_hidden_states=None,
165
+ return_dict=None):
166
+
167
+ input_embs = self.transformer.wte(samples['input_ids'])
168
+ video_embs = self.video_ff(samples['video_fts'])
169
+ input_embs = torch.cat([video_embs, input_embs], dim=1)
170
+
171
+ transformer_outputs = self.transformer(
172
+ attention_mask=samples['attn_mask'],
173
+ token_type_ids=samples['token_type_ids'],
174
+ inputs_embeds=input_embs,
175
+ position_ids=position_ids,
176
+ head_mask=head_mask,
177
+ encoder_hidden_states=encoder_hidden_states,
178
+ encoder_attention_mask=encoder_attention_mask,
179
+ use_cache=use_cache,
180
+ output_attentions=output_attentions,
181
+ output_hidden_states=output_hidden_states,
182
+ return_dict=return_dict,
183
+ )
184
+ hidden_states = transformer_outputs[0]
185
+
186
+ lm_logits = self.lm_head(hidden_states)
187
+ ...
188
+
189
+ Registering New Model ``lavis.models.__init__``
190
+ ********************************************************************************
191
+
192
+ Any new model must be officially registered as part of the ``lavis.models`` module.
193
+ For instance, to add a model class for GPT-based dialogue models, we can modify the ``__init__.py`` as follows:
194
+
195
+ .. code-block:: python
196
+
197
+ from lavis.models.gpt_models.gpt_dialogue import GPTDialogue
198
+
199
+ __all__ = [
200
+ ...
201
+ "GPTDialogue"
202
+ ]
203
+
204
+ Assigning Model
205
+ ********************************************************************************
206
+
207
+ From the above example of a model class, note that we define a ``from_config method`` for the new model class.
208
+ This method will process a configuration file and pass specific parameters to initialize the model classes properly.
209
+ To do this, we can assign/ associate the correct registry of model classes in a configuration file.
210
+ For instance, the following should be specified in a configuration file e.g. ``dialogue_avsd_ft.yaml``:
211
+
212
+ .. code-block:: yaml
213
+
214
+ model:
215
+ arch: gpt_dialogue # name of the model
216
+ model_type: base
217
+
218
+
219
+ Subsequently, any processes (e.g. training) should load this configuration file to assign the correct model.
220
+
221
+ .. code-block:: sh
222
+
223
+ python train.py --cfg-path dialogue_avsd_ft.yaml
224
+
225
+ Note that to simplify the model configuration, we only enable two main parameters here: ``arch`` and ``model_type``. ``arch`` refers to the model class registry, and ``model_type`` is the corresponding model type under this model family.
226
+ For instance, with ``gpt_dialogue``, we have a model ``base`` which has its own configuration in a separate configuration file e.g. ``gpt_dialogue_base.yaml``:
227
+
228
+ .. code-block:: yaml
229
+
230
+ model:
231
+ arch: gpt_dialogue
232
+ len_tokenizer: 50264 # 50257 tokens from gpt2 default tokenizer + additional special tokens
233
+ len_video_ft: 4224 # i3d_rgb: 2048 i3d_flow: 2048 vggish: 128
234
+
235
+ We can pass load this configuration and pass the parameters to the above ``from_config`` method to initialize the model accordingly.
236
+ We advise the users to maintain a dictionary that contains default paths to model configurations, in the model class definition.
237
+ By default, the LAVIS framework will search for configurations from each model class defined as ``model.PRETRAINED_MODEL_CONFIG_DICT``.
238
+
239
+ .. code-block:: python
240
+
241
+ class GPTDialogue(GPT2LMHeadModel, BaseModel):
242
+ PRETRAINED_MODEL_CONFIG_DICT = {
243
+ "base": "configs/models/gpt_dialogue_base.yaml"
244
+ }
245
+ ...
docs/tutorial.processors.rst ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Adding Processors
2
+ ################################################
3
+
4
+ This is a tutorial on adding new processors using ``lavis.processors`` module.
5
+
6
+ The LAVIS library includes a standard processor module that preprocesses data e.g. image transformation and sequence concatenation.
7
+ The ``lavis.processors`` module is designed such that any processors can be added, specifically to the requirements of corresponding models of interest.
8
+ In this tutorial, we will replicate the steps to add visual and textual processors specifically for `video-grounded dialogue tasks <https://arxiv.org/pdf/1901.09107.pdf>`_.
9
+ In addition, we also want the processors to have processing features to make the data samples compatible with GPT-style models.
10
+
11
+ Base Processor ``lavis.processors.base_processors``
12
+ *****************************************************
13
+
14
+ Note that any new processor definition should inherit the base processor class ``BaseProcessor``:
15
+
16
+ .. code-block:: python
17
+
18
+ from omegaconf import OmegaConf
19
+
20
+ class BaseProcessor:
21
+ def __init__(self):
22
+ self.transform = lambda x: x
23
+ return
24
+
25
+ def __call__(self, item):
26
+ return self.transform(item)
27
+
28
+ @classmethod
29
+ def from_config(cls, cfg=None):
30
+ return cls()
31
+
32
+ def build(self, **kwargs):
33
+ cfg = OmegaConf.create(kwargs)
34
+
35
+ return self.from_config(cfg)
36
+
37
+ This allows us to standardize operations of processors across all processor classes while still allowing customization of processors specifically to data and model types.
38
+ We encourage users not to modify the implementation of the base processor class as this will have an impact on all existing processor subclasses.
39
+
40
+ GPT-style Processors ``lavis.processors.gpt_processors``
41
+ **************************************************************
42
+ In this step, we can define new processor classes, e.g. under ``lavis.processors.gpt_processors``, for GPT models designed specifically for video-grounded dialogues.
43
+ First, we want to process video features by defining ``GPTVideoFeatureProcessor`` class.
44
+ In this tutorial, we assume video features are extracted beforehand and this processor simply loads the features from ``npy`` files.
45
+ Other methods that are specifically defined are ``padding`` (which is used by dataset instances to pad multiple video samples) and ``get_attention_mask`` (which creates an attention mask for Transformer attention in GPT models).
46
+
47
+ .. code-block:: python
48
+
49
+ SPECIAL_TOKENS_DICT = {'bos_token': "<bos>", 'eos_token': "<eos>", 'additional_special_tokens': ["<speaker1>", "<speaker2>", "<video>", "<cap>"], 'pad_token': "<pad>"}
50
+ ...
51
+
52
+ @registry.register_processor("gpt_video_ft")
53
+ class GPTVideoFeatureProcessor(BaseProcessor):
54
+ def __init__(self, visual_ft, audio_ft):
55
+
56
+ self.visual_ft = visual_ft
57
+ self.audio_ft = audio_ft
58
+
59
+ self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
60
+ self.tokenizer.add_special_tokens(SPECIAL_TOKENS_DICT)
61
+
62
+ def padding(self, seq):
63
+ padded_seq = torch.nn.utils.rnn.pad_sequence(seq, batch_first=True, padding_value=1.0)
64
+ return padded_seq
65
+
66
+ def get_attention_mask(self, seq):
67
+ return torch.sum(seq != 1, dim=2) != 0
68
+
69
+ def __call__(self, ft_root, vname):
70
+ all_ft = []
71
+
72
+ for ft_name in self.visual_ft:
73
+ ft_path = os.path.join(ft_root, ft_name, vname)
74
+ all_ft.append(np.load(ft_path + '.npy'))
75
+
76
+ for ft_name in self.audio_ft:
77
+ ft_path = os.path.join(ft_root, ft_name, vname)
78
+ all_ft.append(np.load(ft_path + '.npy'))
79
+
80
+ min_len = min([len(ft) for ft in all_ft])
81
+
82
+ sampled_ft = [ft[:min_len] for ft in all_ft]
83
+ sampled_ft = np.concatenate(sampled_ft, axis=1)
84
+ item = {}
85
+ item['video_fts'] = torch.Tensor(sampled_ft)
86
+
87
+ video_type_token = self.tokenizer.convert_tokens_to_ids('<video>')
88
+ item['token_type_ids'] = torch.Tensor([video_type_token] * len(sampled_ft)).long()
89
+
90
+ return item
91
+
92
+ @classmethod
93
+ def from_config(cls, cfg=None):
94
+ if cfg is None:
95
+ cfg = OmegaConf.create()
96
+
97
+ visual_ft = cfg.get("visual_ft", ["i3d_rgb"])
98
+ audio_ft = cfg.get("audio_ft", ["vggish"])
99
+
100
+ return cls(
101
+ visual_ft=visual_ft,
102
+ audio_ft=audio_ft
103
+ )
104
+
105
+ Another processor class that will be useful to have is to process dialogue data. Here we can define a ``GPTDialogueProcessor`` class.
106
+ This processor class receives raw annotations and constructs inputs as a concatenation of input sequences (questions, dialogue contexts, and responses) to facilitate application in GPT models.
107
+ Other methods that are specifically defined are ``padding`` (which is used by dataset instances to pad multiple sequence samples) and ``get_attention_mask`` (which creates an attention mask for Transformer attention in GPT models).
108
+
109
+ .. code-block:: python
110
+
111
+ SPECIAL_TOKENS_DICT = {'bos_token': "<bos>", 'eos_token': "<eos>", 'additional_special_tokens': ["<speaker1>", "<speaker2>", "<video>", "<cap>"], 'pad_token': "<pad>"}
112
+ ...
113
+
114
+ @registry.register_processor("gpt_dialogue")
115
+ class GPTDialogueProcessor(BaseProcessor):
116
+ def __init__(self, max_turns=3, use_caption=True):
117
+ self.max_turns = max_turns
118
+ self.use_caption = use_caption
119
+ self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
120
+ self.tokenizer.add_special_tokens(SPECIAL_TOKENS_DICT)
121
+
122
+ def sample_sequence(self, caption, history, answer):
123
+ bos, eos, speaker1, speaker2, cap = self.tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[:-2])
124
+ instance = {}
125
+ sequence = [caption] + history + [answer]
126
+ sequence = [s + [eos] for s in sequence]
127
+
128
+ instance["input_ids"] = list(chain(*sequence))
129
+ instance["token_type_ids"] = [cap] * len(sequence[0]) + [speaker2 if i % 2 else speaker1 for i, s in enumerate(sequence[1:]) for _ in s]
130
+ instance["labels"] = ([-1]*sum(len(s) for s in sequence[:-1])) + sequence[-1]
131
+
132
+ assert len(instance["input_ids"])==len(instance["token_type_ids"])
133
+ assert len(instance["token_type_ids"])==len(instance["labels"])
134
+
135
+ for k,v in instance.items():
136
+ instance[k] = torch.Tensor(v).long()
137
+
138
+ return instance
139
+
140
+ def padding(self, seq, pad_token=-1):
141
+ if pad_token==-1: pad_token = self.tokenizer.pad_token_id
142
+ padded_seq = torch.nn.utils.rnn.pad_sequence(seq, batch_first=True, padding_value=pad_token)
143
+ return padded_seq
144
+
145
+ def get_attention_mask(self, seq, pad_token=-1):
146
+ if pad_token==-1: pad_token = self.tokenizer.pad_token_id
147
+ return seq != pad_token
148
+
149
+ def __call__(self, ann):
150
+ if self.use_caption:
151
+ caption = ' '.join([ann['caption'], ann['summary']])
152
+ caption = self.tokenizer.encode(caption)
153
+ else:
154
+ caption = []
155
+
156
+ dial_history = []
157
+ for turn in ann['dialog'][-self.max_turns:]:
158
+ dial_history.append(turn['question'])
159
+ dial_history.append(turn['answer'])
160
+ dial_history.append(ann['question'])
161
+ dial_history = [self.tokenizer.encode(t) for t in dial_history]
162
+
163
+ answer = self.tokenizer.encode(ann['answer'])
164
+
165
+ item = self.sample_sequence(caption, dial_history, answer)
166
+
167
+ return item
168
+
169
+ @classmethod
170
+ def from_config(cls, cfg=None):
171
+ if cfg is None:
172
+ cfg = OmegaConf.create()
173
+
174
+ use_caption = cfg.get("use_caption", True)
175
+ max_turns = cfg.get("max_turns", 3)
176
+
177
+ return cls(max_turns=max_turns, use_caption=use_caption)
178
+
179
+ Registering New Processors ``lavis.processors.__init__``
180
+ **************************************************************
181
+
182
+ Finally, any new processor must be officially registered as part of the ``lavis.processors`` module.
183
+ For instance, to add processor classes for GPT-based dialogue models, including one for dialogue data ``GPTDialogueProcessor`` and one for video features ``GPTVideoFeatureProcessor``, we can modify the ``__init__.py`` as follows:
184
+
185
+ .. code-block:: python
186
+
187
+ from lavis.processors.gpt_processors import (
188
+ GPTVideoFeatureProcessor,
189
+ GPTDialogueProcessor,
190
+ )
191
+
192
+ __all__ = [
193
+ ...
194
+ # GPT
195
+ "GPTVideoFeatureProcessor",
196
+ "GPTDialogueProcessor"
197
+ ]
198
+
199
+ Assigning Processors
200
+ **************************************************************
201
+ From the above example of processor classes, note that we define a ``from_config`` method for each class.
202
+ This method will process a configuration file and pass specific parameters e.g. ``max_turns``, ``visual_ft``, to initialize the processor classes properly.
203
+ To do this, we can assign/ associate the correct registry of processor classes in a configuration file.
204
+ For instance, the following should be specified in a configuration file e.g. ``dialogue_avsd_ft.yaml``:
205
+
206
+ .. code-block:: yaml
207
+
208
+ datasets:
209
+ avsd_dialogue: # name of the dataset builder
210
+ vis_processor:
211
+ train:
212
+ name: "gpt_video_ft" # name of the visual processor for training data
213
+ visual_ft: ["i3d_flow", "i3d_rgb"]
214
+ audio_ft: ["vggish"]
215
+ eval:
216
+ name: "gpt_video_ft" # name of the visual processor for evaluation data
217
+ visual_ft: ["i3d_flow", "i3d_rgb"]
218
+ audio_ft: ["vggish"]
219
+ text_processor:
220
+ train:
221
+ name: "gpt_dialogue" # name of the textual processor for training data
222
+ max_turns: 3
223
+ use_caption: True
224
+ eval:
225
+ name: "gpt_dialogue" # name of the textual processor for evaluation data
226
+ max_turns: 3
227
+ use_caption: True
228
+
229
+ Subsequently, any processes (e.g. training) should load this configuration file to assign the correct processors.
230
+
231
+ .. code-block:: sh
232
+
233
+ python train.py --cfg-path dialogue_avsd_ft.yaml
docs/tutorial.rst ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Tutorials
2
+ ==============================
3
+
4
+ .. toctree::
5
+ :maxdepth: 1
6
+
7
+ tutorial.evaluation
8
+ tutorial.training-example
9
+ tutorial.configs
10
+ tutorial.datasets
11
+ tutorial.processors
12
+ tutorial.models
13
+ tutorial.tasks
docs/tutorial.tasks.rst ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Adding Tasks
2
+ ####################################
3
+
4
+ This is a tutorial on adding new machine learning tasks using ``lavis.tasks`` module.
5
+
6
+ The LAVIS library includes a standard task module that centralizes the model training and evaluation procedure of machine learning tasks.
7
+ The ``lavis.tasks`` module is designed such that any new tasks can be added and integrated, catering to any customization in the training and testing procedures.
8
+ In this tutorial, we will replicate the steps to add a new task into LAVIS for the `video-grounded dialogue tasks <https://arxiv.org/pdf/1901.09107.pdf>`_.
9
+
10
+ Base Task ``lavis.tasks.base_task``
11
+ ********************************************************************************
12
+
13
+ Note that any new model definition should inherit the base task class ``BaseTask``:
14
+
15
+ .. code-block:: python
16
+
17
+ import logging
18
+ import os
19
+
20
+ import torch.distributed as dist
21
+ from lavis.common.dist_utils import get_rank, get_world_size, is_main_process
22
+ from lavis.common.logger import MetricLogger, SmoothedValue
23
+ from lavis.common.registry import registry
24
+ from lavis.datasets.data_utils import prepare_sample
25
+
26
+ class BaseTask:
27
+ def __init__(self, **kwargs):
28
+ super().__init__()
29
+
30
+ self.inst_id_key = "instance_id"
31
+
32
+ @classmethod
33
+ def setup_task(cls, **kwargs):
34
+ return cls()
35
+
36
+ def build_model(self, cfg):
37
+ model_config = cfg.model_cfg
38
+
39
+ model_cls = registry.get_model_class(model_config.arch)
40
+ return model_cls.from_config(model_config)
41
+
42
+ def build_datasets(self, cfg):
43
+ """
44
+ Build a dictionary of datasets, keyed by split 'train', 'valid', 'test'.
45
+ Download dataset and annotations automatically if not exist.
46
+
47
+ Args:
48
+ cfg (common.config.Config): _description_
49
+
50
+ Returns:
51
+ dict: Dictionary of torch.utils.data.Dataset objects by split.
52
+ """
53
+
54
+ datasets = dict()
55
+
56
+ datasets_config = cfg.datasets_cfg
57
+
58
+ assert len(datasets_config) > 0, "At least one dataset has to be specified."
59
+
60
+ for name in datasets_config:
61
+ dataset_config = datasets_config[name]
62
+
63
+ builder = registry.get_builder_class(name)(dataset_config)
64
+ dataset = builder.build_datasets()
65
+
66
+ datasets[name] = dataset
67
+
68
+ return datasets
69
+
70
+ def train_step(self, model, samples):
71
+ loss = model(samples)["loss"]
72
+ return loss
73
+
74
+ ...
75
+
76
+ In this base task, we already declare and standardize many common methods such as ``train_step``, ``build_model``, and ``build_datasets``.
77
+ Inheriting this base task class allows us to standardize operations of tasks across all task classes.
78
+ We recommend users not change the implementation of the base task class as this will have an impact on all existing task subclasses.
79
+
80
+ Dialogue Task ``lavis.tasks.dialogue``
81
+ ********************************************************************************
82
+
83
+ In this step, we can define a new task class, e.g. under ``lavis.tasks.dialogue``, for video-grounded dialogues.
84
+ For instance, we define a new task class ``DialogueTask`` that inherits the super task class ``BaseTask``.
85
+
86
+ .. code-block:: python
87
+
88
+ import json
89
+ import os
90
+
91
+ from lavis.common.dist_utils import main_process
92
+ from lavis.common.logger import MetricLogger
93
+ from lavis.common.registry import registry
94
+ from lavis.tasks.base_task import BaseTask
95
+ from lavis.datasets.data_utils import prepare_sample
96
+
97
+ import numpy as np
98
+
99
+ @registry.register_task("dialogue")
100
+ class DialogueTask(BaseTask):
101
+ def __init__(self, num_beams, max_len, min_len, evaluate, report_metric=True):
102
+ super().__init__()
103
+
104
+ self.num_beams = num_beams
105
+ self.max_len = max_len
106
+ self.min_len = min_len
107
+ self.evaluate = evaluate
108
+
109
+ self.report_metric = report_metric
110
+
111
+ @classmethod
112
+ def setup_task(cls, cfg):
113
+ run_cfg = cfg.run_cfg
114
+
115
+ num_beams = run_cfg.num_beams
116
+ max_len = run_cfg.max_len
117
+ min_len = run_cfg.min_len
118
+ evaluate = run_cfg.evaluate
119
+
120
+ report_metric = run_cfg.get("report_metric", True)
121
+
122
+ return cls(
123
+ num_beams=num_beams,
124
+ max_len=max_len,
125
+ min_len=min_len,
126
+ evaluate=evaluate,
127
+ report_metric=report_metric,
128
+ )
129
+
130
+ def valid_step(self, model, samples):
131
+ results = []
132
+ loss = model(samples)["loss"].item()
133
+
134
+ return [loss]
135
+ ...
136
+
137
+ Note that for any new task, we advise the users to review carefully the functions implemented within ``BaseTask`` and consider which methods should be modified.
138
+ For instance, the base task class already contains a standard implementation of model training steps that are common among machine learning steps.
139
+ Some major methods we want to emphasize and should be customized by each task are the ``valid_step`` and ``evaluation``.
140
+ These operations were not fully implemented in the base task class due to the differences in evaluation procedures among many machine learning tasks.
141
+ Another method that should be considered is the ``setup_task`` method.
142
+ This method will receive configurations that set task-specific parameters to initialize any task instance.
143
+
144
+ Registering New Task ``lavis.tasks.__init__``
145
+ ********************************************************************************
146
+
147
+ Any new task must be officially registered as part of the ``lavis.tasks`` module. For instance, to add a new task for video-grounded dialogues, we can modify the ``__init__.py`` as follows:
148
+
149
+ .. code-block:: python
150
+
151
+ from lavis.tasks.dialogue import DialogueTask
152
+
153
+ ...
154
+ __all__ = [
155
+ ...
156
+ "DialogueTask"
157
+ ]
158
+
159
+ Assigning Task
160
+ ***************
161
+
162
+ From the above example of task class, note that we define a ``setup_task`` method for each task class.
163
+ This method will process a configuration file and pass specific parameters e.g. ``num_beams`` (for beam search generative tasks during the inference stage), to initialize the task classes properly.
164
+ To assign and associate any task, we need to specify the correct registry of task classes in a configuration file.
165
+ For instance, the following should be specified in a configuration file e.g. ``dialogue_avsd_ft.yaml``:
166
+
167
+ .. code-block:: yaml
168
+
169
+ run:
170
+ task: dialogue # name of the task
171
+
172
+ # optimizer
173
+ ...
174
+
175
+ max_len: 20
176
+ min_len: 5
177
+ num_beams: 3
178
+ ...
179
+
180
+ Subsequently, any processes (e.g. training) should load this configuration file to assign the correct task.
181
+
182
+ .. code-block:: sh
183
+
184
+ python train.py --cfg-path dialogue_avsd_ft.yaml
docs/tutorial.training-example.rst ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Example on Finetuning BLIP on COCO-Captioning
2
+ ################################################
3
+
4
+ To finetune BLIP model on the coco caption dataset, first refer to :ref:`prep coco` to prepare the dataset if you have not done so.
5
+
6
+ To finetune the model, we have prepared a run script for you, which can run as follows:
7
+
8
+ .. code-block:: bash
9
+
10
+ bash run_scripts/lavis/blip/train/train_caption_coco_large.sh
11
+
12
+ This will finetune the pre-trained BLIP large model into a new model that can be used for captioning.
13
+
14
+ Deep Dive
15
+ **********
16
+ Now let's take a closer look at the script and see what it does.
17
+
18
+ .. code-block:: bash
19
+
20
+ python -m torch.distributed.run --nproc_per_node=8 train.py --cfg-path lavis/projects/blip/train/caption_coco_large_ft.yaml
21
+
22
+ As can be seen, the script simply calls the :code:`train.py` with PyTorch distributed training enabled.
23
+ The :code:`--cfg-path` argument specifies the **runtime config** file to use. The config file is a YAML file that specifies the training parameters, shown as follows:
24
+
25
+ .. literalinclude:: ../lavis/projects/blip/train/caption_coco_large_ft.yaml
26
+ :language: yaml
27
+ :linenos:
28
+
29
+ The runtime config file is divided into 3 sections:
30
+ - :code:`model`: specifies the model architecture and type to use.
31
+ - :code:`data`: specifies the dataset to use.
32
+ - :code:`run`: specifies the runner arguments, such as tasks, optimizer, learning rate scheduler, etc.
33
+
34
+ We describe each section in detail below.
35
+
36
+ Model configurations
37
+ =====================
38
+
39
+ .. literalinclude:: ../lavis/projects/blip/train/caption_coco_large_ft.yaml
40
+ :language: yaml
41
+ :linenos:
42
+ :lines: 6-10
43
+
44
+ The :code:`arch` argument specifies the model architecture to use. In this case, we use the :code:`blip_caption` architecture.
45
+ You can find available architectures by inspecting the :code:`model_zoo`.
46
+ Once the architecture is specified, the runner will look for the model class registered with the name and try to instantiate a model instance.
47
+ In this case :code:`BlipCaption` is the model registered with the name :code:`blip_caption`.
48
+
49
+ The registry maintains a mapping from the name string to the model class.
50
+ This allows the runner to find the model class dynamically based on the name string from the config file.
51
+ The following segment in :code:`lavis/models/blip_models/blip_caption.py` shows how :code:`BlipCaption` is registered with the name string :code:`blip_caption`:
52
+
53
+ .. literalinclude:: ../lavis/models/blip_models/blip_caption.py
54
+ :language: python
55
+ :linenos:
56
+ :lines: 20-38
57
+
58
+ One same model architecture may be pre-trained or finetuned on different datasets or have different model configurations.
59
+ For example, :code:`BlipCaption` have:
60
+
61
+ - :code:`base_coco`: pre-trained base BLIP model adapated for COCO captioning finetuning.
62
+
63
+ - :code:`large_coco`: pre-trained large BLIP model adapated for COCO captioning finetuning.
64
+
65
+ Therefore, we also need to specify :code:`model_type`. Here we use :code:`large_coco`.
66
+ And we set :code:`load_finetuned` to :code:`False` to indicate that we are finetuning the model from the pre-trained weights.
67
+ If :code:`load_finetuned` set to :code:`True` as by default, the model will load finetuned weights on coco captioning.
68
+
69
+ Given the model architecture and type, the library will then look for the default model config for :code:`large_coco` in :code:`lavis/models/blip_models/blip_caption.py`.
70
+ As can be seen in the above code snippet, the corresponding config path is stored in :code:`BlipCaption.PRETRAINED_MODEL_CONFIG_DICT`.
71
+ Then the library will load :code:`lavis/configs/models/blip_caption_large_coco.yaml` as the configuration to build the model.
72
+
73
+ *Priority of Configs*: Note that the priority of the run config is higher than the default model config, meaning that arguments in the run config will override the default model config.
74
+ For example, in the default model config, :code:`load_finetuned` is set to :code:`True` by default, while in the run config, we set it to :code:`False` and finetuning from the pre-trained weights only.
75
+
76
+
77
+ Dataset configurations
78
+ =========================
79
+
80
+ The second section of the config file specifies the dataset(s) to use.
81
+
82
+ .. literalinclude:: ../lavis/projects/blip/train/caption_coco_large_ft.yaml
83
+ :language: yaml
84
+ :linenos:
85
+ :lines: 12-24
86
+
87
+ We associate each dataset with a :code:`vis_processor` and a :code:`text_processor`, responsible for processing the visual and textual input respectively.
88
+ Here we again use the registry mechanism to dynamically load the processor class based on the name string.
89
+ For example, :code:`blip_image_train` is the name string for the :code:`BlipImageTrainProcessor` class, which is registered in :code:`lavis/processors/blip_processors.py`.
90
+
91
+ Similarly, the dataset name string is also registered in the registry, pointing to a dataset builder :code:`COCOCapBuilder` class.
92
+ By default, the builder will load the default dataset configuration as in :code:`DATASET_CONFIG_DICT`. You may also add new dataset types by adding new entries to the dictionary.
93
+
94
+ The dataset configuration used here is:
95
+
96
+ .. literalinclude:: ../lavis/configs/datasets/coco/defaults_cap.yaml
97
+ :language: yaml
98
+ :linenos:
99
+ :lines: 6-28
100
+
101
+ In this configuration file, we specify the dataset name and mainly its building information.
102
+ The build information is divided into two parts: :code:`annotation` and :code:`images`. The annotation files will be automatically downloaded upon loading the dataset for the first time.
103
+ The :code:`images` part specifies the image root directory. This is a relative path to the cache directory, which is :code:`cache` by default. If you have a local copy of the dataset, you can specify the path to the local copy by
104
+ overwriting the :code:`images` part in the runtime config file. For example, you may alter the run config as below to use your local dataset copy:
105
+
106
+ .. code:: yaml
107
+
108
+ datasets:
109
+ coco_caption: # name of the dataset builder
110
+ vis_processor:
111
+ train:
112
+ name: "blip_image_train"
113
+ eval:
114
+ name: "blip_image_eval"
115
+ text_processor:
116
+ train:
117
+ name: "blip_caption"
118
+ prompt: "a picture of "
119
+ eval:
120
+ name: "blip_caption"
121
+ images:
122
+ YOUR_LOCAL_IMAGE_ROOT_DIR
123
+
124
+ LAVIS supports using multiple datasets for training. See an example in :code:`lavis/projects/blip/train/pretrain_14m.yaml`.
125
+
126
+
127
+ Runner configurations
128
+ =========================
129
+ The last section of the config file specifies the arguments for the runner, shown below:
130
+
131
+ .. literalinclude:: ../lavis/projects/blip/train/caption_coco_large_ft.yaml
132
+ :language: yaml
133
+ :linenos:
134
+ :lines: 26-56
135
+
136
+ Here we specify runner-related arguments, including
137
+ - task-specific arguments, such as :code:`task`, :code:`max_len`, :code:`min_len`, etc.
138
+ - learning rate schedulers, optimizer;
139
+ - distributed training settings;
140
+ - logging and checkpointing settings.
141
+
142
+ Available Configurations
143
+ #########################
144
+
145
+ See :ref:`config` for the full list of available configurations and their descriptions.
evaluate.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import argparse
9
+ import random
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.backends.cudnn as cudnn
14
+
15
+ import lavis.tasks as tasks
16
+ from lavis.common.config import Config
17
+ from lavis.common.dist_utils import get_rank, init_distributed_mode
18
+ from lavis.common.logger import setup_logger
19
+ from lavis.common.optims import (
20
+ LinearWarmupCosineLRScheduler,
21
+ LinearWarmupStepLRScheduler,
22
+ )
23
+ from lavis.common.utils import now
24
+
25
+ # imports modules for registration
26
+ from lavis.datasets.builders import *
27
+ from lavis.models import *
28
+ from lavis.processors import *
29
+ from lavis.runners.runner_base import RunnerBase
30
+ from lavis.tasks import *
31
+
32
+
33
+ def parse_args():
34
+ parser = argparse.ArgumentParser(description="Training")
35
+
36
+ parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
37
+ parser.add_argument(
38
+ "--options",
39
+ nargs="+",
40
+ help="override some settings in the used config, the key-value pair "
41
+ "in xxx=yyy format will be merged into config file (deprecate), "
42
+ "change to --cfg-options instead.",
43
+ )
44
+
45
+ args = parser.parse_args()
46
+ # if 'LOCAL_RANK' not in os.environ:
47
+ # os.environ['LOCAL_RANK'] = str(args.local_rank)
48
+
49
+ return args
50
+
51
+
52
+ def setup_seeds(config):
53
+ seed = config.run_cfg.seed + get_rank()
54
+
55
+ random.seed(seed)
56
+ np.random.seed(seed)
57
+ torch.manual_seed(seed)
58
+
59
+ cudnn.benchmark = False
60
+ cudnn.deterministic = True
61
+
62
+
63
+ def main():
64
+ # allow auto-dl completes on main process without timeout when using NCCL backend.
65
+ # os.environ["NCCL_BLOCKING_WAIT"] = "1"
66
+
67
+ # set before init_distributed_mode() to ensure the same job_id shared across all ranks.
68
+ job_id = now()
69
+
70
+ cfg = Config(parse_args())
71
+
72
+ init_distributed_mode(cfg.run_cfg)
73
+
74
+ setup_seeds(cfg)
75
+
76
+ # set after init_distributed_mode() to only log on master.
77
+ setup_logger()
78
+
79
+ cfg.pretty_print()
80
+
81
+ task = tasks.setup_task(cfg)
82
+ datasets = task.build_datasets(cfg)
83
+ model = task.build_model(cfg)
84
+
85
+ runner = RunnerBase(
86
+ cfg=cfg, job_id=job_id, task=task, model=model, datasets=datasets
87
+ )
88
+ runner.evaluate(skip_reload=True)
89
+
90
+
91
+ if __name__ == "__main__":
92
+ main()
lavis/.DS_Store ADDED
Binary file (10.2 kB). View file
 
lavis/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import os
9
+ import sys
10
+
11
+ from omegaconf import OmegaConf
12
+
13
+ from lavis.common.registry import registry
14
+
15
+ from lavis.datasets.builders import *
16
+ from lavis.models import *
17
+ from lavis.processors import *
18
+ from lavis.tasks import *
19
+
20
+
21
+ root_dir = os.path.dirname(os.path.abspath(__file__))
22
+ default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml"))
23
+
24
+ registry.register_path("library_root", root_dir)
25
+ repo_root = os.path.join(root_dir, "..")
26
+ registry.register_path("repo_root", repo_root)
27
+ cache_root = os.path.join(repo_root, default_cfg.env.cache_root)
28
+ registry.register_path("cache_root", cache_root)
29
+
30
+ registry.register("MAX_INT", sys.maxsize)
31
+ registry.register("SPLIT_NAMES", ["train", "val", "test"])
lavis/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (988 Bytes). View file
 
lavis/common/.DS_Store ADDED
Binary file (6.15 kB). View file
 
lavis/common/__pycache__/config.cpython-38.pyc ADDED
Binary file (12.1 kB). View file
 
lavis/common/__pycache__/dist_utils.cpython-38.pyc ADDED
Binary file (3.76 kB). View file