To download and prepare the datasets, please check +our [first stage dataset preparation instruction](dataset/README_1_STAGE.md). +After the first stage, the visual features are mapped and can be understood by the language +model. +To launch the first stage training, run the following command. In our experiments, we use 4 A100. +You can change the save path in the config file +[train_configs/minigpt4_stage1_pretrain.yaml](train_configs/minigpt4_stage1_pretrain.yaml) + +```bash +torchrun --nproc-per-node NUM_GPU train.py --cfg-path train_configs/minigpt4_stage1_pretrain.yaml +``` + +A MiniGPT-4 checkpoint with only stage one training can be downloaded +[here (13B)](https://drive.google.com/file/d/1u9FRRBB3VovP1HxCAlpD9Lw4t4P6-Yq8/view?usp=share_link) or [here (7B)](https://drive.google.com/file/d/1HihQtCEXUyBM1i9DQbaK934wW3TZi-h5/view?usp=share_link). +Compared to the model after stage two, this checkpoint generate incomplete and repeated sentences frequently. + + +**2. Second finetuning stage** + +In the second stage, we use a small high quality image-text pair dataset created by ourselves +and convert it to a conversation format to further align MiniGPT-4. +To download and prepare our second stage dataset, please check our +[second stage dataset preparation instruction](dataset/README_2_STAGE.md). +To launch the second stage alignment, +first specify the path to the checkpoint file trained in stage 1 in +[train_configs/minigpt4_stage1_pretrain.yaml](train_configs/minigpt4_stage2_finetune.yaml). +You can also specify the output path there. +Then, run the following command. # MiniGPT-V

**MiniGPT-v2: Large Language Model as a Unified Interface for Vision-Language Multi-task Learning**

Jun Chen, Deyao Zhu, Xiaoqian Shen, Xiang Li, Zechun Liu, Pengchuan Zhang, Raghuraman Krishnamoorthi, Vikas Chandra, Yunyang Xiong☨, Mohamed Elhoseiny☨

☨equal last author

**MiniGPT-4: Enhancing Vision-language Understanding with Advanced Large Language Models**

Deyao Zhu*, Jun Chen*, Xiaoqian Shen, Xiang Li, Mohamed Elhoseiny

*equal contribution

*King Abdullah University of Science and Technology*

## News
[Oct.13 2023] Breaking! We release the first major update with our MiniGPT-v2 + +[Aug.28 2023] We now provide a llama 2 version of MiniGPT-4 + +## Online Demo + +Click the image to chat with MiniGPT-v2 around your images +[![demo](figs/minigpt2_demo.png)](https://minigpt-v2.github.io/) + +Click the image to chat with MiniGPT-4 around your images +[![demo](figs/online_demo.png)](https://minigpt-4.github.io) + + +## MiniGPT-v2 Examples + +![MiniGPT-v2 demos](figs/demo.png) + + + +## MiniGPT-4 Examples + | | | +:-------------------------:|:-------------------------: +![find wild](figs/examples/wop_2.png) | ![write story](figs/examples/ad_2.png) +![solve problem](figs/examples/fix_1.png) | ![write Poem](figs/examples/rhyme_1.png) + +More examples can be found in the [project page](https://minigpt-4.github.io). + + + +## Getting Started +### Installation + +**1. Prepare the code and the environment** + +Git clone our repository, creating a python environment and activate it via the following command + +```bash +git clone https://github.com/Vision-CAIR/MiniGPT-4.git +cd MiniGPT-4 +conda env create -f environment.yml +conda activate minigpt4 +``` + + +**2. Prepare the pretrained LLM weights** + +**MiniGPT-v2** is based on Llama2 Chat 7B. For **MiniGPT-4**, we have both Vicuna V0 and Llama 2 version. +Download the corresponding LLM weights from the following huggingface space via clone the repository using git-lfs. + +| Llama 2 Chat 7B | Vicuna V0 13B | Vicuna V0 7B | +:------------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------: +[Download](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/tree/main) | [Downlad](https://huggingface.co/Vision-CAIR/vicuna/tree/main) | [Download](https://huggingface.co/Vision-CAIR/vicuna-7b/tree/main) + + +Then, set the variable *llama_model* in the model config file to the LLM weight path. + +* For MiniGPT-v2, set the LLM path +[here](minigpt4/configs/models/minigpt_v2.yaml#L15) at Line 14. + +* For MiniGPT-4 (Llama2), set the LLM path +[here](minigpt4/configs/models/minigpt4_llama2.yaml#L15) at Line 15. + +* For MiniGPT-4 (Vicuna), set the LLM path +[here](minigpt4/configs/models/minigpt4_vicuna0.yaml#L18) at Line 18 + +**3. Prepare the pretrained model checkpoints** + +Download the pretrained model checkpoints + + +| MiniGPT-v2 (LLaMA-2 Chat 7B) | +|------------------------------| +| [Download](https://drive.google.com/file/d/1aVbfW7nkCSYx99_vCRyP1sOlQiWVSnAl/view?usp=sharing) | + +For **MiniGPT-v2**, set the path to the pretrained checkpoint in the evaluation config file +in [eval_configs/minigptv2_eval.yaml](eval_configs/minigptv2_eval.yaml#L10) at Line 8. + + + +| MiniGPT-4 (Vicuna 13B) | MiniGPT-4 (Vicuna 7B) | MiniGPT-4 (LLaMA-2 Chat 7B) | +|----------------------------|---------------------------|---------------------------------| +| [Download](https://drive.google.com/file/d/1a4zLvaiDBr-36pasffmgpvH5P7CKmpze/view?usp=share_link) | [Download](https://drive.google.com/file/d/1RY9jV0dyqLX-o38LrumkKRh6Jtaop58R/view?usp=sharing) | [Download](https://drive.google.com/file/d/11nAPjEok8eAGGEG1N2vXo3kBLCg0WgUk/view?usp=sharing) | + +For **MiniGPT-4**, set the path to the pretrained checkpoint in the evaluation config file +in [eval_configs/minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml#L10) at Line 8 for Vicuna version or [eval_configs/minigpt4_llama2_eval.yaml](eval_configs/minigpt4_llama2_eval.yaml#L10) for LLama2 version. + + + +### Launching Demo Locally + +For MiniGPT-v2, run +``` +python demo_v2.py --cfg-path eval_configs/minigpt4v2_eval.yaml --gpu-id 0 +``` + +For MiniGPT-4 (Vicuna version), run + +``` +python demo.py --cfg-path eval_configs/minigpt4_eval.yaml --gpu-id 0 +``` + +For MiniGPT-4 (Llama2 version), run + +``` +python demo.py --cfg-path eval_configs/minigpt4_llama2_eval.yaml --gpu-id 0 +``` + + +To save GPU memory, LLMs loads as 8 bit by default, with a beam search width of 1. +This configuration requires about 23G GPU memory for 13B LLM and 11.5G GPU memory for 7B LLM. +For more powerful GPUs, you can run the model +in 16 bit by setting `low_resource` to `False` in the relevant config file: + +* MiniGPT-v2: [minigptv2_eval.yaml](eval_configs/minigptv2_eval.yaml#6) +* MiniGPT-4 (Llama2): [minigpt4_llama2_eval.yaml](eval_configs/minigpt4_llama2_eval.yaml#6) +* MiniGPT-4 (Vicuna): [minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml#6) + +Thanks [@WangRongsheng](https://github.com/WangRongsheng), you can also run MiniGPT-4 on [Colab](https://colab.research.google.com/drive/1OK4kYsZphwt5DXchKkzMBjYF6jnkqh4R?usp=sharing) + + +### Training +For training details of MiniGPT-4, check [here](MiniGPT4_Train.md). + + + + +## Acknowledgement + ++ [BLIP2](https://huggingface.co/docs/transformers/main/model_doc/blip-2) The model architecture of MiniGPT-4 follows BLIP-2. Don't forget to check this great open-source work if you don't know it before! ++ [Lavis](https://github.com/salesforce/LAVIS) This repository is built upon Lavis! ++ [Vicuna](https://github.com/lm-sys/FastChat) The fantastic language ability of Vicuna with only 13B parameters is just amazing. And it is open-source! ++ [LLaMA](https://github.com/facebookresearch/llama) The strong open-sourced LLaMA 2 language model. + + +If you're using MiniGPT-4/MiniGPT-v2 in your research or applications, please cite using this BibTeX: +```bibtex + +@article{Chen2023minigpt, + title={MiniGPT-v2: Large Language Model as a Unified Interface for Vision-Language Multi-task Learning}, + author={Chen, Jun and Zhu, Deyao and Shen, Xiaoqian and Li, Xiang and Liu, Zechu and Zhang, Pengchuan and Krishnamoorthi, Raghuraman and Chandra, Vikas and Xiong, Yunyang and Elhoseiny, Mohamed}, + journal={github}, + year={2023} +} + +@article{zhu2023minigpt, + title={MiniGPT-4: Enhancing Vision-Language Understanding with Advanced Large Language Models}, + author={Zhu, Deyao and Chen, Jun and Shen, Xiaoqian and Li, Xiang and Elhoseiny, Mohamed}, + journal={arXiv preprint arXiv:2304.10592}, + year={2023} +} +``` + + +## License +This repository is under [BSD 3-Clause License](LICENSE.md). +Many codes are based on [Lavis](https://github.com/salesforce/LAVIS) with +BSD 3-Clause License [here](LICENSE_Lavis.md). diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000000000000000000000000000000000000..034e848032092eaf8ef96eac731b6ed5961987f3 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,21 @@ +# Security Policy + +## Supported Versions + +Use this section to tell people about which versions of your project are +currently being supported with security updates. + +| Version | Supported | +| ------- | ------------------ | +| 5.1.x | :white_check_mark: | +| 5.0.x | :x: | +| 4.0.x | :white_check_mark: | +| < 4.0 | :x: | + +## Reporting a Vulnerability + +Use this section to tell people how to report a vulnerability. + +Tell them where to go, how often they can expect to get an update on a +reported vulnerability, what to expect if the vulnerability is accepted or +declined, etc. diff --git a/dataset/README_1_STAGE.md b/dataset/README_1_STAGE.md new file mode 100644 index 0000000000000000000000000000000000000000..47ffaaef6ddf1677bb467e116e03b039febda759 --- /dev/null +++ b/dataset/README_1_STAGE.md @@ -0,0 +1,96 @@ +## Download the filtered Conceptual Captions, SBU, LAION datasets + +### Pre-training datasets download: +We use the filtered synthetic captions prepared by BLIP. For more details about the dataset, please refer to [BLIP](https://github.com/salesforce/BLIP). + +It requires ~2.3T to store LAION and CC3M+CC12M+SBU datasets + +Image source | Filtered synthetic caption by ViT-L +--- | :---: +CC3M+CC12M+SBU | Download +LAION115M | Download + +This will download two json files +``` +ccs_synthetic_filtered_large.json +laion_synthetic_filtered_large.json +``` + +## prepare the data step-by-step + + +### setup the dataset folder and move the annotation file to the data storage folder +``` +export MINIGPT4_DATASET=/YOUR/PATH/FOR/LARGE/DATASET/ +mkdir ${MINIGPT4_DATASET}/cc_sbu +mkdir ${MINIGPT4_DATASET}/laion +mv ccs_synthetic_filtered_large.json ${MINIGPT4_DATASET}/cc_sbu +mv laion_synthetic_filtered_large.json ${MINIGPT4_DATASET}/laion +``` + +### Convert the scripts to data storate folder +``` +cp convert_cc_sbu.py ${MINIGPT4_DATASET}/cc_sbu +cp download_cc_sbu.sh ${MINIGPT4_DATASET}/cc_sbu +cp convert_laion.py ${MINIGPT4_DATASET}/laion +cp download_laion.sh ${MINIGPT4_DATASET}/laion +``` + + +### Convert the laion and cc_sbu annotation file format to be img2dataset format +``` +cd ${MINIGPT4_DATASET}/cc_sbu +python convert_cc_sbu.py + +cd ${MINIGPT4_DATASET}/laion +python convert_laion.py +``` + +### Download the datasets with img2dataset +``` +cd ${MINIGPT4_DATASET}/cc_sbu +sh download_cc_sbu.sh +cd ${MINIGPT4_DATASET}/laion +sh download_laion.sh +``` + + +The final dataset structure + +``` +. +β”œβ”€β”€ ${MINIGPT4_DATASET} +β”‚ β”œβ”€β”€ cc_sbu +β”‚ β”œβ”€β”€ convert_cc_sbu.py +β”‚ β”œβ”€β”€ download_cc_sbu.sh +β”‚ β”œβ”€β”€ ccs_synthetic_filtered_large.json +β”‚ β”œβ”€β”€ ccs_synthetic_filtered_large.tsv +β”‚ └── cc_sbu_dataset +β”‚ β”œβ”€β”€ 00000.tar +β”‚ β”œβ”€β”€ 00000.parquet +β”‚ ... +β”‚ β”œβ”€β”€ laion +β”‚ β”œβ”€β”€ convert_laion.py +β”‚ β”œβ”€β”€ download_laion.sh +β”‚ β”œβ”€β”€ laion_synthetic_filtered_large.json +β”‚ β”œβ”€β”€ laion_synthetic_filtered_large.tsv +β”‚ └── laion_dataset +β”‚ β”œβ”€β”€ 00000.tar +β”‚ β”œβ”€β”€ 00000.parquet +β”‚ ... +... +``` + + +## Set up the dataset configuration files + +Then, set up the LAION dataset loading path in +[here](../minigpt4/configs/datasets/laion/defaults.yaml#L5) at Line 5 as +${MINIGPT4_DATASET}/laion/laion_dataset/{00000..10488}.tar + +and the Conceptual Captoin and SBU datasets loading path in +[here](../minigpt4/configs/datasets/cc_sbu/defaults.yaml#L5) at Line 5 as +${MINIGPT4_DATASET}/cc_sbu/cc_sbu_dataset/{00000..01255}.tar + + + diff --git a/dataset/README_2_STAGE.md b/dataset/README_2_STAGE.md new file mode 100644 index 0000000000000000000000000000000000000000..b826765fef6675ab5f2e3a5dfe619d2e351614d3 --- /dev/null +++ b/dataset/README_2_STAGE.md @@ -0,0 +1,19 @@ +## Second Stage Data Preparation + +Our second stage dataset can be downloaded from +[here](https://drive.google.com/file/d/1nJXhoEcy3KTExr17I7BXqY5Y9Lx_-n-9/view?usp=share_link) +After extraction, you will get a data follder with the following structure: + +``` +cc_sbu_align +β”œβ”€β”€ filter_cap.json +└── image + β”œβ”€β”€ 2.jpg + β”œβ”€β”€ 3.jpg + ... +``` + +Put the folder to any path you want. +Then, set up the dataset path in the dataset config file +[here](../minigpt4/configs/datasets/cc_sbu/align.yaml#L5) at Line 5. + diff --git a/dataset/convert_cc_sbu.py b/dataset/convert_cc_sbu.py new file mode 100644 index 0000000000000000000000000000000000000000..8c325ed3afa3ddb81c5535b5a6febc23d3d5ceee --- /dev/null +++ b/dataset/convert_cc_sbu.py @@ -0,0 +1,20 @@ +import json +import csv + +# specify input and output file paths +input_file = 'ccs_synthetic_filtered_large.json' +output_file = 'ccs_synthetic_filtered_large.tsv' + +# load JSON data from input file +with open(input_file, 'r') as f: + data = json.load(f) + +# extract header and data from JSON +header = data[0].keys() +rows = [x.values() for x in data] + +# write data to TSV file +with open(output_file, 'w') as f: + writer = csv.writer(f, delimiter='\t') + writer.writerow(header) + writer.writerows(rows) diff --git a/dataset/convert_laion.py b/dataset/convert_laion.py new file mode 100644 index 0000000000000000000000000000000000000000..b793579ce276b72a4313bba4f237b8cb0becb294 --- /dev/null +++ b/dataset/convert_laion.py @@ -0,0 +1,20 @@ +import json +import csv + +# specify input and output file paths +input_file = 'laion_synthetic_filtered_large.json' +output_file = 'laion_synthetic_filtered_large.tsv' + +# load JSON data from input file +with open(input_file, 'r') as f: + data = json.load(f) + +# extract header and data from JSON +header = data[0].keys() +rows = [x.values() for x in data] + +# write data to TSV file +with open(output_file, 'w') as f: + writer = csv.writer(f, delimiter='\t') + writer.writerow(header) + writer.writerows(rows) diff --git a/dataset/download_cc_sbu.sh b/dataset/download_cc_sbu.sh new file mode 100644 index 0000000000000000000000000000000000000000..64082eee0466bdad0fb5d377f4501758a82e805c --- /dev/null +++ b/dataset/download_cc_sbu.sh @@ -0,0 +1,6 @@ +#!/bin/bash + +img2dataset --url_list ccs_synthetic_filtered_large.tsv --input_format "tsv"\ + --url_col "url" --caption_col "caption" --output_format webdataset\ + --output_folder cc_sbu_dataset --processes_count 16 --thread_count 128 --image_size 224 \ + --enable_wandb True diff --git a/dataset/download_laion.sh b/dataset/download_laion.sh new file mode 100644 index 0000000000000000000000000000000000000000..42beb0c9af3535ef55045a1e8a1333d623f540ad --- /dev/null +++ b/dataset/download_laion.sh @@ -0,0 +1,6 @@ +#!/bin/bash + +img2dataset --url_list laion_synthetic_filtered_large.tsv --input_format "tsv"\ + --url_col "url" --caption_col "caption" --output_format webdataset\ + --output_folder laion_dataset --processes_count 16 --thread_count 128 --image_size 224 \ + --enable_wandb True diff --git a/demo.py b/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..c7646c43b51d59a29d5d6fe872c34c27c14981e5 --- /dev/null +++ b/demo.py @@ -0,0 +1,171 @@ +import argparse +import os +import random + +import numpy as np +import torch +import torch.backends.cudnn as cudnn +import gradio as gr + +from transformers import StoppingCriteriaList + +from minigpt4.common.config import Config +from minigpt4.common.dist_utils import get_rank +from minigpt4.common.registry import registry +from minigpt4.conversation.conversation import Chat, CONV_VISION_Vicuna0, CONV_VISION_LLama2, StoppingCriteriaSub + +# imports modules for registration +from minigpt4.datasets.builders import * +from minigpt4.models import * +from minigpt4.processors import * +from minigpt4.runners import * +from minigpt4.tasks import * + + +def parse_args(): + parser = argparse.ArgumentParser(description="Demo") + parser.add_argument("--cfg-path", required=True, help="path to configuration file.") + parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.") + parser.add_argument( + "--options", + nargs="+", + help="override some settings in the used config, the key-value pair " + "in xxx=yyy format will be merged into config file (deprecate), " + "change to --cfg-options instead.", + ) + args = parser.parse_args() + return args + + +def setup_seeds(config): + seed = config.run_cfg.seed + get_rank() + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + cudnn.benchmark = False + cudnn.deterministic = True + + +# ======================================== +# Model Initialization +# ======================================== + +conv_dict = {'pretrain_vicuna0': CONV_VISION_Vicuna0, + 'pretrain_llama2': CONV_VISION_LLama2} + +print('Initializing Chat') +args = parse_args() +cfg = Config(args) + +model_config = cfg.model_cfg +model_config.device_8bit = args.gpu_id +model_cls = registry.get_model_class(model_config.arch) +model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id)) + +CONV_VISION = conv_dict[model_config.model_type] + +vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train +vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg) + +stop_words_ids = [[835], [2277, 29937]] +stop_words_ids = [torch.tensor(ids).to(device='cuda:{}'.format(args.gpu_id)) for ids in stop_words_ids] +stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) + +chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id), stopping_criteria=stopping_criteria) +print('Initialization Finished') + + +# ======================================== +# Gradio Setting +# ======================================== + + +def gradio_reset(chat_state, img_list): + if chat_state is not None: + chat_state.messages = [] + if img_list is not None: + img_list = [] + return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first', interactive=False),gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list + + +def upload_img(gr_img, text_input, chat_state): + if gr_img is None: + return None, None, gr.update(interactive=True), chat_state, None + chat_state = CONV_VISION.copy() + img_list = [] + llm_message = chat.upload_img(gr_img, chat_state, img_list) + chat.encode_img(img_list) + return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list + + +def gradio_ask(user_message, chatbot, chat_state): + if len(user_message) == 0: + return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state + chat.ask(user_message, chat_state) + chatbot = chatbot + [[user_message, None]] + return '', chatbot, chat_state + + +def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature): + llm_message = chat.answer(conv=chat_state, + img_list=img_list, + num_beams=num_beams, + temperature=temperature, + max_new_tokens=300, + max_length=2000)[0] + chatbot[-1][1] = llm_message + return chatbot, chat_state, img_list + + +title = """

Demo of MiniGPT-4

""" +description = """

This is the demo of MiniGPT-4. Upload your images and start chatting!

""" +article = """

+""" + +#TODO show examples below + +with gr.Blocks() as demo: + gr.Markdown(title) + gr.Markdown(description) + gr.Markdown(article) + + with gr.Row(): + with gr.Column(scale=1): + image = gr.Image(type="pil") + upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary") + clear = gr.Button("Restart") + + num_beams = gr.Slider( + minimum=1, + maximum=10, + value=1, + step=1, + interactive=True, + label="beam search numbers)", + ) + + temperature = gr.Slider( + minimum=0.1, + maximum=2.0, + value=1.0, + step=0.1, + interactive=True, + label="Temperature", + ) + + with gr.Column(scale=2): + chat_state = gr.State() + img_list = gr.State() + chatbot = gr.Chatbot(label='MiniGPT-4') + text_input = gr.Textbox(label='User', placeholder='Please upload your image first', interactive=False) + + upload_button.click(upload_img, [image, text_input, chat_state], [image, text_input, upload_button, chat_state, img_list]) + + text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then( + gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list] + ) + clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list], queue=False) + +demo.launch(share=True, enable_queue=True) diff --git a/demo_v2.py b/demo_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..4f66d533e58f822ac5dddd84f5407823fe641c0f --- /dev/null +++ b/demo_v2.py @@ -0,0 +1,662 @@ +import argparse +import os +import random +from collections import defaultdict + +import cv2 +import re + +import numpy as np +from PIL import Image +import torch +import html +import gradio as gr + +import torchvision.transforms as T +import torch.backends.cudnn as cudnn + +from minigpt4.common.config import Config + +from minigpt4.common.registry import registry +from minigpt4.conversation.conversation import Conversation, SeparatorStyle, Chat + +# imports modules for registration +from minigpt4.datasets.builders import * +from minigpt4.models import * +from minigpt4.processors import * +from minigpt4.runners import * +from minigpt4.tasks import * + + +def parse_args(): + parser = argparse.ArgumentParser(description="Demo") + parser.add_argument("--cfg-path", default='eval_configs/minigptv2_eval.yaml', + help="path to configuration file.") + parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.") + parser.add_argument( + "--options", + nargs="+", + help="override some settings in the used config, the key-value pair " + "in xxx=yyy format will be merged into config file (deprecate), " + "change to --cfg-options instead.", + ) + args = parser.parse_args() + return args + + +random.seed(42) +np.random.seed(42) +torch.manual_seed(42) + +cudnn.benchmark = False +cudnn.deterministic = True + +print('Initializing Chat') +args = parse_args() +cfg = Config(args) + +device = 'cuda:{}'.format(args.gpu_id) + +model_config = cfg.model_cfg +model_config.device_8bit = args.gpu_id +model_cls = registry.get_model_class(model_config.arch) +model = model_cls.from_config(model_config).to(device) +bounding_box_size = 100 + +vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train +vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg) + +model = model.eval() + +CONV_VISION = Conversation( + system="", + roles=(r"[INST] ", r" [/INST]"), + messages=[], + offset=2, + sep_style=SeparatorStyle.SINGLE, + sep="", +) + + +def extract_substrings(string): + # first check if there is no-finished bracket + index = string.rfind('}') + if index != -1: + string = string[:index + 1] + + pattern = r'

(.*?)\}(?!<)' + matches = re.findall(pattern, string) + substrings = [match for match in matches] + + return substrings + + +def is_overlapping(rect1, rect2): + x1, y1, x2, y2 = rect1 + x3, y3, x4, y4 = rect2 + return not (x2 < x3 or x1 > x4 or y2 < y3 or y1 > y4) + + +def computeIoU(bbox1, bbox2): + x1, y1, x2, y2 = bbox1 + x3, y3, x4, y4 = bbox2 + intersection_x1 = max(x1, x3) + intersection_y1 = max(y1, y3) + intersection_x2 = min(x2, x4) + intersection_y2 = min(y2, y4) + intersection_area = max(0, intersection_x2 - intersection_x1 + 1) * max(0, intersection_y2 - intersection_y1 + 1) + bbox1_area = (x2 - x1 + 1) * (y2 - y1 + 1) + bbox2_area = (x4 - x3 + 1) * (y4 - y3 + 1) + union_area = bbox1_area + bbox2_area - intersection_area + iou = intersection_area / union_area + return iou + + +def save_tmp_img(visual_img): + file_name = "".join([str(random.randint(0, 9)) for _ in range(5)]) + ".jpg" + file_path = "/tmp/" + file_name + visual_img.save(file_path) + return file_path + + +def mask2bbox(mask): + if mask is None: + return '' + mask = mask.resize([100, 100], resample=Image.NEAREST) + mask = np.array(mask)[:, :, 0] + + rows = np.any(mask, axis=1) + cols = np.any(mask, axis=0) + + if rows.sum(): + # Get the top, bottom, left, and right boundaries + rmin, rmax = np.where(rows)[0][[0, -1]] + cmin, cmax = np.where(cols)[0][[0, -1]] + bbox = '{{<{}><{}><{}><{}>}}'.format(cmin, rmin, cmax, rmax) + else: + bbox = '' + + return bbox + + +def escape_markdown(text): + # List of Markdown special characters that need to be escaped + md_chars = ['<', '>'] + + # Escape each special character + for char in md_chars: + text = text.replace(char, '\\' + char) + + return text + + +def reverse_escape(text): + md_chars = ['\\<', '\\>'] + + for char in md_chars: + text = text.replace(char, char[1:]) + + return text + + +colors = [ + (255, 0, 0), + (0, 255, 0), + (0, 0, 255), + (210, 210, 0), + (255, 0, 255), + (0, 255, 255), + (114, 128, 250), + (0, 165, 255), + (0, 128, 0), + (144, 238, 144), + (238, 238, 175), + (255, 191, 0), + (0, 128, 0), + (226, 43, 138), + (255, 0, 255), + (0, 215, 255), +] + +color_map = { + f"{color_id}": f"#{hex(color[2])[2:].zfill(2)}{hex(color[1])[2:].zfill(2)}{hex(color[0])[2:].zfill(2)}" for + color_id, color in enumerate(colors) +} + +used_colors = colors + + +def visualize_all_bbox_together(image, generation): + if image is None: + return None, '' + + generation = html.unescape(generation) + print('gen begin', generation) + + image_width, image_height = image.size + image = image.resize([500, int(500 / image_width * image_height)]) + image_width, image_height = image.size + + string_list = extract_substrings(generation) + if string_list: # it is grounding or detection + mode = 'all' + entities = defaultdict(list) + i = 0 + j = 0 + for string in string_list: + try: + obj, string = string.split('

') + except ValueError: + print('wrong string: ', string) + continue + bbox_list = string.split('') + flag = False + for bbox_string in bbox_list: + integers = re.findall(r'-?\d+', bbox_string) + if len(integers) == 4: + x0, y0, x1, y1 = int(integers[0]), int(integers[1]), int(integers[2]), int(integers[3]) + left = x0 / bounding_box_size * image_width + bottom = y0 / bounding_box_size * image_height + right = x1 / bounding_box_size * image_width + top = y1 / bounding_box_size * image_height + + entities[obj].append([left, bottom, right, top]) + + j += 1 + flag = True + if flag: + i += 1 + else: + integers = re.findall(r'-?\d+', generation) + + if len(integers) == 4: # it is refer + mode = 'single' + + entities = list() + x0, y0, x1, y1 = int(integers[0]), int(integers[1]), int(integers[2]), int(integers[3]) + left = x0 / bounding_box_size * image_width + bottom = y0 / bounding_box_size * image_height + right = x1 / bounding_box_size * image_width + top = y1 / bounding_box_size * image_height + entities.append([left, bottom, right, top]) + else: + # don't detect any valid bbox to visualize + return None, '' + + if len(entities) == 0: + return None, '' + + if isinstance(image, Image.Image): + image_h = image.height + image_w = image.width + image = np.array(image) + + elif isinstance(image, str): + if os.path.exists(image): + pil_img = Image.open(image).convert("RGB") + image = np.array(pil_img)[:, :, [2, 1, 0]] + image_h = pil_img.height + image_w = pil_img.width + else: + raise ValueError(f"invaild image path, {image}") + elif isinstance(image, torch.Tensor): + + image_tensor = image.cpu() + reverse_norm_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073])[:, None, None] + reverse_norm_std = torch.tensor([0.26862954, 0.26130258, 0.27577711])[:, None, None] + image_tensor = image_tensor * reverse_norm_std + reverse_norm_mean + pil_img = T.ToPILImage()(image_tensor) + image_h = pil_img.height + image_w = pil_img.width + image = np.array(pil_img)[:, :, [2, 1, 0]] + else: + raise ValueError(f"invaild image format, {type(image)} for {image}") + + indices = list(range(len(entities))) + + new_image = image.copy() + + previous_bboxes = [] + # size of text + text_size = 0.5 + # thickness of text + text_line = 1 # int(max(1 * min(image_h, image_w) / 512, 1)) + box_line = 2 + (c_width, text_height), _ = cv2.getTextSize("F", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line) + base_height = int(text_height * 0.675) + text_offset_original = text_height - base_height + text_spaces = 2 + + # num_bboxes = sum(len(x[-1]) for x in entities) + used_colors = colors # random.sample(colors, k=num_bboxes) + + color_id = -1 + for entity_idx, entity_name in enumerate(entities): + if mode == 'single' or mode == 'identify': + bboxes = entity_name + bboxes = [bboxes] + else: + bboxes = entities[entity_name] + color_id += 1 + for bbox_id, (x1_norm, y1_norm, x2_norm, y2_norm) in enumerate(bboxes): + skip_flag = False + orig_x1, orig_y1, orig_x2, orig_y2 = int(x1_norm), int(y1_norm), int(x2_norm), int(y2_norm) + + color = used_colors[entity_idx % len(used_colors)] # tuple(np.random.randint(0, 255, size=3).tolist()) + new_image = cv2.rectangle(new_image, (orig_x1, orig_y1), (orig_x2, orig_y2), color, box_line) + + if mode == 'all': + l_o, r_o = box_line // 2 + box_line % 2, box_line // 2 + box_line % 2 + 1 + + x1 = orig_x1 - l_o + y1 = orig_y1 - l_o + + if y1 < text_height + text_offset_original + 2 * text_spaces: + y1 = orig_y1 + r_o + text_height + text_offset_original + 2 * text_spaces + x1 = orig_x1 + r_o + + # add text background + (text_width, text_height), _ = cv2.getTextSize(f" {entity_name}", cv2.FONT_HERSHEY_COMPLEX, text_size, + text_line) + text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2 = x1, y1 - ( + text_height + text_offset_original + 2 * text_spaces), x1 + text_width, y1 + + for prev_bbox in previous_bboxes: + if computeIoU((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox['bbox']) > 0.95 and \ + prev_bbox['phrase'] == entity_name: + skip_flag = True + break + while is_overlapping((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox['bbox']): + text_bg_y1 += (text_height + text_offset_original + 2 * text_spaces) + text_bg_y2 += (text_height + text_offset_original + 2 * text_spaces) + y1 += (text_height + text_offset_original + 2 * text_spaces) + + if text_bg_y2 >= image_h: + text_bg_y1 = max(0, image_h - (text_height + text_offset_original + 2 * text_spaces)) + text_bg_y2 = image_h + y1 = image_h + break + if not skip_flag: + alpha = 0.5 + for i in range(text_bg_y1, text_bg_y2): + for j in range(text_bg_x1, text_bg_x2): + if i < image_h and j < image_w: + if j < text_bg_x1 + 1.35 * c_width: + # original color + bg_color = color + else: + # white + bg_color = [255, 255, 255] + new_image[i, j] = (alpha * new_image[i, j] + (1 - alpha) * np.array(bg_color)).astype( + np.uint8) + + cv2.putText( + new_image, f" {entity_name}", (x1, y1 - text_offset_original - 1 * text_spaces), + cv2.FONT_HERSHEY_COMPLEX, text_size, (0, 0, 0), text_line, cv2.LINE_AA + ) + + previous_bboxes.append( + {'bbox': (text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), 'phrase': entity_name}) + + if mode == 'all': + def color_iterator(colors): + while True: + for color in colors: + yield color + + color_gen = color_iterator(colors) + + # Add colors to phrases and remove

+ def colored_phrases(match): + phrase = match.group(1) + color = next(color_gen) + return f'{phrase}' + + print('gen before', generation) + generation = re.sub(r'{<\d+><\d+><\d+><\d+>}|', '', generation) + print('gen after', generation) + generation_colored = re.sub(r'


', colored_phrases, generation) + else: + generation_colored = '' + + pil_image = Image.fromarray(new_image) + return pil_image, generation_colored + + +def gradio_reset(chat_state, img_list): + if chat_state is not None: + chat_state.messages = [] + if img_list is not None: + img_list = [] + return None, gr.update(value=None, interactive=True), gr.update(placeholder='Upload your image and chat', + interactive=True), chat_state, img_list + + +def image_upload_trigger(upload_flag, replace_flag, img_list): + # set the upload flag to true when receive a new image. + # if there is an old image (and old conversation), set the replace flag to true to reset the conv later. + print('flag', upload_flag, replace_flag) + print("SET UPLOAD FLAG!") + upload_flag = 1 + if img_list: + print("SET REPLACE FLAG!") + replace_flag = 1 + print('flag', upload_flag, replace_flag) + return upload_flag, replace_flag + + +def example_trigger(text_input, image, upload_flag, replace_flag, img_list): + # set the upload flag to true when receive a new image. + # if there is an old image (and old conversation), set the replace flag to true to reset the conv later. + print('flag', upload_flag, replace_flag) + print("SET UPLOAD FLAG!") + upload_flag = 1 + if img_list or replace_flag == 1: + print("SET REPLACE FLAG!") + replace_flag = 1 + + print('flag', upload_flag, replace_flag) + return upload_flag, replace_flag + + +def gradio_ask(user_message, chatbot, chat_state, gr_img, img_list, upload_flag, replace_flag): + if isinstance(gr_img, dict): + gr_img, mask = gr_img['image'], gr_img['mask'] + else: + mask = None + + if '[identify]' in user_message: + # check if user provide bbox in the text input + integers = re.findall(r'-?\d+', user_message) + if len(integers) != 4: # no bbox in text + bbox = mask2bbox(mask) + user_message = user_message + bbox + + if len(user_message) == 0: + return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state + + if chat_state is None: + chat_state = CONV_VISION.copy() + + print('upload flag: {}'.format(upload_flag)) + if upload_flag: + if replace_flag: + print('RESET!!!!!!!') + chat_state = CONV_VISION.copy() # new image, reset everything + replace_flag = 0 + chatbot = [] + print('UPLOAD IMAGE!!') + img_list = [] + llm_message = chat.upload_img(gr_img, chat_state, img_list) + upload_flag = 0 + + chat.ask(user_message, chat_state) + + chatbot = chatbot + [[user_message, None]] + + if '[identify]' in user_message: + visual_img, _ = visualize_all_bbox_together(gr_img, user_message) + if visual_img is not None: + print('Visualizing the input') + file_path = save_tmp_img(visual_img) + chatbot = chatbot + [[(file_path,), None]] + + return '', chatbot, chat_state, img_list, upload_flag, replace_flag + + +def gradio_answer(chatbot, chat_state, img_list, temperature): + llm_message = chat.answer(conv=chat_state, + img_list=img_list, + temperature=temperature, + max_new_tokens=500, + max_length=2000)[0] + chatbot[-1][1] = llm_message + return chatbot, chat_state + + +def gradio_stream_answer(chatbot, chat_state, img_list, temperature): + print('chat state', chat_state.get_prompt()) + if not isinstance(img_list[0], torch.Tensor): + chat.encode_img(img_list) + streamer = chat.stream_answer(conv=chat_state, + img_list=img_list, + temperature=temperature, + max_new_tokens=500, + max_length=2000) + output = '' + for new_output in streamer: + escapped = escape_markdown(new_output) + output += escapped + chatbot[-1][1] = output + yield chatbot, chat_state + # print('message: ', chat_state.messages) + chat_state.messages[-1][1] = '
' + return chatbot, chat_state + + +def gradio_visualize(chatbot, gr_img): + if isinstance(gr_img, dict): + gr_img, mask = gr_img['image'], gr_img['mask'] + + unescaped = reverse_escape(chatbot[-1][1]) + visual_img, generation_color = visualize_all_bbox_together(gr_img, unescaped) + if visual_img is not None: + print('Visualizing the output') + if len(generation_color): + chatbot[-1][1] = generation_color + file_path = save_tmp_img(visual_img) + chatbot = chatbot + [[None, (file_path,)]] + + return chatbot + + +def gradio_taskselect(idx): + prompt_list = [ + '', + '[grounding] describe this image in detail', + '[refer] ', + '[detection] ', + '[identify] what is this ', + '[vqa] ' + ] + instruct_list = [ + '**Hint:** Type in whatever you want', + '**Hint:** Send the command to generate a grounded image description', + '**Hint:** Type in a phrase about an object in the image and send the command', + '**Hint:** Type in a caption or phrase, and see object locations in the image', + '**Hint:** Draw a bounding box on the uploaded image then send the command. Click the "clear" botton on the top right of the image before redraw', + '**Hint:** Send a question to get a short answer', + ] + return prompt_list[idx], instruct_list[idx] + + + + +chat = Chat(model, vis_processor, device=device) + +title = """

MiniGPT-v2 Demo

""" +description = 'Welcome to Our MiniGPT-v2 Chatbot Demo!' +# article = """

""" +article = """

""" + +introduction = ''' +For Abilities Involving Visual Grounding: +1. Grounding: CLICK **Send** to generate a grounded image description. +2. Refer: Input a referring object and CLICK **Send**. +3. Detection: Write a caption or phrase, and CLICK **Send**. +4. Identify: Draw the bounding box on the uploaded image window and CLICK **Send** to generate the bounding box. (CLICK "clear" button before re-drawing next time). +5. VQA: Input a visual question and CLICK **Send**. +6. No Tag: Input whatever you want and CLICK **Send** without any tagging + +You can also simply chat in free form! +''' + +text_input = gr.Textbox(placeholder='Upload your image and chat', interactive=True, show_label=False, container=False, + scale=8) +with gr.Blocks() as demo: + gr.Markdown(title) + # gr.Markdown(description) + gr.Markdown(article) + + with gr.Row(): + with gr.Column(scale=0.5): + image = gr.Image(type="pil", tool='sketch', brush_radius=20) + + temperature = gr.Slider( + minimum=0.1, + maximum=2.0, + value=1.0, + step=0.1, + interactive=True, + label="Temperature", + ) + + clear = gr.Button("Restart") + + gr.Markdown(introduction) + + with gr.Column(): + chat_state = gr.State(value=None) + img_list = gr.State(value=[]) + chatbot = gr.Chatbot(label='MiniGPT-v2') + + dataset = gr.Dataset( + components=[gr.Textbox(visible=False)], + samples=[['No Tag'], ['Grounding'], ['Refer'], ['Detection'], ['Identify'], ['VQA']], + type="index", + label='Task Shortcuts', + ) + task_inst = gr.Markdown('**Hint:** Upload your image and chat') + with gr.Row(): + text_input.render() + send = gr.Button("Send", variant='primary', size='sm', scale=1) + + upload_flag = gr.State(value=0) + replace_flag = gr.State(value=0) + image.upload(image_upload_trigger, [upload_flag, replace_flag, img_list], [upload_flag, replace_flag]) + + with gr.Row(): + with gr.Column(): + gr.Examples(examples=[ + ["examples_v2/office.jpg", "[grounding] describe this image in detail", upload_flag, replace_flag, + img_list], + ["examples_v2/sofa.jpg", "[detection] sofas", upload_flag, replace_flag, img_list], + ["examples_v2/2000x1372_wmkn_0012149409555.jpg", "[refer] the world cup", upload_flag, replace_flag, + img_list], + ["examples_v2/KFC-20-for-20-Nuggets.jpg", "[identify] what is this {<4><50><30><65>}", upload_flag, + replace_flag, img_list], + ], inputs=[image, text_input, upload_flag, replace_flag, img_list], fn=example_trigger, + outputs=[upload_flag, replace_flag]) + with gr.Column(): + gr.Examples(examples=[ + ["examples_v2/glip_test.jpg", "[vqa] where should I hide in this room when playing hide and seek", + upload_flag, replace_flag, img_list], + ["examples_v2/float.png", "Please write a poem about the image", upload_flag, replace_flag, img_list], + ["examples_v2/thief.png", "Is the weapon fateful", upload_flag, replace_flag, img_list], + ["examples_v2/cockdial.png", "What might happen in this image in the next second", upload_flag, + replace_flag, img_list], + ], inputs=[image, text_input, upload_flag, replace_flag, img_list], fn=example_trigger, + outputs=[upload_flag, replace_flag]) + + dataset.click( + gradio_taskselect, + inputs=[dataset], + outputs=[text_input, task_inst], + show_progress="hidden", + postprocess=False, + queue=False, + ) + + text_input.submit( + gradio_ask, + [text_input, chatbot, chat_state, image, img_list, upload_flag, replace_flag], + [text_input, chatbot, chat_state, img_list, upload_flag, replace_flag], queue=False + ).success( + gradio_stream_answer, + [chatbot, chat_state, img_list, temperature], + [chatbot, chat_state] + ).success( + gradio_visualize, + [chatbot, image], + [chatbot], + queue=False, + ) + + send.click( + gradio_ask, + [text_input, chatbot, chat_state, image, img_list, upload_flag, replace_flag], + [text_input, chatbot, chat_state, img_list, upload_flag, replace_flag], queue=False + ).success( + gradio_stream_answer, + [chatbot, chat_state, img_list, temperature], + [chatbot, chat_state] + ).success( + gradio_visualize, + [chatbot, image], + [chatbot], + queue=False, + ) + + clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, chat_state, img_list], queue=False) + +demo.launch(share=True, enable_queue=True) diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000000000000000000000000000000000000..cf90e898375a8927fd1e6e3fa9b2605b1606c176 --- /dev/null +++ b/environment.yml @@ -0,0 +1,33 @@ +name: minigpt4 +channels: + - pytorch + - defaults + - anaconda +dependencies: + - python=3.9 + - cudatoolkit + - pip + - pip: + - torch==2.0.0 + - torchaudio + - torchvision + - huggingface-hub==0.18.0 + - matplotlib==3.7.0 + - psutil==5.9.4 + - iopath + - pyyaml==6.0 + - regex==2022.10.31 + - tokenizers==0.13.2 + - tqdm==4.64.1 + - transformers==4.30.0 + - timm==0.6.13 + - webdataset==0.2.48 + - omegaconf==2.3.0 + - opencv-python== + - decord==0.6.0 + - peft==0.2.0 + - sentence-transformers + - gradio==3.47.1 + - accelerate==0.20.3 + - bitsandbytes==0.37.0 + - wandb diff --git a/eval_configs/minigpt4_eval.yaml b/eval_configs/minigpt4_eval.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d73ba3aacb5a86fc31eb02949303a37e362103d9 --- /dev/null +++ b/eval_configs/minigpt4_eval.yaml @@ -0,0 +1,22 @@ +model: + arch: minigpt4 + model_type: pretrain_vicuna0 + max_txt_len: 160 + end_sym: "###" + low_resource: True + prompt_template: '###Human: {} ###Assistant: ' + ckpt: 'please set this value to the path of pretrained checkpoint' + + +datasets: + cc_sbu_align: + vis_processor: + train: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + +run: + task: image_text_pretrain diff --git a/eval_configs/minigpt4_llama2_eval.yaml b/eval_configs/minigpt4_llama2_eval.yaml new file mode 100644 index 0000000000000000000000000000000000000000..93efab1339aaa65057c0b7ea81b2b05e0b6f84de --- /dev/null +++ b/eval_configs/minigpt4_llama2_eval.yaml @@ -0,0 +1,22 @@ +model: + arch: minigpt4 + model_type: pretrain_llama2 + max_txt_len: 160 + end_sym: "" + low_resource: True + prompt_template: '[INST] {} [/INST] ' + ckpt: 'please set this value to the path of pretrained checkpoint' + + +datasets: + cc_sbu_align: + vis_processor: + train: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + +run: + task: image_text_pretrain diff --git a/eval_configs/minigptv2_eval.yaml b/eval_configs/minigptv2_eval.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0479f2a7f35ed08113beb888bb7aef8f16ece257 --- /dev/null +++ b/eval_configs/minigptv2_eval.yaml @@ -0,0 +1,24 @@ +model: + arch: minigpt_v2 + model_type: pretrain + max_txt_len: 160 + end_sym: "" + low_resource: True + prompt_template: '[INST] {} [/INST]' + ckpt: 'please set this value to the path of pretrained checkpoint' + lora_r: 64 + lora_alpha: 16 + + +datasets: + cc_sbu_align: + vis_processor: + train: + name: "blip2_image_eval" + image_size: 448 + text_processor: + train: + name: "blip_caption" + +run: + task: image_text_pretrain diff --git a/examples/ad_1.png b/examples/ad_1.png new file mode 100644 index 0000000000000000000000000000000000000000..d0378e43e9a0e797b2ab32f4d8f6261fa2224408 Binary files /dev/null and b/examples/ad_1.png differ diff --git a/examples/ad_2.png b/examples/ad_2.png new file mode 100644 index 0000000000000000000000000000000000000000..674248b723bee885c43a55d85d83ea1c0fa41477 Binary files /dev/null and b/examples/ad_2.png differ diff --git a/examples/cook_1.png b/examples/cook_1.png new file mode 100644 index 0000000000000000000000000000000000000000..d8cdb45c98492afd4f975b8626bb590b580616a5 Binary files /dev/null and b/examples/cook_1.png differ diff --git a/examples/cook_2.png b/examples/cook_2.png new file mode 100644 index 0000000000000000000000000000000000000000..d08272b3733dda976bfa78733d9ca4eb544fee52 Binary files /dev/null and b/examples/cook_2.png differ diff --git a/examples/describe_1.png b/examples/describe_1.png new file mode 100644 index 0000000000000000000000000000000000000000..02f3c92f54749fa354a5f8c617f24301728555b2 Binary files /dev/null and b/examples/describe_1.png differ diff --git a/examples/describe_2.png b/examples/describe_2.png new file mode 100644 index 0000000000000000000000000000000000000000..20bf8c7cd86c03f9ed77f95d912057438997277d Binary files /dev/null and b/examples/describe_2.png differ diff --git a/examples/fact_1.png b/examples/fact_1.png new file mode 100644 index 0000000000000000000000000000000000000000..1f7522871916c3d1bd113cfb88c45386d8abda7a Binary files /dev/null and b/examples/fact_1.png differ diff --git a/examples/fact_2.png b/examples/fact_2.png new file mode 100644 index 0000000000000000000000000000000000000000..de6ef53ef7afd72894711a3a5288a24d62c39182 Binary files /dev/null and b/examples/fact_2.png differ diff --git a/examples/fix_1.png b/examples/fix_1.png new file mode 100644 index 0000000000000000000000000000000000000000..023cfe6610747868805c70001e14f2b408f3cebb Binary files /dev/null and b/examples/fix_1.png differ diff --git a/examples/fix_2.png b/examples/fix_2.png new file mode 100644 index 0000000000000000000000000000000000000000..f60da5ff9bdef7018e98a92b19c0e59d31acd059 Binary files /dev/null and b/examples/fix_2.png differ diff --git a/examples/fun_1.png b/examples/fun_1.png new file mode 100644 index 0000000000000000000000000000000000000000..f720ea603f88019e24dbcb569328c3083c832baf Binary files /dev/null and b/examples/fun_1.png differ diff --git a/examples/fun_2.png b/examples/fun_2.png new file mode 100644 index 0000000000000000000000000000000000000000..1d37a8068feda3f7ecc2d0b22893d26071e68b64 Binary files /dev/null and b/examples/fun_2.png differ diff --git a/examples/logo_1.png b/examples/logo_1.png new file mode 100644 index 0000000000000000000000000000000000000000..8bbe438bdc05ce023251045575c6b7e7b04f210f Binary files /dev/null and b/examples/logo_1.png differ diff --git a/examples/op_1.png b/examples/op_1.png new file mode 100644 index 0000000000000000000000000000000000000000..3dbb2ff51ca08f62171f48167bbb97ad604cc4d0 Binary files /dev/null and b/examples/op_1.png differ diff --git a/examples/op_2.png b/examples/op_2.png new file mode 100644 index 0000000000000000000000000000000000000000..2cd3e1f8b0326dea14d45bf866b244deb38ef409 Binary files /dev/null and b/examples/op_2.png differ diff --git a/examples/people_1.png b/examples/people_1.png new file mode 100644 index 0000000000000000000000000000000000000000..7e95c42c710aef5efe94a52da280fd7451f185d7 Binary files /dev/null and b/examples/people_1.png differ diff --git a/examples/people_2.png b/examples/people_2.png new file mode 100644 index 0000000000000000000000000000000000000000..aec6c83b217c96af91668dbc566e05a14238b2a8 Binary files /dev/null and b/examples/people_2.png differ diff --git a/examples/rhyme_1.png b/examples/rhyme_1.png new file mode 100644 index 0000000000000000000000000000000000000000..7d133878d8b534867253c7be7b2805faffbd6ad7 Binary files /dev/null and b/examples/rhyme_1.png differ diff --git a/examples/rhyme_2.png b/examples/rhyme_2.png new file mode 100644 index 0000000000000000000000000000000000000000..6cf9bf8958302e461dec8b58dd7cbbe2224a8e5c Binary files /dev/null and b/examples/rhyme_2.png differ diff --git a/examples/story_1.png b/examples/story_1.png new file mode 100644 index 0000000000000000000000000000000000000000..3eb6ccb93fb5c866eeb758ba962904e5c3d57875 Binary files /dev/null and b/examples/story_1.png differ diff --git a/examples/story_2.png b/examples/story_2.png new file mode 100644 index 0000000000000000000000000000000000000000..9d37142a9ae32f20d47ebd10cf1c395f10b363f7 Binary files /dev/null and b/examples/story_2.png differ diff --git a/examples/web_1.png b/examples/web_1.png new file mode 100644 index 0000000000000000000000000000000000000000..8943842c08609713b78d95ab3b5c418995569505 Binary files /dev/null and b/examples/web_1.png differ diff --git a/examples/wop_1.png b/examples/wop_1.png new file mode 100644 index 0000000000000000000000000000000000000000..88f37d672bb2dd3dac34caced2ed4bebcfe15412 Binary files /dev/null and b/examples/wop_1.png differ diff --git a/examples/wop_2.png b/examples/wop_2.png new file mode 100644 index 0000000000000000000000000000000000000000..8255974176014db0b388617821630bdb438b5e6b Binary files /dev/null and b/examples/wop_2.png differ diff --git a/examples_v2/2000x1372_wmkn_0012149409555.jpg b/examples_v2/2000x1372_wmkn_0012149409555.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1250f7fa5e84e9301bd33f59a9626b904cc21a12 Binary files /dev/null and b/examples_v2/2000x1372_wmkn_0012149409555.jpg differ diff --git a/examples_v2/KFC-20-for-20-Nuggets.jpg b/examples_v2/KFC-20-for-20-Nuggets.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0ec641c2306645e1f6a2bb38c7e62cabcb295808 Binary files /dev/null and b/examples_v2/KFC-20-for-20-Nuggets.jpg differ diff --git a/examples_v2/cockdial.png b/examples_v2/cockdial.png new file mode 100644 index 0000000000000000000000000000000000000000..32aae7a8461fa5aca01f7ae9701706c136f57a89 --- /dev/null +++ b/examples_v2/cockdial.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:48e6fcd1994b733174bb2484038a6eba18c36922686e9bffaaa6216ac704ea6e +size 1528183 diff --git a/examples_v2/float.png b/examples_v2/float.png new file mode 100644 index 0000000000000000000000000000000000000000..690801a47fef974c476e3180ffa298c9d6dda55c --- /dev/null +++ b/examples_v2/float.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ee6365239cec6f1cceb156273ba30b43295bf92eef9b3e44f854eec335fa0646 +size 1248467 diff --git a/examples_v2/glip_test.jpg b/examples_v2/glip_test.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f9198f2733daf7d93488a2ae5574c1011c889c31 Binary files /dev/null and b/examples_v2/glip_test.jpg differ diff --git a/examples_v2/office.jpg b/examples_v2/office.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e35bdc2e0091f8df6dd9c3be1ca0c926954c5757 Binary files /dev/null and b/examples_v2/office.jpg differ diff --git a/examples_v2/sofa.jpg b/examples_v2/sofa.jpg new file mode 100644 index 0000000000000000000000000000000000000000..861059151b6baeca0369be0925c14a029fb3dd8c Binary files /dev/null and b/examples_v2/sofa.jpg differ diff --git a/examples_v2/thief.png b/examples_v2/thief.png new file mode 100644 index 0000000000000000000000000000000000000000..579ee5218d5bc7a403671378adccf500c05356e8 Binary files /dev/null and b/examples_v2/thief.png differ diff --git a/figs/demo.png b/figs/demo.png new file mode 100644 index 0000000000000000000000000000000000000000..67b58774b720136f1491bd2ac39fd3aabebb9856 --- /dev/null +++ b/figs/demo.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b4b53cddf459e45298a2a682fa770126958f0ccb17a2acaf71c9f7b00dc40033 +size 1103091 diff --git a/figs/examples/ad_1.png b/figs/examples/ad_1.png new file mode 100644 index 0000000000000000000000000000000000000000..d0378e43e9a0e797b2ab32f4d8f6261fa2224408 Binary files /dev/null and b/figs/examples/ad_1.png differ diff --git a/figs/examples/ad_2.png b/figs/examples/ad_2.png new file mode 100644 index 0000000000000000000000000000000000000000..674248b723bee885c43a55d85d83ea1c0fa41477 Binary files /dev/null and b/figs/examples/ad_2.png differ diff --git a/figs/examples/cook_1.png b/figs/examples/cook_1.png new file mode 100644 index 0000000000000000000000000000000000000000..d8cdb45c98492afd4f975b8626bb590b580616a5 Binary files /dev/null and b/figs/examples/cook_1.png differ diff --git a/figs/examples/cook_2.png b/figs/examples/cook_2.png new file mode 100644 index 0000000000000000000000000000000000000000..d08272b3733dda976bfa78733d9ca4eb544fee52 Binary files /dev/null and b/figs/examples/cook_2.png differ diff --git a/figs/examples/describe_1.png b/figs/examples/describe_1.png new file mode 100644 index 0000000000000000000000000000000000000000..02f3c92f54749fa354a5f8c617f24301728555b2 Binary files /dev/null and b/figs/examples/describe_1.png differ diff --git a/figs/examples/describe_2.png b/figs/examples/describe_2.png new file mode 100644 index 0000000000000000000000000000000000000000..20bf8c7cd86c03f9ed77f95d912057438997277d Binary files /dev/null and b/figs/examples/describe_2.png differ diff --git a/figs/examples/fact_1.png b/figs/examples/fact_1.png new file mode 100644 index 0000000000000000000000000000000000000000..1f7522871916c3d1bd113cfb88c45386d8abda7a Binary files /dev/null and b/figs/examples/fact_1.png differ diff --git a/figs/examples/fact_2.png b/figs/examples/fact_2.png new file mode 100644 index 0000000000000000000000000000000000000000..de6ef53ef7afd72894711a3a5288a24d62c39182 Binary files /dev/null and b/figs/examples/fact_2.png differ diff --git a/figs/examples/fix_1.png b/figs/examples/fix_1.png new file mode 100644 index 0000000000000000000000000000000000000000..023cfe6610747868805c70001e14f2b408f3cebb Binary files /dev/null and b/figs/examples/fix_1.png differ diff --git a/figs/examples/fix_2.png b/figs/examples/fix_2.png new file mode 100644 index 0000000000000000000000000000000000000000..f60da5ff9bdef7018e98a92b19c0e59d31acd059 Binary files /dev/null and b/figs/examples/fix_2.png differ diff --git a/figs/examples/fun_1.png b/figs/examples/fun_1.png new file mode 100644 index 0000000000000000000000000000000000000000..f720ea603f88019e24dbcb569328c3083c832baf Binary files /dev/null and b/figs/examples/fun_1.png differ diff --git a/figs/examples/fun_2.png b/figs/examples/fun_2.png new file mode 100644 index 0000000000000000000000000000000000000000..1d37a8068feda3f7ecc2d0b22893d26071e68b64 Binary files /dev/null and b/figs/examples/fun_2.png differ diff --git a/figs/examples/logo_1.png b/figs/examples/logo_1.png new file mode 100644 index 0000000000000000000000000000000000000000..8bbe438bdc05ce023251045575c6b7e7b04f210f Binary files /dev/null and b/figs/examples/logo_1.png differ diff --git a/figs/examples/op_1.png b/figs/examples/op_1.png new file mode 100644 index 0000000000000000000000000000000000000000..3dbb2ff51ca08f62171f48167bbb97ad604cc4d0 Binary files /dev/null and b/figs/examples/op_1.png differ diff --git a/figs/examples/op_2.png b/figs/examples/op_2.png new file mode 100644 index 0000000000000000000000000000000000000000..2cd3e1f8b0326dea14d45bf866b244deb38ef409 Binary files /dev/null and b/figs/examples/op_2.png differ diff --git a/figs/examples/people_1.png b/figs/examples/people_1.png new file mode 100644 index 0000000000000000000000000000000000000000..7e95c42c710aef5efe94a52da280fd7451f185d7 Binary files /dev/null and b/figs/examples/people_1.png differ diff --git a/figs/examples/people_2.png b/figs/examples/people_2.png new file mode 100644 index 0000000000000000000000000000000000000000..aec6c83b217c96af91668dbc566e05a14238b2a8 Binary files /dev/null and b/figs/examples/people_2.png differ diff --git a/figs/examples/rhyme_1.png b/figs/examples/rhyme_1.png new file mode 100644 index 0000000000000000000000000000000000000000..7d133878d8b534867253c7be7b2805faffbd6ad7 Binary files /dev/null and b/figs/examples/rhyme_1.png differ diff --git a/figs/examples/rhyme_2.png b/figs/examples/rhyme_2.png new file mode 100644 index 0000000000000000000000000000000000000000..6cf9bf8958302e461dec8b58dd7cbbe2224a8e5c Binary files /dev/null and b/figs/examples/rhyme_2.png differ diff --git a/figs/examples/story_1.png b/figs/examples/story_1.png new file mode 100644 index 0000000000000000000000000000000000000000..3eb6ccb93fb5c866eeb758ba962904e5c3d57875 Binary files /dev/null and b/figs/examples/story_1.png differ diff --git a/figs/examples/story_2.png b/figs/examples/story_2.png new file mode 100644 index 0000000000000000000000000000000000000000..9d37142a9ae32f20d47ebd10cf1c395f10b363f7 Binary files /dev/null and b/figs/examples/story_2.png differ diff --git a/figs/examples/web_1.png b/figs/examples/web_1.png new file mode 100644 index 0000000000000000000000000000000000000000..8943842c08609713b78d95ab3b5c418995569505 Binary files /dev/null and b/figs/examples/web_1.png differ diff --git a/figs/examples/wop_1.png b/figs/examples/wop_1.png new file mode 100644 index 0000000000000000000000000000000000000000..88f37d672bb2dd3dac34caced2ed4bebcfe15412 Binary files /dev/null and b/figs/examples/wop_1.png differ diff --git a/figs/examples/wop_2.png b/figs/examples/wop_2.png new file mode 100644 index 0000000000000000000000000000000000000000..8255974176014db0b388617821630bdb438b5e6b Binary files /dev/null and b/figs/examples/wop_2.png differ diff --git a/figs/minigpt2_demo.png b/figs/minigpt2_demo.png new file mode 100644 index 0000000000000000000000000000000000000000..4202fed0763da8fba875ce0493d4f086b1dbbf7f --- /dev/null +++ b/figs/minigpt2_demo.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3f01d120ab3599c184e9f471e1a5052fc7b4f058951b684a108b660bb74bc1dc +size 1152844 diff --git a/figs/online_demo.png b/figs/online_demo.png new file mode 100644 index 0000000000000000000000000000000000000000..cc3dde7a4ac01cf1bcb8096de63c0ec583070846 --- /dev/null +++ b/figs/online_demo.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:236107e23897574578111e0a4bbe6c0475e794b5eca13943f692755b1c71c8df +size 1259978 diff --git a/figs/overview.png b/figs/overview.png new file mode 100644 index 0000000000000000000000000000000000000000..83ac7cc1657d98262d34459268421713074b4a96 --- /dev/null +++ b/figs/overview.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:378307bed387d530087773826ac3b4686e106cf2d1e4a24daed27573f8cc14cb +size 2532475 diff --git a/minigpt4/__init__.py b/minigpt4/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bb31f42f9107a0b748b878deb1c5768019d62b32 --- /dev/null +++ b/minigpt4/__init__.py @@ -0,0 +1,31 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import os +import sys + +from omegaconf import OmegaConf + +from minigpt4.common.registry import registry + +from minigpt4.datasets.builders import * +from minigpt4.models import * +from minigpt4.processors import * +from minigpt4.tasks import * + + +root_dir = os.path.dirname(os.path.abspath(__file__)) +default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml")) + +registry.register_path("library_root", root_dir) +repo_root = os.path.join(root_dir, "..") +registry.register_path("repo_root", repo_root) +cache_root = os.path.join(repo_root, default_cfg.env.cache_root) +registry.register_path("cache_root", cache_root) + +registry.register("MAX_INT", sys.maxsize) +registry.register("SPLIT_NAMES", ["train", "val", "test"]) diff --git a/minigpt4/common/__init__.py b/minigpt4/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/minigpt4/common/config.py b/minigpt4/common/config.py new file mode 100644 index 0000000000000000000000000000000000000000..e184b1f9024957a42fc6a4f796d0e8a7804e1ef7 --- /dev/null +++ b/minigpt4/common/config.py @@ -0,0 +1,468 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import logging +import json +from typing import Dict + +from omegaconf import OmegaConf +from minigpt4.common.registry import registry + + +class Config: + def __init__(self, args): + self.config = {} + + self.args = args + + # Register the config and configuration for setup + registry.register("configuration", self) + + user_config = self._build_opt_list(self.args.options) + + config = OmegaConf.load(self.args.cfg_path) + + runner_config = self.build_runner_config(config) + model_config = self.build_model_config(config, **user_config) + dataset_config = self.build_dataset_config(config) + + # Validate the user-provided runner configuration + # model and dataset configuration are supposed to be validated by the respective classes + # [TODO] validate the model/dataset configuration + # self._validate_runner_config(runner_config) + + # Override the default configuration with user options. + self.config = OmegaConf.merge( + runner_config, model_config, dataset_config, user_config + ) + + def _validate_runner_config(self, runner_config): + """ + This method validates the configuration, such that + 1) all the user specified options are valid; + 2) no type mismatches between the user specified options and the config. + """ + runner_config_validator = create_runner_config_validator() + runner_config_validator.validate(runner_config) + + def _build_opt_list(self, opts): + opts_dot_list = self._convert_to_dot_list(opts) + return OmegaConf.from_dotlist(opts_dot_list) + + @staticmethod + def build_model_config(config, **kwargs): + model = config.get("model", None) + assert model is not None, "Missing model configuration file." + + model_cls = registry.get_model_class(model.arch) + assert model_cls is not None, f"Model '{model.arch}' has not been registered." + + model_type = kwargs.get("model.model_type", None) + if not model_type: + model_type = model.get("model_type", None) + # else use the model type selected by user. + + assert model_type is not None, "Missing model_type." + + model_config_path = model_cls.default_config_path(model_type=model_type) + + model_config = OmegaConf.create() + # hierarchy override, customized config > default config + model_config = OmegaConf.merge( + model_config, + OmegaConf.load(model_config_path), + {"model": config["model"]}, + ) + + return model_config + + @staticmethod + def build_runner_config(config): + return {"run": config.run} + + @staticmethod + def build_dataset_config(config): + datasets = config.get("datasets", None) + if datasets is None: + raise KeyError( + "Expecting 'datasets' as the root key for dataset configuration." + ) + + dataset_config = OmegaConf.create() + + for dataset_name in datasets: + builder_cls = registry.get_builder_class(dataset_name) + + dataset_config_type = datasets[dataset_name].get("type", "default") + dataset_config_path = builder_cls.default_config_path( + type=dataset_config_type + ) + + # hierarchy override, customized config > default config + dataset_config = OmegaConf.merge( + dataset_config, + OmegaConf.load(dataset_config_path), + {"datasets": {dataset_name: config["datasets"][dataset_name]}}, + ) + + return dataset_config + + def _convert_to_dot_list(self, opts): + if opts is None: + opts = [] + + if len(opts) == 0: + return opts + + has_equal = opts[0].find("=") != -1 + + if has_equal: + return opts + + return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])] + + def get_config(self): + return self.config + + @property + def run_cfg(self): + return self.config.run + + @property + def datasets_cfg(self): + return self.config.datasets + + @property + def model_cfg(self): + return self.config.model + + def pretty_print(self): + logging.info("\n===== Running Parameters =====") + logging.info(self._convert_node_to_json(self.config.run)) + + logging.info("\n====== Dataset Attributes ======") + datasets = self.config.datasets + + for dataset in datasets: + if dataset in self.config.datasets: + logging.info(f"\n======== {dataset} =======") + dataset_config = self.config.datasets[dataset] + logging.info(self._convert_node_to_json(dataset_config)) + else: + logging.warning(f"No dataset named '{dataset}' in config. Skipping") + + logging.info(f"\n====== Model Attributes ======") + logging.info(self._convert_node_to_json(self.config.model)) + + def _convert_node_to_json(self, node): + container = OmegaConf.to_container(node, resolve=True) + return json.dumps(container, indent=4, sort_keys=True) + + def to_dict(self): + return OmegaConf.to_container(self.config) + + +def node_to_dict(node): + return OmegaConf.to_container(node) + + +class ConfigValidator: + """ + This is a preliminary implementation to centralize and validate the configuration. + May be altered in the future. + + A helper class to validate configurations from yaml file. + + This serves the following purposes: + 1. Ensure all the options in the yaml are defined, raise error if not. + 2. when type mismatches are found, the validator will raise an error. + 3. a central place to store and display helpful messages for supported configurations. + + """ + + class _Argument: + def __init__(self, name, choices=None, type=None, help=None): + self.name = name + self.val = None + self.choices = choices + self.type = type + self.help = help + + def __str__(self): + s = f"{self.name}={self.val}" + if self.type is not None: + s += f", ({self.type})" + if self.choices is not None: + s += f", choices: {self.choices}" + if self.help is not None: + s += f", ({self.help})" + return s + + def __init__(self, description): + self.description = description + + self.arguments = dict() + + self.parsed_args = None + + def __getitem__(self, key): + assert self.parsed_args is not None, "No arguments parsed yet." + + return self.parsed_args[key] + + def __str__(self) -> str: + return self.format_help() + + def add_argument(self, *args, **kwargs): + """ + Assume the first argument is the name of the argument. + """ + self.arguments[args[0]] = self._Argument(*args, **kwargs) + + def validate(self, config=None): + """ + Convert yaml config (dict-like) to list, required by argparse. + """ + for k, v in config.items(): + assert ( + k in self.arguments + ), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}.""" + + if self.arguments[k].type is not None: + try: + self.arguments[k].val = self.arguments[k].type(v) + except ValueError: + raise ValueError(f"{k} is not a valid {self.arguments[k].type}.") + + if self.arguments[k].choices is not None: + assert ( + v in self.arguments[k].choices + ), f"""{k} must be one of {self.arguments[k].choices}.""" + + return config + + def format_arguments(self): + return str([f"{k}" for k in sorted(self.arguments.keys())]) + + def format_help(self): + # description + key-value pair string for each argument + help_msg = str(self.description) + return help_msg + ", available arguments: " + self.format_arguments() + + def print_help(self): + # display help message + print(self.format_help()) + + +def create_runner_config_validator(): + validator = ConfigValidator(description="Runner configurations") + + validator.add_argument( + "runner", + type=str, + choices=["runner_base", "runner_iter"], + help="""Runner to use. The "runner_base" uses epoch-based training while iter-based + runner runs based on iters. Default: runner_base""", + ) + # add argumetns for training dataset ratios + validator.add_argument( + "train_dataset_ratios", + type=Dict[str, float], + help="""Ratios of training dataset. This is used in iteration-based runner. + Do not support for epoch-based runner because how to define an epoch becomes tricky. + Default: None""", + ) + validator.add_argument( + "max_iters", + type=float, + help="Maximum number of iterations to run.", + ) + validator.add_argument( + "max_epoch", + type=int, + help="Maximum number of epochs to run.", + ) + # add arguments for iters_per_inner_epoch + validator.add_argument( + "iters_per_inner_epoch", + type=float, + help="Number of iterations per inner epoch. This is required when runner is runner_iter.", + ) + lr_scheds_choices = registry.list_lr_schedulers() + validator.add_argument( + "lr_sched", + type=str, + choices=lr_scheds_choices, + help="Learning rate scheduler to use, from {}".format(lr_scheds_choices), + ) + task_choices = registry.list_tasks() + validator.add_argument( + "task", + type=str, + choices=task_choices, + help="Task to use, from {}".format(task_choices), + ) + # add arguments for init_lr + validator.add_argument( + "init_lr", + type=float, + help="Initial learning rate. This will be the learning rate after warmup and before decay.", + ) + # add arguments for min_lr + validator.add_argument( + "min_lr", + type=float, + help="Minimum learning rate (after decay).", + ) + # add arguments for warmup_lr + validator.add_argument( + "warmup_lr", + type=float, + help="Starting learning rate for warmup.", + ) + # add arguments for learning rate decay rate + validator.add_argument( + "lr_decay_rate", + type=float, + help="Learning rate decay rate. Required if using a decaying learning rate scheduler.", + ) + # add arguments for weight decay + validator.add_argument( + "weight_decay", + type=float, + help="Weight decay rate.", + ) + # add arguments for training batch size + validator.add_argument( + "batch_size_train", + type=int, + help="Training batch size.", + ) + # add arguments for evaluation batch size + validator.add_argument( + "batch_size_eval", + type=int, + help="Evaluation batch size, including validation and testing.", + ) + # add arguments for number of workers for data loading + validator.add_argument( + "num_workers", + help="Number of workers for data loading.", + ) + # add arguments for warm up steps + validator.add_argument( + "warmup_steps", + type=int, + help="Number of warmup steps. Required if a warmup schedule is used.", + ) + # add arguments for random seed + validator.add_argument( + "seed", + type=int, + help="Random seed.", + ) + # add arguments for output directory + validator.add_argument( + "output_dir", + type=str, + help="Output directory to save checkpoints and logs.", + ) + # add arguments for whether only use evaluation + validator.add_argument( + "evaluate", + help="Whether to only evaluate the model. If true, training will not be performed.", + ) + # add arguments for splits used for training, e.g. ["train", "val"] + validator.add_argument( + "train_splits", + type=list, + help="Splits to use for training.", + ) + # add arguments for splits used for validation, e.g. ["val"] + validator.add_argument( + "valid_splits", + type=list, + help="Splits to use for validation. If not provided, will skip the validation.", + ) + # add arguments for splits used for testing, e.g. ["test"] + validator.add_argument( + "test_splits", + type=list, + help="Splits to use for testing. If not provided, will skip the testing.", + ) + # add arguments for accumulating gradient for iterations + validator.add_argument( + "accum_grad_iters", + type=int, + help="Number of iterations to accumulate gradient for.", + ) + + # ====== distributed training ====== + validator.add_argument( + "device", + type=str, + choices=["cpu", "cuda"], + help="Device to use. Support 'cuda' or 'cpu' as for now.", + ) + validator.add_argument( + "world_size", + type=int, + help="Number of processes participating in the job.", + ) + validator.add_argument("dist_url", type=str) + validator.add_argument("distributed", type=bool) + # add arguments to opt using distributed sampler during evaluation or not + validator.add_argument( + "use_dist_eval_sampler", + type=bool, + help="Whether to use distributed sampler during evaluation or not.", + ) + + # ====== task specific ====== + # generation task specific arguments + # add arguments for maximal length of text output + validator.add_argument( + "max_len", + type=int, + help="Maximal length of text output.", + ) + # add arguments for minimal length of text output + validator.add_argument( + "min_len", + type=int, + help="Minimal length of text output.", + ) + # add arguments number of beams + validator.add_argument( + "num_beams", + type=int, + help="Number of beams used for beam search.", + ) + + # vqa task specific arguments + # add arguments for number of answer candidates + validator.add_argument( + "num_ans_candidates", + type=int, + help="""For ALBEF and BLIP, these models first rank answers according to likelihood to select answer candidates.""", + ) + # add arguments for inference method + validator.add_argument( + "inference_method", + type=str, + choices=["genearte", "rank"], + help="""Inference method to use for question answering. If rank, requires a answer list.""", + ) + + # ====== model specific ====== + validator.add_argument( + "k_test", + type=int, + help="Number of top k most similar samples from ITC/VTC selection to be tested.", + ) + + return validator diff --git a/minigpt4/common/dist_utils.py b/minigpt4/common/dist_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a6fc1b904dccccbffbd96326b1506f8ff3ca19c1 --- /dev/null +++ b/minigpt4/common/dist_utils.py @@ -0,0 +1,140 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import datetime +import functools +import os + +import torch +import torch.distributed as dist +import timm.models.hub as timm_hub + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop("force", False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def init_distributed_mode(args): + if args.distributed is False: + print("Not using distributed mode") + return + elif "RANK" in os.environ and "WORLD_SIZE" in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ["WORLD_SIZE"]) + args.gpu = int(os.environ["LOCAL_RANK"]) + elif "SLURM_PROCID" in os.environ: + args.rank = int(os.environ["SLURM_PROCID"]) + args.gpu = args.rank % torch.cuda.device_count() + else: + print("Not using distributed mode") + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = "nccl" + print( + "| distributed init (rank {}, world {}): {}".format( + args.rank, args.world_size, args.dist_url + ), + flush=True, + ) + torch.distributed.init_process_group( + backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank, + timeout=datetime.timedelta( + days=365 + ), # allow auto-downloading and de-compressing + ) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +def get_dist_info(): + if torch.__version__ < "1.0": + initialized = dist._initialized + else: + initialized = dist.is_initialized() + if initialized: + rank = dist.get_rank() + world_size = dist.get_world_size() + else: # non-distributed training + rank = 0 + world_size = 1 + return rank, world_size + + +def main_process(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + rank, _ = get_dist_info() + if rank == 0: + return func(*args, **kwargs) + + return wrapper + + +def download_cached_file(url, check_hash=True, progress=False): + """ + Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again. + If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded. + """ + + def get_cached_file_path(): + # a hack to sync the file path across processes + parts = torch.hub.urlparse(url) + filename = os.path.basename(parts.path) + cached_file = os.path.join(timm_hub.get_cache_dir(), filename) + + return cached_file + + if is_main_process(): + timm_hub.download_cached_file(url, check_hash, progress) + + if is_dist_avail_and_initialized(): + dist.barrier() + + return get_cached_file_path() diff --git a/minigpt4/common/gradcam.py b/minigpt4/common/gradcam.py new file mode 100644 index 0000000000000000000000000000000000000000..d53a5254d4b319eaf2cbfbd081b0ca8e38c5c7a0 --- /dev/null +++ b/minigpt4/common/gradcam.py @@ -0,0 +1,24 @@ +import numpy as np +from matplotlib import pyplot as plt +from scipy.ndimage import filters +from skimage import transform as skimage_transform + + +def getAttMap(img, attMap, blur=True, overlap=True): + attMap -= attMap.min() + if attMap.max() > 0: + attMap /= attMap.max() + attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant") + if blur: + attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2])) + attMap -= attMap.min() + attMap /= attMap.max() + cmap = plt.get_cmap("jet") + attMapV = cmap(attMap) + attMapV = np.delete(attMapV, 3, 2) + if overlap: + attMap = ( + 1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img + + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV + ) + return attMap diff --git a/minigpt4/common/logger.py b/minigpt4/common/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..9a5a727213c6478606a154172830cdc43aae6f5a --- /dev/null +++ b/minigpt4/common/logger.py @@ -0,0 +1,195 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import datetime +import logging +import time +from collections import defaultdict, deque + +import torch +import torch.distributed as dist + +from minigpt4.common import dist_utils + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not dist_utils.is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value, + ) + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError( + "'{}' object has no attribute '{}'".format(type(self).__name__, attr) + ) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append("{}: {}".format(name, str(meter))) + return self.delimiter.join(loss_str) + + def global_avg(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append("{}: {:.4f}".format(name, meter.global_avg)) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = "" + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt="{avg:.4f}") + data_time = SmoothedValue(fmt="{avg:.4f}") + space_fmt = ":" + str(len(str(len(iterable)))) + "d" + log_msg = [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + ] + if torch.cuda.is_available(): + log_msg.append("max mem: {memory:.0f}") + log_msg = self.delimiter.join(log_msg) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB, + ) + ) + else: + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + ) + ) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print( + "{} Total time: {} ({:.4f} s / it)".format( + header, total_time_str, total_time / len(iterable) + ) + ) + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def setup_logger(): + logging.basicConfig( + level=logging.INFO if dist_utils.is_main_process() else logging.WARN, + format="%(asctime)s [%(levelname)s] %(message)s", + handlers=[logging.StreamHandler()], + ) diff --git a/minigpt4/common/optims.py b/minigpt4/common/optims.py new file mode 100644 index 0000000000000000000000000000000000000000..58327f723d445633ce7d1b5c3cc799b041319a97 --- /dev/null +++ b/minigpt4/common/optims.py @@ -0,0 +1,119 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import math + +from minigpt4.common.registry import registry + + +@registry.register_lr_scheduler("linear_warmup_step_lr") +class LinearWarmupStepLRScheduler: + def __init__( + self, + optimizer, + max_epoch, + min_lr, + init_lr, + decay_rate=1, + warmup_start_lr=-1, + warmup_steps=0, + **kwargs + ): + self.optimizer = optimizer + + self.max_epoch = max_epoch + self.min_lr = min_lr + + self.decay_rate = decay_rate + + self.init_lr = init_lr + self.warmup_steps = warmup_steps + self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr + + def step(self, cur_epoch, cur_step): + if cur_epoch == 0: + warmup_lr_schedule( + step=cur_step, + optimizer=self.optimizer, + max_step=self.warmup_steps, + init_lr=self.warmup_start_lr, + max_lr=self.init_lr, + ) + else: + step_lr_schedule( + epoch=cur_epoch, + optimizer=self.optimizer, + init_lr=self.init_lr, + min_lr=self.min_lr, + decay_rate=self.decay_rate, + ) + + +@registry.register_lr_scheduler("linear_warmup_cosine_lr") +class LinearWarmupCosineLRScheduler: + def __init__( + self, + optimizer, + max_epoch, + iters_per_epoch, + min_lr, + init_lr, + warmup_steps=0, + warmup_start_lr=-1, + **kwargs + ): + self.optimizer = optimizer + + self.max_epoch = max_epoch + self.iters_per_epoch = iters_per_epoch + self.min_lr = min_lr + + self.init_lr = init_lr + self.warmup_steps = warmup_steps + self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr + + def step(self, cur_epoch, cur_step): + total_cur_step = cur_epoch * self.iters_per_epoch + cur_step + if total_cur_step < self.warmup_steps: + warmup_lr_schedule( + step=cur_step, + optimizer=self.optimizer, + max_step=self.warmup_steps, + init_lr=self.warmup_start_lr, + max_lr=self.init_lr, + ) + else: + cosine_lr_schedule( + epoch=total_cur_step, + optimizer=self.optimizer, + max_epoch=self.max_epoch * self.iters_per_epoch, + init_lr=self.init_lr, + min_lr=self.min_lr, + ) + + +def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr): + """Decay the learning rate""" + lr = (init_lr - min_lr) * 0.5 * ( + 1.0 + math.cos(math.pi * epoch / max_epoch) + ) + min_lr + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + +def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr): + """Warmup the learning rate""" + lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1)) + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + +def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate): + """Decay the learning rate""" + lr = max(min_lr, init_lr * (decay_rate**epoch)) + for param_group in optimizer.param_groups: + param_group["lr"] = lr diff --git a/minigpt4/common/registry.py b/minigpt4/common/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..679467a7411eda19ed956b810c21234322f06779 --- /dev/null +++ b/minigpt4/common/registry.py @@ -0,0 +1,329 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + + +class Registry: + mapping = { + "builder_name_mapping": {}, + "task_name_mapping": {}, + "processor_name_mapping": {}, + "model_name_mapping": {}, + "lr_scheduler_name_mapping": {}, + "runner_name_mapping": {}, + "state": {}, + "paths": {}, + } + + @classmethod + def register_builder(cls, name): + r"""Register a dataset builder to registry with key 'name' + + Args: + name: Key with which the builder will be registered. + + Usage: + + from minigpt4.common.registry import registry + from minigpt4.datasets.base_dataset_builder import BaseDatasetBuilder + """ + + def wrap(builder_cls): + from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder + + assert issubclass( + builder_cls, BaseDatasetBuilder + ), "All builders must inherit BaseDatasetBuilder class, found {}".format( + builder_cls + ) + if name in cls.mapping["builder_name_mapping"]: + raise KeyError( + "Name '{}' already registered for {}.".format( + name, cls.mapping["builder_name_mapping"][name] + ) + ) + cls.mapping["builder_name_mapping"][name] = builder_cls + return builder_cls + + return wrap + + @classmethod + def register_task(cls, name): + r"""Register a task to registry with key 'name' + + Args: + name: Key with which the task will be registered. + + Usage: + + from minigpt4.common.registry import registry + """ + + def wrap(task_cls): + from minigpt4.tasks.base_task import BaseTask + + assert issubclass( + task_cls, BaseTask + ), "All tasks must inherit BaseTask class" + if name in cls.mapping["task_name_mapping"]: + raise KeyError( + "Name '{}' already registered for {}.".format( + name, cls.mapping["task_name_mapping"][name] + ) + ) + cls.mapping["task_name_mapping"][name] = task_cls + return task_cls + + return wrap + + @classmethod + def register_model(cls, name): + r"""Register a task to registry with key 'name' + + Args: + name: Key with which the task will be registered. + + Usage: + + from minigpt4.common.registry import registry + """ + + def wrap(model_cls): + from minigpt4.models import BaseModel + + assert issubclass( + model_cls, BaseModel + ), "All models must inherit BaseModel class" + if name in cls.mapping["model_name_mapping"]: + raise KeyError( + "Name '{}' already registered for {}.".format( + name, cls.mapping["model_name_mapping"][name] + ) + ) + cls.mapping["model_name_mapping"][name] = model_cls + return model_cls + + return wrap + + @classmethod + def register_processor(cls, name): + r"""Register a processor to registry with key 'name' + + Args: + name: Key with which the task will be registered. + + Usage: + + from minigpt4.common.registry import registry + """ + + def wrap(processor_cls): + from minigpt4.processors import BaseProcessor + + assert issubclass( + processor_cls, BaseProcessor + ), "All processors must inherit BaseProcessor class" + if name in cls.mapping["processor_name_mapping"]: + raise KeyError( + "Name '{}' already registered for {}.".format( + name, cls.mapping["processor_name_mapping"][name] + ) + ) + cls.mapping["processor_name_mapping"][name] = processor_cls + return processor_cls + + return wrap + + @classmethod + def register_lr_scheduler(cls, name): + r"""Register a model to registry with key 'name' + + Args: + name: Key with which the task will be registered. + + Usage: + + from minigpt4.common.registry import registry + """ + + def wrap(lr_sched_cls): + if name in cls.mapping["lr_scheduler_name_mapping"]: + raise KeyError( + "Name '{}' already registered for {}.".format( + name, cls.mapping["lr_scheduler_name_mapping"][name] + ) + ) + cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls + return lr_sched_cls + + return wrap + + @classmethod + def register_runner(cls, name): + r"""Register a model to registry with key 'name' + + Args: + name: Key with which the task will be registered. + + Usage: + + from minigpt4.common.registry import registry + """ + + def wrap(runner_cls): + if name in cls.mapping["runner_name_mapping"]: + raise KeyError( + "Name '{}' already registered for {}.".format( + name, cls.mapping["runner_name_mapping"][name] + ) + ) + cls.mapping["runner_name_mapping"][name] = runner_cls + return runner_cls + + return wrap + + @classmethod + def register_path(cls, name, path): + r"""Register a path to registry with key 'name' + + Args: + name: Key with which the path will be registered. + + Usage: + + from minigpt4.common.registry import registry + """ + assert isinstance(path, str), "All path must be str." + if name in cls.mapping["paths"]: + raise KeyError("Name '{}' already registered.".format(name)) + cls.mapping["paths"][name] = path + + @classmethod + def register(cls, name, obj): + r"""Register an item to registry with key 'name' + + Args: + name: Key with which the item will be registered. + + Usage:: + + from minigpt4.common.registry import registry + + registry.register("config", {}) + """ + path = name.split(".") + current = cls.mapping["state"] + + for part in path[:-1]: + if part not in current: + current[part] = {} + current = current[part] + + current[path[-1]] = obj + + # @classmethod + # def get_trainer_class(cls, name): + # return cls.mapping["trainer_name_mapping"].get(name, None) + + @classmethod + def get_builder_class(cls, name): + return cls.mapping["builder_name_mapping"].get(name, None) + + @classmethod + def get_model_class(cls, name): + return cls.mapping["model_name_mapping"].get(name, None) + + @classmethod + def get_task_class(cls, name): + return cls.mapping["task_name_mapping"].get(name, None) + + @classmethod + def get_processor_class(cls, name): + return cls.mapping["processor_name_mapping"].get(name, None) + + @classmethod + def get_lr_scheduler_class(cls, name): + return cls.mapping["lr_scheduler_name_mapping"].get(name, None) + + @classmethod + def get_runner_class(cls, name): + return cls.mapping["runner_name_mapping"].get(name, None) + + @classmethod + def list_runners(cls): + return sorted(cls.mapping["runner_name_mapping"].keys()) + + @classmethod + def list_models(cls): + return sorted(cls.mapping["model_name_mapping"].keys()) + + @classmethod + def list_tasks(cls): + return sorted(cls.mapping["task_name_mapping"].keys()) + + @classmethod + def list_processors(cls): + return sorted(cls.mapping["processor_name_mapping"].keys()) + + @classmethod + def list_lr_schedulers(cls): + return sorted(cls.mapping["lr_scheduler_name_mapping"].keys()) + + @classmethod + def list_datasets(cls): + return sorted(cls.mapping["builder_name_mapping"].keys()) + + @classmethod + def get_path(cls, name): + return cls.mapping["paths"].get(name, None) + + @classmethod + def get(cls, name, default=None, no_warning=False): + r"""Get an item from registry with key 'name' + + Args: + name (string): Key whose value needs to be retrieved. + default: If passed and key is not in registry, default value will + be returned with a warning. Default: None + no_warning (bool): If passed as True, warning when key doesn't exist + will not be generated. Useful for MMF's + internal operations. Default: False + """ + original_name = name + name = name.split(".") + value = cls.mapping["state"] + for subname in name: + value = value.get(subname, default) + if value is default: + break + + if ( + "writer" in cls.mapping["state"] + and value == default + and no_warning is False + ): + cls.mapping["state"]["writer"].warning( + "Key {} is not present in registry, returning default value " + "of {}".format(original_name, default) + ) + return value + + @classmethod + def unregister(cls, name): + r"""Remove an item from registry with key 'name' + + Args: + name: Key which needs to be removed. + Usage:: + + from mmf.common.registry import registry + + config = registry.unregister("config") + """ + return cls.mapping["state"].pop(name, None) + + +registry = Registry() diff --git a/minigpt4/common/utils.py b/minigpt4/common/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a3069cd10ce986a1ec249490fa813cae9254bd0d --- /dev/null +++ b/minigpt4/common/utils.py @@ -0,0 +1,424 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import io +import json +import logging +import os +import pickle +import re +import shutil +import urllib +import urllib.error +import urllib.request +from typing import Optional +from urllib.parse import urlparse + +import numpy as np +import pandas as pd +import yaml +from iopath.common.download import download +from iopath.common.file_io import file_lock, g_pathmgr +from minigpt4.common.registry import registry +from torch.utils.model_zoo import tqdm +from torchvision.datasets.utils import ( + check_integrity, + download_file_from_google_drive, + extract_archive, +) + + +def now(): + from datetime import datetime + + return datetime.now().strftime("%Y%m%d%H%M")[:-1] + + +def is_url(url_or_filename): + parsed = urlparse(url_or_filename) + return parsed.scheme in ("http", "https") + + +def get_cache_path(rel_path): + return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path)) + + +def get_abs_path(rel_path): + return os.path.join(registry.get_path("library_root"), rel_path) + + +def load_json(filename): + with open(filename, "r") as f: + return json.load(f) + + +# The following are adapted from torchvision and vissl +# torchvision: https://github.com/pytorch/vision +# vissl: https://github.com/facebookresearch/vissl/blob/main/vissl/utils/download.py + + +def makedir(dir_path): + """ + Create the directory if it does not exist. + """ + is_success = False + try: + if not g_pathmgr.exists(dir_path): + g_pathmgr.mkdirs(dir_path) + is_success = True + except BaseException: + print(f"Error creating directory: {dir_path}") + return is_success + + +def get_redirected_url(url: str): + """ + Given a URL, returns the URL it redirects to or the + original URL in case of no indirection + """ + import requests + + with requests.Session() as session: + with session.get(url, stream=True, allow_redirects=True) as response: + if response.history: + return response.url + else: + return url + + +def to_google_drive_download_url(view_url: str) -> str: + """ + Utility function to transform a view URL of google drive + to a download URL for google drive + Example input: + https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view + Example output: + https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp + """ + splits = view_url.split("/") + assert splits[-1] == "view" + file_id = splits[-2] + return f"https://drive.google.com/uc?export=download&id={file_id}" + + +def download_google_drive_url(url: str, output_path: str, output_file_name: str): + """ + Download a file from google drive + Downloading an URL from google drive requires confirmation when + the file of the size is too big (google drive notifies that + anti-viral checks cannot be performed on such files) + """ + import requests + + with requests.Session() as session: + + # First get the confirmation token and append it to the URL + with session.get(url, stream=True, allow_redirects=True) as response: + for k, v in response.cookies.items(): + if k.startswith("download_warning"): + url = url + "&confirm=" + v + + # Then download the content of the file + with session.get(url, stream=True, verify=True) as response: + makedir(output_path) + path = os.path.join(output_path, output_file_name) + total_size = int(response.headers.get("Content-length", 0)) + with open(path, "wb") as file: + from tqdm import tqdm + + with tqdm(total=total_size) as progress_bar: + for block in response.iter_content( + chunk_size=io.DEFAULT_BUFFER_SIZE + ): + file.write(block) + progress_bar.update(len(block)) + + +def _get_google_drive_file_id(url: str) -> Optional[str]: + parts = urlparse(url) + + if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None: + return None + + match = re.match(r"/file/d/(?P[^/]*)", parts.path) + if match is None: + return None + + return match.group("id") + + +def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None: + with open(filename, "wb") as fh: + with urllib.request.urlopen( + urllib.request.Request(url, headers={"User-Agent": "vissl"}) + ) as response: + with tqdm(total=response.length) as pbar: + for chunk in iter(lambda: response.read(chunk_size), ""): + if not chunk: + break + pbar.update(chunk_size) + fh.write(chunk) + + +def download_url( + url: str, + root: str, + filename: Optional[str] = None, + md5: Optional[str] = None, +) -> None: + """Download a file from a url and place it in root. + Args: + url (str): URL to download file from + root (str): Directory to place downloaded file in + filename (str, optional): Name to save the file under. + If None, use the basename of the URL. + md5 (str, optional): MD5 checksum of the download. If None, do not check + """ + root = os.path.expanduser(root) + if not filename: + filename = os.path.basename(url) + fpath = os.path.join(root, filename) + + makedir(root) + + # check if file is already present locally + if check_integrity(fpath, md5): + print("Using downloaded and verified file: " + fpath) + return + + # expand redirect chain if needed + url = get_redirected_url(url) + + # check if file is located on Google Drive + file_id = _get_google_drive_file_id(url) + if file_id is not None: + return download_file_from_google_drive(file_id, root, filename, md5) + + # download the file + try: + print("Downloading " + url + " to " + fpath) + _urlretrieve(url, fpath) + except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined] + if url[:5] == "https": + url = url.replace("https:", "http:") + print( + "Failed download. Trying https -> http instead." + " Downloading " + url + " to " + fpath + ) + _urlretrieve(url, fpath) + else: + raise e + + # check integrity of downloaded file + if not check_integrity(fpath, md5): + raise RuntimeError("File not found or corrupted.") + + +def download_and_extract_archive( + url: str, + download_root: str, + extract_root: Optional[str] = None, + filename: Optional[str] = None, + md5: Optional[str] = None, + remove_finished: bool = False, +) -> None: + download_root = os.path.expanduser(download_root) + if extract_root is None: + extract_root = download_root + if not filename: + filename = os.path.basename(url) + + download_url(url, download_root, filename, md5) + + archive = os.path.join(download_root, filename) + print("Extracting {} to {}".format(archive, extract_root)) + extract_archive(archive, extract_root, remove_finished) + + +def cache_url(url: str, cache_dir: str) -> str: + """ + This implementation downloads the remote resource and caches it locally. + The resource will only be downloaded if not previously requested. + """ + parsed_url = urlparse(url) + dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/"))) + makedir(dirname) + filename = url.split("/")[-1] + cached = os.path.join(dirname, filename) + with file_lock(cached): + if not os.path.isfile(cached): + logging.info(f"Downloading {url} to {cached} ...") + cached = download(url, dirname, filename=filename) + logging.info(f"URL {url} cached in {cached}") + return cached + + +# TODO (prigoyal): convert this into RAII-style API +def create_file_symlink(file1, file2): + """ + Simply create the symlinks for a given file1 to file2. + Useful during model checkpointing to symlinks to the + latest successful checkpoint. + """ + try: + if g_pathmgr.exists(file2): + g_pathmgr.rm(file2) + g_pathmgr.symlink(file1, file2) + except Exception as e: + logging.info(f"Could NOT create symlink. Error: {e}") + + +def save_file(data, filename, append_to_json=True, verbose=True): + """ + Common i/o utility to handle saving data to various file formats. + Supported: + .pkl, .pickle, .npy, .json + Specifically for .json, users have the option to either append (default) + or rewrite by passing in Boolean value to append_to_json. + """ + if verbose: + logging.info(f"Saving data to file: {filename}") + file_ext = os.path.splitext(filename)[1] + if file_ext in [".pkl", ".pickle"]: + with g_pathmgr.open(filename, "wb") as fopen: + pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL) + elif file_ext == ".npy": + with g_pathmgr.open(filename, "wb") as fopen: + np.save(fopen, data) + elif file_ext == ".json": + if append_to_json: + with g_pathmgr.open(filename, "a") as fopen: + fopen.write(json.dumps(data, sort_keys=True) + "\n") + fopen.flush() + else: + with g_pathmgr.open(filename, "w") as fopen: + fopen.write(json.dumps(data, sort_keys=True) + "\n") + fopen.flush() + elif file_ext == ".yaml": + with g_pathmgr.open(filename, "w") as fopen: + dump = yaml.dump(data) + fopen.write(dump) + fopen.flush() + else: + raise Exception(f"Saving {file_ext} is not supported yet") + + if verbose: + logging.info(f"Saved data to file: {filename}") + + +def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False): + """ + Common i/o utility to handle loading data from various file formats. + Supported: + .pkl, .pickle, .npy, .json + For the npy files, we support reading the files in mmap_mode. + If the mmap_mode of reading is not successful, we load data without the + mmap_mode. + """ + if verbose: + logging.info(f"Loading data from file: {filename}") + + file_ext = os.path.splitext(filename)[1] + if file_ext == ".txt": + with g_pathmgr.open(filename, "r") as fopen: + data = fopen.readlines() + elif file_ext in [".pkl", ".pickle"]: + with g_pathmgr.open(filename, "rb") as fopen: + data = pickle.load(fopen, encoding="latin1") + elif file_ext == ".npy": + if mmap_mode: + try: + with g_pathmgr.open(filename, "rb") as fopen: + data = np.load( + fopen, + allow_pickle=allow_pickle, + encoding="latin1", + mmap_mode=mmap_mode, + ) + except ValueError as e: + logging.info( + f"Could not mmap {filename}: {e}. Trying without g_pathmgr" + ) + data = np.load( + filename, + allow_pickle=allow_pickle, + encoding="latin1", + mmap_mode=mmap_mode, + ) + logging.info("Successfully loaded without g_pathmgr") + except Exception: + logging.info("Could not mmap without g_pathmgr. Trying without mmap") + with g_pathmgr.open(filename, "rb") as fopen: + data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1") + else: + with g_pathmgr.open(filename, "rb") as fopen: + data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1") + elif file_ext == ".json": + with g_pathmgr.open(filename, "r") as fopen: + data = json.load(fopen) + elif file_ext == ".yaml": + with g_pathmgr.open(filename, "r") as fopen: + data = yaml.load(fopen, Loader=yaml.FullLoader) + elif file_ext == ".csv": + with g_pathmgr.open(filename, "r") as fopen: + data = pd.read_csv(fopen) + else: + raise Exception(f"Reading from {file_ext} is not supported yet") + return data + + +def abspath(resource_path: str): + """ + Make a path absolute, but take into account prefixes like + "http://" or "manifold://" + """ + regex = re.compile(r"^\w+://") + if regex.match(resource_path) is None: + return os.path.abspath(resource_path) + else: + return resource_path + + +def makedir(dir_path): + """ + Create the directory if it does not exist. + """ + is_success = False + try: + if not g_pathmgr.exists(dir_path): + g_pathmgr.mkdirs(dir_path) + is_success = True + except BaseException: + logging.info(f"Error creating directory: {dir_path}") + return is_success + + +def is_url(input_url): + """ + Check if an input string is a url. look for http(s):// and ignoring the case + """ + is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None + return is_url + + +def cleanup_dir(dir): + """ + Utility for deleting a directory. Useful for cleaning the storage space + that contains various training artifacts like checkpoints, data etc. + """ + if os.path.exists(dir): + logging.info(f"Deleting directory: {dir}") + shutil.rmtree(dir) + logging.info(f"Deleted contents of directory: {dir}") + + +def get_file_size(filename): + """ + Given a file, get the size of file in MB + """ + size_in_mb = os.path.getsize(filename) / float(1024**2) + return size_in_mb diff --git a/minigpt4/configs/datasets/cc_sbu/align.yaml b/minigpt4/configs/datasets/cc_sbu/align.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5710834200fe45449d60185d467d6bcb90a98cca --- /dev/null +++ b/minigpt4/configs/datasets/cc_sbu/align.yaml @@ -0,0 +1,5 @@ +datasets: + cc_sbu_align: + data_type: images + build_info: + storage: /path/to/cc_sbu_align/ diff --git a/minigpt4/configs/datasets/cc_sbu/defaults.yaml b/minigpt4/configs/datasets/cc_sbu/defaults.yaml new file mode 100644 index 0000000000000000000000000000000000000000..60390eece551fe06a0f7c3ebb395351794b9f5f1 --- /dev/null +++ b/minigpt4/configs/datasets/cc_sbu/defaults.yaml @@ -0,0 +1,5 @@ +datasets: + cc_sbu: + data_type: images + build_info: + storage: /path/to/cc_sbu_dataset/{00000..01255}.tar diff --git a/minigpt4/configs/datasets/laion/defaults.yaml b/minigpt4/configs/datasets/laion/defaults.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6bad62901619c0a9e34619a400290f3e18083899 --- /dev/null +++ b/minigpt4/configs/datasets/laion/defaults.yaml @@ -0,0 +1,5 @@ +datasets: + laion: + data_type: images + build_info: + storage: /path/to/laion_dataset/{00000..10488}.tar diff --git a/minigpt4/configs/default.yaml b/minigpt4/configs/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ff5a6a23fa2e3914938631b96c71fdf723dbbc10 --- /dev/null +++ b/minigpt4/configs/default.yaml @@ -0,0 +1,5 @@ +env: + # For default users + # cache_root: "cache" + # For internal use with persistent storage + cache_root: "/export/home/.cache/minigpt4" diff --git a/minigpt4/configs/models/minigpt4_llama2.yaml b/minigpt4/configs/models/minigpt4_llama2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fdd25e0947d2d9c1c6f6270a95be1c8f01538baa --- /dev/null +++ b/minigpt4/configs/models/minigpt4_llama2.yaml @@ -0,0 +1,29 @@ +model: + arch: minigpt4 + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + freeze_vit: True + has_qformer: False + + # generation configs + prompt: "" + + llama_model: "please set this value to the path of llama2-chat-7b" + +preprocess: + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" diff --git a/minigpt4/configs/models/minigpt4_vicuna0.yaml b/minigpt4/configs/models/minigpt4_vicuna0.yaml new file mode 100644 index 0000000000000000000000000000000000000000..718054c8b62a17f3b72bbb3efb3eda95e55a64f1 --- /dev/null +++ b/minigpt4/configs/models/minigpt4_vicuna0.yaml @@ -0,0 +1,32 @@ +model: + arch: minigpt4 + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + freeze_vit: True + freeze_qformer: True + + # Q-Former + num_query_token: 32 + + # generation configs + prompt: "" + + llama_model: "please set this value to the path of vicuna model" + +preprocess: + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" diff --git a/minigpt4/configs/models/minigpt_v2.yaml b/minigpt4/configs/models/minigpt_v2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1d85d203b58e859d026b85565b34ad6724951eac --- /dev/null +++ b/minigpt4/configs/models/minigpt_v2.yaml @@ -0,0 +1,31 @@ +model: + arch: minigpt_v2 + + # vit encoder + image_size: 448 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + freeze_vit: True + + # generation configs + prompt: "" + + llama_model: "please set this value to the path of llama2-chat-7b" + lora_r: 64 + lora_alpha: 16 + + +preprocess: + vis_processor: + train: + name: "blip2_image_train" + image_size: 448 + eval: + name: "blip2_image_eval" + image_size: 448 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" diff --git a/minigpt4/conversation/__init__.py b/minigpt4/conversation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/minigpt4/conversation/conversation.py b/minigpt4/conversation/conversation.py new file mode 100644 index 0000000000000000000000000000000000000000..9c27c785db4a36f6cad16e6c209eeac8e606044b --- /dev/null +++ b/minigpt4/conversation/conversation.py @@ -0,0 +1,238 @@ +import argparse +import time +from threading import Thread +from PIL import Image + +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer +from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer + +import dataclasses +from enum import auto, Enum +from typing import List, Tuple, Any + +from minigpt4.common.registry import registry + + +class SeparatorStyle(Enum): + """Different separator style.""" + SINGLE = auto() + TWO = auto() + + +@dataclasses.dataclass +class Conversation: + """A class that keeps all conversation history.""" + system: str + roles: List[str] + messages: List[List[str]] + offset: int + # system_img: List[Image.Image] = [] + sep_style: SeparatorStyle = SeparatorStyle.SINGLE + sep: str = "###" + sep2: str = None + + skip_next: bool = False + conv_id: Any = None + + def get_prompt(self): + if self.sep_style == SeparatorStyle.SINGLE: + ret = self.system + self.sep + for role, message in self.messages: + if message: + ret += role + message + self.sep + else: + ret += role + return ret + elif self.sep_style == SeparatorStyle.TWO: + seps = [self.sep, self.sep2] + ret = self.system + seps[0] + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + message + seps[i % 2] + else: + ret += role + return ret + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + def append_message(self, role, message): + self.messages.append([role, message]) + + def to_gradio_chatbot(self): + ret = [] + for i, (role, msg) in enumerate(self.messages[self.offset:]): + if i % 2 == 0: + ret.append([msg, None]) + else: + ret[-1][-1] = msg + return ret + + def copy(self): + return Conversation( + system=self.system, + # system_img=self.system_img, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2, + conv_id=self.conv_id) + + def dict(self): + return { + "system": self.system, + # "system_img": self.system_img, + "roles": self.roles, + "messages": self.messages, + "offset": self.offset, + "sep": self.sep, + "sep2": self.sep2, + "conv_id": self.conv_id, + } + + +class StoppingCriteriaSub(StoppingCriteria): + + def __init__(self, stops=[], encounters=1): + super().__init__() + self.stops = stops + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): + for stop in self.stops: + if torch.all((stop == input_ids[0][-len(stop):])).item(): + return True + + return False + + +CONV_VISION_Vicuna0 = Conversation( + system="Give the following image: ImageContent. " + "You will be able to see the image once I provide it to you. Please answer my questions.", + roles=("Human: ", "Assistant: "), + messages=[], + offset=2, + sep_style=SeparatorStyle.SINGLE, + sep="###", +) + +CONV_VISION_LLama2 = Conversation( + system="Give the following image: ImageContent. " + "You will be able to see the image once I provide it to you. Please answer my questions.", + roles=("[INST] ", " [/INST] "), + messages=[], + offset=2, + sep_style=SeparatorStyle.SINGLE, + sep="", +) + + + +class Chat: + def __init__(self, model, vis_processor, device='cuda:0', stopping_criteria=None): + self.device = device + self.model = model + self.vis_processor = vis_processor + + if stopping_criteria is not None: + self.stopping_criteria = stopping_criteria + else: + stop_words_ids = [torch.tensor([2]).to(self.device)] + self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) + + def ask(self, text, conv): + if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \ + and conv.messages[-1][1][-6:] == '': # last message is image. + conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text]) + else: + conv.append_message(conv.roles[0], text) + + def answer_prepare(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9, + repetition_penalty=1.05, length_penalty=1, temperature=1.0, max_length=2000): + conv.append_message(conv.roles[1], None) + embs = self.get_context_emb(conv, img_list) + + current_max_len = embs.shape[1] + max_new_tokens + if current_max_len - max_length > 0: + print('Warning: The number of tokens in current conversation exceeds the max length. ' + 'The model will not see the contexts outside the range.') + begin_idx = max(0, current_max_len - max_length) + embs = embs[:, begin_idx:] + + generation_kwargs = dict( + inputs_embeds=embs, + max_new_tokens=max_new_tokens, + stopping_criteria=self.stopping_criteria, + num_beams=num_beams, + do_sample=True, + min_length=min_length, + top_p=top_p, + repetition_penalty=repetition_penalty, + length_penalty=length_penalty, + temperature=temperature, + ) + return generation_kwargs + + def answer(self, conv, img_list, **kargs): + generation_dict = self.answer_prepare(conv, img_list, **kargs) + + output_token = self.model.llama_model.generate(**generation_dict)[0] + output_text = self.model.llama_tokenizer.decode(output_token, skip_special_tokens=True) + + output_text = output_text.split('###')[0] # remove the stop sign '###' + output_text = output_text.split('Assistant:')[-1].strip() + + conv.messages[-1][1] = output_text + return output_text, output_token.cpu().numpy() + + def stream_answer(self, conv, img_list, **kargs): + generation_kwargs = self.answer_prepare(conv, img_list, **kargs) + streamer = TextIteratorStreamer(self.model.llama_tokenizer, skip_special_tokens=True) + generation_kwargs['streamer'] = streamer + thread = Thread(target=self.model.llama_model.generate, kwargs=generation_kwargs) + thread.start() + return streamer + + def encode_img(self, img_list): + image = img_list[0] + img_list.pop(0) + if isinstance(image, str): # is a image path + raw_image = Image.open(image).convert('RGB') + image = self.vis_processor(raw_image).unsqueeze(0).to(self.device) + elif isinstance(image, Image.Image): + raw_image = image + image = self.vis_processor(raw_image).unsqueeze(0).to(self.device) + elif isinstance(image, torch.Tensor): + if len(image.shape) == 3: + image = image.unsqueeze(0) + image = image.to(self.device) + + image_emb, _ = self.model.encode_img(image) + img_list.append(image_emb) + + def upload_img(self, image, conv, img_list): + conv.append_message(conv.roles[0], "") + img_list.append(image) + msg = "Received." + + return msg + + def get_context_emb(self, conv, img_list): + prompt = conv.get_prompt() + prompt_segs = prompt.split('') + assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images." + seg_tokens = [ + self.model.llama_tokenizer( + seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids + # only add bos to the first seg + for i, seg in enumerate(prompt_segs) + ] + print('debug device: ', self.device) + print('debug model device: ', self.model.device) + seg_embs = [self.model.embed_tokens(seg_t) for seg_t in seg_tokens] + mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]] + mixed_embs = torch.cat(mixed_embs, dim=1) + return mixed_embs + + diff --git a/minigpt4/datasets/__init__.py b/minigpt4/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/minigpt4/datasets/builders/__init__.py b/minigpt4/datasets/builders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6d0964063f145c6b119c78460aed69bcc4dfa4c1 --- /dev/null +++ b/minigpt4/datasets/builders/__init__.py @@ -0,0 +1,72 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from minigpt4.datasets.builders.base_dataset_builder import load_dataset_config +from minigpt4.datasets.builders.image_text_pair_builder import ( + CCSBUBuilder, + LaionBuilder, + CCSBUAlignBuilder +) +from minigpt4.common.registry import registry + +__all__ = [ + "CCSBUBuilder", + "LaionBuilder", + "CCSBUAlignBuilder" +] + + +def load_dataset(name, cfg_path=None, vis_path=None, data_type=None): + """ + Example + + >>> dataset = load_dataset("coco_caption", cfg=None) + >>> splits = dataset.keys() + >>> print([len(dataset[split]) for split in splits]) + + """ + if cfg_path is None: + cfg = None + else: + cfg = load_dataset_config(cfg_path) + + try: + builder = registry.get_builder_class(name)(cfg) + except TypeError: + print( + f"Dataset {name} not found. Available datasets:\n" + + ", ".join([str(k) for k in dataset_zoo.get_names()]) + ) + exit(1) + + if vis_path is not None: + if data_type is None: + # use default data type in the config + data_type = builder.config.data_type + + assert ( + data_type in builder.config.build_info + ), f"Invalid data_type {data_type} for {name}." + + builder.config.build_info.get(data_type).storage = vis_path + + dataset = builder.build_datasets() + return dataset + + +class DatasetZoo: + def __init__(self) -> None: + self.dataset_zoo = { + k: list(v.DATASET_CONFIG_DICT.keys()) + for k, v in sorted(registry.mapping["builder_name_mapping"].items()) + } + + def get_names(self): + return list(self.dataset_zoo.keys()) + + +dataset_zoo = DatasetZoo() diff --git a/minigpt4/datasets/builders/base_dataset_builder.py b/minigpt4/datasets/builders/base_dataset_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..4b607e3c0a8abaa6b1ccbc711e27ff3755f5ec11 --- /dev/null +++ b/minigpt4/datasets/builders/base_dataset_builder.py @@ -0,0 +1,236 @@ +""" + This file is from + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import logging +import os +import shutil +import warnings + +from omegaconf import OmegaConf +import torch.distributed as dist +from torchvision.datasets.utils import download_url + +import minigpt4.common.utils as utils +from minigpt4.common.dist_utils import is_dist_avail_and_initialized, is_main_process +from minigpt4.common.registry import registry +from minigpt4.processors.base_processor import BaseProcessor + + + +class BaseDatasetBuilder: + train_dataset_cls, eval_dataset_cls = None, None + + def __init__(self, cfg=None): + super().__init__() + + if cfg is None: + # help to create datasets from default config. + self.config = load_dataset_config(self.default_config_path()) + elif isinstance(cfg, str): + self.config = load_dataset_config(cfg) + else: + # when called from task.build_dataset() + self.config = cfg + + self.data_type = self.config.data_type + + self.vis_processors = {"train": BaseProcessor(), "eval": BaseProcessor()} + self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()} + + def build_datasets(self): + # download, split, etc... + # only called on 1 GPU/TPU in distributed + + if is_main_process(): + self._download_data() + + if is_dist_avail_and_initialized(): + dist.barrier() + + # at this point, all the annotations and image/videos should be all downloaded to the specified locations. + logging.info("Building datasets...") + datasets = self.build() # dataset['train'/'val'/'test'] + + return datasets + + def build_processors(self): + vis_proc_cfg = self.config.get("vis_processor") + txt_proc_cfg = self.config.get("text_processor") + + if vis_proc_cfg is not None: + vis_train_cfg = vis_proc_cfg.get("train") + vis_eval_cfg = vis_proc_cfg.get("eval") + + self.vis_processors["train"] = self._build_proc_from_cfg(vis_train_cfg) + self.vis_processors["eval"] = self._build_proc_from_cfg(vis_eval_cfg) + + if txt_proc_cfg is not None: + txt_train_cfg = txt_proc_cfg.get("train") + txt_eval_cfg = txt_proc_cfg.get("eval") + + self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg) + self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg) + + @staticmethod + def _build_proc_from_cfg(cfg): + return ( + registry.get_processor_class(cfg.name).from_config(cfg) + if cfg is not None + else None + ) + + @classmethod + def default_config_path(cls, type="default"): + return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type]) + + def _download_data(self): + self._download_ann() + self._download_vis() + + def _download_ann(self): + """ + Download annotation files if necessary. + All the vision-language datasets should have annotations of unified format. + + storage_path can be: + (1) relative/absolute: will be prefixed with env.cache_root to make full path if relative. + (2) basename/dirname: will be suffixed with base name of URL if dirname is provided. + + Local annotation paths should be relative. + """ + anns = self.config.build_info.annotations + + splits = anns.keys() + + cache_root = registry.get_path("cache_root") + + for split in splits: + info = anns[split] + + urls, storage_paths = info.get("url", None), info.storage + + if isinstance(urls, str): + urls = [urls] + if isinstance(storage_paths, str): + storage_paths = [storage_paths] + + assert len(urls) == len(storage_paths) + + for url_or_filename, storage_path in zip(urls, storage_paths): + # if storage_path is relative, make it full by prefixing with cache_root. + if not os.path.isabs(storage_path): + storage_path = os.path.join(cache_root, storage_path) + + dirname = os.path.dirname(storage_path) + if not os.path.exists(dirname): + os.makedirs(dirname) + + if os.path.isfile(url_or_filename): + src, dst = url_or_filename, storage_path + if not os.path.exists(dst): + shutil.copyfile(src=src, dst=dst) + else: + logging.info("Using existing file {}.".format(dst)) + else: + if os.path.isdir(storage_path): + # if only dirname is provided, suffix with basename of URL. + raise ValueError( + "Expecting storage_path to be a file path, got directory {}".format( + storage_path + ) + ) + else: + filename = os.path.basename(storage_path) + + download_url(url=url_or_filename, root=dirname, filename=filename) + + def _download_vis(self): + + storage_path = self.config.build_info.get(self.data_type).storage + storage_path = utils.get_cache_path(storage_path) + + if not os.path.exists(storage_path): + warnings.warn( + f""" + The specified path {storage_path} for visual inputs does not exist. + Please provide a correct path to the visual inputs or + refer to datasets/download_scripts/README.md for downloading instructions. + """ + ) + + def build(self): + """ + Create by split datasets inheriting torch.utils.data.Datasets. + + # build() can be dataset-specific. Overwrite to customize. + """ + self.build_processors() + + build_info = self.config.build_info + + ann_info = build_info.annotations + vis_info = build_info.get(self.data_type) + + datasets = dict() + for split in ann_info.keys(): + if split not in ["train", "val", "test"]: + continue + + is_train = split == "train" + + # processors + vis_processor = ( + self.vis_processors["train"] + if is_train + else self.vis_processors["eval"] + ) + text_processor = ( + self.text_processors["train"] + if is_train + else self.text_processors["eval"] + ) + + # annotation path + ann_paths = ann_info.get(split).storage + if isinstance(ann_paths, str): + ann_paths = [ann_paths] + + abs_ann_paths = [] + for ann_path in ann_paths: + if not os.path.isabs(ann_path): + ann_path = utils.get_cache_path(ann_path) + abs_ann_paths.append(ann_path) + ann_paths = abs_ann_paths + + # visual data storage path + vis_path = os.path.join(vis_info.storage, split) + + if not os.path.isabs(vis_path): + # vis_path = os.path.join(utils.get_cache_path(), vis_path) + vis_path = utils.get_cache_path(vis_path) + + if not os.path.exists(vis_path): + warnings.warn("storage path {} does not exist.".format(vis_path)) + + # create datasets + dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls + datasets[split] = dataset_cls( + vis_processor=vis_processor, + text_processor=text_processor, + ann_paths=ann_paths, + vis_root=vis_path, + ) + + return datasets + + +def load_dataset_config(cfg_path): + cfg = OmegaConf.load(cfg_path).datasets + cfg = cfg[list(cfg.keys())[0]] + + return cfg diff --git a/minigpt4/datasets/builders/image_text_pair_builder.py b/minigpt4/datasets/builders/image_text_pair_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..e5d66b8f63d4489b1e6fa7954cb80b5e7344f473 --- /dev/null +++ b/minigpt4/datasets/builders/image_text_pair_builder.py @@ -0,0 +1,105 @@ +import os +import logging +import warnings + +from minigpt4.common.registry import registry +from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder +from minigpt4.datasets.datasets.laion_dataset import LaionDataset +from minigpt4.datasets.datasets.cc_sbu_dataset import CCSBUDataset, CCSBUAlignDataset + + +@registry.register_builder("cc_sbu") +class CCSBUBuilder(BaseDatasetBuilder): + train_dataset_cls = CCSBUDataset + + DATASET_CONFIG_DICT = {"default": "configs/datasets/cc_sbu/defaults.yaml"} + + def _download_ann(self): + pass + + def _download_vis(self): + pass + + def build(self): + self.build_processors() + + build_info = self.config.build_info + + datasets = dict() + split = "train" + + # create datasets + # [NOTE] return inner_datasets (wds.DataPipeline) + dataset_cls = self.train_dataset_cls + datasets[split] = dataset_cls( + vis_processor=self.vis_processors[split], + text_processor=self.text_processors[split], + location=build_info.storage, + ).inner_dataset + + return datasets + + +@registry.register_builder("laion") +class LaionBuilder(BaseDatasetBuilder): + train_dataset_cls = LaionDataset + + DATASET_CONFIG_DICT = {"default": "configs/datasets/laion/defaults.yaml"} + + def _download_ann(self): + pass + + def _download_vis(self): + pass + + def build(self): + self.build_processors() + + build_info = self.config.build_info + + datasets = dict() + split = "train" + + # create datasets + # [NOTE] return inner_datasets (wds.DataPipeline) + dataset_cls = self.train_dataset_cls + datasets[split] = dataset_cls( + vis_processor=self.vis_processors[split], + text_processor=self.text_processors[split], + location=build_info.storage, + ).inner_dataset + + return datasets + + +@registry.register_builder("cc_sbu_align") +class CCSBUAlignBuilder(BaseDatasetBuilder): + train_dataset_cls = CCSBUAlignDataset + + DATASET_CONFIG_DICT = { + "default": "configs/datasets/cc_sbu/align.yaml", + } + + def build_datasets(self): + # at this point, all the annotations and image/videos should be all downloaded to the specified locations. + logging.info("Building datasets...") + self.build_processors() + + build_info = self.config.build_info + storage_path = build_info.storage + + datasets = dict() + + if not os.path.exists(storage_path): + warnings.warn("storage path {} does not exist.".format(storage_path)) + + # create datasets + dataset_cls = self.train_dataset_cls + datasets['train'] = dataset_cls( + vis_processor=self.vis_processors["train"], + text_processor=self.text_processors["train"], + ann_paths=[os.path.join(storage_path, 'filter_cap.json')], + vis_root=os.path.join(storage_path, 'image'), + ) + + return datasets diff --git a/minigpt4/datasets/data_utils.py b/minigpt4/datasets/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cf6497fd4389295d11b1c19f6927aba7ac658d1d --- /dev/null +++ b/minigpt4/datasets/data_utils.py @@ -0,0 +1,196 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import gzip +import logging +import os +import random as rnd +import tarfile +import zipfile +import random +from typing import List +from tqdm import tqdm + +import decord +from decord import VideoReader +import webdataset as wds +import numpy as np +import torch +from torch.utils.data.dataset import IterableDataset + +from minigpt4.common.registry import registry +from minigpt4.datasets.datasets.base_dataset import ConcatDataset + + +decord.bridge.set_bridge("torch") +MAX_INT = registry.get("MAX_INT") + + +class ChainDataset(wds.DataPipeline): + r"""Dataset for chaining multiple :class:`DataPipeline` s. + + This class is useful to assemble different existing dataset streams. The + chaining operation is done on-the-fly, so concatenating large-scale + datasets with this class will be efficient. + + Args: + datasets (iterable of IterableDataset): datasets to be chained together + """ + def __init__(self, datasets: List[wds.DataPipeline]) -> None: + super().__init__() + self.datasets = datasets + self.prob = [] + self.names = [] + for dataset in self.datasets: + if hasattr(dataset, 'name'): + self.names.append(dataset.name) + else: + self.names.append('Unknown') + if hasattr(dataset, 'sample_ratio'): + self.prob.append(dataset.sample_ratio) + else: + self.prob.append(1) + logging.info("One of the datapipeline doesn't define ratio and set to 1 automatically.") + + def __iter__(self): + datastreams = [iter(dataset) for dataset in self.datasets] + while True: + select_datastream = random.choices(datastreams, weights=self.prob, k=1)[0] + yield next(select_datastream) + + +def apply_to_sample(f, sample): + if len(sample) == 0: + return {} + + def _apply(x): + if torch.is_tensor(x): + return f(x) + elif isinstance(x, dict): + return {key: _apply(value) for key, value in x.items()} + elif isinstance(x, list): + return [_apply(x) for x in x] + else: + return x + + return _apply(sample) + + +def move_to_cuda(sample): + def _move_to_cuda(tensor): + return tensor.cuda() + + return apply_to_sample(_move_to_cuda, sample) + + +def prepare_sample(samples, cuda_enabled=True): + if cuda_enabled: + samples = move_to_cuda(samples) + + # TODO fp16 support + + return samples + + +def reorg_datasets_by_split(datasets): + """ + Organizes datasets by split. + + Args: + datasets: dict of torch.utils.data.Dataset objects by name. + + Returns: + Dict of datasets by split {split_name: List[Datasets]}. + """ + # if len(datasets) == 1: + # return datasets[list(datasets.keys())[0]] + # else: + reorg_datasets = dict() + + # reorganize by split + for _, dataset in datasets.items(): + for split_name, dataset_split in dataset.items(): + if split_name not in reorg_datasets: + reorg_datasets[split_name] = [dataset_split] + else: + reorg_datasets[split_name].append(dataset_split) + + return reorg_datasets + + +def concat_datasets(datasets): + """ + Concatenates multiple datasets into a single dataset. + + It supports may-style datasets and DataPipeline from WebDataset. Currently, does not support + generic IterableDataset because it requires creating separate samplers. + + Now only supports conctenating training datasets and assuming validation and testing + have only a single dataset. This is because metrics should not be computed on the concatenated + datasets. + + Args: + datasets: dict of torch.utils.data.Dataset objects by split. + + Returns: + Dict of concatenated datasets by split, "train" is the concatenation of multiple datasets, + "val" and "test" remain the same. + + If the input training datasets contain both map-style and DataPipeline datasets, returns + a tuple, where the first element is a concatenated map-style dataset and the second + element is a chained DataPipeline dataset. + + """ + # concatenate datasets in the same split + for split_name in datasets: + if split_name != "train": + assert ( + len(datasets[split_name]) == 1 + ), "Do not support multiple {} datasets.".format(split_name) + datasets[split_name] = datasets[split_name][0] + else: + iterable_datasets, map_datasets = [], [] + for dataset in datasets[split_name]: + if isinstance(dataset, wds.DataPipeline): + logging.info( + "Dataset {} is IterableDataset, can't be concatenated.".format( + dataset + ) + ) + iterable_datasets.append(dataset) + elif isinstance(dataset, IterableDataset): + raise NotImplementedError( + "Do not support concatenation of generic IterableDataset." + ) + else: + map_datasets.append(dataset) + + # if len(iterable_datasets) > 0: + # concatenate map-style datasets and iterable-style datasets separately + if len(iterable_datasets) > 1: + chained_datasets = ( + ChainDataset(iterable_datasets) + ) + elif len(iterable_datasets) == 1: + chained_datasets = iterable_datasets[0] + else: + chained_datasets = None + + concat_datasets = ( + ConcatDataset(map_datasets) if len(map_datasets) > 0 else None + ) + + train_datasets = concat_datasets, chained_datasets + train_datasets = tuple([x for x in train_datasets if x is not None]) + train_datasets = ( + train_datasets[0] if len(train_datasets) == 1 else train_datasets + ) + + datasets[split_name] = train_datasets + + return datasets + diff --git a/minigpt4/datasets/datasets/__init__.py b/minigpt4/datasets/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/minigpt4/datasets/datasets/base_dataset.py b/minigpt4/datasets/datasets/base_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ae2a8d0e21370129c0182cddc427eb293bbe5982 --- /dev/null +++ b/minigpt4/datasets/datasets/base_dataset.py @@ -0,0 +1,68 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import json +from typing import Iterable + +from torch.utils.data import Dataset, ConcatDataset +from torch.utils.data.dataloader import default_collate + + +class BaseDataset(Dataset): + def __init__( + self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[] + ): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + self.vis_root = vis_root + + self.annotation = [] + for ann_path in ann_paths: + self.annotation.extend(json.load(open(ann_path, "r"))['annotations']) + + self.vis_processor = vis_processor + self.text_processor = text_processor + + self._add_instance_ids() + + def __len__(self): + return len(self.annotation) + + def collater(self, samples): + return default_collate(samples) + + def set_processors(self, vis_processor, text_processor): + self.vis_processor = vis_processor + self.text_processor = text_processor + + def _add_instance_ids(self, key="instance_id"): + for idx, ann in enumerate(self.annotation): + ann[key] = str(idx) + + +class ConcatDataset(ConcatDataset): + def __init__(self, datasets: Iterable[Dataset]) -> None: + super().__init__(datasets) + + def collater(self, samples): + # TODO For now only supports datasets with same underlying collater implementations + + all_keys = set() + for s in samples: + all_keys.update(s) + + shared_keys = all_keys + for s in samples: + shared_keys = shared_keys & set(s.keys()) + + samples_shared_keys = [] + for s in samples: + samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys}) + + return self.datasets[0].collater(samples_shared_keys) diff --git a/minigpt4/datasets/datasets/caption_datasets.py b/minigpt4/datasets/datasets/caption_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..78bab668d34c8a28917af171700d43dbb20f3926 --- /dev/null +++ b/minigpt4/datasets/datasets/caption_datasets.py @@ -0,0 +1,85 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import os +from collections import OrderedDict + +from minigpt4.datasets.datasets.base_dataset import BaseDataset +from PIL import Image + + +class __DisplMixin: + def displ_item(self, index): + sample, ann = self.__getitem__(index), self.annotation[index] + + return OrderedDict( + { + "file": ann["image"], + "caption": ann["caption"], + "image": sample["image"], + } + ) + + +class CaptionDataset(BaseDataset, __DisplMixin): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + + self.img_ids = {} + n = 0 + for ann in self.annotation: + img_id = ann["image_id"] + if img_id not in self.img_ids.keys(): + self.img_ids[img_id] = n + n += 1 + + def __getitem__(self, index): + + # TODO this assumes image input, not general enough + ann = self.annotation[index] + + img_file = '{:0>12}.jpg'.format(ann["image_id"]) + image_path = os.path.join(self.vis_root, img_file) + image = Image.open(image_path).convert("RGB") + + image = self.vis_processor(image) + caption = self.text_processor(ann["caption"]) + + return { + "image": image, + "text_input": caption, + "image_id": self.img_ids[ann["image_id"]], + } + + +class CaptionEvalDataset(BaseDataset, __DisplMixin): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + split (string): val or test + """ + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + + def __getitem__(self, index): + + ann = self.annotation[index] + + image_path = os.path.join(self.vis_root, ann["image"]) + image = Image.open(image_path).convert("RGB") + + image = self.vis_processor(image) + + return { + "image": image, + "image_id": ann["image_id"], + "instance_id": ann["instance_id"], + } diff --git a/minigpt4/datasets/datasets/cc_sbu_dataset.py b/minigpt4/datasets/datasets/cc_sbu_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..80b658d97ad47052653cecf25daeb512793bfc7b --- /dev/null +++ b/minigpt4/datasets/datasets/cc_sbu_dataset.py @@ -0,0 +1,47 @@ +import os +from PIL import Image +import webdataset as wds +from minigpt4.datasets.datasets.base_dataset import BaseDataset +from minigpt4.datasets.datasets.caption_datasets import CaptionDataset + + +class CCSBUDataset(BaseDataset): + def __init__(self, vis_processor, text_processor, location): + super().__init__(vis_processor=vis_processor, text_processor=text_processor) + + self.inner_dataset = wds.DataPipeline( + wds.ResampledShards(location), + wds.tarfile_to_samples(handler=wds.warn_and_continue), + wds.shuffle(1000, handler=wds.warn_and_continue), + wds.decode("pilrgb", handler=wds.warn_and_continue), + wds.to_tuple("jpg", "json", handler=wds.warn_and_continue), + wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue), + wds.map(self.to_dict, handler=wds.warn_and_continue), + ) + + def to_dict(self, sample): + return { + "image": sample[0], + "answer": self.text_processor(sample[1]["caption"]), + } + + +class CCSBUAlignDataset(CaptionDataset): + + def __getitem__(self, index): + + # TODO this assumes image input, not general enough + ann = self.annotation[index] + + img_file = '{}.jpg'.format(ann["image_id"]) + image_path = os.path.join(self.vis_root, img_file) + image = Image.open(image_path).convert("RGB") + + image = self.vis_processor(image) + caption = ann["caption"] + + return { + "image": image, + "answer": caption, + "image_id": self.img_ids[ann["image_id"]], + } \ No newline at end of file diff --git a/minigpt4/datasets/datasets/dataloader_utils.py b/minigpt4/datasets/datasets/dataloader_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8eaa3a58b0ad42ca7937fb51b46e53511cc3cd0c --- /dev/null +++ b/minigpt4/datasets/datasets/dataloader_utils.py @@ -0,0 +1,162 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import time +import random +import torch +from minigpt4.datasets.data_utils import move_to_cuda +from torch.utils.data import DataLoader + + +class MultiIterLoader: + """ + A simple wrapper for iterating over multiple iterators. + + Args: + loaders (List[Loader]): List of Iterator loaders. + ratios (List[float]): List of ratios to sample from each loader. If None, all loaders are sampled uniformly. + """ + + def __init__(self, loaders, ratios=None): + # assert all loaders has __next__ method + for loader in loaders: + assert hasattr( + loader, "__next__" + ), "Loader {} has no __next__ method.".format(loader) + + if ratios is None: + ratios = [1.0] * len(loaders) + else: + assert len(ratios) == len(loaders) + ratios = [float(ratio) / sum(ratios) for ratio in ratios] + + self.loaders = loaders + self.ratios = ratios + + def __next__(self): + # random sample from each loader by ratio + loader_idx = random.choices(range(len(self.loaders)), self.ratios, k=1)[0] + return next(self.loaders[loader_idx]) + + +class PrefetchLoader(object): + """ + Modified from https://github.com/ChenRocks/UNITER. + + overlap compute and cuda data transfer + (copied and then modified from nvidia apex) + """ + + def __init__(self, loader): + self.loader = loader + self.stream = torch.cuda.Stream() + + def __iter__(self): + loader_it = iter(self.loader) + self.preload(loader_it) + batch = self.next(loader_it) + while batch is not None: + is_tuple = isinstance(batch, tuple) + if is_tuple: + task, batch = batch + + if is_tuple: + yield task, batch + else: + yield batch + batch = self.next(loader_it) + + def __len__(self): + return len(self.loader) + + def preload(self, it): + try: + self.batch = next(it) + except StopIteration: + self.batch = None + return + # if record_stream() doesn't work, another option is to make sure + # device inputs are created on the main stream. + # self.next_input_gpu = torch.empty_like(self.next_input, + # device='cuda') + # self.next_target_gpu = torch.empty_like(self.next_target, + # device='cuda') + # Need to make sure the memory allocated for next_* is not still in use + # by the main stream at the time we start copying to next_*: + # self.stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self.stream): + self.batch = move_to_cuda(self.batch) + # more code for the alternative if record_stream() doesn't work: + # copy_ will record the use of the pinned source tensor in this + # side stream. + # self.next_input_gpu.copy_(self.next_input, non_blocking=True) + # self.next_target_gpu.copy_(self.next_target, non_blocking=True) + # self.next_input = self.next_input_gpu + # self.next_target = self.next_target_gpu + + def next(self, it): + torch.cuda.current_stream().wait_stream(self.stream) + batch = self.batch + if batch is not None: + record_cuda_stream(batch) + self.preload(it) + return batch + + def __getattr__(self, name): + method = self.loader.__getattribute__(name) + return method + + +def record_cuda_stream(batch): + if isinstance(batch, torch.Tensor): + batch.record_stream(torch.cuda.current_stream()) + elif isinstance(batch, list) or isinstance(batch, tuple): + for t in batch: + record_cuda_stream(t) + elif isinstance(batch, dict): + for t in batch.values(): + record_cuda_stream(t) + else: + pass + + +class IterLoader: + """ + A wrapper to convert DataLoader as an infinite iterator. + + Modified from: + https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py + """ + + def __init__(self, dataloader: DataLoader, use_distributed: bool = False): + self._dataloader = dataloader + self.iter_loader = iter(self._dataloader) + self._use_distributed = use_distributed + self._epoch = 0 + + @property + def epoch(self) -> int: + return self._epoch + + def __next__(self): + try: + data = next(self.iter_loader) + except StopIteration: + self._epoch += 1 + if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed: + self._dataloader.sampler.set_epoch(self._epoch) + time.sleep(2) # Prevent possible deadlock during epoch transition + self.iter_loader = iter(self._dataloader) + data = next(self.iter_loader) + + return data + + def __iter__(self): + return self + + def __len__(self): + return len(self._dataloader) diff --git a/minigpt4/datasets/datasets/laion_dataset.py b/minigpt4/datasets/datasets/laion_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..6f3ce873a44bcc675a8b5b50d2aff0b8c542ac26 --- /dev/null +++ b/minigpt4/datasets/datasets/laion_dataset.py @@ -0,0 +1,31 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import webdataset as wds +from minigpt4.datasets.datasets.base_dataset import BaseDataset + + +class LaionDataset(BaseDataset): + def __init__(self, vis_processor, text_processor, location): + super().__init__(vis_processor=vis_processor, text_processor=text_processor) + + self.inner_dataset = wds.DataPipeline( + wds.ResampledShards(location), + wds.tarfile_to_samples(handler=wds.warn_and_continue), + wds.shuffle(1000, handler=wds.warn_and_continue), + wds.decode("pilrgb", handler=wds.warn_and_continue), + wds.to_tuple("jpg", "json", handler=wds.warn_and_continue), + wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue), + wds.map(self.to_dict, handler=wds.warn_and_continue), + ) + + def to_dict(self, sample): + return { + "image": sample[0], + "answer": self.text_processor(sample[1]["caption"]), + } + diff --git a/minigpt4/models/Qformer.py b/minigpt4/models/Qformer.py new file mode 100644 index 0000000000000000000000000000000000000000..e71b12375e10511858a9c505dc795181e6ce5603 --- /dev/null +++ b/minigpt4/models/Qformer.py @@ -0,0 +1,1216 @@ +""" + * Copyright (c) 2023, salesforce.com, inc. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + * By Junnan Li + * Based on huggingface code base + * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert +""" + +import math +import os +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple, Dict, Any + +import torch +from torch import Tensor, device, dtype, nn +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss +import torch.nn.functional as F + +from transformers.activations import ACT2FN +from transformers.file_utils import ( + ModelOutput, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + NextSentencePredictorOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from transformers.modeling_utils import ( + PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from transformers.utils import logging +from transformers.models.bert.configuration_bert import BertConfig + +logger = logging.get_logger(__name__) + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id + ) + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size + ) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + ) + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) + + self.config = config + + def forward( + self, + input_ids=None, + position_ids=None, + query_embeds=None, + past_key_values_length=0, + ): + if input_ids is not None: + seq_length = input_ids.size()[1] + else: + seq_length = 0 + + if position_ids is None: + position_ids = self.position_ids[ + :, past_key_values_length : seq_length + past_key_values_length + ].clone() + + if input_ids is not None: + embeddings = self.word_embeddings(input_ids) + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + + if query_embeds is not None: + embeddings = torch.cat((query_embeds, embeddings), dim=1) + else: + embeddings = query_embeds + + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config, is_cross_attention): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr( + config, "embedding_size" + ): + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + if is_cross_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) + self.value = nn.Linear(config.encoder_width, self.all_head_size) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding( + 2 * config.max_position_embeddings - 1, self.attention_head_size + ) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + mixed_query_layer = self.query(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(-1, 1) + position_ids_r = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding( + distance + self.max_position_embeddings - 1 + ) + positional_embedding = positional_embedding.to( + dtype=query_layer.dtype + ) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + relative_position_scores_key = torch.einsum( + "bhrd,lrd->bhlr", key_layer, positional_embedding + ) + attention_scores = ( + attention_scores + + relative_position_scores_query + + relative_position_scores_key + ) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = ( + (context_layer, attention_probs) if output_attentions else (context_layer,) + ) + + outputs = outputs + (past_key_value,) + return outputs + + +class BertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config, is_cross_attention=False): + super().__init__() + self.self = BertSelfAttention(config, is_cross_attention) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, + self.self.num_attention_heads, + self.self.attention_head_size, + self.pruned_heads, + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = ( + self.self.attention_head_size * self.self.num_attention_heads + ) + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[ + 1: + ] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.layer_num = layer_num + if ( + self.config.add_cross_attention + and layer_num % self.config.cross_attention_freq == 0 + ): + self.crossattention = BertAttention( + config, is_cross_attention=self.config.add_cross_attention + ) + self.has_cross_attention = True + else: + self.has_cross_attention = False + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + self.intermediate_query = BertIntermediate(config) + self.output_query = BertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + query_length=0, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = ( + past_key_value[:2] if past_key_value is not None else None + ) + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:-1] + + present_key_value = self_attention_outputs[-1] + + if query_length > 0: + query_attention_output = attention_output[:, :query_length, :] + + if self.has_cross_attention: + assert ( + encoder_hidden_states is not None + ), "encoder_hidden_states must be given for cross-attention layers" + cross_attention_outputs = self.crossattention( + query_attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + query_attention_output = cross_attention_outputs[0] + outputs = ( + outputs + cross_attention_outputs[1:-1] + ) # add cross attentions if we output attention weights + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk_query, + self.chunk_size_feed_forward, + self.seq_len_dim, + query_attention_output, + ) + if attention_output.shape[1] > query_length: + layer_output_text = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output[:, query_length:, :], + ) + layer_output = torch.cat([layer_output, layer_output_text], dim=1) + else: + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + outputs = (layer_output,) + outputs + + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + def feed_forward_chunk_query(self, attention_output): + intermediate_output = self.intermediate_query(attention_output) + layer_output = self.output_query(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [BertLayer(config, i) for i in range(config.num_hidden_layers)] + ) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + query_length=0, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = ( + () if output_attentions and self.config.add_cross_attention else None + ) + + next_decoder_cache = () if use_cache else None + + for i in range(self.config.num_hidden_layers): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module( + *inputs, past_key_value, output_attentions, query_length + ) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + query_length, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + base_model_prefix = "bert" + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class BertModel(BertPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an + input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=False): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_extended_attention_mask( + self, + attention_mask: Tensor, + input_shape: Tuple[int], + device: device, + is_decoder: bool, + has_query: bool = False, + ) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + + seq_ids = torch.arange(seq_length, device=device) + causal_mask = ( + seq_ids[None, None, :].repeat(batch_size, seq_length, 1) + <= seq_ids[None, :, None] + ) + + # add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + if has_query: # UniLM style attention mask + causal_mask = torch.cat( + [ + torch.zeros( + (batch_size, prefix_seq_len, seq_length), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=1, + ) + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, causal_mask.shape[1], prefix_seq_len), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=-1, + ) + extended_attention_mask = ( + causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + ) + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=self.dtype + ) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # use_cache = use_cache if use_cache is not None else self.config.use_cache + + if input_ids is None: + assert ( + query_embeds is not None + ), "You have to specify query_embeds when input_ids is None" + + # past_key_values_length + past_key_values_length = ( + past_key_values[0][0].shape[2] - self.config.query_length + if past_key_values is not None + else 0 + ) + + query_length = query_embeds.shape[1] if query_embeds is not None else 0 + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + query_embeds=query_embeds, + past_key_values_length=past_key_values_length, + ) + + input_shape = embedding_output.size()[:-1] + batch_size, seq_length = input_shape + device = embedding_output.device + + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), device=device + ) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if is_decoder: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, + input_ids.shape, + device, + is_decoder, + has_query=(query_embeds is not None), + ) + else: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, input_shape, device, is_decoder + ) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[ + 0 + ].size() + else: + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [ + self.invert_attention_mask(mask) for mask in encoder_attention_mask + ] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + query_length=query_length, + ) + sequence_output = encoder_outputs[0] + pooled_output = ( + self.pooler(sequence_output) if self.pooler is not None else None + ) + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +class BertLMHeadModel(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=True, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=True, + reduction="mean", + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + Returns: + Example:: + >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + >>> config = BertConfig.from_pretrained("bert-base-cased") + >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.logits + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + if labels is not None: + use_cache = False + if past_key_values is not None: + query_embeds = None + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + + sequence_output = outputs[0] + if query_embeds is not None: + sequence_output = outputs[0][:, query_embeds.shape[1] :, :] + + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1) + lm_loss = loss_fct( + shifted_prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1), + ) + if reduction == "none": + lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs + ): + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + query_mask = input_ids.new_ones(query_embeds.shape[:-1]) + attention_mask = torch.cat([query_mask, attention_mask], dim=-1) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": input_ids, + "query_embeds": query_embeds, + "attention_mask": attention_mask, + "past_key_values": past, + "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), + "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), + "is_decoder": True, + } + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx) for past_state in layer_past + ), + ) + return reordered_past + + +class BertForMaskedLM(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=False, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., + config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + """ + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + + if query_embeds is not None: + sequence_output = outputs[0][:, query_embeds.shape[1] :, :] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), labels.view(-1) + ) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ( + ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + ) + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/minigpt4/models/__init__.py b/minigpt4/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bc01b56181aa81554efbe9df10ab3678a1c7bb86 --- /dev/null +++ b/minigpt4/models/__init__.py @@ -0,0 +1,202 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import logging +import torch +from omegaconf import OmegaConf + +from minigpt4.common.registry import registry +from minigpt4.models.base_model import BaseModel +from minigpt4.models.minigpt_base import MiniGPTBase +from minigpt4.models.minigpt4 import MiniGPT4 +from minigpt4.models.minigpt_v2 import MiniGPTv2 +from minigpt4.processors.base_processor import BaseProcessor + + +__all__ = [ + "load_model", + "BaseModel", + "MiniGPTBase", + "MiniGPT4", + "MiniGPTv2" +] + + +def load_model(name, model_type, is_eval=False, device="cpu", checkpoint=None): + """ + Load supported models. + + To list all available models and types in registry: + >>> from minigpt4.models import model_zoo + >>> print(model_zoo) + + Args: + name (str): name of the model. + model_type (str): type of the model. + is_eval (bool): whether the model is in eval mode. Default: False. + device (str): device to use. Default: "cpu". + checkpoint (str): path or to checkpoint. Default: None. + Note that expecting the checkpoint to have the same keys in state_dict as the model. + + Returns: + model (torch.nn.Module): model. + """ + + model = registry.get_model_class(name).from_pretrained(model_type=model_type) + + if checkpoint is not None: + model.load_checkpoint(checkpoint) + + if is_eval: + model.eval() + + if device == "cpu": + model = model.float() + + return model.to(device) + + +def load_preprocess(config): + """ + Load preprocessor configs and construct preprocessors. + + If no preprocessor is specified, return BaseProcessor, which does not do any preprocessing. + + Args: + config (dict): preprocessor configs. + + Returns: + vis_processors (dict): preprocessors for visual inputs. + txt_processors (dict): preprocessors for text inputs. + + Key is "train" or "eval" for processors used in training and evaluation respectively. + """ + + def _build_proc_from_cfg(cfg): + return ( + registry.get_processor_class(cfg.name).from_config(cfg) + if cfg is not None + else BaseProcessor() + ) + + vis_processors = dict() + txt_processors = dict() + + vis_proc_cfg = config.get("vis_processor") + txt_proc_cfg = config.get("text_processor") + + if vis_proc_cfg is not None: + vis_train_cfg = vis_proc_cfg.get("train") + vis_eval_cfg = vis_proc_cfg.get("eval") + else: + vis_train_cfg = None + vis_eval_cfg = None + + vis_processors["train"] = _build_proc_from_cfg(vis_train_cfg) + vis_processors["eval"] = _build_proc_from_cfg(vis_eval_cfg) + + if txt_proc_cfg is not None: + txt_train_cfg = txt_proc_cfg.get("train") + txt_eval_cfg = txt_proc_cfg.get("eval") + else: + txt_train_cfg = None + txt_eval_cfg = None + + txt_processors["train"] = _build_proc_from_cfg(txt_train_cfg) + txt_processors["eval"] = _build_proc_from_cfg(txt_eval_cfg) + + return vis_processors, txt_processors + + +def load_model_and_preprocess(name, model_type, is_eval=False, device="cpu"): + """ + Load model and its related preprocessors. + + List all available models and types in registry: + >>> from minigpt4.models import model_zoo + >>> print(model_zoo) + + Args: + name (str): name of the model. + model_type (str): type of the model. + is_eval (bool): whether the model is in eval mode. Default: False. + device (str): device to use. Default: "cpu". + + Returns: + model (torch.nn.Module): model. + vis_processors (dict): preprocessors for visual inputs. + txt_processors (dict): preprocessors for text inputs. + """ + model_cls = registry.get_model_class(name) + + # load model + model = model_cls.from_pretrained(model_type=model_type) + + if is_eval: + model.eval() + + # load preprocess + cfg = OmegaConf.load(model_cls.default_config_path(model_type)) + if cfg is not None: + preprocess_cfg = cfg.preprocess + + vis_processors, txt_processors = load_preprocess(preprocess_cfg) + else: + vis_processors, txt_processors = None, None + logging.info( + f"""No default preprocess for model {name} ({model_type}). + This can happen if the model is not finetuned on downstream datasets, + or it is not intended for direct use without finetuning. + """ + ) + + if device == "cpu" or device == torch.device("cpu"): + model = model.float() + + return model.to(device), vis_processors, txt_processors + + +class ModelZoo: + """ + A utility class to create string representation of available model architectures and types. + + >>> from minigpt4.models import model_zoo + >>> # list all available models + >>> print(model_zoo) + >>> # show total number of models + >>> print(len(model_zoo)) + """ + + def __init__(self) -> None: + self.model_zoo = { + k: list(v.PRETRAINED_MODEL_CONFIG_DICT.keys()) + for k, v in registry.mapping["model_name_mapping"].items() + } + + def __str__(self) -> str: + return ( + "=" * 50 + + "\n" + + f"{'Architectures':<30} {'Types'}\n" + + "=" * 50 + + "\n" + + "\n".join( + [ + f"{name:<30} {', '.join(types)}" + for name, types in self.model_zoo.items() + ] + ) + ) + + def __iter__(self): + return iter(self.model_zoo.items()) + + def __len__(self): + return sum([len(v) for v in self.model_zoo.values()]) + + +model_zoo = ModelZoo() diff --git a/minigpt4/models/base_model.py b/minigpt4/models/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..fd1d63679ccd0717c39dd96727dca9a85b549c57 --- /dev/null +++ b/minigpt4/models/base_model.py @@ -0,0 +1,248 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import os +import logging +import contextlib + +from omegaconf import OmegaConf +import numpy as np +import torch +import torch.nn as nn +from transformers import BertTokenizer, LlamaTokenizer +from transformers.models.llama.modeling_llama import LlamaForCausalLM +from peft import ( + LoraConfig, + get_peft_model, + prepare_model_for_int8_training, +) + +from minigpt4.common.dist_utils import download_cached_file, is_dist_avail_and_initialized +from minigpt4.common.utils import get_abs_path, is_url +from minigpt4.models.eva_vit import create_eva_vit_g + + + +class BaseModel(nn.Module): + """Base class for models.""" + + def __init__(self): + super().__init__() + + @property + def device(self): + return list(self.parameters())[-1].device + + def load_checkpoint(self, url_or_filename): + """ + Load from a finetuned checkpoint. + + This should expect no mismatch in the model keys and the checkpoint keys. + """ + + if is_url(url_or_filename): + cached_file = download_cached_file( + url_or_filename, check_hash=False, progress=True + ) + checkpoint = torch.load(cached_file, map_location="cpu") + elif os.path.isfile(url_or_filename): + checkpoint = torch.load(url_or_filename, map_location="cpu") + else: + raise RuntimeError("checkpoint url or path is invalid") + + if "model" in checkpoint.keys(): + state_dict = checkpoint["model"] + else: + state_dict = checkpoint + + msg = self.load_state_dict(state_dict, strict=False) + + logging.info("Missing keys {}".format(msg.missing_keys)) + logging.info("load checkpoint from %s" % url_or_filename) + + return msg + + @classmethod + def from_pretrained(cls, model_type): + """ + Build a pretrained model from default configuration file, specified by model_type. + + Args: + - model_type (str): model type, specifying architecture and checkpoints. + + Returns: + - model (nn.Module): pretrained or finetuned model, depending on the configuration. + """ + model_cfg = OmegaConf.load(cls.default_config_path(model_type)).model + model = cls.from_config(model_cfg) + + return model + + @classmethod + def default_config_path(cls, model_type): + assert ( + model_type in cls.PRETRAINED_MODEL_CONFIG_DICT + ), "Unknown model type {}".format(model_type) + return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type]) + + def load_checkpoint_from_config(self, cfg, **kwargs): + """ + Load checkpoint as specified in the config file. + + If load_finetuned is True, load the finetuned model; otherwise, load the pretrained model. + When loading the pretrained model, each task-specific architecture may define their + own load_from_pretrained() method. + """ + load_finetuned = cfg.get("load_finetuned", True) + if load_finetuned: + finetune_path = cfg.get("finetuned", None) + assert ( + finetune_path is not None + ), "Found load_finetuned is True, but finetune_path is None." + self.load_checkpoint(url_or_filename=finetune_path) + else: + # load pre-trained weights + pretrain_path = cfg.get("pretrained", None) + assert "Found load_finetuned is False, but pretrain_path is None." + self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs) + + def before_evaluation(self, **kwargs): + pass + + def show_n_params(self, return_str=True): + tot = 0 + for p in self.parameters(): + w = 1 + for x in p.shape: + w *= x + tot += w + if return_str: + if tot >= 1e6: + return "{:.1f}M".format(tot / 1e6) + else: + return "{:.1f}K".format(tot / 1e3) + else: + return tot + + def maybe_autocast(self, dtype=torch.float16): + # if on cpu, don't use autocast + # if on gpu, use autocast with dtype if provided, otherwise use torch.float16 + enable_autocast = self.device != torch.device("cpu") + + if enable_autocast: + return torch.cuda.amp.autocast(dtype=dtype) + else: + return contextlib.nullcontext() + + @classmethod + def init_vision_encoder( + cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision, freeze + ): + logging.info('Loading VIT') + + assert model_name == "eva_clip_g", "vit model must be eva_clip_g for current version of MiniGPT-4" + if not freeze: + precision = "fp32" # fp16 is not for training + + visual_encoder = create_eva_vit_g( + img_size, drop_path_rate, use_grad_checkpoint, precision + ) + + ln_vision = LayerNorm(visual_encoder.num_features) + + if freeze: + for name, param in visual_encoder.named_parameters(): + param.requires_grad = False + visual_encoder = visual_encoder.eval() + visual_encoder.train = disabled_train + for name, param in ln_vision.named_parameters(): + param.requires_grad = False + ln_vision = ln_vision.eval() + ln_vision.train = disabled_train + logging.info("freeze vision encoder") + + logging.info('Loading VIT Done') + return visual_encoder, ln_vision + + def init_llm(cls, llama_model_path, low_resource=False, low_res_device=0, lora_r=0, + lora_target_modules=["q_proj","v_proj"], **lora_kargs): + logging.info('Loading LLAMA') + llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model_path, use_fast=False) + llama_tokenizer.pad_token = "$$" + + if low_resource: + llama_model = LlamaForCausalLM.from_pretrained( + llama_model_path, + torch_dtype=torch.float16, + load_in_8bit=True, + device_map={'': low_res_device} + ) + else: + llama_model = LlamaForCausalLM.from_pretrained( + llama_model_path, + torch_dtype=torch.float16, + ) + + if lora_r > 0: + llama_model = prepare_model_for_int8_training(llama_model) + loraconfig = LoraConfig( + r=lora_r, + bias="none", + task_type="CAUSAL_LM", + target_modules=lora_target_modules, + **lora_kargs + ) + llama_model = get_peft_model(llama_model, loraconfig) + + llama_model.print_trainable_parameters() + + else: + for name, param in llama_model.named_parameters(): + param.requires_grad = False + logging.info('Loading LLAMA Done') + return llama_model, llama_tokenizer + + + def load_from_pretrained(self, url_or_filename): + if is_url(url_or_filename): + cached_file = download_cached_file( + url_or_filename, check_hash=False, progress=True + ) + checkpoint = torch.load(cached_file, map_location="cpu") + elif os.path.isfile(url_or_filename): + checkpoint = torch.load(url_or_filename, map_location="cpu") + else: + raise RuntimeError("checkpoint url or path is invalid") + + state_dict = checkpoint["model"] + + msg = self.load_state_dict(state_dict, strict=False) + + # logging.info("Missing keys {}".format(msg.missing_keys)) + logging.info("load checkpoint from %s" % url_or_filename) + + return msg + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + + + + diff --git a/minigpt4/models/eva_vit.py b/minigpt4/models/eva_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..7fcc63a74049f1faf65c99943ef94f72383ca3f5 --- /dev/null +++ b/minigpt4/models/eva_vit.py @@ -0,0 +1,442 @@ +# Based on EVA, BEIT, timm and DeiT code bases +# https://github.com/baaivision/EVA +# https://github.com/rwightman/pytorch-image-models/tree/master/timm +# https://github.com/microsoft/unilm/tree/master/beit +# https://github.com/facebookresearch/deit/ +# https://github.com/facebookresearch/dino +# --------------------------------------------------------' +import math +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import drop_path, to_2tuple, trunc_normal_ +from timm.models.registry import register_model + +from minigpt4.common.dist_utils import download_cached_file + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', + 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), + **kwargs + } + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return 'p={}'.format(self.drop_prob) + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + # x = self.drop(x) + # commit this for the orignal BERT implement + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__( + self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., + proj_drop=0., window_size=None, attn_head_dim=None): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * self.num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) + self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) + else: + self.q_bias = None + self.v_bias = None + + if window_size: + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = \ + torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer("relative_position_index", relative_position_index) + else: + self.window_size = None + self.relative_position_bias_table = None + self.relative_position_index = None + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(all_head_dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, rel_pos_bias=None): + B, N, C = x.shape + qkv_bias = None + if self.q_bias is not None: + qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) + # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + if self.relative_position_bias_table is not None: + relative_position_bias = \ + self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] + 1, + self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if rel_pos_bias is not None: + attn = attn + rel_pos_bias + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, + window_size=None, attn_head_dim=None): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if init_values is not None and init_values > 0: + self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) + self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) + else: + self.gamma_1, self.gamma_2 = None, None + + def forward(self, x, rel_pos_bias=None): + if self.gamma_1 is None: + x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + else: + x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) + x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x, **kwargs): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class RelativePositionBias(nn.Module): + + def __init__(self, window_size, num_heads): + super().__init__() + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = \ + torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer("relative_position_index", relative_position_index) + + # trunc_normal_(self.relative_position_bias_table, std=.02) + + def forward(self): + relative_position_bias = \ + self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] + 1, + self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH + return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + + +class VisionTransformer(nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, + use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, + use_mean_pooling=True, init_scale=0.001, use_checkpoint=False): + super().__init__() + self.image_size = img_size + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + if use_abs_pos_emb: + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + else: + self.pos_embed = None + self.pos_drop = nn.Dropout(p=drop_rate) + + if use_shared_rel_pos_bias: + self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads) + else: + self.rel_pos_bias = None + self.use_checkpoint = use_checkpoint + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.use_rel_pos_bias = use_rel_pos_bias + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None) + for i in range(depth)]) +# self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim) +# self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None +# self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + # trunc_normal_(self.mask_token, std=.02) +# if isinstance(self.head, nn.Linear): +# trunc_normal_(self.head.weight, std=.02) + self.apply(self._init_weights) + self.fix_init_weight() +# if isinstance(self.head, nn.Linear): +# self.head.weight.data.mul_(init_scale) +# self.head.bias.data.mul_(init_scale) + + def fix_init_weight(self): + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.patch_embed(x) + batch_size, seq_len, _ = x.size() + + cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + if self.pos_embed is not None: + x = x + self.pos_embed + x = self.pos_drop(x) + + rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, rel_pos_bias) + else: + x = blk(x, rel_pos_bias) + return x +# x = self.norm(x) + +# if self.fc_norm is not None: +# t = x[:, 1:, :] +# return self.fc_norm(t.mean(1)) +# else: +# return x[:, 0] + + def forward(self, x): + x = self.forward_features(x) +# x = self.head(x) + return x + + def get_intermediate_layers(self, x): + x = self.patch_embed(x) + batch_size, seq_len, _ = x.size() + + cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + if self.pos_embed is not None: + x = x + self.pos_embed + x = self.pos_drop(x) + + features = [] + rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None + for blk in self.blocks: + x = blk(x, rel_pos_bias) + features.append(x) + + return features + + +def interpolate_pos_embed(model, checkpoint_model): + if 'pos_embed' in checkpoint_model: + pos_embed_checkpoint = checkpoint_model['pos_embed'].float() + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches ** 0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model['pos_embed'] = new_pos_embed + + +def convert_weights_to_fp16(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + +# if isinstance(l, (nn.MultiheadAttention, Attention)): +# for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: +# tensor = getattr(l, attr) +# if tensor is not None: +# tensor.data = tensor.data.half() + + model.apply(_convert_weights_to_fp16) + + +def create_eva_vit_g(img_size=224,drop_path_rate=0.4,use_checkpoint=False,precision="fp16"): + model = VisionTransformer( + img_size=img_size, + patch_size=14, + use_mean_pooling=False, + embed_dim=1408, + depth=39, + num_heads=1408//88, + mlp_ratio=4.3637, + qkv_bias=True, + drop_path_rate=drop_path_rate, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + use_checkpoint=use_checkpoint, + ) + url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth" + cached_file = download_cached_file( + url, check_hash=False, progress=True + ) + state_dict = torch.load(cached_file, map_location="cpu") + interpolate_pos_embed(model,state_dict) + + incompatible_keys = model.load_state_dict(state_dict, strict=False) +# print(incompatible_keys) + + if precision == "fp16": +# model.to("cuda") + convert_weights_to_fp16(model) + return model \ No newline at end of file diff --git a/minigpt4/models/minigpt4.py b/minigpt4/models/minigpt4.py new file mode 100644 index 0000000000000000000000000000000000000000..a2e4798bb9713467b0ddac2dcec3cb1681c6418d --- /dev/null +++ b/minigpt4/models/minigpt4.py @@ -0,0 +1,195 @@ +import logging +import random + +import torch +from torch.cuda.amp import autocast as autocast +import torch.nn as nn + +from minigpt4.common.registry import registry +from minigpt4.models.base_model import disabled_train +from minigpt4.models.minigpt_base import MiniGPTBase +from minigpt4.models.Qformer import BertConfig, BertLMHeadModel + + +@registry.register_model("minigpt4") +class MiniGPT4(MiniGPTBase): + """ + MiniGPT-4 model + """ + + PRETRAINED_MODEL_CONFIG_DICT = { + "pretrain_vicuna0": "configs/models/minigpt4_vicuna0.yaml", + "pretrain_llama2": "configs/models/minigpt4_llama2.yaml", + } + + def __init__( + self, + vit_model="eva_clip_g", + q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth", + img_size=224, + drop_path_rate=0, + use_grad_checkpoint=False, + vit_precision="fp16", + freeze_vit=True, + has_qformer=True, + freeze_qformer=True, + num_query_token=32, + llama_model="", + prompt_path="", + prompt_template="", + max_txt_len=32, + end_sym='\n', + low_resource=False, # use 8 bit and put vit in cpu + device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore. + ): + super().__init__( + vit_model=vit_model, + img_size=img_size, + drop_path_rate=drop_path_rate, + use_grad_checkpoint=use_grad_checkpoint, + vit_precision=vit_precision, + freeze_vit=freeze_vit, + llama_model=llama_model, + max_txt_len=max_txt_len, + end_sym=end_sym, + low_resource=low_resource, + device_8bit=device_8bit, + ) + + self.has_qformer = has_qformer + if self.has_qformer: + print('Loading Q-Former') + self.Qformer, self.query_tokens = self.init_Qformer( + num_query_token, self.visual_encoder.num_features, freeze_qformer + ) + self.load_from_pretrained(url_or_filename=q_former_model) # load q-former weights here + + img_f_dim = self.Qformer.config.hidden_size + print('Loading Q-Former Done') + else: + img_f_dim = self.visual_encoder.num_features * 4 + print('Do not use Q-Former here.') + + self.llama_proj = nn.Linear( + img_f_dim, self.llama_model.config.hidden_size + ) + + if prompt_path: + with open(prompt_path, 'r') as f: + raw_prompts = f.read().splitlines() + filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "" in raw_prompt] + self.prompt_list = [prompt_template.format(p) for p in filted_prompts] + print('Load {} training prompts'.format(len(self.prompt_list))) + print('Prompt Example \n{}'.format(random.choice(self.prompt_list))) + else: + self.prompt_list = [] + + @classmethod + def init_Qformer(cls, num_query_token, vision_width, freeze): + encoder_config = BertConfig.from_pretrained("bert-base-uncased") + encoder_config.encoder_width = vision_width + # insert cross-attention layer every other block + encoder_config.add_cross_attention = True + encoder_config.cross_attention_freq = 2 + encoder_config.query_length = num_query_token + Qformer = BertLMHeadModel(config=encoder_config) + query_tokens = nn.Parameter( + torch.zeros(1, num_query_token, encoder_config.hidden_size) + ) + query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) + + Qformer.cls = None + Qformer.bert.embeddings.word_embeddings = None + Qformer.bert.embeddings.position_embeddings = None + for layer in Qformer.bert.encoder.layer: + layer.output = None + layer.intermediate = None + + if freeze: + for name, param in Qformer.named_parameters(): + param.requires_grad = False + Qformer = Qformer.eval() + Qformer.train = disabled_train + query_tokens.requires_grad = False + logging.info("freeze Qformer") + + return Qformer, query_tokens + + def encode_img(self, image): + device = image.device + + if len(image.shape) > 4: + image = image.reshape(-1, *image.shape[-3:]) + + with self.maybe_autocast(): + image_embeds = self.ln_vision(self.visual_encoder(image)).to(device) + if self.has_qformer: + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device) + + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_output = self.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + inputs_llama = self.llama_proj(query_output.last_hidden_state) + else: + image_embeds = image_embeds[:, 1:, :] + bs, pn, hs = image_embeds.shape + image_embeds = image_embeds.view(bs, int(pn / 4), int(hs * 4)) + + inputs_llama = self.llama_proj(image_embeds) + atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device) + return inputs_llama, atts_llama + + @classmethod + def from_config(cls, cfg): + vit_model = cfg.get("vit_model", "eva_clip_g") + q_former_model = cfg.get("q_former_model", "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth") + img_size = cfg.get("image_size") + num_query_token = cfg.get("num_query_token") + llama_model = cfg.get("llama_model") + + drop_path_rate = cfg.get("drop_path_rate", 0) + use_grad_checkpoint = cfg.get("use_grad_checkpoint", False) + vit_precision = cfg.get("vit_precision", "fp16") + freeze_vit = cfg.get("freeze_vit", True) + has_qformer = cfg.get("has_qformer", True) + freeze_qformer = cfg.get("freeze_qformer", True) + low_resource = cfg.get("low_resource", False) + device_8bit = cfg.get("device_8bit", 0) + + prompt_path = cfg.get("prompt_path", "") + prompt_template = cfg.get("prompt_template", "") + max_txt_len = cfg.get("max_txt_len", 32) + end_sym = cfg.get("end_sym", '\n') + + model = cls( + vit_model=vit_model, + q_former_model=q_former_model, + img_size=img_size, + drop_path_rate=drop_path_rate, + use_grad_checkpoint=use_grad_checkpoint, + vit_precision=vit_precision, + freeze_vit=freeze_vit, + has_qformer=has_qformer, + freeze_qformer=freeze_qformer, + num_query_token=num_query_token, + llama_model=llama_model, + prompt_path=prompt_path, + prompt_template=prompt_template, + max_txt_len=max_txt_len, + end_sym=end_sym, + low_resource=low_resource, + device_8bit=device_8bit, + ) + + ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4 + if ckpt_path: + print("Load MiniGPT-4 Checkpoint: {}".format(ckpt_path)) + ckpt = torch.load(ckpt_path, map_location="cpu") + msg = model.load_state_dict(ckpt['model'], strict=False) + + return model diff --git a/minigpt4/models/minigpt_base.py b/minigpt4/models/minigpt_base.py new file mode 100644 index 0000000000000000000000000000000000000000..77c919a4e9a11e09ca1299bf4a2643e42a1440da --- /dev/null +++ b/minigpt4/models/minigpt_base.py @@ -0,0 +1,401 @@ +import logging +import random + +import torch +from torch.cuda.amp import autocast as autocast +import torch.nn as nn + +from minigpt4.common.registry import registry +from minigpt4.models.base_model import BaseModel + + + +class MiniGPTBase(BaseModel): + """ + Base class for MiniGPT-4 and MiniGPT-v2 + """ + + def __init__( + self, + vit_model="eva_clip_g", + img_size=224, + drop_path_rate=0, + use_grad_checkpoint=False, + vit_precision="fp16", + freeze_vit=True, + llama_model="", + max_txt_len=32, + max_context_len=3800, + prompt_template="", + end_sym='\n', + low_resource=False, # use 8 bit and put vit in cpu + device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore. + lora_r=0, # lora_r means lora is not used + lora_target_modules=["q_proj", "v_proj"], + lora_alpha=16, + lora_dropout=0.05, + ): + super().__init__() + + self.llama_model, self.llama_tokenizer = self.init_llm( + llama_model_path=llama_model, + low_resource=low_resource, + low_res_device=device_8bit, + lora_r=lora_r, + lora_target_modules=lora_target_modules, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + ) + + self.visual_encoder, self.ln_vision = self.init_vision_encoder( + vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision, freeze_vit + ) + + self.max_txt_len = max_txt_len + self.max_context_len = max_context_len + self.end_sym = end_sym + + self.prompt_template = prompt_template + self.prompt_list = [] + + def vit_to_cpu(self): + self.ln_vision.to("cpu") + self.ln_vision.float() + self.visual_encoder.to("cpu") + self.visual_encoder.float() + + def get_context_emb(self, prompt, img_list): + device = img_list[0].device + prompt_segs = prompt.split('') + assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images." + seg_tokens = [ + self.llama_tokenizer( + seg, return_tensors="pt", add_special_tokens=i==0).to(device).input_ids # only add bos to the first seg + for i, seg in enumerate(prompt_segs) + ] + seg_embs = [self.embed_tokens(seg_t) for seg_t in seg_tokens] + + mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]] + mixed_embs = torch.cat(mixed_embs, dim=1) + return mixed_embs + + def prompt_wrap(self, img_embeds, atts_img, prompts, lengths=None): + if prompts is None or len(prompts) == 0: + # prompts is not provided, just return the original image embedding + return img_embeds, atts_img + elif img_embeds is None: + # prompt is provided but there is no image embedding. return the prompt embedding in right padding + self.llama_tokenizer.padding_side = "right" + prompt_tokens = self.llama_tokenizer( + prompts, + return_tensors="pt", + padding="longest", + add_special_tokens=False + ).to(self.device) + prompt_embeds = self.embed_tokens(prompt_tokens.input_ids) + atts_prompt = prompt_tokens.attention_mask + return prompt_embeds, atts_prompt + else: + # return the multi-modal embedding in right padding + emb_lists = [] + if isinstance(prompts, str): + prompts = [prompts] * len(img_embeds) + + for idx, (each_img_embed, each_prompt) in enumerate(zip(img_embeds, prompts)): + pn = each_img_embed.shape[-2] + if lengths is not None: + each_img_embed = each_img_embed.reshape(-1, each_img_embed.shape[-1]) + each_img_embed = each_img_embed[:lengths[idx] * pn] + p_segs = each_prompt.split('') + interleave_emb = [] + for idx, seg in enumerate(p_segs[:-1]): + p_tokens = self.llama_tokenizer( + seg, return_tensors="pt", add_special_tokens=False).to(img_embeds.device) + p_embed = self.embed_tokens(p_tokens.input_ids) + interleave_emb.append(torch.cat([p_embed, each_img_embed[None][:, idx * pn:(idx + 1) * pn]], dim=1)) + wrapped_emb = torch.cat(interleave_emb, dim=1) + p_tokens = self.llama_tokenizer( + p_segs[-1], return_tensors="pt", add_special_tokens=False).to(img_embeds.device) + p_embed = self.embed_tokens(p_tokens.input_ids) + wrapped_emb = torch.cat([wrapped_emb, p_embed], dim=1) + emb_lists.append(wrapped_emb) + + emb_lens = [emb.shape[1] for emb in emb_lists] + pad_emb = self.embed_tokens(torch.tensor(self.llama_tokenizer.pad_token_id, device=img_embeds.device)) + + max_length = max(emb_lens) if max(emb_lens) < self.max_context_len else self.max_context_len + wrapped_embs = pad_emb.expand(len(emb_lens), max_length, -1).clone() + wrapped_atts = torch.zeros([len(emb_lens), max_length], dtype=torch.int, device=img_embeds.device) + + for i, emb in enumerate(emb_lists): + length = emb_lens[i] if emb_lens[i] < self.max_context_len else self.max_context_len + wrapped_embs[i, :length] = emb[:, :length] + wrapped_atts[i, :length] = 1 + return wrapped_embs, wrapped_atts + + def concat_emb_input_output(self, input_embs, input_atts, output_embs, output_atts): + """ + Concatenate the batched input embedding and batched output embedding together. + Both the input and the output embedding should be right padded. + """ + input_lens = [] + cat_embs = [] + cat_atts = [] + for i in range(input_embs.size(0)): + input_len = input_atts[i].sum() + input_lens.append(input_len) + cat_embs.append( + torch.cat([ + input_embs[i][:input_len], + output_embs[i], + input_embs[i][input_len:] + ]) + ) + cat_atts.append( + torch.cat([ + input_atts[i][:input_len], + output_atts[i], + input_atts[i][input_len:] + ]) + ) + cat_embs = torch.stack(cat_embs) + cat_atts = torch.stack(cat_atts) + return cat_embs, cat_atts, input_lens + + def tokenize_conversation(self, conv_q, conv_a): + """concatenate conversation and make sure the model is only trained to regress the answer""" + + to_regress_token_ids_list = [] + targets_list = [] + + batch_size = len(conv_q) + for batch_idx in range(batch_size): + questions, answers = conv_q[batch_idx], conv_a[batch_idx] + questions = [self.llama_tokenizer(q, + return_tensors="pt", + add_special_tokens=False).to(self.device) for q in questions[1:]] # the first question is handled in the prompt wrap function, skip it + answers = [self.llama_tokenizer(q, + return_tensors="pt", + add_special_tokens=False).to(self.device) for q in answers] + cur_id = [] + cur_target = [] + for i in range(len(questions)): + cur_id.append(answers[i].input_ids) + cur_target.append(answers[i].input_ids) + cur_id.append(questions[i].input_ids) + cur_target.append(torch.ones_like(questions[i].input_ids) * -100) + + cur_id.append(answers[-1].input_ids) + cur_target.append(answers[-1].input_ids) + + cur_id = torch.cat(cur_id, dim=1) + cur_target = torch.cat(cur_target, dim=1) + to_regress_token_ids_list.append(cur_id) + targets_list.append(cur_target) + + max_len = min(max([target.shape[1] for target in targets_list]), self.max_txt_len) + to_regress_token_ids = torch.ones([batch_size, max_len], + dtype=cur_id.dtype, device=self.device) * self.llama_tokenizer.pad_token_id + targets = torch.ones([batch_size, max_len], + dtype=cur_id.dtype, device=self.device) * -100 + for batch_idx in range(batch_size): + cur_len = to_regress_token_ids_list[batch_idx].shape[1] + to_regress_token_ids[batch_idx, :cur_len] = to_regress_token_ids_list[batch_idx][0, :max_len] + targets[batch_idx, :cur_len] = targets_list[batch_idx][0, :max_len] + + to_regress_token_attn = (to_regress_token_ids != self.llama_tokenizer.pad_token_id).to(torch.int) + + return to_regress_token_ids, to_regress_token_attn, targets + + def preparing_embedding(self, samples): + ### prepare input tokens + if 'image' in samples: + img_embeds, img_atts = self.encode_img(samples["image"]) + else: + img_embeds = img_atts = None + + if 'conv_q' in samples: + # handeling conversation datasets + conv_q, conv_a = samples['conv_q'], samples['conv_a'] + + connect_sym = samples['connect_sym'][0] + conv_q = [q.split(connect_sym)for q in conv_q] + conv_a = [a.split(connect_sym) for a in conv_a] + + conv_q = [[self.prompt_template.format(item) for item in items] for items in conv_q] + + cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, [q[0] for q in conv_q]) + regress_token_ids, regress_atts, part_targets = self.tokenize_conversation(conv_q, conv_a) + + else: + if "instruction_input" in samples: + instruction = samples["instruction_input"] + elif self.prompt_list: + instruction = random.choice(self.prompt_list) + else: + instruction = None + + if self.chat_template: + instruction = [self.prompt_template.format(instruct) for instruct in instruction] + + if 'length' in samples: + # the input is a image train (like videos) + bsz, pn, hs = img_embeds.shape + img_embeds = img_embeds.reshape(len(samples['image']), -1, pn, hs) + cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction, samples['length']) + else: + cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction) + + ### prepare target tokens + self.llama_tokenizer.padding_side = "right" + text = [t + self.end_sym for t in samples["answer"]] + + regress_tokens = self.llama_tokenizer( + text, + return_tensors="pt", + padding="longest", + truncation=True, + max_length=self.max_txt_len, + add_special_tokens=False + ).to(self.device) + + regress_token_ids = regress_tokens.input_ids + regress_atts = regress_tokens.attention_mask + part_targets = regress_token_ids.masked_fill( + regress_token_ids == self.llama_tokenizer.pad_token_id, -100 + ) + + regress_embeds = self.embed_tokens(regress_token_ids) + + return cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets + + def forward(self, samples, reduction='mean'): + # prepare the embedding to condition and the embedding to regress + cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets = \ + self.preparing_embedding(samples) + + # concat the embedding to condition and the embedding to regress + inputs_embeds, attention_mask, input_lens = \ + self.concat_emb_input_output(cond_embeds, cond_atts, regress_embeds, regress_atts) + + # get bos token embedding + bos = torch.ones_like(part_targets[:, :1]) * self.llama_tokenizer.bos_token_id + bos_embeds = self.embed_tokens(bos) + bos_atts = cond_atts[:, :1] + + # add bos token at the begining + inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1) + attention_mask = torch.cat([bos_atts, attention_mask], dim=1) + + # ensemble the final targets + targets = torch.ones([inputs_embeds.shape[0], inputs_embeds.shape[1]], + dtype=torch.long).to(self.device).fill_(-100) + + for i, target in enumerate(part_targets): + targets[i, input_lens[i]+1:input_lens[i]+len(target)+1] = target # plus 1 for bos + + with self.maybe_autocast(): + outputs = self.llama_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + return_dict=True, + labels=targets, + reduction=reduction + ) + loss = outputs.loss + + return {"loss": loss} + + def embed_tokens(self, token_ids): + if hasattr(self.llama_model.base_model, 'model'): ## lora wrapped model + embeds = self.llama_model.base_model.model.model.embed_tokens(token_ids) + else: + embeds = self.llama_model.base_model.embed_tokens(token_ids) + return embeds + + + @torch.no_grad() + def generate( + self, + images, + texts, + num_beams=1, + max_new_tokens=20, + min_length=1, + top_p=0.9, + repetition_penalty=1, + length_penalty=1, + temperature=1, + do_sample=False, + stop_words_ids=[2], + ): + ''' + function for generate test use + ''' + + stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub( + stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])]) + + img_embeds, atts_img = self.encode_img(images.to(self.device)) + image_lists = [[image_emb[None]] for image_emb in img_embeds] + + batch_embs = [self.get_context_emb(text, img_list) for text, img_list in zip(texts, image_lists)] + + batch_size = len(batch_embs) + max_len = max([emb.shape[1] for emb in batch_embs]) + emb_dim = batch_embs[0].shape[2] + dtype = batch_embs[0].dtype + device = batch_embs[0].device + + embs = torch.zeros([batch_size, max_len, emb_dim], dtype=dtype, device=device) + attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device) + for i, emb in enumerate(batch_embs): + emb_len = emb.shape[1] + embs[i, -emb_len:] = emb[0] + attn_mask[i, -emb_len:] = 1 + + with self.maybe_autocast(): + outputs = self.llama_model.generate( + inputs_embeds=embs, + attention_mask=attn_mask, + max_new_tokens=max_new_tokens, + num_beams=num_beams, + length_penalty=length_penalty, + temperature=temperature, + do_sample=do_sample, + min_length=min_length, + top_p=top_p, + repetition_penalty=repetition_penalty + # stopping_criteria=stopping_criteria, + ) + + answers = [] + for output_token in outputs: + if output_token[0] == 0: + output_token = output_token[1:] + output_texts = self.llama_tokenizer.decode(output_token, skip_special_tokens=True) + output_texts = output_texts.split('')[0] # remove the stop sign + output_texts = output_texts.replace("", "") + output_texts = output_texts.split(r'[/INST]')[-1].strip() + answers.append(output_texts) + + return answers + + @torch.no_grad() + def multi_select(self, images, texts, answers, num_cand=None): + all_losses = [] + for answer in answers: + choice_samples = { + 'image': images, + 'instruction_input': texts, + 'answer': answer + } + loss = self.forward(choice_samples, reduction='none')['loss'].reshape(-1, 1) + all_losses.append(loss) + torch.cuda.empty_cache() + all_losses = torch.cat(all_losses, dim=-1) + if num_cand is not None: + for i in range(all_losses.shape[0]): + all_losses[i, num_cand[i]:] = 9999 + output_class_ranks = torch.argsort(all_losses, dim=-1) + return output_class_ranks.tolist() \ No newline at end of file diff --git a/minigpt4/models/minigpt_v2.py b/minigpt4/models/minigpt_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..a046b0baff41db50477e35904af9bcad5baa619c --- /dev/null +++ b/minigpt4/models/minigpt_v2.py @@ -0,0 +1,139 @@ +import logging +import random + +import torch +from torch.cuda.amp import autocast as autocast +import torch.nn as nn + +from minigpt4.common.registry import registry +from minigpt4.models.base_model import disabled_train +from minigpt4.models.minigpt_base import MiniGPTBase +from minigpt4.models.Qformer import BertConfig, BertLMHeadModel + + +@registry.register_model("minigpt_v2") +class MiniGPTv2(MiniGPTBase): + """ + MiniGPT-v2 model + """ + + PRETRAINED_MODEL_CONFIG_DICT = { + "pretrain": "configs/models/minigpt_v2.yaml", + } + + def __init__( + self, + vit_model="eva_clip_g", + img_size=448, + drop_path_rate=0, + use_grad_checkpoint=False, + vit_precision="fp16", + freeze_vit=True, + llama_model="", + prompt_template='[INST] {} [/INST]', + max_txt_len=300, + end_sym='\n', + lora_r=64, + lora_target_modules=["q_proj", "v_proj"], + lora_alpha=16, + lora_dropout=0.05, + chat_template=False, + use_grad_checkpoint_llm=False, + max_context_len=3800, + low_resource=False, # use 8 bit and put vit in cpu + device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore. + ): + super().__init__( + vit_model=vit_model, + img_size=img_size, + drop_path_rate=drop_path_rate, + use_grad_checkpoint=use_grad_checkpoint, + vit_precision=vit_precision, + freeze_vit=freeze_vit, + llama_model=llama_model, + max_txt_len=max_txt_len, + max_context_len=max_context_len, + end_sym=end_sym, + prompt_template=prompt_template, + low_resource=low_resource, + device_8bit=device_8bit, + lora_r=lora_r, + lora_target_modules=lora_target_modules, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + ) + + img_f_dim = self.visual_encoder.num_features * 4 + self.llama_proj = nn.Linear( + img_f_dim, self.llama_model.config.hidden_size + ) + self.chat_template = chat_template + + if use_grad_checkpoint_llm: + self.llama_model.gradient_checkpointing_enable() + + def encode_img(self, image): + device = image.device + + if len(image.shape) > 4: + image = image.reshape(-1, *image.shape[-3:]) + + with self.maybe_autocast(): + image_embeds = self.ln_vision(self.visual_encoder(image)).to(device) + image_embeds = image_embeds[:, 1:, :] + bs, pn, hs = image_embeds.shape + image_embeds = image_embeds.view(bs, int(pn / 4), int(hs * 4)) + + inputs_llama = self.llama_proj(image_embeds) + atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device) + return inputs_llama, atts_llama + + @classmethod + def from_config(cls, cfg): + vit_model = cfg.get("vit_model", "eva_clip_g") + img_size = cfg.get("image_size") + llama_model = cfg.get("llama_model") + + drop_path_rate = cfg.get("drop_path_rate", 0) + use_grad_checkpoint = cfg.get("use_grad_checkpoint", False) + vit_precision = cfg.get("vit_precision", "fp16") + freeze_vit = cfg.get("freeze_vit", True) + low_resource = cfg.get("low_resource", False) + + prompt_template = cfg.get("prompt_template", '[INST] {} [/INST]') + max_txt_len = cfg.get("max_txt_len", 300) + end_sym = cfg.get("end_sym", '\n') + + lora_r = cfg.get("lora_r", 64) + lora_alpha = cfg.get("lora_alpha", 16) + chat_template = cfg.get("chat_template", False) + + use_grad_checkpoint_llm = cfg.get("use_grad_checkpoint_llm", False) + max_context_len = cfg.get("max_context_len", 3800) + + model = cls( + vit_model=vit_model, + img_size=img_size, + drop_path_rate=drop_path_rate, + use_grad_checkpoint=use_grad_checkpoint, + vit_precision=vit_precision, + freeze_vit=freeze_vit, + llama_model=llama_model, + prompt_template=prompt_template, + max_txt_len=max_txt_len, + low_resource=low_resource, + end_sym=end_sym, + lora_r=lora_r, + lora_alpha=lora_alpha, + chat_template=chat_template, + use_grad_checkpoint_llm=use_grad_checkpoint_llm, + max_context_len=max_context_len, + ) + + ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4 + if ckpt_path: + print("Load Minigpt-4-LLM Checkpoint: {}".format(ckpt_path)) + ckpt = torch.load(ckpt_path, map_location="cpu") + msg = model.load_state_dict(ckpt['model'], strict=False) + + return model diff --git a/minigpt4/models/modeling_llama.py b/minigpt4/models/modeling_llama.py new file mode 100644 index 0000000000000000000000000000000000000000..6d2802050ff5fe65b7bbef22a4b648c4109c90bc --- /dev/null +++ b/minigpt4/models/modeling_llama.py @@ -0,0 +1,111 @@ +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch.nn import CrossEntropyLoss + +from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING, _CONFIG_FOR_DOC +from transformers.models.llama.modeling_llama import LlamaForCausalLM as LlamaForCausalLMOrig + + +class LlamaForCausalLM(LlamaForCausalLMOrig): + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + reduction: Optional[str] = "mean", + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss(reduction=reduction) + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + if reduction == "none": + loss = loss.view(logits.size(0), -1).mean(1) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/minigpt4/processors/__init__.py b/minigpt4/processors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e560eaa15f3266dbc1ffbca70bdc791901737a60 --- /dev/null +++ b/minigpt4/processors/__init__.py @@ -0,0 +1,33 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from minigpt4.processors.base_processor import BaseProcessor +from minigpt4.processors.blip_processors import ( + Blip2ImageTrainProcessor, + Blip2ImageEvalProcessor, + BlipCaptionProcessor, +) + +from minigpt4.common.registry import registry + +__all__ = [ + "BaseProcessor", + "Blip2ImageTrainProcessor", + "Blip2ImageEvalProcessor", + "BlipCaptionProcessor", +] + + +def load_processor(name, cfg=None): + """ + Example + + >>> processor = load_processor("alpro_video_train", cfg=None) + """ + processor = registry.get_processor_class(name).from_config(cfg) + + return processor diff --git a/minigpt4/processors/base_processor.py b/minigpt4/processors/base_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..39b33cdf8fcd97cfd3e4a5fbece6593357af9d41 --- /dev/null +++ b/minigpt4/processors/base_processor.py @@ -0,0 +1,26 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from omegaconf import OmegaConf + + +class BaseProcessor: + def __init__(self): + self.transform = lambda x: x + return + + def __call__(self, item): + return self.transform(item) + + @classmethod + def from_config(cls, cfg=None): + return cls() + + def build(self, **kwargs): + cfg = OmegaConf.create(kwargs) + + return self.from_config(cfg) diff --git a/minigpt4/processors/blip_processors.py b/minigpt4/processors/blip_processors.py new file mode 100644 index 0000000000000000000000000000000000000000..fd26160ec96a8458cdac083d19c19695937a7a62 --- /dev/null +++ b/minigpt4/processors/blip_processors.py @@ -0,0 +1,141 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import re + +from minigpt4.common.registry import registry +from minigpt4.processors.base_processor import BaseProcessor +from minigpt4.processors.randaugment import RandomAugment +from omegaconf import OmegaConf +from torchvision import transforms +from torchvision.transforms.functional import InterpolationMode + + +class BlipImageBaseProcessor(BaseProcessor): + def __init__(self, mean=None, std=None): + if mean is None: + mean = (0.48145466, 0.4578275, 0.40821073) + if std is None: + std = (0.26862954, 0.26130258, 0.27577711) + + self.normalize = transforms.Normalize(mean, std) + + +@registry.register_processor("blip_caption") +class BlipCaptionProcessor(BaseProcessor): + def __init__(self, prompt="", max_words=50): + self.prompt = prompt + self.max_words = max_words + + def __call__(self, caption): + caption = self.prompt + self.pre_caption(caption) + + return caption + + @classmethod + def from_config(cls, cfg=None): + if cfg is None: + cfg = OmegaConf.create() + + prompt = cfg.get("prompt", "") + max_words = cfg.get("max_words", 50) + + return cls(prompt=prompt, max_words=max_words) + + def pre_caption(self, caption): + caption = re.sub( + r"([.!\"()*#:;~])", + " ", + caption.lower(), + ) + caption = re.sub( + r"\s{2,}", + " ", + caption, + ) + caption = caption.rstrip("\n") + caption = caption.strip(" ") + + # truncate caption + caption_words = caption.split(" ") + if len(caption_words) > self.max_words: + caption = " ".join(caption_words[: self.max_words]) + + return caption + + +@registry.register_processor("blip2_image_train") +class Blip2ImageTrainProcessor(BlipImageBaseProcessor): + def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0): + super().__init__(mean=mean, std=std) + + self.transform = transforms.Compose( + [ + transforms.RandomResizedCrop( + image_size, + scale=(min_scale, max_scale), + interpolation=InterpolationMode.BICUBIC, + ), + transforms.ToTensor(), + self.normalize, + ] + ) + + def __call__(self, item): + return self.transform(item) + + @classmethod + def from_config(cls, cfg=None): + if cfg is None: + cfg = OmegaConf.create() + + image_size = cfg.get("image_size", 224) + + mean = cfg.get("mean", None) + std = cfg.get("std", None) + + min_scale = cfg.get("min_scale", 0.5) + max_scale = cfg.get("max_scale", 1.0) + + return cls( + image_size=image_size, + mean=mean, + std=std, + min_scale=min_scale, + max_scale=max_scale, + ) + + +@registry.register_processor("blip2_image_eval") +class Blip2ImageEvalProcessor(BlipImageBaseProcessor): + def __init__(self, image_size=224, mean=None, std=None): + super().__init__(mean=mean, std=std) + + self.transform = transforms.Compose( + [ + transforms.Resize( + (image_size, image_size), interpolation=InterpolationMode.BICUBIC + ), + transforms.ToTensor(), + self.normalize, + ] + ) + + def __call__(self, item): + return self.transform(item) + + @classmethod + def from_config(cls, cfg=None): + if cfg is None: + cfg = OmegaConf.create() + + image_size = cfg.get("image_size", 224) + + mean = cfg.get("mean", None) + std = cfg.get("std", None) + + return cls(image_size=image_size, mean=mean, std=std) \ No newline at end of file diff --git a/minigpt4/processors/randaugment.py b/minigpt4/processors/randaugment.py new file mode 100644 index 0000000000000000000000000000000000000000..7034a49ad5fc63b97910790017432617ff4c6d7b --- /dev/null +++ b/minigpt4/processors/randaugment.py @@ -0,0 +1,398 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import cv2 +import numpy as np + +import torch + + +## aug functions +def identity_func(img): + return img + + +def autocontrast_func(img, cutoff=0): + """ + same output as PIL.ImageOps.autocontrast + """ + n_bins = 256 + + def tune_channel(ch): + n = ch.size + cut = cutoff * n // 100 + if cut == 0: + high, low = ch.max(), ch.min() + else: + hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) + low = np.argwhere(np.cumsum(hist) > cut) + low = 0 if low.shape[0] == 0 else low[0] + high = np.argwhere(np.cumsum(hist[::-1]) > cut) + high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0] + if high <= low: + table = np.arange(n_bins) + else: + scale = (n_bins - 1) / (high - low) + offset = -low * scale + table = np.arange(n_bins) * scale + offset + table[table < 0] = 0 + table[table > n_bins - 1] = n_bins - 1 + table = table.clip(0, 255).astype(np.uint8) + return table[ch] + + channels = [tune_channel(ch) for ch in cv2.split(img)] + out = cv2.merge(channels) + return out + + +def equalize_func(img): + """ + same output as PIL.ImageOps.equalize + PIL's implementation is different from cv2.equalize + """ + n_bins = 256 + + def tune_channel(ch): + hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) + non_zero_hist = hist[hist != 0].reshape(-1) + step = np.sum(non_zero_hist[:-1]) // (n_bins - 1) + if step == 0: + return ch + n = np.empty_like(hist) + n[0] = step // 2 + n[1:] = hist[:-1] + table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8) + return table[ch] + + channels = [tune_channel(ch) for ch in cv2.split(img)] + out = cv2.merge(channels) + return out + + +def rotate_func(img, degree, fill=(0, 0, 0)): + """ + like PIL, rotate by degree, not radians + """ + H, W = img.shape[0], img.shape[1] + center = W / 2, H / 2 + M = cv2.getRotationMatrix2D(center, degree, 1) + out = cv2.warpAffine(img, M, (W, H), borderValue=fill) + return out + + +def solarize_func(img, thresh=128): + """ + same output as PIL.ImageOps.posterize + """ + table = np.array([el if el < thresh else 255 - el for el in range(256)]) + table = table.clip(0, 255).astype(np.uint8) + out = table[img] + return out + + +def color_func(img, factor): + """ + same output as PIL.ImageEnhance.Color + """ + ## implementation according to PIL definition, quite slow + # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis] + # out = blend(degenerate, img, factor) + # M = ( + # np.eye(3) * factor + # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor) + # )[np.newaxis, np.newaxis, :] + M = np.float32( + [[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]] + ) * factor + np.float32([[0.114], [0.587], [0.299]]) + out = np.matmul(img, M).clip(0, 255).astype(np.uint8) + return out + + +def contrast_func(img, factor): + """ + same output as PIL.ImageEnhance.Contrast + """ + mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299])) + table = ( + np.array([(el - mean) * factor + mean for el in range(256)]) + .clip(0, 255) + .astype(np.uint8) + ) + out = table[img] + return out + + +def brightness_func(img, factor): + """ + same output as PIL.ImageEnhance.Contrast + """ + table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8) + out = table[img] + return out + + +def sharpness_func(img, factor): + """ + The differences the this result and PIL are all on the 4 boundaries, the center + areas are same + """ + kernel = np.ones((3, 3), dtype=np.float32) + kernel[1][1] = 5 + kernel /= 13 + degenerate = cv2.filter2D(img, -1, kernel) + if factor == 0.0: + out = degenerate + elif factor == 1.0: + out = img + else: + out = img.astype(np.float32) + degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :] + out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate) + out = out.astype(np.uint8) + return out + + +def shear_x_func(img, factor, fill=(0, 0, 0)): + H, W = img.shape[0], img.shape[1] + M = np.float32([[1, factor, 0], [0, 1, 0]]) + out = cv2.warpAffine( + img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR + ).astype(np.uint8) + return out + + +def translate_x_func(img, offset, fill=(0, 0, 0)): + """ + same output as PIL.Image.transform + """ + H, W = img.shape[0], img.shape[1] + M = np.float32([[1, 0, -offset], [0, 1, 0]]) + out = cv2.warpAffine( + img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR + ).astype(np.uint8) + return out + + +def translate_y_func(img, offset, fill=(0, 0, 0)): + """ + same output as PIL.Image.transform + """ + H, W = img.shape[0], img.shape[1] + M = np.float32([[1, 0, 0], [0, 1, -offset]]) + out = cv2.warpAffine( + img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR + ).astype(np.uint8) + return out + + +def posterize_func(img, bits): + """ + same output as PIL.ImageOps.posterize + """ + out = np.bitwise_and(img, np.uint8(255 << (8 - bits))) + return out + + +def shear_y_func(img, factor, fill=(0, 0, 0)): + H, W = img.shape[0], img.shape[1] + M = np.float32([[1, 0, 0], [factor, 1, 0]]) + out = cv2.warpAffine( + img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR + ).astype(np.uint8) + return out + + +def cutout_func(img, pad_size, replace=(0, 0, 0)): + replace = np.array(replace, dtype=np.uint8) + H, W = img.shape[0], img.shape[1] + rh, rw = np.random.random(2) + pad_size = pad_size // 2 + ch, cw = int(rh * H), int(rw * W) + x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H) + y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W) + out = img.copy() + out[x1:x2, y1:y2, :] = replace + return out + + +### level to args +def enhance_level_to_args(MAX_LEVEL): + def level_to_args(level): + return ((level / MAX_LEVEL) * 1.8 + 0.1,) + + return level_to_args + + +def shear_level_to_args(MAX_LEVEL, replace_value): + def level_to_args(level): + level = (level / MAX_LEVEL) * 0.3 + if np.random.random() > 0.5: + level = -level + return (level, replace_value) + + return level_to_args + + +def translate_level_to_args(translate_const, MAX_LEVEL, replace_value): + def level_to_args(level): + level = (level / MAX_LEVEL) * float(translate_const) + if np.random.random() > 0.5: + level = -level + return (level, replace_value) + + return level_to_args + + +def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value): + def level_to_args(level): + level = int((level / MAX_LEVEL) * cutout_const) + return (level, replace_value) + + return level_to_args + + +def solarize_level_to_args(MAX_LEVEL): + def level_to_args(level): + level = int((level / MAX_LEVEL) * 256) + return (level,) + + return level_to_args + + +def none_level_to_args(level): + return () + + +def posterize_level_to_args(MAX_LEVEL): + def level_to_args(level): + level = int((level / MAX_LEVEL) * 4) + return (level,) + + return level_to_args + + +def rotate_level_to_args(MAX_LEVEL, replace_value): + def level_to_args(level): + level = (level / MAX_LEVEL) * 30 + if np.random.random() < 0.5: + level = -level + return (level, replace_value) + + return level_to_args + + +func_dict = { + "Identity": identity_func, + "AutoContrast": autocontrast_func, + "Equalize": equalize_func, + "Rotate": rotate_func, + "Solarize": solarize_func, + "Color": color_func, + "Contrast": contrast_func, + "Brightness": brightness_func, + "Sharpness": sharpness_func, + "ShearX": shear_x_func, + "TranslateX": translate_x_func, + "TranslateY": translate_y_func, + "Posterize": posterize_func, + "ShearY": shear_y_func, +} + +translate_const = 10 +MAX_LEVEL = 10 +replace_value = (128, 128, 128) +arg_dict = { + "Identity": none_level_to_args, + "AutoContrast": none_level_to_args, + "Equalize": none_level_to_args, + "Rotate": rotate_level_to_args(MAX_LEVEL, replace_value), + "Solarize": solarize_level_to_args(MAX_LEVEL), + "Color": enhance_level_to_args(MAX_LEVEL), + "Contrast": enhance_level_to_args(MAX_LEVEL), + "Brightness": enhance_level_to_args(MAX_LEVEL), + "Sharpness": enhance_level_to_args(MAX_LEVEL), + "ShearX": shear_level_to_args(MAX_LEVEL, replace_value), + "TranslateX": translate_level_to_args(translate_const, MAX_LEVEL, replace_value), + "TranslateY": translate_level_to_args(translate_const, MAX_LEVEL, replace_value), + "Posterize": posterize_level_to_args(MAX_LEVEL), + "ShearY": shear_level_to_args(MAX_LEVEL, replace_value), +} + + +class RandomAugment(object): + def __init__(self, N=2, M=10, isPIL=False, augs=[]): + self.N = N + self.M = M + self.isPIL = isPIL + if augs: + self.augs = augs + else: + self.augs = list(arg_dict.keys()) + + def get_random_ops(self): + sampled_ops = np.random.choice(self.augs, self.N) + return [(op, 0.5, self.M) for op in sampled_ops] + + def __call__(self, img): + if self.isPIL: + img = np.array(img) + ops = self.get_random_ops() + for name, prob, level in ops: + if np.random.random() > prob: + continue + args = arg_dict[name](level) + img = func_dict[name](img, *args) + return img + + +class VideoRandomAugment(object): + def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]): + self.N = N + self.M = M + self.p = p + self.tensor_in_tensor_out = tensor_in_tensor_out + if augs: + self.augs = augs + else: + self.augs = list(arg_dict.keys()) + + def get_random_ops(self): + sampled_ops = np.random.choice(self.augs, self.N, replace=False) + return [(op, self.M) for op in sampled_ops] + + def __call__(self, frames): + assert ( + frames.shape[-1] == 3 + ), "Expecting last dimension for 3-channels RGB (b, h, w, c)." + + if self.tensor_in_tensor_out: + frames = frames.numpy().astype(np.uint8) + + num_frames = frames.shape[0] + + ops = num_frames * [self.get_random_ops()] + apply_or_not = num_frames * [np.random.random(size=self.N) > self.p] + + frames = torch.stack( + list(map(self._aug, frames, ops, apply_or_not)), dim=0 + ).float() + + return frames + + def _aug(self, img, ops, apply_or_not): + for i, (name, level) in enumerate(ops): + if not apply_or_not[i]: + continue + args = arg_dict[name](level) + img = func_dict[name](img, *args) + return torch.from_numpy(img) + + +if __name__ == "__main__": + a = RandomAugment() + img = np.random.randn(32, 32, 3) + a(img) diff --git a/minigpt4/runners/__init__.py b/minigpt4/runners/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..64e7a4d643a8b5a1714687f42d43347a94b72373 --- /dev/null +++ b/minigpt4/runners/__init__.py @@ -0,0 +1,10 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from minigpt4.runners.runner_base import RunnerBase + +__all__ = ["RunnerBase"] diff --git a/minigpt4/runners/runner_base.py b/minigpt4/runners/runner_base.py new file mode 100644 index 0000000000000000000000000000000000000000..ccb5706fdb19222f51dd032ed723179f9ab19fc9 --- /dev/null +++ b/minigpt4/runners/runner_base.py @@ -0,0 +1,658 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import datetime +import json +import logging +import os +import time +from pathlib import Path + +import torch +import torch.distributed as dist +import webdataset as wds +from minigpt4.common.dist_utils import ( + download_cached_file, + get_rank, + get_world_size, + is_main_process, + main_process, +) +from minigpt4.common.registry import registry +from minigpt4.common.utils import is_url +from minigpt4.datasets.data_utils import concat_datasets, reorg_datasets_by_split, ChainDataset +from minigpt4.datasets.datasets.dataloader_utils import ( + IterLoader, + MultiIterLoader, + PrefetchLoader, +) +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import DataLoader, DistributedSampler + + +@registry.register_runner("runner_base") +class RunnerBase: + """ + A runner class to train and evaluate a model given a task and datasets. + + The runner uses pytorch distributed data parallel by default. Future release + will support other distributed frameworks. + """ + + def __init__(self, cfg, task, model, datasets, job_id): + self.config = cfg + self.job_id = job_id + + self.task = task + self.datasets = datasets + + self._model = model + + self._wrapped_model = None + self._device = None + self._optimizer = None + self._scaler = None + self._dataloaders = None + self._lr_sched = None + + self.start_epoch = 0 + + # self.setup_seeds() + self.setup_output_dir() + + @property + def device(self): + if self._device is None: + self._device = torch.device(self.config.run_cfg.device) + + return self._device + + @property + def use_distributed(self): + return self.config.run_cfg.distributed + + @property + def model(self): + """ + A property to get the DDP-wrapped model on the device. + """ + # move model to device + if self._model.device != self.device: + self._model = self._model.to(self.device) + + # distributed training wrapper + if self.use_distributed: + if self._wrapped_model is None: + self._wrapped_model = DDP( + self._model, device_ids=[self.config.run_cfg.gpu] + ) + else: + self._wrapped_model = self._model + + return self._wrapped_model + + @property + def optimizer(self): + # TODO make optimizer class and configurations + if self._optimizer is None: + num_parameters = 0 + p_wd, p_non_wd = [], [] + for n, p in self.model.named_parameters(): + if not p.requires_grad: + continue # frozen weights + print(n) + if p.ndim < 2 or "bias" in n or "ln" in n or "bn" in n: + p_non_wd.append(p) + else: + p_wd.append(p) + num_parameters += p.data.nelement() + logging.info("number of trainable parameters: %d" % num_parameters) + optim_params = [ + { + "params": p_wd, + "weight_decay": float(self.config.run_cfg.weight_decay), + }, + {"params": p_non_wd, "weight_decay": 0}, + ] + beta2 = self.config.run_cfg.get("beta2", 0.999) + self._optimizer = torch.optim.AdamW( + optim_params, + lr=float(self.config.run_cfg.init_lr), + weight_decay=float(self.config.run_cfg.weight_decay), + betas=(0.9, beta2), + ) + + return self._optimizer + + @property + def scaler(self): + amp = self.config.run_cfg.get("amp", False) + + if amp: + if self._scaler is None: + self._scaler = torch.cuda.amp.GradScaler() + + return self._scaler + + @property + def lr_scheduler(self): + """ + A property to get and create learning rate scheduler by split just in need. + """ + if self._lr_sched is None: + lr_sched_cls = registry.get_lr_scheduler_class(self.config.run_cfg.lr_sched) + + # max_epoch = self.config.run_cfg.max_epoch + max_epoch = self.max_epoch + # min_lr = self.config.run_cfg.min_lr + min_lr = self.min_lr + # init_lr = self.config.run_cfg.init_lr + init_lr = self.init_lr + + # optional parameters + decay_rate = self.config.run_cfg.get("lr_decay_rate", None) + warmup_start_lr = self.config.run_cfg.get("warmup_lr", -1) + warmup_steps = self.config.run_cfg.get("warmup_steps", 0) + iters_per_epoch = self.config.run_cfg.get("iters_per_epoch", None) + + if iters_per_epoch is None: + try: + iters_per_epoch = len(self.dataloaders['train']) + except (AttributeError, TypeError): + iters_per_epoch = 10000 + + self._lr_sched = lr_sched_cls( + optimizer=self.optimizer, + max_epoch=max_epoch, + iters_per_epoch=iters_per_epoch, + min_lr=min_lr, + init_lr=init_lr, + decay_rate=decay_rate, + warmup_start_lr=warmup_start_lr, + warmup_steps=warmup_steps, + ) + + return self._lr_sched + + @property + def dataloaders(self) -> dict: + """ + A property to get and create dataloaders by split just in need. + + If no train_dataset_ratio is provided, concatenate map-style datasets and + chain wds.DataPipe datasets separately. Training set becomes a tuple + (ConcatDataset, ChainDataset), both are optional but at least one of them is + required. The resultant ConcatDataset and ChainDataset will be sampled evenly. + + If train_dataset_ratio is provided, create a MultiIterLoader to sample + each dataset by ratios during training. + + Currently do not support multiple datasets for validation and test. + + Returns: + dict: {split_name: (tuples of) dataloader} + """ + if self._dataloaders is None: + + # concatenate map-style datasets and chain wds.DataPipe datasets separately + # training set becomes a tuple (ConcatDataset, ChainDataset), both are + # optional but at least one of them is required. The resultant ConcatDataset + # and ChainDataset will be sampled evenly. + logging.info( + "dataset_ratios not specified, datasets will be concatenated (map-style datasets) or chained (webdataset.DataPipeline)." + ) + + datasets = reorg_datasets_by_split(self.datasets) + self.datasets = datasets + # self.datasets = concat_datasets(datasets) + + # print dataset statistics after concatenation/chaining + for split_name in self.datasets: + if isinstance(self.datasets[split_name], tuple) or isinstance( + self.datasets[split_name], list + ): + # mixed wds.DataPipeline and torch.utils.data.Dataset + num_records = sum( + [ + len(d) + if not type(d) in [wds.DataPipeline, ChainDataset] + else 0 + for d in self.datasets[split_name] + ] + ) + + else: + if hasattr(self.datasets[split_name], "__len__"): + # a single map-style dataset + num_records = len(self.datasets[split_name]) + else: + # a single wds.DataPipeline + num_records = -1 + logging.info( + "Only a single wds.DataPipeline dataset, no __len__ attribute." + ) + + if num_records >= 0: + logging.info( + "Loaded {} records for {} split from the dataset.".format( + num_records, split_name + ) + ) + + # create dataloaders + split_names = sorted(self.datasets.keys()) + + datasets = [self.datasets[split] for split in split_names] + is_trains = [split in self.train_splits for split in split_names] + + batch_sizes = [ + self.config.run_cfg.batch_size_train + if split == "train" + else self.config.run_cfg.batch_size_eval + for split in split_names + ] + + collate_fns = [] + for dataset in datasets: + if isinstance(dataset, tuple) or isinstance(dataset, list): + collate_fns.append([getattr(d, "collater", None) for d in dataset]) + else: + collate_fns.append(getattr(dataset, "collater", None)) + + dataloaders = self.create_loaders( + datasets=datasets, + num_workers=self.config.run_cfg.num_workers, + batch_sizes=batch_sizes, + is_trains=is_trains, + collate_fns=collate_fns, + ) + + self._dataloaders = {k: v for k, v in zip(split_names, dataloaders)} + + return self._dataloaders + + @property + def cuda_enabled(self): + return self.device.type == "cuda" + + @property + def max_epoch(self): + return int(self.config.run_cfg.max_epoch) + + @property + def log_freq(self): + log_freq = self.config.run_cfg.get("log_freq", 50) + return int(log_freq) + + @property + def init_lr(self): + return float(self.config.run_cfg.init_lr) + + @property + def min_lr(self): + return float(self.config.run_cfg.min_lr) + + @property + def accum_grad_iters(self): + return int(self.config.run_cfg.get("accum_grad_iters", 1)) + + @property + def valid_splits(self): + valid_splits = self.config.run_cfg.get("valid_splits", []) + + if len(valid_splits) == 0: + logging.info("No validation splits found.") + + return valid_splits + + @property + def test_splits(self): + test_splits = self.config.run_cfg.get("test_splits", []) + + return test_splits + + @property + def train_splits(self): + train_splits = self.config.run_cfg.get("train_splits", []) + + if len(train_splits) == 0: + logging.info("Empty train splits.") + + return train_splits + + @property + def evaluate_only(self): + """ + Set to True to skip training. + """ + return self.config.run_cfg.evaluate + + @property + def use_dist_eval_sampler(self): + return self.config.run_cfg.get("use_dist_eval_sampler", True) + + @property + def resume_ckpt_path(self): + return self.config.run_cfg.get("resume_ckpt_path", None) + + @property + def train_loader(self): + train_dataloader = self.dataloaders["train"] + + return train_dataloader + + def setup_output_dir(self): + lib_root = Path(registry.get_path("library_root")) + + output_dir = lib_root / self.config.run_cfg.output_dir / self.job_id + result_dir = output_dir / "result" + + output_dir.mkdir(parents=True, exist_ok=True) + result_dir.mkdir(parents=True, exist_ok=True) + + registry.register_path("result_dir", str(result_dir)) + registry.register_path("output_dir", str(output_dir)) + + self.result_dir = result_dir + self.output_dir = output_dir + + def train(self): + start_time = time.time() + best_agg_metric = 0 + best_epoch = 0 + + self.log_config() + + # resume from checkpoint if specified + if not self.evaluate_only and self.resume_ckpt_path is not None: + self._load_checkpoint(self.resume_ckpt_path) + + for cur_epoch in range(self.start_epoch, self.max_epoch): + # training phase + if not self.evaluate_only: + logging.info("Start training") + train_stats = self.train_epoch(cur_epoch) + self.log_stats(split_name="train", stats=train_stats) + + # evaluation phase + if len(self.valid_splits) > 0: + for split_name in self.valid_splits: + logging.info("Evaluating on {}.".format(split_name)) + + val_log = self.eval_epoch( + split_name=split_name, cur_epoch=cur_epoch + ) + if val_log is not None: + if is_main_process(): + assert ( + "agg_metrics" in val_log + ), "No agg_metrics found in validation log." + + agg_metrics = val_log["agg_metrics"] + if agg_metrics > best_agg_metric and split_name == "val": + best_epoch, best_agg_metric = cur_epoch, agg_metrics + + self._save_checkpoint(cur_epoch, is_best=True) + + val_log.update({"best_epoch": best_epoch}) + self.log_stats(val_log, split_name) + + else: + # if no validation split is provided, we just save the checkpoint at the end of each epoch. + if not self.evaluate_only: + self._save_checkpoint(cur_epoch, is_best=False) + + if self.evaluate_only: + break + + if self.config.run_cfg.distributed: + dist.barrier() + + # testing phase + test_epoch = "best" if len(self.valid_splits) > 0 else cur_epoch + self.evaluate(cur_epoch=test_epoch, skip_reload=self.evaluate_only) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + logging.info("Training time {}".format(total_time_str)) + + def evaluate(self, cur_epoch="best", skip_reload=False): + test_logs = dict() + + if len(self.test_splits) > 0: + for split_name in self.test_splits: + test_logs[split_name] = self.eval_epoch( + split_name=split_name, cur_epoch=cur_epoch, skip_reload=skip_reload + ) + + return test_logs + + def train_epoch(self, epoch): + # train + self.model.train() + + return self.task.train_epoch( + epoch=epoch, + model=self.model, + data_loader=self.train_loader, + optimizer=self.optimizer, + scaler=self.scaler, + lr_scheduler=self.lr_scheduler, + cuda_enabled=self.cuda_enabled, + log_freq=self.log_freq, + accum_grad_iters=self.accum_grad_iters, + ) + + @torch.no_grad() + def eval_epoch(self, split_name, cur_epoch, skip_reload=False): + """ + Evaluate the model on a given split. + + Args: + split_name (str): name of the split to evaluate on. + cur_epoch (int): current epoch. + skip_reload_best (bool): whether to skip reloading the best checkpoint. + During training, we will reload the best checkpoint for validation. + During testing, we will use provided weights and skip reloading the best checkpoint . + """ + data_loader = self.dataloaders.get(split_name, None) + assert data_loader, "data_loader for split {} is None.".format(split_name) + + # TODO In validation, you need to compute loss as well as metrics + # TODO consider moving to model.before_evaluation() + model = self.unwrap_dist_model(self.model) + if not skip_reload and cur_epoch == "best": + model = self._reload_best_model(model) + model.eval() + + self.task.before_evaluation( + model=model, + dataset=self.datasets[split_name], + ) + results = self.task.evaluation(model, data_loader) + + if results is not None: + return self.task.after_evaluation( + val_result=results, + split_name=split_name, + epoch=cur_epoch, + ) + + def unwrap_dist_model(self, model): + if self.use_distributed: + return model.module + else: + return model + + def create_loaders( + self, + datasets, + num_workers, + batch_sizes, + is_trains, + collate_fns, + dataset_ratios=None, + ): + """ + Create dataloaders for training and validation. + """ + + def _create_loader(dataset, num_workers, bsz, is_train, collate_fn): + # create a single dataloader for each split + if isinstance(dataset, ChainDataset) or isinstance( + dataset, wds.DataPipeline + ): + # wds.WebdDataset instance are chained together + # webdataset.DataPipeline has its own sampler and collate_fn + loader = iter( + DataLoader( + dataset, + batch_size=bsz, + num_workers=num_workers, + pin_memory=True, + ) + ) + else: + # map-style dataset are concatenated together + # setup distributed sampler + if self.use_distributed: + sampler = DistributedSampler( + dataset, + shuffle=is_train, + num_replicas=get_world_size(), + rank=get_rank(), + ) + if not self.use_dist_eval_sampler: + # e.g. retrieval evaluation + sampler = sampler if is_train else None + else: + sampler = None + + loader = DataLoader( + dataset, + batch_size=bsz, + num_workers=num_workers, + pin_memory=True, + sampler=sampler, + shuffle=sampler is None and is_train, + collate_fn=collate_fn, + drop_last=True if is_train else False, + ) + loader = PrefetchLoader(loader) + + if is_train: + loader = IterLoader(loader, use_distributed=self.use_distributed) + + return loader + + loaders = [] + + for dataset, bsz, is_train, collate_fn in zip( + datasets, batch_sizes, is_trains, collate_fns + ): + if isinstance(dataset, list) or isinstance(dataset, tuple): + if hasattr(dataset[0], 'sample_ratio') and dataset_ratios is None: + dataset_ratios = [d.sample_ratio for d in dataset] + loader = MultiIterLoader( + loaders=[ + _create_loader(d, num_workers, bsz, is_train, collate_fn[i]) + for i, d in enumerate(dataset) + ], + ratios=dataset_ratios, + ) + else: + loader = _create_loader(dataset, num_workers, bsz, is_train, collate_fn) + + loaders.append(loader) + + return loaders + + @main_process + def _save_checkpoint(self, cur_epoch, is_best=False): + """ + Save the checkpoint at the current epoch. + """ + model_no_ddp = self.unwrap_dist_model(self.model) + param_grad_dic = { + k: v.requires_grad for (k, v) in model_no_ddp.named_parameters() + } + state_dict = model_no_ddp.state_dict() + for k in list(state_dict.keys()): + if k in param_grad_dic.keys() and not param_grad_dic[k]: + # delete parameters that do not require gradient + del state_dict[k] + save_obj = { + "model": state_dict, + "optimizer": self.optimizer.state_dict(), + "config": self.config.to_dict(), + "scaler": self.scaler.state_dict() if self.scaler else None, + "epoch": cur_epoch, + } + save_to = os.path.join( + self.output_dir, + "checkpoint_{}.pth".format("best" if is_best else cur_epoch), + ) + logging.info("Saving checkpoint at epoch {} to {}.".format(cur_epoch, save_to)) + torch.save(save_obj, save_to) + + def _reload_best_model(self, model): + """ + Load the best checkpoint for evaluation. + """ + checkpoint_path = os.path.join(self.output_dir, "checkpoint_best.pth") + + logging.info("Loading checkpoint from {}.".format(checkpoint_path)) + checkpoint = torch.load(checkpoint_path, map_location="cpu") + try: + model.load_state_dict(checkpoint["model"]) + except RuntimeError as e: + logging.warning( + """ + Key mismatch when loading checkpoint. This is expected if only part of the model is saved. + Trying to load the model with strict=False. + """ + ) + model.load_state_dict(checkpoint["model"], strict=False) + return model + + def _load_checkpoint(self, url_or_filename): + """ + Resume from a checkpoint. + """ + if is_url(url_or_filename): + cached_file = download_cached_file( + url_or_filename, check_hash=False, progress=True + ) + checkpoint = torch.load(cached_file, map_location=self.device) + elif os.path.isfile(url_or_filename): + checkpoint = torch.load(url_or_filename, map_location=self.device) + else: + raise RuntimeError("checkpoint url or path is invalid") + + state_dict = checkpoint["model"] + self.unwrap_dist_model(self.model).load_state_dict(state_dict,strict=False) + + self.optimizer.load_state_dict(checkpoint["optimizer"]) + if self.scaler and "scaler" in checkpoint: + self.scaler.load_state_dict(checkpoint["scaler"]) + + self.start_epoch = checkpoint["epoch"] + 1 + logging.info("Resume checkpoint from {}".format(url_or_filename)) + + @main_process + def log_stats(self, stats, split_name): + if isinstance(stats, dict): + log_stats = {**{f"{split_name}_{k}": v for k, v in stats.items()}} + with open(os.path.join(self.output_dir, "log.txt"), "a") as f: + f.write(json.dumps(log_stats) + "\n") + elif isinstance(stats, list): + pass + + @main_process + def log_config(self): + with open(os.path.join(self.output_dir, "log.txt"), "a") as f: + f.write(json.dumps(self.config.to_dict(), indent=4) + "\n") diff --git a/minigpt4/tasks/__init__.py b/minigpt4/tasks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ab1fb1c8289535cf9397bb9805c0cba3666ad26f --- /dev/null +++ b/minigpt4/tasks/__init__.py @@ -0,0 +1,26 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from minigpt4.common.registry import registry +from minigpt4.tasks.base_task import BaseTask +from minigpt4.tasks.image_text_pretrain import ImageTextPretrainTask + + +def setup_task(cfg): + assert "task" in cfg.run_cfg, "Task name must be provided." + + task_name = cfg.run_cfg.task + task = registry.get_task_class(task_name).setup_task(cfg=cfg) + assert task is not None, "Task {} not properly registered.".format(task_name) + + return task + + +__all__ = [ + "BaseTask", + "ImageTextPretrainTask", +] diff --git a/minigpt4/tasks/base_task.py b/minigpt4/tasks/base_task.py new file mode 100644 index 0000000000000000000000000000000000000000..7ceee96bdf520f8d730651e815defd83b7ecfebb --- /dev/null +++ b/minigpt4/tasks/base_task.py @@ -0,0 +1,286 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import logging +import os + +import torch +import torch.distributed as dist +from minigpt4.common.dist_utils import get_rank, get_world_size, is_main_process, is_dist_avail_and_initialized +from minigpt4.common.logger import MetricLogger, SmoothedValue +from minigpt4.common.registry import registry +from minigpt4.datasets.data_utils import prepare_sample + + +class BaseTask: + def __init__(self, **kwargs): + super().__init__() + + self.inst_id_key = "instance_id" + + @classmethod + def setup_task(cls, **kwargs): + return cls() + + def build_model(self, cfg): + model_config = cfg.model_cfg + + model_cls = registry.get_model_class(model_config.arch) + return model_cls.from_config(model_config) + + def build_datasets(self, cfg): + """ + Build a dictionary of datasets, keyed by split 'train', 'valid', 'test'. + Download dataset and annotations automatically if not exist. + + Args: + cfg (common.config.Config): _description_ + + Returns: + dict: Dictionary of torch.utils.data.Dataset objects by split. + """ + + datasets = dict() + + datasets_config = cfg.datasets_cfg + + assert len(datasets_config) > 0, "At least one dataset has to be specified." + + for name in datasets_config: + dataset_config = datasets_config[name] + + builder = registry.get_builder_class(name)(dataset_config) + dataset = builder.build_datasets() + + dataset['train'].name = name + if 'sample_ratio' in dataset_config: + dataset['train'].sample_ratio = dataset_config.sample_ratio + + datasets[name] = dataset + + return datasets + + def train_step(self, model, samples): + loss = model(samples)["loss"] + return loss + + def valid_step(self, model, samples): + raise NotImplementedError + + def before_evaluation(self, model, dataset, **kwargs): + model.before_evaluation(dataset=dataset, task_type=type(self)) + + def after_evaluation(self, **kwargs): + pass + + def inference_step(self): + raise NotImplementedError + + def evaluation(self, model, data_loader, cuda_enabled=True): + metric_logger = MetricLogger(delimiter=" ") + header = "Evaluation" + # TODO make it configurable + print_freq = 10 + + results = [] + + for samples in metric_logger.log_every(data_loader, print_freq, header): + samples = prepare_sample(samples, cuda_enabled=cuda_enabled) + + eval_output = self.valid_step(model=model, samples=samples) + results.extend(eval_output) + + if is_dist_avail_and_initialized(): + dist.barrier() + + return results + + def train_epoch( + self, + epoch, + model, + data_loader, + optimizer, + lr_scheduler, + scaler=None, + cuda_enabled=False, + log_freq=50, + accum_grad_iters=1, + ): + return self._train_inner_loop( + epoch=epoch, + iters_per_epoch=lr_scheduler.iters_per_epoch, + model=model, + data_loader=data_loader, + optimizer=optimizer, + scaler=scaler, + lr_scheduler=lr_scheduler, + log_freq=log_freq, + cuda_enabled=cuda_enabled, + accum_grad_iters=accum_grad_iters, + ) + + def train_iters( + self, + epoch, + start_iters, + iters_per_inner_epoch, + model, + data_loader, + optimizer, + lr_scheduler, + scaler=None, + cuda_enabled=False, + log_freq=50, + accum_grad_iters=1, + ): + return self._train_inner_loop( + epoch=epoch, + start_iters=start_iters, + iters_per_epoch=iters_per_inner_epoch, + model=model, + data_loader=data_loader, + optimizer=optimizer, + scaler=scaler, + lr_scheduler=lr_scheduler, + log_freq=log_freq, + cuda_enabled=cuda_enabled, + accum_grad_iters=accum_grad_iters, + ) + + def _train_inner_loop( + self, + epoch, + iters_per_epoch, + model, + data_loader, + optimizer, + lr_scheduler, + scaler=None, + start_iters=None, + log_freq=50, + cuda_enabled=False, + accum_grad_iters=1, + ): + """ + An inner training loop compatible with both epoch-based and iter-based training. + + When using epoch-based, training stops after one epoch; when using iter-based, + training stops after #iters_per_epoch iterations. + """ + use_amp = scaler is not None + + if not hasattr(data_loader, "__next__"): + # convert to iterator if not already + data_loader = iter(data_loader) + + metric_logger = MetricLogger(delimiter=" ") + metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}")) + metric_logger.add_meter("loss", SmoothedValue(window_size=1, fmt="{value:.4f}")) + + # if iter-based runner, schedule lr based on inner epoch. + logging.info( + "Start training epoch {}, {} iters per inner epoch.".format( + epoch, iters_per_epoch + ) + ) + header = "Train: data epoch: [{}]".format(epoch) + if start_iters is None: + # epoch-based runner + inner_epoch = epoch + else: + # In iter-based runner, we schedule the learning rate based on iterations. + inner_epoch = start_iters // iters_per_epoch + header = header + "; inner epoch [{}]".format(inner_epoch) + + for i in metric_logger.log_every(range(iters_per_epoch), log_freq, header): + # if using iter-based runner, we stop after iters_per_epoch iterations. + if i >= iters_per_epoch: + break + + samples = next(data_loader) + + samples = prepare_sample(samples, cuda_enabled=cuda_enabled) + samples.update( + { + "epoch": inner_epoch, + "num_iters_per_epoch": iters_per_epoch, + "iters": i, + } + ) + + lr_scheduler.step(cur_epoch=inner_epoch, cur_step=i) + + with torch.cuda.amp.autocast(enabled=use_amp): + loss = self.train_step(model=model, samples=samples) + + # after_train_step() + if use_amp: + scaler.scale(loss).backward() + else: + loss.backward() + + # update gradients every accum_grad_iters iterations + if (i + 1) % accum_grad_iters == 0: + if use_amp: + scaler.step(optimizer) + scaler.update() + else: + optimizer.step() + optimizer.zero_grad() + + metric_logger.update(loss=loss.item()) + metric_logger.update(lr=optimizer.param_groups[0]["lr"]) + + # after train_epoch() + # gather the stats from all processes + metric_logger.synchronize_between_processes() + logging.info("Averaged stats: " + str(metric_logger.global_avg())) + return { + k: "{:.3f}".format(meter.global_avg) + for k, meter in metric_logger.meters.items() + } + + @staticmethod + def save_result(result, result_dir, filename, remove_duplicate=""): + import json + + result_file = os.path.join( + result_dir, "%s_rank%d.json" % (filename, get_rank()) + ) + final_result_file = os.path.join(result_dir, "%s.json" % filename) + + json.dump(result, open(result_file, "w")) + + if is_dist_avail_and_initialized(): + dist.barrier() + + if is_main_process(): + logging.warning("rank %d starts merging results." % get_rank()) + # combine results from all processes + result = [] + + for rank in range(get_world_size()): + result_file = os.path.join( + result_dir, "%s_rank%d.json" % (filename, rank) + ) + res = json.load(open(result_file, "r")) + result += res + + if remove_duplicate: + result_new = [] + id_list = [] + for res in result: + if res[remove_duplicate] not in id_list: + id_list.append(res[remove_duplicate]) + result_new.append(res) + result = result_new + + json.dump(result, open(final_result_file, "w")) + print("result file saved to %s" % final_result_file) + + return final_result_file diff --git a/minigpt4/tasks/image_text_pretrain.py b/minigpt4/tasks/image_text_pretrain.py new file mode 100644 index 0000000000000000000000000000000000000000..bbe8ec83a5dc95ee26a36e457feb394d18b7cd17 --- /dev/null +++ b/minigpt4/tasks/image_text_pretrain.py @@ -0,0 +1,18 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from minigpt4.common.registry import registry +from minigpt4.tasks.base_task import BaseTask + + +@registry.register_task("image_text_pretrain") +class ImageTextPretrainTask(BaseTask): + def __init__(self): + super().__init__() + + def evaluation(self, model, data_loader, cuda_enabled=True): + pass diff --git a/prompts/alignment.txt b/prompts/alignment.txt new file mode 100644 index 0000000000000000000000000000000000000000..38ae75a9cee293861f06544cbff6fdc4aa941d85 --- /dev/null +++ b/prompts/alignment.txt @@ -0,0 +1,4 @@ + Describe this image in detail. + Take a look at this image and describe what you notice. + Please provide a detailed description of the picture. + Could you describe the contents of this image for me? \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000000000000000000000000000000000000..a90cb3ff111a9e099654ca88162d81b2735192ac --- /dev/null +++ b/train.py @@ -0,0 +1,103 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import argparse +import os +import random + +import numpy as np +import torch +import torch.backends.cudnn as cudnn + +import minigpt4.tasks as tasks +from minigpt4.common.config import Config +from minigpt4.common.dist_utils import get_rank, init_distributed_mode +from minigpt4.common.logger import setup_logger +from minigpt4.common.optims import ( + LinearWarmupCosineLRScheduler, + LinearWarmupStepLRScheduler, +) +from minigpt4.common.registry import registry +from minigpt4.common.utils import now + +# imports modules for registration +from minigpt4.datasets.builders import * +from minigpt4.models import * +from minigpt4.processors import * +from minigpt4.runners import * +from minigpt4.tasks import * + + +def parse_args(): + parser = argparse.ArgumentParser(description="Training") + + parser.add_argument("--cfg-path", required=True, help="path to configuration file.") + parser.add_argument( + "--options", + nargs="+", + help="override some settings in the used config, the key-value pair " + "in xxx=yyy format will be merged into config file (deprecate), " + "change to --cfg-options instead.", + ) + + args = parser.parse_args() + # if 'LOCAL_RANK' not in os.environ: + # os.environ['LOCAL_RANK'] = str(args.local_rank) + + return args + + +def setup_seeds(config): + seed = config.run_cfg.seed + get_rank() + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + cudnn.benchmark = False + cudnn.deterministic = True + + +def get_runner_class(cfg): + """ + Get runner class from config. Default to epoch-based runner. + """ + runner_cls = registry.get_runner_class(cfg.run_cfg.get("runner", "runner_base")) + + return runner_cls + + +def main(): + # allow auto-dl completes on main process without timeout when using NCCL backend. + # os.environ["NCCL_BLOCKING_WAIT"] = "1" + + # set before init_distributed_mode() to ensure the same job_id shared across all ranks. + job_id = now() + + cfg = Config(parse_args()) + + init_distributed_mode(cfg.run_cfg) + + setup_seeds(cfg) + + # set after init_distributed_mode() to only log on master. + setup_logger() + + cfg.pretty_print() + + task = tasks.setup_task(cfg) + datasets = task.build_datasets(cfg) + model = task.build_model(cfg) + + runner = get_runner_class(cfg)( + cfg=cfg, job_id=job_id, task=task, model=model, datasets=datasets + ) + runner.train() + + +if __name__ == "__main__": + main() diff --git a/train_configs/minigpt4_llama2_stage1_pretrain.yaml b/train_configs/minigpt4_llama2_stage1_pretrain.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f3981b8ac9a8e08c14b56b0ace366edcee5f88a9 --- /dev/null +++ b/train_configs/minigpt4_llama2_stage1_pretrain.yaml @@ -0,0 +1,55 @@ +model: + arch: minigpt4 + model_type: pretrain_llama2 + + +datasets: + laion: + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + text_processor: + train: + name: "blip_caption" + sample_ratio: 115 + cc_sbu: + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + text_processor: + train: + name: "blip_caption" + sample_ratio: 14 + + +run: + task: image_text_pretrain + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 1e-4 + min_lr: 8e-5 + warmup_lr: 1e-6 + + weight_decay: 0.05 + max_epoch: 4 + batch_size_train: 64 + batch_size_eval: 64 + num_workers: 4 + warmup_steps: 5000 + iters_per_epoch: 5000 + + seed: 42 + output_dir: "output/minigpt4_stage1_pretrain" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file diff --git a/train_configs/minigpt4_llama2_stage2_finetune.yaml b/train_configs/minigpt4_llama2_stage2_finetune.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fa2b578ed86a433682d1c7e620f90ab3dd3e2209 --- /dev/null +++ b/train_configs/minigpt4_llama2_stage2_finetune.yaml @@ -0,0 +1,50 @@ +model: + arch: minigpt4 + model_type: pretrain_llama2 + + max_txt_len: 160 + end_sym: "" + prompt_path: "prompts/alignment.txt" + prompt_template: '[INST] {} [/INST] ' + ckpt: '/path/to/stage1/checkpoint/' + + +datasets: + cc_sbu_align: + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + text_processor: + train: + name: "blip_caption" + +run: + task: image_text_pretrain + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 3e-5 + min_lr: 1e-5 + warmup_lr: 1e-6 + + weight_decay: 0.05 + max_epoch: 5 + iters_per_epoch: 200 + batch_size_train: 12 + batch_size_eval: 12 + num_workers: 4 + warmup_steps: 200 + + seed: 42 + output_dir: "output/minigpt4_stage2_finetune" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file diff --git a/train_configs/minigpt4_stage1_pretrain.yaml b/train_configs/minigpt4_stage1_pretrain.yaml new file mode 100644 index 0000000000000000000000000000000000000000..be87b77cd2df109685f36c11d6ed7813f47a17c9 --- /dev/null +++ b/train_configs/minigpt4_stage1_pretrain.yaml @@ -0,0 +1,55 @@ +model: + arch: minigpt4 + model_type: pretrain_vicuna0 + + +datasets: + laion: + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + text_processor: + train: + name: "blip_caption" + sample_ratio: 115 + cc_sbu: + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + text_processor: + train: + name: "blip_caption" + sample_ratio: 14 + + +run: + task: image_text_pretrain + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 1e-4 + min_lr: 8e-5 + warmup_lr: 1e-6 + + weight_decay: 0.05 + max_epoch: 4 + batch_size_train: 64 + batch_size_eval: 64 + num_workers: 4 + warmup_steps: 5000 + iters_per_epoch: 5000 + + seed: 42 + output_dir: "output/minigpt4_stage1_pretrain" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file diff --git a/train_configs/minigpt4_stage2_finetune.yaml b/train_configs/minigpt4_stage2_finetune.yaml new file mode 100644 index 0000000000000000000000000000000000000000..404dfd6a0552c31cf6c59e3ee82333676407933c --- /dev/null +++ b/train_configs/minigpt4_stage2_finetune.yaml @@ -0,0 +1,50 @@ +model: + arch: minigpt4 + model_type: pretrain_vicuna0 + + max_txt_len: 160 + end_sym: "###" + prompt_path: "prompts/alignment.txt" + prompt_template: '###Human: {} ###Assistant: ' + ckpt: '/path/to/stage1/checkpoint/' + + +datasets: + cc_sbu_align: + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + text_processor: + train: + name: "blip_caption" + +run: + task: image_text_pretrain + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 3e-5 + min_lr: 1e-5 + warmup_lr: 1e-6 + + weight_decay: 0.05 + max_epoch: 5 + iters_per_epoch: 200 + batch_size_train: 12 + batch_size_eval: 12 + num_workers: 4 + warmup_steps: 200 + + seed: 42 + output_dir: "output/minigpt4_stage2_finetune" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file