Spaces:
Build error
Build error
AK391
commited on
Commit
•
794924b
1
Parent(s):
bbbc401
files
Browse files- LICENSE.txt +12 -0
- configs/caption_coco.yaml +33 -0
- configs/med_config.json +21 -0
- configs/nlvr.yaml +21 -0
- configs/nocaps.yaml +15 -0
- configs/pretrain.yaml +27 -0
- configs/retrieval_coco.yaml +34 -0
- configs/retrieval_flickr.yaml +34 -0
- configs/vqa.yaml +25 -0
- data/__init__.py +101 -0
- data/coco_karpathy_dataset.py +126 -0
- data/flickr30k_dataset.py +93 -0
- data/nlvr_dataset.py +78 -0
- data/nocaps_dataset.py +32 -0
- data/pretrain_dataset.py +59 -0
- data/utils.py +112 -0
- data/vqa_dataset.py +88 -0
- eval_nocaps.py +118 -0
- models/__init__.py +0 -0
- models/blip.py +238 -0
- models/blip_nlvr.py +103 -0
- models/blip_pretrain.py +339 -0
- models/blip_retrieval.py +322 -0
- models/blip_vqa.py +186 -0
- models/med.py +955 -0
- models/nlvr_encoder.py +843 -0
- models/vit.py +305 -0
- pretrain.py +173 -0
- train_caption.py +206 -0
- train_nlvr.py +213 -0
- train_retrieval.py +345 -0
- train_vqa.py +202 -0
- transform/randaugment.py +340 -0
- utils.py +278 -0
LICENSE.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2022, Salesforce.com, Inc.
|
2 |
+
All rights reserved.
|
3 |
+
|
4 |
+
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
5 |
+
|
6 |
+
* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
7 |
+
|
8 |
+
* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
9 |
+
|
10 |
+
* Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
11 |
+
|
12 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
configs/caption_coco.yaml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
image_root: '/export/share/datasets/vision/coco/images/'
|
2 |
+
ann_root: 'annotation'
|
3 |
+
coco_gt_root: 'annotation/coco_gt'
|
4 |
+
|
5 |
+
# set pretrained as a file path or an url
|
6 |
+
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth'
|
7 |
+
|
8 |
+
# size of vit model; base or large
|
9 |
+
vit: 'base'
|
10 |
+
vit_grad_ckpt: False
|
11 |
+
vit_ckpt_layer: 0
|
12 |
+
batch_size: 32
|
13 |
+
init_lr: 1e-5
|
14 |
+
|
15 |
+
# vit: 'large'
|
16 |
+
# vit_grad_ckpt: True
|
17 |
+
# vit_ckpt_layer: 5
|
18 |
+
# batch_size: 16
|
19 |
+
# init_lr: 2e-6
|
20 |
+
|
21 |
+
image_size: 384
|
22 |
+
|
23 |
+
# generation configs
|
24 |
+
max_length: 20
|
25 |
+
min_length: 5
|
26 |
+
num_beams: 3
|
27 |
+
prompt: 'a picture of '
|
28 |
+
|
29 |
+
# optimizer
|
30 |
+
weight_decay: 0.05
|
31 |
+
min_lr: 0
|
32 |
+
max_epoch: 5
|
33 |
+
|
configs/med_config.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"BertModel"
|
4 |
+
],
|
5 |
+
"attention_probs_dropout_prob": 0.1,
|
6 |
+
"hidden_act": "gelu",
|
7 |
+
"hidden_dropout_prob": 0.1,
|
8 |
+
"hidden_size": 768,
|
9 |
+
"initializer_range": 0.02,
|
10 |
+
"intermediate_size": 3072,
|
11 |
+
"layer_norm_eps": 1e-12,
|
12 |
+
"max_position_embeddings": 512,
|
13 |
+
"model_type": "bert",
|
14 |
+
"num_attention_heads": 12,
|
15 |
+
"num_hidden_layers": 12,
|
16 |
+
"pad_token_id": 0,
|
17 |
+
"type_vocab_size": 2,
|
18 |
+
"vocab_size": 30524,
|
19 |
+
"encoder_width": 768,
|
20 |
+
"add_cross_attention": true
|
21 |
+
}
|
configs/nlvr.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
image_root: '/export/share/datasets/vision/NLVR2/'
|
2 |
+
ann_root: 'annotation'
|
3 |
+
|
4 |
+
# set pretrained as a file path or an url
|
5 |
+
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_nlvr.pth'
|
6 |
+
|
7 |
+
#size of vit model; base or large
|
8 |
+
vit: 'base'
|
9 |
+
batch_size_train: 16
|
10 |
+
batch_size_test: 64
|
11 |
+
vit_grad_ckpt: False
|
12 |
+
vit_ckpt_layer: 0
|
13 |
+
max_epoch: 15
|
14 |
+
|
15 |
+
image_size: 384
|
16 |
+
|
17 |
+
# optimizer
|
18 |
+
weight_decay: 0.05
|
19 |
+
init_lr: 3e-5
|
20 |
+
min_lr: 0
|
21 |
+
|
configs/nocaps.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
image_root: '/export/share/datasets/vision/nocaps/'
|
2 |
+
ann_root: 'annotation'
|
3 |
+
|
4 |
+
# set pretrained as a file path or an url
|
5 |
+
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth'
|
6 |
+
|
7 |
+
vit: 'base'
|
8 |
+
batch_size: 32
|
9 |
+
|
10 |
+
image_size: 384
|
11 |
+
|
12 |
+
max_length: 20
|
13 |
+
min_length: 5
|
14 |
+
num_beams: 3
|
15 |
+
prompt: 'a picture of '
|
configs/pretrain.yaml
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
train_file: ['/export/share/junnan-li/VL_pretrain/annotation/coco_karpathy_train.json',
|
2 |
+
'/export/share/junnan-li/VL_pretrain/annotation/vg_caption.json',
|
3 |
+
]
|
4 |
+
laion_path: ''
|
5 |
+
|
6 |
+
# size of vit model; base or large
|
7 |
+
vit: 'base'
|
8 |
+
vit_grad_ckpt: False
|
9 |
+
vit_ckpt_layer: 0
|
10 |
+
|
11 |
+
image_size: 224
|
12 |
+
batch_size: 75
|
13 |
+
|
14 |
+
queue_size: 57600
|
15 |
+
alpha: 0.4
|
16 |
+
|
17 |
+
# optimizer
|
18 |
+
weight_decay: 0.05
|
19 |
+
init_lr: 3e-4
|
20 |
+
min_lr: 1e-6
|
21 |
+
warmup_lr: 1e-6
|
22 |
+
lr_decay_rate: 0.9
|
23 |
+
max_epoch: 20
|
24 |
+
warmup_steps: 3000
|
25 |
+
|
26 |
+
|
27 |
+
|
configs/retrieval_coco.yaml
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
image_root: '/export/share/datasets/vision/coco/images/'
|
2 |
+
ann_root: 'annotation'
|
3 |
+
dataset: 'coco'
|
4 |
+
|
5 |
+
# set pretrained as a file path or an url
|
6 |
+
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth'
|
7 |
+
|
8 |
+
# size of vit model; base or large
|
9 |
+
|
10 |
+
vit: 'base'
|
11 |
+
batch_size_train: 32
|
12 |
+
batch_size_test: 64
|
13 |
+
vit_grad_ckpt: True
|
14 |
+
vit_ckpt_layer: 4
|
15 |
+
init_lr: 1e-5
|
16 |
+
|
17 |
+
# vit: 'large'
|
18 |
+
# batch_size_train: 16
|
19 |
+
# batch_size_test: 32
|
20 |
+
# vit_grad_ckpt: True
|
21 |
+
# vit_ckpt_layer: 12
|
22 |
+
# init_lr: 5e-6
|
23 |
+
|
24 |
+
image_size: 384
|
25 |
+
queue_size: 57600
|
26 |
+
alpha: 0.4
|
27 |
+
k_test: 256
|
28 |
+
negative_all_rank: True
|
29 |
+
|
30 |
+
# optimizer
|
31 |
+
weight_decay: 0.05
|
32 |
+
min_lr: 0
|
33 |
+
max_epoch: 6
|
34 |
+
|
configs/retrieval_flickr.yaml
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
image_root: '/export/share/datasets/vision/flickr30k/'
|
2 |
+
ann_root: 'annotation'
|
3 |
+
dataset: 'flickr'
|
4 |
+
|
5 |
+
# set pretrained as a file path or an url
|
6 |
+
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_flickr.pth'
|
7 |
+
|
8 |
+
# size of vit model; base or large
|
9 |
+
|
10 |
+
vit: 'base'
|
11 |
+
batch_size_train: 32
|
12 |
+
batch_size_test: 64
|
13 |
+
vit_grad_ckpt: True
|
14 |
+
vit_ckpt_layer: 4
|
15 |
+
init_lr: 1e-5
|
16 |
+
|
17 |
+
# vit: 'large'
|
18 |
+
# batch_size_train: 16
|
19 |
+
# batch_size_test: 32
|
20 |
+
# vit_grad_ckpt: True
|
21 |
+
# vit_ckpt_layer: 10
|
22 |
+
# init_lr: 5e-6
|
23 |
+
|
24 |
+
image_size: 384
|
25 |
+
queue_size: 57600
|
26 |
+
alpha: 0.4
|
27 |
+
k_test: 128
|
28 |
+
negative_all_rank: False
|
29 |
+
|
30 |
+
# optimizer
|
31 |
+
weight_decay: 0.05
|
32 |
+
min_lr: 0
|
33 |
+
max_epoch: 6
|
34 |
+
|
configs/vqa.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
vqa_root: '/export/share/datasets/vision/VQA/Images/mscoco/' #followed by train2014/
|
2 |
+
vg_root: '/export/share/datasets/vision/visual-genome/' #followed by image/
|
3 |
+
train_files: ['vqa_train','vqa_val','vg_qa']
|
4 |
+
ann_root: 'annotation'
|
5 |
+
|
6 |
+
# set pretrained as a file path or an url
|
7 |
+
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_vqa.pth'
|
8 |
+
|
9 |
+
# size of vit model; base or large
|
10 |
+
vit: 'base'
|
11 |
+
batch_size_train: 16
|
12 |
+
batch_size_test: 32
|
13 |
+
vit_grad_ckpt: False
|
14 |
+
vit_ckpt_layer: 0
|
15 |
+
init_lr: 2e-5
|
16 |
+
|
17 |
+
image_size: 480
|
18 |
+
|
19 |
+
k_test: 128
|
20 |
+
inference: 'rank'
|
21 |
+
|
22 |
+
# optimizer
|
23 |
+
weight_decay: 0.05
|
24 |
+
min_lr: 0
|
25 |
+
max_epoch: 10
|
data/__init__.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import DataLoader
|
3 |
+
from torchvision import transforms
|
4 |
+
from torchvision.transforms.functional import InterpolationMode
|
5 |
+
|
6 |
+
from data.coco_karpathy_dataset import coco_karpathy_train, coco_karpathy_caption_eval, coco_karpathy_retrieval_eval
|
7 |
+
from data.nocaps_dataset import nocaps_eval
|
8 |
+
from data.flickr30k_dataset import flickr30k_train, flickr30k_retrieval_eval
|
9 |
+
from data.vqa_dataset import vqa_dataset
|
10 |
+
from data.nlvr_dataset import nlvr_dataset
|
11 |
+
from data.pretrain_dataset import pretrain_dataset
|
12 |
+
from transform.randaugment import RandomAugment
|
13 |
+
|
14 |
+
def create_dataset(dataset, config, min_scale=0.5):
|
15 |
+
|
16 |
+
normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
17 |
+
|
18 |
+
transform_train = transforms.Compose([
|
19 |
+
transforms.RandomResizedCrop(config['image_size'],scale=(min_scale, 1.0),interpolation=InterpolationMode.BICUBIC),
|
20 |
+
transforms.RandomHorizontalFlip(),
|
21 |
+
RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize',
|
22 |
+
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
|
23 |
+
transforms.ToTensor(),
|
24 |
+
normalize,
|
25 |
+
])
|
26 |
+
transform_test = transforms.Compose([
|
27 |
+
transforms.Resize((config['image_size'],config['image_size']),interpolation=InterpolationMode.BICUBIC),
|
28 |
+
transforms.ToTensor(),
|
29 |
+
normalize,
|
30 |
+
])
|
31 |
+
|
32 |
+
if dataset=='pretrain':
|
33 |
+
dataset = pretrain_dataset(config['train_file'], config['laion_path'], transform_train)
|
34 |
+
return dataset
|
35 |
+
|
36 |
+
elif dataset=='caption_coco':
|
37 |
+
train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'], prompt=config['prompt'])
|
38 |
+
val_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'val')
|
39 |
+
test_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'test')
|
40 |
+
return train_dataset, val_dataset, test_dataset
|
41 |
+
|
42 |
+
elif dataset=='nocaps':
|
43 |
+
val_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'val')
|
44 |
+
test_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'test')
|
45 |
+
return val_dataset, test_dataset
|
46 |
+
|
47 |
+
elif dataset=='retrieval_coco':
|
48 |
+
train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'])
|
49 |
+
val_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val')
|
50 |
+
test_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test')
|
51 |
+
return train_dataset, val_dataset, test_dataset
|
52 |
+
|
53 |
+
elif dataset=='retrieval_flickr':
|
54 |
+
train_dataset = flickr30k_train(transform_train, config['image_root'], config['ann_root'])
|
55 |
+
val_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val')
|
56 |
+
test_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test')
|
57 |
+
return train_dataset, val_dataset, test_dataset
|
58 |
+
|
59 |
+
elif dataset=='vqa':
|
60 |
+
train_dataset = vqa_dataset(transform_train, config['ann_root'], config['vqa_root'], config['vg_root'],
|
61 |
+
train_files = config['train_files'], split='train')
|
62 |
+
test_dataset = vqa_dataset(transform_test, config['ann_root'], config['vqa_root'], config['vg_root'], split='test')
|
63 |
+
return train_dataset, test_dataset
|
64 |
+
|
65 |
+
elif dataset=='nlvr':
|
66 |
+
train_dataset = nlvr_dataset(transform_train, config['image_root'], config['ann_root'],'train')
|
67 |
+
val_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'val')
|
68 |
+
test_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'test')
|
69 |
+
return train_dataset, val_dataset, test_dataset
|
70 |
+
|
71 |
+
|
72 |
+
def create_sampler(datasets, shuffles, num_tasks, global_rank):
|
73 |
+
samplers = []
|
74 |
+
for dataset,shuffle in zip(datasets,shuffles):
|
75 |
+
sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle)
|
76 |
+
samplers.append(sampler)
|
77 |
+
return samplers
|
78 |
+
|
79 |
+
|
80 |
+
def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
|
81 |
+
loaders = []
|
82 |
+
for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns):
|
83 |
+
if is_train:
|
84 |
+
shuffle = (sampler is None)
|
85 |
+
drop_last = True
|
86 |
+
else:
|
87 |
+
shuffle = False
|
88 |
+
drop_last = False
|
89 |
+
loader = DataLoader(
|
90 |
+
dataset,
|
91 |
+
batch_size=bs,
|
92 |
+
num_workers=n_worker,
|
93 |
+
pin_memory=True,
|
94 |
+
sampler=sampler,
|
95 |
+
shuffle=shuffle,
|
96 |
+
collate_fn=collate_fn,
|
97 |
+
drop_last=drop_last,
|
98 |
+
)
|
99 |
+
loaders.append(loader)
|
100 |
+
return loaders
|
101 |
+
|
data/coco_karpathy_dataset.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
from torchvision.datasets.utils import download_url
|
6 |
+
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
from data.utils import pre_caption
|
10 |
+
|
11 |
+
class coco_karpathy_train(Dataset):
|
12 |
+
def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''):
|
13 |
+
'''
|
14 |
+
image_root (string): Root directory of images (e.g. coco/images/)
|
15 |
+
ann_root (string): directory to store the annotation file
|
16 |
+
'''
|
17 |
+
url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json'
|
18 |
+
filename = 'coco_karpathy_train.json'
|
19 |
+
|
20 |
+
download_url(url,ann_root)
|
21 |
+
|
22 |
+
self.annotation = json.load(open(os.path.join(ann_root,filename),'r'))
|
23 |
+
self.transform = transform
|
24 |
+
self.image_root = image_root
|
25 |
+
self.max_words = max_words
|
26 |
+
self.prompt = prompt
|
27 |
+
|
28 |
+
self.img_ids = {}
|
29 |
+
n = 0
|
30 |
+
for ann in self.annotation:
|
31 |
+
img_id = ann['image_id']
|
32 |
+
if img_id not in self.img_ids.keys():
|
33 |
+
self.img_ids[img_id] = n
|
34 |
+
n += 1
|
35 |
+
|
36 |
+
def __len__(self):
|
37 |
+
return len(self.annotation)
|
38 |
+
|
39 |
+
def __getitem__(self, index):
|
40 |
+
|
41 |
+
ann = self.annotation[index]
|
42 |
+
|
43 |
+
image_path = os.path.join(self.image_root,ann['image'])
|
44 |
+
image = Image.open(image_path).convert('RGB')
|
45 |
+
image = self.transform(image)
|
46 |
+
|
47 |
+
caption = self.prompt+pre_caption(ann['caption'], self.max_words)
|
48 |
+
|
49 |
+
return image, caption, self.img_ids[ann['image_id']]
|
50 |
+
|
51 |
+
|
52 |
+
class coco_karpathy_caption_eval(Dataset):
|
53 |
+
def __init__(self, transform, image_root, ann_root, split):
|
54 |
+
'''
|
55 |
+
image_root (string): Root directory of images (e.g. coco/images/)
|
56 |
+
ann_root (string): directory to store the annotation file
|
57 |
+
split (string): val or test
|
58 |
+
'''
|
59 |
+
urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json',
|
60 |
+
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'}
|
61 |
+
filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'}
|
62 |
+
|
63 |
+
download_url(urls[split],ann_root)
|
64 |
+
|
65 |
+
self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
|
66 |
+
self.transform = transform
|
67 |
+
self.image_root = image_root
|
68 |
+
|
69 |
+
def __len__(self):
|
70 |
+
return len(self.annotation)
|
71 |
+
|
72 |
+
def __getitem__(self, index):
|
73 |
+
|
74 |
+
ann = self.annotation[index]
|
75 |
+
|
76 |
+
image_path = os.path.join(self.image_root,ann['image'])
|
77 |
+
image = Image.open(image_path).convert('RGB')
|
78 |
+
image = self.transform(image)
|
79 |
+
|
80 |
+
img_id = ann['image'].split('/')[-1].strip('.jpg').split('_')[-1]
|
81 |
+
|
82 |
+
return image, int(img_id)
|
83 |
+
|
84 |
+
|
85 |
+
class coco_karpathy_retrieval_eval(Dataset):
|
86 |
+
def __init__(self, transform, image_root, ann_root, split, max_words=30):
|
87 |
+
'''
|
88 |
+
image_root (string): Root directory of images (e.g. coco/images/)
|
89 |
+
ann_root (string): directory to store the annotation file
|
90 |
+
split (string): val or test
|
91 |
+
'''
|
92 |
+
urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json',
|
93 |
+
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'}
|
94 |
+
filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'}
|
95 |
+
|
96 |
+
download_url(urls[split],ann_root)
|
97 |
+
|
98 |
+
self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
|
99 |
+
self.transform = transform
|
100 |
+
self.image_root = image_root
|
101 |
+
|
102 |
+
self.text = []
|
103 |
+
self.image = []
|
104 |
+
self.txt2img = {}
|
105 |
+
self.img2txt = {}
|
106 |
+
|
107 |
+
txt_id = 0
|
108 |
+
for img_id, ann in enumerate(self.annotation):
|
109 |
+
self.image.append(ann['image'])
|
110 |
+
self.img2txt[img_id] = []
|
111 |
+
for i, caption in enumerate(ann['caption']):
|
112 |
+
self.text.append(pre_caption(caption,max_words))
|
113 |
+
self.img2txt[img_id].append(txt_id)
|
114 |
+
self.txt2img[txt_id] = img_id
|
115 |
+
txt_id += 1
|
116 |
+
|
117 |
+
def __len__(self):
|
118 |
+
return len(self.annotation)
|
119 |
+
|
120 |
+
def __getitem__(self, index):
|
121 |
+
|
122 |
+
image_path = os.path.join(self.image_root, self.annotation[index]['image'])
|
123 |
+
image = Image.open(image_path).convert('RGB')
|
124 |
+
image = self.transform(image)
|
125 |
+
|
126 |
+
return image, index
|
data/flickr30k_dataset.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
from torchvision.datasets.utils import download_url
|
6 |
+
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
from data.utils import pre_caption
|
10 |
+
|
11 |
+
class flickr30k_train(Dataset):
|
12 |
+
def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''):
|
13 |
+
'''
|
14 |
+
image_root (string): Root directory of images (e.g. flickr30k/)
|
15 |
+
ann_root (string): directory to store the annotation file
|
16 |
+
'''
|
17 |
+
url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_train.json'
|
18 |
+
filename = 'flickr30k_train.json'
|
19 |
+
|
20 |
+
download_url(url,ann_root)
|
21 |
+
|
22 |
+
self.annotation = json.load(open(os.path.join(ann_root,filename),'r'))
|
23 |
+
self.transform = transform
|
24 |
+
self.image_root = image_root
|
25 |
+
self.max_words = max_words
|
26 |
+
self.prompt = prompt
|
27 |
+
|
28 |
+
self.img_ids = {}
|
29 |
+
n = 0
|
30 |
+
for ann in self.annotation:
|
31 |
+
img_id = ann['image_id']
|
32 |
+
if img_id not in self.img_ids.keys():
|
33 |
+
self.img_ids[img_id] = n
|
34 |
+
n += 1
|
35 |
+
|
36 |
+
def __len__(self):
|
37 |
+
return len(self.annotation)
|
38 |
+
|
39 |
+
def __getitem__(self, index):
|
40 |
+
|
41 |
+
ann = self.annotation[index]
|
42 |
+
|
43 |
+
image_path = os.path.join(self.image_root,ann['image'])
|
44 |
+
image = Image.open(image_path).convert('RGB')
|
45 |
+
image = self.transform(image)
|
46 |
+
|
47 |
+
caption = self.prompt+pre_caption(ann['caption'], self.max_words)
|
48 |
+
|
49 |
+
return image, caption, self.img_ids[ann['image_id']]
|
50 |
+
|
51 |
+
|
52 |
+
class flickr30k_retrieval_eval(Dataset):
|
53 |
+
def __init__(self, transform, image_root, ann_root, split, max_words=30):
|
54 |
+
'''
|
55 |
+
image_root (string): Root directory of images (e.g. flickr30k/)
|
56 |
+
ann_root (string): directory to store the annotation file
|
57 |
+
split (string): val or test
|
58 |
+
'''
|
59 |
+
urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json',
|
60 |
+
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json'}
|
61 |
+
filenames = {'val':'flickr30k_val.json','test':'flickr30k_test.json'}
|
62 |
+
|
63 |
+
download_url(urls[split],ann_root)
|
64 |
+
|
65 |
+
self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
|
66 |
+
self.transform = transform
|
67 |
+
self.image_root = image_root
|
68 |
+
|
69 |
+
self.text = []
|
70 |
+
self.image = []
|
71 |
+
self.txt2img = {}
|
72 |
+
self.img2txt = {}
|
73 |
+
|
74 |
+
txt_id = 0
|
75 |
+
for img_id, ann in enumerate(self.annotation):
|
76 |
+
self.image.append(ann['image'])
|
77 |
+
self.img2txt[img_id] = []
|
78 |
+
for i, caption in enumerate(ann['caption']):
|
79 |
+
self.text.append(pre_caption(caption,max_words))
|
80 |
+
self.img2txt[img_id].append(txt_id)
|
81 |
+
self.txt2img[txt_id] = img_id
|
82 |
+
txt_id += 1
|
83 |
+
|
84 |
+
def __len__(self):
|
85 |
+
return len(self.annotation)
|
86 |
+
|
87 |
+
def __getitem__(self, index):
|
88 |
+
|
89 |
+
image_path = os.path.join(self.image_root, self.annotation[index]['image'])
|
90 |
+
image = Image.open(image_path).convert('RGB')
|
91 |
+
image = self.transform(image)
|
92 |
+
|
93 |
+
return image, index
|
data/nlvr_dataset.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import random
|
4 |
+
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
from torchvision.datasets.utils import download_url
|
7 |
+
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
from data.utils import pre_caption
|
11 |
+
|
12 |
+
class nlvr_dataset(Dataset):
|
13 |
+
def __init__(self, transform, image_root, ann_root, split):
|
14 |
+
'''
|
15 |
+
image_root (string): Root directory of images
|
16 |
+
ann_root (string): directory to store the annotation file
|
17 |
+
split (string): train, val or test
|
18 |
+
'''
|
19 |
+
urls = {'train':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_train.json',
|
20 |
+
'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_dev.json',
|
21 |
+
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_test.json'}
|
22 |
+
filenames = {'train':'nlvr_train.json','val':'nlvr_dev.json','test':'nlvr_test.json'}
|
23 |
+
|
24 |
+
download_url(urls[split],ann_root)
|
25 |
+
self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
|
26 |
+
|
27 |
+
self.transform = transform
|
28 |
+
self.image_root = image_root
|
29 |
+
|
30 |
+
|
31 |
+
def __len__(self):
|
32 |
+
return len(self.annotation)
|
33 |
+
|
34 |
+
|
35 |
+
def __getitem__(self, index):
|
36 |
+
|
37 |
+
ann = self.annotation[index]
|
38 |
+
|
39 |
+
image0_path = os.path.join(self.image_root,ann['images'][0])
|
40 |
+
image0 = Image.open(image0_path).convert('RGB')
|
41 |
+
image0 = self.transform(image0)
|
42 |
+
|
43 |
+
image1_path = os.path.join(self.image_root,ann['images'][1])
|
44 |
+
image1 = Image.open(image1_path).convert('RGB')
|
45 |
+
image1 = self.transform(image1)
|
46 |
+
|
47 |
+
sentence = pre_caption(ann['sentence'], 40)
|
48 |
+
|
49 |
+
if ann['label']=='True':
|
50 |
+
label = 1
|
51 |
+
else:
|
52 |
+
label = 0
|
53 |
+
|
54 |
+
words = sentence.split(' ')
|
55 |
+
|
56 |
+
if 'left' not in words and 'right' not in words:
|
57 |
+
if random.random()<0.5:
|
58 |
+
return image0, image1, sentence, label
|
59 |
+
else:
|
60 |
+
return image1, image0, sentence, label
|
61 |
+
else:
|
62 |
+
if random.random()<0.5:
|
63 |
+
return image0, image1, sentence, label
|
64 |
+
else:
|
65 |
+
new_words = []
|
66 |
+
for word in words:
|
67 |
+
if word=='left':
|
68 |
+
new_words.append('right')
|
69 |
+
elif word=='right':
|
70 |
+
new_words.append('left')
|
71 |
+
else:
|
72 |
+
new_words.append(word)
|
73 |
+
|
74 |
+
sentence = ' '.join(new_words)
|
75 |
+
return image1, image0, sentence, label
|
76 |
+
|
77 |
+
|
78 |
+
|
data/nocaps_dataset.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
from torchvision.datasets.utils import download_url
|
6 |
+
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
class nocaps_eval(Dataset):
|
10 |
+
def __init__(self, transform, image_root, ann_root, split):
|
11 |
+
urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_val.json',
|
12 |
+
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_test.json'}
|
13 |
+
filenames = {'val':'nocaps_val.json','test':'nocaps_test.json'}
|
14 |
+
|
15 |
+
download_url(urls[split],ann_root)
|
16 |
+
|
17 |
+
self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
|
18 |
+
self.transform = transform
|
19 |
+
self.image_root = image_root
|
20 |
+
|
21 |
+
def __len__(self):
|
22 |
+
return len(self.annotation)
|
23 |
+
|
24 |
+
def __getitem__(self, index):
|
25 |
+
|
26 |
+
ann = self.annotation[index]
|
27 |
+
|
28 |
+
image_path = os.path.join(self.image_root,ann['image'])
|
29 |
+
image = Image.open(image_path).convert('RGB')
|
30 |
+
image = self.transform(image)
|
31 |
+
|
32 |
+
return image, int(ann['img_id'])
|
data/pretrain_dataset.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
|
7 |
+
from PIL import Image
|
8 |
+
from PIL import ImageFile
|
9 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
10 |
+
Image.MAX_IMAGE_PIXELS = None
|
11 |
+
|
12 |
+
from data.utils import pre_caption
|
13 |
+
import os,glob
|
14 |
+
|
15 |
+
class pretrain_dataset(Dataset):
|
16 |
+
def __init__(self, ann_file, laion_path, transform):
|
17 |
+
|
18 |
+
self.ann_pretrain = []
|
19 |
+
for f in ann_file:
|
20 |
+
print('loading '+f)
|
21 |
+
ann = json.load(open(f,'r'))
|
22 |
+
self.ann_pretrain += ann
|
23 |
+
|
24 |
+
self.laion_path = laion_path
|
25 |
+
if self.laion_path:
|
26 |
+
self.laion_files = glob.glob(os.path.join(laion_path,'*.json'))
|
27 |
+
|
28 |
+
print('loading '+self.laion_files[0])
|
29 |
+
with open(self.laion_files[0],'r') as f:
|
30 |
+
self.ann_laion = json.load(f)
|
31 |
+
|
32 |
+
self.annotation = self.ann_pretrain + self.ann_laion
|
33 |
+
else:
|
34 |
+
self.annotation = self.ann_pretrain
|
35 |
+
|
36 |
+
self.transform = transform
|
37 |
+
|
38 |
+
|
39 |
+
def reload_laion(self, epoch):
|
40 |
+
n = epoch%len(self.laion_files)
|
41 |
+
print('loading '+self.laion_files[n])
|
42 |
+
with open(self.laion_files[n],'r') as f:
|
43 |
+
self.ann_laion = json.load(f)
|
44 |
+
|
45 |
+
self.annotation = self.ann_pretrain + self.ann_laion
|
46 |
+
|
47 |
+
|
48 |
+
def __len__(self):
|
49 |
+
return len(self.annotation)
|
50 |
+
|
51 |
+
def __getitem__(self, index):
|
52 |
+
|
53 |
+
ann = self.annotation[index]
|
54 |
+
|
55 |
+
image = Image.open(ann['image']).convert('RGB')
|
56 |
+
image = self.transform(image)
|
57 |
+
caption = pre_caption(ann['caption'],30)
|
58 |
+
|
59 |
+
return image, caption
|
data/utils.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.distributed as dist
|
7 |
+
|
8 |
+
import utils
|
9 |
+
|
10 |
+
def pre_caption(caption,max_words=50):
|
11 |
+
caption = re.sub(
|
12 |
+
r"([.!\"()*#:;~])",
|
13 |
+
' ',
|
14 |
+
caption.lower(),
|
15 |
+
)
|
16 |
+
caption = re.sub(
|
17 |
+
r"\s{2,}",
|
18 |
+
' ',
|
19 |
+
caption,
|
20 |
+
)
|
21 |
+
caption = caption.rstrip('\n')
|
22 |
+
caption = caption.strip(' ')
|
23 |
+
|
24 |
+
#truncate caption
|
25 |
+
caption_words = caption.split(' ')
|
26 |
+
if len(caption_words)>max_words:
|
27 |
+
caption = ' '.join(caption_words[:max_words])
|
28 |
+
|
29 |
+
return caption
|
30 |
+
|
31 |
+
def pre_question(question,max_ques_words=50):
|
32 |
+
question = re.sub(
|
33 |
+
r"([.!\"()*#:;~])",
|
34 |
+
'',
|
35 |
+
question.lower(),
|
36 |
+
)
|
37 |
+
question = question.rstrip(' ')
|
38 |
+
|
39 |
+
#truncate question
|
40 |
+
question_words = question.split(' ')
|
41 |
+
if len(question_words)>max_ques_words:
|
42 |
+
question = ' '.join(question_words[:max_ques_words])
|
43 |
+
|
44 |
+
return question
|
45 |
+
|
46 |
+
|
47 |
+
def save_result(result, result_dir, filename, remove_duplicate=''):
|
48 |
+
result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,utils.get_rank()))
|
49 |
+
final_result_file = os.path.join(result_dir, '%s.json'%filename)
|
50 |
+
|
51 |
+
json.dump(result,open(result_file,'w'))
|
52 |
+
|
53 |
+
dist.barrier()
|
54 |
+
|
55 |
+
if utils.is_main_process():
|
56 |
+
# combine results from all processes
|
57 |
+
result = []
|
58 |
+
|
59 |
+
for rank in range(utils.get_world_size()):
|
60 |
+
result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,rank))
|
61 |
+
res = json.load(open(result_file,'r'))
|
62 |
+
result += res
|
63 |
+
|
64 |
+
if remove_duplicate:
|
65 |
+
result_new = []
|
66 |
+
id_list = []
|
67 |
+
for res in result:
|
68 |
+
if res[remove_duplicate] not in id_list:
|
69 |
+
id_list.append(res[remove_duplicate])
|
70 |
+
result_new.append(res)
|
71 |
+
result = result_new
|
72 |
+
|
73 |
+
json.dump(result,open(final_result_file,'w'))
|
74 |
+
print('result file saved to %s'%final_result_file)
|
75 |
+
|
76 |
+
return final_result_file
|
77 |
+
|
78 |
+
|
79 |
+
|
80 |
+
from pycocotools.coco import COCO
|
81 |
+
from pycocoevalcap.eval import COCOEvalCap
|
82 |
+
from torchvision.datasets.utils import download_url
|
83 |
+
|
84 |
+
def coco_caption_eval(coco_gt_root, results_file, split):
|
85 |
+
urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json',
|
86 |
+
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json'}
|
87 |
+
filenames = {'val':'coco_karpathy_val_gt.json','test':'coco_karpathy_test_gt.json'}
|
88 |
+
|
89 |
+
download_url(urls[split],coco_gt_root)
|
90 |
+
annotation_file = os.path.join(coco_gt_root,filenames[split])
|
91 |
+
|
92 |
+
# create coco object and coco_result object
|
93 |
+
coco = COCO(annotation_file)
|
94 |
+
coco_result = coco.loadRes(results_file)
|
95 |
+
|
96 |
+
# create coco_eval object by taking coco and coco_result
|
97 |
+
coco_eval = COCOEvalCap(coco, coco_result)
|
98 |
+
|
99 |
+
# evaluate on a subset of images by setting
|
100 |
+
# coco_eval.params['image_id'] = coco_result.getImgIds()
|
101 |
+
# please remove this line when evaluating the full validation set
|
102 |
+
# coco_eval.params['image_id'] = coco_result.getImgIds()
|
103 |
+
|
104 |
+
# evaluate results
|
105 |
+
# SPICE will take a few minutes the first time, but speeds up due to caching
|
106 |
+
coco_eval.evaluate()
|
107 |
+
|
108 |
+
# print output evaluation scores
|
109 |
+
for metric, score in coco_eval.eval.items():
|
110 |
+
print(f'{metric}: {score:.3f}')
|
111 |
+
|
112 |
+
return coco_eval
|
data/vqa_dataset.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import random
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch.utils.data import Dataset
|
8 |
+
from data.utils import pre_question
|
9 |
+
|
10 |
+
from torchvision.datasets.utils import download_url
|
11 |
+
|
12 |
+
class vqa_dataset(Dataset):
|
13 |
+
def __init__(self, transform, ann_root, vqa_root, vg_root, train_files=[], split="train"):
|
14 |
+
self.split = split
|
15 |
+
|
16 |
+
self.transform = transform
|
17 |
+
self.vqa_root = vqa_root
|
18 |
+
self.vg_root = vg_root
|
19 |
+
|
20 |
+
if split=='train':
|
21 |
+
urls = {'vqa_train':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_train.json',
|
22 |
+
'vqa_val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_val.json',
|
23 |
+
'vg_qa':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vg_qa.json'}
|
24 |
+
|
25 |
+
self.annotation = []
|
26 |
+
for f in train_files:
|
27 |
+
download_url(urls[f],ann_root)
|
28 |
+
self.annotation += json.load(open(os.path.join(ann_root,'%s.json'%f),'r'))
|
29 |
+
else:
|
30 |
+
download_url('https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_test.json',ann_root)
|
31 |
+
self.annotation = json.load(open(os.path.join(ann_root,'vqa_test.json'),'r'))
|
32 |
+
|
33 |
+
download_url('https://storage.googleapis.com/sfr-vision-language-research/datasets/answer_list.json',ann_root)
|
34 |
+
self.answer_list = json.load(open(os.path.join(ann_root,'answer_list.json'),'r'))
|
35 |
+
|
36 |
+
|
37 |
+
def __len__(self):
|
38 |
+
return len(self.annotation)
|
39 |
+
|
40 |
+
def __getitem__(self, index):
|
41 |
+
|
42 |
+
ann = self.annotation[index]
|
43 |
+
|
44 |
+
if ann['dataset']=='vqa':
|
45 |
+
image_path = os.path.join(self.vqa_root,ann['image'])
|
46 |
+
elif ann['dataset']=='vg':
|
47 |
+
image_path = os.path.join(self.vg_root,ann['image'])
|
48 |
+
|
49 |
+
image = Image.open(image_path).convert('RGB')
|
50 |
+
image = self.transform(image)
|
51 |
+
|
52 |
+
if self.split == 'test':
|
53 |
+
question = pre_question(ann['question'])
|
54 |
+
question_id = ann['question_id']
|
55 |
+
return image, question, question_id
|
56 |
+
|
57 |
+
|
58 |
+
elif self.split=='train':
|
59 |
+
|
60 |
+
question = pre_question(ann['question'])
|
61 |
+
|
62 |
+
if ann['dataset']=='vqa':
|
63 |
+
answer_weight = {}
|
64 |
+
for answer in ann['answer']:
|
65 |
+
if answer in answer_weight.keys():
|
66 |
+
answer_weight[answer] += 1/len(ann['answer'])
|
67 |
+
else:
|
68 |
+
answer_weight[answer] = 1/len(ann['answer'])
|
69 |
+
|
70 |
+
answers = list(answer_weight.keys())
|
71 |
+
weights = list(answer_weight.values())
|
72 |
+
|
73 |
+
elif ann['dataset']=='vg':
|
74 |
+
answers = [ann['answer']]
|
75 |
+
weights = [0.2]
|
76 |
+
|
77 |
+
return image, question, answers, weights
|
78 |
+
|
79 |
+
|
80 |
+
def vqa_collate_fn(batch):
|
81 |
+
image_list, question_list, answer_list, weight_list, n = [], [], [], [], []
|
82 |
+
for image, question, answer, weights in batch:
|
83 |
+
image_list.append(image)
|
84 |
+
question_list.append(question)
|
85 |
+
weight_list += weights
|
86 |
+
answer_list += answer
|
87 |
+
n.append(len(answer))
|
88 |
+
return torch.stack(image_list,dim=0), question_list, answer_list, torch.Tensor(weight_list), n
|
eval_nocaps.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
* Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
* All rights reserved.
|
4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
* By Junnan Li
|
7 |
+
'''
|
8 |
+
import argparse
|
9 |
+
import os
|
10 |
+
import ruamel_yaml as yaml
|
11 |
+
import numpy as np
|
12 |
+
import random
|
13 |
+
import time
|
14 |
+
import datetime
|
15 |
+
import json
|
16 |
+
from pathlib import Path
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
import torch.backends.cudnn as cudnn
|
22 |
+
import torch.distributed as dist
|
23 |
+
from torch.utils.data import DataLoader
|
24 |
+
|
25 |
+
from models.blip import blip_decoder
|
26 |
+
import utils
|
27 |
+
from data import create_dataset, create_sampler, create_loader
|
28 |
+
from data.utils import save_result
|
29 |
+
|
30 |
+
@torch.no_grad()
|
31 |
+
def evaluate(model, data_loader, device, config):
|
32 |
+
# evaluate
|
33 |
+
model.eval()
|
34 |
+
|
35 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
36 |
+
header = 'Evaluation:'
|
37 |
+
print_freq = 10
|
38 |
+
|
39 |
+
result = []
|
40 |
+
for image, image_id in metric_logger.log_every(data_loader, print_freq, header):
|
41 |
+
|
42 |
+
image = image.to(device)
|
43 |
+
|
44 |
+
captions = model.generate(image, sample=False, num_beams=config['num_beams'], max_length=config['max_length'],
|
45 |
+
min_length=config['min_length'], repetition_penalty=1.1)
|
46 |
+
|
47 |
+
for caption, img_id in zip(captions, image_id):
|
48 |
+
result.append({"image_id": img_id.item(), "caption": caption})
|
49 |
+
|
50 |
+
return result
|
51 |
+
|
52 |
+
|
53 |
+
def main(args, config):
|
54 |
+
utils.init_distributed_mode(args)
|
55 |
+
|
56 |
+
device = torch.device(args.device)
|
57 |
+
|
58 |
+
# fix the seed for reproducibility
|
59 |
+
seed = args.seed + utils.get_rank()
|
60 |
+
torch.manual_seed(seed)
|
61 |
+
np.random.seed(seed)
|
62 |
+
random.seed(seed)
|
63 |
+
cudnn.benchmark = True
|
64 |
+
|
65 |
+
#### Dataset ####
|
66 |
+
print("Creating captioning dataset")
|
67 |
+
val_dataset, test_dataset = create_dataset('nocaps', config)
|
68 |
+
|
69 |
+
if args.distributed:
|
70 |
+
num_tasks = utils.get_world_size()
|
71 |
+
global_rank = utils.get_rank()
|
72 |
+
samplers = create_sampler([val_dataset,test_dataset], [False,False], num_tasks, global_rank)
|
73 |
+
else:
|
74 |
+
samplers = [None,None]
|
75 |
+
|
76 |
+
val_loader, test_loader = create_loader([val_dataset, test_dataset],samplers,
|
77 |
+
batch_size=[config['batch_size']]*2,num_workers=[4,4],
|
78 |
+
is_trains=[False, False], collate_fns=[None,None])
|
79 |
+
|
80 |
+
#### Model ####
|
81 |
+
print("Creating model")
|
82 |
+
model = blip_decoder(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'],
|
83 |
+
prompt=config['prompt'])
|
84 |
+
|
85 |
+
model = model.to(device)
|
86 |
+
|
87 |
+
model_without_ddp = model
|
88 |
+
if args.distributed:
|
89 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
90 |
+
model_without_ddp = model.module
|
91 |
+
|
92 |
+
val_result = evaluate(model_without_ddp, val_loader, device, config)
|
93 |
+
val_result_file = save_result(val_result, args.result_dir, 'val', remove_duplicate='image_id')
|
94 |
+
test_result = evaluate(model_without_ddp, test_loader, device, config)
|
95 |
+
test_result_file = save_result(test_result, args.result_dir, 'test', remove_duplicate='image_id')
|
96 |
+
|
97 |
+
|
98 |
+
if __name__ == '__main__':
|
99 |
+
parser = argparse.ArgumentParser()
|
100 |
+
parser.add_argument('--config', default='./configs/nocaps.yaml')
|
101 |
+
parser.add_argument('--output_dir', default='output/NoCaps')
|
102 |
+
parser.add_argument('--device', default='cuda')
|
103 |
+
parser.add_argument('--seed', default=42, type=int)
|
104 |
+
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
|
105 |
+
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
|
106 |
+
parser.add_argument('--distributed', default=True, type=bool)
|
107 |
+
args = parser.parse_args()
|
108 |
+
|
109 |
+
config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
|
110 |
+
|
111 |
+
args.result_dir = os.path.join(args.output_dir, 'result')
|
112 |
+
|
113 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
114 |
+
Path(args.result_dir).mkdir(parents=True, exist_ok=True)
|
115 |
+
|
116 |
+
yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
|
117 |
+
|
118 |
+
main(args, config)
|
models/__init__.py
ADDED
File without changes
|
models/blip.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
* Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
* All rights reserved.
|
4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
* By Junnan Li
|
7 |
+
'''
|
8 |
+
import warnings
|
9 |
+
warnings.filterwarnings("ignore")
|
10 |
+
|
11 |
+
from models.vit import VisionTransformer, interpolate_pos_embed
|
12 |
+
from models.med import BertConfig, BertModel, BertLMHeadModel
|
13 |
+
from transformers import BertTokenizer
|
14 |
+
|
15 |
+
import torch
|
16 |
+
from torch import nn
|
17 |
+
import torch.nn.functional as F
|
18 |
+
|
19 |
+
import os
|
20 |
+
from urllib.parse import urlparse
|
21 |
+
from timm.models.hub import download_cached_file
|
22 |
+
|
23 |
+
class BLIP_Base(nn.Module):
|
24 |
+
def __init__(self,
|
25 |
+
med_config = 'configs/med_config.json',
|
26 |
+
image_size = 224,
|
27 |
+
vit = 'base',
|
28 |
+
vit_grad_ckpt = False,
|
29 |
+
vit_ckpt_layer = 0,
|
30 |
+
):
|
31 |
+
"""
|
32 |
+
Args:
|
33 |
+
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
34 |
+
image_size (int): input image size
|
35 |
+
vit (str): model size of vision transformer
|
36 |
+
"""
|
37 |
+
super().__init__()
|
38 |
+
|
39 |
+
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
|
40 |
+
self.tokenizer = init_tokenizer()
|
41 |
+
med_config = BertConfig.from_json_file(med_config)
|
42 |
+
med_config.encoder_width = vision_width
|
43 |
+
self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
|
44 |
+
|
45 |
+
|
46 |
+
def forward(self, image, caption, mode):
|
47 |
+
|
48 |
+
assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal"
|
49 |
+
text = self.tokenizer(caption, return_tensors="pt").to(image.device)
|
50 |
+
|
51 |
+
if mode=='image':
|
52 |
+
# return image features
|
53 |
+
image_embeds = self.visual_encoder(image)
|
54 |
+
return image_embeds
|
55 |
+
|
56 |
+
elif mode=='text':
|
57 |
+
# return text features
|
58 |
+
text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
|
59 |
+
return_dict = True, mode = 'text')
|
60 |
+
return text_output.last_hidden_state
|
61 |
+
|
62 |
+
elif mode=='multimodal':
|
63 |
+
# return multimodel features
|
64 |
+
image_embeds = self.visual_encoder(image)
|
65 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
66 |
+
|
67 |
+
text.input_ids[:,0] = self.tokenizer.enc_token_id
|
68 |
+
output = self.text_encoder(text.input_ids,
|
69 |
+
attention_mask = text.attention_mask,
|
70 |
+
encoder_hidden_states = image_embeds,
|
71 |
+
encoder_attention_mask = image_atts,
|
72 |
+
return_dict = True,
|
73 |
+
)
|
74 |
+
return output.last_hidden_state
|
75 |
+
|
76 |
+
|
77 |
+
|
78 |
+
class BLIP_Decoder(nn.Module):
|
79 |
+
def __init__(self,
|
80 |
+
med_config = 'configs/med_config.json',
|
81 |
+
image_size = 384,
|
82 |
+
vit = 'base',
|
83 |
+
vit_grad_ckpt = False,
|
84 |
+
vit_ckpt_layer = 0,
|
85 |
+
prompt = 'a picture of ',
|
86 |
+
):
|
87 |
+
"""
|
88 |
+
Args:
|
89 |
+
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
90 |
+
image_size (int): input image size
|
91 |
+
vit (str): model size of vision transformer
|
92 |
+
"""
|
93 |
+
super().__init__()
|
94 |
+
|
95 |
+
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
|
96 |
+
self.tokenizer = init_tokenizer()
|
97 |
+
med_config = BertConfig.from_json_file(med_config)
|
98 |
+
med_config.encoder_width = vision_width
|
99 |
+
self.text_decoder = BertLMHeadModel(config=med_config)
|
100 |
+
|
101 |
+
self.prompt = prompt
|
102 |
+
self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1
|
103 |
+
|
104 |
+
|
105 |
+
def forward(self, image, caption):
|
106 |
+
|
107 |
+
image_embeds = self.visual_encoder(image)
|
108 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
109 |
+
|
110 |
+
text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device)
|
111 |
+
|
112 |
+
text.input_ids[:,0] = self.tokenizer.bos_token_id
|
113 |
+
|
114 |
+
decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100)
|
115 |
+
decoder_targets[:,:self.prompt_length] = -100
|
116 |
+
|
117 |
+
decoder_output = self.text_decoder(text.input_ids,
|
118 |
+
attention_mask = text.attention_mask,
|
119 |
+
encoder_hidden_states = image_embeds,
|
120 |
+
encoder_attention_mask = image_atts,
|
121 |
+
labels = decoder_targets,
|
122 |
+
return_dict = True,
|
123 |
+
)
|
124 |
+
loss_lm = decoder_output.loss
|
125 |
+
|
126 |
+
return loss_lm
|
127 |
+
|
128 |
+
def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0):
|
129 |
+
image_embeds = self.visual_encoder(image)
|
130 |
+
|
131 |
+
if not sample:
|
132 |
+
image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
|
133 |
+
|
134 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
135 |
+
model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts}
|
136 |
+
|
137 |
+
prompt = [self.prompt] * image.size(0)
|
138 |
+
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device)
|
139 |
+
input_ids[:,0] = self.tokenizer.bos_token_id
|
140 |
+
input_ids = input_ids[:, :-1]
|
141 |
+
|
142 |
+
if sample:
|
143 |
+
#nucleus sampling
|
144 |
+
outputs = self.text_decoder.generate(input_ids=input_ids,
|
145 |
+
max_length=max_length,
|
146 |
+
min_length=min_length,
|
147 |
+
do_sample=True,
|
148 |
+
top_p=top_p,
|
149 |
+
num_return_sequences=1,
|
150 |
+
eos_token_id=self.tokenizer.sep_token_id,
|
151 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
152 |
+
repetition_penalty=1.1,
|
153 |
+
**model_kwargs)
|
154 |
+
else:
|
155 |
+
#beam search
|
156 |
+
outputs = self.text_decoder.generate(input_ids=input_ids,
|
157 |
+
max_length=max_length,
|
158 |
+
min_length=min_length,
|
159 |
+
num_beams=num_beams,
|
160 |
+
eos_token_id=self.tokenizer.sep_token_id,
|
161 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
162 |
+
repetition_penalty=repetition_penalty,
|
163 |
+
**model_kwargs)
|
164 |
+
|
165 |
+
captions = []
|
166 |
+
for output in outputs:
|
167 |
+
caption = self.tokenizer.decode(output, skip_special_tokens=True)
|
168 |
+
captions.append(caption[len(self.prompt):])
|
169 |
+
return captions
|
170 |
+
|
171 |
+
|
172 |
+
def blip_decoder(pretrained='',**kwargs):
|
173 |
+
model = BLIP_Decoder(**kwargs)
|
174 |
+
if pretrained:
|
175 |
+
model,msg = load_checkpoint(model,pretrained)
|
176 |
+
assert(len(msg.missing_keys)==0)
|
177 |
+
return model
|
178 |
+
|
179 |
+
def blip_feature_extractor(pretrained='',**kwargs):
|
180 |
+
model = BLIP_Base(**kwargs)
|
181 |
+
if pretrained:
|
182 |
+
model,msg = load_checkpoint(model,pretrained)
|
183 |
+
assert(len(msg.missing_keys)==0)
|
184 |
+
return model
|
185 |
+
|
186 |
+
def init_tokenizer():
|
187 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
188 |
+
tokenizer.add_special_tokens({'bos_token':'[DEC]'})
|
189 |
+
tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
|
190 |
+
tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
|
191 |
+
return tokenizer
|
192 |
+
|
193 |
+
|
194 |
+
def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
|
195 |
+
|
196 |
+
assert vit in ['base', 'large'], "vit parameter must be base or large"
|
197 |
+
if vit=='base':
|
198 |
+
vision_width = 768
|
199 |
+
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
|
200 |
+
num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
|
201 |
+
drop_path_rate=0 or drop_path_rate
|
202 |
+
)
|
203 |
+
elif vit=='large':
|
204 |
+
vision_width = 1024
|
205 |
+
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
|
206 |
+
num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
|
207 |
+
drop_path_rate=0.1 or drop_path_rate
|
208 |
+
)
|
209 |
+
return visual_encoder, vision_width
|
210 |
+
|
211 |
+
def is_url(url_or_filename):
|
212 |
+
parsed = urlparse(url_or_filename)
|
213 |
+
return parsed.scheme in ("http", "https")
|
214 |
+
|
215 |
+
def load_checkpoint(model,url_or_filename):
|
216 |
+
if is_url(url_or_filename):
|
217 |
+
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
|
218 |
+
checkpoint = torch.load(cached_file, map_location='cpu')
|
219 |
+
elif os.path.isfile(url_or_filename):
|
220 |
+
checkpoint = torch.load(url_or_filename, map_location='cpu')
|
221 |
+
else:
|
222 |
+
raise RuntimeError('checkpoint url or path is invalid')
|
223 |
+
|
224 |
+
state_dict = checkpoint['model']
|
225 |
+
|
226 |
+
state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
|
227 |
+
if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
|
228 |
+
state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
|
229 |
+
model.visual_encoder_m)
|
230 |
+
for key in model.state_dict().keys():
|
231 |
+
if key in state_dict.keys():
|
232 |
+
if state_dict[key].shape!=model.state_dict()[key].shape:
|
233 |
+
del state_dict[key]
|
234 |
+
|
235 |
+
msg = model.load_state_dict(state_dict,strict=False)
|
236 |
+
print('load checkpoint from %s'%url_or_filename)
|
237 |
+
return model,msg
|
238 |
+
|
models/blip_nlvr.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models.med import BertConfig
|
2 |
+
from models.nlvr_encoder import BertModel
|
3 |
+
from models.vit import interpolate_pos_embed
|
4 |
+
from models.blip import create_vit, init_tokenizer, is_url
|
5 |
+
|
6 |
+
from timm.models.hub import download_cached_file
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch import nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from transformers import BertTokenizer
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
class BLIP_NLVR(nn.Module):
|
15 |
+
def __init__(self,
|
16 |
+
med_config = 'configs/med_config.json',
|
17 |
+
image_size = 480,
|
18 |
+
vit = 'base',
|
19 |
+
vit_grad_ckpt = False,
|
20 |
+
vit_ckpt_layer = 0,
|
21 |
+
):
|
22 |
+
"""
|
23 |
+
Args:
|
24 |
+
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
25 |
+
image_size (int): input image size
|
26 |
+
vit (str): model size of vision transformer
|
27 |
+
"""
|
28 |
+
super().__init__()
|
29 |
+
|
30 |
+
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1)
|
31 |
+
self.tokenizer = init_tokenizer()
|
32 |
+
med_config = BertConfig.from_json_file(med_config)
|
33 |
+
med_config.encoder_width = vision_width
|
34 |
+
self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
|
35 |
+
|
36 |
+
self.cls_head = nn.Sequential(
|
37 |
+
nn.Linear(self.text_encoder.config.hidden_size, self.text_encoder.config.hidden_size),
|
38 |
+
nn.ReLU(),
|
39 |
+
nn.Linear(self.text_encoder.config.hidden_size, 2)
|
40 |
+
)
|
41 |
+
|
42 |
+
def forward(self, image, text, targets, train=True):
|
43 |
+
|
44 |
+
image_embeds = self.visual_encoder(image)
|
45 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
46 |
+
image0_embeds, image1_embeds = torch.split(image_embeds,targets.size(0))
|
47 |
+
|
48 |
+
text = self.tokenizer(text, padding='longest', return_tensors="pt").to(image.device)
|
49 |
+
text.input_ids[:,0] = self.tokenizer.enc_token_id
|
50 |
+
|
51 |
+
output = self.text_encoder(text.input_ids,
|
52 |
+
attention_mask = text.attention_mask,
|