Upload 17 files
Browse files- BLIP_CODEOWNERS.txt +2 -0
- BLIP_CODE_OF_CONDUCT.md +105 -0
- BLIP_LICENSE.txt +12 -0
- BLIP_README.md +114 -0
- BLIP_SECURITY.md +7 -0
- BLIP_cog.yaml.txt +17 -0
- BLIP_demo.ipynb.txt +0 -0
- BLIP_eval_nocaps.py +118 -0
- BLIP_eval_retrieval_video.py +250 -0
- BLIP_predict.py +98 -0
- BLIP_pretrain.py +173 -0
- BLIP_requirements.txt +4 -0
- BLIP_train_caption.py +206 -0
- BLIP_train_nlvr.py +213 -0
- BLIP_train_retrieval.py +345 -0
- BLIP_train_vqa.py +202 -0
- BLIP_utils.py +278 -0
BLIP_CODEOWNERS.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Comment line immediately above ownership line is reserved for related gus information. Please be careful while editing.
|
2 |
+
#ECCN:Open Source
|
BLIP_CODE_OF_CONDUCT.md
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Salesforce Open Source Community Code of Conduct
|
2 |
+
|
3 |
+
## About the Code of Conduct
|
4 |
+
|
5 |
+
Equality is a core value at Salesforce. We believe a diverse and inclusive
|
6 |
+
community fosters innovation and creativity, and are committed to building a
|
7 |
+
culture where everyone feels included.
|
8 |
+
|
9 |
+
Salesforce open-source projects are committed to providing a friendly, safe, and
|
10 |
+
welcoming environment for all, regardless of gender identity and expression,
|
11 |
+
sexual orientation, disability, physical appearance, body size, ethnicity, nationality,
|
12 |
+
race, age, religion, level of experience, education, socioeconomic status, or
|
13 |
+
other similar personal characteristics.
|
14 |
+
|
15 |
+
The goal of this code of conduct is to specify a baseline standard of behavior so
|
16 |
+
that people with different social values and communication styles can work
|
17 |
+
together effectively, productively, and respectfully in our open source community.
|
18 |
+
It also establishes a mechanism for reporting issues and resolving conflicts.
|
19 |
+
|
20 |
+
All questions and reports of abusive, harassing, or otherwise unacceptable behavior
|
21 |
+
in a Salesforce open-source project may be reported by contacting the Salesforce
|
22 |
+
Open Source Conduct Committee at ossconduct@salesforce.com.
|
23 |
+
|
24 |
+
## Our Pledge
|
25 |
+
|
26 |
+
In the interest of fostering an open and welcoming environment, we as
|
27 |
+
contributors and maintainers pledge to making participation in our project and
|
28 |
+
our community a harassment-free experience for everyone, regardless of gender
|
29 |
+
identity and expression, sexual orientation, disability, physical appearance,
|
30 |
+
body size, ethnicity, nationality, race, age, religion, level of experience, education,
|
31 |
+
socioeconomic status, or other similar personal characteristics.
|
32 |
+
|
33 |
+
## Our Standards
|
34 |
+
|
35 |
+
Examples of behavior that contributes to creating a positive environment
|
36 |
+
include:
|
37 |
+
|
38 |
+
* Using welcoming and inclusive language
|
39 |
+
* Being respectful of differing viewpoints and experiences
|
40 |
+
* Gracefully accepting constructive criticism
|
41 |
+
* Focusing on what is best for the community
|
42 |
+
* Showing empathy toward other community members
|
43 |
+
|
44 |
+
Examples of unacceptable behavior by participants include:
|
45 |
+
|
46 |
+
* The use of sexualized language or imagery and unwelcome sexual attention or
|
47 |
+
advances
|
48 |
+
* Personal attacks, insulting/derogatory comments, or trolling
|
49 |
+
* Public or private harassment
|
50 |
+
* Publishing, or threatening to publish, others' private information—such as
|
51 |
+
a physical or electronic address—without explicit permission
|
52 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
53 |
+
professional setting
|
54 |
+
* Advocating for or encouraging any of the above behaviors
|
55 |
+
|
56 |
+
## Our Responsibilities
|
57 |
+
|
58 |
+
Project maintainers are responsible for clarifying the standards of acceptable
|
59 |
+
behavior and are expected to take appropriate and fair corrective action in
|
60 |
+
response to any instances of unacceptable behavior.
|
61 |
+
|
62 |
+
Project maintainers have the right and responsibility to remove, edit, or
|
63 |
+
reject comments, commits, code, wiki edits, issues, and other contributions
|
64 |
+
that are not aligned with this Code of Conduct, or to ban temporarily or
|
65 |
+
permanently any contributor for other behaviors that they deem inappropriate,
|
66 |
+
threatening, offensive, or harmful.
|
67 |
+
|
68 |
+
## Scope
|
69 |
+
|
70 |
+
This Code of Conduct applies both within project spaces and in public spaces
|
71 |
+
when an individual is representing the project or its community. Examples of
|
72 |
+
representing a project or community include using an official project email
|
73 |
+
address, posting via an official social media account, or acting as an appointed
|
74 |
+
representative at an online or offline event. Representation of a project may be
|
75 |
+
further defined and clarified by project maintainers.
|
76 |
+
|
77 |
+
## Enforcement
|
78 |
+
|
79 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
80 |
+
reported by contacting the Salesforce Open Source Conduct Committee
|
81 |
+
at ossconduct@salesforce.com. All complaints will be reviewed and investigated
|
82 |
+
and will result in a response that is deemed necessary and appropriate to the
|
83 |
+
circumstances. The committee is obligated to maintain confidentiality with
|
84 |
+
regard to the reporter of an incident. Further details of specific enforcement
|
85 |
+
policies may be posted separately.
|
86 |
+
|
87 |
+
Project maintainers who do not follow or enforce the Code of Conduct in good
|
88 |
+
faith may face temporary or permanent repercussions as determined by other
|
89 |
+
members of the project's leadership and the Salesforce Open Source Conduct
|
90 |
+
Committee.
|
91 |
+
|
92 |
+
## Attribution
|
93 |
+
|
94 |
+
This Code of Conduct is adapted from the [Contributor Covenant][contributor-covenant-home],
|
95 |
+
version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html.
|
96 |
+
It includes adaptions and additions from [Go Community Code of Conduct][golang-coc],
|
97 |
+
[CNCF Code of Conduct][cncf-coc], and [Microsoft Open Source Code of Conduct][microsoft-coc].
|
98 |
+
|
99 |
+
This Code of Conduct is licensed under the [Creative Commons Attribution 3.0 License][cc-by-3-us].
|
100 |
+
|
101 |
+
[contributor-covenant-home]: https://www.contributor-covenant.org (https://www.contributor-covenant.org/)
|
102 |
+
[golang-coc]: https://golang.org/conduct
|
103 |
+
[cncf-coc]: https://github.com/cncf/foundation/blob/master/code-of-conduct.md
|
104 |
+
[microsoft-coc]: https://opensource.microsoft.com/codeofconduct/
|
105 |
+
[cc-by-3-us]: https://creativecommons.org/licenses/by/3.0/us/
|
BLIP_LICENSE.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2022, Salesforce.com, Inc.
|
2 |
+
All rights reserved.
|
3 |
+
|
4 |
+
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
5 |
+
|
6 |
+
* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
7 |
+
|
8 |
+
* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
9 |
+
|
10 |
+
* Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
11 |
+
|
12 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
BLIP_README.md
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation
|
2 |
+
|
3 |
+
<img src="BLIP.gif" width="700">
|
4 |
+
|
5 |
+
This is the PyTorch code of the <a href="https://arxiv.org/abs/2201.12086">BLIP paper</a> [[blog](https://blog.salesforceairesearch.com/blip-bootstrapping-language-image-pretraining/)]. The code has been tested on PyTorch 1.10.
|
6 |
+
To install the dependencies, run <pre/>pip install -r requirements.txt</pre>
|
7 |
+
|
8 |
+
Catalog:
|
9 |
+
- [x] Inference demo
|
10 |
+
- [x] Pre-trained and finetuned checkpoints
|
11 |
+
- [x] Finetuning code for Image-Text Retrieval, Image Captioning, VQA, and NLVR2
|
12 |
+
- [x] Pre-training code
|
13 |
+
- [x] Zero-shot video-text retrieval
|
14 |
+
- [x] Download of bootstrapped pre-training datasets
|
15 |
+
|
16 |
+
|
17 |
+
### Inference demo:
|
18 |
+
Run our interactive demo using [Colab notebook](https://colab.research.google.com/github/salesforce/BLIP/blob/main/demo.ipynb) (no GPU needed).
|
19 |
+
The demo includes code for:
|
20 |
+
1. Image captioning
|
21 |
+
2. Open-ended visual question answering
|
22 |
+
3. Multimodal / unimodal feature extraction
|
23 |
+
4. Image-text matching
|
24 |
+
|
25 |
+
Try out the [Web demo](https://huggingface.co/spaces/Salesforce/BLIP), integrated into [Huggingface Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio).
|
26 |
+
|
27 |
+
Replicate web demo and Docker image is also available at [![Replicate](https://replicate.com/salesforce/blip/badge)](https://replicate.com/salesforce/blip)
|
28 |
+
|
29 |
+
### Pre-trained checkpoints:
|
30 |
+
Num. pre-train images | BLIP w/ ViT-B | BLIP w/ ViT-B and CapFilt-L | BLIP w/ ViT-L
|
31 |
+
--- | :---: | :---: | :---:
|
32 |
+
14M | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_14M.pth">Download</a>| - | -
|
33 |
+
129M | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth">Download</a> | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large.pth">Download</a>
|
34 |
+
|
35 |
+
### Finetuned checkpoints:
|
36 |
+
Task | BLIP w/ ViT-B | BLIP w/ ViT-B and CapFilt-L | BLIP w/ ViT-L
|
37 |
+
--- | :---: | :---: | :---:
|
38 |
+
Image-Text Retrieval (COCO) | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth">Download</a>| - | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco.pth">Download</a>
|
39 |
+
Image-Text Retrieval (Flickr30k) | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_flickr.pth">Download</a>| - | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_flickr.pth">Download</a>
|
40 |
+
Image Captioning (COCO) | - | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth">Download</a> |
|
41 |
+
VQA | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_vqa.pth">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth">Download</a> | -
|
42 |
+
NLVR2 | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_nlvr.pth">Download</a>| - | -
|
43 |
+
|
44 |
+
|
45 |
+
### Image-Text Retrieval:
|
46 |
+
1. Download COCO and Flickr30k datasets from the original websites, and set 'image_root' in configs/retrieval_{dataset}.yaml accordingly.
|
47 |
+
2. To evaluate the finetuned BLIP model on COCO, run:
|
48 |
+
<pre>python -m torch.distributed.run --nproc_per_node=8 train_retrieval.py \
|
49 |
+
--config ./configs/retrieval_coco.yaml \
|
50 |
+
--output_dir output/retrieval_coco \
|
51 |
+
--evaluate</pre>
|
52 |
+
3. To finetune the pre-trained checkpoint using 8 A100 GPUs, first set 'pretrained' in configs/retrieval_coco.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth". Then run:
|
53 |
+
<pre>python -m torch.distributed.run --nproc_per_node=8 train_retrieval.py \
|
54 |
+
--config ./configs/retrieval_coco.yaml \
|
55 |
+
--output_dir output/retrieval_coco </pre>
|
56 |
+
|
57 |
+
### Image-Text Captioning:
|
58 |
+
1. Download COCO and NoCaps datasets from the original websites, and set 'image_root' in configs/caption_coco.yaml and configs/nocaps.yaml accordingly.
|
59 |
+
2. To evaluate the finetuned BLIP model on COCO, run:
|
60 |
+
<pre>python -m torch.distributed.run --nproc_per_node=8 train_caption.py --evaluate</pre>
|
61 |
+
3. To evaluate the finetuned BLIP model on NoCaps, generate results with: (evaluation needs to be performed on official server)
|
62 |
+
<pre>python -m torch.distributed.run --nproc_per_node=8 eval_nocaps.py </pre>
|
63 |
+
4. To finetune the pre-trained checkpoint using 8 A100 GPUs, first set 'pretrained' in configs/caption_coco.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth". Then run:
|
64 |
+
<pre>python -m torch.distributed.run --nproc_per_node=8 train_caption.py </pre>
|
65 |
+
|
66 |
+
### VQA:
|
67 |
+
1. Download VQA v2 dataset and Visual Genome dataset from the original websites, and set 'vqa_root' and 'vg_root' in configs/vqa.yaml.
|
68 |
+
2. To evaluate the finetuned BLIP model, generate results with: (evaluation needs to be performed on official server)
|
69 |
+
<pre>python -m torch.distributed.run --nproc_per_node=8 train_vqa.py --evaluate</pre>
|
70 |
+
3. To finetune the pre-trained checkpoint using 16 A100 GPUs, first set 'pretrained' in configs/vqa.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth". Then run:
|
71 |
+
<pre>python -m torch.distributed.run --nproc_per_node=16 train_vqa.py </pre>
|
72 |
+
|
73 |
+
### NLVR2:
|
74 |
+
1. Download NLVR2 dataset from the original websites, and set 'image_root' in configs/nlvr.yaml.
|
75 |
+
2. To evaluate the finetuned BLIP model, run
|
76 |
+
<pre>python -m torch.distributed.run --nproc_per_node=8 train_nlvr.py --evaluate</pre>
|
77 |
+
3. To finetune the pre-trained checkpoint using 16 A100 GPUs, first set 'pretrained' in configs/nlvr.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth". Then run:
|
78 |
+
<pre>python -m torch.distributed.run --nproc_per_node=16 train_nlvr.py </pre>
|
79 |
+
|
80 |
+
### Finetune with ViT-L:
|
81 |
+
In order to finetune a model with ViT-L, simply change the config file to set 'vit' as large. Batch size and learning rate may also need to be adjusted accordingly (please see the paper's appendix for hyper-parameter details). <a href="https://github.com/facebookresearch/fairscale">Gradient checkpoint</a> can also be activated in the config file to reduce GPU memory usage.
|
82 |
+
|
83 |
+
### Pre-train:
|
84 |
+
1. Prepare training json files where each json file contains a list. Each item in the list is a dictonary with two key-value pairs: {'image': path_of_image, 'caption': text_of_image}.
|
85 |
+
2. In configs/pretrain.yaml, set 'train_file' as the paths for the json files .
|
86 |
+
3. Pre-train the model using 8 A100 GPUs:
|
87 |
+
<pre>python -m torch.distributed.run --nproc_per_node=8 pretrain.py --config ./configs/Pretrain.yaml --output_dir output/Pretrain </pre>
|
88 |
+
|
89 |
+
### Zero-shot video-text retrieval:
|
90 |
+
1. Download MSRVTT dataset following the instructions from https://github.com/salesforce/ALPRO, and set 'video_root' accordingly in configs/retrieval_msrvtt.yaml.
|
91 |
+
2. Install [decord](https://github.com/dmlc/decord) with <pre>pip install decord</pre>
|
92 |
+
3. To perform zero-shot evaluation, run
|
93 |
+
<pre>python -m torch.distributed.run --nproc_per_node=8 eval_retrieval_video.py</pre>
|
94 |
+
|
95 |
+
### Pre-training datasets download:
|
96 |
+
We provide bootstrapped pre-training datasets as json files. Each json file contains a list. Each item in the list is a dictonary with two key-value pairs: {'url': url_of_image, 'caption': text_of_image}.
|
97 |
+
|
98 |
+
Image source | Filtered web caption | Filtered synthetic caption by ViT-B | Filtered synthetic caption by ViT-L
|
99 |
+
--- | :---: | :---: | :---:
|
100 |
+
CC3M+CC12M+SBU | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/ccs_filtered.json">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/ccs_synthetic_filtered.json">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/ccs_synthetic_filtered_large.json">Download</a>
|
101 |
+
LAION115M | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/laion_filtered.json">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/laion_synthetic_filtered.json">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/laion_synthetic_filtered_large.json">Download</a>
|
102 |
+
|
103 |
+
### Citation
|
104 |
+
If you find this code to be useful for your research, please consider citing.
|
105 |
+
<pre>
|
106 |
+
@inproceedings{li2022blip,
|
107 |
+
title={BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation},
|
108 |
+
author={Junnan Li and Dongxu Li and Caiming Xiong and Steven Hoi},
|
109 |
+
year={2022},
|
110 |
+
booktitle={ICML},
|
111 |
+
}</pre>
|
112 |
+
|
113 |
+
### Acknowledgement
|
114 |
+
The implementation of BLIP relies on resources from <a href="https://github.com/salesforce/ALBEF">ALBEF</a>, <a href="https://github.com/huggingface/transformers">Huggingface Transformers</a>, and <a href="https://github.com/rwightman/pytorch-image-models/tree/master/timm">timm</a>. We thank the original authors for their open-sourcing.
|
BLIP_SECURITY.md
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Security
|
2 |
+
|
3 |
+
Please report any security issue to [security@salesforce.com](mailto:security@salesforce.com)
|
4 |
+
as soon as it is discovered. This library limits its runtime dependencies in
|
5 |
+
order to reduce the total cost of ownership as much as can be, but all consumers
|
6 |
+
should remain vigilant and have their security stakeholders review all third-party
|
7 |
+
products (3PP) like this one and their dependencies.
|
BLIP_cog.yaml.txt
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
build:
|
2 |
+
gpu: true
|
3 |
+
cuda: "11.1"
|
4 |
+
python_version: "3.8"
|
5 |
+
system_packages:
|
6 |
+
- "libgl1-mesa-glx"
|
7 |
+
- "libglib2.0-0"
|
8 |
+
python_packages:
|
9 |
+
- "ipython==7.30.1"
|
10 |
+
- "torchvision==0.11.1"
|
11 |
+
- "torch==1.10.0"
|
12 |
+
- "timm==0.4.12"
|
13 |
+
- "transformers==4.15.0"
|
14 |
+
- "fairscale==0.4.4"
|
15 |
+
- "pycocoevalcap==1.2"
|
16 |
+
|
17 |
+
predict: "predict.py:Predictor"
|
BLIP_demo.ipynb.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
BLIP_eval_nocaps.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
* Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
* All rights reserved.
|
4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
* By Junnan Li
|
7 |
+
'''
|
8 |
+
import argparse
|
9 |
+
import os
|
10 |
+
import ruamel_yaml as yaml
|
11 |
+
import numpy as np
|
12 |
+
import random
|
13 |
+
import time
|
14 |
+
import datetime
|
15 |
+
import json
|
16 |
+
from pathlib import Path
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
import torch.backends.cudnn as cudnn
|
22 |
+
import torch.distributed as dist
|
23 |
+
from torch.utils.data import DataLoader
|
24 |
+
|
25 |
+
from models.blip import blip_decoder
|
26 |
+
import utils
|
27 |
+
from data import create_dataset, create_sampler, create_loader
|
28 |
+
from data.utils import save_result
|
29 |
+
|
30 |
+
@torch.no_grad()
|
31 |
+
def evaluate(model, data_loader, device, config):
|
32 |
+
# evaluate
|
33 |
+
model.eval()
|
34 |
+
|
35 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
36 |
+
header = 'Evaluation:'
|
37 |
+
print_freq = 10
|
38 |
+
|
39 |
+
result = []
|
40 |
+
for image, image_id in metric_logger.log_every(data_loader, print_freq, header):
|
41 |
+
|
42 |
+
image = image.to(device)
|
43 |
+
|
44 |
+
captions = model.generate(image, sample=False, num_beams=config['num_beams'], max_length=config['max_length'],
|
45 |
+
min_length=config['min_length'], repetition_penalty=1.1)
|
46 |
+
|
47 |
+
for caption, img_id in zip(captions, image_id):
|
48 |
+
result.append({"image_id": img_id.item(), "caption": caption})
|
49 |
+
|
50 |
+
return result
|
51 |
+
|
52 |
+
|
53 |
+
def main(args, config):
|
54 |
+
utils.init_distributed_mode(args)
|
55 |
+
|
56 |
+
device = torch.device(args.device)
|
57 |
+
|
58 |
+
# fix the seed for reproducibility
|
59 |
+
seed = args.seed + utils.get_rank()
|
60 |
+
torch.manual_seed(seed)
|
61 |
+
np.random.seed(seed)
|
62 |
+
random.seed(seed)
|
63 |
+
cudnn.benchmark = True
|
64 |
+
|
65 |
+
#### Dataset ####
|
66 |
+
print("Creating captioning dataset")
|
67 |
+
val_dataset, test_dataset = create_dataset('nocaps', config)
|
68 |
+
|
69 |
+
if args.distributed:
|
70 |
+
num_tasks = utils.get_world_size()
|
71 |
+
global_rank = utils.get_rank()
|
72 |
+
samplers = create_sampler([val_dataset,test_dataset], [False,False], num_tasks, global_rank)
|
73 |
+
else:
|
74 |
+
samplers = [None,None]
|
75 |
+
|
76 |
+
val_loader, test_loader = create_loader([val_dataset, test_dataset],samplers,
|
77 |
+
batch_size=[config['batch_size']]*2,num_workers=[4,4],
|
78 |
+
is_trains=[False, False], collate_fns=[None,None])
|
79 |
+
|
80 |
+
#### Model ####
|
81 |
+
print("Creating model")
|
82 |
+
model = blip_decoder(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'],
|
83 |
+
prompt=config['prompt'])
|
84 |
+
|
85 |
+
model = model.to(device)
|
86 |
+
|
87 |
+
model_without_ddp = model
|
88 |
+
if args.distributed:
|
89 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
90 |
+
model_without_ddp = model.module
|
91 |
+
|
92 |
+
val_result = evaluate(model_without_ddp, val_loader, device, config)
|
93 |
+
val_result_file = save_result(val_result, args.result_dir, 'val', remove_duplicate='image_id')
|
94 |
+
test_result = evaluate(model_without_ddp, test_loader, device, config)
|
95 |
+
test_result_file = save_result(test_result, args.result_dir, 'test', remove_duplicate='image_id')
|
96 |
+
|
97 |
+
|
98 |
+
if __name__ == '__main__':
|
99 |
+
parser = argparse.ArgumentParser()
|
100 |
+
parser.add_argument('--config', default='./configs/nocaps.yaml')
|
101 |
+
parser.add_argument('--output_dir', default='output/NoCaps')
|
102 |
+
parser.add_argument('--device', default='cuda')
|
103 |
+
parser.add_argument('--seed', default=42, type=int)
|
104 |
+
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
|
105 |
+
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
|
106 |
+
parser.add_argument('--distributed', default=True, type=bool)
|
107 |
+
args = parser.parse_args()
|
108 |
+
|
109 |
+
config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
|
110 |
+
|
111 |
+
args.result_dir = os.path.join(args.output_dir, 'result')
|
112 |
+
|
113 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
114 |
+
Path(args.result_dir).mkdir(parents=True, exist_ok=True)
|
115 |
+
|
116 |
+
yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
|
117 |
+
|
118 |
+
main(args, config)
|
BLIP_eval_retrieval_video.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
* Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
* All rights reserved.
|
4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
* By Junnan Li
|
7 |
+
'''
|
8 |
+
import argparse
|
9 |
+
import os
|
10 |
+
import ruamel_yaml as yaml
|
11 |
+
import numpy as np
|
12 |
+
import random
|
13 |
+
import time
|
14 |
+
import datetime
|
15 |
+
import json
|
16 |
+
from pathlib import Path
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
import torch.backends.cudnn as cudnn
|
22 |
+
import torch.distributed as dist
|
23 |
+
from torch.utils.data import DataLoader
|
24 |
+
|
25 |
+
from models.blip_retrieval import blip_retrieval
|
26 |
+
import utils
|
27 |
+
from data.video_dataset import VideoDataset
|
28 |
+
|
29 |
+
|
30 |
+
@torch.no_grad()
|
31 |
+
def evaluation(model, data_loader, tokenizer, device, config):
|
32 |
+
# test
|
33 |
+
model.eval()
|
34 |
+
|
35 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
36 |
+
header = 'Evaluation:'
|
37 |
+
|
38 |
+
print('Computing features for evaluation...')
|
39 |
+
start_time = time.time()
|
40 |
+
|
41 |
+
texts = data_loader.dataset.text
|
42 |
+
num_text = len(texts)
|
43 |
+
text_bs = 256
|
44 |
+
text_ids = []
|
45 |
+
text_embeds = []
|
46 |
+
text_atts = []
|
47 |
+
for i in range(0, num_text, text_bs):
|
48 |
+
text = texts[i: min(num_text, i+text_bs)]
|
49 |
+
text_input = tokenizer(text, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(device)
|
50 |
+
text_output = model.text_encoder(text_input.input_ids, attention_mask = text_input.attention_mask, mode='text')
|
51 |
+
text_embed = F.normalize(model.text_proj(text_output.last_hidden_state[:,0,:]))
|
52 |
+
text_embeds.append(text_embed)
|
53 |
+
text_ids.append(text_input.input_ids)
|
54 |
+
text_atts.append(text_input.attention_mask)
|
55 |
+
|
56 |
+
text_embeds = torch.cat(text_embeds,dim=0)
|
57 |
+
text_ids = torch.cat(text_ids,dim=0)
|
58 |
+
text_atts = torch.cat(text_atts,dim=0)
|
59 |
+
text_ids[:,0] = tokenizer.additional_special_tokens_ids[0]
|
60 |
+
|
61 |
+
video_feats = []
|
62 |
+
video_embeds = []
|
63 |
+
for video, video_id in data_loader:
|
64 |
+
|
65 |
+
B,N,C,W,H = video.size()
|
66 |
+
video = video.view(-1,C,W,H)
|
67 |
+
video = video.to(device,non_blocking=True)
|
68 |
+
video_feat = model.visual_encoder(video)
|
69 |
+
video_embed = model.vision_proj(video_feat[:,0,:])
|
70 |
+
video_embed = video_embed.view(B,N,-1).mean(dim=1)
|
71 |
+
video_embed = F.normalize(video_embed,dim=-1)
|
72 |
+
|
73 |
+
video_feat = video_feat.view(B,-1,video_feat.shape[-1])
|
74 |
+
video_feats.append(video_feat.cpu())
|
75 |
+
video_embeds.append(video_embed)
|
76 |
+
|
77 |
+
video_feats = torch.cat(video_feats,dim=0)
|
78 |
+
video_embeds = torch.cat(video_embeds,dim=0)
|
79 |
+
|
80 |
+
sims_matrix = video_embeds @ text_embeds.t()
|
81 |
+
score_matrix_v2t = torch.full((len(texts),len(texts)),-100.0).to(device)
|
82 |
+
|
83 |
+
num_tasks = utils.get_world_size()
|
84 |
+
rank = utils.get_rank()
|
85 |
+
step = sims_matrix.size(0)//num_tasks + 1
|
86 |
+
start = rank*step
|
87 |
+
end = min(sims_matrix.size(0),start+step)
|
88 |
+
|
89 |
+
for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)):
|
90 |
+
topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)
|
91 |
+
|
92 |
+
encoder_output = video_feats[start+i].repeat(config['k_test'],1,1).to(device,non_blocking=True)
|
93 |
+
encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device,non_blocking=True)
|
94 |
+
output = model.text_encoder(text_ids[topk_idx],
|
95 |
+
attention_mask = text_atts[topk_idx],
|
96 |
+
encoder_hidden_states = encoder_output,
|
97 |
+
encoder_attention_mask = encoder_att,
|
98 |
+
return_dict = True,
|
99 |
+
)
|
100 |
+
score = model.itm_head(output.last_hidden_state[:,0,:])[:,1]
|
101 |
+
score_matrix_v2t[start+i,topk_idx] = score + topk_sim
|
102 |
+
|
103 |
+
sims_matrix = sims_matrix.t()
|
104 |
+
score_matrix_t2v = torch.full((len(texts),len(texts)),-100.0).to(device)
|
105 |
+
|
106 |
+
step = sims_matrix.size(0)//num_tasks + 1
|
107 |
+
start = rank*step
|
108 |
+
end = min(sims_matrix.size(0),start+step)
|
109 |
+
|
110 |
+
for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)):
|
111 |
+
|
112 |
+
topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)
|
113 |
+
encoder_output = video_feats[topk_idx].to(device,non_blocking=True)
|
114 |
+
encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device,non_blocking=True)
|
115 |
+
output = model.text_encoder(text_ids[start+i].repeat(config['k_test'],1),
|
116 |
+
attention_mask = text_atts[start+i].repeat(config['k_test'],1),
|
117 |
+
encoder_hidden_states = encoder_output,
|
118 |
+
encoder_attention_mask = encoder_att,
|
119 |
+
return_dict = True,
|
120 |
+
)
|
121 |
+
score = model.itm_head(output.last_hidden_state[:,0,:])[:,1]
|
122 |
+
score_matrix_t2v[start+i,topk_idx] = score + topk_sim
|
123 |
+
|
124 |
+
if args.distributed:
|
125 |
+
dist.barrier()
|
126 |
+
torch.distributed.all_reduce(score_matrix_v2t, op=torch.distributed.ReduceOp.SUM)
|
127 |
+
torch.distributed.all_reduce(score_matrix_t2v, op=torch.distributed.ReduceOp.SUM)
|
128 |
+
|
129 |
+
total_time = time.time() - start_time
|
130 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
131 |
+
print('Evaluation time {}'.format(total_time_str))
|
132 |
+
|
133 |
+
return score_matrix_v2t.cpu().numpy(), score_matrix_t2v.cpu().numpy()
|
134 |
+
|
135 |
+
|
136 |
+
|
137 |
+
@torch.no_grad()
|
138 |
+
def itm_eval(scores_v2t, scores_t2v, txt2vmg, vid2txt):
|
139 |
+
|
140 |
+
#Video->Text
|
141 |
+
ranks = np.zeros(scores_v2t.shape[0])
|
142 |
+
for index,score in enumerate(scores_v2t):
|
143 |
+
inds = np.argsort(score)[::-1]
|
144 |
+
ranks[index] = np.where(inds == vid2txt[index])[0][0]
|
145 |
+
|
146 |
+
# Compute metrics
|
147 |
+
tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
|
148 |
+
tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
|
149 |
+
tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
|
150 |
+
|
151 |
+
#Text->Video
|
152 |
+
ranks = np.zeros(scores_t2v.shape[0])
|
153 |
+
|
154 |
+
for index,score in enumerate(scores_t2v):
|
155 |
+
inds = np.argsort(score)[::-1]
|
156 |
+
ranks[index] = np.where(inds == txt2vmg[index])[0][0]
|
157 |
+
|
158 |
+
mdR = np.median(ranks+1)
|
159 |
+
|
160 |
+
# Compute metrics
|
161 |
+
vr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
|
162 |
+
vr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
|
163 |
+
vr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
|
164 |
+
|
165 |
+
tr_mean = (tr1 + tr5 + tr10) / 3
|
166 |
+
vr_mean = (vr1 + vr5 + vr10) / 3
|
167 |
+
r_mean = (tr_mean + vr_mean) / 2
|
168 |
+
|
169 |
+
eval_result = {'txt_r1': tr1,
|
170 |
+
'txt_r5': tr5,
|
171 |
+
'txt_r10': tr10,
|
172 |
+
'txt_r_mean': tr_mean,
|
173 |
+
'vid_r1': vr1,
|
174 |
+
'vid_r5': vr5,
|
175 |
+
'vid_r10': vr10,
|
176 |
+
'vid_r_mean': vr_mean,
|
177 |
+
'vid_mdR': mdR,
|
178 |
+
'r_mean': r_mean}
|
179 |
+
return eval_result
|
180 |
+
|
181 |
+
|
182 |
+
|
183 |
+
|
184 |
+
def main(args, config):
|
185 |
+
utils.init_distributed_mode(args)
|
186 |
+
|
187 |
+
device = torch.device(args.device)
|
188 |
+
|
189 |
+
# fix the seed for reproducibility
|
190 |
+
seed = args.seed + utils.get_rank()
|
191 |
+
torch.manual_seed(seed)
|
192 |
+
np.random.seed(seed)
|
193 |
+
random.seed(seed)
|
194 |
+
cudnn.benchmark = True
|
195 |
+
|
196 |
+
#### Dataset ####
|
197 |
+
print("Creating retrieval dataset")
|
198 |
+
test_dataset = VideoDataset(config['video_root'],config['ann_root'],num_frm=config['num_frm_test'],
|
199 |
+
max_img_size=config['image_size'], frm_sampling_strategy='uniform')
|
200 |
+
|
201 |
+
test_loader = DataLoader(
|
202 |
+
test_dataset,
|
203 |
+
batch_size=config['batch_size'],
|
204 |
+
num_workers=4,
|
205 |
+
pin_memory=True,
|
206 |
+
drop_last=False,
|
207 |
+
shuffle=False,
|
208 |
+
)
|
209 |
+
|
210 |
+
#### Model ####
|
211 |
+
print("Creating model")
|
212 |
+
model = blip_retrieval(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'])
|
213 |
+
|
214 |
+
model = model.to(device)
|
215 |
+
|
216 |
+
model_without_ddp = model
|
217 |
+
if args.distributed:
|
218 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
219 |
+
model_without_ddp = model.module
|
220 |
+
|
221 |
+
score_v2t, score_t2v, = evaluation(model_without_ddp, test_loader, model_without_ddp.tokenizer, device, config)
|
222 |
+
|
223 |
+
if utils.is_main_process():
|
224 |
+
|
225 |
+
test_result = itm_eval(score_v2t, score_t2v, test_loader.dataset.txt2video, test_loader.dataset.video2txt)
|
226 |
+
print(test_result)
|
227 |
+
|
228 |
+
log_stats = {**{f'{k}': v for k, v in test_result.items()},}
|
229 |
+
with open(os.path.join(args.output_dir, "test_result.txt"),"a") as f:
|
230 |
+
f.write(json.dumps(log_stats) + "\n")
|
231 |
+
|
232 |
+
|
233 |
+
if __name__ == '__main__':
|
234 |
+
parser = argparse.ArgumentParser()
|
235 |
+
parser.add_argument('--config', default='./configs/retrieval_msrvtt.yaml')
|
236 |
+
parser.add_argument('--output_dir', default='output/Retrieval_msrvtt')
|
237 |
+
parser.add_argument('--device', default='cuda')
|
238 |
+
parser.add_argument('--seed', default=42, type=int)
|
239 |
+
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
|
240 |
+
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
|
241 |
+
parser.add_argument('--distributed', default=True, type=bool)
|
242 |
+
args = parser.parse_args()
|
243 |
+
|
244 |
+
config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
|
245 |
+
|
246 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
247 |
+
|
248 |
+
yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
|
249 |
+
|
250 |
+
main(args, config)
|
BLIP_predict.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Download the weights in ./checkpoints beforehand for fast inference
|
3 |
+
wget https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth
|
4 |
+
wget https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_vqa.pth
|
5 |
+
wget https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth
|
6 |
+
"""
|
7 |
+
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
from PIL import Image
|
11 |
+
import torch
|
12 |
+
from torchvision import transforms
|
13 |
+
from torchvision.transforms.functional import InterpolationMode
|
14 |
+
import cog
|
15 |
+
|
16 |
+
from models.blip import blip_decoder
|
17 |
+
from models.blip_vqa import blip_vqa
|
18 |
+
from models.blip_itm import blip_itm
|
19 |
+
|
20 |
+
|
21 |
+
class Predictor(cog.Predictor):
|
22 |
+
def setup(self):
|
23 |
+
self.device = "cuda:0"
|
24 |
+
|
25 |
+
self.models = {
|
26 |
+
'image_captioning': blip_decoder(pretrained='checkpoints/model*_base_caption.pth',
|
27 |
+
image_size=384, vit='base'),
|
28 |
+
'visual_question_answering': blip_vqa(pretrained='checkpoints/model*_vqa.pth',
|
29 |
+
image_size=480, vit='base'),
|
30 |
+
'image_text_matching': blip_itm(pretrained='checkpoints/model_base_retrieval_coco.pth',
|
31 |
+
image_size=384, vit='base')
|
32 |
+
}
|
33 |
+
|
34 |
+
@cog.input(
|
35 |
+
"image",
|
36 |
+
type=Path,
|
37 |
+
help="input image",
|
38 |
+
)
|
39 |
+
@cog.input(
|
40 |
+
"task",
|
41 |
+
type=str,
|
42 |
+
default='image_captioning',
|
43 |
+
options=['image_captioning', 'visual_question_answering', 'image_text_matching'],
|
44 |
+
help="Choose a task.",
|
45 |
+
)
|
46 |
+
@cog.input(
|
47 |
+
"question",
|
48 |
+
type=str,
|
49 |
+
default=None,
|
50 |
+
help="Type question for the input image for visual question answering task.",
|
51 |
+
)
|
52 |
+
@cog.input(
|
53 |
+
"caption",
|
54 |
+
type=str,
|
55 |
+
default=None,
|
56 |
+
help="Type caption for the input image for image text matching task.",
|
57 |
+
)
|
58 |
+
def predict(self, image, task, question, caption):
|
59 |
+
if task == 'visual_question_answering':
|
60 |
+
assert question is not None, 'Please type a question for visual question answering task.'
|
61 |
+
if task == 'image_text_matching':
|
62 |
+
assert caption is not None, 'Please type a caption for mage text matching task.'
|
63 |
+
|
64 |
+
im = load_image(image, image_size=480 if task == 'visual_question_answering' else 384, device=self.device)
|
65 |
+
model = self.models[task]
|
66 |
+
model.eval()
|
67 |
+
model = model.to(self.device)
|
68 |
+
|
69 |
+
if task == 'image_captioning':
|
70 |
+
with torch.no_grad():
|
71 |
+
caption = model.generate(im, sample=False, num_beams=3, max_length=20, min_length=5)
|
72 |
+
return 'Caption: ' + caption[0]
|
73 |
+
|
74 |
+
if task == 'visual_question_answering':
|
75 |
+
with torch.no_grad():
|
76 |
+
answer = model(im, question, train=False, inference='generate')
|
77 |
+
return 'Answer: ' + answer[0]
|
78 |
+
|
79 |
+
# image_text_matching
|
80 |
+
itm_output = model(im, caption, match_head='itm')
|
81 |
+
itm_score = torch.nn.functional.softmax(itm_output, dim=1)[:, 1]
|
82 |
+
itc_score = model(im, caption, match_head='itc')
|
83 |
+
return f'The image and text is matched with a probability of {itm_score.item():.4f}.\n' \
|
84 |
+
f'The image feature and text feature has a cosine similarity of {itc_score.item():.4f}.'
|
85 |
+
|
86 |
+
|
87 |
+
def load_image(image, image_size, device):
|
88 |
+
raw_image = Image.open(str(image)).convert('RGB')
|
89 |
+
|
90 |
+
w, h = raw_image.size
|
91 |
+
|
92 |
+
transform = transforms.Compose([
|
93 |
+
transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
|
94 |
+
transforms.ToTensor(),
|
95 |
+
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
96 |
+
])
|
97 |
+
image = transform(raw_image).unsqueeze(0).to(device)
|
98 |
+
return image
|
BLIP_pretrain.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
* Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
* All rights reserved.
|
4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
* By Junnan Li
|
7 |
+
'''
|
8 |
+
import argparse
|
9 |
+
import os
|
10 |
+
import ruamel_yaml as yaml
|
11 |
+
import numpy as np
|
12 |
+
import random
|
13 |
+
import time
|
14 |
+
import datetime
|
15 |
+
import json
|
16 |
+
from pathlib import Path
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
import torch.backends.cudnn as cudnn
|
22 |
+
import torch.distributed as dist
|
23 |
+
from torch.utils.data import DataLoader
|
24 |
+
|
25 |
+
from models.blip_pretrain import blip_pretrain
|
26 |
+
import utils
|
27 |
+
from utils import warmup_lr_schedule, step_lr_schedule
|
28 |
+
from data import create_dataset, create_sampler, create_loader
|
29 |
+
|
30 |
+
def train(model, data_loader, optimizer, epoch, device, config):
|
31 |
+
# train
|
32 |
+
model.train()
|
33 |
+
|
34 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
35 |
+
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=50, fmt='{value:.6f}'))
|
36 |
+
metric_logger.add_meter('loss_ita', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
|
37 |
+
metric_logger.add_meter('loss_itm', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
|
38 |
+
metric_logger.add_meter('loss_lm', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
|
39 |
+
|
40 |
+
header = 'Train Epoch: [{}]'.format(epoch)
|
41 |
+
print_freq = 50
|
42 |
+
|
43 |
+
if config['laion_path']:
|
44 |
+
data_loader.dataset.reload_laion(epoch)
|
45 |
+
|
46 |
+
data_loader.sampler.set_epoch(epoch)
|
47 |
+
|
48 |
+
for i, (image, caption) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
|
49 |
+
|
50 |
+
if epoch==0:
|
51 |
+
warmup_lr_schedule(optimizer, i, config['warmup_steps'], config['warmup_lr'], config['init_lr'])
|
52 |
+
|
53 |
+
optimizer.zero_grad()
|
54 |
+
|
55 |
+
image = image.to(device,non_blocking=True)
|
56 |
+
|
57 |
+
# ramp up alpha in the first 2 epochs
|
58 |
+
alpha = config['alpha']*min(1,(epoch*len(data_loader)+i)/(2*len(data_loader)))
|
59 |
+
|
60 |
+
loss_ita, loss_itm, loss_lm = model(image, caption, alpha = alpha)
|
61 |
+
loss = loss_ita + loss_itm + loss_lm
|
62 |
+
|
63 |
+
loss.backward()
|
64 |
+
optimizer.step()
|
65 |
+
|
66 |
+
metric_logger.update(loss_ita=loss_ita.item())
|
67 |
+
metric_logger.update(loss_itm=loss_itm.item())
|
68 |
+
metric_logger.update(loss_lm=loss_lm.item())
|
69 |
+
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
70 |
+
|
71 |
+
|
72 |
+
# gather the stats from all processes
|
73 |
+
metric_logger.synchronize_between_processes()
|
74 |
+
print("Averaged stats:", metric_logger.global_avg())
|
75 |
+
return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
|
76 |
+
|
77 |
+
|
78 |
+
def main(args, config):
|
79 |
+
utils.init_distributed_mode(args)
|
80 |
+
|
81 |
+
device = torch.device(args.device)
|
82 |
+
|
83 |
+
# fix the seed for reproducibility
|
84 |
+
seed = args.seed + utils.get_rank()
|
85 |
+
torch.manual_seed(seed)
|
86 |
+
np.random.seed(seed)
|
87 |
+
random.seed(seed)
|
88 |
+
cudnn.benchmark = True
|
89 |
+
|
90 |
+
#### Dataset ####
|
91 |
+
print("Creating dataset")
|
92 |
+
datasets = [create_dataset('pretrain', config, min_scale=0.2)]
|
93 |
+
print('number of training samples: %d'%len(datasets[0]))
|
94 |
+
|
95 |
+
num_tasks = utils.get_world_size()
|
96 |
+
global_rank = utils.get_rank()
|
97 |
+
samplers = create_sampler(datasets, [True], num_tasks, global_rank)
|
98 |
+
|
99 |
+
data_loader = create_loader(datasets,samplers,batch_size=[config['batch_size']], num_workers=[4], is_trains=[True], collate_fns=[None])[0]
|
100 |
+
|
101 |
+
#### Model ####
|
102 |
+
print("Creating model")
|
103 |
+
model = blip_pretrain(image_size=config['image_size'], vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'],
|
104 |
+
vit_ckpt_layer=config['vit_ckpt_layer'], queue_size=config['queue_size'])
|
105 |
+
|
106 |
+
model = model.to(device)
|
107 |
+
|
108 |
+
optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay'])
|
109 |
+
|
110 |
+
start_epoch = 0
|
111 |
+
if args.checkpoint:
|
112 |
+
checkpoint = torch.load(args.checkpoint, map_location='cpu')
|
113 |
+
state_dict = checkpoint['model']
|
114 |
+
model.load_state_dict(state_dict)
|
115 |
+
|
116 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
117 |
+
start_epoch = checkpoint['epoch']+1
|
118 |
+
print('resume checkpoint from %s'%args.checkpoint)
|
119 |
+
|
120 |
+
model_without_ddp = model
|
121 |
+
if args.distributed:
|
122 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
123 |
+
model_without_ddp = model.module
|
124 |
+
|
125 |
+
print("Start training")
|
126 |
+
start_time = time.time()
|
127 |
+
for epoch in range(start_epoch, config['max_epoch']):
|
128 |
+
|
129 |
+
step_lr_schedule(optimizer, epoch, config['init_lr'], config['min_lr'], config['lr_decay_rate'])
|
130 |
+
|
131 |
+
train_stats = train(model, data_loader, optimizer, epoch, device, config)
|
132 |
+
if utils.is_main_process():
|
133 |
+
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
|
134 |
+
'epoch': epoch,
|
135 |
+
}
|
136 |
+
save_obj = {
|
137 |
+
'model': model_without_ddp.state_dict(),
|
138 |
+
'optimizer': optimizer.state_dict(),
|
139 |
+
'config': config,
|
140 |
+
'epoch': epoch,
|
141 |
+
}
|
142 |
+
torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_%02d.pth'%epoch))
|
143 |
+
|
144 |
+
with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
|
145 |
+
f.write(json.dumps(log_stats) + "\n")
|
146 |
+
|
147 |
+
dist.barrier()
|
148 |
+
|
149 |
+
total_time = time.time() - start_time
|
150 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
151 |
+
print('Training time {}'.format(total_time_str))
|
152 |
+
|
153 |
+
|
154 |
+
if __name__ == '__main__':
|
155 |
+
parser = argparse.ArgumentParser()
|
156 |
+
parser.add_argument('--config', default='./configs/pretrain.yaml')
|
157 |
+
parser.add_argument('--output_dir', default='output/Pretrain')
|
158 |
+
parser.add_argument('--checkpoint', default='')
|
159 |
+
parser.add_argument('--evaluate', action='store_true')
|
160 |
+
parser.add_argument('--device', default='cuda')
|
161 |
+
parser.add_argument('--seed', default=42, type=int)
|
162 |
+
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
|
163 |
+
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
|
164 |
+
parser.add_argument('--distributed', default=True, type=bool)
|
165 |
+
args = parser.parse_args()
|
166 |
+
|
167 |
+
config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
|
168 |
+
|
169 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
170 |
+
|
171 |
+
yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
|
172 |
+
|
173 |
+
main(args, config)
|
BLIP_requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
timm==0.4.12
|
2 |
+
transformers==4.15.0
|
3 |
+
fairscale==0.4.4
|
4 |
+
pycocoevalcap
|
BLIP_train_caption.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
* Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
* All rights reserved.
|
4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
* By Junnan Li
|
7 |
+
'''
|
8 |
+
import argparse
|
9 |
+
import os
|
10 |
+
import ruamel_yaml as yaml
|
11 |
+
import numpy as np
|
12 |
+
import random
|
13 |
+
import time
|
14 |
+
import datetime
|
15 |
+
import json
|
16 |
+
from pathlib import Path
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
import torch.backends.cudnn as cudnn
|
22 |
+
import torch.distributed as dist
|
23 |
+
from torch.utils.data import DataLoader
|
24 |
+
|
25 |
+
from models.blip import blip_decoder
|
26 |
+
import utils
|
27 |
+
from utils import cosine_lr_schedule
|
28 |
+
from data import create_dataset, create_sampler, create_loader
|
29 |
+
from data.utils import save_result, coco_caption_eval
|
30 |
+
|
31 |
+
def train(model, data_loader, optimizer, epoch, device):
|
32 |
+
# train
|
33 |
+
model.train()
|
34 |
+
|
35 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
36 |
+
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
37 |
+
metric_logger.add_meter('loss', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
|
38 |
+
header = 'Train Caption Epoch: [{}]'.format(epoch)
|
39 |
+
print_freq = 50
|
40 |
+
|
41 |
+
for i, (image, caption, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
|
42 |
+
image = image.to(device)
|
43 |
+
|
44 |
+
loss = model(image, caption)
|
45 |
+
|
46 |
+
optimizer.zero_grad()
|
47 |
+
loss.backward()
|
48 |
+
optimizer.step()
|
49 |
+
|
50 |
+
metric_logger.update(loss=loss.item())
|
51 |
+
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
52 |
+
|
53 |
+
# gather the stats from all processes
|
54 |
+
metric_logger.synchronize_between_processes()
|
55 |
+
print("Averaged stats:", metric_logger.global_avg())
|
56 |
+
return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
|
57 |
+
|
58 |
+
|
59 |
+
@torch.no_grad()
|
60 |
+
def evaluate(model, data_loader, device, config):
|
61 |
+
# evaluate
|
62 |
+
model.eval()
|
63 |
+
|
64 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
65 |
+
header = 'Caption generation:'
|
66 |
+
print_freq = 10
|
67 |
+
|
68 |
+
result = []
|
69 |
+
for image, image_id in metric_logger.log_every(data_loader, print_freq, header):
|
70 |
+
|
71 |
+
image = image.to(device)
|
72 |
+
|
73 |
+
captions = model.generate(image, sample=False, num_beams=config['num_beams'], max_length=config['max_length'],
|
74 |
+
min_length=config['min_length'])
|
75 |
+
|
76 |
+
for caption, img_id in zip(captions, image_id):
|
77 |
+
result.append({"image_id": img_id.item(), "caption": caption})
|
78 |
+
|
79 |
+
return result
|
80 |
+
|
81 |
+
|
82 |
+
def main(args, config):
|
83 |
+
utils.init_distributed_mode(args)
|
84 |
+
|
85 |
+
device = torch.device(args.device)
|
86 |
+
|
87 |
+
# fix the seed for reproducibility
|
88 |
+
seed = args.seed + utils.get_rank()
|
89 |
+
torch.manual_seed(seed)
|
90 |
+
np.random.seed(seed)
|
91 |
+
random.seed(seed)
|
92 |
+
cudnn.benchmark = True
|
93 |
+
|
94 |
+
#### Dataset ####
|
95 |
+
print("Creating captioning dataset")
|
96 |
+
train_dataset, val_dataset, test_dataset = create_dataset('caption_coco', config)
|
97 |
+
|
98 |
+
if args.distributed:
|
99 |
+
num_tasks = utils.get_world_size()
|
100 |
+
global_rank = utils.get_rank()
|
101 |
+
samplers = create_sampler([train_dataset,val_dataset,test_dataset], [True,False,False], num_tasks, global_rank)
|
102 |
+
else:
|
103 |
+
samplers = [None, None, None]
|
104 |
+
|
105 |
+
train_loader, val_loader, test_loader = create_loader([train_dataset, val_dataset, test_dataset],samplers,
|
106 |
+
batch_size=[config['batch_size']]*3,num_workers=[4,4,4],
|
107 |
+
is_trains=[True, False, False], collate_fns=[None,None,None])
|
108 |
+
|
109 |
+
#### Model ####
|
110 |
+
print("Creating model")
|
111 |
+
model = blip_decoder(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'],
|
112 |
+
vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'],
|
113 |
+
prompt=config['prompt'])
|
114 |
+
|
115 |
+
model = model.to(device)
|
116 |
+
|
117 |
+
model_without_ddp = model
|
118 |
+
if args.distributed:
|
119 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
120 |
+
model_without_ddp = model.module
|
121 |
+
|
122 |
+
optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay'])
|
123 |
+
|
124 |
+
best = 0
|
125 |
+
best_epoch = 0
|
126 |
+
|
127 |
+
print("Start training")
|
128 |
+
start_time = time.time()
|
129 |
+
for epoch in range(0, config['max_epoch']):
|
130 |
+
if not args.evaluate:
|
131 |
+
if args.distributed:
|
132 |
+
train_loader.sampler.set_epoch(epoch)
|
133 |
+
|
134 |
+
cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr'])
|
135 |
+
|
136 |
+
train_stats = train(model, train_loader, optimizer, epoch, device)
|
137 |
+
|
138 |
+
val_result = evaluate(model_without_ddp, val_loader, device, config)
|
139 |
+
val_result_file = save_result(val_result, args.result_dir, 'val_epoch%d'%epoch, remove_duplicate='image_id')
|
140 |
+
|
141 |
+
test_result = evaluate(model_without_ddp, test_loader, device, config)
|
142 |
+
test_result_file = save_result(test_result, args.result_dir, 'test_epoch%d'%epoch, remove_duplicate='image_id')
|
143 |
+
|
144 |
+
if utils.is_main_process():
|
145 |
+
coco_val = coco_caption_eval(config['coco_gt_root'],val_result_file,'val')
|
146 |
+
coco_test = coco_caption_eval(config['coco_gt_root'],test_result_file,'test')
|
147 |
+
|
148 |
+
if args.evaluate:
|
149 |
+
log_stats = {**{f'val_{k}': v for k, v in coco_val.eval.items()},
|
150 |
+
**{f'test_{k}': v for k, v in coco_test.eval.items()},
|
151 |
+
}
|
152 |
+
with open(os.path.join(args.output_dir, "evaluate.txt"),"a") as f:
|
153 |
+
f.write(json.dumps(log_stats) + "\n")
|
154 |
+
else:
|
155 |
+
save_obj = {
|
156 |
+
'model': model_without_ddp.state_dict(),
|
157 |
+
'optimizer': optimizer.state_dict(),
|
158 |
+
'config': config,
|
159 |
+
'epoch': epoch,
|
160 |
+
}
|
161 |
+
|
162 |
+
if coco_val.eval['CIDEr'] + coco_val.eval['Bleu_4'] > best:
|
163 |
+
best = coco_val.eval['CIDEr'] + coco_val.eval['Bleu_4']
|
164 |
+
best_epoch = epoch
|
165 |
+
torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth'))
|
166 |
+
|
167 |
+
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
|
168 |
+
**{f'val_{k}': v for k, v in coco_val.eval.items()},
|
169 |
+
**{f'test_{k}': v for k, v in coco_test.eval.items()},
|
170 |
+
'epoch': epoch,
|
171 |
+
'best_epoch': best_epoch,
|
172 |
+
}
|
173 |
+
with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
|
174 |
+
f.write(json.dumps(log_stats) + "\n")
|
175 |
+
|
176 |
+
if args.evaluate:
|
177 |
+
break
|
178 |
+
dist.barrier()
|
179 |
+
|
180 |
+
total_time = time.time() - start_time
|
181 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
182 |
+
print('Training time {}'.format(total_time_str))
|
183 |
+
|
184 |
+
|
185 |
+
if __name__ == '__main__':
|
186 |
+
parser = argparse.ArgumentParser()
|
187 |
+
parser.add_argument('--config', default='./configs/caption_coco.yaml')
|
188 |
+
parser.add_argument('--output_dir', default='output/Caption_coco')
|
189 |
+
parser.add_argument('--evaluate', action='store_true')
|
190 |
+
parser.add_argument('--device', default='cuda')
|
191 |
+
parser.add_argument('--seed', default=42, type=int)
|
192 |
+
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
|
193 |
+
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
|
194 |
+
parser.add_argument('--distributed', default=True, type=bool)
|
195 |
+
args = parser.parse_args()
|
196 |
+
|
197 |
+
config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
|
198 |
+
|
199 |
+
args.result_dir = os.path.join(args.output_dir, 'result')
|
200 |
+
|
201 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
202 |
+
Path(args.result_dir).mkdir(parents=True, exist_ok=True)
|
203 |
+
|
204 |
+
yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
|
205 |
+
|
206 |
+
main(args, config)
|
BLIP_train_nlvr.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
* Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
* All rights reserved.
|
4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
* By Junnan Li
|
7 |
+
'''
|
8 |
+
import argparse
|
9 |
+
import os
|
10 |
+
import ruamel_yaml as yaml
|
11 |
+
import numpy as np
|
12 |
+
import random
|
13 |
+
import time
|
14 |
+
import datetime
|
15 |
+
import json
|
16 |
+
from pathlib import Path
|
17 |
+
import json
|
18 |
+
import pickle
|
19 |
+
|
20 |
+
import torch
|
21 |
+
import torch.nn as nn
|
22 |
+
import torch.nn.functional as F
|
23 |
+
from torch.utils.data import DataLoader
|
24 |
+
import torch.backends.cudnn as cudnn
|
25 |
+
import torch.distributed as dist
|
26 |
+
|
27 |
+
from models.blip_nlvr import blip_nlvr
|
28 |
+
|
29 |
+
import utils
|
30 |
+
from utils import cosine_lr_schedule, warmup_lr_schedule
|
31 |
+
from data import create_dataset, create_sampler, create_loader
|
32 |
+
|
33 |
+
def train(model, data_loader, optimizer, epoch, device, config):
|
34 |
+
# train
|
35 |
+
model.train()
|
36 |
+
|
37 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
38 |
+
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=50, fmt='{value:.6f}'))
|
39 |
+
metric_logger.add_meter('loss', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
|
40 |
+
|
41 |
+
header = 'Train Epoch: [{}]'.format(epoch)
|
42 |
+
print_freq = 50
|
43 |
+
step_size = 10
|
44 |
+
|
45 |
+
for i,(image0, image1, text, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
|
46 |
+
|
47 |
+
images = torch.cat([image0, image1], dim=0)
|
48 |
+
images, targets = images.to(device), targets.to(device)
|
49 |
+
|
50 |
+
loss = model(images, text, targets=targets, train=True)
|
51 |
+
|
52 |
+
optimizer.zero_grad()
|
53 |
+
loss.backward()
|
54 |
+
optimizer.step()
|
55 |
+
|
56 |
+
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
57 |
+
metric_logger.update(loss=loss.item())
|
58 |
+
|
59 |
+
# gather the stats from all processes
|
60 |
+
metric_logger.synchronize_between_processes()
|
61 |
+
print("Averaged stats:", metric_logger.global_avg())
|
62 |
+
return {k: "{:.4f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
|
63 |
+
|
64 |
+
|
65 |
+
@torch.no_grad()
|
66 |
+
def evaluate(model, data_loader, device, config):
|
67 |
+
# test
|
68 |
+
model.eval()
|
69 |
+
|
70 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
71 |
+
|
72 |
+
header = 'Evaluation:'
|
73 |
+
print_freq = 50
|
74 |
+
|
75 |
+
for image0, image1, text, targets in metric_logger.log_every(data_loader, print_freq, header):
|
76 |
+
images = torch.cat([image0, image1], dim=0)
|
77 |
+
images, targets = images.to(device), targets.to(device)
|
78 |
+
|
79 |
+
prediction = model(images, text, targets=targets, train=False)
|
80 |
+
|
81 |
+
_, pred_class = prediction.max(1)
|
82 |
+
accuracy = (targets==pred_class).sum() / targets.size(0)
|
83 |
+
|
84 |
+
metric_logger.meters['acc'].update(accuracy.item(), n=image0.size(0))
|
85 |
+
|
86 |
+
# gather the stats from all processes
|
87 |
+
metric_logger.synchronize_between_processes()
|
88 |
+
|
89 |
+
print("Averaged stats:", metric_logger.global_avg())
|
90 |
+
return {k: "{:.4f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
def main(args, config):
|
95 |
+
utils.init_distributed_mode(args)
|
96 |
+
|
97 |
+
device = torch.device(args.device)
|
98 |
+
|
99 |
+
# fix the seed for reproducibility
|
100 |
+
seed = args.seed + utils.get_rank()
|
101 |
+
torch.manual_seed(seed)
|
102 |
+
np.random.seed(seed)
|
103 |
+
random.seed(seed)
|
104 |
+
cudnn.benchmark = True
|
105 |
+
|
106 |
+
#### Dataset ####
|
107 |
+
print("Creating dataset")
|
108 |
+
datasets = create_dataset('nlvr', config)
|
109 |
+
|
110 |
+
if args.distributed:
|
111 |
+
num_tasks = utils.get_world_size()
|
112 |
+
global_rank = utils.get_rank()
|
113 |
+
samplers = create_sampler(datasets, [True,False,False], num_tasks, global_rank)
|
114 |
+
else:
|
115 |
+
samplers = [None, None, None]
|
116 |
+
|
117 |
+
batch_size=[config['batch_size_train'],config['batch_size_test'],config['batch_size_test']]
|
118 |
+
train_loader, val_loader, test_loader = create_loader(datasets,samplers,batch_size=batch_size,
|
119 |
+
num_workers=[4,4,4],is_trains=[True,False,False],
|
120 |
+
collate_fns=[None,None,None])
|
121 |
+
|
122 |
+
#### Model ####
|
123 |
+
print("Creating model")
|
124 |
+
model = blip_nlvr(pretrained=config['pretrained'], image_size=config['image_size'],
|
125 |
+
vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'])
|
126 |
+
|
127 |
+
model = model.to(device)
|
128 |
+
|
129 |
+
model_without_ddp = model
|
130 |
+
if args.distributed:
|
131 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
132 |
+
model_without_ddp = model.module
|
133 |
+
|
134 |
+
optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay'])
|
135 |
+
|
136 |
+
print("Start training")
|
137 |
+
start_time = time.time()
|
138 |
+
best = 0
|
139 |
+
best_epoch = 0
|
140 |
+
|
141 |
+
for epoch in range(0, config['max_epoch']):
|
142 |
+
if not args.evaluate:
|
143 |
+
if args.distributed:
|
144 |
+
train_loader.sampler.set_epoch(epoch)
|
145 |
+
|
146 |
+
cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr'])
|
147 |
+
|
148 |
+
train_stats = train(model, train_loader, optimizer, epoch, device, config)
|
149 |
+
|
150 |
+
val_stats = evaluate(model, val_loader, device, config)
|
151 |
+
test_stats = evaluate(model, test_loader, device, config)
|
152 |
+
|
153 |
+
if utils.is_main_process():
|
154 |
+
if args.evaluate:
|
155 |
+
log_stats = {**{f'val_{k}': v for k, v in val_stats.items()},
|
156 |
+
**{f'test_{k}': v for k, v in test_stats.items()},
|
157 |
+
}
|
158 |
+
with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
|
159 |
+
f.write(json.dumps(log_stats) + "\n")
|
160 |
+
|
161 |
+
else:
|
162 |
+
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
|
163 |
+
**{f'val_{k}': v for k, v in val_stats.items()},
|
164 |
+
**{f'test_{k}': v for k, v in test_stats.items()},
|
165 |
+
'epoch': epoch,
|
166 |
+
}
|
167 |
+
|
168 |
+
if float(val_stats['acc'])>best:
|
169 |
+
save_obj = {
|
170 |
+
'model': model_without_ddp.state_dict(),
|
171 |
+
'optimizer': optimizer.state_dict(),
|
172 |
+
'config': config,
|
173 |
+
'epoch': epoch,
|
174 |
+
}
|
175 |
+
torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth'))
|
176 |
+
best = float(val_stats['acc'])
|
177 |
+
best_epoch = epoch
|
178 |
+
|
179 |
+
with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
|
180 |
+
f.write(json.dumps(log_stats) + "\n")
|
181 |
+
if args.evaluate:
|
182 |
+
break
|
183 |
+
|
184 |
+
dist.barrier()
|
185 |
+
|
186 |
+
if utils.is_main_process():
|
187 |
+
with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
|
188 |
+
f.write("best epoch: %d"%best_epoch)
|
189 |
+
|
190 |
+
total_time = time.time() - start_time
|
191 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
192 |
+
print('Training time {}'.format(total_time_str))
|
193 |
+
|
194 |
+
|
195 |
+
if __name__ == '__main__':
|
196 |
+
parser = argparse.ArgumentParser()
|
197 |
+
parser.add_argument('--config', default='./configs/nlvr.yaml')
|
198 |
+
parser.add_argument('--output_dir', default='output/NLVR')
|
199 |
+
parser.add_argument('--evaluate', action='store_true')
|
200 |
+
parser.add_argument('--device', default='cuda')
|
201 |
+
parser.add_argument('--seed', default=42, type=int)
|
202 |
+
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
|
203 |
+
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
|
204 |
+
parser.add_argument('--distributed', default=True, type=bool)
|
205 |
+
args = parser.parse_args()
|
206 |
+
|
207 |
+
config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
|
208 |
+
|
209 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
210 |
+
|
211 |
+
yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
|
212 |
+
|
213 |
+
main(args, config)
|
BLIP_train_retrieval.py
ADDED
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
* Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
* All rights reserved.
|
4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
* By Junnan Li
|
7 |
+
'''
|
8 |
+
import argparse
|
9 |
+
import os
|
10 |
+
import ruamel_yaml as yaml
|
11 |
+
import numpy as np
|
12 |
+
import random
|
13 |
+
import time
|
14 |
+
import datetime
|
15 |
+
import json
|
16 |
+
from pathlib import Path
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
import torch.backends.cudnn as cudnn
|
22 |
+
import torch.distributed as dist
|
23 |
+
from torch.utils.data import DataLoader
|
24 |
+
|
25 |
+
from models.blip_retrieval import blip_retrieval
|
26 |
+
import utils
|
27 |
+
from utils import cosine_lr_schedule
|
28 |
+
from data import create_dataset, create_sampler, create_loader
|
29 |
+
|
30 |
+
|
31 |
+
def train(model, data_loader, optimizer, epoch, device, config):
|
32 |
+
# train
|
33 |
+
model.train()
|
34 |
+
|
35 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
36 |
+
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
37 |
+
metric_logger.add_meter('loss_itm', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
|
38 |
+
metric_logger.add_meter('loss_ita', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
|
39 |
+
header = 'Train Epoch: [{}]'.format(epoch)
|
40 |
+
print_freq = 50
|
41 |
+
|
42 |
+
for i,(image, caption, idx) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
|
43 |
+
image = image.to(device,non_blocking=True)
|
44 |
+
idx = idx.to(device,non_blocking=True)
|
45 |
+
|
46 |
+
if epoch>0:
|
47 |
+
alpha = config['alpha']
|
48 |
+
else:
|
49 |
+
alpha = config['alpha']*min(1,i/len(data_loader))
|
50 |
+
|
51 |
+
loss_ita, loss_itm = model(image, caption, alpha=alpha, idx=idx)
|
52 |
+
loss = loss_ita + loss_itm
|
53 |
+
|
54 |
+
optimizer.zero_grad()
|
55 |
+
loss.backward()
|
56 |
+
optimizer.step()
|
57 |
+
|
58 |
+
metric_logger.update(loss_itm=loss_itm.item())
|
59 |
+
metric_logger.update(loss_ita=loss_ita.item())
|
60 |
+
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
61 |
+
|
62 |
+
# gather the stats from all processes
|
63 |
+
metric_logger.synchronize_between_processes()
|
64 |
+
print("Averaged stats:", metric_logger.global_avg())
|
65 |
+
return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
|
66 |
+
|
67 |
+
|
68 |
+
@torch.no_grad()
|
69 |
+
def evaluation(model, data_loader, device, config):
|
70 |
+
# test
|
71 |
+
model.eval()
|
72 |
+
|
73 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
74 |
+
header = 'Evaluation:'
|
75 |
+
|
76 |
+
print('Computing features for evaluation...')
|
77 |
+
start_time = time.time()
|
78 |
+
|
79 |
+
texts = data_loader.dataset.text
|
80 |
+
num_text = len(texts)
|
81 |
+
text_bs = 256
|
82 |
+
text_ids = []
|
83 |
+
text_embeds = []
|
84 |
+
text_atts = []
|
85 |
+
for i in range(0, num_text, text_bs):
|
86 |
+
text = texts[i: min(num_text, i+text_bs)]
|
87 |
+
text_input = model.tokenizer(text, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(device)
|
88 |
+
text_output = model.text_encoder(text_input.input_ids, attention_mask = text_input.attention_mask, mode='text')
|
89 |
+
text_embed = F.normalize(model.text_proj(text_output.last_hidden_state[:,0,:]))
|
90 |
+
text_embeds.append(text_embed)
|
91 |
+
text_ids.append(text_input.input_ids)
|
92 |
+
text_atts.append(text_input.attention_mask)
|
93 |
+
|
94 |
+
text_embeds = torch.cat(text_embeds,dim=0)
|
95 |
+
text_ids = torch.cat(text_ids,dim=0)
|
96 |
+
text_atts = torch.cat(text_atts,dim=0)
|
97 |
+
text_ids[:,0] = model.tokenizer.enc_token_id
|
98 |
+
|
99 |
+
image_feats = []
|
100 |
+
image_embeds = []
|
101 |
+
for image, img_id in data_loader:
|
102 |
+
image = image.to(device)
|
103 |
+
image_feat = model.visual_encoder(image)
|
104 |
+
image_embed = model.vision_proj(image_feat[:,0,:])
|
105 |
+
image_embed = F.normalize(image_embed,dim=-1)
|
106 |
+
|
107 |
+
image_feats.append(image_feat.cpu())
|
108 |
+
image_embeds.append(image_embed)
|
109 |
+
|
110 |
+
image_feats = torch.cat(image_feats,dim=0)
|
111 |
+
image_embeds = torch.cat(image_embeds,dim=0)
|
112 |
+
|
113 |
+
sims_matrix = image_embeds @ text_embeds.t()
|
114 |
+
score_matrix_i2t = torch.full((len(data_loader.dataset.image),len(texts)),-100.0).to(device)
|
115 |
+
|
116 |
+
num_tasks = utils.get_world_size()
|
117 |
+
rank = utils.get_rank()
|
118 |
+
step = sims_matrix.size(0)//num_tasks + 1
|
119 |
+
start = rank*step
|
120 |
+
end = min(sims_matrix.size(0),start+step)
|
121 |
+
|
122 |
+
for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)):
|
123 |
+
topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)
|
124 |
+
|
125 |
+
encoder_output = image_feats[start+i].repeat(config['k_test'],1,1).to(device)
|
126 |
+
encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device)
|
127 |
+
output = model.text_encoder(text_ids[topk_idx],
|
128 |
+
attention_mask = text_atts[topk_idx],
|
129 |
+
encoder_hidden_states = encoder_output,
|
130 |
+
encoder_attention_mask = encoder_att,
|
131 |
+
return_dict = True,
|
132 |
+
)
|
133 |
+
score = model.itm_head(output.last_hidden_state[:,0,:])[:,1]
|
134 |
+
score_matrix_i2t[start+i,topk_idx] = score + topk_sim
|
135 |
+
|
136 |
+
sims_matrix = sims_matrix.t()
|
137 |
+
score_matrix_t2i = torch.full((len(texts),len(data_loader.dataset.image)),-100.0).to(device)
|
138 |
+
|
139 |
+
step = sims_matrix.size(0)//num_tasks + 1
|
140 |
+
start = rank*step
|
141 |
+
end = min(sims_matrix.size(0),start+step)
|
142 |
+
|
143 |
+
for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)):
|
144 |
+
|
145 |
+
topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)
|
146 |
+
encoder_output = image_feats[topk_idx].to(device)
|
147 |
+
encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device)
|
148 |
+
output = model.text_encoder(text_ids[start+i].repeat(config['k_test'],1),
|
149 |
+
attention_mask = text_atts[start+i].repeat(config['k_test'],1),
|
150 |
+
encoder_hidden_states = encoder_output,
|
151 |
+
encoder_attention_mask = encoder_att,
|
152 |
+
return_dict = True,
|
153 |
+
)
|
154 |
+
score = model.itm_head(output.last_hidden_state[:,0,:])[:,1]
|
155 |
+
score_matrix_t2i[start+i,topk_idx] = score + topk_sim
|
156 |
+
|
157 |
+
if args.distributed:
|
158 |
+
dist.barrier()
|
159 |
+
torch.distributed.all_reduce(score_matrix_i2t, op=torch.distributed.ReduceOp.SUM)
|
160 |
+
torch.distributed.all_reduce(score_matrix_t2i, op=torch.distributed.ReduceOp.SUM)
|
161 |
+
|
162 |
+
total_time = time.time() - start_time
|
163 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
164 |
+
print('Evaluation time {}'.format(total_time_str))
|
165 |
+
|
166 |
+
return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()
|
167 |
+
|
168 |
+
|
169 |
+
|
170 |
+
@torch.no_grad()
|
171 |
+
def itm_eval(scores_i2t, scores_t2i, txt2img, img2txt):
|
172 |
+
|
173 |
+
#Images->Text
|
174 |
+
ranks = np.zeros(scores_i2t.shape[0])
|
175 |
+
for index,score in enumerate(scores_i2t):
|
176 |
+
inds = np.argsort(score)[::-1]
|
177 |
+
# Score
|
178 |
+
rank = 1e20
|
179 |
+
for i in img2txt[index]:
|
180 |
+
tmp = np.where(inds == i)[0][0]
|
181 |
+
if tmp < rank:
|
182 |
+
rank = tmp
|
183 |
+
ranks[index] = rank
|
184 |
+
|
185 |
+
# Compute metrics
|
186 |
+
tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
|
187 |
+
tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
|
188 |
+
tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
|
189 |
+
|
190 |
+
#Text->Images
|
191 |
+
ranks = np.zeros(scores_t2i.shape[0])
|
192 |
+
|
193 |
+
for index,score in enumerate(scores_t2i):
|
194 |
+
inds = np.argsort(score)[::-1]
|
195 |
+
ranks[index] = np.where(inds == txt2img[index])[0][0]
|
196 |
+
|
197 |
+
# Compute metrics
|
198 |
+
ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
|
199 |
+
ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
|
200 |
+
ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
|
201 |
+
|
202 |
+
tr_mean = (tr1 + tr5 + tr10) / 3
|
203 |
+
ir_mean = (ir1 + ir5 + ir10) / 3
|
204 |
+
r_mean = (tr_mean + ir_mean) / 2
|
205 |
+
|
206 |
+
eval_result = {'txt_r1': tr1,
|
207 |
+
'txt_r5': tr5,
|
208 |
+
'txt_r10': tr10,
|
209 |
+
'txt_r_mean': tr_mean,
|
210 |
+
'img_r1': ir1,
|
211 |
+
'img_r5': ir5,
|
212 |
+
'img_r10': ir10,
|
213 |
+
'img_r_mean': ir_mean,
|
214 |
+
'r_mean': r_mean}
|
215 |
+
return eval_result
|
216 |
+
|
217 |
+
|
218 |
+
def main(args, config):
|
219 |
+
utils.init_distributed_mode(args)
|
220 |
+
|
221 |
+
device = torch.device(args.device)
|
222 |
+
|
223 |
+
# fix the seed for reproducibility
|
224 |
+
seed = args.seed + utils.get_rank()
|
225 |
+
torch.manual_seed(seed)
|
226 |
+
np.random.seed(seed)
|
227 |
+
random.seed(seed)
|
228 |
+
cudnn.benchmark = True
|
229 |
+
|
230 |
+
#### Dataset ####
|
231 |
+
print("Creating retrieval dataset")
|
232 |
+
train_dataset, val_dataset, test_dataset = create_dataset('retrieval_%s'%config['dataset'], config)
|
233 |
+
|
234 |
+
if args.distributed:
|
235 |
+
num_tasks = utils.get_world_size()
|
236 |
+
global_rank = utils.get_rank()
|
237 |
+
samplers = create_sampler([train_dataset], [True], num_tasks, global_rank) + [None, None]
|
238 |
+
else:
|
239 |
+
samplers = [None, None, None]
|
240 |
+
|
241 |
+
train_loader, val_loader, test_loader = create_loader([train_dataset, val_dataset, test_dataset],samplers,
|
242 |
+
batch_size=[config['batch_size_train']]+[config['batch_size_test']]*2,
|
243 |
+
num_workers=[4,4,4],
|
244 |
+
is_trains=[True, False, False],
|
245 |
+
collate_fns=[None,None,None])
|
246 |
+
|
247 |
+
|
248 |
+
#### Model ####
|
249 |
+
print("Creating model")
|
250 |
+
model = blip_retrieval(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'],
|
251 |
+
vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'],
|
252 |
+
queue_size=config['queue_size'], negative_all_rank=config['negative_all_rank'])
|
253 |
+
|
254 |
+
model = model.to(device)
|
255 |
+
|
256 |
+
model_without_ddp = model
|
257 |
+
if args.distributed:
|
258 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
259 |
+
model_without_ddp = model.module
|
260 |
+
|
261 |
+
optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay'])
|
262 |
+
|
263 |
+
best = 0
|
264 |
+
best_epoch = 0
|
265 |
+
|
266 |
+
print("Start training")
|
267 |
+
start_time = time.time()
|
268 |
+
|
269 |
+
for epoch in range(0, config['max_epoch']):
|
270 |
+
if not args.evaluate:
|
271 |
+
if args.distributed:
|
272 |
+
train_loader.sampler.set_epoch(epoch)
|
273 |
+
|
274 |
+
cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr'])
|
275 |
+
|
276 |
+
train_stats = train(model, train_loader, optimizer, epoch, device, config)
|
277 |
+
|
278 |
+
score_val_i2t, score_val_t2i, = evaluation(model_without_ddp, val_loader, device, config)
|
279 |
+
score_test_i2t, score_test_t2i = evaluation(model_without_ddp, test_loader, device, config)
|
280 |
+
|
281 |
+
if utils.is_main_process():
|
282 |
+
|
283 |
+
val_result = itm_eval(score_val_i2t, score_val_t2i, val_loader.dataset.txt2img, val_loader.dataset.img2txt)
|
284 |
+
print(val_result)
|
285 |
+
|
286 |
+
if val_result['r_mean']>best:
|
287 |
+
save_obj = {
|
288 |
+
'model': model_without_ddp.state_dict(),
|
289 |
+
'optimizer': optimizer.state_dict(),
|
290 |
+
'config': config,
|
291 |
+
'epoch': epoch,
|
292 |
+
}
|
293 |
+
torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth'))
|
294 |
+
best = val_result['r_mean']
|
295 |
+
best_epoch = epoch
|
296 |
+
|
297 |
+
test_result = itm_eval(score_test_i2t, score_test_t2i, test_loader.dataset.txt2img, test_loader.dataset.img2txt)
|
298 |
+
print(test_result)
|
299 |
+
|
300 |
+
if args.evaluate:
|
301 |
+
log_stats = {**{f'val_{k}': v for k, v in val_result.items()},
|
302 |
+
**{f'test_{k}': v for k, v in test_result.items()},
|
303 |
+
}
|
304 |
+
with open(os.path.join(args.output_dir, "evaluate.txt"),"a") as f:
|
305 |
+
f.write(json.dumps(log_stats) + "\n")
|
306 |
+
else:
|
307 |
+
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
|
308 |
+
**{f'val_{k}': v for k, v in val_result.items()},
|
309 |
+
**{f'test_{k}': v for k, v in test_result.items()},
|
310 |
+
'epoch': epoch,
|
311 |
+
'best_epoch': best_epoch,
|
312 |
+
}
|
313 |
+
with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
|
314 |
+
f.write(json.dumps(log_stats) + "\n")
|
315 |
+
|
316 |
+
if args.evaluate:
|
317 |
+
break
|
318 |
+
|
319 |
+
dist.barrier()
|
320 |
+
torch.cuda.empty_cache()
|
321 |
+
|
322 |
+
total_time = time.time() - start_time
|
323 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
324 |
+
print('Training time {}'.format(total_time_str))
|
325 |
+
|
326 |
+
|
327 |
+
if __name__ == '__main__':
|
328 |
+
parser = argparse.ArgumentParser()
|
329 |
+
parser.add_argument('--config', default='./configs/retrieval_flickr.yaml')
|
330 |
+
parser.add_argument('--output_dir', default='output/Retrieval_flickr')
|
331 |
+
parser.add_argument('--evaluate', action='store_true')
|
332 |
+
parser.add_argument('--device', default='cuda')
|
333 |
+
parser.add_argument('--seed', default=42, type=int)
|
334 |
+
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
|
335 |
+
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
|
336 |
+
parser.add_argument('--distributed', default=True, type=bool)
|
337 |
+
args = parser.parse_args()
|
338 |
+
|
339 |
+
config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
|
340 |
+
|
341 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
342 |
+
|
343 |
+
yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
|
344 |
+
|
345 |
+
main(args, config)
|
BLIP_train_vqa.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
* Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
* All rights reserved.
|
4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
* By Junnan Li
|
7 |
+
'''
|
8 |
+
import argparse
|
9 |
+
import os
|
10 |
+
import ruamel_yaml as yaml
|
11 |
+
import numpy as np
|
12 |
+
import random
|
13 |
+
import time
|
14 |
+
import datetime
|
15 |
+
import json
|
16 |
+
from pathlib import Path
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
from torch.utils.data import DataLoader
|
22 |
+
import torch.backends.cudnn as cudnn
|
23 |
+
import torch.distributed as dist
|
24 |
+
|
25 |
+
from models.blip_vqa import blip_vqa
|
26 |
+
import utils
|
27 |
+
from utils import cosine_lr_schedule
|
28 |
+
from data import create_dataset, create_sampler, create_loader
|
29 |
+
from data.vqa_dataset import vqa_collate_fn
|
30 |
+
from data.utils import save_result
|
31 |
+
|
32 |
+
|
33 |
+
def train(model, data_loader, optimizer, epoch, device):
|
34 |
+
# train
|
35 |
+
model.train()
|
36 |
+
|
37 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
38 |
+
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
39 |
+
metric_logger.add_meter('loss', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
|
40 |
+
|
41 |
+
header = 'Train Epoch: [{}]'.format(epoch)
|
42 |
+
print_freq = 50
|
43 |
+
|
44 |
+
for i,(image, question, answer, weights, n) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
|
45 |
+
image, weights = image.to(device,non_blocking=True), weights.to(device,non_blocking=True)
|
46 |
+
|
47 |
+
loss = model(image, question, answer, train=True, n=n, weights=weights)
|
48 |
+
|
49 |
+
optimizer.zero_grad()
|
50 |
+
loss.backward()
|
51 |
+
optimizer.step()
|
52 |
+
|
53 |
+
metric_logger.update(loss=loss.item())
|
54 |
+
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
55 |
+
|
56 |
+
# gather the stats from all processes
|
57 |
+
metric_logger.synchronize_between_processes()
|
58 |
+
print("Averaged stats:", metric_logger.global_avg())
|
59 |
+
return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
|
60 |
+
|
61 |
+
|
62 |
+
@torch.no_grad()
|
63 |
+
def evaluation(model, data_loader, device, config) :
|
64 |
+
# test
|
65 |
+
model.eval()
|
66 |
+
|
67 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
68 |
+
header = 'Generate VQA test result:'
|
69 |
+
print_freq = 50
|
70 |
+
|
71 |
+
result = []
|
72 |
+
|
73 |
+
if config['inference']=='rank':
|
74 |
+
answer_list = data_loader.dataset.answer_list
|
75 |
+
answer_candidates = model.tokenizer(answer_list, padding='longest', return_tensors='pt').to(device)
|
76 |
+
answer_candidates.input_ids[:,0] = model.tokenizer.bos_token_id
|
77 |
+
|
78 |
+
for n, (image, question, question_id) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
|
79 |
+
image = image.to(device,non_blocking=True)
|
80 |
+
|
81 |
+
if config['inference']=='generate':
|
82 |
+
answers = model(image, question, train=False, inference='generate')
|
83 |
+
|
84 |
+
for answer, ques_id in zip(answers, question_id):
|
85 |
+
ques_id = int(ques_id.item())
|
86 |
+
result.append({"question_id":ques_id, "answer":answer})
|
87 |
+
|
88 |
+
elif config['inference']=='rank':
|
89 |
+
answer_ids = model(image, question, answer_candidates, train=False, inference='rank', k_test=config['k_test'])
|
90 |
+
|
91 |
+
for ques_id, answer_id in zip(question_id, answer_ids):
|
92 |
+
result.append({"question_id":int(ques_id.item()), "answer":answer_list[answer_id]})
|
93 |
+
|
94 |
+
return result
|
95 |
+
|
96 |
+
|
97 |
+
def main(args, config):
|
98 |
+
utils.init_distributed_mode(args)
|
99 |
+
|
100 |
+
device = torch.device(args.device)
|
101 |
+
|
102 |
+
# fix the seed for reproducibility
|
103 |
+
seed = args.seed + utils.get_rank()
|
104 |
+
torch.manual_seed(seed)
|
105 |
+
np.random.seed(seed)
|
106 |
+
random.seed(seed)
|
107 |
+
cudnn.benchmark = True
|
108 |
+
|
109 |
+
#### Dataset ####
|
110 |
+
print("Creating vqa datasets")
|
111 |
+
datasets = create_dataset('vqa', config)
|
112 |
+
|
113 |
+
if args.distributed:
|
114 |
+
num_tasks = utils.get_world_size()
|
115 |
+
global_rank = utils.get_rank()
|
116 |
+
samplers = create_sampler(datasets, [True, False], num_tasks, global_rank)
|
117 |
+
else:
|
118 |
+
samplers = [None, None]
|
119 |
+
|
120 |
+
train_loader, test_loader = create_loader(datasets,samplers,
|
121 |
+
batch_size=[config['batch_size_train'],config['batch_size_test']],
|
122 |
+
num_workers=[4,4],is_trains=[True, False],
|
123 |
+
collate_fns=[vqa_collate_fn,None])
|
124 |
+
#### Model ####
|
125 |
+
print("Creating model")
|
126 |
+
model = blip_vqa(pretrained=config['pretrained'], image_size=config['image_size'],
|
127 |
+
vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'])
|
128 |
+
|
129 |
+
model = model.to(device)
|
130 |
+
|
131 |
+
model_without_ddp = model
|
132 |
+
if args.distributed:
|
133 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
134 |
+
model_without_ddp = model.module
|
135 |
+
|
136 |
+
optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay'])
|
137 |
+
|
138 |
+
best = 0
|
139 |
+
best_epoch = 0
|
140 |
+
|
141 |
+
print("Start training")
|
142 |
+
start_time = time.time()
|
143 |
+
for epoch in range(0, config['max_epoch']):
|
144 |
+
if not args.evaluate:
|
145 |
+
if args.distributed:
|
146 |
+
train_loader.sampler.set_epoch(epoch)
|
147 |
+
|
148 |
+
cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr'])
|
149 |
+
|
150 |
+
train_stats = train(model, train_loader, optimizer, epoch, device)
|
151 |
+
|
152 |
+
else:
|
153 |
+
break
|
154 |
+
|
155 |
+
if utils.is_main_process():
|
156 |
+
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
|
157 |
+
'epoch': epoch,
|
158 |
+
}
|
159 |
+
with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
|
160 |
+
f.write(json.dumps(log_stats) + "\n")
|
161 |
+
|
162 |
+
save_obj = {
|
163 |
+
'model': model_without_ddp.state_dict(),
|
164 |
+
'optimizer': optimizer.state_dict(),
|
165 |
+
'config': config,
|
166 |
+
'epoch': epoch,
|
167 |
+
}
|
168 |
+
torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_%02d.pth'%epoch))
|
169 |
+
|
170 |
+
dist.barrier()
|
171 |
+
|
172 |
+
vqa_result = evaluation(model_without_ddp, test_loader, device, config)
|
173 |
+
result_file = save_result(vqa_result, args.result_dir, 'vqa_result')
|
174 |
+
|
175 |
+
total_time = time.time() - start_time
|
176 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
177 |
+
print('Training time {}'.format(total_time_str))
|
178 |
+
|
179 |
+
|
180 |
+
|
181 |
+
if __name__ == '__main__':
|
182 |
+
parser = argparse.ArgumentParser()
|
183 |
+
parser.add_argument('--config', default='./configs/vqa.yaml')
|
184 |
+
parser.add_argument('--output_dir', default='output/VQA')
|
185 |
+
parser.add_argument('--evaluate', action='store_true')
|
186 |
+
parser.add_argument('--device', default='cuda')
|
187 |
+
parser.add_argument('--seed', default=42, type=int)
|
188 |
+
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
|
189 |
+
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
|
190 |
+
parser.add_argument('--distributed', default=True, type=bool)
|
191 |
+
args = parser.parse_args()
|
192 |
+
|
193 |
+
config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
|
194 |
+
|
195 |
+
args.result_dir = os.path.join(args.output_dir, 'result')
|
196 |
+
|
197 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
198 |
+
Path(args.result_dir).mkdir(parents=True, exist_ok=True)
|
199 |
+
|
200 |
+
yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
|
201 |
+
|
202 |
+
main(args, config)
|
BLIP_utils.py
ADDED
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
|
3 |
+
"""Decay the learning rate"""
|
4 |
+
lr = (init_lr - min_lr) * 0.5 * (1. + math.cos(math.pi * epoch / max_epoch)) + min_lr
|
5 |
+
for param_group in optimizer.param_groups:
|
6 |
+
param_group['lr'] = lr
|
7 |
+
|
8 |
+
def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
|
9 |
+
"""Warmup the learning rate"""
|
10 |
+
lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max_step)
|
11 |
+
for param_group in optimizer.param_groups:
|
12 |
+
param_group['lr'] = lr
|
13 |
+
|
14 |
+
def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
|
15 |
+
"""Decay the learning rate"""
|
16 |
+
lr = max(min_lr, init_lr * (decay_rate**epoch))
|
17 |
+
for param_group in optimizer.param_groups:
|
18 |
+
param_group['lr'] = lr
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
import io
|
22 |
+
import os
|
23 |
+
import time
|
24 |
+
from collections import defaultdict, deque
|
25 |
+
import datetime
|
26 |
+
|
27 |
+
import torch
|
28 |
+
import torch.distributed as dist
|
29 |
+
|
30 |
+
class SmoothedValue(object):
|
31 |
+
"""Track a series of values and provide access to smoothed values over a
|
32 |
+
window or the global series average.
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __init__(self, window_size=20, fmt=None):
|
36 |
+
if fmt is None:
|
37 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
38 |
+
self.deque = deque(maxlen=window_size)
|
39 |
+
self.total = 0.0
|
40 |
+
self.count = 0
|
41 |
+
self.fmt = fmt
|
42 |
+
|
43 |
+
def update(self, value, n=1):
|
44 |
+
self.deque.append(value)
|
45 |
+
self.count += n
|
46 |
+
self.total += value * n
|
47 |
+
|
48 |
+
def synchronize_between_processes(self):
|
49 |
+
"""
|
50 |
+
Warning: does not synchronize the deque!
|
51 |
+
"""
|
52 |
+
if not is_dist_avail_and_initialized():
|
53 |
+
return
|
54 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
55 |
+
dist.barrier()
|
56 |
+
dist.all_reduce(t)
|
57 |
+
t = t.tolist()
|
58 |
+
self.count = int(t[0])
|
59 |
+
self.total = t[1]
|
60 |
+
|
61 |
+
@property
|
62 |
+
def median(self):
|
63 |
+
d = torch.tensor(list(self.deque))
|
64 |
+
return d.median().item()
|
65 |
+
|
66 |
+
@property
|
67 |
+
def avg(self):
|
68 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
69 |
+
return d.mean().item()
|
70 |
+
|
71 |
+
@property
|
72 |
+
def global_avg(self):
|
73 |
+
return self.total / self.count
|
74 |
+
|
75 |
+
@property
|
76 |
+
def max(self):
|
77 |
+
return max(self.deque)
|
78 |
+
|
79 |
+
@property
|
80 |
+
def value(self):
|
81 |
+
return self.deque[-1]
|
82 |
+
|
83 |
+
def __str__(self):
|
84 |
+
return self.fmt.format(
|
85 |
+
median=self.median,
|
86 |
+
avg=self.avg,
|
87 |
+
global_avg=self.global_avg,
|
88 |
+
max=self.max,
|
89 |
+
value=self.value)
|
90 |
+
|
91 |
+
|
92 |
+
class MetricLogger(object):
|
93 |
+
def __init__(self, delimiter="\t"):
|
94 |
+
self.meters = defaultdict(SmoothedValue)
|
95 |
+
self.delimiter = delimiter
|
96 |
+
|
97 |
+
def update(self, **kwargs):
|
98 |
+
for k, v in kwargs.items():
|
99 |
+
if isinstance(v, torch.Tensor):
|
100 |
+
v = v.item()
|
101 |
+
assert isinstance(v, (float, int))
|
102 |
+
self.meters[k].update(v)
|
103 |
+
|
104 |
+
def __getattr__(self, attr):
|
105 |
+
if attr in self.meters:
|
106 |
+
return self.meters[attr]
|
107 |
+
if attr in self.__dict__:
|
108 |
+
return self.__dict__[attr]
|
109 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(
|
110 |
+
type(self).__name__, attr))
|
111 |
+
|
112 |
+
def __str__(self):
|
113 |
+
loss_str = []
|
114 |
+
for name, meter in self.meters.items():
|
115 |
+
loss_str.append(
|
116 |
+
"{}: {}".format(name, str(meter))
|
117 |
+
)
|
118 |
+
return self.delimiter.join(loss_str)
|
119 |
+
|
120 |
+
def global_avg(self):
|
121 |
+
loss_str = []
|
122 |
+
for name, meter in self.meters.items():
|
123 |
+
loss_str.append(
|
124 |
+
"{}: {:.4f}".format(name, meter.global_avg)
|
125 |
+
)
|
126 |
+
return self.delimiter.join(loss_str)
|
127 |
+
|
128 |
+
def synchronize_between_processes(self):
|
129 |
+
for meter in self.meters.values():
|
130 |
+
meter.synchronize_between_processes()
|
131 |
+
|
132 |
+
def add_meter(self, name, meter):
|
133 |
+
self.meters[name] = meter
|
134 |
+
|
135 |
+
def log_every(self, iterable, print_freq, header=None):
|
136 |
+
i = 0
|
137 |
+
if not header:
|
138 |
+
header = ''
|
139 |
+
start_time = time.time()
|
140 |
+
end = time.time()
|
141 |
+
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
142 |
+
data_time = SmoothedValue(fmt='{avg:.4f}')
|
143 |
+
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
144 |
+
log_msg = [
|
145 |
+
header,
|
146 |
+
'[{0' + space_fmt + '}/{1}]',
|
147 |
+
'eta: {eta}',
|
148 |
+
'{meters}',
|
149 |
+
'time: {time}',
|
150 |
+
'data: {data}'
|
151 |
+
]
|
152 |
+
if torch.cuda.is_available():
|
153 |
+
log_msg.append('max mem: {memory:.0f}')
|
154 |
+
log_msg = self.delimiter.join(log_msg)
|
155 |
+
MB = 1024.0 * 1024.0
|
156 |
+
for obj in iterable:
|
157 |
+
data_time.update(time.time() - end)
|
158 |
+
yield obj
|
159 |
+
iter_time.update(time.time() - end)
|
160 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
161 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
162 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
163 |
+
if torch.cuda.is_available():
|
164 |
+
print(log_msg.format(
|
165 |
+
i, len(iterable), eta=eta_string,
|
166 |
+
meters=str(self),
|
167 |
+
time=str(iter_time), data=str(data_time),
|
168 |
+
memory=torch.cuda.max_memory_allocated() / MB))
|
169 |
+
else:
|
170 |
+
print(log_msg.format(
|
171 |
+
i, len(iterable), eta=eta_string,
|
172 |
+
meters=str(self),
|
173 |
+
time=str(iter_time), data=str(data_time)))
|
174 |
+
i += 1
|
175 |
+
end = time.time()
|
176 |
+
total_time = time.time() - start_time
|
177 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
178 |
+
print('{} Total time: {} ({:.4f} s / it)'.format(
|
179 |
+
header, total_time_str, total_time / len(iterable)))
|
180 |
+
|
181 |
+
|
182 |
+
class AttrDict(dict):
|
183 |
+
def __init__(self, *args, **kwargs):
|
184 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
185 |
+
self.__dict__ = self
|
186 |
+
|
187 |
+
|
188 |
+
def compute_acc(logits, label, reduction='mean'):
|
189 |
+
ret = (torch.argmax(logits, dim=1) == label).float()
|
190 |
+
if reduction == 'none':
|
191 |
+
return ret.detach()
|
192 |
+
elif reduction == 'mean':
|
193 |
+
return ret.mean().item()
|
194 |
+
|
195 |
+
def compute_n_params(model, return_str=True):
|
196 |
+
tot = 0
|
197 |
+
for p in model.parameters():
|
198 |
+
w = 1
|
199 |
+
for x in p.shape:
|
200 |
+
w *= x
|
201 |
+
tot += w
|
202 |
+
if return_str:
|
203 |
+
if tot >= 1e6:
|
204 |
+
return '{:.1f}M'.format(tot / 1e6)
|
205 |
+
else:
|
206 |
+
return '{:.1f}K'.format(tot / 1e3)
|
207 |
+
else:
|
208 |
+
return tot
|
209 |
+
|
210 |
+
def setup_for_distributed(is_master):
|
211 |
+
"""
|
212 |
+
This function disables printing when not in master process
|
213 |
+
"""
|
214 |
+
import builtins as __builtin__
|
215 |
+
builtin_print = __builtin__.print
|
216 |
+
|
217 |
+
def print(*args, **kwargs):
|
218 |
+
force = kwargs.pop('force', False)
|
219 |
+
if is_master or force:
|
220 |
+
builtin_print(*args, **kwargs)
|
221 |
+
|
222 |
+
__builtin__.print = print
|
223 |
+
|
224 |
+
|
225 |
+
def is_dist_avail_and_initialized():
|
226 |
+
if not dist.is_available():
|
227 |
+
return False
|
228 |
+
if not dist.is_initialized():
|
229 |
+
return False
|
230 |
+
return True
|
231 |
+
|
232 |
+
|
233 |
+
def get_world_size():
|
234 |
+
if not is_dist_avail_and_initialized():
|
235 |
+
return 1
|
236 |
+
return dist.get_world_size()
|
237 |
+
|
238 |
+
|
239 |
+
def get_rank():
|
240 |
+
if not is_dist_avail_and_initialized():
|
241 |
+
return 0
|
242 |
+
return dist.get_rank()
|
243 |
+
|
244 |
+
|
245 |
+
def is_main_process():
|
246 |
+
return get_rank() == 0
|
247 |
+
|
248 |
+
|
249 |
+
def save_on_master(*args, **kwargs):
|
250 |
+
if is_main_process():
|
251 |
+
torch.save(*args, **kwargs)
|
252 |
+
|
253 |
+
|
254 |
+
def init_distributed_mode(args):
|
255 |
+
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
256 |
+
args.rank = int(os.environ["RANK"])
|
257 |
+
args.world_size = int(os.environ['WORLD_SIZE'])
|
258 |
+
args.gpu = int(os.environ['LOCAL_RANK'])
|
259 |
+
elif 'SLURM_PROCID' in os.environ:
|
260 |
+
args.rank = int(os.environ['SLURM_PROCID'])
|
261 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
262 |
+
else:
|
263 |
+
print('Not using distributed mode')
|
264 |
+
args.distributed = False
|
265 |
+
return
|
266 |
+
|
267 |
+
args.distributed = True
|
268 |
+
|
269 |
+
torch.cuda.set_device(args.gpu)
|
270 |
+
args.dist_backend = 'nccl'
|
271 |
+
print('| distributed init (rank {}, word {}): {}'.format(
|
272 |
+
args.rank, args.world_size, args.dist_url), flush=True)
|
273 |
+
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
274 |
+
world_size=args.world_size, rank=args.rank)
|
275 |
+
torch.distributed.barrier()
|
276 |
+
setup_for_distributed(args.rank == 0)
|
277 |
+
|
278 |
+
|