diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..eec0272e28ace9b0debcba546480e2ace1fe2f20 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +*.mp4 filter=lfs diff=lfs merge=lfs -text +*.mov filter=lfs diff=lfs merge=lfs -text +*.png filter=lfs diff=lfs merge=lfs -text +*.jpg filter=lfs diff=lfs merge=lfs -text +*.pdf filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..5edeb811648e63ed12f0ed009b0118b1d055d414 --- /dev/null +++ b/.gitignore @@ -0,0 +1,66 @@ +# local # +tmp*/ +cache/* +*/cache*/ +tmp*.py +tmp* +*pickle +data/ + +# Zip Files/Packages # +*.7z +*.dmg +*.gz +*.iso +*.jar +*.rar +*.tar +*.zip + +# Logs and databases # +*.log +*.sql +*.sqlite +.ipynb_checkpoints/ +*.swp +*.vscode/ +*.idea/ +*.pyc +__pycache__ +slurm*out + +# OS files # +.DS_Store +.DS_Store? +._* +.Spotlight-V100 +.Trashes +ehthumbs.db +Thumbs.db + + +.vim-arsync +scratch.norg +sync_to_red.sh + +anno/ +wandb/ +logs/ +accelerate_config/ +*.pth +hf_* + +# local folders +MODELS +DATAS +SAVED +EXPERIMENTS +REMOTE_HF +TEST + +test_results +test_training +test_hdfs.py +magic_video_outputs/llava* +magic_video_outputs +pllava_video_outputs/ \ No newline at end of file diff --git a/DATA.md b/DATA.md new file mode 100644 index 0000000000000000000000000000000000000000..cf7783763fa362e2c8e57fe78fbf354c3261485d --- /dev/null +++ b/DATA.md @@ -0,0 +1,124 @@ +# Data +## Instruction Training Data + + + +For training, we leveraged the video instruction tuning data from [Videochat2](https://github.com/OpenGVLab/Ask-Anything/tree/main/video_chat2). + +#### 1. Download json annotation files from huggingface. +[![Dataset meta](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-VideoChat2%20IT-blue)](https://huggingface.co/datasets/OpenGVLab/VideoChat2-IT) + + + +#### 2. Download the raw videos from the following links. +The video directories can be found in tasks/train/instruction_data.py. You can also change them to your own saved paths. + +- [VideoChat](https://github.com/OpenGVLab/InternVideo/tree/main/Data/instruction_data): Based on [InternVid](https://github.com/OpenGVLab/InternVideo/tree/main/Data/InternVid), download the processed version directly [here](https://pjlab-gvm-data.oss-cn-shanghai.aliyuncs.com/videochat2/data/videochat2_conversation_videos.zip) +- [VideoChatGPT](https://github.com/mbzuai-oryx/Video-ChatGPT/tree/main/data) +- [Kinetics-710](https://github.com/OpenGVLab/UniFormerV2/blob/main/DATASET.md), download Kinetics 400/600/700 [here](https://openxlab.org.cn/datasets?keywords=kinetics). +- [SthSthV2](https://developer.qualcomm.com/software/ai-datasets/something-something): Option candidates were generated from [UMT](https://github.com/OpenGVLab/unmasked_teacher) top-20 predictions. +- [NExTQA](https://github.com/doc-doc/NExT-QA) +- [CLEVRER](https://clevrer.csail.mit.edu/) +- [WebVid](https://maxbain.com/webvid-dataset/) +- [YouCook2](https://youcook2.eecs.umich.edu/), download the processed version [here](https://pjlab-gvm-data.oss-cn-shanghai.aliyuncs.com/videochat2/data/youcook_split_videos.zip). +- [TextVR](https://github.com/callsys/textvr) +- [TGIF](https://github.com/YunseokJANG/tgif-qa) +- [EgoQA](https://ego4d-data.org/), download the processed version [here](https://pjlab-gvm-data.oss-cn-shanghai.aliyuncs.com/videochat2/data/egoqa_split_videos.zip). + +#### 3. We also provide our processed json annotation files here. + +[![Dataset meta](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-magic%5Fjsons-blue)](https://huggingface.co/datasets/cathyxl/magic_jsons) + + + + +## Evaluation Data & Others +Follow this section to obtain the evaluation open resources. + +### VCGBench + +We refer to the VideoChatGPT video question answering evaluation as VCGBench in this repo. We followed the original [repo](https://github.com/mbzuai-oryx/Video-ChatGPT/tree/main) to prepare the evaluation data. + +### MVBench +We follow the original [Videochat2 repo](https://github.com/OpenGVLab/Ask-Anything/tree/main/video_chat2) in setting up the MVBench Evaluation. You can also find helpful resources at their [huggingface repo](https://huggingface.co/datasets/OpenGVLab/MVBench) + + +### Videoqabench +We refer to all other video question answering benchmarks as videoqabench in this repo. They are mainly prepared folloing the original repos. Each listed: +1. [MSVD](https://www.cs.utexas.edu/users/ml/clamp/videoDescription/) & [MSRVTT](https://github.com/xudejing/video-question-answering) + +3. [Activity Net](https://github.com/MILVLG/activitynet-qa/tree/master) +4. [TGIF](https://github.com/raingo/TGIF-Release/tree/master) + +Also other fantastic repo intergrating these benchmarks are helpful in the process of setting up the evaluation data: +- [VideoChatGPT](https://github.com/mbzuai-oryx/Video-ChatGPT/tree/main) +- [VideoLlava](https://github.com/PKU-YuanGroup/Video-LLaVA/tree/main/videollava) +- [IG-VLM](https://github.com/imagegridworth/IG-VLM/tree/main) + + + +### Recaptioning +#### Inter4k + +This is a dataset with 1000 samples of high resolution videos. We prepare the data folloing the instructions from their [official website](https://alexandrosstergiou.github.io/datasets/Inter4K/index.html) + +#### Extending Reacptioning +The recaptioning part is designed to be extendable. + +inference script [tasks/eval/recaption/pllava_recaption.py](tasks/eval/recaption/pllava_recaption.py) would use a dataset class [RecaptionDataset](tasks/eval/recaption/__init__.py#L197). The detailed information is kept in the data_list_info attribute as: +``` +data_list_info = OrderedDict({ + # "Panda70M": OrderedDict( + # json_relpath="Panda70M/annotations.json", + # prefix="DATAS/Recaption/Panda70M/videos", + # data_type="video", + # bound=False, + # key_rename_map={ + # # 'caption': 'hint', + # }, + # name_key='video_name', + # postfix=('mp4', 'mkv', 'webm'), + # recaption_type=RecaptionSample, + # ), # don't has start & end + "Inter4K": OrderedDict( + json_relpath="Inter4K/annotations.json", + prefix="DATAS/Recaption/Inter4K/60fps/UHD", + data_type="video", + bound=False, + key_rename_map={ + # 'caption': 'hint', + }, + name_key='video_name', + postfix=('mp4', 'mkv', 'webm'), + recaption_type=CaptionSample, + ), # don't has start & end + }) +``` +It contains the path to a annotation json file where there is a list and each item of the list is a sample waiting for captioning. For example, the Inter4K/annotations.json is like: +```json +[ + { + "video_name": "973" + }, + ... +] +``` +and the directory DATAS/Recaption/Inter4K/60fps/UHD would look like: +``` +$ ls DATAS/Recaption/Inter4K/60fps/UHD +1.mp4 134.mp4 170.mp4 .... +``` + +Naively, only the video is needed when captioning directly, therefore the annotation file only needs to contain the names of each video under the "prefix" directory. + +Extending a dataset for captioning would consist of the folloing steps: +1. have all the videos downloaded +2. construct a annotation.json file with sepecific format. +3. configure the recaption dataset [here](tasks/eval/recaption/__init__.py#L197), where you would need to determine: + - json_relpath: the annotation relative path + - prefix: root directory for videos + - postfix: a list containing all the file extensions for these videos + +The other options are experimental, so stick with the default setting as in Inter4k. The recommended length of video is around 5-20 seconds. + +p.s. "bound" is to make sure the video pass to the model doesn't have scene transition or so. This part wasn't tested, so set the bound to false and make sure the original videos files are single clip of a video. But always feel free to discover and contribute to PLLaVA! \ No newline at end of file diff --git a/README.md b/README.md index d4d6aabb57207c1a5fd98dfc453134d181f91b46..c14552fd3f0c263ff23e5acb724216e40fc56a58 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,383 @@ --- -title: Pllava 7b Demo -emoji: 🌖 +title: Plava 7b Demo +emoji: 👁 colorFrom: blue -colorTo: blue +colorTo: yellow sdk: gradio -sdk_version: 4.28.3 +sdk_version: 4.27.0 app_file: app.py pinned: false --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +
+ +

PLLaVA: Parameter-free LLaVA Extension from Images to Videos for Video Dense Captioning

