Upload 32 files
Browse files- .gitattributes +2 -0
- dataset_card/coco_caption.md +41 -0
- dataset_card/imgs/coco_caption.png +3 -0
- dataset_card/protein_function.md +0 -0
- docs/Makefile +20 -0
- docs/_static/Confusing-Pictures.jpg +0 -0
- docs/_static/architecture.png +0 -0
- docs/_static/logo_final.png +0 -0
- docs/_static/merlion.png +3 -0
- docs/benchmark.rst +348 -0
- docs/build_docs.sh +101 -0
- docs/conf.py +56 -0
- docs/getting_started.rst +233 -0
- docs/index.rst +46 -0
- docs/intro.rst +99 -0
- docs/make.bat +35 -0
- docs/requirements.txt +7 -0
- docs/tutorial.configs.rst +172 -0
- docs/tutorial.datasets.rst +424 -0
- docs/tutorial.evaluation.rst +40 -0
- docs/tutorial.models.rst +245 -0
- docs/tutorial.processors.rst +233 -0
- docs/tutorial.rst +13 -0
- docs/tutorial.tasks.rst +184 -0
- docs/tutorial.training-example.rst +145 -0
- examples/blip2_itm.py +520 -0
- examples/blip2_predict_func.py +178 -0
- examples/blip2_predict_func_concat.py +193 -0
- examples/blip2_predict_func_concat_pretrain.py +197 -0
- examples/blip2_predict_func_concat_timesplit.py +166 -0
- examples/blip2_predict_names.py +247 -0
- examples/predict_test.sh +14 -0
- examples/predict_train.sh +14 -0
.gitattributes
CHANGED
@@ -39,3 +39,5 @@ data/go1.4-basic.obo filter=lfs diff=lfs merge=lfs -text
|
|
39 |
data/swissprot_exp/train_exp_prompt_bp_new.csv filter=lfs diff=lfs merge=lfs -text
|
40 |
data/swissprot_exp/train_exp_prompt_cc_new.csv filter=lfs diff=lfs merge=lfs -text
|
41 |
data/swissprot_exp/train_exp_prompt_mf_new.csv filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
39 |
data/swissprot_exp/train_exp_prompt_bp_new.csv filter=lfs diff=lfs merge=lfs -text
|
40 |
data/swissprot_exp/train_exp_prompt_cc_new.csv filter=lfs diff=lfs merge=lfs -text
|
41 |
data/swissprot_exp/train_exp_prompt_mf_new.csv filter=lfs diff=lfs merge=lfs -text
|
42 |
+
dataset_card/imgs/coco_caption.png filter=lfs diff=lfs merge=lfs -text
|
43 |
+
docs/_static/merlion.png filter=lfs diff=lfs merge=lfs -text
|
dataset_card/coco_caption.md
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
![Samples from the COCO Caption dataset (Image credit: "https://arxiv.org/pdf/1504.00325.pdf").](imgs/coco_caption.png)(Samples from the COCO Caption dataset. Image credit: "https://arxiv.org/pdf/1504.00325.pdf")
|
2 |
+
|
3 |
+
# Microsoft COCO Dataset (Captioning)
|
4 |
+
|
5 |
+
## Description
|
6 |
+
[Microsoft COCO Captions dataset](https://github.com/tylin/coco-caption) contains over one and a half million captions describing over 330,000 images. For the training and validation images, five independent human generated captions are be provided for each image.
|
7 |
+
|
8 |
+
## Task
|
9 |
+
|
10 |
+
(from https://paperswithcode.com/task/image-captioning)
|
11 |
+
|
12 |
+
**Image captioning** is the task of describing the content of an image in words. This task lies at the intersection of computer vision and natural language processing. Most image captioning systems use an encoder-decoder framework, where an input image is encoded into an intermediate representation of the information in the image, and then decoded into a descriptive text sequence.
|
13 |
+
|
14 |
+
## Metrics
|
15 |
+
Models are typically evaluated according to a [BLEU](https://aclanthology.org/P02-1040/) or [CIDER](https://www.cv-foundation.org/openaccess/content_cvpr_2015/papers/Vedantam_CIDEr_Consensus-Based_Image_2015_CVPR_paper.pdf) metric.
|
16 |
+
|
17 |
+
## Leaderboard
|
18 |
+
|
19 |
+
(Ranked by BLEU-4)
|
20 |
+
|
21 |
+
| Rank | Model | BLEU-4 | CIDEr | METEOR | SPICE | Resources |
|
22 |
+
| ---- | :-----: | :----: | :---: | :----: | :---: | :----------------------------------------------------------------------------------------------------------------------------------------------: |
|
23 |
+
| 1 | OFA | 44.9 | 154.9 | 32.5 | 26.6 | [paper](https://arxiv.org/abs/2202.03052), [code](https://github.com/OFA-Sys/OFA) |
|
24 |
+
| 2 | LEMON | 42.6 | 145.5 | 31.4 | 25.5 | [paper]() |
|
25 |
+
| 3 | CoCa | 40.9 | 143.6 | 33.9 | 24.7 | [paper](https://arxiv.org/pdf/2205.01917.pdf) |
|
26 |
+
| 4 | SimVLM | 40.6 | 143.3 | 33.7 | 25.4 | [paper](https://openreview.net/pdf?id=GUrhfTuf_3) |
|
27 |
+
| 5 | VinVL | 41.0 | 140.9 | 31.1 | 25.2 | [paper](https://arxiv.org/pdf/2101.00529v2.pdf), [code](https://github.com/microsoft/Oscar) |
|
28 |
+
| 6 | OSCAR | 40.7 | 140.0 | 30.6 | 24.5 | [paper](https://arxiv.org/pdf/2004.06165v5.pdf), [code](https://github.com/microsoft/Oscar) |
|
29 |
+
| 7 | BLIP | 40.4 | 136.7 | 31.4 | 24.3 | [paper](https://arxiv.org/pdf/2201.12086.pdf), [code](https://github.com/salesforce/BLIP), [demo](https://huggingface.co/spaces/Salesforce/BLIP) |
|
30 |
+
| 8 | M^2 | 39.1 | 131.2 | 29.2 | 22.6 | [paper](https://arxiv.org/pdf/1912.08226v2.pdf), [code](https://github.com/aimagelab/meshed-memory-transformer) |
|
31 |
+
| 9 | BUTD | 36.5 | 113.5 | 27.0 | 20.3 | [paper](https://arxiv.org/abs/1707.07998?context=cs), [code](https://github.com/peteanderson80/bottom-up-attention) |
|
32 |
+
| 10 | ClipCap | 32.2 | 108.4 | 27.1 | 20.1 | [paper](https://arxiv.org/pdf/2111.09734v1.pdf), [code](https://github.com/rmokady/clip_prefix_caption) |
|
33 |
+
|
34 |
+
## Auto-Downloading
|
35 |
+
|
36 |
+
```
|
37 |
+
cd lavis/datasets/download_scripts && python download_coco.py
|
38 |
+
```
|
39 |
+
|
40 |
+
## References
|
41 |
+
"Microsoft COCO Captions: Data Collection and Evaluation Server", Xinlei Chen, Hao Fang, Tsung-Yi Lin, Ramakrishna Vedantam, Saurabh Gupta, Piotr Dollar, C. Lawrence Zitnick
|
dataset_card/imgs/coco_caption.png
ADDED
Git LFS Details
|
dataset_card/protein_function.md
ADDED
File without changes
|
docs/Makefile
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Minimal makefile for Sphinx documentation
|
2 |
+
#
|
3 |
+
|
4 |
+
# You can set these variables from the command line, and also
|
5 |
+
# from the environment for the first two.
|
6 |
+
SPHINXOPTS ?=
|
7 |
+
SPHINXBUILD ?= sphinx-build
|
8 |
+
SOURCEDIR = source
|
9 |
+
BUILDDIR = build
|
10 |
+
|
11 |
+
# Put it first so that "make" without argument is like "make help".
|
12 |
+
help:
|
13 |
+
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
14 |
+
|
15 |
+
.PHONY: help Makefile
|
16 |
+
|
17 |
+
# Catch-all target: route all unknown targets to Sphinx using the new
|
18 |
+
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
19 |
+
%: Makefile
|
20 |
+
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
docs/_static/Confusing-Pictures.jpg
ADDED
docs/_static/architecture.png
ADDED
docs/_static/logo_final.png
ADDED
docs/_static/merlion.png
ADDED
Git LFS Details
|
docs/benchmark.rst
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Benchmark
|
2 |
+
############
|
3 |
+
|
4 |
+
We provide scripts for evaluating and training models on task datasets. The following benchmark results are included for reference.
|
5 |
+
|
6 |
+
|
7 |
+
ALBEF
|
8 |
+
*******
|
9 |
+
.. list-table::
|
10 |
+
:widths: 30 80 20
|
11 |
+
|
12 |
+
* - **Pretraining**
|
13 |
+
- COCO (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
|
14 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/train/pretrain.sh>`__
|
15 |
+
* -
|
16 |
+
- Visual Genome (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_vg.py>`__)
|
17 |
+
-
|
18 |
+
* -
|
19 |
+
- SBU (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_sbu.py>`__)
|
20 |
+
-
|
21 |
+
* -
|
22 |
+
- CC3M (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/DownloadConceptualCaptions/download_data_cc3m.py>`__)
|
23 |
+
-
|
24 |
+
* -
|
25 |
+
- CC12M (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/DownloadConceptualCaptions/download_data_cc12m.py>`__)
|
26 |
+
-
|
27 |
+
|
28 |
+
.. list-table::
|
29 |
+
:widths: 30 40 20 20 20 30 30
|
30 |
+
:header-rows: 1
|
31 |
+
|
32 |
+
* -
|
33 |
+
- **Retrieval**
|
34 |
+
- **R1**
|
35 |
+
- **R5**
|
36 |
+
- **R10**
|
37 |
+
- **Training**
|
38 |
+
- **Evaluation**
|
39 |
+
* - TR
|
40 |
+
- COCO (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
|
41 |
+
- 77.6
|
42 |
+
- 94.1
|
43 |
+
- 97.2
|
44 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/train/train_coco_retrieval_albef.sh>`__
|
45 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/eval/eval_coco_retrieval.sh>`__
|
46 |
+
* - IR
|
47 |
+
- COCO (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
|
48 |
+
- 61.0
|
49 |
+
- 84.5
|
50 |
+
- 90.7
|
51 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/train/train_coco_retrieval_albef.sh>`__
|
52 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/eval/eval_coco_retrieval.sh>`__
|
53 |
+
* - TR
|
54 |
+
- Flickr30k (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_flickr.py>`__)
|
55 |
+
- 77.6
|
56 |
+
- 94.1
|
57 |
+
- 97.2
|
58 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/train/train_flickr30k_retrieval_albef.sh>`__
|
59 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/eval/eval_flickr30k_retrieval.sh>`__
|
60 |
+
* - IR
|
61 |
+
- Flickr30k (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_flickr.py>`__)
|
62 |
+
- 61.0
|
63 |
+
- 84.5
|
64 |
+
- 90.7
|
65 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/train/train_flickr30k_retrieval_albef.sh>`__
|
66 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/eval/eval_flickr30k_retrieval.sh>`__
|
67 |
+
|
68 |
+
|
69 |
+
.. list-table::
|
70 |
+
:widths: 20 20 20 20 20
|
71 |
+
:header-rows: 1
|
72 |
+
|
73 |
+
* - **VQA**
|
74 |
+
- **test-dev**
|
75 |
+
- **test-std/test**
|
76 |
+
- **Training**
|
77 |
+
- **Evaluation**
|
78 |
+
* - VQAv2 (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
|
79 |
+
- 76.35
|
80 |
+
- 76.54
|
81 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/train/train_vqa_albef.sh>`__
|
82 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/eval/test_albef_vqa.sh>`__
|
83 |
+
* - OKVQA (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
|
84 |
+
- NA
|
85 |
+
- 54.7
|
86 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/train/train_okvqa_albef.sh>`__
|
87 |
+
- NA
|
88 |
+
* - AOKVQA (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
|
89 |
+
- 54.5
|
90 |
+
- NA
|
91 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/train/train_aokvqa_albef.sh>`__
|
92 |
+
- NA
|
93 |
+
|
94 |
+
|
95 |
+
.. list-table::
|
96 |
+
:widths: 20 20 20 20 20
|
97 |
+
:header-rows: 1
|
98 |
+
|
99 |
+
* - **Multimodal Classification**
|
100 |
+
- **val**
|
101 |
+
- **test**
|
102 |
+
- **Training**
|
103 |
+
- **Evaluation**
|
104 |
+
* - SNLI-VE (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
|
105 |
+
- 80.60
|
106 |
+
- 81.04
|
107 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/train/train_ve_albef.sh>`__
|
108 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/eval/eval_albef_ve.sh>`__
|
109 |
+
* - NLVR2 (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
|
110 |
+
- 82.47
|
111 |
+
- 82.91
|
112 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/train/train_nlvr_albef.sh>`__
|
113 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/eval/eval_albef_nlvr.sh>`__
|
114 |
+
|
115 |
+
BLIP
|
116 |
+
*******
|
117 |
+
.. list-table::
|
118 |
+
:widths: 30 80 20
|
119 |
+
|
120 |
+
* - **Pretraining (14M)**
|
121 |
+
- COCO (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
|
122 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/train/pretrain.sh>`__
|
123 |
+
* -
|
124 |
+
- Visual Genome (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_vg.py>`__)
|
125 |
+
-
|
126 |
+
* -
|
127 |
+
- SBU (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_sbu.py>`__)
|
128 |
+
-
|
129 |
+
* -
|
130 |
+
- CC3M (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/DownloadConceptualCaptions/download_data_cc3m.py>`__)
|
131 |
+
-
|
132 |
+
* -
|
133 |
+
- CC12M (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/DownloadConceptualCaptions/download_data_cc12m.py>`__)
|
134 |
+
-
|
135 |
+
|
136 |
+
.. list-table::
|
137 |
+
:widths: 30 40 20 20 20 30 30
|
138 |
+
:header-rows: 1
|
139 |
+
|
140 |
+
* - **Tasks**
|
141 |
+
- **Retrieval**
|
142 |
+
- **R1**
|
143 |
+
- **R5**
|
144 |
+
- **R10**
|
145 |
+
- **Training**
|
146 |
+
- **Evaluation**
|
147 |
+
* - TR
|
148 |
+
- COCO (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
|
149 |
+
- 82.0
|
150 |
+
- 95.8
|
151 |
+
- 98.1
|
152 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/train/train_retrieval_coco.sh>`__
|
153 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/eval/eval_ret_coco.sh>`__
|
154 |
+
* - IR
|
155 |
+
- COCO (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
|
156 |
+
- 64.5
|
157 |
+
- 86.0
|
158 |
+
- 91.7
|
159 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/train/train_retrieval_coco.sh>`__
|
160 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/eval/eval_ret_coco.sh>`__
|
161 |
+
* - TR
|
162 |
+
- Flickr30k (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_flickr.py>`__)
|
163 |
+
- 96.9
|
164 |
+
- 99.9
|
165 |
+
- 100.0
|
166 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/train/train_retrieval_flickr.sh>`__
|
167 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/eval/eval_ret_flickr.sh>`__
|
168 |
+
* - IR
|
169 |
+
- Flickr30k (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_flickr.py>`__)
|
170 |
+
- 87.5
|
171 |
+
- 97.6
|
172 |
+
- 98.9
|
173 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/train/train_retrieval_flickr.sh>`__
|
174 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/eval/eval_ret_flickr.sh>`__
|
175 |
+
|
176 |
+
|
177 |
+
.. list-table::
|
178 |
+
:widths: 20 20 20 20 20
|
179 |
+
:header-rows: 1
|
180 |
+
|
181 |
+
* - **VQA**
|
182 |
+
- **test-dev**
|
183 |
+
- **test-std/test**
|
184 |
+
- **Training**
|
185 |
+
- **Evaluation**
|
186 |
+
* - VQAv2 (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
|
187 |
+
- 78.23
|
188 |
+
- 78.29
|
189 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/train/train_vqa_albef.sh>`__
|
190 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/eval/test_albef_vqa.sh>`__
|
191 |
+
* - OKVQA (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
|
192 |
+
- NA
|
193 |
+
- 55.4
|
194 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/train/train_okvqa.sh>`__
|
195 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/eval/eval_okvqa.sh>`__
|
196 |
+
* - AOKVQA (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
|
197 |
+
- 56.2
|
198 |
+
- 50.1
|
199 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/train/train_aokvqa.sh>`__
|
200 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/eval/eval_aokvqa.sh>`__
|
201 |
+
|
202 |
+
|
203 |
+
.. list-table::
|
204 |
+
:widths: 20 20 20 20 20 20
|
205 |
+
:header-rows: 1
|
206 |
+
|
207 |
+
* - **Image Captioning**
|
208 |
+
- **BLEU@4**
|
209 |
+
- **CIDEr**
|
210 |
+
- **SPICE**
|
211 |
+
- **Training**
|
212 |
+
- **Evaluation**
|
213 |
+
* - COCO (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
|
214 |
+
- 39.9
|
215 |
+
- 133.5
|
216 |
+
- 23.7
|
217 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/train/train_caption_coco.sh>`__
|
218 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/eval/eval_coco_cap.sh>`__
|
219 |
+
* - NoCaps (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_nocaps.py>`__)
|
220 |
+
- 31.9
|
221 |
+
- 109.1
|
222 |
+
- 14.7
|
223 |
+
- NA
|
224 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/eval/eval_nocaps.sh>`__
|
225 |
+
|
226 |
+
|
227 |
+
.. list-table::
|
228 |
+
:widths: 20 20 20 20 20
|
229 |
+
:header-rows: 1
|
230 |
+
|
231 |
+
* - **Multimodal Classification**
|
232 |
+
- **val**
|
233 |
+
- **test**
|
234 |
+
- **Training**
|
235 |
+
- **Evaluation**
|
236 |
+
* - NLVR2 (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
|
237 |
+
- 82.48
|
238 |
+
- 83.25
|
239 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/train/train_nlvr.sh>`__
|
240 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/eval/eval_nlvr.sh>`__
|
241 |
+
|
242 |
+
CLIP
|
243 |
+
*******
|
244 |
+
.. list-table::
|
245 |
+
:widths: 30 40 20 20 20 30
|
246 |
+
:header-rows: 1
|
247 |
+
|
248 |
+
* - **Tasks**
|
249 |
+
- **Retrieval (Zero-shot)**
|
250 |
+
- **R1**
|
251 |
+
- **R5**
|
252 |
+
- **R10**
|
253 |
+
- **Evaluation**
|
254 |
+
* - TR
|
255 |
+
- COCO (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
|
256 |
+
- 57.2
|
257 |
+
- 80.5
|
258 |
+
- 87.8
|
259 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/clip/eval/eval_clip_ret_coco.sh>`__
|
260 |
+
* - IR
|
261 |
+
- COCO (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
|
262 |
+
- 36.5
|
263 |
+
- 60.8
|
264 |
+
- 71.0
|
265 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/clip/eval/eval_clip_ret_coco.sh>`__
|
266 |
+
* - TR
|
267 |
+
- Flickr30k (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_flickr.py>`__)
|
268 |
+
- 86.5
|
269 |
+
- 98.0
|
270 |
+
- 99.1
|
271 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/clip/eval/eval_clip_ret_flickr.sh>`__
|
272 |
+
* - IR
|
273 |
+
- Flickr30k (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_flickr.py>`__)
|
274 |
+
- 67.0
|
275 |
+
- 88.9
|
276 |
+
- 93.3
|
277 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/clip/eval/eval_clip_ret_flickr.sh>`__
|
278 |
+
|
279 |
+
.. list-table::
|
280 |
+
:widths: 20 20 20
|
281 |
+
:header-rows: 1
|
282 |
+
|
283 |
+
* - **Multimodal Classification**
|
284 |
+
- **val**
|
285 |
+
- **Evaluation**
|
286 |
+
* - ImageNet
|
287 |
+
- 76.5
|
288 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/clip/eval/eval_clip_zs_imnet.sh>`__
|
289 |
+
|
290 |
+
|
291 |
+
ALPRO
|
292 |
+
*******
|
293 |
+
.. list-table::
|
294 |
+
:widths: 30 40 20 20 20 20 30
|
295 |
+
:header-rows: 1
|
296 |
+
|
297 |
+
* - **Tasks**
|
298 |
+
- **Retrieval**
|
299 |
+
- **R1**
|
300 |
+
- **R5**
|
301 |
+
- **R10**
|
302 |
+
- **Training**
|
303 |
+
- **Evaluation**
|
304 |
+
* - TR
|
305 |
+
- MSRVTT (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_msrvtt.py>`__)
|
306 |
+
- 33.2
|
307 |
+
- 60.5
|
308 |
+
- 71.7
|
309 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/train/train_msrvtt_ret.sh>`__
|
310 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/eval/eval_msrvtt_ret.sh>`__
|
311 |
+
* - VR
|
312 |
+
- MSRVTT (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_msrvtt.py>`__)
|
313 |
+
- 33.8
|
314 |
+
- 61.4
|
315 |
+
- 72.7
|
316 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/train/train_msrvtt_ret.sh>`__
|
317 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/eval/eval_msrvtt_ret.sh>`__
|
318 |
+
* - TR
|
319 |
+
- DiDeMo (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_didemo.py>`__)
|
320 |
+
- 38.8
|
321 |
+
- 66.4
|
322 |
+
- 76.8
|
323 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/train/train_didemo_ret.sh>`__
|
324 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/eval/eval_didemo_ret.sh>`__
|
325 |
+
* - VR
|
326 |
+
- DiDeMo (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_didemo.py>`__)
|
327 |
+
- 36.6
|
328 |
+
- 67.5
|
329 |
+
- 77.9
|
330 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/train/train_didemo_ret.sh>`__
|
331 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/eval/eval_didemo_ret.sh>`__
|
332 |
+
|
333 |
+
.. list-table::
|
334 |
+
:widths: 20 20 20 20
|
335 |
+
:header-rows: 1
|
336 |
+
|
337 |
+
* - **Video QA**
|
338 |
+
- **test**
|
339 |
+
- **Training**
|
340 |
+
- **Evaluation**
|
341 |
+
* - MSRVTT
|
342 |
+
- 42.1
|
343 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/train/train_msrvtt_qa.sh>`__
|
344 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/eval/eval_msrvtt_qa.sh>`__
|
345 |
+
* - MSVD
|
346 |
+
- 46.0
|
347 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/train/train_msvd_qa.sh>`__
|
348 |
+
- `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/eval/eval_msvd_qa.sh>`__
|
docs/build_docs.sh
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
set -euo pipefail
|
3 |
+
|
4 |
+
# Change to root directory of repo
|
5 |
+
DIRNAME=$(cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
|
6 |
+
cd "${DIRNAME}/.."
|
7 |
+
|
8 |
+
# # Set up virtual environment
|
9 |
+
pip3 install setuptools wheel virtualenv
|
10 |
+
if [ ! -d venv ]; then
|
11 |
+
rm -f venv
|
12 |
+
virtualenv venv
|
13 |
+
fi
|
14 |
+
source venv/bin/activate
|
15 |
+
|
16 |
+
# # Get current git branch & stash unsaved changes
|
17 |
+
GIT_BRANCH=$(git branch --show-current)
|
18 |
+
if [ -z "${GIT_BRANCH}" ]; then
|
19 |
+
GIT_BRANCH="main"
|
20 |
+
fi
|
21 |
+
git stash
|
22 |
+
|
23 |
+
# Set up exit handler to restore git state & delete temp branches
|
24 |
+
# function exit_handler {
|
25 |
+
# git reset --hard
|
26 |
+
# git checkout "${GIT_BRANCH}" --
|
27 |
+
# git stash pop || true
|
28 |
+
# for version in $(git tag --list 'v[0-9]*'); do
|
29 |
+
# branch="${version}_local_docs_only"
|
30 |
+
# if git show-ref --verify --quiet "refs/heads/$branch"; then
|
31 |
+
# git branch -D "$branch"
|
32 |
+
# fi
|
33 |
+
# done
|
34 |
+
# }
|
35 |
+
# trap exit_handler EXIT
|
36 |
+
|
37 |
+
# Clean up build directory and install Sphinx requirements
|
38 |
+
pip3 install -r "${DIRNAME}/requirements.txt"
|
39 |
+
sphinx-build -M clean "${DIRNAME}" "${DIRNAME}/_build"
|
40 |
+
|
41 |
+
# Build API docs for current head
|
42 |
+
export current_version="latest"
|
43 |
+
pip3 install "."
|
44 |
+
sphinx-build -b html "${DIRNAME}" "${DIRNAME}/_build/html/${current_version}" -W --keep-going
|
45 |
+
rm -rf "${DIRNAME}/_build/html/${current_version}/.doctrees"
|
46 |
+
#pip3 uninstall -y omnixai
|
47 |
+
|
48 |
+
# Install all previous released versions
|
49 |
+
# and use them to build the appropriate API docs.
|
50 |
+
# Uninstall after we're done with each one.
|
51 |
+
# versions=()
|
52 |
+
# checkout_files=("${DIRNAME}/*.rst" "lavis" "tutorials" "setup.py")
|
53 |
+
# for version in $(git tag --list 'v[0-9]*'); do
|
54 |
+
# versions+=("$version")
|
55 |
+
# git checkout -b "${version}_local_docs_only"
|
56 |
+
# for f in $(git diff --name-only --diff-filter=A "tags/${version}" "${DIRNAME}/*.rst"); do
|
57 |
+
# git rm "$f"
|
58 |
+
# done
|
59 |
+
# git checkout "tags/${version}" -- "${checkout_files[@]}"
|
60 |
+
# export current_version=${version}
|
61 |
+
# pip3 install ".[all]"
|
62 |
+
# sphinx-build -b html "${DIRNAME}" "${DIRNAME}/_build/html/${current_version}" -W --keep-going
|
63 |
+
# rm -rf "${DIRNAME}/_build/html/${current_version}/.doctrees"
|
64 |
+
# #pip3 uninstall -y omnixai
|
65 |
+
# git reset --hard
|
66 |
+
# git checkout "${GIT_BRANCH}" --
|
67 |
+
# done
|
68 |
+
|
69 |
+
# Determine the latest stable version if there is one
|
70 |
+
# if (( ${#versions[@]} > 0 )); then
|
71 |
+
# stable_hash=$(git rev-list --tags --max-count=1)
|
72 |
+
# stable_version=$(git describe --tags "$stable_hash")
|
73 |
+
# export stable_version
|
74 |
+
# else
|
75 |
+
export stable_version="latest"
|
76 |
+
# fi
|
77 |
+
|
78 |
+
# Create dummy HTML's for the stable version in the base directory
|
79 |
+
while read -r filename; do
|
80 |
+
filename=$(echo "$filename" | sed "s/\.\///")
|
81 |
+
n_sub=$(echo "$filename" | (grep -o "/" || true) | wc -l)
|
82 |
+
prefix=""
|
83 |
+
for (( i=0; i<n_sub; i++ )); do
|
84 |
+
prefix+="../"
|
85 |
+
done
|
86 |
+
url="${prefix}${stable_version}/$filename"
|
87 |
+
mkdir -p "${DIRNAME}/_build/html/$(dirname "$filename")"
|
88 |
+
cat > "${DIRNAME}/_build/html/$filename" <<EOF
|
89 |
+
<!DOCTYPE html>
|
90 |
+
<html>
|
91 |
+
<head>
|
92 |
+
<title>LAVIS Documentation</title>
|
93 |
+
<meta http-equiv = "refresh" content="0; url='$url'" />
|
94 |
+
</head>
|
95 |
+
<body>
|
96 |
+
<p>Please wait while you're redirected to our <a href="$url">documentation</a>.</p>
|
97 |
+
</body>
|
98 |
+
</html>
|
99 |
+
EOF
|
100 |
+
done < <(cd "${DIRNAME}/_build/html/$stable_version" && find . -name "*.html")
|
101 |
+
echo "Finished writing to _build/html."
|
docs/conf.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Configuration file for the Sphinx documentation builder.
|
2 |
+
#
|
3 |
+
# This file only contains a selection of the most common options. For a full
|
4 |
+
# list see the documentation:
|
5 |
+
# https://www.sphinx-doc.org/en/master/usage/configuration.html
|
6 |
+
|
7 |
+
# -- Path setup --------------------------------------------------------------
|
8 |
+
|
9 |
+
# If extensions (or modules to document with autodoc) are in another directory,
|
10 |
+
# add these directories to sys.path here. If the directory is relative to the
|
11 |
+
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
12 |
+
#
|
13 |
+
# import os
|
14 |
+
# import sys
|
15 |
+
# sys.path.insert(0, os.path.abspath('.'))
|
16 |
+
|
17 |
+
|
18 |
+
# -- Project information -----------------------------------------------------
|
19 |
+
|
20 |
+
project = "LAVIS"
|
21 |
+
copyright = "2022, salesforce.com inc."
|
22 |
+
author = (
|
23 |
+
"Dongxu Li, Junnan Li, Hung Le, Guangsen Wang, Silvio Savarese, Steven C.H. Hoi"
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
# -- General configuration ---------------------------------------------------
|
28 |
+
|
29 |
+
# Add any Sphinx extension module names here, as strings. They can be
|
30 |
+
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
31 |
+
# ones.
|
32 |
+
extensions = ["nbsphinx"]
|
33 |
+
|
34 |
+
# Add any paths that contain templates here, relative to this directory.
|
35 |
+
templates_path = ["_templates"]
|
36 |
+
|
37 |
+
# List of patterns, relative to source directory, that match files and
|
38 |
+
# directories to ignore when looking for source files.
|
39 |
+
# This pattern also affects html_static_path and html_extra_path.
|
40 |
+
exclude_patterns = []
|
41 |
+
|
42 |
+
|
43 |
+
# -- Options for HTML output -------------------------------------------------
|
44 |
+
|
45 |
+
# The theme to use for HTML and HTML Help pages. See the documentation for
|
46 |
+
# a list of builtin themes.
|
47 |
+
#
|
48 |
+
# html_theme = "alabaster"
|
49 |
+
html_theme = "sphinx_rtd_theme"
|
50 |
+
|
51 |
+
# Add any paths that contain custom static files (such as style sheets) here,
|
52 |
+
# relative to this directory. They are copied after the builtin static files,
|
53 |
+
# so a file named "default.css" will overwrite the builtin "default.css".
|
54 |
+
html_static_path = ["_static"]
|
55 |
+
|
56 |
+
# pygments_style = "sphinx"
|
docs/getting_started.rst
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Dataset Zoo
|
2 |
+
##################
|
3 |
+
LAVIS inherently supports a wide variety of common language-vision datasets by providing automatic download scripts to help download and organize these datasets;
|
4 |
+
and implements PyTorch datasets for these datasets. To view supported datasets, use the following code:
|
5 |
+
|
6 |
+
.. code-block:: python
|
7 |
+
|
8 |
+
from lavis.datasets.builders import dataset_zoo
|
9 |
+
dataset_names = dataset_zoo.get_names()
|
10 |
+
print(dataset_names)
|
11 |
+
# ['aok_vqa', 'coco_caption', 'coco_retrieval', 'coco_vqa', 'conceptual_caption_12m',
|
12 |
+
# 'conceptual_caption_3m', 'didemo_retrieval', 'flickr30k', 'imagenet', 'laion2B_multi',
|
13 |
+
# 'msrvtt_caption', 'msrvtt_qa', 'msrvtt_retrieval', 'msvd_caption', 'msvd_qa', 'nlvr',
|
14 |
+
# 'nocaps', 'ok_vqa', 'sbu_caption', 'snli_ve', 'vatex_caption', 'vg_caption', 'vg_vqa']
|
15 |
+
print(len(dataset_names))
|
16 |
+
# 23
|
17 |
+
|
18 |
+
|
19 |
+
Auto-Downloading and Loading Datasets
|
20 |
+
######################################
|
21 |
+
We now take COCO caption dataset as an example to demonstrate how to download and prepare the dataset.
|
22 |
+
|
23 |
+
In ``lavis/datasets/download_scripts/``, we provide tools to download most common public language-vision datasets supported by LAVIS.
|
24 |
+
The COCO caption dataset uses images from COCO dataset. Therefore, we first download COCO images via:
|
25 |
+
|
26 |
+
.. code-block:: bash
|
27 |
+
|
28 |
+
cd lavis/datasets/download_scripts/ && python download_coco.py
|
29 |
+
|
30 |
+
This will automatically download and extract COCO images to the default LAVIS cache location.
|
31 |
+
The default cache location is ``~/.cache/lavis``, defined in ``lavis/configs/default.yaml``.
|
32 |
+
|
33 |
+
After downloading the images, we can use ``load_dataset()`` to obtain the dataset. On the first run, this will automatically download and cache annotation files.
|
34 |
+
|
35 |
+
.. code-block:: python
|
36 |
+
|
37 |
+
from lavis.datasets.builders import load_dataset
|
38 |
+
coco_dataset = load_dataset("coco_caption")
|
39 |
+
|
40 |
+
print(coco_dataset.keys())
|
41 |
+
# dict_keys(['train', 'val', 'test'])
|
42 |
+
|
43 |
+
print(len(coco_dataset["train"]))
|
44 |
+
# 566747
|
45 |
+
|
46 |
+
print(coco_dataset["train"][0])
|
47 |
+
# {'image': <PIL.Image.Image image mode=RGB size=640x480>,
|
48 |
+
# 'text_input': 'A woman wearing a net on her head cutting a cake. ',
|
49 |
+
# 'image_id': 0}
|
50 |
+
|
51 |
+
If you already host a local copy of the dataset, you can pass in the ``vis_path`` argument to change the default location to load images.
|
52 |
+
|
53 |
+
.. code-block:: python
|
54 |
+
|
55 |
+
coco_dataset = load_dataset("coco_caption", vis_path=YOUR_LOCAL_PATH)
|
56 |
+
|
57 |
+
|
58 |
+
Model Zoo
|
59 |
+
####################################
|
60 |
+
LAVIS supports a growing list of pre-trained models for different tasks,
|
61 |
+
datatsets and of varying sizes. Let's get started by viewing the supported models.
|
62 |
+
|
63 |
+
.. code-block:: python
|
64 |
+
|
65 |
+
from lavis.models import model_zoo
|
66 |
+
print(model_zoo)
|
67 |
+
# ==================================================
|
68 |
+
# Architectures Types
|
69 |
+
# ==================================================
|
70 |
+
# albef_classification base, ve
|
71 |
+
# albef_nlvr base
|
72 |
+
# albef_pretrain base
|
73 |
+
# albef_retrieval base, coco, flickr
|
74 |
+
# albef_vqa base, vqav2
|
75 |
+
# alpro_qa base, msrvtt, msvd
|
76 |
+
# alpro_retrieval base, msrvtt, didemo
|
77 |
+
# blip_caption base, base_coco, large, large_coco
|
78 |
+
# blip_classification base
|
79 |
+
# blip_feature_extractor base
|
80 |
+
# blip_nlvr base
|
81 |
+
# blip_pretrain base
|
82 |
+
# blip_retrieval base, coco, flickr
|
83 |
+
# blip_vqa base, vqav2
|
84 |
+
# clip ViT-B-32, ViT-B-16, ViT-L-14, ViT-L-14-336, RN50
|
85 |
+
|
86 |
+
# show total number of support model variants
|
87 |
+
len(model_zoo)
|
88 |
+
# 33
|
89 |
+
|
90 |
+
|
91 |
+
Inference with Pre-trained Models
|
92 |
+
####################################
|
93 |
+
|
94 |
+
Now let's see how to use models in LAVIS to perform inference on example data. We first
|
95 |
+
load a sample image from local.
|
96 |
+
|
97 |
+
.. code-block:: python
|
98 |
+
|
99 |
+
from PIL import Image
|
100 |
+
|
101 |
+
# setup device to use
|
102 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
103 |
+
|
104 |
+
# load sample image
|
105 |
+
raw_image = Image.open("docs/_static/merlion.png").convert("RGB")
|
106 |
+
|
107 |
+
This example image shows `Merlion park <https://en.wikipedia.org/wiki/Merlion>`_ (`image credit <https://theculturetrip.com/asia/singapore/articles/what-exactly-is-singapores-merlion-anyway/>`_), a landmark in Singapore.
|
108 |
+
|
109 |
+
.. image:: _static/merlion.png
|
110 |
+
|
111 |
+
Image Captioning
|
112 |
+
*******************************
|
113 |
+
We now use the BLIP model to generate a caption for the image. To make inference even easier, we also associate each
|
114 |
+
pre-trained model with its preprocessors (transforms), we use ``load_model_and_preprocess()`` with the following arguments:
|
115 |
+
|
116 |
+
- ``name``: The name of the model to load. This could be a pre-trained model, task model, or feature extractor. See ``model_zoo`` for a full list of model names.
|
117 |
+
- ``model_type``: Each architecture has variants trained on different datasets and at different scale. See Types column in ``model_zoo`` for a full list of model types.
|
118 |
+
- ``is_eval``: if `True`, set the model to evaluation mode. This is desired for inference or feature extraction.
|
119 |
+
- ``device``: device to load the model to.
|
120 |
+
|
121 |
+
.. code-block:: python
|
122 |
+
|
123 |
+
from lavis.models import load_model_and_preprocess
|
124 |
+
# loads BLIP caption base model, with finetuned checkpoints on MSCOCO captioning dataset.
|
125 |
+
# this also loads the associated image processors
|
126 |
+
model, vis_processors, _ = load_model_and_preprocess(name="blip_caption", model_type="base_coco", is_eval=True, device=device)
|
127 |
+
|
128 |
+
# preprocess the image
|
129 |
+
# vis_processors stores image transforms for "train" and "eval" (validation / testing / inference)
|
130 |
+
image = vis_processors["eval"](raw_image).unsqueeze(0).to(device)
|
131 |
+
|
132 |
+
# generate caption
|
133 |
+
model.generate({"image": image})
|
134 |
+
# ['a large fountain spewing water into the air']
|
135 |
+
|
136 |
+
|
137 |
+
You may also load models and their preprocessors separately via ``load_model()`` and ``load_processor()``.
|
138 |
+
In BLIP, you can also generate diverse captions by turning nucleus sampling on.
|
139 |
+
|
140 |
+
.. code-block:: python
|
141 |
+
|
142 |
+
from lavis.processors import load_processor
|
143 |
+
from lavis.models import load_model
|
144 |
+
|
145 |
+
# load image preprocesser used for BLIP
|
146 |
+
vis_processor = load_processor("blip_image_eval").build(image_size=384)
|
147 |
+
model = load_model(name="blip_caption", model_type="base_coco", is_eval=True, device=device)
|
148 |
+
|
149 |
+
image = vis_processor(image).unsqueeze(0).to(device)
|
150 |
+
model.generate({"image": raw_image}, use_nucleus_sampling=True)
|
151 |
+
# one generated random sample: ['some very pretty buildings and some water jets']
|
152 |
+
|
153 |
+
|
154 |
+
Visual question answering (VQA)
|
155 |
+
*******************************
|
156 |
+
BLIP model is able to answer free-form questions about images in natural language.
|
157 |
+
To access the VQA model, simply replace the ``name`` and ``model_type`` arguments
|
158 |
+
passed to ``load_model_and_preprocess()``.
|
159 |
+
|
160 |
+
.. code-block:: python
|
161 |
+
|
162 |
+
from lavis.models import load_model_and_preprocess
|
163 |
+
model, vis_processors, txt_processors = load_model_and_preprocess(name="blip_vqa", model_type="vqav2", is_eval=True, device=device)
|
164 |
+
|
165 |
+
# ask a random question.
|
166 |
+
question = "Which city is this photo taken?"
|
167 |
+
|
168 |
+
image = vis_processors["eval"](raw_image).unsqueeze(0).to(device)
|
169 |
+
question = txt_processors["eval"](question)
|
170 |
+
|
171 |
+
model.predict_answers(samples={"image": image, "text_input": question}, inference_method="generate")
|
172 |
+
# ['singapore']
|
173 |
+
|
174 |
+
|
175 |
+
Unified Feature Extraction Interface
|
176 |
+
####################################
|
177 |
+
|
178 |
+
LAVIS provides a unified interface to extract multimodal features from each architecture.
|
179 |
+
To extract features, we load the feature extractor variants of each model.
|
180 |
+
The multimodal feature can be used for multimodal classification. The low-dimensional unimodal features can be used to compute cross-modal similarity.
|
181 |
+
|
182 |
+
.. code-block:: python
|
183 |
+
|
184 |
+
from lavis.models import load_model_and_preprocess
|
185 |
+
|
186 |
+
model, vis_processors, txt_processors = load_model_and_preprocess(name="blip_feature_extractor", model_type="base", is_eval=True, device=device)
|
187 |
+
caption = "a large fountain spewing water into the air"
|
188 |
+
|
189 |
+
image = vis_processors["eval"](raw_image).unsqueeze(0).to(device)
|
190 |
+
text_input = txt_processors["eval"](caption)
|
191 |
+
|
192 |
+
sample = {"image": image, "text_input": [text_input]}
|
193 |
+
|
194 |
+
features_multimodal = model.extract_features(sample)
|
195 |
+
print(features_multimodal.keys())
|
196 |
+
# odict_keys(['image_embeds', 'multimodal_embeds'])
|
197 |
+
print(features_multimodal.multimodal_embeds.shape)
|
198 |
+
# torch.Size([1, 12, 768]), use features_multimodal[:, 0, :] for multimodal classification tasks
|
199 |
+
|
200 |
+
features_image = model.extract_features(sample, mode="image")
|
201 |
+
print(features_image.keys())
|
202 |
+
# odict_keys(['image_embeds', 'image_embeds_proj'])
|
203 |
+
print(features_image.image_embeds.shape)
|
204 |
+
# torch.Size([1, 197, 768])
|
205 |
+
print(features_image.image_embeds_proj.shape)
|
206 |
+
# torch.Size([1, 197, 256])
|
207 |
+
|
208 |
+
features_text = model.extract_features(sample, mode="text")
|
209 |
+
print(features_text.keys())
|
210 |
+
# odict_keys(['text_embeds', 'text_embeds_proj'])
|
211 |
+
print(features_text.text_embeds.shape)
|
212 |
+
# torch.Size([1, 12, 768])
|
213 |
+
print(features_text.text_embeds_proj.shape)
|
214 |
+
# torch.Size([1, 12, 256])
|
215 |
+
|
216 |
+
similarity = features_image.image_embeds_proj[:, 0, :] @ features_text.text_embeds_proj[:, 0, :].t()
|
217 |
+
print(similarity)
|
218 |
+
# tensor([[0.2622]])
|
219 |
+
|
220 |
+
Since LAVIS supports a unified feature extraction interface, minimal changes are necessary to use a different model as feature extractor. For example,
|
221 |
+
to use ALBEF as the feature extractor, one only needs to change the following line:
|
222 |
+
|
223 |
+
.. code-block:: python
|
224 |
+
|
225 |
+
model, vis_processors, txt_processors = load_model_and_preprocess(name="albef_feature_extractor", model_type="base", is_eval=True, device=device)
|
226 |
+
|
227 |
+
Similarly, to use CLIP as feature extractor:
|
228 |
+
|
229 |
+
.. code-block:: python
|
230 |
+
|
231 |
+
model, vis_processors, txt_processors = load_model_and_preprocess(name="clip_feature_extractor", model_type="base", is_eval=True, device=device)
|
232 |
+
# model, vis_processors, txt_processors = load_model_and_preprocess(name="clip_feature_extractor", model_type="RN50", is_eval=True, device=device)
|
233 |
+
# model, vis_processors, txt_processors = load_model_and_preprocess(name="clip_feature_extractor", model_type="ViT-L-14", is_eval=True, device=device)
|
docs/index.rst
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.. LAVIS documentation master file, created by
|
2 |
+
sphinx-quickstart on Sun Jul 31 10:32:27 2022.
|
3 |
+
You can adapt this file completely to your liking, but it should at least
|
4 |
+
contain the root `toctree` directive.
|
5 |
+
|
6 |
+
Welcome to LAVIS's documentation!
|
7 |
+
=================================
|
8 |
+
|
9 |
+
.. toctree::
|
10 |
+
:maxdepth: 1
|
11 |
+
:caption: Introduction
|
12 |
+
|
13 |
+
intro
|
14 |
+
|
15 |
+
|
16 |
+
.. toctree::
|
17 |
+
:maxdepth: 1
|
18 |
+
:caption: Getting Started
|
19 |
+
|
20 |
+
getting_started
|
21 |
+
|
22 |
+
|
23 |
+
.. :maxdepth: 1
|
24 |
+
.. :caption: Advanced Training
|
25 |
+
|
26 |
+
.. advanced_training
|
27 |
+
|
28 |
+
|
29 |
+
.. toctree::
|
30 |
+
:maxdepth: 2
|
31 |
+
:caption: Advanced Usage
|
32 |
+
|
33 |
+
benchmark
|
34 |
+
tutorial
|
35 |
+
|
36 |
+
|
37 |
+
.. Documentations
|
38 |
+
.. ===================
|
39 |
+
|
40 |
+
|
41 |
+
Indices and tables
|
42 |
+
==================
|
43 |
+
|
44 |
+
* :ref:`genindex`
|
45 |
+
* :ref:`modindex`
|
46 |
+
* :ref:`search`
|
docs/intro.rst
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
What is LAVIS?
|
2 |
+
####################################
|
3 |
+
|
4 |
+
LAVIS is a Python deep learning library for LAnguage-and-VISion research and applications.
|
5 |
+
It features a unified design to access state-of-the-art foundation language-vision models (`ALBEF <https://arxiv.org/pdf/2107.07651.pdf>`_,
|
6 |
+
`BLIP <https://arxiv.org/pdf/2201.12086.pdf>`_, `ALPRO <https://arxiv.org/pdf/2112.09583.pdf>`_, `CLIP <https://arxiv.org/pdf/2103.00020.pdf>`_), common tasks
|
7 |
+
(retrieval, captioning, visual question answering, multimodal classification etc.) and datasets (COCO, Flickr, Nocaps, Conceptual
|
8 |
+
Commons, SBU, etc.).
|
9 |
+
|
10 |
+
This library aims to provide engineers and researchers with a one-stop solution to rapidly develop models for their specific multimodal
|
11 |
+
scenarios, and benchmark them across standard and customized datasets.
|
12 |
+
|
13 |
+
Key features of LAVIS include:
|
14 |
+
|
15 |
+
- **Modular and Extensible Library Design**: facilitating to easily utilize and repurpose existing modules (datasets, models, preprocessors), also to add new modules.
|
16 |
+
|
17 |
+
- **Easy Off-the-shelf Inference and Feature Extraction**: readily available pre-trained models let you take advantage of state-of-the-art multimodal understanding and generation capabilities on your own data.
|
18 |
+
|
19 |
+
- **Reproducible Model Zoo**: provided training/pre-training recipies to easily replicate and extend state-of-the-art models.
|
20 |
+
|
21 |
+
- **Dataset Zoo and Automatic Downloading Tools**: it can be a hassle to prepare the many language-vision datasets. LAVIS provides automatic downloaing scripts to help prepare a large variety of datasets and their annotations.
|
22 |
+
|
23 |
+
Other features include:
|
24 |
+
|
25 |
+
- **Distributed Training** using multiple GPUs on one machine or across multiple machines.
|
26 |
+
|
27 |
+
- **Web Demo**: try supported models on your own pictures, questions etc.
|
28 |
+
|
29 |
+
- **Leaderboard**: comparing state-of-the-art models across standard datasets.
|
30 |
+
|
31 |
+
- **Dataset Explorer**: help browse and understand language-vision datasets.
|
32 |
+
|
33 |
+
Supported Tasks, Models and Datasets
|
34 |
+
####################################
|
35 |
+
|
36 |
+
The following table shows the supported models and language-vision tasks by LAVIS. Adapting existing models to more tasks is possible and next to come in future releases.
|
37 |
+
|
38 |
+
======================================== =========================== ============================================= ============
|
39 |
+
Tasks Supported Models Supported Datasets Modalities
|
40 |
+
======================================== =========================== ============================================= ============
|
41 |
+
Image-text Pre-training ALBEF, BLIP COCO, VisualGenome, SBU, ConceptualCaptions image, text
|
42 |
+
Image-text Retrieval ALBEF, BLIP, CLIP COCO, Flickr30k image, text
|
43 |
+
Text-image Retrieval ALBEF, BLIP, CLIP COCO, Flickr30k image, text
|
44 |
+
Visual Question Answering ALBEF, BLIP VQAv2, OKVQA, A-OKVQA image, text
|
45 |
+
Image Captioning BLIP COCO, NoCaps image, text
|
46 |
+
Image Classification CLIP ImageNet image
|
47 |
+
Natural Language Visual Reasoning (NLVR) ALBEF, BLIP NLVR2 image, text
|
48 |
+
Visual Entailment (VE) ALBEF SNLI-VE image, text
|
49 |
+
Visual Dialogue BLIP VisDial image, text
|
50 |
+
Video-text Retrieval BLIP, ALPRO MSRVTT, DiDeMo video, text
|
51 |
+
Text-video Retrieval BLIP, ALPRO MSRVTT, DiDeMo video, text
|
52 |
+
Video Question Answering (VideoQA) BLIP, ALPRO MSRVTT, MSVD video, text
|
53 |
+
Video Dialogue VGD-GPT AVSD video, text
|
54 |
+
Multimodal Feature Extraction ALBEF, CLIP, BLIP, ALPRO customized image, text
|
55 |
+
======================================== =========================== ============================================= ============
|
56 |
+
|
57 |
+
Library Design
|
58 |
+
####################################
|
59 |
+
|
60 |
+
.. image:: _static/architecture.png
|
61 |
+
:width: 550
|
62 |
+
|
63 |
+
LAVIS has six key modules.
|
64 |
+
|
65 |
+
- ``lavis.runners`` manages the overall training and evaluation lifecycle. It is also responsible for creating required components lazily as per demand, such as optimizers, learning rate schedulers and dataloaders. Currently ``RunnerBase`` implements epoch-based training and ``RunerIters`` implements iteration-based training.
|
66 |
+
- ``lavis.tasks`` implements concrete training and evaluation logic per task. A task could be, for example, retrieval, captioning, pre-training. The rationale to have an abstraction of task is to accommodate task-specific training and evaluation. For example, evaluating a retrieval model is different from a classification model.
|
67 |
+
- ``lavis.datasets`` is responsible for creating datasets, where ``lavis.datasets.builders`` loads dataset configurations, downloads annotations and returns a dataset object; ``lavis.datasets.datasets`` defines the supported datasets, each is a ``torch.utils.data.Dataset`` instance. We also provide `automatic dataset downloading tools` in ``datasets/download_scripts`` to help prepare common public datasets.
|
68 |
+
- ``lavis.models`` holds definition for the supported models and shared model layers.
|
69 |
+
- ``lavis.processors`` handles preprocessing of text and images/videos before feeding the model. For images and videos, a processor can be thought as transfroms in torchvision; for text input, this may include lowering case, truncation etc.
|
70 |
+
- ``lavis.common`` module contains shared classes and methods used by multiple other modules. For example,
|
71 |
+
|
72 |
+
- ``lavis.common.config`` contains classes to store and manipulate configuration files used by LAVIS. In particular, we use a hierarchical configuration design, to allow highly customizable training and evaluation.
|
73 |
+
- ``lavis.common.registry`` serves as a centralized place to manage modules that share the same functionalities. It allows building datasets, models, tasks, and learning rate schedulers during runtime, by specifying their names as string in the configuration file.
|
74 |
+
- ``lavis.common.optims`` contains definitions of learning rate schedulers.
|
75 |
+
- ``lavis.common.dist_utils`` contains utilities for distributed training and evaluation.
|
76 |
+
- ``lavis.common.utils`` contains miscellaneous utilities, mostly IO-related helper functions.
|
77 |
+
|
78 |
+
|
79 |
+
Installation
|
80 |
+
############
|
81 |
+
1. (Optional) Creating conda environment
|
82 |
+
|
83 |
+
.. code-block:: bash
|
84 |
+
|
85 |
+
conda create -n lavis python=3.8
|
86 |
+
conda activate lavis
|
87 |
+
|
88 |
+
2. Cloning and building from source
|
89 |
+
|
90 |
+
.. code-block:: bash
|
91 |
+
|
92 |
+
git clone https://github.com/salesforce/LAVIS.git
|
93 |
+
cd LAVIS
|
94 |
+
pip install .
|
95 |
+
|
96 |
+
If you would like to develop on LAVIS, you may find it easier to build with editable mode::
|
97 |
+
|
98 |
+
pip install -e .
|
99 |
+
|
docs/make.bat
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
@ECHO OFF
|
2 |
+
|
3 |
+
pushd %~dp0
|
4 |
+
|
5 |
+
REM Command file for Sphinx documentation
|
6 |
+
|
7 |
+
if "%SPHINXBUILD%" == "" (
|
8 |
+
set SPHINXBUILD=sphinx-build
|
9 |
+
)
|
10 |
+
set SOURCEDIR=source
|
11 |
+
set BUILDDIR=build
|
12 |
+
|
13 |
+
if "%1" == "" goto help
|
14 |
+
|
15 |
+
%SPHINXBUILD% >NUL 2>NUL
|
16 |
+
if errorlevel 9009 (
|
17 |
+
echo.
|
18 |
+
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
|
19 |
+
echo.installed, then set the SPHINXBUILD environment variable to point
|
20 |
+
echo.to the full path of the 'sphinx-build' executable. Alternatively you
|
21 |
+
echo.may add the Sphinx directory to PATH.
|
22 |
+
echo.
|
23 |
+
echo.If you don't have Sphinx installed, grab it from
|
24 |
+
echo.http://sphinx-doc.org/
|
25 |
+
exit /b 1
|
26 |
+
)
|
27 |
+
|
28 |
+
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
29 |
+
goto end
|
30 |
+
|
31 |
+
:help
|
32 |
+
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
33 |
+
|
34 |
+
:end
|
35 |
+
popd
|
docs/requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
GitPython
|
2 |
+
ipykernel
|
3 |
+
nbsphinx==0.8.7
|
4 |
+
pandoc
|
5 |
+
sphinx
|
6 |
+
sphinx_autodoc_typehints
|
7 |
+
sphinx_rtd_theme
|
docs/tutorial.configs.rst
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.. _config:
|
2 |
+
|
3 |
+
Training Models on Task Datasets (Commands and Configurations)
|
4 |
+
#################################################################
|
5 |
+
|
6 |
+
LAVIS provides scripts to pre-train and finetune supported models on standard language-vision tasks, stored at ``lavis/run_scripts/``.
|
7 |
+
To replicate the experiments, just run these bash scripts. For example, to train BLIP model on the image-text retrieval task with MSCOCO dataset, we can run
|
8 |
+
|
9 |
+
.. code-block::
|
10 |
+
|
11 |
+
bash run_scripts/blip/train/train_retrieval_coco.sh
|
12 |
+
|
13 |
+
Inside the scripts, we can see
|
14 |
+
|
15 |
+
.. code-block:: bash
|
16 |
+
|
17 |
+
python -m torch.distributed.run --nproc_per_node=8 train.py --cfg-path lavis/projects/blip/train/retrieval_coco_ft.yaml
|
18 |
+
|
19 |
+
where we start a pytorch distributed training on 8 GPUs (you may change according to your own hardware setup). The ``--cfg-path`` specifys a `runtime configuration file`, specifying
|
20 |
+
the task, model, dataset and training recipes.
|
21 |
+
|
22 |
+
Available options and their descriptions are as below.
|
23 |
+
|
24 |
+
.. LAVIS executes training and evaluation based on arguments specified in the configuration files. The default model and dataset configurations are defined in ``lavis/configs``. The task-specific configurations are defined in ``lavis/projects``. Task-specific configurations have higher priority over the default configurations.
|
25 |
+
|
26 |
+
.. The following tables provide explanations for the arguments in the configuration files.
|
27 |
+
|
28 |
+
.. list-table::
|
29 |
+
:widths: 30 40
|
30 |
+
:header-rows: 1
|
31 |
+
|
32 |
+
* - Model Configurations
|
33 |
+
- Functionalities
|
34 |
+
* - arch
|
35 |
+
- | name of the model from the model zoo
|
36 |
+
| default: task-dependent
|
37 |
+
* - model_type
|
38 |
+
- | the type of the model (e.g., base)
|
39 |
+
| default: task-dependent
|
40 |
+
* - load_pretrained
|
41 |
+
- | load pretrained weights
|
42 |
+
| default: True (for finetuning task) | False (for pretraining task)
|
43 |
+
* - load_finetuned
|
44 |
+
- | load task-specific finetuned weights
|
45 |
+
| default: False (for finetuning task) | True (for evaluation)
|
46 |
+
* - pretrained
|
47 |
+
- | URL or local path which stores the pretrained model, defined in the default model configuration file
|
48 |
+
| default: task-dependent
|
49 |
+
* - finetuned
|
50 |
+
- | URL or local path which stores the finetuned model, defined in the default model configuration file
|
51 |
+
| default: task-dependent
|
52 |
+
|
53 |
+
.. list-table::
|
54 |
+
:widths: 30 50
|
55 |
+
:header-rows: 1
|
56 |
+
|
57 |
+
* - Dataset Configurations
|
58 |
+
- Functionalities
|
59 |
+
* - vis_processor
|
60 |
+
- | pre-processing of visual input
|
61 |
+
| default: task-dependent
|
62 |
+
* - text_processor
|
63 |
+
- | pre-processing of text input
|
64 |
+
| default: task-dependent
|
65 |
+
* - build_info
|
66 |
+
- | dataset information including the storage location, defined in the default dataset configuration file
|
67 |
+
| default: task-dependent
|
68 |
+
|
69 |
+
.. list-table::
|
70 |
+
:widths: 30 50
|
71 |
+
:header-rows: 1
|
72 |
+
|
73 |
+
* - Runtime Configurations
|
74 |
+
- Functionalities
|
75 |
+
* - task
|
76 |
+
- | name of the task
|
77 |
+
| default: task-dependent
|
78 |
+
* - lr_sched
|
79 |
+
- | learning rate schedular
|
80 |
+
| default: linear_warmup_cosine_lr
|
81 |
+
* - init_lr
|
82 |
+
- | initial learning rate (after warmup)
|
83 |
+
| default: task-dependent
|
84 |
+
* - min_lr
|
85 |
+
- | final learning rate after decay
|
86 |
+
| default: task-dependent
|
87 |
+
* - warmup_lr
|
88 |
+
- | starting learning rate for warmup
|
89 |
+
| default: init_lr (no warmup)
|
90 |
+
* - lr_decay_rate
|
91 |
+
- | learning rate decay per epoch for step_lr_shedule
|
92 |
+
| default: 0.9
|
93 |
+
* - warmup_steps
|
94 |
+
- | number of steps for learning rate warmup
|
95 |
+
| default: 0
|
96 |
+
* - max_epoch
|
97 |
+
- | total number of training epochs
|
98 |
+
| default: task-dependent
|
99 |
+
* - weight_decay
|
100 |
+
- | weight decay coefficient for the optimizer
|
101 |
+
| default: 0.05
|
102 |
+
* - batch_size_train
|
103 |
+
- | batch size during training
|
104 |
+
| default: task-dependent
|
105 |
+
* - batch_size_eval
|
106 |
+
- | batch size during evaluation
|
107 |
+
| default: task-dependent
|
108 |
+
* - seed
|
109 |
+
- | pseudo random number generator seed
|
110 |
+
| default: 42
|
111 |
+
* - output_dir
|
112 |
+
- | directory to store logs, results and checkpoints
|
113 |
+
| default: task-dependent
|
114 |
+
* - resume_ckpt_path
|
115 |
+
- | path of the checkpoint to resume training from
|
116 |
+
| default: None
|
117 |
+
* - evaluate
|
118 |
+
- | only perform evaluation without training
|
119 |
+
| default: False
|
120 |
+
* - train_splits
|
121 |
+
- | dataset splits used for training
|
122 |
+
| default: ["train"]
|
123 |
+
* - valid_splits
|
124 |
+
- | dataset splits used for validation
|
125 |
+
| default: ["val"]
|
126 |
+
* - test
|
127 |
+
- | dataset splits used for test
|
128 |
+
| default: ["test"]
|
129 |
+
* - device
|
130 |
+
- | use cpu or gpu (cuda)
|
131 |
+
| default: cuda
|
132 |
+
* - world_size
|
133 |
+
- | number of processes participating in the job
|
134 |
+
| default: 1
|
135 |
+
* - dist_url
|
136 |
+
- | URL specifying how to initialize the process group
|
137 |
+
| default: "env://"
|
138 |
+
* - distributed
|
139 |
+
- | use distributed training
|
140 |
+
| default: True
|
141 |
+
* - amp
|
142 |
+
- | use automatic mixed precision training
|
143 |
+
| default: False
|
144 |
+
|
145 |
+
.. list-table::
|
146 |
+
:widths: 40 50
|
147 |
+
:header-rows: 1
|
148 |
+
|
149 |
+
* - Text Generation Configurations
|
150 |
+
- Functionalities
|
151 |
+
* - max_len
|
152 |
+
- | maximum number of text tokens to generate
|
153 |
+
| default: 20 (for image captioning)
|
154 |
+
* - min_len
|
155 |
+
- | minimum number of text tokens to generate
|
156 |
+
| default: 5 (for image captioning)
|
157 |
+
* - num_beams
|
158 |
+
- | number of beams to perform beam search
|
159 |
+
| default: 3
|
160 |
+
|
161 |
+
.. list-table::
|
162 |
+
:widths: 40 50
|
163 |
+
:header-rows: 1
|
164 |
+
|
165 |
+
* - Multimodal Retrieval Configurations
|
166 |
+
- Functionalities
|
167 |
+
* - negative_all_rank
|
168 |
+
- | collect negatives from all processes for the image-text matching loss
|
169 |
+
| default: True (for coco)
|
170 |
+
* - k_test
|
171 |
+
- | number of retrieval candidates ranked from contrastive similarity
|
172 |
+
| default: 256 (for coco)
|
docs/tutorial.datasets.rst
ADDED
@@ -0,0 +1,424 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Adding Datasets
|
2 |
+
################################################
|
3 |
+
|
4 |
+
This is a tutorial on adding a new dataset using ``lavis.datasets`` module.
|
5 |
+
|
6 |
+
The LAVIS library includes a standard dataset module, which allows customization to add new datasets.
|
7 |
+
The ``lavis.datasets`` module is designed such that any new dataset class can be easily added and adapted from our code base, including creating dataset configuration, and defining and associating new dataset classes.
|
8 |
+
|
9 |
+
In this tutorial, we will replicate the steps to add a dataset class for the `Audio-Visual Scene-Aware Dialogue (AVSD) <https://arxiv.org/pdf/1901.09107.pdf>`_ benchmark for the video-grounded dialogue task.
|
10 |
+
|
11 |
+
Dataset Configuration ``lavis.configs.datasets``
|
12 |
+
**************************************************************
|
13 |
+
|
14 |
+
First, we define the basic configurations for this dataset, including a new dataset class ``avsd_dialogue``, dataset card, and data types.
|
15 |
+
We can define any new dataset configuration in ``lavis.configs.datasets``. For instance, under this module, we can set up a configuration file ``avsd/defaults_dial.yaml`` as follows:
|
16 |
+
|
17 |
+
.. code-block:: yaml
|
18 |
+
|
19 |
+
datasets:
|
20 |
+
avsd_dialogue: # name of the dataset builder
|
21 |
+
dataset_card: dataset_card/avsd_dialogue.md # path to the dataset card
|
22 |
+
data_type: features # [images|videos|features] we use features in this case for extracted video features
|
23 |
+
|
24 |
+
build_info:
|
25 |
+
# Be careful not to append minus sign (-) before split to avoid itemizing
|
26 |
+
annotations:
|
27 |
+
train:
|
28 |
+
url: /export/home/data/avsd/train_set4DSTC7-AVSD.json
|
29 |
+
storage: avsd/annotations/train.json
|
30 |
+
val:
|
31 |
+
url: /export/home/data/avsd/valid_set4DSTC7-AVSD.json
|
32 |
+
storage: avsd/annotations/val.json
|
33 |
+
test:
|
34 |
+
url: /export/home/data/avsd/test_set4DSTC7-AVSD.json
|
35 |
+
storage: avsd/annotations/test.json
|
36 |
+
features:
|
37 |
+
storage: /export/home/data/avsd/features/
|
38 |
+
|
39 |
+
|
40 |
+
Dataset Card
|
41 |
+
===============
|
42 |
+
One optional step to set up dataset configuration is defining a dataset card, which contains more details about the dataset such as description, tasks, and metrics.
|
43 |
+
For instance, we can define a dataset card for the AVSD benchmark in ``dataset_card/avsd_dialogue.md``.
|
44 |
+
Depending on the dataset, we included in its corresponding dataset card the command for auto-downloading data (with python code defined in ``lavis.datasets.download_scripts``) that will automatically load the data and store it in a specific folder.
|
45 |
+
Else, you should describe in the dataset card the external download instructions from the original data source to load the dataset properly.
|
46 |
+
|
47 |
+
One example of a dataset card for the AVSD benchmark is:
|
48 |
+
|
49 |
+
.. code-block:: md
|
50 |
+
|
51 |
+
![Samples from the AVSD dataset (Image credit: "https://arxiv.org/pdf/1901.09107.pdf").](imgs/avsd_dialogue.png)(Samples from the AVSD dataset. Image credit: "https://arxiv.org/pdf/1901.09107.pdf")
|
52 |
+
|
53 |
+
# Audio-Visual Scene-Aware Dialogues (AVSD)
|
54 |
+
|
55 |
+
## Description
|
56 |
+
[Audio-Visual Scene-Aware Dialogues (AVSD)](https://github.com/hudaAlamri/DSTC7-Audio-Visual-Scene-Aware-Dialog-AVSD-Challenge) contains more than 10,000 dialogues, each of which is grounded on a unique video. In the test split, for each test sample, 6 reference dialogue responses are provided.
|
57 |
+
|
58 |
+
|
59 |
+
## Task
|
60 |
+
|
61 |
+
(https://github.com/hudaAlamri/DSTC7-Audio-Visual-Scene-Aware-Dialog-AVSD-Challenge)
|
62 |
+
|
63 |
+
In a **video-grounded dialogue task**, the system must generate responses to user input in the context of a given dialog.
|
64 |
+
This context consists of a dialog history (previous utterances by both user and system) in addition to video and audio information that comprise the scene. The quality of a system’s automatically generated sentences is evaluated using objective measures to determine whether or not the generated responses are natural and informative
|
65 |
+
|
66 |
+
## Metrics
|
67 |
+
Models are typically evaluated according to [BLEU](https://aclanthology.org/P02-1040/), [CIDER](https://www.cv-foundation.org/openaccess/content_cvpr_2015/papers/Vedantam_CIDEr_Consensus-Based_Image_2015_CVPR_paper.pdf), [METEOR](https://aclanthology.org/W05-0909/), and [ROUGE-L](https://aclanthology.org/W04-1013/) metrics.
|
68 |
+
|
69 |
+
## Leaderboard
|
70 |
+
|
71 |
+
....
|
72 |
+
|
73 |
+
|
74 |
+
## Auto-Downloading
|
75 |
+
|
76 |
+
Please refer to [benchmark webite](https://github.com/hudaAlamri/DSTC7-Audio-Visual-Scene-Aware-Dialog-AVSD-Challenge) for instructions to download the dataset.
|
77 |
+
|
78 |
+
|
79 |
+
## References
|
80 |
+
"Audio Visual Scene-Aware Dialog", Huda Alamri, Vincent Cartillier, Abhishek Das, Jue Wang, Anoop Cherian, Irfan Essa, Dhruv Batra, Tim K. Marks, Chiori Hori, Peter Anderson, Stefan Lee, Devi Parikh
|
81 |
+
|
82 |
+
Visual Data Type
|
83 |
+
==============================
|
84 |
+
We currently limit the visual data types to one of three options: ``images``, ``videos``, and ``features``.
|
85 |
+
"Images" and "videos" refer to the raw visual data, which is appropriate for models processing visual data in their original forms (e.g. ViT models).
|
86 |
+
"Features" are visual representations extracted from pretrained models (e.g. CNN models).
|
87 |
+
In this tutorial, the AVSD benchmark consists of video features extracted from 3D-CNN models.
|
88 |
+
|
89 |
+
Build Info
|
90 |
+
==============================
|
91 |
+
Build info refers to the specific locations where data is stored and cached.
|
92 |
+
|
93 |
+
For text annotations (e.g. captioning or dialogues), by default, we include three data splits, namely "train", "val", and "test", typically used in all machine learning projects.
|
94 |
+
For each split, we specify 2 parameters: ``url`` and ``storage``.
|
95 |
+
``url`` can be either an online URL where the dataset can be loaded automatically (e.g. from *googleapis*), or a local directory where data is already downloaded beforehand.
|
96 |
+
``storage`` is the directory where the data will be cached over time, avoiding downloading data repeatedly.
|
97 |
+
|
98 |
+
For visual data annotations, ensure the field name matches the data types defined earlier (e.g. one of "images", "videos" or features").
|
99 |
+
As visual features are usually large and should be downloaded beforehand, we maintain only a ``storage`` parameter where visual data is cached.
|
100 |
+
|
101 |
+
Dataset ``lavis.datasets.datasets``
|
102 |
+
**************************************************************
|
103 |
+
|
104 |
+
Base Dataset ``lavis.datasets.datasets.base_dataset``
|
105 |
+
=======================================================
|
106 |
+
In this step, we want to define new dataset classes that inherit our base dataset class ``lavis.datasets.datasets.base_dataset``. This base dataset class already defines standard methods such as ``collater`` which uses the default collator from Pytorch.
|
107 |
+
|
108 |
+
.. code-block:: python
|
109 |
+
|
110 |
+
import json
|
111 |
+
from typing import Iterable
|
112 |
+
|
113 |
+
from torch.utils.data import Dataset, ConcatDataset
|
114 |
+
from torch.utils.data.dataloader import default_collate
|
115 |
+
|
116 |
+
class BaseDataset(Dataset):
|
117 |
+
def __init__(
|
118 |
+
self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[]
|
119 |
+
):
|
120 |
+
"""
|
121 |
+
vis_root (string): Root directory of images (e.g. coco/images/)
|
122 |
+
ann_root (string): directory to store the annotation file
|
123 |
+
"""
|
124 |
+
self.vis_root = vis_root
|
125 |
+
|
126 |
+
self.annotation = []
|
127 |
+
for ann_path in ann_paths:
|
128 |
+
self.annotation.extend(json.load(open(ann_path, "r")))
|
129 |
+
|
130 |
+
self.vis_processor = vis_processor
|
131 |
+
self.text_processor = text_processor
|
132 |
+
|
133 |
+
self._add_instance_ids()
|
134 |
+
|
135 |
+
def __len__(self):
|
136 |
+
return len(self.annotation)
|
137 |
+
|
138 |
+
def collater(self, samples):
|
139 |
+
return default_collate(samples)
|
140 |
+
|
141 |
+
def set_processors(self, vis_processor, text_processor):
|
142 |
+
self.vis_processor = vis_processor
|
143 |
+
self.text_processor = text_processor
|
144 |
+
|
145 |
+
def _add_instance_ids(self, key="instance_id"):
|
146 |
+
for idx, ann in enumerate(self.annotation):
|
147 |
+
ann[key] = str(idx)
|
148 |
+
|
149 |
+
Any dataset subclass will inherit these methods and it is optional to define and overwrite these methods accordingly to the specifications of the dataset.
|
150 |
+
We encourage users not to modify the base dataset class as any modification will have cascading impacts on any other dataset classes that inherit this base dataset.
|
151 |
+
Instead, the users should independently create new dataset classes to cater to their specific requirements.
|
152 |
+
|
153 |
+
Dialogue Datasets ``lavis.datasets.datasets.dialogue_datasets``
|
154 |
+
======================================================================
|
155 |
+
|
156 |
+
For example, for the AVSD dataset, we want to define a new dataset subclass ``DialogueDataset`` for dialogue tasks. We can define this dataset class in ``lavis.datasets.datasets.dialogue_datasets`` as following:
|
157 |
+
|
158 |
+
.. code-block:: python
|
159 |
+
|
160 |
+
import os
|
161 |
+
from collections import OrderedDict
|
162 |
+
|
163 |
+
from lavis.datasets.datasets.base_dataset import BaseDataset
|
164 |
+
|
165 |
+
import json
|
166 |
+
import copy
|
167 |
+
|
168 |
+
class DialogueDataset(BaseDataset):
|
169 |
+
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
|
170 |
+
"""
|
171 |
+
vis_processor (string): visual processor
|
172 |
+
text_processor (string): textual processor
|
173 |
+
vis_root (string): Root directory of images (e.g. coco/images/)
|
174 |
+
ann_paths (string): Root directory of images (e.g. coco/images/)
|
175 |
+
"""
|
176 |
+
|
177 |
+
self.vis_root = vis_root
|
178 |
+
|
179 |
+
self.annotation = []
|
180 |
+
for ann_path in ann_paths:
|
181 |
+
dialogs = json.load(open(ann_path, "r"))['dialogs']
|
182 |
+
for dialog in dialogs:
|
183 |
+
all_turns = dialog['dialog']
|
184 |
+
dialogue_context = []
|
185 |
+
for turn in all_turns:
|
186 |
+
dialog_instance = copy.deepcopy(dialog)
|
187 |
+
question = turn['question']
|
188 |
+
answer = turn['answer']
|
189 |
+
|
190 |
+
dialog_instance['dialog'] = copy.deepcopy(dialogue_context)
|
191 |
+
dialog_instance['question'] = question
|
192 |
+
dialog_instance['answer'] = answer
|
193 |
+
self.annotation.append(dialog_instance)
|
194 |
+
dialogue_context.append(turn)
|
195 |
+
|
196 |
+
self.vis_processor = vis_processor
|
197 |
+
self.text_processor = text_processor
|
198 |
+
|
199 |
+
self._add_instance_ids()
|
200 |
+
|
201 |
+
self.img_ids = {}
|
202 |
+
n = 0
|
203 |
+
for ann in self.annotation:
|
204 |
+
img_id = ann["image_id"]
|
205 |
+
if img_id not in self.img_ids.keys():
|
206 |
+
self.img_ids[img_id] = n
|
207 |
+
n += 1
|
208 |
+
|
209 |
+
Class inheritance allows us to define multiple subclasses. For instance, we want another dialogue dataset class that is defined only for the test split. We can define another dataset class ``DialogueEvalDataset`` as similarly defined above but the annotations are processed differently.
|
210 |
+
Typically, in dialogue tasks, during test time, only a single test sample is constructed per dialogue (rather than decomposing all dialogue turns as samples during training time).
|
211 |
+
The dataset class can then be defined as:
|
212 |
+
|
213 |
+
.. code-block:: python
|
214 |
+
|
215 |
+
class DialogueEvalDataset(BaseDataset):
|
216 |
+
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
|
217 |
+
# ...
|
218 |
+
# defined similarly as DialogueDataset above
|
219 |
+
# except for the loading of dialogue annotation data
|
220 |
+
|
221 |
+
self.annotation = []
|
222 |
+
for ann_path in ann_paths:
|
223 |
+
dialogs = json.load(open(ann_path, "r"))['dialogs']
|
224 |
+
for dialog in dialogs:
|
225 |
+
all_turns = dialog['dialog']
|
226 |
+
dialogue_context = all_turns[:-1]
|
227 |
+
last_turn = all_turns[-1]
|
228 |
+
|
229 |
+
question = last_turn['question']
|
230 |
+
answer = last_turn['answer']
|
231 |
+
|
232 |
+
dialog['dialog'] = dialogue_context
|
233 |
+
dialog['question'] = question
|
234 |
+
dialog['answer'] = answer
|
235 |
+
|
236 |
+
self.annotation.append(dialog)
|
237 |
+
|
238 |
+
|
239 |
+
Using class inheritance to define datasets also allows us to develop more fine-grain class implementations, each of which is specifically designated for a benchmark.
|
240 |
+
For instance, under the dialogue-based tasks, we can further define another dataset subclass that is specified for the AVSD dataset.
|
241 |
+
We can define a new class ``AVSDDialDataset`` that further specifies how to load individual samples and collate them accordingly to specific requirements:
|
242 |
+
|
243 |
+
.. code-block:: python
|
244 |
+
|
245 |
+
import os
|
246 |
+
from lavis.datasets.datasets.base_dataset import BaseDataset
|
247 |
+
from lavis.datasets.datasets.dialogue_datasets import DialogueDataset, DialogueEvalDataset
|
248 |
+
|
249 |
+
import torch
|
250 |
+
|
251 |
+
class AVSDDialDataset(DialogueDataset):
|
252 |
+
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
|
253 |
+
|
254 |
+
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
|
255 |
+
|
256 |
+
def __getitem__(self, index):
|
257 |
+
|
258 |
+
ann = self.annotation[index]
|
259 |
+
|
260 |
+
vname = ann["image_id"]
|
261 |
+
|
262 |
+
video = self.vis_processor(self.vis_root, vname)
|
263 |
+
|
264 |
+
dialogue = self.text_processor(ann)
|
265 |
+
|
266 |
+
return {
|
267 |
+
"video_fts": video['video_fts'],
|
268 |
+
"video_token_type_ids": video['token_type_ids'],
|
269 |
+
"input_ids": dialogue['input_ids'],
|
270 |
+
"token_type_ids": dialogue['token_type_ids'],
|
271 |
+
"labels": dialogue['labels'],
|
272 |
+
"image_id": ann["image_id"],
|
273 |
+
"instance_id": ann["instance_id"]
|
274 |
+
}
|
275 |
+
|
276 |
+
def collater(self, samples):
|
277 |
+
|
278 |
+
input_ids, token_type_ids, labels, video_fts, video_token_type_ids = [], [], [], [], []
|
279 |
+
|
280 |
+
for i in samples:
|
281 |
+
input_ids.append(i['input_ids'])
|
282 |
+
token_type_ids.append(i['token_type_ids'])
|
283 |
+
labels.append(i['labels'])
|
284 |
+
video_fts.append(i['video_fts'])
|
285 |
+
video_token_type_ids.append(i['video_token_type_ids'])
|
286 |
+
|
287 |
+
input_ids = self.text_processor.padding(input_ids)
|
288 |
+
|
289 |
+
labels = self.text_processor.padding(labels, -1)
|
290 |
+
video_fts = self.vis_processor.padding(video_fts)
|
291 |
+
|
292 |
+
token_type_ids = self.text_processor.padding(token_type_ids)
|
293 |
+
video_token_type_ids = self.text_processor.padding(video_token_type_ids)
|
294 |
+
token_type_ids = torch.cat([video_token_type_ids, token_type_ids], dim=1)
|
295 |
+
|
296 |
+
attn_mask = self.text_processor.get_attention_mask(input_ids)
|
297 |
+
video_mask = self.vis_processor.get_attention_mask(video_fts)
|
298 |
+
attn_mask = torch.cat([video_mask, attn_mask], dim=1)
|
299 |
+
|
300 |
+
video_labels = torch.ones((video_fts.size(0), video_fts.size(1))).long() * -1 # ignore token indice -1 by default
|
301 |
+
|
302 |
+
labels = torch.cat([video_labels, labels], dim=1)
|
303 |
+
|
304 |
+
samples = {}
|
305 |
+
samples['input_ids'] = input_ids
|
306 |
+
samples['token_type_ids'] = token_type_ids
|
307 |
+
samples['labels'] = labels
|
308 |
+
samples['video_fts'] = video_fts
|
309 |
+
samples['attn_mask'] = attn_mask
|
310 |
+
|
311 |
+
return samples
|
312 |
+
|
313 |
+
Note that in a dataset subclass, if methods such as ``__getitem__`` and ``collater`` are not defined, the same functions from the corresponding superclass will be used.
|
314 |
+
For instance, by default, we always use the collater from the ``BaseDataset`` class to collate data samples.
|
315 |
+
|
316 |
+
Dataset Builder ``lavis.datasets.builders``
|
317 |
+
**************************************************************
|
318 |
+
Dataset Builder is the data processing module that controls the dataset classes (by training or evaluation split) and associates the specific dataset configurations to these dataset classes.
|
319 |
+
|
320 |
+
Base Dataset Builder ``lavis.datasets.builders.base_dataset_builder``
|
321 |
+
======================================================================
|
322 |
+
|
323 |
+
Note that any new builder class definition should inherit the base dataset builder class ``lavis.datasets.builders.base_dataset_builder``:
|
324 |
+
|
325 |
+
.. code-block:: python
|
326 |
+
|
327 |
+
class BaseDatasetBuilder:
|
328 |
+
train_dataset_cls, eval_dataset_cls = None, None
|
329 |
+
...
|
330 |
+
|
331 |
+
This allows us to standardize the operations of dataset builders across all builder classes. We advise the users to carefully review the standard methods defined in the base builder class, including methods such as ``_download_data`` and ``build_dataset`` that will load download the data and create instances of dataset classes:
|
332 |
+
|
333 |
+
.. code-block:: python
|
334 |
+
|
335 |
+
class BaseDatasetBuilder:
|
336 |
+
...
|
337 |
+
|
338 |
+
def build_datasets(self):
|
339 |
+
# download, split, etc...
|
340 |
+
# only called on 1 GPU/TPU in distributed
|
341 |
+
|
342 |
+
if is_main_process():
|
343 |
+
self._download_data()
|
344 |
+
|
345 |
+
if is_dist_avail_and_initialized():
|
346 |
+
dist.barrier()
|
347 |
+
|
348 |
+
# at this point, all the annotations and image/videos should be all downloaded to the specified locations.
|
349 |
+
logging.info("Building datasets...")
|
350 |
+
datasets = self.build() # dataset['train'/'val'/'test']
|
351 |
+
|
352 |
+
return datasets
|
353 |
+
|
354 |
+
def _download_data(self):
|
355 |
+
self._download_ann()
|
356 |
+
self._download_vis()
|
357 |
+
|
358 |
+
We encourage users not to modify the implementation of the base dataset builder class as this will affect all existing dataset builder subclasses.
|
359 |
+
|
360 |
+
Dialogue Dataset Builder ``lavis.datasets.builders.dialogue_builder``
|
361 |
+
======================================================================
|
362 |
+
We can define any new builder subclass and associate this builder with the corresponding dataset classes and dataset configurations.
|
363 |
+
For instance, for the AVSD dataset, we can define a builder ``lavis.datasets.builders.dialogue_builder`` for dialogue-based datasets as follows:
|
364 |
+
|
365 |
+
.. code-block:: python
|
366 |
+
|
367 |
+
from lavis.datasets.builders.base_dataset_builder import BaseDatasetBuilder
|
368 |
+
from lavis.datasets.datasets.avsd_dialogue_datasets import (
|
369 |
+
AVSDDialDataset,
|
370 |
+
AVSDDialEvalDataset
|
371 |
+
)
|
372 |
+
|
373 |
+
from lavis.common.registry import registry
|
374 |
+
|
375 |
+
|
376 |
+
@registry.register_builder("avsd_dialogue")
|
377 |
+
class AVSDDialBuilder(BaseDatasetBuilder):
|
378 |
+
train_dataset_cls = AVSDDialDataset
|
379 |
+
eval_dataset_cls = AVSDDialEvalDataset
|
380 |
+
|
381 |
+
DATASET_CONFIG_DICT = {
|
382 |
+
"default": "configs/datasets/avsd/defaults_dial.yaml"
|
383 |
+
}
|
384 |
+
|
385 |
+
Note that we chose to separately define the parameters ``train_dataset_cls`` and ``eval_dataset_cls`` to consider cases where data is processed differently between training and test time.
|
386 |
+
For instance, in captioning tasks, during test time, each data sample often includes multiple ground-truth captions rather than just a single ground-truth during training time.
|
387 |
+
If the data processing is the same in both training and test time, the two parameters can be linked to the same dataset class.
|
388 |
+
|
389 |
+
Finally, define ``DATASET_CONFIG_DICT`` to associate the dataset configurations to the assigned dataset classes.
|
390 |
+
|
391 |
+
Registering Builder ``lavis.datasets.builders.__init__``
|
392 |
+
======================================================================
|
393 |
+
|
394 |
+
To add a new builder class, ensure to first include the class within the ``__init__.py``. For instance, to define a new builder for the AVSD dataset:
|
395 |
+
|
396 |
+
.. code-block:: python
|
397 |
+
|
398 |
+
from lavis.datasets.builders.dialogue_builder import (
|
399 |
+
AVSDDialBuilder
|
400 |
+
)
|
401 |
+
|
402 |
+
__all__ = [
|
403 |
+
...,
|
404 |
+
"AVSDDialBuilder"
|
405 |
+
]
|
406 |
+
|
407 |
+
Assigning Builder
|
408 |
+
======================================================================
|
409 |
+
Note that during data loading and processing, the builder being assigned must have the correct registry to be able to load it properly.
|
410 |
+
For instance, the following should be specified in a configuration file e.g. ``dialogue_avsd_ft.yaml``:
|
411 |
+
|
412 |
+
.. code-block:: yaml
|
413 |
+
|
414 |
+
datasets:
|
415 |
+
avsd_dialogue: # name of the dataset builder
|
416 |
+
...
|
417 |
+
# processor configuration
|
418 |
+
...
|
419 |
+
|
420 |
+
Subsequently, any processes (e.g. training) should load this configuration file to assign the correct builder which will then associate the correct dataset classes to construct data samples.
|
421 |
+
|
422 |
+
.. code-block:: sh
|
423 |
+
|
424 |
+
python train.py --cfg-path dialogue_avsd_ft.yaml
|
docs/tutorial.evaluation.rst
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Evaluating Pre-trained Models on Task Datasets
|
2 |
+
###############################################
|
3 |
+
LAVIS provides pre-trained and finetuned model for off-the-shelf evaluation on task dataset.
|
4 |
+
Let's now see an example to evaluate BLIP model on the captioning task, using MSCOCO dataset.
|
5 |
+
|
6 |
+
.. _prep coco:
|
7 |
+
|
8 |
+
Preparing Datasets
|
9 |
+
******************
|
10 |
+
First, let's download the dataset. LAVIS provides `automatic downloading scripts` to help prepare
|
11 |
+
most of the public dataset, to download MSCOCO dataset, simply run
|
12 |
+
|
13 |
+
.. code-block:: bash
|
14 |
+
|
15 |
+
cd lavis/datasets/download_scripts && python download_coco.py
|
16 |
+
|
17 |
+
This will put the downloaded dataset at a default cache location ``cache`` used by LAVIS.
|
18 |
+
|
19 |
+
If you want to use a different cache location, you can specify it by updating ``cache_root`` in ``lavis/configs/default.yaml``.
|
20 |
+
|
21 |
+
If you have a local copy of the dataset, it is recommended to create a symlink from the cache location to the local copy, e.g.
|
22 |
+
|
23 |
+
.. code-block:: bash
|
24 |
+
|
25 |
+
ln -s /path/to/local/coco cache/coco
|
26 |
+
|
27 |
+
Evaluating pre-trained models
|
28 |
+
******************************
|
29 |
+
|
30 |
+
To evaluate pre-trained model, simply run
|
31 |
+
|
32 |
+
.. code-block:: bash
|
33 |
+
|
34 |
+
bash run_scripts/blip/eval/eval_coco_cap.sh
|
35 |
+
|
36 |
+
Or to evaluate a large model:
|
37 |
+
|
38 |
+
.. code-block:: bash
|
39 |
+
|
40 |
+
bash run_scripts/blip/eval/eval_coco_cap_large.sh
|
docs/tutorial.models.rst
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Adding Models
|
2 |
+
####################################
|
3 |
+
|
4 |
+
This is a tutorial on adding new models using ``lavis.models`` module.
|
5 |
+
|
6 |
+
The LAVIS library includes a standard model module that builds the foundation for many major language-vision models such as `ALBEF <https://arxiv.org/pdf/2107.07651.pdf>`_,
|
7 |
+
`BLIP <https://arxiv.org/pdf/2201.12086.pdf>`_, `ALPRO <https://arxiv.org/pdf/2112.09583.pdf>`_, and `CLIP <https://arxiv.org/pdf/2103.00020.pdf>`_.
|
8 |
+
The ``lavis.models`` module is designed such that any new models can be added and integrated into the LAVIS library, with minimal steps to develop training and testing procedures.
|
9 |
+
In this tutorial, we will replicate the steps to add a GPT-style model specifically for `video-grounded dialogue tasks <https://arxiv.org/pdf/1901.09107.pdf>`_.
|
10 |
+
|
11 |
+
Base Model ``lavis.models.base_model``
|
12 |
+
**************************************************************
|
13 |
+
|
14 |
+
Note that any new model definition should inherit the base model class ``BaseModel``:
|
15 |
+
|
16 |
+
.. code-block:: python
|
17 |
+
|
18 |
+
from omegaconf import OmegaConf
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
|
22 |
+
import torch
|
23 |
+
import torch.nn as nn
|
24 |
+
|
25 |
+
from lavis.common.utils import get_abs_path
|
26 |
+
|
27 |
+
class BaseModel(nn.Module):
|
28 |
+
"""Base class for models."""
|
29 |
+
|
30 |
+
def __init__(self):
|
31 |
+
super().__init__()
|
32 |
+
|
33 |
+
def forward_features(self, *args, **kwargs):
|
34 |
+
"""Similar to *forward* but only return features."""
|
35 |
+
raise NotImplementedError
|
36 |
+
|
37 |
+
def load_from_pretrained(self, url_or_filename):
|
38 |
+
raise NotImplementedError
|
39 |
+
|
40 |
+
@classmethod
|
41 |
+
def _from_config(cls, cfg=None, model_type="base"):
|
42 |
+
if not cfg:
|
43 |
+
# useful when building model without a provided configuration file
|
44 |
+
cfg = OmegaConf.load(cls.default_config_path(model_type)).model
|
45 |
+
|
46 |
+
return cls.from_config(cfg)
|
47 |
+
|
48 |
+
@classmethod
|
49 |
+
def from_pretrained(cls, model_type="base"):
|
50 |
+
"""
|
51 |
+
Build a pretrained model from the default configuration file, specified by model_type.
|
52 |
+
"""
|
53 |
+
return cls._from_config(cfg=None, model_type=model_type)
|
54 |
+
|
55 |
+
@property
|
56 |
+
def device(self):
|
57 |
+
return list(self.parameters())[0].device
|
58 |
+
|
59 |
+
@classmethod
|
60 |
+
def default_config_path(cls, model_type="base"):
|
61 |
+
assert (
|
62 |
+
model_type in cls.PRETRAINED_MODEL_CONFIG_DICT
|
63 |
+
), "Unknown model type {}".format(model_type)
|
64 |
+
return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type])
|
65 |
+
|
66 |
+
def before_evaluation(self, **kwargs):
|
67 |
+
pass
|
68 |
+
|
69 |
+
def show_n_params(self, return_str=True):
|
70 |
+
tot = 0
|
71 |
+
for p in self.parameters():
|
72 |
+
w = 1
|
73 |
+
for x in p.shape:
|
74 |
+
w *= x
|
75 |
+
tot += w
|
76 |
+
if return_str:
|
77 |
+
if tot >= 1e6:
|
78 |
+
return "{:.1f}M".format(tot / 1e6)
|
79 |
+
else:
|
80 |
+
return "{:.1f}K".format(tot / 1e3)
|
81 |
+
else:
|
82 |
+
return tot
|
83 |
+
|
84 |
+
|
85 |
+
In this base model, we already declare and standardize many common methods such as ``_from_config`` and ``_from_pretrained``.
|
86 |
+
Inheriting this base model class allows us to standardize operations of models across all model classes while still allowing customizations.
|
87 |
+
We advise users not to change the implementation of the base model class as this will affect all existing model subclasses.
|
88 |
+
|
89 |
+
GPT-style Video-grounded Dialogue Model ``lavis.models.gpt_models.gpt_dialogue``
|
90 |
+
********************************************************************************
|
91 |
+
|
92 |
+
In this step, we can define a new model class, e.g. under ``lavis.models.gpt_models.gpt_dialogue``, for GPT-based dialogue models designed specifically for video-grounded dialogues.
|
93 |
+
Note that we assume the model class inherits from the standard model super class ``GPT2LMHeadModel`` from the ``transformers`` `library <https://huggingface.co/docs/transformers/index>`_.
|
94 |
+
We also enforce model integration to the LAVIS framework through the inheritance of the ``BaseModel`` from the LAVIS library, as the secondary super class.
|
95 |
+
|
96 |
+
.. code-block:: python
|
97 |
+
|
98 |
+
import torch
|
99 |
+
from lavis.common.registry import registry
|
100 |
+
from lavis.models.base_model import BaseModel
|
101 |
+
|
102 |
+
from transformers import GPT2Model, GPT2LMHeadModel
|
103 |
+
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
104 |
+
import math
|
105 |
+
import torch
|
106 |
+
import torch.nn as nn
|
107 |
+
from torch.nn import CrossEntropyLoss, MSELoss
|
108 |
+
|
109 |
+
@registry.register_model("gpt_dialogue")
|
110 |
+
class GPTDialogue(GPT2LMHeadModel, BaseModel):
|
111 |
+
...
|
112 |
+
|
113 |
+
Next, we can modify the architecture of the model during model initialization to fit the tasks of interest, i.e. video-grounded dialogues.
|
114 |
+
In this case, we want to add additional model parameters for a linear network to transform the video feature representations to the model dimension.
|
115 |
+
|
116 |
+
.. code-block:: python
|
117 |
+
|
118 |
+
class GPTDialogue(GPT2LMHeadModel, BaseModel):
|
119 |
+
|
120 |
+
def __init__(self, config, len_video_ft=4224):
|
121 |
+
|
122 |
+
super().__init__(config)
|
123 |
+
|
124 |
+
self.video_ff = nn.Linear(len_video_ft, config.n_embd)
|
125 |
+
|
126 |
+
# Model parallel
|
127 |
+
self.model_parallel = False
|
128 |
+
self.device_map = None
|
129 |
+
|
130 |
+
# Initialize weights and apply final processing
|
131 |
+
self.post_init()
|
132 |
+
|
133 |
+
Note that for each new model class, we advise redefining the ``from_config`` method which is inherited from the ``BaseModel`` class.
|
134 |
+
As each model usually has its own unique configurations, redefining the method will ensure the model instances are created properly.
|
135 |
+
For instance, ``GPTDialogue`` requires an additional parameter of video feature length (``len_video_ft``) which should be part of the model initialization procedure.
|
136 |
+
Another additional parameter is the number of tokens/words (as we include additional special tokens in the vocabulary for dialogue tasks).
|
137 |
+
|
138 |
+
.. code-block:: python
|
139 |
+
|
140 |
+
class GPTDialogue(GPT2LMHeadModel, BaseModel):
|
141 |
+
...
|
142 |
+
@classmethod
|
143 |
+
def from_config(cls, cfg):
|
144 |
+
model = cls.from_pretrained('gpt2', len_video_ft=cfg['len_video_ft'])
|
145 |
+
model.resize_token_embeddings(cfg['len_tokenizer'])
|
146 |
+
return model
|
147 |
+
|
148 |
+
Other basic methods should also be defined explicitly in the new model class, including the ``forward`` function.
|
149 |
+
For instance, in GPT models for video-grounded dialogue tasks, we want the forward operation also includes the transformation and integration of video features before passing the representations to the Transformer layers.
|
150 |
+
|
151 |
+
.. code-block:: python
|
152 |
+
|
153 |
+
class GPTDialogue(GPT2LMHeadModel, BaseModel):
|
154 |
+
...
|
155 |
+
|
156 |
+
def forward(self, samples,
|
157 |
+
past_key_values=None,
|
158 |
+
position_ids=None,
|
159 |
+
head_mask=None,
|
160 |
+
encoder_hidden_states=None,
|
161 |
+
encoder_attention_mask=None,
|
162 |
+
use_cache=None,
|
163 |
+
output_attentions=None,
|
164 |
+
output_hidden_states=None,
|
165 |
+
return_dict=None):
|
166 |
+
|
167 |
+
input_embs = self.transformer.wte(samples['input_ids'])
|
168 |
+
video_embs = self.video_ff(samples['video_fts'])
|
169 |
+
input_embs = torch.cat([video_embs, input_embs], dim=1)
|
170 |
+
|
171 |
+
transformer_outputs = self.transformer(
|
172 |
+
attention_mask=samples['attn_mask'],
|
173 |
+
token_type_ids=samples['token_type_ids'],
|
174 |
+
inputs_embeds=input_embs,
|
175 |
+
position_ids=position_ids,
|
176 |
+
head_mask=head_mask,
|
177 |
+
encoder_hidden_states=encoder_hidden_states,
|
178 |
+
encoder_attention_mask=encoder_attention_mask,
|
179 |
+
use_cache=use_cache,
|
180 |
+
output_attentions=output_attentions,
|
181 |
+
output_hidden_states=output_hidden_states,
|
182 |
+
return_dict=return_dict,
|
183 |
+
)
|
184 |
+
hidden_states = transformer_outputs[0]
|
185 |
+
|
186 |
+
lm_logits = self.lm_head(hidden_states)
|
187 |
+
...
|
188 |
+
|
189 |
+
Registering New Model ``lavis.models.__init__``
|
190 |
+
********************************************************************************
|
191 |
+
|
192 |
+
Any new model must be officially registered as part of the ``lavis.models`` module.
|
193 |
+
For instance, to add a model class for GPT-based dialogue models, we can modify the ``__init__.py`` as follows:
|
194 |
+
|
195 |
+
.. code-block:: python
|
196 |
+
|
197 |
+
from lavis.models.gpt_models.gpt_dialogue import GPTDialogue
|
198 |
+
|
199 |
+
__all__ = [
|
200 |
+
...
|
201 |
+
"GPTDialogue"
|
202 |
+
]
|
203 |
+
|
204 |
+
Assigning Model
|
205 |
+
********************************************************************************
|
206 |
+
|
207 |
+
From the above example of a model class, note that we define a ``from_config method`` for the new model class.
|
208 |
+
This method will process a configuration file and pass specific parameters to initialize the model classes properly.
|
209 |
+
To do this, we can assign/ associate the correct registry of model classes in a configuration file.
|
210 |
+
For instance, the following should be specified in a configuration file e.g. ``dialogue_avsd_ft.yaml``:
|
211 |
+
|
212 |
+
.. code-block:: yaml
|
213 |
+
|
214 |
+
model:
|
215 |
+
arch: gpt_dialogue # name of the model
|
216 |
+
model_type: base
|
217 |
+
|
218 |
+
|
219 |
+
Subsequently, any processes (e.g. training) should load this configuration file to assign the correct model.
|
220 |
+
|
221 |
+
.. code-block:: sh
|
222 |
+
|
223 |
+
python train.py --cfg-path dialogue_avsd_ft.yaml
|
224 |
+
|
225 |
+
Note that to simplify the model configuration, we only enable two main parameters here: ``arch`` and ``model_type``. ``arch`` refers to the model class registry, and ``model_type`` is the corresponding model type under this model family.
|
226 |
+
For instance, with ``gpt_dialogue``, we have a model ``base`` which has its own configuration in a separate configuration file e.g. ``gpt_dialogue_base.yaml``:
|
227 |
+
|
228 |
+
.. code-block:: yaml
|
229 |
+
|
230 |
+
model:
|
231 |
+
arch: gpt_dialogue
|
232 |
+
len_tokenizer: 50264 # 50257 tokens from gpt2 default tokenizer + additional special tokens
|
233 |
+
len_video_ft: 4224 # i3d_rgb: 2048 i3d_flow: 2048 vggish: 128
|
234 |
+
|
235 |
+
We can pass load this configuration and pass the parameters to the above ``from_config`` method to initialize the model accordingly.
|
236 |
+
We advise the users to maintain a dictionary that contains default paths to model configurations, in the model class definition.
|
237 |
+
By default, the LAVIS framework will search for configurations from each model class defined as ``model.PRETRAINED_MODEL_CONFIG_DICT``.
|
238 |
+
|
239 |
+
.. code-block:: python
|
240 |
+
|
241 |
+
class GPTDialogue(GPT2LMHeadModel, BaseModel):
|
242 |
+
PRETRAINED_MODEL_CONFIG_DICT = {
|
243 |
+
"base": "configs/models/gpt_dialogue_base.yaml"
|
244 |
+
}
|
245 |
+
...
|
docs/tutorial.processors.rst
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Adding Processors
|
2 |
+
################################################
|
3 |
+
|
4 |
+
This is a tutorial on adding new processors using ``lavis.processors`` module.
|
5 |
+
|
6 |
+
The LAVIS library includes a standard processor module that preprocesses data e.g. image transformation and sequence concatenation.
|
7 |
+
The ``lavis.processors`` module is designed such that any processors can be added, specifically to the requirements of corresponding models of interest.
|
8 |
+
In this tutorial, we will replicate the steps to add visual and textual processors specifically for `video-grounded dialogue tasks <https://arxiv.org/pdf/1901.09107.pdf>`_.
|
9 |
+
In addition, we also want the processors to have processing features to make the data samples compatible with GPT-style models.
|
10 |
+
|
11 |
+
Base Processor ``lavis.processors.base_processors``
|
12 |
+
*****************************************************
|
13 |
+
|
14 |
+
Note that any new processor definition should inherit the base processor class ``BaseProcessor``:
|
15 |
+
|
16 |
+
.. code-block:: python
|
17 |
+
|
18 |
+
from omegaconf import OmegaConf
|
19 |
+
|
20 |
+
class BaseProcessor:
|
21 |
+
def __init__(self):
|
22 |
+
self.transform = lambda x: x
|
23 |
+
return
|
24 |
+
|
25 |
+
def __call__(self, item):
|
26 |
+
return self.transform(item)
|
27 |
+
|
28 |
+
@classmethod
|
29 |
+
def from_config(cls, cfg=None):
|
30 |
+
return cls()
|
31 |
+
|
32 |
+
def build(self, **kwargs):
|
33 |
+
cfg = OmegaConf.create(kwargs)
|
34 |
+
|
35 |
+
return self.from_config(cfg)
|
36 |
+
|
37 |
+
This allows us to standardize operations of processors across all processor classes while still allowing customization of processors specifically to data and model types.
|
38 |
+
We encourage users not to modify the implementation of the base processor class as this will have an impact on all existing processor subclasses.
|
39 |
+
|
40 |
+
GPT-style Processors ``lavis.processors.gpt_processors``
|
41 |
+
**************************************************************
|
42 |
+
In this step, we can define new processor classes, e.g. under ``lavis.processors.gpt_processors``, for GPT models designed specifically for video-grounded dialogues.
|
43 |
+
First, we want to process video features by defining ``GPTVideoFeatureProcessor`` class.
|
44 |
+
In this tutorial, we assume video features are extracted beforehand and this processor simply loads the features from ``npy`` files.
|
45 |
+
Other methods that are specifically defined are ``padding`` (which is used by dataset instances to pad multiple video samples) and ``get_attention_mask`` (which creates an attention mask for Transformer attention in GPT models).
|
46 |
+
|
47 |
+
.. code-block:: python
|
48 |
+
|
49 |
+
SPECIAL_TOKENS_DICT = {'bos_token': "<bos>", 'eos_token': "<eos>", 'additional_special_tokens': ["<speaker1>", "<speaker2>", "<video>", "<cap>"], 'pad_token': "<pad>"}
|
50 |
+
...
|
51 |
+
|
52 |
+
@registry.register_processor("gpt_video_ft")
|
53 |
+
class GPTVideoFeatureProcessor(BaseProcessor):
|
54 |
+
def __init__(self, visual_ft, audio_ft):
|
55 |
+
|
56 |
+
self.visual_ft = visual_ft
|
57 |
+
self.audio_ft = audio_ft
|
58 |
+
|
59 |
+
self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
60 |
+
self.tokenizer.add_special_tokens(SPECIAL_TOKENS_DICT)
|
61 |
+
|
62 |
+
def padding(self, seq):
|
63 |
+
padded_seq = torch.nn.utils.rnn.pad_sequence(seq, batch_first=True, padding_value=1.0)
|
64 |
+
return padded_seq
|
65 |
+
|
66 |
+
def get_attention_mask(self, seq):
|
67 |
+
return torch.sum(seq != 1, dim=2) != 0
|
68 |
+
|
69 |
+
def __call__(self, ft_root, vname):
|
70 |
+
all_ft = []
|
71 |
+
|
72 |
+
for ft_name in self.visual_ft:
|
73 |
+
ft_path = os.path.join(ft_root, ft_name, vname)
|
74 |
+
all_ft.append(np.load(ft_path + '.npy'))
|
75 |
+
|
76 |
+
for ft_name in self.audio_ft:
|
77 |
+
ft_path = os.path.join(ft_root, ft_name, vname)
|
78 |
+
all_ft.append(np.load(ft_path + '.npy'))
|
79 |
+
|
80 |
+
min_len = min([len(ft) for ft in all_ft])
|
81 |
+
|
82 |
+
sampled_ft = [ft[:min_len] for ft in all_ft]
|
83 |
+
sampled_ft = np.concatenate(sampled_ft, axis=1)
|
84 |
+
item = {}
|
85 |
+
item['video_fts'] = torch.Tensor(sampled_ft)
|
86 |
+
|
87 |
+
video_type_token = self.tokenizer.convert_tokens_to_ids('<video>')
|
88 |
+
item['token_type_ids'] = torch.Tensor([video_type_token] * len(sampled_ft)).long()
|
89 |
+
|
90 |
+
return item
|
91 |
+
|
92 |
+
@classmethod
|
93 |
+
def from_config(cls, cfg=None):
|
94 |
+
if cfg is None:
|
95 |
+
cfg = OmegaConf.create()
|
96 |
+
|
97 |
+
visual_ft = cfg.get("visual_ft", ["i3d_rgb"])
|
98 |
+
audio_ft = cfg.get("audio_ft", ["vggish"])
|
99 |
+
|
100 |
+
return cls(
|
101 |
+
visual_ft=visual_ft,
|
102 |
+
audio_ft=audio_ft
|
103 |
+
)
|
104 |
+
|
105 |
+
Another processor class that will be useful to have is to process dialogue data. Here we can define a ``GPTDialogueProcessor`` class.
|
106 |
+
This processor class receives raw annotations and constructs inputs as a concatenation of input sequences (questions, dialogue contexts, and responses) to facilitate application in GPT models.
|
107 |
+
Other methods that are specifically defined are ``padding`` (which is used by dataset instances to pad multiple sequence samples) and ``get_attention_mask`` (which creates an attention mask for Transformer attention in GPT models).
|
108 |
+
|
109 |
+
.. code-block:: python
|
110 |
+
|
111 |
+
SPECIAL_TOKENS_DICT = {'bos_token': "<bos>", 'eos_token': "<eos>", 'additional_special_tokens': ["<speaker1>", "<speaker2>", "<video>", "<cap>"], 'pad_token': "<pad>"}
|
112 |
+
...
|
113 |
+
|
114 |
+
@registry.register_processor("gpt_dialogue")
|
115 |
+
class GPTDialogueProcessor(BaseProcessor):
|
116 |
+
def __init__(self, max_turns=3, use_caption=True):
|
117 |
+
self.max_turns = max_turns
|
118 |
+
self.use_caption = use_caption
|
119 |
+
self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
120 |
+
self.tokenizer.add_special_tokens(SPECIAL_TOKENS_DICT)
|
121 |
+
|
122 |
+
def sample_sequence(self, caption, history, answer):
|
123 |
+
bos, eos, speaker1, speaker2, cap = self.tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[:-2])
|
124 |
+
instance = {}
|
125 |
+
sequence = [caption] + history + [answer]
|
126 |
+
sequence = [s + [eos] for s in sequence]
|
127 |
+
|
128 |
+
instance["input_ids"] = list(chain(*sequence))
|
129 |
+
instance["token_type_ids"] = [cap] * len(sequence[0]) + [speaker2 if i % 2 else speaker1 for i, s in enumerate(sequence[1:]) for _ in s]
|
130 |
+
instance["labels"] = ([-1]*sum(len(s) for s in sequence[:-1])) + sequence[-1]
|
131 |
+
|
132 |
+
assert len(instance["input_ids"])==len(instance["token_type_ids"])
|
133 |
+
assert len(instance["token_type_ids"])==len(instance["labels"])
|
134 |
+
|
135 |
+
for k,v in instance.items():
|
136 |
+
instance[k] = torch.Tensor(v).long()
|
137 |
+
|
138 |
+
return instance
|
139 |
+
|
140 |
+
def padding(self, seq, pad_token=-1):
|
141 |
+
if pad_token==-1: pad_token = self.tokenizer.pad_token_id
|
142 |
+
padded_seq = torch.nn.utils.rnn.pad_sequence(seq, batch_first=True, padding_value=pad_token)
|
143 |
+
return padded_seq
|
144 |
+
|
145 |
+
def get_attention_mask(self, seq, pad_token=-1):
|
146 |
+
if pad_token==-1: pad_token = self.tokenizer.pad_token_id
|
147 |
+
return seq != pad_token
|
148 |
+
|
149 |
+
def __call__(self, ann):
|
150 |
+
if self.use_caption:
|
151 |
+
caption = ' '.join([ann['caption'], ann['summary']])
|
152 |
+
caption = self.tokenizer.encode(caption)
|
153 |
+
else:
|
154 |
+
caption = []
|
155 |
+
|
156 |
+
dial_history = []
|
157 |
+
for turn in ann['dialog'][-self.max_turns:]:
|
158 |
+
dial_history.append(turn['question'])
|
159 |
+
dial_history.append(turn['answer'])
|
160 |
+
dial_history.append(ann['question'])
|
161 |
+
dial_history = [self.tokenizer.encode(t) for t in dial_history]
|
162 |
+
|
163 |
+
answer = self.tokenizer.encode(ann['answer'])
|
164 |
+
|
165 |
+
item = self.sample_sequence(caption, dial_history, answer)
|
166 |
+
|
167 |
+
return item
|
168 |
+
|
169 |
+
@classmethod
|
170 |
+
def from_config(cls, cfg=None):
|
171 |
+
if cfg is None:
|
172 |
+
cfg = OmegaConf.create()
|
173 |
+
|
174 |
+
use_caption = cfg.get("use_caption", True)
|
175 |
+
max_turns = cfg.get("max_turns", 3)
|
176 |
+
|
177 |
+
return cls(max_turns=max_turns, use_caption=use_caption)
|
178 |
+
|
179 |
+
Registering New Processors ``lavis.processors.__init__``
|
180 |
+
**************************************************************
|
181 |
+
|
182 |
+
Finally, any new processor must be officially registered as part of the ``lavis.processors`` module.
|
183 |
+
For instance, to add processor classes for GPT-based dialogue models, including one for dialogue data ``GPTDialogueProcessor`` and one for video features ``GPTVideoFeatureProcessor``, we can modify the ``__init__.py`` as follows:
|
184 |
+
|
185 |
+
.. code-block:: python
|
186 |
+
|
187 |
+
from lavis.processors.gpt_processors import (
|
188 |
+
GPTVideoFeatureProcessor,
|
189 |
+
GPTDialogueProcessor,
|
190 |
+
)
|
191 |
+
|
192 |
+
__all__ = [
|
193 |
+
...
|
194 |
+
# GPT
|
195 |
+
"GPTVideoFeatureProcessor",
|
196 |
+
"GPTDialogueProcessor"
|
197 |
+
]
|
198 |
+
|
199 |
+
Assigning Processors
|
200 |
+
**************************************************************
|
201 |
+
From the above example of processor classes, note that we define a ``from_config`` method for each class.
|
202 |
+
This method will process a configuration file and pass specific parameters e.g. ``max_turns``, ``visual_ft``, to initialize the processor classes properly.
|
203 |
+
To do this, we can assign/ associate the correct registry of processor classes in a configuration file.
|
204 |
+
For instance, the following should be specified in a configuration file e.g. ``dialogue_avsd_ft.yaml``:
|
205 |
+
|
206 |
+
.. code-block:: yaml
|
207 |
+
|
208 |
+
datasets:
|
209 |
+
avsd_dialogue: # name of the dataset builder
|
210 |
+
vis_processor:
|
211 |
+
train:
|
212 |
+
name: "gpt_video_ft" # name of the visual processor for training data
|
213 |
+
visual_ft: ["i3d_flow", "i3d_rgb"]
|
214 |
+
audio_ft: ["vggish"]
|
215 |
+
eval:
|
216 |
+
name: "gpt_video_ft" # name of the visual processor for evaluation data
|
217 |
+
visual_ft: ["i3d_flow", "i3d_rgb"]
|
218 |
+
audio_ft: ["vggish"]
|
219 |
+
text_processor:
|
220 |
+
train:
|
221 |
+
name: "gpt_dialogue" # name of the textual processor for training data
|
222 |
+
max_turns: 3
|
223 |
+
use_caption: True
|
224 |
+
eval:
|
225 |
+
name: "gpt_dialogue" # name of the textual processor for evaluation data
|
226 |
+
max_turns: 3
|
227 |
+
use_caption: True
|
228 |
+
|
229 |
+
Subsequently, any processes (e.g. training) should load this configuration file to assign the correct processors.
|
230 |
+
|
231 |
+
.. code-block:: sh
|
232 |
+
|
233 |
+
python train.py --cfg-path dialogue_avsd_ft.yaml
|
docs/tutorial.rst
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Tutorials
|
2 |
+
==============================
|
3 |
+
|
4 |
+
.. toctree::
|
5 |
+
:maxdepth: 1
|
6 |
+
|
7 |
+
tutorial.evaluation
|
8 |
+
tutorial.training-example
|
9 |
+
tutorial.configs
|
10 |
+
tutorial.datasets
|
11 |
+
tutorial.processors
|
12 |
+
tutorial.models
|
13 |
+
tutorial.tasks
|
docs/tutorial.tasks.rst
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Adding Tasks
|
2 |
+
####################################
|
3 |
+
|
4 |
+
This is a tutorial on adding new machine learning tasks using ``lavis.tasks`` module.
|
5 |
+
|
6 |
+
The LAVIS library includes a standard task module that centralizes the model training and evaluation procedure of machine learning tasks.
|
7 |
+
The ``lavis.tasks`` module is designed such that any new tasks can be added and integrated, catering to any customization in the training and testing procedures.
|
8 |
+
In this tutorial, we will replicate the steps to add a new task into LAVIS for the `video-grounded dialogue tasks <https://arxiv.org/pdf/1901.09107.pdf>`_.
|
9 |
+
|
10 |
+
Base Task ``lavis.tasks.base_task``
|
11 |
+
********************************************************************************
|
12 |
+
|
13 |
+
Note that any new model definition should inherit the base task class ``BaseTask``:
|
14 |
+
|
15 |
+
.. code-block:: python
|
16 |
+
|
17 |
+
import logging
|
18 |
+
import os
|
19 |
+
|
20 |
+
import torch.distributed as dist
|
21 |
+
from lavis.common.dist_utils import get_rank, get_world_size, is_main_process
|
22 |
+
from lavis.common.logger import MetricLogger, SmoothedValue
|
23 |
+
from lavis.common.registry import registry
|
24 |
+
from lavis.datasets.data_utils import prepare_sample
|
25 |
+
|
26 |
+
class BaseTask:
|
27 |
+
def __init__(self, **kwargs):
|
28 |
+
super().__init__()
|
29 |
+
|
30 |
+
self.inst_id_key = "instance_id"
|
31 |
+
|
32 |
+
@classmethod
|
33 |
+
def setup_task(cls, **kwargs):
|
34 |
+
return cls()
|
35 |
+
|
36 |
+
def build_model(self, cfg):
|
37 |
+
model_config = cfg.model_cfg
|
38 |
+
|
39 |
+
model_cls = registry.get_model_class(model_config.arch)
|
40 |
+
return model_cls.from_config(model_config)
|
41 |
+
|
42 |
+
def build_datasets(self, cfg):
|
43 |
+
"""
|
44 |
+
Build a dictionary of datasets, keyed by split 'train', 'valid', 'test'.
|
45 |
+
Download dataset and annotations automatically if not exist.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
cfg (common.config.Config): _description_
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
dict: Dictionary of torch.utils.data.Dataset objects by split.
|
52 |
+
"""
|
53 |
+
|
54 |
+
datasets = dict()
|
55 |
+
|
56 |
+
datasets_config = cfg.datasets_cfg
|
57 |
+
|
58 |
+
assert len(datasets_config) > 0, "At least one dataset has to be specified."
|
59 |
+
|
60 |
+
for name in datasets_config:
|
61 |
+
dataset_config = datasets_config[name]
|
62 |
+
|
63 |
+
builder = registry.get_builder_class(name)(dataset_config)
|
64 |
+
dataset = builder.build_datasets()
|
65 |
+
|
66 |
+
datasets[name] = dataset
|
67 |
+
|
68 |
+
return datasets
|
69 |
+
|
70 |
+
def train_step(self, model, samples):
|
71 |
+
loss = model(samples)["loss"]
|
72 |
+
return loss
|
73 |
+
|
74 |
+
...
|
75 |
+
|
76 |
+
In this base task, we already declare and standardize many common methods such as ``train_step``, ``build_model``, and ``build_datasets``.
|
77 |
+
Inheriting this base task class allows us to standardize operations of tasks across all task classes.
|
78 |
+
We recommend users not change the implementation of the base task class as this will have an impact on all existing task subclasses.
|
79 |
+
|
80 |
+
Dialogue Task ``lavis.tasks.dialogue``
|
81 |
+
********************************************************************************
|
82 |
+
|
83 |
+
In this step, we can define a new task class, e.g. under ``lavis.tasks.dialogue``, for video-grounded dialogues.
|
84 |
+
For instance, we define a new task class ``DialogueTask`` that inherits the super task class ``BaseTask``.
|
85 |
+
|
86 |
+
.. code-block:: python
|
87 |
+
|
88 |
+
import json
|
89 |
+
import os
|
90 |
+
|
91 |
+
from lavis.common.dist_utils import main_process
|
92 |
+
from lavis.common.logger import MetricLogger
|
93 |
+
from lavis.common.registry import registry
|
94 |
+
from lavis.tasks.base_task import BaseTask
|
95 |
+
from lavis.datasets.data_utils import prepare_sample
|
96 |
+
|
97 |
+
import numpy as np
|
98 |
+
|
99 |
+
@registry.register_task("dialogue")
|
100 |
+
class DialogueTask(BaseTask):
|
101 |
+
def __init__(self, num_beams, max_len, min_len, evaluate, report_metric=True):
|
102 |
+
super().__init__()
|
103 |
+
|
104 |
+
self.num_beams = num_beams
|
105 |
+
self.max_len = max_len
|
106 |
+
self.min_len = min_len
|
107 |
+
self.evaluate = evaluate
|
108 |
+
|
109 |
+
self.report_metric = report_metric
|
110 |
+
|
111 |
+
@classmethod
|
112 |
+
def setup_task(cls, cfg):
|
113 |
+
run_cfg = cfg.run_cfg
|
114 |
+
|
115 |
+
num_beams = run_cfg.num_beams
|
116 |
+
max_len = run_cfg.max_len
|
117 |
+
min_len = run_cfg.min_len
|
118 |
+
evaluate = run_cfg.evaluate
|
119 |
+
|
120 |
+
report_metric = run_cfg.get("report_metric", True)
|
121 |
+
|
122 |
+
return cls(
|
123 |
+
num_beams=num_beams,
|
124 |
+
max_len=max_len,
|
125 |
+
min_len=min_len,
|
126 |
+
evaluate=evaluate,
|
127 |
+
report_metric=report_metric,
|
128 |
+
)
|
129 |
+
|
130 |
+
def valid_step(self, model, samples):
|
131 |
+
results = []
|
132 |
+
loss = model(samples)["loss"].item()
|
133 |
+
|
134 |
+
return [loss]
|
135 |
+
...
|
136 |
+
|
137 |
+
Note that for any new task, we advise the users to review carefully the functions implemented within ``BaseTask`` and consider which methods should be modified.
|
138 |
+
For instance, the base task class already contains a standard implementation of model training steps that are common among machine learning steps.
|
139 |
+
Some major methods we want to emphasize and should be customized by each task are the ``valid_step`` and ``evaluation``.
|
140 |
+
These operations were not fully implemented in the base task class due to the differences in evaluation procedures among many machine learning tasks.
|
141 |
+
Another method that should be considered is the ``setup_task`` method.
|
142 |
+
This method will receive configurations that set task-specific parameters to initialize any task instance.
|
143 |
+
|
144 |
+
Registering New Task ``lavis.tasks.__init__``
|
145 |
+
********************************************************************************
|
146 |
+
|
147 |
+
Any new task must be officially registered as part of the ``lavis.tasks`` module. For instance, to add a new task for video-grounded dialogues, we can modify the ``__init__.py`` as follows:
|
148 |
+
|
149 |
+
.. code-block:: python
|
150 |
+
|
151 |
+
from lavis.tasks.dialogue import DialogueTask
|
152 |
+
|
153 |
+
...
|
154 |
+
__all__ = [
|
155 |
+
...
|
156 |
+
"DialogueTask"
|
157 |
+
]
|
158 |
+
|
159 |
+
Assigning Task
|
160 |
+
***************
|
161 |
+
|
162 |
+
From the above example of task class, note that we define a ``setup_task`` method for each task class.
|
163 |
+
This method will process a configuration file and pass specific parameters e.g. ``num_beams`` (for beam search generative tasks during the inference stage), to initialize the task classes properly.
|
164 |
+
To assign and associate any task, we need to specify the correct registry of task classes in a configuration file.
|
165 |
+
For instance, the following should be specified in a configuration file e.g. ``dialogue_avsd_ft.yaml``:
|
166 |
+
|
167 |
+
.. code-block:: yaml
|
168 |
+
|
169 |
+
run:
|
170 |
+
task: dialogue # name of the task
|
171 |
+
|
172 |
+
# optimizer
|
173 |
+
...
|
174 |
+
|
175 |
+
max_len: 20
|
176 |
+
min_len: 5
|
177 |
+
num_beams: 3
|
178 |
+
...
|
179 |
+
|
180 |
+
Subsequently, any processes (e.g. training) should load this configuration file to assign the correct task.
|
181 |
+
|
182 |
+
.. code-block:: sh
|
183 |
+
|
184 |
+
python train.py --cfg-path dialogue_avsd_ft.yaml
|
docs/tutorial.training-example.rst
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Example on Finetuning BLIP on COCO-Captioning
|
2 |
+
################################################
|
3 |
+
|
4 |
+
To finetune BLIP model on the coco caption dataset, first refer to :ref:`prep coco` to prepare the dataset if you have not done so.
|
5 |
+
|
6 |
+
To finetune the model, we have prepared a run script for you, which can run as follows:
|
7 |
+
|
8 |
+
.. code-block:: bash
|
9 |
+
|
10 |
+
bash run_scripts/blip/train/train_caption_coco_large.sh
|
11 |
+
|
12 |
+
This will finetune the pre-trained BLIP large model into a new model that can be used for captioning.
|
13 |
+
|
14 |
+
Deep Dive
|
15 |
+
**********
|
16 |
+
Now let's take a closer look at the script and see what it does.
|
17 |
+
|
18 |
+
.. code-block:: bash
|
19 |
+
|
20 |
+
python -m torch.distributed.run --nproc_per_node=8 train.py --cfg-path lavis/projects/blip/train/caption_coco_large_ft.yaml
|
21 |
+
|
22 |
+
As can be seen, the script simply calls the :code:`train.py` with PyTorch distributed training enabled.
|
23 |
+
The :code:`--cfg-path` argument specifies the **runtime config** file to use. The config file is a YAML file that specifies the training parameters, shown as follows:
|
24 |
+
|
25 |
+
.. literalinclude:: ../lavis/projects/blip/train/caption_coco_large_ft.yaml
|
26 |
+
:language: yaml
|
27 |
+
:linenos:
|
28 |
+
|
29 |
+
The runtime config file is divided into 3 sections:
|
30 |
+
- :code:`model`: specifies the model architecture and type to use.
|
31 |
+
- :code:`data`: specifies the dataset to use.
|
32 |
+
- :code:`run`: specifies the runner arguments, such as tasks, optimizer, learning rate scheduler, etc.
|
33 |
+
|
34 |
+
We describe each section in detail below.
|
35 |
+
|
36 |
+
Model configurations
|
37 |
+
=====================
|
38 |
+
|
39 |
+
.. literalinclude:: ../lavis/projects/blip/train/caption_coco_large_ft.yaml
|
40 |
+
:language: yaml
|
41 |
+
:linenos:
|
42 |
+
:lines: 6-10
|
43 |
+
|
44 |
+
The :code:`arch` argument specifies the model architecture to use. In this case, we use the :code:`blip_caption` architecture.
|
45 |
+
You can find available architectures by inspecting the :code:`model_zoo`.
|
46 |
+
Once the architecture is specified, the runner will look for the model class registered with the name and try to instantiate a model instance.
|
47 |
+
In this case :code:`BlipCaption` is the model registered with the name :code:`blip_caption`.
|
48 |
+
|
49 |
+
The registry maintains a mapping from the name string to the model class.
|
50 |
+
This allows the runner to find the model class dynamically based on the name string from the config file.
|
51 |
+
The following segment in :code:`lavis/models/blip_models/blip_caption.py` shows how :code:`BlipCaption` is registered with the name string :code:`blip_caption`:
|
52 |
+
|
53 |
+
.. literalinclude:: ../lavis/models/blip_models/blip_caption.py
|
54 |
+
:language: python
|
55 |
+
:linenos:
|
56 |
+
:lines: 20-38
|
57 |
+
|
58 |
+
One same model architecture may be pre-trained or finetuned on different datasets or have different model configurations.
|
59 |
+
For example, :code:`BlipCaption` have:
|
60 |
+
|
61 |
+
- :code:`base_coco`: pre-trained base BLIP model adapated for COCO captioning finetuning.
|
62 |
+
|
63 |
+
- :code:`large_coco`: pre-trained large BLIP model adapated for COCO captioning finetuning.
|
64 |
+
|
65 |
+
Therefore, we also need to specify :code:`model_type`. Here we use :code:`large_coco`.
|
66 |
+
And we set :code:`load_finetuned` to :code:`False` to indicate that we are finetuning the model from the pre-trained weights.
|
67 |
+
If :code:`load_finetuned` set to :code:`True` as by default, the model will load finetuned weights on coco captioning.
|
68 |
+
|
69 |
+
Given the model architecture and type, the library will then look for the default model config for :code:`large_coco` in :code:`lavis/models/blip_models/blip_caption.py`.
|
70 |
+
As can be seen in the above code snippet, the corresponding config path is stored in :code:`BlipCaption.PRETRAINED_MODEL_CONFIG_DICT`.
|
71 |
+
Then the library will load :code:`lavis/configs/models/blip_caption_large_coco.yaml` as the configuration to build the model.
|
72 |
+
|
73 |
+
*Priority of Configs*: Note that the priority of the run config is higher than the default model config, meaning that arguments in the run config will override the default model config.
|
74 |
+
For example, in the default model config, :code:`load_finetuned` is set to :code:`True` by default, while in the run config, we set it to :code:`False` and finetuning from the pre-trained weights only.
|
75 |
+
|
76 |
+
|
77 |
+
Dataset configurations
|
78 |
+
=========================
|
79 |
+
|
80 |
+
The second section of the config file specifies the dataset(s) to use.
|
81 |
+
|
82 |
+
.. literalinclude:: ../lavis/projects/blip/train/caption_coco_large_ft.yaml
|
83 |
+
:language: yaml
|
84 |
+
:linenos:
|
85 |
+
:lines: 12-24
|
86 |
+
|
87 |
+
We associate each dataset with a :code:`vis_processor` and a :code:`text_processor`, responsible for processing the visual and textual input respectively.
|
88 |
+
Here we again use the registry mechanism to dynamically load the processor class based on the name string.
|
89 |
+
For example, :code:`blip_image_train` is the name string for the :code:`BlipImageTrainProcessor` class, which is registered in :code:`lavis/processors/blip_processors.py`.
|
90 |
+
|
91 |
+
Similarly, the dataset name string is also registered in the registry, pointing to a dataset builder :code:`COCOCapBuilder` class.
|
92 |
+
By default, the builder will load the default dataset configuration as in :code:`DATASET_CONFIG_DICT`. You may also add new dataset types by adding new entries to the dictionary.
|
93 |
+
|
94 |
+
The dataset configuration used here is:
|
95 |
+
|
96 |
+
.. literalinclude:: ../lavis/configs/datasets/coco/defaults_cap.yaml
|
97 |
+
:language: yaml
|
98 |
+
:linenos:
|
99 |
+
:lines: 6-28
|
100 |
+
|
101 |
+
In this configuration file, we specify the dataset name and mainly its building information.
|
102 |
+
The build information is divided into two parts: :code:`annotation` and :code:`images`. The annotation files will be automatically downloaded upon loading the dataset for the first time.
|
103 |
+
The :code:`images` part specifies the image root directory. This is a relative path to the cache directory, which is :code:`cache` by default. If you have a local copy of the dataset, you can specify the path to the local copy by
|
104 |
+
overwriting the :code:`images` part in the runtime config file. For example, you may alter the run config as below to use your local dataset copy:
|
105 |
+
|
106 |
+
.. code:: yaml
|
107 |
+
|
108 |
+
datasets:
|
109 |
+
coco_caption: # name of the dataset builder
|
110 |
+
vis_processor:
|
111 |
+
train:
|
112 |
+
name: "blip_image_train"
|
113 |
+
eval:
|
114 |
+
name: "blip_image_eval"
|
115 |
+
text_processor:
|
116 |
+
train:
|
117 |
+
name: "blip_caption"
|
118 |
+
prompt: "a picture of "
|
119 |
+
eval:
|
120 |
+
name: "blip_caption"
|
121 |
+
images:
|
122 |
+
YOUR_LOCAL_IMAGE_ROOT_DIR
|
123 |
+
|
124 |
+
LAVIS supports using multiple datasets for training. See an example in :code:`lavis/projects/blip/train/pretrain_14m.yaml`.
|
125 |
+
|
126 |
+
|
127 |
+
Runner configurations
|
128 |
+
=========================
|
129 |
+
The last section of the config file specifies the arguments for the runner, shown below:
|
130 |
+
|
131 |
+
.. literalinclude:: ../lavis/projects/blip/train/caption_coco_large_ft.yaml
|
132 |
+
:language: yaml
|
133 |
+
:linenos:
|
134 |
+
:lines: 26-56
|
135 |
+
|
136 |
+
Here we specify runner-related arguments, including
|
137 |
+
- task-specific arguments, such as :code:`task`, :code:`max_len`, :code:`min_len`, etc.
|
138 |
+
- learning rate schedulers, optimizer;
|
139 |
+
- distributed training settings;
|
140 |
+
- logging and checkpointing settings.
|
141 |
+
|
142 |
+
Available Configurations
|
143 |
+
#########################
|
144 |
+
|
145 |
+
See :ref:`config` for the full list of available configurations and their descriptions.
|
examples/blip2_itm.py
ADDED
@@ -0,0 +1,520 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
from lavis.models import load_model_and_preprocess
|
7 |
+
from lavis.processors import load_processor
|
8 |
+
from lavis.common.registry import registry
|
9 |
+
from torch.nn import functional as F
|
10 |
+
from lavis.models.base_model import all_gather_with_grad, concat_all_gather
|
11 |
+
import numpy as np
|
12 |
+
import pandas as pd
|
13 |
+
import time
|
14 |
+
from fuzzywuzzy import process
|
15 |
+
from multiprocessing import Pool, Queue, Process
|
16 |
+
import difflib
|
17 |
+
import Levenshtein
|
18 |
+
import os
|
19 |
+
# import obonet
|
20 |
+
|
21 |
+
|
22 |
+
def fuzzy_match(texts):
|
23 |
+
text_dict = {}
|
24 |
+
for context in texts:
|
25 |
+
if context not in choices:
|
26 |
+
# txt_dict[txt] = process.extractOne(txt, choices)[0]
|
27 |
+
text_dict[context] = difflib.get_close_matches(context, choices, n=1, cutoff=0.)[0]
|
28 |
+
return text_dict
|
29 |
+
|
30 |
+
|
31 |
+
def txt_map(x, txt_dict):
|
32 |
+
if type(x) == str:
|
33 |
+
x = eval(x)
|
34 |
+
x_ = []
|
35 |
+
for i in x:
|
36 |
+
if i in txt_dict:
|
37 |
+
x_.append(txt_dict[i])
|
38 |
+
else:
|
39 |
+
x_.append(i)
|
40 |
+
return x_
|
41 |
+
|
42 |
+
|
43 |
+
def levenshtein_sim(text, label):
|
44 |
+
all_s = []
|
45 |
+
for x in label:
|
46 |
+
s = 0
|
47 |
+
for y in text:
|
48 |
+
temp = Levenshtein.ratio(x, y)
|
49 |
+
if temp > s:
|
50 |
+
s = temp
|
51 |
+
all_s.append(s)
|
52 |
+
all_s = [round(i, 3) for i in all_s]
|
53 |
+
return all_s
|
54 |
+
|
55 |
+
def func(text, label):
|
56 |
+
all_s = []
|
57 |
+
for x in label:
|
58 |
+
s = 0
|
59 |
+
for y in text:
|
60 |
+
temp = Levenshtein.ratio(x, y)
|
61 |
+
if temp > s:
|
62 |
+
s = temp
|
63 |
+
all_s.append(s)
|
64 |
+
all_s = [round(i, 3) for i in all_s]
|
65 |
+
return all_s
|
66 |
+
|
67 |
+
|
68 |
+
def stage2_output(df_test):
|
69 |
+
config = {'arch': 'blip2_protein_opt', 'load_finetuned': False,
|
70 |
+
'pretrained': '/cluster/home/wenkai/LAVIS/lavis/output/BLIP2/Pretrain_stage2/20230924220/checkpoint_5.pth',
|
71 |
+
'finetuned': '', 'num_query_token': 32, 'opt_model': 'facebook/opt-2.7b', 'prompt': '',
|
72 |
+
'model_type': 'pretrain_protein_opt2.7b', 'load_pretrained': True, 'freeze_vit': True,
|
73 |
+
'max_protein_len': 600,
|
74 |
+
'max_txt_len': 25}
|
75 |
+
|
76 |
+
model_cls = registry.get_model_class(config['arch'])
|
77 |
+
model = model_cls.from_config(config)
|
78 |
+
model.to(device)
|
79 |
+
model.eval()
|
80 |
+
|
81 |
+
images = df_test['protein'].tolist()
|
82 |
+
n = len(images)
|
83 |
+
bsz = 12
|
84 |
+
iter = n // bsz + 1
|
85 |
+
|
86 |
+
for i in range(iter):
|
87 |
+
image = images[i*bsz: min(n, (i+1)*bsz)]
|
88 |
+
image = [('protein{}'.format(i), x) for i, x in enumerate(image)]
|
89 |
+
|
90 |
+
with model.maybe_autocast():
|
91 |
+
_, _, batch_tokens = model.visual_encoder(image)
|
92 |
+
image_embeds = model.ln_vision(batch_tokens.to(device), repr_layers=[model.vis_layers], return_contacts=True)["representations"][model.vis_layers].contiguous()
|
93 |
+
|
94 |
+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
|
95 |
+
|
96 |
+
query_tokens = model.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
97 |
+
query_output = model.Qformer.bert(
|
98 |
+
query_embeds=query_tokens,
|
99 |
+
encoder_hidden_states=image_embeds,
|
100 |
+
encoder_attention_mask=image_atts,
|
101 |
+
return_dict=True,
|
102 |
+
)
|
103 |
+
|
104 |
+
inputs_opt = model.opt_proj(query_output.last_hidden_state)
|
105 |
+
atts_opt = torch.ones(inputs_opt.size()[:-1], dtype=torch.long).to(device)
|
106 |
+
|
107 |
+
model.opt_tokenizer.padding_side = "right"
|
108 |
+
|
109 |
+
text = ['' for i in range(len(image))]
|
110 |
+
opt_tokens = model.opt_tokenizer(
|
111 |
+
text,
|
112 |
+
return_tensors="pt",
|
113 |
+
padding="longest",
|
114 |
+
truncation=True,
|
115 |
+
max_length=model.max_txt_len,
|
116 |
+
).to(device)
|
117 |
+
inputs_embeds = model.opt_model.model.decoder.embed_tokens(opt_tokens.input_ids)
|
118 |
+
inputs_embeds = torch.cat([inputs_opt, inputs_embeds], dim=1)
|
119 |
+
attention_mask = torch.cat([atts_opt, opt_tokens.attention_mask], dim=1)
|
120 |
+
num_txt = 10
|
121 |
+
return_num_txt = 5
|
122 |
+
with model.maybe_autocast():
|
123 |
+
outputs = model.opt_model.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, min_length=3,
|
124 |
+
max_length=30,
|
125 |
+
repetition_penalty=5., num_beams=num_txt, eos_token_id=50118,
|
126 |
+
length_penalty=1., num_return_sequences=return_num_txt, temperature=1.)
|
127 |
+
output_text = model.opt_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
128 |
+
output_text = [text.strip() for text in output_text]
|
129 |
+
output_text_ = []
|
130 |
+
for i in range(len(image)):
|
131 |
+
output_text_.append(';'.join(output_text[i * return_num_txt:(i + 1) * return_num_txt]))
|
132 |
+
with open('/cluster/home/wenkai/LAVIS/output/output{}.txt'.format(fix), 'a+') as f:
|
133 |
+
for i in range(len(image)):
|
134 |
+
f.write(image[i][1] + "|" + output_text_[i] + '\n')
|
135 |
+
|
136 |
+
|
137 |
+
cat = 'mf'
|
138 |
+
fix = '_mf'
|
139 |
+
if cat == 'bp':
|
140 |
+
fix = '_bp'
|
141 |
+
if cat == 'cc':
|
142 |
+
fix = '_cc'
|
143 |
+
|
144 |
+
# model_pth = {'mf': 'uniprot_swissprot_mf_stage1_epo19.pth', 'bp': 'checkpoint17_GO_swissprot_reviewed_bp_stage1.pth', 'cc': ''}
|
145 |
+
|
146 |
+
# graph = obonet.read_obo("http://purl.obolibrary.org/obo/go.obo")
|
147 |
+
|
148 |
+
# setup device to use
|
149 |
+
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
|
150 |
+
# device = 'cpu'
|
151 |
+
|
152 |
+
### Levenshtein similarity
|
153 |
+
test = pd.read_csv('/cluster/home/wenkai/LAVIS/data/sim_split/test{}.csv'.format(fix), sep='|')[:10000]
|
154 |
+
test['function'] = test['function'].apply(lambda x: x.lower())
|
155 |
+
|
156 |
+
|
157 |
+
if os.path.exists('/cluster/home/wenkai/LAVIS/output/output{}.txt'.format(fix)):
|
158 |
+
os.remove('/cluster/home/wenkai/LAVIS/output/output{}.txt'.format(fix))
|
159 |
+
print("stage 2 predict starting")
|
160 |
+
stage2_output(test)
|
161 |
+
print("stage 2 predict completed")
|
162 |
+
|
163 |
+
|
164 |
+
|
165 |
+
df_pred = pd.read_csv('/cluster/home/wenkai/LAVIS/output/output{}.txt'.format(fix), sep='|', header=None, on_bad_lines='warn')
|
166 |
+
df_pred.columns = ['protein', 'function']
|
167 |
+
df_pred = df_pred.drop_duplicates()
|
168 |
+
df_pred['function'] = df_pred['function'].apply(lambda x: str(x).split(';'))
|
169 |
+
df_pred['function'] = df_pred['function'].apply(lambda x: [i.strip() for i in list(set(x))])
|
170 |
+
|
171 |
+
test.columns
|
172 |
+
test_g = test.groupby(['protein']).agg({'function': lambda x: list(x)}).reset_index()
|
173 |
+
test_g.columns = ['protein', 'label']
|
174 |
+
|
175 |
+
data = pd.merge(df_pred, test_g, on='protein', how='left')
|
176 |
+
data = data[data['label'].notnull()]
|
177 |
+
|
178 |
+
sim = []
|
179 |
+
for text, label in zip(data['function'].tolist(), data['label'].tolist()):
|
180 |
+
sim.append(func(text, label))
|
181 |
+
|
182 |
+
data['sim'] = sim
|
183 |
+
data['avg_score'] = data['sim'].apply(lambda x: round(np.mean(x), 3))
|
184 |
+
print("average similarity score: {}".format(round(data['avg_score'].mean(), 3)))
|
185 |
+
# data.to_csv('/home/nilin/LAVIS/predict_{}.csv'.format(cat), index=False, sep='|')
|
186 |
+
|
187 |
+
|
188 |
+
test = pd.read_csv('/cluster/home/wenkai/LAVIS/data/sim_split/test{}.csv'.format(fix), sep='|', usecols=['function', 'GO_label'])
|
189 |
+
test['function'] = test['function'].apply(lambda x: x.lower())
|
190 |
+
test = test.drop_duplicates()
|
191 |
+
test_dict = dict(zip(test['function'], test['GO_label']))
|
192 |
+
val = pd.read_csv('/cluster/home/wenkai/LAVIS/data/sim_split/val{}.csv'.format(fix), sep='|', usecols=['function', 'GO_label'])
|
193 |
+
val['function'] = val['function'].apply(lambda x: x.lower())
|
194 |
+
val = val.drop_duplicates()
|
195 |
+
val_dict = dict(zip(val['function'], val['GO_label']))
|
196 |
+
train = pd.read_csv('/cluster/home/wenkai/LAVIS/data/sim_split/train{}.csv'.format(fix), sep='|', usecols=['function', 'GO_label'])
|
197 |
+
train['function'] = train['function'].apply(lambda x: x.lower())
|
198 |
+
train = train.drop_duplicates()
|
199 |
+
train_dict = dict(zip(train['function'], train['GO_label']))
|
200 |
+
|
201 |
+
|
202 |
+
# go_des = pd.read_csv('/home/nilin/LAVIS/data/go_descriptions_new.txt', sep='|', header=None)
|
203 |
+
# # go_des = pd.read_csv('/home/nilin/LAVIS/data/go_descriptions.txt', sep='|', header=None)
|
204 |
+
# go_des.columns = ['GO', 'function']
|
205 |
+
# go_des = go_des[go_des['function'].notnull()]
|
206 |
+
# go_des['function'] = go_des['function'].apply(lambda x: x.lower())
|
207 |
+
# GO_dict = dict(zip(go_des['function'], go_des['GO']))
|
208 |
+
GO_dict = {}
|
209 |
+
GO_dict.update(train_dict)
|
210 |
+
GO_dict.update(val_dict)
|
211 |
+
GO_dict.update(test_dict)
|
212 |
+
choices = list(GO_dict.keys())
|
213 |
+
|
214 |
+
|
215 |
+
|
216 |
+
# data = pd.read_csv('/home/nilin/LAVIS/predict_{}.csv'.format(cat), sep='|')
|
217 |
+
data = data.sort_values(by='protein')
|
218 |
+
data = data.drop_duplicates('protein')
|
219 |
+
# data = data.sample(1000)
|
220 |
+
|
221 |
+
### 预测的文本如果不在GO标签词中,则算作最相似的GO标签
|
222 |
+
t0 = time.time()
|
223 |
+
txt_dict = {}
|
224 |
+
|
225 |
+
all_txt = []
|
226 |
+
for txt in data['function']:
|
227 |
+
if type(txt) == str:
|
228 |
+
all_txt.extend(eval(txt))
|
229 |
+
else:
|
230 |
+
all_txt.extend(txt)
|
231 |
+
all_txt = list(set(all_txt))
|
232 |
+
|
233 |
+
n = len(all_txt)
|
234 |
+
thread = 20
|
235 |
+
size = int(n/thread)
|
236 |
+
inds = list(range(0, n, size))
|
237 |
+
inds.append(n)
|
238 |
+
all_txt_sep = [all_txt[i: min(i+size, n)] for i in inds[:-1]]
|
239 |
+
|
240 |
+
with Pool(processes=thread) as pool:
|
241 |
+
result = pool.map(fuzzy_match, all_txt_sep)
|
242 |
+
pool.close()
|
243 |
+
pool.join()
|
244 |
+
for d in result:
|
245 |
+
txt_dict.update(d)
|
246 |
+
|
247 |
+
# for txt in all_txt[:10]:
|
248 |
+
# fuzzy_match(txt)
|
249 |
+
|
250 |
+
data['function'] = data['function'].apply(lambda x: txt_map(x, txt_dict))
|
251 |
+
data['function'] = data['function'].apply(lambda x: list(set(x)))
|
252 |
+
print("fuzzy matching time: {}".format(time.time() - t0))
|
253 |
+
|
254 |
+
|
255 |
+
|
256 |
+
|
257 |
+
### Find the generated GO text that not included in the ground truth. Then generate pairs between them.
|
258 |
+
# pair_a, pair_b = [], []
|
259 |
+
# for preds, labels in zip(data['function'], data['label']):
|
260 |
+
# if type(preds) == str:
|
261 |
+
# preds = eval(preds)
|
262 |
+
# if type(labels) == str:
|
263 |
+
# labels = eval(labels)
|
264 |
+
# l = len(labels)
|
265 |
+
# for pred in preds:
|
266 |
+
# if pred not in labels:
|
267 |
+
# pair_a.extend([pred]*l)
|
268 |
+
# pair_b.extend(labels[:])
|
269 |
+
# pair_a = [re.sub('_', ':', GO_dict[i]) for i in pair_a]
|
270 |
+
# pair_b = [re.sub('_', ':', GO_dict[i]) for i in pair_b]
|
271 |
+
# with open('/home/nilin/LAVIS/examples/GO_pair{}.txt'.format(fix), 'w+') as f:
|
272 |
+
# for i, j in zip(pair_a, pair_b):
|
273 |
+
# f.write(i+' '+j+'\n')
|
274 |
+
|
275 |
+
|
276 |
+
# load model
|
277 |
+
model_config = {'arch': 'blip2_protein', 'load_finetuned': False,
|
278 |
+
'pretrained': '/cluster/home/wenkai/LAVIS/lavis/output/BLIP2/Pretrain_stage1/20230922185/checkpoint_15.pth',
|
279 |
+
'finetuned': '', 'num_query_token': 32, 'prompt': '',
|
280 |
+
'model_type': 'pretrain', 'load_pretrained': True, 'freeze_vit': False,
|
281 |
+
'max_protein_len': 512, 'max_txt_len': 25}
|
282 |
+
|
283 |
+
model_cls = registry.get_model_class(model_config['arch'])
|
284 |
+
model = model_cls.from_config(model_config)
|
285 |
+
model = model.to(device)
|
286 |
+
model.eval()
|
287 |
+
|
288 |
+
# evaluate
|
289 |
+
t0 = time.time()
|
290 |
+
proteins = list(data['protein'])
|
291 |
+
txts = list(data['function'])
|
292 |
+
scores = []
|
293 |
+
for seq, txt in zip(proteins, txts):
|
294 |
+
image = [('protein1', seq)]
|
295 |
+
_, _, batch_tokens = model.visual_encoder(image)
|
296 |
+
image_embeds = model.ln_vision(batch_tokens.to(device), repr_layers=[30], return_contacts=True)["representations"][
|
297 |
+
30].contiguous()
|
298 |
+
|
299 |
+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
|
300 |
+
|
301 |
+
query_tokens = model.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
302 |
+
|
303 |
+
query_output = model.Qformer.bert(
|
304 |
+
query_embeds=query_tokens,
|
305 |
+
encoder_hidden_states=image_embeds,
|
306 |
+
encoder_attention_mask=image_atts,
|
307 |
+
use_cache=True,
|
308 |
+
return_dict=True,
|
309 |
+
)
|
310 |
+
|
311 |
+
image_feats = F.normalize(model.vision_proj(query_output.last_hidden_state), dim=-1)
|
312 |
+
|
313 |
+
image_feats_all = concat_all_gather(image_feats)
|
314 |
+
|
315 |
+
if type(txt) == str:
|
316 |
+
txt = eval(txt)
|
317 |
+
length = len(txt)
|
318 |
+
with torch.no_grad():
|
319 |
+
text_tokens = model.tokenizer(
|
320 |
+
txt,
|
321 |
+
padding="max_length",
|
322 |
+
truncation=True,
|
323 |
+
max_length=model.max_txt_len,
|
324 |
+
return_tensors="pt",
|
325 |
+
).to(device)
|
326 |
+
text_output = model.Qformer.bert(
|
327 |
+
text_tokens.input_ids,
|
328 |
+
attention_mask=text_tokens.attention_mask,
|
329 |
+
return_dict=True,
|
330 |
+
)
|
331 |
+
|
332 |
+
text_feat = F.normalize(
|
333 |
+
model.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1
|
334 |
+
)
|
335 |
+
|
336 |
+
text_feat_all = concat_all_gather(text_feat)
|
337 |
+
sim_q2t = torch.matmul(image_feats.unsqueeze(1), text_feat_all.unsqueeze(-1)).squeeze()
|
338 |
+
sim_i2t, _ = sim_q2t.max(-1)
|
339 |
+
# print('sim_i2t: {}'.format(sim_i2t))
|
340 |
+
if length > 1:
|
341 |
+
scores.append(list(sim_i2t.detach().cpu().numpy()))
|
342 |
+
else:
|
343 |
+
scores.append([sim_i2t.item()])
|
344 |
+
print("model evaluate time: {}".format(time.time() - t0))
|
345 |
+
data['score'] = scores
|
346 |
+
|
347 |
+
# precision and recall top-k
|
348 |
+
topk = 2
|
349 |
+
threshould = 0.1
|
350 |
+
labels = []
|
351 |
+
pred_labels = []
|
352 |
+
for l in data['label']:
|
353 |
+
if type(l) == str:
|
354 |
+
l = eval(l)
|
355 |
+
labels.extend(l)
|
356 |
+
|
357 |
+
labels = list(set(labels))
|
358 |
+
total = len(labels)
|
359 |
+
for topk in range(1,7):
|
360 |
+
for threshould in range(1, 25, 1):
|
361 |
+
threshould /= 100
|
362 |
+
filter_txts = []
|
363 |
+
recalls = []
|
364 |
+
precisions = []
|
365 |
+
f1 = []
|
366 |
+
tp_dict, fp_dict, fn_dict = dict(zip(labels, [0]*len(labels))), dict(zip(labels, [0]*len(labels))), dict(zip(labels, [0]*len(labels)))
|
367 |
+
for txts, scores, label in zip(data['function'], data['score'], data['label']):
|
368 |
+
if type(label) == str:
|
369 |
+
label = eval(label)
|
370 |
+
txts_ = np.array(txts)
|
371 |
+
scores = np.array(scores)
|
372 |
+
txts = txts_[scores > threshould]
|
373 |
+
if len(txts) < 1:
|
374 |
+
txts = txts_[np.argmax(scores)]
|
375 |
+
scores = scores[scores > threshould]
|
376 |
+
|
377 |
+
l = len(scores)
|
378 |
+
ll = len(label)
|
379 |
+
if l <= topk:
|
380 |
+
filter_txts.append(list(txts))
|
381 |
+
else:
|
382 |
+
ind = np.argpartition(scores, -topk)[-topk:]
|
383 |
+
txts = txts[ind]
|
384 |
+
filter_txts.append(list(txts))
|
385 |
+
l = topk
|
386 |
+
for t in label:
|
387 |
+
if t in txts:
|
388 |
+
tp_dict[t] += 1
|
389 |
+
else:
|
390 |
+
fn_dict[t] += 1
|
391 |
+
for p in txts:
|
392 |
+
if p not in label:
|
393 |
+
if p in fp_dict:
|
394 |
+
fp_dict[p] += 1
|
395 |
+
else:
|
396 |
+
fp_dict[p] = 1
|
397 |
+
pred_labels.extend(txts)
|
398 |
+
p_total = len(set(pred_labels))
|
399 |
+
re, pr = 0., 0.
|
400 |
+
for x in labels:
|
401 |
+
re += tp_dict[x] / (1.0 * (tp_dict[x] + fn_dict[x] + 1e-8))
|
402 |
+
pr += tp_dict[x] / (1.0 * (tp_dict[x] + fp_dict[x]+1e-8))
|
403 |
+
r = re / total
|
404 |
+
p = pr / total
|
405 |
+
f1 = 2 * p * r / (p + r)
|
406 |
+
print("Topk: {}, threshould: {}, macro_recall: {}, macro_precision: {}, micro_f1: {}".format(topk, threshould, r, p, f1))
|
407 |
+
# num_r = 0
|
408 |
+
# num_p = 0
|
409 |
+
# for x in label:
|
410 |
+
# if x in txts:
|
411 |
+
# num_r += 1
|
412 |
+
# for x in txts:
|
413 |
+
# if x in label:
|
414 |
+
# num_p += 1
|
415 |
+
# recall = num_r/ll
|
416 |
+
# precision = num_p/(l+0.0001)
|
417 |
+
# recalls.append(recall)
|
418 |
+
# precisions.append(precision)
|
419 |
+
# f1.append((2*recall*precision)/(recall+precision+0.0001))
|
420 |
+
#
|
421 |
+
# data['predict'] = filter_txts
|
422 |
+
# data['precision'] = precisions
|
423 |
+
# data['recall'] = recalls
|
424 |
+
# data['f1'] = f1
|
425 |
+
# print("Topk: {}, threshould: {}, macro_recall: {}, macro_precision: {}, micro_f1: {}".format(topk, threshould, round(data['recall'].mean(), 4), round(data['precision'].mean(), 4), round(data['f1'].mean(), 4)))
|
426 |
+
|
427 |
+
|
428 |
+
|
429 |
+
|
430 |
+
|
431 |
+
|
432 |
+
# sim = []
|
433 |
+
# for text, label in zip(data['predict'].tolist(), data['label'].tolist()):
|
434 |
+
# sim.append(levenshtein_sim(text, label))
|
435 |
+
#
|
436 |
+
# data['sim_filter'] = sim
|
437 |
+
# data['avg_score'] = data['sim_filter'].apply(lambda x: round(np.mean(x), 3))
|
438 |
+
|
439 |
+
|
440 |
+
# data['function'] = data['function'].apply(lambda x: eval(re.sub(';', ',', str(x))))
|
441 |
+
# data['label'] = data['label'].apply(lambda x: eval(re.sub(';', ',', str(x))))
|
442 |
+
# data['sim'] = data['sim'].apply(lambda x: eval(re.sub(';', ',', str(x))))
|
443 |
+
#
|
444 |
+
# data['function'] = data['function'].apply(lambda x: re.sub(',', ';', str(x)))
|
445 |
+
# data['label'] = data['label'].apply(lambda x: re.sub(',', ';', str(x)))
|
446 |
+
# data['sim'] = data['sim'].apply(lambda x: re.sub(',', ';', str(x)))
|
447 |
+
# data['predict'] = data['predict'].apply(lambda x: re.sub(',', ';', str(x)))
|
448 |
+
# data['sim_filter'] = data['sim_filter'].apply(lambda x: re.sub(',', ';', str(x)))
|
449 |
+
|
450 |
+
data.to_csv('/cluster/home/wenkai/LAVIS/output/predict_sim{}.csv'.format(fix), sep='|', index=False)
|
451 |
+
# data = pd.read_csv('/cluster/home/wenkai/LAVIS/output/predict_sim{}.csv'.format(fix), sep='|')
|
452 |
+
|
453 |
+
|
454 |
+
|
455 |
+
|
456 |
+
|
457 |
+
|
458 |
+
|
459 |
+
|
460 |
+
#
|
461 |
+
# # example
|
462 |
+
# image = ['MIELKHVTFGYNKKQMVLQDINITIPDGENVGILGESGCGKSTLASLVLGLFKPVKGEIYLSDNAVLTIFQHPLTSFNPDWTIETSLKEALYYYRGLTDNTAQDQLLLQHLSTFELNAQLLTKLPSEVSGGQLQRFNVMRSLLAQPRVLICDEITSNLDVIAEQNVINILKAQTITNLNHFIVISHDLSVLQRLVNRIIVLKDGMIVDDFAIEELFNVDRHPYTKELVQTFSY']
|
463 |
+
# image = [('protein{}'.format(i), x) for i, x in enumerate(image)]
|
464 |
+
#
|
465 |
+
# _, _, batch_tokens = model.visual_encoder(image)
|
466 |
+
# image_embeds = model.ln_vision(batch_tokens.to(device), repr_layers=[30], return_contacts=True)["representations"][30].contiguous()
|
467 |
+
#
|
468 |
+
# image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
|
469 |
+
#
|
470 |
+
# query_tokens = model.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
471 |
+
#
|
472 |
+
# query_output = model.Qformer.bert(
|
473 |
+
# query_embeds=query_tokens,
|
474 |
+
# encoder_hidden_states=image_embeds,
|
475 |
+
# encoder_attention_mask=image_atts,
|
476 |
+
# use_cache=True,
|
477 |
+
# return_dict=True,
|
478 |
+
# )
|
479 |
+
#
|
480 |
+
# image_feats = F.normalize(model.vision_proj(query_output.last_hidden_state), dim=-1)
|
481 |
+
#
|
482 |
+
# image_feats_all = concat_all_gather(image_feats)
|
483 |
+
#
|
484 |
+
# functions = ['transmembrane transporter activity', 'nickel cation transmembrane transporter activity', 'nickel cation binding', 'atp hydrolysis activity', 'atp hydrolysis', 'cadmium binding', 'abc-type nickel transmembrane transporter activity', 'abc-type nickel transporter activity', 'nickel transmembrane transporter activity', 'atp binding']
|
485 |
+
# for text in functions:
|
486 |
+
# with torch.no_grad():
|
487 |
+
# # text = 'flavin adenine dinucleotide binding'
|
488 |
+
# text_tokens = model.tokenizer(
|
489 |
+
# text,
|
490 |
+
# padding="max_length",
|
491 |
+
# truncation=True,
|
492 |
+
# max_length=model.max_txt_len,
|
493 |
+
# return_tensors="pt",
|
494 |
+
# ).to(device)
|
495 |
+
# text_output = model.Qformer.bert(
|
496 |
+
# text_tokens.input_ids,
|
497 |
+
# attention_mask=text_tokens.attention_mask,
|
498 |
+
# return_dict=True,
|
499 |
+
# )
|
500 |
+
#
|
501 |
+
# text_feat = F.normalize(
|
502 |
+
# model.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1
|
503 |
+
# )
|
504 |
+
#
|
505 |
+
# text_feat_all = concat_all_gather(text_feat)
|
506 |
+
# sim_q2t = torch.matmul(image_feats.unsqueeze(1), text_feat_all.unsqueeze(-1)).squeeze()
|
507 |
+
# sim_i2t, _ = sim_q2t.max(-1)
|
508 |
+
# print('sim_i2t: {}'.format(sim_i2t))
|
509 |
+
#
|
510 |
+
# # # text-query similarity: [batch_size, batch_size*num_gpu, num_query_tokens]
|
511 |
+
# # sim_t2q = torch.matmul(
|
512 |
+
# # text_feat.unsqueeze(1).unsqueeze(1), image_feats_all.permute(0, 2, 1)
|
513 |
+
# # ).squeeze()
|
514 |
+
# #
|
515 |
+
# # # text-image similarity: aggregate across all query tokens
|
516 |
+
# # sim_t2i, _ = sim_t2q.max(-1)
|
517 |
+
# # print('sim_t2i: {}'.format(sim_t2i))
|
518 |
+
|
519 |
+
|
520 |
+
|
examples/blip2_predict_func.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
from lavis.models import load_model_and_preprocess
|
8 |
+
from lavis.processors import load_processor
|
9 |
+
from lavis.common.registry import registry
|
10 |
+
from torch.nn import functional as F
|
11 |
+
from lavis.models.base_model import all_gather_with_grad, concat_all_gather
|
12 |
+
import numpy as np
|
13 |
+
import pandas as pd
|
14 |
+
import time
|
15 |
+
from fuzzywuzzy import process
|
16 |
+
from multiprocessing import Pool, Queue, Process
|
17 |
+
import difflib
|
18 |
+
import Levenshtein
|
19 |
+
# import obonet
|
20 |
+
|
21 |
+
|
22 |
+
# setup device to use
|
23 |
+
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
|
24 |
+
# device = 'cpu'
|
25 |
+
|
26 |
+
|
27 |
+
def txt_map(x, txt_dict):
|
28 |
+
if type(x) == str:
|
29 |
+
x = eval(x)
|
30 |
+
x_ = []
|
31 |
+
for i in x:
|
32 |
+
if i in txt_dict:
|
33 |
+
x_.append(txt_dict[i])
|
34 |
+
else:
|
35 |
+
x_.append(i)
|
36 |
+
return x_
|
37 |
+
|
38 |
+
|
39 |
+
def levenshtein_sim(text, label):
|
40 |
+
all_s = []
|
41 |
+
for x in label:
|
42 |
+
s = 0
|
43 |
+
for y in text:
|
44 |
+
temp = Levenshtein.ratio(x, y)
|
45 |
+
if temp > s:
|
46 |
+
s = temp
|
47 |
+
all_s.append(s)
|
48 |
+
all_s = [round(i, 3) for i in all_s]
|
49 |
+
return all_s
|
50 |
+
|
51 |
+
def func(text, label):
|
52 |
+
all_s = []
|
53 |
+
for x in text:
|
54 |
+
s = 0
|
55 |
+
for y in label:
|
56 |
+
temp = Levenshtein.ratio(x, y)
|
57 |
+
if temp > s:
|
58 |
+
s = temp
|
59 |
+
all_s.append(s)
|
60 |
+
all_s = [round(i, 3) for i in all_s]
|
61 |
+
return all_s
|
62 |
+
|
63 |
+
|
64 |
+
def stage2_output(df_test, return_num_txt=1):
|
65 |
+
config = {'arch': 'blip2_protein_opt', 'load_finetuned': False,
|
66 |
+
'pretrained': '/cluster/home/wenkai/LAVIS/lavis/output/BLIP2/Pretrain_stage2/20230924220/checkpoint_5.pth',
|
67 |
+
'finetuned': '', 'num_query_token': 32, 'opt_model': 'facebook/opt-2.7b', 'prompt': '',
|
68 |
+
'model_type': 'pretrain_protein_opt2.7b', 'load_pretrained': True, 'freeze_vit': True,
|
69 |
+
'max_protein_len': 600,
|
70 |
+
'max_txt_len': 25}
|
71 |
+
|
72 |
+
model_cls = registry.get_model_class(config['arch'])
|
73 |
+
model = model_cls.from_config(config)
|
74 |
+
model.to(device)
|
75 |
+
model.eval()
|
76 |
+
|
77 |
+
images = df_test['protein'].tolist()
|
78 |
+
n = len(images)
|
79 |
+
bsz = 12
|
80 |
+
iter = n // bsz + 1
|
81 |
+
|
82 |
+
for i in range(iter):
|
83 |
+
image = images[i*bsz: min(n, (i+1)*bsz)]
|
84 |
+
image = [('protein{}'.format(i), x) for i, x in enumerate(image)]
|
85 |
+
|
86 |
+
with model.maybe_autocast():
|
87 |
+
_, _, batch_tokens = model.visual_encoder(image)
|
88 |
+
image_embeds = model.ln_vision(batch_tokens.to(device), repr_layers=[model.vis_layers], return_contacts=True)["representations"][model.vis_layers].contiguous()
|
89 |
+
|
90 |
+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
|
91 |
+
|
92 |
+
query_tokens = model.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
93 |
+
query_output = model.Qformer.bert(
|
94 |
+
query_embeds=query_tokens,
|
95 |
+
encoder_hidden_states=image_embeds,
|
96 |
+
encoder_attention_mask=image_atts,
|
97 |
+
return_dict=True,
|
98 |
+
)
|
99 |
+
|
100 |
+
inputs_opt = model.opt_proj(query_output.last_hidden_state)
|
101 |
+
atts_opt = torch.ones(inputs_opt.size()[:-1], dtype=torch.long).to(device)
|
102 |
+
|
103 |
+
model.opt_tokenizer.padding_side = "right"
|
104 |
+
|
105 |
+
text = ['' for i in range(len(image))]
|
106 |
+
opt_tokens = model.opt_tokenizer(
|
107 |
+
text,
|
108 |
+
return_tensors="pt",
|
109 |
+
padding="longest",
|
110 |
+
truncation=True,
|
111 |
+
max_length=model.max_txt_len,
|
112 |
+
).to(device)
|
113 |
+
inputs_embeds = model.opt_model.model.decoder.embed_tokens(opt_tokens.input_ids)
|
114 |
+
inputs_embeds = torch.cat([inputs_opt, inputs_embeds], dim=1)
|
115 |
+
attention_mask = torch.cat([atts_opt, opt_tokens.attention_mask], dim=1)
|
116 |
+
num_txt = 6
|
117 |
+
with model.maybe_autocast():
|
118 |
+
outputs = model.opt_model.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, min_length=3,
|
119 |
+
max_length=30,
|
120 |
+
repetition_penalty=1., num_beams=num_txt, eos_token_id=50118,
|
121 |
+
length_penalty=1., num_return_sequences=return_num_txt, temperature=1.)
|
122 |
+
output_text = model.opt_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
123 |
+
output_text = [text.strip() for text in output_text]
|
124 |
+
output_text_ = []
|
125 |
+
for i in range(len(image)):
|
126 |
+
output_text_.append(';'.join(output_text[i * return_num_txt:(i + 1) * return_num_txt]))
|
127 |
+
with open('/cluster/home/wenkai/LAVIS/output/output{}.txt'.format(fix), 'a+') as f:
|
128 |
+
for i in range(len(image)):
|
129 |
+
f.write(image[i][1] + "|" + output_text_[i] + '\n')
|
130 |
+
|
131 |
+
|
132 |
+
cat = 'mf'
|
133 |
+
fix = '_mf'
|
134 |
+
if cat == 'bp':
|
135 |
+
fix = '_bp'
|
136 |
+
if cat == 'cc':
|
137 |
+
fix = '_cc'
|
138 |
+
|
139 |
+
return_num_txt = 1
|
140 |
+
# graph = obonet.read_obo("http://purl.obolibrary.org/obo/go.obo")
|
141 |
+
|
142 |
+
### Levenshtein similarity
|
143 |
+
test = pd.read_csv('/cluster/home/wenkai/LAVIS/data/sim_split/test{}.csv'.format(fix), sep='|')
|
144 |
+
test['function'] = test['function'].apply(lambda x: x.lower())
|
145 |
+
|
146 |
+
|
147 |
+
if os.path.exists('/cluster/home/wenkai/LAVIS/output/output{}.txt'.format(fix)):
|
148 |
+
os.remove('/cluster/home/wenkai/LAVIS/output/output{}.txt'.format(fix))
|
149 |
+
print("stage 2 predict starting")
|
150 |
+
stage2_output(test)
|
151 |
+
print("stage 2 predict completed")
|
152 |
+
|
153 |
+
df_pred = pd.read_csv('/cluster/home/wenkai/LAVIS/output/output{}.txt'.format(fix), sep='|', header=None, on_bad_lines='warn')
|
154 |
+
df_pred.columns = ['protein', 'function']
|
155 |
+
df_pred = df_pred.drop_duplicates()
|
156 |
+
df_pred['function'] = df_pred['function'].apply(lambda x: str(x).split(';'))
|
157 |
+
df_pred['function'] = df_pred['function'].apply(lambda x: [i.strip() for i in list(set(x))])
|
158 |
+
|
159 |
+
test_g = test.groupby(['protein']).agg({'function': lambda x: list(x)}).reset_index()
|
160 |
+
test_g.columns = ['protein', 'label']
|
161 |
+
|
162 |
+
data = pd.merge(df_pred, test_g, on='protein', how='left')
|
163 |
+
data = data[data['label'].notnull()]
|
164 |
+
|
165 |
+
sim = []
|
166 |
+
for text, label in zip(data['function'].tolist(), data['label'].tolist()):
|
167 |
+
sim.append(func(text, label))
|
168 |
+
|
169 |
+
data['sim'] = sim
|
170 |
+
data['avg_score'] = data['sim'].apply(lambda x: round(np.mean(x), 3))
|
171 |
+
data['count'] = data['sim'].apply(lambda x: x.count(1.))
|
172 |
+
print("average similarity score: {}".format(round(data['avg_score'].mean(), 3)))
|
173 |
+
print("Return texts: {}; Accuracy: {}".format(return_num_txt, data['count'].sum()/(return_num_txt*data.shape[0])))
|
174 |
+
data.to_csv('/cluster/home/wenkai/LAVIS/output/predict_{}.csv'.format(cat), index=False, sep='|')
|
175 |
+
|
176 |
+
|
177 |
+
|
178 |
+
|
examples/blip2_predict_func_concat.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
from lavis.models import load_model_and_preprocess
|
8 |
+
from lavis.processors import load_processor
|
9 |
+
from lavis.common.registry import registry
|
10 |
+
from torch.nn import functional as F
|
11 |
+
from lavis.models.base_model import all_gather_with_grad, concat_all_gather
|
12 |
+
import numpy as np
|
13 |
+
import pandas as pd
|
14 |
+
import time
|
15 |
+
from fuzzywuzzy import process
|
16 |
+
from multiprocessing import Pool, Queue, Process
|
17 |
+
import difflib
|
18 |
+
import Levenshtein
|
19 |
+
|
20 |
+
# import obonet
|
21 |
+
|
22 |
+
|
23 |
+
# setup device to use
|
24 |
+
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
|
25 |
+
|
26 |
+
|
27 |
+
# device = torch.device("cuda")
|
28 |
+
|
29 |
+
|
30 |
+
def txt_map(x, txt_dict):
|
31 |
+
if type(x) == str:
|
32 |
+
x = eval(x)
|
33 |
+
x_ = []
|
34 |
+
for i in x:
|
35 |
+
if i in txt_dict:
|
36 |
+
x_.append(txt_dict[i])
|
37 |
+
else:
|
38 |
+
x_.append(i)
|
39 |
+
return x_
|
40 |
+
|
41 |
+
|
42 |
+
def levenshtein_sim(text, label):
|
43 |
+
all_s = []
|
44 |
+
for x in label:
|
45 |
+
s = 0
|
46 |
+
for y in text:
|
47 |
+
temp = Levenshtein.ratio(x, y)
|
48 |
+
if temp > s:
|
49 |
+
s = temp
|
50 |
+
all_s.append(s)
|
51 |
+
all_s = [round(i, 3) for i in all_s]
|
52 |
+
return all_s
|
53 |
+
|
54 |
+
|
55 |
+
def func(text, label):
|
56 |
+
all_s = []
|
57 |
+
for x in text:
|
58 |
+
s = 0
|
59 |
+
for y in label:
|
60 |
+
temp = Levenshtein.ratio(x, y)
|
61 |
+
if temp > s:
|
62 |
+
s = temp
|
63 |
+
all_s.append(s)
|
64 |
+
all_s = [round(i, 3) for i in all_s]
|
65 |
+
return all_s
|
66 |
+
|
67 |
+
|
68 |
+
def stage2_output(df_test, return_num_txt=1):
|
69 |
+
config = {'arch': 'blip2_protein_opt', 'load_finetuned': False,
|
70 |
+
'pretrained': '/cluster/home/wenkai/LAVIS/lavis/output/BLIP2/Pretrain_stage2/20231029182/checkpoint_0.pth',
|
71 |
+
'finetuned': '', 'num_query_token': 32, 'opt_model': 'facebook/opt-2.7b', 'prompt': '',
|
72 |
+
'model_type': 'pretrain_protein_opt2.7b', 'load_pretrained': True, 'freeze_vit': True,
|
73 |
+
'max_protein_len': 600,
|
74 |
+
'max_txt_len': 256}
|
75 |
+
|
76 |
+
model_cls = registry.get_model_class(config['arch'])
|
77 |
+
model = model_cls.from_config(config)
|
78 |
+
model.to(device)
|
79 |
+
model.eval()
|
80 |
+
|
81 |
+
images = df_test['protein'].tolist()
|
82 |
+
n = len(images)
|
83 |
+
bsz = 8
|
84 |
+
iter = n // bsz + 1
|
85 |
+
with open('/cluster/home/wenkai/LAVIS/output/output_concat_{}{}{}.txt'.format(split, fix, type_fix), 'a+') as f:
|
86 |
+
for i in range(iter):
|
87 |
+
image = images[i * bsz: min(n, (i + 1) * bsz)]
|
88 |
+
image = [('protein{}'.format(i), x) for i, x in enumerate(image)]
|
89 |
+
|
90 |
+
with model.maybe_autocast():
|
91 |
+
_, _, batch_tokens = model.visual_encoder(image)
|
92 |
+
image_embeds = \
|
93 |
+
model.ln_vision(batch_tokens.to(device), repr_layers=[model.vis_layers], return_contacts=True)[
|
94 |
+
"representations"][model.vis_layers].contiguous()
|
95 |
+
|
96 |
+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
|
97 |
+
|
98 |
+
query_tokens = model.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
99 |
+
query_output = model.Qformer.bert(
|
100 |
+
query_embeds=query_tokens,
|
101 |
+
encoder_hidden_states=image_embeds,
|
102 |
+
encoder_attention_mask=image_atts,
|
103 |
+
return_dict=True,
|
104 |
+
)
|
105 |
+
|
106 |
+
inputs_opt = model.opt_proj(query_output.last_hidden_state)
|
107 |
+
atts_opt = torch.ones(inputs_opt.size()[:-1], dtype=torch.long).to(device)
|
108 |
+
|
109 |
+
model.opt_tokenizer.padding_side = "right"
|
110 |
+
|
111 |
+
text = ['' for i in range(len(image))]
|
112 |
+
opt_tokens = model.opt_tokenizer(
|
113 |
+
text,
|
114 |
+
return_tensors="pt",
|
115 |
+
padding="longest",
|
116 |
+
truncation=True,
|
117 |
+
max_length=model.max_txt_len,
|
118 |
+
).to(device)
|
119 |
+
inputs_embeds = model.opt_model.model.decoder.embed_tokens(opt_tokens.input_ids)
|
120 |
+
inputs_embeds = torch.cat([inputs_opt, inputs_embeds], dim=1)
|
121 |
+
attention_mask = torch.cat([atts_opt, opt_tokens.attention_mask], dim=1)
|
122 |
+
num_txt = 5
|
123 |
+
with model.maybe_autocast():
|
124 |
+
outputs = model.opt_model.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, min_length=1,
|
125 |
+
max_length=256,
|
126 |
+
repetition_penalty=1., num_beams=num_txt, eos_token_id=50118,
|
127 |
+
length_penalty=1., num_return_sequences=return_num_txt, temperature=1.)
|
128 |
+
output_text = model.opt_tokenizer.batch_decode(outputs)
|
129 |
+
|
130 |
+
output_text = [re.sub('\t', '', str(x)) for x in output_text]
|
131 |
+
output_text = [text.strip() for text in output_text]
|
132 |
+
output_text_ = []
|
133 |
+
for i in range(len(image)):
|
134 |
+
output_text_.append(';'.join(output_text[i * return_num_txt:(i + 1) * return_num_txt]))
|
135 |
+
|
136 |
+
for i in range(len(image)):
|
137 |
+
f.write(image[i][1] + "|" + output_text_[i] + '\n')
|
138 |
+
|
139 |
+
|
140 |
+
if __name__=="__main__":
|
141 |
+
split = 'test'
|
142 |
+
cat = 'bp'
|
143 |
+
fix = '_mf'
|
144 |
+
type_fix = ''
|
145 |
+
if cat == 'bp':
|
146 |
+
fix = '_bp'
|
147 |
+
if cat == 'cc':
|
148 |
+
fix = '_cc'
|
149 |
+
|
150 |
+
print(device)
|
151 |
+
return_num_txt = 1
|
152 |
+
# graph = obonet.read_obo("http://purl.obolibrary.org/obo/go.obo")
|
153 |
+
|
154 |
+
### Levenshtein similarity
|
155 |
+
print("reading file ...")
|
156 |
+
test = pd.read_csv('/cluster/home/wenkai/LAVIS/data/sim_split_concat/{}{}.csv'.format(split, fix),
|
157 |
+
usecols=['name', 'protein', 'function'], sep='|')
|
158 |
+
# test['function'] = test['function'].apply(lambda x: x.lower().split('; '))
|
159 |
+
test.columns = ['name', 'protein', 'label']
|
160 |
+
|
161 |
+
if os.path.exists('/cluster/home/wenkai/LAVIS/output/output_concat_{}{}{}.txt'.format(split, fix, type_fix)):
|
162 |
+
os.remove('/cluster/home/wenkai/LAVIS/output/output_concat_{}{}{}.txt'.format(split, fix, type_fix))
|
163 |
+
print("stage 2 predict starting")
|
164 |
+
stage2_output(test)
|
165 |
+
print("stage 2 predict completed")
|
166 |
+
|
167 |
+
df_pred = pd.read_csv('/cluster/home/wenkai/LAVIS/output/output_concat_{}{}{}.txt'.format(split, fix, type_fix), sep='|',
|
168 |
+
header=None, on_bad_lines='warn')
|
169 |
+
df_pred.columns = ['protein', 'pred']
|
170 |
+
df_pred = df_pred.drop_duplicates()
|
171 |
+
# df_pred['function'] = df_pred['function'].apply(lambda x: str(x).split(';'))
|
172 |
+
# df_pred['function'] = df_pred['function'].apply(lambda x: [i.strip() for i in list(set(x))])
|
173 |
+
|
174 |
+
|
175 |
+
data = pd.merge(df_pred, test, on='protein', how='left')
|
176 |
+
data = data[data['label'].notnull()]
|
177 |
+
|
178 |
+
# sim = []
|
179 |
+
# for text, label in zip(data['function'].tolist(), data['label'].tolist()):
|
180 |
+
# sim.append(func(text, label))
|
181 |
+
|
182 |
+
# data['sim'] = sim
|
183 |
+
# data['avg_score'] = data['sim'].apply(lambda x: round(np.mean(x), 3))
|
184 |
+
# data['count'] = data['sim'].apply(lambda x: x.count(1.))
|
185 |
+
# print("average similarity score: {}".format(round(data['avg_score'].mean(), 3)))
|
186 |
+
# print("Return texts: {}; Accuracy: {}".format(return_num_txt, data['count'].sum()/(return_num_txt*data.shape[0])))
|
187 |
+
data[['name', 'label', 'pred']].to_csv(
|
188 |
+
'/cluster/home/wenkai/LAVIS/output/predict_concat_{}{}{}.csv'.format(split, cat, type_fix), index=False, sep='|')
|
189 |
+
|
190 |
+
|
191 |
+
|
192 |
+
|
193 |
+
|
examples/blip2_predict_func_concat_pretrain.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
from lavis.models import load_model_and_preprocess
|
8 |
+
from lavis.processors import load_processor
|
9 |
+
from lavis.common.registry import registry
|
10 |
+
from torch.nn import functional as F
|
11 |
+
from lavis.models.base_model import all_gather_with_grad, concat_all_gather
|
12 |
+
import numpy as np
|
13 |
+
import pandas as pd
|
14 |
+
import time
|
15 |
+
from fuzzywuzzy import process
|
16 |
+
from multiprocessing import Pool, Queue, Process
|
17 |
+
import difflib
|
18 |
+
import Levenshtein
|
19 |
+
|
20 |
+
# import obonet
|
21 |
+
|
22 |
+
|
23 |
+
# setup device to use
|
24 |
+
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
|
25 |
+
|
26 |
+
|
27 |
+
# device = torch.device("cuda")
|
28 |
+
|
29 |
+
|
30 |
+
def txt_map(x, txt_dict):
|
31 |
+
if type(x) == str:
|
32 |
+
x = eval(x)
|
33 |
+
x_ = []
|
34 |
+
for i in x:
|
35 |
+
if i in txt_dict:
|
36 |
+
x_.append(txt_dict[i])
|
37 |
+
else:
|
38 |
+
x_.append(i)
|
39 |
+
return x_
|
40 |
+
|
41 |
+
|
42 |
+
def levenshtein_sim(text, label):
|
43 |
+
all_s = []
|
44 |
+
for x in label:
|
45 |
+
s = 0
|
46 |
+
for y in text:
|
47 |
+
temp = Levenshtein.ratio(x, y)
|
48 |
+
if temp > s:
|
49 |
+
s = temp
|
50 |
+
all_s.append(s)
|
51 |
+
all_s = [round(i, 3) for i in all_s]
|
52 |
+
return all_s
|
53 |
+
|
54 |
+
|
55 |
+
def func(text, label):
|
56 |
+
all_s = []
|
57 |
+
for x in text:
|
58 |
+
s = 0
|
59 |
+
for y in label:
|
60 |
+
temp = Levenshtein.ratio(x, y)
|
61 |
+
if temp > s:
|
62 |
+
s = temp
|
63 |
+
all_s.append(s)
|
64 |
+
all_s = [round(i, 3) for i in all_s]
|
65 |
+
return all_s
|
66 |
+
|
67 |
+
|
68 |
+
def stage2_output(df_test, return_num_txt=1):
|
69 |
+
config = {'arch': 'blip2_protein_opt', 'load_finetuned': False,
|
70 |
+
'pretrained': '/cluster/home/wenkai/LAVIS/lavis/output/BLIP2/Pretrain_stage2/20231029182/checkpoint_0.pth',
|
71 |
+
'finetuned': '', 'num_query_token': 32, 'opt_model': 'facebook/opt-2.7b', 'prompt': '',
|
72 |
+
'model_type': 'pretrain_protein_opt2.7b', 'load_pretrained': True, 'freeze_vit': True,
|
73 |
+
'max_protein_len': 600,
|
74 |
+
'max_txt_len': 256}
|
75 |
+
|
76 |
+
model_cls = registry.get_model_class(config['arch'])
|
77 |
+
model = model_cls.from_config(config)
|
78 |
+
model.to(device)
|
79 |
+
model.eval()
|
80 |
+
|
81 |
+
images = df_test['protein'].tolist()
|
82 |
+
n = len(images)
|
83 |
+
bsz = 8
|
84 |
+
iter = n // bsz + 1
|
85 |
+
if n > 0:
|
86 |
+
for i in range(iter):
|
87 |
+
image = images[i * bsz: min(n, (i + 1) * bsz)]
|
88 |
+
image = [('protein{}'.format(i), x) for i, x in enumerate(image)]
|
89 |
+
|
90 |
+
with model.maybe_autocast():
|
91 |
+
_, _, batch_tokens = model.visual_encoder(image)
|
92 |
+
image_embeds = \
|
93 |
+
model.ln_vision(batch_tokens.to(device), repr_layers=[model.vis_layers], return_contacts=True)[
|
94 |
+
"representations"][model.vis_layers].contiguous()
|
95 |
+
|
96 |
+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
|
97 |
+
|
98 |
+
query_tokens = model.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
99 |
+
query_output = model.Qformer.bert(
|
100 |
+
query_embeds=query_tokens,
|
101 |
+
encoder_hidden_states=image_embeds,
|
102 |
+
encoder_attention_mask=image_atts,
|
103 |
+
return_dict=True,
|
104 |
+
)
|
105 |
+
|
106 |
+
inputs_opt = model.opt_proj(query_output.last_hidden_state)
|
107 |
+
atts_opt = torch.ones(inputs_opt.size()[:-1], dtype=torch.long).to(device)
|
108 |
+
|
109 |
+
model.opt_tokenizer.padding_side = "right"
|
110 |
+
|
111 |
+
text = ['' for i in range(len(image))]
|
112 |
+
opt_tokens = model.opt_tokenizer(
|
113 |
+
text,
|
114 |
+
return_tensors="pt",
|
115 |
+
padding="longest",
|
116 |
+
truncation=True,
|
117 |
+
max_length=model.max_txt_len,
|
118 |
+
).to(device)
|
119 |
+
inputs_embeds = model.opt_model.model.decoder.embed_tokens(opt_tokens.input_ids)
|
120 |
+
inputs_embeds = torch.cat([inputs_opt, inputs_embeds], dim=1)
|
121 |
+
attention_mask = torch.cat([atts_opt, opt_tokens.attention_mask], dim=1)
|
122 |
+
num_txt = 5
|
123 |
+
with model.maybe_autocast():
|
124 |
+
outputs = model.opt_model.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, min_length=1,
|
125 |
+
max_length=256,
|
126 |
+
repetition_penalty=1., num_beams=num_txt, eos_token_id=50118,
|
127 |
+
length_penalty=1., num_return_sequences=return_num_txt, temperature=1.)
|
128 |
+
output_text = model.opt_tokenizer.batch_decode(outputs)
|
129 |
+
|
130 |
+
output_text = [re.sub('\t', '', str(x)) for x in output_text]
|
131 |
+
output_text = [text.strip() for text in output_text]
|
132 |
+
output_text_ = []
|
133 |
+
for i in range(len(image)):
|
134 |
+
output_text_.append(';'.join(output_text[i * return_num_txt:(i + 1) * return_num_txt]))
|
135 |
+
|
136 |
+
f = open('/cluster/home/wenkai/LAVIS/output/output_concat_{}{}{}.txt'.format(split, fix, type_fix), 'a+')
|
137 |
+
for i in range(len(image)):
|
138 |
+
f.write(image[i][1] + "|" + output_text_[i] + '\n')
|
139 |
+
f.close()
|
140 |
+
|
141 |
+
|
142 |
+
|
143 |
+
|
144 |
+
if __name__=="__main__":
|
145 |
+
split = 'test'
|
146 |
+
cat = ''
|
147 |
+
fix = ''
|
148 |
+
type_fix = '_pretrain'
|
149 |
+
if cat == 'bp':
|
150 |
+
fix = '_bp'
|
151 |
+
if cat == 'cc':
|
152 |
+
fix = '_cc'
|
153 |
+
|
154 |
+
print(device)
|
155 |
+
return_num_txt = 1
|
156 |
+
# graph = obonet.read_obo("http://purl.obolibrary.org/obo/go.obo")
|
157 |
+
|
158 |
+
### Levenshtein similarity
|
159 |
+
print("reading file ...")
|
160 |
+
test = pd.read_csv('/cluster/home/wenkai/LAVIS/data/pretrain/{}_sample10000.csv'.format(split),
|
161 |
+
usecols=['name', 'protein', 'function'], sep='|')
|
162 |
+
# test['function'] = test['function'].apply(lambda x: x.lower().split('; '))
|
163 |
+
test.columns = ['name', 'protein', 'label']
|
164 |
+
|
165 |
+
if os.path.exists('/cluster/home/wenkai/LAVIS/output/output_concat_{}{}{}.txt'.format(split, fix, type_fix)):
|
166 |
+
os.remove('/cluster/home/wenkai/LAVIS/output/output_concat_{}{}{}.txt'.format(split, fix, type_fix))
|
167 |
+
print("stage 2 predict starting")
|
168 |
+
stage2_output(test)
|
169 |
+
print("stage 2 predict completed")
|
170 |
+
|
171 |
+
df_pred = pd.read_csv('/cluster/home/wenkai/LAVIS/output/output_concat_{}{}{}.txt'.format(split, fix, type_fix), sep='|',
|
172 |
+
header=None, on_bad_lines='warn')
|
173 |
+
df_pred.columns = ['protein', 'pred']
|
174 |
+
df_pred = df_pred.drop_duplicates()
|
175 |
+
# df_pred['function'] = df_pred['function'].apply(lambda x: str(x).split(';'))
|
176 |
+
# df_pred['function'] = df_pred['function'].apply(lambda x: [i.strip() for i in list(set(x))])
|
177 |
+
|
178 |
+
|
179 |
+
data = pd.merge(df_pred, test, on='protein', how='left')
|
180 |
+
data = data[data['label'].notnull()]
|
181 |
+
|
182 |
+
# sim = []
|
183 |
+
# for text, label in zip(data['function'].tolist(), data['label'].tolist()):
|
184 |
+
# sim.append(func(text, label))
|
185 |
+
|
186 |
+
# data['sim'] = sim
|
187 |
+
# data['avg_score'] = data['sim'].apply(lambda x: round(np.mean(x), 3))
|
188 |
+
# data['count'] = data['sim'].apply(lambda x: x.count(1.))
|
189 |
+
# print("average similarity score: {}".format(round(data['avg_score'].mean(), 3)))
|
190 |
+
# print("Return texts: {}; Accuracy: {}".format(return_num_txt, data['count'].sum()/(return_num_txt*data.shape[0])))
|
191 |
+
data[['name', 'label', 'pred']].to_csv(
|
192 |
+
'/cluster/home/wenkai/LAVIS/output/predict_concat_{}{}{}.csv'.format(split, cat, type_fix), index=False, sep='|')
|
193 |
+
|
194 |
+
|
195 |
+
|
196 |
+
|
197 |
+
|
examples/blip2_predict_func_concat_timesplit.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
from lavis.models import load_model_and_preprocess
|
8 |
+
from lavis.processors import load_processor
|
9 |
+
from lavis.common.registry import registry
|
10 |
+
from torch.nn import functional as F
|
11 |
+
from lavis.models.base_model import all_gather_with_grad, concat_all_gather
|
12 |
+
import numpy as np
|
13 |
+
import pandas as pd
|
14 |
+
import time
|
15 |
+
from fuzzywuzzy import process
|
16 |
+
from multiprocessing import Pool, Queue, Process
|
17 |
+
import difflib
|
18 |
+
import Levenshtein
|
19 |
+
# import obonet
|
20 |
+
|
21 |
+
|
22 |
+
# setup device to use
|
23 |
+
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
|
24 |
+
# device = 'cpu'
|
25 |
+
|
26 |
+
|
27 |
+
def txt_map(x, txt_dict):
|
28 |
+
if type(x) == str:
|
29 |
+
x = eval(x)
|
30 |
+
x_ = []
|
31 |
+
for i in x:
|
32 |
+
if i in txt_dict:
|
33 |
+
x_.append(txt_dict[i])
|
34 |
+
else:
|
35 |
+
x_.append(i)
|
36 |
+
return x_
|
37 |
+
|
38 |
+
|
39 |
+
def levenshtein_sim(text, label):
|
40 |
+
all_s = []
|
41 |
+
for x in label:
|
42 |
+
s = 0
|
43 |
+
for y in text:
|
44 |
+
temp = Levenshtein.ratio(x, y)
|
45 |
+
if temp > s:
|
46 |
+
s = temp
|
47 |
+
all_s.append(s)
|
48 |
+
all_s = [round(i, 3) for i in all_s]
|
49 |
+
return all_s
|
50 |
+
|
51 |
+
def func(text, label):
|
52 |
+
all_s = []
|
53 |
+
for x in text:
|
54 |
+
s = 0
|
55 |
+
for y in label:
|
56 |
+
temp = Levenshtein.ratio(x, y)
|
57 |
+
if temp > s:
|
58 |
+
s = temp
|
59 |
+
all_s.append(s)
|
60 |
+
all_s = [round(i, 3) for i in all_s]
|
61 |
+
return all_s
|
62 |
+
|
63 |
+
|
64 |
+
def stage2_output(df_test, return_num_txt=1):
|
65 |
+
config = {'arch': 'blip2_protein_opt', 'load_finetuned': False,
|
66 |
+
'pretrained': '/cluster/home/wenkai/LAVIS/lavis/output/BLIP2/Pretrain_stage2/20231007085/checkpoint_19.pth',
|
67 |
+
'finetuned': '', 'num_query_token': 32, 'opt_model': 'facebook/opt-2.7b', 'prompt': '',
|
68 |
+
'model_type': 'pretrain_protein_opt2.7b', 'load_pretrained': True, 'freeze_vit': True,
|
69 |
+
'max_protein_len': 600,
|
70 |
+
'max_txt_len': 256}
|
71 |
+
|
72 |
+
model_cls = registry.get_model_class(config['arch'])
|
73 |
+
model = model_cls.from_config(config)
|
74 |
+
model.to(device)
|
75 |
+
model.eval()
|
76 |
+
|
77 |
+
images = df_test['protein'].tolist()
|
78 |
+
n = len(images)
|
79 |
+
bsz = 12
|
80 |
+
iter = n // bsz + 1
|
81 |
+
|
82 |
+
for i in range(iter):
|
83 |
+
image = images[i*bsz: min(n, (i+1)*bsz)]
|
84 |
+
image = [('protein{}'.format(i), x) for i, x in enumerate(image)]
|
85 |
+
|
86 |
+
with model.maybe_autocast():
|
87 |
+
_, _, batch_tokens = model.visual_encoder(image)
|
88 |
+
image_embeds = model.ln_vision(batch_tokens.to(device), repr_layers=[model.vis_layers], return_contacts=True)["representations"][model.vis_layers].contiguous()
|
89 |
+
|
90 |
+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
|
91 |
+
|
92 |
+
query_tokens = model.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
93 |
+
query_output = model.Qformer.bert(
|
94 |
+
query_embeds=query_tokens,
|
95 |
+
encoder_hidden_states=image_embeds,
|
96 |
+
encoder_attention_mask=image_atts,
|
97 |
+
return_dict=True,
|
98 |
+
)
|
99 |
+
|
100 |
+
inputs_opt = model.opt_proj(query_output.last_hidden_state)
|
101 |
+
atts_opt = torch.ones(inputs_opt.size()[:-1], dtype=torch.long).to(device)
|
102 |
+
|
103 |
+
model.opt_tokenizer.padding_side = "right"
|
104 |
+
|
105 |
+
text = ['' for i in range(len(image))]
|
106 |
+
opt_tokens = model.opt_tokenizer(
|
107 |
+
text,
|
108 |
+
return_tensors="pt",
|
109 |
+
padding="longest",
|
110 |
+
truncation=True,
|
111 |
+
max_length=model.max_txt_len,
|
112 |
+
).to(device)
|
113 |
+
inputs_embeds = model.opt_model.model.decoder.embed_tokens(opt_tokens.input_ids)
|
114 |
+
inputs_embeds = torch.cat([inputs_opt, inputs_embeds], dim=1)
|
115 |
+
attention_mask = torch.cat([atts_opt, opt_tokens.attention_mask], dim=1)
|
116 |
+
num_txt = 5
|
117 |
+
with model.maybe_autocast():
|
118 |
+
outputs = model.opt_model.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, min_length=3,
|
119 |
+
max_length=256,
|
120 |
+
repetition_penalty=1., num_beams=num_txt, eos_token_id=50118,
|
121 |
+
length_penalty=1., num_return_sequences=return_num_txt, temperature=1.)
|
122 |
+
output_text = model.opt_tokenizer.batch_decode(outputs)
|
123 |
+
output_text = [re.sub('\t', '', x) for x in output_text]
|
124 |
+
|
125 |
+
output_text = [text.strip() for text in output_text]
|
126 |
+
output_text_ = []
|
127 |
+
for i in range(len(image)):
|
128 |
+
output_text_.append(';'.join(output_text[i * return_num_txt:(i + 1) * return_num_txt]))
|
129 |
+
with open('/cluster/home/wenkai/LAVIS/output/output_timeconcat{}.txt'.format(fix), 'a+') as f:
|
130 |
+
for i in range(len(image)):
|
131 |
+
f.write(image[i][1] + "|" + output_text_[i] + '\n')
|
132 |
+
|
133 |
+
|
134 |
+
cat = 'mf'
|
135 |
+
fix = '_mf'
|
136 |
+
if cat == 'bp':
|
137 |
+
fix = '_bp'
|
138 |
+
if cat == 'cc':
|
139 |
+
fix = '_cc'
|
140 |
+
|
141 |
+
return_num_txt = 1
|
142 |
+
# graph = obonet.read_obo("http://purl.obolibrary.org/obo/go.obo")
|
143 |
+
|
144 |
+
### Levenshtein similarity
|
145 |
+
test = pd.read_csv('/cluster/home/wenkai/LAVIS/data/review_time_concat/test{}.csv'.format(fix), usecols=['name', 'protein', 'function'], sep='|')
|
146 |
+
#test['function'] = test['function'].apply(lambda x: x.lower().split('; '))
|
147 |
+
test.columns = ['name', 'protein', 'label']
|
148 |
+
|
149 |
+
if os.path.exists('/cluster/home/wenkai/LAVIS/output/output_timeconcat{}.txt'.format(fix)):
|
150 |
+
os.remove('/cluster/home/wenkai/LAVIS/output/output_timeconcat{}.txt'.format(fix))
|
151 |
+
print("stage 2 predict starting")
|
152 |
+
stage2_output(test)
|
153 |
+
print("stage 2 predict completed")
|
154 |
+
|
155 |
+
df_pred = pd.read_csv('/cluster/home/wenkai/LAVIS/output/output_timeconcat{}.txt'.format(fix), sep='|', header=None, on_bad_lines='warn')
|
156 |
+
df_pred.columns = ['protein', 'pred']
|
157 |
+
df_pred = df_pred.drop_duplicates()
|
158 |
+
|
159 |
+
data = pd.merge(df_pred, test, on='protein', how='left')
|
160 |
+
data = data[data['label'].notnull()]
|
161 |
+
|
162 |
+
data[['name', 'label', 'pred']].to_csv('/cluster/home/wenkai/LAVIS/output/predict_timeconcat_{}.csv'.format(cat), index=False, sep='|')
|
163 |
+
|
164 |
+
|
165 |
+
|
166 |
+
|
examples/blip2_predict_names.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
from lavis.models import load_model_and_preprocess
|
8 |
+
from lavis.processors import load_processor
|
9 |
+
from lavis.common.registry import registry
|
10 |
+
from torch.nn import functional as F
|
11 |
+
from lavis.models.base_model import all_gather_with_grad, concat_all_gather
|
12 |
+
import numpy as np
|
13 |
+
import pandas as pd
|
14 |
+
import time
|
15 |
+
from fuzzywuzzy import process
|
16 |
+
from multiprocessing import Pool, Queue, Process
|
17 |
+
import difflib
|
18 |
+
import Levenshtein
|
19 |
+
# import obonet
|
20 |
+
|
21 |
+
|
22 |
+
# setup device to use
|
23 |
+
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
|
24 |
+
# device = 'cpu'
|
25 |
+
|
26 |
+
|
27 |
+
def txt_map(x, txt_dict):
|
28 |
+
if type(x) == str:
|
29 |
+
x = eval(x)
|
30 |
+
x_ = []
|
31 |
+
for i in x:
|
32 |
+
if i in txt_dict:
|
33 |
+
x_.append(txt_dict[i])
|
34 |
+
else:
|
35 |
+
x_.append(i)
|
36 |
+
return x_
|
37 |
+
|
38 |
+
|
39 |
+
def levenshtein_sim(text, label):
|
40 |
+
all_s = []
|
41 |
+
for x in label:
|
42 |
+
s = 0
|
43 |
+
for y in text:
|
44 |
+
temp = Levenshtein.ratio(x, y)
|
45 |
+
if temp > s:
|
46 |
+
s = temp
|
47 |
+
all_s.append(s)
|
48 |
+
all_s = [round(i, 3) for i in all_s]
|
49 |
+
return all_s
|
50 |
+
|
51 |
+
def func(text, label):
|
52 |
+
all_s = []
|
53 |
+
for x in label:
|
54 |
+
s = 0
|
55 |
+
for y in text:
|
56 |
+
temp = Levenshtein.ratio(x, y)
|
57 |
+
if temp > s:
|
58 |
+
s = temp
|
59 |
+
all_s.append(s)
|
60 |
+
all_s = [round(i, 3) for i in all_s]
|
61 |
+
return all_s
|
62 |
+
|
63 |
+
|
64 |
+
def stage2_output(df_test):
|
65 |
+
config = {'arch': 'blip2_protein_opt', 'load_finetuned': False,
|
66 |
+
'pretrained': '/cluster/home/wenkai/LAVIS/lavis/output/BLIP2/Pretrain_stage2/20230926091/checkpoint_3.pth',
|
67 |
+
'finetuned': '', 'num_query_token': 32, 'opt_model': 'facebook/opt-2.7b', 'prompt': '',
|
68 |
+
'model_type': 'pretrain_protein_opt2.7b', 'load_pretrained': True, 'freeze_vit': True,
|
69 |
+
'max_protein_len': 600,
|
70 |
+
'max_txt_len': 25}
|
71 |
+
|
72 |
+
model_cls = registry.get_model_class(config['arch'])
|
73 |
+
model = model_cls.from_config(config)
|
74 |
+
model.to(device)
|
75 |
+
model.eval()
|
76 |
+
|
77 |
+
images = df_test['protein'].tolist()
|
78 |
+
n = len(images)
|
79 |
+
bsz = 12
|
80 |
+
iter = n // bsz + 1
|
81 |
+
|
82 |
+
for i in range(iter):
|
83 |
+
image = images[i*bsz: min(n, (i+1)*bsz)]
|
84 |
+
image = [('protein{}'.format(i), x) for i, x in enumerate(image)]
|
85 |
+
|
86 |
+
with model.maybe_autocast():
|
87 |
+
_, _, batch_tokens = model.visual_encoder(image)
|
88 |
+
image_embeds = model.ln_vision(batch_tokens.to(device), repr_layers=[model.vis_layers], return_contacts=True)["representations"][model.vis_layers].contiguous()
|
89 |
+
|
90 |
+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
|
91 |
+
|
92 |
+
query_tokens = model.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
93 |
+
query_output = model.Qformer.bert(
|
94 |
+
query_embeds=query_tokens,
|
95 |
+
encoder_hidden_states=image_embeds,
|
96 |
+
encoder_attention_mask=image_atts,
|
97 |
+
return_dict=True,
|
98 |
+
)
|
99 |
+
|
100 |
+
inputs_opt = model.opt_proj(query_output.last_hidden_state)
|
101 |
+
atts_opt = torch.ones(inputs_opt.size()[:-1], dtype=torch.long).to(device)
|
102 |
+
|
103 |
+
model.opt_tokenizer.padding_side = "right"
|
104 |
+
|
105 |
+
text = ['' for i in range(len(image))]
|
106 |
+
opt_tokens = model.opt_tokenizer(
|
107 |
+
text,
|
108 |
+
return_tensors="pt",
|
109 |
+
padding="longest",
|
110 |
+
truncation=True,
|
111 |
+
max_length=model.max_txt_len,
|
112 |
+
).to(device)
|
113 |
+
inputs_embeds = model.opt_model.model.decoder.embed_tokens(opt_tokens.input_ids)
|
114 |
+
inputs_embeds = torch.cat([inputs_opt, inputs_embeds], dim=1)
|
115 |
+
attention_mask = torch.cat([atts_opt, opt_tokens.attention_mask], dim=1)
|
116 |
+
num_txt = 5
|
117 |
+
return_num_txt = 2
|
118 |
+
with model.maybe_autocast():
|
119 |
+
outputs = model.opt_model.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, min_length=3,
|
120 |
+
max_length=30,
|
121 |
+
repetition_penalty=5., num_beams=num_txt, eos_token_id=50118,
|
122 |
+
length_penalty=1., num_return_sequences=return_num_txt, temperature=1.)
|
123 |
+
output_text = model.opt_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
124 |
+
output_text = [text.strip() for text in output_text]
|
125 |
+
output_text_ = []
|
126 |
+
for i in range(len(image)):
|
127 |
+
output_text_.append(';'.join(output_text[i * return_num_txt:(i + 1) * return_num_txt]))
|
128 |
+
with open('/cluster/home/wenkai/LAVIS/output/output_names.txt', 'a+') as f:
|
129 |
+
for i in range(len(image)):
|
130 |
+
f.write(image[i][1] + "|" + output_text_[i] + '\n')
|
131 |
+
|
132 |
+
|
133 |
+
def evaluate_score(data):
|
134 |
+
model_config = {'arch': 'blip2_protein', 'load_finetuned': False,
|
135 |
+
'pretrained': '/cluster/home/wenkai/LAVIS/lavis/output/BLIP2/Pretrain_stage1/20230925102/checkpoint_6.pth',
|
136 |
+
'finetuned': '', 'num_query_token': 32, 'prompt': '',
|
137 |
+
'model_type': 'pretrain', 'load_pretrained': True, 'freeze_vit': False,
|
138 |
+
'max_protein_len': 512, 'max_txt_len': 30}
|
139 |
+
|
140 |
+
model_cls = registry.get_model_class(model_config['arch'])
|
141 |
+
model = model_cls.from_config(model_config)
|
142 |
+
model = model.to(device)
|
143 |
+
model.eval()
|
144 |
+
|
145 |
+
# evaluate
|
146 |
+
t0 = time.time()
|
147 |
+
proteins = list(data['protein'])
|
148 |
+
txts = list(data['function'])
|
149 |
+
scores = []
|
150 |
+
for seq, txt in zip(proteins, txts):
|
151 |
+
image = [('protein1', seq)]
|
152 |
+
_, _, batch_tokens = model.visual_encoder(image)
|
153 |
+
image_embeds = \
|
154 |
+
model.ln_vision(batch_tokens.to(device), repr_layers=[30], return_contacts=True)["representations"][
|
155 |
+
30].contiguous()
|
156 |
+
|
157 |
+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
|
158 |
+
|
159 |
+
query_tokens = model.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
160 |
+
|
161 |
+
query_output = model.Qformer.bert(
|
162 |
+
query_embeds=query_tokens,
|
163 |
+
encoder_hidden_states=image_embeds,
|
164 |
+
encoder_attention_mask=image_atts,
|
165 |
+
use_cache=True,
|
166 |
+
return_dict=True,
|
167 |
+
)
|
168 |
+
|
169 |
+
image_feats = F.normalize(model.vision_proj(query_output.last_hidden_state), dim=-1)
|
170 |
+
|
171 |
+
image_feats_all = concat_all_gather(image_feats)
|
172 |
+
|
173 |
+
if type(txt) == str:
|
174 |
+
txt = eval(txt)
|
175 |
+
length = len(txt)
|
176 |
+
with torch.no_grad():
|
177 |
+
text_tokens = model.tokenizer(
|
178 |
+
txt,
|
179 |
+
padding="max_length",
|
180 |
+
truncation=True,
|
181 |
+
max_length=model.max_txt_len,
|
182 |
+
return_tensors="pt",
|
183 |
+
).to(device)
|
184 |
+
text_output = model.Qformer.bert(
|
185 |
+
text_tokens.input_ids,
|
186 |
+
attention_mask=text_tokens.attention_mask,
|
187 |
+
return_dict=True,
|
188 |
+
)
|
189 |
+
|
190 |
+
text_feat = F.normalize(
|
191 |
+
model.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1
|
192 |
+
)
|
193 |
+
|
194 |
+
text_feat_all = concat_all_gather(text_feat)
|
195 |
+
sim_q2t = torch.matmul(image_feats.unsqueeze(1), text_feat_all.unsqueeze(-1)).squeeze()
|
196 |
+
sim_i2t, _ = sim_q2t.max(-1)
|
197 |
+
# print('sim_i2t: {}'.format(sim_i2t))
|
198 |
+
if length > 1:
|
199 |
+
scores.append(list(sim_i2t.detach().cpu().numpy()))
|
200 |
+
else:
|
201 |
+
scores.append([sim_i2t.item()])
|
202 |
+
print("model evaluate time: {}".format(time.time() - t0))
|
203 |
+
data['sim'] = scores
|
204 |
+
return data
|
205 |
+
|
206 |
+
|
207 |
+
|
208 |
+
# graph = obonet.read_obo("http://purl.obolibrary.org/obo/go.obo")
|
209 |
+
|
210 |
+
### Levenshtein similarity
|
211 |
+
test = pd.read_csv('/cluster/home/wenkai/LAVIS/data/raw_time_split/reviewed//test.csv', sep='|')
|
212 |
+
test['function'] = test['function'].apply(lambda x: x.lower())
|
213 |
+
|
214 |
+
|
215 |
+
if os.path.exists('/cluster/home/wenkai/LAVIS/output/output_names.txt'):
|
216 |
+
os.remove('/cluster/home/wenkai/LAVIS/output/output_names.txt')
|
217 |
+
print("stage 2 predict starting")
|
218 |
+
stage2_output(test)
|
219 |
+
print("stage 2 predict completed")
|
220 |
+
|
221 |
+
df_pred = pd.read_csv('/cluster/home/wenkai/LAVIS/output/output_names.txt', sep='|', header=None, on_bad_lines='warn')
|
222 |
+
df_pred.columns = ['protein', 'function']
|
223 |
+
df_pred = df_pred.drop_duplicates()
|
224 |
+
df_pred['function'] = df_pred['function'].apply(lambda x: str(x).split(';'))
|
225 |
+
df_pred['function'] = df_pred['function'].apply(lambda x: [i.strip() for i in list(set(x))])
|
226 |
+
|
227 |
+
test.columns
|
228 |
+
test_g = test.groupby(['protein']).agg({'function': lambda x: list(x)}).reset_index()
|
229 |
+
test_g.columns = ['protein', 'label']
|
230 |
+
|
231 |
+
data = pd.merge(df_pred, test_g, on='protein', how='left')
|
232 |
+
data = data[data['label'].notnull()]
|
233 |
+
|
234 |
+
sim = []
|
235 |
+
for text, label in zip(data['function'].tolist(), data['label'].tolist()):
|
236 |
+
sim.append(func(text, label))
|
237 |
+
|
238 |
+
data['sim'] = sim
|
239 |
+
data['avg_score'] = data['sim'].apply(lambda x: round(np.mean(x), 3))
|
240 |
+
print("average similarity score: {}".format(round(data['avg_score'].mean(), 3)))
|
241 |
+
data.to_csv('/cluster/home/wenkai/LAVIS/output/output_names.csv', index=False, sep='|')
|
242 |
+
|
243 |
+
|
244 |
+
|
245 |
+
|
246 |
+
|
247 |
+
|
examples/predict_test.sh
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
#SBATCH -J infer_test
|
3 |
+
#SBATCH -p gpu1
|
4 |
+
#SBATCH -N 1
|
5 |
+
#SBATCH -w node[84]
|
6 |
+
#SBATCH --mem 80G
|
7 |
+
#SBATCH --gres=gpu:1
|
8 |
+
#SBATCH --output=log_predict_test.out
|
9 |
+
#SBATCH --error=log_predict_test.err
|
10 |
+
#SBATCH --cpus-per-task=8
|
11 |
+
module load anaconda3/2021.05
|
12 |
+
source activate LAVIS
|
13 |
+
|
14 |
+
python blip2_predict_func_concat_pretrain.py
|
examples/predict_train.sh
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
#SBATCH -J infer_cc
|
3 |
+
#SBATCH -p gpu1
|
4 |
+
#SBATCH -N 1
|
5 |
+
#SBATCH -w node[84]
|
6 |
+
#SBATCH --mem 80G
|
7 |
+
#SBATCH --gres=gpu:1
|
8 |
+
#SBATCH --output=log_predict.out
|
9 |
+
#SBATCH --error=log_predict.err
|
10 |
+
#SBATCH --cpus-per-task=8
|
11 |
+
module load anaconda3/2021.05
|
12 |
+
source activate LAVIS
|
13 |
+
|
14 |
+
python blip2_predict_func_concat.py
|