diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..33e33c4f9da6fcd2fbab53911e80e142acaf1d73 Binary files /dev/null and b/.DS_Store differ diff --git a/.gitattributes b/.gitattributes index d39c3a12b82626027eeb8167291b436617976543..0a90b7bba124d23d38ba55c8e1a58621e90ee599 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text demo4.mp4 filter=lfs diff=lfs merge=lfs -text +videos/*.mp4 filter=lfs diff=lfs merge=lfs -text diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..9ba97919e5b9568c8b9c42ea85251f01049a220e --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,14 @@ +BSD 3-Clause License + +Copyright (c) 2022 Salesforce, Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000000000000000000000000000000000000..aead17fe5bb1086ea8b3054750ff5721c3ba6601 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,7 @@ +recursive-include lavis/configs *.yaml *.json +recursive-include lavis/projects *.yaml *.json + +recursive-exclude lavis/datasets/download_scripts * +recursive-exclude lavis/output * + +include requirements.txt diff --git a/README.md b/README.md index e9aa6ce729e467028080d237bf015eb9d49c88c9..d708f4d957ee173263e60f10ad23bda1a3d5978c 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,112 @@ ---- -title: SeViLA -emoji: 📉 -colorFrom: pink -colorTo: yellow -sdk: gradio -sdk_version: 3.29.0 -app_file: app.py -pinned: false -license: openrail ---- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +# Self-Chained Image-Language Model for Video Localization and Question Answering + +* Authors: [Shoubin Yu](https://yui010206.github.io/), [Jaemin Cho](https://j-min.io), [Prateek Yadav](https://prateek-yadav.github.io/), [Mohit Bansal](https://www.cs.unc.edu/~mbansal/) +* [arXiv](https://arxiv.org/abs/2305.06988) +teaser image + +teaser image + +teaser image + + +# Code structure +```bash + +# Data & Data Preprocessing +./sevila_data + +# Pretrained Checkpoints +./sevila_checkpoints + +# SeViLA code +./lavis/ + +# running scripts for SeViLa localizer/answerer training/inference +./run_scripts + +``` + +# Setup + +## Install Dependencies + +1. (Optional) Creating conda environment + +```bash +conda create -n sevila python=3.8 +conda activate sevila +``` + +2. build from source + +```bash +pip install -e . +``` + +## Download Pretrained Models +We pre-train SeViLA localizer on QVHighlights and hold checkpoints via [Huggingface](https://huggingface.co/Shoubin/SeViLA/resolve/main/sevila_pretrained.pth). +Download checkpoints and put it under /sevila_checkpoints. +The checkpoints (814.55M) contains pre-trained localizer and zero-shot answerer. + + + +# Dataset Preparation +We test our model on: ++ [NExT-QA](https://doc-doc.github.io/docs/nextqa.html) + ++ [STAR](https://star.csail.mit.edu/) + ++ [How2QA](https://value-benchmark.github.io/index.html) + ++ [TVQA](https://tvqa.cs.unc.edu/) + ++ [VLEP](https://value-benchmark.github.io/index.html) + ++ [QVHighlights](https://github.com/jayleicn/moment_detr) + +please download original data and preprocess them via our [scripts](sevila_data/) under ./sevila_data/ . + + +# Training and Inference +We provideo SeViLA training and inference script examples as following: +## 1) Localizer Pre-training +```bash +sh run_scripts/sevila/pre-train/pretrain_qvh.sh +``` + +## 2) Localizer Self-refinement + +```bash +sh run_scripts/sevila/refinement/nextqa_sr.sh +``` + +## 3) Answerer Fine-tuning + +```bash +sh run_scripts/sevila/finetune/nextqa_ft.sh +``` + +## 4) Inference + +```bash +sh run_scripts/sevila/inference/nextqa_infer.sh +``` + + +# Acknowledgments +We thank the developers of [LAVIS](https://github.com/salesforce/LAVIS), [BLIP-2](https://github.com/salesforce/LAVIS/tree/main/projects/blip2), [CLIP](https://github.com/openai/CLIP), [All-in-one](https://github.com/showlab/all-in-one), for their public code release. + + +# Reference +Please cite our paper if you use our models in your works: + + +```bibtex +@misc{yu2023selfchained, + title={Self-Chained Image-Language Model for Video Localization and Question Answering}, + author={Shoubin Yu and Jaemin Cho and Prateek Yadav and Mohit Bansal}, + year={2023}, + eprint={2305.06988}, + archivePrefix={arXiv}, + primaryClass={cs.CV} +} diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..e08158601aada02ca8a2955da83aa9a600e005c4 --- /dev/null +++ b/app.py @@ -0,0 +1,206 @@ +import gradio as gr +import os +import torch +from torchvision import transforms +from lavis.processors import transforms_video +from lavis.datasets.data_utils import load_video_demo +from lavis.processors.blip_processors import ToUint8, ToTHWC +from lavis.models.sevila_models.sevila import SeViLA +from typing import Optional +import warnings +# model config +img_size = 224 +num_query_token = 32 +t5_model = 'google/flan-t5-xl' +drop_path_rate = 0 +use_grad_checkpoint = False +vit_precision = "fp16" +freeze_vit = True +prompt = '' +max_txt_len = 77 +answer_num = 5 +apply_lemmatizer = False +task = 'freeze_loc_freeze_qa_vid' + +# prompt +LOC_propmpt = 'Does the information within the frame provide the necessary details to accurately answer the given question?' +QA_prompt = 'Considering the information presented in the frame, select the correct answer from the options.' + +# processors config +mean = (0.48145466, 0.4578275, 0.40821073) +std = (0.26862954, 0.26130258, 0.27577711) +normalize = transforms.Normalize(mean, std) +image_size = img_size +transform = transforms.Compose([ToUint8(), ToTHWC(), transforms_video.ToTensorVideo(), normalize]) + +print('model loading') +sevila = SeViLA( + img_size=img_size, + drop_path_rate=drop_path_rate, + use_grad_checkpoint=use_grad_checkpoint, + vit_precision=vit_precision, + freeze_vit=freeze_vit, + num_query_token=num_query_token, + t5_model=t5_model, + prompt=prompt, + max_txt_len=max_txt_len, + apply_lemmatizer=apply_lemmatizer, + frame_num=4, + answer_num=answer_num, + task=task, + ) + +sevila.load_checkpoint(url_or_filename='https://huggingface.co/Shoubin/SeViLA/resolve/main/sevila_pretrained.pth') +print('model loaded') + +ANS_MAPPING = {0 : 'A', 1 : 'B', 2 : 'C', 3 : 'D', 4 : 'E'} + +# os.mkdir('video') + +def sevila_demo(video, + question, + option1, option2, option3, + video_frame_num, + keyframe_num): + + if torch.cuda.is_available(): + device = 0 + else: + device = 'cpu' + + global sevila + if device == "cpu": + sevila = sevila.float() + else: + sevila = sevila.to(int(device)) + + vpath = video + raw_clip, indice, fps, vlen = load_video_demo( + video_path=vpath, + n_frms=int(video_frame_num), + height=image_size, + width=image_size, + sampling="uniform", + clip_proposal=None + ) + clip = transform(raw_clip.permute(1,0,2,3)) + clip = clip.float().to(int(device)) + clip = clip.unsqueeze(0) + # check + if option1[-1] != '.': + option1 += '.' + if option2[-1] != '.': + option2 += '.' + if option3[-1] != '.': + option3 += '.' + option_dict = {0:option1, 1:option2, 2:option3} + options = 'Option A:{} Option B:{} Option C:{}'.format(option1, option2, option3) + text_input_qa = 'Question: ' + question + ' ' + options + ' ' + QA_prompt + text_input_loc = 'Question: ' + question + ' ' + options + ' ' + LOC_propmpt + + out = sevila.generate_demo(clip, text_input_qa, text_input_loc, int(keyframe_num)) + # print(out) + answer_id = out['output_text'][0] + answer = option_dict[answer_id] + select_index = out['frame_idx'][0] + # images = [] + keyframes = [] + timestamps =[] + + # print('raw_clip', len(raw_clip)) + # for j in range(int(video_frame_num)): + # image = raw_clip[:, j, :, :].int() + # image = image.permute(1, 2, 0).numpy() + # images.append(image) + + video_len = vlen/fps # seconds + + for i in select_index: + image = raw_clip[:, i, :, :].int() + image = image.permute(1, 2, 0).numpy() + keyframes.append(image) + select_i = indice[i] + time = round((select_i / vlen) * video_len, 2) + timestamps.append(str(time)+'s') + + gr.components.Gallery(keyframes) + #gr.components.Gallery(images) + timestamps_des = '' + for i in range(len(select_index)): + timestamps_des += 'Keyframe {}: {} \n'.format(str(i+1), timestamps[i]) + + return keyframes, timestamps_des, answer + +with gr.Blocks(title="SeViLA demo") as demo: + description = """

+ Self-Chained Image-Language Model for Video Localization and Question Answering +
+ + Shoubin Yu, + Jaemin Cho, + Prateek Yadav, + Mohit Bansal + +
+ + [GitHub] + [Paper] + +

+

+ To locate keyframes in a video and answer question, please: +
+ (1) upolad your video; (2) write your question/options and set # video frame/# keyframe/running device; (3) click Locate and Answer! +
+ Just a heads up - loading the SeViLA model can take a few minutes (typically 2-3), and running examples requires about 12GB of memory. +
+ We've got you covered! We've provided some example videos and questions below to help you get started. Feel free to try out SeViLA with these! +

+ """ + gr.HTML(description) + with gr.Row(): + with gr.Column(scale=1, min_width=600): + video = gr.Video(label='Video') + question = gr.Textbox(placeholder="Why did the two ladies put their hands above their eyes while staring out?", label='Question') + with gr.Row(): + option1 = gr.Textbox(placeholder="practicing cheer", label='Option 1') + option2 = gr.Textbox(placeholder="posing for photo", label='Option 2') + option3 = gr.Textbox(placeholder="to see better", label='Option 3') + video_frame_num = gr.Textbox(placeholder=32, label='# Video Frame') + keyframe_num = gr.Textbox(placeholder=4, label='# Keyframe') + # device = gr.Textbox(placeholder=0, label='Device') + gen_btn = gr.Button(value='Locate and Answer!') + with gr.Column(scale=2, min_width=600): + keyframes = gr.Gallery( + label="Keyframes", show_label=False, elem_id="gallery" + ).style(columns=[4], rows=[1], object_fit="contain", height="auto") + #keyframes = gr.Gallery(label='Keyframes') + timestamps = gr.outputs.Textbox(label="Keyframe Timestamps") + answer = gr.outputs.Textbox(label="Output Answer") + + gen_btn.click( + sevila_demo, + inputs=[video, question, option1, option2, option3, video_frame_num, keyframe_num], + outputs=[keyframes, timestamps, answer], + queue=True + ) + #demo = gr.Interface(sevila_demo, + # inputs=[gr.Video(), question, option1, option2, option3, video_frame_num, keyframe_num, device], + # outputs=['gallery', timestamps, answer], + # examples=[['videos/demo1.mp4', 'Why did the two ladies put their hands above their eyes while staring out?', 'practicing cheer.', 'play ball.', 'to see better.', 32, 4, 0], + # ['videos/demo2.mp4', 'What did both of them do after completing skiing?', 'jump and pose.' , 'bend down.','raised their hands.', 32, 4, 0], + # ['videos/demo3.mp4', 'What room was Wilson breaking into when House found him?', 'the kitchen.' , 'the dining room.','the bathroom.', 32, 4, 0]] + # ) + with gr.Column(): + gr.Examples( + inputs=[video, question, option1, option2, option3, video_frame_num, keyframe_num], + outputs=[keyframes, timestamps, answer], + fn=sevila_demo, + examples=[['videos/demo1.mp4', 'Why did the two ladies put their hands above their eyes while staring out?', 'practicing cheer', 'play ball', 'to see better', 32, 4], + ['videos/demo2.mp4', 'What did both of them do after completing skiing?', 'jump and pose' , 'bend down','raised their hands', 32, 4], + ['videos/demo3.mp4', 'What room was Wilson breaking into when House found him?', 'the kitchen' , 'the dining room','the bathroom', 32, 4], + ['videos/demo4.mp4', 'what kind of bird is it?', 'chikadee' , 'eagle','seagull', 32, 1]], + cache_examples=False, + ) +demo.queue(concurrency_count=1, api_open=False) +demo.launch(share=False) diff --git a/app/__init__.py b/app/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0552204081d813fbd3fb28c81dbada83d4b53021 --- /dev/null +++ b/app/__init__.py @@ -0,0 +1,26 @@ +""" + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from PIL import Image +import requests + +import streamlit as st +import torch + + +@st.cache() +def load_demo_image(): + img_url = ( + "https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg" + ) + raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") + return raw_image + + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +cache_root = "/export/home/.cache/lavis/" diff --git a/app/calculate_coco_features.py b/app/calculate_coco_features.py new file mode 100644 index 0000000000000000000000000000000000000000..168e8503e943b715fbc3e010444bfc57901b8ffc --- /dev/null +++ b/app/calculate_coco_features.py @@ -0,0 +1,87 @@ +""" + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from PIL import Image +import requests +import torch + +import os + +from lavis.common.registry import registry +from lavis.processors import * +from lavis.models import * +from lavis.common.utils import build_default_model + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def load_demo_image(): + img_url = ( + "https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg" + ) + raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") + + return raw_image + + +def read_img(filepath): + raw_image = Image.open(filepath).convert("RGB") + + return raw_image + + +# model +model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth" +feature_extractor = BlipFeatureExtractor(pretrained=model_url) + +feature_extractor.eval() +feature_extractor = feature_extractor.to(device) + +# preprocessors +vis_processor = BlipImageEvalProcessor(image_size=224) +text_processor = BlipCaptionProcessor() + +# files to process +# file_root = "/export/home/.cache/lavis/coco/images/val2014" +file_root = "/export/home/.cache/lavis/coco/images/train2014" +filepaths = os.listdir(file_root) + +print(len(filepaths)) + +caption = "dummy" + +path2feat = dict() +bsz = 256 + +images_in_batch = [] +filepaths_in_batch = [] + +for i, filename in enumerate(filepaths): + if i % bsz == 0 and i > 0: + images_in_batch = torch.cat(images_in_batch, dim=0).to(device) + with torch.no_grad(): + image_features = feature_extractor( + images_in_batch, caption, mode="image", normalized=True + )[:, 0] + + for filepath, image_feat in zip(filepaths_in_batch, image_features): + path2feat[os.path.basename(filepath)] = image_feat.detach().cpu() + + images_in_batch = [] + filepaths_in_batch = [] + + print(len(path2feat), image_features.shape) + else: + filepath = os.path.join(file_root, filename) + + image = read_img(filepath) + image = vis_processor(image).unsqueeze(0) + + images_in_batch.append(image) + filepaths_in_batch.append(filepath) + +torch.save(path2feat, "path2feat_coco_train2014.pth") diff --git a/app/caption.py b/app/caption.py new file mode 100644 index 0000000000000000000000000000000000000000..ad118988e8692f64261e344ebe76b264f9ab02d7 --- /dev/null +++ b/app/caption.py @@ -0,0 +1,98 @@ +""" + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import streamlit as st +from app import device, load_demo_image +from app.utils import load_model_cache +from lavis.processors import load_processor +from PIL import Image + + +def app(): + # ===== layout ===== + model_type = st.sidebar.selectbox("Model:", ["BLIP_base", "BLIP_large"]) + + sampling_method = st.sidebar.selectbox( + "Sampling method:", ["Beam search", "Nucleus sampling"] + ) + + st.markdown( + "

Image Description Generation

", + unsafe_allow_html=True, + ) + + instructions = """Try the provided image or upload your own:""" + file = st.file_uploader(instructions) + + use_beam = sampling_method == "Beam search" + + col1, col2 = st.columns(2) + + if file: + raw_img = Image.open(file).convert("RGB") + else: + raw_img = load_demo_image() + + col1.header("Image") + + w, h = raw_img.size + scaling_factor = 720 / w + resized_image = raw_img.resize((int(w * scaling_factor), int(h * scaling_factor))) + + col1.image(resized_image, use_column_width=True) + col2.header("Description") + + cap_button = st.button("Generate") + + # ==== event ==== + vis_processor = load_processor("blip_image_eval").build(image_size=384) + + if cap_button: + if model_type.startswith("BLIP"): + blip_type = model_type.split("_")[1].lower() + model = load_model_cache( + "blip_caption", + model_type=f"{blip_type}_coco", + is_eval=True, + device=device, + ) + + img = vis_processor(raw_img).unsqueeze(0).to(device) + captions = generate_caption( + model=model, image=img, use_nucleus_sampling=not use_beam + ) + + col2.write("\n\n".join(captions), use_column_width=True) + + +def generate_caption( + model, image, use_nucleus_sampling=False, num_beams=3, max_length=40, min_length=5 +): + samples = {"image": image} + + captions = [] + if use_nucleus_sampling: + for _ in range(5): + caption = model.generate( + samples, + use_nucleus_sampling=True, + max_length=max_length, + min_length=min_length, + top_p=0.9, + ) + captions.append(caption[0]) + else: + caption = model.generate( + samples, + use_nucleus_sampling=False, + num_beams=num_beams, + max_length=max_length, + min_length=min_length, + ) + captions.append(caption[0]) + + return captions diff --git a/app/classification.py b/app/classification.py new file mode 100644 index 0000000000000000000000000000000000000000..2b5bd896474d5df6814220c79aeda6d5a895ab92 --- /dev/null +++ b/app/classification.py @@ -0,0 +1,216 @@ +""" + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import plotly.graph_objects as go +import requests +import streamlit as st +import torch +from lavis.models import load_model +from lavis.processors import load_processor +from lavis.processors.blip_processors import BlipCaptionProcessor +from PIL import Image + +from app import device, load_demo_image +from app.utils import load_blip_itm_model +from lavis.processors.clip_processors import ClipImageEvalProcessor + + +@st.cache() +def load_demo_image(img_url=None): + if not img_url: + img_url = "https://img.atlasobscura.com/yDJ86L8Ou6aIjBsxnlAy5f164w1rjTgcHZcx2yUs4mo/rt:fit/w:1200/q:81/sm:1/scp:1/ar:1/aHR0cHM6Ly9hdGxh/cy1kZXYuczMuYW1h/em9uYXdzLmNvbS91/cGxvYWRzL3BsYWNl/X2ltYWdlcy85MDll/MDRjOS00NTJjLTQx/NzQtYTY4MS02NmQw/MzI2YWIzNjk1ZGVk/MGZhMTJiMTM5MmZi/NGFfUmVhcl92aWV3/X29mX3RoZV9NZXJs/aW9uX3N0YXR1ZV9h/dF9NZXJsaW9uX1Bh/cmssX1NpbmdhcG9y/ZSxfd2l0aF9NYXJp/bmFfQmF5X1NhbmRz/X2luX3RoZV9kaXN0/YW5jZV8tXzIwMTQw/MzA3LmpwZw.jpg" + raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") + return raw_image + + +@st.cache( + hash_funcs={ + torch.nn.parameter.Parameter: lambda parameter: parameter.data.detach() + .cpu() + .numpy() + }, + allow_output_mutation=True, +) +def load_model_cache(model_type, device): + if model_type == "blip": + model = load_model( + "blip_feature_extractor", model_type="base", is_eval=True, device=device + ) + elif model_type == "albef": + model = load_model( + "albef_feature_extractor", model_type="base", is_eval=True, device=device + ) + elif model_type == "CLIP_ViT-B-32": + model = load_model( + "clip_feature_extractor", "ViT-B-32", is_eval=True, device=device + ) + elif model_type == "CLIP_ViT-B-16": + model = load_model( + "clip_feature_extractor", "ViT-B-16", is_eval=True, device=device + ) + elif model_type == "CLIP_ViT-L-14": + model = load_model( + "clip_feature_extractor", "ViT-L-14", is_eval=True, device=device + ) + + return model + + +def app(): + model_type = st.sidebar.selectbox( + "Model:", + ["ALBEF", "BLIP_Base", "CLIP_ViT-B-32", "CLIP_ViT-B-16", "CLIP_ViT-L-14"], + ) + score_type = st.sidebar.selectbox("Score type:", ["Cosine", "Multimodal"]) + + # ===== layout ===== + st.markdown( + "

Zero-shot Classification

", + unsafe_allow_html=True, + ) + + instructions = """Try the provided image or upload your own:""" + file = st.file_uploader(instructions) + + st.header("Image") + if file: + raw_img = Image.open(file).convert("RGB") + else: + raw_img = load_demo_image() + + st.image(raw_img) # , use_column_width=True) + + col1, col2 = st.columns(2) + + col1.header("Categories") + + cls_0 = col1.text_input("category 1", value="merlion") + cls_1 = col1.text_input("category 2", value="sky") + cls_2 = col1.text_input("category 3", value="giraffe") + cls_3 = col1.text_input("category 4", value="fountain") + cls_4 = col1.text_input("category 5", value="marina bay") + + cls_names = [cls_0, cls_1, cls_2, cls_3, cls_4] + cls_names = [cls_nm for cls_nm in cls_names if len(cls_nm) > 0] + + if len(cls_names) != len(set(cls_names)): + st.error("Please provide unique class names") + return + + button = st.button("Submit") + + col2.header("Prediction") + + # ===== event ===== + + if button: + if model_type.startswith("BLIP"): + text_processor = BlipCaptionProcessor(prompt="A picture of ") + cls_prompt = [text_processor(cls_nm) for cls_nm in cls_names] + + if score_type == "Cosine": + vis_processor = load_processor("blip_image_eval").build(image_size=224) + img = vis_processor(raw_img).unsqueeze(0).to(device) + + feature_extractor = load_model_cache(model_type="blip", device=device) + + sample = {"image": img, "text_input": cls_prompt} + + with torch.no_grad(): + image_features = feature_extractor.extract_features( + sample, mode="image" + ).image_embeds_proj[:, 0] + text_features = feature_extractor.extract_features( + sample, mode="text" + ).text_embeds_proj[:, 0] + sims = (image_features @ text_features.t())[ + 0 + ] / feature_extractor.temp + + else: + vis_processor = load_processor("blip_image_eval").build(image_size=384) + img = vis_processor(raw_img).unsqueeze(0).to(device) + + model = load_blip_itm_model(device) + + output = model(img, cls_prompt, match_head="itm") + sims = output[:, 1] + + sims = torch.nn.Softmax(dim=0)(sims) + inv_sims = [sim * 100 for sim in sims.tolist()[::-1]] + + elif model_type.startswith("ALBEF"): + vis_processor = load_processor("blip_image_eval").build(image_size=224) + img = vis_processor(raw_img).unsqueeze(0).to(device) + + text_processor = BlipCaptionProcessor(prompt="A picture of ") + cls_prompt = [text_processor(cls_nm) for cls_nm in cls_names] + + feature_extractor = load_model_cache(model_type="albef", device=device) + + sample = {"image": img, "text_input": cls_prompt} + + with torch.no_grad(): + image_features = feature_extractor.extract_features( + sample, mode="image" + ).image_embeds_proj[:, 0] + text_features = feature_extractor.extract_features( + sample, mode="text" + ).text_embeds_proj[:, 0] + + st.write(image_features.shape) + st.write(text_features.shape) + + sims = (image_features @ text_features.t())[0] / feature_extractor.temp + + sims = torch.nn.Softmax(dim=0)(sims) + inv_sims = [sim * 100 for sim in sims.tolist()[::-1]] + + elif model_type.startswith("CLIP"): + if model_type == "CLIP_ViT-B-32": + model = load_model_cache(model_type="CLIP_ViT-B-32", device=device) + elif model_type == "CLIP_ViT-B-16": + model = load_model_cache(model_type="CLIP_ViT-B-16", device=device) + elif model_type == "CLIP_ViT-L-14": + model = load_model_cache(model_type="CLIP_ViT-L-14", device=device) + else: + raise ValueError(f"Unknown model type {model_type}") + + if score_type == "Cosine": + # image_preprocess = ClipImageEvalProcessor(image_size=336) + image_preprocess = ClipImageEvalProcessor(image_size=224) + img = image_preprocess(raw_img).unsqueeze(0).to(device) + + sample = {"image": img, "text_input": cls_names} + + with torch.no_grad(): + clip_features = model.extract_features(sample) + + image_features = clip_features.image_embeds_proj + text_features = clip_features.text_embeds_proj + + sims = (100.0 * image_features @ text_features.T)[0].softmax(dim=-1) + inv_sims = sims.tolist()[::-1] + else: + st.warning("CLIP does not support multimodal scoring.") + return + + fig = go.Figure( + go.Bar( + x=inv_sims, + y=cls_names[::-1], + text=["{:.2f}".format(s) for s in inv_sims], + orientation="h", + ) + ) + fig.update_traces( + textfont_size=12, + textangle=0, + textposition="outside", + cliponaxis=False, + ) + col2.plotly_chart(fig, use_container_width=True) diff --git a/app/dataset_browser.py b/app/dataset_browser.py new file mode 100644 index 0000000000000000000000000000000000000000..6b761d899731940b8963c8894473848359418a74 --- /dev/null +++ b/app/dataset_browser.py @@ -0,0 +1,240 @@ +""" + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import random +from collections import OrderedDict +from functools import reduce +from tkinter import N + +import streamlit as st +from lavis.common.registry import registry +from lavis.datasets.builders import dataset_zoo, load_dataset +from lavis.datasets.builders.base_dataset_builder import load_dataset_config +from PIL import Image + +IMAGE_LAYOUT = 3, 4 +VIDEO_LAYOUT = 1, 2 + +PREV_STR = "Prev" +NEXT_STR = "Next" + + +def sample_dataset(dataset, indices): + samples = [dataset.displ_item(idx) for idx in indices] + + return samples + + +def get_concat_v(im1, im2): + margin = 5 + + canvas_size = (im1.width + im2.width + margin, max(im1.height, im2.height)) + canvas = Image.new("RGB", canvas_size, "White") + canvas.paste(im1, (0, 0)) + canvas.paste(im2, (im1.width + margin, 0)) + + return canvas + + +def resize_img_w(raw_img, new_w=224): + if isinstance(raw_img, list): + resized_imgs = [resize_img_w(img, 196) for img in raw_img] + # concatenate images + resized_image = reduce(get_concat_v, resized_imgs) + else: + w, h = raw_img.size + scaling_factor = new_w / w + resized_image = raw_img.resize( + (int(w * scaling_factor), int(h * scaling_factor)) + ) + + return resized_image + + +def get_visual_key(dataset): + if "image" in dataset[0]: + return "image" + elif "image0" in dataset[0]: # NLVR2 dataset + return "image" + elif "video" in dataset[0]: + return "video" + else: + raise ValueError("Visual key not found.") + + +def gather_items(samples, exclude=[]): + gathered = [] + + for s in samples: + ns = OrderedDict() + for k in s.keys(): + if k not in exclude: + ns[k] = s[k] + + gathered.append(ns) + + return gathered + + +@st.cache(allow_output_mutation=True) +def load_dataset_cache(name): + return load_dataset(name) + + +def format_text(text): + md = "\n\n".join([f"**{k}**: {v}" for k, v in text.items()]) + + return md + + +def show_samples(dataset, offset=0, is_next=False): + visual_key = get_visual_key(dataset) + + num_rows, num_cols = IMAGE_LAYOUT if visual_key == "image" else VIDEO_LAYOUT + n_samples = num_rows * num_cols + + if not shuffle: + if is_next: + start = min(int(start_idx) + offset + n_samples, len(dataset) - n_samples) + else: + start = max(0, int(start_idx) + offset - n_samples) + + st.session_state.last_start = start + end = min(start + n_samples, len(dataset)) + + indices = list(range(start, end)) + else: + indices = random.sample(range(len(dataset)), n_samples) + samples = sample_dataset(dataset, indices) + + visual_info = ( + iter([resize_img_w(s[visual_key]) for s in samples]) + if visual_key == "image" + # else iter([s[visual_key] for s in samples]) + else iter([s["file"] for s in samples]) + ) + text_info = gather_items(samples, exclude=["image", "video"]) + text_info = iter([format_text(s) for s in text_info]) + + st.markdown( + """
""", + unsafe_allow_html=True, + ) + for _ in range(num_rows): + with st.container(): + for col in st.columns(num_cols): + # col.text(next(text_info)) + # col.caption(next(text_info)) + try: + col.markdown(next(text_info)) + if visual_key == "image": + col.image(next(visual_info), use_column_width=True, clamp=True) + elif visual_key == "video": + col.markdown( + "![Alt Text](https://media.giphy.com/media/vFKqnCdLPNOKc/giphy.gif)" + ) + except StopIteration: + break + + st.markdown( + """
""", + unsafe_allow_html=True, + ) + + st.session_state.n_display = n_samples + + +if __name__ == "__main__": + st.set_page_config( + page_title="LAVIS Dataset Explorer", + # layout="wide", + initial_sidebar_state="expanded", + ) + + dataset_name = st.sidebar.selectbox("Dataset:", dataset_zoo.get_names()) + + function = st.sidebar.selectbox("Function:", ["Browser"], index=0) + + if function == "Browser": + shuffle = st.sidebar.selectbox("Shuffled:", [True, False], index=0) + + dataset = load_dataset_cache(dataset_name) + split = st.sidebar.selectbox("Split:", dataset.keys()) + + dataset_len = len(dataset[split]) + st.success( + f"Loaded {dataset_name}/{split} with **{dataset_len}** records. **Image/video directory**: {dataset[split].vis_root}" + ) + + if "last_dataset" not in st.session_state: + st.session_state.last_dataset = dataset_name + st.session_state.last_split = split + + if "last_start" not in st.session_state: + st.session_state.last_start = 0 + + if "start_idx" not in st.session_state: + st.session_state.start_idx = 0 + + if "shuffle" not in st.session_state: + st.session_state.shuffle = shuffle + + if "first_run" not in st.session_state: + st.session_state.first_run = True + elif ( + st.session_state.last_dataset != dataset_name + or st.session_state.last_split != split + ): + st.session_state.first_run = True + + st.session_state.last_dataset = dataset_name + st.session_state.last_split = split + elif st.session_state.shuffle != shuffle: + st.session_state.shuffle = shuffle + st.session_state.first_run = True + + if not shuffle: + n_col, p_col = st.columns([0.05, 1]) + + prev_button = n_col.button(PREV_STR) + next_button = p_col.button(NEXT_STR) + + else: + next_button = st.button(NEXT_STR) + + if not shuffle: + start_idx = st.sidebar.text_input(f"Begin from (total {dataset_len})", 0) + + if not start_idx.isdigit(): + st.error(f"Input to 'Begin from' must be digits, found {start_idx}.") + else: + if int(start_idx) != st.session_state.start_idx: + st.session_state.start_idx = int(start_idx) + st.session_state.last_start = int(start_idx) + + if prev_button: + show_samples( + dataset[split], + offset=st.session_state.last_start - st.session_state.start_idx, + is_next=False, + ) + + if next_button: + show_samples( + dataset[split], + offset=st.session_state.last_start - st.session_state.start_idx, + is_next=True, + ) + + if st.session_state.first_run: + st.session_state.first_run = False + + show_samples( + dataset[split], + offset=st.session_state.last_start - st.session_state.start_idx, + is_next=True, + ) diff --git a/app/image_text_match.py b/app/image_text_match.py new file mode 100644 index 0000000000000000000000000000000000000000..e7957384e6ebba19acf47658f6c97446b54b3aeb --- /dev/null +++ b/app/image_text_match.py @@ -0,0 +1,87 @@ +""" + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import numpy as np +import streamlit as st +import torch +from lavis.models.blip_models.blip_image_text_matching import compute_gradcam +from lavis.processors import load_processor +from PIL import Image + +from app import device, load_demo_image +from app.utils import getAttMap, init_bert_tokenizer, load_blip_itm_model + + +def app(): + model_type = st.sidebar.selectbox("Model:", ["BLIP_base", "BLIP_large"]) + + if model_type.startswith("BLIP"): + blip_type = model_type.split("_")[1] + model = load_blip_itm_model(device, model_type=blip_type) + + vis_processor = load_processor("blip_image_eval").build(image_size=384) + + st.markdown( + "

Image Text Matching

", + unsafe_allow_html=True, + ) + + values = list(range(1, 12)) + default_layer_num = values.index(7) + layer_num = ( + st.sidebar.selectbox("Layer number", values, index=default_layer_num) - 1 + ) + + instructions = """Try the provided image or upload your own:""" + file = st.file_uploader(instructions) + + col1, col2 = st.columns(2) + col1.header("Image") + col2.header("GradCam") + if file: + raw_img = Image.open(file).convert("RGB") + else: + raw_img = load_demo_image() + + w, h = raw_img.size + scaling_factor = 720 / w + resized_image = raw_img.resize((int(w * scaling_factor), int(h * scaling_factor))) + col1.image(resized_image, use_column_width=True) + + col3, col4 = st.columns(2) + col3.header("Text") + user_question = col3.text_input( + "Input your sentence!", "a woman sitting on the beach with a dog" + ) + submit_button = col3.button("Submit") + + col4.header("Matching score") + + if submit_button: + tokenizer = init_bert_tokenizer() + + img = vis_processor(raw_img).unsqueeze(0).to(device) + text_processor = load_processor("blip_caption").build() + + qry = text_processor(user_question) + + norm_img = np.float32(resized_image) / 255 + + qry_tok = tokenizer(qry, return_tensors="pt").to(device) + gradcam, output = compute_gradcam(model, img, qry, qry_tok, block_num=layer_num) + + avg_gradcam = getAttMap(norm_img, gradcam[0][1], blur=True) + + col2.image(avg_gradcam, use_column_width=True, clamp=True) + # output = model(img, question) + itm_score = torch.nn.functional.softmax(output, dim=1) + new_title = ( + '

\n{:.3f}%

'.format( + itm_score[0][1].item() * 100 + ) + ) + col4.markdown(new_title, unsafe_allow_html=True) diff --git a/app/main.py b/app/main.py new file mode 100644 index 0000000000000000000000000000000000000000..108c46f8cce738499adeb3a65091f3b1919563e0 --- /dev/null +++ b/app/main.py @@ -0,0 +1,25 @@ +""" + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from app.multipage import MultiPage +from app import vqa, caption +from app import image_text_match as itm +from app import text_localization as tl +from app import multimodal_search as ms +from app import classification as cl + + +if __name__ == "__main__": + app = MultiPage() + + app.add_page("Image Description Generation", caption.app) + app.add_page("Multimodal Search", ms.app) + app.add_page("Visual Question Answering", vqa.app) + app.add_page("Image Text Matching", itm.app) + app.add_page("Text Localization", tl.app) + app.add_page("Classification", cl.app) + app.run() diff --git a/app/multimodal_search.py b/app/multimodal_search.py new file mode 100644 index 0000000000000000000000000000000000000000..ffc9766429e4922eac34db8b643445d3bc1622a3 --- /dev/null +++ b/app/multimodal_search.py @@ -0,0 +1,230 @@ +""" + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import os + +import numpy as np +import streamlit as st +import torch +import torch.nn.functional as F +from app import cache_root, device +from app.utils import ( + getAttMap, + init_bert_tokenizer, + load_blip_itm_model, + read_img, + resize_img, +) +from lavis.models import load_model +from lavis.processors import load_processor + + +@st.cache( + hash_funcs={ + torch.nn.parameter.Parameter: lambda parameter: parameter.data.detach() + .cpu() + .numpy() + }, + allow_output_mutation=True, +) +def load_feat(): + from lavis.common.utils import download_url + + dirname = os.path.join(os.path.dirname(__file__), "assets") + filename = "path2feat_coco_train2014.pth" + filepath = os.path.join(dirname, filename) + url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/path2feat_coco_train2014.pth" + + if not os.path.exists(filepath): + download_url(url=url, root=dirname, filename="path2feat_coco_train2014.pth") + + path2feat = torch.load(filepath) + paths = sorted(path2feat.keys()) + + all_img_feats = torch.stack([path2feat[k] for k in paths], dim=0).to(device) + + return path2feat, paths, all_img_feats + + +@st.cache( + hash_funcs={ + torch.nn.parameter.Parameter: lambda parameter: parameter.data.detach() + .cpu() + .numpy() + }, + allow_output_mutation=True, +) +def load_feature_extractor_model(device): + model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth" + + model = load_model( + "blip_feature_extractor", model_type="base", is_eval=True, device=device + ) + model.load_from_pretrained(model_url) + + return model + + +def app(): + # === layout === + model_type = st.sidebar.selectbox("Model:", ["BLIP_base", "BLIP_large"]) + file_root = os.path.join(cache_root, "coco/images/train2014/") + + values = [12, 24, 48] + default_layer_num = values.index(24) + num_display = st.sidebar.selectbox( + "Number of images:", values, index=default_layer_num + ) + show_gradcam = st.sidebar.selectbox("Show GradCam:", [True, False], index=1) + itm_ranking = st.sidebar.selectbox("Multimodal re-ranking:", [True, False], index=0) + + # st.title('Multimodal Search') + st.markdown( + "

Multimodal Search

", unsafe_allow_html=True + ) + + # === event === + vis_processor = load_processor("blip_image_eval").build(image_size=384) + text_processor = load_processor("blip_caption") + + user_question = st.text_input( + "Search query", "A dog running on the grass.", help="Type something to search." + ) + user_question = text_processor(user_question) + feature_extractor = load_feature_extractor_model(device) + + # ======= ITC ========= + sample = {"text_input": user_question} + + with torch.no_grad(): + text_feature = feature_extractor.extract_features( + sample, mode="text" + ).text_embeds_proj[0, 0] + + path2feat, paths, all_img_feats = load_feat() + all_img_feats.to(device) + all_img_feats = F.normalize(all_img_feats, dim=1) + + num_cols = 4 + num_rows = int(num_display / num_cols) + + similarities = text_feature @ all_img_feats.T + indices = torch.argsort(similarities, descending=True)[:num_display] + + top_paths = [paths[ind.detach().cpu().item()] for ind in indices] + sorted_similarities = [similarities[idx] for idx in indices] + filenames = [os.path.join(file_root, p) for p in top_paths] + + # ========= ITM and GradCam ========== + bsz = 4 # max number of images to avoid cuda oom + if model_type.startswith("BLIP"): + blip_type = model_type.split("_")[1] + + itm_model = load_blip_itm_model(device, model_type=blip_type) + + tokenizer = init_bert_tokenizer() + queries_batch = [user_question] * bsz + queries_tok_batch = tokenizer(queries_batch, return_tensors="pt").to(device) + + num_batches = int(num_display / bsz) + + avg_gradcams = [] + all_raw_images = [] + itm_scores = [] + + for i in range(num_batches): + filenames_in_batch = filenames[i * bsz : (i + 1) * bsz] + raw_images, images = read_and_process_images(filenames_in_batch, vis_processor) + gradcam, itm_output = compute_gradcam_batch( + itm_model, images, queries_batch, queries_tok_batch + ) + + all_raw_images.extend([resize_img(r_img) for r_img in raw_images]) + norm_imgs = [np.float32(r_img) / 255 for r_img in raw_images] + + for norm_img, grad_cam in zip(norm_imgs, gradcam): + avg_gradcam = getAttMap(norm_img, grad_cam[0], blur=True) + avg_gradcams.append(avg_gradcam) + + with torch.no_grad(): + itm_score = torch.nn.functional.softmax(itm_output, dim=1) + + itm_scores.append(itm_score) + + # ========= ITM re-ranking ========= + itm_scores = torch.cat(itm_scores)[:, 1] + if itm_ranking: + itm_scores_sorted, indices = torch.sort(itm_scores, descending=True) + + avg_gradcams_sorted = [] + all_raw_images_sorted = [] + for idx in indices: + avg_gradcams_sorted.append(avg_gradcams[idx]) + all_raw_images_sorted.append(all_raw_images[idx]) + + avg_gradcams = avg_gradcams_sorted + all_raw_images = all_raw_images_sorted + + if show_gradcam: + images_to_show = iter(avg_gradcams) + else: + images_to_show = iter(all_raw_images) + + for _ in range(num_rows): + with st.container(): + for col in st.columns(num_cols): + col.image(next(images_to_show), use_column_width=True, clamp=True) + + +def read_and_process_images(image_paths, vis_processor): + raw_images = [read_img(path) for path in image_paths] + images = [vis_processor(r_img) for r_img in raw_images] + images_tensors = torch.stack(images).to(device) + + return raw_images, images_tensors + + +def compute_gradcam_batch(model, visual_input, text_input, tokenized_text, block_num=6): + model.text_encoder.base_model.base_model.encoder.layer[ + block_num + ].crossattention.self.save_attention = True + + output = model({"image": visual_input, "text_input": text_input}, match_head="itm") + loss = output[:, 1].sum() + + model.zero_grad() + loss.backward() + with torch.no_grad(): + mask = tokenized_text.attention_mask.view( + tokenized_text.attention_mask.size(0), 1, -1, 1, 1 + ) # (bsz,1,token_len, 1,1) + token_length = mask.sum() - 2 + token_length = token_length.cpu() + # grads and cams [bsz, num_head, seq_len, image_patch] + grads = model.text_encoder.base_model.base_model.encoder.layer[ + block_num + ].crossattention.self.get_attn_gradients() + cams = model.text_encoder.base_model.base_model.encoder.layer[ + block_num + ].crossattention.self.get_attention_map() + + # assume using vit large with 576 num image patch + cams = cams[:, :, :, 1:].reshape(visual_input.size(0), 12, -1, 24, 24) * mask + grads = ( + grads[:, :, :, 1:].clamp(0).reshape(visual_input.size(0), 12, -1, 24, 24) + * mask + ) + + gradcam = cams * grads + # [enc token gradcam, average gradcam across token, gradcam for individual token] + # gradcam = torch.cat((gradcam[0:1,:], gradcam[1:token_length+1, :].sum(dim=0, keepdim=True)/token_length, gradcam[1:, :])) + gradcam = gradcam.mean(1).cpu().detach() + gradcam = ( + gradcam[:, 1 : token_length + 1, :].sum(dim=1, keepdim=True) / token_length + ) + + return gradcam, output diff --git a/app/multipage.py b/app/multipage.py new file mode 100644 index 0000000000000000000000000000000000000000..040f76ebd2f86d7ded9e8a224a20ce779862c607 --- /dev/null +++ b/app/multipage.py @@ -0,0 +1,41 @@ +""" + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +""" +This file is the framework for generating multiple Streamlit applications +through an object oriented framework. +""" + +# Import necessary libraries +import streamlit as st + +# Define the multipage class to manage the multiple apps in our program +class MultiPage: + """Framework for combining multiple streamlit applications.""" + + def __init__(self) -> None: + """Constructor class to generate a list which will store all our applications as an instance variable.""" + self.pages = [] + + def add_page(self, title, func) -> None: + """Class Method to Add pages to the project + Args: + title ([str]): The title of page which we are adding to the list of apps + + func: Python function to render this page in Streamlit + """ + + self.pages.append({"title": title, "function": func}) + + def run(self): + # Drodown to select the page to run + page = st.sidebar.selectbox( + "Navigation", self.pages, format_func=lambda page: page["title"] + ) + + # run the app function + page["function"]() diff --git a/app/text_localization.py b/app/text_localization.py new file mode 100644 index 0000000000000000000000000000000000000000..d01655b97d7c0e495caf42c81c83b59e1bc3c811 --- /dev/null +++ b/app/text_localization.py @@ -0,0 +1,105 @@ +""" + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import math + +import numpy as np +import streamlit as st +from lavis.models.blip_models.blip_image_text_matching import compute_gradcam +from lavis.processors import load_processor +from PIL import Image + +from app import device, load_demo_image +from app.utils import getAttMap, init_bert_tokenizer, load_blip_itm_model + + +def app(): + model_type = st.sidebar.selectbox("Model:", ["BLIP_base", "BLIP_large"]) + + values = list(range(1, 12)) + default_layer_num = values.index(7) + layer_num = ( + st.sidebar.selectbox("Layer number", values, index=default_layer_num) - 1 + ) + + st.markdown( + "

Text Localization

", unsafe_allow_html=True + ) + + vis_processor = load_processor("blip_image_eval").build(image_size=384) + text_processor = load_processor("blip_caption") + + tokenizer = init_bert_tokenizer() + + instructions = "Try the provided image and text or use your own ones." + file = st.file_uploader(instructions) + + query = st.text_input( + "Try a different input.", "A girl playing with her dog on the beach." + ) + + submit_button = st.button("Submit") + + col1, col2 = st.columns(2) + + if file: + raw_img = Image.open(file).convert("RGB") + else: + raw_img = load_demo_image() + + col1.header("Image") + w, h = raw_img.size + scaling_factor = 720 / w + resized_image = raw_img.resize((int(w * scaling_factor), int(h * scaling_factor))) + col1.image(resized_image, use_column_width=True) + + col2.header("GradCam") + + if submit_button: + if model_type.startswith("BLIP"): + blip_type = model_type.split("_")[1] + model = load_blip_itm_model(device, model_type=blip_type) + + img = vis_processor(raw_img).unsqueeze(0).to(device) + qry = text_processor(query) + + qry_tok = tokenizer(qry, return_tensors="pt").to(device) + + norm_img = np.float32(resized_image) / 255 + + gradcam, _ = compute_gradcam(model, img, qry, qry_tok, block_num=layer_num) + + avg_gradcam = getAttMap(norm_img, gradcam[0][1], blur=True) + col2.image(avg_gradcam, use_column_width=True, clamp=True) + + num_cols = 4.0 + num_tokens = len(qry_tok.input_ids[0]) - 2 + + num_rows = int(math.ceil(num_tokens / num_cols)) + + gradcam_iter = iter(gradcam[0][2:-1]) + token_id_iter = iter(qry_tok.input_ids[0][1:-1]) + + for _ in range(num_rows): + with st.container(): + for col in st.columns(int(num_cols)): + token_id = next(token_id_iter, None) + if not token_id: + break + gradcam_img = next(gradcam_iter) + + word = tokenizer.decode([token_id]) + gradcam_todraw = getAttMap(norm_img, gradcam_img, blur=True) + + new_title = ( + '

{}

'.format( + word + ) + ) + col.markdown(new_title, unsafe_allow_html=True) + # st.image(image, channels="BGR") + col.image(gradcam_todraw, use_column_width=True, clamp=True) diff --git a/app/utils.py b/app/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5a4f209d6b90f6747f4f0a090276d5032c1049db --- /dev/null +++ b/app/utils.py @@ -0,0 +1,81 @@ +""" + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import numpy as np +import streamlit as st +import torch +from lavis.models import BlipBase, load_model +from matplotlib import pyplot as plt +from PIL import Image +from scipy.ndimage import filters +from skimage import transform as skimage_transform + + +def resize_img(raw_img): + w, h = raw_img.size + scaling_factor = 240 / w + resized_image = raw_img.resize((int(w * scaling_factor), int(h * scaling_factor))) + return resized_image + + +def read_img(filepath): + raw_image = Image.open(filepath).convert("RGB") + + return raw_image + + +@st.cache( + hash_funcs={ + torch.nn.parameter.Parameter: lambda parameter: parameter.data.detach() + .cpu() + .numpy() + }, + allow_output_mutation=True, +) +def load_model_cache(name, model_type, is_eval, device): + return load_model(name, model_type, is_eval, device) + + +@st.cache(allow_output_mutation=True) +def init_bert_tokenizer(): + tokenizer = BlipBase.init_tokenizer() + return tokenizer + + +def getAttMap(img, attMap, blur=True, overlap=True): + attMap -= attMap.min() + if attMap.max() > 0: + attMap /= attMap.max() + attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant") + if blur: + attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2])) + attMap -= attMap.min() + attMap /= attMap.max() + cmap = plt.get_cmap("jet") + attMapV = cmap(attMap) + attMapV = np.delete(attMapV, 3, 2) + if overlap: + attMap = ( + 1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img + + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV + ) + return attMap + + +@st.cache( + hash_funcs={ + torch.nn.parameter.Parameter: lambda parameter: parameter.data.detach() + .cpu() + .numpy() + }, + allow_output_mutation=True, +) +def load_blip_itm_model(device, model_type="base"): + model = load_model( + "blip_image_text_matching", model_type, is_eval=True, device=device + ) + return model diff --git a/app/vqa.py b/app/vqa.py new file mode 100644 index 0000000000000000000000000000000000000000..c505a985d2450a4a2065faca67a1e8974e899c95 --- /dev/null +++ b/app/vqa.py @@ -0,0 +1,63 @@ +""" + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import streamlit as st +from app import load_demo_image, device +from app.utils import load_model_cache +from lavis.processors import load_processor +from PIL import Image + + +def app(): + model_type = st.sidebar.selectbox("Model:", ["BLIP"]) + + # ===== layout ===== + st.markdown( + "

Visual Question Answering

", + unsafe_allow_html=True, + ) + + instructions = """Try the provided image or upload your own:""" + file = st.file_uploader(instructions) + + col1, col2 = st.columns(2) + + col1.header("Image") + if file: + raw_img = Image.open(file).convert("RGB") + else: + raw_img = load_demo_image() + + w, h = raw_img.size + scaling_factor = 720 / w + resized_image = raw_img.resize((int(w * scaling_factor), int(h * scaling_factor))) + + col1.image(resized_image, use_column_width=True) + col2.header("Question") + + user_question = col2.text_input("Input your question!", "What are objects there?") + qa_button = st.button("Submit") + + col2.header("Answer") + + # ===== event ===== + vis_processor = load_processor("blip_image_eval").build(image_size=480) + text_processor = load_processor("blip_question").build() + + if qa_button: + if model_type.startswith("BLIP"): + model = load_model_cache( + "blip_vqa", model_type="vqav2", is_eval=True, device=device + ) + + img = vis_processor(raw_img).unsqueeze(0).to(device) + question = text_processor(user_question) + + vqa_samples = {"image": img, "text_input": [question]} + answers = model.predict_answers(vqa_samples, inference_method="generate") + + col2.write("\n".join(answers), use_column_width=True) diff --git a/assets/.DS_Store b/assets/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..203422b51a8d699aae8877abaa4427adac9b1717 Binary files /dev/null and b/assets/.DS_Store differ diff --git a/assets/chain.png b/assets/chain.png new file mode 100644 index 0000000000000000000000000000000000000000..2b8082e4b623046824faefd50e399ede1265f92d Binary files /dev/null and b/assets/chain.png differ diff --git a/assets/model.png b/assets/model.png new file mode 100644 index 0000000000000000000000000000000000000000..ae5e046f6f9da4ba497943ee51313195a93f4d51 Binary files /dev/null and b/assets/model.png differ diff --git a/assets/teaser.png b/assets/teaser.png new file mode 100644 index 0000000000000000000000000000000000000000..ae9ec56778ca181ee1f373be37cb29554712d934 Binary files /dev/null and b/assets/teaser.png differ diff --git a/docs/.DS_Store b/docs/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..84cac25de99fc9a2f48a27359a8cb468927a90e1 Binary files /dev/null and b/docs/.DS_Store differ diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..d0c3cbf1020d5c292abdedf27627c6abe25e2293 --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/_static/.DS_Store b/docs/_static/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..fb8647072d886037512d8810bac076e2c7675e4e Binary files /dev/null and b/docs/_static/.DS_Store differ diff --git a/docs/_static/architecture.png b/docs/_static/architecture.png new file mode 100755 index 0000000000000000000000000000000000000000..817e139e8141797502b5329e910994a1d96241f8 Binary files /dev/null and b/docs/_static/architecture.png differ diff --git a/docs/_static/logo_final.png b/docs/_static/logo_final.png new file mode 100644 index 0000000000000000000000000000000000000000..632d094422f3e3ffc3b6ecfaec8bc2df3b9e2a42 Binary files /dev/null and b/docs/_static/logo_final.png differ diff --git a/docs/benchmark.rst b/docs/benchmark.rst new file mode 100644 index 0000000000000000000000000000000000000000..eeefa34dfe049981f6330bc011057bec26dc945d --- /dev/null +++ b/docs/benchmark.rst @@ -0,0 +1,348 @@ +Benchmark +############ + +We provide scripts for evaluating and training models on task datasets. The following benchmark results are included for reference. + + +ALBEF +******* +.. list-table:: + :widths: 30 80 20 + + * - **Pretraining** + - COCO (`download `__) + - `script `__ + * - + - Visual Genome (`download `__) + - + * - + - SBU (`download `__) + - + * - + - CC3M (`download `__) + - + * - + - CC12M (`download `__) + - + +.. list-table:: + :widths: 30 40 20 20 20 30 30 + :header-rows: 1 + + * - + - **Retrieval** + - **R1** + - **R5** + - **R10** + - **Training** + - **Evaluation** + * - TR + - COCO (`download `__) + - 77.6 + - 94.1 + - 97.2 + - `script `__ + - `script `__ + * - IR + - COCO (`download `__) + - 61.0 + - 84.5 + - 90.7 + - `script `__ + - `script `__ + * - TR + - Flickr30k (`download `__) + - 77.6 + - 94.1 + - 97.2 + - `script `__ + - `script `__ + * - IR + - Flickr30k (`download `__) + - 61.0 + - 84.5 + - 90.7 + - `script `__ + - `script `__ + + +.. list-table:: + :widths: 20 20 20 20 20 + :header-rows: 1 + + * - **VQA** + - **test-dev** + - **test-std/test** + - **Training** + - **Evaluation** + * - VQAv2 (`download `__) + - 76.35 + - 76.54 + - `script `__ + - `script `__ + * - OKVQA (`download `__) + - NA + - 54.7 + - `script `__ + - NA + * - AOKVQA (`download `__) + - 54.5 + - NA + - `script `__ + - NA + + +.. list-table:: + :widths: 20 20 20 20 20 + :header-rows: 1 + + * - **Multimodal Classification** + - **val** + - **test** + - **Training** + - **Evaluation** + * - SNLI-VE (`download `__) + - 80.60 + - 81.04 + - `script `__ + - `script `__ + * - NLVR2 (`download `__) + - 82.47 + - 82.91 + - `script `__ + - `script `__ + +BLIP +******* +.. list-table:: + :widths: 30 80 20 + + * - **Pretraining (14M)** + - COCO (`download `__) + - `script `__ + * - + - Visual Genome (`download `__) + - + * - + - SBU (`download `__) + - + * - + - CC3M (`download `__) + - + * - + - CC12M (`download `__) + - + +.. list-table:: + :widths: 30 40 20 20 20 30 30 + :header-rows: 1 + + * - **Tasks** + - **Retrieval** + - **R1** + - **R5** + - **R10** + - **Training** + - **Evaluation** + * - TR + - COCO (`download `__) + - 82.0 + - 95.8 + - 98.1 + - `script `__ + - `script `__ + * - IR + - COCO (`download `__) + - 64.5 + - 86.0 + - 91.7 + - `script `__ + - `script `__ + * - TR + - Flickr30k (`download `__) + - 96.9 + - 99.9 + - 100.0 + - `script `__ + - `script `__ + * - IR + - Flickr30k (`download `__) + - 87.5 + - 97.6 + - 98.9 + - `script `__ + - `script `__ + + +.. list-table:: + :widths: 20 20 20 20 20 + :header-rows: 1 + + * - **VQA** + - **test-dev** + - **test-std/test** + - **Training** + - **Evaluation** + * - VQAv2 (`download `__) + - 78.23 + - 78.29 + - `script `__ + - `script `__ + * - OKVQA (`download `__) + - NA + - 55.4 + - `script `__ + - `script `__ + * - AOKVQA (`download `__) + - 56.2 + - 50.1 + - `script `__ + - `script `__ + + +.. list-table:: + :widths: 20 20 20 20 20 20 + :header-rows: 1 + + * - **Image Captioning** + - **BLEU@4** + - **CIDEr** + - **SPICE** + - **Training** + - **Evaluation** + * - COCO (`download `__) + - 39.9 + - 133.5 + - 23.7 + - `script `__ + - `script `__ + * - NoCaps (`download `__) + - 31.9 + - 109.1 + - 14.7 + - NA + - `script `__ + + +.. list-table:: + :widths: 20 20 20 20 20 + :header-rows: 1 + + * - **Multimodal Classification** + - **val** + - **test** + - **Training** + - **Evaluation** + * - NLVR2 (`download `__) + - 82.48 + - 83.25 + - `script `__ + - `script `__ + +CLIP +******* +.. list-table:: + :widths: 30 40 20 20 20 30 + :header-rows: 1 + + * - **Tasks** + - **Retrieval (Zero-shot)** + - **R1** + - **R5** + - **R10** + - **Evaluation** + * - TR + - COCO (`download `__) + - 57.2 + - 80.5 + - 87.8 + - `script `__ + * - IR + - COCO (`download `__) + - 36.5 + - 60.8 + - 71.0 + - `script `__ + * - TR + - Flickr30k (`download `__) + - 86.5 + - 98.0 + - 99.1 + - `script `__ + * - IR + - Flickr30k (`download `__) + - 67.0 + - 88.9 + - 93.3 + - `script `__ + +.. list-table:: + :widths: 20 20 20 + :header-rows: 1 + + * - **Multimodal Classification** + - **val** + - **Evaluation** + * - ImageNet + - 76.5 + - `script `__ + + +ALPRO +******* +.. list-table:: + :widths: 30 40 20 20 20 20 30 + :header-rows: 1 + + * - **Tasks** + - **Retrieval** + - **R1** + - **R5** + - **R10** + - **Training** + - **Evaluation** + * - TR + - MSRVTT (`download `__) + - 33.2 + - 60.5 + - 71.7 + - `script `__ + - `script `__ + * - VR + - MSRVTT (`download `__) + - 33.8 + - 61.4 + - 72.7 + - `script `__ + - `script `__ + * - TR + - DiDeMo (`download `__) + - 38.8 + - 66.4 + - 76.8 + - `script `__ + - `script `__ + * - VR + - DiDeMo (`download `__) + - 36.6 + - 67.5 + - 77.9 + - `script `__ + - `script `__ + +.. list-table:: + :widths: 20 20 20 20 + :header-rows: 1 + + * - **Video QA** + - **test** + - **Training** + - **Evaluation** + * - MSRVTT + - 42.1 + - `script `__ + - `script `__ + * - MSVD + - 46.0 + - `script `__ + - `script `__ \ No newline at end of file diff --git a/docs/build_docs.sh b/docs/build_docs.sh new file mode 100755 index 0000000000000000000000000000000000000000..122172f753f29d49839de55a4555365c6ab0c020 --- /dev/null +++ b/docs/build_docs.sh @@ -0,0 +1,101 @@ +#!/bin/bash +set -euo pipefail + +# Change to root directory of repo +DIRNAME=$(cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +cd "${DIRNAME}/.." + +# # Set up virtual environment +pip3 install setuptools wheel virtualenv +if [ ! -d venv ]; then + rm -f venv + virtualenv venv +fi +source venv/bin/activate + +# # Get current git branch & stash unsaved changes +GIT_BRANCH=$(git branch --show-current) +if [ -z "${GIT_BRANCH}" ]; then + GIT_BRANCH="main" +fi +git stash + +# Set up exit handler to restore git state & delete temp branches +# function exit_handler { +# git reset --hard +# git checkout "${GIT_BRANCH}" -- +# git stash pop || true +# for version in $(git tag --list 'v[0-9]*'); do +# branch="${version}_local_docs_only" +# if git show-ref --verify --quiet "refs/heads/$branch"; then +# git branch -D "$branch" +# fi +# done +# } +# trap exit_handler EXIT + +# Clean up build directory and install Sphinx requirements +pip3 install -r "${DIRNAME}/requirements.txt" +sphinx-build -M clean "${DIRNAME}" "${DIRNAME}/_build" + +# Build API docs for current head +export current_version="latest" +pip3 install "." +sphinx-build -b html "${DIRNAME}" "${DIRNAME}/_build/html/${current_version}" -W --keep-going +rm -rf "${DIRNAME}/_build/html/${current_version}/.doctrees" +#pip3 uninstall -y omnixai + +# Install all previous released versions +# and use them to build the appropriate API docs. +# Uninstall after we're done with each one. +# versions=() +# checkout_files=("${DIRNAME}/*.rst" "lavis" "tutorials" "setup.py") +# for version in $(git tag --list 'v[0-9]*'); do +# versions+=("$version") +# git checkout -b "${version}_local_docs_only" +# for f in $(git diff --name-only --diff-filter=A "tags/${version}" "${DIRNAME}/*.rst"); do +# git rm "$f" +# done +# git checkout "tags/${version}" -- "${checkout_files[@]}" +# export current_version=${version} +# pip3 install ".[all]" +# sphinx-build -b html "${DIRNAME}" "${DIRNAME}/_build/html/${current_version}" -W --keep-going +# rm -rf "${DIRNAME}/_build/html/${current_version}/.doctrees" +# #pip3 uninstall -y omnixai +# git reset --hard +# git checkout "${GIT_BRANCH}" -- +# done + +# Determine the latest stable version if there is one +# if (( ${#versions[@]} > 0 )); then +# stable_hash=$(git rev-list --tags --max-count=1) +# stable_version=$(git describe --tags "$stable_hash") +# export stable_version +# else +export stable_version="latest" +# fi + +# Create dummy HTML's for the stable version in the base directory +while read -r filename; do + filename=$(echo "$filename" | sed "s/\.\///") + n_sub=$(echo "$filename" | (grep -o "/" || true) | wc -l) + prefix="" + for (( i=0; i "${DIRNAME}/_build/html/$filename" < + + + LAVIS Documentation + + + +

Please wait while you're redirected to our documentation.

+ + +EOF +done < <(cd "${DIRNAME}/_build/html/$stable_version" && find . -name "*.html") +echo "Finished writing to _build/html." \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000000000000000000000000000000000000..ad4f8ab1af27d21f1a1a51e57e5bb49cc4485c05 --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,56 @@ +# Configuration file for the Sphinx documentation builder. +# +# This file only contains a selection of the most common options. For a full +# list see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +# import os +# import sys +# sys.path.insert(0, os.path.abspath('.')) + + +# -- Project information ----------------------------------------------------- + +project = "LAVIS" +copyright = "2022, salesforce.com inc." +author = ( + "Dongxu Li, Junnan Li, Hung Le, Guangsen Wang, Silvio Savarese, Steven C.H. Hoi" +) + + +# -- General configuration --------------------------------------------------- + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = ["nbsphinx"] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ["_templates"] + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns = [] + + +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +# html_theme = "alabaster" +html_theme = "sphinx_rtd_theme" + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ["_static"] + +# pygments_style = "sphinx" diff --git a/docs/getting_started.rst b/docs/getting_started.rst new file mode 100644 index 0000000000000000000000000000000000000000..b90e20f8845ad3d79c53bed09728153eb75f675c --- /dev/null +++ b/docs/getting_started.rst @@ -0,0 +1,233 @@ +Dataset Zoo +################## +LAVIS inherently supports a wide variety of common language-vision datasets by providing automatic download scripts to help download and organize these datasets; +and implements PyTorch datasets for these datasets. To view supported datasets, use the following code: + +.. code-block:: python + + from lavis.datasets.builders import dataset_zoo + dataset_names = dataset_zoo.get_names() + print(dataset_names) + # ['aok_vqa', 'coco_caption', 'coco_retrieval', 'coco_vqa', 'conceptual_caption_12m', + # 'conceptual_caption_3m', 'didemo_retrieval', 'flickr30k', 'imagenet', 'laion2B_multi', + # 'msrvtt_caption', 'msrvtt_qa', 'msrvtt_retrieval', 'msvd_caption', 'msvd_qa', 'nlvr', + # 'nocaps', 'ok_vqa', 'sbu_caption', 'snli_ve', 'vatex_caption', 'vg_caption', 'vg_vqa'] + print(len(dataset_names)) + # 23 + + +Auto-Downloading and Loading Datasets +###################################### +We now take COCO caption dataset as an example to demonstrate how to download and prepare the dataset. + +In ``lavis/datasets/download_scripts/``, we provide tools to download most common public language-vision datasets supported by LAVIS. +The COCO caption dataset uses images from COCO dataset. Therefore, we first download COCO images via: + +.. code-block:: bash + + cd lavis/datasets/download_scripts/ && python download_coco.py + +This will automatically download and extract COCO images to the default LAVIS cache location. +The default cache location is ``~/.cache/lavis``, defined in ``lavis/configs/default.yaml``. + +After downloading the images, we can use ``load_dataset()`` to obtain the dataset. On the first run, this will automatically download and cache annotation files. + +.. code-block:: python + + from lavis.datasets.builders import load_dataset + coco_dataset = load_dataset("coco_caption") + + print(coco_dataset.keys()) + # dict_keys(['train', 'val', 'test']) + + print(len(coco_dataset["train"])) + # 566747 + + print(coco_dataset["train"][0]) + # {'image': , + # 'text_input': 'A woman wearing a net on her head cutting a cake. ', + # 'image_id': 0} + +If you already host a local copy of the dataset, you can pass in the ``vis_path`` argument to change the default location to load images. + +.. code-block:: python + + coco_dataset = load_dataset("coco_caption", vis_path=YOUR_LOCAL_PATH) + + +Model Zoo +#################################### +LAVIS supports a growing list of pre-trained models for different tasks, +datatsets and of varying sizes. Let's get started by viewing the supported models. + +.. code-block:: python + + from lavis.models import model_zoo + print(model_zoo) + # ================================================== + # Architectures Types + # ================================================== + # albef_classification base, ve + # albef_nlvr base + # albef_pretrain base + # albef_retrieval base, coco, flickr + # albef_vqa base, vqav2 + # alpro_qa base, msrvtt, msvd + # alpro_retrieval base, msrvtt, didemo + # blip_caption base, base_coco, large, large_coco + # blip_classification base + # blip_feature_extractor base + # blip_nlvr base + # blip_pretrain base + # blip_retrieval base, coco, flickr + # blip_vqa base, vqav2 + # clip ViT-B-32, ViT-B-16, ViT-L-14, ViT-L-14-336, RN50 + + # show total number of support model variants + len(model_zoo) + # 33 + + +Inference with Pre-trained Models +#################################### + +Now let's see how to use models in LAVIS to perform inference on example data. We first +load a sample image from local. + +.. code-block:: python + + from PIL import Image + + # setup device to use + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # load sample image + raw_image = Image.open("docs/_static/merlion.png").convert("RGB") + +This example image shows `Merlion park `_ (`image credit `_), a landmark in Singapore. + +.. image:: _static/merlion.png + +Image Captioning +******************************* +We now use the BLIP model to generate a caption for the image. To make inference even easier, we also associate each +pre-trained model with its preprocessors (transforms), we use ``load_model_and_preprocess()`` with the following arguments: + +- ``name``: The name of the model to load. This could be a pre-trained model, task model, or feature extractor. See ``model_zoo`` for a full list of model names. +- ``model_type``: Each architecture has variants trained on different datasets and at different scale. See Types column in ``model_zoo`` for a full list of model types. +- ``is_eval``: if `True`, set the model to evaluation mode. This is desired for inference or feature extraction. +- ``devce``: device to load the model to. + +.. code-block:: python + + from lavis.models import load_model_and_preprocess + # loads BLIP caption base model, with finetuned checkpoints on MSCOCO captioning dataset. + # this also loads the associated image processors + model, vis_processors, _ = load_model_and_preprocess(name="blip_caption", model_type="base_coco", is_eval=True, device=device) + + # preprocess the image + # vis_processors stores image transforms for "train" and "eval" (validation / testing / inference) + image = vis_processors["eval"](raw_image).unsqueeze(0).to(device) + + # generate caption + model.generate({"image": image}) + # ['a large fountain spewing water into the air'] + + +You may also load models and their preprocessors separately via ``load_model()`` and ``load_processor()``. +In BLIP, you can also generate diverse captions by turning nucleus sampling on. + +.. code-block:: python + + from lavis.processors import load_processor + from lavis.models import load_model + + # load image preprocesser used for BLIP + vis_processor = load_processor("blip_image_eval").build(image_size=384) + model = load_model(name="blip_caption", model_type="base_coco", is_eval=True, device=device) + + image = vis_processor(image).unsqueeze(0).to(device) + model.generate({"image": raw_image}, use_nucleus_sampling=True) + # one generated random sample: ['some very pretty buildings and some water jets'] + + +Visual question answering (VQA) +******************************* +BLIP model is able to answer free-form questions about images in natural language. +To access the VQA model, simply replace the ``name`` and ``model_type`` arguments +passed to ``load_model_and_preprocess()``. + +.. code-block:: python + + from lavis.models import load_model_and_preprocess + model, vis_processors, txt_processors = load_model_and_preprocess(name="blip_vqa", model_type="vqav2", is_eval=True, device=device) + + # ask a random question. + question = "Which city is this photo taken?" + + image = vis_processors["eval"](raw_image).unsqueeze(0).to(device) + question = txt_processors["eval"](question) + + model.predict_answers(samples={"image": image, "text_input": question}, inference_method="generate") + # ['singapore'] + + +Unified Feature Extraction Interface +#################################### + +LAVIS provides a unified interface to extract multimodal features from each architecture. +To extract features, we load the feature extractor variants of each model. +The multimodal feature can be used for multimodal classification. The low-dimensional unimodal features can be used to compute cross-modal similarity. + +.. code-block:: python + + from lavis.models import load_model_and_preprocess + + model, vis_processors, txt_processors = load_model_and_preprocess(name="blip_feature_extractor", model_type="base", is_eval=True, device=device) + caption = "a large fountain spewing water into the air" + + image = vis_processors["eval"](raw_image).unsqueeze(0).to(device) + text_input = txt_processors["eval"](caption) + + sample = {"image": image, "text_input": [text_input]} + + features_multimodal = model.extract_features(sample) + print(features_multimodal.keys()) + # odict_keys(['image_embeds', 'multimodal_embeds']) + print(features_multimodal.multimodal_embeds.shape) + # torch.Size([1, 12, 768]), use features_multimodal[:, 0, :] for multimodal classification tasks + + features_image = model.extract_features(sample, mode="image") + print(features_image.keys()) + # odict_keys(['image_embeds', 'image_embeds_proj']) + print(features_image.image_embeds.shape) + # torch.Size([1, 197, 768]) + print(features_image.image_embeds_proj.shape) + # torch.Size([1, 197, 256]) + + features_text = model.extract_features(sample, mode="text") + print(features_text.keys()) + # odict_keys(['text_embeds', 'text_embeds_proj']) + print(features_text.text_embeds.shape) + # torch.Size([1, 12, 768]) + print(features_text.text_embeds_proj.shape) + # torch.Size([1, 12, 256]) + + similarity = features_image.image_embeds_proj[:, 0, :] @ features_text.text_embeds_proj[:, 0, :].t() + print(similarity) + # tensor([[0.2622]]) + +Since LAVIS supports a unified feature extraction interface, minimal changes are necessary to use a different model as feature extractor. For example, +to use ALBEF as the feature extractor, one only needs to change the following line: + +.. code-block:: python + + model, vis_processors, txt_processors = load_model_and_preprocess(name="albef_feature_extractor", model_type="base", is_eval=True, device=device) + +Similarly, to use CLIP as feature extractor: + +.. code-block:: python + + model, vis_processors, txt_processors = load_model_and_preprocess(name="clip_feature_extractor", model_type="base", is_eval=True, device=device) + # model, vis_processors, txt_processors = load_model_and_preprocess(name="clip_feature_extractor", model_type="RN50", is_eval=True, device=device) + # model, vis_processors, txt_processors = load_model_and_preprocess(name="clip_feature_extractor", model_type="ViT-L-14", is_eval=True, device=device) diff --git a/docs/index.rst b/docs/index.rst new file mode 100644 index 0000000000000000000000000000000000000000..0e2a8458d04c2020df562bb59a0ec7833126d6c4 --- /dev/null +++ b/docs/index.rst @@ -0,0 +1,46 @@ +.. LAVIS documentation master file, created by + sphinx-quickstart on Sun Jul 31 10:32:27 2022. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +Welcome to LAVIS's documentation! +================================= + +.. toctree:: + :maxdepth: 1 + :caption: Introduction + + intro + + +.. toctree:: + :maxdepth: 1 + :caption: Getting Started + + getting_started + + +.. :maxdepth: 1 +.. :caption: Advanced Training + +.. advanced_training + + +.. toctree:: + :maxdepth: 2 + :caption: Advanced Usage + + benchmark + tutorial + + +.. Documentations +.. =================== + + +Indices and tables +================== + +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` diff --git a/docs/intro.rst b/docs/intro.rst new file mode 100644 index 0000000000000000000000000000000000000000..bc8510bd05ba3fcded7fee6ec436f742d122ed25 --- /dev/null +++ b/docs/intro.rst @@ -0,0 +1,99 @@ +What is LAVIS? +#################################### + +LAVIS is a Python deep learning library for LAnguage-and-VISion research and applications. +It features a unified design to access state-of-the-art foundation language-vision models (`ALBEF `_, +`BLIP `_, `ALPRO `_, `CLIP `_), common tasks +(retrieval, captioning, visual question answering, multimodal classification etc.) and datasets (COCO, Flickr, Nocaps, Conceptual +Commons, SBU, etc.). + +This library aims to provide engineers and researchers with a one-stop solution to rapidly develop models for their specific multimodal +scenarios, and benchmark them across standard and customized datasets. + +Key features of LAVIS include: + +- **Modular and Extensible Library Design**: facilitating to easily utilize and repurpose existing modules (datasets, models, preprocessors), also to add new modules. + +- **Easy Off-the-shelf Inference and Feature Extraction**: readily available pre-trained models let you take advantage of state-of-the-art multimodal understanding and generation capabilities on your own data. + +- **Reproducible Model Zoo**: provided training/pre-training recipies to easily replicate and extend state-of-the-art models. + +- **Dataset Zoo and Automatic Downloading Tools**: it can be a hassle to prepare the many language-vision datasets. LAVIS provides automatic downloaing scripts to help prepare a large variety of datasets and their annotations. + +Other features include: + +- **Distributed Training** using multiple GPUs on one machine or across multiple machines. + +- **Web Demo**: try supported models on your own pictures, questions etc. + +- **Leaderboard**: comparing state-of-the-art models across standard datasets. + +- **Dataset Explorer**: help browse and understand language-vision datasets. + +Supported Tasks, Models and Datasets +#################################### + +The following table shows the supported models and language-vision tasks by LAVIS. Adapting existing models to more tasks is possible and next to come in future releases. + +======================================== =========================== ============================================= ============ +Tasks Supported Models Supported Datasets Modalities +======================================== =========================== ============================================= ============ +Image-text Pre-training ALBEF, BLIP COCO, VisualGenome, SBU, ConceptualCaptions image, text +Image-text Retrieval ALBEF, BLIP, CLIP COCO, Flickr30k image, text +Text-image Retrieval ALBEF, BLIP, CLIP COCO, Flickr30k image, text +Visual Question Answering ALBEF, BLIP VQAv2, OKVQA, A-OKVQA image, text +Image Captioning BLIP COCO, NoCaps image, text +Image Classification CLIP ImageNet image +Natural Language Visual Reasoning (NLVR) ALBEF, BLIP NLVR2 image, text +Visual Entailment (VE) ALBEF SNLI-VE image, text +Visual Dialogue BLIP VisDial image, text +Video-text Retrieval BLIP, ALPRO MSRVTT, DiDeMo video, text +Text-video Retrieval BLIP, ALPRO MSRVTT, DiDeMo video, text +Video Question Answering (VideoQA) BLIP, ALPRO MSRVTT, MSVD video, text +Video Dialogue VGD-GPT AVSD video, text +Multimodal Feature Extraction ALBEF, CLIP, BLIP, ALPRO customized image, text +======================================== =========================== ============================================= ============ + +Library Design +#################################### + +.. image:: _static/architecture.png + :width: 550 + +LAVIS has six key modules. + +- ``lavis.runners`` manages the overall training and evaluation lifecycle. It is also responsible for creating required components lazily as per demand, such as optimizers, learning rate schedulers and dataloaders. Currently ``RunnerBase`` implements epoch-based training and ``RunerIters`` implements iteration-based training. +- ``lavis.tasks`` implements concrete training and evaluation logic per task. A task could be, for example, retrieval, captioning, pre-training. The rationale to have an abstraction of task is to accomodate task-specific training and evaluation. For example, evaluating a retrieval model is different from a classification model. +- ``lavis.datasets`` is responsible for creating datasets, where ``lavis.datasets.builders`` loads dataset configurations, downloads annotations and returns a dataset object; ``lavis.datasets.datasets`` defines the supported datasets, each is a ``torch.utils.data.Dataset`` instance. We also provide `automatic dataset downloading tools` in ``datasets/download_scripts`` to help prepare common public datasets. +- ``lavis.models`` holds definition for the supported models and shared model layers. +- ``lavis.processors`` handles preprocessing of text and images/videos before feeding the model. For images and videos, a processor can be thought as transfroms in torchvision; for text input, this may include lowering case, truncation etc. +- ``lavis.common`` module contains shared classes and methods used by multiple other modules. For example, + + - ``lavis.common.config`` contains classes to store and manipulate configuration files used by LAVIS. In particular, we use a hierarchical configuration design, to allow highly customizable training and evaluation. + - ``lavis.common.registry`` serves as a centralized place to manage modules that share the same functionalities. It allows building datasets, models, tasks, and learning rate schedulers during runtime, by specifying their names as string in the configuration file. + - ``lavis.common.optims`` contains definitions of learning rate schedulers. + - ``lavis.common.dist_utils`` contains utilities for distributed training and evaluation. + - ``lavis.common.utils`` contains miscellaneous utilities, mostly IO-related helper functions. + + +Installation +############ +1. (Optional) Creating conda environment + +.. code-block:: bash + + conda create -n lavis python=3.8 + conda activate lavis + +2. Cloning and building from source + +.. code-block:: bash + + git clone https://github.com/salesforce/LAVIS.git + cd LAVIS + pip install . + +If you would like to develop on LAVIS, you may find it easier to build with editable mode:: + + pip install -e . + diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 0000000000000000000000000000000000000000..6247f7e231716482115f34084ac61030743e0715 --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=source +set BUILDDIR=build + +if "%1" == "" goto help + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.http://sphinx-doc.org/ + exit /b 1 +) + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..da594f6d5cee1569d227e8c30af5c3113304bd8f --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,7 @@ +GitPython +ipykernel +nbsphinx==0.8.7 +pandoc +sphinx +sphinx_autodoc_typehints +sphinx_rtd_theme \ No newline at end of file diff --git a/docs/tutorial.configs.rst b/docs/tutorial.configs.rst new file mode 100644 index 0000000000000000000000000000000000000000..7841f8ea66e26d3f8c9342a07eec3a89da3d9276 --- /dev/null +++ b/docs/tutorial.configs.rst @@ -0,0 +1,172 @@ +.. _config: + +Training Models on Task Datasets (Commands and Configurations) +################################################################# + +LAVIS provides scripts to pre-train and finetune supported models on standard language-vision tasks, stored at ``lavis/run_scripts/``. +To replicate the experiments, just run these bash scripts. For example, to train BLIP model on the image-text retrieval task with MSCOCO dataset, we can run + +.. code-block:: + + bash run_scripts/lavis/blip/train/train_retrieval_coco.sh + +Inside the scripts, we can see + +.. code-block:: bash + + python -m torch.distributed.run --nproc_per_node=8 train.py --cfg-path lavis/projects/blip/train/retrieval_coco_ft.yaml + +where we start a pytorch distributed training on 8 GPUs (you may change according to your own hardware setup). The ``--cfg-path`` specifys a `runtime configuration file`, specifying +the task, model, dataset and training recipes. + +Available options and their descriptions are as below. + +.. LAVIS executes training and evaluation based on arguments specified in the configuration files. The default model and dataset configurations are defined in ``lavis/configs``. The task-specific configurations are defined in ``lavis/projects``. Task-specific configurations have higher priority over the default configurations. + +.. The following tables provide explanations for the arguments in the configuration files. + +.. list-table:: + :widths: 30 40 + :header-rows: 1 + + * - Model Configurations + - Functionalities + * - arch + - | name of the model from the model zoo + | default: task-dependent + * - model_type + - | the type of the model (e.g., base) + | default: task-dependent + * - load_pretrained + - | load pretrained weights + | default: True (for finetuning task) | False (for pretraining task) + * - load_finetuned + - | load task-specific finetuned weights + | default: False (for finetuning task) | True (for evaluation) + * - pretrained + - | URL or local path which stores the pretrained model, defined in the default model configuration file + | default: task-dependent + * - finetuned + - | URL or local path which stores the finetuned model, defined in the default model configuration file + | default: task-dependent + +.. list-table:: + :widths: 30 50 + :header-rows: 1 + + * - Dataset Configurations + - Functionalities + * - vis_processor + - | pre-processing of visual input + | default: task-dependent + * - text_processor + - | pre-processing of text input + | default: task-dependent + * - build_info + - | dataset information including the storage location, defined in the default dataset configuration file + | default: task-dependent + +.. list-table:: + :widths: 30 50 + :header-rows: 1 + + * - Runtime Configurations + - Functionalities + * - task + - | name of the task + | default: task-dependent + * - lr_sched + - | learning rate schedular + | default: linear_warmup_cosine_lr + * - init_lr + - | initial learning rate (after warmup) + | default: task-dependent + * - min_lr + - | final learning rate after decay + | default: task-dependent + * - warmup_lr + - | starting learning rate for warmup + | default: init_lr (no warmup) + * - lr_decay_rate + - | learning rate decay per epoch for step_lr_shedule + | default: 0.9 + * - warmup_steps + - | number of steps for learning rate warmup + | default: 0 + * - max_epoch + - | total number of training epochs + | default: task-dependent + * - weight_decay + - | weight decay coefficient for the optimizer + | default: 0.05 + * - batch_size_train + - | batch size during training + | default: task-dependent + * - batch_size_eval + - | batch size during evaluation + | default: task-dependent + * - seed + - | pseudo random number generator seed + | default: 42 + * - output_dir + - | directory to store logs, results and checkpoints + | default: task-dependent + * - resume_ckpt_path + - | path of the checkpoint to resume training from + | default: None + * - evaluate + - | only perform evaluation without training + | default: False + * - train_splits + - | dataset splits used for training + | default: ["train"] + * - valid_splits + - | dataset splits used for validation + | default: ["val"] + * - test + - | dataset splits used for test + | default: ["test"] + * - device + - | use cpu or gpu (cuda) + | default: cuda + * - world_size + - | number of processes participating in the job + | default: 1 + * - dist_url + - | URL specifying how to initialize the process group + | default: "env://" + * - distributed + - | use distributed training + | default: True + * - amp + - | use automatic mixed precision training + | default: False + +.. list-table:: + :widths: 40 50 + :header-rows: 1 + + * - Text Generation Configurations + - Functionalities + * - max_len + - | maximum number of text tokens to generate + | default: 20 (for image captioning) + * - min_len + - | minimum number of text tokens to generate + | default: 5 (for image captioning) + * - num_beams + - | number of beams to perform beam search + | default: 3 + +.. list-table:: + :widths: 40 50 + :header-rows: 1 + + * - Multimodal Retrieval Configurations + - Functionalities + * - negative_all_rank + - | collect negatives from all processes for the image-text matching loss + | default: True (for coco) + * - k_test + - | number of retrieval candidates ranked from contrastive similarity + | default: 256 (for coco) diff --git a/docs/tutorial.datasets.rst b/docs/tutorial.datasets.rst new file mode 100644 index 0000000000000000000000000000000000000000..ee026c03f3fecfe2c2bc3cf0076b5f2242898d42 --- /dev/null +++ b/docs/tutorial.datasets.rst @@ -0,0 +1,424 @@ +Adding Datasets +################################################ + +This is a tutorial on adding a new dataset using ``lavis.datasets`` module. + +The LAVIS library includes a standard dataset module, which allows customization to add new datasets. +The ``lavis.datasets`` module is designed such that any new dataset class can be easily added and adapted from our code base, including creating dataset configuration, and defining and associating new dataset classes. + +In this tutorial, we will replicate the steps to add a dataset class for the `Audio-Visual Scene-Aware Dialogue (AVSD) `_ benchmark for the video-grounded dialogue task. + +Dataset Configuration ``lavis.configs.datasets`` +************************************************************** + +First, we define the basic configurations for this dataset, including a new dataset class ``avsd_dialogue``, dataset card, and data types. +We can define any new dataset configuration in ``lavis.configs.datasets``. For instance, under this module, we can set up a configuration file ``avsd/defaults_dial.yaml`` as follows: + +.. code-block:: yaml + + datasets: + avsd_dialogue: # name of the dataset builder + dataset_card: dataset_card/avsd_dialogue.md # path to the dataset card + data_type: features # [images|videos|features] we use features in this case for extracted video features + + build_info: + # Be careful not to append minus sign (-) before split to avoid itemizing + annotations: + train: + url: /export/home/data/avsd/train_set4DSTC7-AVSD.json + storage: avsd/annotations/train.json + val: + url: /export/home/data/avsd/valid_set4DSTC7-AVSD.json + storage: avsd/annotations/val.json + test: + url: /export/home/data/avsd/test_set4DSTC7-AVSD.json + storage: avsd/annotations/test.json + features: + storage: /export/home/data/avsd/features/ + + +Dataset Card +=============== +One optional step to set up dataset configuration is defining a dataset card, which contains more details about the dataset such as description, tasks, and metrics. +For instance, we can define a dataset card for the AVSD benchmark in ``dataset_card/avsd_dialogue.md``. +Depending on the dataset, we included in its corresponding dataset card the command for auto-downloading data (with python code defined in ``lavis.datasets.download_scripts``) that will automatically load the data and store it in a specific folder. +Else, you should describe in the dataset card the external download instructions from the original data source to load the dataset properly. + +One example of a dataset card for the AVSD benchmark is: + +.. code-block:: md + + ![Samples from the AVSD dataset (Image credit: "https://arxiv.org/pdf/1901.09107.pdf").](imgs/avsd_dialogue.png)(Samples from the AVSD dataset. Image credit: "https://arxiv.org/pdf/1901.09107.pdf") + + # Audio-Visual Scene-Aware Dialogues (AVSD) + + ## Description + [Audio-Visual Scene-Aware Dialogues (AVSD)](https://github.com/hudaAlamri/DSTC7-Audio-Visual-Scene-Aware-Dialog-AVSD-Challenge) contains more than 10,000 dialogues, each of which is grounded on a unique video. In the test split, for each test sample, 6 reference dialogue responses are provided. + + + ## Task + + (https://github.com/hudaAlamri/DSTC7-Audio-Visual-Scene-Aware-Dialog-AVSD-Challenge) + + In a **video-grounded dialogue task**, the system must generate responses to user input in the context of a given dialog. + This context consists of a dialog history (previous utterances by both user and system) in addition to video and audio information that comprise the scene. The quality of a system’s automatically generated sentences is evaluated using objective measures to determine whether or not the generated responses are natural and informative + + ## Metrics + Models are typically evaluated according to [BLEU](https://aclanthology.org/P02-1040/), [CIDER](https://www.cv-foundation.org/openaccess/content_cvpr_2015/papers/Vedantam_CIDEr_Consensus-Based_Image_2015_CVPR_paper.pdf), [METEOR](https://aclanthology.org/W05-0909/), and [ROUGE-L](https://aclanthology.org/W04-1013/) metrics. + + ## Leaderboard + + .... + + + ## Auto-Downloading + + Please refer to [benchmark webite](https://github.com/hudaAlamri/DSTC7-Audio-Visual-Scene-Aware-Dialog-AVSD-Challenge) for instructions to download the dataset. + + + ## References + "Audio Visual Scene-Aware Dialog", Huda Alamri, Vincent Cartillier, Abhishek Das, Jue Wang, Anoop Cherian, Irfan Essa, Dhruv Batra, Tim K. Marks, Chiori Hori, Peter Anderson, Stefan Lee, Devi Parikh + +Visual Data Type +============================== +We currently limit the visual data types to one of three options: ``images``, ``videos``, and ``features``. +"Images" and "videos" refer to the raw visual data, which is appropriate for models processing visual data in their original forms (e.g. ViT models). +"Features" are visual representations extracted from pretrained models (e.g. CNN models). +In this tutorial, the AVSD benchmark consists of video features extracted from 3D-CNN models. + +Build Info +============================== +Build info refers to the specific locations where data is stored and cached. + +For text annotations (e.g. captioning or dialogues), by default, we include three data splits, namely "train", "val", and "test", typically used in all machine learning projects. +For each split, we specify 2 parameters: ``url`` and ``storage``. +``url`` can be either an online URL where the dataset can be loaded automatically (e.g. from *googleapis*), or a local directory where data is already downloaded beforehand. +``storage`` is the directory where the data will be cached over time, avoiding downloading data repeatedly. + +For visual data annotations, ensure the field name matches the data types defined earlier (e.g. one of "images", "videos" or features"). +As visual features are usually large and should be downloaded beforehand, we maintain only a ``storage`` parameter where visual data is cached. + +Dataset ``lavis.datasets.datasets`` +************************************************************** + +Base Dataset ``lavis.datasets.datasets.base_dataset`` +======================================================= +In this step, we want to define new dataset classes that inherit our base dataset class ``lavis.datasets.datasets.base_dataset``. This base dataset class already defines standard methods such as ``collater`` which uses the default collator from Pytorch. + +.. code-block:: python + + import json + from typing import Iterable + + from torch.utils.data import Dataset, ConcatDataset + from torch.utils.data.dataloader import default_collate + + class BaseDataset(Dataset): + def __init__( + self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[] + ): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + self.vis_root = vis_root + + self.annotation = [] + for ann_path in ann_paths: + self.annotation.extend(json.load(open(ann_path, "r"))) + + self.vis_processor = vis_processor + self.text_processor = text_processor + + self._add_instance_ids() + + def __len__(self): + return len(self.annotation) + + def collater(self, samples): + return default_collate(samples) + + def set_processors(self, vis_processor, text_processor): + self.vis_processor = vis_processor + self.text_processor = text_processor + + def _add_instance_ids(self, key="instance_id"): + for idx, ann in enumerate(self.annotation): + ann[key] = str(idx) + +Any dataset subclass will inherit these methods and it is optional to define and overwrite these methods accordingly to the specifications of the dataset. +We encourage users not to modify the base dataset class as any modification will have cascading impacts on any other dataset classes that inherit this base dataset. +Instead, the users should independently create new dataset classes to cater to their specific requirements. + +Dialogue Datasets ``lavis.datasets.datasets.dialogue_datasets`` +====================================================================== + +For example, for the AVSD dataset, we want to define a new dataset subclass ``DialogueDataset`` for dialogue tasks. We can define this dataset class in ``lavis.datasets.datasets.dialogue_datasets`` as following: + +.. code-block:: python + + import os + from collections import OrderedDict + + from lavis.datasets.datasets.base_dataset import BaseDataset + + import json + import copy + + class DialogueDataset(BaseDataset): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + """ + vis_processor (string): visual processor + text_processor (string): textual processor + vis_root (string): Root directory of images (e.g. coco/images/) + ann_paths (string): Root directory of images (e.g. coco/images/) + """ + + self.vis_root = vis_root + + self.annotation = [] + for ann_path in ann_paths: + dialogs = json.load(open(ann_path, "r"))['dialogs'] + for dialog in dialogs: + all_turns = dialog['dialog'] + dialogue_context = [] + for turn in all_turns: + dialog_instance = copy.deepcopy(dialog) + question = turn['question'] + answer = turn['answer'] + + dialog_instance['dialog'] = copy.deepcopy(dialogue_context) + dialog_instance['question'] = question + dialog_instance['answer'] = answer + self.annotation.append(dialog_instance) + dialogue_context.append(turn) + + self.vis_processor = vis_processor + self.text_processor = text_processor + + self._add_instance_ids() + + self.img_ids = {} + n = 0 + for ann in self.annotation: + img_id = ann["image_id"] + if img_id not in self.img_ids.keys(): + self.img_ids[img_id] = n + n += 1 + +Class inheritance allows us to define multiple subclasses. For instance, we want another dialogue dataset class that is defined only for the test split. We can define another dataset class ``DialogueEvalDataset`` as similarly defined above but the annotations are processed differently. +Typically, in dialogue tasks, during test time, only a single test sample is constructed per dialogue (rather than decomposing all dialogue turns as samples during training time). +The dataset class can then be defined as: + +.. code-block:: python + + class DialogueEvalDataset(BaseDataset): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + # ... + # defined similarly as DialogueDataset above + # except for the loading of dialogue annotation data + + self.annotation = [] + for ann_path in ann_paths: + dialogs = json.load(open(ann_path, "r"))['dialogs'] + for dialog in dialogs: + all_turns = dialog['dialog'] + dialogue_context = all_turns[:-1] + last_turn = all_turns[-1] + + question = last_turn['question'] + answer = last_turn['answer'] + + dialog['dialog'] = dialogue_context + dialog['question'] = question + dialog['answer'] = answer + + self.annotation.append(dialog) + + +Using class inheritance to define datasets also allows us to develop more fine-grain class implementations, each of which is specifically designated for a benchmark. +For instance, under the dialogue-based tasks, we can further define another dataset subclass that is specified for the AVSD dataset. +We can define a new class ``AVSDDialDataset`` that further specifies how to load individual samples and collate them accordingly to specific requirements: + +.. code-block:: python + + import os + from lavis.datasets.datasets.base_dataset import BaseDataset + from lavis.datasets.datasets.dialogue_datasets import DialogueDataset, DialogueEvalDataset + + import torch + + class AVSDDialDataset(DialogueDataset): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + + def __getitem__(self, index): + + ann = self.annotation[index] + + vname = ann["image_id"] + + video = self.vis_processor(self.vis_root, vname) + + dialogue = self.text_processor(ann) + + return { + "video_fts": video['video_fts'], + "video_token_type_ids": video['token_type_ids'], + "input_ids": dialogue['input_ids'], + "token_type_ids": dialogue['token_type_ids'], + "labels": dialogue['labels'], + "image_id": ann["image_id"], + "instance_id": ann["instance_id"] + } + + def collater(self, samples): + + input_ids, token_type_ids, labels, video_fts, video_token_type_ids = [], [], [], [], [] + + for i in samples: + input_ids.append(i['input_ids']) + token_type_ids.append(i['token_type_ids']) + labels.append(i['labels']) + video_fts.append(i['video_fts']) + video_token_type_ids.append(i['video_token_type_ids']) + + input_ids = self.text_processor.padding(input_ids) + + labels = self.text_processor.padding(labels, -1) + video_fts = self.vis_processor.padding(video_fts) + + token_type_ids = self.text_processor.padding(token_type_ids) + video_token_type_ids = self.text_processor.padding(video_token_type_ids) + token_type_ids = torch.cat([video_token_type_ids, token_type_ids], dim=1) + + attn_mask = self.text_processor.get_attention_mask(input_ids) + video_mask = self.vis_processor.get_attention_mask(video_fts) + attn_mask = torch.cat([video_mask, attn_mask], dim=1) + + video_labels = torch.ones((video_fts.size(0), video_fts.size(1))).long() * -1 # ignore token indice -1 by default + + labels = torch.cat([video_labels, labels], dim=1) + + samples = {} + samples['input_ids'] = input_ids + samples['token_type_ids'] = token_type_ids + samples['labels'] = labels + samples['video_fts'] = video_fts + samples['attn_mask'] = attn_mask + + return samples + +Note that in a dataset subclass, if methods such as ``__getitem__`` and ``collater`` are not defined, the same functions from the corresponding superclass will be used. +For instance, by default, we always use the collater from the ``BaseDataset`` class to collate data samples. + +Dataset Builder ``lavis.datasets.builders`` +************************************************************** +Dataset Builder is the data processing module that controls the dataset classes (by training or evaluation split) and associates the specific dataset configurations to these dataset classes. + +Base Dataset Builder ``lavis.datasets.builders.base_dataset_builder`` +====================================================================== + +Note that any new builder class definition should inherit the base dataset builder class ``lavis.datasets.builders.base_dataset_builder``: + +.. code-block:: python + + class BaseDatasetBuilder: + train_dataset_cls, eval_dataset_cls = None, None + ... + +This allows us to standardize the operations of dataset builders across all builder classes. We advise the users to carefully review the standard methods defined in the base builder class, including methods such as ``_download_data`` and ``build_dataset`` that will load download the data and create instances of dataset classes: + +.. code-block:: python + + class BaseDatasetBuilder: + ... + + def build_datasets(self): + # download, split, etc... + # only called on 1 GPU/TPU in distributed + + if is_main_process(): + self._download_data() + + if is_dist_avail_and_initialized(): + dist.barrier() + + # at this point, all the annotations and image/videos should be all downloaded to the specified locations. + logging.info("Building datasets...") + datasets = self.build() # dataset['train'/'val'/'test'] + + return datasets + + def _download_data(self): + self._download_ann() + self._download_vis() + +We encourage users not to modify the implementation of the base dataset builder class as this will affect all existing dataset builder subclasses. + +Dialogue Dataset Builder ``lavis.datasets.builders.dialogue_builder`` +====================================================================== +We can define any new builder subclass and associate this builder with the corresponding dataset classes and dataset configurations. +For instance, for the AVSD dataset, we can define a builder ``lavis.datasets.builders.dialogue_builder`` for dialogue-based datasets as follows: + +.. code-block:: python + + from lavis.datasets.builders.base_dataset_builder import BaseDatasetBuilder + from lavis.datasets.datasets.avsd_dialogue_datasets import ( + AVSDDialDataset, + AVSDDialEvalDataset + ) + + from lavis.common.registry import registry + + + @registry.register_builder("avsd_dialogue") + class AVSDDialBuilder(BaseDatasetBuilder): + train_dataset_cls = AVSDDialDataset + eval_dataset_cls = AVSDDialEvalDataset + + DATASET_CONFIG_DICT = { + "default": "configs/datasets/avsd/defaults_dial.yaml" + } + +Note that we chose to separately define the parameters ``train_dataset_cls`` and ``eval_dataset_cls`` to consider cases where data is processed differently between training and test time. +For instance, in captioning tasks, during test time, each data sample often includes multiple ground-truth captions rather than just a single ground-truth during training time. +If the data processing is the same in both training and test time, the two parameters can be linked to the same dataset class. + +Finally, define ``DATASET_CONFIG_DICT`` to associate the dataset configurations to the assigned dataset classes. + +Registering Builder ``lavis.datasets.builders.__init__`` +====================================================================== + +To add a new builder class, ensure to first include the class within the ``__init__.py``. For instance, to define a new builder for the AVSD dataset: + +.. code-block:: python + + from lavis.datasets.builders.dialogue_builder import ( + AVSDDialBuilder + ) + + __all__ = [ + ..., + "AVSDDialBuilder" + ] + +Assigning Builder +====================================================================== +Note that during data loading and processing, the builder being assigned must have the correct registry to be able to load it properly. +For instance, the following should be specified in a configuration file e.g. ``dialogue_avsd_ft.yaml``: + +.. code-block:: yaml + + datasets: + avsd_dialogue: # name of the dataset builder + ... + # processor configuration + ... + +Subsequently, any processes (e.g. training) should load this configuration file to assign the correct builder which will then associate the correct dataset classes to construct data samples. + +.. code-block:: sh + + python train.py --cfg-path dialogue_avsd_ft.yaml diff --git a/docs/tutorial.evaluation.rst b/docs/tutorial.evaluation.rst new file mode 100644 index 0000000000000000000000000000000000000000..b20a295456494ab746cb90766047a2b79fcec3e9 --- /dev/null +++ b/docs/tutorial.evaluation.rst @@ -0,0 +1,40 @@ +Evaluating Pre-trained Models on Task Datasets +############################################### +LAVIS provides pre-trained and finetuned model for off-the-shelf evaluation on task dataset. +Let's now see an example to evaluate BLIP model on the captioning task, using MSCOCO dataset. + +.. _prep coco: + +Preparing Datasets +****************** +First, let's download the dataset. LAVIS provides `automatic downloading scripts` to help prepare +most of the public dataset, to download MSCOCO dataset, simply run + +.. code-block:: bash + + cd lavis/datasets/download_scripts && bash download_coco.py + +This will put the downloaded dataset at a default cache location ``cache`` used by LAVIS. + +If you want to use a different cache location, you can specify it by updating ``cache_root`` in ``lavis/configs/default.yaml``. + +If you have a local copy of the dataset, it is recommended to create a symlink from the cache location to the local copy, e.g. + +.. code-block:: bash + + ln -s /path/to/local/coco cache/coco + +Evaluating pre-trained models +****************************** + +To evaluate pre-trained model, simply run + +.. code-block:: bash + + bash run_scripts/lavis/blip/eval/eval_coco_cap.sh + +Or to evaluate a large model: + +.. code-block:: bash + + bash run_scripts/lavis/blip/eval/eval_coco_cap_large.sh \ No newline at end of file diff --git a/docs/tutorial.models.rst b/docs/tutorial.models.rst new file mode 100644 index 0000000000000000000000000000000000000000..61a4ce0c228e206ed9f01d7df10d20624339de15 --- /dev/null +++ b/docs/tutorial.models.rst @@ -0,0 +1,245 @@ +Adding Models +#################################### + +This is a tutorial on adding new models using ``lavis.models`` module. + +The LAVIS library includes a standard model module that builds the foundation for many major language-vision models such as `ALBEF `_, +`BLIP `_, `ALPRO `_, and `CLIP `_. +The ``lavis.models`` module is designed such that any new models can be added and integrated into the LAVIS library, with minimal steps to develop training and testing procedures. +In this tutorial, we will replicate the steps to add a GPT-style model specifically for `video-grounded dialogue tasks `_. + +Base Model ``lavis.models.base_model`` +************************************************************** + +Note that any new model definition should inherit the base model class ``BaseModel``: + +.. code-block:: python + + from omegaconf import OmegaConf + + import numpy as np + + import torch + import torch.nn as nn + + from lavis.common.utils import get_abs_path + + class BaseModel(nn.Module): + """Base class for models.""" + + def __init__(self): + super().__init__() + + def forward_features(self, *args, **kwargs): + """Similar to *forward* but only return features.""" + raise NotImplementedError + + def load_from_pretrained(self, url_or_filename): + raise NotImplementedError + + @classmethod + def _from_config(cls, cfg=None, model_type="base"): + if not cfg: + # useful when building model without a provided configuration file + cfg = OmegaConf.load(cls.default_config_path(model_type)).model + + return cls.from_config(cfg) + + @classmethod + def from_pretrained(cls, model_type="base"): + """ + Build a pretrained model from the default configuration file, specified by model_type. + """ + return cls._from_config(cfg=None, model_type=model_type) + + @property + def device(self): + return list(self.parameters())[0].device + + @classmethod + def default_config_path(cls, model_type="base"): + assert ( + model_type in cls.PRETRAINED_MODEL_CONFIG_DICT + ), "Unknown model type {}".format(model_type) + return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type]) + + def before_evaluation(self, **kwargs): + pass + + def show_n_params(self, return_str=True): + tot = 0 + for p in self.parameters(): + w = 1 + for x in p.shape: + w *= x + tot += w + if return_str: + if tot >= 1e6: + return "{:.1f}M".format(tot / 1e6) + else: + return "{:.1f}K".format(tot / 1e3) + else: + return tot + + +In this base model, we already declare and standardize many common methods such as ``_from_config`` and ``_from_pretrained``. +Inheriting this base model class allows us to standardize operations of models across all model classes while still allowing customizations. +We advise users not to change the implementation of the base model class as this will affect all existing model subclasses. + +GPT-style Video-grounded Dialogue Model ``lavis.models.gpt_models.gpt_dialogue`` +******************************************************************************** + +In this step, we can define a new model class, e.g. under ``lavis.models.gpt_models.gpt_dialogue``, for GPT-based dialogue models designed specifically for video-grounded dialogues. +Note that we assume the model class inherits from the standard model super class ``GPT2LMHeadModel`` from the ``transformers`` `library `_. +We also enforce model integration to the LAVIS framework through the inheritance of the ``BaseModel`` from the LAVIS library, as the secondary super class. + +.. code-block:: python + + import torch + from lavis.common.registry import registry + from lavis.models.base_model import BaseModel + + from transformers import GPT2Model, GPT2LMHeadModel + from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions + import math + import torch + import torch.nn as nn + from torch.nn import CrossEntropyLoss, MSELoss + + @registry.register_model("gpt_dialogue") + class GPTDialogue(GPT2LMHeadModel, BaseModel): + ... + +Next, we can modify the architecture of the model during model initialization to fit the tasks of interest, i.e. video-grounded dialogues. +In this case, we want to add additional model parameters for a linear network to transform the video feature representations to the model dimension. + +.. code-block:: python + + class GPTDialogue(GPT2LMHeadModel, BaseModel): + + def __init__(self, config, len_video_ft=4224): + + super().__init__(config) + + self.video_ff = nn.Linear(len_video_ft, config.n_embd) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + +Note that for each new model class, we advise redefining the ``from_config`` method which is inherited from the ``BaseModel`` class. +As each model usually has its own unique configurations, redefining the method will ensure the model instances are created properly. +For instance, ``GPTDialogue`` requires an additional parameter of video feature length (``len_video_ft``) which should be part of the model initialization procedure. +Another additional parameter is the number of tokens/words (as we include additional special tokens in the vocabulary for dialogue tasks). + +.. code-block:: python + + class GPTDialogue(GPT2LMHeadModel, BaseModel): + ... + @classmethod + def from_config(cls, cfg): + model = cls.from_pretrained('gpt2', len_video_ft=cfg['len_video_ft']) + model.resize_token_embeddings(cfg['len_tokenizer']) + return model + +Other basic methods should also be defined explicitly in the new model class, including the ``forward`` function. +For instance, in GPT models for video-grounded dialogue tasks, we want the forward operation also includes the transformation and integration of video features before passing the representations to the Transformer layers. + +.. code-block:: python + + class GPTDialogue(GPT2LMHeadModel, BaseModel): + ... + + def forward(self, samples, + past_key_values=None, + position_ids=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None): + + input_embs = self.transformer.wte(samples['input_ids']) + video_embs = self.video_ff(samples['video_fts']) + input_embs = torch.cat([video_embs, input_embs], dim=1) + + transformer_outputs = self.transformer( + attention_mask=samples['attn_mask'], + token_type_ids=samples['token_type_ids'], + inputs_embeds=input_embs, + position_ids=position_ids, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + lm_logits = self.lm_head(hidden_states) + ... + +Registering New Model ``lavis.models.__init__`` +******************************************************************************** + +Any new model must be officially registered as part of the ``lavis.models`` module. +For instance, to add a model class for GPT-based dialogue models, we can modify the ``__init__.py`` as follows: + +.. code-block:: python + + from lavis.models.gpt_models.gpt_dialogue import GPTDialogue + + __all__ = [ + ... + "GPTDialogue" + ] + +Assigning Model +******************************************************************************** + +From the above example of a model class, note that we define a ``from_config method`` for the new model class. +This method will process a configuration file and pass specific parameters to initialize the model classes properly. +To do this, we can assign/ associate the correct registry of model classes in a configuration file. +For instance, the following should be specified in a configuration file e.g. ``dialogue_avsd_ft.yaml``: + +.. code-block:: yaml + + model: + arch: gpt_dialogue # name of the model + model_type: base + + +Subsequently, any processes (e.g. training) should load this configuration file to assign the correct model. + +.. code-block:: sh + + python train.py --cfg-path dialogue_avsd_ft.yaml + +Note that to simplify the model configuration, we only enable two main parameters here: ``arch`` and ``model_type``. ``arch`` refers to the model class registry, and ``model_type`` is the corresponding model type under this model family. +For instance, with ``gpt_dialogue``, we have a model ``base`` which has its own configuration in a separate configuration file e.g. ``gpt_dialogue_base.yaml``: + +.. code-block:: yaml + + model: + arch: gpt_dialogue + len_tokenizer: 50264 # 50257 tokens from gpt2 default tokenizer + additional special tokens + len_video_ft: 4224 # i3d_rgb: 2048 i3d_flow: 2048 vggish: 128 + +We can pass load this configuration and pass the parameters to the above ``from_config`` method to initialize the model accordingly. +We advise the users to maintain a dictionary that contains default paths to model configurations, in the model class definition. +By default, the LAVIS framework will search for configurations from each model class defined as ``model.PRETRAINED_MODEL_CONFIG_DICT``. + +.. code-block:: python + + class GPTDialogue(GPT2LMHeadModel, BaseModel): + PRETRAINED_MODEL_CONFIG_DICT = { + "base": "configs/models/gpt_dialogue_base.yaml" + } + ... diff --git a/docs/tutorial.processors.rst b/docs/tutorial.processors.rst new file mode 100644 index 0000000000000000000000000000000000000000..14566b5f189de9411d19e2a2bc13045d8b087f83 --- /dev/null +++ b/docs/tutorial.processors.rst @@ -0,0 +1,233 @@ +Adding Processors +################################################ + +This is a tutorial on adding new processors using ``lavis.processors`` module. + +The LAVIS library includes a standard processor module that preprocesses data e.g. image transformation and sequence concatenation. +The ``lavis.processors`` module is designed such that any processors can be added, specifically to the requirements of corresponding models of interest. +In this tutorial, we will replicate the steps to add visual and textual processors specifically for `video-grounded dialogue tasks `_. +In addition, we also want the processors to have processing features to make the data samples compatible with GPT-style models. + +Base Processor ``lavis.processors.base_processors`` +***************************************************** + +Note that any new processor definition should inherit the base processor class ``BaseProcessor``: + +.. code-block:: python + + from omegaconf import OmegaConf + + class BaseProcessor: + def __init__(self): + self.transform = lambda x: x + return + + def __call__(self, item): + return self.transform(item) + + @classmethod + def from_config(cls, cfg=None): + return cls() + + def build(self, **kwargs): + cfg = OmegaConf.create(kwargs) + + return self.from_config(cfg) + +This allows us to standardize operations of processors across all processor classes while still allowing customization of processors specifically to data and model types. +We encourage users not to modify the implementation of the base processor class as this will have an impact on all existing processor subclasses. + +GPT-style Processors ``lavis.processors.gpt_processors`` +************************************************************** +In this step, we can define new processor classes, e.g. under ``lavis.processors.gpt_processors``, for GPT models designed specifically for video-grounded dialogues. +First, we want to process video features by defining ``GPTVideoFeatureProcessor`` class. +In this tutorial, we assume video features are extracted beforehand and this processor simply loads the features from ``npy`` files. +Other methods that are specifically defined are ``padding`` (which is used by dataset instances to pad multiple video samples) and ``get_attention_mask`` (which creates an attention mask for Transformer attention in GPT models). + +.. code-block:: python + + SPECIAL_TOKENS_DICT = {'bos_token': "", 'eos_token': "", 'additional_special_tokens': ["", "", "