arnavkartikeya commited on
Commit
be02b2a
1 Parent(s): 3575076

first commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
CODEOWNERS ADDED
@@ -0,0 +1,2 @@
 
 
1
+ # Comment line immediately above ownership line is reserved for related gus information. Please be careful while editing.
2
+ #ECCN:Open Source
CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Salesforce Open Source Community Code of Conduct
2
+
3
+ ## About the Code of Conduct
4
+
5
+ Equality is a core value at Salesforce. We believe a diverse and inclusive
6
+ community fosters innovation and creativity, and are committed to building a
7
+ culture where everyone feels included.
8
+
9
+ Salesforce open-source projects are committed to providing a friendly, safe, and
10
+ welcoming environment for all, regardless of gender identity and expression,
11
+ sexual orientation, disability, physical appearance, body size, ethnicity, nationality,
12
+ race, age, religion, level of experience, education, socioeconomic status, or
13
+ other similar personal characteristics.
14
+
15
+ The goal of this code of conduct is to specify a baseline standard of behavior so
16
+ that people with different social values and communication styles can work
17
+ together effectively, productively, and respectfully in our open source community.
18
+ It also establishes a mechanism for reporting issues and resolving conflicts.
19
+
20
+ All questions and reports of abusive, harassing, or otherwise unacceptable behavior
21
+ in a Salesforce open-source project may be reported by contacting the Salesforce
22
+ Open Source Conduct Committee at ossconduct@salesforce.com.
23
+
24
+ ## Our Pledge
25
+
26
+ In the interest of fostering an open and welcoming environment, we as
27
+ contributors and maintainers pledge to making participation in our project and
28
+ our community a harassment-free experience for everyone, regardless of gender
29
+ identity and expression, sexual orientation, disability, physical appearance,
30
+ body size, ethnicity, nationality, race, age, religion, level of experience, education,
31
+ socioeconomic status, or other similar personal characteristics.
32
+
33
+ ## Our Standards
34
+
35
+ Examples of behavior that contributes to creating a positive environment
36
+ include:
37
+
38
+ * Using welcoming and inclusive language
39
+ * Being respectful of differing viewpoints and experiences
40
+ * Gracefully accepting constructive criticism
41
+ * Focusing on what is best for the community
42
+ * Showing empathy toward other community members
43
+
44
+ Examples of unacceptable behavior by participants include:
45
+
46
+ * The use of sexualized language or imagery and unwelcome sexual attention or
47
+ advances
48
+ * Personal attacks, insulting/derogatory comments, or trolling
49
+ * Public or private harassment
50
+ * Publishing, or threatening to publish, others' private information—such as
51
+ a physical or electronic address—without explicit permission
52
+ * Other conduct which could reasonably be considered inappropriate in a
53
+ professional setting
54
+ * Advocating for or encouraging any of the above behaviors
55
+
56
+ ## Our Responsibilities
57
+
58
+ Project maintainers are responsible for clarifying the standards of acceptable
59
+ behavior and are expected to take appropriate and fair corrective action in
60
+ response to any instances of unacceptable behavior.
61
+
62
+ Project maintainers have the right and responsibility to remove, edit, or
63
+ reject comments, commits, code, wiki edits, issues, and other contributions
64
+ that are not aligned with this Code of Conduct, or to ban temporarily or
65
+ permanently any contributor for other behaviors that they deem inappropriate,
66
+ threatening, offensive, or harmful.
67
+
68
+ ## Scope
69
+
70
+ This Code of Conduct applies both within project spaces and in public spaces
71
+ when an individual is representing the project or its community. Examples of
72
+ representing a project or community include using an official project email
73
+ address, posting via an official social media account, or acting as an appointed
74
+ representative at an online or offline event. Representation of a project may be
75
+ further defined and clarified by project maintainers.
76
+
77
+ ## Enforcement
78
+
79
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
80
+ reported by contacting the Salesforce Open Source Conduct Committee
81
+ at ossconduct@salesforce.com. All complaints will be reviewed and investigated
82
+ and will result in a response that is deemed necessary and appropriate to the
83
+ circumstances. The committee is obligated to maintain confidentiality with
84
+ regard to the reporter of an incident. Further details of specific enforcement
85
+ policies may be posted separately.
86
+
87
+ Project maintainers who do not follow or enforce the Code of Conduct in good
88
+ faith may face temporary or permanent repercussions as determined by other
89
+ members of the project's leadership and the Salesforce Open Source Conduct
90
+ Committee.
91
+
92
+ ## Attribution
93
+
94
+ This Code of Conduct is adapted from the [Contributor Covenant][contributor-covenant-home],
95
+ version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html.
96
+ It includes adaptions and additions from [Go Community Code of Conduct][golang-coc],
97
+ [CNCF Code of Conduct][cncf-coc], and [Microsoft Open Source Code of Conduct][microsoft-coc].
98
+
99
+ This Code of Conduct is licensed under the [Creative Commons Attribution 3.0 License][cc-by-3-us].
100
+
101
+ [contributor-covenant-home]: https://www.contributor-covenant.org (https://www.contributor-covenant.org/)
102
+ [golang-coc]: https://golang.org/conduct
103
+ [cncf-coc]: https://github.com/cncf/foundation/blob/master/code-of-conduct.md
104
+ [microsoft-coc]: https://opensource.microsoft.com/codeofconduct/
105
+ [cc-by-3-us]: https://creativecommons.org/licenses/by/3.0/us/
LICENSE.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2022, Salesforce.com, Inc.
2
+ All rights reserved.
3
+
4
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
5
+
6
+ * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
7
+
8
+ * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
9
+
10
+ * Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
11
+
12
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
README.md CHANGED
@@ -1,12 +1,116 @@
1
- ---
2
- title: SCRIPture Final
3
- emoji: 😻
4
- colorFrom: pink
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 3.9.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation
2
+
3
+ ## Announcement: BLIP is now officially integrated into [LAVIS](https://github.com/salesforce/LAVIS) - a one-stop library for language-and-vision research and applications!
4
+
5
+ <img src="BLIP.gif" width="700">
6
+
7
+ This is the PyTorch code of the <a href="https://arxiv.org/abs/2201.12086">BLIP paper</a> [[blog](https://blog.salesforceairesearch.com/blip-bootstrapping-language-image-pretraining/)]. The code has been tested on PyTorch 1.10.
8
+ To install the dependencies, run <pre/>pip install -r requirements.txt</pre>
9
+
10
+ Catalog:
11
+ - [x] Inference demo
12
+ - [x] Pre-trained and finetuned checkpoints
13
+ - [x] Finetuning code for Image-Text Retrieval, Image Captioning, VQA, and NLVR2
14
+ - [x] Pre-training code
15
+ - [x] Zero-shot video-text retrieval
16
+ - [x] Download of bootstrapped pre-training datasets
17
+
18
+
19
+ ### Inference demo:
20
+ Run our interactive demo using [Colab notebook](https://colab.research.google.com/github/salesforce/BLIP/blob/main/demo.ipynb) (no GPU needed).
21
+ The demo includes code for:
22
+ 1. Image captioning
23
+ 2. Open-ended visual question answering
24
+ 3. Multimodal / unimodal feature extraction
25
+ 4. Image-text matching
26
+
27
+ Try out the [Web demo](https://huggingface.co/spaces/Salesforce/BLIP), integrated into [Huggingface Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio).
28
+
29
+ Replicate web demo and Docker image is also available at [![Replicate](https://replicate.com/salesforce/blip/badge)](https://replicate.com/salesforce/blip)
30
+
31
+ ### Pre-trained checkpoints:
32
+ Num. pre-train images | BLIP w/ ViT-B | BLIP w/ ViT-B and CapFilt-L | BLIP w/ ViT-L
33
+ --- | :---: | :---: | :---:
34
+ 14M | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_14M.pth">Download</a>| - | -
35
+ 129M | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth">Download</a> | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large.pth">Download</a>
36
+
37
+ ### Finetuned checkpoints:
38
+ Task | BLIP w/ ViT-B | BLIP w/ ViT-B and CapFilt-L | BLIP w/ ViT-L
39
+ --- | :---: | :---: | :---:
40
+ Image-Text Retrieval (COCO) | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth">Download</a>| - | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco.pth">Download</a>
41
+ Image-Text Retrieval (Flickr30k) | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_flickr.pth">Download</a>| - | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_flickr.pth">Download</a>
42
+ Image Captioning (COCO) | - | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth">Download</a> |
43
+ VQA | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_vqa.pth">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth">Download</a> | -
44
+ NLVR2 | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_nlvr.pth">Download</a>| - | -
45
+
46
+
47
+ ### Image-Text Retrieval:
48
+ 1. Download COCO and Flickr30k datasets from the original websites, and set 'image_root' in configs/retrieval_{dataset}.yaml accordingly.
49
+ 2. To evaluate the finetuned BLIP model on COCO, run:
50
+ <pre>python -m torch.distributed.run --nproc_per_node=8 train_retrieval.py \
51
+ --config ./configs/retrieval_coco.yaml \
52
+ --output_dir output/retrieval_coco \
53
+ --evaluate</pre>
54
+ 3. To finetune the pre-trained checkpoint using 8 A100 GPUs, first set 'pretrained' in configs/retrieval_coco.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth". Then run:
55
+ <pre>python -m torch.distributed.run --nproc_per_node=8 train_retrieval.py \
56
+ --config ./configs/retrieval_coco.yaml \
57
+ --output_dir output/retrieval_coco </pre>
58
+
59
+ ### Image-Text Captioning:
60
+ 1. Download COCO and NoCaps datasets from the original websites, and set 'image_root' in configs/caption_coco.yaml and configs/nocaps.yaml accordingly.
61
+ 2. To evaluate the finetuned BLIP model on COCO, run:
62
+ <pre>python -m torch.distributed.run --nproc_per_node=8 train_caption.py --evaluate</pre>
63
+ 3. To evaluate the finetuned BLIP model on NoCaps, generate results with: (evaluation needs to be performed on official server)
64
+ <pre>python -m torch.distributed.run --nproc_per_node=8 eval_nocaps.py </pre>
65
+ 4. To finetune the pre-trained checkpoint using 8 A100 GPUs, first set 'pretrained' in configs/caption_coco.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth". Then run:
66
+ <pre>python -m torch.distributed.run --nproc_per_node=8 train_caption.py </pre>
67
+
68
+ ### VQA:
69
+ 1. Download VQA v2 dataset and Visual Genome dataset from the original websites, and set 'vqa_root' and 'vg_root' in configs/vqa.yaml.
70
+ 2. To evaluate the finetuned BLIP model, generate results with: (evaluation needs to be performed on official server)
71
+ <pre>python -m torch.distributed.run --nproc_per_node=8 train_vqa.py --evaluate</pre>
72
+ 3. To finetune the pre-trained checkpoint using 16 A100 GPUs, first set 'pretrained' in configs/vqa.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth". Then run:
73
+ <pre>python -m torch.distributed.run --nproc_per_node=16 train_vqa.py </pre>
74
+
75
+ ### NLVR2:
76
+ 1. Download NLVR2 dataset from the original websites, and set 'image_root' in configs/nlvr.yaml.
77
+ 2. To evaluate the finetuned BLIP model, run
78
+ <pre>python -m torch.distributed.run --nproc_per_node=8 train_nlvr.py --evaluate</pre>
79
+ 3. To finetune the pre-trained checkpoint using 16 A100 GPUs, first set 'pretrained' in configs/nlvr.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth". Then run:
80
+ <pre>python -m torch.distributed.run --nproc_per_node=16 train_nlvr.py </pre>
81
+
82
+ ### Finetune with ViT-L:
83
+ In order to finetune a model with ViT-L, simply change the config file to set 'vit' as large. Batch size and learning rate may also need to be adjusted accordingly (please see the paper's appendix for hyper-parameter details). <a href="https://github.com/facebookresearch/fairscale">Gradient checkpoint</a> can also be activated in the config file to reduce GPU memory usage.
84
+
85
+ ### Pre-train:
86
+ 1. Prepare training json files where each json file contains a list. Each item in the list is a dictonary with two key-value pairs: {'image': path_of_image, 'caption': text_of_image}.
87
+ 2. In configs/pretrain.yaml, set 'train_file' as the paths for the json files .
88
+ 3. Pre-train the model using 8 A100 GPUs:
89
+ <pre>python -m torch.distributed.run --nproc_per_node=8 pretrain.py --config ./configs/Pretrain.yaml --output_dir output/Pretrain </pre>
90
+
91
+ ### Zero-shot video-text retrieval:
92
+ 1. Download MSRVTT dataset following the instructions from https://github.com/salesforce/ALPRO, and set 'video_root' accordingly in configs/retrieval_msrvtt.yaml.
93
+ 2. Install [decord](https://github.com/dmlc/decord) with <pre>pip install decord</pre>
94
+ 3. To perform zero-shot evaluation, run
95
+ <pre>python -m torch.distributed.run --nproc_per_node=8 eval_retrieval_video.py</pre>
96
+
97
+ ### Pre-training datasets download:
98
+ We provide bootstrapped pre-training datasets as json files. Each json file contains a list. Each item in the list is a dictonary with two key-value pairs: {'url': url_of_image, 'caption': text_of_image}.
99
+
100
+ Image source | Filtered web caption | Filtered synthetic caption by ViT-B | Filtered synthetic caption by ViT-L
101
+ --- | :---: | :---: | :---:
102
+ CC3M+CC12M+SBU | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/ccs_filtered.json">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/ccs_synthetic_filtered.json">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/ccs_synthetic_filtered_large.json">Download</a>
103
+ LAION115M | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/laion_filtered.json">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/laion_synthetic_filtered.json">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/laion_synthetic_filtered_large.json">Download</a>
104
+
105
+ ### Citation
106
+ If you find this code to be useful for your research, please consider citing.
107
+ <pre>
108
+ @inproceedings{li2022blip,
109
+ title={BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation},
110
+ author={Junnan Li and Dongxu Li and Caiming Xiong and Steven Hoi},
111
+ year={2022},
112
+ booktitle={ICML},
113
+ }</pre>
114
+
115
+ ### Acknowledgement
116
+ The implementation of BLIP relies on resources from <a href="https://github.com/salesforce/ALBEF">ALBEF</a>, <a href="https://github.com/huggingface/transformers">Huggingface Transformers</a>, and <a href="https://github.com/rwightman/pytorch-image-models/tree/master/timm">timm</a>. We thank the original authors for their open-sourcing.
SECURITY.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
1
+ ## Security
2
+
3
+ Please report any security issue to [security@salesforce.com](mailto:security@salesforce.com)
4
+ as soon as it is discovered. This library limits its runtime dependencies in
5
+ order to reduce the total cost of ownership as much as can be, but all consumers
6
+ should remain vigilant and have their security stakeholders review all third-party
7
+ products (3PP) like this one and their dependencies.
app.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import requests
3
+ import torch
4
+ from torchvision import transforms
5
+ import os
6
+ from torchvision.transforms.functional import InterpolationMode
7
+ import matplotlib.pyplot as plt
8
+ import matplotlib.image as mpimg
9
+ import cohere
10
+ import gradio as gr
11
+ import string
12
+
13
+ def cap(t):
14
+ indices = []
15
+ tem = ""
16
+ for j in range(len(t)):
17
+ if t[j] == "." or t[j] == "!" or t[j] == "?":
18
+ if j+2 < len(t):
19
+ indices.append(j+2)
20
+ for j in range(len(t)):
21
+ if j in indices:
22
+ tem += t[j].upper()
23
+ else:
24
+ tem += t[j]
25
+ return tem
26
+ def processing(s):
27
+ #create a string[] that holds every sentence
28
+ arr = []
29
+ temp = ""
30
+ fin = ""
31
+ for i in range(len(s)):
32
+ temp += s[i]
33
+ if s[i] == "\n":
34
+ arr.append(temp)
35
+ temp = ""
36
+ if i == len(s)-1:
37
+ arr.append(temp)
38
+ for i in arr:
39
+ t = i
40
+ t = t.strip()
41
+ temp = ""
42
+ #make the first element of the string be the first alpha character
43
+ ind = 0
44
+ for j in range(len(t)):
45
+ if t[j].isalpha():
46
+ ind = j
47
+ break
48
+ t = t[ind:]
49
+ t = t.capitalize()
50
+ # capitalize all words after punctuation
51
+ t = cap(t)
52
+ #remove some punctuation
53
+ t = t.replace("(", "")
54
+ t = t.replace(")", "")
55
+ t = t.replace("&", "")
56
+ t = t.replace("#", "")
57
+ t = t.replace("_", "")
58
+
59
+ #remove punctuation if it is not following an alpha character
60
+ temp = ""
61
+ for j in range(len(t)):
62
+ if t[j] in string.punctuation:
63
+ if t[j-1] not in string.punctuation:
64
+ temp += t[j]
65
+ else:
66
+ temp += t[j]
67
+ fin += temp + "\n"
68
+ #find the last punctuation in fin and return everything before that
69
+ ind = 0
70
+ for i in range(len(fin)):
71
+ if fin[i] == "." or fin[i] == "?" or fin[i] == "!":
72
+ ind = i
73
+ if(ind != 0 and ind != len(fin) - 1):
74
+ return fin[:ind+1]
75
+ else:
76
+ return fin
77
+
78
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
79
+
80
+ from models.blip import blip_decoder
81
+
82
+ image_size = 384
83
+ transform = transforms.Compose([
84
+ transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
85
+ transforms.ToTensor(),
86
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
87
+ ])
88
+
89
+ model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth'
90
+
91
+ model = blip_decoder(pretrained=model_url, image_size=384, vit='large')
92
+ model.eval()
93
+ model = model.to(device)
94
+
95
+
96
+ from models.blip_vqa import blip_vqa
97
+
98
+ image_size_vq = 480
99
+ transform_vq = transforms.Compose([
100
+ transforms.Resize((image_size_vq,image_size_vq),interpolation=InterpolationMode.BICUBIC),
101
+ transforms.ToTensor(),
102
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
103
+ ])
104
+
105
+ model_url_vq = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_vqa.pth'
106
+
107
+ model_vq = blip_vqa(pretrained=model_url_vq, image_size=480, vit='base')
108
+ model_vq.eval()
109
+ model_vq = model_vq.to(device)
110
+
111
+
112
+
113
+ def inference(raw_image, model_n, question="", strategy=""):
114
+ if model_n == 'Image Captioning':
115
+ image = transform(raw_image).unsqueeze(0).to(device)
116
+ with torch.no_grad():
117
+ if strategy == "Beam search":
118
+ caption = model.generate(image, sample=False, num_beams=3, max_length=20, min_length=5)
119
+ else:
120
+ caption = model.generate(image, sample=True, top_p=0.9, max_length=20, min_length=5)
121
+ return 'caption: '+caption[0]
122
+
123
+ else:
124
+ image_vq = transform_vq(raw_image).unsqueeze(0).to(device)
125
+ with torch.no_grad():
126
+ answer = model_vq(image_vq, question, train=False, inference='generate')
127
+ return 'answer: '+answer[0]
128
+
129
+ #get caption for a single iamge
130
+ def get_caption(image_path):
131
+ img = Image.open(image_path)
132
+ return inference(img, "Image Captioning")[9:]
133
+
134
+ def display(image_path):
135
+ img = mpimg.imread(image_path)
136
+ img = Image.open(image_path)
137
+ plt.imshow(img)
138
+ print("Caption: " + get_caption(image_path))
139
+
140
+ #returns a dictionary with key -> img_path and value -> caption
141
+ def get_captions(img_directory, print_status=True):
142
+ #key is img path, value is the caption
143
+ captions = {}
144
+ length = 0
145
+ for file in os.listdir(img_directory):
146
+ length+=1
147
+ count = 0
148
+ for file in os.listdir(img_directory):
149
+ f = os.path.join(img_directory, file)
150
+ captions[f] = inference(Image.open(f), "Image Captioning")
151
+ if print_status:
152
+ print("Images complete:", str(count) + "/" + str(length))
153
+ print("Caption:", captions[f])
154
+ return captions
155
+ #writes dictionary to file, key and value seperated by ':'
156
+ def write_to_file(filename, caption_dict):
157
+ with open(filename, "w") as file:
158
+ for i in caption_dict:
159
+ file.write(i + ":" + caption_dict[i])
160
+ file.close()
161
+
162
+ # Text to Image API
163
+
164
+ import requests
165
+ import base64
166
+
167
+
168
+ #add max tokens a slider
169
+
170
+ def make_image_and_story(prompt):
171
+ if(prompt is None or prompt == ""):
172
+ host = 'https://dev.paint.cohere.ai/txt2img'
173
+ response = requests.post(host, json={'prompt': 'Random monster', 'n_samples' : 1, 'n_iter' : 1})
174
+
175
+ # decode image
176
+ imageBytes = base64.b64decode(response.json()['image']) #decode
177
+
178
+ # save to disk
179
+ f = open("sample.png", "wb")
180
+ f.write(imageBytes)
181
+ f.close()
182
+
183
+ caption = get_caption("sample.png")
184
+
185
+ co = cohere.Client('SD5vY3pwFrA0bBNTnIpp4N02sWhK4vd7mkkcrpXS')
186
+ response = co.generate(prompt=caption, model ='aeb523c3-a79c-48ba-9274-a12ac07492a2-ft', max_tokens=80)
187
+
188
+ return Image.open("sample.png"), processing(response.generations[0].text)
189
+ else:
190
+ host = 'https://dev.paint.cohere.ai/txt2img'
191
+ response = requests.post(host, json={'prompt': prompt+", epic", 'n_samples' : 1, 'n_iter' : 1})
192
+
193
+ # decode image
194
+ imageBytes = base64.b64decode(response.json()['image']) #decode
195
+
196
+ # save to disk
197
+ f = open("sample.png", "wb")
198
+ f.write(imageBytes)
199
+ f.close()
200
+
201
+ caption = get_caption("sample.png")
202
+ caption += " " + prompt
203
+
204
+ co = cohere.Client('SD5vY3pwFrA0bBNTnIpp4N02sWhK4vd7mkkcrpXS')
205
+ response = co.generate(prompt=caption, model ='aeb523c3-a79c-48ba-9274-a12ac07492a2-ft', max_tokens=80)
206
+
207
+ return Image.open("sample.png"), processing(response.generations[0].text)
208
+
209
+
210
+ gr.Interface(fn=make_image_and_story, inputs="text", outputs=["image","text"],title='Fantasy Creature Generator').launch();
cog.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ build:
2
+ gpu: true
3
+ cuda: "11.1"
4
+ python_version: "3.8"
5
+ system_packages:
6
+ - "libgl1-mesa-glx"
7
+ - "libglib2.0-0"
8
+ python_packages:
9
+ - "ipython==7.30.1"
10
+ - "torchvision==0.11.1"
11
+ - "torch==1.10.0"
12
+ - "timm==0.4.12"
13
+ - "transformers==4.15.0"
14
+ - "fairscale==0.4.4"
15
+ - "pycocoevalcap==1.2"
16
+
17
+ predict: "predict.py:Predictor"
configs/bert_config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertModel"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "hidden_act": "gelu",
7
+ "hidden_dropout_prob": 0.1,
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "intermediate_size": 3072,
11
+ "layer_norm_eps": 1e-12,
12
+ "max_position_embeddings": 512,
13
+ "model_type": "bert",
14
+ "num_attention_heads": 12,
15
+ "num_hidden_layers": 12,
16
+ "pad_token_id": 0,
17
+ "type_vocab_size": 2,
18
+ "vocab_size": 30522,
19
+ "encoder_width": 768,
20
+ "add_cross_attention": true
21
+ }
configs/caption_coco.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_root: '/export/share/datasets/vision/coco/images/'
2
+ ann_root: 'annotation'
3
+ coco_gt_root: 'annotation/coco_gt'
4
+
5
+ # set pretrained as a file path or an url
6
+ pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
7
+
8
+ # size of vit model; base or large
9
+ vit: 'base'
10
+ vit_grad_ckpt: False
11
+ vit_ckpt_layer: 0
12
+ batch_size: 32
13
+ init_lr: 1e-5
14
+
15
+ # vit: 'large'
16
+ # vit_grad_ckpt: True
17
+ # vit_ckpt_layer: 5
18
+ # batch_size: 16
19
+ # init_lr: 2e-6
20
+
21
+ image_size: 384
22
+
23
+ # generation configs
24
+ max_length: 20
25
+ min_length: 5
26
+ num_beams: 3
27
+ prompt: 'a picture of '
28
+
29
+ # optimizer
30
+ weight_decay: 0.05
31
+ min_lr: 0
32
+ max_epoch: 5
33
+
configs/med_config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertModel"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "hidden_act": "gelu",
7
+ "hidden_dropout_prob": 0.1,
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "intermediate_size": 3072,
11
+ "layer_norm_eps": 1e-12,
12
+ "max_position_embeddings": 512,
13
+ "model_type": "bert",
14
+ "num_attention_heads": 12,
15
+ "num_hidden_layers": 12,
16
+ "pad_token_id": 0,
17
+ "type_vocab_size": 2,
18
+ "vocab_size": 30524,
19
+ "encoder_width": 768,
20
+ "add_cross_attention": true
21
+ }
configs/nlvr.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_root: '/export/share/datasets/vision/NLVR2/'
2
+ ann_root: 'annotation'
3
+
4
+ # set pretrained as a file path or an url
5
+ pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_nlvr.pth'
6
+
7
+ #size of vit model; base or large
8
+ vit: 'base'
9
+ batch_size_train: 16
10
+ batch_size_test: 64
11
+ vit_grad_ckpt: False
12
+ vit_ckpt_layer: 0
13
+ max_epoch: 15
14
+
15
+ image_size: 384
16
+
17
+ # optimizer
18
+ weight_decay: 0.05
19
+ init_lr: 3e-5
20
+ min_lr: 0
21
+
configs/nocaps.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_root: '/export/share/datasets/vision/nocaps/'
2
+ ann_root: 'annotation'
3
+
4
+ # set pretrained as a file path or an url
5
+ pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
6
+
7
+ vit: 'base'
8
+ batch_size: 32
9
+
10
+ image_size: 384
11
+
12
+ max_length: 20
13
+ min_length: 5
14
+ num_beams: 3
15
+ prompt: 'a picture of '
configs/pretrain.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train_file: ['/export/share/junnan-li/VL_pretrain/annotation/coco_karpathy_train.json',
2
+ '/export/share/junnan-li/VL_pretrain/annotation/vg_caption.json',
3
+ ]
4
+ laion_path: ''
5
+
6
+ # size of vit model; base or large
7
+ vit: 'base'
8
+ vit_grad_ckpt: False
9
+ vit_ckpt_layer: 0
10
+
11
+ image_size: 224
12
+ batch_size: 75
13
+
14
+ queue_size: 57600
15
+ alpha: 0.4
16
+
17
+ # optimizer
18
+ weight_decay: 0.05
19
+ init_lr: 3e-4
20
+ min_lr: 1e-6
21
+ warmup_lr: 1e-6
22
+ lr_decay_rate: 0.9
23
+ max_epoch: 20
24
+ warmup_steps: 3000
25
+
26
+
27
+
configs/retrieval_coco.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_root: '/export/share/datasets/vision/coco/images/'
2
+ ann_root: 'annotation'
3
+ dataset: 'coco'
4
+
5
+ # set pretrained as a file path or an url
6
+ pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth'
7
+
8
+ # size of vit model; base or large
9
+
10
+ vit: 'base'
11
+ batch_size_train: 32
12
+ batch_size_test: 64
13
+ vit_grad_ckpt: True
14
+ vit_ckpt_layer: 4
15
+ init_lr: 1e-5
16
+
17
+ # vit: 'large'
18
+ # batch_size_train: 16
19
+ # batch_size_test: 32
20
+ # vit_grad_ckpt: True
21
+ # vit_ckpt_layer: 12
22
+ # init_lr: 5e-6
23
+
24
+ image_size: 384
25
+ queue_size: 57600
26
+ alpha: 0.4
27
+ k_test: 256
28
+ negative_all_rank: True
29
+
30
+ # optimizer
31
+ weight_decay: 0.05
32
+ min_lr: 0
33
+ max_epoch: 6
34
+
configs/retrieval_flickr.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_root: '/export/share/datasets/vision/flickr30k/'
2
+ ann_root: 'annotation'
3
+ dataset: 'flickr'
4
+
5
+ # set pretrained as a file path or an url
6
+ pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_flickr.pth'
7
+
8
+ # size of vit model; base or large
9
+
10
+ vit: 'base'
11
+ batch_size_train: 32
12
+ batch_size_test: 64
13
+ vit_grad_ckpt: True
14
+ vit_ckpt_layer: 4
15
+ init_lr: 1e-5
16
+
17
+ # vit: 'large'
18
+ # batch_size_train: 16
19
+ # batch_size_test: 32
20
+ # vit_grad_ckpt: True
21
+ # vit_ckpt_layer: 10
22
+ # init_lr: 5e-6
23
+
24
+ image_size: 384
25
+ queue_size: 57600
26
+ alpha: 0.4
27
+ k_test: 128
28
+ negative_all_rank: False
29
+
30
+ # optimizer
31
+ weight_decay: 0.05
32
+ min_lr: 0
33
+ max_epoch: 6
34
+
configs/retrieval_msrvtt.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ video_root: '/export/share/dongxuli/data/msrvtt_retrieval/videos'
2
+ ann_root: 'annotation'
3
+
4
+ # set pretrained as a file path or an url
5
+ pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth'
6
+
7
+ # size of vit model; base or large
8
+ vit: 'base'
9
+ batch_size: 64
10
+ k_test: 128
11
+ image_size: 384
12
+ num_frm_test: 8
configs/vqa.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ vqa_root: '/export/share/datasets/vision/VQA/Images/mscoco/' #followed by train2014/
2
+ vg_root: '/export/share/datasets/vision/visual-genome/' #followed by image/
3
+ train_files: ['vqa_train','vqa_val','vg_qa']
4
+ ann_root: 'annotation'
5
+
6
+ # set pretrained as a file path or an url
7
+ pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth'
8
+
9
+ # size of vit model; base or large
10
+ vit: 'base'
11
+ batch_size_train: 16
12
+ batch_size_test: 32
13
+ vit_grad_ckpt: False
14
+ vit_ckpt_layer: 0
15
+ init_lr: 2e-5
16
+
17
+ image_size: 480
18
+
19
+ k_test: 128
20
+ inference: 'rank'
21
+
22
+ # optimizer
23
+ weight_decay: 0.05
24
+ min_lr: 0
25
+ max_epoch: 10
data.txt ADDED
The diff for this file is too large to render. See raw diff
data/__init__.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DataLoader
3
+ from torchvision import transforms
4
+ from torchvision.transforms.functional import InterpolationMode
5
+
6
+ from data.coco_karpathy_dataset import coco_karpathy_train, coco_karpathy_caption_eval, coco_karpathy_retrieval_eval
7
+ from data.nocaps_dataset import nocaps_eval
8
+ from data.flickr30k_dataset import flickr30k_train, flickr30k_retrieval_eval
9
+ from data.vqa_dataset import vqa_dataset
10
+ from data.nlvr_dataset import nlvr_dataset
11
+ from data.pretrain_dataset import pretrain_dataset
12
+ from transform.randaugment import RandomAugment
13
+
14
+ def create_dataset(dataset, config, min_scale=0.5):
15
+
16
+ normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
17
+
18
+ transform_train = transforms.Compose([
19
+ transforms.RandomResizedCrop(config['image_size'],scale=(min_scale, 1.0),interpolation=InterpolationMode.BICUBIC),
20
+ transforms.RandomHorizontalFlip(),
21
+ RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize',
22
+ 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
23
+ transforms.ToTensor(),
24
+ normalize,
25
+ ])
26
+ transform_test = transforms.Compose([
27
+ transforms.Resize((config['image_size'],config['image_size']),interpolation=InterpolationMode.BICUBIC),
28
+ transforms.ToTensor(),
29
+ normalize,
30
+ ])
31
+
32
+ if dataset=='pretrain':
33
+ dataset = pretrain_dataset(config['train_file'], config['laion_path'], transform_train)
34
+ return dataset
35
+
36
+ elif dataset=='caption_coco':
37
+ train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'], prompt=config['prompt'])
38
+ val_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'val')
39
+ test_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'test')
40
+ return train_dataset, val_dataset, test_dataset
41
+
42
+ elif dataset=='nocaps':
43
+ val_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'val')
44
+ test_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'test')
45
+ return val_dataset, test_dataset
46
+
47
+ elif dataset=='retrieval_coco':
48
+ train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'])
49
+ val_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val')
50
+ test_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test')
51
+ return train_dataset, val_dataset, test_dataset
52
+
53
+ elif dataset=='retrieval_flickr':
54
+ train_dataset = flickr30k_train(transform_train, config['image_root'], config['ann_root'])
55
+ val_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val')
56
+ test_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test')
57
+ return train_dataset, val_dataset, test_dataset
58
+
59
+ elif dataset=='vqa':
60
+ train_dataset = vqa_dataset(transform_train, config['ann_root'], config['vqa_root'], config['vg_root'],
61
+ train_files = config['train_files'], split='train')
62
+ test_dataset = vqa_dataset(transform_test, config['ann_root'], config['vqa_root'], config['vg_root'], split='test')
63
+ return train_dataset, test_dataset
64
+
65
+ elif dataset=='nlvr':
66
+ train_dataset = nlvr_dataset(transform_train, config['image_root'], config['ann_root'],'train')
67
+ val_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'val')
68
+ test_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'test')
69
+ return train_dataset, val_dataset, test_dataset
70
+
71
+
72
+ def create_sampler(datasets, shuffles, num_tasks, global_rank):
73
+ samplers = []
74
+ for dataset,shuffle in zip(datasets,shuffles):
75
+ sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle)
76
+ samplers.append(sampler)
77
+ return samplers
78
+
79
+
80
+ def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
81
+ loaders = []
82
+ for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns):
83
+ if is_train:
84
+ shuffle = (sampler is None)
85
+ drop_last = True
86
+ else:
87
+ shuffle = False
88
+ drop_last = False
89
+ loader = DataLoader(
90
+ dataset,
91
+ batch_size=bs,
92
+ num_workers=n_worker,
93
+ pin_memory=True,
94
+ sampler=sampler,
95
+ shuffle=shuffle,
96
+ collate_fn=collate_fn,
97
+ drop_last=drop_last,
98
+ )
99
+ loaders.append(loader)
100
+ return loaders
101
+
data/coco_karpathy_dataset.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ from torch.utils.data import Dataset
5
+ from torchvision.datasets.utils import download_url
6
+
7
+ from PIL import Image
8
+
9
+ from data.utils import pre_caption
10
+
11
+ class coco_karpathy_train(Dataset):
12
+ def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''):
13
+ '''
14
+ image_root (string): Root directory of images (e.g. coco/images/)
15
+ ann_root (string): directory to store the annotation file
16
+ '''
17
+ url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json'
18
+ filename = 'coco_karpathy_train.json'
19
+
20
+ download_url(url,ann_root)
21
+
22
+ self.annotation = json.load(open(os.path.join(ann_root,filename),'r'))
23
+ self.transform = transform
24
+ self.image_root = image_root
25
+ self.max_words = max_words
26
+ self.prompt = prompt
27
+
28
+ self.img_ids = {}
29
+ n = 0
30
+ for ann in self.annotation:
31
+ img_id = ann['image_id']
32
+ if img_id not in self.img_ids.keys():
33
+ self.img_ids[img_id] = n
34
+ n += 1
35
+
36
+ def __len__(self):
37
+ return len(self.annotation)
38
+
39
+ def __getitem__(self, index):
40
+
41
+ ann = self.annotation[index]
42
+
43
+ image_path = os.path.join(self.image_root,ann['image'])
44
+ image = Image.open(image_path).convert('RGB')
45
+ image = self.transform(image)
46
+
47
+ caption = self.prompt+pre_caption(ann['caption'], self.max_words)
48
+
49
+ return image, caption, self.img_ids[ann['image_id']]
50
+
51
+
52
+ class coco_karpathy_caption_eval(Dataset):
53
+ def __init__(self, transform, image_root, ann_root, split):
54
+ '''
55
+ image_root (string): Root directory of images (e.g. coco/images/)
56
+ ann_root (string): directory to store the annotation file
57
+ split (string): val or test
58
+ '''
59
+ urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json',
60
+ 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'}
61
+ filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'}
62
+
63
+ download_url(urls[split],ann_root)
64
+
65
+ self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
66
+ self.transform = transform
67
+ self.image_root = image_root
68
+
69
+ def __len__(self):
70
+ return len(self.annotation)
71
+
72
+ def __getitem__(self, index):
73
+
74
+ ann = self.annotation[index]
75
+
76
+ image_path = os.path.join(self.image_root,ann['image'])
77
+ image = Image.open(image_path).convert('RGB')
78
+ image = self.transform(image)
79
+
80
+ img_id = ann['image'].split('/')[-1].strip('.jpg').split('_')[-1]
81
+
82
+ return image, int(img_id)
83
+
84
+
85
+ class coco_karpathy_retrieval_eval(Dataset):
86
+ def __init__(self, transform, image_root, ann_root, split, max_words=30):
87
+ '''
88
+ image_root (string): Root directory of images (e.g. coco/images/)
89
+ ann_root (string): directory to store the annotation file
90
+ split (string): val or test
91
+ '''
92
+ urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json',
93
+ 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'}
94
+ filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'}
95
+
96
+ download_url(urls[split],ann_root)
97
+
98
+ self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
99
+ self.transform = transform
100
+ self.image_root = image_root
101
+
102
+ self.text = []
103
+ self.image = []
104
+ self.txt2img = {}
105
+ self.img2txt = {}
106
+
107
+ txt_id = 0
108
+ for img_id, ann in enumerate(self.annotation):
109
+ self.image.append(ann['image'])
110
+ self.img2txt[img_id] = []
111
+ for i, caption in enumerate(ann['caption']):
112
+ self.text.append(pre_caption(caption,max_words))
113
+ self.img2txt[img_id].append(txt_id)
114
+ self.txt2img[txt_id] = img_id
115
+ txt_id += 1
116
+
117
+ def __len__(self):
118
+ return len(self.annotation)
119
+
120
+ def __getitem__(self, index):
121
+
122
+ image_path = os.path.join(self.image_root, self.annotation[index]['image'])
123
+ image = Image.open(image_path).convert('RGB')
124
+ image = self.transform(image)
125
+
126
+ return image, index
data/flickr30k_dataset.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ from torch.utils.data import Dataset
5
+ from torchvision.datasets.utils import download_url
6
+
7
+ from PIL import Image
8
+
9
+ from data.utils import pre_caption
10
+
11
+ class flickr30k_train(Dataset):
12
+ def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''):
13
+ '''
14
+ image_root (string): Root directory of images (e.g. flickr30k/)
15
+ ann_root (string): directory to store the annotation file
16
+ '''
17
+ url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_train.json'
18
+ filename = 'flickr30k_train.json'
19
+
20
+ download_url(url,ann_root)
21
+
22
+ self.annotation = json.load(open(os.path.join(ann_root,filename),'r'))
23
+ self.transform = transform
24
+ self.image_root = image_root
25
+ self.max_words = max_words
26
+ self.prompt = prompt
27
+
28
+ self.img_ids = {}
29
+ n = 0
30
+ for ann in self.annotation:
31
+ img_id = ann['image_id']
32
+ if img_id not in self.img_ids.keys():
33
+ self.img_ids[img_id] = n
34
+ n += 1
35
+
36
+ def __len__(self):
37
+ return len(self.annotation)
38
+
39
+ def __getitem__(self, index):
40
+
41
+ ann = self.annotation[index]
42
+
43
+ image_path = os.path.join(self.image_root,ann['image'])
44
+ image = Image.open(image_path).convert('RGB')
45
+ image = self.transform(image)
46
+
47
+ caption = self.prompt+pre_caption(ann['caption'], self.max_words)
48
+
49
+ return image, caption, self.img_ids[ann['image_id']]
50
+
51
+
52
+ class flickr30k_retrieval_eval(Dataset):
53
+ def __init__(self, transform, image_root, ann_root, split, max_words=30):
54
+ '''
55
+ image_root (string): Root directory of images (e.g. flickr30k/)
56
+ ann_root (string): directory to store the annotation file
57
+ split (string): val or test
58
+ '''
59
+ urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json',
60
+ 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json'}
61
+ filenames = {'val':'flickr30k_val.json','test':'flickr30k_test.json'}
62
+
63
+ download_url(urls[split],ann_root)
64
+
65
+ self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
66
+ self.transform = transform
67
+ self.image_root = image_root
68
+
69
+ self.text = []
70
+ self.image = []
71
+ self.txt2img = {}
72
+ self.img2txt = {}
73
+
74
+ txt_id = 0
75
+ for img_id, ann in enumerate(self.annotation):
76
+ self.image.append(ann['image'])
77
+ self.img2txt[img_id] = []
78
+ for i, caption in enumerate(ann['caption']):
79
+ self.text.append(pre_caption(caption,max_words))
80
+ self.img2txt[img_id].append(txt_id)
81
+ self.txt2img[txt_id] = img_id
82
+ txt_id += 1
83
+
84
+ def __len__(self):
85
+ return len(self.annotation)
86
+
87
+ def __getitem__(self, index):
88
+
89
+ image_path = os.path.join(self.image_root, self.annotation[index]['image'])
90
+ image = Image.open(image_path).convert('RGB')
91
+ image = self.transform(image)
92
+
93
+ return image, index
data/nlvr_dataset.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+
5
+ from torch.utils.data import Dataset
6
+ from torchvision.datasets.utils import download_url
7
+
8
+ from PIL import Image
9
+
10
+ from data.utils import pre_caption
11
+
12
+ class nlvr_dataset(Dataset):
13
+ def __init__(self, transform, image_root, ann_root, split):
14
+ '''
15
+ image_root (string): Root directory of images
16
+ ann_root (string): directory to store the annotation file
17
+ split (string): train, val or test
18
+ '''
19
+ urls = {'train':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_train.json',
20
+ 'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_dev.json',
21
+ 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_test.json'}
22
+ filenames = {'train':'nlvr_train.json','val':'nlvr_dev.json','test':'nlvr_test.json'}
23
+
24
+ download_url(urls[split],ann_root)
25
+ self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
26
+
27
+ self.transform = transform
28
+ self.image_root = image_root
29
+
30
+
31
+ def __len__(self):
32
+ return len(self.annotation)
33
+
34
+
35
+ def __getitem__(self, index):
36
+
37
+ ann = self.annotation[index]
38
+
39
+ image0_path = os.path.join(self.image_root,ann['images'][0])
40
+ image0 = Image.open(image0_path).convert('RGB')
41
+ image0 = self.transform(image0)
42
+
43
+ image1_path = os.path.join(self.image_root,ann['images'][1])
44
+ image1 = Image.open(image1_path).convert('RGB')
45
+ image1 = self.transform(image1)
46
+
47
+ sentence = pre_caption(ann['sentence'], 40)
48
+
49
+ if ann['label']=='True':
50
+ label = 1
51
+ else:
52
+ label = 0
53
+
54
+ words = sentence.split(' ')
55
+
56
+ if 'left' not in words and 'right' not in words:
57
+ if random.random()<0.5:
58
+ return image0, image1, sentence, label
59
+ else:
60
+ return image1, image0, sentence, label
61
+ else:
62
+ if random.random()<0.5:
63
+ return image0, image1, sentence, label
64
+ else:
65
+ new_words = []
66
+ for word in words:
67
+ if word=='left':
68
+ new_words.append('right')
69
+ elif word=='right':
70
+ new_words.append('left')
71
+ else:
72
+ new_words.append(word)
73
+
74
+ sentence = ' '.join(new_words)
75
+ return image1, image0, sentence, label
76
+
77
+
78
+
data/nocaps_dataset.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ from torch.utils.data import Dataset
5
+ from torchvision.datasets.utils import download_url
6
+
7
+ from PIL import Image
8
+
9
+ class nocaps_eval(Dataset):
10
+ def __init__(self, transform, image_root, ann_root, split):
11
+ urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_val.json',
12
+ 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_test.json'}
13
+ filenames = {'val':'nocaps_val.json','test':'nocaps_test.json'}
14
+
15
+ download_url(urls[split],ann_root)
16
+
17
+ self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
18
+ self.transform = transform
19
+ self.image_root = image_root
20
+
21
+ def __len__(self):
22
+ return len(self.annotation)
23
+
24
+ def __getitem__(self, index):
25
+
26
+ ann = self.annotation[index]
27
+
28
+ image_path = os.path.join(self.image_root,ann['image'])
29
+ image = Image.open(image_path).convert('RGB')
30
+ image = self.transform(image)
31
+
32
+ return image, int(ann['img_id'])
data/pretrain_dataset.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+
5
+ from torch.utils.data import Dataset
6
+
7
+ from PIL import Image
8
+ from PIL import ImageFile
9
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
10
+ Image.MAX_IMAGE_PIXELS = None
11
+
12
+ from data.utils import pre_caption
13
+ import os,glob
14
+
15
+ class pretrain_dataset(Dataset):
16
+ def __init__(self, ann_file, laion_path, transform):
17
+
18
+ self.ann_pretrain = []
19
+ for f in ann_file:
20
+ print('loading '+f)
21
+ ann = json.load(open(f,'r'))
22
+ self.ann_pretrain += ann
23
+
24
+ self.laion_path = laion_path
25
+ if self.laion_path:
26
+ self.laion_files = glob.glob(os.path.join(laion_path,'*.json'))
27
+
28
+ print('loading '+self.laion_files[0])
29
+ with open(self.laion_files[0],'r') as f:
30
+ self.ann_laion = json.load(f)
31
+
32
+ self.annotation = self.ann_pretrain + self.ann_laion
33
+ else:
34
+ self.annotation = self.ann_pretrain
35
+
36
+ self.transform = transform
37
+
38
+
39
+ def reload_laion(self, epoch):
40
+ n = epoch%len(self.laion_files)
41
+ print('loading '+self.laion_files[n])
42
+ with open(self.laion_files[n],'r') as f:
43
+ self.ann_laion = json.load(f)
44
+
45
+ self.annotation = self.ann_pretrain + self.ann_laion
46
+
47
+
48
+ def __len__(self):
49
+ return len(self.annotation)
50
+
51
+ def __getitem__(self, index):
52
+
53
+ ann = self.annotation[index]
54
+
55
+ image = Image.open(ann['image']).convert('RGB')
56
+ image = self.transform(image)
57
+ caption = pre_caption(ann['caption'],30)
58
+
59
+ return image, caption
data/utils.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import json
3
+ import os
4
+
5
+ import torch
6
+ import torch.distributed as dist
7
+
8
+ import utils
9
+
10
+ def pre_caption(caption,max_words=50):
11
+ caption = re.sub(
12
+ r"([.!\"()*#:;~])",
13
+ ' ',
14
+ caption.lower(),
15
+ )
16
+ caption = re.sub(
17
+ r"\s{2,}",
18
+ ' ',
19
+ caption,
20
+ )
21
+ caption = caption.rstrip('\n')
22
+ caption = caption.strip(' ')
23
+
24
+ #truncate caption
25
+ caption_words = caption.split(' ')
26
+ if len(caption_words)>max_words:
27
+ caption = ' '.join(caption_words[:max_words])
28
+
29
+ return caption
30
+
31
+ def pre_question(question,max_ques_words=50):
32
+ question = re.sub(
33
+ r"([.!\"()*#:;~])",
34
+ '',
35
+ question.lower(),
36
+ )
37
+ question = question.rstrip(' ')
38
+
39
+ #truncate question
40
+ question_words = question.split(' ')
41
+ if len(question_words)>max_ques_words:
42
+ question = ' '.join(question_words[:max_ques_words])
43
+
44
+ return question
45
+
46
+
47
+ def save_result(result, result_dir, filename, remove_duplicate=''):
48
+ result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,utils.get_rank()))
49
+ final_result_file = os.path.join(result_dir, '%s.json'%filename)
50
+
51
+ json.dump(result,open(result_file,'w'))
52
+
53
+ dist.barrier()
54
+
55
+ if utils.is_main_process():
56
+ # combine results from all processes
57
+ result = []
58
+
59
+ for rank in range(utils.get_world_size()):
60
+ result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,rank))
61
+ res = json.load(open(result_file,'r'))
62
+ result += res
63
+
64
+ if remove_duplicate:
65
+ result_new = []
66
+ id_list = []
67
+ for res in result:
68
+ if res[remove_duplicate] not in id_list:
69
+ id_list.append(res[remove_duplicate])
70
+ result_new.append(res)
71
+ result = result_new
72
+
73
+ json.dump(result,open(final_result_file,'w'))
74
+ print('result file saved to %s'%final_result_file)
75
+
76
+ return final_result_file
77
+
78
+
79
+
80
+ from pycocotools.coco import COCO
81
+ from pycocoevalcap.eval import COCOEvalCap
82
+ from torchvision.datasets.utils import download_url
83
+
84
+ def coco_caption_eval(coco_gt_root, results_file, split):
85
+ urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json',
86
+ 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json'}
87
+ filenames = {'val':'coco_karpathy_val_gt.json','test':'coco_karpathy_test_gt.json'}
88
+
89
+ download_url(urls[split],coco_gt_root)
90
+ annotation_file = os.path.join(coco_gt_root,filenames[split])
91
+
92
+ # create coco object and coco_result object
93
+ coco = COCO(annotation_file)
94
+ coco_result = coco.loadRes(results_file)
95
+
96
+ # create coco_eval object by taking coco and coco_result
97
+ coco_eval = COCOEvalCap(coco, coco_result)
98
+
99
+ # evaluate on a subset of images by setting
100
+ # coco_eval.params['image_id'] = coco_result.getImgIds()
101
+ # please remove this line when evaluating the full validation set
102
+ # coco_eval.params['image_id'] = coco_result.getImgIds()
103
+
104
+ # evaluate results
105
+ # SPICE will take a few minutes the first time, but speeds up due to caching
106
+ coco_eval.evaluate()
107
+
108
+ # print output evaluation scores
109
+ for metric, score in coco_eval.eval.items():
110
+ print(f'{metric}: {score:.3f}')
111
+
112
+ return coco_eval
data/video_dataset.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ from torchvision.datasets.utils import download_url
3
+
4
+ from PIL import Image
5
+ import torch
6
+ import numpy as np
7
+ import random
8
+ import decord
9
+ from decord import VideoReader
10
+ import json
11
+ import os
12
+ from data.utils import pre_caption
13
+
14
+ decord.bridge.set_bridge("torch")
15
+
16
+ class ImageNorm(object):
17
+ """Apply Normalization to Image Pixels on GPU
18
+ """
19
+ def __init__(self, mean, std):
20
+ self.mean = torch.tensor(mean).view(1, 3, 1, 1)
21
+ self.std = torch.tensor(std).view(1, 3, 1, 1)
22
+
23
+ def __call__(self, img):
24
+
25
+ if torch.max(img) > 1 and self.mean.max() <= 1:
26
+ img.div_(255.)
27
+ return img.sub_(self.mean).div_(self.std)
28
+
29
+ def load_jsonl(filename):
30
+ with open(filename, "r") as f:
31
+ return [json.loads(l.strip("\n")) for l in f.readlines()]
32
+
33
+
34
+ class VideoDataset(Dataset):
35
+
36
+ def __init__(self, video_root, ann_root, num_frm=4, frm_sampling_strategy="rand", max_img_size=384, video_fmt='.mp4'):
37
+ '''
38
+ image_root (string): Root directory of video
39
+ ann_root (string): directory to store the annotation file
40
+ '''
41
+ url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/msrvtt_test.jsonl'
42
+ filename = 'msrvtt_test.jsonl'
43
+
44
+ download_url(url,ann_root)
45
+ self.annotation = load_jsonl(os.path.join(ann_root,filename))
46
+
47
+ self.num_frm = num_frm
48
+ self.frm_sampling_strategy = frm_sampling_strategy
49
+ self.max_img_size = max_img_size
50
+ self.video_root = video_root
51
+ self.video_fmt = video_fmt
52
+ self.img_norm = ImageNorm(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
53
+
54
+ self.text = [pre_caption(ann['caption'],40) for ann in self.annotation]
55
+ self.txt2video = [i for i in range(len(self.annotation))]
56
+ self.video2txt = self.txt2video
57
+
58
+
59
+ def __len__(self):
60
+ return len(self.annotation)
61
+
62
+ def __getitem__(self, index):
63
+
64
+ ann = self.annotation[index]
65
+
66
+ video_path = os.path.join(self.video_root, ann['clip_name'] + self.video_fmt)
67
+
68
+ vid_frm_array = self._load_video_from_path_decord(video_path, height=self.max_img_size, width=self.max_img_size)
69
+
70
+ video = self.img_norm(vid_frm_array.float())
71
+
72
+ return video, ann['clip_name']
73
+
74
+
75
+
76
+ def _load_video_from_path_decord(self, video_path, height=None, width=None, start_time=None, end_time=None, fps=-1):
77
+ try:
78
+ if not height or not width:
79
+ vr = VideoReader(video_path)
80
+ else:
81
+ vr = VideoReader(video_path, width=width, height=height)
82
+
83
+ vlen = len(vr)
84
+
85
+ if start_time or end_time:
86
+ assert fps > 0, 'must provide video fps if specifying start and end time.'
87
+
88
+ start_idx = min(int(start_time * fps), vlen)
89
+ end_idx = min(int(end_time * fps), vlen)
90
+ else:
91
+ start_idx, end_idx = 0, vlen
92
+
93
+ if self.frm_sampling_strategy == 'uniform':
94
+ frame_indices = np.arange(start_idx, end_idx, vlen / self.num_frm, dtype=int)
95
+ elif self.frm_sampling_strategy == 'rand':
96
+ frame_indices = sorted(random.sample(range(vlen), self.num_frm))
97
+ elif self.frm_sampling_strategy == 'headtail':
98
+ frame_indices_head = sorted(random.sample(range(vlen // 2), self.num_frm // 2))
99
+ frame_indices_tail = sorted(random.sample(range(vlen // 2, vlen), self.num_frm // 2))
100
+ frame_indices = frame_indices_head + frame_indices_tail
101
+ else:
102
+ raise NotImplementedError('Invalid sampling strategy {} '.format(self.frm_sampling_strategy))
103
+
104
+ raw_sample_frms = vr.get_batch(frame_indices)
105
+ except Exception as e:
106
+ return None
107
+
108
+ raw_sample_frms = raw_sample_frms.permute(0, 3, 1, 2)
109
+
110
+ return raw_sample_frms
data/vqa_dataset.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+ from PIL import Image
5
+
6
+ import torch
7
+ from torch.utils.data import Dataset
8
+ from data.utils import pre_question
9
+
10
+ from torchvision.datasets.utils import download_url
11
+
12
+ class vqa_dataset(Dataset):
13
+ def __init__(self, transform, ann_root, vqa_root, vg_root, train_files=[], split="train"):
14
+ self.split = split
15
+
16
+ self.transform = transform
17
+ self.vqa_root = vqa_root
18
+ self.vg_root = vg_root
19
+
20
+ if split=='train':
21
+ urls = {'vqa_train':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_train.json',
22
+ 'vqa_val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_val.json',
23
+ 'vg_qa':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vg_qa.json'}
24
+
25
+ self.annotation = []
26
+ for f in train_files:
27
+ download_url(urls[f],ann_root)
28
+ self.annotation += json.load(open(os.path.join(ann_root,'%s.json'%f),'r'))
29
+ else:
30
+ download_url('https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_test.json',ann_root)
31
+ self.annotation = json.load(open(os.path.join(ann_root,'vqa_test.json'),'r'))
32
+
33
+ download_url('https://storage.googleapis.com/sfr-vision-language-research/datasets/answer_list.json',ann_root)
34
+ self.answer_list = json.load(open(os.path.join(ann_root,'answer_list.json'),'r'))
35
+
36
+
37
+ def __len__(self):
38
+ return len(self.annotation)
39
+
40
+ def __getitem__(self, index):
41
+
42
+ ann = self.annotation[index]
43
+
44
+ if ann['dataset']=='vqa':
45
+ image_path = os.path.join(self.vqa_root,ann['image'])
46
+ elif ann['dataset']=='vg':
47
+ image_path = os.path.join(self.vg_root,ann['image'])
48
+
49
+ image = Image.open(image_path).convert('RGB')
50
+ image = self.transform(image)
51
+
52
+ if self.split == 'test':
53
+ question = pre_question(ann['question'])
54
+ question_id = ann['question_id']
55
+ return image, question, question_id
56
+
57
+
58
+ elif self.split=='train':
59
+
60
+ question = pre_question(ann['question'])
61
+
62
+ if ann['dataset']=='vqa':
63
+ answer_weight = {}
64
+ for answer in ann['answer']:
65
+ if answer in answer_weight.keys():
66
+ answer_weight[answer] += 1/len(ann['answer'])
67
+ else:
68
+ answer_weight[answer] = 1/len(ann['answer'])
69
+
70
+ answers = list(answer_weight.keys())
71
+ weights = list(answer_weight.values())
72
+
73
+ elif ann['dataset']=='vg':
74
+ answers = [ann['answer']]
75
+ weights = [0.2]
76
+
77
+ return image, question, answers, weights
78
+
79
+
80
+ def vqa_collate_fn(batch):
81
+ image_list, question_list, answer_list, weight_list, n = [], [], [], [], []
82
+ for image, question, answer, weights in batch:
83
+ image_list.append(image)
84
+ question_list.append(question)
85
+ weight_list += weights
86
+ answer_list += answer
87
+ n.append(len(answer))
88
+ return torch.stack(image_list,dim=0), question_list, answer_list, torch.Tensor(weight_list), n
demo.ipynb ADDED
The diff for this file is too large to render. See raw diff
eval_nocaps.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ '''
8
+ import argparse
9
+ import os
10
+ import ruamel_yaml as yaml
11
+ import numpy as np
12
+ import random
13
+ import time
14
+ import datetime
15
+ import json
16
+ from pathlib import Path
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ import torch.backends.cudnn as cudnn
22
+ import torch.distributed as dist
23
+ from torch.utils.data import DataLoader
24
+
25
+ from models.blip import blip_decoder
26
+ import utils
27
+ from data import create_dataset, create_sampler, create_loader
28
+ from data.utils import save_result
29
+
30
+ @torch.no_grad()
31
+ def evaluate(model, data_loader, device, config):
32
+ # evaluate
33
+ model.eval()
34
+
35
+ metric_logger = utils.MetricLogger(delimiter=" ")
36
+ header = 'Evaluation:'
37
+ print_freq = 10
38
+
39
+ result = []
40
+ for image, image_id in metric_logger.log_every(data_loader, print_freq, header):
41
+
42
+ image = image.to(device)
43
+
44
+ captions = model.generate(image, sample=False, num_beams=config['num_beams'], max_length=config['max_length'],
45
+ min_length=config['min_length'], repetition_penalty=1.1)
46
+
47
+ for caption, img_id in zip(captions, image_id):
48
+ result.append({"image_id": img_id.item(), "caption": caption})
49
+
50
+ return result
51
+
52
+
53
+ def main(args, config):
54
+ utils.init_distributed_mode(args)
55
+
56
+ device = torch.device(args.device)
57
+
58
+ # fix the seed for reproducibility
59
+ seed = args.seed + utils.get_rank()
60
+ torch.manual_seed(seed)
61
+ np.random.seed(seed)
62
+ random.seed(seed)
63
+ cudnn.benchmark = True
64
+
65
+ #### Dataset ####
66
+ print("Creating captioning dataset")
67
+ val_dataset, test_dataset = create_dataset('nocaps', config)
68
+
69
+ if args.distributed:
70
+ num_tasks = utils.get_world_size()
71
+ global_rank = utils.get_rank()
72
+ samplers = create_sampler([val_dataset,test_dataset], [False,False], num_tasks, global_rank)
73
+ else:
74
+ samplers = [None,None]
75
+
76
+ val_loader, test_loader = create_loader([val_dataset, test_dataset],samplers,
77
+ batch_size=[config['batch_size']]*2,num_workers=[4,4],
78
+ is_trains=[False, False], collate_fns=[None,None])
79
+
80
+ #### Model ####
81
+ print("Creating model")
82
+ model = blip_decoder(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'],
83
+ prompt=config['prompt'])
84
+
85
+ model = model.to(device)
86
+
87
+ model_without_ddp = model
88
+ if args.distributed:
89
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
90
+ model_without_ddp = model.module
91
+
92
+ val_result = evaluate(model_without_ddp, val_loader, device, config)
93
+ val_result_file = save_result(val_result, args.result_dir, 'val', remove_duplicate='image_id')
94
+ test_result = evaluate(model_without_ddp, test_loader, device, config)
95
+ test_result_file = save_result(test_result, args.result_dir, 'test', remove_duplicate='image_id')
96
+
97
+
98
+ if __name__ == '__main__':
99
+ parser = argparse.ArgumentParser()
100
+ parser.add_argument('--config', default='./configs/nocaps.yaml')
101
+ parser.add_argument('--output_dir', default='output/NoCaps')
102
+ parser.add_argument('--device', default='cuda')
103
+ parser.add_argument('--seed', default=42, type=int)
104
+ parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
105
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
106
+ parser.add_argument('--distributed', default=True, type=bool)
107
+ args = parser.parse_args()
108
+
109
+ config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
110
+
111
+ args.result_dir = os.path.join(args.output_dir, 'result')
112
+
113
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
114
+ Path(args.result_dir).mkdir(parents=True, exist_ok=True)
115
+
116
+ yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
117
+
118
+ main(args, config)
eval_retrieval_video.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ '''
8
+ import argparse
9
+ import os
10
+ import ruamel_yaml as yaml
11
+ import numpy as np
12
+ import random
13
+ import time
14
+ import datetime
15
+ import json
16
+ from pathlib import Path
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ import torch.backends.cudnn as cudnn
22
+ import torch.distributed as dist
23
+ from torch.utils.data import DataLoader
24
+
25
+ from models.blip_retrieval import blip_retrieval
26
+ import utils
27
+ from data.video_dataset import VideoDataset
28
+
29
+
30
+ @torch.no_grad()
31
+ def evaluation(model, data_loader, tokenizer, device, config):
32
+ # test
33
+ model.eval()
34
+
35
+ metric_logger = utils.MetricLogger(delimiter=" ")
36
+ header = 'Evaluation:'
37
+
38
+ print('Computing features for evaluation...')
39
+ start_time = time.time()
40
+
41
+ texts = data_loader.dataset.text
42
+ num_text = len(texts)
43
+ text_bs = 256
44
+ text_ids = []
45
+ text_embeds = []
46
+ text_atts = []
47
+ for i in range(0, num_text, text_bs):
48
+ text = texts[i: min(num_text, i+text_bs)]
49
+ text_input = tokenizer(text, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(device)
50
+ text_output = model.text_encoder(text_input.input_ids, attention_mask = text_input.attention_mask, mode='text')
51
+ text_embed = F.normalize(model.text_proj(text_output.last_hidden_state[:,0,:]))
52
+ text_embeds.append(text_embed)
53
+ text_ids.append(text_input.input_ids)
54
+ text_atts.append(text_input.attention_mask)
55
+
56
+ text_embeds = torch.cat(text_embeds,dim=0)
57
+ text_ids = torch.cat(text_ids,dim=0)
58
+ text_atts = torch.cat(text_atts,dim=0)
59
+ text_ids[:,0] = tokenizer.additional_special_tokens_ids[0]
60
+
61
+ video_feats = []
62
+ video_embeds = []
63
+ for video, video_id in data_loader:
64
+
65
+ B,N,C,W,H = video.size()
66
+ video = video.view(-1,C,W,H)
67
+ video = video.to(device,non_blocking=True)
68
+ video_feat = model.visual_encoder(video)
69
+ video_embed = model.vision_proj(video_feat[:,0,:])
70
+ video_embed = video_embed.view(B,N,-1).mean(dim=1)
71
+ video_embed = F.normalize(video_embed,dim=-1)
72
+
73
+ video_feat = video_feat.view(B,-1,video_feat.shape[-1])
74
+ video_feats.append(video_feat.cpu())
75
+ video_embeds.append(video_embed)
76
+
77
+ video_feats = torch.cat(video_feats,dim=0)
78
+ video_embeds = torch.cat(video_embeds,dim=0)
79
+
80
+ sims_matrix = video_embeds @ text_embeds.t()
81
+ score_matrix_v2t = torch.full((len(texts),len(texts)),-100.0).to(device)
82
+
83
+ num_tasks = utils.get_world_size()
84
+ rank = utils.get_rank()
85
+ step = sims_matrix.size(0)//num_tasks + 1
86
+ start = rank*step
87
+ end = min(sims_matrix.size(0),start+step)
88
+
89
+ for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)):
90
+ topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)
91
+
92
+ encoder_output = video_feats[start+i].repeat(config['k_test'],1,1).to(device,non_blocking=True)
93
+ encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device,non_blocking=True)
94
+ output = model.text_encoder(text_ids[topk_idx],
95
+ attention_mask = text_atts[topk_idx],
96
+ encoder_hidden_states = encoder_output,
97
+ encoder_attention_mask = encoder_att,
98
+ return_dict = True,
99
+ )
100
+ score = model.itm_head(output.last_hidden_state[:,0,:])[:,1]
101
+ score_matrix_v2t[start+i,topk_idx] = score + topk_sim
102
+
103
+ sims_matrix = sims_matrix.t()
104
+ score_matrix_t2v = torch.full((len(texts),len(texts)),-100.0).to(device)
105
+
106
+ step = sims_matrix.size(0)//num_tasks + 1
107
+ start = rank*step
108
+ end = min(sims_matrix.size(0),start+step)
109
+
110
+ for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)):
111
+
112
+ topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)
113
+ encoder_output = video_feats[topk_idx].to(device,non_blocking=True)
114
+ encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device,non_blocking=True)
115
+ output = model.text_encoder(text_ids[start+i].repeat(config['k_test'],1),
116
+ attention_mask = text_atts[start+i].repeat(config['k_test'],1),
117
+ encoder_hidden_states = encoder_output,
118
+ encoder_attention_mask = encoder_att,
119
+ return_dict = True,
120
+ )
121
+ score = model.itm_head(output.last_hidden_state[:,0,:])[:,1]
122
+ score_matrix_t2v[start+i,topk_idx] = score + topk_sim
123
+
124
+ if args.distributed:
125
+ dist.barrier()
126
+ torch.distributed.all_reduce(score_matrix_v2t, op=torch.distributed.ReduceOp.SUM)
127
+ torch.distributed.all_reduce(score_matrix_t2v, op=torch.distributed.ReduceOp.SUM)
128
+
129
+ total_time = time.time() - start_time
130
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
131
+ print('Evaluation time {}'.format(total_time_str))
132
+
133
+ return score_matrix_v2t.cpu().numpy(), score_matrix_t2v.cpu().numpy()
134
+
135
+
136
+
137
+ @torch.no_grad()
138
+ def itm_eval(scores_v2t, scores_t2v, txt2vmg, vid2txt):
139
+
140
+ #Video->Text
141
+ ranks = np.zeros(scores_v2t.shape[0])
142
+ for index,score in enumerate(scores_v2t):
143
+ inds = np.argsort(score)[::-1]
144
+ ranks[index] = np.where(inds == vid2txt[index])[0][0]
145
+
146
+ # Compute metrics
147
+ tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
148
+ tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
149
+ tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
150
+
151
+ #Text->Video
152
+ ranks = np.zeros(scores_t2v.shape[0])
153
+
154
+ for index,score in enumerate(scores_t2v):
155
+ inds = np.argsort(score)[::-1]
156
+ ranks[index] = np.where(inds == txt2vmg[index])[0][0]
157
+
158
+ mdR = np.median(ranks+1)
159
+
160
+ # Compute metrics
161
+ vr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
162
+ vr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
163
+ vr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
164
+
165
+ tr_mean = (tr1 + tr5 + tr10) / 3
166
+ vr_mean = (vr1 + vr5 + vr10) / 3
167
+ r_mean = (tr_mean + vr_mean) / 2
168
+
169
+ eval_result = {'txt_r1': tr1,
170
+ 'txt_r5': tr5,
171
+ 'txt_r10': tr10,
172
+ 'txt_r_mean': tr_mean,
173
+ 'vid_r1': vr1,
174
+ 'vid_r5': vr5,
175
+ 'vid_r10': vr10,
176
+ 'vid_r_mean': vr_mean,
177
+ 'vid_mdR': mdR,
178
+ 'r_mean': r_mean}
179
+ return eval_result
180
+
181
+
182
+
183
+
184
+ def main(args, config):
185
+ utils.init_distributed_mode(args)
186
+
187
+ device = torch.device(args.device)
188
+
189
+ # fix the seed for reproducibility
190
+ seed = args.seed + utils.get_rank()
191
+ torch.manual_seed(seed)
192
+ np.random.seed(seed)
193
+ random.seed(seed)
194
+ cudnn.benchmark = True
195
+
196
+ #### Dataset ####
197
+ print("Creating retrieval dataset")
198
+ test_dataset = VideoDataset(config['video_root'],config['ann_root'],num_frm=config['num_frm_test'],
199
+ max_img_size=config['image_size'], frm_sampling_strategy='uniform')
200
+
201
+ test_loader = DataLoader(
202
+ test_dataset,
203
+ batch_size=config['batch_size'],
204
+ num_workers=4,
205
+ pin_memory=True,
206
+ drop_last=False,
207
+ shuffle=False,
208
+ )
209
+
210
+ #### Model ####
211
+ print("Creating model")
212
+ model = blip_retrieval(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'])
213
+
214
+ model = model.to(device)
215
+
216
+ model_without_ddp = model
217
+ if args.distributed:
218
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
219
+ model_without_ddp = model.module
220
+
221
+ score_v2t, score_t2v, = evaluation(model_without_ddp, test_loader, model_without_ddp.tokenizer, device, config)
222
+
223
+ if utils.is_main_process():
224
+
225
+ test_result = itm_eval(score_v2t, score_t2v, test_loader.dataset.txt2video, test_loader.dataset.video2txt)
226
+ print(test_result)
227
+
228
+ log_stats = {**{f'{k}': v for k, v in test_result.items()},}
229
+ with open(os.path.join(args.output_dir, "test_result.txt"),"a") as f:
230
+ f.write(json.dumps(log_stats) + "\n")
231
+
232
+
233
+ if __name__ == '__main__':
234
+ parser = argparse.ArgumentParser()
235
+ parser.add_argument('--config', default='./configs/retrieval_msrvtt.yaml')
236
+ parser.add_argument('--output_dir', default='output/Retrieval_msrvtt')
237
+ parser.add_argument('--device', default='cuda')
238
+ parser.add_argument('--seed', default=42, type=int)
239
+ parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
240
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
241
+ parser.add_argument('--distributed', default=True, type=bool)
242
+ args = parser.parse_args()
243
+
244
+ config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
245
+
246
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
247
+
248
+ yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
249
+
250
+ main(args, config)
imagecaptioning.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import requests
3
+ import torch
4
+ from torchvision import transforms
5
+ import os
6
+ from torchvision.transforms.functional import InterpolationMode
7
+ import matplotlib.pyplot as plt
8
+ import matplotlib.image as mpimg
9
+
10
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
+
12
+ from models.blip import blip_decoder
13
+
14
+ image_size = 384
15
+ transform = transforms.Compose([
16
+ transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
17
+ transforms.ToTensor(),
18
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
19
+ ])
20
+
21
+ model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth'
22
+
23
+ model = blip_decoder(pretrained=model_url, image_size=384, vit='large')
24
+ model.eval()
25
+ model = model.to(device)
26
+
27
+
28
+ from models.blip_vqa import blip_vqa
29
+
30
+ image_size_vq = 480
31
+ transform_vq = transforms.Compose([
32
+ transforms.Resize((image_size_vq,image_size_vq),interpolation=InterpolationMode.BICUBIC),
33
+ transforms.ToTensor(),
34
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
35
+ ])
36
+
37
+ model_url_vq = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_vqa.pth'
38
+
39
+ model_vq = blip_vqa(pretrained=model_url_vq, image_size=480, vit='base')
40
+ model_vq.eval()
41
+ model_vq = model_vq.to(device)
42
+
43
+
44
+
45
+ def inference(raw_image, model_n, question="", strategy=""):
46
+ if model_n == 'Image Captioning':
47
+ image = transform(raw_image).unsqueeze(0).to(device)
48
+ with torch.no_grad():
49
+ if strategy == "Beam search":
50
+ caption = model.generate(image, sample=False, num_beams=3, max_length=20, min_length=5)
51
+ else:
52
+ caption = model.generate(image, sample=True, top_p=0.9, max_length=20, min_length=5)
53
+ return 'caption: '+caption[0]
54
+
55
+ else:
56
+ image_vq = transform_vq(raw_image).unsqueeze(0).to(device)
57
+ with torch.no_grad():
58
+ answer = model_vq(image_vq, question, train=False, inference='generate')
59
+ return 'answer: '+answer[0]
60
+
61
+ #get caption for a single iamge
62
+ def get_caption(image_path):
63
+ img = Image.open(image_path)
64
+ return inference(img, "Image Captioning")[9:]
65
+
66
+ def display(image_path):
67
+ img = mpimg.imread(image_path)
68
+ img = Image.open(image_path)
69
+ plt.imshow(img)
70
+ print("Caption: " + get_caption(image_path))
71
+
72
+ #returns a dictionary with key -> img_path and value -> caption
73
+ def get_captions(img_directory, print_status=True):
74
+ #key is img path, value is the caption
75
+ captions = {}
76
+ length = 0
77
+ for file in os.listdir(img_directory):
78
+ length+=1
79
+ count = 0
80
+ for file in os.listdir(img_directory):
81
+ f = os.path.join(img_directory, file)
82
+ captions[f] = inference(Image.open(f), "Image Captioning")
83
+ if print_status:
84
+ print("Images complete:", str(count) + "/" + str(length))
85
+ print("Caption:", captions[f])
86
+ return captions
87
+ #writes dictionary to file, key and value seperated by ':'
88
+ def write_to_file(filename, caption_dict):
89
+ with open(filename, "w") as file:
90
+ for i in caption_dict:
91
+ file.write(i + ":" + caption_dict[i])
92
+ file.close()
93
+
logo.png ADDED
models/__init__.py ADDED
File without changes
models/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (179 Bytes). View file
models/__pycache__/blip.cpython-38.pyc ADDED
Binary file (6.98 kB). View file
models/__pycache__/blip_vqa.cpython-38.pyc ADDED
Binary file (4.91 kB). View file
models/__pycache__/med.cpython-38.pyc ADDED
Binary file (28.2 kB). View file
models/__pycache__/vit.cpython-38.pyc ADDED
Binary file (12.4 kB). View file
models/blip.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ '''
8
+ import warnings
9
+ warnings.filterwarnings("ignore")
10
+
11
+ from models.vit import VisionTransformer, interpolate_pos_embed
12
+ from models.med import BertConfig, BertModel, BertLMHeadModel
13
+ from transformers import BertTokenizer
14
+
15
+ import torch
16
+ from torch import nn
17
+ import torch.nn.functional as F
18
+
19
+ import os
20
+ from urllib.parse import urlparse
21
+ from timm.models.hub import download_cached_file
22
+
23
+ class BLIP_Base(nn.Module):
24
+ def __init__(self,
25
+ med_config = 'configs/med_config.json',
26
+ image_size = 224,
27
+ vit = 'base',
28
+ vit_grad_ckpt = False,
29
+ vit_ckpt_layer = 0,
30
+ ):
31
+ """
32
+ Args:
33
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
34
+ image_size (int): input image size
35
+ vit (str): model size of vision transformer
36
+ """
37
+ super().__init__()
38
+
39
+ self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
40
+ self.tokenizer = init_tokenizer()
41
+ med_config = BertConfig.from_json_file(med_config)
42
+ med_config.encoder_width = vision_width
43
+ self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
44
+
45
+
46
+ def forward(self, image, caption, mode):
47
+
48
+ assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal"
49
+ text = self.tokenizer(caption, return_tensors="pt").to(image.device)
50
+
51
+ if mode=='image':
52
+ # return image features
53
+ image_embeds = self.visual_encoder(image)
54
+ return image_embeds
55
+
56
+ elif mode=='text':
57
+ # return text features
58
+ text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
59
+ return_dict = True, mode = 'text')
60
+ return text_output.last_hidden_state
61
+
62
+ elif mode=='multimodal':
63
+ # return multimodel features
64
+ image_embeds = self.visual_encoder(image)
65
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
66
+
67
+ text.input_ids[:,0] = self.tokenizer.enc_token_id
68
+ output = self.text_encoder(text.input_ids,
69
+ attention_mask = text.attention_mask,
70
+ encoder_hidden_states = image_embeds,
71
+ encoder_attention_mask = image_atts,
72
+ return_dict = True,
73
+ )
74
+ return output.last_hidden_state
75
+
76
+
77
+
78
+ class BLIP_Decoder(nn.Module):
79
+ def __init__(self,
80
+ med_config = 'configs/med_config.json',
81
+ image_size = 384,
82
+ vit = 'base',
83
+ vit_grad_ckpt = False,
84
+ vit_ckpt_layer = 0,
85
+ prompt = 'a picture of ',
86
+ ):
87
+ """
88
+ Args:
89
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
90
+ image_size (int): input image size
91
+ vit (str): model size of vision transformer
92
+ """
93
+ super().__init__()
94
+
95
+ self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
96
+ self.tokenizer = init_tokenizer()
97
+ med_config = BertConfig.from_json_file(med_config)
98
+ med_config.encoder_width = vision_width
99
+ self.text_decoder = BertLMHeadModel(config=med_config)
100
+
101
+ self.prompt = prompt
102
+ self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1
103
+
104
+
105
+ def forward(self, image, caption):
106
+
107
+ image_embeds = self.visual_encoder(image)
108
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
109
+
110
+ text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device)
111
+
112
+ text.input_ids[:,0] = self.tokenizer.bos_token_id
113
+
114
+ decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100)
115
+ decoder_targets[:,:self.prompt_length] = -100
116
+
117
+ decoder_output = self.text_decoder(text.input_ids,
118
+ attention_mask = text.attention_mask,
119
+ encoder_hidden_states = image_embeds,
120
+ encoder_attention_mask = image_atts,
121
+ labels = decoder_targets,
122
+ return_dict = True,
123
+ )
124
+ loss_lm = decoder_output.loss
125
+
126
+ return loss_lm
127
+
128
+ def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0):
129
+ image_embeds = self.visual_encoder(image)
130
+
131
+ if not sample:
132
+ image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
133
+
134
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
135
+ model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts}
136
+
137
+ prompt = [self.prompt] * image.size(0)
138
+ input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device)
139
+ input_ids[:,0] = self.tokenizer.bos_token_id
140
+ input_ids = input_ids[:, :-1]
141
+
142
+ if sample:
143
+ #nucleus sampling
144
+ outputs = self.text_decoder.generate(input_ids=input_ids,
145
+ max_length=max_length,
146
+ min_length=min_length,
147
+ do_sample=True,
148
+ top_p=top_p,
149
+ num_return_sequences=1,
150
+ eos_token_id=self.tokenizer.sep_token_id,
151
+ pad_token_id=self.tokenizer.pad_token_id,
152
+ repetition_penalty=1.1,
153
+ **model_kwargs)
154
+ else:
155
+ #beam search
156
+ outputs = self.text_decoder.generate(input_ids=input_ids,
157
+ max_length=max_length,
158
+ min_length=min_length,
159
+ num_beams=num_beams,
160
+ eos_token_id=self.tokenizer.sep_token_id,
161
+ pad_token_id=self.tokenizer.pad_token_id,
162
+ repetition_penalty=repetition_penalty,
163
+ **model_kwargs)
164
+
165
+ captions = []
166
+ for output in outputs:
167
+ caption = self.tokenizer.decode(output, skip_special_tokens=True)
168
+ captions.append(caption[len(self.prompt):])
169
+ return captions
170
+
171
+
172
+ def blip_decoder(pretrained='',**kwargs):
173
+ model = BLIP_Decoder(**kwargs)
174
+ if pretrained:
175
+ model,msg = load_checkpoint(model,pretrained)
176
+ assert(len(msg.missing_keys)==0)
177
+ return model
178
+
179
+ def blip_feature_extractor(pretrained='',**kwargs):
180
+ model = BLIP_Base(**kwargs)
181
+ if pretrained:
182
+ model,msg = load_checkpoint(model,pretrained)
183
+ assert(len(msg.missing_keys)==0)
184
+ return model
185
+
186
+ def init_tokenizer():
187
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
188
+ tokenizer.add_special_tokens({'bos_token':'[DEC]'})
189
+ tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
190
+ tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
191
+ return tokenizer
192
+
193
+
194
+ def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
195
+
196
+ assert vit in ['base', 'large'], "vit parameter must be base or large"
197
+ if vit=='base':
198
+ vision_width = 768
199
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
200
+ num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
201
+ drop_path_rate=0 or drop_path_rate
202
+ )
203
+ elif vit=='large':
204
+ vision_width = 1024
205
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
206
+ num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
207
+ drop_path_rate=0.1 or drop_path_rate
208
+ )
209
+ return visual_encoder, vision_width
210
+
211
+ def is_url(url_or_filename):
212
+ parsed = urlparse(url_or_filename)
213
+ return parsed.scheme in ("http", "https")
214
+
215
+ def load_checkpoint(model,url_or_filename):
216
+ if is_url(url_or_filename):
217
+ cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
218
+ checkpoint = torch.load(cached_file, map_location='cpu')
219
+ elif os.path.isfile(url_or_filename):
220
+ checkpoint = torch.load(url_or_filename, map_location='cpu')
221
+ else:
222
+ raise RuntimeError('checkpoint url or path is invalid')
223
+
224
+ state_dict = checkpoint['model']
225
+
226
+ state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
227
+ if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
228
+ state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
229
+ model.visual_encoder_m)
230
+ for key in model.state_dict().keys():
231
+ if key in state_dict.keys():
232
+ if state_dict[key].shape!=model.state_dict()[key].shape:
233
+ del state_dict[key]
234
+
235
+ msg = model.load_state_dict(state_dict,strict=False)
236
+ print('load checkpoint from %s'%url_or_filename)
237
+ return model,msg
238
+
models/blip_itm.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.med import BertConfig, BertModel
2
+ from transformers import BertTokenizer
3
+
4
+ import torch
5
+ from torch import nn
6
+ import torch.nn.functional as F
7
+
8
+ from models.blip import create_vit, init_tokenizer, load_checkpoint
9
+
10
+ class BLIP_ITM(nn.Module):
11
+ def __init__(self,
12
+ med_config = 'configs/med_config.json',
13
+ image_size = 384,
14
+ vit = 'base',
15
+ vit_grad_ckpt = False,
16
+ vit_ckpt_layer = 0,
17
+ embed_dim = 256,
18
+ ):
19
+ """
20
+ Args:
21
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
22
+ image_size (int): input image size
23
+ vit (str): model size of vision transformer
24
+ """
25
+ super().__init__()
26
+
27
+ self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
28
+ self.tokenizer = init_tokenizer()
29
+ med_config = BertConfig.from_json_file(med_config)
30
+ med_config.encoder_width = vision_width
31
+ self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
32
+
33
+ text_width = self.text_encoder.config.hidden_size
34
+
35
+ self.vision_proj = nn.Linear(vision_width, embed_dim)
36
+ self.text_proj = nn.Linear(text_width, embed_dim)
37
+
38
+ self.itm_head = nn.Linear(text_width, 2)
39
+
40
+
41
+ def forward(self, image, caption, match_head='itm'):
42
+
43
+ image_embeds = self.visual_encoder(image)
44
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
45
+
46
+ text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35,
47
+ return_tensors="pt").to(image.device)
48
+
49
+
50
+ if match_head=='itm':
51
+ output = self.text_encoder(text.input_ids,
52
+ attention_mask = text.attention_mask,
53
+ encoder_hidden_states = image_embeds,
54
+ encoder_attention_mask = image_atts,
55
+ return_dict = True,
56
+ )
57
+ itm_output = self.itm_head(output.last_hidden_state[:,0,:])
58
+ return itm_output
59
+
60
+ elif match_head=='itc':
61
+ text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
62
+ return_dict = True, mode = 'text')
63
+ image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)
64
+ text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1)
65
+
66
+ sim = image_feat @ text_feat.t()
67
+ return sim
68
+
69
+
70
+ def blip_itm(pretrained='',**kwargs):
71
+ model = BLIP_ITM(**kwargs)
72
+ if pretrained:
73
+ model,msg = load_checkpoint(model,pretrained)
74
+ assert(len(msg.missing_keys)==0)
75
+ return model
76
+
models/blip_nlvr.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.med import BertConfig
2
+ from models.nlvr_encoder import BertModel
3
+ from models.vit import interpolate_pos_embed
4
+ from models.blip import create_vit, init_tokenizer, is_url
5
+
6
+ from timm.models.hub import download_cached_file
7
+
8
+ import torch
9
+ from torch import nn
10
+ import torch.nn.functional as F
11
+ from transformers import BertTokenizer
12
+ import numpy as np
13
+
14
+ class BLIP_NLVR(nn.Module):
15
+ def __init__(self,
16
+ med_config = 'configs/med_config.json',
17
+ image_size = 480,
18
+ vit = 'base',
19
+ vit_grad_ckpt = False,
20
+ vit_ckpt_layer = 0,
21
+ ):
22
+ """
23
+ Args:
24
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
25
+ image_size (int): input image size
26
+ vit (str): model size of vision transformer
27
+ """
28
+ super().__init__()
29
+
30
+ self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1)
31
+ self.tokenizer = init_tokenizer()
32
+ med_config = BertConfig.from_json_file(med_config)
33
+ med_config.encoder_width = vision_width
34
+ self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
35
+
36
+ self.cls_head = nn.Sequential(
37
+ nn.Linear(self.text_encoder.config.hidden_size, self.text_encoder.config.hidden_size),
38
+ nn.ReLU(),
39
+ nn.Linear(self.text_encoder.config.hidden_size, 2)
40
+ )
41
+
42
+ def forward(self, image, text, targets, train=True):
43
+
44
+ image_embeds = self.visual_encoder(image)
45
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
46
+ image0_embeds, image1_embeds = torch.split(image_embeds,targets.size(0))
47
+
48
+ text = self.tokenizer(text, padding='longest', return_tensors="pt").to(image.device)
49
+ text.input_ids[:,0] = self.tokenizer.enc_token_id
50
+
51
+ output = self.text_encoder(text.input_ids,
52
+ attention_mask = text.attention_mask,
53
+ encoder_hidden_states = [image0_embeds,image1_embeds],
54
+ encoder_attention_mask = [image_atts[:image0_embeds.size(0)],
55
+ image_atts[image0_embeds.size(0):]],
56
+ return_dict = True,
57
+ )
58
+ hidden_state = output.last_hidden_state[:,0,:]
59
+ prediction = self.cls_head(hidden_state)
60
+
61
+ if train:
62
+ loss = F.cross_entropy(prediction, targets)
63
+ return loss
64
+ else:
65
+ return prediction
66
+
67
+ def blip_nlvr(pretrained='',**kwargs):
68
+ model = BLIP_NLVR(**kwargs)
69
+ if pretrained:
70
+ model,msg = load_checkpoint(model,pretrained)
71
+ print("missing keys:")
72
+ print(msg.missing_keys)
73
+ return model
74
+
75
+
76
+ def load_checkpoint(model,url_or_filename):
77
+ if is_url(url_or_filename):
78
+ cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
79
+ checkpoint = torch.load(cached_file, map_location='cpu')
80
+ elif os.path.isfile(url_or_filename):
81
+ checkpoint = torch.load(url_or_filename, map_location='cpu')
82
+ else:
83
+ raise RuntimeError('checkpoint url or path is invalid')
84
+ state_dict = checkpoint['model']
85
+
86
+ state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
87
+
88
+ for key in list(state_dict.keys()):
89
+ if 'crossattention.self.' in key:
90
+ new_key0 = key.replace('self','self0')
91
+ new_key1 = key.replace('self','self1')
92
+ state_dict[new_key0] = state_dict[key]
93
+ state_dict[new_key1] = state_dict[key]
94
+ elif 'crossattention.output.dense.' in key:
95
+ new_key0 = key.replace('dense','dense0')
96
+ new_key1 = key.replace('dense','dense1')
97
+ state_dict[new_key0] = state_dict[key]
98
+ state_dict[new_key1] = state_dict[key]
99
+
100
+ msg = model.load_state_dict(state_dict,strict=False)
101
+ print('load checkpoint from %s'%url_or_filename)
102
+ return model,msg
103
+
models/blip_pretrain.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ '''
8
+ from models.med import BertConfig, BertModel, BertLMHeadModel
9
+ from transformers import BertTokenizer
10
+ import transformers
11
+ transformers.logging.set_verbosity_error()
12
+
13
+ import torch
14
+ from torch import nn
15
+ import torch.nn.functional as F
16
+
17
+ from models.blip import create_vit, init_tokenizer, load_checkpoint
18
+
19
+ class BLIP_Pretrain(nn.Module):
20
+ def __init__(self,
21
+ med_config = 'configs/bert_config.json',
22
+ image_size = 224,
23
+ vit = 'base',
24
+ vit_grad_ckpt = False,
25
+ vit_ckpt_layer = 0,
26
+ embed_dim = 256,
27
+ queue_size = 57600,
28
+ momentum = 0.995,
29
+ ):
30
+ """
31
+ Args:
32
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
33
+ image_size (int): input image size
34
+ vit (str): model size of vision transformer
35
+ """
36
+ super().__init__()
37
+
38
+ self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, 0)
39
+
40
+ if vit=='base':
41
+ checkpoint = torch.hub.load_state_dict_from_url(
42
+ url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
43
+ map_location="cpu", check_hash=True)
44
+ state_dict = checkpoint["model"]
45
+ msg = self.visual_encoder.load_state_dict(state_dict,strict=False)
46
+ elif vit=='large':
47
+ from timm.models.helpers import load_custom_pretrained
48
+ from timm.models.vision_transformer import default_cfgs
49
+ load_custom_pretrained(self.visual_encoder,default_cfgs['vit_large_patch16_224_in21k'])
50
+
51
+ self.tokenizer = init_tokenizer()
52
+ encoder_config = BertConfig.from_json_file(med_config)
53
+ encoder_config.encoder_width = vision_width
54
+ self.text_encoder = BertModel.from_pretrained('bert-base-uncased',config=encoder_config, add_pooling_layer=False)
55
+ self.text_encoder.resize_token_embeddings(len(self.tokenizer))
56
+
57
+ text_width = self.text_encoder.config.hidden_size
58
+
59
+ self.vision_proj = nn.Linear(vision_width, embed_dim)
60
+ self.text_proj = nn.Linear(text_width, embed_dim)
61
+
62
+ self.itm_head = nn.Linear(text_width, 2)
63
+
64
+ # create momentum encoders
65
+ self.visual_encoder_m, vision_width = create_vit(vit,image_size)
66
+ self.vision_proj_m = nn.Linear(vision_width, embed_dim)
67
+ self.text_encoder_m = BertModel(config=encoder_config, add_pooling_layer=False)
68
+ self.text_proj_m = nn.Linear(text_width, embed_dim)
69
+
70
+ self.model_pairs = [[self.visual_encoder,self.visual_encoder_m],
71
+ [self.vision_proj,self.vision_proj_m],
72
+ [self.text_encoder,self.text_encoder_m],
73
+ [self.text_proj,self.text_proj_m],
74
+ ]
75
+ self.copy_params()
76
+
77
+ # create the queue
78
+ self.register_buffer("image_queue", torch.randn(embed_dim, queue_size))
79
+ self.register_buffer("text_queue", torch.randn(embed_dim, queue_size))
80
+ self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
81
+
82
+ self.image_queue = nn.functional.normalize(self.image_queue, dim=0)
83
+ self.text_queue = nn.functional.normalize(self.text_queue, dim=0)
84
+
85
+ self.queue_size = queue_size
86
+ self.momentum = momentum
87
+ self.temp = nn.Parameter(0.07*torch.ones([]))
88
+
89
+ # create the decoder
90
+ decoder_config = BertConfig.from_json_file(med_config)
91
+ decoder_config.encoder_width = vision_width
92
+ self.text_decoder = BertLMHeadModel.from_pretrained('bert-base-uncased',config=decoder_config)
93
+ self.text_decoder.resize_token_embeddings(len(self.tokenizer))
94
+ tie_encoder_decoder_weights(self.text_encoder,self.text_decoder.bert,'','/attention')
95
+
96
+
97
+ def forward(self, image, caption, alpha):
98
+ with torch.no_grad():
99
+ self.temp.clamp_(0.001,0.5)
100
+
101
+ image_embeds = self.visual_encoder(image)
102
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
103
+ image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)
104
+
105
+ text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=30,
106
+ return_tensors="pt").to(image.device)
107
+ text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
108
+ return_dict = True, mode = 'text')
109
+ text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1)
110
+
111
+ # get momentum features
112
+ with torch.no_grad():
113
+ self._momentum_update()
114
+ image_embeds_m = self.visual_encoder_m(image)
115
+ image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1)
116
+ image_feat_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1)
117
+
118
+ text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask,
119
+ return_dict = True, mode = 'text')
120
+ text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1)
121
+ text_feat_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1)
122
+
123
+ sim_i2t_m = image_feat_m @ text_feat_all / self.temp
124
+ sim_t2i_m = text_feat_m @ image_feat_all / self.temp
125
+
126
+ sim_targets = torch.zeros(sim_i2t_m.size()).to(image.device)
127
+ sim_targets.fill_diagonal_(1)
128
+
129
+ sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets
130
+ sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets
131
+
132
+ sim_i2t = image_feat @ text_feat_all / self.temp
133
+ sim_t2i = text_feat @ image_feat_all / self.temp
134
+
135
+ loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean()
136
+ loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean()
137
+
138
+ loss_ita = (loss_i2t+loss_t2i)/2
139
+
140
+ self._dequeue_and_enqueue(image_feat_m, text_feat_m)
141
+
142
+ ###============== Image-text Matching ===================###
143
+ encoder_input_ids = text.input_ids.clone()
144
+ encoder_input_ids[:,0] = self.tokenizer.enc_token_id
145
+
146
+ # forward the positve image-text pair
147
+ bs = image.size(0)
148
+ output_pos = self.text_encoder(encoder_input_ids,
149
+ attention_mask = text.attention_mask,
150
+ encoder_hidden_states = image_embeds,
151
+ encoder_attention_mask = image_atts,
152
+ return_dict = True,
153
+ )
154
+ with torch.no_grad():
155
+ weights_t2i = F.softmax(sim_t2i[:,:bs],dim=1)+1e-4
156
+ weights_t2i.fill_diagonal_(0)
157
+ weights_i2t = F.softmax(sim_i2t[:,:bs],dim=1)+1e-4
158
+ weights_i2t.fill_diagonal_(0)
159
+
160
+ # select a negative image for each text
161
+ image_embeds_neg = []
162
+ for b in range(bs):
163
+ neg_idx = torch.multinomial(weights_t2i[b], 1).item()
164
+ image_embeds_neg.append(image_embeds[neg_idx])
165
+ image_embeds_neg = torch.stack(image_embeds_neg,dim=0)
166
+
167
+ # select a negative text for each image
168
+ text_ids_neg = []
169
+ text_atts_neg = []
170
+ for b in range(bs):
171
+ neg_idx = torch.multinomial(weights_i2t[b], 1).item()
172
+ text_ids_neg.append(encoder_input_ids[neg_idx])
173
+ text_atts_neg.append(text.attention_mask[neg_idx])
174
+
175
+ text_ids_neg = torch.stack(text_ids_neg,dim=0)
176
+ text_atts_neg = torch.stack(text_atts_neg,dim=0)
177
+
178
+ text_ids_all = torch.cat([encoder_input_ids, text_ids_neg],dim=0)
179
+ text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0)
180
+
181
+ image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0)
182
+ image_atts_all = torch.cat([image_atts,image_atts],dim=0)
183
+
184
+ output_neg = self.text_encoder(text_ids_all,
185
+ attention_mask = text_atts_all,
186
+ encoder_hidden_states = image_embeds_all,
187
+ encoder_attention_mask = image_atts_all,
188
+ return_dict = True,
189
+ )
190
+
191
+ vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0)
192
+ vl_output = self.itm_head(vl_embeddings)
193
+
194
+ itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)],
195
+ dim=0).to(image.device)
196
+ loss_itm = F.cross_entropy(vl_output, itm_labels)
197
+
198
+ ##================= LM ========================##
199
+ decoder_input_ids = text.input_ids.clone()
200
+ decoder_input_ids[:,0] = self.tokenizer.bos_token_id
201
+ decoder_targets = decoder_input_ids.masked_fill(decoder_input_ids == self.tokenizer.pad_token_id, -100)
202
+
203
+ decoder_output = self.text_decoder(decoder_input_ids,
204
+ attention_mask = text.attention_mask,
205
+ encoder_hidden_states = image_embeds,
206
+ encoder_attention_mask = image_atts,
207
+ labels = decoder_targets,
208
+ return_dict = True,
209
+ )
210
+
211
+ loss_lm = decoder_output.loss
212
+ return loss_ita, loss_itm, loss_lm
213
+
214
+
215
+
216
+ @torch.no_grad()
217
+ def copy_params(self):
218
+ for model_pair in self.model_pairs:
219
+ for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
220
+ param_m.data.copy_(param.data) # initialize
221
+ param_m.requires_grad = False # not update by gradient
222
+
223
+
224
+ @torch.no_grad()
225
+ def _momentum_update(self):
226
+ for model_pair in self.model_pairs:
227
+ for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
228
+ param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum)
229
+
230
+
231
+ @torch.no_grad()
232
+ def _dequeue_and_enqueue(self, image_feat, text_feat):
233
+ # gather keys before updating queue
234
+ image_feats = concat_all_gather(image_feat)
235
+ text_feats = concat_all_gather(text_feat)
236
+
237
+ batch_size = image_feats.shape[0]
238
+
239
+ ptr = int(self.queue_ptr)
240
+ assert self.queue_size % batch_size == 0 # for simplicity
241
+
242
+ # replace the keys at ptr (dequeue and enqueue)
243
+ self.image_queue[:, ptr:ptr + batch_size] = image_feats.T
244
+ self.text_queue[:, ptr:ptr + batch_size] = text_feats.T
245
+ ptr = (ptr + batch_size) % self.queue_size # move pointer
246
+
247
+ self.queue_ptr[0] = ptr
248
+
249
+
250
+ def blip_pretrain(**kwargs):
251
+ model = BLIP_Pretrain(**kwargs)
252
+ return model
253
+
254
+
255
+ @torch.no_grad()
256
+ def concat_all_gather(tensor):
257
+ """
258
+ Performs all_gather operation on the provided tensors.
259
+ *** Warning ***: torch.distributed.all_gather has no gradient.
260
+ """
261
+ tensors_gather = [torch.ones_like(tensor)
262
+ for _ in range(torch.distributed.get_world_size())]
263
+ torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
264
+
265
+ output = torch.cat(tensors_gather, dim=0)
266
+ return output
267
+
268
+
269
+ from typing import List
270
+ def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, skip_key:str):
271
+ uninitialized_encoder_weights: List[str] = []
272
+ if decoder.__class__ != encoder.__class__:
273
+ logger.info(
274
+ f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder weights are correctly initialized."
275
+ )
276
+
277
+ def tie_encoder_to_decoder_recursively(
278
+ decoder_pointer: nn.Module,
279
+ encoder_pointer: nn.Module,
280
+ module_name: str,
281
+ uninitialized_encoder_weights: List[str],
282
+ skip_key: str,
283
+ depth=0,
284
+ ):
285
+ assert isinstance(decoder_pointer, nn.Module) and isinstance(
286
+ encoder_pointer, nn.Module
287
+ ), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module"
288
+ if hasattr(decoder_pointer, "weight") and skip_key not in module_name:
289
+ assert hasattr(encoder_pointer, "weight")
290
+ encoder_pointer.weight = decoder_pointer.weight
291
+ if hasattr(decoder_pointer, "bias"):
292
+ assert hasattr(encoder_pointer, "bias")
293
+ encoder_pointer.bias = decoder_pointer.bias
294
+ print(module_name+' is tied')
295
+ return
296
+
297
+ encoder_modules = encoder_pointer._modules
298
+ decoder_modules = decoder_pointer._modules
299
+ if len(decoder_modules) > 0:
300
+ assert (
301
+ len(encoder_modules) > 0
302
+ ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
303
+
304
+ all_encoder_weights = set([module_name + "/" + sub_name for sub_name in encoder_modules.keys()])
305
+ encoder_layer_pos = 0
306
+ for name, module in decoder_modules.items():
307
+ if name.isdigit():
308
+ encoder_name = str(int(name) + encoder_layer_pos)
309
+ decoder_name = name
310
+ if not isinstance(decoder_modules[decoder_name], type(encoder_modules[encoder_name])) and len(
311
+ encoder_modules
312
+ ) != len(decoder_modules):
313
+ # this can happen if the name corresponds to the position in a list module list of layers
314
+ # in this case the decoder has added a cross-attention that the encoder does not have
315
+ # thus skip this step and subtract one layer pos from encoder
316
+ encoder_layer_pos -= 1
317
+ continue
318
+ elif name not in encoder_modules:
319
+ continue
320
+ elif depth > 500:
321
+ raise ValueError(
322
+ "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model."
323
+ )
324
+ else:
325
+ decoder_name = encoder_name = name
326
+ tie_encoder_to_decoder_recursively(
327
+ decoder_modules[decoder_name],
328
+ encoder_modules[encoder_name],
329
+ module_name + "/" + name,
330
+ uninitialized_encoder_weights,
331
+ skip_key,
332
+ depth=depth + 1,
333
+ )
334
+ all_encoder_weights.remove(module_name + "/" + encoder_name)
335
+
336
+ uninitialized_encoder_weights += list(all_encoder_weights)
337
+
338
+ # tie weights recursively
339
+ tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights, skip_key)
models/blip_retrieval.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.med import BertConfig, BertModel
2
+ from transformers import BertTokenizer
3
+
4
+ import torch
5
+ from torch import nn
6
+ import torch.nn.functional as F
7
+
8
+ from models.blip import create_vit, init_tokenizer, load_checkpoint
9
+
10
+ class BLIP_Retrieval(nn.Module):
11
+ def __init__(self,
12
+ med_config = 'configs/med_config.json',
13
+ image_size = 384,
14
+ vit = 'base',
15
+ vit_grad_ckpt = False,
16
+ vit_ckpt_layer = 0,
17
+ embed_dim = 256,
18
+ queue_size = 57600,
19
+ momentum = 0.995,
20
+ negative_all_rank = False,
21
+ ):
22
+ """
23
+ Args:
24
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
25
+ image_size (int): input image size
26
+ vit (str): model size of vision transformer
27
+ """
28
+ super().__init__()
29
+
30
+ self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
31
+ self.tokenizer = init_tokenizer()
32
+ med_config = BertConfig.from_json_file(med_config)
33
+ med_config.encoder_width = vision_width
34
+ self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
35
+
36
+ text_width = self.text_encoder.config.hidden_size
37
+
38
+ self.vision_proj = nn.Linear(vision_width, embed_dim)
39
+ self.text_proj = nn.Linear(text_width, embed_dim)
40
+
41
+ self.itm_head = nn.Linear(text_width, 2)
42
+
43
+ # create momentum encoders
44
+ self.visual_encoder_m, vision_width = create_vit(vit,image_size)
45
+ self.vision_proj_m = nn.Linear(vision_width, embed_dim)
46
+ self.text_encoder_m = BertModel(config=med_config, add_pooling_layer=False)
47
+ self.text_proj_m = nn.Linear(text_width, embed_dim)
48
+
49
+ self.model_pairs = [[self.visual_encoder,self.visual_encoder_m],
50
+ [self.vision_proj,self.vision_proj_m],
51
+ [self.text_encoder,self.text_encoder_m],
52
+ [self.text_proj,self.text_proj_m],
53
+ ]
54
+ self.copy_params()
55
+
56
+ # create the queue
57
+ self.register_buffer("image_queue", torch.randn(embed_dim, queue_size))
58
+ self.register_buffer("text_queue", torch.randn(embed_dim, queue_size))
59
+ self.register_buffer("idx_queue", torch.full((1,queue_size),-100))
60
+ self.register_buffer("ptr_queue", torch.zeros(1, dtype=torch.long))
61
+
62
+ self.image_queue = nn.functional.normalize(self.image_queue, dim=0)
63
+ self.text_queue = nn.functional.normalize(self.text_queue, dim=0)
64
+
65
+ self.queue_size = queue_size
66
+ self.momentum = momentum
67
+ self.temp = nn.Parameter(0.07*torch.ones([]))
68
+
69
+ self.negative_all_rank = negative_all_rank
70
+
71
+
72
+ def forward(self, image, caption, alpha, idx):
73
+ with torch.no_grad():
74
+ self.temp.clamp_(0.001,0.5)
75
+
76
+ image_embeds = self.visual_encoder(image)
77
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
78
+ image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)
79
+
80
+ text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35,
81
+ return_tensors="pt").to(image.device)
82
+
83
+ text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
84
+ return_dict = True, mode = 'text')
85
+ text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1)
86
+
87
+ ###============== Image-text Contrastive Learning ===================###
88
+ idx = idx.view(-1,1)
89
+ idx_all = torch.cat([idx.t(), self.idx_queue.clone().detach()],dim=1)
90
+ pos_idx = torch.eq(idx, idx_all).float()
91
+ sim_targets = pos_idx / pos_idx.sum(1,keepdim=True)
92
+
93
+ # get momentum features
94
+ with torch.no_grad():
95
+ self._momentum_update()
96
+ image_embeds_m = self.visual_encoder_m(image)
97
+ image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1)
98
+ image_feat_m_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1)
99
+
100
+ text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask,
101
+ return_dict = True, mode = 'text')
102
+ text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1)
103
+ text_feat_m_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1)
104
+
105
+ sim_i2t_m = image_feat_m @ text_feat_m_all / self.temp
106
+ sim_t2i_m = text_feat_m @ image_feat_m_all / self.temp
107
+
108
+ sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets
109
+ sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets
110
+
111
+ sim_i2t = image_feat @ text_feat_m_all / self.temp
112
+ sim_t2i = text_feat @ image_feat_m_all / self.temp
113
+
114
+ loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean()
115
+ loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean()
116
+
117
+ loss_ita = (loss_i2t+loss_t2i)/2
118
+
119
+ idxs = concat_all_gather(idx)
120
+ self._dequeue_and_enqueue(image_feat_m, text_feat_m, idxs)
121
+
122
+ ###============== Image-text Matching ===================###
123
+ encoder_input_ids = text.input_ids.clone()
124
+ encoder_input_ids[:,0] = self.tokenizer.enc_token_id
125
+
126
+ # forward the positve image-text pair
127
+ bs = image.size(0)
128
+ output_pos = self.text_encoder(encoder_input_ids,
129
+ attention_mask = text.attention_mask,
130
+ encoder_hidden_states = image_embeds,
131
+ encoder_attention_mask = image_atts,
132
+ return_dict = True,
133
+ )
134
+
135
+
136
+ if self.negative_all_rank:
137
+ # compute sample similarity
138
+ with torch.no_grad():
139
+ mask = torch.eq(idx, idxs.t())
140
+
141
+ image_feat_world = concat_all_gather(image_feat)
142
+ text_feat_world = concat_all_gather(text_feat)
143
+
144
+ sim_i2t = image_feat @ text_feat_world.t() / self.temp
145
+ sim_t2i = text_feat @ image_feat_world.t() / self.temp
146
+
147
+ weights_i2t = F.softmax(sim_i2t,dim=1)
148
+ weights_i2t.masked_fill_(mask, 0)
149
+
150
+ weights_t2i = F.softmax(sim_t2i,dim=1)
151
+ weights_t2i.masked_fill_(mask, 0)
152
+
153
+ image_embeds_world = all_gather_with_grad(image_embeds)
154
+
155
+ # select a negative image (from all ranks) for each text
156
+ image_embeds_neg = []
157
+ for b in range(bs):
158
+ neg_idx = torch.multinomial(weights_t2i[b], 1).item()
159
+ image_embeds_neg.append(image_embeds_world[neg_idx])
160
+ image_embeds_neg = torch.stack(image_embeds_neg,dim=0)
161
+
162
+ # select a negative text (from all ranks) for each image
163
+ input_ids_world = concat_all_gather(encoder_input_ids)
164
+ att_mask_world = concat_all_gather(text.attention_mask)
165
+
166
+ text_ids_neg = []
167
+ text_atts_neg = []
168
+ for b in range(bs):
169
+ neg_idx = torch.multinomial(weights_i2t[b], 1).item()
170
+ text_ids_neg.append(input_ids_world[neg_idx])
171
+ text_atts_neg.append(att_mask_world[neg_idx])
172
+
173
+ else:
174
+ with torch.no_grad():
175
+ mask = torch.eq(idx, idx.t())
176
+
177
+ sim_i2t = image_feat @ text_feat.t() / self.temp
178
+ sim_t2i = text_feat @ image_feat.t() / self.temp
179
+
180
+ weights_i2t = F.softmax(sim_i2t,dim=1)
181
+ weights_i2t.masked_fill_(mask, 0)
182
+
183
+ weights_t2i = F.softmax(sim_t2i,dim=1)
184
+ weights_t2i.masked_fill_(mask, 0)
185
+
186
+ # select a negative image (from same rank) for each text
187
+ image_embeds_neg = []
188
+ for b in range(bs):
189
+ neg_idx = torch.multinomial(weights_t2i[b], 1).item()
190
+ image_embeds_neg.append(image_embeds[neg_idx])
191
+ image_embeds_neg = torch.stack(image_embeds_neg,dim=0)
192
+
193
+ # select a negative text (from same rank) for each image
194
+ text_ids_neg = []
195
+ text_atts_neg = []
196
+ for b in range(bs):
197
+ neg_idx = torch.multinomial(weights_i2t[b], 1).item()
198
+ text_ids_neg.append(encoder_input_ids[neg_idx])
199
+ text_atts_neg.append(text.attention_mask[neg_idx])
200
+
201
+ text_ids_neg = torch.stack(text_ids_neg,dim=0)
202
+ text_atts_neg = torch.stack(text_atts_neg,dim=0)
203
+
204
+ text_ids_all = torch.cat([encoder_input_ids, text_ids_neg],dim=0)
205
+ text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0)
206
+
207
+ image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0)
208
+ image_atts_all = torch.cat([image_atts,image_atts],dim=0)
209
+
210
+ output_neg = self.text_encoder(text_ids_all,
211
+ attention_mask = text_atts_all,
212
+ encoder_hidden_states = image_embeds_all,
213
+ encoder_attention_mask = image_atts_all,
214
+ return_dict = True,
215
+ )
216
+
217
+
218
+ vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0)
219
+ vl_output = self.itm_head(vl_embeddings)
220
+
221
+ itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)],
222
+ dim=0).to(image.device)
223
+ loss_itm = F.cross_entropy(vl_output, itm_labels)
224
+
225
+ return loss_ita, loss_itm
226
+
227
+
228
+ @torch.no_grad()
229
+ def copy_params(self):
230
+ for model_pair in self.model_pairs:
231
+ for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
232
+ param_m.data.copy_(param.data) # initialize
233
+ param_m.requires_grad = False # not update by gradient
234
+
235
+
236
+ @torch.no_grad()
237
+ def _momentum_update(self):
238
+ for model_pair in self.model_pairs:
239
+ for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
240
+ param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum)
241
+
242
+
243
+ @torch.no_grad()
244
+ def _dequeue_and_enqueue(self, image_feat, text_feat, idxs):
245
+ # gather keys before updating queue
246
+ image_feats = concat_all_gather(image_feat)
247
+ text_feats = concat_all_gather(text_feat)
248
+
249
+
250
+ batch_size = image_feats.shape[0]
251
+
252
+ ptr = int(self.ptr_queue)
253
+ assert self.queue_size % batch_size == 0 # for simplicity
254
+
255
+ # replace the keys at ptr (dequeue and enqueue)
256
+ self.image_queue[:, ptr:ptr + batch_size] = image_feats.T
257
+ self.text_queue[:, ptr:ptr + batch_size] = text_feats.T
258
+ self.idx_queue[:, ptr:ptr + batch_size] = idxs.T
259
+ ptr = (ptr + batch_size) % self.queue_size # move pointer
260
+
261
+ self.ptr_queue[0] = ptr
262
+
263
+
264
+ def blip_retrieval(pretrained='',**kwargs):
265
+ model = BLIP_Retrieval(**kwargs)
266
+ if pretrained:
267
+ model,msg = load_checkpoint(model,pretrained)
268
+ print("missing keys:")
269
+ print(msg.missing_keys)
270
+ return model
271
+
272
+
273
+ @torch.no_grad()
274
+ def concat_all_gather(tensor):
275
+ """
276
+ Performs all_gather operation on the provided tensors.
277
+ *** Warning ***: torch.distributed.all_gather has no gradient.
278
+ """
279
+ tensors_gather = [torch.ones_like(tensor)
280
+ for _ in range(torch.distributed.get_world_size())]
281
+ torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
282
+
283
+ output = torch.cat(tensors_gather, dim=0)
284
+ return output
285
+
286
+
287
+ class GatherLayer(torch.autograd.Function):
288
+ """
289
+ Gather tensors from all workers with support for backward propagation:
290
+ This implementation does not cut the gradients as torch.distributed.all_gather does.
291
+ """
292
+
293
+ @staticmethod
294
+ def forward(ctx, x):
295
+ output = [torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())]
296
+ torch.distributed.all_gather(output, x)
297
+ return tuple(output)
298
+
299
+ @staticmethod
300
+ def backward(ctx, *grads):
301
+ all_gradients = torch.stack(grads)
302
+ torch.distributed.all_reduce(all_gradients)
303
+ return all_gradients[torch.distributed.get_rank()]
304
+
305
+
306
+ def all_gather_with_grad(tensors):
307
+ """
308
+ Performs all_gather operation on the provided tensors.
309
+ Graph remains connected for backward grad computation.
310
+ """
311
+ # Queue the gathered tensors
312
+ world_size = torch.distributed.get_world_size()
313
+ # There is no need for reduction in the single-proc case
314
+ if world_size == 1:
315
+ return tensors
316
+
317
+ tensor_all = GatherLayer.apply(tensors)
318
+
319
+ return torch.cat(tensor_all, dim=0)
models/blip_vqa.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.med import BertConfig, BertModel, BertLMHeadModel
2
+ from models.blip import create_vit, init_tokenizer, load_checkpoint
3
+
4
+ import torch
5
+ from torch import nn
6
+ import torch.nn.functional as F
7
+ from transformers import BertTokenizer
8
+ import numpy as np
9
+
10
+ class BLIP_VQA(nn.Module):
11
+ def __init__(self,
12
+ med_config = 'configs/med_config.json',
13
+ image_size = 480,
14
+ vit = 'base',
15
+ vit_grad_ckpt = False,
16
+ vit_ckpt_layer = 0,
17
+ ):
18
+ """
19
+ Args:
20
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
21
+ image_size (int): input image size
22
+ vit (str): model size of vision transformer
23
+ """
24
+ super().__init__()
25
+
26
+ self.visual_encoder, vision_width = create_vit(vit, image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1)
27
+ self.tokenizer = init_tokenizer()
28
+
29
+ encoder_config = BertConfig.from_json_file(med_config)
30
+ encoder_config.encoder_width = vision_width
31
+ self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False)
32
+
33
+ decoder_config = BertConfig.from_json_file(med_config)
34
+ self.text_decoder = BertLMHeadModel(config=decoder_config)
35
+
36
+
37
+ def forward(self, image, question, answer=None, n=None, weights=None, train=True, inference='rank', k_test=128):
38
+
39
+ image_embeds = self.visual_encoder(image)
40
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
41
+
42
+ question = self.tokenizer(question, padding='longest', truncation=True, max_length=35,
43
+ return_tensors="pt").to(image.device)
44
+ question.input_ids[:,0] = self.tokenizer.enc_token_id
45
+
46
+ if train:
47
+ '''
48
+ n: number of answers for each question
49
+ weights: weight for each answer
50
+ '''
51
+ answer = self.tokenizer(answer, padding='longest', return_tensors="pt").to(image.device)
52
+ answer.input_ids[:,0] = self.tokenizer.bos_token_id
53
+ answer_targets = answer.input_ids.masked_fill(answer.input_ids == self.tokenizer.pad_token_id, -100)
54
+
55
+ question_output = self.text_encoder(question.input_ids,
56
+ attention_mask = question.attention_mask,
57
+ encoder_hidden_states = image_embeds,
58
+ encoder_attention_mask = image_atts,
59
+ return_dict = True)
60
+
61
+ question_states = []
62
+ question_atts = []
63
+ for b, n in enumerate(n):
64
+ question_states += [question_output.last_hidden_state[b]]*n
65
+ question_atts += [question.attention_mask[b]]*n
66
+ question_states = torch.stack(question_states,0)
67
+ question_atts = torch.stack(question_atts,0)
68
+
69
+ answer_output = self.text_decoder(answer.input_ids,
70
+ attention_mask = answer.attention_mask,
71
+ encoder_hidden_states = question_states,
72
+ encoder_attention_mask = question_atts,
73
+ labels = answer_targets,
74
+ return_dict = True,
75
+ reduction = 'none',
76
+ )
77
+
78
+ loss = weights * answer_output.loss
79
+ loss = loss.sum()/image.size(0)
80
+
81
+ return loss
82
+
83
+
84
+ else:
85
+ question_output = self.text_encoder(question.input_ids,
86
+ attention_mask = question.attention_mask,
87
+ encoder_hidden_states = image_embeds,
88
+ encoder_attention_mask = image_atts,
89
+ return_dict = True)
90
+
91
+ if inference=='generate':
92
+ num_beams = 3
93
+ question_states = question_output.last_hidden_state.repeat_interleave(num_beams,dim=0)
94
+ question_atts = torch.ones(question_states.size()[:-1],dtype=torch.long).to(question_states.device)
95
+ model_kwargs = {"encoder_hidden_states": question_states, "encoder_attention_mask":question_atts}
96
+
97
+ bos_ids = torch.full((image.size(0),1),fill_value=self.tokenizer.bos_token_id,device=image.device)
98
+
99
+ outputs = self.text_decoder.generate(input_ids=bos_ids,
100
+ max_length=10,
101
+ min_length=1,
102
+ num_beams=num_beams,
103
+ eos_token_id=self.tokenizer.sep_token_id,
104
+ pad_token_id=self.tokenizer.pad_token_id,
105
+ **model_kwargs)
106
+
107
+ answers = []
108
+ for output in outputs:
109
+ answer = self.tokenizer.decode(output, skip_special_tokens=True)
110
+ answers.append(answer)
111
+ return answers
112
+
113
+ elif inference=='rank':
114
+ max_ids = self.rank_answer(question_output.last_hidden_state, question.attention_mask,
115
+ answer.input_ids, answer.attention_mask, k_test)
116
+ return max_ids
117
+
118
+
119
+
120
+ def rank_answer(self, question_states, question_atts, answer_ids, answer_atts, k):
121
+
122
+ num_ques = question_states.size(0)
123
+ start_ids = answer_ids[0,0].repeat(num_ques,1) # bos token
124
+
125
+ start_output = self.text_decoder(start_ids,
126
+ encoder_hidden_states = question_states,
127
+ encoder_attention_mask = question_atts,
128
+ return_dict = True,
129
+ reduction = 'none')
130
+ logits = start_output.logits[:,0,:] # first token's logit
131
+
132
+ # topk_probs: top-k probability
133
+ # topk_ids: [num_question, k]
134
+ answer_first_token = answer_ids[:,1]
135
+ prob_first_token = F.softmax(logits,dim=1).index_select(dim=1, index=answer_first_token)
136
+ topk_probs, topk_ids = prob_first_token.topk(k,dim=1)
137
+
138
+ # answer input: [num_question*k, answer_len]
139
+ input_ids = []
140
+ input_atts = []
141
+ for b, topk_id in enumerate(topk_ids):
142
+ input_ids.append(answer_ids.index_select(dim=0, index=topk_id))
143
+ input_atts.append(answer_atts.index_select(dim=0, index=topk_id))
144
+ input_ids = torch.cat(input_ids,dim=0)
145
+ input_atts = torch.cat(input_atts,dim=0)
146
+
147
+ targets_ids = input_ids.masked_fill(input_ids == self.tokenizer.pad_token_id, -100)
148
+
149
+ # repeat encoder's output for top-k answers
150
+ question_states = tile(question_states, 0, k)
151
+ question_atts = tile(question_atts, 0, k)
152
+
153
+ output = self.text_decoder(input_ids,
154
+ attention_mask = input_atts,
155
+ encoder_hidden_states = question_states,
156
+ encoder_attention_mask = question_atts,
157
+ labels = targets_ids,
158
+ return_dict = True,
159
+ reduction = 'none')
160
+
161
+ log_probs_sum = -output.loss
162
+ log_probs_sum = log_probs_sum.view(num_ques,k)
163
+
164
+ max_topk_ids = log_probs_sum.argmax(dim=1)
165
+ max_ids = topk_ids[max_topk_ids>=0,max_topk_ids]
166
+
167
+ return max_ids
168
+
169
+
170
+ def blip_vqa(pretrained='',**kwargs):
171
+ model = BLIP_VQA(**kwargs)
172
+ if pretrained:
173
+ model,msg = load_checkpoint(model,pretrained)
174
+ # assert(len(msg.missing_keys)==0)
175
+ return model
176
+
177
+
178
+ def tile(x, dim, n_tile):
179
+ init_dim = x.size(dim)
180
+ repeat_idx = [1] * x.dim()
181
+ repeat_idx[dim] = n_tile
182
+ x = x.repeat(*(repeat_idx))
183
+ order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
184
+ return torch.index_select(x, dim, order_index.to(x.device))
185
+
186
+
models/med.py ADDED
@@ -0,0 +1,955 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ * Based on huggingface code base
8
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
9
+ '''
10
+
11
+ import math
12
+ import os
13
+ import warnings
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple
16
+
17
+ import torch
18
+ from torch import Tensor, device, dtype, nn
19
+ import torch.utils.checkpoint
20
+ from torch import nn
21
+ from torch.nn import CrossEntropyLoss
22
+ import torch.nn.functional as F
23
+
24
+ from transformers.activations import ACT2FN
25
+ from transformers.file_utils import (
26
+ ModelOutput,
27
+ )
28
+ from transformers.modeling_outputs import (
29
+ BaseModelOutputWithPastAndCrossAttentions,
30
+ BaseModelOutputWithPoolingAndCrossAttentions,
31
+ CausalLMOutputWithCrossAttentions,
32
+ MaskedLMOutput,
33
+ MultipleChoiceModelOutput,
34
+ NextSentencePredictorOutput,
35
+ QuestionAnsweringModelOutput,
36
+ SequenceClassifierOutput,
37
+ TokenClassifierOutput,
38
+ )
39
+ from transformers.modeling_utils import (
40
+ PreTrainedModel,
41
+ apply_chunking_to_forward,
42
+ find_pruneable_heads_and_indices,
43
+ prune_linear_layer,
44
+ )
45
+ from transformers.utils import logging
46
+ from transformers.models.bert.configuration_bert import BertConfig
47
+
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+
52
+ class BertEmbeddings(nn.Module):
53
+ """Construct the embeddings from word and position embeddings."""
54
+
55
+ def __init__(self, config):
56
+ super().__init__()
57
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
58
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
59
+
60
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
61
+ # any TensorFlow checkpoint file
62
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
63
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
64
+
65
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
66
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
67
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
68
+
69
+ self.config = config
70
+
71
+ def forward(
72
+ self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
73
+ ):
74
+ if input_ids is not None:
75
+ input_shape = input_ids.size()
76
+ else:
77
+ input_shape = inputs_embeds.size()[:-1]
78
+
79
+ seq_length = input_shape[1]
80
+
81
+ if position_ids is None:
82
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
83
+
84
+ if inputs_embeds is None:
85
+ inputs_embeds = self.word_embeddings(input_ids)
86
+
87
+ embeddings = inputs_embeds
88
+
89
+ if self.position_embedding_type == "absolute":
90
+ position_embeddings = self.position_embeddings(position_ids)
91
+ embeddings += position_embeddings
92
+ embeddings = self.LayerNorm(embeddings)
93
+ embeddings = self.dropout(embeddings)
94
+ return embeddings
95
+
96
+
97
+ class BertSelfAttention(nn.Module):
98
+ def __init__(self, config, is_cross_attention):
99
+ super().__init__()
100
+ self.config = config
101
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
102
+ raise ValueError(
103
+ "The hidden size (%d) is not a multiple of the number of attention "
104
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
105
+ )
106
+
107
+ self.num_attention_heads = config.num_attention_heads
108
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
109
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
110
+
111
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
112
+ if is_cross_attention:
113
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
114
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
115
+ else:
116
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
117
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
118
+
119
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
120
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
121
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
122
+ self.max_position_embeddings = config.max_position_embeddings
123
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
124
+ self.save_attention = False
125
+
126
+ def save_attn_gradients(self, attn_gradients):
127
+ self.attn_gradients = attn_gradients
128
+
129
+ def get_attn_gradients(self):
130
+ return self.attn_gradients
131
+
132
+ def save_attention_map(self, attention_map):
133
+ self.attention_map = attention_map
134
+
135
+ def get_attention_map(self):
136
+ return self.attention_map
137
+
138
+ def transpose_for_scores(self, x):
139
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
140
+ x = x.view(*new_x_shape)
141
+ return x.permute(0, 2, 1, 3)
142
+
143
+ def forward(
144
+ self,
145
+ hidden_states,
146
+ attention_mask=None,
147
+ head_mask=None,
148
+ encoder_hidden_states=None,
149
+ encoder_attention_mask=None,
150
+ past_key_value=None,
151
+ output_attentions=False,
152
+ ):
153
+ mixed_query_layer = self.query(hidden_states)
154
+
155
+ # If this is instantiated as a cross-attention module, the keys
156
+ # and values come from an encoder; the attention mask needs to be
157
+ # such that the encoder's padding tokens are not attended to.
158
+ is_cross_attention = encoder_hidden_states is not None
159
+
160
+ if is_cross_attention:
161
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
162
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
163
+ attention_mask = encoder_attention_mask
164
+ elif past_key_value is not None:
165
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
166
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
167
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
168
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
169
+ else:
170
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
171
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
172
+
173
+ query_layer = self.transpose_for_scores(mixed_query_layer)
174
+
175
+ past_key_value = (key_layer, value_layer)
176
+
177
+ # Take the dot product between "query" and "key" to get the raw attention scores.
178
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
179
+
180
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
181
+ seq_length = hidden_states.size()[1]
182
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
183
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
184
+ distance = position_ids_l - position_ids_r
185
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
186
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
187
+
188
+ if self.position_embedding_type == "relative_key":
189
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
190
+ attention_scores = attention_scores + relative_position_scores
191
+ elif self.position_embedding_type == "relative_key_query":
192
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
193
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
194
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
195
+
196
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
197
+ if attention_mask is not None:
198
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
199
+ attention_scores = attention_scores + attention_mask
200
+
201
+ # Normalize the attention scores to probabilities.
202
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
203
+
204
+ if is_cross_attention and self.save_attention:
205
+ self.save_attention_map(attention_probs)
206
+ attention_probs.register_hook(self.save_attn_gradients)
207
+
208
+ # This is actually dropping out entire tokens to attend to, which might
209
+ # seem a bit unusual, but is taken from the original Transformer paper.
210
+ attention_probs_dropped = self.dropout(attention_probs)
211
+
212
+ # Mask heads if we want to
213
+ if head_mask is not None:
214
+ attention_probs_dropped = attention_probs_dropped * head_mask
215
+
216
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
217
+
218
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
219
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
220
+ context_layer = context_layer.view(*new_context_layer_shape)
221
+
222
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
223
+
224
+ outputs = outputs + (past_key_value,)
225
+ return outputs
226
+
227
+
228
+ class BertSelfOutput(nn.Module):
229
+ def __init__(self, config):
230
+ super().__init__()
231
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
232
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
233
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
234
+
235
+ def forward(self, hidden_states, input_tensor):
236
+ hidden_states = self.dense(hidden_states)
237
+ hidden_states = self.dropout(hidden_states)
238
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
239
+ return hidden_states
240
+
241
+
242
+ class BertAttention(nn.Module):
243
+ def __init__(self, config, is_cross_attention=False):
244
+ super().__init__()
245
+ self.self = BertSelfAttention(config, is_cross_attention)
246
+ self.output = BertSelfOutput(config)
247
+ self.pruned_heads = set()
248
+
249
+ def prune_heads(self, heads):
250
+ if len(heads) == 0:
251
+ return
252
+ heads, index = find_pruneable_heads_and_indices(
253
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
254
+ )
255
+
256
+ # Prune linear layers
257
+ self.self.query = prune_linear_layer(self.self.query, index)
258
+ self.self.key = prune_linear_layer(self.self.key, index)
259
+ self.self.value = prune_linear_layer(self.self.value, index)
260
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
261
+
262
+ # Update hyper params and store pruned heads
263
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
264
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
265
+ self.pruned_heads = self.pruned_heads.union(heads)
266
+
267
+ def forward(
268
+ self,
269
+ hidden_states,
270
+ attention_mask=None,
271
+ head_mask=None,
272
+ encoder_hidden_states=None,
273
+ encoder_attention_mask=None,
274
+ past_key_value=None,
275
+ output_attentions=False,
276
+ ):
277
+ self_outputs = self.self(
278
+ hidden_states,
279
+ attention_mask,
280
+ head_mask,
281
+ encoder_hidden_states,
282
+ encoder_attention_mask,
283
+ past_key_value,
284
+ output_attentions,
285
+ )
286
+ attention_output = self.output(self_outputs[0], hidden_states)
287
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
288
+ return outputs
289
+
290
+
291
+ class BertIntermediate(nn.Module):
292
+ def __init__(self, config):
293
+ super().__init__()
294
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
295
+ if isinstance(config.hidden_act, str):
296
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
297
+ else:
298
+ self.intermediate_act_fn = config.hidden_act
299
+
300
+ def forward(self, hidden_states):
301
+ hidden_states = self.dense(hidden_states)
302
+ hidden_states = self.intermediate_act_fn(hidden_states)
303
+ return hidden_states
304
+
305
+
306
+ class BertOutput(nn.Module):
307
+ def __init__(self, config):
308
+ super().__init__()
309
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
310
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
311
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
312
+
313
+ def forward(self, hidden_states, input_tensor):
314
+ hidden_states = self.dense(hidden_states)
315
+ hidden_states = self.dropout(hidden_states)
316
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
317
+ return hidden_states
318
+
319
+
320
+ class BertLayer(nn.Module):
321
+ def __init__(self, config, layer_num):
322
+ super().__init__()
323
+ self.config = config
324
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
325
+ self.seq_len_dim = 1
326
+ self.attention = BertAttention(config)
327
+ self.layer_num = layer_num
328
+ if self.config.add_cross_attention:
329
+ self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
330
+ self.intermediate = BertIntermediate(config)
331
+ self.output = BertOutput(config)
332
+
333
+ def forward(
334
+ self,
335
+ hidden_states,
336
+ attention_mask=None,
337
+ head_mask=None,
338
+ encoder_hidden_states=None,
339
+ encoder_attention_mask=None,
340
+ past_key_value=None,
341
+ output_attentions=False,
342
+ mode=None,
343
+ ):
344
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
345
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
346
+ self_attention_outputs = self.attention(
347
+ hidden_states,
348
+ attention_mask,
349
+ head_mask,
350
+ output_attentions=output_attentions,
351
+ past_key_value=self_attn_past_key_value,
352
+ )
353
+ attention_output = self_attention_outputs[0]
354
+
355
+ outputs = self_attention_outputs[1:-1]
356
+ present_key_value = self_attention_outputs[-1]
357
+
358
+ if mode=='multimodal':
359
+ assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
360
+
361
+ cross_attention_outputs = self.crossattention(
362
+ attention_output,
363
+ attention_mask,
364
+ head_mask,
365
+ encoder_hidden_states,
366
+ encoder_attention_mask,
367
+ output_attentions=output_attentions,
368
+ )
369
+ attention_output = cross_attention_outputs[0]
370
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
371
+ layer_output = apply_chunking_to_forward(
372
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
373
+ )
374
+ outputs = (layer_output,) + outputs
375
+
376
+ outputs = outputs + (present_key_value,)
377
+
378
+ return outputs
379
+
380
+ def feed_forward_chunk(self, attention_output):
381
+ intermediate_output = self.intermediate(attention_output)
382
+ layer_output = self.output(intermediate_output, attention_output)
383
+ return layer_output
384
+
385
+
386
+ class BertEncoder(nn.Module):
387
+ def __init__(self, config):
388
+ super().__init__()
389
+ self.config = config
390
+ self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
391
+ self.gradient_checkpointing = False
392
+
393
+ def forward(
394
+ self,
395
+ hidden_states,
396
+ attention_mask=None,
397
+ head_mask=None,
398
+ encoder_hidden_states=None,
399
+ encoder_attention_mask=None,
400
+ past_key_values=None,
401
+ use_cache=None,
402
+ output_attentions=False,
403
+ output_hidden_states=False,
404
+ return_dict=True,
405
+ mode='multimodal',
406
+ ):
407
+ all_hidden_states = () if output_hidden_states else None
408
+ all_self_attentions = () if output_attentions else None
409
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
410
+
411
+ next_decoder_cache = () if use_cache else None
412
+
413
+ for i in range(self.config.num_hidden_layers):
414
+ layer_module = self.layer[i]
415
+ if output_hidden_states:
416
+ all_hidden_states = all_hidden_states + (hidden_states,)
417
+
418
+ layer_head_mask = head_mask[i] if head_mask is not None else None
419
+ past_key_value = past_key_values[i] if past_key_values is not None else None
420
+
421
+ if self.gradient_checkpointing and self.training:
422
+
423
+ if use_cache:
424
+ logger.warn(
425
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
426
+ )
427
+ use_cache = False
428
+
429
+ def create_custom_forward(module):
430
+ def custom_forward(*inputs):
431
+ return module(*inputs, past_key_value, output_attentions)
432
+
433
+ return custom_forward
434
+
435
+ layer_outputs = torch.utils.checkpoint.checkpoint(
436
+ create_custom_forward(layer_module),
437
+ hidden_states,
438
+ attention_mask,
439
+ layer_head_mask,
440
+ encoder_hidden_states,
441
+ encoder_attention_mask,
442
+ mode=mode,
443
+ )
444
+ else:
445
+ layer_outputs = layer_module(
446
+ hidden_states,
447
+ attention_mask,
448
+ layer_head_mask,
449
+ encoder_hidden_states,
450
+ encoder_attention_mask,
451
+ past_key_value,
452
+ output_attentions,
453
+ mode=mode,
454
+ )
455
+
456
+ hidden_states = layer_outputs[0]
457
+ if use_cache:
458
+ next_decoder_cache += (layer_outputs[-1],)
459
+ if output_attentions:
460
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
461
+
462
+ if output_hidden_states:
463
+ all_hidden_states = all_hidden_states + (hidden_states,)
464
+
465
+ if not return_dict:
466
+ return tuple(
467
+ v
468
+ for v in [
469
+ hidden_states,
470
+ next_decoder_cache,
471
+ all_hidden_states,
472
+ all_self_attentions,
473
+ all_cross_attentions,
474
+ ]
475
+ if v is not None
476
+ )
477
+ return BaseModelOutputWithPastAndCrossAttentions(
478
+ last_hidden_state=hidden_states,
479
+ past_key_values=next_decoder_cache,
480
+ hidden_states=all_hidden_states,
481
+ attentions=all_self_attentions,
482
+ cross_attentions=all_cross_attentions,
483
+ )
484
+
485
+
486
+ class BertPooler(nn.Module):
487
+ def __init__(self, config):
488
+ super().__init__()
489
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
490
+ self.activation = nn.Tanh()
491
+
492
+ def forward(self, hidden_states):
493
+ # We "pool" the model by simply taking the hidden state corresponding
494
+ # to the first token.
495
+ first_token_tensor = hidden_states[:, 0]
496
+ pooled_output = self.dense(first_token_tensor)
497
+ pooled_output = self.activation(pooled_output)
498
+ return pooled_output
499
+
500
+
501
+ class BertPredictionHeadTransform(nn.Module):
502
+ def __init__(self, config):
503
+ super().__init__()
504
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
505
+ if isinstance(config.hidden_act, str):
506
+ self.transform_act_fn = ACT2FN[config.hidden_act]
507
+ else:
508
+ self.transform_act_fn = config.hidden_act
509
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
510
+
511
+ def forward(self, hidden_states):
512
+ hidden_states = self.dense(hidden_states)
513
+ hidden_states = self.transform_act_fn(hidden_states)
514
+ hidden_states = self.LayerNorm(hidden_states)
515
+ return hidden_states
516
+
517
+
518
+ class BertLMPredictionHead(nn.Module):
519
+ def __init__(self, config):
520
+ super().__init__()
521
+ self.transform = BertPredictionHeadTransform(config)
522
+
523
+ # The output weights are the same as the input embeddings, but there is
524
+ # an output-only bias for each token.
525
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
526
+
527
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
528
+
529
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
530
+ self.decoder.bias = self.bias
531
+
532
+ def forward(self, hidden_states):
533
+ hidden_states = self.transform(hidden_states)
534
+ hidden_states = self.decoder(hidden_states)
535
+ return hidden_states
536
+
537
+
538
+ class BertOnlyMLMHead(nn.Module):
539
+ def __init__(self, config):
540
+ super().__init__()
541
+ self.predictions = BertLMPredictionHead(config)
542
+
543
+ def forward(self, sequence_output):
544
+ prediction_scores = self.predictions(sequence_output)
545
+ return prediction_scores
546
+
547
+
548
+ class BertPreTrainedModel(PreTrainedModel):
549
+ """
550
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
551
+ models.
552
+ """
553
+
554
+ config_class = BertConfig
555
+ base_model_prefix = "bert"
556
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
557
+
558
+ def _init_weights(self, module):
559
+ """ Initialize the weights """
560
+ if isinstance(module, (nn.Linear, nn.Embedding)):
561
+ # Slightly different from the TF version which uses truncated_normal for initialization
562
+ # cf https://github.com/pytorch/pytorch/pull/5617
563
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
564
+ elif isinstance(module, nn.LayerNorm):
565
+ module.bias.data.zero_()
566
+ module.weight.data.fill_(1.0)
567
+ if isinstance(module, nn.Linear) and module.bias is not None:
568
+ module.bias.data.zero_()
569
+
570
+
571
+ class BertModel(BertPreTrainedModel):
572
+ """
573
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
574
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
575
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
576
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
577
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
578
+ input to the forward pass.
579
+ """
580
+
581
+ def __init__(self, config, add_pooling_layer=True):
582
+ super().__init__(config)
583
+ self.config = config
584
+
585
+ self.embeddings = BertEmbeddings(config)
586
+
587
+ self.encoder = BertEncoder(config)
588
+
589
+ self.pooler = BertPooler(config) if add_pooling_layer else None
590
+
591
+ self.init_weights()
592
+
593
+
594
+ def get_input_embeddings(self):
595
+ return self.embeddings.word_embeddings
596
+
597
+ def set_input_embeddings(self, value):
598
+ self.embeddings.word_embeddings = value
599
+
600
+ def _prune_heads(self, heads_to_prune):
601
+ """
602
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
603
+ class PreTrainedModel
604
+ """
605
+ for layer, heads in heads_to_prune.items():
606
+ self.encoder.layer[layer].attention.prune_heads(heads)
607
+
608
+
609
+ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
610
+ """
611
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
612
+
613
+ Arguments:
614
+ attention_mask (:obj:`torch.Tensor`):
615
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
616
+ input_shape (:obj:`Tuple[int]`):
617
+ The shape of the input to the model.
618
+ device: (:obj:`torch.device`):
619
+ The device of the input to the model.
620
+
621
+ Returns:
622
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
623
+ """
624
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
625
+ # ourselves in which case we just need to make it broadcastable to all heads.
626
+ if attention_mask.dim() == 3:
627
+ extended_attention_mask = attention_mask[:, None, :, :]
628
+ elif attention_mask.dim() == 2:
629
+ # Provided a padding mask of dimensions [batch_size, seq_length]
630
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
631
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
632
+ if is_decoder:
633
+ batch_size, seq_length = input_shape
634
+
635
+ seq_ids = torch.arange(seq_length, device=device)
636
+ causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
637
+ # in case past_key_values are used we need to add a prefix ones mask to the causal mask
638
+ # causal and attention masks must have same type with pytorch version < 1.3
639
+ causal_mask = causal_mask.to(attention_mask.dtype)
640
+
641
+ if causal_mask.shape[1] < attention_mask.shape[1]:
642
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
643
+ causal_mask = torch.cat(
644
+ [
645
+ torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
646
+ causal_mask,
647
+ ],
648
+ axis=-1,
649
+ )
650
+
651
+ extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
652
+ else:
653
+ extended_attention_mask = attention_mask[:, None, None, :]
654
+ else:
655
+ raise ValueError(
656
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
657
+ input_shape, attention_mask.shape
658
+ )
659
+ )
660
+
661
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
662
+ # masked positions, this operation will create a tensor which is 0.0 for
663
+ # positions we want to attend and -10000.0 for masked positions.
664
+ # Since we are adding it to the raw scores before the softmax, this is
665
+ # effectively the same as removing these entirely.
666
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
667
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
668
+ return extended_attention_mask
669
+
670
+ def forward(
671
+ self,
672
+ input_ids=None,
673
+ attention_mask=None,
674
+ position_ids=None,
675
+ head_mask=None,
676
+ inputs_embeds=None,
677
+ encoder_embeds=None,
678
+ encoder_hidden_states=None,
679
+ encoder_attention_mask=None,
680
+ past_key_values=None,
681
+ use_cache=None,
682
+ output_attentions=None,
683
+ output_hidden_states=None,
684
+ return_dict=None,
685
+ is_decoder=False,
686
+ mode='multimodal',
687
+ ):
688
+ r"""
689
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
690
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
691
+ the model is configured as a decoder.
692
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
693
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
694
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
695
+ - 1 for tokens that are **not masked**,
696
+ - 0 for tokens that are **masked**.
697
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
698
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
699
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
700
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
701
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
702
+ use_cache (:obj:`bool`, `optional`):
703
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
704
+ decoding (see :obj:`past_key_values`).
705
+ """
706
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
707
+ output_hidden_states = (
708
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
709
+ )
710
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
711
+
712
+ if is_decoder:
713
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
714
+ else:
715
+ use_cache = False
716
+
717
+ if input_ids is not None and inputs_embeds is not None:
718
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
719
+ elif input_ids is not None:
720
+ input_shape = input_ids.size()
721
+ batch_size, seq_length = input_shape
722
+ device = input_ids.device
723
+ elif inputs_embeds is not None:
724
+ input_shape = inputs_embeds.size()[:-1]
725
+ batch_size, seq_length = input_shape
726
+ device = inputs_embeds.device
727
+ elif encoder_embeds is not None:
728
+ input_shape = encoder_embeds.size()[:-1]
729
+ batch_size, seq_length = input_shape
730
+ device = encoder_embeds.device
731
+ else:
732
+ raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
733
+
734
+ # past_key_values_length
735
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
736
+
737
+ if attention_mask is None:
738
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
739
+
740
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
741
+ # ourselves in which case we just need to make it broadcastable to all heads.
742
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
743
+ device, is_decoder)
744
+
745
+ # If a 2D or 3D attention mask is provided for the cross-attention
746
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
747
+ if encoder_hidden_states is not None:
748
+ if type(encoder_hidden_states) == list:
749
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
750
+ else:
751
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
752
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
753
+
754
+ if type(encoder_attention_mask) == list:
755
+ encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
756
+ elif encoder_attention_mask is None:
757
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
758
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
759
+ else:
760
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
761
+ else:
762
+ encoder_extended_attention_mask = None
763
+
764
+ # Prepare head mask if needed
765
+ # 1.0 in head_mask indicate we keep the head
766
+ # attention_probs has shape bsz x n_heads x N x N
767
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
768
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
769
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
770
+
771
+ if encoder_embeds is None:
772
+ embedding_output = self.embeddings(
773
+ input_ids=input_ids,
774
+ position_ids=position_ids,
775
+ inputs_embeds=inputs_embeds,
776
+ past_key_values_length=past_key_values_length,
777
+ )
778
+ else:
779
+ embedding_output = encoder_embeds
780
+
781
+ encoder_outputs = self.encoder(
782
+ embedding_output,
783
+ attention_mask=extended_attention_mask,
784
+ head_mask=head_mask,
785
+ encoder_hidden_states=encoder_hidden_states,
786
+ encoder_attention_mask=encoder_extended_attention_mask,
787
+ past_key_values=past_key_values,
788
+ use_cache=use_cache,
789
+ output_attentions=output_attentions,
790
+ output_hidden_states=output_hidden_states,
791
+ return_dict=return_dict,
792
+ mode=mode,
793
+ )
794
+ sequence_output = encoder_outputs[0]
795
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
796
+
797
+ if not return_dict:
798
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
799
+
800
+ return BaseModelOutputWithPoolingAndCrossAttentions(
801
+ last_hidden_state=sequence_output,
802
+ pooler_output=pooled_output,
803
+ past_key_values=encoder_outputs.past_key_values,
804
+ hidden_states=encoder_outputs.hidden_states,
805
+ attentions=encoder_outputs.attentions,
806
+ cross_attentions=encoder_outputs.cross_attentions,
807
+ )
808
+
809
+
810
+
811
+ class BertLMHeadModel(BertPreTrainedModel):
812
+
813
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
814
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
815
+
816
+ def __init__(self, config):
817
+ super().__init__(config)
818
+
819
+ self.bert = BertModel(config, add_pooling_layer=False)
820
+ self.cls = BertOnlyMLMHead(config)
821
+
822
+ self.init_weights()
823
+
824
+ def get_output_embeddings(self):
825
+ return self.cls.predictions.decoder
826
+
827
+ def set_output_embeddings(self, new_embeddings):
828
+ self.cls.predictions.decoder = new_embeddings
829
+
830
+ def forward(
831
+ self,
832
+ input_ids=None,
833
+ attention_mask=None,
834
+ position_ids=None,
835
+ head_mask=None,
836
+ inputs_embeds=None,
837
+ encoder_hidden_states=None,
838
+ encoder_attention_mask=None,
839
+ labels=None,
840
+ past_key_values=None,
841
+ use_cache=None,
842
+ output_attentions=None,
843
+ output_hidden_states=None,
844
+ return_dict=None,
845
+ return_logits=False,
846
+ is_decoder=True,
847
+ reduction='mean',
848
+ mode='multimodal',
849
+ ):
850
+ r"""
851
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
852
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
853
+ the model is configured as a decoder.
854
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
855
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
856
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
857
+ - 1 for tokens that are **not masked**,
858
+ - 0 for tokens that are **masked**.
859
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
860
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
861
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
862
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
863
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
864
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
865
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
866
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
867
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
868
+ use_cache (:obj:`bool`, `optional`):
869
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
870
+ decoding (see :obj:`past_key_values`).
871
+ Returns:
872
+ Example::
873
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
874
+ >>> import torch
875
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
876
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
877
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
878
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
879
+ >>> outputs = model(**inputs)
880
+ >>> prediction_logits = outputs.logits
881
+ """
882
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
883
+ if labels is not None:
884
+ use_cache = False
885
+
886
+ outputs = self.bert(
887
+ input_ids,
888
+ attention_mask=attention_mask,
889
+ position_ids=position_ids,
890
+ head_mask=head_mask,
891
+ inputs_embeds=inputs_embeds,
892
+ encoder_hidden_states=encoder_hidden_states,
893
+ encoder_attention_mask=encoder_attention_mask,
894
+ past_key_values=past_key_values,
895
+ use_cache=use_cache,
896
+ output_attentions=output_attentions,
897
+ output_hidden_states=output_hidden_states,
898
+ return_dict=return_dict,
899
+ is_decoder=is_decoder,
900
+ mode=mode,
901
+ )
902
+
903
+ sequence_output = outputs[0]
904
+ prediction_scores = self.cls(sequence_output)
905
+
906
+ if return_logits:
907
+ return prediction_scores[:, :-1, :].contiguous()
908
+
909
+ lm_loss = None
910
+ if labels is not None:
911
+ # we are doing next-token prediction; shift prediction scores and input ids by one
912
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
913
+ labels = labels[:, 1:].contiguous()
914
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
915
+ lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
916
+ if reduction=='none':
917
+ lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
918
+
919
+ if not return_dict:
920
+ output = (prediction_scores,) + outputs[2:]
921
+ return ((lm_loss,) + output) if lm_loss is not None else output
922
+
923
+ return CausalLMOutputWithCrossAttentions(
924
+ loss=lm_loss,
925
+ logits=prediction_scores,
926
+ past_key_values=outputs.past_key_values,
927
+ hidden_states=outputs.hidden_states,
928
+ attentions=outputs.attentions,
929
+ cross_attentions=outputs.cross_attentions,
930
+ )
931
+
932
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
933
+ input_shape = input_ids.shape
934
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
935
+ if attention_mask is None:
936
+ attention_mask = input_ids.new_ones(input_shape)
937
+
938
+ # cut decoder_input_ids if past is used
939
+ if past is not None:
940
+ input_ids = input_ids[:, -1:]
941
+
942
+ return {
943
+ "input_ids": input_ids,
944
+ "attention_mask": attention_mask,
945
+ "past_key_values": past,
946
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
947
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
948
+ "is_decoder": True,
949
+ }
950
+
951
+ def _reorder_cache(self, past, beam_idx):
952
+ reordered_past = ()
953
+ for layer_past in past:
954
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
955
+ return reordered_past
models/nlvr_encoder.py ADDED
@@ -0,0 +1,843 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import warnings
4
+ from dataclasses import dataclass
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ from torch import Tensor, device, dtype, nn
9
+ import torch.utils.checkpoint
10
+ from torch import nn
11
+ from torch.nn import CrossEntropyLoss
12
+ import torch.nn.functional as F
13
+
14
+ from transformers.activations import ACT2FN
15
+ from transformers.file_utils import (
16
+ ModelOutput,
17
+ )
18
+ from transformers.modeling_outputs import (
19
+ BaseModelOutputWithPastAndCrossAttentions,
20
+ BaseModelOutputWithPoolingAndCrossAttentions,
21
+ CausalLMOutputWithCrossAttentions,
22
+ MaskedLMOutput,
23
+ MultipleChoiceModelOutput,
24
+ NextSentencePredictorOutput,
25
+ QuestionAnsweringModelOutput,
26
+ SequenceClassifierOutput,
27
+ TokenClassifierOutput,
28
+ )
29
+ from transformers.modeling_utils import (
30
+ PreTrainedModel,
31
+ apply_chunking_to_forward,
32
+ find_pruneable_heads_and_indices,
33
+ prune_linear_layer,
34
+ )
35
+ from transformers.utils import logging
36
+ from transformers.models.bert.configuration_bert import BertConfig
37
+
38
+
39
+ logger = logging.get_logger(__name__)
40
+
41
+
42
+ class BertEmbeddings(nn.Module):
43
+ """Construct the embeddings from word and position embeddings."""
44
+
45
+ def __init__(self, config):
46
+ super().__init__()
47
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
48
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
49
+
50
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
51
+ # any TensorFlow checkpoint file
52
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
53
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
54
+
55
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
56
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
57
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
58
+
59
+ self.config = config
60
+
61
+ def forward(
62
+ self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
63
+ ):
64
+ if input_ids is not None:
65
+ input_shape = input_ids.size()
66
+ else:
67
+ input_shape = inputs_embeds.size()[:-1]
68
+
69
+ seq_length = input_shape[1]
70
+
71
+ if position_ids is None:
72
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
73
+
74
+ if inputs_embeds is None:
75
+ inputs_embeds = self.word_embeddings(input_ids)
76
+
77
+ embeddings = inputs_embeds
78
+
79
+ if self.position_embedding_type == "absolute":
80
+ position_embeddings = self.position_embeddings(position_ids)
81
+ embeddings += position_embeddings
82
+ embeddings = self.LayerNorm(embeddings)
83
+ embeddings = self.dropout(embeddings)
84
+ return embeddings
85
+
86
+
87
+ class BertSelfAttention(nn.Module):
88
+ def __init__(self, config, is_cross_attention):
89
+ super().__init__()
90
+ self.config = config
91
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
92
+ raise ValueError(
93
+ "The hidden size (%d) is not a multiple of the number of attention "
94
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
95
+ )
96
+
97
+ self.num_attention_heads = config.num_attention_heads
98
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
99
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
100
+
101
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
102
+ if is_cross_attention:
103
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
104
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
105
+ else:
106
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
107
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
108
+
109
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
110
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
111
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
112
+ self.max_position_embeddings = config.max_position_embeddings
113
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
114
+ self.save_attention = False
115
+
116
+ def save_attn_gradients(self, attn_gradients):
117
+ self.attn_gradients = attn_gradients
118
+
119
+ def get_attn_gradients(self):
120
+ return self.attn_gradients
121
+
122
+ def save_attention_map(self, attention_map):
123
+ self.attention_map = attention_map
124
+
125
+ def get_attention_map(self):
126
+ return self.attention_map
127
+
128
+ def transpose_for_scores(self, x):
129
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
130
+ x = x.view(*new_x_shape)
131
+ return x.permute(0, 2, 1, 3)
132
+
133
+ def forward(
134
+ self,
135
+ hidden_states,
136
+ attention_mask=None,
137
+ head_mask=None,
138
+ encoder_hidden_states=None,
139
+ encoder_attention_mask=None,
140
+ past_key_value=None,
141
+ output_attentions=False,
142
+ ):
143
+ mixed_query_layer = self.query(hidden_states)
144
+
145
+ # If this is instantiated as a cross-attention module, the keys
146
+ # and values come from an encoder; the attention mask needs to be
147
+ # such that the encoder's padding tokens are not attended to.
148
+ is_cross_attention = encoder_hidden_states is not None
149
+
150
+ if is_cross_attention:
151
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
152
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
153
+ attention_mask = encoder_attention_mask
154
+ elif past_key_value is not None:
155
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
156
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
157
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
158
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
159
+ else:
160
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
161
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
162
+
163
+ query_layer = self.transpose_for_scores(mixed_query_layer)
164
+
165
+ past_key_value = (key_layer, value_layer)
166
+
167
+ # Take the dot product between "query" and "key" to get the raw attention scores.
168
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
169
+
170
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
171
+ seq_length = hidden_states.size()[1]
172
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
173
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
174
+ distance = position_ids_l - position_ids_r
175
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
176
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
177
+
178
+ if self.position_embedding_type == "relative_key":
179
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
180
+ attention_scores = attention_scores + relative_position_scores
181
+ elif self.position_embedding_type == "relative_key_query":
182
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
183
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
184
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
185
+
186
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
187
+ if attention_mask is not None:
188
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
189
+ attention_scores = attention_scores + attention_mask
190
+
191
+ # Normalize the attention scores to probabilities.
192
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
193
+
194
+ if is_cross_attention and self.save_attention:
195
+ self.save_attention_map(attention_probs)
196
+ attention_probs.register_hook(self.save_attn_gradients)
197
+
198
+ # This is actually dropping out entire tokens to attend to, which might
199
+ # seem a bit unusual, but is taken from the original Transformer paper.
200
+ attention_probs_dropped = self.dropout(attention_probs)
201
+
202
+ # Mask heads if we want to
203
+ if head_mask is not None:
204
+ attention_probs_dropped = attention_probs_dropped * head_mask
205
+
206
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
207
+
208
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
209
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
210
+ context_layer = context_layer.view(*new_context_layer_shape)
211
+
212
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
213
+
214
+ outputs = outputs + (past_key_value,)
215
+ return outputs
216
+
217
+
218
+ class BertSelfOutput(nn.Module):
219
+ def __init__(self, config, twin=False, merge=False):
220
+ super().__init__()
221
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
222
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
223
+ if twin:
224
+ self.dense0 = nn.Linear(config.hidden_size, config.hidden_size)
225
+ self.dense1 = nn.Linear(config.hidden_size, config.hidden_size)
226
+ else:
227
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
228
+ if merge:
229
+ self.act = ACT2FN[config.hidden_act]
230
+ self.merge_layer = nn.Linear(config.hidden_size * 2, config.hidden_size)
231
+ self.merge = True
232
+ else:
233
+ self.merge = False
234
+
235
+ def forward(self, hidden_states, input_tensor):
236
+ if type(hidden_states) == list:
237
+ hidden_states0 = self.dense0(hidden_states[0])
238
+ hidden_states1 = self.dense1(hidden_states[1])
239
+ if self.merge:
240
+ #hidden_states = self.merge_layer(self.act(torch.cat([hidden_states0,hidden_states1],dim=-1)))
241
+ hidden_states = self.merge_layer(torch.cat([hidden_states0,hidden_states1],dim=-1))
242
+ else:
243
+ hidden_states = (hidden_states0+hidden_states1)/2
244
+ else:
245
+ hidden_states = self.dense(hidden_states)
246
+ hidden_states = self.dropout(hidden_states)
247
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
248
+ return hidden_states
249
+
250
+
251
+ class BertAttention(nn.Module):
252
+ def __init__(self, config, is_cross_attention=False, layer_num=-1):
253
+ super().__init__()
254
+ if is_cross_attention:
255
+ self.self0 = BertSelfAttention(config, is_cross_attention)
256
+ self.self1 = BertSelfAttention(config, is_cross_attention)
257
+ else:
258
+ self.self = BertSelfAttention(config, is_cross_attention)
259
+ self.output = BertSelfOutput(config, twin=is_cross_attention, merge=(is_cross_attention and layer_num>=6))
260
+ self.pruned_heads = set()
261
+
262
+ def prune_heads(self, heads):
263
+ if len(heads) == 0:
264
+ return
265
+ heads, index = find_pruneable_heads_and_indices(
266
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
267
+ )
268
+
269
+ # Prune linear layers
270
+ self.self.query = prune_linear_layer(self.self.query, index)
271
+ self.self.key = prune_linear_layer(self.self.key, index)
272
+ self.self.value = prune_linear_layer(self.self.value, index)
273
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
274
+
275
+ # Update hyper params and store pruned heads
276
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
277
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
278
+ self.pruned_heads = self.pruned_heads.union(heads)
279
+
280
+ def forward(
281
+ self,
282
+ hidden_states,
283
+ attention_mask=None,
284
+ head_mask=None,
285
+ encoder_hidden_states=None,
286
+ encoder_attention_mask=None,
287
+ past_key_value=None,
288
+ output_attentions=False,
289
+ ):
290
+ if type(encoder_hidden_states)==list:
291
+ self_outputs0 = self.self0(
292
+ hidden_states,
293
+ attention_mask,
294
+ head_mask,
295
+ encoder_hidden_states[0],
296
+ encoder_attention_mask[0],
297
+ past_key_value,
298
+ output_attentions,
299
+ )
300
+ self_outputs1 = self.self1(
301
+ hidden_states,
302
+ attention_mask,
303
+ head_mask,
304
+ encoder_hidden_states[1],
305
+ encoder_attention_mask[1],
306
+ past_key_value,
307
+ output_attentions,
308
+ )
309
+ attention_output = self.output([self_outputs0[0],self_outputs1[0]], hidden_states)
310
+
311
+ outputs = (attention_output,) + self_outputs0[1:] # add attentions if we output them
312
+ else:
313
+ self_outputs = self.self(
314
+ hidden_states,
315
+ attention_mask,
316
+ head_mask,
317
+ encoder_hidden_states,
318
+ encoder_attention_mask,
319
+ past_key_value,
320
+ output_attentions,
321
+ )
322
+ attention_output = self.output(self_outputs[0], hidden_states)
323
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
324
+ return outputs
325
+
326
+
327
+ class BertIntermediate(nn.Module):
328
+ def __init__(self, config):
329
+ super().__init__()
330
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
331
+ if isinstance(config.hidden_act, str):
332
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
333
+ else:
334
+ self.intermediate_act_fn = config.hidden_act
335
+
336
+ def forward(self, hidden_states):
337
+ hidden_states = self.dense(hidden_states)
338
+ hidden_states = self.intermediate_act_fn(hidden_states)
339
+ return hidden_states
340
+
341
+
342
+ class BertOutput(nn.Module):
343
+ def __init__(self, config):
344
+ super().__init__()
345
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
346
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
347
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
348
+
349
+ def forward(self, hidden_states, input_tensor):
350
+ hidden_states = self.dense(hidden_states)
351
+ hidden_states = self.dropout(hidden_states)
352
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
353
+ return hidden_states
354
+
355
+
356
+ class BertLayer(nn.Module):
357
+ def __init__(self, config, layer_num):
358
+ super().__init__()
359
+ self.config = config
360
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
361
+ self.seq_len_dim = 1
362
+ self.attention = BertAttention(config)
363
+ self.layer_num = layer_num
364
+ if self.config.add_cross_attention:
365
+ self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention, layer_num=layer_num)
366
+ self.intermediate = BertIntermediate(config)
367
+ self.output = BertOutput(config)
368
+
369
+ def forward(
370
+ self,
371
+ hidden_states,
372
+ attention_mask=None,
373
+ head_mask=None,
374
+ encoder_hidden_states=None,
375
+ encoder_attention_mask=None,
376
+ past_key_value=None,
377
+ output_attentions=False,
378
+ mode=None,
379
+ ):
380
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
381
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
382
+ self_attention_outputs = self.attention(
383
+ hidden_states,
384
+ attention_mask,
385
+ head_mask,
386
+ output_attentions=output_attentions,
387
+ past_key_value=self_attn_past_key_value,
388
+ )
389
+ attention_output = self_attention_outputs[0]
390
+
391
+ outputs = self_attention_outputs[1:-1]
392
+ present_key_value = self_attention_outputs[-1]
393
+
394
+ if mode=='multimodal':
395
+ assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
396
+ cross_attention_outputs = self.crossattention(
397
+ attention_output,
398
+ attention_mask,
399
+ head_mask,
400
+ encoder_hidden_states,
401
+ encoder_attention_mask,
402
+ output_attentions=output_attentions,
403
+ )
404
+ attention_output = cross_attention_outputs[0]
405
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
406
+ layer_output = apply_chunking_to_forward(
407
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
408
+ )
409
+ outputs = (layer_output,) + outputs
410
+
411
+ outputs = outputs + (present_key_value,)
412
+
413
+ return outputs
414
+
415
+ def feed_forward_chunk(self, attention_output):
416
+ intermediate_output = self.intermediate(attention_output)
417
+ layer_output = self.output(intermediate_output, attention_output)
418
+ return layer_output
419
+
420
+
421
+ class BertEncoder(nn.Module):
422
+ def __init__(self, config):
423
+ super().__init__()
424
+ self.config = config
425
+ self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
426
+ self.gradient_checkpointing = False
427
+
428
+ def forward(
429
+ self,
430
+ hidden_states,
431
+ attention_mask=None,
432
+ head_mask=None,
433
+ encoder_hidden_states=None,
434
+ encoder_attention_mask=None,
435
+ past_key_values=None,
436
+ use_cache=None,
437
+ output_attentions=False,
438
+ output_hidden_states=False,
439
+ return_dict=True,
440
+ mode='multimodal',
441
+ ):
442
+ all_hidden_states = () if output_hidden_states else None
443
+ all_self_attentions = () if output_attentions else None
444
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
445
+
446
+ next_decoder_cache = () if use_cache else None
447
+
448
+ for i in range(self.config.num_hidden_layers):
449
+ layer_module = self.layer[i]
450
+ if output_hidden_states:
451
+ all_hidden_states = all_hidden_states + (hidden_states,)
452
+
453
+ layer_head_mask = head_mask[i] if head_mask is not None else None
454
+ past_key_value = past_key_values[i] if past_key_values is not None else None
455
+
456
+ if self.gradient_checkpointing and self.training:
457
+
458
+ if use_cache:
459
+ logger.warn(
460
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
461
+ )
462
+ use_cache = False
463
+
464
+ def create_custom_forward(module):
465
+ def custom_forward(*inputs):
466
+ return module(*inputs, past_key_value, output_attentions)
467
+
468
+ return custom_forward
469
+
470
+ layer_outputs = torch.utils.checkpoint.checkpoint(
471
+ create_custom_forward(layer_module),
472
+ hidden_states,
473
+ attention_mask,
474
+ layer_head_mask,
475
+ encoder_hidden_states,
476
+ encoder_attention_mask,
477
+ mode=mode,
478
+ )
479
+ else:
480
+ layer_outputs = layer_module(
481
+ hidden_states,
482
+ attention_mask,
483
+ layer_head_mask,
484
+ encoder_hidden_states,
485
+ encoder_attention_mask,
486
+ past_key_value,
487
+ output_attentions,
488
+ mode=mode,
489
+ )
490
+
491
+ hidden_states = layer_outputs[0]
492
+ if use_cache:
493
+ next_decoder_cache += (layer_outputs[-1],)
494
+ if output_attentions:
495
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
496
+
497
+ if output_hidden_states:
498
+ all_hidden_states = all_hidden_states + (hidden_states,)
499
+
500
+ if not return_dict:
501
+ return tuple(
502
+ v
503
+ for v in [
504
+ hidden_states,
505
+ next_decoder_cache,
506
+ all_hidden_states,
507
+ all_self_attentions,
508
+ all_cross_attentions,
509
+ ]
510
+ if v is not None
511
+ )
512
+ return BaseModelOutputWithPastAndCrossAttentions(
513
+ last_hidden_state=hidden_states,
514
+ past_key_values=next_decoder_cache,
515
+ hidden_states=all_hidden_states,
516
+ attentions=all_self_attentions,
517
+ cross_attentions=all_cross_attentions,
518
+ )
519
+
520
+
521
+ class BertPooler(nn.Module):
522
+ def __init__(self, config):
523
+ super().__init__()
524
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
525
+ self.activation = nn.Tanh()
526
+
527
+ def forward(self, hidden_states):
528
+ # We "pool" the model by simply taking the hidden state corresponding
529
+ # to the first token.
530
+ first_token_tensor = hidden_states[:, 0]
531
+ pooled_output = self.dense(first_token_tensor)
532
+ pooled_output = self.activation(pooled_output)
533
+ return pooled_output
534
+
535
+
536
+ class BertPredictionHeadTransform(nn.Module):
537
+ def __init__(self, config):
538
+ super().__init__()
539
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
540
+ if isinstance(config.hidden_act, str):
541
+ self.transform_act_fn = ACT2FN[config.hidden_act]
542
+ else:
543
+ self.transform_act_fn = config.hidden_act
544
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
545
+
546
+ def forward(self, hidden_states):
547
+ hidden_states = self.dense(hidden_states)
548
+ hidden_states = self.transform_act_fn(hidden_states)
549
+ hidden_states = self.LayerNorm(hidden_states)
550
+ return hidden_states
551
+
552
+
553
+ class BertLMPredictionHead(nn.Module):
554
+ def __init__(self, config):
555
+ super().__init__()
556
+ self.transform = BertPredictionHeadTransform(config)
557
+
558
+ # The output weights are the same as the input embeddings, but there is
559
+ # an output-only bias for each token.
560
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
561
+
562
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
563
+
564
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
565
+ self.decoder.bias = self.bias
566
+
567
+ def forward(self, hidden_states):
568
+ hidden_states = self.transform(hidden_states)
569
+ hidden_states = self.decoder(hidden_states)
570
+ return hidden_states
571
+
572
+
573
+ class BertOnlyMLMHead(nn.Module):
574
+ def __init__(self, config):
575
+ super().__init__()
576
+ self.predictions = BertLMPredictionHead(config)
577
+
578
+ def forward(self, sequence_output):
579
+ prediction_scores = self.predictions(sequence_output)
580
+ return prediction_scores
581
+
582
+
583
+ class BertPreTrainedModel(PreTrainedModel):
584
+ """
585
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
586
+ models.
587
+ """
588
+
589
+ config_class = BertConfig
590
+ base_model_prefix = "bert"
591
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
592
+
593
+ def _init_weights(self, module):
594
+ """ Initialize the weights """
595
+ if isinstance(module, (nn.Linear, nn.Embedding)):
596
+ # Slightly different from the TF version which uses truncated_normal for initialization
597
+ # cf https://github.com/pytorch/pytorch/pull/5617
598
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
599
+ elif isinstance(module, nn.LayerNorm):
600
+ module.bias.data.zero_()
601
+ module.weight.data.fill_(1.0)
602
+ if isinstance(module, nn.Linear) and module.bias is not None:
603
+ module.bias.data.zero_()
604
+
605
+
606
+ class BertModel(BertPreTrainedModel):
607
+ """
608
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
609
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
610
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
611
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
612
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
613
+ input to the forward pass.
614
+ """
615
+
616
+ def __init__(self, config, add_pooling_layer=True):
617
+ super().__init__(config)
618
+ self.config = config
619
+
620
+ self.embeddings = BertEmbeddings(config)
621
+
622
+ self.encoder = BertEncoder(config)
623
+
624
+ self.pooler = BertPooler(config) if add_pooling_layer else None
625
+
626
+ self.init_weights()
627
+
628
+
629
+ def get_input_embeddings(self):
630
+ return self.embeddings.word_embeddings
631
+
632
+ def set_input_embeddings(self, value):
633
+ self.embeddings.word_embeddings = value
634
+
635
+ def _prune_heads(self, heads_to_prune):
636
+ """
637
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
638
+ class PreTrainedModel
639
+ """
640
+ for layer, heads in heads_to_prune.items():
641
+ self.encoder.layer[layer].attention.prune_heads(heads)
642
+
643
+
644
+ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
645
+ """
646
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
647
+
648
+ Arguments:
649
+ attention_mask (:obj:`torch.Tensor`):
650
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
651
+ input_shape (:obj:`Tuple[int]`):
652
+ The shape of the input to the model.
653
+ device: (:obj:`torch.device`):
654
+ The device of the input to the model.
655
+
656
+ Returns:
657
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
658
+ """
659
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
660
+ # ourselves in which case we just need to make it broadcastable to all heads.
661
+ if attention_mask.dim() == 3:
662
+ extended_attention_mask = attention_mask[:, None, :, :]
663
+ elif attention_mask.dim() == 2:
664
+ # Provided a padding mask of dimensions [batch_size, seq_length]
665
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
666
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
667
+ if is_decoder:
668
+ batch_size, seq_length = input_shape
669
+
670
+ seq_ids = torch.arange(seq_length, device=device)
671
+ causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
672
+ # in case past_key_values are used we need to add a prefix ones mask to the causal mask
673
+ # causal and attention masks must have same type with pytorch version < 1.3
674
+ causal_mask = causal_mask.to(attention_mask.dtype)
675
+
676
+ if causal_mask.shape[1] < attention_mask.shape[1]:
677
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
678
+ causal_mask = torch.cat(
679
+ [
680
+ torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
681
+ causal_mask,
682
+ ],
683
+ axis=-1,
684
+ )
685
+
686
+ extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
687
+ else:
688
+ extended_attention_mask = attention_mask[:, None, None, :]
689
+ else:
690
+ raise ValueError(
691
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
692
+ input_shape, attention_mask.shape
693
+ )
694
+ )
695
+
696
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
697
+ # masked positions, this operation will create a tensor which is 0.0 for
698
+ # positions we want to attend and -10000.0 for masked positions.
699
+ # Since we are adding it to the raw scores before the softmax, this is
700
+ # effectively the same as removing these entirely.
701
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
702
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
703
+ return extended_attention_mask
704
+
705
+ def forward(
706
+ self,
707
+ input_ids=None,
708
+ attention_mask=None,
709
+ position_ids=None,
710
+ head_mask=None,
711
+ inputs_embeds=None,
712
+ encoder_embeds=None,
713
+ encoder_hidden_states=None,
714
+ encoder_attention_mask=None,
715
+ past_key_values=None,
716
+ use_cache=None,
717
+ output_attentions=None,
718
+ output_hidden_states=None,
719
+ return_dict=None,
720
+ is_decoder=False,
721
+ mode='multimodal',
722
+ ):
723
+ r"""
724
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
725
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
726
+ the model is configured as a decoder.
727
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
728
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
729
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
730
+ - 1 for tokens that are **not masked**,
731
+ - 0 for tokens that are **masked**.
732
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
733
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
734
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
735
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
736
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
737
+ use_cache (:obj:`bool`, `optional`):
738
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
739
+ decoding (see :obj:`past_key_values`).
740
+ """
741
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
742
+ output_hidden_states = (
743
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
744
+ )
745
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
746
+
747
+ if is_decoder:
748
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
749
+ else:
750
+ use_cache = False
751
+
752
+ if input_ids is not None and inputs_embeds is not None:
753
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
754
+ elif input_ids is not None:
755
+ input_shape = input_ids.size()
756
+ batch_size, seq_length = input_shape
757
+ device = input_ids.device
758
+ elif inputs_embeds is not None:
759
+ input_shape = inputs_embeds.size()[:-1]
760
+ batch_size, seq_length = input_shape
761
+ device = inputs_embeds.device
762
+ elif encoder_embeds is not None:
763
+ input_shape = encoder_embeds.size()[:-1]
764
+ batch_size, seq_length = input_shape
765
+ device = encoder_embeds.device
766
+ else:
767
+ raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
768
+
769
+ # past_key_values_length
770
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
771
+
772
+ if attention_mask is None:
773
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
774
+
775
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
776
+ # ourselves in which case we just need to make it broadcastable to all heads.
777
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
778
+ device, is_decoder)
779
+
780
+ # If a 2D or 3D attention mask is provided for the cross-attention
781
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
782
+ if encoder_hidden_states is not None:
783
+ if type(encoder_hidden_states) == list:
784
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
785
+ else:
786
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
787
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
788
+
789
+ if type(encoder_attention_mask) == list:
790
+ encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
791
+ elif encoder_attention_mask is None:
792
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
793
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
794
+ else:
795
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
796
+ else:
797
+ encoder_extended_attention_mask = None
798
+
799
+ # Prepare head mask if needed
800
+ # 1.0 in head_mask indicate we keep the head
801
+ # attention_probs has shape bsz x n_heads x N x N
802
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
803
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
804
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
805
+
806
+ if encoder_embeds is None:
807
+ embedding_output = self.embeddings(
808
+ input_ids=input_ids,
809
+ position_ids=position_ids,
810
+ inputs_embeds=inputs_embeds,
811
+ past_key_values_length=past_key_values_length,
812
+ )
813
+ else:
814
+ embedding_output = encoder_embeds
815
+
816
+ encoder_outputs = self.encoder(
817
+ embedding_output,
818
+ attention_mask=extended_attention_mask,
819
+ head_mask=head_mask,
820
+ encoder_hidden_states=encoder_hidden_states,
821
+ encoder_attention_mask=encoder_extended_attention_mask,
822
+ past_key_values=past_key_values,
823
+ use_cache=use_cache,
824
+ output_attentions=output_attentions,
825
+ output_hidden_states=output_hidden_states,
826
+ return_dict=return_dict,
827
+ mode=mode,
828
+ )
829
+ sequence_output = encoder_outputs[0]
830
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
831
+
832
+ if not return_dict:
833
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
834
+
835
+ return BaseModelOutputWithPoolingAndCrossAttentions(
836
+ last_hidden_state=sequence_output,
837
+ pooler_output=pooled_output,
838
+ past_key_values=encoder_outputs.past_key_values,
839
+ hidden_states=encoder_outputs.hidden_states,
840
+ attentions=encoder_outputs.attentions,
841
+ cross_attentions=encoder_outputs.cross_attentions,
842
+ )
843
+
models/vit.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ * Based on timm code base
8
+ * https://github.com/rwightman/pytorch-image-models/tree/master/timm
9
+ '''
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from functools import partial
15
+
16
+ from timm.models.vision_transformer import _cfg, PatchEmbed
17
+ from timm.models.registry import register_model
18
+ from timm.models.layers import trunc_normal_, DropPath
19
+ from timm.models.helpers import named_apply, adapt_input_conv
20
+
21
+ from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
22
+
23
+ class Mlp(nn.Module):
24
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
25
+ """
26
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
27
+ super().__init__()
28
+ out_features = out_features or in_features
29
+ hidden_features = hidden_features or in_features
30
+ self.fc1 = nn.Linear(in_features, hidden_features)
31
+ self.act = act_layer()
32
+ self.fc2 = nn.Linear(hidden_features, out_features)
33
+ self.drop = nn.Dropout(drop)
34
+
35
+ def forward(self, x):
36
+ x = self.fc1(x)
37
+ x = self.act(x)
38
+ x = self.drop(x)
39
+ x = self.fc2(x)
40
+ x = self.drop(x)
41
+ return x
42
+
43
+
44
+ class Attention(nn.Module):
45
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
46
+ super().__init__()
47
+ self.num_heads = num_heads
48
+ head_dim = dim // num_heads
49
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
50
+ self.scale = qk_scale or head_dim ** -0.5
51
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
52
+ self.attn_drop = nn.Dropout(attn_drop)
53
+ self.proj = nn.Linear(dim, dim)
54
+ self.proj_drop = nn.Dropout(proj_drop)
55
+ self.attn_gradients = None
56
+ self.attention_map = None
57
+
58
+ def save_attn_gradients(self, attn_gradients):
59
+ self.attn_gradients = attn_gradients
60
+
61
+ def get_attn_gradients(self):
62
+ return self.attn_gradients
63
+
64
+ def save_attention_map(self, attention_map):
65
+ self.attention_map = attention_map
66
+
67
+ def get_attention_map(self):
68
+ return self.attention_map
69
+
70
+ def forward(self, x, register_hook=False):
71
+ B, N, C = x.shape
72
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
73
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
74
+
75
+ attn = (q @ k.transpose(-2, -1)) * self.scale
76
+ attn = attn.softmax(dim=-1)
77
+ attn = self.attn_drop(attn)
78
+
79
+ if register_hook:
80
+ self.save_attention_map(attn)
81
+ attn.register_hook(self.save_attn_gradients)
82
+
83
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
84
+ x = self.proj(x)
85
+ x = self.proj_drop(x)
86
+ return x
87
+
88
+
89
+ class Block(nn.Module):
90
+
91
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
92
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
93
+ super().__init__()
94
+ self.norm1 = norm_layer(dim)
95
+ self.attn = Attention(
96
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
97
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
98
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
99
+ self.norm2 = norm_layer(dim)
100
+ mlp_hidden_dim = int(dim * mlp_ratio)
101
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
102
+
103
+ if use_grad_checkpointing:
104
+ self.attn = checkpoint_wrapper(self.attn)
105
+ self.mlp = checkpoint_wrapper(self.mlp)
106
+
107
+ def forward(self, x, register_hook=False):
108
+ x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
109
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
110
+ return x
111
+
112
+
113
+ class VisionTransformer(nn.Module):
114
+ """ Vision Transformer
115
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
116
+ https://arxiv.org/abs/2010.11929
117
+ """
118
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
119
+ num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
120
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
121
+ use_grad_checkpointing=False, ckpt_layer=0):
122
+ """
123
+ Args:
124
+ img_size (int, tuple): input image size
125
+ patch_size (int, tuple): patch size
126
+ in_chans (int): number of input channels
127
+ num_classes (int): number of classes for classification head
128
+ embed_dim (int): embedding dimension
129
+ depth (int): depth of transformer
130
+ num_heads (int): number of attention heads
131
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
132
+ qkv_bias (bool): enable bias for qkv if True
133
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set
134
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
135
+ drop_rate (float): dropout rate
136
+ attn_drop_rate (float): attention dropout rate
137
+ drop_path_rate (float): stochastic depth rate
138
+ norm_layer: (nn.Module): normalization layer
139
+ """
140
+ super().__init__()
141
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
142
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
143
+
144
+ self.patch_embed = PatchEmbed(
145
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
146
+
147
+ num_patches = self.patch_embed.num_patches
148
+
149
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
150
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
151
+ self.pos_drop = nn.Dropout(p=drop_rate)
152
+
153
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
154
+ self.blocks = nn.ModuleList([
155
+ Block(
156
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
157
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
158
+ use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer)
159
+ )
160
+ for i in range(depth)])
161
+ self.norm = norm_layer(embed_dim)
162
+
163
+ trunc_normal_(self.pos_embed, std=.02)
164
+ trunc_normal_(self.cls_token, std=.02)
165
+ self.apply(self._init_weights)
166
+
167
+ def _init_weights(self, m):
168
+ if isinstance(m, nn.Linear):
169
+ trunc_normal_(m.weight, std=.02)
170
+ if isinstance(m, nn.Linear) and m.bias is not None:
171
+ nn.init.constant_(m.bias, 0)
172
+ elif isinstance(m, nn.LayerNorm):
173
+ nn.init.constant_(m.bias, 0)
174
+ nn.init.constant_(m.weight, 1.0)
175
+
176
+ @torch.jit.ignore
177
+ def no_weight_decay(self):
178
+ return {'pos_embed', 'cls_token'}
179
+
180
+ def forward(self, x, register_blk=-1):
181
+ B = x.shape[0]
182
+ x = self.patch_embed(x)
183
+
184
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
185
+ x = torch.cat((cls_tokens, x), dim=1)
186
+
187
+ x = x + self.pos_embed[:,:x.size(1),:]
188
+ x = self.pos_drop(x)
189
+
190
+ for i,blk in enumerate(self.blocks):
191
+ x = blk(x, register_blk==i)
192
+ x = self.norm(x)
193
+
194
+ return x
195
+
196
+ @torch.jit.ignore()
197
+ def load_pretrained(self, checkpoint_path, prefix=''):
198
+ _load_weights(self, checkpoint_path, prefix)
199
+
200
+
201
+ @torch.no_grad()
202
+ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
203
+ """ Load weights from .npz checkpoints for official Google Brain Flax implementation
204
+ """
205
+ import numpy as np
206
+
207
+ def _n2p(w, t=True):
208
+ if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
209
+ w = w.flatten()
210
+ if t:
211
+ if w.ndim == 4:
212
+ w = w.transpose([3, 2, 0, 1])
213
+ elif w.ndim == 3:
214
+ w = w.transpose([2, 0, 1])
215
+ elif w.ndim == 2:
216
+ w = w.transpose([1, 0])
217
+ return torch.from_numpy(w)
218
+
219
+ w = np.load(checkpoint_path)
220
+ if not prefix and 'opt/target/embedding/kernel' in w:
221
+ prefix = 'opt/target/'
222
+
223
+ if hasattr(model.patch_embed, 'backbone'):
224
+ # hybrid
225
+ backbone = model.patch_embed.backbone
226
+ stem_only = not hasattr(backbone, 'stem')
227
+ stem = backbone if stem_only else backbone.stem
228
+ stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
229
+ stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
230
+ stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
231
+ if not stem_only:
232
+ for i, stage in enumerate(backbone.stages):
233
+ for j, block in enumerate(stage.blocks):
234
+ bp = f'{prefix}block{i + 1}/unit{j + 1}/'
235
+ for r in range(3):
236
+ getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
237
+ getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
238
+ getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
239
+ if block.downsample is not None:
240
+ block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
241
+ block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
242
+ block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
243
+ embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
244
+ else:
245
+ embed_conv_w = adapt_input_conv(
246
+ model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
247
+ model.patch_embed.proj.weight.copy_(embed_conv_w)
248
+ model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
249
+ model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
250
+ pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
251
+ if pos_embed_w.shape != model.pos_embed.shape:
252
+ pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
253
+ pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
254
+ model.pos_embed.copy_(pos_embed_w)
255
+ model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
256
+ model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
257
+ # if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
258
+ # model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
259
+ # model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
260
+ # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
261
+ # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
262
+ # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
263
+ for i, block in enumerate(model.blocks.children()):
264
+ block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
265
+ mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
266
+ block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
267
+ block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
268
+ block.attn.qkv.weight.copy_(torch.cat([
269
+ _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
270
+ block.attn.qkv.bias.copy_(torch.cat([
271
+ _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
272
+ block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
273
+ block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
274
+ for r in range(2):
275
+ getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
276
+ getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
277
+ block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
278
+ block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
279
+
280
+
281
+ def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
282
+ # interpolate position embedding
283
+ embedding_size = pos_embed_checkpoint.shape[-1]
284
+ num_patches = visual_encoder.patch_embed.num_patches
285
+ num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
286
+ # height (== width) for the checkpoint position embedding
287
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
288
+ # height (== width) for the new position embedding
289
+ new_size = int(num_patches ** 0.5)
290
+
291
+ if orig_size!=new_size:
292
+ # class_token and dist_token are kept unchanged
293
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
294
+ # only the position tokens are interpolated
295
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
296
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
297
+ pos_tokens = torch.nn.functional.interpolate(
298
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
299
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
300
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
301
+ print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
302
+
303
+ return new_pos_embed
304
+ else:
305
+ return pos_embed_checkpoint
predict.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Download the weights in ./checkpoints beforehand for fast inference
3
+ wget https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth
4
+ wget https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_vqa.pth
5
+ wget https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth
6
+ """
7
+
8
+ from pathlib import Path
9
+
10
+ from PIL import Image
11
+ import torch
12
+ from torchvision import transforms
13
+ from torchvision.transforms.functional import InterpolationMode
14
+ import cog
15
+
16
+ from models.blip import blip_decoder
17
+ from models.blip_vqa import blip_vqa
18
+ from models.blip_itm import blip_itm
19
+
20
+
21
+ class Predictor(cog.Predictor):
22
+ def setup(self):
23
+ self.device = "cuda:0"
24
+
25
+ self.models = {
26
+ 'image_captioning': blip_decoder(pretrained='checkpoints/model*_base_caption.pth',
27
+ image_size=384, vit='base'),
28
+ 'visual_question_answering': blip_vqa(pretrained='checkpoints/model*_vqa.pth',
29
+ image_size=480, vit='base'),
30
+ 'image_text_matching': blip_itm(pretrained='checkpoints/model_base_retrieval_coco.pth',
31
+ image_size=384, vit='base')
32
+ }
33
+
34
+ @cog.input(
35
+ "image",
36
+ type=Path,
37
+ help="input image",
38
+ )
39
+ @cog.input(
40
+ "task",
41
+ type=str,
42
+ default='image_captioning',
43
+ options=['image_captioning', 'visual_question_answering', 'image_text_matching'],
44
+ help="Choose a task.",
45
+ )
46
+ @cog.input(
47
+ "question",
48
+ type=str,
49
+ default=None,
50
+ help="Type question for the input image for visual question answering task.",
51
+ )
52
+ @cog.input(
53
+ "caption",
54
+ type=str,
55
+ default=None,
56
+ help="Type caption for the input image for image text matching task.",
57
+ )
58
+ def predict(self, image, task, question, caption):
59
+ if task == 'visual_question_answering':
60
+ assert question is not None, 'Please type a question for visual question answering task.'
61
+ if task == 'image_text_matching':
62
+ assert caption is not None, 'Please type a caption for mage text matching task.'
63
+
64
+ im = load_image(image, image_size=480 if task == 'visual_question_answering' else 384, device=self.device)
65
+ model = self.models[task]
66
+ model.eval()
67
+ model = model.to(self.device)
68
+
69
+ if task == 'image_captioning':
70
+ with torch.no_grad():
71
+ caption = model.generate(im, sample=False, num_beams=3, max_length=20, min_length=5)
72
+ return 'Caption: ' + caption[0]
73
+
74
+ if task == 'visual_question_answering':
75
+ with torch.no_grad():
76
+ answer = model(im, question, train=False, inference='generate')
77
+ return 'Answer: ' + answer[0]
78
+
79
+ # image_text_matching
80
+ itm_output = model(im, caption, match_head='itm')
81
+ itm_score = torch.nn.functional.softmax(itm_output, dim=1)[:, 1]
82
+ itc_score = model(im, caption, match_head='itc')
83
+ return f'The image and text is matched with a probability of {itm_score.item():.4f}.\n' \
84
+ f'The image feature and text feature has a cosine similarity of {itc_score.item():.4f}.'
85
+
86
+
87
+ def load_image(image, image_size, device):
88
+ raw_image = Image.open(str(image)).convert('RGB')
89
+
90
+ w, h = raw_image.size
91
+
92
+ transform = transforms.Compose([
93
+ transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
94
+ transforms.ToTensor(),
95
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
96
+ ])
97
+ image = transform(raw_image).unsqueeze(0).to(device)
98
+ return image
pretrain.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ '''
8
+ import argparse
9
+ import os
10
+ import ruamel_yaml as yaml
11
+ import numpy as np
12
+ import random
13
+ import time
14
+ import datetime
15
+ import json
16
+ from pathlib import Path
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ import torch.backends.cudnn as cudnn
22
+ import torch.distributed as dist
23
+ from torch.utils.data import DataLoader
24
+
25
+ from models.blip_pretrain import blip_pretrain
26
+ import utils
27
+ from utils import warmup_lr_schedule, step_lr_schedule
28
+ from data import create_dataset, create_sampler, create_loader
29
+
30
+ def train(model, data_loader, optimizer, epoch, device, config):
31
+ # train
32
+ model.train()
33
+
34
+ metric_logger = utils.MetricLogger(delimiter=" ")
35
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=50, fmt='{value:.6f}'))
36
+ metric_logger.add_meter('loss_ita', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
37
+ metric_logger.add_meter('loss_itm', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
38
+ metric_logger.add_meter('loss_lm', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
39
+
40
+ header = 'Train Epoch: [{}]'.format(epoch)
41
+ print_freq = 50
42
+
43
+ if config['laion_path']:
44
+ data_loader.dataset.reload_laion(epoch)
45
+
46
+ data_loader.sampler.set_epoch(epoch)
47
+
48
+ for i, (image, caption) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
49
+
50
+ if epoch==0:
51
+ warmup_lr_schedule(optimizer, i, config['warmup_steps'], config['warmup_lr'], config['init_lr'])
52
+
53
+ optimizer.zero_grad()
54
+
55
+ image = image.to(device,non_blocking=True)
56
+
57
+ # ramp up alpha in the first 2 epochs
58
+ alpha = config['alpha']*min(1,(epoch*len(data_loader)+i)/(2*len(data_loader)))
59
+
60
+ loss_ita, loss_itm, loss_lm = model(image, caption, alpha = alpha)
61
+ loss = loss_ita + loss_itm + loss_lm
62
+
63
+ loss.backward()
64
+ optimizer.step()
65
+
66
+ metric_logger.update(loss_ita=loss_ita.item())
67
+ metric_logger.update(loss_itm=loss_itm.item())
68
+ metric_logger.update(loss_lm=loss_lm.item())
69
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
70
+
71
+
72
+ # gather the stats from all processes
73
+ metric_logger.synchronize_between_processes()
74
+ print("Averaged stats:", metric_logger.global_avg())
75
+ return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
76
+
77
+
78
+ def main(args, config):
79
+ utils.init_distributed_mode(args)
80
+
81
+ device = torch.device(args.device)
82
+
83
+ # fix the seed for reproducibility
84
+ seed = args.seed + utils.get_rank()
85
+ torch.manual_seed(seed)
86
+ np.random.seed(seed)
87
+ random.seed(seed)
88
+ cudnn.benchmark = True
89
+
90
+ #### Dataset ####
91
+ print("Creating dataset")
92
+ datasets = [create_dataset('pretrain', config, min_scale=0.2)]
93
+ print('number of training samples: %d'%len(datasets[0]))
94
+
95
+ num_tasks = utils.get_world_size()
96
+ global_rank = utils.get_rank()
97
+ samplers = create_sampler(datasets, [True], num_tasks, global_rank)
98
+
99
+ data_loader = create_loader(datasets,samplers,batch_size=[config['batch_size']], num_workers=[4], is_trains=[True], collate_fns=[None])[0]
100
+
101
+ #### Model ####
102
+ print("Creating model")
103
+ model = blip_pretrain(image_size=config['image_size'], vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'],
104
+ vit_ckpt_layer=config['vit_ckpt_layer'], queue_size=config['queue_size'])
105
+
106
+ model = model.to(device)
107
+
108
+ optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay'])
109
+
110
+ start_epoch = 0
111
+ if args.checkpoint:
112
+ checkpoint = torch.load(args.checkpoint, map_location='cpu')
113
+ state_dict = checkpoint['model']
114
+ model.load_state_dict(state_dict)
115
+
116
+ optimizer.load_state_dict(checkpoint['optimizer'])
117
+ start_epoch = checkpoint['epoch']+1
118
+ print('resume checkpoint from %s'%args.checkpoint)
119
+
120
+ model_without_ddp = model
121
+ if args.distributed:
122
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
123
+ model_without_ddp = model.module
124
+
125
+ print("Start training")
126
+ start_time = time.time()
127
+ for epoch in range(start_epoch, config['max_epoch']):
128
+
129
+ step_lr_schedule(optimizer, epoch, config['init_lr'], config['min_lr'], config['lr_decay_rate'])
130
+
131
+ train_stats = train(model, data_loader, optimizer, epoch, device, config)
132
+ if utils.is_main_process():
133
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
134
+ 'epoch': epoch,
135
+ }
136
+ save_obj = {
137
+ 'model': model_without_ddp.state_dict(),
138
+ 'optimizer': optimizer.state_dict(),
139
+ 'config': config,
140
+ 'epoch': epoch,
141
+ }
142
+ torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_%02d.pth'%epoch))
143
+
144
+ with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
145
+ f.write(json.dumps(log_stats) + "\n")
146
+
147
+ dist.barrier()
148
+
149
+ total_time = time.time() - start_time
150
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
151
+ print('Training time {}'.format(total_time_str))
152
+
153
+
154
+ if __name__ == '__main__':
155
+ parser = argparse.ArgumentParser()
156
+ parser.add_argument('--config', default='./configs/pretrain.yaml')
157
+ parser.add_argument('--output_dir', default='output/Pretrain')
158
+ parser.add_argument('--checkpoint', default='')
159
+ parser.add_argument('--evaluate', action='store_true')
160
+ parser.add_argument('--device', default='cuda')
161
+ parser.add_argument('--seed', default=42, type=int)
162
+ parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
163
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
164
+ parser.add_argument('--distributed', default=True, type=bool)
165
+ args = parser.parse_args()
166
+
167
+ config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
168
+
169
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
170
+
171
+ yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
172
+
173
+ main(args, config)
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
1
+ timm==0.4.12
2
+ transformers==4.15.0
3
+ fairscale==0.4.4
4
+ pycocoevalcap
5
+ Pillow
6
+ pandas
7
+ torch
8
+ torchvision
9
+ cohere
10
+ gradio