+ +[Lin Xu](https://scholar.google.com/citations?user=_Gu69coAAAAJ), [Yilin Zhao](https://ermu2001.github.io/me.io/), [Daquan Zhou](https://scholar.google.com/citations?user=DdCAbWwAAAAJ), [Zhijie Lin](https://scholar.google.com/citations?user=xXMj6_EAAAAJ), [See-Kiong Ng](https://scholar.google.com/citations?user=_wsommYAAAAJ), [Jiashi Feng](https://scholar.google.com.sg/citations?user=Q8iay0gAAAAJ&hl=en) + +
+ + + +**Project Page: [PLLaVA](https://pllava.github.io/)** + +[![arXiv](https://img.shields.io/badge/arXiv-2404.16994-b31b1b.svg)](https://arxiv.org/abs/2404.16994) +[![YouTube Video](https://img.shields.io/badge/YouTube-Video-red)](https://www.youtube.com/watch?v=nAEje8tu18U) +[![Model on HF](https://huggingface.co/datasets/huggingface/badges/resolve/main/model-on-hf-sm-dark.svg)](https://huggingface.co/ermu2001/pllava-34b) + +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/pllava-parameter-free-llava-extension-from-1/zeroshot-video-question-answer-on-activitynet)](https://paperswithcode.com/sota/zeroshot-video-question-answer-on-activitynet?p=pllava-parameter-free-llava-extension-from-1) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/pllava-parameter-free-llava-extension-from-1/zeroshot-video-question-answer-on-msrvtt-qa)](https://paperswithcode.com/sota/zeroshot-video-question-answer-on-msrvtt-qa?p=pllava-parameter-free-llava-extension-from-1) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/pllava-parameter-free-llava-extension-from-1/zeroshot-video-question-answer-on-msvd-qa)](https://paperswithcode.com/sota/zeroshot-video-question-answer-on-msvd-qa?p=pllava-parameter-free-llava-extension-from-1) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/pllava-parameter-free-llava-extension-from-1/video-question-answering-on-mvbench)](https://paperswithcode.com/sota/video-question-answering-on-mvbench?p=pllava-parameter-free-llava-extension-from-1) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/pllava-parameter-free-llava-extension-from-1/zeroshot-video-question-answer-on-tgif-qa)](https://paperswithcode.com/sota/zeroshot-video-question-answer-on-tgif-qa?p=pllava-parameter-free-llava-extension-from-1) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/pllava-parameter-free-llava-extension-from-1/video-based-generative-performance-4)](https://paperswithcode.com/sota/video-based-generative-performance-4?p=pllava-parameter-free-llava-extension-from-1) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/pllava-parameter-free-llava-extension-from-1/video-based-generative-performance-3)](https://paperswithcode.com/sota/video-based-generative-performance-3?p=pllava-parameter-free-llava-extension-from-1) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/pllava-parameter-free-llava-extension-from-1/video-based-generative-performance)](https://paperswithcode.com/sota/video-based-generative-performance?p=pllava-parameter-free-llava-extension-from-1) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/pllava-parameter-free-llava-extension-from-1/video-based-generative-performance-2)](https://paperswithcode.com/sota/video-based-generative-performance-2?p=pllava-parameter-free-llava-extension-from-1) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/pllava-parameter-free-llava-extension-from-1/video-based-generative-performance-1)](https://paperswithcode.com/sota/video-based-generative-performance-1?p=pllava-parameter-free-llava-extension-from-1) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/pllava-parameter-free-llava-extension-from-1/video-based-generative-performance-5)](https://paperswithcode.com/sota/video-based-generative-performance-5?p=pllava-parameter-free-llava-extension-from-1) + + + + + + + +![]() +
+ + + +
+ +
+
+ + + + + + +## Overview + +Welcome to PLLAVA! + +The primary purpose of this repository is to support research and the development of prototype models. It is designed to facilitate ease of experimentation and enable a clear overview of results. Please note that this repo is currently undergoing development and reconstruction. + +It's important to mention that we have not optimized the response speed of the application or the frontend logic. Our goal is to maintain simplicity, clarity, and ease of development, making it accessible for both researchers and students. If you have suggestions or want to enhance the application's performance, please feel free to contact us or contribute to the project. + + +We've briefly introduce our work in section [PLLAVA](#%EF%B8%8F-pllava). For more details, feel free to read our paper. Check out section [Usage](#hammer-usage) to start using this repo. If you felt our works interesting, please star us, your support is all we want. If you find our work helpful, feel free to [cite](#page_facing_up-citation) us directly. + +## :fire: Updates + +- **2024/4/24**: Release: + - We are releasing our code/models/datasets. + +## 🏖️ PLLAVA +
+ + + +
+ + +### Abstract + +Vision-language pre-training (VLP) has significantly elevated performance across a range of vision-language applications. Yet, the pre-training process for video-related tasks demands an exceptionally high degree of computational and data resources. This paper investigates a straightforward, highly efficient, and resource-light approach to adapting an existing image-language pre-training model for video data. Our preliminary experiments reveal that directly fine-tuning pre-trained image-language models with multiple frames on video datasets leads to performance saturation or even a drop in caption-related tasks. Besides, it is also vulnerable to prompts and tends to provide short descriptions. We conducted a deep analysis and observed that the performance saturation and the vulnerability might be related to the dominant patches that exist in some single video patches. We then propose a simple pooling strategy to smooth the feature distribution along the temporal dimension and thus reduce the dominant impacts from some extreme tokens. The new model is termed Pooling LLaVA, or PLLaVA in short. With the proposed pooling strategy, we achieve new state-of-the-art performance on all evaluated datasets. Notably, on the recent popular Video ChatGPT benchmark, PLLaVA achieves a score of 3.48 out of 5 on average of five evaluated dimensions, which is the new state-of-the-art score on the leaderboard and is 0.31 higher than the previous SOTA results from GPT4V (IG-VLM). On the latest multi-choice benchmark MVBench, PLLaVA achieves 58.1% accuracy on average across 20 sub-tasks, which is the new state-of-the-art result and is 14.5% higher than GPT4V (IG-VLM). + +
+ + +### SEARCHING FOR OPTIMAL POOLING STRATEGY +There are two dimensions for the pooling strategy: the spatial dimension and the temporal dimension. We empirically found that reducing the spatial dimension with a larger temporal dimension could lead to better model performance, compared to reducing the temporal dimension directly. + +
+ + +### STATE-OF-THE-ART PERFORMANCE +We compare the performance of PLLAVA with recent popular methods over both question-answer and captioning datasets. The results are shown below. + +
+ +## :hammer: Usage + +This section provides guidance on how to run, train, and evaluate our models. + +### Install +First, you will need to set up the environment and download some pre-trained weights. + +This repo is built up using [transformers](https://github.com/huggingface/transformers) for model construction along with [accelerate](https://github.com/huggingface/accelerate) for distributed training. Follow the instructions to install the needed environment. + +1. Above all, the following environment set up is for python 3.10. If you choose to use conda for environment setup, we recommend creating the virtual environment with: +```bash +conda create -n pllava python=3.10 +``` + +1. Firstly, install [pytorch](https://pytorch.org/) from the official website. The code runs on torch 2.2.1, cu118 or cu122. Select the version that suits your drive version. + +``` +torch 2.2.1+cu118 +torchaudio 2.2.1+cu118 +torchvision 0.17.1+cu118 +``` + +If your driver version is higher than cu121, you could probably try installing with the following scripts: +```bash +pip install -r requirements.txt +``` + +Otherwise, you would need to install a torch for your server first, then install the other packages: +```bash +pip install -r requirements.torch.txt # decide your own requirements, (this is for cu11), or install torch directly following the official website. +pip install -r requirements.no_torch.txt # install the following +``` + +1. Prepare the model. +We prefer to have huggingface models explicitly downloaded to a MODELS directory. However, if you are familiar with huggingface-hub usage, feel free to organize the model yourself. +``` +python python_scripts/hf.py +``` + +Here are some detailed information of the obtained models: + + +| Model | Link | Initialized From | +| ------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------- | +| pllava-7b | [![Model on HF](https://huggingface.co/datasets/huggingface/badges/resolve/main/model-on-hf-sm-dark.svg)](https://huggingface.co/ermu2001/pllava-7b) | [llava-hf/llava-v1.6-vicuna-7b-hf · Hugging Face](https://huggingface.co/llava-hf/llava-v1.6-vicuna-7b-hf) | +| pllava-13b | [![Model on HF](https://huggingface.co/datasets/huggingface/badges/resolve/main/model-on-hf-sm-dark.svg)](https://huggingface.co/ermu2001/pllava-13b) | [llava-hf/llava-v1.6-vicuna-13b-hf · Hugging Face](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) | +| pllava-34b | [![Model on HF](https://huggingface.co/datasets/huggingface/badges/resolve/main/model-on-hf-sm-dark.svg)](https://huggingface.co/ermu2001/pllava-34b) | [llava-hf/llava-v1.6-34b-hf · Hugging Face](https://huggingface.co/llava-hf/llava-v1.6-34b-hf) | + +The model directory should look like this, where you would only need the corresponding model's weights and directory. + +``` +$ tree MODELS +MODELS +|-- pllava-13b +| |-- added_tokens.json +| |-- config.json +| |-- generation_config.json +| |-- model-00001-of-00006.safetensors +| |-- model-00002-of-00006.safetensors +| |-- model-00003-of-00006.safetensors +| |-- model-00004-of-00006.safetensors +| |-- model-00005-of-00006.safetensors +| |-- model-00006-of-00006.safetensors +| |-- model.safetensors.index.json +| |-- preprocessor_config.json +| |-- processor_config.json +| |-- special_tokens_map.json +| |-- tokenizer.json +| |-- tokenizer.model +| `-- tokenizer_config.json +|-- pllava-34b +| |-- added_tokens.json +| |-- config.json +| |-- generation_config.json +| |-- model-00001-of-00015.safetensors +| |-- model-00002-of-00015.safetensors +| |-- model-00003-of-00015.safetensors +| |-- model-00004-of-00015.safetensors +| |-- model-00005-of-00015.safetensors +| |-- model-00006-of-00015.safetensors +| |-- model-00007-of-00015.safetensors +| |-- model-00008-of-00015.safetensors +| |-- model-00009-of-00015.safetensors +| |-- model-00010-of-00015.safetensors +| |-- model-00011-of-00015.safetensors +| |-- model-00012-of-00015.safetensors +| |-- model-00013-of-00015.safetensors +| |-- model-00014-of-00015.safetensors +| |-- model-00015-of-00015.safetensors +| |-- model.safetensors-deprecated +| |-- model.safetensors.index.json +| |-- preprocessor_config.json +| |-- processor_config.json +| |-- special_tokens_map.json +| |-- tokenizer.json +| |-- tokenizer.model +| `-- tokenizer_config.json +|-- pllava-7b + |-- added_tokens.json + |-- config.json + |-- generation_config.json + |-- model-00001-of-00003.safetensors + |-- model-00002-of-00003.safetensors + |-- model-00003-of-00003.safetensors + |-- model.safetensors.index.json + |-- preprocessor_config.json + |-- processor_config.json + |-- special_tokens_map.json + |-- tokenizer.json + |-- tokenizer.model + `-- tokenizer_config.json +``` + +With the above steps, you should be able to proceed on with the following usages. + +### Run Application + +To run our models, make sure you have downloaded a model pretrained weights from the huggingface spaces. Then, run the following scripts with the corresponding path input. Since we are only training with lora and the projector, the model to be run are determined with: + +- **model_dir**: model directory, one with config.json as compatible with transformers. This refers to the base model's directory, for example "llava-hf/llava-v1.6-vicuna-7b-hf"/"ermu2001/pllava-7b"/"MODELS/pllava-7b". (default to: MODELS/plave-7b) +- **weights_dir**: your weights directory. could be the same as model_dir, but if you have a weights directory for the lora weights, you should set this weights_dir to that directory to load the lora weights. This directory should be local. Also, it would need to contain a config.json file within. (default to: ${model_dir}). + +```bash +model_dir="model directory" +weights_dir="weights directory" +bash scripts/demo.sh ${model_dir} ${weights_dir} +``` + +Now check out the application demo and try play with PLLAVA! + +### Train + +Follow the following steps to reproduce our results or train your own variant: + +#### 1. Data Preparation + +To train our model from a starting Image-aligned Vision LLM, you would need to download the data first. Our data set up is mainly based on the original Videochat2's training data. Check out [Instruction Data](./DATA.md) to prepare the instruction training data. Ideally, setting up a root data directory and alter the code [here](./tasks/train/instruction_data.py#L6) would accomodate the data for training most smoothly. + +#### 2. Start Training + +Now you're only a few step away from starting the training. Follow the instructions: + +##### Setup Accelerator + +Customize a accelerate training config. For example, a simple config using multiple gpus with no distribution strategy (only torch DDP) would look like: + +```yaml +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: MULTI_GPU +downcast_bf16: 'no' +gpu_ids: all +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false +``` + +Check out out the [Accelerate](https://huggingface.co/docs/accelerate/index) documents for more details. + +##### Overwatch the training configuration + +Next, you should go over a basic training configuration of the training process in [here](tasks/train/config_pllava_nframe.py). Then passing this file as the first arg to the [training script](tasks/train/train_pllava_nframe_accel.py) would utilize every arguments in the file. You can customize some of the hyper parameters for your own training process by passing them in the format of "key" "value" pair in the following arguments. A example training scripts could be find [here](scripts/train_pllava.sh). + +We recommand customize a [configuration](tasks/train/config_pllava_nframe.py) to set up a customized training! + +With the above steps, you would be able to start the training process. The output would be well organized in the output directory, each a qualified model directory to pass in to demo as weights_dir, since we are only saveing the lora weights and projector weights to avoide redundancy. + +### Evaluation + +This section mainly introduce how to reproduce the evaluation or evaluate your own model. + +#### Set up Evaluation Data + +Make sure you set up the "DATAS" directory as in [DATA.md](DATA.md), then you would be able to run the inference with fortune! The evaluation data directory of DATAS would look like: + +``` +DATAS/: +DATAS/VideoQA: +DATAS/VideoQA/TGIF_QA: + test_a.json + test_q.json +DATAS/VideoQA/TGIF_QA/videos: + tumblr_m4387mGrlc1r6m5e8o1_250.gif + ... +DATAS/VideoQA/TGIF_QA/videos_mp4: + tumblr_m4387mGrlc1r6m5e8o1_250.mp4 + ... +DATAS/VideoQA/TGIF_QA/video_gif: + tumblr_m4387mGrlc1r6m5e8o1_250.gif + ... +DATAS/VideoQA/MSVD_Zero_Shot_QA: + test_a.json + test_q.json +DATAS/VideoQA/MSVD_Zero_Shot_QA/videos: + -4wsuPCjDBc_5_15.avi +DATAS/VideoQA/MSVD_Zero_Shot_QA/msvd_qa: +DATAS/VideoQA/ActivityNet: + test_a.json + test_q.json +DATAS/VideoQA/ActivityNet/all_test: + v_--tFD65KaK4.mp4 + ... +DATAS/VideoQA/MSRVTT_Zero_Shot_QA: + test_a.json + test_q.json +DATAS/VideoQA/MSRVTT_Zero_Shot_QA/videos: +DATAS/VideoQA/MSRVTT_Zero_Shot_QA/videos/all: + video0.mp4 + ... + +DATAS/MVBench: + ... + +DATAS/Recaption/Inter4K: + annotations.json +DATAS/Recaption/Inter4K/60fps: +DATAS/Recaption/Inter4K/60fps/UHD: + 1.mp4 + ... + +``` + +#### Start Evaluate + +Once you have construted the evaluation data, you can start the evaluation as in [here](scripts/eval.sh). This script is for evaluating 7B/13B models. As pllava-34b model uses a slightly different prompting, it is evaluated with this [script](scripts/eval_yiprompt.sh). + +``` +bash scripts/eval.sh +``` + +Same as running the demo, you would need to determine the model_dir and weights_dir to evaluate the model. Feel free to comment out some commands and produce partial evaluation. + +#### Overwatch the Results + +The evaluation results would be shown to you with our results gallery demo: + +```bash +bash scripts/gallery.sh +``` + +Feel free to use the compare version to compare differnt models' results or use the single gallery version to check out one model's results. They are basically the same. Check out this [script](scripts/gallery.sh) for more details + +#### For Captioning and Recaptioning +Follow instructions at [DATA.md](DATA.md#extending-reacptioning) and you can extend the recaptioning data with a few steps. + +Feel free to point out high quality dataset of videos, we would proceed on doing captioning on those datasets. + + +# :page_facing_up: Citation + +If you find this project useful in your research, please consider cite: + +```BibTeX +@misc{xu2024pllava, + title={PLLaVA : Parameter-free LLaVA Extension from Images to Videos for Video Dense Captioning}, + author={Lin Xu and Yilin Zhao and Daquan Zhou and Zhijie Lin and See Kiong Ng and Jiashi Feng}, + year={2024}, + eprint={2404.16994}, + archivePrefix={arXiv}, + primaryClass={cs.CV} +} +``` + +# :dizzy: Acknowledgement + +This code base is mainly built upon [Videochat2](https://github.com/OpenGVLab/Ask-Anything/tree/main/video_chat2). SALUTE. + +We would also like to recognize and commend the following open source projects, thank you for your great contribution to the open source community: + +- [LLaVA](https://github.com/haotian-liu/LLaVA): Fantastic Open Source Image LLM Model. +- [VideoChatGPT](https://github.com/mbzuai-oryx/Video-ChatGPT/tree/main): Great Evaluation Benchmarking Framework. +- [VideoLlava](https://github.com/PKU-YuanGroup/Video-LLaVA/tree/main/videollava):Video LLM repo with helpful resources. diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..b171e461da7f087cbb50b4651ec8e180d7a21978 --- /dev/null +++ b/app.py @@ -0,0 +1,18 @@ +import sys +from huggingface_hub import snapshot_download +snapshot_download( + 'ermu2001/pllava-7b', + local_dir='MODELS/pllava-7b', + repo_type='model', + local_dir_use_symlinks=True, +) + +sys.argv.extend([ + "--pretrained_model_name_or_path", "MODELS/pllava-7b", + "--num_frames", "16", + "--use_lora", + "--weight_dir", "MODELS/pllava-7b", + "--lora_alpha", "4", + "--conv_mode", "plain", +]) +import tasks.eval.demo.pllava_demo \ No newline at end of file diff --git a/assert/data.png b/assert/data.png new file mode 100644 index 0000000000000000000000000000000000000000..6350365b1c8c71e46baadfe5541d8026e48ddc7d --- /dev/null +++ b/assert/data.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:72bd5fa48454bfcb6ee1c5b26c3baffd2397502a27bb666860069f0a5755a51b +size 223788 diff --git a/assert/logo.png b/assert/logo.png new file mode 100644 index 0000000000000000000000000000000000000000..7ad2c1ba7bb0a3a08609b9f2154ca4e8e9a49633 --- /dev/null +++ b/assert/logo.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:df1ae4a260b20b749eaaef02d9bad7057cbba958fff92e23e28d1d3b91224668 +size 1319221 diff --git a/assert/module.png b/assert/module.png new file mode 100644 index 0000000000000000000000000000000000000000..6de1783c3f318eaf83498b50e87616a56235ecf6 --- /dev/null +++ b/assert/module.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7933116caeb3552590bc80c543f37456261dcb9984d75a6f81555f4d38ccfa65 +size 226479 diff --git a/assert/performance.png b/assert/performance.png new file mode 100644 index 0000000000000000000000000000000000000000..cfec2e8302cdf5f4d85b0601e2164e10ec2ff8d3 --- /dev/null +++ b/assert/performance.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9bced5f433da0a6424d8bd1bd776f6cb16407ae94d5cf2fbc09ba09e407c37ac +size 106315 diff --git a/assert/teaser.jpg b/assert/teaser.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b6fed202577485c339afe2d1f0f30da8419ce3fe --- /dev/null +++ b/assert/teaser.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f204476020f3995d37a5f7c5b341f8eb739cbb0b5e1e529a8c4e722e5976de54 +size 372098 diff --git a/assert/zeroshot.png b/assert/zeroshot.png new file mode 100644 index 0000000000000000000000000000000000000000..5fc5f31a52ad61ec3ac427f016c81c26114642dd --- /dev/null +++ b/assert/zeroshot.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d6ee8e95e824759b2f93d63db9c4c57f81775576c8b2932b875dd4176b702dab +size 147256 diff --git a/dataset/__init__.py b/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6944e662eba412d2d94931541e04184f0ac31c33 --- /dev/null +++ b/dataset/__init__.py @@ -0,0 +1,158 @@ +import torch +from torch.utils.data import ConcatDataset, DataLoader +from torchvision import transforms +from torchvision.transforms import InterpolationMode +from dataset.it_dataset import ITImgTrainDataset, ITVidTrainDataset + + +def get_media_type(dataset_config): + if len(dataset_config) == 3 and dataset_config[2] == "video": + return "video" + elif dataset_config[-1] == "only_video": + return "only_video" + else: + return "image" + + +def create_dataset(dataset_type, config): + if "clip" in config.model.get("vit_model", 'vit'): + mean = (0.485, 0.456, 0.406) + std = (0.229, 0.224, 0.225) + else: + vision_enc_name = config.model.vision_encoder.name + if "swin" in vision_enc_name or "vit" in vision_enc_name: + mean = (0.485, 0.456, 0.406) + std = (0.229, 0.224, 0.225) + elif "beit" in vision_enc_name: + mean = (0.5, 0.5, 0.5) # for all beit model except IN1K finetuning + std = (0.5, 0.5, 0.5) + elif "clip" in vision_enc_name: + mean = (0.48145466, 0.4578275, 0.40821073) + std = (0.26862954, 0.26130258, 0.27577711) + else: + raise ValueError + + normalize = transforms.Normalize(mean, std) + + # loaded images and videos are torch.Tensor of torch.uint8 format, + # ordered as (T, 1 or 3, H, W) where T=1 for image + type_transform = transforms.Lambda(lambda x: x.float().div(255.0)) + + if config.inputs.video_input.random_aug: + aug_transform = transforms.RandAugment() + else: + aug_transform = transforms.Lambda(lambda x: x) + + train_transform = transforms.Compose( + [ + aug_transform, + transforms.RandomResizedCrop( + config.inputs.image_res, + scale=(0.5, 1.0), + interpolation=InterpolationMode.BICUBIC, + ), + transforms.RandomHorizontalFlip(), + type_transform, + normalize, + ] + ) + test_transform = transforms.Compose( + [ + transforms.Resize( + (config.inputs.image_res, config.inputs.image_res), + interpolation=InterpolationMode.BICUBIC, + ), + type_transform, + normalize, + ] + ) + + video_reader_type = config.inputs.video_input.get("video_reader_type", "decord") + video_only_dataset_kwargs_train = dict( + video_reader_type=video_reader_type, + sample_type=config.inputs.video_input.sample_type, + num_frames=config.inputs.video_input.num_frames, + num_tries=3, # false tolerance + ) + + if dataset_type == "pt_train": + raise ValueError("NOT PRETRAINING YET") + elif dataset_type in ["it_train"]: + # convert to list of lists + train_files = ( + [config.train_file] if isinstance(config.train_file[0], str) else config.train_file + ) + train_media_types = sorted(list({get_media_type(e) for e in train_files})) + + train_datasets = [] + for m in train_media_types: + dataset_cls = ITImgTrainDataset if m == "image" else ITVidTrainDataset + # dataset of the same media_type will be mixed in a single Dataset object + _train_files = [e for e in train_files if get_media_type(e) == m] + + datasets = [] + for train_file in _train_files: + dataset_kwargs = dict( + ann_file=train_file, + transform=train_transform, + mm_alone=config.preprocess.get("mm_alone", True), + add_second_msg=config.preprocess.get("add_second_msg", True), + skip_short_sample=config.preprocess.get("skip_short_sample", False), + clip_transform=config.preprocess.get("clip_transform", False), + random_shuffle=config.preprocess.get("random_shuffle", True), + system=config.preprocess.get("system", ""), + role=config.preprocess.get('roles', ("Human", "Assistant")), + end_signal=config.preprocess.get('end_signal', "###"), + begin_signal=config.preprocess.get('begin_signal', ""), + ) + if m == "video": + video_only_dataset_kwargs_train.update({ + "start_token": config.model.get("start_token", ""), + }) + dataset_kwargs.update(video_only_dataset_kwargs_train) + if "tgif" in train_file[1]: + video_only_dataset_kwargs_train.update({ + "video_reader_type": "gif" + }) + dataset_kwargs.update(video_only_dataset_kwargs_train) + elif "webvid" in train_file[1]: + video_only_dataset_kwargs_train.update({ + "video_reader_type": "hdfs" + }) + else: + video_only_dataset_kwargs_train.update({ + "video_reader_type": "decord" + }) + dataset_kwargs.update(video_only_dataset_kwargs_train) + datasets.append(dataset_cls(**dataset_kwargs)) + dataset = ConcatDataset(datasets) + train_datasets.append(dataset) + return train_datasets + + +def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns): + loaders = [] + for dataset, sampler, bs, n_worker, is_train, collate_fn in zip( + datasets, samplers, batch_size, num_workers, is_trains, collate_fns + ): + if is_train: + shuffle = sampler is None + drop_last = True + else: + shuffle = False + drop_last = False + loader = DataLoader( + dataset, + batch_size=bs, + num_workers=n_worker, + pin_memory=False, + sampler=sampler, + shuffle=shuffle, + collate_fn=collate_fn, + drop_last=drop_last, + persistent_workers=True if n_worker > 0 else False, + ) + loaders.append(loader) + return loaders + diff --git a/dataset/base_dataset.py b/dataset/base_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..36a8aab3befaddb6fd4724687390d6f3879b5837 --- /dev/null +++ b/dataset/base_dataset.py @@ -0,0 +1,108 @@ +import logging +import os +import json +import random +from torch.utils.data import Dataset +import time +from dataset.utils import load_image_from_path + +try: + from petrel_client.client import Client + has_client = True +except ImportError: + has_client = False + +logger = logging.getLogger(__name__) + + +class ImageVideoBaseDataset(Dataset): + """Base class that implements the image and video loading methods""" + + media_type = "video" + + def __init__(self): + assert self.media_type in ["image", "video", "only_video"] + self.data_root = None + self.anno_list = ( + None # list(dict), each dict contains {"image": str, # image or video path} + ) + self.transform = None + self.video_reader = None + self.num_tries = None + + self.client = None + if has_client: + self.client = Client('~/petreloss.conf') + + def __getitem__(self, index): + raise NotImplementedError + + def __len__(self): + raise NotImplementedError + + def get_anno(self, index): + """obtain the annotation for one media (video or image) + + Args: + index (int): The media index. + + Returns: dict. + - "image": the filename, video also use "image". + - "caption": The caption for this file. + + """ + anno = self.anno_list[index] + if self.data_root is not None: + anno["image"] = os.path.join(self.data_root, anno["image"]) + return anno + + def load_and_transform_media_data(self, index, data_path): + if self.media_type == "image": + return self.load_and_transform_media_data_image(index, data_path, clip_transform=self.clip_transform) + else: + return self.load_and_transform_media_data_video(index, data_path, clip_transform=self.clip_transform) + + def load_and_transform_media_data_image(self, index, data_path, clip_transform=False): + image = load_image_from_path(data_path, client=self.client) + if not clip_transform: + image = self.transform(image) + return image, index + + def load_and_transform_media_data_video(self, index, data_path, return_fps=False, clip=None, clip_transform=False): + for _ in range(self.num_tries): + try: + max_num_frames = self.max_num_frames if hasattr(self, "max_num_frames") else -1 + if "webvid" in data_path: + hdfs_dir="hdfs://harunava/home/byte_ailab_us_cvg/user/weimin.wang/videogen_data/webvid_data/10M_full_train" + video_name = os.path.basename(data_path) + video_id, extension = os.path.splitext(video_name) + ind_file = os.path.join(hdfs_dir, self.keys_indexfile[video_id]) + frames, frame_indices, fps = self.video_reader(ind_file, video_id, self.num_frames, self.sample_type, + max_num_frames=max_num_frames, client=self.client, clip=clip) + else: + frames, frame_indices, fps = self.video_reader( + data_path, self.num_frames, self.sample_type, + max_num_frames=max_num_frames, client=self.client, clip=clip + ) + except Exception as e: + logger.warning( + f"Caught exception {e} when loading video {data_path}, " + f"randomly sample a new video as replacement" + ) + index = random.randint(0, len(self) - 1) + ann = self.get_anno(index) + data_path = ann["image"] + continue + # shared aug for video frames + if not clip_transform: + frames = self.transform(frames) + if return_fps: + sec = [str(round(f / fps, 1)) for f in frame_indices] + return frames, index, sec + else: + return frames, index + else: + raise RuntimeError( + f"Failed to fetch video after {self.num_tries} tries. " + f"This might indicate that you have many corrupted videos." + ) diff --git a/dataset/it_dataset.py b/dataset/it_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..dd67b7d7e1168361c7e3f80bdd55d3d4797a3499 --- /dev/null +++ b/dataset/it_dataset.py @@ -0,0 +1,206 @@ +import logging +import os +import json +import sqlite3 +import random +from os.path import basename + +import numpy as np +import datetime + +from dataset.base_dataset import ImageVideoBaseDataset +from dataset.video_utils import VIDEO_READER_FUNCS + +logger = logging.getLogger(__name__) +IMAGE_TOKEN="" + +class ITImgTrainDataset(ImageVideoBaseDataset): + media_type = "image" + + def __init__( + self, ann_file, transform, + system="", role=("Human", "Assistant"), + mm_alone=True, + add_second_msg=True, + start_token="", end_token="", + random_shuffle=True, # if True, shuffle the QA list ##xl:????? why need random shuffle + begin_signal=None, + end_signal=None, + clip_transform=False, + skip_short_sample=False, + ): + super().__init__() + self.mm_alone = mm_alone + self.clip_transform = clip_transform + if len(ann_file) == 3 and ann_file[2] == "video": + self.media_type = "video" + else: + self.media_type = "image" + self.label_file, self.data_root = ann_file[:2] + + logger.info('Load json file') + with open(self.label_file, 'r') as f: + self.anno = json.load(f) + self.num_examples = len(self.anno) + self.transform = transform + annos = [] + for ann in self.anno: + filename = ann['video'] if 'video' in ann else ann['image'] + if self.media_type =='video' and "webvid" in self.data_root: + video_id, extension = os.path.splitext(os.path.basename(filename)) + if video_id not in self.keys_indexfile: + pass + else: + annos.append(ann) + else: + + if filename is None or filename=="None": + pass + else: + if os.path.exists(os.path.join(self.data_root, filename)): + annos.append(ann) + else: + ... + self.anno = annos + self.num_examples = len(self.anno) + + + # prompt parameters + if system: + assert system[-1] == " ", "' ' should be add in the end of system, thus '###' will be tokenized into one token." + # currently not support add start_token and end_token in the system, since the msg should be added properly + self.begin_signal = [begin_signal for _ in role] if isinstance(begin_signal, str) else begin_signal + self.end_signal = [end_signal for _ in role] if isinstance(end_signal, str) else end_signal + self.start_token = start_token + self.end_token = end_token + self.system = system + self.role = role + self.random_shuffle = random_shuffle + # instruction location and number + logger.info(f"Random shuffle: {self.random_shuffle}") + + def get_anno(self, index): + filename = self.anno[index][self.media_type] + qa = self.anno[index]["QA"] + + if "start" in self.anno[index] and "end" in self.anno[index]: + anno = { + "image": os.path.join(self.data_root, filename), "qa": qa, + "start": self.anno[index]["start"], "end": self.anno[index]["end"], + } + else: + anno = {"image": os.path.join(self.data_root, filename), "qa": qa} + return anno + + def __len__(self): + return self.num_examples + + def process_qa(self, qa, msg=""): + cur_instruction = "" + # randomly shuffle qa for conversation + if self.random_shuffle and len(qa) > 1: + random.shuffle(qa) + if "i" in qa[0].keys() and qa[0]["i"] != "": + cur_instruction = qa[0]["i"] + self.end_signal[0] + + conversation = self.system + # add instruction as system message + if cur_instruction: + conversation += cur_instruction + + # rstrip() for the extra " " in msg + if self.mm_alone: + conversation += ( + self.begin_signal[0] + self.role[0] + + self.start_token + self.end_token + msg.rstrip() + self.end_signal[0] + ) + + for i, sentence in enumerate(qa): + q = self.start_token + self.end_token+"\n"+ qa[0]["q"] if (not self.mm_alone) and (i == 0) else sentence["q"] + a = sentence["a"] + if q != "": + conversation += (self.begin_signal[0] + self.role[0] + q + self.end_signal[1]) + else: + # no question, often in caption dataset + pass + conversation += (self.begin_signal[0] + self.role[1] + a + self.end_signal[1]) + + + if cur_instruction: + cur_instruction += qa[0]["q"] + return conversation, cur_instruction.strip() + + def __getitem__(self, index): + try: + ann = self.get_anno(index) + image, index = self.load_and_transform_media_data_image(index, ann["image"], clip_transform=self.clip_transform) + conversation, instruction = self.process_qa(ann["qa"]) + return image, conversation, instruction, index + except Exception as e: + logger.warning(f"Caught exception {e} when loading image {ann['image']}") + index = np.random.randint(0, len(self)) + return self.__getitem__(index) + + +class ITVidTrainDataset(ITImgTrainDataset): + media_type = "video" + + def __init__( + self, ann_file, transform, + num_frames=4, video_reader_type="decord", sample_type="rand", num_tries=3, + mm_alone=True, + system="", role=("Human", "Assistant"), + start_token="", + add_second_msg=True, + random_shuffle=True, + begin_signal=None, + end_signal=None, + clip_transform=False, + skip_short_sample=False, + + ): + # "id index file for webvid" + if "webvid" in ann_file[1]: + with open("/mnt/bn/dq-storage-ckpt/xulin/datasets/videos/webvid_10m/keys_indexfile.json") as f: + self.keys_indexfile = json.load(f) # the correponding index file for each webvid id + + super().__init__( + ann_file, transform, + system=system, role=role, + mm_alone=mm_alone, + start_token=start_token, end_token=end_token, + random_shuffle=random_shuffle, + begin_signal=begin_signal, + end_signal=end_signal, + clip_transform=clip_transform, + skip_short_sample=skip_short_sample, + ) + self.num_frames = num_frames + self.video_reader_type = video_reader_type + self.video_reader = VIDEO_READER_FUNCS[video_reader_type] + self.sample_type = sample_type + self.num_tries = num_tries + self.add_second_msg = add_second_msg + + logger.info(f"Use {video_reader_type} for data in {ann_file}") + if add_second_msg: + logger.info(f"Add second message: The video contains X frames sampled at T seconds.") + + def __getitem__(self, index): + try: + ann = self.get_anno(index) + + msg = "" + clip = None + if "start" in ann and "end" in ann: + clip = [ann["start"], ann["end"]] + video, index, sec = self.load_and_transform_media_data_video(index, ann["image"], return_fps=True, clip=clip, clip_transform=self.clip_transform) + if self.add_second_msg: + # " " should be added in the start and end + msg = f" The video contains {len(sec)} frames sampled at {', '.join(sec)} seconds. " + conversation, instruction = self.process_qa(ann["qa"], msg) + return video, conversation, instruction, index + except Exception as e: + logger.warning(f"Caught exception {e} when loading video {ann['image']}") + index = np.random.randint(0, len(self)) + return self.__getitem__(index) \ No newline at end of file diff --git a/dataset/utils.py b/dataset/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..77dea14908e4914a2cdd5e91658bd89e02ab1105 --- /dev/null +++ b/dataset/utils.py @@ -0,0 +1,41 @@ +from utils.distributed import is_main_process, get_rank, get_world_size +import io +import json +import re +import numpy as np +from os.path import join +from tqdm import trange +from PIL import Image +from PIL import ImageFile +from torchvision.transforms import PILToTensor +ImageFile.LOAD_TRUNCATED_IMAGES = True +Image.MAX_IMAGE_PIXELS = None + + +def load_image_from_path(image_path, client): + if image_path.startswith('s3') or image_path.startswith('p2'): + value = client.Get(image_path) + img_bytes = np.frombuffer(value, dtype=np.uint8) + buff = io.BytesIO(img_bytes) + image = Image.open(buff).convert('RGB') + else: + image = Image.open(image_path).convert('RGB') # PIL Image + image = PILToTensor()(image).unsqueeze(0) # (1, C, H, W), torch.uint8 + return image + +def pre_text(text, max_l=None, pre_text=True): + if pre_text: + text = re.sub(r"([,.'!?\"()*#:;~])", '', text.lower()) + text = text.replace('-', ' ').replace('/', ' ').replace('', 'person') + + text = re.sub(r"\s{2,}", ' ', text) + text = text.rstrip('\n').strip(' ') + + if max_l: # truncate + words = text.split(' ') + if len(words) > max_l: + text = ' '.join(words[:max_l]) + else: + pass + return text + diff --git a/dataset/video_utils.py b/dataset/video_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..11103d7ca67a4aa36f1c1a3ddd154f3dae5d3a78 --- /dev/null +++ b/dataset/video_utils.py @@ -0,0 +1,214 @@ +""" +Modified from https://github.com/m-bain/frozen-in-time/blob/22a91d78405ec6032fdf521ae1ff5573358e632f/base/base_dataset.py +""" +import random +import io +import os +import av +import cv2 +import decord +import imageio +from decord import VideoReader + +# from dataloader import KVReader +import torch +import numpy as np +import math +# import tensorflow as tf +decord.bridge.set_bridge("torch") + +import logging +logger = logging.getLogger(__name__) + +def pts_to_secs(pts: int, time_base: float, start_pts: int) -> float: + """ + Converts a present time with the given time base and start_pts offset to seconds. + + Returns: + time_in_seconds (float): The corresponding time in seconds. + + https://github.com/facebookresearch/pytorchvideo/blob/main/pytorchvideo/data/utils.py#L54-L64 + """ + if pts == math.inf: + return math.inf + + return int(pts - start_pts) * time_base + + +def get_pyav_video_duration(video_reader): + video_stream = video_reader.streams.video[0] + video_duration = pts_to_secs( + video_stream.duration, + video_stream.time_base, + video_stream.start_time + ) + return float(video_duration) + + +def get_frame_indices_by_fps(): + pass + + +def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1): + if sample in ["rand", "middle"]: # uniform sampling + acc_samples = min(num_frames, vlen) + # split the video into `acc_samples` intervals, and sample from each interval. + intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int) + ranges = [] + for idx, interv in enumerate(intervals[:-1]): + ranges.append((interv, intervals[idx + 1] - 1)) + if sample == 'rand': + try: + frame_indices = [random.choice(range(x[0], x[1])) for x in ranges] + except: + frame_indices = np.random.permutation(vlen)[:acc_samples] + frame_indices.sort() + frame_indices = list(frame_indices) + elif fix_start is not None: + frame_indices = [x[0] + fix_start for x in ranges] + elif sample == 'middle': + frame_indices = [(x[0] + x[1]) // 2 for x in ranges] + else: + raise NotImplementedError + + if len(frame_indices) < num_frames: # padded with last frame + padded_frame_indices = [frame_indices[-1]] * num_frames + padded_frame_indices[:len(frame_indices)] = frame_indices + frame_indices = padded_frame_indices + elif "fps" in sample: # fps0.5, sequentially sample frames at 0.5 fps + output_fps = float(sample[3:]) + duration = float(vlen) / input_fps + delta = 1 / output_fps # gap between frames, this is also the clip length each frame represents + frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta) + frame_indices = np.around(frame_seconds * input_fps).astype(int) + frame_indices = [e for e in frame_indices if e < vlen] + if max_num_frames > 0 and len(frame_indices) > max_num_frames: + frame_indices = frame_indices[:max_num_frames] + # frame_indices = np.linspace(0 + delta / 2, duration + delta / 2, endpoint=False, num=max_num_frames) + else: + raise ValueError + return frame_indices + + +def read_frames_av( + video_path, num_frames, sample='rand', fix_start=None, + max_num_frames=-1, client=None, clip=None, + ): + reader = av.open(video_path) + frames = [torch.from_numpy(f.to_rgb().to_ndarray()) for f in reader.decode(video=0)] + vlen = len(frames) + duration = get_pyav_video_duration(reader) + fps = vlen / float(duration) + frame_indices = get_frame_indices( + num_frames, vlen, sample=sample, fix_start=fix_start, + input_fps=fps, max_num_frames=max_num_frames + ) + frames = torch.stack([frames[idx] for idx in frame_indices]) # (T, H, W, C), torch.uint8 + frames = frames.permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8 + return frames, frame_indices, fps + + +def read_frames_gif( + video_path, num_frames, sample='rand', fix_start=None, + max_num_frames=-1, client=None, clip=None, + ): + if video_path.startswith('s3') or video_path.startswith('p2'): + video_bytes = client.get(video_path) + gif = imageio.get_reader(io.BytesIO(video_bytes)) + else: + gif = imageio.get_reader(video_path) + vlen = len(gif) + frame_indices = get_frame_indices( + num_frames, vlen, sample=sample, fix_start=fix_start, + max_num_frames=max_num_frames + ) + frames = [] + for index, frame in enumerate(gif): + # for index in frame_idxs: + if index in frame_indices: + frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB) + frame = torch.from_numpy(frame).byte() + # # (H x W x C) to (C x H x W) + frame = frame.permute(2, 0, 1) + frames.append(frame) + frames = torch.stack(frames) # .float() / 255 + + return frames, frame_indices, 25. # for tgif + + +def read_frames_hdfs(ind_file, vid, num_frames, sample='rand',fix_start=None, + max_num_frames=-1, client=None, clip=None): + _context_features = {'title': tf.io.FixedLenFeature([], dtype=tf.string)} + _sequence_features = {'data': tf.io.FixedLenSequenceFeature([], dtype=tf.string)} + num_parallel_reader = 1 + filename, extension = os.path.splitext(ind_file) + reader = KVReader(filename, num_parallel_reader) + key = vid + values = reader.read_many([key]) + item = values[0] + contexts, sequences = tf.io.parse_single_sequence_example( + serialized=item, + context_features=_context_features, + sequence_features=_sequence_features) + + # text = contexts['title'].numpy().decode("utf-8") + rawframes = sequences['data'] + vlen = len(rawframes) + sample="rand" + + frame_indices = get_frame_indices(num_frames, vlen, sample=sample, + fix_start=fix_start, + max_num_frames=max_num_frames) + def read_image(raw_data): + return tf.image.decode_jpeg(raw_data, channels=3, dct_method='INTEGER_ACCURATE').numpy() + + frames = [] + for index, frame in enumerate(rawframes): + if index in frame_indices: + frame = read_image(frame) + frame = torch.as_tensor(frame) + frames.append(frame) + + frames = torch.stack(frames) + # print("in hdfs========>",frames[0]) + frames = frames.permute(0, 3, 1, 2) + return frames, frame_indices, 25 # don't know the fps for index + + +def read_frames_decord( + video_path, num_frames, sample='rand', fix_start=None, + max_num_frames=-1, client=None, clip=None + ): + if video_path.startswith('s3') or video_path.startswith('p2'): + video_bytes = client.get(video_path) + video_reader = VideoReader(io.BytesIO(video_bytes), num_threads=1) + else: + video_reader = VideoReader(video_path, num_threads=1) + vlen = len(video_reader) + fps = video_reader.get_avg_fps() + duration = vlen / float(fps) + + if clip: + start, end = clip + duration = end - start + vlen = int(duration * fps) + start_index = int(start * fps) + + frame_indices = get_frame_indices( + num_frames, vlen, sample=sample, fix_start=fix_start, + input_fps=fps, max_num_frames=max_num_frames + ) + if clip: + frame_indices = [f + start_index for f in frame_indices] + + frames = video_reader.get_batch(frame_indices) # (T, H, W, C), torch.uint8 + frames = frames.permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8 + return frames, frame_indices, float(fps) + + +VIDEO_READER_FUNCS = { + 'av': read_frames_av, + 'decord': read_frames_decord, + 'gif': read_frames_gif, + 'hdfs': read_frames_hdfs, +} diff --git a/docs/PoolLLaVA_Report.pdf b/docs/PoolLLaVA_Report.pdf new file mode 100644 index 0000000000000000000000000000000000000000..7eb1089a479bda55dc186c374510c4c90b23272f --- /dev/null +++ b/docs/PoolLLaVA_Report.pdf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6b9f175bd915cdc6f9791a95149992fde1f48ebfffa6c8bff9e6365b7186c57d +size 3850702 diff --git a/example/1917.mp4 b/example/1917.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..1eed1efb72a9efa597eb1416bf8f505ff90df33c --- /dev/null +++ b/example/1917.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:99f5f2a10985964ddc0555a8fa12b9d41f130b49ad62879a9e150d91834e93d5 +size 1535936 diff --git a/example/bear.jpg b/example/bear.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cffbe5ccb6baa239a3a970655dac89cd31e73274 --- /dev/null +++ b/example/bear.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:286b3a5693322edf01870a561e35016ed46a7cb4b9194c58e2f3526eab1f9efc +size 376329 diff --git a/example/cooking.mp4 b/example/cooking.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..53d1d0c7d483ed555d35126bff74a66528dd1929 --- /dev/null +++ b/example/cooking.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6a1395530cc13c0441ae99ce66477f533f6009ebdb913064aec91e38eaf3b8e9 +size 876622 diff --git a/example/dog.png b/example/dog.png new file mode 100644 index 0000000000000000000000000000000000000000..0fbd8e56d7b4c4c1c637684f3282441500dbdc19 --- /dev/null +++ b/example/dog.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:919b6e24d3cc7d7998181029fb76e94d8149e6a9d2c4930445fa217f6715716d +size 562829 diff --git a/example/jesse_dance.mp4 b/example/jesse_dance.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..1663b109ac8eb2cdacd5312e94c982bdc62a2250 --- /dev/null +++ b/example/jesse_dance.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f1fc41c6ebae0692726ea56b33ba711f21186fd4203ac54cd43a5cd898be4350 +size 1221420 diff --git a/example/working.mp4 b/example/working.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..5a801e0f820eaccd4422bf1767e7541e16ee0ff4 --- /dev/null +++ b/example/working.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:09372cdb6b0ea272868b4469d5067674670a948962f1236196e8f23e1f7ce764 +size 4718899 diff --git a/example/yoga.mp4 b/example/yoga.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..3b4271609367609234a3d509f4a8252bf2a194b1 --- /dev/null +++ b/example/yoga.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:74b65d9bec7f83e487b7f923076c01d476dd2ef7ed83928a696ab6f88c7751b7 +size 776184 diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/pllava/__init__.py b/models/pllava/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0c9a9a6b2d8ce0c839df55e4cc70cb5ed9b37e6a --- /dev/null +++ b/models/pllava/__init__.py @@ -0,0 +1,55 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from transformers.utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available + + +_import_structure = {"configuration_pllava": ["PLLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP", "PllavaConfig"]} + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_pllava"] = [ + "PLLAVA_PRETRAINED_MODEL_ARCHIVE_LIST", + "PllavaForConditionalGeneration", + "PllavaPreTrainedModel", + ] + _import_structure["processing_pllava"] = ["PllavaProcessor"] + + +if TYPE_CHECKING: + from .configuration_pllava import PLLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP, PllavaConfig + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_pllava import ( + PLLAVA_PRETRAINED_MODEL_ARCHIVE_LIST, + PllavaForConditionalGeneration, + PllavaPreTrainedModel, + ) + from .processing_pllava import PllavaProcessor + + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) diff --git a/models/pllava/configuration_pllava.py b/models/pllava/configuration_pllava.py new file mode 100644 index 0000000000000000000000000000000000000000..6c429ce7120a5c184768aae58dce8c8c379985b5 --- /dev/null +++ b/models/pllava/configuration_pllava.py @@ -0,0 +1,149 @@ +# coding=utf-8 +# Copyright 2023 Microsoft Research & University of Wisconsin-Madison and the HuggingFace Inc. team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Llava model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging +from transformers.models.auto import CONFIG_MAPPING + + +logger = logging.get_logger(__name__) + +PLLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "llava-hf/llava-v1.5-7b": "https://huggingface.co/llava-hf/llava-v1.5-7b/resolve/main/config.json", +} + + +class PllavaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`LlavaForConditionalGeneration`]. It is used to instantiate an + Llava model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Llava-9B. + + e.g. [llava-hf/llava-9b](https://huggingface.co/llava-hf/llava-9b) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vision_config (`LlavaVisionConfig`, *optional*): + Custom vision config or dict + text_config (`Union[AutoConfig, dict]`, *optional*): + The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`. + ignore_index (`int`, *optional*, defaults to -100): + The ignore index for the loss function. + image_token_index (`int`, *optional*, defaults to 32000): + The image token index to encode the image prompt. + projector_hidden_act (`str`, *optional*, defaults to `"gelu"`): + The activation function used by the multimodal projector. + vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`): + The feature selection strategy used to select the vision feature from the CLIP backbone. + vision_feature_layer (`int`, *optional*, defaults to -2): + The index of the layer to select the vision feature. + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Llava model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`~LlavaForConditionalGeneration`] + + Example: + + ```python + >>> from transformers import LlavaForConditionalGeneration, LlavaConfig, CLIPVisionConfig, LlamaConfig + + >>> # Initializing a CLIP-vision config + >>> vision_config = CLIPVisionConfig() + + >>> # Initializing a Llama config + >>> text_config = LlamaConfig() + + >>> # Initializing a Llava llava-1.5-7b style configuration + >>> configuration = LlavaConfig(vision_config, text_config) + + >>> # Initializing a model from the llava-1.5-7b style configuration + >>> model = LlavaForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "llava" + is_composition = False + + def __init__( + self, + vision_config=None, + text_config=None, + ignore_index=-100, + image_token_index=32000, + projector_hidden_act="gelu", + vision_feature_select_strategy="default", + vision_feature_layer=-2, + vocab_size=32000, + pooling_method='avg', + pooling_shape=(8, 16, 16), + frame_shape=(24, 24), # llava 1.5 pretrained frame shape + num_frames=1, # llava 1.5 pretrained frame shape + use_pooling=True, + gradient_checkpointing=False, + **kwargs, + ): + self.ignore_index = ignore_index + self.image_token_index = image_token_index + self.projector_hidden_act = projector_hidden_act + self.vision_feature_select_strategy = vision_feature_select_strategy + self.vision_feature_layer = vision_feature_layer + self.vocab_size = vocab_size + self.use_pooling = use_pooling + self.gradient_checkpointing = gradient_checkpointing + + self.vision_config = vision_config + + self.pooling_method = pooling_method # should be in 'max', 'avg' + self.pooling_shape = pooling_shape # + self.frame_shape = frame_shape # + self.num_frames = num_frames + if isinstance(self.vision_config, dict): + vision_config["model_type"] = ( + vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model" + ) + self.vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) + elif vision_config is None: + self.vision_config = CONFIG_MAPPING["clip_vision_model"]( + intermediate_size=4096, + hidden_size=1024, + patch_size=14, + image_size=336, + num_hidden_layers=24, + num_attention_heads=16, + vocab_size=32000, + projection_dim=768, + ) + self.vocab_size = self.vocab_size + + self.text_config = text_config + + if isinstance(self.text_config, dict): + text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama" + self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + self.vocab_size = self.text_config.vocab_size + self.text_config.gradient_checkpointing = self.gradient_checkpointing + + elif text_config is None: + tmp_config = {"_attn_implementation":"flash_attention_2", + "gradient_checkpointing": self.gradient_checkpointing} + self.text_config = CONFIG_MAPPING["llama"](**tmp_config) + self.text_config.gradient_checkpointing = self.gradient_checkpointing + # self.text_config["_attn_implementation"]="flash_attention_2" # xl: temporal hard code + + + super().__init__(**kwargs) diff --git a/models/pllava/convert_pllava_weights_to_hf.py b/models/pllava/convert_pllava_weights_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..68085e559a09da8c2142fefc42b2ddcdf5fe295a --- /dev/null +++ b/models/pllava/convert_pllava_weights_to_hf.py @@ -0,0 +1 @@ +# Not yet \ No newline at end of file diff --git a/models/pllava/modeling_pllava.py b/models/pllava/modeling_pllava.py new file mode 100644 index 0000000000000000000000000000000000000000..a420eb60cc4871f0f26b27fddabb949c501d2e57 --- /dev/null +++ b/models/pllava/modeling_pllava.py @@ -0,0 +1,626 @@ +# coding=utf-8 +# Copyright 2023 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Llava model.""" +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union +import math + +import torch +import torch.utils.checkpoint +from torch import nn +import os +from transformers import PreTrainedModel +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache +from transformers.modeling_outputs import ModelOutput +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from transformers.models.auto import AutoModel, AutoModelForCausalLM +import einops + +from .configuration_pllava import PllavaConfig +import pickle + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "LlavaConfig" + +PLLAVA_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "", + "", + "", + # See all Llava models at https://huggingface.co/models?filter=llava +] + + +@dataclass +# Copied from transformers.models.idefics.modeling_idefics.IdeficsCausalLMOutputWithPast with Idefics->Llava +class PllavaCausalLMOutputWithPast(ModelOutput): + """ + Base class for Llava causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): + Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images, + sequence_length, hidden_size)`. + + image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + +class PllavaMultiModalProjector(nn.Module): + supported_highres = ['pad_crop_four', 'slide', ] + def __init__(self, config: PllavaConfig): + super().__init__() + self.use_pooling = config.use_pooling + self.frame_shape=config.frame_shape + self.num_frames = config.num_frames + self.pooling_shape = config.pooling_shape + + self.pooling = nn.AdaptiveAvgPool3d(config.pooling_shape) + self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True) + self.act = ACT2FN[config.projector_hidden_act] + self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True) + + def convert_Fembeddings2video(self, input, num_videos, frame_shape): + input = einops.rearrange(input, + '(num_videos num_frames) (h w) embed_dims -> num_videos embed_dims num_frames h w', + num_videos=num_videos, h=frame_shape[0]) + return input + + def convert_video2Fembeddings(self, input): + input = einops.rearrange(input, 'num_videos embed_dims num_frames h w -> (num_videos num_frames) (h w) embed_dims ', ) + return input + + def convert_video2MMembeddings(self, input): + input = einops.rearrange(input, 'num_videos embed_dims num_frames h w -> num_videos (num_frames h w) embed_dims ', ) + return input + + def forward(self, image_features, media_type, batch_size=None, num_videos=None): + frame_shape = self.frame_shape + num_frames = self.num_frames + assert media_type in ( 'video', 'image'), f'only image or video, but got media_type {media_type}' + hidden_states = image_features + + if media_type == 'image': + hidden_states = hidden_states.repeat(num_frames, 1, 1) + + total_frames, spatial_seqlen, embed_dims = hidden_states.shape + #TODO: temporal code, should ensure num_frames == total frames in data loading later + if total_frames < num_frames and self.use_pooling: # + multiplier = int(num_frames/total_frames)+1 + hidden_states= hidden_states.repeat_interleave(multiplier, dim=0)[:num_frames] + total_frames, spatial_seqlen, embed_dims = hidden_states.shape + + assert total_frames % num_frames == 0 + assert frame_shape[0] * frame_shape[1] == spatial_seqlen + hidden_states = self.linear_1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + hidden_states_videos = self.convert_Fembeddings2video(hidden_states, num_videos * batch_size, frame_shape) + hidden_states_videos = self.pooling(hidden_states_videos) + hidden_states = einops.rearrange(hidden_states_videos, 'batch_size_num_videos embed_dims num_frames h w -> batch_size_num_videos num_frames (h w) embed_dims', ) + hidden_states = einops.rearrange(hidden_states, 'batch_size_num_videos num_frames hw embed_dims -> batch_size_num_videos (num_frames hw) embed_dims ') + return hidden_states + + + +PLLAVA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`LlavaConfig`] or [`LlavaVisionConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + PLLAVA_START_DOCSTRING, +) +class PllavaPreTrainedModel(PreTrainedModel): + config_class = PllavaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlavaVisionAttention"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + + def _init_weights(self, module): + # important: this ported version of Llava isn't meant for training from scratch - only + # inference and fine-tuning - so the proper init weights code has been removed - the original codebase + # https://github.com/haotian-liu/LLaVA/tree/main/llava should serve for that purpose + std = ( + self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.config.text_config.initializer_range + ) + + if hasattr(module, "class_embedding"): + module.class_embedding.data.normal_(mean=0.0, std=std) + + # if isinstance(module, (nn.Linear, nn.Conv2d)): + # module.weight.data.normal_(mean=0.0, std=std) + # if module.bias is not None: + # module.bias.data.zero_() + + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + elif isinstance(module, PllavaMultiModalProjector): + # module.register_embed.data.normal_(mean=0.0, std=std) + if self.config.register: + module.register_embed.data.zero_() + + @property + def _supports_sdpa(self): + """ + Retrieve language_model's attribute to check whether the model supports + SDPA or not. + """ + return self.language_model._supports_sdpa + + +PLLAVA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): + The tensors corresponding to the input images. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`LlavaProcessor`] uses + [`CLIPImageProcessor`] for processing images). + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + """The LLAVA model which consists of a vision backbone and a language model.""", + PLLAVA_START_DOCSTRING, +) +class PllavaForConditionalGeneration(PllavaPreTrainedModel): + def __init__(self, config: PllavaConfig): + super().__init__(config) + self.config = config + self.vision_tower = AutoModel.from_config(config.vision_config) + self.multi_modal_projector = PllavaMultiModalProjector(config) + self.vocab_size = config.vocab_size + # self.language_model = AutoModelForCausalLM.from_config(config.text_config, torch_dtype=config.torch_dtype, attn_implementation="flash_attention_2") + self.language_model = AutoModelForCausalLM.from_config(config.text_config, torch_dtype=config.torch_dtype, attn_implementation="eager") + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else self.config.text_config.pad_token_id + assert self.pad_token_id is not None, 'provide the model with pad_token_id, this would be used to arranging new embedings' + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.language_model.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + self.language_model.set_output_embeddings(new_embeddings) + + def set_decoder(self, decoder): + self.language_model.set_decoder(decoder) + + def get_decoder(self): + return self.language_model.get_decoder() + + def tie_weights(self): + return self.language_model.tie_weights() + + def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding: + model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + # update vocab size + self.config.text_config.vocab_size = model_embeds.num_embeddings + self.config.vocab_size = model_embeds.num_embeddings + self.vocab_size = model_embeds.num_embeddings + return model_embeds + + def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels): + num_images, num_image_patches, embed_dim = image_features.shape + batch_size, sequence_length = input_ids.shape + left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id)) + # 1. Create a mask to know where special image tokens are + special_image_token_mask = input_ids == self.config.image_token_index + num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) + # Compute the maximum embed dimension + max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length + batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index) + + # 2. Compute the positions where text should be written + # Calculate new positions for text tokens in merged image-text sequence. + # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens. + # `torch.cumsum` computes how each image token shifts subsequent text token positions. + # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. + new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1 + nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1] + if left_padding: + new_token_positions += nb_image_pad[:, None] # offset for left padding + text_to_overwrite = new_token_positions[batch_indices, non_image_indices] + + # 3. Create the full embedding, already padded to the maximum position + final_embedding = torch.zeros( + batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) + final_attention_mask = torch.zeros( + batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device + ) + if labels is not None: + final_labels = torch.full( + (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device + ) + # In case the Vision model or the Language model has been offloaded to CPU, we need to manually + # set the corresponding tensors into their correct target device. + target_device = inputs_embeds.device + batch_indices, non_image_indices, text_to_overwrite = ( + batch_indices.to(target_device), + non_image_indices.to(target_device), + text_to_overwrite.to(target_device), + ) + attention_mask = attention_mask.to(target_device) + + # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"] + # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features + final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] + final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] + if labels is not None: + final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] + + # 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling + image_to_overwrite = torch.all(final_embedding == 0, dim=-1) + image_to_overwrite &= image_to_overwrite.cumsum(-1) > nb_image_pad[:, None].to(target_device) + + # # somthing really weird here. + # temp1 = (image_to_overwrite.cumsum(-1) > nb_image_pad[:, None].to(target_device)) & image_to_overwrite + # # this is for right padding + # temp2 = (image_to_overwrite.cumsum(-1) <= num_special_image_tokens.max() * num_image_patches - nb_image_pad[:, None]) & image_to_overwrite + + if image_to_overwrite.sum() != image_features.shape[:-1].numel(): + raise ValueError( + f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while" + f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation." + ) + + final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device) + final_attention_mask |= image_to_overwrite + position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) + + if labels is None: + final_labels = None + + return final_embedding, final_attention_mask, final_labels, position_ids + + @add_start_docstrings_to_model_forward(PLLAVA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=PllavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + media_type: str = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[int] = None, + vision_feature_select_strategy: Optional[str] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, PllavaCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, LlavaForConditionalGeneration + + >>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf") + >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") + + >>> prompt = "\nUSER: What's the content of the image?\nASSISTANT:" + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(text=prompt, images=image, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_length=30) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "\nUSER: What's the content of the image?\nASSISTANT: The image features a stop sign on a street corner" + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + if inputs_embeds is None: + # 1. Extra the input embeddings + no_img_input_ids = torch.where(input_ids!=self.config.image_token_index, input_ids, self.pad_token_id) # some model used up all the embeddings + inputs_embeds = self.get_input_embeddings()(no_img_input_ids) + batch_size = inputs_embeds.shape[0] + # 2. Merge text and images + if pixel_values is not None and input_ids.shape[1] != 1: + image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) + # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated. + selected_image_feature = image_outputs.hidden_states[vision_feature_layer] # ( b, img_seqlen, embed_dim) + if vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + elif vision_feature_select_strategy == "full": + raise ValueError("not implemented") + selected_image_feature = selected_image_feature + else: + raise ValueError( + f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}" + ) + + image_features = self.multi_modal_projector(selected_image_feature, + media_type, + batch_size=batch_size, + num_videos=pixel_values.shape[0]//self.config.num_frames//batch_size,) + + inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( + image_features, inputs_embeds, input_ids, attention_mask, labels + ) + if labels is None: + labels = torch.full_like(attention_mask, self.config.ignore_index).to(torch.long) + else: + # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of + # generation with cache + if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1: + # Retrieve the first layer to inspect the logits and mask out the hidden states + # that are set to 0 + first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] + + # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 + batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) + + # Get the target length + target_seqlen = first_layer_past_key_value.shape[-1] + 1 + + extended_attention_mask = torch.ones( + (attention_mask.shape[0], target_seqlen - attention_mask.shape[1]), + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + + # Filter out only the tokens that can be un-attended, this can happen + # if one uses Llava + Fused modules where the cache on the + # first iteration is already big enough, or if one passes custom cache + valid_indices = non_attended_tokens < extended_attention_mask.size(-1) + new_batch_index = batch_index[valid_indices] + new_non_attended_tokens = non_attended_tokens[valid_indices] + + # Zero-out the places where we don't need to attend + extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 + + attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1) + position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = outputs[0] + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + if attention_mask is not None: + shift_attention_mask = attention_mask[..., 1:] + shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() + shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() + else: + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return PllavaCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, attention_mask=None, **kwargs + ): + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + else: + cache_length = past_length = past_key_values[0][0].shape[2] + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + elif self.config.image_token_index in input_ids: + input_ids = input_ids[:, input_ids.shape[1] - 1 :] + # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the + # older attention values, as their corresponding values are not part of the input. + if cache_length < past_length and attention_mask is not None: + attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + media_type = kwargs.get('media_type', None) + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + "pixel_values": pixel_values, + "media_type": media_type, + } + ) + return model_inputs + + def _reorder_cache(self, *args, **kwargs): + return self.language_model._reorder_cache(*args, **kwargs) diff --git a/models/pllava/processing_pllava.py b/models/pllava/processing_pllava.py new file mode 100644 index 0000000000000000000000000000000000000000..8f1211f0170b918628a5d4720c07478af2c18f35 --- /dev/null +++ b/models/pllava/processing_pllava.py @@ -0,0 +1,292 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for Llava. +""" + + +import itertools +from typing import List, Optional, Union +import PIL.Image +import numpy as np + +from transformers import AutoTokenizer +from transformers.feature_extraction_utils import BatchFeature +from transformers.image_utils import ( + ImageInput, + make_list_of_images, + valid_images, + infer_channel_dimension_format, + to_numpy_array, + get_image_size, + ChannelDimension, +) +from transformers.image_processing_utils import get_size_dict +from transformers.image_utils import PILImageResampling +from transformers.processing_utils import ProcessorMixin +from transformers.image_transforms import resize, pad, PaddingMode, to_channel_dimension_format, get_resize_output_image_size +from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy +from transformers.utils import TensorType + + +class PllavaProcessor(ProcessorMixin): + r""" + Constructs a Llava processor which wraps a Llava image processor and a Llava tokenizer into a single processor. + + [`LlavaProcessor`] offers all the functionalities of [`CLIPImageProcessor`] and [`LlamaTokenizerFast`]. See the + [`~LlavaProcessor.__call__`] and [`~LlavaProcessor.decode`] for more information. + + Args: + image_processor ([`CLIPImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`LlamaTokenizerFast`], *optional*): + The tokenizer is a required input. + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = "CLIPImageProcessor" + tokenizer_class = "AutoTokenizer" + + def __init__(self, image_processor=None, tokenizer=None, + shortest_edge=336, + longest_edge=762, + center_pad=False): + self.shortest_edge = shortest_edge + self.longest_edge = longest_edge + self.center_pad = center_pad + super().__init__(image_processor, tokenizer) + + def resize_crop_longshort(self, videos: list[list[np.ndarray]], input_data_format): + video_spatial_sizes = [get_image_size(images[0], input_data_format) for images in videos] + long_short_rates = [max(size) / min(size) for size in video_spatial_sizes] + min_long_short_rate = min(long_short_rates) + min_long_short_video_idx = long_short_rates.index(min_long_short_rate) + + clip_resolution = self.image_processor.size['shortest_edge'] + out_video_spatial_size = video_spatial_sizes[min_long_short_video_idx] + out_videos_short_edge = max(min(size) for size in video_spatial_sizes) + resize_longest_edge = max(max(size) for size in video_spatial_sizes) + resize_longest_edge = min(640, resize_longest_edge) + out_videos_short_edge = min(out_videos_short_edge, int(resize_longest_edge / min_long_short_rate)) + out_videos_short_edge = max(out_videos_short_edge, clip_resolution) + + + if out_video_spatial_size[0] > out_video_spatial_size[1]: # h > w: + out_video_spatial_size = (int(out_videos_short_edge * min_long_short_rate), out_videos_short_edge ) + else: + out_video_spatial_size = ( out_videos_short_edge, int(out_videos_short_edge * min_long_short_rate) ) + videos = [ + [self.resize(frame, input_data_format=input_data_format, shortest_edge=out_videos_short_edge, longest_edge=9999) for frame in frames] + for frames in videos + ] + out_videos = [] + for frames in videos: + out_frames = [] + video_spatial_size = get_image_size(frames[0], input_data_format) + assert min(video_spatial_size) == out_videos_short_edge + overhead = (max(video_spatial_size) - max(out_video_spatial_size)) // 2 + slice_start, slice_end = overhead // 2, overhead // 2 + max(out_video_spatial_size) + hslice, wslice = (slice(slice_start, slice_end), slice(None, None)) if video_spatial_size[0] > video_spatial_size[1] \ + else (slice(None, None), slice(slice_start, slice_end)) # h > w + for frame in frames: + if input_data_format == ChannelDimension.FIRST: + out_frames.append(frame[..., hslice, wslice]) + elif input_data_format == ChannelDimension.LAST: + out_frames.append(frame[..., hslice, wslice, :]) + out_videos.append(out_frames) + + return out_videos + + @staticmethod + def _compute_num_blocks_and_overlaps(input_shape, resolution): + input_shape = np.array(input_shape) + resolution = np.array(resolution) + assert input_shape.max() >= resolution + num_blocks = np.ceil(input_shape / resolution).astype(np.int32).tolist() + overlaps = [0 if size % resolution==0 + else int(np.floor((resolution - size % resolution) / (num_block - 1))) for num_block, size in zip(num_blocks, input_shape)] + return num_blocks, overlaps + + def resize( + self, + image: np.ndarray, + resample: PILImageResampling = PILImageResampling.BICUBIC, # type: ignore + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + shortest_edge: int = None, + longest_edge: int = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge + resized to keep the input aspect ratio. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Resampling filter to use when resiizing the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + shortest_edge = getattr(self, 'shortest_edge', None) if shortest_edge is None else shortest_edge + longest_edge = getattr(self, 'longest_edge', None) if longest_edge is None else longest_edge + default_to_square = False + output_size = get_resize_output_image_size( + image, + size=shortest_edge, + default_to_square=default_to_square, + max_size=longest_edge, + input_data_format=input_data_format, + ) + clip_resolution = self.image_processor.size['shortest_edge'] + if min(output_size) < clip_resolution: + output_size = get_resize_output_image_size( + image, + size=shortest_edge, + default_to_square=default_to_square, + input_data_format=input_data_format, + ) + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def __call__( + self, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + images: ImageInput = None, + center_pad = None, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length=None, + return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to + CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring + of the above two methods for more information. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a + number of channels, H and W are image height and width. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + truncation (`bool`, *optional*): + Activates truncation to cut input sequences longer than `max_length` to `max_length`. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + data=dict() + if images is not None: + if isinstance(images, list) and isinstance(images[0], PIL.Image.Image): + videos = [images] # one video + else: + videos = images + + pixel_values_list = [] + videos = [[to_numpy_array(image) for image in make_list_of_images(images)] for images in videos] + # images = [self.resize(image, ) if min(get_image_size(image, input_data_format)) < clip_resolution else image for image in images] + input_data_format = infer_channel_dimension_format(videos[0][0]) + videos = self.resize_crop_longshort(videos, input_data_format) + + for images in videos: + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + center_pad = center_pad if center_pad is not None else self.center_pad + if center_pad: + images = [self.pad_to_square(image, 0, input_data_format, input_data_format) for image in images] + + pixel_values = self.image_processor(images, return_tensors='np')["pixel_values"] + pixel_values_list.append(pixel_values) + + pixel_values = np.concatenate(pixel_values_list) + data.update(pixel_values=pixel_values) + + else: + data.update(pixel_values = None) + + if text is not None: + text_inputs = self.tokenizer( + text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length + ) + data.update(**text_inputs) + return BatchFeature(data, tensor_type=return_tensors) + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) diff --git a/python_scripts/hf.py b/python_scripts/hf.py new file mode 100644 index 0000000000000000000000000000000000000000..83929d1738b85594130d80e99072e7abaa63bfaf --- /dev/null +++ b/python_scripts/hf.py @@ -0,0 +1,80 @@ +import os.path as osp +import os +import re +import multiprocessing +import functools +import huggingface_hub +from huggingface_hub import snapshot_download + + +def upload(repo_id, local_dir, path_in_repo, repo_type, token): + huggingface_hub.upload_folder( + repo_id=repo_id, + folder_path=local_dir, + path_in_repo=path_in_repo, + token=token, + repo_type=repo_type + ) + +def download(repo_id, local_dir, repo_type, token, filter_re=None): + files = huggingface_hub.list_repo_files(repo_id, repo_type=repo_type, token=token) + if filter_re is not None: + files = [file for file in files if re.search(filter_re, file) is not None] + pool = multiprocessing.Pool(8) + download_func = functools.partial( + huggingface_hub.hf_hub_download, + repo_id, + repo_type=repo_type, + local_dir=local_dir, + local_dir_use_symlinks=True, + token=token + ) + pool.map(download_func, files) + print(f'downloaded files {files}') + + +def upload_file(repo_id, file_path, repo_type, token): + huggingface_hub.upload_file( + repo_id=repo_id, + path_or_fileobj=file_path, + path_in_repo=file_path, + token=token, + repo_type=repo_type, + ) + +if __name__ == '__main__': + read_token = '...' + write_token = '...' + repo_id = '...' + local_dir = '...' + repo_type = '...' + + + # ############# + # # Examples on most simple hf usage + # # downlaod + # filters = [] + # for filter_re in filters: + # download(repo_id, + # local_dir, + # repo_type, + # filter_re) + + # # upload + # upload(repo_id, local_dir, local_dir, repo_type, write_token) + # ############# + + # download models + repo_ids = [ + 'ermu2001/pllava-7b', + 'ermu2001/pllava-13b', + ] + for repo_id in repo_ids: + local_dir = repo_id.replace('ermu2001', 'MODELS') + snapshot_download( + repo_id, + local_dir=local_dir, + repo_type='model', + local_dir_use_symlinks=True, + token=read_token, + ) \ No newline at end of file diff --git a/requirements.no_torch.txt b/requirements.no_torch.txt new file mode 100644 index 0000000000000000000000000000000000000000..307cc3971f513ed7e81ae7f122060941eea2dc00 --- /dev/null +++ b/requirements.no_torch.txt @@ -0,0 +1,244 @@ +absl-py==2.1.0 +accelerate==0.26.1 +addict==2.4.0 +aiofiles==23.2.1 +aliyun-python-sdk-core==2.15.0 +aliyun-python-sdk-kms==2.16.2 +altair==5.2.0 +annotated-types==0.6.0 +antlr4-python3-runtime==4.9.3 +anyio==4.3.0 +anykeystore==0.2 +apex==0.9.10.dev0 +appdirs==1.4.4 +argcomplete==3.2.3 +attrs==23.2.0 +av==10.0.0 +beautifulsoup4==4.12.3 +blessed==1.20.0 +blessings==1.7 +boto3==1.34.63 +botocore==1.34.63 +Brotli==1.1.0 +cachetools==5.3.3 +certifi==2024.2.2 +cffi==1.16.0 +charset-normalizer==3.3.2 +click==8.1.7 +colorama==0.4.6 +contourpy==1.2.0 +crcmod==1.7 +cryptacular==1.6.2 +cryptography==42.0.5 +cycler==0.12.1 +dacite==1.7.0 +decorator==4.4.2 +decord==0.6.0 +deepspeed==0.14.0 +defusedxml==0.7.1 +Deprecated==1.2.14 +dill==0.3.8 +distro==1.9.0 +dnspython==2.6.1 +docker-pycreds==0.4.0 +einops==0.6.1 +exceptiongroup==1.2.0 +fastapi==0.110.0 +ffmpeg==1.4 +ffmpy==0.3.2 +fiftyone==0.23.6 +fiftyone-brain==0.16.1 +fiftyone_db==1.1.2 +filelock==3.9.0 +flash-attn==2.5.6 +fonttools==4.49.0 +fsspec==2024.2.0 +ftfy==6.1.3 +future==1.0.0 +fvcore==0.1.5.post20221221 +gdown==5.1.0 +gitdb==4.0.11 +GitPython==3.1.42 +glob2==0.7 +google-auth==2.28.2 +google-auth-oauthlib==1.2.0 +gpustat==1.1.1 +gradio==4.21.0 +gradio_client==0.12.0 +graphql-core==3.2.3 +greenlet==3.0.3 +grpcio==1.62.1 +h11==0.14.0 +h2==4.1.0 +hjson==3.1.0 +hpack==4.0.0 +httpcore==1.0.4 +httpx==0.27.0 +huggingface-hub==0.21.4 +humanize==4.9.0 +hupper==1.12.1 +Hypercorn==0.16.0 +hyperframe==6.0.1 +idna==3.6 +idscheck==2.3.0 +imageio==2.27.0 +imageio-ffmpeg==0.4.9 +importlib_metadata==7.0.2 +importlib_resources==6.3.0 +inflate64==1.0.0 +iopath==0.1.10 +Jinja2==3.1.2 +jmespath==0.10.0 +joblib==1.3.2 +jsonlines==4.0.0 +jsonschema==4.21.1 +jsonschema-specifications==2023.12.1 +kaleido==0.2.1 +kiwisolver==1.4.5 +lazy_loader==0.3 +Markdown==3.6 +markdown-it-py==3.0.0 +MarkupSafe==2.1.3 +matplotlib==3.8.3 +mdurl==0.1.2 +mmcv-full==1.7.2 +model-index==0.1.11 +mongoengine==0.24.2 +motor==3.3.2 +moviepy==1.0.3 +mpmath==1.3.0 +multivolumefile==0.2.3 +networkx==3.2.1 +ninja==1.11.1.1 +numpy +oauthlib==3.2.2 +omegaconf==2.3.0 +openai==1.14.0 +opencv-python==4.9.0.80 +opencv-python-headless==4.9.0.80 +opendatalab==0.0.10 +openmim==0.3.9 +openxlab==0.0.36 +ordered-set==4.1.0 +orjson==3.9.15 +oss2==2.17.0 +packaging==24.0 +pandas==1.5.3 +PasteDeploy==3.1.0 +pathtools==0.1.2 +pbkdf2==1.3 +peft==0.10.0 +pillow==10.2.0 +plaster==1.1.2 +plaster-pastedeploy==1.0.1 +platformdirs==4.2.0 +plotly==5.20.0 +portalocker==2.8.2 +pprintpp==0.4.0 +priority==2.0.0 +proglog==0.1.10 +protobuf==4.23.4 +psutil==5.9.4 +py-cpuinfo==9.0.0 +py7zr==0.21.0 +pyasn1==0.5.1 +pyasn1-modules==0.3.0 +pybcj==1.0.2 +pycparser==2.21 +pycryptodome==3.20.0 +pycryptodomex==3.20.0 +pydantic==2.6.4 +pydantic_core==2.16.3 +pydub==0.25.1 +Pygments==2.17.2 +pymongo==4.6.2 +pynvml==11.5.0 +pyparsing==3.1.2 +pyppmd==1.1.0 +pyramid==2.0.2 +pyramid-mailer==0.15.1 +PySocks==1.7.1 +python-dateutil==2.9.0.post0 +python-multipart==0.0.9 +python3-openid==3.2.0 +pytz==2023.4 +PyYAML==6.0 +pyzstd==0.15.9 +rarfile==4.1 +referencing==0.33.0 +regex==2023.12.25 +repoze.sendmail==4.4.1 +requests==2.28.2 +requests-oauthlib==1.4.0 +retrying==1.3.4 +rich==13.4.2 +rpds-py==0.18.0 +rsa==4.9 +ruff==0.3.2 +s3transfer==0.10.1 +safetensors==0.4.2 +scikit-image==0.22.0 +scikit-learn==1.4.1.post1 +scipy==1.10.1 +semantic-version==2.10.0 +sentencepiece==0.2.0 +sentry-sdk==1.42.0 +setproctitle==1.3.3 +shellingham==1.5.4 +six==1.16.0 +smmap==5.0.1 +sniffio==1.3.1 +sortedcontainers==2.4.0 +soupsieve==2.5 +SQLAlchemy==2.0.28 +sse-starlette==0.10.3 +sseclient-py==1.8.0 +starlette==0.36.3 +strawberry-graphql==0.138.1 +sympy==1.12 +tabulate==0.9.0 +taskgroup==0.0.0a4 +tenacity==8.2.3 +tensorboard==2.15.1 +tensorboard-data-server==0.7.2 +tensorboardX==2.6.2.2 +termcolor==2.3.0 +texttable==1.7.0 +threadpoolctl==3.3.0 +tifffile==2024.2.12 +timm==0.6.12 +tokenizers==0.15.2 +tomli==2.0.1 +tomlkit==0.12.0 +toolz==0.12.1 +tqdm==4.65.2 +transaction==4.0 +transformers==4.37.1 +translationstring==1.4 +triton==2.2.0 +typer==0.9.0 +typing_extensions==4.8.0 +tzdata==2024.1 +tzlocal==5.2 +universal-analytics-python3==1.1.1 +urllib3==1.26.18 +uvicorn==0.28.0 +velruse==1.1.1 +venusian==3.1.0 +voxel51-eta==0.12.6 +wandb==0.14.0 +wcwidth==0.2.13 +WebOb==1.8.7 +websockets==11.0.3 +Werkzeug==3.0.1 +wrapt==1.16.0 +wsproto==1.2.0 +WTForms==3.1.2 +wtforms-recaptcha==0.3.2 +xmltodict==0.13.0 +yacs==0.1.8 +yapf==0.40.2 +zipp==3.18.1 +zope.deprecation==5.0 +zope.interface==6.2 +zope.sqlalchemy==3.1 diff --git a/requirements.torch.txt b/requirements.torch.txt new file mode 100644 index 0000000000000000000000000000000000000000..75367ad5ca53ff03cc399347237e3f565f9dee34 --- /dev/null +++ b/requirements.torch.txt @@ -0,0 +1,4 @@ +--index-url https://download.pytorch.org/whl/cu118 +torch==2.2.1 +torchaudio==2.2.1 +torchvision==0.17.1 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..54656e56be266fb3a4a6a4769be4e0f83874c2bf --- /dev/null +++ b/requirements.txt @@ -0,0 +1,246 @@ +absl-py==2.1.0 +accelerate==0.26.1 +addict==2.4.0 +aiofiles==23.2.1 +aliyun-python-sdk-core==2.15.0 +aliyun-python-sdk-kms==2.16.2 +altair==5.2.0 +annotated-types==0.6.0 +antlr4-python3-runtime==4.9.3 +anyio==4.3.0 +anykeystore==0.2 +apex==0.9.10.dev0 +appdirs==1.4.4 +argcomplete==3.2.3 +attrs==23.2.0 +av==10.0.0 +beautifulsoup4==4.12.3 +blessed==1.20.0 +blessings==1.7 +boto3==1.34.63 +botocore==1.34.63 +Brotli==1.1.0 +cachetools==5.3.3 +certifi==2024.2.2 +cffi==1.16.0 +charset-normalizer==3.3.2 +click==8.1.7 +colorama==0.4.6 +contourpy==1.2.0 +crcmod==1.7 +cryptacular==1.6.2 +cryptography==42.0.5 +cycler==0.12.1 +dacite==1.7.0 +decorator==4.4.2 +decord==0.6.0 +deepspeed==0.14.0 +defusedxml==0.7.1 +Deprecated==1.2.14 +dill==0.3.8 +distro==1.9.0 +dnspython==2.6.1 +docker-pycreds==0.4.0 +einops==0.6.1 +exceptiongroup==1.2.0 +fastapi==0.110.0 +ffmpeg==1.4 +ffmpy==0.3.2 +fiftyone==0.23.6 +fiftyone-brain==0.16.1 +fiftyone_db==1.1.2 +filelock==3.9.0 +fonttools==4.49.0 +fsspec==2024.2.0 +ftfy==6.1.3 +future==1.0.0 +fvcore==0.1.5.post20221221 +gdown==5.1.0 +gitdb==4.0.11 +GitPython==3.1.42 +glob2==0.7 +google-auth==2.28.2 +google-auth-oauthlib==1.2.0 +gpustat==1.1.1 +gradio==4.21.0 +gradio_client==0.12.0 +graphql-core==3.2.3 +greenlet==3.0.3 +grpcio==1.62.1 +h11==0.14.0 +h2==4.1.0 +hjson==3.1.0 +hpack==4.0.0 +httpcore==1.0.4 +httpx==0.27.0 +huggingface-hub==0.21.4 +humanize==4.9.0 +hupper==1.12.1 +Hypercorn==0.16.0 +hyperframe==6.0.1 +idna==3.6 +idscheck==2.3.0 +imageio==2.27.0 +imageio-ffmpeg==0.4.9 +importlib_metadata==7.0.2 +importlib_resources==6.3.0 +inflate64==1.0.0 +iopath==0.1.10 +Jinja2==3.1.2 +jmespath==0.10.0 +joblib==1.3.2 +jsonlines==4.0.0 +jsonschema==4.21.1 +jsonschema-specifications==2023.12.1 +kaleido==0.2.1 +kiwisolver==1.4.5 +lazy_loader==0.3 +Markdown==3.6 +markdown-it-py==3.0.0 +MarkupSafe==2.1.3 +matplotlib==3.8.3 +mdurl==0.1.2 +mmcv-full==1.7.2 +model-index==0.1.11 +mongoengine==0.24.2 +motor==3.3.2 +moviepy==1.0.3 +mpmath==1.3.0 +multivolumefile==0.2.3 +networkx==3.2.1 +ninja==1.11.1.1 +numpy==1.23.5 +oauthlib==3.2.2 +omegaconf==2.3.0 +openai==1.14.0 +opencv-python==4.9.0.80 +opencv-python-headless==4.9.0.80 +opendatalab==0.0.10 +openmim==0.3.9 +openxlab==0.0.36 +ordered-set==4.1.0 +orjson==3.9.15 +oss2==2.17.0 +packaging==24.0 +pandas==1.5.3 +PasteDeploy==3.1.0 +pathtools==0.1.2 +pbkdf2==1.3 +peft==0.10.0 +pillow==10.2.0 +plaster==1.1.2 +plaster-pastedeploy==1.0.1 +platformdirs==4.2.0 +plotly==5.20.0 +portalocker==2.8.2 +pprintpp==0.4.0 +priority==2.0.0 +proglog==0.1.10 +protobuf==4.23.4 +psutil==5.9.4 +py-cpuinfo==9.0.0 +py7zr==0.21.0 +pyasn1==0.5.1 +pyasn1-modules==0.3.0 +pybcj==1.0.2 +pycparser==2.21 +pycryptodome==3.20.0 +pycryptodomex==3.20.0 +pydantic==2.6.4 +pydantic_core==2.16.3 +pydub==0.25.1 +Pygments==2.17.2 +pymongo==4.6.2 +pynvml==11.5.0 +pyparsing==3.1.2 +pyppmd==1.1.0 +pyramid==2.0.2 +pyramid-mailer==0.15.1 +PySocks==1.7.1 +python-dateutil==2.9.0.post0 +python-multipart==0.0.9 +python3-openid==3.2.0 +pytz==2023.4 +PyYAML==6.0 +pyzstd==0.15.9 +rarfile==4.1 +referencing==0.33.0 +regex==2023.12.25 +repoze.sendmail==4.4.1 +requests==2.28.2 +requests-oauthlib==1.4.0 +retrying==1.3.4 +rich==13.4.2 +rpds-py==0.18.0 +rsa==4.9 +ruff==0.3.2 +s3transfer==0.10.1 +safetensors==0.4.2 +scikit-image==0.22.0 +scikit-learn==1.4.1.post1 +scipy==1.10.1 +semantic-version==2.10.0 +sentencepiece==0.2.0 +sentry-sdk==1.42.0 +setproctitle==1.3.3 +shellingham==1.5.4 +six==1.16.0 +smmap==5.0.1 +sniffio==1.3.1 +sortedcontainers==2.4.0 +soupsieve==2.5 +SQLAlchemy==2.0.28 +sse-starlette==0.10.3 +sseclient-py==1.8.0 +starlette==0.36.3 +strawberry-graphql==0.138.1 +sympy==1.12 +tabulate==0.9.0 +taskgroup==0.0.0a4 +tenacity==8.2.3 +tensorboard==2.15.1 +tensorboard-data-server==0.7.2 +tensorboardX==2.6.2.2 +termcolor==2.3.0 +texttable==1.7.0 +threadpoolctl==3.3.0 +tifffile==2024.2.12 +timm==0.6.12 +tokenizers==0.15.2 +tomli==2.0.1 +tomlkit==0.12.0 +toolz==0.12.1 +torch==2.2.1 +torchaudio==2.2.1 +torchvision==0.17.1 +tqdm==4.65.2 +transaction==4.0 +transformers +translationstring==1.4 +triton==2.2.0 +typer==0.9.0 +typing_extensions==4.8.0 +tzdata==2024.1 +tzlocal==5.2 +universal-analytics-python3==1.1.1 +urllib3==1.26.18 +uvicorn==0.28.0 +velruse==1.1.1 +venusian==3.1.0 +voxel51-eta==0.12.6 +wandb==0.14.0 +wcwidth==0.2.13 +WebOb==1.8.7 +websockets==11.0.3 +Werkzeug==3.0.1 +wrapt==1.16.0 +wsproto==1.2.0 +WTForms==3.1.2 +wtforms-recaptcha==0.3.2 +xmltodict==0.13.0 +yacs==0.1.8 +yapf==0.40.2 +zipp==3.18.1 +zope.deprecation==5.0 +zope.interface==6.2 +zope.sqlalchemy==3.1 diff --git a/scripts/accel_config_deepspeed_zero2.yaml b/scripts/accel_config_deepspeed_zero2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ee8d5e49ae4c5d253ba8c1ea0ffe7b729b905cfd --- /dev/null +++ b/scripts/accel_config_deepspeed_zero2.yaml @@ -0,0 +1,21 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + gradient_accumulation_steps: 8 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: false + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 4 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/scripts/accel_config_deepspeed_zero3_offload.yaml b/scripts/accel_config_deepspeed_zero3_offload.yaml new file mode 100644 index 0000000000000000000000000000000000000000..436357c30fc3ca74e68eded9495fef8b3b244f22 --- /dev/null +++ b/scripts/accel_config_deepspeed_zero3_offload.yaml @@ -0,0 +1,22 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + gradient_accumulation_steps: 2 + offload_optimizer_device: cpu + offload_param_device: cpu + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 3 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/scripts/accel_config_deepspeed_zero3_offload_multinode.yaml b/scripts/accel_config_deepspeed_zero3_offload_multinode.yaml new file mode 100644 index 0000000000000000000000000000000000000000..333b4f18e6e540b162c9846d7632c64d6c8827e0 --- /dev/null +++ b/scripts/accel_config_deepspeed_zero3_offload_multinode.yaml @@ -0,0 +1,25 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_multinode_launcher: standard + gradient_accumulation_steps: 2 + offload_optimizer_device: cpu + offload_param_device: cpu + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 3 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +machine_rank: 0 +main_process_ip: fdbd:dc61:18:8::20 +main_process_port: 6876 +main_training_function: main +mixed_precision: bf16 +num_machines: 2 +num_processes: 16 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/scripts/accel_config_deepspeed_zero3_offload_multinode_1.yaml b/scripts/accel_config_deepspeed_zero3_offload_multinode_1.yaml new file mode 100644 index 0000000000000000000000000000000000000000..333b4f18e6e540b162c9846d7632c64d6c8827e0 --- /dev/null +++ b/scripts/accel_config_deepspeed_zero3_offload_multinode_1.yaml @@ -0,0 +1,25 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_multinode_launcher: standard + gradient_accumulation_steps: 2 + offload_optimizer_device: cpu + offload_param_device: cpu + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 3 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +machine_rank: 0 +main_process_ip: fdbd:dc61:18:8::20 +main_process_port: 6876 +main_training_function: main +mixed_precision: bf16 +num_machines: 2 +num_processes: 16 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/scripts/accel_config_deepspeed_zero3_offload_multinode_2.yaml b/scripts/accel_config_deepspeed_zero3_offload_multinode_2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f2c57be497505189415cf0ffbf98af21516f676a --- /dev/null +++ b/scripts/accel_config_deepspeed_zero3_offload_multinode_2.yaml @@ -0,0 +1,25 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_multinode_launcher: standard + gradient_accumulation_steps: 2 + offload_optimizer_device: cpu + offload_param_device: cpu + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 3 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +machine_rank: 1 +main_process_ip: fdbd:dc61:18:8::20 +main_process_port: 6876 +main_training_function: main +mixed_precision: bf16 +num_machines: 2 +num_processes: 16 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/scripts/accel_config_deepspeed_zero3_offload_singlegpu.yaml b/scripts/accel_config_deepspeed_zero3_offload_singlegpu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0583d16a02f966f66c74e02edb80da970f6dceee --- /dev/null +++ b/scripts/accel_config_deepspeed_zero3_offload_singlegpu.yaml @@ -0,0 +1,23 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + gradient_accumulation_steps: 16 + gradient_clipping: 1.0 + offload_optimizer_device: cpu + offload_param_device: cpu + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 3 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 1 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/scripts/accel_config_multigpu.yaml b/scripts/accel_config_multigpu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..dbe0dc7b6ade744eca906c95a06c018f21cac09f --- /dev/null +++ b/scripts/accel_config_multigpu.yaml @@ -0,0 +1,16 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: MULTI_GPU +downcast_bf16: 'no' +gpu_ids: 2,3,4,5 +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 4 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/scripts/accel_config_multinode.yaml b/scripts/accel_config_multinode.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b437201b4d6b27f339756bc44061c9e3f568c50c --- /dev/null +++ b/scripts/accel_config_multinode.yaml @@ -0,0 +1,18 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: MULTI_GPU +downcast_bf16: 'no' +gpu_ids: all +machine_rank: 1 +main_process_ip: 10.193.16.150 +main_process_port: 6784 +main_training_function: main +mixed_precision: bf16 +num_machines: 2 +num_processes: 16 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/scripts/accel_config_singlegpu.yaml b/scripts/accel_config_singlegpu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cda636385ae4afb7425dbb4ed6c2630ec42b6c70 --- /dev/null +++ b/scripts/accel_config_singlegpu.yaml @@ -0,0 +1,16 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: 'NO' +downcast_bf16: 'no' +gpu_ids: '0' +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 1 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/scripts/demo.sh b/scripts/demo.sh new file mode 100644 index 0000000000000000000000000000000000000000..5b6dfd2f00f4f463b91dc3911efdc66b2c8b97f0 --- /dev/null +++ b/scripts/demo.sh @@ -0,0 +1,32 @@ +model_dir=${1:-"MODELS/pllava-7b"} +weight_dir=${2:-"${model_dir}"} +num_frames=16 +lora_alpha=4 + +echo Running DEMO from model_dir: ${model_dir} +echo Running DEMO from weights_dir: ${weight_dir} +echo Running DEMO On Devices: ${CUDA_VISIBLE_DEVICES} + + +# # 34B Need to Use dispatch for this large. +# CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} python -m tasks.eval.demo.pllava_demo \ +# --pretrained_model_name_or_path ${model_dir} \ +# --num_frames ${num_frames} \ +# --use_lora \ +# --weight_dir ${weight_dir} \ +# --lora_alpha ${lora_alpha} \ +# --conv_mode eval_vcg_llava_next \ +# --use_multi_gpus \ + + +# 7B and 13B, There are problem if Model was split around A100 40G... Probably because some unkown bug in accelerate dispatch +CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-"0,1"} python -m tasks.eval.demo.pllava_demo \ + --pretrained_model_name_or_path ${model_dir} \ + --num_frames ${num_frames} \ + --use_lora \ + --weight_dir ${weight_dir} \ + --lora_alpha ${lora_alpha} \ + --conv_mode plain \ + --use_multi_gpus + + diff --git a/scripts/eval.sh b/scripts/eval.sh new file mode 100644 index 0000000000000000000000000000000000000000..db91cdde75b56e1c15c0bfc1e76b0d1764d08ac2 --- /dev/null +++ b/scripts/eval.sh @@ -0,0 +1,104 @@ +# export CUDA_VISIBLE_DEVICES=2,6,7 +export OPENAI_API_KEY=... +num_frames=16 +test_ratio=1 + +# 13b, uses offload thus saving the full model +model_dir=MODELS/pllava-13b +weight_dir=MODELS/pllava-13b +SAVE_DIR=test_results/test_pllava_13b +lora_alpha=4 +conv_mode=eval_vcgbench +python -m tasks.eval.vcgbench.pllava_eval_vcgbench \ + --pretrained_model_name_or_path ${model_dir} \ + --save_path ${SAVE_DIR}/vcgbench \ + --num_frames ${num_frames} \ + --use_lora \ + --lora_alpha ${lora_alpha} \ + --weight_dir ${weight_dir} \ + --pooling_shape 16-12-12 \ + --test_ratio ${test_ratio} \ + --conv_mode ${conv_mode} + +conv_mode=eval_mvbench +python -m tasks.eval.mvbench.pllava_eval_mvbench \ + --pretrained_model_name_or_path ${model_dir} \ + --save_path ${SAVE_DIR}/mvbench \ + --use_lora \ + --lora_alpha ${lora_alpha} \ + --num_frames ${num_frames} \ + --weight_dir ${weight_dir} \ + --pooling_shape 16-12-12 \ + --conv_mode ${conv_mode} + +onv_mode=eval_videoqabench +python -m tasks.eval.videoqabench.pllava_eval_videoqabench \ + --pretrained_model_name_or_path ${model_dir} \ + --save_path ${SAVE_DIR}/videoqabench \ + --num_frames ${num_frames} \ + --use_lora \ + --lora_alpha ${lora_alpha} \ + --weight_dir ${weight_dir} \ + --test_ratio ${test_ratio} \ + --conv_mode ${conv_mode} + + +conv_mode=eval_recaption +python -m tasks.eval.recaption.pllava_recaption \ + --pretrained_model_name_or_path ${model_dir} \ + --save_path ${SAVE_DIR}/recaption \ + --num_frames ${num_frames} \ + --use_lora \ + --weight_dir ${weight_dir} \ + --lora_alpha ${lora_alpha} \ + --test_ratio ${test_ratio} \ + --conv_mode ${conv_mode} + + +model_dir=MODELS/pllava-7b +weight_dir=MODELS/pllava-7b +SAVE_DIR=test_results/test_pllava_7b +lora_alpha=4 + +conv_mode=eval_vcgbench +python -m tasks.eval.vcgbench.pllava_eval_vcgbench \ + --pretrained_model_name_or_path ${model_dir} \ + --save_path ${SAVE_DIR}/vcgbench \ + --num_frames ${num_frames} \ + --use_lora \ + --lora_alpha ${lora_alpha} \ + --weight_dir ${weight_dir} \ + --pooling_shape 16-12-12 \ + --test_ratio ${test_ratio} + + +conv_mode=eval_mvbench +python -m tasks.eval.mvbench.pllava_eval_mvbench \ + --pretrained_model_name_or_path ${model_dir} \ + --save_path ${SAVE_DIR}/mvbench \ + --use_lora \ + --lora_alpha ${lora_alpha} \ + --num_frames ${num_frames} \ + --weight_dir ${weight_dir} \ + --pooling_shape 16-12-12 + + +onv_mode=eval_videoqabench +python -m tasks.eval.videoqabench.pllava_eval_videoqabench \ + --pretrained_model_name_or_path ${model_dir} \ + --save_path ${SAVE_DIR}/videoqabench \ + --num_frames ${num_frames} \ + --use_lora \ + --lora_alpha ${lora_alpha} \ + --weight_dir ${weight_dir} \ + --test_ratio ${test_ratio} + +conv_mode=eval_recaption +python -m tasks.eval.recaption.pllava_recaption \ + --pretrained_model_name_or_path ${model_dir} \ + --save_path ${SAVE_DIR}/recaption \ + --num_frames ${num_frames} \ + --use_lora \ + --lora_alpha ${lora_alpha} \ + --weight_dir ${weight_dir} \ + --test_ratio ${test_ratio} \ No newline at end of file diff --git a/scripts/eval_yiprompt.sh b/scripts/eval_yiprompt.sh new file mode 100644 index 0000000000000000000000000000000000000000..0307017c9d314133a2a2071d2b418a782ddc8a2d --- /dev/null +++ b/scripts/eval_yiprompt.sh @@ -0,0 +1,53 @@ +# export CUDA_VISIBLE_DEVICES=0,3,4,5,6,7 +export OPENAI_API_KEY=... +num_frames=16 +test_ratio=200 + +model_dir=MODELS/pllava-34b +weight_dir=MODELS/pllava-34b +SAVE_DIR=test_results/test_pllava_34b +lora_alpha=4 +conv_mode=eval_vcg_llavanext +python -m tasks.eval.vcgbench.pllava_eval_vcgbench \ + --pretrained_model_name_or_path ${model_dir} \ + --save_path ${SAVE_DIR}/vcgbench \ + --num_frames ${num_frames} \ + --use_lora \ + --lora_alpha ${lora_alpha} \ + --weight_dir ${weight_dir} \ + --pooling_shape 16-12-12 \ + --test_ratio ${test_ratio} \ + --conv_mode $conv_mode + +conv_mode=eval_mvbench_llavanext +python -m tasks.eval.mvbench.pllava_eval_mvbench \ + --pretrained_model_name_or_path ${model_dir} \ + --save_path ${SAVE_DIR}/mvbench \ + --use_lora \ + --lora_alpha ${lora_alpha} \ + --num_frames ${num_frames} \ + --weight_dir ${weight_dir} \ + --pooling_shape 16-12-12 \ + --conv_mode $conv_mode + +conv_mode=eval_videoqa_llavanext +python -m tasks.eval.videoqabench.pllava_eval_videoqabench \ + --pretrained_model_name_or_path ${model_dir} \ + --save_path ${SAVE_DIR}/videoqabench \ + --num_frames ${num_frames} \ + --use_lora \ + --lora_alpha ${lora_alpha} \ + --weight_dir ${weight_dir} \ + --test_ratio ${test_ratio} \ + --conv_mode ${conv_mode} + +conv_mode=eval_recaption_llavanext +python -m tasks.eval.recaption.pllava_recaption \ + --pretrained_model_name_or_path ${model_dir} \ + --save_path ${SAVE_DIR}/recaption \ + --num_frames ${num_frames} \ + --use_lora \ + --weight_dir ${weight_dir} \ + --lora_alpha ${lora_alpha} \ + --test_ratio ${test_ratio} \ + --conv_mode $conv_mode diff --git a/scripts/gallery.sh b/scripts/gallery.sh new file mode 100644 index 0000000000000000000000000000000000000000..862898a40b8a98405922b89e0d1ce166f6b42e0b --- /dev/null +++ b/scripts/gallery.sh @@ -0,0 +1,11 @@ +export OPENAI_API_KEY=... +SAVE_DIR=${1:-"test_results"} + +# # gallery view +# python -m tasks.eval.show_gallery \ +# --root_dir ${SAVE_DIR} + +# # compare view +python -m tasks.eval.demo.show_compare \ + --root_dir ${SAVE_DIR} + diff --git a/scripts/train_pllava.sh b/scripts/train_pllava.sh new file mode 100644 index 0000000000000000000000000000000000000000..3c7b2c23bc7dd9699fcc6027752d7ce9dbaf826c --- /dev/null +++ b/scripts/train_pllava.sh @@ -0,0 +1,34 @@ +echo "PYTHONPATH: ${PYTHONPATH}" +which_python=$(which python) +echo "which python: ${which_python}" +export PYTHONPATH=${PYTHONPATH}:${which_python} +export PYTHONPATH=${PYTHONPATH}:. +echo "PYTHONPATH: ${PYTHONPATH}" + +OUTPUT_DIR=./pllava_video_outputs/test_train_7b_reconstruct + +# # Naive Env +# rm -rf ${OUTPUT_DIR} +pooling_shape=(16,12,12) +accelerate launch --main_process_port 6876 --config_file scripts/accel_config_multigpu.yaml tasks/train/train_pllava_nframe_accel.py \ + tasks/train/config_pllava_nframe.py \ + output_dir ${OUTPUT_DIR} \ + train_corpus videochat2_video \ + save_steps 10000 \ + num_workers 8 \ + num_frames 16 \ + model.pooling_method avg \ + model.repo_id llava-hf/llava-v1.6-vicuna-7b-hf \ + model.use_lora True \ + model.pooling_shape $pooling_shape \ + optimizer.lr 2e-5 \ + scheduler.epochs 3 \ + scheduler.warmup_ratio 0.2 \ + scheduler.min_lr_multi 0.25 \ + scheduler.is_videochat2_custom True \ + preprocess.mm_alone False \ + preprocess.random_shuffle False \ + preprocess.add_second_msg False \ + train_corpus videochat2_instruction_debug + + \ No newline at end of file diff --git a/scripts/train_pllava_13b.sh b/scripts/train_pllava_13b.sh new file mode 100644 index 0000000000000000000000000000000000000000..ba23997cbbb77b268fa2a2766a00d352ffbd6f85 --- /dev/null +++ b/scripts/train_pllava_13b.sh @@ -0,0 +1,50 @@ +echo "PYTHONPATH: ${PYTHONPATH}" +which_python=$(which python) +echo "which python: ${which_python}" +export PYTHONPATH=${PYTHONPATH}:${which_python} +export PYTHONPATH=${PYTHONPATH}:. +echo "PYTHONPATH: ${PYTHONPATH}" + +OUTPUT_DIR=./pllava_video_outputs/pllava_13b + + +pooling_shape=(16,12,12) +num_save_samples=80000 +num_gpus=8 +full_batch_size=128 +batch_size=8 +save_steps=$[$num_save_samples/($batch_size*$num_gpus)] +ckpt_steps=$[$save_steps/10] +gradient_accumulation_steps=$[$full_batch_size/($batch_size*$num_gpus)] +echo $batch_size +echo $gradient_accumulation_steps +repo_id=llava-hf/llava-v1.6-vicuna-13b-hf +accelerate launch --main_process_port 6876 --config_file scripts/accel_config_deepspeed_zero3_offload.yaml tasks/train/train_pllava_nframe_accel.py \ + tasks/train/config_pllava_nframe.py \ + output_dir ${OUTPUT_DIR} \ + train_corpus videochat2_instruction_debug \ + save_steps $save_steps \ + ckpt_steps $ckpt_steps \ + num_workers 8 \ + num_frames 16 \ + gradient_accumulation_steps $gradient_accumulation_steps \ + batch_size $batch_size \ + deepspeed True \ + model.pooling_method avg \ + model.use_lora True \ + model.use_pooling True \ + model.repo_id $repo_id \ + gradient_checkpointing True \ + preprocess.center_pad False \ + preprocess.clip_transform False \ + optimizer.lr 2e-5 \ + scheduler.epochs 3 \ + scheduler.warmup_ratio 0.2 \ + scheduler.min_lr_multi 0.25 \ + model.pooling_shape $pooling_shape \ + scheduler.is_videochat2_custom True \ + preprocess.mm_alone False \ + preprocess.random_shuffle False \ + preprocess.add_second_msg False + + diff --git a/scripts/train_pllava_34b.sh b/scripts/train_pllava_34b.sh new file mode 100644 index 0000000000000000000000000000000000000000..2c167e34dd7a5b0bbe776d784af1894d8b1830d4 --- /dev/null +++ b/scripts/train_pllava_34b.sh @@ -0,0 +1,50 @@ +echo "PYTHONPATH: ${PYTHONPATH}" +which_python=$(which python) +echo "which python: ${which_python}" +export PYTHONPATH=${PYTHONPATH}:${which_python} +export PYTHONPATH=${PYTHONPATH}:. +echo "PYTHONPATH: ${PYTHONPATH}" + +machine_rank=${1:-"0"} # machine rank + +OUTPUT_DIR=./pllava_video_outputs/pllava_34b_videchat2-video + +pooling_shape=(16,12,12) +num_save_samples=80000 +num_gpus=8 +full_batch_size=128 +batch_size=4 +save_steps=$[$num_save_samples/($batch_size*$num_gpus)] +ckpt_steps=$[$save_steps/10] +gradient_accumulation_steps=$[$full_batch_size/($batch_size*$num_gpus)] +echo $batch_size +echo $gradient_accumulation_steps +repo_id=llava-hf/llava-v1.6-34b-hf +accelerate launch --main_process_port 6876 --config_file scripts/accel_config_deepspeed_zero3_offload.yaml tasks/train/train_pllava_nframe_accel.py \ + tasks/train/config_pllava_nframe_yiprompt.py \ + output_dir ${OUTPUT_DIR} \ + train_corpus videochat2_instruction_debug \ + save_steps $save_steps \ + ckpt_steps $ckpt_steps \ + num_workers 8 \ + num_frames 16 \ + deepspeed True \ + gradient_accumulation_steps $gradient_accumulation_steps \ + batch_size $batch_size \ + model.pooling_method avg \ + model.use_lora True \ + model.use_pooling True \ + model.repo_id $repo_id \ + gradient_checkpointing True \ + preprocess.center_pad False \ + preprocess.clip_transform True \ + optimizer.lr 2e-5 \ + scheduler.epochs 3 \ + scheduler.warmup_ratio 0.2 \ + scheduler.min_lr_multi 0.25 \ + model.pooling_shape $pooling_shape \ + scheduler.is_videochat2_custom True \ + preprocess.image_token_index 64002 \ + preprocess.mm_alone False \ + preprocess.random_shuffle False \ + preprocess.add_second_msg False diff --git a/scripts/train_pllava_7b.sh b/scripts/train_pllava_7b.sh new file mode 100644 index 0000000000000000000000000000000000000000..f21cad8869e90727b2836af987ffd0e00972ceef --- /dev/null +++ b/scripts/train_pllava_7b.sh @@ -0,0 +1,49 @@ +echo "PYTHONPATH: ${PYTHONPATH}" +which_python=$(which python) +echo "which python: ${which_python}" +export PYTHONPATH=${PYTHONPATH}:${which_python} +export PYTHONPATH=${PYTHONPATH}:. +echo "PYTHONPATH: ${PYTHONPATH}" + +OUTPUT_DIR=./pllava_video_outputs/test_train_7b_reconstruct + +pooling_shape=(16,12,12) +num_save_samples=80000 +num_gpus=8 +full_batch_size=128 +batch_size=8 +save_steps=$[$num_save_samples/($batch_size*$num_gpus)] +ckpt_steps=$[$save_steps/10] +gradient_accumulation_steps=$[$full_batch_size/($batch_size*$num_gpus)] +echo $batch_size +echo $gradient_accumulation_steps +repo_id=llava-hf/llava-v1.6-vicuna-7b-hf +accelerate launch --main_process_port 6876 --config_file scripts/accel_config_multigpu.yaml tasks/train/train_pllava_nframe_accel.py \ + tasks/train/config_pllava_nframe.py \ + output_dir ${OUTPUT_DIR} \ + train_corpus videochat2_instruction_debug \ + save_steps $save_steps \ + ckpt_steps $ckpt_steps \ + num_workers 8 \ + num_frames 16 \ + gradient_accumulation_steps $gradient_accumulation_steps \ + batch_size $batch_size \ + model.pooling_method avg \ + model.use_lora True \ + model.use_pooling True \ + model.repo_id $repo_id \ + gradient_checkpointing True \ + preprocess.center_pad False \ + preprocess.clip_transform False \ + optimizer.lr 2e-5 \ + scheduler.epochs 3 \ + scheduler.warmup_ratio 0.2 \ + scheduler.min_lr_multi 0.25 \ + model.pooling_shape $pooling_shape \ + scheduler.is_videochat2_custom True \ + preprocess.mm_alone False \ + preprocess.random_shuffle False \ + preprocess.add_second_msg False + + + diff --git a/tasks/eval/demo/__init__.py b/tasks/eval/demo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b18087bf6f52339838e51b5dad0e1a1ab17f43cc --- /dev/null +++ b/tasks/eval/demo/__init__.py @@ -0,0 +1,15 @@ +import gradio as gr +from gradio.themes.utils import colors, fonts, sizes + + +pllava_theme = gr.themes.Monochrome( + text_size="sm", + spacing_size="sm", + primary_hue=gr.themes.Color(c100="#f5f5f5", c200="#e5e5e5", c300="#d4d4d4", c400="#a3a3a3", c50="#fafafa", c500="#737373", c600="#525252", c700="#404040", c800="#262626", c900="#171717", c950="#000000"), + secondary_hue=gr.themes.Color(c100="#f5f5f5", c200="#e5e5e5", c300="#d4d4d4", c400="#a3a3a3", c50="#fafafa", c500="#737373", c600="#525252", c700="#404040", c800="#262626", c900="#171717", c950="#000000"), + neutral_hue=gr.themes.Color(c100="#f5f5f5", c200="#e5e5e5", c300="#d4d4d4", c400="#a3a3a3", c50="#fafafa", c500="#737373", c600="#525252", c700="#404040", c800="#262626", c900="#171717", c950="#000000"), +).set( + background_fill_primary_dark='*primary_950', + background_fill_secondary_dark='*neutral_950' +) + diff --git a/tasks/eval/demo/pllava_demo.py b/tasks/eval/demo/pllava_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..6126db37b766e25665a8c4a7a7a4f10ab7958bb7 --- /dev/null +++ b/tasks/eval/demo/pllava_demo.py @@ -0,0 +1,261 @@ +from argparse import ArgumentParser +import copy +import gradio as gr +from gradio.themes.utils import colors, fonts, sizes + +from utils.easydict import EasyDict +from tasks.eval.model_utils import load_pllava +from tasks.eval.eval_utils import ( + ChatPllava, + conv_plain_v1, + Conversation, + conv_templates +) +from tasks.eval.demo import pllava_theme + +SYSTEM="""You are Pllava, a large vision-language assistant. +You are able to understand the video content that the user provides, and assist the user with a variety of tasks using natural language. +Follow the instructions carefully and explain your answers in detail based on the provided video. +""" +INIT_CONVERSATION: Conversation = conv_plain_v1.copy() + + +# ======================================== +# Model Initialization +# ======================================== +def init_model(args): + + print('Initializing PLLaVA') + model, processor = load_pllava( + args.pretrained_model_name_or_path, args.num_frames, + use_lora=args.use_lora, + weight_dir=args.weight_dir, + lora_alpha=args.lora_alpha, + use_multi_gpus=args.use_multi_gpus) + if not args.use_multi_gpus: + model = model.to('cuda') + chat = ChatPllava(model, processor) + return chat + + +# ======================================== +# Gradio Setting +# ======================================== +def gradio_reset(chat_state, img_list): + if chat_state is not None: + chat_state = INIT_CONVERSATION.copy() + if img_list is not None: + img_list = [] + return ( + None, + gr.update(value=None, interactive=True), + gr.update(value=None, interactive=True), + gr.update(placeholder='Please upload your video first', interactive=False), + gr.update(value="Upload & Start Chat", interactive=True), + chat_state, + img_list + ) + + +def upload_img(gr_img, gr_video, chat_state=None, num_segments=None, img_list=None): + print(gr_img, gr_video) + chat_state = INIT_CONVERSATION.copy() if chat_state is None else chat_state + img_list = [] if img_list is None else img_list + + if gr_img is None and gr_video is None: + return None, None, gr.update(interactive=True),gr.update(interactive=True, placeholder='Please upload video/image first!'), chat_state, None + if gr_video: + llm_message, img_list, chat_state = chat.upload_video(gr_video, chat_state, img_list, num_segments) + return ( + gr.update(interactive=True), + gr.update(interactive=True), + gr.update(interactive=True, placeholder='Type and press Enter'), + gr.update(value="Start Chatting", interactive=False), + chat_state, + img_list, + ) + if gr_img: + llm_message, img_list,chat_state = chat.upload_img(gr_img, chat_state, img_list) + return ( + gr.update(interactive=True), + gr.update(interactive=True), + gr.update(interactive=True, placeholder='Type and press Enter'), + gr.update(value="Start Chatting", interactive=False), + chat_state, + img_list + ) + + +def gradio_ask(user_message, chatbot, chat_state, system): + if len(user_message) == 0: + return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state + chat_state = chat.ask(user_message, chat_state, system) + chatbot = chatbot + [[user_message, None]] + return '', chatbot, chat_state + + +def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature): + llm_message, llm_message_token, chat_state = chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=200, num_beams=num_beams, temperature=temperature) + llm_message = llm_message.replace("", "") # handle + chatbot[-1][1] = llm_message + print(chat_state) + print(f"Answer: {llm_message}") + return chatbot, chat_state, img_list + + +def parse_args(): + parser = ArgumentParser() + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + required=True, + default='llava-hf/llava-1.5-7b-hf' + ) + parser.add_argument( + "--num_frames", + type=int, + required=True, + default=4, + ) + parser.add_argument( + "--use_lora", + action='store_true' + ) + parser.add_argument( + "--use_multi_gpus", + action='store_true' + ) + parser.add_argument( + "--weight_dir", + type=str, + required=False, + default=None, + ) + parser.add_argument( + "--conv_mode", + type=str, + required=False, + default=None, + ) + parser.add_argument( + "--lora_alpha", + type=int, + required=False, + default=None, + ) + parser.add_argument( + "--server_port", + type=int, + required=False, + default=7868, + ) + args = parser.parse_args() + return args + + +title = """

PLLAVA

""" +description = ( + """

+ # PLLAVA! +

+ - Upload A Video + - Press Upload + - Start Chatting + """ +) + +args = parse_args() + +model_description = f""" + # MODEL INFO + - pretrained_model_name_or_path:{args.pretrained_model_name_or_path} + - use_lora:{args.use_lora} + - weight_dir:{args.weight_dir} +""" + +# with gr.Blocks(title="InternVideo-VideoChat!",theme=gvlabtheme,css="#chatbot {overflow:auto; height:500px;} #InputVideo {overflow:visible; height:320px;} footer {visibility: none}") as demo: +with gr.Blocks(title="PLLaVA", + theme=pllava_theme, + css="#chatbot {overflow:auto; height:500px;} #InputVideo {overflow:visible; height:320px;} footer {visibility: none}") as demo: + gr.Markdown(title) + gr.Markdown(description) + gr.Markdown(model_description) + with gr.Row(): + with gr.Column(scale=0.5, visible=True) as video_upload: + # with gr.Column(elem_id="image", scale=0.5) as img_part: + with gr.Tab("Video", elem_id='video_tab'): + up_video = gr.Video(interactive=True, include_audio=True, elem_id="video_upload", height=360) + with gr.Tab("Image", elem_id='image_tab'): + up_image = gr.Image(type="pil", interactive=True, elem_id="image_upload", height=360) + upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary") + clear = gr.Button("Restart") + + # num_segments = gr.Slider( + # minimum=8, + # maximum=64, + # value=8, + # step=1, + # interactive=True, + # label="Video Segments", + # ) + + with gr.Column(visible=True) as input_raws: + system_string = gr.Textbox(SYSTEM, interactive=True, label='system') + num_beams = gr.Slider( + minimum=1, + maximum=5, + value=1, + step=1, + interactive=True, + label="beam search numbers", + ) + temperature = gr.Slider( + minimum=0.1, + maximum=2.0, + value=1.0, + step=0.1, + interactive=True, + label="Temperature", + ) + + chat_state = gr.State() + img_list = gr.State() + chatbot = gr.Chatbot(elem_id="chatbot",label='Conversation') + with gr.Row(): + with gr.Column(scale=0.7): + text_input = gr.Textbox(show_label=False, placeholder='Please upload your video first', interactive=False, container=False) + with gr.Column(scale=0.15, min_width=0): + run = gr.Button("💭Send") + with gr.Column(scale=0.15, min_width=0): + clear = gr.Button("🔄Clear") + + with gr.Row(): + examples = gr.Examples( + examples=[ + ['example/jesse_dance.mp4', 'What is the man doing?'], + ['example/yoga.mp4', 'What is the woman doing?'], + ['example/cooking.mp4', 'Describe the background, characters and the actions in the provided video.'], + # ['example/cooking.mp4', 'What is happening in the video?'], + ['example/working.mp4', 'Describe the background, characters and the actions in the provided video.'], + ['example/1917.mp4', 'Describe the background, characters and the actions in the provided video.'], + ], + inputs=[up_video, text_input] + ) + + + chat = init_model(args) + INIT_CONVERSATION = conv_templates[args.conv_mode] + upload_button.click(upload_img, [up_image, up_video, chat_state], [up_image, up_video, text_input, upload_button, chat_state, img_list]) + + text_input.submit(gradio_ask, [text_input, chatbot, chat_state, system_string], [text_input, chatbot, chat_state]).then( + gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list] + ) + run.click(gradio_ask, [text_input, chatbot, chat_state, system_string], [text_input, chatbot, chat_state]).then( + gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list] + ) + run.click(lambda: "", None, text_input) + clear.click(gradio_reset, [chat_state, img_list], [chatbot, up_image, up_video, text_input, upload_button, chat_state, img_list], queue=False) + +# demo.queue(max_size=5) +demo.launch(share=True,server_port=args.server_port) +# demo.launch(server_name="0.0.0.0", server_port=10034, enable_queue=True) diff --git a/tasks/eval/demo/show_compare.py b/tasks/eval/demo/show_compare.py new file mode 100644 index 0000000000000000000000000000000000000000..d7accf685a3db3e7428f6a861d1e53028b5b216a --- /dev/null +++ b/tasks/eval/demo/show_compare.py @@ -0,0 +1,124 @@ + + +import argparse +import json +import os +import os.path as osp +import gradio as gr +import numpy as np + +from tasks.eval.recaption import load_results as load_results_recaption +from tasks.eval.mvbench import load_results as load_results_mvbench +from tasks.eval.vcgbench import load_results as load_results_vcgbench +from tasks.eval.videoqabench import load_results as load_results_videoqabench +from tasks.eval.demo import pllava_theme + + +load_results_funcs = [ + load_results_recaption, + load_results_mvbench, + load_results_vcgbench, + load_results_videoqabench, +] + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--root_dir', + required=True, + ) + args = parser.parse_args() + return args + +args = parse_args() +root_dir = args.root_dir + +def show(result_list_first, result_list_second, result_index): + sample2index_second = {} + + for i, result in enumerate(result_list_second): + if 'video_path' not in result: + continue + + question = result['question'] if 'question' in result else '' + video_path = result['video_path'] + samplehash = question + '--' +video_path + sample2index_second[samplehash] = i + + info = result_list_first[result_index] + info_str_first = json.dumps(info, indent=4, ensure_ascii=False) + video_path = info['video_path'] + question = info['question'] if 'question' in info else '' + samplehash = question + '--' +video_path + if samplehash in sample2index_second: + info = result_list_second[sample2index_second[samplehash]] + info_str_second = json.dumps(info, indent=4, ensure_ascii=False) + else: + info_str_second = f"NO {video_path} IN THE SECOND RESULT DIR" + return video_path, info_str_first, info_str_second + +def reload_results_dirs(): + result_dirs = [] + # load result dir paths + for dirpath, dirnames, filenames in os.walk(args.root_dir): + if len(dirnames) == 0 and len(filenames) != 0: + result_dirs.append(dirpath) + return gr.Dropdown(result_dirs, value=result_dirs[0]) + +def reload_results(result_dir): + # if isinstance(result_dir, list): + # result_dir = result_dir[0] + + if result_dir is None or not osp.exists(result_dir): + return None + + for fn in load_results_funcs: + result_list = fn(result_dir) + if result_list is not None: + np.random.shuffle(result_list) + break + result_index = gr.Slider(0, len(result_list), step=1) + + return result_list, result_index + + + +with gr.Blocks(title="PLLAVA RESULTS", theme=pllava_theme) as demo: + result_list_first = gr.State() + result_list_second = gr.State() + + with gr.Row(): + with gr.Column(): + gr.Markdown("# Showing off Model's Outputs.") + gr.Markdown( + "You can find all our results, including:\n" + "1. results of Captioned Inter4k\n" + "2. results of Different Benchmark inference outputs.\n" + "Choose a directory to see the different output variant.\n" + "You can also choose secondary directory (as long as they are from the same dataset.) to compare on the results.\n" + ) + + with gr.Row(): + with gr.Column(): + show_video = gr.Video(interactive=False) + + with gr.Column(): + button_reload = gr.Button(value='Reload From The Evaluation/Inference Root Directory') + result_index = gr.Slider(0, 0, step=1, label="Index") + + result_dir_first = gr.Dropdown(label='Test Result Path') + info_first = gr.Text(interactive=False, label='Detailed Output Information') + result_dir_second = gr.Dropdown(label='Test Result Path') + info_second = gr.Text(interactive=False, label='Detailed Output Information') + + + button_reload.click(reload_results_dirs, [], [result_dir_first]) + button_reload.click(reload_results_dirs, [], [result_dir_second]) + result_dir_first.change(reload_results, [result_dir_first], [result_list_first, result_index]) + result_dir_second.change(reload_results, [result_dir_second], [result_list_second, result_index]) + result_index.change(show, [result_list_first, result_list_second, result_index], [show_video, info_first, info_second]) + demo.load(reload_results_dirs, [], [result_dir_first]) + demo.load(reload_results_dirs, [], [result_dir_second]) + +demo.launch(share=True) \ No newline at end of file diff --git a/tasks/eval/demo/show_gallery.py b/tasks/eval/demo/show_gallery.py new file mode 100644 index 0000000000000000000000000000000000000000..7fc7725f5f37eab84deb6c8071d7e7895579964d --- /dev/null +++ b/tasks/eval/demo/show_gallery.py @@ -0,0 +1,94 @@ + + +import argparse +import json +import os +import os.path as osp +import gradio as gr + +from tasks.eval.recaption import load_results as load_results_recaption +from tasks.eval.mvbench import load_results as load_results_mvbench +from tasks.eval.vcgbench import load_results as load_results_vcgbench +from tasks.eval.videoqabench import load_results as load_results_videoqabench + +load_results_funcs = [ + load_results_recaption, + load_results_mvbench, + load_results_vcgbench, + load_results_videoqabench, +] + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--root_dir', + required=True, + ) + args = parser.parse_args() + return args + +args = parse_args() +root_dir = args.root_dir + +def show(result_list, result_index): + info = result_list[result_index] + video_path = info['video_path'] + info_str = json.dumps(info, indent=4) + return video_path, info_str + +def reload_results_dirs(): + result_dirs = [] + # load result dir paths + for dirpath, dirnames, filenames in os.walk(args.root_dir): + if len(dirnames) == 0 and len(filenames) != 0: + result_dirs.append(dirpath) + return gr.Dropdown(result_dirs, value=result_dirs[0]) + +def reload_results(result_dir): + # if isinstance(result_dir, list): + # result_dir = result_dir[0] + + if result_dir is None or not osp.exists(result_dir): + return None + + for fn in load_results_funcs: + result_list = fn(result_dir) + if result_list is not None: + break + + result_index = gr.Slider(0, len(result_list), step=1) + + return result_list, result_index + +with gr.Blocks() as demo: + result_list = gr.State() + + with gr.Row(): + gr.Markdown("# Showing of what has came out.") + + with gr.Row(): + with gr.Column(scale=1): + gr.Markdown(f"### From Saved Results Directory {args.root_dir}") + + with gr.Column(scale=2): + result_dir = gr.Dropdown(label='Test Result Path') + button_reload = gr.Button(value='Reload From The Evaluation/Inference Root Directory') + + + + with gr.Row(): + with gr.Column(): + show_video = gr.Video(interactive=False) + + with gr.Column(): + result_index = gr.Slider(0, 0, step=1, label="Index") + info = gr.Text(interactive=False, label='Detailed Output Information') + + + button_reload.click(reload_results_dirs, [], [result_dir]) + result_dir.change(reload_results, [result_dir], [result_list, result_index]) + result_index.change(show, [result_list, result_index], [show_video, info]) + demo.load(reload_results_dirs, [], [result_dir]) + +demo.launch(share=True) \ No newline at end of file diff --git a/tasks/eval/eval_utils.py b/tasks/eval/eval_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3aabad1c2a33dec15ba4997d8b7f004519c05376 --- /dev/null +++ b/tasks/eval/eval_utils.py @@ -0,0 +1,517 @@ +import copy +import itertools +import re +import os +import json +from enum import auto, Enum +import dataclasses +from typing import Any, List + +from PIL import Image +import cv2 +import imageio +import numpy as np +import torch +from torch.utils.data import Dataset +import torchvision.transforms as T +from torchvision.transforms.functional import InterpolationMode +from moviepy.editor import VideoFileClip + + +from decord import VideoReader, cpu # This is Terrible, if you have this line of import in front of torch, will cause model.to(device) to hang +from transformers import StoppingCriteria, StoppingCriteriaList +from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection + +from utils.easydict import EasyDict + +IMAGE_TOKEN = "" +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + + +class SeparatorStyle(Enum): + """Different separator style.""" + SINGLE = auto() + TWO = auto() + MPT = auto() + +class MultiModalConvStyle(Enum): + """Different separator style.""" + MM_ALONE = 'mm_alone' + MM_INTERLEAF = 'mm_inferleaf' + +def dump_json(obj_serializable ,save_dir_path, json_file_name): + os.makedirs(save_dir_path, exist_ok=True) + save_path = os.path.join(save_dir_path, json_file_name) + with open(save_path, 'w', encoding='utf-8') as f: + json.dump(obj_serializable, f, indent=4, ensure_ascii=False, ) + +def load_json(load_dir_path, json_file_name): + + load_path = os.path.join(load_dir_path, json_file_name) + if not os.path.exists(load_path): + return None + with open(load_path, 'r', encoding='utf-8') as f: + obj_serializable = json.load(f) + return obj_serializable + + + +@dataclasses.dataclass +class Conversation(EasyDict): + """A class that keeps all conversation history.""" + system: str + roles: List[str] + messages: List[List[str]] + sep: List[str] + mm_token: str + + mm_style: MultiModalConvStyle = MultiModalConvStyle.MM_INTERLEAF + pre_query_prompt: str=None + post_query_prompt: str=None + answer_prompt: str=None + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if isinstance(self.sep, str): + self.sep = [self.sep for _ in self.roles] + + def get_prompt(self): + sep = [self.sep for _ in self.roles] if isinstance(self.sep, str) else self.sep # if only one sep given, then both sep are the sames + sep = dict(zip(self.roles, sep)) + ret = self.system + sep[self.roles[0]] if self.system != "" else "" + for i, (role, message) in enumerate(self.messages): + # if is last msg(the prompt for assistant), if answer prompt exists, no sep added + if i+1 == len(self.messages): + if role != self.roles[-1]: # last role is not the model + ret += role + message + sep[role] + self.roles[-1] + else: + ret += role + message + else: + ret += role + message + sep[role] + return ret + # def get_prompt_multichoice(self): + # pass + def user_query(self, query=None, pre_query_prompt=None, post_query_prompt=None, is_mm=False, num_mm_token=1): + if post_query_prompt is not None: + query = f"{query} {post_query_prompt}" + + if pre_query_prompt is not None: + query = f"{pre_query_prompt} {query}" + role = self.roles[0] + # TODO: remove the num_mm_token and hack the self.mm_token outside + if is_mm: + mm_str = num_mm_token*self.mm_token[:-1] + self.mm_token[-1] + if self.mm_style == MultiModalConvStyle.MM_ALONE: + self._append_message(role, mm_str) + elif self.mm_style == MultiModalConvStyle.MM_INTERLEAF: + if self.mm_token not in query: + query = f'{mm_str} {query}' + self._append_message(role, query) + + def assistant_response(self, response, pre_query_prompt=None, post_query_prompt=None): + if post_query_prompt is not None: + response = f"{response} {post_query_prompt}" + + if pre_query_prompt is not None: + response = f"{post_query_prompt} {response}" + + role = self.roles[1] + self._append_message(role, response) + + def _append_message(self, role, message): + message = '' if message is None else message + self.messages.append([role, message]) + + def copy(self): + return copy.deepcopy(self) + +conv_video_chatgpt_v1 = Conversation( + system="You are Video-ChatGPT, a large vision-language assistant. " + "You are able to understand the video content that the user provides, and assist the user with a variety of tasks using natural language." + "Follow the instructions carefully and explain your answers in detail based on the provided video.", + roles=("USER:", "ASSISTANT:"), + messages=[], + sep=[" ",""], + mm_token='', + mm_style=MultiModalConvStyle.MM_INTERLEAF, +) + + +conv_plain_v1 = Conversation( + system="", + roles=("USER:", "ASSISTANT:"), + messages=[], + sep=(" ", ""), + mm_token='' +) + +# Attention to the roles[0] "USER: " has a space! +conv_eval_vcg = Conversation( + system="You are Video-ChatGPT, a large vision-language assistant. " + "You are able to understand the video content that the user provides, and assist the user with a variety of tasks using natural language." + "Follow the instructions carefully and explain your answers in detail based on the provided video.", + roles=("USER: ", "ASSISTANT:"), + messages=[], + sep=[" ",""], + mm_token='\n', + mm_style=MultiModalConvStyle.MM_ALONE, +) + +conv_eval_vcg_llavanext = Conversation( + system="You are Video-ChatGPT, a large vision-language assistant. " + "You are able to understand the video content that the user provides, and assist the user with a variety of tasks using natural language." + "Follow the instructions carefully and explain your answers in detail based on the provided video.", + roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), + messages=[], + sep=["<|im_end|>\n","<|im_end|>\n"], + mm_token='\n', + mm_style=MultiModalConvStyle.MM_ALONE, +) + +SYSTEM_MVBENCH="Carefully watch the video and pay attention to the cause and sequence of events, the detail and movement of objects, and the action and pose of persons. Based on your observations, select the best option that accurately addresses the question.\n" +conv_eval_mvbench = Conversation( + system=SYSTEM_MVBENCH, + roles=("USER: ", "ASSISTANT:"), + messages=[], + sep=[" ",""], + mm_token='\n', + mm_style=MultiModalConvStyle.MM_ALONE, +) +conv_eval_mvbench_llavanext = Conversation( + system="You are Video-ChatGPT, a large vision-language assistant. " + "You are able to understand the video content that the user provides, and assist the user with a variety of tasks using natural language." + "Follow the instructions carefully and explain your answers in detail based on the provided video.", + roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), + messages=[], + sep=["<|im_end|>\n","<|im_end|>\n"], + mm_token='\n', + mm_style=MultiModalConvStyle.MM_ALONE, +) + + +conv_eval_videoqabench = Conversation( + system="", + roles=("USER: ", "ASSISTANT:"), + messages=[], + sep=[" ",""], + mm_token='\n', + mm_style=MultiModalConvStyle.MM_INTERLEAF, + pre_query_prompt="The input consists of a sequence of key frames from a video. Answer the question concisely first and followed by significant events, characters, or objects that appear throughout the frames. Question:", + post_query_prompt="\n", + answer_prompt='\nAnswer: In the video,' +) + +conv_eval_videoqa_llavanext = Conversation( + system="<|im_start|>system\nAnswer the question.", + roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), + messages=[], + sep=["<|im_end|>\n","<|im_end|>\n"], + mm_token='\n', + mm_style=MultiModalConvStyle.MM_INTERLEAF, + pre_query_prompt="The input consists of a sequence of key frames from a video. Answer the question concisely first and followed by significant events, characters, or objects that appear throughout the frames. Question:", + post_query_prompt="\n", + answer_prompt='\nAnswer: In the video,' +) + + +SYSTEM_RECAPTION="""You are a powerful Video Magic ChatBot, a large vision-language assistant. +You are able to understand the video content that the user provides and assist the user in a video recaptioning task. +The user will provide you with the video and maybe some extra noisy information to help you out. Make use of the information in a proper way to be competent for the recaption job +### INSTRUCTIONS: +1. Follow the user's instruction. +2. Be critical yet believe in yourself. +""" +conv_eval_recaption = Conversation( + system=SYSTEM_RECAPTION, + roles=("USER: ", "ASSISTANT:"), + messages=[], + sep=[" ",""], + mm_token='\n', + mm_style=MultiModalConvStyle.MM_ALONE, +) + + +conv_eval_recaption_llavanext = Conversation( + system=SYSTEM_RECAPTION, + roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), + messages=[], + sep=["<|im_end|>\n","<|im_end|>\n"], + mm_token='\n', + mm_style=MultiModalConvStyle.MM_ALONE, +) + + +conv_templates = { + "plain": conv_plain_v1, + "eval_vcgbench": conv_eval_vcg, + "eval_vcg_llavanext": conv_eval_vcg_llavanext, + "eval_mvbench": conv_eval_mvbench, + "eval_mvbench_llavanext": conv_eval_mvbench_llavanext, + "eval_videoqabench": conv_eval_videoqabench, + "eval_videoqa_llavanext": conv_eval_videoqa_llavanext, + "eval_recaption": conv_eval_recaption, + "eval_recaption_llavanext": conv_eval_recaption_llavanext, +} + + +class EvalDataset(Dataset): + + def __init__(self, num_segments, test_ratio=None): + super().__init__() + self.num_segments = num_segments + self.test_ratio = test_ratio + self.decord_method = { + 'video': self.read_video, + 'gif': self.read_clip_gif, + 'frame': self.read_frame, + } + + def __getitem__(self, index) -> Any: + raise NotImplementedError('') + + def __str__(self): + len_list = {} + option_list = {} + for data in self.data_list: + if data['task_type'] not in len_list: + len_list[data['task_type']] = 0 + len_list[data['task_type']] += 1 + if data['task_type'] not in option_list: + option_list[data['task_type']] = 0 + option_list[data['task_type']] += len(data['data']['candidates']) + + correct = 0 + total = 0 + res = f"There are {len(self.data_list)} videos as follow:\n" + for k, v in len_list.items(): + correct += len_list[k] + total += option_list[k] + res += f"{v} for {k} ({option_list[k]} options => {len_list[k]/option_list[k]*100:.2f}%)\n" + correct = correct + 1 / option_list[k] + res += f"Total random accuracy: {correct/total*100:.2f}%" + return res.rstrip() + + def __len__(self): + return len(self.data_list) + + def get_index(self, bound, fps, max_frame, first_idx=0): + if bound: + start, end = bound[0], bound[1] + else: + start, end = -100000, 100000 + start_idx = max(first_idx, round(start * fps)) + end_idx = min(round(end * fps), max_frame) + seg_size = float(end_idx - start_idx) / self.num_segments + frame_indices = np.array([ + int(start_idx + (seg_size / 2) + np.round(seg_size * idx)) + for idx in range(self.num_segments) + ]) + return frame_indices + + def read_video(self, video_path, bound=None): + vr = VideoReader(video_path, ctx=cpu(0), num_threads=4) + max_frame = len(vr) - 1 + fps = float(vr.get_avg_fps()) + + images_group = list() + frame_indices = self.get_index(bound, fps, max_frame, first_idx=0) + for frame_index in frame_indices: + img = Image.fromarray(vr[frame_index].asnumpy()) + images_group.append(img) + return images_group + + def read_gif(self, video_path, bound=None, fps=25): + gif = imageio.get_reader(video_path) + max_frame = len(gif) - 1 + + images_group = list() + frame_indices = self.get_index(bound, fps, max_frame, first_idx=0) + for index, frame in enumerate(gif): + if index in frame_indices: + img = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB) + img = Image.fromarray(img) + images_group.append(img) + if len(images_group) == len(frame_indices): + break + + # might be some really short videos in the gif datasets + if len(images_group) < self.num_segments: + multiplier = int(self.num_segments/len(images_group)) + 1 + images_group = [image for _ in range(multiplier) for image in images_group][:self.num_segments] + assert len(images_group) == self.num_segments + + return images_group + + def read_clip_gif(self, video_path, bound=None, fps=25): + gif = VideoFileClip(video_path) + frames = gif.iter_frames() + max_frame = gif.reader.nframes - 1 + images_group = list() + frame_indices = self.get_index(bound, fps, max_frame, first_idx=0) + for index, frame in enumerate(frames): + if index in frame_indices: + img = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB) + img = Image.fromarray(img) + images_group.append(img) + + # might be some really short videos in the gif datasets + if len(images_group) < self.num_segments: + multiplier = int(self.num_segments/len(images_group)) + 1 + images_group = [image for _ in range(multiplier) for image in images_group][:self.num_segments] + assert len(images_group) == self.num_segments + + return images_group + + def read_frame(self, video_path, bound=None, fps=3): + max_frame = len(os.listdir(video_path)) + images_group = list() + frame_indices = self.get_index(bound, fps, max_frame, first_idx=1) # frame_idx starts from 1 + for frame_index in frame_indices: + img = Image.open(os.path.join(video_path, f"{frame_index:05d}.jpg")) + images_group.append(img) + return images_group + + def set_rank_and_world_size(self, rank, world_size): + self.rank = rank + self.world_size = world_size + # self.data_list = self.data_list[::200] # debug + if self.test_ratio is None: + self.data_list = self.data_list[rank::world_size] + else: + np.random.RandomState(42).shuffle(self.data_list) + if isinstance(self.test_ratio, float): + num_samples = int(len(self.data_list) * self.test_ratio) + else: + num_samples = int(self.test_ratio) + self.data_list = self.data_list[rank:num_samples:world_size] + + +class ChatPllava: + print_res=True + do_sample=False + def __init__(self, model, processor): + self.model = model + self.processor = processor + + def ask(self, text, conv: Conversation, system): + conv.system = system + conv.user_query(text, ) + return conv + + def answer(self, conv: Conversation, img_list, max_new_tokens=200, num_beams=1, min_length=1, top_p=0.9, + repetition_penalty=1.0, length_penalty=1, temperature=1.0): + torch.cuda.empty_cache() + prompt = conv.get_prompt() + if prompt.count(conv.mm_token) < len(img_list): + diff_mm_num = len(img_list) - prompt.count(conv.mm_token) + for i in range(diff_mm_num): + conv.user_query("", is_mm=True) + prompt = conv.get_prompt() + + inputs = self.processor(text=prompt, images=img_list, return_tensors="pt") + if inputs['pixel_values'] is None: + inputs.pop('pixel_values') + inputs = inputs.to(self.model.device) + + with torch.no_grad(): + output_token = self.model.generate(**inputs, media_type='video', + do_sample=self.do_sample,max_new_tokens=max_new_tokens, num_beams=num_beams, min_length=min_length, + top_p=top_p, repetition_penalty=repetition_penalty, length_penalty=length_penalty, temperature=temperature, + ) # dont need to long for the choice. + output_text = self.processor.batch_decode(output_token, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + + if self.print_res: + print('###PROMPT: ', prompt) + print('###LM OUTPUT TEXT', output_text) + # <|im_start|> encode and then decode would extend a space at folloing, this is insane... + if conv.roles[-1] == "<|im_start|>assistant\n": + split_tag = "<|im_start|> assistant\n" + else: + split_tag = conv.roles[-1] + output_text = output_text.split(split_tag)[-1].rstrip(conv.sep[1]) + conv.assistant_response(output_text) + return output_text, output_token.cpu().numpy(), conv + + + def get_index(self, num_frames, num_segments): + seg_size = float(num_frames - 1) / num_segments + start = int(seg_size / 2) + offsets = np.array([ + start + int(np.round(seg_size * idx)) for idx in range(num_segments) + ]) + return offsets + + def load_video(self, video_path, num_segments=8, return_msg=False): + vr = VideoReader(video_path, ctx=cpu(0)) + num_frames = len(vr) + frame_indices = self.get_index(num_frames, num_segments) + + duration = len(vr) // vr.get_avg_fps() + index = np.linspace(0, len(vr)-1, num=int(duration)) + buffer = vr.get_batch(index).asnumpy() + # transform + + images_group = list() + for frame in buffer: + img = Image.fromarray(frame) + images_group.append(img) + images_group = list() + for frame_index in frame_indices: + img = Image.fromarray(vr[frame_index].asnumpy()) + images_group.append(img) + if return_msg: + fps = float(vr.get_avg_fps()) + sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices]) + # " " should be added in the start and end + msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds." + return images_group, msg + else: + return images_group + + def upload_video(self, image, conv: Conversation, img_list: list[list], num_segments=None): + num_segments = self.model.config.num_frames if num_segments is None else num_segments + if isinstance(image, str): # is a image path + vid, msg = self.load_video(image, num_segments=num_segments, return_msg=True) + else: + raise NotImplementedError + print("Input video shape:", len(vid), *vid[0].size) + img_list.append(vid) + conv.user_query("", is_mm=True) + msg = "Received." + # self.conv.append_message(self.conv.roles[1], msg) + return msg, img_list, conv + + def upload_img(self, image, conv, img_list): + assert False + img = image#Image.open(image)#.convert('RGB') + transform = T.Compose( + [ + T.Resize( + (224, 224), interpolation=InterpolationMode.BICUBIC + ), + T.ToTensor(), + T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ] + ) + + img = transform(img).unsqueeze(0).unsqueeze(0).cuda() + image_emb, _ = self.model.encode_img(img, "Observe the image and answer the question.") + img_list.append(image_emb) + conv.messages.append([ + conv.roles[0], + f"\n" + ]) + msg = "Received." + # self.conv.append_message(self.conv.roles[1], msg) + return msg,img_list, conv + +class StoppingCriteriaSub(StoppingCriteria): + def __init__(self, stops=[], encounters=1): + super().__init__() + self.stops = stops + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): + for stop in self.stops: + if torch.all((stop == input_ids[0][-len(stop):])).item(): + return True + return False diff --git a/tasks/eval/model_utils.py b/tasks/eval/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b9396464ee5e42da2c5032ec2fa892c9d0c3efc7 --- /dev/null +++ b/tasks/eval/model_utils.py @@ -0,0 +1,172 @@ + +import torch +import os +from peft import get_peft_model, LoraConfig, TaskType +from safetensors import safe_open +from peft import PeftModel +from tasks.eval.eval_utils import Conversation +from models.pllava import PllavaProcessor, PllavaForConditionalGeneration, PllavaConfig +from accelerate import init_empty_weights, dispatch_model, infer_auto_device_map,load_checkpoint_in_model +from accelerate.utils import get_balanced_memory + +from transformers import StoppingCriteria +class KeywordsStoppingCriteria(StoppingCriteria): + def __init__(self, keywords, tokenizer, input_ids): + self.keywords = keywords + self.tokenizer = tokenizer + self.start_len = None + self.input_ids = input_ids + + def __call__( + self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs + ) -> bool: + if self.start_len is None: + self.start_len = self.input_ids.shape[1] + return False + else: + outputs = self.tokenizer.batch_decode( + output_ids[:, self.start_len:], skip_special_tokens=True + ) + flag = True + for output in outputs: + for keyword in self.keywords: + if keyword not in output: + flag = False + return False + return flag + + +def load_pllava(repo_id, num_frames, use_lora=False, weight_dir=None, lora_alpha=32, use_multi_gpus=False, pooling_shape=(16,12,12)): + kwargs = { + 'num_frames': num_frames, + } + # print("===============>pooling_shape", pooling_shape) + if num_frames == 0: + kwargs.update(pooling_shape=(0,12,12)) # produce a bug if ever usen the pooling projector + config = PllavaConfig.from_pretrained( + repo_id if not use_lora else weight_dir, + pooling_shape=pooling_shape, + **kwargs, + ) + + with torch.no_grad(): + model = PllavaForConditionalGeneration.from_pretrained(repo_id, config=config, torch_dtype=torch.bfloat16) + + try: + processor = PllavaProcessor.from_pretrained(repo_id) + except Exception as e: + processor = PllavaProcessor.from_pretrained('llava-hf/llava-1.5-7b-hf') + + # config lora + if use_lora and weight_dir is not None: + print("Use lora") + peft_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, inference_mode=False, target_modules=["q_proj", "v_proj"], + r=128, lora_alpha=lora_alpha, lora_dropout=0. + ) + print("Lora Scaling:", lora_alpha/128) + model.language_model = get_peft_model(model.language_model, peft_config) + assert weight_dir is not None, "pass a folder to your lora weight" + print("Finish use lora") + + # load weights + if weight_dir is not None: + state_dict = {} + save_fnames = os.listdir(weight_dir) + if "model.safetensors" in save_fnames: + use_full = False + for fn in save_fnames: + if fn.startswith('model-0'): + use_full=True + break + else: + use_full= True + + if not use_full: + print("Loading weight from", weight_dir, "model.safetensors") + with safe_open(f"{weight_dir}/model.safetensors", framework="pt", device="cpu") as f: + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + else: + print("Loading weight from", weight_dir) + for fn in save_fnames: + if fn.startswith('model-0'): + with safe_open(f"{weight_dir}/{fn}", framework="pt", device="cpu") as f: + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + + if 'model' in state_dict.keys(): + msg = model.load_state_dict(state_dict['model'], strict=False) + else: + msg = model.load_state_dict(state_dict, strict=False) + print(msg) + # dispatch model weight + if use_multi_gpus: + max_memory = get_balanced_memory( + model, + max_memory=None, + no_split_module_classes=["LlamaDecoderLayer"], + dtype='bfloat16', + low_zero=False, + ) + + device_map = infer_auto_device_map( + model, + max_memory=max_memory, + no_split_module_classes=["LlamaDecoderLayer"], + dtype='bfloat16' + ) + + dispatch_model(model, device_map=device_map) + print(model.hf_device_map) + + model = model.eval() + + return model, processor + + +def load_adapters(model, adapter_model_name_or_paths): + + for adapter_model_name_or_path in adapter_model_name_or_paths: + if not isinstance(model, PeftModel): + model = PeftModel.from_pretrained(model, adapter_model_name_or_path, adapter_model_name_or_path) + else: + model.load_adapter(adapter_model_name_or_path, adapter_model_name_or_path) + + return model + + +def pllava_answer(conv: Conversation, model, processor, img_list, do_sample=True, max_new_tokens=200, num_beams=1, min_length=1, top_p=0.9, + repetition_penalty=1.0, length_penalty=1, temperature=1.0, stop_criteria_keywords=None, print_res=False): + # torch.cuda.empty_cache() + prompt = conv.get_prompt() + inputs = processor(text=prompt, images=img_list, return_tensors="pt") + if inputs['pixel_values'] is None: + inputs.pop('pixel_values') + inputs = inputs.to(model.device) + + # set up stopping criteria + if stop_criteria_keywords is not None: + stopping_criteria = [KeywordsStoppingCriteria(stop_criteria_keywords, processor.tokenizer, inputs["input_ids"])] + else: + stopping_criteria= None + + with torch.no_grad(): + output_token = model.generate(**inputs, media_type='video', + do_sample=do_sample, max_new_tokens=max_new_tokens, num_beams=num_beams, min_length=min_length, + top_p=top_p, repetition_penalty=repetition_penalty, length_penalty=length_penalty, temperature=temperature, + stopping_criteria=stopping_criteria,) + output_text = processor.batch_decode(output_token, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + if "###" in output_text: + output_text = "###".join(output_text.split('###')[:-1]) # remove the stop sign '###' + if print_res: # debug usage + print('### PROMPTING LM WITH: ', prompt) + print('### LM OUTPUT TEXT: ', output_text) + if conv.roles[-1] == "<|im_start|>assistant\n": + split_tag = "<|im_start|> assistant\n" + else: + split_tag = conv.roles[-1] + output_text = output_text.split(split_tag)[-1].rstrip(conv.sep if isinstance(conv.sep, str) else conv.sep[1]).strip() + conv.messages[-1][1] = output_text + return output_text, conv + diff --git a/tasks/eval/mvbench/__init__.py b/tasks/eval/mvbench/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f6f7df338b72f28d273360beda9cd45ce24262fc --- /dev/null +++ b/tasks/eval/mvbench/__init__.py @@ -0,0 +1,173 @@ +import os +import json +from tasks.eval.eval_utils import ( + dump_json, + load_json, + EvalDataset, +) + + +def check_ans(pred, gt): + flag = False + + pred_list = pred.lower().split(' ') + pred_option, pred_content = pred_list[0], ' '.join(pred_list[1:]) + gt_list = gt.lower().split(' ') + gt_option, gt_content = gt_list[0], ' '.join(gt_list[1:]) + if gt_content[-1] == '.': + gt_content = gt_content[:-1] + + if not any([c in pred_option for c in 'abcdefgABCDEFG']): + print(f"model doesn't follow instructions: {pred}") + elif pred_option.replace('.', '') in gt_option: + flag = True + elif gt_option in pred_option: + flag = True + + return flag + +def save_results(result_list, save_path): + + final_res, acc_dict = {}, {} + correct, total = 0, 0 + for res in result_list: + task_type = res['task_type'] + if task_type not in acc_dict: + acc_dict[task_type] = [0, 0] # correct, total + acc_dict[task_type][1] += 1 + total += 1 + pred = res['pred'] + gt = res['gt'] + if check_ans(pred=pred, gt=gt): + acc_dict[task_type][0] += 1 + correct += 1 + + for k, v in acc_dict.items(): + final_res[k] = v[0] / v[1] * 100 + correct += v[0] + total += v[1] + final_res['Avg'] = correct / total * 100 + + all_results = { + "acc_dict": acc_dict, + "result_list": result_list + } + dump_json(all_results, save_path, 'all_results.json') + dump_json(final_res, save_path, 'upload_leaderboard.json') + +def load_results(save_path): + all_results = load_json(save_path, 'all_results.json') + if all_results is not None: + result_list = all_results['result_list'] + else: + result_list = None + # json_data = load_json(save_path, 'all_results.json')['result_list'] + return result_list + +class MVBenchDataset(EvalDataset): + data_list_info = { + # "task_type (sub task name)": ("json file name", "image/video prefix", "data_type", "bound") + "Action Sequence": ("action_sequence.json", "DATAS/MVBench/video/star/Charades_v1_480/", "video", True), # has start & end + "Action Prediction": ("action_prediction.json", "DATAS/MVBench/video/star/Charades_v1_480/", "video", True), # has start & end + "Action Antonym": ("action_antonym.json", "DATAS/MVBench/video/ssv2_video/", "video", False), + "Fine-grained Action": ("fine_grained_action.json", "DATAS/MVBench/video/Moments_in_Time_Raw/videos/", "video", False), + "Unexpected Action": ("unexpected_action.json", "DATAS/MVBench/video/FunQA_test/test/", "video", False), + "Object Existence": ("object_existence.json", "DATAS/MVBench/video/clevrer/video_validation/", "video", False), + "Object Interaction": ("object_interaction.json", "DATAS/MVBench/video/star/Charades_v1_480/", "video", True), # has start & end + "Object Shuffle": ("object_shuffle.json", "DATAS/MVBench/video/perception/videos/", "video", False), + "Moving Direction": ("moving_direction.json", "DATAS/MVBench/video/clevrer/video_validation/", "video", False), + "Action Localization": ("action_localization.json", "DATAS/MVBench/video/sta/sta_video/", "video", True), # has start & end + "Scene Transition": ("scene_transition.json", "DATAS/MVBench/video/scene_qa/video/", "video", False), + "Action Count": ("action_count.json", "DATAS/MVBench/video/perception/videos/", "video", False), + "Moving Count": ("moving_count.json", "DATAS/MVBench/video/clevrer/video_validation/", "video", False), + "Moving Attribute": ("moving_attribute.json", "DATAS/MVBench/video/clevrer/video_validation/", "video", False), + "State Change": ("state_change.json", "DATAS/MVBench/video/perception/videos/", "video", False), + "Fine-grained Pose": ("fine_grained_pose.json", "DATAS/MVBench/video/nturgbd/", "video", False), + "Character Order": ("character_order.json", "DATAS/MVBench/video/perception/videos/", "video", False), + "Egocentric Navigation": ("egocentric_navigation.json", "DATAS/MVBench/video/vlnqa/", "video", False), + "Episodic Reasoning": ("episodic_reasoning.json", "DATAS/MVBench/video/tvqa/frames_fps3_hq/", "frame", True), # has start & end, read frame + "Counterfactual Inference": ("counterfactual_inference.json", "DATAS/MVBench/video/clevrer/video_validation/", "video", False), + } + data_dir = "DATAS/MVBench/json" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + data_list_info = self.data_list_info + data_dir = self.data_dir + + self.data_list = [] + for k, v in data_list_info.items(): + with open(os.path.join(data_dir, v[0]), 'r') as f: + json_data = json.load(f) + for data in json_data: + self.data_list.append({ + 'task_type': k, + 'prefix': v[1], + 'data_type': v[2], + 'bound': v[3], + 'data': data + }) + # self.data_list = self.data_list[:100] # for debug + self.decord_method = { + 'video': self.read_video, + 'gif': self.read_gif, + 'frame': self.read_frame, + } + + # # transform + # crop_size = resolution + # scale_size = resolution + # input_mean = [0.48145466, 0.4578275, 0.40821073] + # input_std = [0.26862954, 0.26130258, 0.27577711] + # self.transform = T.Compose([ + # GroupScale(int(scale_size), interpolation=InterpolationMode.BICUBIC), + # GroupCenterCrop(crop_size), + # Stack(), + # ToTorchFormatTensor(), + # GroupNormalize(input_mean, input_std) + # ]) + + def __getitem__(self, idx): + question, answer = self.qa_template(self.data_list[idx]['data']) + task_type = self.data_list[idx]['task_type'] + decord_method = self.decord_method[self.data_list[idx]['data_type']] + bound = None + if self.data_list[idx]['bound']: + bound = ( + self.data_list[idx]['data']['start'], + self.data_list[idx]['data']['end'], + ) + video_path = os.path.join(self.data_list[idx]['prefix'], self.data_list[idx]['data']['video']) + + + # images_group = decord_method(video_path, bound) + try: # might be problem with decord + images_group = decord_method(video_path, bound) + except Exception as e: + print(f'error decoding {video_path}') + task_type = 'error_reading_video' + images_group = None + + return { + 'video_path': video_path, + 'video_pils': images_group, # some might use the original pils and do their own transforms + 'question': question, + 'answer': answer, + 'task_type': task_type, + } + + + def qa_template(self, data): + question = f"Question: {data['question']}\n" + question += "Options:\n" + answer = data['answer'] + answer_idx = -1 + for idx, c in enumerate(data['candidates']): + question += f"({chr(ord('A') + idx)}) {c}\n" + if c == answer: + answer_idx = idx + question = question.rstrip() + answer = f"({chr(ord('A') + answer_idx)}) {answer}" + return question, answer + diff --git a/tasks/eval/mvbench/pllava_eval_mvbench.py b/tasks/eval/mvbench/pllava_eval_mvbench.py new file mode 100644 index 0000000000000000000000000000000000000000..117785e9b860b965a1549ba64163de516cf9748f --- /dev/null +++ b/tasks/eval/mvbench/pllava_eval_mvbench.py @@ -0,0 +1,278 @@ + +import functools +import itertools +import logging +from tqdm import tqdm +from PIL import Image +from multiprocessing import Pool +import multiprocessing as mp +from argparse import ArgumentParser +import numpy as np + +import torch +import torchvision + +from decord import VideoReader, cpu +import transformers + + +from tasks.eval.model_utils import load_pllava, pllava_answer +from tasks.eval.eval_utils import conv_templates +from tasks.eval.mvbench import ( + MVBenchDataset, + check_ans, + save_results, + load_results, +) + +logging.basicConfig() +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +RESOLUTION = 672 # + + +def parse_args(): + parser = ArgumentParser() + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + required=True, + default='llava-hf/llava-1.5-7b-hf' + ) + parser.add_argument( + "--save_path", + type=str, + required=True, + default='"./test_results/test_llava_mvbench"' + ) + parser.add_argument( + "--num_frames", + type=int, + required=True, + default=4, + ) + parser.add_argument( + "--use_lora", + action='store_true' + ) + parser.add_argument( + "--lora_alpha", + type=int, + required=False, + default=32, + ) + parser.add_argument( + "--weight_dir", + type=str, + required=False, + default=None, + ) + parser.add_argument( + "--conv_mode", + type=str, + required=False, + default='eval_mvbench', + ) + parser.add_argument( + "--pooling_shape", + type=str, + required=False, + default=None, + ) + args = parser.parse_args() + return args + +def load_model_and_dataset(rank, world_size, pretrained_model_name_or_path, num_frames, use_lora, lora_alpha, weight_dir, pooling_shape=(16,12,12)): + # remind that, once the model goes larger (30B+) may cause the memory to be heavily used up. Even Tearing Nodes. + model, processor = load_pllava(pretrained_model_name_or_path, num_frames=num_frames, use_lora=use_lora, weight_dir=weight_dir, lora_alpha=lora_alpha, pooling_shape=pooling_shape) + logger.info('done loading llava') + + # position embedding + model = model.to(torch.device(rank)) + model = model.eval() + + dataset = MVBenchDataset(num_segments=num_frames) + dataset.set_rank_and_world_size(rank, world_size) + return model, processor, dataset + +def infer_mvbench( + model, + processor, + data_sample, + conv_mode, + pre_query_prompt=None, # add in the head of question + post_query_prompt=None, # add in the end of question + answer_prompt=None, # add in the begining of answer + return_prompt=None, # add in the begining of return message + print_res=False, + ): + video_list = data_sample["video_pils"] + conv = conv_templates[conv_mode].copy() + conv.user_query(data_sample['question'], pre_query_prompt, post_query_prompt, is_mm=True) + if answer_prompt is not None: + conv.assistant_response(answer_prompt) + + llm_message, conv = pllava_answer( + conv=conv, + model=model, + processor=processor, + img_list=video_list, + max_new_tokens=100, + do_sample=False, + print_res=print_res + ) + + if answer_prompt is not None: + llm_message = ''.join(llm_message.split(answer_prompt)[1:]) + + if return_prompt is not None: + llm_message = return_prompt + llm_message + + return llm_message + +def single_test(model, processor, vid_path, num_frames=4, conv_mode="plain"): + def get_index(num_frames, num_segments): + seg_size = float(num_frames - 1) / num_segments + start = int(seg_size / 2) + offsets = np.array([ + start + int(np.round(seg_size * idx)) for idx in range(num_segments) + ]) + return offsets + + def load_video(video_path, num_segments=8, return_msg=False, num_frames=4, resolution=336): + transforms = torchvision.transforms.Resize(size=resolution) + vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) + num_frames = len(vr) + frame_indices = get_index(num_frames, num_segments) + images_group = list() + for frame_index in frame_indices: + img = Image.fromarray(vr[frame_index].asnumpy()) + images_group.append(transforms(img)) + if return_msg: + fps = float(vr.get_avg_fps()) + sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices]) + # " " should be added in the start and end + msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds." + return images_group, msg + else: + return images_group + + if num_frames != 0: + vid, msg = load_video(vid_path, num_segments=num_frames, return_msg=True, resolution=RESOLUTION) + else: + vid, msg = None, 'num_frames is 0, not inputing image' + img_list = vid + conv = conv_templates[conv_mode].copy() + conv.user_query("Describe the video in details.", is_mm=True) + llm_response, conv = pllava_answer(conv=conv, model=model, processor=processor, do_sample=False, img_list=img_list, max_new_tokens=256, print_res=True) + +def run(rank, args, world_size): + if rank != 0: + transformers.utils.logging.set_verbosity_error() + logger.setLevel(transformers.logging.ERROR) + + print_res = False + conv_mode= args.conv_mode + pre_query_prompt = None + post_query_prompt = "\nOnly give the best option." + if args.pooling_shape is not None: + pooling_shape=tuple([int(x) for x in args.pooling_shape.split("-")]) + + logger.info(f'loading model and constructing dataset to gpu {rank}...') + model, processor, dataset = load_model_and_dataset(rank, + world_size, + pretrained_model_name_or_path=args.pretrained_model_name_or_path, + num_frames=args.num_frames, + use_lora=args.use_lora, + lora_alpha=args.lora_alpha, + weight_dir=args.weight_dir, + pooling_shape=pooling_shape) + logger.info(f'done model and dataset...') + logger.info('constructing dataset...') + logger.info('single test...') + + vid_path = "./example/yoga.mp4" + # vid_path = "./example/jesse_dance.mp4" + if rank == 0: + single_test(model, + processor, + vid_path, + num_frames=args.num_frames, + conv_mode=args.conv_mode) + logger.info('single test done...') + tbar = tqdm(total=len(dataset)) + + correct = 0 + total = 0 + result_list = [] + acc_dict = {} + done_count = 0 + + for example in dataset: + task_type = example['task_type'] + if task_type not in acc_dict: + acc_dict[task_type] = [0, 0] # correct, total + acc_dict[task_type][1] += 1 + total += 1 + pred = infer_mvbench( + model, + processor, + example, + conv_mode=conv_mode, + pre_query_prompt=pre_query_prompt, + post_query_prompt=post_query_prompt, + answer_prompt="Best option:(", + return_prompt='(', + print_res=print_res, + ) + gt = example['answer'] + result_list.append({ + 'pred': pred, + 'gt': gt, + 'task_type': task_type, + 'video_path': example['video_path'], + 'question': example['question'], + + }) + if check_ans(pred=pred, gt=gt): + acc_dict[task_type][0] += 1 + correct += 1 + if rank == 0: + tbar.update(len(result_list) - done_count, ) + tbar.set_description_str( + f"One Chunk--Task Type: {task_type}, Chunk Part Acc: {acc_dict[task_type][0] / acc_dict[task_type][1] * 100 :.2f}%;" + f" Chunk Total Acc: {correct / total * 100 :.2f}%" + ) + done_count = len(result_list) + return result_list + +def main(): + multiprocess=True + mp.set_start_method('spawn') + args = parse_args() + save_path = args.save_path + json_data = load_results(save_path) + if json_data is None: + if multiprocess: + logger.info(f'started benchmarking, saving to: {save_path}') + n_gpus = torch.cuda.device_count() + # assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}" + world_size = n_gpus + with Pool(world_size) as pool: + func = functools.partial(run, args=args, world_size=world_size) + result_lists = pool.map(func, range(world_size)) + + logger.info('finished running') + result_list = [ res for res in itertools.chain(*result_lists)] + else: + result_list = run(0, world_size=1, args=args) # debug + + else: + logger.info(f'loaded results from {save_path}') + result_list = json_data + save_results(result_list, save_path) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tasks/eval/recaption/__init__.py b/tasks/eval/recaption/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e68d57bf16c6ca58666e4eac512957baad2938fe --- /dev/null +++ b/tasks/eval/recaption/__init__.py @@ -0,0 +1,293 @@ +from functools import partial +import os +import json +from typing import OrderedDict + +import tqdm +import torch +from PIL import Image +import ast +import numpy as np +from multiprocessing import Pool + +from decord import VideoReader, cpu + +import os +from tasks.eval.eval_utils import ( + dump_json, + load_json, + EvalDataset, +) +from dataclasses import dataclass +from openai import OpenAI +from utils.easydict import EasyDict +client = OpenAI( + # This is the default and can be omitted + api_key=os.environ.get("OPENAI_API_KEY"), +) + +task_type2chatgpt_contents = OrderedDict({ + "Panda70M": { + "system": "You are an intelligent chatbot designed for evaluating the correctness of generative outputs for video captioning. " + "Your task is to compare the predicted captioning with a provided hint (which is usually a ground truth caption provided by human labor or autmated captioning pipeline)." + "You should determine if they match meaningfully, logically and precisely. Here's how you can accomplish the task:" + "------" + "##INSTRUCTIONS: " + "- Focus on the meaningful match between the predicted answer and the correct answer.\n" + "- Consider synonyms or paraphrases as valid matches.\n" + "- Evaluate the correctness of the prediction compared to the answer.", + "user": """Please evaluate the following video-based Captioning pair:\n\n""" + """Caption: {caption}\n""" + """Predicted Caption: {pred}\n\n""" + """Provide your evaluation only as a yes/no and score where the score is an integer value between 0 and 5, with 5 indicating the highest meaningful match. """ + """Please generate the response in the form of a Python dictionary string with keys 'pred' and 'score', where value of 'pred' is a string of 'yes' or 'no' and value of 'score' is in INTEGER, not STRING.""" + """DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. """ + """For example, your response should look like this: {{'pred': 'yes', 'score': 4.8}}.""" + }, +}) + +# Follow the instructions carefully and be helpful and precise with your answer. + +def check_ans_recaption(pred, gt, task_type, model="gpt-3.5-turbo-0125"): + try: + # Compute the temporal understanding score + user_input = task_type2chatgpt_contents[task_type]['user'] + user_input = user_input.format(caption=gt, pred=pred) + completion = client.chat.completions.create( + model=model, + messages=[ + { + "role": "system", + "content": task_type2chatgpt_contents[task_type]['system'], + }, + { + "role": "user", + "content": user_input, + } + ] + ) + # Convert response to a Python dictionary. + # response_message = completion["choices"][0]["message"]["content"] + response_message = completion.choices[0].message.content + num_tokens_openai = completion.usage.total_tokens + response_dict = ast.literal_eval(response_message) + pred = response_dict['pred'] + score = response_dict['score'] + if not pred in ('yes', 'no') or not isinstance(score, (int, float)): + raise ValueError(f"{model} doesn't follow") + flag = pred == 'yes' + except Exception as e: + import traceback + traceback.print_exc() + flag, score, num_tokens_openai = False, 0, 0 + print( + f"GPT cannot deal with:\n" + f"--pred: {pred}\n" + f"--gt: {gt}\n" + f"--gpt responded: {response_message}\n" + "--will assign flag=False and score=0" + ) + print(f"Dumb Answer in {task_type}") + return flag, score, num_tokens_openai + +def chatgpt_eval(res, model="gpt-3.5-turbo-0125"): + pred = res['pred'] + gt = res['caption'] + task_type = res['task_type'] + correct, score, num_tokens_openai = check_ans_recaption(pred=pred, gt=gt,task_type=task_type, model=model) # acc is bool, score is given by chatgpt + # update the scores in result_list for this sample + res['score'] = score + res['correct'] = correct + res['num_tokens_openai'] = num_tokens_openai + return res + +def save_results(result_list, save_path, model="gpt-3.5-turbo-0125"): + dump_json(result_list, save_path, 'inference_results.json') + with Pool(7) as pool: + func = partial(chatgpt_eval, model=model) + result_list = [ res for res in tqdm.tqdm(pool.imap_unordered(func, result_list), total=len(result_list), desc='Language Chat Model Automated Evaluation...')] + + # result_list = [chatgpt_eval(res, model=model) for res in result_list] + + final_res, acc_dict = {}, {} + correct, total, total_score = 0, 0, 0 + for i, res in enumerate(result_list): + task_type = res['task_type'] + if task_type not in acc_dict: + acc_dict[task_type] = { + 'correct': 0, + 'total': 0, + 'score': 0, + } # correct, total + acc_dict[task_type]['total'] += 1 + acc_dict[task_type]['correct'] += res['correct'] + acc_dict[task_type]['score'] += res['score'] + + for k, v in acc_dict.items(): + final_res[k] = { + 'acc': v['correct'] / v['total'] * 100, + 'score': v['score'] / v['total'] + } + correct += v['correct'] + total += v['total'] + total_score += v['score'] + + final_res['Avg_Acc'] = correct / total * 100 + final_res['Avg_Score'] = total_score / total + + all_results = { + "acc_dict": acc_dict, + "result_list": result_list + } + dump_json(all_results, save_path, f'final_results-{model}.json') + dump_json(final_res, save_path, 'upload_leaderboard.json') + +def load_results(save_path, model="gpt-3.5-turbo-0125"): + result_list = load_json(save_path, f'final_results-{model}.json') + if result_list is not None: + result_list = result_list['result_list'] + + if result_list is None: + result_list = load_json(save_path, 'inference_results.json') + + return result_list + +class CaptionSample(EasyDict): + def get_info(self): + return {} + +class RecaptionSample(EasyDict): + caption: str + def get_info(self): + # template = ("""To facilitate success in the task, I'll offer hints from the automated image captioning pipeline's output on the frames. """ + # """Please note that this information may contain noise but remains descriptive.""" + # """Presented below are the noisy details:\n""" + # """Hint: {hint}\n""" + # """The hint comprises noisy captions generated for certain frames in the video. """ + # """Please refrain from disclosing the original hints provided; instead, provide rewritten accurate information.""") + # hint = template.format(hint=self.hint,) + return { + "noisy_caption": self.caption + } + +class RecaptionSampleWithMatchingScore(EasyDict): + caption: str + matching_score: float + + def get_info(self): + # template = ("""To facilitate success in the task, I'll offer hints from the automated image captioning pipeline's output on the frames. """ + # """Please note that this information may contain noise but remains descriptive.""" + # """Presented below are the noisy details:\n""" + # """Hint: {hint}\n""" + # """Matching Score: {matching_score:.02f}\n""" + # """The hint comprises noisy captions generated for certain frames in the video. """ + # """Matching scores indicate the likelihood of these captions matching the original frames.\n""" + # """Please refrain from disclosing the original hints provided; instead, provide rewritten accurate information.""" + # ) + + # hint = template.format(hint=self.hint, + # matching_score=self.matching_score) + info = { + "noisy_caption": self.caption, + "matching_score": self.matching_score, + } + # by far, might use some prompting. + return info + +class RecaptionDataset(EvalDataset): + data_dir = "DATAS/Recaption" + data_list_info = OrderedDict({ + # "Panda70M": OrderedDict( + # json_relpath="Panda70M/annotations.json", + # prefix="DATAS/Recaption/Panda70M/videos", + # data_type="video", + # bound=False, + # key_rename_map={ + # # 'caption': 'hint', + # }, + # name_key='video_name', + # postfix=('mp4', 'mkv', 'webm'), + # recaption_type=RecaptionSample, + # ), # don't has start & end + "Inter4K": OrderedDict( + json_relpath="Inter4K/annotations.json", + prefix="DATAS/Recaption/Inter4K/60fps/UHD", + data_type="video", + bound=False, + key_rename_map={ + # 'caption': 'hint', + }, + name_key='video_name', + postfix=('mp4', 'mkv', 'webm'), + recaption_type=CaptionSample, + ), # don't has start & end + }) + + def __init__(self, *args, **kwargs): + # recaption's test_ratio should shuffle the dataset + test_ratio = kwargs.pop('test_ratio', None) + super().__init__(*args, **kwargs) + self.test_ratio = test_ratio + test_ratio = 1. if test_ratio is None else test_ratio + data_list_info = self.data_list_info + data_dir = self.data_dir + + self.data_list = [] + for k, v in data_list_info.items(): + with open(os.path.join(data_dir, v['json_relpath']), 'r') as f: + annotation_json_data = json.load(f) + + indexs = list(range(len(annotation_json_data))) + np.random.RandomState(42).shuffle(indexs) + num_samples = int(len(indexs) * test_ratio) if 0 < test_ratio <= 1 else int(test_ratio) + indexs = indexs[:num_samples] + for i in indexs: + annotation_data = annotation_json_data[i] + for key_old, key_new in v['key_rename_map'].items(): + # temporary renameing the keys + value = annotation_data.pop(key_old) + annotation_data[key_new] = value + + data = dict(annotation_data) + self.data_list.append({ + 'task_type': k, + 'data': data, + }) + + def __getitem__(self, idx): + task_type = self.data_list[idx]['task_type'] + decord_method = self.decord_method[self.data_list_info[task_type]['data_type']] + bound = None + + if self.data_list_info[task_type]['bound']: + bound = ( + self.data_list[idx]['data']['start'], + self.data_list[idx]['data']['end'], + ) + video_name_key = self.data_list_info[task_type]['name_key'] + video_name = self.data_list[idx]['data'][video_name_key] + + video_postfixs = self.data_list_info[task_type]['postfix'] + video_paths = [] + for p in video_postfixs: + video_path = os.path.join(self.data_list_info[task_type]['prefix'], video_name + '.' + p) + if os.path.exists(video_path): + video_paths.append(video_path) + assert len(video_paths) > 0, f'no video named {video_name}' + # video_filename = self.data_list[idx]['data'][video_name_key] + video_postfix + video_path = video_paths[0] + images_group = decord_method(video_path, bound) + + sample = self.data_list_info[task_type]['recaption_type'](**self.data_list[idx]['data'],) + info = sample.get_info() + + return { + 'video_pils': images_group, # some might use the original pils and do their own transforms + 'video_path': video_path, + 'info': info, + 'sample': sample, + 'task_type': task_type, + } + + + diff --git a/tasks/eval/recaption/pllava_recaption.py b/tasks/eval/recaption/pllava_recaption.py new file mode 100644 index 0000000000000000000000000000000000000000..8530b8ee181a5db7acaad6332734486d79c9b516 --- /dev/null +++ b/tasks/eval/recaption/pllava_recaption.py @@ -0,0 +1,294 @@ + +import functools +import itertools +import json +import logging +from tqdm import tqdm +from PIL import Image +from multiprocessing import Pool +from argparse import ArgumentParser +import multiprocessing as mp + + + +import numpy as np +import torch + +import torchvision + +import transformers +from decord import VideoReader, cpu + +from tasks.eval.model_utils import load_pllava, pllava_answer +from tasks.eval.eval_utils import conv_templates + +logging.basicConfig() +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +IMAGE_TOKEN='' +from tasks.eval.recaption import ( + RecaptionDataset, + load_results, + save_results, +) +RESOLUTION = 672 # + +def parse_args(): + parser = ArgumentParser() + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + required=True, + default='llava-hf/llava-1.5-7b-hf' + ) + parser.add_argument( + "--save_path", + type=str, + required=True, + default='"./test_results/test_llava_mvbench"' + ) + parser.add_argument( + "--num_frames", + type=int, + required=True, + default=4, + ) + parser.add_argument( + "--use_lora", + action='store_true' + ) + parser.add_argument( + "--lora_alpha", + type=int, + required=False, + default=32, + ) + parser.add_argument( + "--weight_dir", + type=str, + required=False, + default=None, + ) + parser.add_argument( + "--eval_model", + type=str, + required=False, + default="gpt-3.5-turbo-0125", + ) + parser.add_argument( + '--test_ratio', + type=float, + required=False, + default=None + ) + parser.add_argument( + "--conv_mode", + type=str, + required=False, + default='eval_videoqabench', + ) + args = parser.parse_args() + return args + +def load_model_and_dataset(rank, world_size, pretrained_model_name_or_path, num_frames, use_lora, lora_alpha, weight_dir, test_ratio): + # remind that, once the model goes larger (30B+) may cause the memory to be heavily used up. Even Tearing Nodes. + model, processor = load_pllava(pretrained_model_name_or_path, num_frames=num_frames, use_lora=use_lora, lora_alpha=lora_alpha, weight_dir=weight_dir) + logger.info('done loading llava') + # position embedding + model = model.to(torch.device(rank)) + model = model.eval() + + dataset = RecaptionDataset(test_ratio=test_ratio, num_segments=num_frames) + dataset.set_rank_and_world_size(rank, world_size) + return model, processor, dataset + +def infer_recaption( + model, + processor, + data_sample, + conv_mode, + pre_query_prompt=None, # add in the head of question + post_query_prompt=None, # add in the end of question + answer_prompt=None, # add in the begining of answer + return_prompt=None, # add in the begining of return message + print_res=False, + ): + video_list = data_sample["video_pils"] + conv = conv_templates[conv_mode].copy() + # info = data_sample['info'] + query = ( + "You are to assist me in accomplishing a task about the input video. Reply to me with a precise yet detailed response. For how you would succeed in the recaptioning task, read the following Instructions section and Then, make your response with a elaborate paragraph.\n" + "# Instructions\n" + "1. Avoid providing over detailed information such as color, counts of any objects as you are terrible regarding observing these details\n" + "2. Instead, you should carefully go over the provided video and reason about key information about the overall video\n" + "3. If you are not sure about something, do not include it in you response.\n" + "# Task\n" + "Describe the background, characters and the actions in the provided video.\n" + ) + conv.user_query(query, pre_query_prompt, post_query_prompt, is_mm=True) + if answer_prompt is not None: + conv.assistant_response(answer_prompt) + + llm_message, conv = pllava_answer( + conv=conv, + model=model, + processor=processor, + img_list=video_list, + max_new_tokens=400, + num_beams=1, + do_sample=False, + print_res=print_res + ) + + if answer_prompt is not None: + llm_message = ''.join(llm_message.split(answer_prompt)[1:]) + + if return_prompt is not None: + llm_message = return_prompt + llm_message + + return llm_message, query + +def single_test(model, processor, vid_path, num_frames=4, conv_mode="plain"): + def get_index(num_frames, num_segments): + seg_size = float(num_frames - 1) / num_segments + start = int(seg_size / 2) + offsets = np.array([ + start + int(np.round(seg_size * idx)) for idx in range(num_segments) + ]) + return offsets + + def load_video(video_path, num_segments=8, return_msg=False, num_frames=4, resolution=336): + transforms = torchvision.transforms.Resize(size=resolution) + vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) + num_frames = len(vr) + frame_indices = get_index(num_frames, num_segments) + images_group = list() + for frame_index in frame_indices: + img = Image.fromarray(vr[frame_index].asnumpy()) + images_group.append(transforms(img)) + if return_msg: + fps = float(vr.get_avg_fps()) + sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices]) + # " " should be added in the start and end + msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds." + return images_group, msg + else: + return images_group + + if num_frames != 0: + vid, msg = load_video(vid_path, num_segments=num_frames, return_msg=True, resolution=RESOLUTION) + else: + vid, msg = None, 'num_frames is 0, not inputing image' + img_list = vid + + conv = conv_templates[conv_mode].copy() + conv.user_query("Describe the video in details.", is_mm=True) + llm_response, conv = pllava_answer(conv=conv, model=model, processor=processor, do_sample=False, img_list=img_list, max_new_tokens=256, print_res=True) + +def run(rank, args, world_size): + if rank != 0: + transformers.utils.logging.set_verbosity_error() + logger.setLevel(transformers.logging.ERROR) + + print_res = True + conv_mode= args.conv_mode + pre_query_prompt = None + post_query_prompt = None + + # pre_query_prompt = ("""Assist me in detailing the background, characters, and actions depicted in the provided video.\n""") + # post_query_prompt = ("""My apologies for any lack of precision; there may be errors in the supplementary information provided.\n""" + # """You are encouraged to be discerning and perceptive, paying attention to the minutest details, """ + # """and to furnish a detailed yet precise description using eloquent language.""") + + logger.info(f'loading model and constructing dataset to gpu {rank}...') + model, processor, dataset = load_model_and_dataset(rank, + world_size, + pretrained_model_name_or_path=args.pretrained_model_name_or_path, + num_frames=args.num_frames, + use_lora=args.use_lora, + lora_alpha=args.lora_alpha, + weight_dir=args.weight_dir, + test_ratio=args.test_ratio) + logger.info(f'done model and dataset...') + logger.info('constructing dataset...') + logger.info('single test...') + vid_path = "./example/yoga.mp4" + # vid_path = "./example/jesse_dance.mp4" + if rank == 0: + single_test(model, processor, vid_path, num_frames=args.num_frames) + logger.info('single test done...') + tbar = tqdm(total=len(dataset)) + logger.info('single test...') + + result_list = [] + done_count = 0 + for example in dataset: + task_type = example['task_type'] + if task_type in dataset.data_list_info: + pred, query = infer_recaption( + model, + processor, + example, + conv_mode=conv_mode, + pre_query_prompt=pre_query_prompt, + post_query_prompt=post_query_prompt, + print_res=print_res, + ) + + infos = {k: v for k, v in example['sample'].items() if isinstance(v, (str, float, int))} + res = { + 'pred': pred, + 'task_type': task_type, + 'video_path': example['video_path'], + 'query': query, + **infos + } + else: + raise NotImplementedError(f'not implemented task type {task_type}') + # res = chatgpt_eval(res) + result_list.append(res) + if rank == 0: + tbar.update(len(result_list) - done_count, ) + tbar.set_description_str( + f"One Chunk--Task Type: {task_type}-" + f"pred: {pred[:min(15, len(pred))]}......" + ) + done_count = len(result_list) + return result_list + +def main(): + multiprocess=True + mp.set_start_method('spawn') + args = parse_args() + save_path = args.save_path + eval_model = args.eval_model + logger.info(f'trying loading results from {save_path}') + result_list = load_results(save_path, model=args.eval_model) + + if result_list is None: + if multiprocess: + + logger.info(f'started benchmarking, saving to: {save_path}') + n_gpus = torch.cuda.device_count() + # assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}" + world_size = n_gpus + with Pool(world_size) as pool: + func = functools.partial(run, args=args, world_size=world_size) + # func = functools.partial(run, world_size=world_size, model=model, dataset=dataset, result_list=[], acc_dict={}) + result_lists = pool.map(func, range(world_size)) + + logger.info('finished running') + + result_list = [ res for res in itertools.chain(*result_lists)] + else: + result_list = run(0, world_size=1, args=args) # debug + else: + logger.info(f'loaded results from {save_path}') + + save_results(result_list, save_path, model=eval_model) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tasks/eval/recaption/show_recaption.py b/tasks/eval/recaption/show_recaption.py new file mode 100644 index 0000000000000000000000000000000000000000..4b3c00a0a775d485ff71c48dff2ea0d16ff446ec --- /dev/null +++ b/tasks/eval/recaption/show_recaption.py @@ -0,0 +1,52 @@ + +import argparse +import gradio as gr + +from tasks.eval.recaption import load_results +import json + +# example = videogallery().example_inputs() + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--save_path', + required=True, + ) + args = parser.parse_args() + return args + + +args = parse_args() +result_list = load_results(args.save_path) + + +def show(result_index, ): + info = result_list[result_index] + video_path = info['video_path'] + info_str = json.dumps(info, indent=4) + return video_path, info_str + + + +from tasks.eval.recaption import load_results + +with gr.Blocks() as demo: + gr.Markdown("# Showing of what has came out.") + gr.Markdown(f"From Saved Results {args.save_path}") + with gr.Row(): + with gr.Column(1): + show_video = gr.Video(interactive=False) + + with gr.Column(): + result_index = gr.Slider(0, len(result_list), step=1) + info = gr.Text(interactive=False) + + result_index.change(show, [result_index], [show_video, info]) + + + + + +demo.launch(share=True) diff --git a/tasks/eval/vcgbench/__init__.py b/tasks/eval/vcgbench/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bad2cdd0b7bdf89806b75eebeaf757b45c5a7a9e --- /dev/null +++ b/tasks/eval/vcgbench/__init__.py @@ -0,0 +1,397 @@ +import ast +import os +import json +from typing import OrderedDict +from multiprocessing import Pool +from functools import partial + +import tqdm + +from tasks.eval.eval_utils import ( + dump_json, + load_json, + EvalDataset, +) + +from openai import OpenAI +client = OpenAI( + # This is the default and can be omitted + api_key=os.environ.get("OPENAI_API_KEY"), +) + +sub_task_type2chatgpt_contents = OrderedDict({ + # general ones + 'temporal': { + "system": "You are an intelligent chatbot designed for evaluating the temporal understanding of generative outputs for video-based question-answer pairs. " + "Your task is to compare the predicted answer with the correct answer and determine if they correctly reflect the temporal sequence of events in the video content. Here's how you can accomplish the task:" + "------" + "##INSTRUCTIONS: " + "- Focus on the temporal consistency between the predicted answer and the correct answer. The predicted answer should correctly reflect the sequence of events or details as they are presented in the video content.\n" + "- Consider synonyms or paraphrases as valid matches, but only if the temporal order is maintained.\n" + "- Evaluate the temporal accuracy of the prediction compared to the answer.", + "user": "Please evaluate the following video-based question-answer pair:\n\n" + "Question: {question}\n" + "Correct Answer: {answer}\n" + "Predicted Answer: {pred}\n\n" + "Provide your evaluation only as a temporal accuracy score where the temporal accuracy score is an integer value between 0 and 5, with 5 indicating the highest level of temporal consistency. " + "Please generate the response in the form of a Python dictionary string with keys 'score', where its value is the temporal accuracy score in INTEGER, not STRING." + "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " + "For example, your response should look like this: {{'score': 4.8}}." + }, + "context": { + "system": "You are an intelligent chatbot designed for evaluating the contextual understanding of generative outputs for video-based question-answer pairs. " + "Your task is to compare the predicted answer with the correct answer and determine if the generated response aligns with the overall context of the video content. Here's how you can accomplish the task:" + "------" + "##INSTRUCTIONS: " + "- Evaluate whether the predicted answer aligns with the overall context of the video content. It should not provide information that is out of context or misaligned.\n" + "- The predicted answer must capture the main themes and sentiments of the video.\n" + "- Consider synonyms or paraphrases as valid matches.\n" + "- Provide your evaluation of the contextual understanding of the prediction compared to the answer.", + "user": "Please evaluate the following video-based question-answer pair:\n\n" + "Question: {question}\n" + "Correct Answer: {answer}\n" + "Predicted Answer: {pred}\n\n" + "Provide your evaluation only as a contextual understanding score where the contextual understanding score is an integer value between 0 and 5, with 5 indicating the highest level of contextual understanding. " + "Please generate the response in the form of a Python dictionary string with keys 'score', where its value is contextual understanding score in INTEGER, not STRING." + "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " + "For example, your response should look like this: {{'score': 4.8}}." + }, + 'detailed_orientation': { + "system": "You are an intelligent chatbot designed for evaluating the detail orientation of generative outputs for video-based question-answer pairs. " + "Your task is to compare the predicted answer with the correct answer and determine its level of detail, considering both completeness and specificity. Here's how you can accomplish the task:" + "------" + "##INSTRUCTIONS: " + "- Check if the predicted answer covers all major points from the video. The response should not leave out any key aspects.\n" + "- Evaluate whether the predicted answer includes specific details rather than just generic points. It should provide comprehensive information that is tied to specific elements of the video.\n" + "- Consider synonyms or paraphrases as valid matches.\n" + "- Provide a single evaluation score that reflects the level of detail orientation of the prediction, considering both completeness and specificity.", + "user": "Please evaluate the following video-based question-answer pair:\n\n" + "Question: {question}\n" + "Correct Answer: {answer}\n" + "Predicted Answer: {pred}\n\n" + "Provide your evaluation only as a detail orientation score where the detail orientation score is an integer value between 0 and 5, with 5 indicating the highest level of detail orientation. " + "Please generate the response in the form of a Python dictionary string with keys 'score', where its value is the detail orientation score in INTEGER, not STRING." + "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " + "For example, your response should look like this: {{'score': 4.8}}." + , + }, + "correctness": { + "system": "You are an intelligent chatbot designed for evaluating the factual accuracy of generative outputs for video-based question-answer pairs. " + "Your task is to compare the predicted answer with the correct answer and determine if they are factually consistent. Here's how you can accomplish the task:" + "------" + "##INSTRUCTIONS: " + "- Focus on the factual consistency between the predicted answer and the correct answer. The predicted answer should not contain any misinterpretations or misinformation.\n" + "- The predicted answer must be factually accurate and align with the video content.\n" + "- Consider synonyms or paraphrases as valid matches.\n" + "- Evaluate the factual accuracy of the prediction compared to the answer.", + "user": "Please evaluate the following video-based question-answer pair:\n\n" + "Question: {question}\n" + "Correct Answer: {answer}\n" + "Predicted Answer: {pred}\n\n" + "Provide your evaluation only as a factual accuracy score where the factual accuracy score is an integer value between 0 and 5, with 5 indicating the highest level of factual consistency. " + "Please generate the response in the form of a Python dictionary string with keys 'score', where its value is the factual accuracy score in INTEGER, not STRING." + "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " + "For example, your response should look like this: {{'score': 4.8}}." + + }, + "consistency": { + "system": "You are an intelligent chatbot designed for evaluating the consistency of generative outputs for similar video-based question-answer pairs. " + "You will be given two very similar questions, a common answer common to both the questions and predicted answers for the two questions ." + "Your task is to compare the predicted answers for two very similar question, with a common correct answer and determine if they are consistent. Here's how you can accomplish the task:" + "------" + "##INSTRUCTIONS: " + "- Focus on the consistency between the two predicted answers and the correct answer. Both predicted answers should correspond to the correct answer and to each other, and should not contain any contradictions or significant differences in the conveyed information.\n" + "- Both predicted answers must be consistent with each other and the correct answer, in terms of the information they provide about the video content.\n" + "- Consider synonyms or paraphrases as valid matches, but only if they maintain the consistency in the conveyed information.\n" + "- Evaluate the consistency of the two predicted answers compared to the correct answer.", + "user":"Please evaluate the following video-based question-answer pair:\n\n" + "Question 1: {question}\n" + "Question 2: {question1}\n" + "Correct Answer: {answer}\n" + "Predicted Answer to Question 1: {pred}\n" + "Predicted Answer to Question 2: {pred1}\n\n" + "Provide your evaluation only as a consistency score where the consistency score is an integer value between 0 and 5, with 5 indicating the highest level of consistency. " + "Please generate the response in the form of a Python dictionary string with keys 'score', where its value is the consistency score in INTEGER, not STRING." + "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " + "For example, your response should look like this: {{'score': 4.8}}." + + }, +}) + +SYSTEM_VCGBENCH=""" +You are Video-ChatGPT, a large vision-language assistant. +You are able to understand the video content that the user provides, and assist the user with a variety of tasks using natural language. +Follow the instructions carefully and explain your answers in detail based on the provided video. +""" + +def check_ans(gt, pred, question, sub_task_type, question1=None, pred1=None, model="gpt-3.5-turbo-0125"): + # # dummy + # print('-' * 10 + f'pred: {pred}') + # print('-' * 10 + f'gt: {gt}') + try: + # Compute the temporal understanding score + user_input = sub_task_type2chatgpt_contents[sub_task_type]['user'] + if question1 is not None and pred1 is not None: + assert sub_task_type == 'consistency', 'consistency has two answers' + user_input = user_input.format(question=question, answer=gt, pred=pred, pred1=pred1, question1=question1) + else: + user_input = user_input.format(question=question, answer=gt, pred=pred) + completion = client.chat.completions.create( + model=model, + messages=[ + { + "role": "system", + "content": sub_task_type2chatgpt_contents[sub_task_type]['system'], + }, + { + "role": "user", + "content": user_input, + } + ] + ) + # Convert response to a Python dictionary. + response_message = completion.choices[0].message.content + response_dict = ast.literal_eval(response_message) + flag, score = response_dict['score'] > 3, response_dict['score'] + except Exception as e: + import traceback + traceback.print_exc() + flag, score = False, 0 + print( + f"GPT cannot deal with:\n" + f"--pred: {pred},\n" + f"--gt: {gt}\n" + f"--gpt responded: {response_message}\n" + "--will assign flag=False and score=0" + ) + print(f"Dumb Answer in {sub_task_type}") + return flag, score + +def chatgpt_eval(res, model="gpt-3.5-turbo-0125"): + pred = res['pred'] + gt = res['gt'] + question=res['question'] + task_type = res['task_type'] + if task_type == 'generic_qa': + # eval three sub tasks for generic + for sub_task_type in ('context', 'detailed_orientation', 'correctness'): + if pred=="": + print("no pred") + score = 0 + else: + acc, score = check_ans(gt=gt, pred=pred, question=question, sub_task_type=sub_task_type, model=model) # acc is bool, score is given by chatgpt + # update the scores in result_list for this sample + res['scores'] = res.get('scores', {}) + res['scores'][sub_task_type] = score + elif task_type == 'temporal_qa': # only do temporal eval for temporal_qa + sub_task_type = 'temporal' + if pred=="": + print("no pred") + score = 0 + else: + acc, score = check_ans(gt=gt, pred=pred, question=question, sub_task_type=sub_task_type, model=model) # acc is bool, score is given by chatgpt + # update the scores in result_list for this sample + res['scores'] = res.get('scores', {}) + res['scores'][sub_task_type] = score + elif task_type == 'consistency_qa': # only do consistency eval for consistency_qa + sub_task_type = 'consistency' + assert 'pred1' in res and 'question1' in res, 'two questions and preds' + pred1 = res['pred1'] + question1 = res['question1'] + if pred=="" or pred1=="": + print("no pred") + score = 0 + else: + acc, score = check_ans( + gt=gt, pred=pred, pred1=pred1, question=question, question1=question1, + sub_task_type=sub_task_type, model=model) # acc is bool, score is given by chatgpt + # update the scores in result_list for this sample + res['scores'] = res.get('scores', {}) + res['scores'][sub_task_type] = score + else: + raise NotImplementedError(f'not implemented task type for {task_type}') + + return res + +def save_results(result_list, save_path, model="gpt-3.5-turbo-0125"): + dump_json(result_list, save_path, 'inference_results.json') + with Pool(7) as pool: + # result_list = pool.map(partial(chatgpt_eval, model=model), result_list) + func = partial(chatgpt_eval, model=model) + result_list = [ res for res in tqdm.tqdm(pool.imap_unordered(func, result_list), total=len(result_list), desc='Language Chat Model Automated Evaluation...')] + + final_res, acc_dict = {}, {} + correct, total, total_score = 0, 0, 0 + for i, res in enumerate(result_list): + task_type = res['task_type'] + for sub_task_type, score in res['scores'].items(): + if sub_task_type not in acc_dict: + acc_dict[sub_task_type] = { + 'correct': 0, + 'total': 0, + 'score': 0, + } # correct, total + correct = score > 3 + acc_dict[sub_task_type]['total'] += 1 + acc_dict[sub_task_type]['correct'] += correct + acc_dict[sub_task_type]['score'] += score + + for k, v in acc_dict.items(): + final_res[k] = { + 'acc': v['correct'] / v['total'] * 100, + 'score': v['score'] / v['total'] + } + correct += v['correct'] + total += v['total'] + total_score += v['score'] + final_res['Avg_Acc'] = correct / total * 100 + final_res['Avg_Score'] = total_score / total + + all_results = { + "acc_dict": acc_dict, + "result_list": result_list + } + result_post =f"-{model}" + dump_json(all_results, save_path, f'final_results{result_post}.json') + dump_json(final_res, save_path, f'upload_leaderboard{result_post}.json') + +def load_results(save_path, model="gpt-3.5-turbo-0125"): + + result_list = load_json(save_path, f'final_results-{model}.json') + if result_list is not None: + result_list = result_list['result_list'] + + if result_list is None: + result_list = load_json(save_path, 'inference_results.json') + + return result_list + +class VideoChatGPTBenchDataset(EvalDataset): + data_dir = "DATAS/VCGBench" + data_list_info = OrderedDict({ + "generic_qa": OrderedDict( + json_relpath="Zero_Shot_QA/Benchmarking_QA/generic_qa.json", + prefix="DATAS/VCGBench/Videos/Benchmarking", + data_type="video", + bound=False, + question_key='Q', + answer_key='A', + name_key='video_name', + postfix=('mp4', 'mkv'), + ), + "temporal_qa": OrderedDict( + json_relpath="Zero_Shot_QA/Benchmarking_QA/temporal_qa.json", + prefix="DATAS/VCGBench/Videos/Benchmarking", + data_type="video", + bound=False, + question_key='Q', + answer_key='A', + name_key='video_name', + postfix=('mp4', 'mkv'), + ), # don't has start & end + "consistency_qa": OrderedDict( + # consistency is quite different in evaluating, and also awkward, hold to later. + json_relpath="Zero_Shot_QA/Benchmarking_QA/consistency_qa.json", + prefix="DATAS/VCGBench/Videos/Benchmarking", + data_type="video", + bound=False, + question_key=('Q1', 'Q2'), + answer_key='A', + name_key='video_name', + postfix=('mp4', 'mkv'), + ), + }) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + data_list_info = self.data_list_info + data_dir = self.data_dir + + self.data_list = [] + for k, v in data_list_info.items(): + with open(os.path.join(data_dir, v['json_relpath']), 'r') as f: + json_data = json.load(f) + for data in json_data: + self.data_list.append({ + 'task_type': k, + 'data': data, + **v, # all the infos + }) + # self.data_list = self.data_list[:10] # for debug + # random.shuffle(self.data_list) # for debug + self.decord_method = { + 'video': self.read_video, + 'gif': self.read_gif, + 'frame': self.read_frame, + } + # # transform + # crop_size = resolution + # scale_size = resolution + # input_mean = [0.48145466, 0.4578275, 0.40821073] + # input_std = [0.26862954, 0.26130258, 0.27577711] + # self.transform = T.Compose([ + # GroupScale(int(scale_size), interpolation=InterpolationMode.BICUBIC), + # GroupCenterCrop(crop_size), + # Stack(), + # ToTorchFormatTensor(), + # GroupNormalize(input_mean, input_std) + # ]) + + def __getitem__(self, idx): + task_type = self.data_list[idx]['task_type'] + video_name_key = self.data_list[idx]['name_key'] + video_name = self.data_list[idx]['data'][video_name_key] + video_postfixs = self.data_list[idx]['postfix'] + + if self.num_segments != 0: + video_paths = [] + for p in video_postfixs: + video_path = os.path.join(self.data_list[idx]['prefix'], video_name + '.' + p) + if os.path.exists(video_path): + video_paths.append(video_path) + assert len(video_paths) > 0, f'no video named {video_name}' + # video_filename = self.data_list[idx]['data'][video_name_key] + video_postfix + video_path = video_paths[0] + decord_method = self.decord_method[self.data_list[idx]['data_type']] + bound = None + if self.data_list[idx]['bound']: + bound = ( + self.data_list[idx]['data']['start'], + self.data_list[idx]['data']['end'], + ) + images_group = decord_method(video_path, bound) + else: + # zero frame, no image + images_group = None + + data = { + 'video_path': video_path, + 'video_pils': images_group, # some might use the original pils and do their own transforms + 'task_type': task_type, + } + + + answer_key = self.data_list[idx]['answer_key'] + question_key = self.data_list[idx]['question_key'] + + if task_type == 'consistency_qa' and isinstance(question_key, tuple): + question=self.data_list[idx]['data'][question_key[0]] + question1=self.data_list[idx]['data'][question_key[1]] + answer=self.data_list[idx]['data'][answer_key] + + data.update({ + 'question': question, + 'question1': question1, + 'answer': answer, + }) + elif isinstance(question_key, str): + question=self.data_list[idx]['data'][question_key] + answer=self.data_list[idx]['data'][answer_key] + data.update({ + 'question': question, + 'answer': answer, + }) + else: + raise ValueError('') + + return data diff --git a/tasks/eval/vcgbench/pllava_eval_vcgbench.py b/tasks/eval/vcgbench/pllava_eval_vcgbench.py new file mode 100644 index 0000000000000000000000000000000000000000..7182a85cf2610fdf1a999bfc1e90b6fcf3fdee43 --- /dev/null +++ b/tasks/eval/vcgbench/pllava_eval_vcgbench.py @@ -0,0 +1,306 @@ + +import functools +import itertools +import logging +from tqdm import tqdm +from PIL import Image +from multiprocessing import Pool +import multiprocessing as mp +from argparse import ArgumentParser +import numpy as np + +import torch +import torchvision + +from decord import VideoReader, cpu +import transformers + + +from tasks.eval.model_utils import load_pllava, pllava_answer +from tasks.eval.eval_utils import conv_templates +from tasks.eval.vcgbench import ( + VideoChatGPTBenchDataset, + save_results, + load_results, +) + +logging.basicConfig() +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +RESOLUTION = 672 # + + +def parse_args(): + parser = ArgumentParser() + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + required=True, + default='llava-hf/llava-1.5-7b-hf' + ) + parser.add_argument( + "--save_path", + type=str, + required=True, + default='"./test_results/test_llava_mvbench"' + ) + parser.add_argument( + "--num_frames", + type=int, + required=True, + default=4, + ) + parser.add_argument( + "--use_lora", + action='store_true' + ) + parser.add_argument( + "--lora_alpha", + type=int, + required=False, + default=32, + ) + parser.add_argument( + "--weight_dir", + type=str, + required=False, + default=None, + ) + parser.add_argument( + "--eval_model", + type=str, + required=False, + default="gpt-3.5-turbo-0125", + ) + parser.add_argument( + "--conv_mode", + type=str, + required=False, + default='eval_vcgbench', + ) + parser.add_argument( + "--test_ratio", + required=False, + default=None, + ) + parser.add_argument( + "--pooling_shape", + type=str, + required=False, + default=None, + ) + args = parser.parse_args() + return args + +def load_model_and_dataset(rank, world_size, pretrained_model_name_or_path, num_frames, use_lora, lora_alpha, weight_dir, test_ratio, pooling_shape=(16,12,12)): + # remind that, once the model goes larger (30B+) may cause the memory to be heavily used up. Even Tearing Nodes., + model, processor = load_pllava(pretrained_model_name_or_path, num_frames=num_frames, use_lora=use_lora, weight_dir=weight_dir, lora_alpha=lora_alpha, pooling_shape=pooling_shape) + logger.info('done loading llava') + # position embedding + model = model.to(torch.device(rank)) + model = model.eval() + + dataset = VideoChatGPTBenchDataset(num_segments=num_frames, test_ratio=test_ratio) + dataset.set_rank_and_world_size(rank, world_size) + return model, processor, dataset + +def infer_vcgbench( + model, + processor, + data_sample, + conv_mode, + pre_query_prompt=None, # add in the head of question + post_query_prompt=None, # add in the end of question + print_res=False, + ): + video_list = data_sample["video_pils"] + conv = conv_templates[conv_mode].copy() + conv.user_query(data_sample['question'], pre_query_prompt, post_query_prompt, is_mm=True) + stop_criteria_keywords=["###","USER"] + + llm_message, conv = pllava_answer( + conv=conv, + model=model, + processor=processor, + img_list=video_list, + max_new_tokens=100, + do_sample=False, + print_res=print_res, + stop_criteria_keywords=stop_criteria_keywords + ) + + + return llm_message + +def single_test(model, processor, vid_path, num_frames=4, conv_mode="plain"): + def get_index(num_frames, num_segments): + seg_size = float(num_frames - 1) / num_segments + start = int(seg_size / 2) + offsets = np.array([ + start + int(np.round(seg_size * idx)) for idx in range(num_segments) + ]) + return offsets + + def load_video(video_path, num_segments=8, return_msg=False, num_frames=4, resolution=336): + transforms = torchvision.transforms.Resize(size=resolution) + vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) + num_frames = len(vr) + frame_indices = get_index(num_frames, num_segments) + images_group = list() + for frame_index in frame_indices: + img = Image.fromarray(vr[frame_index].asnumpy()) + images_group.append(transforms(img)) + if return_msg: + fps = float(vr.get_avg_fps()) + sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices]) + # " " should be added in the start and end + msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds." + return images_group, msg + else: + return images_group + + if num_frames != 0: + vid, msg = load_video(vid_path, num_segments=num_frames, return_msg=True, resolution=RESOLUTION) + else: + vid, msg = None, 'num_frames is 0, not inputing image' + img_list = vid + conv = conv_templates[conv_mode].copy() + conv.user_query("Describe the video in details.", is_mm=True) + llm_response, conv = pllava_answer(conv=conv, model=model, processor=processor, do_sample=False, img_list=img_list, max_new_tokens=256, print_res=True) + +def run(rank, args, world_size,start_rank=0): + if rank != 0: + transformers.utils.logging.set_verbosity_error() + logger.setLevel(transformers.logging.ERROR) + print_res = True + conv_mode= args.conv_mode + pre_query_prompt = None + post_query_prompt = None + + + logger.info(f"CONV_MODE: {conv_mode}") + + logger.info(f'loading model and constructing dataset to gpu {rank}...') + if args.pooling_shape is not None: + pooling_shape=tuple([int(x) for x in args.pooling_shape.split("-")]) + model, processor, dataset = load_model_and_dataset(rank, + world_size, + pretrained_model_name_or_path=args.pretrained_model_name_or_path, + num_frames=args.num_frames, + use_lora=args.use_lora, + weight_dir=args.weight_dir, + lora_alpha=args.lora_alpha, + test_ratio=args.test_ratio, + pooling_shape=pooling_shape) + logger.info(f'done model and dataset...') + logger.info('constructing dataset...') + logger.info('single test...') + vid_path = "./example/yoga.mp4" + if rank == 0: + single_test(model, + processor, + vid_path, + num_frames=args.num_frames, + conv_mode=args.conv_mode) + logger.info('single test done...') + tbar = tqdm(total=len(dataset)) + + result_list = [] + done_count = 0 + for example in dataset: + task_type = example['task_type'] + gt = example['answer'] + if task_type == 'consistency_qa': + assert 'question' in example and 'question1' in example, 'two questions' + pred = infer_vcgbench( + model, + processor, + example, + conv_mode=conv_mode, + pre_query_prompt=pre_query_prompt, + post_query_prompt=post_query_prompt, + print_res=print_res, + ) + # inference the other question + example['question'], example['question1'] = example['question1'], example['question'] + pred1 = infer_vcgbench( + model, + processor, + example, + conv_mode=conv_mode, + pre_query_prompt=pre_query_prompt, + post_query_prompt=post_query_prompt, + print_res=print_res, + ) + res = { + 'pred': pred, + 'pred1': pred1, + 'gt': gt, + 'video': example['video_path'], + 'task_type': task_type, + 'question': example['question'], + 'question1': example['question1'], + } + elif task_type in dataset.data_list_info: + pred = infer_vcgbench( + model, + processor, + example, + conv_mode=conv_mode, + pre_query_prompt=pre_query_prompt, + post_query_prompt=post_query_prompt, + print_res=print_res, + ) + res = { + 'pred': pred, + 'gt': gt, + 'video_path': example['video_path'], + 'question': example['question'], + 'task_type': task_type, + } + else: + raise NotImplementedError(f'not implemented task type {task_type}') + + result_list.append(res) + if rank == 0: + tbar.update(len(result_list) - done_count, ) + tbar.set_description_str( + f"One Chunk--Task Type: {task_type}-" + f"gt: {gt[:min(15, len(gt))]}......--pred: {pred[:min(15, len(gt))]}......" + ) + done_count = len(result_list) + return result_list + +def main(): + multiprocess=True + mp.set_start_method('spawn') + args = parse_args() + save_path = args.save_path + eval_model = args.eval_model + result_list = load_results(save_path) + start_rank=0 + + if result_list is None: + if multiprocess: + logger.info(f'started benchmarking, saving to: {save_path}') + n_gpus = torch.cuda.device_count() + # assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}" + world_size = n_gpus + with Pool(world_size) as pool: + func = functools.partial(run, args=args, world_size=world_size, start_rank=start_rank) + result_lists = pool.map(func, range(world_size)) + + logger.info('finished running') + result_list = [ res for res in itertools.chain(*result_lists)] + else: + result_list = run(0, world_size=1, args=args) # debug + + else: + logger.info(f'loaded results from {save_path}') + + save_results(result_list, save_path, model=eval_model) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tasks/eval/vcgbench/show_vcg.py b/tasks/eval/vcgbench/show_vcg.py new file mode 100644 index 0000000000000000000000000000000000000000..d1848f469b2b6dcb9fa884dc52212a91ab233cbc --- /dev/null +++ b/tasks/eval/vcgbench/show_vcg.py @@ -0,0 +1,45 @@ + +import argparse +import gradio as gr + +from tasks.eval.vcgbench import load_results +import json + +# example = videogallery().example_inputs() + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--save_path', + required=True, + ) + args = parser.parse_args() + return args + + +args = parse_args() +result_list = load_results(args.save_path) + + +def show(result_index, ): + info = result_list[result_index] + video_path = info['video_path'] + info_str = json.dumps(info, indent=4) + return video_path, info_str + +with gr.Blocks() as demo: + gr.Markdown( + f"# Showing The Results from {args.save_path}" + ) + with gr.Row(): + with gr.Column(): + show_video = gr.Video(interactive=False) + + with gr.Column(): + result_index = gr.Slider(0, len(result_list), step=1) + info = gr.Text(interactive=False) + + result_index.change(show, [result_index], [show_video, info]) + +demo.launch(share=True) diff --git a/tasks/eval/videoqabench/__init__.py b/tasks/eval/videoqabench/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..541495d41da665f28951f701147e02b0126295a2 --- /dev/null +++ b/tasks/eval/videoqabench/__init__.py @@ -0,0 +1,348 @@ +from functools import partial +import os +import json +from typing import OrderedDict + +import tqdm +import torch +from PIL import Image +import ast +import numpy as np +from multiprocessing import Pool + +from decord import VideoReader, cpu + +import os +from tasks.eval.eval_utils import ( + dump_json, + load_json, + EvalDataset, +) +from dataclasses import dataclass +from openai import OpenAI +client = OpenAI( + # This is the default and can be omitted + api_key=os.environ.get("OPENAI_API_KEY"), +) + +task_type2chatgpt_contents = OrderedDict({ + "MSVD_QA": { + "system": "You are an intelligent chatbot designed for evaluating the correctness of generative outputs for question-answer pairs. " + "Your task is to compare the predicted answer with the correct answer and determine if they match meaningfully. Here's how you can accomplish the task:" + "------" + "##INSTRUCTIONS: " + "- Focus on the meaningful match between the predicted answer and the correct answer.\n" + "- Consider synonyms or paraphrases as valid matches.\n" + "- Evaluate the correctness of the prediction compared to the answer.", + "user": """Please evaluate the following video-based question-answer pair:\n\n""" + """Question: {question}\n""" + """Correct Answer: {answer}\n""" + """Predicted Answer: {pred}\n\n""" + """Provide your evaluation only as a yes/no and score where the score is an integer value between 0 and 5, with 5 indicating the highest meaningful match. """ + """Please generate the response in the form of a Python dictionary string with keys 'pred' and 'score', where value of 'pred' is a string of 'yes' or 'no' and value of 'score' is in INTEGER, not STRING.""" + """DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. """ + """For example, your response should look like this: {{'pred': 'yes', 'score': 4.8}}.""" + }, + "MSRVTT_QA": { + "system": "You are an intelligent chatbot designed for evaluating the correctness of generative outputs for question-answer pairs. " + "Your task is to compare the predicted answer with the correct answer and determine if they match meaningfully. Here's how you can accomplish the task:" + "------" + "##INSTRUCTIONS: " + "- Focus on the meaningful match between the predicted answer and the correct answer.\n" + "- Consider synonyms or paraphrases as valid matches.\n" + "- Evaluate the correctness of the prediction compared to the answer.", + "user": """Please evaluate the following video-based question-answer pair:\n\n""" + """Question: {question}\n""" + """Correct Answer: {answer}\n""" + """Predicted Answer: {pred}\n\n""" + """Provide your evaluation only as a yes/no and score where the score is an integer value between 0 and 5, with 5 indicating the highest meaningful match. """ + """Please generate the response in the form of a Python dictionary string with keys 'pred' and 'score', where value of 'pred' is a string of 'yes' or 'no' and value of 'score' is in INTEGER, not STRING.""" + """DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. """ + """For example, your response should look like this: {{'pred': 'yes', 'score': 4.8}}.""" + # """Make sure you only response with text that Follows Python syntax. For example, your response should look like this: {'pred': 'yes', 'score': 4.8}.""" + }, + "ActivityNet": { + "system": "You are an intelligent chatbot designed for evaluating the correctness of generative outputs for question-answer pairs. " + "Your task is to compare the predicted answer with the correct answer and determine if they match meaningfully. Here's how you can accomplish the task:" + "------" + "##INSTRUCTIONS: " + "- Focus on the meaningful match between the predicted answer and the correct answer.\n" + "- Consider synonyms or paraphrases as valid matches.\n" + "- Evaluate the correctness of the prediction compared to the answer.", + "user": """Please evaluate the following video-based question-answer pair:\n\n""" + """Question: {question}\n""" + """Correct Answer: {answer}\n""" + """Predicted Answer: {pred}\n\n""" + """Provide your evaluation only as a yes/no and score where the score is an integer value between 0 and 5, with 5 indicating the highest meaningful match. """ + """Please generate the response in the form of a Python dictionary string with keys 'pred' and 'score', where value of 'pred' is a string of 'yes' or 'no' and value of 'score' is in INTEGER, not STRING.""" + """DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. """ + """For example, your response should look like this: {{'pred': 'yes', 'score': 4.8}}.""" + # """Make sure you only response with text that Follows Python syntax. For example, your response should look like this: {'pred': 'yes', 'score': 4.8}.""" + }, + "TGIF_QA": { + "system": "You are an intelligent chatbot designed for evaluating the correctness of generative outputs for question-answer pairs. " + "Your task is to compare the predicted answer with the correct answer and determine if they match meaningfully. Here's how you can accomplish the task:" + "------" + "##INSTRUCTIONS: " + "- Focus on the meaningful match between the predicted answer and the correct answer.\n" + "- Consider synonyms or paraphrases as valid matches.\n" + "- Evaluate the correctness of the prediction compared to the answer.", + "user": """Please evaluate the following video-based question-answer pair:\n\n""" + """Question: {question}\n""" + """Correct Answer: {answer}\n""" + """Predicted Answer: {pred}\n\n""" + """Provide your evaluation only as a yes/no and score where the score is an integer value between 0 and 5, with 5 indicating the highest meaningful match. """ + """Please generate the response in the form of a Python dictionary string with keys 'pred' and 'score', where value of 'pred' is a string of 'yes' or 'no' and value of 'score' is in INTEGER, not STRING.""" + """DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. """ + """For example, your response should look like this: {{'pred': 'yes', 'score': 4.8}}.""" + # """Make sure you only response with text that Follows Python syntax. For example, your response should look like this: {'pred': 'yes', 'score': 4.8}.""" + }, +}) + +# Follow the instructions carefully and be helpful and precise with your answer. + +def check_ans_qa(question, pred, gt, task_type, model="gpt-3.5-turbo-0125"): + try: + # Compute the temporal understanding score + user_input = task_type2chatgpt_contents[task_type]['user'] + user_input = user_input.format(question=question, answer=gt, pred=pred) + completion = client.chat.completions.create( + model=model, + messages=[ + { + "role": "system", + "content": task_type2chatgpt_contents[task_type]['system'], + }, + { + "role": "user", + "content": user_input, + } + ] + ) + # Convert response to a Python dictionary. + # response_message = completion["choices"][0]["message"]["content"] + response_message = completion.choices[0].message.content + response_dict = ast.literal_eval(response_message) + pred = response_dict['pred'] + score = response_dict['score'] + if not pred in ('yes', 'no') or not isinstance(score, (int, float)): + raise ValueError(f"{model} doesn't follow") + flag = pred == 'yes' + except Exception as e: + import traceback + traceback.print_exc() + flag, score = False, 0 + print( + f"GPT cannot deal with:\n" + f"--pred: {pred}\n" + f"--gt: {gt}\n" + f"--gpt responded: {response_message}\n" + "--will assign flag=False and score=0" + ) + print(f"Dumb Answer in {task_type}") + return flag, score + +def chatgpt_eval(res, model="gpt-3.5-turbo-0125"): + pred = res['pred'] + gt = res['gt'] + question=res['question'] + task_type = res['task_type'] + correct, score = check_ans_qa(question=question, pred=pred, gt=gt,task_type=task_type, model=model) # acc is bool, score is given by chatgpt + # update the scores in result_list for this sample + res['score'] = score + res['correct'] = correct + return res + +def save_results(result_list, save_path, model="gpt-3.5-turbo-0125"): + dump_json(result_list, save_path, 'inference_results.json') + with Pool(7) as pool: + func = partial(chatgpt_eval, model=model) + result_list = [ res for res in tqdm.tqdm(pool.imap_unordered(func, result_list), total=len(result_list), desc='Language Chat Model Automated Evaluation...')] + + # result_list = pool.map(partial(chatgpt_eval, model=model), result_list) + # result_list = [chatgpt_eval(res, model=model) for res in result_list] + + final_res, acc_dict = {}, {} + correct, total, total_score = 0, 0, 0 + for i, res in enumerate(result_list): + task_type = res['task_type'] + if task_type not in acc_dict: + acc_dict[task_type] = { + 'correct': 0, + 'total': 0, + 'score': 0, + } # correct, total + acc_dict[task_type]['total'] += 1 + acc_dict[task_type]['correct'] += res['correct'] + acc_dict[task_type]['score'] += res['score'] + + for k, v in acc_dict.items(): + final_res[k] = { + 'acc': v['correct'] / v['total'] * 100, + 'score': v['score'] / v['total'] + } + correct += v['correct'] + total += v['total'] + total_score += v['score'] + + final_res['Avg_Acc'] = correct / total * 100 + final_res['Avg_Score'] = total_score / total + + all_results = { + "acc_dict": acc_dict, + "result_list": result_list + } + dump_json(all_results, save_path, 'all_results.json') + dump_json(final_res, save_path, 'upload_leaderboard.json') + +def load_results(save_path): + json_data = load_json(save_path, 'inference_results.json') + return json_data + +@dataclass +class OpenendQASample(): + question: str + answer: str + + + +class VideoQABenchDataset(EvalDataset): + data_dir = "DATAS/VideoQA" + data_list_info = OrderedDict({ + "MSVD_QA": OrderedDict( + q_json_relpath="MSVD_Zero_Shot_QA/test_q.json", + a_json_relpath="MSVD_Zero_Shot_QA/test_a.json", + prefix="DATAS/VideoQA/MSVD_Zero_Shot_QA/videos", + data_type="video", + bound=False, + question_key='question', + answer_key='answer', + name_key='video_name', + postfix=('avi',), + ), + "MSRVTT_QA": OrderedDict( + q_json_relpath="MSRVTT_Zero_Shot_QA/test_q.json", + a_json_relpath="MSRVTT_Zero_Shot_QA/test_a.json", + prefix="DATAS/VideoQA/MSRVTT_Zero_Shot_QA/videos/all", + data_type="video", + bound=False, + question_key='question', + answer_key='answer', + name_key='video_name', + postfix=('mp4', ), + ), # don't has start & end + "ActivityNet": OrderedDict( + q_json_relpath="ActivityNet/test_q.json", + a_json_relpath="ActivityNet/test_a.json", + prefix="DATAS/VideoQA/ActivityNet/all_test", + data_type="video", + bound=False, + question_key='question', + answer_key='answer', + name_key='video_name', + postfix=('mp4', 'mkv', 'webm'), + ), # don't has start & end + "TGIF_QA": OrderedDict( + q_json_relpath="TGIF_QA/test_q.json", + a_json_relpath="TGIF_QA/test_a.json", + prefix="DATAS/VideoQA/TGIF_QA/tgif_videos", + data_type="gif", + bound=False, + question_key='question', + answer_key='answer', + name_key='video_name', + postfix=('gif',), + ), # don't has start & end + + }) + + def __init__(self, *args, **kwargs): + # test_ratio for videoqa is for each sub dataset + test_ratio = kwargs.pop('test_ratio', None) + kwargs['test_ratio'] = None + test_datasets = kwargs.pop('test_datasets', None) + super().__init__(*args, **kwargs) + test_ratio = 1 if test_ratio is None else test_ratio + self.test_ratio = test_ratio + if test_datasets is not None: + data_list_info = {k:v for k,v in self.data_list_info.items() if k in test_datasets} + else: + data_list_info = self.data_list_info + data_dir = self.data_dir + + self.data_list = [] + for k, v in data_list_info.items(): + with open(os.path.join(data_dir, v['q_json_relpath']), 'r') as f: + quesions_json_data = json.load(f) + with open(os.path.join(data_dir, v['a_json_relpath']), 'r') as f: + answers_json_data = json.load(f) + + indexs = list(range(len(quesions_json_data))) + np.random.RandomState(42).shuffle(indexs) + num_samples = int(len(indexs) * self.test_ratio) if 0 < self.test_ratio <= 1 else int(self.test_ratio) + indexs = indexs[:num_samples] + for i in indexs: + question_data = quesions_json_data[i] + answer_data = answers_json_data[i] + data = {} + # why do we have anet's video name not in the original json file??? + if k == "ActivityNet": + question_data['video_name'] = 'v_' + question_data['video_name'] + data.update(**question_data) + data.update(**answer_data) + self.data_list.append({ + 'task_type': k, + 'data': data, + **v, # all the infos + }) + print(len(self.data_list)) + + def __len__(self): + return len(self.data_list) + + + def __getitem__(self, idx): + decord_method = self.decord_method[self.data_list[idx]['data_type']] + bound = None + if self.data_list[idx]['bound']: + bound = ( + self.data_list[idx]['data']['start'], + self.data_list[idx]['data']['end'], + ) + video_name_key = self.data_list[idx]['name_key'] + video_name = self.data_list[idx]['data'][video_name_key] + + video_postfixs = self.data_list[idx]['postfix'] + video_paths = [] + for p in video_postfixs: + video_path = os.path.join(self.data_list[idx]['prefix'], video_name + '.' + p) + if os.path.exists(video_path): + video_paths.append(video_path) + assert len(video_paths) > 0, f'no video named {video_name}' + # video_filename = self.data_list[idx]['data'][video_name_key] + video_postfix + video_path = video_paths[0] + images_group = decord_method(video_path, bound) + + question_key = self.data_list[idx]['question_key'] + answer_key = self.data_list[idx]['answer_key'] + sample = OpenendQASample( + question=self.data_list[idx]['data'][question_key], + answer=self.data_list[idx]['data'][answer_key] + ) + question, answer = self.qa_template(sample) + + return { + 'video_pils': images_group, # some might use the original pils and do their own transforms + 'question': question, + 'video_path': video_path, + 'answer': answer, + 'task_type': self.data_list[idx]['task_type'] + } + + def qa_template(self, data: OpenendQASample): + answer = data.answer + question = data.question + # by far, might use some prompting. + return question, answer + + diff --git a/tasks/eval/videoqabench/pllava_eval_videoqabench.py b/tasks/eval/videoqabench/pllava_eval_videoqabench.py new file mode 100644 index 0000000000000000000000000000000000000000..a028c0a95ac4817f70dbb8119d9d190d30326875 --- /dev/null +++ b/tasks/eval/videoqabench/pllava_eval_videoqabench.py @@ -0,0 +1,304 @@ + +import functools +import itertools +import logging +from tqdm import tqdm +from PIL import Image +from multiprocessing import Pool +from argparse import ArgumentParser +import multiprocessing as mp + + + +import numpy as np +import torch + +import torchvision + +import transformers +from decord import VideoReader, cpu + +from tasks.eval.model_utils import load_pllava, pllava_answer +from tasks.eval.eval_utils import conv_templates + +logging.basicConfig() +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +IMAGE_TOKEN='' +from tasks.eval.videoqabench import ( + VideoQABenchDataset, + load_results, + save_results, +) +RESOLUTION = 672 # +VIDEOQA_DATASETS=["MSVD_QA","MSRVTT_QA", "ActivityNet","TGIF_QA"] +def parse_args(): + parser = ArgumentParser() + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + required=True, + default='llava-hf/llava-1.5-7b-hf' + ) + parser.add_argument( + "--save_path", + type=str, + required=True, + default='"./test_results/test_llava_mvbench"' + ) + parser.add_argument( + "--num_frames", + type=int, + required=True, + default=4, + ) + parser.add_argument( + "--use_lora", + action='store_true' + ) + parser.add_argument( + "--lora_alpha", + type=int, + required=False, + default=32, + ) + parser.add_argument( + "--max_new_tokens", + type=int, + required=False, + default=100, + ) + parser.add_argument( + "--weight_dir", + type=str, + required=False, + default=None, + ) + parser.add_argument( + "--eval_model", + type=str, + required=False, + default="gpt-3.5-turbo-0125", + ) + parser.add_argument( + '--test_ratio', + type=float, + required=False, + default=1 + ) + parser.add_argument( + "--conv_mode", + type=str, + required=False, + default='eval_videoqabench', + ) + parser.add_argument( + "--test_datasets", + type=str, + required=False, + default='MSVD_QA', + ) + args = parser.parse_args() + return args + +def load_model_and_dataset(rank, world_size, pretrained_model_name_or_path, num_frames, use_lora, lora_alpha, weight_dir, test_ratio, test_datasets): + # remind that, once the model goes larger (30B+) may cause the memory to be heavily used up. Even Tearing Nodes. + model, processor = load_pllava(pretrained_model_name_or_path, num_frames=num_frames, use_lora=use_lora, lora_alpha=lora_alpha, weight_dir=weight_dir) + logger.info('done loading llava') + # position embedding + model = model.to(torch.device(rank)) + model = model.eval() + + dataset = VideoQABenchDataset(test_ratio=test_ratio, test_datasets=test_datasets, num_segments=num_frames) + dataset.set_rank_and_world_size(rank, world_size) + return model, processor, dataset + +def infer_videoqabench( + model, + processor, + data_sample, + conv_mode, + pre_query_prompt=None, # add in the head of question + post_query_prompt=None, # add in the end of question + answer_prompt=None, # add in the begining of answer + return_prompt=None, # add in the begining of return message + print_res=False, + max_new_tokens=100, + ): + video_list = data_sample["video_pils"] + conv = conv_templates[conv_mode].copy() + + pre_query_prompt=conv.pre_query_prompt + post_query_prompt=conv.post_query_prompt + answer_prompt=conv.answer_prompt + + conv.user_query(data_sample['question'], pre_query_prompt, post_query_prompt, is_mm=True) + if answer_prompt is not None: + conv.assistant_response(answer_prompt) + + llm_message, conv = pllava_answer( + conv=conv, + model=model, + processor=processor, + img_list=video_list, + max_new_tokens=max_new_tokens, + do_sample=False, + print_res=print_res, + ) + + if answer_prompt is not None: + llm_message = ''.join(llm_message.split(answer_prompt.strip("\n"))[1:]).strip() + + if return_prompt is not None: + llm_message = return_prompt + llm_message + + return llm_message + +def single_test(model, processor, vid_path, num_frames=4, conv_mode="plain"): + def get_index(num_frames, num_segments): + seg_size = float(num_frames - 1) / num_segments + start = int(seg_size / 2) + offsets = np.array([ + start + int(np.round(seg_size * idx)) for idx in range(num_segments) + ]) + return offsets + + def load_video(video_path, num_segments=8, return_msg=False, num_frames=4, resolution=336): + transforms = torchvision.transforms.Resize(size=resolution) + vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) + num_frames = len(vr) + frame_indices = get_index(num_frames, num_segments) + images_group = list() + for frame_index in frame_indices: + img = Image.fromarray(vr[frame_index].asnumpy()) + images_group.append(transforms(img)) + if return_msg: + fps = float(vr.get_avg_fps()) + sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices]) + # " " should be added in the start and end + msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds." + return images_group, msg + else: + return images_group + + if num_frames != 0: + vid, msg = load_video(vid_path, num_segments=num_frames, return_msg=True, resolution=RESOLUTION) + else: + vid, msg = None, 'num_frames is 0, not inputing image' + img_list = vid + + conv = conv_templates[conv_mode].copy() + conv.user_query("Describe the video in details.", is_mm=True) + llm_response, conv = pllava_answer(conv=conv, model=model, processor=processor, do_sample=False, img_list=img_list, max_new_tokens=256, print_res=True) + +def run(rank, args, world_size): + if rank != 0: + transformers.utils.logging.set_verbosity_error() + logger.setLevel(transformers.logging.ERROR) + + print_res = True + conv_mode= args.conv_mode + pre_query_prompt = None + post_query_prompt = None + # pre_query_prompt = "Answer the question with a single word or phrase." + + logger.info(f'loading model and constructing dataset to gpu {rank}...') + test_datasets = [x for x in args.test_datasets.split("-") if x in VIDEOQA_DATASETS] + assert len(test_datasets)>=1 + + model, processor, dataset = load_model_and_dataset(rank, + world_size, + pretrained_model_name_or_path=args.pretrained_model_name_or_path, + num_frames=args.num_frames, + use_lora=args.use_lora, + lora_alpha=args.lora_alpha, + weight_dir=args.weight_dir, + test_ratio=args.test_ratio, + test_datasets=test_datasets) + logger.info(f'done model and dataset...') + logger.info('constructing dataset...') + logger.info('single test...') + vid_path = "./example/yoga.mp4" + # vid_path = "./example/jesse_dance.mp4" + if rank == 0: + single_test(model, processor, vid_path, num_frames=args.num_frames, conv_mode=args.conv_mode) + logger.info('single test done...') + tbar = tqdm(total=len(dataset)) + logger.info('single test...') + + result_list = [] + done_count = 0 + for example in dataset: + task_type = example['task_type'] + gt = example['answer'] + if task_type in dataset.data_list_info: + pred = infer_videoqabench( + model, + processor, + example, + conv_mode=conv_mode, + pre_query_prompt=pre_query_prompt, + post_query_prompt=post_query_prompt, + print_res=print_res, + max_new_tokens=args.max_new_tokens, + ) + + infos = { + 'question': example['question'], + 'video_path': example['video_path'] + } + res = { + 'pred': pred, + 'gt': gt, + 'task_type': task_type, + **infos + } + else: + raise NotImplementedError(f'not implemented task type {task_type}') + # res = chatgpt_eval(res) + result_list.append(res) + if rank == 0: + tbar.update(len(result_list) - done_count, ) + tbar.set_description_str( + f"One Chunk--Task Type: {task_type}-" + f"gt: {gt[:min(15, len(gt))]}......--pred: {pred[:min(15, len(gt))]}......" + ) + done_count = len(result_list) + return result_list + +def main(): + multiprocess=True + mp.set_start_method('spawn') + args = parse_args() + save_path = args.save_path + eval_model = args.eval_model + logger.info(f'trying loading results from {save_path}') + result_list = load_results(save_path) + + if result_list is None: + if multiprocess: + + logger.info(f'started benchmarking, saving to: {save_path}') + n_gpus = torch.cuda.device_count() + # assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}" + world_size = n_gpus + with Pool(world_size) as pool: + func = functools.partial(run, args=args, world_size=world_size) + # func = functools.partial(run, world_size=world_size, model=model, dataset=dataset, result_list=[], acc_dict={}) + result_lists = pool.map(func, range(world_size)) + + logger.info('finished running') + + result_list = [ res for res in itertools.chain(*result_lists)] + else: + result_list = run(0, world_size=1, args=args) # debug + else: + logger.info(f'loaded results from {save_path}') + + save_results(result_list, save_path, model=eval_model) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tasks/shared_utils.py b/tasks/shared_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a5dec719791e801f8636cf29eb34bf965ab35846 --- /dev/null +++ b/tasks/shared_utils.py @@ -0,0 +1,36 @@ +import copy +import logging +import os +import os.path as osp +from os.path import join + +import torch +from torch.utils.data import ConcatDataset, DataLoader + +from utils.optimizer import create_optimizer +from utils.scheduler import create_scheduler + +logger = logging.getLogger(__name__) + + +def get_media_types(datasources): + """get the media types for for all the dataloaders. + + Args: + datasources (List): List of dataloaders or datasets. + + Returns: List. The media_types. + + """ + if isinstance(datasources[0], DataLoader): + datasets = [dataloader.dataset for dataloader in datasources] + else: + datasets = datasources + media_types = [ + dataset.datasets[0].media_type + if isinstance(dataset, ConcatDataset) + else dataset.media_type + for dataset in datasets + ] + + return media_types diff --git a/tasks/train/config_pllava_nframe.py b/tasks/train/config_pllava_nframe.py new file mode 100644 index 0000000000000000000000000000000000000000..b80ac33155504b002ca182cb3fef4f556ef655a3 --- /dev/null +++ b/tasks/train/config_pllava_nframe.py @@ -0,0 +1,135 @@ +from tasks.train.instruction_data import * + +# ========================= data ========================== +# train_corpus = "videochat2_instruction" +train_corpus = "videochat2_instruction_full" + +train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation +test_file = dict() +test_types = [] +num_workers = 8 +save_steps=10000 +ckpt_steps=1000 +stop_key = None +deepspeed=False +# ========================= input ========================== +num_frames = 16 +num_frames_test = 1 +batch_size = 1 +gradient_accumulation_steps=16 +max_txt_l = 512 +max_train_steps=None +pre_text = False +gradient_checkpointing=False +inputs = dict( + image_res=336, + video_input=dict( + num_frames="${num_frames}", + sample_type="rand", + num_frames_test="${num_frames_test}", + sample_type_test="middle", + random_aug=False, + ), + max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"), + batch_size=dict(image="${batch_size}", video="${batch_size}"), + batch_size_test=dict(image="${batch_size}", video="${batch_size}"), +) + +# ========================= model ========================== +model = dict( + repo_id="llava-hf/llava-v1.6-vicuna-7b-hf", + pretrained_path=None, + load_from_origin=False, + origin_vision="", + origin_llm="", + vision_encoder=dict( + name="vit_l14", # somehow need this to tell the dataset the mean std of pretrained model + ), + torch_dtype='bfloat16', + freeze_projector=False, + freeze_lm=True, + freeze_vision_tower=True, + lora_target_modules=["q_proj", "v_proj"], # for llama/mistral/gemma + use_lora=True, + lora_r=128, + lora_alpha=32, + lora_dropout=0.05, + num_frames="${num_frames}", + pooling_method='avg', + use_pooling=True, + frame_shape=(24,24), + pooling_shape=(16,8,8), +) +preprocess = dict( + system="", + mm_alone=True, + random_shuffle=True, + add_second_msg=True, + roles=['USER:', 'ASSISTANT:'], + end_signal=(' ', ''), + begin_signal='', + dataset_image_placeholder='', + dataset_video_placeholder='', + image_token_index=32000, + max_txt_l = "${max_txt_l}", + ignore_index=-100, # same as torch softmax ignore index + center_pad=False, + longest_edge=762, + shortest_edge=336, + clip_transform=False, + num_frames="${num_frames}", +) + + +optimizer = dict( + opt="adamW", + lr=2e-5, + opt_betas=[0.9, 0.999], # default + weight_decay=0.02, + max_grad_norm=-1, # requires a positive float, use -1 to disable + # use a different lr for some modules, e.g., larger lr for new modules + different_lr=dict(enable=False, module_names=[], lr=1e-3), +) + +# scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.25, warmup_epochs=0.6) +# scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.25, warmup_epochs=0.6) +scheduler = dict( + is_videochat2_custom=False, + sched="cosine", + epochs=2, + warmup_ratio=0.2, + min_lr_multi=0.25) + +evaluate = False +deep_fusion = False +evaluation = dict( + eval_frame_ensemble="concat", # [concat, max, mean, lse] + eval_x_only=False, + k_test=128, + eval_offload=True, # offload gpu tensors to cpu to save memory. +) + +fp16 = True +gradient_checkpointing = True + +# ========================= wandb ========================== +wandb = dict( + enable=False, + entity="user", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init + project="videochat2", # setup in your command line +) +dist_url = "env://" +device = "cuda" +mode = "it" + +# ========================= others ========================== +output_dir = None # output dir +resume = False # if True, load optimizer and scheduler states as well +debug = False +log_freq = 5 +metric_window_size=10 # window size for metric +seed = 42 +report_to='tensorboard' +save_latest = True +auto_resume = True +pretrained_path = "" # path to pretrained model weights, for resume only? diff --git a/tasks/train/config_pllava_nframe_yiprompt.py b/tasks/train/config_pllava_nframe_yiprompt.py new file mode 100644 index 0000000000000000000000000000000000000000..9ea7adeb7714ec9c60a6374bbbfc5dcf190c0e61 --- /dev/null +++ b/tasks/train/config_pllava_nframe_yiprompt.py @@ -0,0 +1,135 @@ +from tasks.train.instruction_data import * + +# ========================= data ========================== +# train_corpus = "videochat2_instruction" +train_corpus = "videochat2_instruction_full" + +train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation +test_file = dict() +test_types = [] +num_workers = 8 +save_steps=10000 +ckpt_steps=1000 +stop_key = None +deepspeed=False +highres=None +# ========================= input ========================== +num_frames = 16 +num_frames_test = 1 +batch_size = 1 +gradient_accumulation_steps=16 +max_txt_l = 512 +max_train_steps=None +pre_text = False +gradient_checkpointing=False +inputs = dict( + image_res=336, + video_input=dict( + num_frames="${num_frames}", + sample_type="rand", + num_frames_test="${num_frames_test}", + sample_type_test="middle", + random_aug=False, + ), + max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"), + batch_size=dict(image="${batch_size}", video="${batch_size}"), + batch_size_test=dict(image="${batch_size}", video="${batch_size}"), +) + +model = dict( + repo_id="llava-hf/llava-1.5-7b-hf", + pretrained_path=None, + load_from_origin=False, + origin_vision="", + origin_llm="", + vision_encoder=dict( + name="vit_l14", # somehow need this to tell the dataset the mean std of pretrained model + ), + torch_dtype='bfloat16', + freeze_projector=False, + freeze_lm=True, + freeze_vision_tower=True, + lora_target_modules=["q_proj", "v_proj"], # for llama/mistral/gemma + use_lora=True, + lora_r=128, + lora_alpha=32, + lora_dropout=0.05, + num_frames="${num_frames}", + pooling_method='avg', + use_pooling=True, + frame_shape=(24,24), + pooling_shape=(16,8,8), +) +preprocess = dict( + system="", + mm_alone=True, + image_token_index=64002, + random_shuffle=True, + add_second_msg=True, + roles=['<|im_start|>user\n', '<|im_start|>assistant\n'], + end_signal=('<|im_end|>\n', '<|im_end|>\n'), + begin_signal='', + dataset_image_placeholder='', + dataset_video_placeholder='', + max_txt_l = "${max_txt_l}", + ignore_index=-100, # same as torch softmax ignore index + center_pad=False, + longest_edge=762, + shortest_edge=336, + clip_transform=False, + num_frames="${num_frames}", +) + + +optimizer = dict( + opt="adamW", + lr=2e-5, + opt_betas=[0.9, 0.999], # default + weight_decay=0.02, + max_grad_norm=-1, # requires a positive float, use -1 to disable + # use a different lr for some modules, e.g., larger lr for new modules + different_lr=dict(enable=False, module_names=[], lr=1e-3), +) + +# scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.25, warmup_epochs=0.6) +# scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.25, warmup_epochs=0.6) +scheduler = dict( + is_videochat2_custom=False, + sched="cosine", + epochs=2, + warmup_ratio=0.2, + min_lr_multi=0.25) + +evaluate = False +deep_fusion = False +evaluation = dict( + eval_frame_ensemble="concat", # [concat, max, mean, lse] + eval_x_only=False, + k_test=128, + eval_offload=True, # offload gpu tensors to cpu to save memory. +) + +fp16 = True +gradient_checkpointing = True + +# ========================= wandb ========================== +wandb = dict( + enable=False, + entity="user", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init + project="videochat2", # setup in your command line +) +dist_url = "env://" +device = "cuda" +mode = "it" + +# ========================= others ========================== +output_dir = None # output dir +resume = False # if True, load optimizer and scheduler states as well +debug = False +log_freq = 5 +metric_window_size=10 # window size for metric +seed = 42 +report_to='tensorboard' +save_latest = True +auto_resume = True +pretrained_path = "" # path to pretrained model weights, for resume only? diff --git a/tasks/train/instruction_data.py b/tasks/train/instruction_data.py new file mode 100644 index 0000000000000000000000000000000000000000..58ac2c8bd5c037b6cf4f37a4862c34c75b36757b --- /dev/null +++ b/tasks/train/instruction_data.py @@ -0,0 +1,271 @@ +import os as __os # add "__" if not want to be exported +from copy import deepcopy as __deepcopy +import itertools as __itertools + +data_root = "DATAS/TRAIN_TEST" +anno_root_it = f"{data_root}/magic_jsons" + +# ============== pretraining datasets================= +available_corpus = dict( + # image + # caption_coco=[ + # f"{anno_root_it}/image/caption/coco/train.json", + # f"{data_root}/images/coco", + # ], + # caption_llava=[ + # f"{anno_root_it}/image/caption/llava/train.json", + # f"{data_root}/images/coco", + # ], + # caption_minigpt4=[ + # f"{anno_root_it}/image/caption/minigpt4/train.json", + # f"{data_root}/images/minigpt4_align/image", + # ], + # caption_paragraph_captioning=[ + # f"{anno_root_it}/image/caption/paragraph_captioning/train.json", + # f"{data_root}/images/m3it/image-paragraph-captioning", + # ], + # caption_textcaps=[ + # f"{anno_root_it}/image/caption/textcaps/train.json", + # f"{data_root}/images/textcaps", + # ], + # classification_imagenet=[ + # f"{anno_root_it}/image/classification/imagenet/train.json", + # f"{data_root}/images/m3it/imagenet", + # ], + # classification_coco_itm=[ + # f"{anno_root_it}/image/classification/coco_itm/train.json", + # f"{data_root}/images/coco", + # ], + # conversation_llava=[ + # f"{anno_root_it}/image/conversation/llava/train.json", + # f"{data_root}/images/coco", + # ], + # reasoning_clevr=[ + # f"{anno_root_it}/image/reasoning/clevr/train.json", + # f"{data_root}/images/m3it/clevr", + # ], + # reasoning_visual_mrc=[ + # f"{anno_root_it}/image/reasoning/visual_mrc/train.json", + # f"{data_root}/images/m3it/visual_mrc", + # ], + # reasoning_llava=[ + # f"{anno_root_it}/image/reasoning/llava/train.json", + # f"{data_root}/images/coco", + # ], + # vqa_vqav2=[ + # f"{anno_root_it}/image/vqa/vqav2/train.json", + # f"{data_root}/images/m3it/vqav2", + # ], + # vqa_gqa=[ + # f"{anno_root_it}/image/vqa/gqa/train.json", + # f"{data_root}/images/gqa/images", + # ], + # vqa_okvqa=[ + # f"{anno_root_it}/image/vqa/okvqa/train.json", + # f"{data_root}/images/m3it/okvqa", + # ], + # vqa_a_okvqa=[ + # f"{anno_root_it}/image/vqa/a_okvqa/train.json", + # f"{data_root}/images/m3it/a_okvqa", + # ], + # vqa_viquae=[ + # f"{anno_root_it}/image/vqa/viquae/train.json", + # f"{data_root}/images/viquae_images", + # ], + # vqa_ocr_vqa=[ + # f"{anno_root_it}/image/vqa/ocr_vqa/train.json", + # f"{data_root}/images/ocr_vqa/images", + # ], + # vqa_text_vqa=[ + # f"{anno_root_it}/image/vqa/text_vqa/train.json", + # f"{data_root}/images/textvqa", + # ], + # vqa_st_vqa=[ + # f"{anno_root_it}/image/vqa/st_vqa/train.json", + # f"{data_root}/images/m3it/st-vqa", + # ], + # vqa_docvqa=[ + # f"{anno_root_it}/image/vqa/docvqa/train.json", + # f"{data_root}/images/docvqa", + # ], + # origin_llava=[ + # f"{anno_root_it}/image/origin_llava/train.json", + # f"{data_root}/images", + # ], + # video + caption_textvr=[ + f"{anno_root_it}/video/caption/textvr/train.json", + f"{data_root}/videos/TextVR", + "video" + ], + caption_videochat=[ + f"{anno_root_it}/video/caption/videochat/train.json", + f"{data_root}/videos/webvid_10m", + "video" + ], # not ready, need to read from hdfs + caption_webvid=[ + f"{anno_root_it}/video/caption/webvid/train.json", + f"{data_root}/videos/webvid_10m", + "video" + ], # not ready, need to read from hdfs + caption_youcook2=[ + f"{anno_root_it}/video/caption/youcook2/train.json", + f"{data_root}/videos/YouCook2/split_videos", + "video" + ], + classification_k710=[ + f"{anno_root_it}/video/classification/k710/train.json", + f"{data_root}/videos/kinetics", + "video" + ], + classification_ssv2=[ + f"{anno_root_it}/video/classification/ssv2/train.json", + f"{data_root}/videos/20bn-something-something-v2", + "video" + ], + conversation_videochat1=[ + f"{anno_root_it}/video/conversation/videochat1/train.json", + f"{data_root}/videos/webvid_10m", + "video" + ],# not ready, need to read from hdfs + conversation_videochat2=[ + f"{anno_root_it}/video/conversation/videochat2/train.json", + f"{data_root}/videos/InternVid-10M-FLT/videos", + "video" + ], + conversation_videochatgpt=[ + f"{anno_root_it}/video/conversation/videochatgpt/train.json", + f"{data_root}/videos/AVideo_ChatGPT", + "video" + ], + reasoning_next_qa=[ + f"{anno_root_it}/video/reasoning/next_qa/train.json", + f"{data_root}/videos/NExTVideo", + "video" + ], + reasoning_clevrer_qa=[ + f"{anno_root_it}/video/reasoning/clevrer_qa/train.json", + f"{data_root}/videos/CLEVRER", + "video" + ], + reasoning_clevrer_mc=[ + f"{anno_root_it}/video/reasoning/clevrer_mc/train.json", + f"{data_root}/videos/CLEVRER", + "video" + ], + vqa_ego_qa=[ + f"{anno_root_it}/video/vqa/ego_qa/train.json", + f"{data_root}/videos/ego4d_data/split_videos", + "video" + ], + vqa_tgif_frame_qa=[ + f"{anno_root_it}/video/vqa/tgif_frame_qa/train.json", + f"{data_root}/videos/tgif", + "video" + ], + vqa_tgif_transition_qa=[ + f"{anno_root_it}/video/vqa/tgif_transition_qa/train.json", + f"{data_root}/videos/tgif", + "video" + ], + vqa_webvid_qa=[ + f"{anno_root_it}/video/vqa/webvid_qa/train.json", + f"{data_root}/videos/webvid_10m", + "video" + ],# not ready, need to read from hdfs + origin_videochatgpt=[ + f"{anno_root_it}/video/origin_videochatgpt/train.json", + f"{data_root}/videos/Video_ChatGPT", + "video" + ], +) + + + +available_corpus["videochat2_instruction_full"] = [ + available_corpus["caption_coco"], + available_corpus["caption_llava"], + available_corpus["caption_minigpt4"], + available_corpus["caption_paragraph_captioning"], + available_corpus["caption_textcaps"], + available_corpus["classification_imagenet"], + available_corpus["classification_coco_itm"], + available_corpus["conversation_llava"], + available_corpus["reasoning_clevr"], + available_corpus["reasoning_visual_mrc"], + available_corpus["reasoning_llava"], + available_corpus["vqa_vqav2"], + available_corpus["vqa_gqa"], + available_corpus["vqa_okvqa"], + available_corpus["vqa_a_okvqa"], + available_corpus["vqa_viquae"], + available_corpus["vqa_ocr_vqa"], + available_corpus["vqa_text_vqa"], + available_corpus["vqa_st_vqa"], + available_corpus["vqa_docvqa"], + available_corpus["caption_textvr"], + available_corpus["caption_youcook2"], + available_corpus["classification_k710"], + available_corpus["classification_ssv2"], + available_corpus["conversation_videochat2"], + available_corpus["conversation_videochatgpt"], + available_corpus["reasoning_next_qa"], + available_corpus["reasoning_clevrer_qa"], + available_corpus["reasoning_clevrer_mc"], + available_corpus["vqa_ego_qa"], + available_corpus["vqa_tgif_frame_qa"], + available_corpus["vqa_tgif_transition_qa"], + available_corpus["conversation_videochat1"], + available_corpus["vqa_webvid_qa"], + available_corpus["caption_videochat"], + available_corpus["caption_webvid"], +] + +available_corpus["videochat2_video"] = [ + available_corpus["caption_textvr"], + available_corpus["caption_youcook2"], + available_corpus["classification_k710"], + available_corpus["classification_ssv2"], + available_corpus["conversation_videochat2"], + available_corpus["conversation_videochatgpt"], + available_corpus["reasoning_next_qa"], + available_corpus["reasoning_clevrer_qa"], + available_corpus["reasoning_clevrer_mc"], + available_corpus["vqa_ego_qa"], + available_corpus["vqa_tgif_frame_qa"], + available_corpus["vqa_tgif_transition_qa"], + available_corpus["conversation_videochat1"], + available_corpus["vqa_webvid_qa"], + available_corpus["caption_videochat"], + available_corpus["caption_webvid"], +] + + + + +# ============== for debug================= +available_corpus["videochat2_instruction_debug"] = [ + # available_corpus["caption_minigpt4"], + available_corpus["caption_textvr"], + # available_corpus["vqa_ego_qa"], + # available_corpus["classification_k710"], + # available_corpus["reasoning_next_qa"], + # available_corpus["caption_textvr"], + # available_corpus["caption_youcook2"], + + # available_corpus["caption_textcaps"], # realistic caption foucsing in real life text + # available_corpus["caption_textvr"], # good realistic captioning, also focusing on text +] + + +if __name__ == '__main__': + print(len(list( + __itertools.chain( + available_corpus['conversation_data'], + available_corpus['reasoning_data'], + available_corpus['conversation_videochat2'], + available_corpus['caption_data'], + available_corpus['classification_data'], + ) + ))) + print(len(available_corpus['videochat2_instruction_full'])) \ No newline at end of file diff --git a/tasks/train/train_pllava_nframe_accel.py b/tasks/train/train_pllava_nframe_accel.py new file mode 100644 index 0000000000000000000000000000000000000000..9f02309d20ac3629f6f5382b697404d1ffcba96e --- /dev/null +++ b/tasks/train/train_pllava_nframe_accel.py @@ -0,0 +1,545 @@ +import datetime +import gc +import time +import os +import os.path as osp +import re +import itertools +import functools +import random +import math +import shutil +from typing import Optional, Union + +import torch +import numpy as np +from safetensors import safe_open + +import logging +from accelerate.logging import get_logger +from accelerate import Accelerator, DistributedType +from accelerate.utils import set_seed +from peft import get_peft_model, LoraConfig, TaskType + + +from dataset import create_dataset, create_loader +from tasks.shared_utils import get_media_types +from utils.basic_utils import (MetricLogger, SmoothedValue, setup_seed) +from utils.config_utils import setup_main +from transformers.utils import TensorType + +from tasks.shared_utils import create_optimizer, create_scheduler +import copy +from transformers import ( + DataCollatorWithPadding, + get_scheduler, + AutoModel, + AutoModelForCausalLM + ) +from models.pllava import PllavaConfig, PllavaForConditionalGeneration, PllavaProcessor + +# logger = logging.getLogger(__name__) +IMAGE_TOKEN='' + +logger = get_logger(__name__) + +def maybe_zero_3(param, ignore_status=False, name=None): + from deepspeed import zero + from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus + if hasattr(param, "ds_id"): + if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: + if not ignore_status: + print(name, 'no ignore status') + with zero.GatheredParameters([param]): + param = param.data.detach().cpu().clone() + else: + param = param.detach().cpu().clone() + return param + + +def get_state_maybe_zero_3(named_params, keys_to_match=["lora_","multi_modal_projector"]): + to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} + to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()} + return to_return + +def setup_dataloaders(config, mode="pt", collate_fn=None): + # train datasets, create a list of data loaders + logger.info(f"Creating dataset for {mode}") + train_datasets = create_dataset(f"{mode}_train", config) + + media_types = get_media_types(train_datasets) + samplers = [None] * len(media_types) + + train_loaders = create_loader( + train_datasets, + samplers, + batch_size=[config.inputs.batch_size[k] for k in media_types], + num_workers=[config.num_workers] * len(media_types), + is_trains=[True] * len(media_types), + collate_fns=[collate_fn] * len(media_types), + ) # [0] + + return train_loaders, media_types + + +def setup_model( + config, find_unused_parameters=False +): + if config.model.torch_dtype in ('bfloat16', 'float16', 'float32'): + torch_dtype = eval(f'torch.{config.model.torch_dtype}') + else: + torch_dtype = config.model.torch_dtype + logger.info("Creating model") + + processor = PllavaProcessor.from_pretrained(config.model.repo_id, + padding_side='right', + center_pad=config.preprocess.center_pad, + ) + + + model_config = PllavaConfig.from_pretrained(config.model.repo_id, + torch_dtype=torch_dtype, + num_frames=config.model.num_frames, + pooling_method=config.model.pooling_method, + image_token_index=config.preprocess.image_token_index, + frame_shape=config.model.frame_shape, + pooling_shape=config.model.pooling_shape, + use_pooling=config.model.use_pooling, + gradient_checkpointing=config.gradient_checkpointing, + ) + print("====>gradient_checkpointing",model_config.gradient_checkpointing) + + model = PllavaForConditionalGeneration.from_pretrained(config.model.repo_id, config=model_config, torch_dtype=torch_dtype) + + if config.model.load_from_origin: + with torch.no_grad(): + lm_model = AutoModelForCausalLM.from_pretrained(config.model.origin_llm, torch_dtype=torch_dtype, device_map="cpu",) + with torch.no_grad(): + clip = AutoModel.from_pretrained(config.model.origin_vision, torch_dtype=torch_dtype, device_map="cpu",) + msg = model.vision_tower.load_state_dict(clip.state_dict(), strict=False) + # print(msg) + msg = model.language_model.load_state_dict(lm_model.state_dict(), strict=False) + print(msg) + + + if config.model.freeze_lm: + logger.info("freezing parameters in model.language_model") + for p in model.language_model.parameters(): + p.requires_grad = False + + if config.model.freeze_projector: + logger.info("freezing parameters in model.multi_modal_projector") + for p in model.multi_modal_projector.parameters(): + p.requires_grad = False + + if config.model.freeze_vision_tower: + logger.info("freezing parameters in model.vision_tower") + for p in model.vision_tower.parameters(): + p.requires_grad = False + + if config.model.use_lora: + logger.info("getting LoRA Language Model") + kwargs = {} + if config.model.lora_target_modules is not None and len(config.model.lora_target_modules) > 0: + kwargs.update({"target_modules": config.model.lora_target_modules}) + peft_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, inference_mode=False, + r=config.model.lora_r, lora_alpha=config.model.lora_alpha, lora_dropout=config.model.lora_dropout, + **kwargs + ) + model.language_model = get_peft_model(model.language_model, peft_config) + model.language_model.print_trainable_parameters() + + if config.model.pretrained_path is not None and not config.deepspeed: + logger.info("======> loading pretrained weights from " + str(config.model.pretrained_path)) + state_dict = {} + save_fnames = os.listdir(config.model.pretrained_path) + if "model.safetensors" in save_fnames: + print("Loading weight from", config.model.pretrained_path, "model.safetensors") + with safe_open(f"{config.model.pretrained_path}/model.safetensors", framework="pt", device="cpu") as f: + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + else: + print("Loading weight from", config.model.pretrained_path) + for fn in save_fnames: + if fn.startswith('model-0000'): + with safe_open(f"{config.model.pretrained_path}/{fn}", framework="pt", device="cpu") as f: + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + + if 'model' in state_dict.keys(): + msg = model.load_state_dict(state_dict['model'], strict=False) + else: + msg = model.load_state_dict(state_dict, strict=False) + logger.info(msg) + logger.info("=====> Finish loading") + + return model, processor + +def setup_optimizer_and_scheduler(config, model): + optimizer = create_optimizer(config.optimizer, model) # do you want to filter bias and bn? + if config.scheduler.is_videochat2_custom: + scheduler = create_scheduler(config.scheduler, optimizer) + else: + scheduler=None + + return optimizer, scheduler + +class RandomMappingIterator(): + # a random iter through the multiple mapping style dataloaders + def __init__(self, train_loaders, media_types, resume_step=0): + self.train_loaders = train_loaders + self.media_types = media_types + self.total_num_samples = sum(len(train_loader) for train_loader in self.train_loaders) + self.weights = [len(loader) / self.total_num_samples for loader in train_loaders] + self.resume_step = resume_step + if resume_step != 0: + self.total_num_samples= self.total_num_samples-resume_step + # remove corresponding iters from each loader + + + def __iter__(self): + train_loaders = self.train_loaders + iters = [iter(train_loader) for train_loader in train_loaders] + + media_types = copy.deepcopy(self.media_types) + weights = copy.deepcopy(self.weights) + while len(iters) > 0: + index = np.random.choice(list(range(len(iters))), p=weights, replace=True) + try: + batch = next(iters[index]) + except StopIteration as e: + iters.pop(index) + media_types.pop(index) + weights.pop(index) + total = sum(weights) + weights = [w/total for w in weights] + continue + + media_type = media_types[index] + yield media_type, batch + + def __len__(self): + return self.total_num_samples + +def split_and_record_separators(input_string, separators) -> list: + texts = [input_string] + for sep in separators: + new_texts = [] + for text in texts: + if sep not in text: + new_texts.append(text) + else: + split_strings = text.split(sep) + joint_strings = [t for pair in zip(split_strings[:-1], itertools.repeat(sep)) for t in pair ] + split_strings[-1:] + new_texts.extend(joint_strings) + texts = new_texts + return texts + +def preprocess( + batch, + args, + processor, + collate_fn, + dtype=torch.bfloat16, + return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, +): + tokenizer = processor.tokenizer + # tokenization for training + max_length = args.max_txt_l + input_list, images = [], [] + for sample in batch: + image, tex, instruction, index = sample # (nframe, 3, h, w), (0-255) + num_img = image.shape[0] + tex = tex.replace(args.dataset_video_placeholder, IMAGE_TOKEN).replace(args.dataset_image_placeholder, IMAGE_TOKEN) + seps = [role for role in args.roles] + segs = split_and_record_separators(tex, seps) + input_ids, labels, attention_mask = [], [], [] + + for i, seg in enumerate(segs): + seg_ignore = False if seg == seps[1] else \ + (True if i == 0 or seg in seps else seg_ignore) # not ignoring assistant, changing in sepecific situations + current_ignore = True if seg in seps else seg_ignore # serve for only this one iteration + seg_input_ids = tokenizer.encode(seg, add_special_tokens=True if i==0 else False) # only add bos token + seg_labels = [args.ignore_index] * len(seg_input_ids) if current_ignore else seg_input_ids + seg_attention_mask = [1] * len(seg_input_ids) # do attend + input_ids.extend(seg_input_ids) + labels.extend(seg_labels) + attention_mask.extend(seg_attention_mask) + + pad_length = max_length - len(input_ids) + labels = labels[:max_length] + attention_mask = attention_mask[:max_length] + input_ids=input_ids[:max_length] + + labels = labels + [args.ignore_index] * pad_length # padding doesn't take care of labels. do the padding here + input_ids = input_ids + [tokenizer.pad_token_id] * pad_length + attention_mask = attention_mask + [0]*pad_length + sample_input = { + 'input_ids': input_ids, + 'labels': labels, + 'attention_mask': attention_mask, + } + input_list.append(sample_input) + images.append(image if image.ndim==4 else image.unsqueeze(0)) # made 4 dim for image, remain 4 dim for video + + inputs = collate_fn(input_list) + + # interpolate frames if the total frame is smaller than needed + for i, video in enumerate(images): + if video.shape[0] < args.num_frames: + multiplier = int(args.num_frames/video.shape[0]) + 1 + video = video.repeat_interleave(multiplier, dim=0)[:args.num_frames] + images[i] = video + assert video.shape[0] == args.num_frames + if args.clip_transform: + multimodal_features = processor(images=images) + inputs.update(**multimodal_features) + else: + inputs["pixel_values"] = torch.concat(images) # already processed to features in dataset get item + + + return inputs + +def main(config): + accelerator_log_kwargs=dict( + log_with=config.report_to, + project_dir=config.output_dir + ) + + accelerator = Accelerator( + gradient_accumulation_steps=config.gradient_accumulation_steps, + **accelerator_log_kwargs + ) + logger.info(f"train_file: {config.train_file}") + model, processor = setup_model( + config, + find_unused_parameters=True, + ) + if accelerator.is_main_process: + logger.setLevel(logging.INFO) + else: + logger.setLevel(logging.WARNING) + + collate_fn = DataCollatorWithPadding(tokenizer=processor.tokenizer, padding='max_length', max_length=config.max_txt_l, return_tensors='pt',) + collate_fn = functools.partial(preprocess, args=config.preprocess, processor=processor, collate_fn=collate_fn) + train_loaders, train_media_types = setup_dataloaders(config, mode=config.mode, collate_fn=collate_fn) + num_steps_per_epoch = math.ceil(sum(len(d) for d in train_loaders) / config.gradient_accumulation_steps) + # load optimizer and custom scheduler + config.scheduler.num_training_steps = num_steps_per_epoch * config.scheduler.epochs + config.scheduler.num_warmup_steps = math.ceil(config.scheduler.num_training_steps * config.scheduler.warmup_ratio) + optimizer, lr_scheduler = setup_optimizer_and_scheduler(config, model) + # if not set customized scheduler, default hf scheduler + overrode_max_train_steps = False + if config.max_train_steps is None: + config.max_train_steps = config.scheduler.epochs * num_steps_per_epoch + overrode_max_train_steps = True + if lr_scheduler is None: + lr_scheduler = get_scheduler( + name=config.scheduler.sched, + optimizer=optimizer, + num_warmup_steps=config.scheduler.num_warmup_steps, + num_training_steps=config.max_train_steps + if overrode_max_train_steps + else config.max_train_steps * accelerator.num_processes, + ) + model, optimizer, lr_scheduler, *train_loaders = accelerator.prepare( + model, optimizer, lr_scheduler, *train_loaders + ) + + if hasattr(config, 'seed'): + set_seed(config.seed) + + experiment_config = { # include all the important hyperparam + 'num_frames': config.num_frames, + 'max_txt_l': config.max_txt_l, + 'batch_size': config.batch_size, + } + + model.train() + + start_epoch = 0 + num_batches = sum(len(loader) for loader in train_loaders) + global_step = start_epoch * num_batches # the steps before divided by accumulation + if osp.exists(config.output_dir): + subfolders = os.listdir(config.output_dir) + sample_saving = False + for subfolder in subfolders: + if subfolder.endswith("M"): + sample_saving = True + if sample_saving: + ckpt_paths = [subfolder for subfolder in subfolders if re.match(r'ckpt_resume_[\d.]+M$', subfolder) is not None] + ckpt_iters = [float(re.findall(r'[\d.]+', x)[0]) for x in ckpt_paths] + else: + ckpt_paths = [subfolder for subfolder in subfolders if re.match("ckpt_[^\d]+", subfolder) is not None] + ckpt_iters = [int(s.split(re.match("ckpt_[^\d]+", s).group())[-1]) for s in ckpt_paths] + + + resume_cur_epoch_step=0 + if len(ckpt_iters) > 0: + resume_iter = max(ckpt_iters) + ckpt_path = osp.join(config.output_dir, ckpt_paths[ckpt_iters.index(resume_iter)]) + accelerator.print(f"Resumed from checkpoint: {ckpt_path}") + accelerator.load_state(ckpt_path) + if sample_saving: + resume_iter = int(resume_iter*1e6/(config.batch_size*accelerator.state.num_processes)) + + if "epoch" in ckpt_path: + start_epoch = int(resume_iter) + 1 + resume_cur_epoch_step = 0 + global_step = start_epoch * num_batches + else: + # need to multiply `gradient_accumulation_steps` to reflect real steps + # num_finish_smaple = int(max_ckpt_num) * config.gradient_accumulation_steps + start_epoch = resume_iter // num_batches + global_step = resume_iter + resume_cur_epoch_step = resume_iter - start_epoch * num_batches + accelerator.print(f"Resume from epoch {start_epoch}, steps{resume_cur_epoch_step}") + + + + # TensorBoard cannot log Enums, need the raw value + accelerator.init_trackers("train_pllava_nframe", experiment_config) + start_time = time.time() + + + + logger.info(f"Start training {str(start_time)}, from start_epoch-{start_epoch}, step-{resume_cur_epoch_step}") + + # skip the first `n` batches in the dataloader when resuming from a checkpoint + active_train_loaders = train_loaders + if resume_cur_epoch_step > 0: + active_train_loaders = [] + total_dta_num = sum(len(train_loader) for train_loader in train_loaders) + for train_loader in train_loaders: + skip_batch_num = int((resume_cur_epoch_step/total_dta_num)*len(train_loader)) + skipped_train_loader = accelerator.skip_first_batches(train_loader, num_batches=skip_batch_num) + active_train_loaders.append(skipped_train_loader) + + media_types = get_media_types(active_train_loaders) + train_loader = RandomMappingIterator(active_train_loaders, media_types) + + for epoch in range(start_epoch, config.scheduler.epochs): + if not config.evaluate: + gc.collect() + torch.cuda.empty_cache() + metric_logger = MetricLogger(delimiter=" ") + loss_names = ["loss"] + for name in loss_names: + for m in media_types: + metric_logger.add_meter( + f"{m}-{name}", SmoothedValue(window=config.metric_window_size, fmt="{value:.4f}") + ) + + header = f"Train Epoch: [{epoch}]" + log_freq = config.log_freq + + iterator = metric_logger.log_every(train_loader, log_freq, header) + mini_batch_losses = [] + + for i, (media_type, inputs) in enumerate(iterator): # video/image, conversation, instruction, index + + with accelerator.accumulate(model): + + inputs['media_type'] = media_type + response = model(**inputs) + loss = response.loss + mini_batch_losses.append(loss.detach().item()) + optimizer.zero_grad() + accelerator.backward(loss) + if config.optimizer.max_grad_norm > 0: + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(model.parameters(), config.optimizer.max_grad_norm) + optimizer.step() + lr_scheduler.step() + # # logging + for name in loss_names: + value = loss + value = value if isinstance(value, float) else value.item() + metric_logger.update(**{f"{media_type}-{name}": value}) + global_step += 1 + resume_num_samples = global_step * config.batch_size * accelerator.state.num_processes/1e6 + + # save small global step checkpoint in case of breakdown + if global_step % config.ckpt_steps == 0: + accelerator.save_state(output_dir=osp.join(config.output_dir, f"ckpt_resume_{resume_num_samples:.4f}M")) + if accelerator.is_main_process: + for fn in os.listdir(config.output_dir): + if "resume" in fn and fn != f"ckpt_resume_{resume_num_samples:.4f}M": + shutil.rmtree(osp.join(config.output_dir, fn)) + + if global_step % config.save_steps == 0: + logger.info(f"global_step {global_step}") + with torch.no_grad(): + accelerator.wait_for_everyone() + unwrapped_model = accelerator.unwrap_model(model) + if not config.deepspeed: + save_state_dict = {k:v for k,v in accelerator.get_state_dict(model).items() if "lora_" in k or "multi_modal_projector" in k} + else: + save_state_dict = accelerator.get_state_dict(model) + unwrapped_model.save_pretrained(osp.join(config.output_dir, f"pretrained_step{resume_num_samples:.4f}M"), + is_main_process=accelerator.is_main_process, + save_function=accelerator.save, + state_dict=save_state_dict) + processor.save_pretrained(osp.join(config.output_dir, f"pretrained_step{resume_num_samples:.4f}M")) + + if global_step % log_freq == 0: + logs = metric_logger.get_global_avg_dict() + logs.update({ + "step_loss_no_smoothing": accelerator.gather_for_metrics(loss).mean().item(), + "epoch": epoch, + "step": global_step, + "lr": lr_scheduler.get_last_lr()[0], + }) + accelerator.log(logs, step=global_step,) + if accelerator.sync_gradients: + mini_batch_loss = torch.tensor(mini_batch_losses, device='cuda') + accelerator.log({"mini_batch_loss": accelerator.gather_for_metrics(mini_batch_loss).mean().item()}, + step=global_step) + mini_batch_losses = [] + + + if config.debug and global_step % 20 == 0: + logger.info("debug mode, break training loop") + break + + if config.debug and global_step % (2 * log_freq + 3) == 0: + logger.info("debug mode, break training loop") + break + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + logger.info(f"Averaged stats: {metric_logger.global_avg()}") + logger.info(f"Epoch {epoch}") + with torch.no_grad(): + accelerator.wait_for_everyone() + unwrapped_model = accelerator.unwrap_model(model) + if not config.deepspeed: + save_state_dict = {k:v for k,v in accelerator.get_state_dict(model).items() if "lora_" in k or "multi_modal_projector" in k} + else: + save_state_dict = accelerator.get_state_dict(model) + unwrapped_model.save_pretrained(osp.join(config.output_dir, f"pretrained_epoch{epoch:02d}"), + is_main_process=accelerator.is_main_process, + save_function=accelerator.save, + state_dict=save_state_dict) + processor.save_pretrained(osp.join(config.output_dir, f"pretrained_step{epoch:02d}")) + accelerator.save_state(output_dir=osp.join(config.output_dir, f"ckpt_epoch{epoch:02d}")) + + + if config.evaluate: + break + + accelerator.end_training() + accelerator.wait_for_everyone() + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + logger.info(f"Training time {total_time_str}") + logger.info(f"Checkpoints and Logs saved at {config.output_dir}") + + + +if __name__ == "__main__": + cfg = setup_main() + print(cfg) + main(cfg) diff --git a/utils/basic_utils.py b/utils/basic_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fb453d35c852741bf1ad6dfe27e604d9fef6557e --- /dev/null +++ b/utils/basic_utils.py @@ -0,0 +1,286 @@ +import numpy as np +import io +import os +import json +import logging +import random +import time +from collections import defaultdict, deque +import datetime +from pathlib import Path +from typing import List, Union + +import torch +import torch.distributed as dist +from .distributed import is_dist_avail_and_initialized + + +logger = logging.getLogger(__name__) + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], + dtype=torch.float64, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + if meter.count == 0: # skip empty meter + loss_str.append( + "{}: {}".format(name, "No data") + ) + else: + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def global_avg(self): + loss_str = [] + for name, meter in self.meters.items(): + if meter.count == 0: + loss_str.append( + "{}: {}".format(name, "No data") + ) + else: + loss_str.append( + "{}: {:.4f}".format(name, meter.global_avg) + ) + return self.delimiter.join(loss_str) + + def get_global_avg_dict(self, prefix=""): + """include a separator (e.g., `/`, or "_") at the end of `prefix`""" + d = {f"{prefix}{k}": m.global_avg if m.count > 0 else 0. for k, m in self.meters.items()} + return d + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, log_freq, header=None): + i = 0 + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.4f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + log_msg = [ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}' + ] + if torch.cuda.is_available(): + log_msg.append('max mem: {memory:.0f} res mem: {res_mem:.0f}') + log_msg = self.delimiter.join(log_msg) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % log_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + logger.info(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB, + res_mem=torch.cuda.max_memory_reserved() / MB, + )) + else: + logger.info(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time))) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + logger.info('{} Total time: {} ({:.4f} s / it)'.format( + header, total_time_str, total_time / len(iterable))) + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def compute_acc(logits, label, reduction='mean'): + ret = (torch.argmax(logits, dim=1) == label).float() + if reduction == 'none': + return ret.detach() + elif reduction == 'mean': + return ret.mean().item() + + +def compute_n_params(model, return_str=True): + tot = 0 + for p in model.parameters(): + w = 1 + for x in p.shape: + w *= x + tot += w + if return_str: + if tot >= 1e6: + return '{:.1f}M'.format(tot / 1e6) + else: + return '{:.1f}K'.format(tot / 1e3) + else: + return tot + + +def setup_seed(seed): + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + + +def remove_files_if_exist(file_paths): + for fp in file_paths: + if os.path.isfile(fp): + os.remove(fp) + + +def save_json(data, filename, save_pretty=False, sort_keys=False): + with open(filename, "w") as f: + if save_pretty: + f.write(json.dumps(data, indent=4, sort_keys=sort_keys)) + else: + json.dump(data, f) + + +def load_json(filename): + with open(filename, "r") as f: + return json.load(f) + + +def flat_list_of_lists(l): + """flatten a list of lists [[1,2], [3,4]] to [1,2,3,4]""" + return [item for sublist in l for item in sublist] + + +def find_files_by_suffix_recursively(root: str, suffix: Union[str, List[str]]): + """ + Args: + root: path to the directory to start search files + suffix: any str as suffix, or can match multiple such strings + when input is List[str]. + Example 1, e.g., suffix: `.jpg` or [`.jpg`, `.png`] + Example 2, e.g., use a `*` in the `suffix`: `START*.jpg.`. + """ + if isinstance(suffix, str): + suffix = [suffix, ] + filepaths = flat_list_of_lists( + [list(Path(root).rglob(f"*{e}")) for e in suffix]) + return filepaths + + +def match_key_and_shape(state_dict1, state_dict2): + keys1 = set(state_dict1.keys()) + keys2 = set(state_dict2.keys()) + print(f"keys1 - keys2: {keys1 - keys2}") + print(f"keys2 - keys1: {keys2 - keys1}") + + mismatch = 0 + for k in list(keys1): + if state_dict1[k].shape != state_dict2[k].shape: + print( + f"k={k}, state_dict1[k].shape={state_dict1[k].shape}, state_dict2[k].shape={state_dict2[k].shape}") + mismatch += 1 + print(f"mismatch {mismatch}") + + +def merge_dicts(list_dicts): + merged_dict = list_dicts[0].copy() + for i in range(1, len(list_dicts)): + merged_dict.update(list_dicts[i]) + return merged_dict diff --git a/utils/config.py b/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..63f9ef375b37daa6926f2259502913e38f22e6e2 --- /dev/null +++ b/utils/config.py @@ -0,0 +1,281 @@ +from __future__ import annotations + +import argparse +import ast +import json +import os +import os.path as osp +import re +import shutil +import sys +import tempfile +from copy import deepcopy +from importlib import import_module + +import yaml + +from .easydict import EasyDict + +__all__ = ["Config", "pretty_text"] + + +BASE_KEY = "_base_" +# BASE_CONFIG = {"OUTPUT_DIR": "./workspace", "SESSION": "base", "LOG_FILE": "log.txt"} +BASE_CONFIG = {} + +cfg = None + + +class Config(object): + """config""" + + @classmethod + def pretty_text(cls, cfg: dict, indent=2) -> str: + """format dict to a string + + Args: + cfg (EasyDict): the params. + + Returns: The string to display. + + """ + msg = "{\n" + for i, (k, v) in enumerate(cfg.items()): + if isinstance(v, dict): + v = cls.pretty_text(v, indent + 4) + spaces = " " * indent + msg += spaces + "{}: {}".format(k, v) + if i == len(cfg) - 1: + msg += " }" + else: + msg += "\n" + return msg + + @classmethod + def dump(cls, cfg, savepath=None): + """dump cfg to `json` file. + + Args: + cfg (dict): The dict to dump. + savepath (str): The filepath to save the dumped dict. + + Returns: TODO + + """ + if savepath is None: + savepath = osp.join(cfg.WORKSPACE, "config.json") + json.dump(cfg, open(savepath, "w"), indent=2) + + @classmethod + def get_config(cls, default_config: dict = None): + """get a `Config` instance. + + Args: + default_config (dict): The default config. `default_config` will be overrided + by config file `--cfg`, `--cfg` will be overrided by commandline args. + + Returns: an EasyDict. + """ + global cfg + if cfg is not None: + return cfg + + # define arg parser. + parser = argparse.ArgumentParser() + # parser.add_argument("--cfg", help="load configs from yaml file", default="", type=str) + parser.add_argument( + "config_file", help="the configuration file to load. support: .yaml, .json, .py" + ) + parser.add_argument( + "opts", + default=None, + nargs="*", + help="overrided configs. List. Format: 'key1 name1 key2 name2'", + ) + args = parser.parse_args() + + cfg = EasyDict(BASE_CONFIG) + if osp.isfile(args.config_file): + cfg_from_file = cls.from_file(args.config_file) + cfg = merge_a_into_b(cfg_from_file, cfg) + cfg = cls.merge_list(cfg, args.opts) + cfg = eval_dict_leaf(cfg) + + # update some keys to make them show at the last + for k in BASE_CONFIG: + cfg[k] = cfg.pop(k) + return cfg + + @classmethod + def from_file(cls, filepath: str) -> EasyDict: + """Build config from file. Supported filetypes: `.py`,`.yaml`,`.json`. + + Args: + filepath (str): The config file path. + + Returns: TODO + + """ + filepath = osp.abspath(osp.expanduser(filepath)) + if not osp.isfile(filepath): + raise IOError(f"File does not exist: {filepath}") + if filepath.endswith(".py"): + with tempfile.TemporaryDirectory() as temp_config_dir: + + shutil.copytree(osp.dirname(filepath), osp.join(temp_config_dir, "tmp_config")) + sys.path.insert(0, temp_config_dir) + mod = import_module("tmp_config." + osp.splitext(osp.basename(filepath))[0]) + # mod = import_module(temp_module_name) + sys.path.pop(0) + cfg_dict = { + name: value + for name, value in mod.__dict__.items() + if not name.startswith("__") + } + for k in list(sys.modules.keys()): + if "tmp_config" in k: + del sys.modules[k] + elif filepath.endswith((".yml", ".yaml")): + cfg_dict = yaml.load(open(filepath, "r"), Loader=yaml.Loader) + elif filepath.endswith(".json"): + cfg_dict = json.load(open(filepath, "r")) + else: + raise IOError("Only py/yml/yaml/json type are supported now!") + + cfg_text = filepath + "\n" + with open(filepath, "r") as f: + cfg_text += f.read() + + if BASE_KEY in cfg_dict: # load configs in `BASE_KEY` + cfg_dir = osp.dirname(filepath) + base_filename = cfg_dict.pop(BASE_KEY) + base_filename = ( + base_filename if isinstance(base_filename, list) else [base_filename] + ) + + cfg_dict_list = list() + for f in base_filename: + _cfg_dict = Config.from_file(osp.join(cfg_dir, f)) + cfg_dict_list.append(_cfg_dict) + + base_cfg_dict = dict() + for c in cfg_dict_list: + if len(base_cfg_dict.keys() & c.keys()) > 0: + raise KeyError("Duplicate key is not allowed among bases") + base_cfg_dict.update(c) + + cfg_dict = merge_a_into_b(cfg_dict, base_cfg_dict) + + return EasyDict(cfg_dict) + + @classmethod + def merge_list(cls, cfg, opts: list): + """merge commandline opts. + + Args: + cfg: (dict): The config to be merged. + opts (list): The list to merge. Format: [key1, name1, key2, name2,...]. + The keys can be nested. For example, ["a.b", v] will be considered + as `dict(a=dict(b=v))`. + + Returns: dict. + + """ + assert len(opts) % 2 == 0, f"length of opts must be even. Got: {opts}" + for i in range(0, len(opts), 2): + full_k, v = opts[i], opts[i + 1] + keys = full_k.split(".") + sub_d = cfg + for i, k in enumerate(keys): + if not hasattr(sub_d, k): + raise ValueError(f"The key {k} not exist in the config. Full key:{full_k}") + if i != len(keys) - 1: + sub_d = sub_d[k] + else: + sub_d[k] = v + return cfg + + +def merge_a_into_b(a, b, inplace=False): + """The values in a will override values in b. + + Args: + a (dict): source dict. + b (dict): target dict. + + Returns: dict. recursively merge dict a into dict b. + + """ + if not inplace: + b = deepcopy(b) + for key in a: + if key in b: + if isinstance(a[key], dict) and isinstance(b[key], dict): + b[key] = merge_a_into_b(a[key], b[key], inplace=True) + else: + b[key] = a[key] + else: + b[key] = a[key] + return b + + +def eval_dict_leaf(d, orig_dict=None): + """eval values of dict leaf. + + Args: + d (dict): The dict to eval. + + Returns: dict. + + """ + if orig_dict is None: + orig_dict = d + for k, v in d.items(): + if not isinstance(v, dict): + d[k] = eval_string(v, orig_dict) + else: + eval_dict_leaf(v, orig_dict) + return d + + +def eval_string(string, d): + """automatically evaluate string to corresponding types. + + For example: + not a string -> return the original input + '0' -> 0 + '0.2' -> 0.2 + '[0, 1, 2]' -> [0,1,2] + 'eval(1+2)' -> 3 + 'eval(range(5))' -> [0,1,2,3,4] + '${a}' -> d.a + + + + Args: + string (str): The value to evaluate. + d (dict): The + + Returns: the corresponding type + + """ + if not isinstance(string, str): + return string + # if len(string) > 1 and string[0] == "[" and string[-1] == "]": + # return eval(string) + if string[0:5] == "eval(": + return eval(string[5:-1]) + + s0 = string + s1 = re.sub(r"\${(.*)}", r"d.\1", s0) + if s1 != s0: + while s1 != s0: + s0 = s1 + s1 = re.sub(r"\${(.*)}", r"d.\1", s0) + return eval(s1) + + try: + v = ast.literal_eval(string) + except: + v = string + return v diff --git a/utils/config_utils.py b/utils/config_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..72e31c7c922e811e62e2b92e708ab087651c40c2 --- /dev/null +++ b/utils/config_utils.py @@ -0,0 +1,60 @@ +import logging +import os +import sys +from os.path import dirname, join + +from utils.config import Config +from utils.distributed import init_distributed_mode, is_main_process +from utils.logger import setup_logger + +logger = logging.getLogger(__name__) + + +def setup_config(): + """Conbine yaml config and command line config with OmegaConf. + Also converts types, e.g., `'None'` (str) --> `None` (None) + """ + config = Config.get_config() + if config.debug: + config.wandb.enable = False + return config + + +def setup_evaluate_config(config): + """setup evaluation default settings, e.g., disable wandb""" + assert config.evaluate + config.wandb.enable = False + if config.output_dir is None: + config.output_dir = join(dirname(config.pretrained_path), "eval") + return config + + +def setup_output_dir(output_dir, excludes=["code"]): + """ensure not overwritting an exisiting/non-empty output dir""" + if not os.path.exists(output_dir): + os.makedirs(output_dir, exist_ok=False) + else: + existing_dirs_files = os.listdir(output_dir) # list + remaining = set(existing_dirs_files) - set(excludes) + remaining = [e for e in remaining if "slurm" not in e] + remaining = [e for e in remaining if ".out" not in e] + # assert len(remaining) == 0, f"remaining dirs or files: {remaining}" + logger.warn(f"remaining dirs or files: {remaining}") + + +def setup_main(): + """ + Setup config, logger, output_dir, etc. + Shared for pretrain and all downstream tasks. + """ + config = setup_config() + if hasattr(config, "evaluate") and config.evaluate: + config = setup_evaluate_config(config) + init_distributed_mode(config) + + if is_main_process(): + setup_output_dir(config.output_dir, excludes=["code"]) + setup_logger(output=config.output_dir, color=True, name="vindlu") + logger.info(f"config: {Config.pretty_text(config)}") + Config.dump(config, os.path.join(config.output_dir, "config.json")) + return config diff --git a/utils/distributed.py b/utils/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..780417ec19767ec8b820bec13a0f030b64e2177e --- /dev/null +++ b/utils/distributed.py @@ -0,0 +1,162 @@ +import os +import torch +import torch.distributed as dist +import logging + + +logger = logging.getLogger(__name__) + + +def setup_for_distributed(is_master): + import warnings + + builtin_warn = warnings.warn + + def warn(*args, **kwargs): + force = kwargs.pop("force", False) + if is_master or force: + builtin_warn(*args, **kwargs) + + # Log warnings only once + warnings.warn = warn + warnings.simplefilter("once", UserWarning) + + if not is_master: + logging.disable() + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def is_port_in_use(port): + import socket + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + return s.connect_ex(('localhost', port)) == 0 + + +def init_distributed_mode(args): + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + # job started by torch.distributed.launch + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + elif 'SLURM_PROCID' in os.environ: + # local rank on the current node / global rank + local_rank = int(os.environ['SLURM_LOCALID']) + global_rank = int(os.environ['SLURM_PROCID']) + # number of processes / GPUs per node + world_size = int(os.environ["SLURM_NNODES"]) * \ + int(os.environ["SLURM_TASKS_PER_NODE"][0]) + + print(world_size) + + args.rank = global_rank + args.gpu = local_rank + args.world_size = world_size + else: + logger.info('Not using distributed mode') + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = 'nccl' + + if "tcp" in args.dist_url: # in slurm, multiple program runs in a single node + dist_port = int(args.dist_url.split(":")[-1]) + while is_port_in_use(dist_port): + dist_port += 10 + args.dist_url = ":".join(args.dist_url.split(":")[:-1] + [str(dist_port)]) + print(args.dist_url) + + logger.info('| distributed init (rank {}): {}'.format( + args.rank, args.dist_url)) + if "SLURM_JOB_ID" in os.environ: + logger.info(f"SLURM_JOB_ID {os.environ['SLURM_JOB_ID']}") + torch.distributed.init_process_group( + backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +# Copyright (c) Facebook, Inc. and its affiliates. +# copied from https://github.com/facebookresearch/vissl/blob/master/vissl/utils/distributed_gradients.py +class GatherLayer(torch.autograd.Function): + """ + Gather tensors from all workers with support for backward propagation: + This implementation does not cut the gradients as torch.distributed.all_gather does. + """ + + @staticmethod + def forward(ctx, x): + output = [torch.zeros_like(x) for _ in range(dist.get_world_size())] + dist.all_gather(output, x) + return tuple(output) + + @staticmethod + def backward(ctx, *grads): + all_gradients = torch.stack(grads) + dist.all_reduce(all_gradients) + return all_gradients[dist.get_rank()] + + +# copied from megavlt +def gather_tensor_along_batch_with_backward(tensor, dim=0): + world_size = get_world_size() + + if world_size < 2: + return tensor + + tensor_list = GatherLayer.apply(tensor) + tensor_list = torch.cat(tensor_list, dim=dim) + return tensor_list + + +@torch.no_grad() +def gather_tensor_along_batch(tensor, dim=0): + """ + Performs all_gather operation on the provided tensors. + *** Warning ***: torch.distributed.all_gather has no gradient. + """ + world_size = get_world_size() + + if world_size < 2: + return tensor + + with torch.no_grad(): + tensor_list = [] + + for _ in range(world_size): + tensor_list.append(torch.zeros_like(tensor)) + + dist.all_gather(tensor_list, tensor) + tensor_list = torch.cat(tensor_list, dim=dim) + return tensor_list diff --git a/utils/easydict.py b/utils/easydict.py new file mode 100644 index 0000000000000000000000000000000000000000..241aca41c9f1b0677be4bf6070c077fa24501816 --- /dev/null +++ b/utils/easydict.py @@ -0,0 +1,149 @@ +class EasyDict(dict): + """ + Get attributes + + >>> d = EasyDict({'foo':3}) + >>> d['foo'] + 3 + >>> d.foo + 3 + >>> d.bar + Traceback (most recent call last): + ... + AttributeError: 'EasyDict' object has no attribute 'bar' + + Works recursively + + >>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}}) + >>> isinstance(d.bar, dict) + True + >>> d.bar.x + 1 + + Bullet-proof + + >>> EasyDict({}) + {} + >>> EasyDict(d={}) + {} + >>> EasyDict(None) + {} + >>> d = {'a': 1} + >>> EasyDict(**d) + {'a': 1} + + Set attributes + + >>> d = EasyDict() + >>> d.foo = 3 + >>> d.foo + 3 + >>> d.bar = {'prop': 'value'} + >>> d.bar.prop + 'value' + >>> d + {'foo': 3, 'bar': {'prop': 'value'}} + >>> d.bar.prop = 'newer' + >>> d.bar.prop + 'newer' + + + Values extraction + + >>> d = EasyDict({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]}) + >>> isinstance(d.bar, list) + True + >>> from operator import attrgetter + >>> map(attrgetter('x'), d.bar) + [1, 3] + >>> map(attrgetter('y'), d.bar) + [2, 4] + >>> d = EasyDict() + >>> d.keys() + [] + >>> d = EasyDict(foo=3, bar=dict(x=1, y=2)) + >>> d.foo + 3 + >>> d.bar.x + 1 + + Still like a dict though + + >>> o = EasyDict({'clean':True}) + >>> o.items() + [('clean', True)] + + And like a class + + >>> class Flower(EasyDict): + ... power = 1 + ... + >>> f = Flower() + >>> f.power + 1 + >>> f = Flower({'height': 12}) + >>> f.height + 12 + >>> f['power'] + 1 + >>> sorted(f.keys()) + ['height', 'power'] + + update and pop items + >>> d = EasyDict(a=1, b='2') + >>> e = EasyDict(c=3.0, a=9.0) + >>> d.update(e) + >>> d.c + 3.0 + >>> d['c'] + 3.0 + >>> d.get('c') + 3.0 + >>> d.update(a=4, b=4) + >>> d.b + 4 + >>> d.pop('a') + 4 + >>> d.a + Traceback (most recent call last): + ... + AttributeError: 'EasyDict' object has no attribute 'a' + """ + + def __init__(self, d=None, **kwargs): + if d is None: + d = {} + if kwargs: + d.update(**kwargs) + for k, v in d.items(): + setattr(self, k, v) + # Class attributes + for k in self.__class__.__dict__.keys(): + if not (k.startswith("__") and k.endswith("__")) and not k in ("update", "pop"): + setattr(self, k, getattr(self, k)) + + def __setattr__(self, name, value): + if isinstance(value, (list, tuple)): + value = [self.__class__(x) if isinstance(x, dict) else x for x in value] + elif isinstance(value, dict) and not isinstance(value, self.__class__): + value = self.__class__(value) + super(EasyDict, self).__setattr__(name, value) + super(EasyDict, self).__setitem__(name, value) + + __setitem__ = __setattr__ + + def update(self, e=None, **f): + d = e or dict() + d.update(f) + for k in d: + setattr(self, k, d[k]) + + def pop(self, k, d=None): + if hasattr(self, k): + delattr(self, k) + return super(EasyDict, self).pop(k, d) + + +if __name__ == "__main__": + import doctest + diff --git a/utils/logger.py b/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..f3164ae7251e1f0006173c4f409c0901742048d6 --- /dev/null +++ b/utils/logger.py @@ -0,0 +1,263 @@ +# from MMF: https://github.com/facebookresearch/mmf/blob/master/mmf/utils/logger.py +# Copyright (c) Facebook, Inc. and its affiliates. + +import functools +import logging +import os +import sys +import time +import wandb +from typing import Any, Dict, Union + +import torch +from .distributed import get_rank, is_main_process +from termcolor import colored + + +def log_dict_to_wandb(log_dict, step, prefix=""): + """include a separator `/` at the end of `prefix`""" + if not is_main_process(): + return + + log_dict = {f"{prefix}{k}": v for k, v in log_dict.items()} + wandb.log(log_dict, step) + + +def setup_wandb(config): + if not (config.wandb.enable and is_main_process()): + return + + run = wandb.init( + config=config, + project=config.wandb.project, + entity=config.wandb.entity, + name=os.path.basename(config.output_dir), + reinit=True + ) + return run + + +def setup_output_folder(save_dir: str, folder_only: bool = False): + """Sets up and returns the output file where the logs will be placed + based on the configuration passed. Usually "save_dir/logs/log_.txt". + If env.log_dir is passed, logs will be directly saved in this folder. + Args: + folder_only (bool, optional): If folder should be returned and not the file. + Defaults to False. + Returns: + str: folder or file path depending on folder_only flag + """ + log_filename = "train_" + log_filename += time.strftime("%Y_%m_%dT%H_%M_%S") + log_filename += ".log" + + log_folder = os.path.join(save_dir, "logs") + + if not os.path.exists(log_folder): + os.path.mkdirs(log_folder) + + if folder_only: + return log_folder + + log_filename = os.path.join(log_folder, log_filename) + + return log_filename + + +def setup_logger( + output: str = None, + color: bool = True, + name: str = "mmf", + disable: bool = False, + clear_handlers=True, + *args, + **kwargs, +): + """ + Initialize the MMF logger and set its verbosity level to "INFO". + Outside libraries shouldn't call this in case they have set there + own logging handlers and setup. If they do, and don't want to + clear handlers, pass clear_handlers options. + The initial version of this function was taken from D2 and adapted + for MMF. + Args: + output (str): a file name or a directory to save log. + If ends with ".txt" or ".log", assumed to be a file name. + Default: Saved to file + color (bool): If false, won't log colored logs. Default: true + name (str): the root module name of this logger. Defaults to "mmf". + disable: do not use + clear_handlers (bool): If false, won't clear existing handlers. + Returns: + logging.Logger: a logger + """ + if disable: + return None + logger = logging.getLogger(name) + logger.propagate = False + + logging.captureWarnings(True) + warnings_logger = logging.getLogger("py.warnings") + + plain_formatter = logging.Formatter( + "%(asctime)s | %(levelname)s | %(name)s : %(message)s", + datefmt="%Y-%m-%dT%H:%M:%S", + ) + + distributed_rank = get_rank() + handlers = [] + + logging_level = logging.INFO + # logging_level = logging.DEBUG + + if distributed_rank == 0: + logger.setLevel(logging_level) + ch = logging.StreamHandler(stream=sys.stdout) + ch.setLevel(logging_level) + if color: + formatter = ColorfulFormatter( + colored("%(asctime)s | %(name)s: ", "green") + "%(message)s", + datefmt="%Y-%m-%dT%H:%M:%S", + ) + else: + formatter = plain_formatter + ch.setFormatter(formatter) + logger.addHandler(ch) + warnings_logger.addHandler(ch) + handlers.append(ch) + + # file logging: all workers + if output is None: + output = setup_output_folder() + + if output is not None: + if output.endswith(".txt") or output.endswith(".log"): + filename = output + else: + filename = os.path.join(output, "train.log") + if distributed_rank > 0: + filename = filename + f".rank{distributed_rank}" + os.makedirs(os.path.dirname(filename), exist_ok=True) + + fh = logging.StreamHandler(_cached_log_stream(filename)) + fh.setLevel(logging_level) + fh.setFormatter(plain_formatter) + logger.addHandler(fh) + warnings_logger.addHandler(fh) + handlers.append(fh) + + # Slurm/FB output, only log the main process + # save_dir = get_mmf_env(key="save_dir") + if "train.log" not in filename and distributed_rank == 0: + filename = os.path.join(output, "train.log") + sh = logging.StreamHandler(_cached_log_stream(filename)) + sh.setLevel(logging_level) + sh.setFormatter(plain_formatter) + logger.addHandler(sh) + warnings_logger.addHandler(sh) + handlers.append(sh) + + logger.info(f"Logging to: {filename}") + + # Remove existing handlers to add MMF specific handlers + if clear_handlers: + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + # Now, add our handlers. + logging.basicConfig(level=logging_level, handlers=handlers) + + return logger + + +def setup_very_basic_config(color=True): + plain_formatter = logging.Formatter( + "%(asctime)s | %(levelname)s | %(name)s : %(message)s", + datefmt="%Y-%m-%dT%H:%M:%S", + ) + ch = logging.StreamHandler(stream=sys.stdout) + ch.setLevel(logging.INFO) + if color: + formatter = ColorfulFormatter( + colored("%(asctime)s | %(name)s: ", "green") + "%(message)s", + datefmt="%Y-%m-%dT%H:%M:%S", + ) + else: + formatter = plain_formatter + ch.setFormatter(formatter) + # Setup a minimal configuration for logging in case something tries to + # log a message even before logging is setup by MMF. + logging.basicConfig(level=logging.INFO, handlers=[ch]) + + +# cache the opened file object, so that different calls to `setup_logger` +# with the same file name can safely write to the same file. +@functools.lru_cache(maxsize=None) +def _cached_log_stream(filename): + return open(filename, "a") + + +# ColorfulFormatter is adopted from Detectron2 and adapted for MMF +class ColorfulFormatter(logging.Formatter): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def formatMessage(self, record): + log = super().formatMessage(record) + if record.levelno == logging.WARNING: + prefix = colored("WARNING", "red", attrs=["blink"]) + elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: + prefix = colored("ERROR", "red", attrs=["blink", "underline"]) + else: + return log + return prefix + " " + log + + +class TensorboardLogger: + def __init__(self, log_folder="./logs", iteration=0): + # This would handle warning of missing tensorboard + from torch.utils.tensorboard import SummaryWriter + + self.summary_writer = None + self._is_master = is_main_process() + # self.timer = Timer() + self.log_folder = log_folder + + if self._is_master: + # current_time = self.timer.get_time_hhmmss(None, format=self.time_format) + current_time = time.strftime("%Y-%m-%dT%H:%M:%S") + # self.timer.get_time_hhmmss(None, format=self.time_format) + tensorboard_folder = os.path.join( + self.log_folder, f"tensorboard_{current_time}" + ) + self.summary_writer = SummaryWriter(tensorboard_folder) + + def __del__(self): + if getattr(self, "summary_writer", None) is not None: + self.summary_writer.close() + + def _should_log_tensorboard(self): + if self.summary_writer is None or not self._is_master: + return False + else: + return True + + def add_scalar(self, key, value, iteration): + if not self._should_log_tensorboard(): + return + + self.summary_writer.add_scalar(key, value, iteration) + + def add_scalars(self, scalar_dict, iteration): + if not self._should_log_tensorboard(): + return + + for key, val in scalar_dict.items(): + self.summary_writer.add_scalar(key, val, iteration) + + def add_histogram_for_model(self, model, iteration): + if not self._should_log_tensorboard(): + return + + for name, param in model.named_parameters(): + np_param = param.clone().cpu().data.numpy() + self.summary_writer.add_histogram(name, np_param, iteration) diff --git a/utils/optimizer.py b/utils/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..679483b72556c83d6ff19bc51fe4db41c656b56d --- /dev/null +++ b/utils/optimizer.py @@ -0,0 +1,133 @@ +""" Optimizer Factory w/ Custom Weight Decay +Hacked together by / Copyright 2020 Ross Wightman +""" +import re +import torch +from torch import optim as optim +from utils.distributed import is_main_process +import logging +logger = logging.getLogger(__name__) +try: + from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD + has_apex = True +except ImportError: + has_apex = False + + +def add_weight_decay(model, weight_decay, no_decay_list=(), filter_bias_and_bn=True): + named_param_tuples = [] + for name, param in model.named_parameters(): + if not param.requires_grad: + continue # frozen weights + if filter_bias_and_bn and (len(param.shape) == 1 or name.endswith(".bias")): + named_param_tuples.append([name, param, 0]) + elif name in no_decay_list: + named_param_tuples.append([name, param, 0]) + else: + named_param_tuples.append([name, param, weight_decay]) + return named_param_tuples + + +def add_different_lr(named_param_tuples_or_model, diff_lr_names, diff_lr, default_lr): + """use lr=diff_lr for modules named found in diff_lr_names, + otherwise use lr=default_lr + + Args: + named_param_tuples_or_model: List([name, param, weight_decay]), or nn.Module + diff_lr_names: List(str) + diff_lr: float + default_lr: float + Returns: + named_param_tuples_with_lr: List([name, param, weight_decay, lr]) + """ + named_param_tuples_with_lr = [] + logger.info(f"diff_names: {diff_lr_names}, diff_lr: {diff_lr}") + for name, p, wd in named_param_tuples_or_model: + use_diff_lr = False + for diff_name in diff_lr_names: + # if diff_name in name: + if re.search(diff_name, name) is not None: + logger.info(f"param {name} use different_lr: {diff_lr}") + use_diff_lr = True + break + + named_param_tuples_with_lr.append( + [name, p, wd, diff_lr if use_diff_lr else default_lr] + ) + + if is_main_process(): + for name, _, wd, diff_lr in named_param_tuples_with_lr: + logger.info(f"param {name}: wd: {wd}, lr: {diff_lr}") + + return named_param_tuples_with_lr + + +def create_optimizer_params_group(named_param_tuples_with_lr): + """named_param_tuples_with_lr: List([name, param, weight_decay, lr])""" + group = {} + for name, p, wd, lr in named_param_tuples_with_lr: + if wd not in group: + group[wd] = {} + if lr not in group[wd]: + group[wd][lr] = [] + group[wd][lr].append(p) + + optimizer_params_group = [] + for wd, lr_groups in group.items(): + for lr, p in lr_groups.items(): + optimizer_params_group.append(dict( + params=p, + weight_decay=wd, + lr=lr + )) + logger.info(f"optimizer -- lr={lr} wd={wd} len(p)={len(p)}") + return optimizer_params_group + + +def create_optimizer(args, model, filter_bias_and_bn=True): + opt_lower = args.opt.lower() + weight_decay = args.weight_decay + # check for modules that requires different lr + if hasattr(args, "different_lr") and args.different_lr.enable: + diff_lr_module_names = args.different_lr.module_names + diff_lr = args.different_lr.lr + else: + diff_lr_module_names = [] + diff_lr = None + + no_decay = {} + if hasattr(model, 'no_weight_decay'): + no_decay = model.no_weight_decay() + named_param_tuples = add_weight_decay( + model, weight_decay, no_decay, filter_bias_and_bn) + named_param_tuples = add_different_lr( + named_param_tuples, diff_lr_module_names, diff_lr, args.lr) + parameters = create_optimizer_params_group(named_param_tuples) + + if 'fused' in opt_lower: + assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' + + opt_args = dict(lr=args.lr, weight_decay=weight_decay) + if hasattr(args, 'opt_eps') and args.opt_eps is not None: + opt_args['eps'] = args.opt_eps + if hasattr(args, 'opt_betas') and args.opt_betas is not None: + opt_args['betas'] = args.opt_betas + if hasattr(args, 'opt_args') and args.opt_args is not None: + opt_args.update(args.opt_args) + + opt_split = opt_lower.split('_') + opt_lower = opt_split[-1] + if opt_lower == 'sgd' or opt_lower == 'nesterov': + opt_args.pop('eps', None) + optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) + elif opt_lower == 'momentum': + opt_args.pop('eps', None) + optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) + elif opt_lower == 'adam': + optimizer = optim.Adam(parameters, **opt_args) + elif opt_lower == 'adamw': + optimizer = optim.AdamW(parameters, **opt_args) + else: + assert False and "Invalid optimizer" + raise ValueError + return optimizer diff --git a/utils/scheduler.py b/utils/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..c5d050fb0d95d8213651b36558668df969694d73 --- /dev/null +++ b/utils/scheduler.py @@ -0,0 +1,56 @@ +""" Scheduler Factory +Hacked together by / Copyright 2020 Ross Wightman +""" +from torch.optim import Optimizer +import math +from torch.optim.lr_scheduler import LambdaLR + + +def create_scheduler(args, optimizer): + lr_scheduler = None + if args.sched == 'cosine': + lr_scheduler = get_cosine_schedule_with_warmup( + optimizer, + num_warmup_steps=args.num_warmup_steps, + num_training_steps=args.num_training_steps, + num_cycles=0.5, + min_lr_multi=args.min_lr_multi + ) + return lr_scheduler + + +def get_cosine_schedule_with_warmup( + optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, + num_cycles: float = 0.5, min_lr_multi: float = 0., last_epoch: int = -1 +): + """ + Modified from https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/optimization.py + + Create a schedule with a learning rate that decreases following the values of the cosine function between the + initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the + initial lr set in the optimizer. + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + num_cycles (`float`, *optional*, defaults to 0.5): + The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 + following a half-cosine). + min_lr_multi (`float`, *optional*, defaults to 0): + The minimum learning rate multiplier. Thus the minimum learning rate is base_lr * min_lr_multi. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + def lr_lambda(current_step): + if current_step < num_warmup_steps: + return max(min_lr_multi, float(current_step) / float(max(1, num_warmup_steps))) + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) + return max(min_lr_multi, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) + + return LambdaLR(optimizer, lr_lambda, last_epoch)