datnguyentien204 commited on
Commit
74bda2d
1 Parent(s): fe7b7dd

Upload folder using huggingface_hub (#1)

Browse files

- a0758b36b503e31948ad69a85744abca27a58e3e385e38b21692d875dad0d798 (8f0ab5123cd4cb4967a1d887d997dabcdbbabfca)
- d16909ba7788f34f26090a924304028c807d03bbcb38c496b65d34ff2e560d85 (2117908f05a819796dd3497133d2c7267cc572df)

This view is limited to 50 files because it contains too many changes.   See raw diff
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ objects.json filter=lfs diff=lfs merge=lfs -text
CODEOWNERS ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Comment line immediately above ownership line is reserved for related gus information. Please be careful while editing.
2
+ #ECCN:Open Source
LICENSE.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2022, Salesforce.com, Inc.
2
+ All rights reserved.
3
+
4
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
5
+
6
+ * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
7
+
8
+ * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
9
+
10
+ * Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
11
+
12
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
__init__.py ADDED
File without changes
cog.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ build:
2
+ gpu: true
3
+ cuda: "11.1"
4
+ python_version: "3.8"
5
+ system_packages:
6
+ - "libgl1-mesa-glx"
7
+ - "libglib2.0-0"
8
+ python_packages:
9
+ - "ipython==7.30.1"
10
+ - "torchvision==0.11.1"
11
+ - "torch==1.10.0"
12
+ - "timm==0.4.12"
13
+ - "transformers==4.15.0"
14
+ - "fairscale==0.4.4"
15
+ - "pycocoevalcap==1.2"
16
+
17
+ predict: "predict.py:Predictor"
configs/bert_config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertModel"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "hidden_act": "gelu",
7
+ "hidden_dropout_prob": 0.1,
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "intermediate_size": 3072,
11
+ "layer_norm_eps": 1e-12,
12
+ "max_position_embeddings": 512,
13
+ "model_type": "bert",
14
+ "num_attention_heads": 12,
15
+ "num_hidden_layers": 12,
16
+ "pad_token_id": 0,
17
+ "type_vocab_size": 2,
18
+ "vocab_size": 30522,
19
+ "encoder_width": 768,
20
+ "add_cross_attention": true
21
+ }
configs/caption_coco.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_root: '/export/share/datasets/vision/coco/images/'
2
+ ann_root: 'annotation'
3
+ coco_gt_root: 'annotation/coco_gt'
4
+
5
+ # set pretrained as a file path or an url
6
+ pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
7
+
8
+ # size of vit model; base or large
9
+ vit: 'base'
10
+ vit_grad_ckpt: False
11
+ vit_ckpt_layer: 0
12
+ batch_size: 32
13
+ init_lr: 1e-5
14
+
15
+ # vit: 'large'
16
+ # vit_grad_ckpt: True
17
+ # vit_ckpt_layer: 5
18
+ # batch_size: 16
19
+ # init_lr: 2e-6
20
+
21
+ image_size: 384
22
+
23
+ # generation configs
24
+ max_length: 20
25
+ min_length: 5
26
+ num_beams: 3
27
+ prompt: 'a picture of '
28
+
29
+ # optimizer
30
+ weight_decay: 0.05
31
+ min_lr: 0
32
+ max_epoch: 5
33
+
configs/med_config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertModel"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "hidden_act": "gelu",
7
+ "hidden_dropout_prob": 0.1,
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "intermediate_size": 3072,
11
+ "layer_norm_eps": 1e-12,
12
+ "max_position_embeddings": 512,
13
+ "model_type": "bert",
14
+ "num_attention_heads": 12,
15
+ "num_hidden_layers": 12,
16
+ "pad_token_id": 0,
17
+ "type_vocab_size": 2,
18
+ "vocab_size": 30524,
19
+ "encoder_width": 768,
20
+ "add_cross_attention": true
21
+ }
configs/nlvr.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_root: '/export/share/datasets/vision/NLVR2/'
2
+ ann_root: 'annotation'
3
+
4
+ # set pretrained as a file path or an url
5
+ pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_nlvr.pth'
6
+
7
+ #size of vit model; base or large
8
+ vit: 'base'
9
+ batch_size_train: 16
10
+ batch_size_test: 64
11
+ vit_grad_ckpt: False
12
+ vit_ckpt_layer: 0
13
+ max_epoch: 15
14
+
15
+ image_size: 384
16
+
17
+ # optimizer
18
+ weight_decay: 0.05
19
+ init_lr: 3e-5
20
+ min_lr: 0
21
+
configs/nocaps.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_root: '/export/share/datasets/vision/nocaps/'
2
+ ann_root: 'annotation'
3
+
4
+ # set pretrained as a file path or an url
5
+ pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
6
+
7
+ vit: 'base'
8
+ batch_size: 32
9
+
10
+ image_size: 384
11
+
12
+ max_length: 20
13
+ min_length: 5
14
+ num_beams: 3
15
+ prompt: 'a picture of '
configs/pretrain.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train_file: ['/export/share/junnan-li/VL_pretrain/annotation/coco_karpathy_train.json',
2
+ '/export/share/junnan-li/VL_pretrain/annotation/vg_caption.json',
3
+ ]
4
+ laion_path: ''
5
+
6
+ # size of vit model; base or large
7
+ vit: 'base'
8
+ vit_grad_ckpt: False
9
+ vit_ckpt_layer: 0
10
+
11
+ image_size: 224
12
+ batch_size: 75
13
+
14
+ queue_size: 57600
15
+ alpha: 0.4
16
+
17
+ # optimizer
18
+ weight_decay: 0.05
19
+ init_lr: 3e-4
20
+ min_lr: 1e-6
21
+ warmup_lr: 1e-6
22
+ lr_decay_rate: 0.9
23
+ max_epoch: 20
24
+ warmup_steps: 3000
25
+
26
+
27
+
configs/retrieval_coco.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_root: '/export/share/datasets/vision/coco/images/'
2
+ ann_root: 'annotation'
3
+ dataset: 'coco'
4
+
5
+ # set pretrained as a file path or an url
6
+ pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth'
7
+
8
+ # size of vit model; base or large
9
+
10
+ vit: 'base'
11
+ batch_size_train: 32
12
+ batch_size_test: 64
13
+ vit_grad_ckpt: True
14
+ vit_ckpt_layer: 4
15
+ init_lr: 1e-5
16
+
17
+ # vit: 'large'
18
+ # batch_size_train: 16
19
+ # batch_size_test: 32
20
+ # vit_grad_ckpt: True
21
+ # vit_ckpt_layer: 12
22
+ # init_lr: 5e-6
23
+
24
+ image_size: 384
25
+ queue_size: 57600
26
+ alpha: 0.4
27
+ k_test: 256
28
+ negative_all_rank: True
29
+
30
+ # optimizer
31
+ weight_decay: 0.05
32
+ min_lr: 0
33
+ max_epoch: 6
34
+
configs/retrieval_flickr.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_root: '/export/share/datasets/vision/flickr30k/'
2
+ ann_root: 'annotation'
3
+ dataset: 'flickr'
4
+
5
+ # set pretrained as a file path or an url
6
+ pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_flickr.pth'
7
+
8
+ # size of vit model; base or large
9
+
10
+ vit: 'base'
11
+ batch_size_train: 32
12
+ batch_size_test: 64
13
+ vit_grad_ckpt: True
14
+ vit_ckpt_layer: 4
15
+ init_lr: 1e-5
16
+
17
+ # vit: 'large'
18
+ # batch_size_train: 16
19
+ # batch_size_test: 32
20
+ # vit_grad_ckpt: True
21
+ # vit_ckpt_layer: 10
22
+ # init_lr: 5e-6
23
+
24
+ image_size: 384
25
+ queue_size: 57600
26
+ alpha: 0.4
27
+ k_test: 128
28
+ negative_all_rank: False
29
+
30
+ # optimizer
31
+ weight_decay: 0.05
32
+ min_lr: 0
33
+ max_epoch: 6
34
+
configs/retrieval_msrvtt.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ video_root: '/export/share/dongxuli/data/msrvtt_retrieval/videos'
2
+ ann_root: 'annotation'
3
+
4
+ # set pretrained as a file path or an url
5
+ pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth'
6
+
7
+ # size of vit model; base or large
8
+ vit: 'base'
9
+ batch_size: 64
10
+ k_test: 128
11
+ image_size: 384
12
+ num_frm_test: 8
configs/vqa.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ vqa_root: '/kaggle/working/vision/mscoco' #followed by train2014/
2
+ vg_root: '/kaggle/working/vision/visual-genome' #followed by image/
3
+ train_files: ['vqa_train','vqa_val','vg_qa']
4
+ ann_root: 'annotation'
5
+
6
+ # set pretrained as a file path or an url
7
+ pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth'
8
+
9
+ # size of vit model; base or large
10
+ vit: 'base'
11
+ batch_size_train: 8
12
+ batch_size_test: 32
13
+ vit_grad_ckpt: False
14
+ vit_ckpt_layer: 0
15
+ init_lr: 2e-5
16
+
17
+ image_size: 480
18
+
19
+ k_test: 128
20
+ inference: 'rank'
21
+
22
+ # optimizer
23
+ weight_decay: 0.05
24
+ min_lr: 0
25
+ max_epoch: 10
data/__init__.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DataLoader
3
+ from torchvision import transforms
4
+ from torchvision.transforms.functional import InterpolationMode
5
+
6
+ from BLIP_main.data.coco_karpathy_dataset import coco_karpathy_train, coco_karpathy_caption_eval, coco_karpathy_retrieval_eval
7
+ from BLIP_main.data.nocaps_dataset import nocaps_eval
8
+ from BLIP_main.data.flickr30k_dataset import flickr30k_train, flickr30k_retrieval_eval
9
+ from BLIP_main.data.vqa_dataset import vqa_dataset
10
+ from BLIP_main.data.nlvr_dataset import nlvr_dataset
11
+ from BLIP_main.data.pretrain_dataset import pretrain_dataset
12
+ from BLIP_main.transform.randaugment import RandomAugment
13
+
14
+ def create_dataset(dataset, config, min_scale=0.5):
15
+
16
+ normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
17
+
18
+ transform_train = transforms.Compose([
19
+ transforms.RandomResizedCrop(config['image_size'],scale=(min_scale, 1.0),interpolation=InterpolationMode.BICUBIC),
20
+ transforms.RandomHorizontalFlip(),
21
+ RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize',
22
+ 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
23
+ transforms.ToTensor(),
24
+ normalize,
25
+ ])
26
+ transform_test = transforms.Compose([
27
+ transforms.Resize((config['image_size'],config['image_size']),interpolation=InterpolationMode.BICUBIC),
28
+ transforms.ToTensor(),
29
+ normalize,
30
+ ])
31
+
32
+ if dataset=='pretrain':
33
+ dataset = pretrain_dataset(config['train_file'], config['laion_path'], transform_train)
34
+ return dataset
35
+
36
+ elif dataset=='caption_coco':
37
+ train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'], prompt=config['prompt'])
38
+ val_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'val')
39
+ test_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'test')
40
+ return train_dataset, val_dataset, test_dataset
41
+
42
+ elif dataset=='nocaps':
43
+ val_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'val')
44
+ test_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'test')
45
+ return val_dataset, test_dataset
46
+
47
+ elif dataset=='retrieval_coco':
48
+ train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'])
49
+ val_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val')
50
+ test_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test')
51
+ return train_dataset, val_dataset, test_dataset
52
+
53
+ elif dataset=='retrieval_flickr':
54
+ train_dataset = flickr30k_train(transform_train, config['image_root'], config['ann_root'])
55
+ val_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val')
56
+ test_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test')
57
+ return train_dataset, val_dataset, test_dataset
58
+
59
+ elif dataset=='vqa':
60
+ train_dataset = vqa_dataset(transform_train, config['ann_root'], config['vqa_root'], config['vg_root'],
61
+ train_files = config['train_files'], split='train')
62
+ test_dataset = vqa_dataset(transform_test, config['ann_root'], config['vqa_root'], config['vg_root'], split='test')
63
+ return train_dataset, test_dataset
64
+
65
+ elif dataset=='nlvr':
66
+ train_dataset = nlvr_dataset(transform_train, config['image_root'], config['ann_root'],'train')
67
+ val_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'val')
68
+ test_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'test')
69
+ return train_dataset, val_dataset, test_dataset
70
+
71
+
72
+ def create_sampler(datasets, shuffles, num_tasks, global_rank):
73
+ samplers = []
74
+ for dataset,shuffle in zip(datasets,shuffles):
75
+ sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle)
76
+ samplers.append(sampler)
77
+ return samplers
78
+
79
+
80
+ def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
81
+ loaders = []
82
+ for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns):
83
+ if is_train:
84
+ shuffle = (sampler is None)
85
+ drop_last = True
86
+ else:
87
+ shuffle = False
88
+ drop_last = False
89
+ loader = DataLoader(
90
+ dataset,
91
+ batch_size=bs,
92
+ num_workers=n_worker,
93
+ pin_memory=True,
94
+ sampler=sampler,
95
+ shuffle=shuffle,
96
+ collate_fn=collate_fn,
97
+ drop_last=drop_last,
98
+ )
99
+ loaders.append(loader)
100
+ return loaders
101
+
data/coco_karpathy_dataset.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ from torch.utils.data import Dataset
5
+ from torchvision.datasets.utils import download_url
6
+
7
+ from PIL import Image
8
+
9
+ from BLIP_main.data.utils import pre_caption
10
+
11
+ class coco_karpathy_train(Dataset):
12
+ def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''):
13
+ '''
14
+ image_root (string): Root directory of images (e.g. coco/images/)
15
+ ann_root (string): directory to store the annotation file
16
+ '''
17
+ url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json'
18
+ filename = 'coco_karpathy_train.json'
19
+
20
+ download_url(url,ann_root)
21
+
22
+ self.annotation = json.load(open(os.path.join(ann_root,filename),'r'))
23
+ self.transform = transform
24
+ self.image_root = image_root
25
+ self.max_words = max_words
26
+ self.prompt = prompt
27
+
28
+ self.img_ids = {}
29
+ n = 0
30
+ for ann in self.annotation:
31
+ img_id = ann['image_id']
32
+ if img_id not in self.img_ids.keys():
33
+ self.img_ids[img_id] = n
34
+ n += 1
35
+
36
+ def __len__(self):
37
+ return len(self.annotation)
38
+
39
+ def __getitem__(self, index):
40
+
41
+ ann = self.annotation[index]
42
+
43
+ image_path = os.path.join(self.image_root,ann['image'])
44
+ image = Image.open(image_path).convert('RGB')
45
+ image = self.transform(image)
46
+
47
+ caption = self.prompt+pre_caption(ann['caption'], self.max_words)
48
+
49
+ return image, caption, self.img_ids[ann['image_id']]
50
+
51
+
52
+ class coco_karpathy_caption_eval(Dataset):
53
+ def __init__(self, transform, image_root, ann_root, split):
54
+ '''
55
+ image_root (string): Root directory of images (e.g. coco/images/)
56
+ ann_root (string): directory to store the annotation file
57
+ split (string): val or test
58
+ '''
59
+ urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json',
60
+ 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'}
61
+ filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'}
62
+
63
+ download_url(urls[split],ann_root)
64
+
65
+ self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
66
+ self.transform = transform
67
+ self.image_root = image_root
68
+
69
+ def __len__(self):
70
+ return len(self.annotation)
71
+
72
+ def __getitem__(self, index):
73
+
74
+ ann = self.annotation[index]
75
+
76
+ image_path = os.path.join(self.image_root,ann['image'])
77
+ image = Image.open(image_path).convert('RGB')
78
+ image = self.transform(image)
79
+
80
+ img_id = ann['image'].split('/')[-1].strip('.jpg').split('_')[-1]
81
+
82
+ return image, int(img_id)
83
+
84
+
85
+ class coco_karpathy_retrieval_eval(Dataset):
86
+ def __init__(self, transform, image_root, ann_root, split, max_words=30):
87
+ '''
88
+ image_root (string): Root directory of images (e.g. coco/images/)
89
+ ann_root (string): directory to store the annotation file
90
+ split (string): val or test
91
+ '''
92
+ urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json',
93
+ 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'}
94
+ filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'}
95
+
96
+ download_url(urls[split],ann_root)
97
+
98
+ self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
99
+ self.transform = transform
100
+ self.image_root = image_root
101
+
102
+ self.text = []
103
+ self.image = []
104
+ self.txt2img = {}
105
+ self.img2txt = {}
106
+
107
+ txt_id = 0
108
+ for img_id, ann in enumerate(self.annotation):
109
+ self.image.append(ann['image'])
110
+ self.img2txt[img_id] = []
111
+ for i, caption in enumerate(ann['caption']):
112
+ self.text.append(pre_caption(caption,max_words))
113
+ self.img2txt[img_id].append(txt_id)
114
+ self.txt2img[txt_id] = img_id
115
+ txt_id += 1
116
+
117
+ def __len__(self):
118
+ return len(self.annotation)
119
+
120
+ def __getitem__(self, index):
121
+
122
+ image_path = os.path.join(self.image_root, self.annotation[index]['image'])
123
+ image = Image.open(image_path).convert('RGB')
124
+ image = self.transform(image)
125
+
126
+ return image, index
data/flickr30k_dataset.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ from torch.utils.data import Dataset
5
+ from torchvision.datasets.utils import download_url
6
+
7
+ from PIL import Image
8
+
9
+ from BLIP_main.data.utils import pre_caption
10
+
11
+ class flickr30k_train(Dataset):
12
+ def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''):
13
+ '''
14
+ image_root (string): Root directory of images (e.g. flickr30k/)
15
+ ann_root (string): directory to store the annotation file
16
+ '''
17
+ url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_train.json'
18
+ filename = 'flickr30k_train.json'
19
+
20
+ download_url(url,ann_root)
21
+
22
+ self.annotation = json.load(open(os.path.join(ann_root,filename),'r'))
23
+ self.transform = transform
24
+ self.image_root = image_root
25
+ self.max_words = max_words
26
+ self.prompt = prompt
27
+
28
+ self.img_ids = {}
29
+ n = 0
30
+ for ann in self.annotation:
31
+ img_id = ann['image_id']
32
+ if img_id not in self.img_ids.keys():
33
+ self.img_ids[img_id] = n
34
+ n += 1
35
+
36
+ def __len__(self):
37
+ return len(self.annotation)
38
+
39
+ def __getitem__(self, index):
40
+
41
+ ann = self.annotation[index]
42
+
43
+ image_path = os.path.join(self.image_root,ann['image'])
44
+ image = Image.open(image_path).convert('RGB')
45
+ image = self.transform(image)
46
+
47
+ caption = self.prompt+pre_caption(ann['caption'], self.max_words)
48
+
49
+ return image, caption, self.img_ids[ann['image_id']]
50
+
51
+
52
+ class flickr30k_retrieval_eval(Dataset):
53
+ def __init__(self, transform, image_root, ann_root, split, max_words=30):
54
+ '''
55
+ image_root (string): Root directory of images (e.g. flickr30k/)
56
+ ann_root (string): directory to store the annotation file
57
+ split (string): val or test
58
+ '''
59
+ urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json',
60
+ 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json'}
61
+ filenames = {'val':'flickr30k_val.json','test':'flickr30k_test.json'}
62
+
63
+ download_url(urls[split],ann_root)
64
+
65
+ self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
66
+ self.transform = transform
67
+ self.image_root = image_root
68
+
69
+ self.text = []
70
+ self.image = []
71
+ self.txt2img = {}
72
+ self.img2txt = {}
73
+
74
+ txt_id = 0
75
+ for img_id, ann in enumerate(self.annotation):
76
+ self.image.append(ann['image'])
77
+ self.img2txt[img_id] = []
78
+ for i, caption in enumerate(ann['caption']):
79
+ self.text.append(pre_caption(caption,max_words))
80
+ self.img2txt[img_id].append(txt_id)
81
+ self.txt2img[txt_id] = img_id
82
+ txt_id += 1
83
+
84
+ def __len__(self):
85
+ return len(self.annotation)
86
+
87
+ def __getitem__(self, index):
88
+
89
+ image_path = os.path.join(self.image_root, self.annotation[index]['image'])
90
+ image = Image.open(image_path).convert('RGB')
91
+ image = self.transform(image)
92
+
93
+ return image, index
data/nlvr_dataset.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+
5
+ from torch.utils.data import Dataset
6
+ from torchvision.datasets.utils import download_url
7
+
8
+ from PIL import Image
9
+
10
+ from BLIP_main.data.utils import pre_caption
11
+
12
+ class nlvr_dataset(Dataset):
13
+ def __init__(self, transform, image_root, ann_root, split):
14
+ '''
15
+ image_root (string): Root directory of images
16
+ ann_root (string): directory to store the annotation file
17
+ split (string): train, val or test
18
+ '''
19
+ urls = {'train':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_train.json',
20
+ 'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_dev.json',
21
+ 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_test.json'}
22
+ filenames = {'train':'nlvr_train.json','val':'nlvr_dev.json','test':'nlvr_test.json'}
23
+
24
+ download_url(urls[split],ann_root)
25
+ self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
26
+
27
+ self.transform = transform
28
+ self.image_root = image_root
29
+
30
+
31
+ def __len__(self):
32
+ return len(self.annotation)
33
+
34
+
35
+ def __getitem__(self, index):
36
+
37
+ ann = self.annotation[index]
38
+
39
+ image0_path = os.path.join(self.image_root,ann['images'][0])
40
+ image0 = Image.open(image0_path).convert('RGB')
41
+ image0 = self.transform(image0)
42
+
43
+ image1_path = os.path.join(self.image_root,ann['images'][1])
44
+ image1 = Image.open(image1_path).convert('RGB')
45
+ image1 = self.transform(image1)
46
+
47
+ sentence = pre_caption(ann['sentence'], 40)
48
+
49
+ if ann['label']=='True':
50
+ label = 1
51
+ else:
52
+ label = 0
53
+
54
+ words = sentence.split(' ')
55
+
56
+ if 'left' not in words and 'right' not in words:
57
+ if random.random()<0.5:
58
+ return image0, image1, sentence, label
59
+ else:
60
+ return image1, image0, sentence, label
61
+ else:
62
+ if random.random()<0.5:
63
+ return image0, image1, sentence, label
64
+ else:
65
+ new_words = []
66
+ for word in words:
67
+ if word=='left':
68
+ new_words.append('right')
69
+ elif word=='right':
70
+ new_words.append('left')
71
+ else:
72
+ new_words.append(word)
73
+
74
+ sentence = ' '.join(new_words)
75
+ return image1, image0, sentence, label
76
+
77
+
78
+
data/nocaps_dataset.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ from torch.utils.data import Dataset
5
+ from torchvision.datasets.utils import download_url
6
+
7
+ from PIL import Image
8
+
9
+ class nocaps_eval(Dataset):
10
+ def __init__(self, transform, image_root, ann_root, split):
11
+ urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_val.json',
12
+ 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_test.json'}
13
+ filenames = {'val':'nocaps_val.json','test':'nocaps_test.json'}
14
+
15
+ download_url(urls[split],ann_root)
16
+
17
+ self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
18
+ self.transform = transform
19
+ self.image_root = image_root
20
+
21
+ def __len__(self):
22
+ return len(self.annotation)
23
+
24
+ def __getitem__(self, index):
25
+
26
+ ann = self.annotation[index]
27
+
28
+ image_path = os.path.join(self.image_root,ann['image'])
29
+ image = Image.open(image_path).convert('RGB')
30
+ image = self.transform(image)
31
+
32
+ return image, int(ann['img_id'])
data/pretrain_dataset.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ from torch.utils.data import Dataset
4
+
5
+ from PIL import Image
6
+ from PIL import ImageFile
7
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
8
+ Image.MAX_IMAGE_PIXELS = None
9
+
10
+ from BLIP_main.data.utils import pre_caption
11
+ import os,glob
12
+
13
+ class pretrain_dataset(Dataset):
14
+ def __init__(self, ann_file, laion_path, transform):
15
+
16
+ self.ann_pretrain = []
17
+ for f in ann_file:
18
+ print('loading '+f)
19
+ ann = json.load(open(f,'r'))
20
+ self.ann_pretrain += ann
21
+
22
+ self.laion_path = laion_path
23
+ if self.laion_path:
24
+ self.laion_files = glob.glob(os.path.join(laion_path,'*.json'))
25
+
26
+ print('loading '+self.laion_files[0])
27
+ with open(self.laion_files[0],'r') as f:
28
+ self.ann_laion = json.load(f)
29
+
30
+ self.annotation = self.ann_pretrain + self.ann_laion
31
+ else:
32
+ self.annotation = self.ann_pretrain
33
+
34
+ self.transform = transform
35
+
36
+
37
+ def reload_laion(self, epoch):
38
+ n = epoch%len(self.laion_files)
39
+ print('loading '+self.laion_files[n])
40
+ with open(self.laion_files[n],'r') as f:
41
+ self.ann_laion = json.load(f)
42
+
43
+ self.annotation = self.ann_pretrain + self.ann_laion
44
+
45
+
46
+ def __len__(self):
47
+ return len(self.annotation)
48
+
49
+ def __getitem__(self, index):
50
+
51
+ ann = self.annotation[index]
52
+
53
+ image = Image.open(ann['image']).convert('RGB')
54
+ image = self.transform(image)
55
+ caption = pre_caption(ann['caption'],30)
56
+
57
+ return image, caption
data/utils.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import json
3
+ import os
4
+
5
+ import torch.distributed as dist
6
+
7
+ from BLIP_main import utils
8
+
9
+
10
+ def pre_caption(caption,max_words=50):
11
+ caption = re.sub(
12
+ r"([.!\"()*#:;~])",
13
+ ' ',
14
+ caption.lower(),
15
+ )
16
+ caption = re.sub(
17
+ r"\s{2,}",
18
+ ' ',
19
+ caption,
20
+ )
21
+ caption = caption.rstrip('\n')
22
+ caption = caption.strip(' ')
23
+
24
+ #truncate caption
25
+ caption_words = caption.split(' ')
26
+ if len(caption_words)>max_words:
27
+ caption = ' '.join(caption_words[:max_words])
28
+
29
+ return caption
30
+
31
+ def pre_question(question,max_ques_words=50):
32
+ question = re.sub(
33
+ r"([.!\"()*#:;~])",
34
+ '',
35
+ question.lower(),
36
+ )
37
+ question = question.rstrip(' ')
38
+
39
+ #truncate question
40
+ question_words = question.split(' ')
41
+ if len(question_words)>max_ques_words:
42
+ question = ' '.join(question_words[:max_ques_words])
43
+
44
+ return question
45
+
46
+
47
+ def save_result(result, result_dir, filename, remove_duplicate=''):
48
+ result_file = os.path.join(result_dir, '%s_rank%d.json' % (filename, utils.get_rank()))
49
+ final_result_file = os.path.join(result_dir, '%s.json'%filename)
50
+
51
+ json.dump(result,open(result_file,'w'))
52
+
53
+ dist.barrier()
54
+
55
+ if utils.is_main_process():
56
+ # combine results from all processes
57
+ result = []
58
+
59
+ for rank in range(utils.get_world_size()):
60
+ result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,rank))
61
+ res = json.load(open(result_file,'r'))
62
+ result += res
63
+
64
+ if remove_duplicate:
65
+ result_new = []
66
+ id_list = []
67
+ for res in result:
68
+ if res[remove_duplicate] not in id_list:
69
+ id_list.append(res[remove_duplicate])
70
+ result_new.append(res)
71
+ result = result_new
72
+
73
+ json.dump(result,open(final_result_file,'w'))
74
+ print('result file saved to %s'%final_result_file)
75
+
76
+ return final_result_file
77
+
78
+
79
+
80
+ from pycocotools.coco import COCO
81
+ from pycocoevalcap.eval import COCOEvalCap
82
+ from torchvision.datasets.utils import download_url
83
+
84
+ def coco_caption_eval(coco_gt_root, results_file, split):
85
+ urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json',
86
+ 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json'}
87
+ filenames = {'val':'coco_karpathy_val_gt.json','test':'coco_karpathy_test_gt.json'}
88
+
89
+ download_url(urls[split],coco_gt_root)
90
+ annotation_file = os.path.join(coco_gt_root,filenames[split])
91
+
92
+ # create coco object and coco_result object
93
+ coco = COCO(annotation_file)
94
+ coco_result = coco.loadRes(results_file)
95
+
96
+ # create coco_eval object by taking coco and coco_result
97
+ coco_eval = COCOEvalCap(coco, coco_result)
98
+
99
+ # evaluate on a subset of images by setting
100
+ # coco_eval.params['image_id'] = coco_result.getImgIds()
101
+ # please remove this line when evaluating the full validation set
102
+ # coco_eval.params['image_id'] = coco_result.getImgIds()
103
+
104
+ # evaluate results
105
+ # SPICE will take a few minutes the first time, but speeds up due to caching
106
+ coco_eval.evaluate()
107
+
108
+ # print output evaluation scores
109
+ for metric, score in coco_eval.eval.items():
110
+ print(f'{metric}: {score:.3f}')
111
+
112
+ return coco_eval
data/video_dataset.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ from torchvision.datasets.utils import download_url
3
+
4
+ import torch
5
+ import numpy as np
6
+ import random
7
+ import decord
8
+ from decord import VideoReader
9
+ import json
10
+ import os
11
+ from BLIP_main.data.utils import pre_caption
12
+
13
+ decord.bridge.set_bridge("torch")
14
+
15
+ class ImageNorm(object):
16
+ """Apply Normalization to Image Pixels on GPU
17
+ """
18
+ def __init__(self, mean, std):
19
+ self.mean = torch.tensor(mean).view(1, 3, 1, 1)
20
+ self.std = torch.tensor(std).view(1, 3, 1, 1)
21
+
22
+ def __call__(self, img):
23
+
24
+ if torch.max(img) > 1 and self.mean.max() <= 1:
25
+ img.div_(255.)
26
+ return img.sub_(self.mean).div_(self.std)
27
+
28
+ def load_jsonl(filename):
29
+ with open(filename, "r") as f:
30
+ return [json.loads(l.strip("\n")) for l in f.readlines()]
31
+
32
+
33
+ class VideoDataset(Dataset):
34
+
35
+ def __init__(self, video_root, ann_root, num_frm=4, frm_sampling_strategy="rand", max_img_size=384, video_fmt='.mp4'):
36
+ '''
37
+ image_root (string): Root directory of video
38
+ ann_root (string): directory to store the annotation file
39
+ '''
40
+ url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/msrvtt_test.jsonl'
41
+ filename = 'msrvtt_test.jsonl'
42
+
43
+ download_url(url,ann_root)
44
+ self.annotation = load_jsonl(os.path.join(ann_root,filename))
45
+
46
+ self.num_frm = num_frm
47
+ self.frm_sampling_strategy = frm_sampling_strategy
48
+ self.max_img_size = max_img_size
49
+ self.video_root = video_root
50
+ self.video_fmt = video_fmt
51
+ self.img_norm = ImageNorm(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
52
+
53
+ self.text = [pre_caption(ann['caption'],40) for ann in self.annotation]
54
+ self.txt2video = [i for i in range(len(self.annotation))]
55
+ self.video2txt = self.txt2video
56
+
57
+
58
+ def __len__(self):
59
+ return len(self.annotation)
60
+
61
+ def __getitem__(self, index):
62
+
63
+ ann = self.annotation[index]
64
+
65
+ video_path = os.path.join(self.video_root, ann['clip_name'] + self.video_fmt)
66
+
67
+ vid_frm_array = self._load_video_from_path_decord(video_path, height=self.max_img_size, width=self.max_img_size)
68
+
69
+ video = self.img_norm(vid_frm_array.float())
70
+
71
+ return video, ann['clip_name']
72
+
73
+
74
+
75
+ def _load_video_from_path_decord(self, video_path, height=None, width=None, start_time=None, end_time=None, fps=-1):
76
+ try:
77
+ if not height or not width:
78
+ vr = VideoReader(video_path)
79
+ else:
80
+ vr = VideoReader(video_path, width=width, height=height)
81
+
82
+ vlen = len(vr)
83
+
84
+ if start_time or end_time:
85
+ assert fps > 0, 'must provide video fps if specifying start and end time.'
86
+
87
+ start_idx = min(int(start_time * fps), vlen)
88
+ end_idx = min(int(end_time * fps), vlen)
89
+ else:
90
+ start_idx, end_idx = 0, vlen
91
+
92
+ if self.frm_sampling_strategy == 'uniform':
93
+ frame_indices = np.arange(start_idx, end_idx, vlen / self.num_frm, dtype=int)
94
+ elif self.frm_sampling_strategy == 'rand':
95
+ frame_indices = sorted(random.sample(range(vlen), self.num_frm))
96
+ elif self.frm_sampling_strategy == 'headtail':
97
+ frame_indices_head = sorted(random.sample(range(vlen // 2), self.num_frm // 2))
98
+ frame_indices_tail = sorted(random.sample(range(vlen // 2, vlen), self.num_frm // 2))
99
+ frame_indices = frame_indices_head + frame_indices_tail
100
+ else:
101
+ raise NotImplementedError('Invalid sampling strategy {} '.format(self.frm_sampling_strategy))
102
+
103
+ raw_sample_frms = vr.get_batch(frame_indices)
104
+ except Exception as e:
105
+ return None
106
+
107
+ raw_sample_frms = raw_sample_frms.permute(0, 3, 1, 2)
108
+
109
+ return raw_sample_frms
data/vqa_dataset.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from PIL import Image
4
+
5
+ import torch
6
+ from torch.utils.data import Dataset
7
+ from BLIP_main.data.utils import pre_question
8
+
9
+ from torchvision.datasets.utils import download_url
10
+
11
+ class vqa_dataset(Dataset):
12
+ def __init__(self, transform, ann_root, vqa_root, vg_root, train_files=[], split="train"):
13
+ self.split = split
14
+
15
+ self.transform = transform
16
+ self.vqa_root = vqa_root
17
+ self.vg_root = vg_root
18
+
19
+ if split=='train':
20
+ urls = {'vqa_train':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_train.json',
21
+ 'vqa_val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_val.json',
22
+ 'vg_qa':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vg_qa.json'}
23
+
24
+ self.annotation = []
25
+ for f in train_files:
26
+ download_url(urls[f],ann_root)
27
+ self.annotation += json.load(open(os.path.join(ann_root,'%s.json'%f),'r'))
28
+ else:
29
+ download_url('https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_test.json',ann_root)
30
+ self.annotation = json.load(open(os.path.join(ann_root,'vqa_test.json'),'r'))
31
+
32
+ download_url('https://storage.googleapis.com/sfr-vision-language-research/datasets/answer_list.json',ann_root)
33
+ self.answer_list = json.load(open(os.path.join(ann_root,'answer_list.json'),'r'))
34
+
35
+
36
+ def __len__(self):
37
+ return len(self.annotation)
38
+
39
+ def __getitem__(self, index):
40
+
41
+ ann = self.annotation[index]
42
+
43
+ if ann['dataset']=='vqa':
44
+ image_path = os.path.join(self.vqa_root,ann['image'])
45
+ elif ann['dataset']=='vg':
46
+ image_path = os.path.join(self.vg_root,ann['image'])
47
+
48
+ image = Image.open(image_path).convert('RGB')
49
+ image = self.transform(image)
50
+
51
+ if self.split == 'test':
52
+ question = pre_question(ann['question'])
53
+ question_id = ann['question_id']
54
+ return image, question, question_id
55
+
56
+
57
+ elif self.split=='train':
58
+
59
+ question = pre_question(ann['question'])
60
+
61
+ if ann['dataset']=='vqa':
62
+ answer_weight = {}
63
+ for answer in ann['answer']:
64
+ if answer in answer_weight.keys():
65
+ answer_weight[answer] += 1/len(ann['answer'])
66
+ else:
67
+ answer_weight[answer] = 1/len(ann['answer'])
68
+
69
+ answers = list(answer_weight.keys())
70
+ weights = list(answer_weight.values())
71
+
72
+ elif ann['dataset']=='vg':
73
+ answers = [ann['answer']]
74
+ weights = [0.2]
75
+
76
+ return image, question, answers, weights
77
+
78
+
79
+ def vqa_collate_fn(batch):
80
+ image_list, question_list, answer_list, weight_list, n = [], [], [], [], []
81
+ for image, question, answer, weights in batch:
82
+ image_list.append(image)
83
+ question_list.append(question)
84
+ weight_list += weights
85
+ answer_list += answer
86
+ n.append(len(answer))
87
+ return torch.stack(image_list,dim=0), question_list, answer_list, torch.Tensor(weight_list), n
demo.ipynb ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 3,
6
+ "id": "a811a65f",
7
+ "metadata": {
8
+ "ExecuteTime": {
9
+ "end_time": "2024-05-01T09:19:50.038353Z",
10
+ "start_time": "2024-05-01T09:19:42.352132Z"
11
+ }
12
+ },
13
+ "outputs": [],
14
+ "source": [
15
+ "from PIL import Image\n",
16
+ "import requests\n",
17
+ "import torch\n",
18
+ "from torchvision import transforms\n",
19
+ "from torchvision.transforms.functional import InterpolationMode\n",
20
+ "from lzma import FILTER_LZMA1\n",
21
+ "try:\n",
22
+ " from _lzma import *\n",
23
+ " from _lzma import _encode_filter_properties, _decode_filter_properties\n",
24
+ "except ImportError:\n",
25
+ " from backports.lzma import *\n",
26
+ " from backports.lzma import _encode_filter_properties, _decode_filter_properties\n",
27
+ "\n",
28
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
29
+ "\n",
30
+ "def load_demo_image(image_size,device):\n",
31
+ " img_url = 'https://upload.wikimedia.org/wikipedia/commons/thumb/0/0a/Sukhoi_T-50_Beltyukov.jpg/800px-Sukhoi_T-50_Beltyukov.jpg' \n",
32
+ " raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB') \n",
33
+ "\n",
34
+ " w,h = raw_image.size\n",
35
+ " display(raw_image.resize((w//5,h//5)))\n",
36
+ " \n",
37
+ " transform = transforms.Compose([\n",
38
+ " transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),\n",
39
+ " transforms.ToTensor(),\n",
40
+ " transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))\n",
41
+ " ]) \n",
42
+ " image = transform(raw_image).unsqueeze(0).to(device) \n",
43
+ " return image"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "execution_count": 5,
49
+ "id": "5e6f3fb1",
50
+ "metadata": {
51
+ "ExecuteTime": {
52
+ "end_time": "2024-05-01T09:22:02.157115Z",
53
+ "start_time": "2024-05-01T09:21:26.706787Z"
54
+ }
55
+ },
56
+ "outputs": [
57
+ {
58
+ "data": {
59
+ "text/plain": "<PIL.Image.Image image mode=RGB size=160x106>",
60
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAKAAAABqCAIAAAC03s2oAABODklEQVR4AbW96ZNk2XXYl/vLPbP2rqrel+meQc8+AAYgFg4AEZRFSqQVdNhh2pb9wYsi7HDoD7BF+QP5hWGHRYc+WGEETSEky5RpEhBIgAKBwTKYtXu23ruqu6pr33PfM/0797x38+VSPQ0SflP98r5zz3bPuffc9b0J/vN/uxQIBILBYKAXCPT4Z9L8mKsHuNeTXH30pRXCPch/LiEPwoELQvfqhdxEED4e0Pv10H34RgG/UA8Xpkga0IcsFYSGlsTytIRSuF7Xj2Oz/AnlwN0APbXNg7GAGIgnaw3SigxkVGhP1O0X2GM7nlzYupYbNIUHFAnG+Fai0UtuItvIkVu37yxFcIsh2hjV/dorhrIeygXYvxTv+DuYQn48wmiekoylULljs35xwMcpa12FOFcZt048Tv5fR+1BtnAY9a5fJOjS2AYvr566lWAg06s9bqXzF8ziAaRyHK+95ICsd0ulCdXFZgmquchVWVaim/CUt3Dh7GPqh/vAkoSxvQ9l2Udw4KA6KL7NAk4M0MehLAEa5gbH2swFkumHqwggykrviqBtzA9Xtn6I6jZM3hcxxouR0ICJ/NxMBiSqvTVk17R1ECWm2Kg7oLGPC2oTGwWgsRTlrIGgIUfVDYbc6qm0imMxbcLH2UuaQhkJqqgHN78+Qi2A6DyAIY6QLFcN1VVwpGieKQXBx2qIgTwaAlMcktrUjOm0vB4fy9CQeG5WzjZEa8L2cVausvLLdiGCaooQpA6KVNEm4DZdrwX36YbL388xKeR5Ij8BU8T7zMIj1xC3sY8e/37mkxCql/o0XsqIfSK5HsXwLxZz/W5yxisj4o2bB6kpC/haIj+hAoFr1iDRmCfFVw5KopABVAkCgAesHOljBN0QZPxv9DXV1sX3igiL0UDfZ+JLoYcfGSp5NHzg+ZiCaa7FEecYKtVJ2PqK4GcIJjiWEDo/pqa9clhF/SgWOJDwM9QMP40tiCpm1bNwSDzRrlUHuH/Sg5XuFwqRFTfIQEWA69bpiBYY0yuekGEl98lvLKExSrtawmHEWLaqQy8shgokNQOwYFknGRRunkIaoJRcJSrEsrKqq8IujqlMwl5kWNx+2T3lBaKi3IKbQNrHG0yZ8nq2GMyyT6qPr0T+0vkNaCkk4ddnIENL7gMZHeRZE0qoQgEOSJdyu2W3Jojo8Nw1vPAxKCOugxUwmWwEwgZr1LuiQl+AkeCKVwJjV8BSgax8m1Acezf1gCc4PGaUAIJaVgo/RmfLThLWFqKCPGouVU0TPlYuQH+8bPMEE2txP9YQfCyOH/9x6eNsYmi0vMeTS3PQJqHlDemPFlLTo8TANWtAb9OCgXiXS6e4Fl+11UfBEDeIyZTKpTGPNq0JF8FUvaEs/yOcwRyC6KMK7Ys20Mc8+rNI+y8/Q9JkKcR/H8Un1+qmJJbQwv0cFN8WhoQfzfL3s1Vyy1ay3H+aE6AP7odAPzuT7xZDDDhoRMm1ihgsaZfiusGSDz4ZnsIKhYZljZrM0GpjJ6nSbKKvgEGTEhoNDWdVw9NHojcXIwyB+KKHQEcuGBh0VyD5ho/CPGyUN4JUKQ9qfofL5WX2SVxWRmHSY3gYmpGsPodjapjf9la9iDhGZCHHG2R5OgnS6Cjby7UlUQ4u2KeuaylTnj6y8YRfVyFUHMNCxvzmErUM3CiIIdwMhblIxj5aGOAqUfsSQdeO2eVDPmLV2ULt1muPrUCkD7LsPT36HpBKCYLiDBfBKnR8QsVDTgJpWjorzyjg1RtVQ1R22RkSSXvkUhh9dDH0x9OfXNVz0IGeNPIk21d4ywVKLsPawsYkFE0x5a4crd5QuNY3tF4xhhgpuaB6llU2Q2j6KAp7KY9Qqo5tsuNKIwQ+Oo+DtavLsW9N5az69KUYNB6HIB61/Ip6RtLjcci1BfGnLSs/ucUczbVZ/WmSFI7lYlOzPS7ArO1t4fvcSLmtDCzF9dARoExESzGzZGimtFG31VvjW57DCUPeBypPy1wzjGauOFuwPg2a8DAmFmmJhiX0CcelhkSPQxkPU83FXp/kZks/pixenihtK6YHHPsrDnatJikRL4+uN4Nez0WWC7L6YTZAKkXkGafJr3upN2XeIsSGrebJg3hcUIXKhbr4HrmoZXkJC68BKiuLZvgbHr4CiwBI8Cq/SujnZYgNH8m2rEza/9jPMQxdFqqA6qNpxTM4fRJSVmeFKushoJ8DaKNMFAKVP0upLMTy9CMrUBzs4klzMv/5bIH5BTRyQeLD8jgIUMBWnqUbhNgnw9ywh7QP1dqgThI3D+hg0Szz4xNu5BiL4IspY/NHgSg6XLrHK6NGUhxJj8O2MFiPMbSnhd8+HqzPz+ZabjYxvg8eZWEhf82EqQ5a4LEcrEKjudJvDFSnUZQxEGE4WOvHIP18oH4VVDqK85gSPZ63sYegKAeXz2M8/Hh2hs9YZfqbDSxUi1FoL9KAXH7YVq+gt63bC3mbKkDUbbSxgWhqOEiFtGyEh2Fr4kwf7HF3hbg/Sueub7iWsL2Gx2eE1K8CHEScUUvlWgm+ElmYTSBMvKixw4P2Jxew8oDyO1g+f85AWqkU2UoHqLwU7uYi3oQrWziLL3FsULpfhksuTIdrYb8FK5Jfb8NzoEh9pq4K8oO/LK2HYODeA7+iXN+vvgyShn4Q1H8yOsvNgvqygPXBxt79PH+Vs+ayPCQBrg+9n2Vk+fn2s/7GqSGNB/hpGceqNIBnHoaKZGkhJ62XUvUdPMQFJDXCEFwfxaeSGmMIY7bhemT1tgnDX1jAZ0hdgfpYj5KQK1FjiNAUq1+NtBEL6zFKioCRywry5/iBpLn85vNjPj6t2mHUUXKFBWWcj7q20Y7hB60LHVQDrSy20bH/yFq0F4KEs8mQYKdR2mPHs1cT3EiIJJwMV+kfJX6GZANSYnWPUyOS4y6CmoGskW6066sCCpSCa6QYyaqlxZEMSjJ0N0hAQYOlJgTmcVC4IeI2XBCVJhmBEJhYxpzmAeBeApJFHzEHpuc/LoDqGMkNCYKApeSeVMF6ksuwEyphL4UQQ3KpIAUcy1TU8LDR/knk9efBT4Ltx9HaJCWUajWmpGopNYUQGlvyK16RPKPf8Uoqmt6tXP8jaQv/hST8zKXOisZeIeRhQFcX2VPBT/tYZWBjWmqnFQpjh6i0BDlKBVFL7IIcw1NKJwIHhPo5gwUm/1nFhGRET5+DtW4cy9BlPsROOfpNjY40YvyomOaMgXhZPapKqDnEYqYwfr2h8nJN+QbNaqXbhJ+WtJ92DPdBbB/ygB7AQVTJqo+lkywvw63ZNu+xCeETDItpmket0m6zchAOB0PhWC8YDcfisXgyFk22IumOeBjFQbex5ji+ognoWgqQhMRorhC99x3cM7FalbcsLY26R73vcjQPUocU6iVEBBBkGS6Wg/IUlcwlqtgHBXl3JbF3wK5ED0Eh+jTEwzxab9mEj7KfpG8SZUyIdrUa4qnMVZN+lmk3ZpA/YNw+Y5Ny24t0Xr1OoFev1wPVYrB1EGoUA516qNsK9ULBWoesbjBU7wWyuXx8+qlKYCrYa4/qrdz8IgyOW9vEmObyK6zpvoM9y/uZSFrxLNTyAqJcabHifs9Eim8UkvxR31hWj0nA5AkJVZ8hJT3VHiNBs0RDV2HTcfgJHqcDdKa8I3L9DNx0LxQuF/c3V26Gm0eJYD2ZTAbDoU6b+WbP4ShaOBAKhUPBULlUjDubTi7VCjgcrtLiq4XHMDUgta9neAH5jaaW8cZO43jYkvszbZEGZEtH7F4ixjwplcX3MyEtcK/e6SMQAY67joMrLrlamHGkx8KUpyoN0igT5TmW8+P1cbXqBSQGR6O1VvOoWGq2uvVGI9JscoYGfeNEZScSYDPPVEaKTZM4ONzsVraE3EQ38gaMrHxH74NItkSK2G/Bo4RAwB4qjEK421Gc6OFD07Q4akiwNnR1qudIUDQJlV8BMQJtRGHG6X4RFtPFslxMhhAZpqYTG2BrCUn4JWp6qKR+ZLL8+P4sTbvzAayFKXBsKNztdvcPDh8t32hXtjKR7mw+ls84keWtVrcVWDwd7DalS2TAFRLsRrO1vXcYjIQjpdunnp4NhmLYQ/SBm04zjBEeo8ZY+6Db4xzsp7FplSGyvUtS/kdf2kORXy+I+8YOiqn3QQcbfJd61EWGldQ8MTrkqovFswoAH2HrMn3iH39JH0skzogygwrGq/Xy1vryxvKHpZ2Hz1w6mZ9PRwOd9vYufW3PSYWD4cjDR81aLXj5TDAcFWWD4WL5sEPca7WD4a7MwqRMUirv7lZ1Ke/oZZHH1cJjHTxUMP+jPz0qbhRiK4ebEM37l3FS/1FTpiSir/jPffCTuXEFTeishvUxgyCA/PpphmWYZ0HzrEbCz8qftrQGn6cBQ+OScDjc6Aa297bXH97deXQ7UN7ttZtnTs3Oz03UWp1Oux3vBbqVSvDsmVgo0lpfcy6eajGWCstsvFqrVxvtQDjS63aisQRxXUpgHGzlPj7h19xfIqjGO3hs2R4v4xNzrRLG6NysjQZCsfDp56iD+uoordR6zyv+tKeDwGBifizMx9aD6e+QRQYztXbBSjUhE+WoCJzTxwNE4zDtr1IpbT66v71yq3mwEew0Ij0mO51kIjy/eKre6jHP7RFyTy8Eb96ObG2R0S2Vw4FowImxHdtutQ8OCziVcRdrKHEnEQxGuoy7n9zD1lxDqpvH8Q62zpACefWahEjtF3Ucv3EwJbGERBTLyiTcYKvhaBwDFyYIooMQmYZn9Bmp6VoJhMZXckUF5ushDIopndVNQIMXWcjjkoQ51kT7kr4zGmnVqweba+sPbu1vrQSahUSgGQ2F692mHD3tdU9fuBB1QrLSJg011G63ovNzvcnpbtiJzkxJ1xsJMzUuHhVbdMaCiH/DsVQy6iSDnW430Om26sGed4bVVxZVcCg6oaELH3TTeAcPltF9eowVHoOvgi2tfbQJaElbBGU1APEaD0FOGo1Xw2xirPQnBFom4gepN54wrUeGiwJVYUJxJ9irlY9W7t3YXb3XrR2Gg20HhwDEUbJcKxxm56YmZ2Y6lAtPibvbvUi4Mz3d7XR67Uo3kuiGQ9FOp9JoFBstVouDkVix1vj4+pvORx+dOPneRP5kfmrm9PmrT1gKi4aqqqfegY9xsL+QQmmKTOm1hSnlMI6V4EuA6UcbIlRV7N1H95gkqsAGtuAYtR6D+8lZwkGrPUUzLW2kmSBGhDHODXV63Wa1UNzdqB5tx7plp7zfKK4EQ5E2KnWl5oFKk8RMiVjs6tUrIXrlejPQ7YRD4Va4G2g36o22E48brHaoE2aJ4/CwtLm1m87lizubH15/K1AtdILBjbvvdbuhcDzx3/yj/zU+MUfH/MlFOR5jjIMtsrGlPIl37XDGWBdLDMU6pVKH+dMaSUz9GGgf4Bjbyd34bMBhmuXhuNHH1DDQEG7qjusOdypPY1E+YmsvXqkmQmCWgAUucsSbEhIlIAgziQ3CxjwbdwFGB7kCwXazcXi4Xdh52Knt5pzeXBI3xRoTJ1Y3N5uNVhjOdJvgsXQRCrfbjV/52hc+99lXi7X6tfevl0olhlEMjePJbLvZbLUZTIW69NGRYK/TmZ5M56cmv/v6z9aWPowF26FoVLQIy1S52Wru7qyemVlot8Y7eNT+UElRzKVmJ+laR6HcpUT2MmkgCtS7zRybsALI9adHkTWXO9djOJMLrWjgKmMqxAg7QRsMsH4U8ZYgGByKjkThyjK51AoGQaw74h76wUS4N52KzKVDC5ngXKITbx5U1m/u3f5Ja/3dqfD+4kQ8nUoy+D2qNCqVZjaVYYgVCYejYbrTCCsD3U77ylMXXv3c59lRy6bSn375lYlsjglxNBqdmpoKRSKRWJyZb7DbDgVpq8F2MPrutXc3br8TDTRRCKW0oGja7ba3Hi1JSJCSmJskQJDdLCmDD2hyAPXRbPpxLVjJ/np30ZX/3EsEm3+uLxFPScydHFPvBFvxDWmfULwj5F5J1dFuvvxIjnDzqrQ89rPhKhnCBaEmh3VBsaMMX+S9mCjNKdSNh4PxaDCfiGTj4XiMIVS4Vq/vR8uleKKQmb6/dLT0aKNGN9tqxRgotzuNDoYORiM0XYngEqh7nWQ08re//rVIzGnVW91uLx1PvPLSy9c/+KBYKu7t7UmhQ2Habjgm06pCsfTt735n69GyE40RP6QbYA2TLl2aebcdaG+uLdGFC5Hs3ojK3Bh5dTqteDyF4lSdMQFczCrYaoOI3entb9yaHDWoaG9MaAn6phuXEubC2phSLC2rNVxmO0wM3wuY/RJMDE6IWh9m3zgQbBn3G09LSvAMLcUjxpj6Kj/w0pCj2SKLZVtTFiq0yBJ8UzwaCmbpdBqhbs0EPsaj3ZBMP9qtVmt/f79YrLAEzJElutB2u5t0IlO5dDWXqaXi8XgkGonWG/VCsVgoFnYOjqq1JuOkVqPOKEn293pdqgUTJFM4MUSYzrjR+OKXv3j67Nlmq6XmZSycSCQ+/fLL716/Bp9wJEa8wDgRx3m0sfPn3/mz0uFmPMaOIcUQrfEw/6kFI5Ho7tbDdq0QimUZsnekKIV6ebdytNlsVlvdWLsXP3XmSjo3g0G7nZZXvZWZ2E99cWwLFpne5U97sPG/ikm4wyedUDcSTqAZtQwzhkV9XadhytetVov12lG5XAmFnFDUQVoilUjFUw6zw14HX0GhPaR4jmzx6YBQZJnK6gJl6de0gm6nWTrarx7tduvllNPOOsxWQ9QCBroUGipWj2Zj3WBn++HSw1KpSrMs1NrUjLgTncgkc+lkPpPMJmlXIAar9fpBEVXr+I99Hupji56WPT6z+dYNdmVVUsJ+6+TCzBe/9KVmuWomx1IXIaeR4eNXXn7l2vVrR+Ui9SAccj64fev7f/nnnWY9GokZh4BoVDclgoyqGQ2Gy6XDYmEvM5Wulzcb5b1e/SjYrTmhIBX0zq13OrXa3es/yi+ce+qZT88tnqcy9X1svKt2kf7cMxsmNDVSqpFoRp0iBGAUgwCaJqR1SmM0XKShuwzUCXRlFBkXR0Kt1vL7Pz1xZjE9MdtsBXYKleLRwcHB7sOle4/uLwUjjYuXTwc7AZbjp2annr3y1BtvP6JiJFOpS/NTaceJJDPRRK4XikoPaYZEIlC8LGqoKqoDMAlvwV6rVijvbxYO1hvVg1QsOJmJp2MxRrO9IEsKRAp8TFyVNUGcfOrcqcVTi7vbu2tr64elMqvBcIhGqSeypEidwOZUmWw246QSDKaajImbzXaboVIb5I69TOiIBrsE53gi3mq20QpCNomwIaEYC6dSyZdffvn9Dz84Kjde/9lP3/zp61Q5h45bYoAUy+gvboYW85ofBne11btvn7tw1G4UolR6GIZjnWZ7dX0t0A468QwR+mD1xm4qs3jyYlsGjXIORTzichE2/SM7CKvXys16JdBtVHe2d/ceXHzuC5FEKhp2QhEMBHakg0MITIxLGB8Q34LtaCQUC4Vi9Czy14uEIg9X1m/fubtfrDx6uPzj73zz5NlTgcRMuRo4Kuwno71ysVKtVBORyJnTM8ul1UerG41m5blPXXDKX/jJd663wizvBBJfeeXszEQklmmEnMzM2UYwzgwzGEskkmmxFh4Q23Xo9iQc8zprp0HdKR2sVkv7rXrZiQRP5Jx8NkktEI/S6UkNIZKBbmK4MUCrI9ZfOH1ybmFuY3NrfWOz02qH6VQdNuHDBEFsKpW1I+sQ1GhckUyyzBQi+NPjaKzH3fVmp1qtfvrFV648dUU4iG+pRhKu0DCCXeh3u520E7t08co//t3fu/nRdXGtVB7pbsQJ4lm5RI7bnARG3tbqrUsXT9OXUcxemyoYWV1bPTo8cFhDkS/qiIBEJi8hTnpB04aJnspWiunNgw33bpwaGIsHOs3q7m4yEdvbXIrEEnRQnVajuLsyPTc3u3ip0Y0FYr0wPV+nufro7tkL55m0EYW7bJLIukzlRz/44cPlOyEnTiQ78/wLjcPy3u3boUSEiNQMtqLB9sJEAv8U91diTiTSKBcKhbfeOnj33Q9oZ4nMTGxu/t791fJhKZ/LzWQTkfYhoTIXj67c2WpMzAeiqVA4kcrmzpw6xdpvq9093N862NuoVQ7ppHBJJhWayGYiEoHEVjiILpalCbyCvcSGxp7cxSAgtduU/RSrxnOzD5YfUq/xZ9cJEfForeJgelypUwRj2ge/ATodbCh+pDqHnXik/fzTF3/tN35TnINITC01C+69SI8hDqLYYkjfefDgd37vd+/dusksWTga58Ld6qNOEQ5GQxLEnr2dXdq3dHO9npOIrzza3FzbiDGbkqGDqbfBYDyRkgV5jWqmgHCw10AfbASjuzP/qRdDgVcIZVDRITCyqBcu9QLNcrO59eD6+sr7gVA0lp5wUsnC+/usyVAf6dxQ7N233zva208l47J+I31QLDmRCVSdvf3trsSHNpGmQtOnkKFAGm+Yqozt6/Vqp1epV4+CgcYbW8szs1Pzs3MnTsxPTU9dvXL+H/zn//XR/sHv/pP/aXJymkMQlUTo0lwn5yTZvjl9MtZdONdqn4O42aqHgwGCKVG0XqvSTdLFiKPN3iZtSpKm6FKh8bx5whtU4Vgg+OwzVy40mg8frm5sb3diMXCY10hbDHYinU6YoM3INebEHSfFPeYksplkIpFKJp9+5pmY47SaTcYZ7ADjdtzcgSlD7piDpB++8bN/8nu/u7e7kYg4HZkE9yg3HQvKUPdMrZD27CpnHIyaOPjo6LBRr4m5Or293b2HDx4kCXFUJalN4uBAKBKPJ6VMwktvpoTebcDBgmWoOG/QkQ7LjHWkywo7Ewt4jKg3Nfv0uYsvr3z0vbXlD2q9TiyZCkeTgRijTmdjdW9/5WGpUqGGZTMZCgkzhDrpRLDAwly1LZWR9h5s1xsRsUNoa3ePYWkqTv0gBobisXCjuNNqM4AobO0fzOwdzM/MLt2796OfXYs6qdLe+sUQlk80g9FGC/WatWoNLriBf+lUPBRKYieCqJiPqtRtYWUJpzKUEG/y2CG2t9sScVg1pElGMVA8n81m05lMJpdKp+lb33rn7XfffRe2ExMT9MHZTDqXTuXyk+lMNpVKxRPMhtA9xCQJJ4mPiCRt6QJU8s7+0et/9d1qtbS2ubV49uJRufZH/+KPCge7TjjYIF7IaIFxGFWxI319mBF72ITJdowRNV4nB6UZcDNVa9QOd3dmFk61Go0HSw8wJmJFoqkaeItAgkGkruIlaWPUEqnDmpbEH373gedsyXMzQPKwLapHSdWSq360tXvrBxv33iqU9thb6YSSH95e4dyCqNCjvkfNjAVb96j7iGAoXylXCoVSi2lEs4kd8xMZFGMci7Ghkq0z0+6R3WV+6SQ5iuYkkvFouFIsLp4+d2LuRJohqRNLOOFfee0LxIkeOzCYSgSgJm41IyMDgJ/Y29wAYxYeQUqk0lPT02oBqkXMwapRajIdkVESrzMJjlTKJQxAnyWlMdiioLEdd236cJPLhAgYYGERH4l9dOfO9996t9EsN2rlza3Nd9++VjgqGp9ginDcYdzH0oiMDMTi3TYOpaegKspyiXQJ1BVkyrSvWW289trfuvrSZ9+5do3aHGMwhK1QQyqAjIcYiXz+3/sH8ewMNkQ8ZEPXQAtGWzfbS/QhahLEElM6LRq4k547+bnfPvn8rxRWPlq59cOlm+93D3Zl+hCWOQ9zFfyFxamyHawvtg2kUgk5kYQWNCBaEIPMUACH7e4eqmbSkwguRaZ+s5pQq9H/R2MRJ7F1eHRUb6ZT6VQinU7Ert9ePn/6ZDJO9KKPlu+G6NoQ5GG6SONMeEk7M/8BoGjcCTDU4guXLntFxZBiMFn3ILpI3youYzCvbsThxggS3knBBKayRmJNJMWhLdIhx8Qvgd6lU3NL66e2jmorGxsbh+uReC6e6NWqFep1t9kqVetUKocTO3GCPb/5aJS5TK/bahJ0gKTjEWL/RDrBKliz081OzN27/7BRb8eYTMqk31UcO8nIOuKEo47bfJHulQptweSuDlbzepkoeewlC3tdpi6y5dmi3gWcqfyV1yaf+uKVLyyvfPD6jXd+vLG2XKhVZCYsHXOUmR/TE2gwr/QZ+F68J8ElQncVCCSSqUCACi5BRkxNCNC5gwTWVoDBLku4oWozHK6XYmUnkYgnsUwiGTus1CYZiE3lMkTNWDQucVomN5D12oQ68W6t2sikM6FouNGsq2nw8drqw3q9duXKp1hllPbQpaYRUFh74qfNXZoToVxiOVltYgH26DEa92pJOBKiT5Ly9Lr1epO43Sjv317bubNZZFsQv1VKRY7g3L5xsxuIzCxenJlvdhr1SuHoYG+72WCI0KzUikeFAlIdJ8rQNsEcXKIJ4+twrRVuM5tphc5eOjs9M7W2tr2ztYEdMRp+kjZg3Ci1kogd5QCBVixTuwYdR5FxsHS05noSB+Mc2rCso8rIBTJqUVtiSnT60uWvX37ql//Dytr9ldvvPrj71tbaMuVstetsZJvpLHowuUIcKxhQsq4UYEjaa8tSFKNULtwj8ZlxOfkmqPEjg2AwMHOYlcJGu1aphMJrLPfX6/FEIpfPTk1MTOZzk7l0mvFGjCYt7VdCK7OpQC8WjzMxY+4qAhgNmonp9uYaQ6Bnnr5KF4iTJL5JqydkyuSVCz/p1WTIRppuuxesi8vpYdotykDvHWaOHlqRQdnO3qNbu+VeJTrHcgOVIkaveri/8cF7oUi0Fei22k3q8dTE9OqtW2knMsfw8cKFeqdbZrGsVmHqXK60wxGWvhu9ZpVOtdHtHbbb37h5s1Rvnzt99uT8fCKZRXVT6VoBhi0UrtsjkBK/EIEbZArsNW4xnklLO/7D7y6ZZ8qGv8b6GKA4xKN3K4TBBwwNDpDaJHMIpmt0bKzttFqVvdXd1Rubyx9srS4VdteqlXId64Rj3QADVKpVr9dsyPgiGNnZZtrcxL1wkzm2KCdC5Vlky6N00CbecqSJ5jm/cGJ6ZkaGKDG6aieZYiiUmsrnJ3LZiWw6RT/NMJehn2nTMJZ+AS5wMExp3AyvJ6emrl59TiOuZEsxsQm6t6RJ4+cO6wosUtL86jRqtm8bzQYbRLVWu1jv7BYrt+8ubTOBZoQYY8YZ5eSGlEK2jMKV1aXW1ioxlPkb0TSeTOeSiVy4FQ9JhOCYNCPyWG66G4ztFgqHLHxjPLqGGEUy62694PrW3s7OXlTiWTCezM2dmDu5eHJ6apai4VCmqbVmI5FZ+MzX/xMZBuACblo8Uxi96UqWBGtj036OD9PYRnI0oWZ3+yChM3DaHOsPEoA7zQZLk3TFs5dOz188+9m/262Vy/tre6u31+9e21m906qVyrVSnRFujM1zWTdicEFclZG08BJ3iHS0Fe7UHOn8JCDJrIc7G6nBwu5OQiahTitcaVRj9XKxeJDY2tpJpJMT+fzc5GQ+k2LFMUW4pAfHz1iJZm3WtnEyq00YZHd35+OPP7z6qWdNO8YGLH1I7JBWLp0685kIHSS4rKKiIvGuydgCr7cC1++t/uDt6ziAukLMoO4QjYKBWrtaK2/vTi8upJpHkTSjxWioF0sxPZDRcTvL8gJjOjiFQO2kpufpzjuHu4lmhViYYLoR6zHSrnXD02fOxiIRKg7jVpZTyrXq3p1bH378IQrNTs2cOLGwOJ1PUa8n4SwDL9TWa8jNtOCHFMRCTYJHCZtDLvc4iAnM5TZl7/HYX/zDEDQaDrdqxU693GlVDzaWD1Y+Ptx4cLi/WSzsFwrlnaPSUakejDhyqlT6UFFJFJB5CBdP0i9gfON68TZuoMVmJ/NOPCG7MHgxyogjHiVQJ5NMfFgqmcxmidyYNWmGrphMIj3BWAK1zKZopbOzs88++yzMSbMmxTidfNJiBS6aG443bZ+pY7sXvvNg7Qdvvnf74Uat2+L0FO21SzOhQlD/woHd5eWdR6svvPRc9eZ7jUZZZmmcseIsTiQ0lc+wb8hskGoaTSRWt4/a3cBULp8Mhav1ajzYnJnO71dKkXY7lszNnz01Fenu7pcPGoHtZrcaTMby05FElgLEgyFGHEwxJrMTUyfOZGZOIvk46wf/z+89HMkz7WYEqs3fRmbbeXvwvgxTY0boZWGRZiRuY6gNFUck2pWD+v6jzeUPlz5+u3R4UC4cFossvbZa1IIOdbnX6PSaBAUTYqGgIcuqGV2lmdniA+Y0MoHN5ZJJB9cxiMbHoZgTdpJOkt3bFDNyZq8zuVQ+TydNFy2DGUY0bvAWPXqLC4vPPPMMGhOYARiYFIfpnESOXg+ydqd3f33r//7OD+6sbbMnj6MYHrBiyVTNrBhIZy5hIlDLMVKqFrY/uo5T6xzj6DRPXTi3u7u7cv8udBiQ6pVM5ydm52amZ1YebNCpd8JBdhGccG9rf/fpxZnU9OlUKpJoVVgFTWeSlWimG3FSU6fyM2fi8xfz8xel5jGUYU5PLyK7Bv3LtlUtxS/ewdY6fZkmJR0ECbMJI8bTHlaGujL367Ub6FsrFSoHa8Xt1YOtld2t1cP9bfaFSuVSpdUtNnvssUqDZvNAplEStcXbGJhBT6+LMyfyOc5a4DCZE7MJz7yCCItD4/F0Np3P5Wfyk9MZVjrTGabXjuzUS2PGLd3ewsLC008/zRN1hJ5AFGRgzWplONjqhHdK5btr20el8sbO7q27y6VqwzTpLpu3rHWl404+HZ9mGyrpZBIx1kJ65cLW2kq1VNzf3a1UGnOnz66tr9+6fo1OvE6trrDU3Tl3/gJt+va9FTxeazSpQ/FoqN6sn1+Y7YSTZxbyhcNyuVjKTE20e1H6pVCkmwq1P/X8q6e+/g+7rDdgTs+xYk/vsg4GIOk//O6yMTyPYhkxvA5IhIBH/YODDHMFRh01lCYtdubJAFwpFg3uiuO/Sy339QjSPEScYkJK4JQZAVamfXfqpWb18K2/+rM3X/8O29wHxXqlSsNmWBZmUhiJJyOJpIyWe92IDF4Z67AW7eSyWennZNJFBy8Z0Qh7fw6VPpXMJNKp/GTuxPTUdD7HUgm7TQkWs5iYxuOXLl6amznBpEmUwumdNoOp5b2Dtb1KoVSlrci6JZWq3V5evr+/tzedTc9kEvkss3vmN6wAiMU1smMUVjLQjIDVatSqpWr5sLC/u7WxuXG0v8+OY6EbWVnfqpePdne2qYroJktDMjNh5trNTE5PZdP1SqPVaTiZdKnOwK8XS6bnstG/96tfm/7if9lihCA2M//ULX4r+9Km9zD+MdYlR9GhBCAh8BMv40g/ntAKI+gNZz8HA5NMF+j+iihRF15MX4jIkk34TUcnJl77rf8ukT/x5s++WwrvReJZOR5DJQuzdBQibNK3sqYo1QYo491eaPeonpuMTU2xuMMeUlMmdezG16rLH79/+uLFbvBkoV5n2E6LZ5jy9MVT589cPrV4MpfNUC3ooGGHuVlgYqN4dWtvv0Cn2JuIS+hnrJRPJfBr5vNX9/Z2H9y/y4p3MOZIWEf3erMbZ+LEg4wTmVA3qk1WrIOJXCKZi0/PT124/CzDsU4z1G7Ri9KL32IUfvPjlaWlnZ2dWDTMzJshE66TUXatVypXmXawIV1igaRJE95PnprJzS2awlKRpKVhKL/pBSL1qg+ThQ7vWaxKWpQVPG566aNLpnAxp0H2aD1cAUraR97PsinNNbIk2Arcp5NKpumZMTmHzsKvfu23QrF45Kd/8WB7u9xoceiF+TOFkyManHgiuJu6yECL9zCZmDvs38ccxsyRDjshgXA8/uDhQzbsTp0+20BUMJR0widPTJ49f47V5v1Gt76xkyuW6KPzmUyCfQSGxcHwRDSan5jkAEi9WiM4EGmoWVQ/ahTWnT8xm03G7y7d3z0smAkxe1ByfMdcmCdYqzTuvP8hM/VnXniOfadKo3nzFrtzjAgdTjUkE43Fk/NPXb7U+vpX11dXP7z+/tLDBzdu3CgeHuRzSfbIK7i4w0yj2+S1B5Z5KRVngNJnJ06cLmIymU54oviV4rsABds7IwM8qooJDmkDUQJPYYvuS/g4giZFItMHdNO++mQristWaLiUp/dreIhKsnSIUQmKvSC7UJ/50q8zFO38+NuPtrYrDL1YWiGgsSRO9y3NR3oOGatHwux+XDp7mqUPGgNbOexVh53YbPwlmmc6m4vGY4snF06fPhVPscsGpewMlirlnf192lwy4UxlszPZDA06n87IqNuJppNs13AUkn5TljzomUMMnVlCdSLPfepZliOXVlaYOgdkyZHFEonimJ9h3PnLF7AM1YJDAnCm5lSr9VadiUSbrWa8wujv0vmTZy6dufD0RfZhCkdHd2/feefdd2/dvLW6tsPwjPYXFxVCrNNBysm99NTCgZwIIl6JyWTtT+yuRjRjUdPRSgg3WToPNjg+77ro436UzPplHIpy82R6rXlsDYPbECtfFUF/Cdqwo4Li45e/+Ldx4Ruv/9nK9k6hbk5gkImbA8E0WxJBukHidyiTjM5PT+B7rMB+LEGf1Yrc9IyMp0LBr37tS9OTUzQHBuoSSxnlyXygxyGNYrW6Xyzf3F9tdXrpRHwyy9g74240cZooxi57lO0PsRwRnHVgXjfpdi9PTi+euvjBR9fY2Ma1HJjC8BHOkrAJvJAlzRK07CGEQhO5TLVSZ55DgdAFHxzuHpTymWw+X66WOEaRTKVffvXVz33hlzqNxtLS0ntvv3Pnxo21lVXaLxUiGGp3Q7FaJCtzNRkGYh9tjSTGX9iWUfSyZhpTK6qd4AoLj9QCJfSPvfyusn7SgG/xgWNW1QwgJMaDOEr4+zlYEq2fssHMsalI+NrP/uJHP/jW6vYe72wJH2YXkdDMRE76ZBkCB+Znp86cPqkP7EDgZw51gGh2lkKsav7yL30+l3DY4ZPeO0yvBxlVCasyHGoeVWpr+8zYKjiR0T1zqlQinmUynUrmMslMMk6roRWeWlhIcPAhkam3wwe725VSoVosweTffOP/KOxuzbDxNTkTZ2FtciqVycUYs8edQqF64+Y9CsJgTVblKHOzRTgp1xpnz5xgk0maI5aQNVamFrzi4jCi4wRMt+vUe4l6J9wIpWOZCWMm2rZiDxiNiKQ2pMqqAXHwA8+U1pcANNuFGK94WEYF8+DHd/n6/Gpz3Zrhy7IOxtmW7XBC8a3LjUL4OOhEgx+8+e/+6vt/urm1U6nX2PlgpjSb47QrDViGzefPnc3n87QRpqOsZ8iYXMMZnbE5mLAwN/Wlz32ajlzaNNsGck5HVqyoHWwhtxq8r11c2TlYK5SD7V6tXmKkI8PyUFB2fxzmq6F0JDw1kYlVl2MTF24v7zIIYMiWSyc2b3/0kz/7UzpwSkh0pztnuZgJa5hoz7rL1Eyh0qrTCnvCiq9ycLKOZW3WoXJp58rls6aiMa6QQZBM9qmZOFuWDKQkTjQRiOWawWSpGSk3ZRUffK9Zqr/EhljUPojjhhY6BPQYqwuH8U1ZqaxXfHwsvhAb6dbB+IvcASf7CBXff5fOlqFVPBp5/90f/PC7/y+T5Rpr9oEuewwEVRzJOu2VK5fYWBK/yl46Pa+0FgkbUj3Eqwg9u7jw2U9/moJSAxwWt3sB8RwjK8ZJrEC3GqAtbR4+PCz+5Q9/urv56Nd/9WuNapHDFXI+lh06xuXJtBMPf3D9TrXeyDnBNOfqMvmV178VrlU45xHstJr1xme/8lq11brxszdjTIG63bgTXDh3rpWe26y037+/unVUYmfQgTKVxdnzc1MvPnc1QwEIK8RvGdUx1ZMlRXTaKRzdubn027/xa7tH+9HcZC80cVgL71Slg2Lr5riGgjFlFM1lzXqcdy2C4iuR3zcg2KzBNGZ1c4aYGLibp1l+QsvNl8C7SAlynOP5T38NK3/zj/5ZhZPD0VjDTNOp7vFUgm0BhrkyMeXcgdk0RQZuIRJzMfzEycsPV6NO/KUXnmdDkFkJFm122iyGy9iUBiVGjVw5s8D09hv37z/a2b/20e1fevXlcvGwXquwS81OSpgVthbHWRrr+xtlqJhFBcI79x8kogGZVTPX6nW/970fUHtwQK3TjHGUOhatNeq9zu6JcPDLZzPL272bK1tr6w9Lcohe2uhP33zr8lMXGXiz2MqOSToj/UImncw4vNHEFml8q3jI4K1bPoq09xbz05Op3KNipNyQObEaGff5bcijhHIsaHpKzw/Goop3nL8NygC+zw2ahIFwNpVHIf4HySLf+t4Ksok+zXBKIhh7eE+/9OX/Kpu/9eFbd258+GjtQa1UY/chEp1mS5/BFi2SJhZkw5UlKdl84fSDTJlZemDxAXveuX8/mUo+femCLDbLWaIObpaFhracrWz3mrVw+I2fvnVxNs9J6beuvZ/OT77wwrPxaqlVbzRrR+jQblRPnD6xu791uHsUTWX2Hj3Y2NpmpYUdDlzsxMJsZIejsSSxIRAstRrpqZM95vEOL/y3iyt3F9vVsxdzxe7U/n51aXd/u1zYf3i00mtMz0wzvyN0s61BzYxFGManshPsoOTf/PDO1MxElrU4to0DB068uhDP7HTjhVaCY580ekzlGlCMj4MkRC97BiRy9i/PPS4Bj57pbcgdwO9TuimLNpLjA3g8XZBPig/pmCRFkFES89xmrXSws3b31k9e/9N6qxJkeSSZjsqwxklT/eXMA52tRDIuRmpsVGj3zML4q6+8cv7koqyCmd6OkQ/npmWdLB7/6Ob9P/5/vsUUBRnlRvvR7v6XvvqVS5cvcwK+06pRZ5gw9Tp8r6HISedAK7D08UfLN24kw0xswoy4kyyTRdn4YOGbXp4zOrH81BTtkEOQE+l4u1pMhHqLk1M7W3tBznJEQg1nIpUIVjq9I/Yia50S43zxEKsuKE79ZCDNHDCSgFUsnExEkvFkJjczO3t67uQlJ7coh55wsG0xZp6MeYcdLCYYd/k8AYLi9B1M7gjhgIOR62fsx/dxHif4k2Ay7qS3JQbzHkKrc/joxr2Pf3rv7m3WcaMsKUxM5DOTmXSCV8Emc2zFiJ6yIMweAB7qdmgfn//MK5OTOanseNccPGDpstHq/e/f+Je8LkoA6LJBy/IZc+5m53OvfTWTzcq2pWwpdlli4dDZ1sbKh9eu3Xn7vfUbN9h+YjecEQBDPjpSpt0sZsaz2YgTZyzAlI66hqOYoM2w+9EJdKoNjpfnZ+Y74cSLF3KV/aNeLNqNxm8urz86ODyqsQCPyswFuiy3JhLp3MT05NzZhcWLk4sX0pNz0VhW+i251OBuWDVBWwrrOlhqtjh/1E+ugW3NEBqXievgIQ+BYGj03if34AKxDh6idbF/nh90kfCLUhSRPVf6wlCkUz7cfXR7+d6HD1Yf1GqlycnMwvyZkyfmJrKpxRl8PsmCc6FUKZQrvK1QLJcvXb4QZ5URxThczbpUKvXtv/zJO29fY1Gb+QpLndQidnBlohWJ/Z2/+3diqcxuqYQPy4UKw1kmuMsPVu/eu1svH9a214tb28yUGvUq78CEE/Emy5ZsWpiIiXpSeEbGvEjeaT115sLuTinD+lgwenouf2o23iwVs+kYi3eVhizSVRv1jbX1yMyZuZPn5k5fmj75VGbqJAfDu7ytJKeNpPMxSynSyo0EE5dJYX78hLPsqUq1tbrBOsBnarvZYGEDDjaOF8+P5WAYooF43QrShGX310j4OdgKZKqfTCwIy51GiSMG92++t7l1Pxt3Lpw+eenCudnpqUyC8+dxVqV4B23nkLNRrdNnFjER57Oi0dj6zsEf/ss/ZlkqxixKNg1odjLDlkbZ683Nz/2n//F/wCbV9dtLxUaDNz/kSFUyfu/eA86pH1aK+1vr+4/ulfe32+xLiB/kcFK31uQuZ/o4cU2oNUM+uJ2ZP1Hd3t6pts7OToU6jVTC4aR1LJFmn8G839abjgWv/vv/KDbPeJAjeNRMWTYdtZWYAjAeldpO0sVxR9EQ9A0EqnGD+szcBTDKVCGWUPH9Rrd8DI6ItLkklNBCyFUOyvYT737CAWSRI5uIsiAYTM6df2n+4kvV4t7KvXeX7r///s0bF88sXrl4YX5ujq1UxldxJ1QoNHe2di5eOCdb/KH4j37yF6x60MGzgkTrkxZHbYnJdhF+Xl/b+df/5k/+3m/9/R//+PU3vvctzmNf4OT7K6+Uq01e+Lz5/jX2N9mzZGTUa9EG6EvlRSYoMT4JDlaxVMKfDAsCvQfbm2dz04ne/tHBFuG90kruPdoo1zj+HWHil0xEz33maoNhYLNOLiUb6wbXksYK4lofXt/BAzb6RTxIlRJby2Xcqcmf7/6ELtcSuqy1gpoHNmzZk291gvHcTCQxsbaxz6zl1vLG6tr2q599kU0kpsosIU2GJzh1zGsEn3rm6s/evt6slM8szNP3UUU4ikVDl4tzOpzIpSiR8M27q3f+4Btf/tovnz63WN5bZ6y1vfWwelSobizPpkOlvcp2obIfOmDIRqAgPsshI7N8JVWk1XDY3WdNps2bXZFMKh1Ixy9Ondl/uHxU5nx7+/DgoFxtSejtdU+dPDG/cDrs5GTZg7ZJoTyT+u2odd3n1n7mGAdbmxqv2MGzNLhjG02foU25AdlEJTckkKdM9P4k3B6DAxMYDiHoo2QZN0uNx6byUnD74pXnC9trP/vRd9gkKJS6f/7v3vj0Ky9cOHeW8Y4shPQ6GxubB4fln73xTpaFbXlrwZwOllfx5TghYzKOXvBaIbZe2z5YuvugEQh99auvxfNztdJh6WiXN6V4XZESMn99+/rHMsfhnd5GTZor0x3O+6CWOKrjTMSeunwxNzGxt7PPRuH7q7d4eyYfDVVZcW7xMRdqlqzM0vAzqWQ4kY+l86xcqWttAWXhxthTjaDWkACtBTfGIUscLKXznnnkso8utkCkDzZ3t+tVTP/dIhsGrvV9rPq4KtGH72b5IZq25H1iT2E/sj93bJoRb7Mb+8xXfisaTb3+/X/FymSp2vnxG++wwfD0pUusNTLkJTq/+dZbjWqNjhd38N4vXS6uYUjMsJepDtGYz+mQu3ZYZd/h2jvvJuLRz3/282xaxXOsbcSrwWiteMCp2IX53Z2dQ4hYF8MaEvjNcT6KQ3fO6P3ajY84TjSVzT115eLzL764uvxgfenu0c4uUZiVUcTiGA6mfer5Zy6//Oq9tgMTU243QkvZxcCeO9UmAkCCyZC0XOJgPwhKHvVuENwbYFJuq/Bn+NJjCX35A0kraADqe/Br5QO71VFK+HNdLG4EOzSOF7/0q5yL+qs//2NmmI1Wm6Fyp9l55ukLLE3cW1o9KspuT4u1LibZdJxyWICiy3lRPCWdsbwrGzpzan6Xlw/rzkfvvp+KJZ577qqcIeNVTDqCWIIVkCvPBvZ//CbBHddiYGaxMjFCZXkhmDUvXljmobd3tLf34UEmN724uPjFT/1mq1peX1rmQxGc8qrWGqVKMRifrKTPdo4kCA10rb6ya6d7nEEGRtFqUz+q+kyYa4URvkMtuD+6NlUHatf0Hje3QsmPGeOR74J8Wo5N+llp2nrdPloc5aAK++9GrHYuiJURZiIWvvveD7/z7W8SLQmGmJ41y1wmvbqyKm9KshJCbWAvApfIEiTe5kwBI+gQO9DdWPzhXuG95UfVVofXQBk70dK+9vnPvPzK80XOhe7v0WW3mjVOCTxaXrlx8zaHNMz5QGbdzGpYTGPSLueeGTERFlhg4dU3qT4y8A2zV3Vidu6pS8+8/JkvTs6c7fSiRU6JcIqSMlBOsRpJ17y2XAoetQwIXJ/gYEUyd2qiXsMOVqgoYERZi6tI1Qsc6R5AAMmrAS6/Y34MQzeWWJ6juJ6UfrEtjrqZR5vQLJjSyS5/9Oa3/uQbvAYsb4jLZi0n5lga4ZuSUgxm17LaJdt24nFZsUg4HOy4/mB9ZXefYwTEXvPGmpzvzDuRr//Ka3zLoF4qMGXGqSxZR4Kx7c3tt974EayQKC9syjBYTvuKBXnpgpPvvF6by0knL2fDQ7x+wde0Oo0qppqZm79w6Sq+npk/H45laUayiO5Ndo2G7k0drA9qDdJqPUk8Zh48aNZPdjDsTHQXW/ts6hK6Dibvya5B6cIQOgX6mAsv/6M/fYwc1JMZL+dy1m6+8yf/1z8PBJtbByVs+9LVK7KVg2NN2wJPNmV5M8eJsLS0VWzcWt0s1ZuMvpiScrjqxPT0uTPnTp05PTU7yUqkjNarNGw+ZSivRDBnzaVzR1ub3//B95mVJ9k9jAU5nsHGF+uXvFAXjcUnshwTSvfY82Jt1Emw8CylkfeVu7VKuVYuUlV4+Wxq8Upi8lw7kqENmfDnq83UR9OyrLnUULbsww5WeyGGP0tjsU1C6h9ZPkZ9ZF/7tA1dBo+GFWqJcc3fIEvzpOKUrU2TMKFJMZ4otkvppTLIhT308nRwH/mh3sWj0Y177//rb/7TpZXVeruby8Sfe/qqE+UNTwZHcuxY9gZ5qbcTur1zsLp7gPGdSOzE1OSF86cvnDs3PTPPAiTzXM4OEOyxCCuP/D8ZOCzCGy7hWIzXZTk61C6Wbt5659zCHIc4ZedZztvLLqkso3PxeQb549yYw7F4JmEEbP6Trl9M1eawfaWb6mQudIO8RSgrLRD1iwEOPhe7ezBTdmlOxkfi4AECQRPGoxbxGAw42HBB20Gpgnqcg8ly27TH0P1VNURXr6VqhnWwyRgiGvM41sFj8GRKIGdm15c//Kf/8/+4f7DHcjav6z97+UqObyoZA/GmxU6hdHd7u9YKzGSzLI88d/W5iZlp3keRT2U1zdeD2JZEYQzAYnWrubd/sPFotVgu8OXCaK81nwr/xtdeOzxo3Fm9OTORwamEaErKijQ3/ojXsnLJJxMScUKHvBNJtrnQGRW74XbXeaoZyZji82WOfhizhQLdptWAtmYPtGDPqbASbsdcrudg6rXXvsMMB61k8kKtETzUguE6nr9fywHRpnaK3hSjX5ABFB6MPsYIKsHkazk9TYZI4CVMOXG+/+jGH/z+P97c2cL6rBVefeZyNpHk3aFHe8XDStlJJ+emZr7y5V9aWDjFK/fSutiHYsra5NQ7XwmKsT98cFBY3+G1hN12oTAZj5w/NXthcfb0iakJ8y2YXiT103vb+4/wcZrFcumGUdcsccvGgzTCIA7mZWc4S44pDryZNHWc+YZzUoyI3ZhjecbWQh1TNKqB68H+iQ7hYPge4wCRa3DExkYHRAnEV3skR/mQMGjc0E1oSXmXK1sfNcvwERxNDDyqVz0GNsvj9nP8DtCikhSZ/jhErC5s3fuD3/+dh2sP2EnmO2hnz54/ZEKczE1O5NlabgbaLEm/8NwLC7OzHBOhrTNW3tje6IWcre29o93tdq2ZTTvnZnOXT848df5Ugk9RsmwsZ1bZbQrUItNbnRMbS9ePNj+ameBdpBitRNwsDZnxlbyBiNyowylQadcYDM3oI3qRdDN+iS8eimX45xlBkp4VXbiUxbUyUs3KqMkZOrJjrEWG5A1eUrOE0lzGK1KXBkwm8l0VFG7ukCjcq3uDfK2DB8Fah4Q/cMVRhEGJQ0THPvr08RnJRRcfc+68tr/2B7//Pzx4eIfWlEmnF06fSeem5JNNnPrgjkVDwctPXT579uz+0dH6+vre7k671ZtIJs7P5c4tzs1OZp1u8wRL3NiJD4Cw2+/QmXea0fRWa77aRUJw4+GH1a0bE1k+J8KJfbexmnNF9PYcmubjCFQ7WXsjFjMM6CTPNSKTMonyPNKPvZ6LLETqq4Q54yazMUr5+i3Y2sbUHvskCTWyH2Qs3newNbptrBZiqNBF1eE+al+XsXoRQp417ZdIepDnUGb/0c+nDz0+pcajMIygm8Wtf/a//M7aowenz5zGcJyk4FMhMqfhPUTZTJaxBq+QdA8PUpHAzMLcpcXZU/MzHKmlZ+SLHLzCks/mg3wXXIK4GeCGQzu9k4V2iomSfHcpHNhb/ai0eXMql5JQjztljiQ7xOYlKkI0IugCeLkj3HGmm84pKp9M0zwH23LYYhp7ueB+s/bM/PM72KMUlm4vq6b3A7ROCER6GonSrtuMkwQ+elmNydL0KM6TQPx8+vgi/9jaZeYetGNeQY21yzt/+s3/beXREtvrcpKC/yNZdpKX+5jq8DGCcKc9m0mca5TPnpiYuPq0BEo+JSZTW75xWJ2bmmZLT3f0GBLHI8FSdH6nOYNw8RlVRrYqQntrt4pbNyazvAcqK4n4mYkZxzo5gkJami+vLUaT7eSFViAp3x6R83d+u5uy+G7SiI2MUQe7IyYfMskBXsY3ah6DZRq4eQbNVFN5ICzIn2aaJWtLQgJMF1lJBsVJ/QDCXROaa9M2MUT1cz2qBrb8ltatDeIoWbfihYVwcuY3/7P//sz5q/I/RmHFiZf6q2VekaHXPeUEPxNsfNnpPHX1QubkqdrdB82bt3tNPgEgTmUOw9SWtz359BLfZGKkXQ9l91tZ6reMfXUgwnsYne7Uyaezi1cPyw02MGgAtHSmzlIx2ItAMznAEArG5trBhJw0kDNkqCmXqm0T/VKo/w21iwPE/MlY1+JpAogCbcKPoGRiKSEcpgUTKulHBnOteiYmuYpatn6NDeYAgj/Xr6EqaSGW21i1be5QwqwVW5iYEFMHnPzf/+3/9uLll3g/iGrbKJcahYPL8xMvPnt+/vmLwViUd9wavLPaaoSuXGazln5W9vZJUMtpwpzfoFHGkoehE+2efgWy7xsYsggytXAlt/j8XrFJ3UA8CyN09QzMwOOvGU61IlN+82phrXFGbaJlGLIJj9qC1VV+b6GGxkmVyN3FgWaU+yBESOxljWcIeVIp4PyNLpVoyzOogFavn5+/aSJYucf+Q9j59f/oHz73/Kt8hIUVwlatfPPu/e3DQisaq1Qq3cIBDY7Tec39Ai8+8clDlidFBzkbx7FMaUqF3mSlHZOxjpbVMBcjChZL0e383KWp0y8elfnKS5dDopVag79ag+9BB3uxRakjbuRVy39yoYaMoOV3z2S5D0YJjahqO4+GJzsGljoBvD/T4gm9fZGAXB49WptlWzY5/oo17AlL6M8Y1MfN8UsZi3wcgh9Z04OYsoXL+gPf0vzht//VG2/8kAVFTmLEY6GXnrkyN5kL8ZUXPpLGB2XYEODcVjSKx1imWjxxAt8Sv+rBzG53lpOOMtiSHt641Wclict0V3wbqnLI/wuC4CwDALpiFjt4hy2S6srbd3yfyLvcpiiPal4vY8yv2koz5OUz+2wKKRwskRbb5zs3R0gMFuQSro1USzX0CIrJEjyT7vP3k3i5QzB5VA09Pn2EUYjmWbhN9GmOSVlMTWBP1G114q/92n+RiGSuvfWd5154IZ/P1hkrc1S5hff5yBzTpzAfWeJ7u7xp0uF0JCWVzfroQXeSN8QMD5EnjdZ3IcLYjIXNTiyVD6VyDJWNkUx3DKZpC1IrRi6hHfWHhyb+Ehu7DUwe/+gvHwhDH41pXsLaD6Sb8Ji44zKVbph5OR4JwD7IpESSKZRUDMnlT8uodxd9iNBQuaxselArIRyicnkZZY7LsjijCSUxpZXRDnGSUz6H26v37lzPJDopJ8KpW45/ZNgcSDhMD1hyZCLLC8m80DAxleNrArXwVDEwaU5cYGjRb0iKlkXvxirylQ1eo1G3049LN2EGSa6th+jNo40Kmml5+nEpi+tgPxSNfO4ezBl5MuaQWZDaRSqQWyLXcybHrVAeNej9XWQP6P4qHz/Qx9yVQq4Falrxh2j9OH6G/rQlAdkP9xhiC8wuS8ehXv3tt/4iUNvP8YIhk+ZeJ5WI8ZIxGCF2isVDvHWej+dPhZxpXqLADMaSYhHjZbGMVUnFId3U+D7cyFVNBP+4y/IRBDeYaosZLsZxVUS5I8n+HSNLlHFVF8+aCyEj9hpQd6RO95lDaGi1kANweNtn0oppkPuNWB8VzY9vCUkYHV1WyqSfqwFOMYz1BZ/SsJoViHzmc7+eyJzc2juo8PYK33EzAxNmtnSgvMMWy19KnHixl1zkO55qbJWl8vx6WiVJaJlswUwWTxbQV42UElpyhQygwk7/PDqZaI9eRqoQ+sZWA3wsiQkUUgSkcmnCkls0A7FPsBU1hi5ogcDE3OWmCMpTsxTHwv1AzfIjKJr/7mdl4Qp8PKEK4gXsq6985WDnUq1WrldL5Tr/U2fGVsGJ/FRq9jyTK7pfTtqZVn9cy5ECWllaWNHEM4fNEuC4y09OvuB7bQHzkVS7ejA+TW0MqnxH0wz0rBQ/jmENel9XA3Fxjf/sDBtZKm6YleWsCZWORNFZKGyvL/m25B6aC7FwcDQLyBCOkgNUuN4tTx4tE+0FhZXooPVVnsQBEquZBAUnZ8+Bbzwir6zxXQd05pN9TGZFTQnJphlZw8PAp795cm/CxrrCl+GppHnCUIR7l6eYQISBp6rmk+shyu/4FuzHGJs25evbxcOR2mOyBOAlBuR5mI/7HdTQz6qfHiqGZecJtYBjE2CqpWwCjyhEafpwbRUUjqIw02VFwkAEXVCNDWVU0a/BBnW86CfS0FSSwRoyntsQdNQs/m90iLamVPy6zrMEWlo/O1O4vo9Vb7+BFFk5jJZqFNPjIHRGnE30pShPRbBpy8rjIPgq149j00puqfpw8ZtpDgZjFE3ZiqNN41Z00PTRz5B0n61JqW5DwPGPpk1KFRFJWhCtUWoW4Wz5jyaUpxXnfmVnBA/ecllLWQSFmyy5iR5yV3zL1mLBwdZrDblKIpxBknhmLiIQIgB6rnENpHI1y0Ud/BmnmNs0BxEHnixbC1URVh8LJ+GX7urnqa25ikzW6KNmDd39DG1WX7R41rBS4wrbAadCYs1kyW3CqqEJ1/r6YJFsQg1hH0cT1FT8onfNHSFBP/kzOChKvus8t1Z4tQNyyRM3S2JUliKMwodKaziIgbiO46O5Y+/SpXmaKPkQf6Wi3Fb/UT7WnqMK+NUbJRSIG/qN3VzjjUd8EqhtXlIpPLXcxKhyytFDEzsIqjGm3kHwcoeNqwgDxTOtFpJRZyoT4W0uwTEXT6qDvStkCG41H4JbKn/Cz0EIfdqQ5c+FynL2c9C0UdAltmhKbhFI+CEDTDy5ItIUUz2srPSuWdwt/wEOhjlZ/ly3DwbkMhXWriiV4sf266cBVQWISKjMDxChF+2k1fJEwDH8BVelSMpcZPmGhx7U+x3UyoVaoIc18Ovl2nrgmc34Zki6UkJCQu+WlxTlGLjFGU1YKmU4JM7/6Onp54EaVm0/XHRTzkNKWiQr1yJIkIeKJRpF8ss2EITpn2UiiUE0JkLiRlVLjDRyiddNhtFAsq0GijtaoEERIxyfCCAqSVOUv75eRof+45NwsuTQ6vUkVE+CA7dPRENXq4BNWCrLwSZsFgnhbsrqOniU3mALmhZM7wboSrUa9u0IF2NQ1d1Nu+0YPi4rZTJ6H1LUSlQ+o/hjIZaJIe/XSKvMWKohoDJRBTTt1+G4tJTQZy7SfkwrwgJtwsU05AB9fhUXiQImS9OqlXKzuvlZaZZLaOfBimrzvIQI0DYwDsHUENHAoHOXIC0USq5gZHvqKfi4u+UyjIBow0QRhnOPeXZ1sLnj9LeZw4khcUO0Q48QC8Q/5DKajqKpGODK300QbAwLm2vMaDoIibOuA9zCizMk6T5KwnSORgdhq1x8909Y6LC7AnD2awwvJKOqlSVyDFTR9NHoMirUJ99NQiSxRAwlG89iLWGtVWNQtAKlMCZXE/bR8IP0/5fLlM/TSiwvglQNW0jRxNjfGlvRVCEX2e9jT1Ox5eAlzjMi3K1EV5wiGQO5SVc4zC0TUcPYx3WwfR4U0S8McIujlOKGfk0SOg/u8hh6dKEjP4rmMYensRrm8Qb4qr6HMMxcyZ9Q1ojwYwGII2+ULRDVZEgfy2gsCbl+fH/aEoo15Z8FDCTGkiiGtiUItWINkNml5lG1hvAsApKGsv6Gj36G8DaCjKpmlORnPoj5C1bDL2g0raLRjYT/Pop5HMRvQJseQraNHrjFsYkhZH20NqFi2LSfvB+ilZEfaZAja60AKJ7cB7OOewLVrL2bfDaxFc+vsYrTOkPaZLn5QgyBiOKf/PgJSVtV/XAV8Zi7J+UxKCLIj2YFWR1UIncbEvXU5OOYenlK6z35fiXosmYkhbbXKLLmo5I5UECDwKqEcJqE2WIf8Ux/mgQNl2U9knhM1gjuCMBo4KqugvQOXLKk3hhvjiigaPBTBD9joTSXH/iJaSgU5zG0CLVoIPvTo/xtQUazgJA7Fj42yx9jhwhVB1EMfrjV4zqk2xAVUmjBatnj1JAIT55ftqJKNJWUJ0qhwm0I4ma4bOTHSHTRBB9WPPXJTEqQ5J9cLo2igiuyvTwfnUUmYbK5CS+jap+9oj353fjIFScmduuiYeCpLZJUbRfRfbRSbBGAqE6u/ZSN1c7i2ehgswTTcDfFF3EiVenNfeDm5snnds1XYNzQYKk8j1IiN3J6jlA2bjmIACbyuvZ2yyvKyQq1CBF/wMG6RBj7dJZMNDXn01DcfVQZ8n8xlyDv6ioi1YpaTh4F5F7wFbYguLoYuJRHd/JED6OWwAVZ2YHuAoRei28o7c0qawgMdymd8FD7evq5FCbLTRuVSBsZ8isijJoGwbBwUeVHRYHkViJflskWtY29XckW6JbabyrXUv8fFCHeAtDTSnIAAAAASUVORK5CYII=",
61
+ "image/jpeg": "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCABqAKADASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwDs8UUtFeqcIUUUUALRRRSAWikFLQAtFJS0ALRSUUDFFLTaWgBaKKKACiiigBaM0lFIZFRRS1QhKWiigQUUUtIAooooAWiiigYUtJS0AGKKa8iRrudlUepOKw73xVZWxKQK07jjj5VB/Hn9KlyUdxqLexvdKrNf24PD7/8Ad5rAtNQv9TVpp28q3PCRx/Ln3J6mpoYWjJYsAvJOK554jpE2hRvrJm9DPHODsPTqCKlrk4tcsbXVUN1exwIVOATwR74rqYpY54klidXjdQyspyCD0IrWlPmWpnONnoPpaSitSCFmWNSzHAHJOM1Gl1buFKzJ83QE4J/A1PQQG4Iz9aNQGhlJxvXJ5+8Ko6lrem6Suby7jRtpYIp3sR7AZNQ3vh20uiXhZ7WU/wAUR4P1XpXA6z4F8RLdSXME0d8rH+FtrY9Np/oaxnOpHoaRjF9Tqbj4jeGbdSReyTN/cihbP6gVFH8SvD7l90zRlADtccnPoO5ryfU9AuoZSb+xuIJD/EVIz/SsO/tlhMZG4gjHze1YfWZXsaewR9K2Os2GoQrLbzgqwyMjBq7HIkozG6vxn5TmvmbS9bv9IlWWzu5oZFO5WRyCp9RXceF/iFcQ6kz3ZaVpFbe5bLEk5Jye9UsQ+xLpI9am1CCE7QTI/wDdQZNIl6XGTFj23ZqhYeLtOvwB5ybj/DIBWzHPauu4RQnPoKiWIn6FxpRM+6u75FLWtvDNj+BnINRWniYsjR3OmRxyjgnzTkH8q7PSYNGvBse1UXAH3Sxw3uKt32g6VchDLpkTlPulTtP5jrUOrN9S1CK6Hk9/PfS3LSW7Io7DO4/mRVfSfD97rWsJAUbDtumlA+6vc8f5zXq62ek2Yyugqcdwgf8AnWXd/ELQNKd4bezkMq8MkKKoB9CajVjbSNebQraawjs1i8tIlCxEDlMV5P4r8RR6RczafaNHNeqSrupykZ/qfbt39Kd4t+Id3fPIum3GoWyzqVISQEIOhAAIrzmOxvfMOyJnQ9GY4z+RNJq27BXlsh0jPM7yyuXkc5ZieSa9R8I6qH8K2q43PDuiYk8DB4/QivMY4yJSk0UiFeGAINbNnrmm6fZyaes9xBvfdL9oXABwBgEdvrirpycXoTKN1qd5ZeJjc6olp9nDJI21ZFOPxx6V0VcB4aeG51u2aGVJFG5sowI+6fSu/rtoyco6nPUST0G0UUVqQFFFFACMoYYYAj0IzWTfeFtD1If6Vpdu59VXafzGK16Wk4p7oak1scJqHwq0C5hf7IJbaYj5Dv3KD7jrj8a8q1Xw7caLqL2t3GUZTwRnkeoNfSFY/iLw7aeIrAwTgJMo/dTAZKH+o9qxlRjukX7SXVnh1gJ/MRI2lfcQBzz+ddjaTa3Y7fs7tINwG1mGAPWhfA+t6NcLL5cd5AmSGgJLD/gJ5/LNaFo7zyeUqESjqjDDD8DXPKNi4ts7bQhrL7JL22a2xyHBy31xXeW+qWksYj+1R+fjkSDZk1xHh7U7u3C2d/FKgxiKR1OPoT/KtbXLzTodLmurxARGPl28MzHoBWTWpqm7Glqmi6lrNu0D3qwQt3tmOT9fUe1cZceB10583RllQnhx8qn/AArI8P6xe3Ia1a7ZbgZKFThZB1xjsR/KneKNe8TadayXMU8s8QjICL8xVvp+tU+ZbEe63qiS58H6PPLHi1kUD+KNvlGOeQfWubufB9vHcyiHV54mViBFPFuj698HpXMWPiu5h3RhJ4ZDjJ+ZSp/Pv71uL4hW7gUG9dbpiAXZc5+uc5rFNX1R1OLS0Zm3vhTxhpsnmJCt7CDndbndkY/u9f0rKPibVI9RSGW2KymQLiaLBXPHRhXcaR4k1qIOPMs7uDPBbKsD+daUPjmSW4WF9MlaQMCREC5Iz0wR3qkk3oS5SSs0dfbWkVtGgEUSyBQGZEC5PfpU9VrS/wDtqBjp91bH1lAUflnP6VZr04O6OCS1GUtFFWSFLSUtABRRRQAUUtFABVW/0+DUIGjlUB8YSXaCyH1GatUUmr6MNjhk1rUtAvf7L1smW3b/AFVwSSCPXPp+oqPXtXTWTDAceRBnaykje394/wAq6/VtJtNa0+Szu0JRh8rLwyH1B9a8ovNO1Twxemwul863PNvc9FZff39v6Vxzo8rujojUvuaZWOxT7VHdNC8fzIxPRh0x3q5pmsNrsFyTdzJqDLmZBKxRx/eVTwPoOlZulRaZNOX1cS3DdU5Owe20V11okrwhNP02K2hPRpMKCP8AdHJqeVJasTk5PRHlfiDSr+xvGZmkkic5V2GefTNZlspRz5qsyn+7wR/OveovDkUiML+5+0ow+aNY1VP6nj61SbSYLGUJ9mh2/wADiMDP/wBesuV2utjVO2jPIZbe8mIFlPFGzcZkYxEH1JPy/jmvVPh5pV7puhyPqEqvdSyHOyVZBtHT5lJB71tW0ce0/ImP90Vet7eOIFkjRC3XaoFaUI+/ewqsny2uTUUtFegcgyilpKACloooAKKKKAFopKWgAooooAKq6hp1tqlm9rdxh4m/NT6g9jVqlpbgcFD4duNCmfzZUlti/wC6k2fMPTnsfb8q6LT3VWTbkgtu461tEBhggEeh5pqQxRklI0QnrtUCueWHTd0zaNdpWJWfdj0qKWJJoyjjKmn0VtGKiuVGTk27sqxWSxH/AFjEehAq1RRTjCMdgcm9wpKKKoQlFPwPSjApAMop+B6UYHpRcBlFPwM9KMD0pgNop2BRjmgBtFPwKMD0pAMop+BRgUANop2KMD0oAbRTsD0oIoAbSU/AowPSgBtFOwPSjA9KBn//2Q=="
62
+ },
63
+ "metadata": {},
64
+ "output_type": "display_data"
65
+ },
66
+ {
67
+ "name": "stdout",
68
+ "output_type": "stream",
69
+ "text": [
70
+ "load checkpoint from /workspace/BLIP/weights/model_base_vqa_capfilt_large.pth\n",
71
+ "answer: jet\n"
72
+ ]
73
+ }
74
+ ],
75
+ "source": [
76
+ "from models.blip_vqa import blip_vqa\n",
77
+ "\n",
78
+ "image_size = 480\n",
79
+ "image = load_demo_image(image_size=image_size, device=device) \n",
80
+ "\n",
81
+ "model_url = '/workspace/BLIP/weights/model_base_vqa_capfilt_large.pth'\n",
82
+ " \n",
83
+ "model = blip_vqa(pretrained=model_url, image_size=image_size, vit='base')\n",
84
+ "model.eval()\n",
85
+ "model = model.to(device)\n",
86
+ "\n",
87
+ "question = 'Can you description about plane in image?'\n",
88
+ "\n",
89
+ "with torch.no_grad():\n",
90
+ " answer = model(image, question, train=False, inference='generate') \n",
91
+ " print('answer: '+answer[0])"
92
+ ]
93
+ },
94
+ {
95
+ "cell_type": "code",
96
+ "outputs": [
97
+ {
98
+ "name": "stderr",
99
+ "output_type": "stream",
100
+ "text": [
101
+ "IOPub data rate exceeded.\n",
102
+ "The Jupyter server will temporarily stop sending output\n",
103
+ "to the client in order to avoid crashing it.\n",
104
+ "To change this limit, set the config variable\n",
105
+ "`--ServerApp.iopub_data_rate_limit`.\n",
106
+ "\n",
107
+ "Current values:\n",
108
+ "ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)\n",
109
+ "ServerApp.rate_limit_window=3.0 (secs)\n",
110
+ "\n"
111
+ ]
112
+ }
113
+ ],
114
+ "source": [
115
+ "import json\n",
116
+ "\n",
117
+ "with open('objects.json') as f:\n",
118
+ " d = json.load(f)\n",
119
+ " print(d)"
120
+ ],
121
+ "metadata": {
122
+ "collapsed": false,
123
+ "ExecuteTime": {
124
+ "end_time": "2024-05-01T09:07:37.393500Z",
125
+ "start_time": "2024-05-01T09:07:15.627630Z"
126
+ }
127
+ },
128
+ "id": "7cd25587bfee1a30",
129
+ "execution_count": 2
130
+ },
131
+ {
132
+ "cell_type": "code",
133
+ "outputs": [],
134
+ "source": [],
135
+ "metadata": {
136
+ "collapsed": false
137
+ },
138
+ "id": "f868e97453edacd8"
139
+ }
140
+ ],
141
+ "metadata": {
142
+ "kernelspec": {
143
+ "display_name": "Python 3",
144
+ "language": "python",
145
+ "name": "python3"
146
+ },
147
+ "language_info": {
148
+ "codemirror_mode": {
149
+ "name": "ipython",
150
+ "version": 3
151
+ },
152
+ "file_extension": ".py",
153
+ "mimetype": "text/x-python",
154
+ "name": "python",
155
+ "nbconvert_exporter": "python",
156
+ "pygments_lexer": "ipython3",
157
+ "version": "3.8.10"
158
+ }
159
+ },
160
+ "nbformat": 4,
161
+ "nbformat_minor": 5
162
+ }
eval_nocaps.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ '''
8
+ import argparse
9
+ import os
10
+ import ruamel_yaml as yaml
11
+ import numpy as np
12
+ import random
13
+ from pathlib import Path
14
+
15
+ import torch
16
+ import torch.backends.cudnn as cudnn
17
+
18
+ from models.blip import blip_decoder
19
+ import utils
20
+ from data import create_dataset, create_sampler, create_loader
21
+ from data.utils import save_result
22
+
23
+ @torch.no_grad()
24
+ def evaluate(model, data_loader, device, config):
25
+ # evaluate
26
+ model.eval()
27
+
28
+ metric_logger = utils.MetricLogger(delimiter=" ")
29
+ header = 'Evaluation:'
30
+ print_freq = 10
31
+
32
+ result = []
33
+ for image, image_id in metric_logger.log_every(data_loader, print_freq, header):
34
+
35
+ image = image.to(device)
36
+
37
+ captions = model.generate(image, sample=False, num_beams=config['num_beams'], max_length=config['max_length'],
38
+ min_length=config['min_length'], repetition_penalty=1.1)
39
+
40
+ for caption, img_id in zip(captions, image_id):
41
+ result.append({"image_id": img_id.item(), "caption": caption})
42
+
43
+ return result
44
+
45
+
46
+ def main(args, config):
47
+ utils.init_distributed_mode(args)
48
+
49
+ device = torch.device(args.device)
50
+
51
+ # fix the seed for reproducibility
52
+ seed = args.seed + utils.get_rank()
53
+ torch.manual_seed(seed)
54
+ np.random.seed(seed)
55
+ random.seed(seed)
56
+ cudnn.benchmark = True
57
+
58
+ #### Dataset ####
59
+ print("Creating captioning dataset")
60
+ val_dataset, test_dataset = create_dataset('nocaps', config)
61
+
62
+ if args.distributed:
63
+ num_tasks = utils.get_world_size()
64
+ global_rank = utils.get_rank()
65
+ samplers = create_sampler([val_dataset,test_dataset], [False,False], num_tasks, global_rank)
66
+ else:
67
+ samplers = [None,None]
68
+
69
+ val_loader, test_loader = create_loader([val_dataset, test_dataset],samplers,
70
+ batch_size=[config['batch_size']]*2,num_workers=[4,4],
71
+ is_trains=[False, False], collate_fns=[None,None])
72
+
73
+ #### Model ####
74
+ print("Creating model")
75
+ model = blip_decoder(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'],
76
+ prompt=config['prompt'])
77
+
78
+ model = model.to(device)
79
+
80
+ model_without_ddp = model
81
+ if args.distributed:
82
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
83
+ model_without_ddp = model.module
84
+
85
+ val_result = evaluate(model_without_ddp, val_loader, device, config)
86
+ val_result_file = save_result(val_result, args.result_dir, 'val', remove_duplicate='image_id')
87
+ test_result = evaluate(model_without_ddp, test_loader, device, config)
88
+ test_result_file = save_result(test_result, args.result_dir, 'test', remove_duplicate='image_id')
89
+
90
+
91
+ if __name__ == '__main__':
92
+ parser = argparse.ArgumentParser()
93
+ parser.add_argument('--config', default='./configs/nocaps.yaml')
94
+ parser.add_argument('--output_dir', default='output/NoCaps')
95
+ parser.add_argument('--device', default='cuda')
96
+ parser.add_argument('--seed', default=42, type=int)
97
+ parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
98
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
99
+ parser.add_argument('--distributed', default=True, type=bool)
100
+ args = parser.parse_args()
101
+
102
+ config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
103
+
104
+ args.result_dir = os.path.join(args.output_dir, 'result')
105
+
106
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
107
+ Path(args.result_dir).mkdir(parents=True, exist_ok=True)
108
+
109
+ yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
110
+
111
+ main(args, config)
eval_retrieval_video.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ '''
8
+ import argparse
9
+ import os
10
+ import ruamel_yaml as yaml
11
+ import numpy as np
12
+ import random
13
+ import time
14
+ import datetime
15
+ import json
16
+ from pathlib import Path
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ import torch.backends.cudnn as cudnn
21
+ import torch.distributed as dist
22
+ from torch.utils.data import DataLoader
23
+
24
+ from models.blip_retrieval import blip_retrieval
25
+ import utils
26
+ from data.video_dataset import VideoDataset
27
+
28
+
29
+ @torch.no_grad()
30
+ def evaluation(model, data_loader, tokenizer, device, config):
31
+ # test
32
+ model.eval()
33
+
34
+ metric_logger = utils.MetricLogger(delimiter=" ")
35
+ header = 'Evaluation:'
36
+
37
+ print('Computing features for evaluation...')
38
+ start_time = time.time()
39
+
40
+ texts = data_loader.dataset.text
41
+ num_text = len(texts)
42
+ text_bs = 256
43
+ text_ids = []
44
+ text_embeds = []
45
+ text_atts = []
46
+ for i in range(0, num_text, text_bs):
47
+ text = texts[i: min(num_text, i+text_bs)]
48
+ text_input = tokenizer(text, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(device)
49
+ text_output = model.text_encoder(text_input.input_ids, attention_mask = text_input.attention_mask, mode='text')
50
+ text_embed = F.normalize(model.text_proj(text_output.last_hidden_state[:,0,:]))
51
+ text_embeds.append(text_embed)
52
+ text_ids.append(text_input.input_ids)
53
+ text_atts.append(text_input.attention_mask)
54
+
55
+ text_embeds = torch.cat(text_embeds,dim=0)
56
+ text_ids = torch.cat(text_ids,dim=0)
57
+ text_atts = torch.cat(text_atts,dim=0)
58
+ text_ids[:,0] = tokenizer.additional_special_tokens_ids[0]
59
+
60
+ video_feats = []
61
+ video_embeds = []
62
+ for video, video_id in data_loader:
63
+
64
+ B,N,C,W,H = video.size()
65
+ video = video.view(-1,C,W,H)
66
+ video = video.to(device,non_blocking=True)
67
+ video_feat = model.visual_encoder(video)
68
+ video_embed = model.vision_proj(video_feat[:,0,:])
69
+ video_embed = video_embed.view(B,N,-1).mean(dim=1)
70
+ video_embed = F.normalize(video_embed,dim=-1)
71
+
72
+ video_feat = video_feat.view(B,-1,video_feat.shape[-1])
73
+ video_feats.append(video_feat.cpu())
74
+ video_embeds.append(video_embed)
75
+
76
+ video_feats = torch.cat(video_feats,dim=0)
77
+ video_embeds = torch.cat(video_embeds,dim=0)
78
+
79
+ sims_matrix = video_embeds @ text_embeds.t()
80
+ score_matrix_v2t = torch.full((len(texts),len(texts)),-100.0).to(device)
81
+
82
+ num_tasks = utils.get_world_size()
83
+ rank = utils.get_rank()
84
+ step = sims_matrix.size(0)//num_tasks + 1
85
+ start = rank*step
86
+ end = min(sims_matrix.size(0),start+step)
87
+
88
+ for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)):
89
+ topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)
90
+
91
+ encoder_output = video_feats[start+i].repeat(config['k_test'],1,1).to(device,non_blocking=True)
92
+ encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device,non_blocking=True)
93
+ output = model.text_encoder(text_ids[topk_idx],
94
+ attention_mask = text_atts[topk_idx],
95
+ encoder_hidden_states = encoder_output,
96
+ encoder_attention_mask = encoder_att,
97
+ return_dict = True,
98
+ )
99
+ score = model.itm_head(output.last_hidden_state[:,0,:])[:,1]
100
+ score_matrix_v2t[start+i,topk_idx] = score + topk_sim
101
+
102
+ sims_matrix = sims_matrix.t()
103
+ score_matrix_t2v = torch.full((len(texts),len(texts)),-100.0).to(device)
104
+
105
+ step = sims_matrix.size(0)//num_tasks + 1
106
+ start = rank*step
107
+ end = min(sims_matrix.size(0),start+step)
108
+
109
+ for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)):
110
+
111
+ topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)
112
+ encoder_output = video_feats[topk_idx].to(device,non_blocking=True)
113
+ encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device,non_blocking=True)
114
+ output = model.text_encoder(text_ids[start+i].repeat(config['k_test'],1),
115
+ attention_mask = text_atts[start+i].repeat(config['k_test'],1),
116
+ encoder_hidden_states = encoder_output,
117
+ encoder_attention_mask = encoder_att,
118
+ return_dict = True,
119
+ )
120
+ score = model.itm_head(output.last_hidden_state[:,0,:])[:,1]
121
+ score_matrix_t2v[start+i,topk_idx] = score + topk_sim
122
+
123
+ if args.distributed:
124
+ dist.barrier()
125
+ torch.distributed.all_reduce(score_matrix_v2t, op=torch.distributed.ReduceOp.SUM)
126
+ torch.distributed.all_reduce(score_matrix_t2v, op=torch.distributed.ReduceOp.SUM)
127
+
128
+ total_time = time.time() - start_time
129
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
130
+ print('Evaluation time {}'.format(total_time_str))
131
+
132
+ return score_matrix_v2t.cpu().numpy(), score_matrix_t2v.cpu().numpy()
133
+
134
+
135
+
136
+ @torch.no_grad()
137
+ def itm_eval(scores_v2t, scores_t2v, txt2vmg, vid2txt):
138
+
139
+ #Video->Text
140
+ ranks = np.zeros(scores_v2t.shape[0])
141
+ for index,score in enumerate(scores_v2t):
142
+ inds = np.argsort(score)[::-1]
143
+ ranks[index] = np.where(inds == vid2txt[index])[0][0]
144
+
145
+ # Compute metrics
146
+ tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
147
+ tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
148
+ tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
149
+
150
+ #Text->Video
151
+ ranks = np.zeros(scores_t2v.shape[0])
152
+
153
+ for index,score in enumerate(scores_t2v):
154
+ inds = np.argsort(score)[::-1]
155
+ ranks[index] = np.where(inds == txt2vmg[index])[0][0]
156
+
157
+ mdR = np.median(ranks+1)
158
+
159
+ # Compute metrics
160
+ vr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
161
+ vr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
162
+ vr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
163
+
164
+ tr_mean = (tr1 + tr5 + tr10) / 3
165
+ vr_mean = (vr1 + vr5 + vr10) / 3
166
+ r_mean = (tr_mean + vr_mean) / 2
167
+
168
+ eval_result = {'txt_r1': tr1,
169
+ 'txt_r5': tr5,
170
+ 'txt_r10': tr10,
171
+ 'txt_r_mean': tr_mean,
172
+ 'vid_r1': vr1,
173
+ 'vid_r5': vr5,
174
+ 'vid_r10': vr10,
175
+ 'vid_r_mean': vr_mean,
176
+ 'vid_mdR': mdR,
177
+ 'r_mean': r_mean}
178
+ return eval_result
179
+
180
+
181
+
182
+
183
+ def main(args, config):
184
+ utils.init_distributed_mode(args)
185
+
186
+ device = torch.device(args.device)
187
+
188
+ # fix the seed for reproducibility
189
+ seed = args.seed + utils.get_rank()
190
+ torch.manual_seed(seed)
191
+ np.random.seed(seed)
192
+ random.seed(seed)
193
+ cudnn.benchmark = True
194
+
195
+ #### Dataset ####
196
+ print("Creating retrieval dataset")
197
+ test_dataset = VideoDataset(config['video_root'],config['ann_root'],num_frm=config['num_frm_test'],
198
+ max_img_size=config['image_size'], frm_sampling_strategy='uniform')
199
+
200
+ test_loader = DataLoader(
201
+ test_dataset,
202
+ batch_size=config['batch_size'],
203
+ num_workers=4,
204
+ pin_memory=True,
205
+ drop_last=False,
206
+ shuffle=False,
207
+ )
208
+
209
+ #### Model ####
210
+ print("Creating model")
211
+ model = blip_retrieval(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'])
212
+
213
+ model = model.to(device)
214
+
215
+ model_without_ddp = model
216
+ if args.distributed:
217
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
218
+ model_without_ddp = model.module
219
+
220
+ score_v2t, score_t2v, = evaluation(model_without_ddp, test_loader, model_without_ddp.tokenizer, device, config)
221
+
222
+ if utils.is_main_process():
223
+
224
+ test_result = itm_eval(score_v2t, score_t2v, test_loader.dataset.txt2video, test_loader.dataset.video2txt)
225
+ print(test_result)
226
+
227
+ log_stats = {**{f'{k}': v for k, v in test_result.items()},}
228
+ with open(os.path.join(args.output_dir, "test_result.txt"),"a") as f:
229
+ f.write(json.dumps(log_stats) + "\n")
230
+
231
+
232
+ if __name__ == '__main__':
233
+ parser = argparse.ArgumentParser()
234
+ parser.add_argument('--config', default='./configs/retrieval_msrvtt.yaml')
235
+ parser.add_argument('--output_dir', default='output/Retrieval_msrvtt')
236
+ parser.add_argument('--device', default='cuda')
237
+ parser.add_argument('--seed', default=42, type=int)
238
+ parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
239
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
240
+ parser.add_argument('--distributed', default=True, type=bool)
241
+ args = parser.parse_args()
242
+
243
+ config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
244
+
245
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
246
+
247
+ yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
248
+
249
+ main(args, config)
models/__init__.py ADDED
File without changes
models/blip.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ '''
8
+ import warnings
9
+ warnings.filterwarnings("ignore")
10
+
11
+ from BLIP_main.models.vit import VisionTransformer, interpolate_pos_embed
12
+ from BLIP_main.models.med import BertConfig, BertModel, BertLMHeadModel
13
+ from transformers import BertTokenizer
14
+
15
+ import torch
16
+ from torch import nn
17
+
18
+ import os
19
+ from urllib.parse import urlparse
20
+ from timm.models.hub import download_cached_file
21
+
22
+ class BLIP_Base(nn.Module):
23
+ def __init__(self,
24
+ med_config = 'configs/med_config.json',
25
+ image_size = 224,
26
+ vit = 'base',
27
+ vit_grad_ckpt = False,
28
+ vit_ckpt_layer = 0,
29
+ ):
30
+ """
31
+ Args:
32
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
33
+ image_size (int): input image size
34
+ vit (str): model size of vision transformer
35
+ """
36
+ super().__init__()
37
+
38
+ self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
39
+ self.tokenizer = init_tokenizer()
40
+ med_config = BertConfig.from_json_file(med_config)
41
+ med_config.encoder_width = vision_width
42
+ self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
43
+
44
+
45
+ def forward(self, image, caption, mode):
46
+
47
+ assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal"
48
+ text = self.tokenizer(caption, return_tensors="pt").to(image.device)
49
+
50
+ if mode=='image':
51
+ # return image features
52
+ image_embeds = self.visual_encoder(image)
53
+ return image_embeds
54
+
55
+ elif mode=='text':
56
+ # return text features
57
+ text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
58
+ return_dict = True, mode = 'text')
59
+ return text_output.last_hidden_state
60
+
61
+ elif mode=='multimodal':
62
+ # return multimodel features
63
+ image_embeds = self.visual_encoder(image)
64
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
65
+
66
+ text.input_ids[:,0] = self.tokenizer.enc_token_id
67
+ output = self.text_encoder(text.input_ids,
68
+ attention_mask = text.attention_mask,
69
+ encoder_hidden_states = image_embeds,
70
+ encoder_attention_mask = image_atts,
71
+ return_dict = True,
72
+ )
73
+ return output.last_hidden_state
74
+
75
+
76
+
77
+ class BLIP_Decoder(nn.Module):
78
+ def __init__(self,
79
+ med_config = 'configs/med_config.json',
80
+ image_size = 384,
81
+ vit = 'base',
82
+ vit_grad_ckpt = False,
83
+ vit_ckpt_layer = 0,
84
+ prompt = 'a picture of ',
85
+ ):
86
+ """
87
+ Args:
88
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
89
+ image_size (int): input image size
90
+ vit (str): model size of vision transformer
91
+ """
92
+ super().__init__()
93
+
94
+ self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
95
+ self.tokenizer = init_tokenizer()
96
+ med_config = BertConfig.from_json_file(med_config)
97
+ med_config.encoder_width = vision_width
98
+ self.text_decoder = BertLMHeadModel(config=med_config)
99
+
100
+ self.prompt = prompt
101
+ self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1
102
+
103
+
104
+ def forward(self, image, caption):
105
+
106
+ image_embeds = self.visual_encoder(image)
107
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
108
+
109
+ text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device)
110
+
111
+ text.input_ids[:,0] = self.tokenizer.bos_token_id
112
+
113
+ decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100)
114
+ decoder_targets[:,:self.prompt_length] = -100
115
+
116
+ decoder_output = self.text_decoder(text.input_ids,
117
+ attention_mask = text.attention_mask,
118
+ encoder_hidden_states = image_embeds,
119
+ encoder_attention_mask = image_atts,
120
+ labels = decoder_targets,
121
+ return_dict = True,
122
+ )
123
+ loss_lm = decoder_output.loss
124
+
125
+ return loss_lm
126
+
127
+ def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0):
128
+ image_embeds = self.visual_encoder(image)
129
+
130
+ if not sample:
131
+ image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
132
+
133
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
134
+ model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts}
135
+
136
+ prompt = [self.prompt] * image.size(0)
137
+ input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device)
138
+ input_ids[:,0] = self.tokenizer.bos_token_id
139
+ input_ids = input_ids[:, :-1]
140
+
141
+ if sample:
142
+ #nucleus sampling
143
+ outputs = self.text_decoder.generate(input_ids=input_ids,
144
+ max_length=max_length,
145
+ min_length=min_length,
146
+ do_sample=True,
147
+ top_p=top_p,
148
+ num_return_sequences=1,
149
+ eos_token_id=self.tokenizer.sep_token_id,
150
+ pad_token_id=self.tokenizer.pad_token_id,
151
+ repetition_penalty=1.1,
152
+ **model_kwargs)
153
+ else:
154
+ #beam search
155
+ outputs = self.text_decoder.generate(input_ids=input_ids,
156
+ max_length=max_length,
157
+ min_length=min_length,
158
+ num_beams=num_beams,
159
+ eos_token_id=self.tokenizer.sep_token_id,
160
+ pad_token_id=self.tokenizer.pad_token_id,
161
+ repetition_penalty=repetition_penalty,
162
+ **model_kwargs)
163
+
164
+ captions = []
165
+ for output in outputs:
166
+ caption = self.tokenizer.decode(output, skip_special_tokens=True)
167
+ captions.append(caption[len(self.prompt):])
168
+ return captions
169
+
170
+
171
+ def blip_decoder(pretrained='',**kwargs):
172
+ model = BLIP_Decoder(**kwargs)
173
+ if pretrained:
174
+ model,msg = load_checkpoint(model,pretrained)
175
+ assert(len(msg.missing_keys)==0)
176
+ return model
177
+
178
+ def blip_feature_extractor(pretrained='',**kwargs):
179
+ model = BLIP_Base(**kwargs)
180
+ if pretrained:
181
+ model,msg = load_checkpoint(model,pretrained)
182
+ assert(len(msg.missing_keys)==0)
183
+ return model
184
+
185
+ def init_tokenizer():
186
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
187
+ tokenizer.add_special_tokens({'bos_token':'[DEC]'})
188
+ tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
189
+ tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
190
+ return tokenizer
191
+
192
+
193
+ def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
194
+
195
+ assert vit in ['base', 'large'], "vit parameter must be base or large"
196
+ if vit=='base':
197
+ vision_width = 768
198
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
199
+ num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
200
+ drop_path_rate=0 or drop_path_rate
201
+ )
202
+ elif vit=='large':
203
+ vision_width = 1024
204
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
205
+ num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
206
+ drop_path_rate=0.1 or drop_path_rate
207
+ )
208
+ return visual_encoder, vision_width
209
+
210
+ def is_url(url_or_filename):
211
+ parsed = urlparse(url_or_filename)
212
+ return parsed.scheme in ("http", "https")
213
+
214
+ def load_checkpoint(model,url_or_filename):
215
+ if is_url(url_or_filename):
216
+ cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
217
+ checkpoint = torch.load(cached_file, map_location='cpu')
218
+ elif os.path.isfile(url_or_filename):
219
+ checkpoint = torch.load(url_or_filename, map_location='cpu')
220
+ else:
221
+ raise RuntimeError('checkpoint url or path is invalid')
222
+
223
+ state_dict = checkpoint['model']
224
+
225
+ state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
226
+ if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
227
+ state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
228
+ model.visual_encoder_m)
229
+ for key in model.state_dict().keys():
230
+ if key in state_dict.keys():
231
+ if state_dict[key].shape!=model.state_dict()[key].shape:
232
+ del state_dict[key]
233
+
234
+ msg = model.load_state_dict(state_dict,strict=False)
235
+ print('load checkpoint from %s'%url_or_filename)
236
+ return model,msg
237
+
models/blip_itm.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from BLIP_main.models.med import BertConfig, BertModel
2
+
3
+ import torch
4
+ from torch import nn
5
+ import torch.nn.functional as F
6
+
7
+ from BLIP_main.models.blip import create_vit, init_tokenizer, load_checkpoint
8
+
9
+ class BLIP_ITM(nn.Module):
10
+ def __init__(self,
11
+ med_config = 'configs/med_config.json',
12
+ image_size = 384,
13
+ vit = 'base',
14
+ vit_grad_ckpt = False,
15
+ vit_ckpt_layer = 0,
16
+ embed_dim = 256,
17
+ ):
18
+ """
19
+ Args:
20
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
21
+ image_size (int): input image size
22
+ vit (str): model size of vision transformer
23
+ """
24
+ super().__init__()
25
+
26
+ self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
27
+ self.tokenizer = init_tokenizer()
28
+ med_config = BertConfig.from_json_file(med_config)
29
+ med_config.encoder_width = vision_width
30
+ self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
31
+
32
+ text_width = self.text_encoder.config.hidden_size
33
+
34
+ self.vision_proj = nn.Linear(vision_width, embed_dim)
35
+ self.text_proj = nn.Linear(text_width, embed_dim)
36
+
37
+ self.itm_head = nn.Linear(text_width, 2)
38
+
39
+
40
+ def forward(self, image, caption, match_head='itm'):
41
+
42
+ image_embeds = self.visual_encoder(image)
43
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
44
+
45
+ text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35,
46
+ return_tensors="pt").to(image.device)
47
+
48
+
49
+ if match_head=='itm':
50
+ output = self.text_encoder(text.input_ids,
51
+ attention_mask = text.attention_mask,
52
+ encoder_hidden_states = image_embeds,
53
+ encoder_attention_mask = image_atts,
54
+ return_dict = True,
55
+ )
56
+ itm_output = self.itm_head(output.last_hidden_state[:,0,:])
57
+ return itm_output
58
+
59
+ elif match_head=='itc':
60
+ text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
61
+ return_dict = True, mode = 'text')
62
+ image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)
63
+ text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1)
64
+
65
+ sim = image_feat @ text_feat.t()
66
+ return sim
67
+
68
+
69
+ def blip_itm(pretrained='',**kwargs):
70
+ model = BLIP_ITM(**kwargs)
71
+ if pretrained:
72
+ model,msg = load_checkpoint(model,pretrained)
73
+ assert(len(msg.missing_keys)==0)
74
+ return model
75
+
models/blip_nlvr.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from BLIP_main.models.med import BertConfig
2
+ from BLIP_main.models.nlvr_encoder import BertModel
3
+ from BLIP_main.models.vit import interpolate_pos_embed
4
+ from BLIP_main.models.blip import create_vit, init_tokenizer, is_url
5
+
6
+ from timm.models.hub import download_cached_file
7
+
8
+ import torch
9
+ from torch import nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ class BLIP_NLVR(nn.Module):
14
+ def __init__(self,
15
+ med_config = 'configs/med_config.json',
16
+ image_size = 480,
17
+ vit = 'base',
18
+ vit_grad_ckpt = False,
19
+ vit_ckpt_layer = 0,
20
+ ):
21
+ """
22
+ Args:
23
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
24
+ image_size (int): input image size
25
+ vit (str): model size of vision transformer
26
+ """
27
+ super().__init__()
28
+
29
+ self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1)
30
+ self.tokenizer = init_tokenizer()
31
+ med_config = BertConfig.from_json_file(med_config)
32
+ med_config.encoder_width = vision_width
33
+ self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
34
+
35
+ self.cls_head = nn.Sequential(
36
+ nn.Linear(self.text_encoder.config.hidden_size, self.text_encoder.config.hidden_size),
37
+ nn.ReLU(),
38
+ nn.Linear(self.text_encoder.config.hidden_size, 2)
39
+ )
40
+
41
+ def forward(self, image, text, targets, train=True):
42
+
43
+ image_embeds = self.visual_encoder(image)
44
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
45
+ image0_embeds, image1_embeds = torch.split(image_embeds,targets.size(0))
46
+
47
+ text = self.tokenizer(text, padding='longest', return_tensors="pt").to(image.device)
48
+ text.input_ids[:,0] = self.tokenizer.enc_token_id
49
+
50
+ output = self.text_encoder(text.input_ids,
51
+ attention_mask = text.attention_mask,
52
+ encoder_hidden_states = [image0_embeds,image1_embeds],
53
+ encoder_attention_mask = [image_atts[:image0_embeds.size(0)],
54
+ image_atts[image0_embeds.size(0):]],
55
+ return_dict = True,
56
+ )
57
+ hidden_state = output.last_hidden_state[:,0,:]
58
+ prediction = self.cls_head(hidden_state)
59
+
60
+ if train:
61
+ loss = F.cross_entropy(prediction, targets)
62
+ return loss
63
+ else:
64
+ return prediction
65
+
66
+ def blip_nlvr(pretrained='',**kwargs):
67
+ model = BLIP_NLVR(**kwargs)
68
+ if pretrained:
69
+ model,msg = load_checkpoint(model,pretrained)
70
+ print("missing keys:")
71
+ print(msg.missing_keys)
72
+ return model
73
+
74
+
75
+ def load_checkpoint(model,url_or_filename):
76
+ if is_url(url_or_filename):
77
+ cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
78
+ checkpoint = torch.load(cached_file, map_location='cpu')
79
+ elif os.path.isfile(url_or_filename):
80
+ checkpoint = torch.load(url_or_filename, map_location='cpu')
81
+ else:
82
+ raise RuntimeError('checkpoint url or path is invalid')
83
+ state_dict = checkpoint['model']
84
+
85
+ state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
86
+
87
+ for key in list(state_dict.keys()):
88
+ if 'crossattention.self.' in key:
89
+ new_key0 = key.replace('self','self0')
90
+ new_key1 = key.replace('self','self1')
91
+ state_dict[new_key0] = state_dict[key]
92
+ state_dict[new_key1] = state_dict[key]
93
+ elif 'crossattention.output.dense.' in key:
94
+ new_key0 = key.replace('dense','dense0')
95
+ new_key1 = key.replace('dense','dense1')
96
+ state_dict[new_key0] = state_dict[key]
97
+ state_dict[new_key1] = state_dict[key]
98
+
99
+ msg = model.load_state_dict(state_dict,strict=False)
100
+ print('load checkpoint from %s'%url_or_filename)
101
+ return model,msg
102
+
models/blip_pretrain.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ '''
8
+ from BLIP_main.models.med import BertConfig, BertModel, BertLMHeadModel
9
+ import transformers
10
+ transformers.logging.set_verbosity_error()
11
+
12
+ import torch
13
+ from torch import nn
14
+ import torch.nn.functional as F
15
+
16
+ from BLIP_main.models.blip import create_vit, init_tokenizer
17
+
18
+
19
+ class BLIP_Pretrain(nn.Module):
20
+ def __init__(self,
21
+ med_config = 'configs/bert_config.json',
22
+ image_size = 224,
23
+ vit = 'base',
24
+ vit_grad_ckpt = False,
25
+ vit_ckpt_layer = 0,
26
+ embed_dim = 256,
27
+ queue_size = 57600,
28
+ momentum = 0.995,
29
+ ):
30
+ """
31
+ Args:
32
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
33
+ image_size (int): input image size
34
+ vit (str): model size of vision transformer
35
+ """
36
+ super().__init__()
37
+
38
+ self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, 0)
39
+
40
+ if vit=='base':
41
+ checkpoint = torch.hub.load_state_dict_from_url(
42
+ url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
43
+ map_location="cpu", check_hash=True)
44
+ state_dict = checkpoint["model"]
45
+ msg = self.visual_encoder.load_state_dict(state_dict,strict=False)
46
+ elif vit=='large':
47
+ from timm.models.helpers import load_custom_pretrained
48
+ from timm.models.vision_transformer import default_cfgs
49
+ load_custom_pretrained(self.visual_encoder,default_cfgs['vit_large_patch16_224_in21k'])
50
+
51
+ self.tokenizer = init_tokenizer()
52
+ encoder_config = BertConfig.from_json_file(med_config)
53
+ encoder_config.encoder_width = vision_width
54
+ self.text_encoder = BertModel.from_pretrained('bert-base-uncased',config=encoder_config, add_pooling_layer=False)
55
+ self.text_encoder.resize_token_embeddings(len(self.tokenizer))
56
+
57
+ text_width = self.text_encoder.config.hidden_size
58
+
59
+ self.vision_proj = nn.Linear(vision_width, embed_dim)
60
+ self.text_proj = nn.Linear(text_width, embed_dim)
61
+
62
+ self.itm_head = nn.Linear(text_width, 2)
63
+
64
+ # create momentum encoders
65
+ self.visual_encoder_m, vision_width = create_vit(vit,image_size)
66
+ self.vision_proj_m = nn.Linear(vision_width, embed_dim)
67
+ self.text_encoder_m = BertModel(config=encoder_config, add_pooling_layer=False)
68
+ self.text_proj_m = nn.Linear(text_width, embed_dim)
69
+
70
+ self.model_pairs = [[self.visual_encoder,self.visual_encoder_m],
71
+ [self.vision_proj,self.vision_proj_m],
72
+ [self.text_encoder,self.text_encoder_m],
73
+ [self.text_proj,self.text_proj_m],
74
+ ]
75
+ self.copy_params()
76
+
77
+ # create the queue
78
+ self.register_buffer("image_queue", torch.randn(embed_dim, queue_size))
79
+ self.register_buffer("text_queue", torch.randn(embed_dim, queue_size))
80
+ self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
81
+
82
+ self.image_queue = nn.functional.normalize(self.image_queue, dim=0)
83
+ self.text_queue = nn.functional.normalize(self.text_queue, dim=0)
84
+
85
+ self.queue_size = queue_size
86
+ self.momentum = momentum
87
+ self.temp = nn.Parameter(0.07*torch.ones([]))
88
+
89
+ # create the decoder
90
+ decoder_config = BertConfig.from_json_file(med_config)
91
+ decoder_config.encoder_width = vision_width
92
+ self.text_decoder = BertLMHeadModel.from_pretrained('bert-base-uncased',config=decoder_config)
93
+ self.text_decoder.resize_token_embeddings(len(self.tokenizer))
94
+ tie_encoder_decoder_weights(self.text_encoder,self.text_decoder.bert,'','/attention')
95
+
96
+
97
+ def forward(self, image, caption, alpha):
98
+ with torch.no_grad():
99
+ self.temp.clamp_(0.001,0.5)
100
+
101
+ image_embeds = self.visual_encoder(image)
102
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
103
+ image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)
104
+
105
+ text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=30,
106
+ return_tensors="pt").to(image.device)
107
+ text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
108
+ return_dict = True, mode = 'text')
109
+ text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1)
110
+
111
+ # get momentum features
112
+ with torch.no_grad():
113
+ self._momentum_update()
114
+ image_embeds_m = self.visual_encoder_m(image)
115
+ image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1)
116
+ image_feat_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1)
117
+
118
+ text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask,
119
+ return_dict = True, mode = 'text')
120
+ text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1)
121
+ text_feat_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1)
122
+
123
+ sim_i2t_m = image_feat_m @ text_feat_all / self.temp
124
+ sim_t2i_m = text_feat_m @ image_feat_all / self.temp
125
+
126
+ sim_targets = torch.zeros(sim_i2t_m.size()).to(image.device)
127
+ sim_targets.fill_diagonal_(1)
128
+
129
+ sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets
130
+ sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets
131
+
132
+ sim_i2t = image_feat @ text_feat_all / self.temp
133
+ sim_t2i = text_feat @ image_feat_all / self.temp
134
+
135
+ loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean()
136
+ loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean()
137
+
138
+ loss_ita = (loss_i2t+loss_t2i)/2
139
+
140
+ self._dequeue_and_enqueue(image_feat_m, text_feat_m)
141
+
142
+ ###============== Image-text Matching ===================###
143
+ encoder_input_ids = text.input_ids.clone()
144
+ encoder_input_ids[:,0] = self.tokenizer.enc_token_id
145
+
146
+ # forward the positve image-text pair
147
+ bs = image.size(0)
148
+ output_pos = self.text_encoder(encoder_input_ids,
149
+ attention_mask = text.attention_mask,
150
+ encoder_hidden_states = image_embeds,
151
+ encoder_attention_mask = image_atts,
152
+ return_dict = True,
153
+ )
154
+ with torch.no_grad():
155
+ weights_t2i = F.softmax(sim_t2i[:,:bs],dim=1)+1e-4
156
+ weights_t2i.fill_diagonal_(0)
157
+ weights_i2t = F.softmax(sim_i2t[:,:bs],dim=1)+1e-4
158
+ weights_i2t.fill_diagonal_(0)
159
+
160
+ # select a negative image for each text
161
+ image_embeds_neg = []
162
+ for b in range(bs):
163
+ neg_idx = torch.multinomial(weights_t2i[b], 1).item()
164
+ image_embeds_neg.append(image_embeds[neg_idx])
165
+ image_embeds_neg = torch.stack(image_embeds_neg,dim=0)
166
+
167
+ # select a negative text for each image
168
+ text_ids_neg = []
169
+ text_atts_neg = []
170
+ for b in range(bs):
171
+ neg_idx = torch.multinomial(weights_i2t[b], 1).item()
172
+ text_ids_neg.append(encoder_input_ids[neg_idx])
173
+ text_atts_neg.append(text.attention_mask[neg_idx])
174
+
175
+ text_ids_neg = torch.stack(text_ids_neg,dim=0)
176
+ text_atts_neg = torch.stack(text_atts_neg,dim=0)
177
+
178
+ text_ids_all = torch.cat([encoder_input_ids, text_ids_neg],dim=0)
179
+ text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0)
180
+
181
+ image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0)
182
+ image_atts_all = torch.cat([image_atts,image_atts],dim=0)
183
+
184
+ output_neg = self.text_encoder(text_ids_all,
185
+ attention_mask = text_atts_all,
186
+ encoder_hidden_states = image_embeds_all,
187
+ encoder_attention_mask = image_atts_all,
188
+ return_dict = True,
189
+ )
190
+
191
+ vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0)
192
+ vl_output = self.itm_head(vl_embeddings)
193
+
194
+ itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)],
195
+ dim=0).to(image.device)
196
+ loss_itm = F.cross_entropy(vl_output, itm_labels)
197
+
198
+ ##================= LM ========================##
199
+ decoder_input_ids = text.input_ids.clone()
200
+ decoder_input_ids[:,0] = self.tokenizer.bos_token_id
201
+ decoder_targets = decoder_input_ids.masked_fill(decoder_input_ids == self.tokenizer.pad_token_id, -100)
202
+
203
+ decoder_output = self.text_decoder(decoder_input_ids,
204
+ attention_mask = text.attention_mask,
205
+ encoder_hidden_states = image_embeds,
206
+ encoder_attention_mask = image_atts,
207
+ labels = decoder_targets,
208
+ return_dict = True,
209
+ )
210
+
211
+ loss_lm = decoder_output.loss
212
+ return loss_ita, loss_itm, loss_lm
213
+
214
+
215
+
216
+ @torch.no_grad()
217
+ def copy_params(self):
218
+ for model_pair in self.model_pairs:
219
+ for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
220
+ param_m.data.copy_(param.data) # initialize
221
+ param_m.requires_grad = False # not update by gradient
222
+
223
+
224
+ @torch.no_grad()
225
+ def _momentum_update(self):
226
+ for model_pair in self.model_pairs:
227
+ for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
228
+ param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum)
229
+
230
+
231
+ @torch.no_grad()
232
+ def _dequeue_and_enqueue(self, image_feat, text_feat):
233
+ # gather keys before updating queue
234
+ image_feats = concat_all_gather(image_feat)
235
+ text_feats = concat_all_gather(text_feat)
236
+
237
+ batch_size = image_feats.shape[0]
238
+
239
+ ptr = int(self.queue_ptr)
240
+ assert self.queue_size % batch_size == 0 # for simplicity
241
+
242
+ # replace the keys at ptr (dequeue and enqueue)
243
+ self.image_queue[:, ptr:ptr + batch_size] = image_feats.T
244
+ self.text_queue[:, ptr:ptr + batch_size] = text_feats.T
245
+ ptr = (ptr + batch_size) % self.queue_size # move pointer
246
+
247
+ self.queue_ptr[0] = ptr
248
+
249
+
250
+ def blip_pretrain(**kwargs):
251
+ model = BLIP_Pretrain(**kwargs)
252
+ return model
253
+
254
+
255
+ @torch.no_grad()
256
+ def concat_all_gather(tensor):
257
+ """
258
+ Performs all_gather operation on the provided tensors.
259
+ *** Warning ***: torch.distributed.all_gather has no gradient.
260
+ """
261
+ tensors_gather = [torch.ones_like(tensor)
262
+ for _ in range(torch.distributed.get_world_size())]
263
+ torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
264
+
265
+ output = torch.cat(tensors_gather, dim=0)
266
+ return output
267
+
268
+
269
+ from typing import List
270
+ def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, skip_key:str):
271
+ uninitialized_encoder_weights: List[str] = []
272
+ if decoder.__class__ != encoder.__class__:
273
+ logger.info(
274
+ f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder weights are correctly initialized."
275
+ )
276
+
277
+ def tie_encoder_to_decoder_recursively(
278
+ decoder_pointer: nn.Module,
279
+ encoder_pointer: nn.Module,
280
+ module_name: str,
281
+ uninitialized_encoder_weights: List[str],
282
+ skip_key: str,
283
+ depth=0,
284
+ ):
285
+ assert isinstance(decoder_pointer, nn.Module) and isinstance(
286
+ encoder_pointer, nn.Module
287
+ ), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module"
288
+ if hasattr(decoder_pointer, "weight") and skip_key not in module_name:
289
+ assert hasattr(encoder_pointer, "weight")
290
+ encoder_pointer.weight = decoder_pointer.weight
291
+ if hasattr(decoder_pointer, "bias"):
292
+ assert hasattr(encoder_pointer, "bias")
293
+ encoder_pointer.bias = decoder_pointer.bias
294
+ print(module_name+' is tied')
295
+ return
296
+
297
+ encoder_modules = encoder_pointer._modules
298
+ decoder_modules = decoder_pointer._modules
299
+ if len(decoder_modules) > 0:
300
+ assert (
301
+ len(encoder_modules) > 0
302
+ ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
303
+
304
+ all_encoder_weights = set([module_name + "/" + sub_name for sub_name in encoder_modules.keys()])
305
+ encoder_layer_pos = 0
306
+ for name, module in decoder_modules.items():
307
+ if name.isdigit():
308
+ encoder_name = str(int(name) + encoder_layer_pos)
309
+ decoder_name = name
310
+ if not isinstance(decoder_modules[decoder_name], type(encoder_modules[encoder_name])) and len(
311
+ encoder_modules
312
+ ) != len(decoder_modules):
313
+ # this can happen if the name corresponds to the position in a list module list of layers
314
+ # in this case the decoder has added a cross-attention that the encoder does not have
315
+ # thus skip this step and subtract one layer pos from encoder
316
+ encoder_layer_pos -= 1
317
+ continue
318
+ elif name not in encoder_modules:
319
+ continue
320
+ elif depth > 500:
321
+ raise ValueError(
322
+ "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model."
323
+ )
324
+ else:
325
+ decoder_name = encoder_name = name
326
+ tie_encoder_to_decoder_recursively(
327
+ decoder_modules[decoder_name],
328
+ encoder_modules[encoder_name],
329
+ module_name + "/" + name,
330
+ uninitialized_encoder_weights,
331
+ skip_key,
332
+ depth=depth + 1,
333
+ )
334
+ all_encoder_weights.remove(module_name + "/" + encoder_name)
335
+
336
+ uninitialized_encoder_weights += list(all_encoder_weights)
337
+
338
+ # tie weights recursively
339
+ tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights, skip_key)
models/blip_retrieval.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from BLIP_main.models.med import BertConfig, BertModel
2
+
3
+ import torch
4
+ from torch import nn
5
+ import torch.nn.functional as F
6
+
7
+ from BLIP_main.models.blip import create_vit, init_tokenizer, load_checkpoint
8
+
9
+ class BLIP_Retrieval(nn.Module):
10
+ def __init__(self,
11
+ med_config = 'configs/med_config.json',
12
+ image_size = 384,
13
+ vit = 'base',
14
+ vit_grad_ckpt = False,
15
+ vit_ckpt_layer = 0,
16
+ embed_dim = 256,
17
+ queue_size = 57600,
18
+ momentum = 0.995,
19
+ negative_all_rank = False,
20
+ ):
21
+ """
22
+ Args:
23
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
24
+ image_size (int): input image size
25
+ vit (str): model size of vision transformer
26
+ """
27
+ super().__init__()
28
+
29
+ self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
30
+ self.tokenizer = init_tokenizer()
31
+ med_config = BertConfig.from_json_file(med_config)
32
+ med_config.encoder_width = vision_width
33
+ self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
34
+
35
+ text_width = self.text_encoder.config.hidden_size
36
+
37
+ self.vision_proj = nn.Linear(vision_width, embed_dim)
38
+ self.text_proj = nn.Linear(text_width, embed_dim)
39
+
40
+ self.itm_head = nn.Linear(text_width, 2)
41
+
42
+ # create momentum encoders
43
+ self.visual_encoder_m, vision_width = create_vit(vit,image_size)
44
+ self.vision_proj_m = nn.Linear(vision_width, embed_dim)
45
+ self.text_encoder_m = BertModel(config=med_config, add_pooling_layer=False)
46
+ self.text_proj_m = nn.Linear(text_width, embed_dim)
47
+
48
+ self.model_pairs = [[self.visual_encoder,self.visual_encoder_m],
49
+ [self.vision_proj,self.vision_proj_m],
50
+ [self.text_encoder,self.text_encoder_m],
51
+ [self.text_proj,self.text_proj_m],
52
+ ]
53
+ self.copy_params()
54
+
55
+ # create the queue
56
+ self.register_buffer("image_queue", torch.randn(embed_dim, queue_size))
57
+ self.register_buffer("text_queue", torch.randn(embed_dim, queue_size))
58
+ self.register_buffer("idx_queue", torch.full((1,queue_size),-100))
59
+ self.register_buffer("ptr_queue", torch.zeros(1, dtype=torch.long))
60
+
61
+ self.image_queue = nn.functional.normalize(self.image_queue, dim=0)
62
+ self.text_queue = nn.functional.normalize(self.text_queue, dim=0)
63
+
64
+ self.queue_size = queue_size
65
+ self.momentum = momentum
66
+ self.temp = nn.Parameter(0.07*torch.ones([]))
67
+
68
+ self.negative_all_rank = negative_all_rank
69
+
70
+
71
+ def forward(self, image, caption, alpha, idx):
72
+ with torch.no_grad():
73
+ self.temp.clamp_(0.001,0.5)
74
+
75
+ image_embeds = self.visual_encoder(image)
76
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
77
+ image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)
78
+
79
+ text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35,
80
+ return_tensors="pt").to(image.device)
81
+
82
+ text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
83
+ return_dict = True, mode = 'text')
84
+ text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1)
85
+
86
+ ###============== Image-text Contrastive Learning ===================###
87
+ idx = idx.view(-1,1)
88
+ idx_all = torch.cat([idx.t(), self.idx_queue.clone().detach()],dim=1)
89
+ pos_idx = torch.eq(idx, idx_all).float()
90
+ sim_targets = pos_idx / pos_idx.sum(1,keepdim=True)
91
+
92
+ # get momentum features
93
+ with torch.no_grad():
94
+ self._momentum_update()
95
+ image_embeds_m = self.visual_encoder_m(image)
96
+ image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1)
97
+ image_feat_m_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1)
98
+
99
+ text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask,
100
+ return_dict = True, mode = 'text')
101
+ text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1)
102
+ text_feat_m_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1)
103
+
104
+ sim_i2t_m = image_feat_m @ text_feat_m_all / self.temp
105
+ sim_t2i_m = text_feat_m @ image_feat_m_all / self.temp
106
+
107
+ sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets
108
+ sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets
109
+
110
+ sim_i2t = image_feat @ text_feat_m_all / self.temp
111
+ sim_t2i = text_feat @ image_feat_m_all / self.temp
112
+
113
+ loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean()
114
+ loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean()
115
+
116
+ loss_ita = (loss_i2t+loss_t2i)/2
117
+
118
+ idxs = concat_all_gather(idx)
119
+ self._dequeue_and_enqueue(image_feat_m, text_feat_m, idxs)
120
+
121
+ ###============== Image-text Matching ===================###
122
+ encoder_input_ids = text.input_ids.clone()
123
+ encoder_input_ids[:,0] = self.tokenizer.enc_token_id
124
+
125
+ # forward the positve image-text pair
126
+ bs = image.size(0)
127
+ output_pos = self.text_encoder(encoder_input_ids,
128
+ attention_mask = text.attention_mask,
129
+ encoder_hidden_states = image_embeds,
130
+ encoder_attention_mask = image_atts,
131
+ return_dict = True,
132
+ )
133
+
134
+
135
+ if self.negative_all_rank:
136
+ # compute sample similarity
137
+ with torch.no_grad():
138
+ mask = torch.eq(idx, idxs.t())
139
+
140
+ image_feat_world = concat_all_gather(image_feat)
141
+ text_feat_world = concat_all_gather(text_feat)
142
+
143
+ sim_i2t = image_feat @ text_feat_world.t() / self.temp
144
+ sim_t2i = text_feat @ image_feat_world.t() / self.temp
145
+
146
+ weights_i2t = F.softmax(sim_i2t,dim=1)
147
+ weights_i2t.masked_fill_(mask, 0)
148
+
149
+ weights_t2i = F.softmax(sim_t2i,dim=1)
150
+ weights_t2i.masked_fill_(mask, 0)
151
+
152
+ image_embeds_world = all_gather_with_grad(image_embeds)
153
+
154
+ # select a negative image (from all ranks) for each text
155
+ image_embeds_neg = []
156
+ for b in range(bs):
157
+ neg_idx = torch.multinomial(weights_t2i[b], 1).item()
158
+ image_embeds_neg.append(image_embeds_world[neg_idx])
159
+ image_embeds_neg = torch.stack(image_embeds_neg,dim=0)
160
+
161
+ # select a negative text (from all ranks) for each image
162
+ input_ids_world = concat_all_gather(encoder_input_ids)
163
+ att_mask_world = concat_all_gather(text.attention_mask)
164
+
165
+ text_ids_neg = []
166
+ text_atts_neg = []
167
+ for b in range(bs):
168
+ neg_idx = torch.multinomial(weights_i2t[b], 1).item()
169
+ text_ids_neg.append(input_ids_world[neg_idx])
170
+ text_atts_neg.append(att_mask_world[neg_idx])
171
+
172
+ else:
173
+ with torch.no_grad():
174
+ mask = torch.eq(idx, idx.t())
175
+
176
+ sim_i2t = image_feat @ text_feat.t() / self.temp
177
+ sim_t2i = text_feat @ image_feat.t() / self.temp
178
+
179
+ weights_i2t = F.softmax(sim_i2t,dim=1)
180
+ weights_i2t.masked_fill_(mask, 0)
181
+
182
+ weights_t2i = F.softmax(sim_t2i,dim=1)
183
+ weights_t2i.masked_fill_(mask, 0)
184
+
185
+ # select a negative image (from same rank) for each text
186
+ image_embeds_neg = []
187
+ for b in range(bs):
188
+ neg_idx = torch.multinomial(weights_t2i[b], 1).item()
189
+ image_embeds_neg.append(image_embeds[neg_idx])
190
+ image_embeds_neg = torch.stack(image_embeds_neg,dim=0)
191
+
192
+ # select a negative text (from same rank) for each image
193
+ text_ids_neg = []
194
+ text_atts_neg = []
195
+ for b in range(bs):
196
+ neg_idx = torch.multinomial(weights_i2t[b], 1).item()
197
+ text_ids_neg.append(encoder_input_ids[neg_idx])
198
+ text_atts_neg.append(text.attention_mask[neg_idx])
199
+
200
+ text_ids_neg = torch.stack(text_ids_neg,dim=0)
201
+ text_atts_neg = torch.stack(text_atts_neg,dim=0)
202
+
203
+ text_ids_all = torch.cat([encoder_input_ids, text_ids_neg],dim=0)
204
+ text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0)
205
+
206
+ image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0)
207
+ image_atts_all = torch.cat([image_atts,image_atts],dim=0)
208
+
209
+ output_neg = self.text_encoder(text_ids_all,
210
+ attention_mask = text_atts_all,
211
+ encoder_hidden_states = image_embeds_all,
212
+ encoder_attention_mask = image_atts_all,
213
+ return_dict = True,
214
+ )
215
+
216
+
217
+ vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0)
218
+ vl_output = self.itm_head(vl_embeddings)
219
+
220
+ itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)],
221
+ dim=0).to(image.device)
222
+ loss_itm = F.cross_entropy(vl_output, itm_labels)
223
+
224
+ return loss_ita, loss_itm
225
+
226
+
227
+ @torch.no_grad()
228
+ def copy_params(self):
229
+ for model_pair in self.model_pairs:
230
+ for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
231
+ param_m.data.copy_(param.data) # initialize
232
+ param_m.requires_grad = False # not update by gradient
233
+
234
+
235
+ @torch.no_grad()
236
+ def _momentum_update(self):
237
+ for model_pair in self.model_pairs:
238
+ for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
239
+ param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum)
240
+
241
+
242
+ @torch.no_grad()
243
+ def _dequeue_and_enqueue(self, image_feat, text_feat, idxs):
244
+ # gather keys before updating queue
245
+ image_feats = concat_all_gather(image_feat)
246
+ text_feats = concat_all_gather(text_feat)
247
+
248
+
249
+ batch_size = image_feats.shape[0]
250
+
251
+ ptr = int(self.ptr_queue)
252
+ assert self.queue_size % batch_size == 0 # for simplicity
253
+
254
+ # replace the keys at ptr (dequeue and enqueue)
255
+ self.image_queue[:, ptr:ptr + batch_size] = image_feats.T
256
+ self.text_queue[:, ptr:ptr + batch_size] = text_feats.T
257
+ self.idx_queue[:, ptr:ptr + batch_size] = idxs.T
258
+ ptr = (ptr + batch_size) % self.queue_size # move pointer
259
+
260
+ self.ptr_queue[0] = ptr
261
+
262
+
263
+ def blip_retrieval(pretrained='',**kwargs):
264
+ model = BLIP_Retrieval(**kwargs)
265
+ if pretrained:
266
+ model,msg = load_checkpoint(model,pretrained)
267
+ print("missing keys:")
268
+ print(msg.missing_keys)
269
+ return model
270
+
271
+
272
+ @torch.no_grad()
273
+ def concat_all_gather(tensor):
274
+ """
275
+ Performs all_gather operation on the provided tensors.
276
+ *** Warning ***: torch.distributed.all_gather has no gradient.
277
+ """
278
+ tensors_gather = [torch.ones_like(tensor)
279
+ for _ in range(torch.distributed.get_world_size())]
280
+ torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
281
+
282
+ output = torch.cat(tensors_gather, dim=0)
283
+ return output
284
+
285
+
286
+ class GatherLayer(torch.autograd.Function):
287
+ """
288
+ Gather tensors from all workers with support for backward propagation:
289
+ This implementation does not cut the gradients as torch.distributed.all_gather does.
290
+ """
291
+
292
+ @staticmethod
293
+ def forward(ctx, x):
294
+ output = [torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())]
295
+ torch.distributed.all_gather(output, x)
296
+ return tuple(output)
297
+
298
+ @staticmethod
299
+ def backward(ctx, *grads):
300
+ all_gradients = torch.stack(grads)
301
+ torch.distributed.all_reduce(all_gradients)
302
+ return all_gradients[torch.distributed.get_rank()]
303
+
304
+
305
+ def all_gather_with_grad(tensors):
306
+ """
307
+ Performs all_gather operation on the provided tensors.
308
+ Graph remains connected for backward grad computation.
309
+ """
310
+ # Queue the gathered tensors
311
+ world_size = torch.distributed.get_world_size()
312
+ # There is no need for reduction in the single-proc case
313
+ if world_size == 1:
314
+ return tensors
315
+
316
+ tensor_all = GatherLayer.apply(tensors)
317
+
318
+ return torch.cat(tensor_all, dim=0)
models/blip_vqa.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from BLIP_main.models.med import BertConfig, BertModel, BertLMHeadModel
2
+ from BLIP_main.models.blip import create_vit, init_tokenizer, load_checkpoint
3
+
4
+ import torch
5
+ from torch import nn
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+
9
+ class BLIP_VQA(nn.Module):
10
+ def __init__(self,
11
+ med_config = 'configs/med_config.json',
12
+ image_size = 480,
13
+ vit = 'base',
14
+ vit_grad_ckpt = False,
15
+ vit_ckpt_layer = 0,
16
+ ):
17
+ """
18
+ Args:
19
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
20
+ image_size (int): input image size
21
+ vit (str): model size of vision transformer
22
+ """
23
+ super().__init__()
24
+
25
+ self.visual_encoder, vision_width = create_vit(vit, image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1)
26
+ self.tokenizer = init_tokenizer()
27
+
28
+ encoder_config = BertConfig.from_json_file(med_config)
29
+ encoder_config.encoder_width = vision_width
30
+ self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False)
31
+
32
+ decoder_config = BertConfig.from_json_file(med_config)
33
+ self.text_decoder = BertLMHeadModel(config=decoder_config)
34
+
35
+
36
+ def forward(self, image, question, answer=None, n=None, weights=None, train=True, inference='rank', k_test=128):
37
+
38
+ image_embeds = self.visual_encoder(image)
39
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
40
+
41
+ question = self.tokenizer(question, padding='longest', truncation=True, max_length=35,
42
+ return_tensors="pt").to(image.device)
43
+ question.input_ids[:,0] = self.tokenizer.enc_token_id
44
+
45
+ if train:
46
+ '''
47
+ n: number of answers for each question
48
+ weights: weight for each answer
49
+ '''
50
+ answer = self.tokenizer(answer, padding='longest', return_tensors="pt").to(image.device)
51
+ answer.input_ids[:,0] = self.tokenizer.bos_token_id
52
+ answer_targets = answer.input_ids.masked_fill(answer.input_ids == self.tokenizer.pad_token_id, -100)
53
+
54
+ question_output = self.text_encoder(question.input_ids,
55
+ attention_mask = question.attention_mask,
56
+ encoder_hidden_states = image_embeds,
57
+ encoder_attention_mask = image_atts,
58
+ return_dict = True)
59
+
60
+ question_states = []
61
+ question_atts = []
62
+ for b, n in enumerate(n):
63
+ question_states += [question_output.last_hidden_state[b]]*n
64
+ question_atts += [question.attention_mask[b]]*n
65
+ question_states = torch.stack(question_states,0)
66
+ question_atts = torch.stack(question_atts,0)
67
+
68
+ answer_output = self.text_decoder(answer.input_ids,
69
+ attention_mask = answer.attention_mask,
70
+ encoder_hidden_states = question_states,
71
+ encoder_attention_mask = question_atts,
72
+ labels = answer_targets,
73
+ return_dict = True,
74
+ reduction = 'none',
75
+ )
76
+
77
+ loss = weights * answer_output.loss
78
+ loss = loss.sum()/image.size(0)
79
+
80
+ return loss
81
+
82
+
83
+ else:
84
+ question_output = self.text_encoder(question.input_ids,
85
+ attention_mask = question.attention_mask,
86
+ encoder_hidden_states = image_embeds,
87
+ encoder_attention_mask = image_atts,
88
+ return_dict = True)
89
+
90
+ if inference=='generate':
91
+ num_beams = 3
92
+ question_states = question_output.last_hidden_state.repeat_interleave(num_beams,dim=0)
93
+ question_atts = torch.ones(question_states.size()[:-1],dtype=torch.long).to(question_states.device)
94
+ model_kwargs = {"encoder_hidden_states": question_states, "encoder_attention_mask":question_atts}
95
+
96
+ bos_ids = torch.full((image.size(0),1),fill_value=self.tokenizer.bos_token_id,device=image.device)
97
+
98
+ outputs = self.text_decoder.generate(input_ids=bos_ids,
99
+ max_length=10,
100
+ min_length=1,
101
+ num_beams=num_beams,
102
+ eos_token_id=self.tokenizer.sep_token_id,
103
+ pad_token_id=self.tokenizer.pad_token_id,
104
+ **model_kwargs)
105
+
106
+ answers = []
107
+ for output in outputs:
108
+ answer = self.tokenizer.decode(output, skip_special_tokens=True)
109
+ answers.append(answer)
110
+ return answers
111
+
112
+ elif inference=='rank':
113
+ max_ids = self.rank_answer(question_output.last_hidden_state, question.attention_mask,
114
+ answer.input_ids, answer.attention_mask, k_test)
115
+ return max_ids
116
+
117
+
118
+
119
+ def rank_answer(self, question_states, question_atts, answer_ids, answer_atts, k):
120
+
121
+ num_ques = question_states.size(0)
122
+ start_ids = answer_ids[0,0].repeat(num_ques,1) # bos token
123
+
124
+ start_output = self.text_decoder(start_ids,
125
+ encoder_hidden_states = question_states,
126
+ encoder_attention_mask = question_atts,
127
+ return_dict = True,
128
+ reduction = 'none')
129
+ logits = start_output.logits[:,0,:] # first token's logit
130
+
131
+ # topk_probs: top-k probability
132
+ # topk_ids: [num_question, k]
133
+ answer_first_token = answer_ids[:,1]
134
+ prob_first_token = F.softmax(logits,dim=1).index_select(dim=1, index=answer_first_token)
135
+ topk_probs, topk_ids = prob_first_token.topk(k,dim=1)
136
+
137
+ # answer input: [num_question*k, answer_len]
138
+ input_ids = []
139
+ input_atts = []
140
+ for b, topk_id in enumerate(topk_ids):
141
+ input_ids.append(answer_ids.index_select(dim=0, index=topk_id))
142
+ input_atts.append(answer_atts.index_select(dim=0, index=topk_id))
143
+ input_ids = torch.cat(input_ids,dim=0)
144
+ input_atts = torch.cat(input_atts,dim=0)
145
+
146
+ targets_ids = input_ids.masked_fill(input_ids == self.tokenizer.pad_token_id, -100)
147
+
148
+ # repeat encoder's output for top-k answers
149
+ question_states = tile(question_states, 0, k)
150
+ question_atts = tile(question_atts, 0, k)
151
+
152
+ output = self.text_decoder(input_ids,
153
+ attention_mask = input_atts,
154
+ encoder_hidden_states = question_states,
155
+ encoder_attention_mask = question_atts,
156
+ labels = targets_ids,
157
+ return_dict = True,
158
+ reduction = 'none')
159
+
160
+ log_probs_sum = -output.loss
161
+ log_probs_sum = log_probs_sum.view(num_ques,k)
162
+
163
+ max_topk_ids = log_probs_sum.argmax(dim=1)
164
+ max_ids = topk_ids[max_topk_ids>=0,max_topk_ids]
165
+
166
+ return max_ids
167
+
168
+
169
+ def blip_vqa(pretrained='',**kwargs):
170
+ model = BLIP_VQA(**kwargs)
171
+ if pretrained:
172
+ model,msg = load_checkpoint(model,pretrained)
173
+ # assert(len(msg.missing_keys)==0)
174
+ return model
175
+
176
+
177
+ def tile(x, dim, n_tile):
178
+ init_dim = x.size(dim)
179
+ repeat_idx = [1] * x.dim()
180
+ repeat_idx[dim] = n_tile
181
+ x = x.repeat(*(repeat_idx))
182
+ order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
183
+ return torch.index_select(x, dim, order_index.to(x.device))
184
+
185
+
models/med.py ADDED
@@ -0,0 +1,955 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ * Based on huggingface code base
8
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
9
+ '''
10
+
11
+ import math
12
+ import os
13
+ import warnings
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple
16
+
17
+ import torch
18
+ from torch import Tensor, device, dtype, nn
19
+ import torch.utils.checkpoint
20
+ from torch import nn
21
+ from torch.nn import CrossEntropyLoss
22
+ import torch.nn.functional as F
23
+
24
+ from transformers.activations import ACT2FN
25
+ from transformers.file_utils import (
26
+ ModelOutput,
27
+ )
28
+ from transformers.modeling_outputs import (
29
+ BaseModelOutputWithPastAndCrossAttentions,
30
+ BaseModelOutputWithPoolingAndCrossAttentions,
31
+ CausalLMOutputWithCrossAttentions,
32
+ MaskedLMOutput,
33
+ MultipleChoiceModelOutput,
34
+ NextSentencePredictorOutput,
35
+ QuestionAnsweringModelOutput,
36
+ SequenceClassifierOutput,
37
+ TokenClassifierOutput,
38
+ )
39
+ from transformers.modeling_utils import (
40
+ PreTrainedModel,
41
+ apply_chunking_to_forward,
42
+ find_pruneable_heads_and_indices,
43
+ prune_linear_layer,
44
+ )
45
+ from transformers.utils import logging
46
+ from transformers.models.bert.configuration_bert import BertConfig
47
+
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+
52
+ class BertEmbeddings(nn.Module):
53
+ """Construct the embeddings from word and position embeddings."""
54
+
55
+ def __init__(self, config):
56
+ super().__init__()
57
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
58
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
59
+
60
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
61
+ # any TensorFlow checkpoint file
62
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
63
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
64
+
65
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
66
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
67
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
68
+
69
+ self.config = config
70
+
71
+ def forward(
72
+ self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
73
+ ):
74
+ if input_ids is not None:
75
+ input_shape = input_ids.size()
76
+ else:
77
+ input_shape = inputs_embeds.size()[:-1]
78
+
79
+ seq_length = input_shape[1]
80
+
81
+ if position_ids is None:
82
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
83
+
84
+ if inputs_embeds is None:
85
+ inputs_embeds = self.word_embeddings(input_ids)
86
+
87
+ embeddings = inputs_embeds
88
+
89
+ if self.position_embedding_type == "absolute":
90
+ position_embeddings = self.position_embeddings(position_ids)
91
+ embeddings += position_embeddings
92
+ embeddings = self.LayerNorm(embeddings)
93
+ embeddings = self.dropout(embeddings)
94
+ return embeddings
95
+
96
+
97
+ class BertSelfAttention(nn.Module):
98
+ def __init__(self, config, is_cross_attention):
99
+ super().__init__()
100
+ self.config = config
101
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
102
+ raise ValueError(
103
+ "The hidden size (%d) is not a multiple of the number of attention "
104
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
105
+ )
106
+
107
+ self.num_attention_heads = config.num_attention_heads
108
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
109
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
110
+
111
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
112
+ if is_cross_attention:
113
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
114
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
115
+ else:
116
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
117
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
118
+
119
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
120
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
121
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
122
+ self.max_position_embeddings = config.max_position_embeddings
123
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
124
+ self.save_attention = False
125
+
126
+ def save_attn_gradients(self, attn_gradients):
127
+ self.attn_gradients = attn_gradients
128
+
129
+ def get_attn_gradients(self):
130
+ return self.attn_gradients
131
+
132
+ def save_attention_map(self, attention_map):
133
+ self.attention_map = attention_map
134
+
135
+ def get_attention_map(self):
136
+ return self.attention_map
137
+
138
+ def transpose_for_scores(self, x):
139
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
140
+ x = x.view(*new_x_shape)
141
+ return x.permute(0, 2, 1, 3)
142
+
143
+ def forward(
144
+ self,
145
+ hidden_states,
146
+ attention_mask=None,
147
+ head_mask=None,
148
+ encoder_hidden_states=None,
149
+ encoder_attention_mask=None,
150
+ past_key_value=None,
151
+ output_attentions=False,
152
+ ):
153
+ mixed_query_layer = self.query(hidden_states)
154
+
155
+ # If this is instantiated as a cross-attention module, the keys
156
+ # and values come from an encoder; the attention mask needs to be
157
+ # such that the encoder's padding tokens are not attended to.
158
+ is_cross_attention = encoder_hidden_states is not None
159
+
160
+ if is_cross_attention:
161
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
162
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
163
+ attention_mask = encoder_attention_mask
164
+ elif past_key_value is not None:
165
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
166
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
167
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
168
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
169
+ else:
170
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
171
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
172
+
173
+ query_layer = self.transpose_for_scores(mixed_query_layer)
174
+
175
+ past_key_value = (key_layer, value_layer)
176
+
177
+ # Take the dot product between "query" and "key" to get the raw attention scores.
178
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
179
+
180
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
181
+ seq_length = hidden_states.size()[1]
182
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
183
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
184
+ distance = position_ids_l - position_ids_r
185
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
186
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
187
+
188
+ if self.position_embedding_type == "relative_key":
189
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
190
+ attention_scores = attention_scores + relative_position_scores
191
+ elif self.position_embedding_type == "relative_key_query":
192
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
193
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
194
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
195
+
196
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
197
+ if attention_mask is not None:
198
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
199
+ attention_scores = attention_scores + attention_mask
200
+
201
+ # Normalize the attention scores to probabilities.
202
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
203
+
204
+ if is_cross_attention and self.save_attention:
205
+ self.save_attention_map(attention_probs)
206
+ attention_probs.register_hook(self.save_attn_gradients)
207
+
208
+ # This is actually dropping out entire tokens to attend to, which might
209
+ # seem a bit unusual, but is taken from the original Transformer paper.
210
+ attention_probs_dropped = self.dropout(attention_probs)
211
+
212
+ # Mask heads if we want to
213
+ if head_mask is not None:
214
+ attention_probs_dropped = attention_probs_dropped * head_mask
215
+
216
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
217
+
218
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
219
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
220
+ context_layer = context_layer.view(*new_context_layer_shape)
221
+
222
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
223
+
224
+ outputs = outputs + (past_key_value,)
225
+ return outputs
226
+
227
+
228
+ class BertSelfOutput(nn.Module):
229
+ def __init__(self, config):
230
+ super().__init__()
231
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
232
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
233
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
234
+
235
+ def forward(self, hidden_states, input_tensor):
236
+ hidden_states = self.dense(hidden_states)
237
+ hidden_states = self.dropout(hidden_states)
238
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
239
+ return hidden_states
240
+
241
+
242
+ class BertAttention(nn.Module):
243
+ def __init__(self, config, is_cross_attention=False):
244
+ super().__init__()
245
+ self.self = BertSelfAttention(config, is_cross_attention)
246
+ self.output = BertSelfOutput(config)
247
+ self.pruned_heads = set()
248
+
249
+ def prune_heads(self, heads):
250
+ if len(heads) == 0:
251
+ return
252
+ heads, index = find_pruneable_heads_and_indices(
253
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
254
+ )
255
+
256
+ # Prune linear layers
257
+ self.self.query = prune_linear_layer(self.self.query, index)
258
+ self.self.key = prune_linear_layer(self.self.key, index)
259
+ self.self.value = prune_linear_layer(self.self.value, index)
260
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
261
+
262
+ # Update hyper params and store pruned heads
263
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
264
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
265
+ self.pruned_heads = self.pruned_heads.union(heads)
266
+
267
+ def forward(
268
+ self,
269
+ hidden_states,
270
+ attention_mask=None,
271
+ head_mask=None,
272
+ encoder_hidden_states=None,
273
+ encoder_attention_mask=None,
274
+ past_key_value=None,
275
+ output_attentions=False,
276
+ ):
277
+ self_outputs = self.self(
278
+ hidden_states,
279
+ attention_mask,
280
+ head_mask,
281
+ encoder_hidden_states,
282
+ encoder_attention_mask,
283
+ past_key_value,
284
+ output_attentions,
285
+ )
286
+ attention_output = self.output(self_outputs[0], hidden_states)
287
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
288
+ return outputs
289
+
290
+
291
+ class BertIntermediate(nn.Module):
292
+ def __init__(self, config):
293
+ super().__init__()
294
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
295
+ if isinstance(config.hidden_act, str):
296
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
297
+ else:
298
+ self.intermediate_act_fn = config.hidden_act
299
+
300
+ def forward(self, hidden_states):
301
+ hidden_states = self.dense(hidden_states)
302
+ hidden_states = self.intermediate_act_fn(hidden_states)
303
+ return hidden_states
304
+
305
+
306
+ class BertOutput(nn.Module):
307
+ def __init__(self, config):
308
+ super().__init__()
309
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
310
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
311
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
312
+
313
+ def forward(self, hidden_states, input_tensor):
314
+ hidden_states = self.dense(hidden_states)
315
+ hidden_states = self.dropout(hidden_states)
316
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
317
+ return hidden_states
318
+
319
+
320
+ class BertLayer(nn.Module):
321
+ def __init__(self, config, layer_num):
322
+ super().__init__()
323
+ self.config = config
324
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
325
+ self.seq_len_dim = 1
326
+ self.attention = BertAttention(config)
327
+ self.layer_num = layer_num
328
+ if self.config.add_cross_attention:
329
+ self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
330
+ self.intermediate = BertIntermediate(config)
331
+ self.output = BertOutput(config)
332
+
333
+ def forward(
334
+ self,
335
+ hidden_states,
336
+ attention_mask=None,
337
+ head_mask=None,
338
+ encoder_hidden_states=None,
339
+ encoder_attention_mask=None,
340
+ past_key_value=None,
341
+ output_attentions=False,
342
+ mode=None,
343
+ ):
344
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
345
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
346
+ self_attention_outputs = self.attention(
347
+ hidden_states,
348
+ attention_mask,
349
+ head_mask,
350
+ output_attentions=output_attentions,
351
+ past_key_value=self_attn_past_key_value,
352
+ )
353
+ attention_output = self_attention_outputs[0]
354
+
355
+ outputs = self_attention_outputs[1:-1]
356
+ present_key_value = self_attention_outputs[-1]
357
+
358
+ if mode=='multimodal':
359
+ assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
360
+
361
+ cross_attention_outputs = self.crossattention(
362
+ attention_output,
363
+ attention_mask,
364
+ head_mask,
365
+ encoder_hidden_states,
366
+ encoder_attention_mask,
367
+ output_attentions=output_attentions,
368
+ )
369
+ attention_output = cross_attention_outputs[0]
370
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
371
+ layer_output = apply_chunking_to_forward(
372
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
373
+ )
374
+ outputs = (layer_output,) + outputs
375
+
376
+ outputs = outputs + (present_key_value,)
377
+
378
+ return outputs
379
+
380
+ def feed_forward_chunk(self, attention_output):
381
+ intermediate_output = self.intermediate(attention_output)
382
+ layer_output = self.output(intermediate_output, attention_output)
383
+ return layer_output
384
+
385
+
386
+ class BertEncoder(nn.Module):
387
+ def __init__(self, config):
388
+ super().__init__()
389
+ self.config = config
390
+ self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
391
+ self.gradient_checkpointing = False
392
+
393
+ def forward(
394
+ self,
395
+ hidden_states,
396
+ attention_mask=None,
397
+ head_mask=None,
398
+ encoder_hidden_states=None,
399
+ encoder_attention_mask=None,
400
+ past_key_values=None,
401
+ use_cache=None,
402
+ output_attentions=False,
403
+ output_hidden_states=False,
404
+ return_dict=True,
405
+ mode='multimodal',
406
+ ):
407
+ all_hidden_states = () if output_hidden_states else None
408
+ all_self_attentions = () if output_attentions else None
409
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
410
+
411
+ next_decoder_cache = () if use_cache else None
412
+
413
+ for i in range(self.config.num_hidden_layers):
414
+ layer_module = self.layer[i]
415
+ if output_hidden_states:
416
+ all_hidden_states = all_hidden_states + (hidden_states,)
417
+
418
+ layer_head_mask = head_mask[i] if head_mask is not None else None
419
+ past_key_value = past_key_values[i] if past_key_values is not None else None
420
+
421
+ if self.gradient_checkpointing and self.training:
422
+
423
+ if use_cache:
424
+ logger.warn(
425
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
426
+ )
427
+ use_cache = False
428
+
429
+ def create_custom_forward(module):
430
+ def custom_forward(*inputs):
431
+ return module(*inputs, past_key_value, output_attentions)
432
+
433
+ return custom_forward
434
+
435
+ layer_outputs = torch.utils.checkpoint.checkpoint(
436
+ create_custom_forward(layer_module),
437
+ hidden_states,
438
+ attention_mask,
439
+ layer_head_mask,
440
+ encoder_hidden_states,
441
+ encoder_attention_mask,
442
+ mode=mode,
443
+ )
444
+ else:
445
+ layer_outputs = layer_module(
446
+ hidden_states,
447
+ attention_mask,
448
+ layer_head_mask,
449
+ encoder_hidden_states,
450
+ encoder_attention_mask,
451
+ past_key_value,
452
+ output_attentions,
453
+ mode=mode,
454
+ )
455
+
456
+ hidden_states = layer_outputs[0]
457
+ if use_cache:
458
+ next_decoder_cache += (layer_outputs[-1],)
459
+ if output_attentions:
460
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
461
+
462
+ if output_hidden_states:
463
+ all_hidden_states = all_hidden_states + (hidden_states,)
464
+
465
+ if not return_dict:
466
+ return tuple(
467
+ v
468
+ for v in [
469
+ hidden_states,
470
+ next_decoder_cache,
471
+ all_hidden_states,
472
+ all_self_attentions,
473
+ all_cross_attentions,
474
+ ]
475
+ if v is not None
476
+ )
477
+ return BaseModelOutputWithPastAndCrossAttentions(
478
+ last_hidden_state=hidden_states,
479
+ past_key_values=next_decoder_cache,
480
+ hidden_states=all_hidden_states,
481
+ attentions=all_self_attentions,
482
+ cross_attentions=all_cross_attentions,
483
+ )
484
+
485
+
486
+ class BertPooler(nn.Module):
487
+ def __init__(self, config):
488
+ super().__init__()
489
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
490
+ self.activation = nn.Tanh()
491
+
492
+ def forward(self, hidden_states):
493
+ # We "pool" the model by simply taking the hidden state corresponding
494
+ # to the first token.
495
+ first_token_tensor = hidden_states[:, 0]
496
+ pooled_output = self.dense(first_token_tensor)
497
+ pooled_output = self.activation(pooled_output)
498
+ return pooled_output
499
+
500
+
501
+ class BertPredictionHeadTransform(nn.Module):
502
+ def __init__(self, config):
503
+ super().__init__()
504
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
505
+ if isinstance(config.hidden_act, str):
506
+ self.transform_act_fn = ACT2FN[config.hidden_act]
507
+ else:
508
+ self.transform_act_fn = config.hidden_act
509
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
510
+
511
+ def forward(self, hidden_states):
512
+ hidden_states = self.dense(hidden_states)
513
+ hidden_states = self.transform_act_fn(hidden_states)
514
+ hidden_states = self.LayerNorm(hidden_states)
515
+ return hidden_states
516
+
517
+
518
+ class BertLMPredictionHead(nn.Module):
519
+ def __init__(self, config):
520
+ super().__init__()
521
+ self.transform = BertPredictionHeadTransform(config)
522
+
523
+ # The output weights are the same as the input embeddings, but there is
524
+ # an output-only bias for each token.
525
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
526
+
527
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
528
+
529
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
530
+ self.decoder.bias = self.bias
531
+
532
+ def forward(self, hidden_states):
533
+ hidden_states = self.transform(hidden_states)
534
+ hidden_states = self.decoder(hidden_states)
535
+ return hidden_states
536
+
537
+
538
+ class BertOnlyMLMHead(nn.Module):
539
+ def __init__(self, config):
540
+ super().__init__()
541
+ self.predictions = BertLMPredictionHead(config)
542
+
543
+ def forward(self, sequence_output):
544
+ prediction_scores = self.predictions(sequence_output)
545
+ return prediction_scores
546
+
547
+
548
+ class BertPreTrainedModel(PreTrainedModel):
549
+ """
550
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
551
+ models.
552
+ """
553
+
554
+ config_class = BertConfig
555
+ base_model_prefix = "bert"
556
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
557
+
558
+ def _init_weights(self, module):
559
+ """ Initialize the weights """
560
+ if isinstance(module, (nn.Linear, nn.Embedding)):
561
+ # Slightly different from the TF version which uses truncated_normal for initialization
562
+ # cf https://github.com/pytorch/pytorch/pull/5617
563
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
564
+ elif isinstance(module, nn.LayerNorm):
565
+ module.bias.data.zero_()
566
+ module.weight.data.fill_(1.0)
567
+ if isinstance(module, nn.Linear) and module.bias is not None:
568
+ module.bias.data.zero_()
569
+
570
+
571
+ class BertModel(BertPreTrainedModel):
572
+ """
573
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
574
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
575
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
576
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
577
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
578
+ input to the forward pass.
579
+ """
580
+
581
+ def __init__(self, config, add_pooling_layer=True):
582
+ super().__init__(config)
583
+ self.config = config
584
+
585
+ self.embeddings = BertEmbeddings(config)
586
+
587
+ self.encoder = BertEncoder(config)
588
+
589
+ self.pooler = BertPooler(config) if add_pooling_layer else None
590
+
591
+ self.init_weights()
592
+
593
+
594
+ def get_input_embeddings(self):
595
+ return self.embeddings.word_embeddings
596
+
597
+ def set_input_embeddings(self, value):
598
+ self.embeddings.word_embeddings = value
599
+
600
+ def _prune_heads(self, heads_to_prune):
601
+ """
602
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
603
+ class PreTrainedModel
604
+ """
605
+ for layer, heads in heads_to_prune.items():
606
+ self.encoder.layer[layer].attention.prune_heads(heads)
607
+
608
+
609
+ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
610
+ """
611
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
612
+
613
+ Arguments:
614
+ attention_mask (:obj:`torch.Tensor`):
615
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
616
+ input_shape (:obj:`Tuple[int]`):
617
+ The shape of the input to the model.
618
+ device: (:obj:`torch.device`):
619
+ The device of the input to the model.
620
+
621
+ Returns:
622
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
623
+ """
624
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
625
+ # ourselves in which case we just need to make it broadcastable to all heads.
626
+ if attention_mask.dim() == 3:
627
+ extended_attention_mask = attention_mask[:, None, :, :]
628
+ elif attention_mask.dim() == 2:
629
+ # Provided a padding mask of dimensions [batch_size, seq_length]
630
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
631
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
632
+ if is_decoder:
633
+ batch_size, seq_length = input_shape
634
+
635
+ seq_ids = torch.arange(seq_length, device=device)
636
+ causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
637
+ # in case past_key_values are used we need to add a prefix ones mask to the causal mask
638
+ # causal and attention masks must have same type with pytorch version < 1.3
639
+ causal_mask = causal_mask.to(attention_mask.dtype)
640
+
641
+ if causal_mask.shape[1] < attention_mask.shape[1]:
642
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
643
+ causal_mask = torch.cat(
644
+ [
645
+ torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
646
+ causal_mask,
647
+ ],
648
+ axis=-1,
649
+ )
650
+
651
+ extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
652
+ else:
653
+ extended_attention_mask = attention_mask[:, None, None, :]
654
+ else:
655
+ raise ValueError(
656
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
657
+ input_shape, attention_mask.shape
658
+ )
659
+ )
660
+
661
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
662
+ # masked positions, this operation will create a tensor which is 0.0 for
663
+ # positions we want to attend and -10000.0 for masked positions.
664
+ # Since we are adding it to the raw scores before the softmax, this is
665
+ # effectively the same as removing these entirely.
666
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
667
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
668
+ return extended_attention_mask
669
+
670
+ def forward(
671
+ self,
672
+ input_ids=None,
673
+ attention_mask=None,
674
+ position_ids=None,
675
+ head_mask=None,
676
+ inputs_embeds=None,
677
+ encoder_embeds=None,
678
+ encoder_hidden_states=None,
679
+ encoder_attention_mask=None,
680
+ past_key_values=None,
681
+ use_cache=None,
682
+ output_attentions=None,
683
+ output_hidden_states=None,
684
+ return_dict=None,
685
+ is_decoder=False,
686
+ mode='multimodal',
687
+ ):
688
+ r"""
689
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
690
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
691
+ the model is configured as a decoder.
692
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
693
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
694
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
695
+ - 1 for tokens that are **not masked**,
696
+ - 0 for tokens that are **masked**.
697
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
698
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
699
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
700
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
701
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
702
+ use_cache (:obj:`bool`, `optional`):
703
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
704
+ decoding (see :obj:`past_key_values`).
705
+ """
706
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
707
+ output_hidden_states = (
708
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
709
+ )
710
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
711
+
712
+ if is_decoder:
713
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
714
+ else:
715
+ use_cache = False
716
+
717
+ if input_ids is not None and inputs_embeds is not None:
718
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
719
+ elif input_ids is not None:
720
+ input_shape = input_ids.size()
721
+ batch_size, seq_length = input_shape
722
+ device = input_ids.device
723
+ elif inputs_embeds is not None:
724
+ input_shape = inputs_embeds.size()[:-1]
725
+ batch_size, seq_length = input_shape
726
+ device = inputs_embeds.device
727
+ elif encoder_embeds is not None:
728
+ input_shape = encoder_embeds.size()[:-1]
729
+ batch_size, seq_length = input_shape
730
+ device = encoder_embeds.device
731
+ else:
732
+ raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
733
+
734
+ # past_key_values_length
735
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
736
+
737
+ if attention_mask is None:
738
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
739
+
740
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
741
+ # ourselves in which case we just need to make it broadcastable to all heads.
742
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
743
+ device, is_decoder)
744
+
745
+ # If a 2D or 3D attention mask is provided for the cross-attention
746
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
747
+ if encoder_hidden_states is not None:
748
+ if type(encoder_hidden_states) == list:
749
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
750
+ else:
751
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
752
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
753
+
754
+ if type(encoder_attention_mask) == list:
755
+ encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
756
+ elif encoder_attention_mask is None:
757
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
758
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
759
+ else:
760
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
761
+ else:
762
+ encoder_extended_attention_mask = None
763
+
764
+ # Prepare head mask if needed
765
+ # 1.0 in head_mask indicate we keep the head
766
+ # attention_probs has shape bsz x n_heads x N x N
767
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
768
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
769
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
770
+
771
+ if encoder_embeds is None:
772
+ embedding_output = self.embeddings(
773
+ input_ids=input_ids,
774
+ position_ids=position_ids,
775
+ inputs_embeds=inputs_embeds,
776
+ past_key_values_length=past_key_values_length,
777
+ )
778
+ else:
779
+ embedding_output = encoder_embeds
780
+
781
+ encoder_outputs = self.encoder(
782
+ embedding_output,
783
+ attention_mask=extended_attention_mask,
784
+ head_mask=head_mask,
785
+ encoder_hidden_states=encoder_hidden_states,
786
+ encoder_attention_mask=encoder_extended_attention_mask,
787
+ past_key_values=past_key_values,
788
+ use_cache=use_cache,
789
+ output_attentions=output_attentions,
790
+ output_hidden_states=output_hidden_states,
791
+ return_dict=return_dict,
792
+ mode=mode,
793
+ )
794
+ sequence_output = encoder_outputs[0]
795
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
796
+
797
+ if not return_dict:
798
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
799
+
800
+ return BaseModelOutputWithPoolingAndCrossAttentions(
801
+ last_hidden_state=sequence_output,
802
+ pooler_output=pooled_output,
803
+ past_key_values=encoder_outputs.past_key_values,
804
+ hidden_states=encoder_outputs.hidden_states,
805
+ attentions=encoder_outputs.attentions,
806
+ cross_attentions=encoder_outputs.cross_attentions,
807
+ )
808
+
809
+
810
+
811
+ class BertLMHeadModel(BertPreTrainedModel):
812
+
813
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
814
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
815
+
816
+ def __init__(self, config):
817
+ super().__init__(config)
818
+
819
+ self.bert = BertModel(config, add_pooling_layer=False)
820
+ self.cls = BertOnlyMLMHead(config)
821
+
822
+ self.init_weights()
823
+
824
+ def get_output_embeddings(self):
825
+ return self.cls.predictions.decoder
826
+
827
+ def set_output_embeddings(self, new_embeddings):
828
+ self.cls.predictions.decoder = new_embeddings
829
+
830
+ def forward(
831
+ self,
832
+ input_ids=None,
833
+ attention_mask=None,
834
+ position_ids=None,
835
+ head_mask=None,
836
+ inputs_embeds=None,
837
+ encoder_hidden_states=None,
838
+ encoder_attention_mask=None,
839
+ labels=None,
840
+ past_key_values=None,
841
+ use_cache=None,
842
+ output_attentions=None,
843
+ output_hidden_states=None,
844
+ return_dict=None,
845
+ return_logits=False,
846
+ is_decoder=True,
847
+ reduction='mean',
848
+ mode='multimodal',
849
+ ):
850
+ r"""
851
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
852
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
853
+ the model is configured as a decoder.
854
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
855
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
856
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
857
+ - 1 for tokens that are **not masked**,
858
+ - 0 for tokens that are **masked**.
859
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
860
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
861
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
862
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
863
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
864
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
865
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
866
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
867
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
868
+ use_cache (:obj:`bool`, `optional`):
869
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
870
+ decoding (see :obj:`past_key_values`).
871
+ Returns:
872
+ Example::
873
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
874
+ >>> import torch
875
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
876
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
877
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
878
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
879
+ >>> outputs = model(**inputs)
880
+ >>> prediction_logits = outputs.logits
881
+ """
882
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
883
+ if labels is not None:
884
+ use_cache = False
885
+
886
+ outputs = self.bert(
887
+ input_ids,
888
+ attention_mask=attention_mask,
889
+ position_ids=position_ids,
890
+ head_mask=head_mask,
891
+ inputs_embeds=inputs_embeds,
892
+ encoder_hidden_states=encoder_hidden_states,
893
+ encoder_attention_mask=encoder_attention_mask,
894
+ past_key_values=past_key_values,
895
+ use_cache=use_cache,
896
+ output_attentions=output_attentions,
897
+ output_hidden_states=output_hidden_states,
898
+ return_dict=return_dict,
899
+ is_decoder=is_decoder,
900
+ mode=mode,
901
+ )
902
+
903
+ sequence_output = outputs[0]
904
+ prediction_scores = self.cls(sequence_output)
905
+
906
+ if return_logits:
907
+ return prediction_scores[:, :-1, :].contiguous()
908
+
909
+ lm_loss = None
910
+ if labels is not None:
911
+ # we are doing next-token prediction; shift prediction scores and input ids by one
912
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
913
+ labels = labels[:, 1:].contiguous()
914
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
915
+ lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
916
+ if reduction=='none':
917
+ lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
918
+
919
+ if not return_dict:
920
+ output = (prediction_scores,) + outputs[2:]
921
+ return ((lm_loss,) + output) if lm_loss is not None else output
922
+
923
+ return CausalLMOutputWithCrossAttentions(
924
+ loss=lm_loss,
925
+ logits=prediction_scores,
926
+ past_key_values=outputs.past_key_values,
927
+ hidden_states=outputs.hidden_states,
928
+ attentions=outputs.attentions,
929
+ cross_attentions=outputs.cross_attentions,
930
+ )
931
+
932
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
933
+ input_shape = input_ids.shape
934
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
935
+ if attention_mask is None:
936
+ attention_mask = input_ids.new_ones(input_shape)
937
+
938
+ # cut decoder_input_ids if past is used
939
+ if past is not None:
940
+ input_ids = input_ids[:, -1:]
941
+
942
+ return {
943
+ "input_ids": input_ids,
944
+ "attention_mask": attention_mask,
945
+ "past_key_values": past,
946
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
947
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
948
+ "is_decoder": True,
949
+ }
950
+
951
+ def _reorder_cache(self, past, beam_idx):
952
+ reordered_past = ()
953
+ for layer_past in past:
954
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
955
+ return reordered_past
models/nlvr_encoder.py ADDED
@@ -0,0 +1,843 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import warnings
4
+ from dataclasses import dataclass
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ from torch import Tensor, device, dtype, nn
9
+ import torch.utils.checkpoint
10
+ from torch import nn
11
+ from torch.nn import CrossEntropyLoss
12
+ import torch.nn.functional as F
13
+
14
+ from transformers.activations import ACT2FN
15
+ from transformers.file_utils import (
16
+ ModelOutput,
17
+ )
18
+ from transformers.modeling_outputs import (
19
+ BaseModelOutputWithPastAndCrossAttentions,
20
+ BaseModelOutputWithPoolingAndCrossAttentions,
21
+ CausalLMOutputWithCrossAttentions,
22
+ MaskedLMOutput,
23
+ MultipleChoiceModelOutput,
24
+ NextSentencePredictorOutput,
25
+ QuestionAnsweringModelOutput,
26
+ SequenceClassifierOutput,
27
+ TokenClassifierOutput,
28
+ )
29
+ from transformers.modeling_utils import (
30
+ PreTrainedModel,
31
+ apply_chunking_to_forward,
32
+ find_pruneable_heads_and_indices,
33
+ prune_linear_layer,
34
+ )
35
+ from transformers.utils import logging
36
+ from transformers.models.bert.configuration_bert import BertConfig
37
+
38
+
39
+ logger = logging.get_logger(__name__)
40
+
41
+
42
+ class BertEmbeddings(nn.Module):
43
+ """Construct the embeddings from word and position embeddings."""
44
+
45
+ def __init__(self, config):
46
+ super().__init__()
47
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
48
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
49
+
50
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
51
+ # any TensorFlow checkpoint file
52
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
53
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
54
+
55
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
56
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
57
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
58
+
59
+ self.config = config
60
+
61
+ def forward(
62
+ self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
63
+ ):
64
+ if input_ids is not None:
65
+ input_shape = input_ids.size()
66
+ else:
67
+ input_shape = inputs_embeds.size()[:-1]
68
+
69
+ seq_length = input_shape[1]
70
+
71
+ if position_ids is None:
72
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
73
+
74
+ if inputs_embeds is None:
75
+ inputs_embeds = self.word_embeddings(input_ids)
76
+
77
+ embeddings = inputs_embeds
78
+
79
+ if self.position_embedding_type == "absolute":
80
+ position_embeddings = self.position_embeddings(position_ids)
81
+ embeddings += position_embeddings
82
+ embeddings = self.LayerNorm(embeddings)
83
+ embeddings = self.dropout(embeddings)
84
+ return embeddings
85
+
86
+
87
+ class BertSelfAttention(nn.Module):
88
+ def __init__(self, config, is_cross_attention):
89
+ super().__init__()
90
+ self.config = config
91
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
92
+ raise ValueError(
93
+ "The hidden size (%d) is not a multiple of the number of attention "
94
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
95
+ )
96
+
97
+ self.num_attention_heads = config.num_attention_heads
98
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
99
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
100
+
101
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
102
+ if is_cross_attention:
103
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
104
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
105
+ else:
106
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
107
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
108
+
109
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
110
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
111
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
112
+ self.max_position_embeddings = config.max_position_embeddings
113
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
114
+ self.save_attention = False
115
+
116
+ def save_attn_gradients(self, attn_gradients):
117
+ self.attn_gradients = attn_gradients
118
+
119
+ def get_attn_gradients(self):
120
+ return self.attn_gradients
121
+
122
+ def save_attention_map(self, attention_map):
123
+ self.attention_map = attention_map
124
+
125
+ def get_attention_map(self):
126
+ return self.attention_map
127
+
128
+ def transpose_for_scores(self, x):
129
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
130
+ x = x.view(*new_x_shape)
131
+ return x.permute(0, 2, 1, 3)
132
+
133
+ def forward(
134
+ self,
135
+ hidden_states,
136
+ attention_mask=None,
137
+ head_mask=None,
138
+ encoder_hidden_states=None,
139
+ encoder_attention_mask=None,
140
+ past_key_value=None,
141
+ output_attentions=False,
142
+ ):
143
+ mixed_query_layer = self.query(hidden_states)
144
+
145
+ # If this is instantiated as a cross-attention module, the keys
146
+ # and values come from an encoder; the attention mask needs to be
147
+ # such that the encoder's padding tokens are not attended to.
148
+ is_cross_attention = encoder_hidden_states is not None
149
+
150
+ if is_cross_attention:
151
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
152
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
153
+ attention_mask = encoder_attention_mask
154
+ elif past_key_value is not None:
155
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
156
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
157
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
158
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
159
+ else:
160
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
161
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
162
+
163
+ query_layer = self.transpose_for_scores(mixed_query_layer)
164
+
165
+ past_key_value = (key_layer, value_layer)
166
+
167
+ # Take the dot product between "query" and "key" to get the raw attention scores.
168
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
169
+
170
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
171
+ seq_length = hidden_states.size()[1]
172
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
173
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
174
+ distance = position_ids_l - position_ids_r
175
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
176
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
177
+
178
+ if self.position_embedding_type == "relative_key":
179
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
180
+ attention_scores = attention_scores + relative_position_scores
181
+ elif self.position_embedding_type == "relative_key_query":
182
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
183
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
184
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
185
+
186
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
187
+ if attention_mask is not None:
188
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
189
+ attention_scores = attention_scores + attention_mask
190
+
191
+ # Normalize the attention scores to probabilities.
192
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
193
+
194
+ if is_cross_attention and self.save_attention:
195
+ self.save_attention_map(attention_probs)
196
+ attention_probs.register_hook(self.save_attn_gradients)
197
+
198
+ # This is actually dropping out entire tokens to attend to, which might
199
+ # seem a bit unusual, but is taken from the original Transformer paper.
200
+ attention_probs_dropped = self.dropout(attention_probs)
201
+
202
+ # Mask heads if we want to
203
+ if head_mask is not None:
204
+ attention_probs_dropped = attention_probs_dropped * head_mask
205
+
206
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
207
+
208
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
209
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
210
+ context_layer = context_layer.view(*new_context_layer_shape)
211
+
212
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
213
+
214
+ outputs = outputs + (past_key_value,)
215
+ return outputs
216
+
217
+
218
+ class BertSelfOutput(nn.Module):
219
+ def __init__(self, config, twin=False, merge=False):
220
+ super().__init__()
221
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
222
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
223
+ if twin:
224
+ self.dense0 = nn.Linear(config.hidden_size, config.hidden_size)
225
+ self.dense1 = nn.Linear(config.hidden_size, config.hidden_size)
226
+ else:
227
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
228
+ if merge:
229
+ self.act = ACT2FN[config.hidden_act]
230
+ self.merge_layer = nn.Linear(config.hidden_size * 2, config.hidden_size)
231
+ self.merge = True
232
+ else:
233
+ self.merge = False
234
+
235
+ def forward(self, hidden_states, input_tensor):
236
+ if type(hidden_states) == list:
237
+ hidden_states0 = self.dense0(hidden_states[0])
238
+ hidden_states1 = self.dense1(hidden_states[1])
239
+ if self.merge:
240
+ #hidden_states = self.merge_layer(self.act(torch.cat([hidden_states0,hidden_states1],dim=-1)))
241
+ hidden_states = self.merge_layer(torch.cat([hidden_states0,hidden_states1],dim=-1))
242
+ else:
243
+ hidden_states = (hidden_states0+hidden_states1)/2
244
+ else:
245
+ hidden_states = self.dense(hidden_states)
246
+ hidden_states = self.dropout(hidden_states)
247
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
248
+ return hidden_states
249
+
250
+
251
+ class BertAttention(nn.Module):
252
+ def __init__(self, config, is_cross_attention=False, layer_num=-1):
253
+ super().__init__()
254
+ if is_cross_attention:
255
+ self.self0 = BertSelfAttention(config, is_cross_attention)
256
+ self.self1 = BertSelfAttention(config, is_cross_attention)
257
+ else:
258
+ self.self = BertSelfAttention(config, is_cross_attention)
259
+ self.output = BertSelfOutput(config, twin=is_cross_attention, merge=(is_cross_attention and layer_num>=6))
260
+ self.pruned_heads = set()
261
+
262
+ def prune_heads(self, heads):
263
+ if len(heads) == 0:
264
+ return
265
+ heads, index = find_pruneable_heads_and_indices(
266
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
267
+ )
268
+
269
+ # Prune linear layers
270
+ self.self.query = prune_linear_layer(self.self.query, index)
271
+ self.self.key = prune_linear_layer(self.self.key, index)
272
+ self.self.value = prune_linear_layer(self.self.value, index)
273
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
274
+
275
+ # Update hyper params and store pruned heads
276
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
277
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
278
+ self.pruned_heads = self.pruned_heads.union(heads)
279
+
280
+ def forward(
281
+ self,
282
+ hidden_states,
283
+ attention_mask=None,
284
+ head_mask=None,
285
+ encoder_hidden_states=None,
286
+ encoder_attention_mask=None,
287
+ past_key_value=None,
288
+ output_attentions=False,
289
+ ):
290
+ if type(encoder_hidden_states)==list:
291
+ self_outputs0 = self.self0(
292
+ hidden_states,
293
+ attention_mask,
294
+ head_mask,
295
+ encoder_hidden_states[0],
296
+ encoder_attention_mask[0],
297
+ past_key_value,
298
+ output_attentions,
299
+ )
300
+ self_outputs1 = self.self1(
301
+ hidden_states,
302
+ attention_mask,
303
+ head_mask,
304
+ encoder_hidden_states[1],
305
+ encoder_attention_mask[1],
306
+ past_key_value,
307
+ output_attentions,
308
+ )
309
+ attention_output = self.output([self_outputs0[0],self_outputs1[0]], hidden_states)
310
+
311
+ outputs = (attention_output,) + self_outputs0[1:] # add attentions if we output them
312
+ else:
313
+ self_outputs = self.self(
314
+ hidden_states,
315
+ attention_mask,
316
+ head_mask,
317
+ encoder_hidden_states,
318
+ encoder_attention_mask,
319
+ past_key_value,
320
+ output_attentions,
321
+ )
322
+ attention_output = self.output(self_outputs[0], hidden_states)
323
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
324
+ return outputs
325
+
326
+
327
+ class BertIntermediate(nn.Module):
328
+ def __init__(self, config):
329
+ super().__init__()
330
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
331
+ if isinstance(config.hidden_act, str):
332
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
333
+ else:
334
+ self.intermediate_act_fn = config.hidden_act
335
+
336
+ def forward(self, hidden_states):
337
+ hidden_states = self.dense(hidden_states)
338
+ hidden_states = self.intermediate_act_fn(hidden_states)
339
+ return hidden_states
340
+
341
+
342
+ class BertOutput(nn.Module):
343
+ def __init__(self, config):
344
+ super().__init__()
345
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
346
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
347
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
348
+
349
+ def forward(self, hidden_states, input_tensor):
350
+ hidden_states = self.dense(hidden_states)
351
+ hidden_states = self.dropout(hidden_states)
352
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
353
+ return hidden_states
354
+
355
+
356
+ class BertLayer(nn.Module):
357
+ def __init__(self, config, layer_num):
358
+ super().__init__()
359
+ self.config = config
360
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
361
+ self.seq_len_dim = 1
362
+ self.attention = BertAttention(config)
363
+ self.layer_num = layer_num
364
+ if self.config.add_cross_attention:
365
+ self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention, layer_num=layer_num)
366
+ self.intermediate = BertIntermediate(config)
367
+ self.output = BertOutput(config)
368
+
369
+ def forward(
370
+ self,
371
+ hidden_states,
372
+ attention_mask=None,
373
+ head_mask=None,
374
+ encoder_hidden_states=None,
375
+ encoder_attention_mask=None,
376
+ past_key_value=None,
377
+ output_attentions=False,
378
+ mode=None,
379
+ ):
380
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
381
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
382
+ self_attention_outputs = self.attention(
383
+ hidden_states,
384
+ attention_mask,
385
+ head_mask,
386
+ output_attentions=output_attentions,
387
+ past_key_value=self_attn_past_key_value,
388
+ )
389
+ attention_output = self_attention_outputs[0]
390
+
391
+ outputs = self_attention_outputs[1:-1]
392
+ present_key_value = self_attention_outputs[-1]
393
+
394
+ if mode=='multimodal':
395
+ assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
396
+ cross_attention_outputs = self.crossattention(
397
+ attention_output,
398
+ attention_mask,
399
+ head_mask,
400
+ encoder_hidden_states,
401
+ encoder_attention_mask,
402
+ output_attentions=output_attentions,
403
+ )
404
+ attention_output = cross_attention_outputs[0]
405
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
406
+ layer_output = apply_chunking_to_forward(
407
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
408
+ )
409
+ outputs = (layer_output,) + outputs
410
+
411
+ outputs = outputs + (present_key_value,)
412
+
413
+ return outputs
414
+
415
+ def feed_forward_chunk(self, attention_output):
416
+ intermediate_output = self.intermediate(attention_output)
417
+ layer_output = self.output(intermediate_output, attention_output)
418
+ return layer_output
419
+
420
+
421
+ class BertEncoder(nn.Module):
422
+ def __init__(self, config):
423
+ super().__init__()
424
+ self.config = config
425
+ self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
426
+ self.gradient_checkpointing = False
427
+
428
+ def forward(
429
+ self,
430
+ hidden_states,
431
+ attention_mask=None,
432
+ head_mask=None,
433
+ encoder_hidden_states=None,
434
+ encoder_attention_mask=None,
435
+ past_key_values=None,
436
+ use_cache=None,
437
+ output_attentions=False,
438
+ output_hidden_states=False,
439
+ return_dict=True,
440
+ mode='multimodal',
441
+ ):
442
+ all_hidden_states = () if output_hidden_states else None
443
+ all_self_attentions = () if output_attentions else None
444
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
445
+
446
+ next_decoder_cache = () if use_cache else None
447
+
448
+ for i in range(self.config.num_hidden_layers):
449
+ layer_module = self.layer[i]
450
+ if output_hidden_states:
451
+ all_hidden_states = all_hidden_states + (hidden_states,)
452
+
453
+ layer_head_mask = head_mask[i] if head_mask is not None else None
454
+ past_key_value = past_key_values[i] if past_key_values is not None else None
455
+
456
+ if self.gradient_checkpointing and self.training:
457
+
458
+ if use_cache:
459
+ logger.warn(
460
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
461
+ )
462
+ use_cache = False
463
+
464
+ def create_custom_forward(module):
465
+ def custom_forward(*inputs):
466
+ return module(*inputs, past_key_value, output_attentions)
467
+
468
+ return custom_forward
469
+
470
+ layer_outputs = torch.utils.checkpoint.checkpoint(
471
+ create_custom_forward(layer_module),
472
+ hidden_states,
473
+ attention_mask,
474
+ layer_head_mask,
475
+ encoder_hidden_states,
476
+ encoder_attention_mask,
477
+ mode=mode,
478
+ )
479
+ else:
480
+ layer_outputs = layer_module(
481
+ hidden_states,
482
+ attention_mask,
483
+ layer_head_mask,
484
+ encoder_hidden_states,
485
+ encoder_attention_mask,
486
+ past_key_value,
487
+ output_attentions,
488
+ mode=mode,
489
+ )
490
+
491
+ hidden_states = layer_outputs[0]
492
+ if use_cache:
493
+ next_decoder_cache += (layer_outputs[-1],)
494
+ if output_attentions:
495
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
496
+
497
+ if output_hidden_states:
498
+ all_hidden_states = all_hidden_states + (hidden_states,)
499
+
500
+ if not return_dict:
501
+ return tuple(
502
+ v
503
+ for v in [
504
+ hidden_states,
505
+ next_decoder_cache,
506
+ all_hidden_states,
507
+ all_self_attentions,
508
+ all_cross_attentions,
509
+ ]
510
+ if v is not None
511
+ )
512
+ return BaseModelOutputWithPastAndCrossAttentions(
513
+ last_hidden_state=hidden_states,
514
+ past_key_values=next_decoder_cache,
515
+ hidden_states=all_hidden_states,
516
+ attentions=all_self_attentions,
517
+ cross_attentions=all_cross_attentions,
518
+ )
519
+
520
+
521
+ class BertPooler(nn.Module):
522
+ def __init__(self, config):
523
+ super().__init__()
524
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
525
+ self.activation = nn.Tanh()
526
+
527
+ def forward(self, hidden_states):
528
+ # We "pool" the model by simply taking the hidden state corresponding
529
+ # to the first token.
530
+ first_token_tensor = hidden_states[:, 0]
531
+ pooled_output = self.dense(first_token_tensor)
532
+ pooled_output = self.activation(pooled_output)
533
+ return pooled_output
534
+
535
+
536
+ class BertPredictionHeadTransform(nn.Module):
537
+ def __init__(self, config):
538
+ super().__init__()
539
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
540
+ if isinstance(config.hidden_act, str):
541
+ self.transform_act_fn = ACT2FN[config.hidden_act]
542
+ else:
543
+ self.transform_act_fn = config.hidden_act
544
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
545
+
546
+ def forward(self, hidden_states):
547
+ hidden_states = self.dense(hidden_states)
548
+ hidden_states = self.transform_act_fn(hidden_states)
549
+ hidden_states = self.LayerNorm(hidden_states)
550
+ return hidden_states
551
+
552
+
553
+ class BertLMPredictionHead(nn.Module):
554
+ def __init__(self, config):
555
+ super().__init__()
556
+ self.transform = BertPredictionHeadTransform(config)
557
+
558
+ # The output weights are the same as the input embeddings, but there is
559
+ # an output-only bias for each token.
560
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
561
+
562
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
563
+
564
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
565
+ self.decoder.bias = self.bias
566
+
567
+ def forward(self, hidden_states):
568
+ hidden_states = self.transform(hidden_states)
569
+ hidden_states = self.decoder(hidden_states)
570
+ return hidden_states
571
+
572
+
573
+ class BertOnlyMLMHead(nn.Module):
574
+ def __init__(self, config):
575
+ super().__init__()
576
+ self.predictions = BertLMPredictionHead(config)
577
+
578
+ def forward(self, sequence_output):
579
+ prediction_scores = self.predictions(sequence_output)
580
+ return prediction_scores
581
+
582
+
583
+ class BertPreTrainedModel(PreTrainedModel):
584
+ """
585
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
586
+ models.
587
+ """
588
+
589
+ config_class = BertConfig
590
+ base_model_prefix = "bert"
591
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
592
+
593
+ def _init_weights(self, module):
594
+ """ Initialize the weights """
595
+ if isinstance(module, (nn.Linear, nn.Embedding)):
596
+ # Slightly different from the TF version which uses truncated_normal for initialization
597
+ # cf https://github.com/pytorch/pytorch/pull/5617
598
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
599
+ elif isinstance(module, nn.LayerNorm):
600
+ module.bias.data.zero_()
601
+ module.weight.data.fill_(1.0)
602
+ if isinstance(module, nn.Linear) and module.bias is not None:
603
+ module.bias.data.zero_()
604
+
605
+
606
+ class BertModel(BertPreTrainedModel):
607
+ """
608
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
609
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
610
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
611
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
612
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
613
+ input to the forward pass.
614
+ """
615
+
616
+ def __init__(self, config, add_pooling_layer=True):
617
+ super().__init__(config)
618
+ self.config = config
619
+
620
+ self.embeddings = BertEmbeddings(config)
621
+
622
+ self.encoder = BertEncoder(config)
623
+
624
+ self.pooler = BertPooler(config) if add_pooling_layer else None
625
+
626
+ self.init_weights()
627
+
628
+
629
+ def get_input_embeddings(self):
630
+ return self.embeddings.word_embeddings
631
+
632
+ def set_input_embeddings(self, value):
633
+ self.embeddings.word_embeddings = value
634
+
635
+ def _prune_heads(self, heads_to_prune):
636
+ """
637
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
638
+ class PreTrainedModel
639
+ """
640
+ for layer, heads in heads_to_prune.items():
641
+ self.encoder.layer[layer].attention.prune_heads(heads)
642
+
643
+
644
+ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
645
+ """
646
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
647
+
648
+ Arguments:
649
+ attention_mask (:obj:`torch.Tensor`):
650
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
651
+ input_shape (:obj:`Tuple[int]`):
652
+ The shape of the input to the model.
653
+ device: (:obj:`torch.device`):
654
+ The device of the input to the model.
655
+
656
+ Returns:
657
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
658
+ """
659
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
660
+ # ourselves in which case we just need to make it broadcastable to all heads.
661
+ if attention_mask.dim() == 3:
662
+ extended_attention_mask = attention_mask[:, None, :, :]
663
+ elif attention_mask.dim() == 2:
664
+ # Provided a padding mask of dimensions [batch_size, seq_length]
665
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
666
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
667
+ if is_decoder:
668
+ batch_size, seq_length = input_shape
669
+
670
+ seq_ids = torch.arange(seq_length, device=device)
671
+ causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
672
+ # in case past_key_values are used we need to add a prefix ones mask to the causal mask
673
+ # causal and attention masks must have same type with pytorch version < 1.3
674
+ causal_mask = causal_mask.to(attention_mask.dtype)
675
+
676
+ if causal_mask.shape[1] < attention_mask.shape[1]:
677
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
678
+ causal_mask = torch.cat(
679
+ [
680
+ torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
681
+ causal_mask,
682
+ ],
683
+ axis=-1,
684
+ )
685
+
686
+ extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
687
+ else:
688
+ extended_attention_mask = attention_mask[:, None, None, :]
689
+ else:
690
+ raise ValueError(
691
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
692
+ input_shape, attention_mask.shape
693
+ )
694
+ )
695
+
696
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
697
+ # masked positions, this operation will create a tensor which is 0.0 for
698
+ # positions we want to attend and -10000.0 for masked positions.
699
+ # Since we are adding it to the raw scores before the softmax, this is
700
+ # effectively the same as removing these entirely.
701
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
702
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
703
+ return extended_attention_mask
704
+
705
+ def forward(
706
+ self,
707
+ input_ids=None,
708
+ attention_mask=None,
709
+ position_ids=None,
710
+ head_mask=None,
711
+ inputs_embeds=None,
712
+ encoder_embeds=None,
713
+ encoder_hidden_states=None,
714
+ encoder_attention_mask=None,
715
+ past_key_values=None,
716
+ use_cache=None,
717
+ output_attentions=None,
718
+ output_hidden_states=None,
719
+ return_dict=None,
720
+ is_decoder=False,
721
+ mode='multimodal',
722
+ ):
723
+ r"""
724
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
725
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
726
+ the model is configured as a decoder.
727
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
728
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
729
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
730
+ - 1 for tokens that are **not masked**,
731
+ - 0 for tokens that are **masked**.
732
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
733
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
734
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
735
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
736
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
737
+ use_cache (:obj:`bool`, `optional`):
738
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
739
+ decoding (see :obj:`past_key_values`).
740
+ """
741
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
742
+ output_hidden_states = (
743
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
744
+ )
745
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
746
+
747
+ if is_decoder:
748
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
749
+ else:
750
+ use_cache = False
751
+
752
+ if input_ids is not None and inputs_embeds is not None:
753
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
754
+ elif input_ids is not None:
755
+ input_shape = input_ids.size()
756
+ batch_size, seq_length = input_shape
757
+ device = input_ids.device
758
+ elif inputs_embeds is not None:
759
+ input_shape = inputs_embeds.size()[:-1]
760
+ batch_size, seq_length = input_shape
761
+ device = inputs_embeds.device
762
+ elif encoder_embeds is not None:
763
+ input_shape = encoder_embeds.size()[:-1]
764
+ batch_size, seq_length = input_shape
765
+ device = encoder_embeds.device
766
+ else:
767
+ raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
768
+
769
+ # past_key_values_length
770
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
771
+
772
+ if attention_mask is None:
773
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
774
+
775
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
776
+ # ourselves in which case we just need to make it broadcastable to all heads.
777
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
778
+ device, is_decoder)
779
+
780
+ # If a 2D or 3D attention mask is provided for the cross-attention
781
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
782
+ if encoder_hidden_states is not None:
783
+ if type(encoder_hidden_states) == list:
784
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
785
+ else:
786
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
787
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
788
+
789
+ if type(encoder_attention_mask) == list:
790
+ encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
791
+ elif encoder_attention_mask is None:
792
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
793
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
794
+ else:
795
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
796
+ else:
797
+ encoder_extended_attention_mask = None
798
+
799
+ # Prepare head mask if needed
800
+ # 1.0 in head_mask indicate we keep the head
801
+ # attention_probs has shape bsz x n_heads x N x N
802
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
803
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
804
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
805
+
806
+ if encoder_embeds is None:
807
+ embedding_output = self.embeddings(
808
+ input_ids=input_ids,
809
+ position_ids=position_ids,
810
+ inputs_embeds=inputs_embeds,
811
+ past_key_values_length=past_key_values_length,
812
+ )
813
+ else:
814
+ embedding_output = encoder_embeds
815
+
816
+ encoder_outputs = self.encoder(
817
+ embedding_output,
818
+ attention_mask=extended_attention_mask,
819
+ head_mask=head_mask,
820
+ encoder_hidden_states=encoder_hidden_states,
821
+ encoder_attention_mask=encoder_extended_attention_mask,
822
+ past_key_values=past_key_values,
823
+ use_cache=use_cache,
824
+ output_attentions=output_attentions,
825
+ output_hidden_states=output_hidden_states,
826
+ return_dict=return_dict,
827
+ mode=mode,
828
+ )
829
+ sequence_output = encoder_outputs[0]
830
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
831
+
832
+ if not return_dict:
833
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
834
+
835
+ return BaseModelOutputWithPoolingAndCrossAttentions(
836
+ last_hidden_state=sequence_output,
837
+ pooler_output=pooled_output,
838
+ past_key_values=encoder_outputs.past_key_values,
839
+ hidden_states=encoder_outputs.hidden_states,
840
+ attentions=encoder_outputs.attentions,
841
+ cross_attentions=encoder_outputs.cross_attentions,
842
+ )
843
+
models/vit.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ * Based on timm code base
8
+ * https://github.com/rwightman/pytorch-image-models/tree/master/timm
9
+ '''
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from functools import partial
15
+
16
+ from timm.models.vision_transformer import _cfg, PatchEmbed
17
+ from timm.models.registry import register_model
18
+ from timm.models.layers import trunc_normal_, DropPath
19
+ from timm.models.helpers import named_apply, adapt_input_conv
20
+
21
+ from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
22
+
23
+ class Mlp(nn.Module):
24
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
25
+ """
26
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
27
+ super().__init__()
28
+ out_features = out_features or in_features
29
+ hidden_features = hidden_features or in_features
30
+ self.fc1 = nn.Linear(in_features, hidden_features)
31
+ self.act = act_layer()
32
+ self.fc2 = nn.Linear(hidden_features, out_features)
33
+ self.drop = nn.Dropout(drop)
34
+
35
+ def forward(self, x):
36
+ x = self.fc1(x)
37
+ x = self.act(x)
38
+ x = self.drop(x)
39
+ x = self.fc2(x)
40
+ x = self.drop(x)
41
+ return x
42
+
43
+
44
+ class Attention(nn.Module):
45
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
46
+ super().__init__()
47
+ self.num_heads = num_heads
48
+ head_dim = dim // num_heads
49
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
50
+ self.scale = qk_scale or head_dim ** -0.5
51
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
52
+ self.attn_drop = nn.Dropout(attn_drop)
53
+ self.proj = nn.Linear(dim, dim)
54
+ self.proj_drop = nn.Dropout(proj_drop)
55
+ self.attn_gradients = None
56
+ self.attention_map = None
57
+
58
+ def save_attn_gradients(self, attn_gradients):
59
+ self.attn_gradients = attn_gradients
60
+
61
+ def get_attn_gradients(self):
62
+ return self.attn_gradients
63
+
64
+ def save_attention_map(self, attention_map):
65
+ self.attention_map = attention_map
66
+
67
+ def get_attention_map(self):
68
+ return self.attention_map
69
+
70
+ def forward(self, x, register_hook=False):
71
+ B, N, C = x.shape
72
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
73
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
74
+
75
+ attn = (q @ k.transpose(-2, -1)) * self.scale
76
+ attn = attn.softmax(dim=-1)
77
+ attn = self.attn_drop(attn)
78
+
79
+ if register_hook:
80
+ self.save_attention_map(attn)
81
+ attn.register_hook(self.save_attn_gradients)
82
+
83
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
84
+ x = self.proj(x)
85
+ x = self.proj_drop(x)
86
+ return x
87
+
88
+
89
+ class Block(nn.Module):
90
+
91
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
92
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
93
+ super().__init__()
94
+ self.norm1 = norm_layer(dim)
95
+ self.attn = Attention(
96
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
97
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
98
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
99
+ self.norm2 = norm_layer(dim)
100
+ mlp_hidden_dim = int(dim * mlp_ratio)
101
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
102
+
103
+ if use_grad_checkpointing:
104
+ self.attn = checkpoint_wrapper(self.attn)
105
+ self.mlp = checkpoint_wrapper(self.mlp)
106
+
107
+ def forward(self, x, register_hook=False):
108
+ x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
109
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
110
+ return x
111
+
112
+
113
+ class VisionTransformer(nn.Module):
114
+ """ Vision Transformer
115
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
116
+ https://arxiv.org/abs/2010.11929
117
+ """
118
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
119
+ num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
120
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
121
+ use_grad_checkpointing=False, ckpt_layer=0):
122
+ """
123
+ Args:
124
+ img_size (int, tuple): input image size
125
+ patch_size (int, tuple): patch size
126
+ in_chans (int): number of input channels
127
+ num_classes (int): number of classes for classification head
128
+ embed_dim (int): embedding dimension
129
+ depth (int): depth of transformer
130
+ num_heads (int): number of attention heads
131
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
132
+ qkv_bias (bool): enable bias for qkv if True
133
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set
134
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
135
+ drop_rate (float): dropout rate
136
+ attn_drop_rate (float): attention dropout rate
137
+ drop_path_rate (float): stochastic depth rate
138
+ norm_layer: (nn.Module): normalization layer
139
+ """
140
+ super().__init__()
141
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
142
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
143
+
144
+ self.patch_embed = PatchEmbed(
145
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
146
+
147
+ num_patches = self.patch_embed.num_patches
148
+
149
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
150
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
151
+ self.pos_drop = nn.Dropout(p=drop_rate)
152
+
153
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
154
+ self.blocks = nn.ModuleList([
155
+ Block(
156
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
157
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
158
+ use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer)
159
+ )
160
+ for i in range(depth)])
161
+ self.norm = norm_layer(embed_dim)
162
+
163
+ trunc_normal_(self.pos_embed, std=.02)
164
+ trunc_normal_(self.cls_token, std=.02)
165
+ self.apply(self._init_weights)
166
+
167
+ def _init_weights(self, m):
168
+ if isinstance(m, nn.Linear):
169
+ trunc_normal_(m.weight, std=.02)
170
+ if isinstance(m, nn.Linear) and m.bias is not None:
171
+ nn.init.constant_(m.bias, 0)
172
+ elif isinstance(m, nn.LayerNorm):
173
+ nn.init.constant_(m.bias, 0)
174
+ nn.init.constant_(m.weight, 1.0)
175
+
176
+ @torch.jit.ignore
177
+ def no_weight_decay(self):
178
+ return {'pos_embed', 'cls_token'}
179
+
180
+ def forward(self, x, register_blk=-1):
181
+ B = x.shape[0]
182
+ x = self.patch_embed(x)
183
+
184
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
185
+ x = torch.cat((cls_tokens, x), dim=1)
186
+
187
+ x = x + self.pos_embed[:,:x.size(1),:]
188
+ x = self.pos_drop(x)
189
+
190
+ for i,blk in enumerate(self.blocks):
191
+ x = blk(x, register_blk==i)
192
+ x = self.norm(x)
193
+
194
+ return x
195
+
196
+ @torch.jit.ignore()
197
+ def load_pretrained(self, checkpoint_path, prefix=''):
198
+ _load_weights(self, checkpoint_path, prefix)
199
+
200
+
201
+ @torch.no_grad()
202
+ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
203
+ """ Load weights from .npz checkpoints for official Google Brain Flax implementation
204
+ """
205
+ import numpy as np
206
+
207
+ def _n2p(w, t=True):
208
+ if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
209
+ w = w.flatten()
210
+ if t:
211
+ if w.ndim == 4:
212
+ w = w.transpose([3, 2, 0, 1])
213
+ elif w.ndim == 3:
214
+ w = w.transpose([2, 0, 1])
215
+ elif w.ndim == 2:
216
+ w = w.transpose([1, 0])
217
+ return torch.from_numpy(w)
218
+
219
+ w = np.load(checkpoint_path)
220
+ if not prefix and 'opt/target/embedding/kernel' in w:
221
+ prefix = 'opt/target/'
222
+
223
+ if hasattr(model.patch_embed, 'backbone'):
224
+ # hybrid
225
+ backbone = model.patch_embed.backbone
226
+ stem_only = not hasattr(backbone, 'stem')
227
+ stem = backbone if stem_only else backbone.stem
228
+ stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
229
+ stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
230
+ stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
231
+ if not stem_only:
232
+ for i, stage in enumerate(backbone.stages):
233
+ for j, block in enumerate(stage.blocks):
234
+ bp = f'{prefix}block{i + 1}/unit{j + 1}/'
235
+ for r in range(3):
236
+ getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
237
+ getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
238
+ getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
239
+ if block.downsample is not None:
240
+ block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
241
+ block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
242
+ block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
243
+ embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
244
+ else:
245
+ embed_conv_w = adapt_input_conv(
246
+ model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
247
+ model.patch_embed.proj.weight.copy_(embed_conv_w)
248
+ model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
249
+ model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
250
+ pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
251
+ if pos_embed_w.shape != model.pos_embed.shape:
252
+ pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
253
+ pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
254
+ model.pos_embed.copy_(pos_embed_w)
255
+ model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
256
+ model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
257
+ # if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
258
+ # model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
259
+ # model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
260
+ # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
261
+ # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
262
+ # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
263
+ for i, block in enumerate(model.blocks.children()):
264
+ block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
265
+ mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
266
+ block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
267
+ block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
268
+ block.attn.qkv.weight.copy_(torch.cat([
269
+ _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
270
+ block.attn.qkv.bias.copy_(torch.cat([
271
+ _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
272
+ block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
273
+ block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
274
+ for r in range(2):
275
+ getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
276
+ getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
277
+ block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
278
+ block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
279
+
280
+
281
+ def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
282
+ # interpolate position embedding
283
+ embedding_size = pos_embed_checkpoint.shape[-1]
284
+ num_patches = visual_encoder.patch_embed.num_patches
285
+ num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
286
+ # height (== width) for the checkpoint position embedding
287
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
288
+ # height (== width) for the new position embedding
289
+ new_size = int(num_patches ** 0.5)
290
+
291
+ if orig_size!=new_size:
292
+ # class_token and dist_token are kept unchanged
293
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
294
+ # only the position tokens are interpolated
295
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
296
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
297
+ pos_tokens = torch.nn.functional.interpolate(
298
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
299
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
300
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
301
+ print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
302
+
303
+ return new_pos_embed
304
+ else:
305
+ return pos_embed_checkpoint
objects.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:94238b6187c0ea1f97acc1fdaec106bd3dee12bb0f81b23dc8b07e9173514985
3
+ size 349437266
predict.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Download the weights in ./checkpoints beforehand for fast inference
3
+ wget https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth
4
+ wget https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_vqa.pth
5
+ wget https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth
6
+ """
7
+
8
+ from pathlib import Path
9
+
10
+ from PIL import Image
11
+ import torch
12
+ from torchvision import transforms
13
+ from torchvision.transforms.functional import InterpolationMode
14
+ import cog
15
+
16
+ from models.blip import blip_decoder
17
+ from models.blip_vqa import blip_vqa
18
+ from models.blip_itm import blip_itm
19
+
20
+
21
+ class Predictor(cog.Predictor):
22
+ def setup(self):
23
+ self.device = "cuda:0"
24
+
25
+ self.models = {
26
+ 'image_captioning': blip_decoder(pretrained='checkpoints/model*_base_caption.pth',
27
+ image_size=384, vit='base'),
28
+ 'visual_question_answering': blip_vqa(pretrained='checkpoints/model*_vqa.pth',
29
+ image_size=480, vit='base'),
30
+ 'image_text_matching': blip_itm(pretrained='checkpoints/model_base_retrieval_coco.pth',
31
+ image_size=384, vit='base')
32
+ }
33
+
34
+ @cog.input(
35
+ "image",
36
+ type=Path,
37
+ help="input image",
38
+ )
39
+ @cog.input(
40
+ "task",
41
+ type=str,
42
+ default='image_captioning',
43
+ options=['image_captioning', 'visual_question_answering', 'image_text_matching'],
44
+ help="Choose a task.",
45
+ )
46
+ @cog.input(
47
+ "question",
48
+ type=str,
49
+ default=None,
50
+ help="Type question for the input image for visual question answering task.",
51
+ )
52
+ @cog.input(
53
+ "caption",
54
+ type=str,
55
+ default=None,
56
+ help="Type caption for the input image for image text matching task.",
57
+ )
58
+ def predict(self, image, task, question, caption):
59
+ if task == 'visual_question_answering':
60
+ assert question is not None, 'Please type a question for visual question answering task.'
61
+ if task == 'image_text_matching':
62
+ assert caption is not None, 'Please type a caption for mage text matching task.'
63
+
64
+ im = load_image(image, image_size=480 if task == 'visual_question_answering' else 384, device=self.device)
65
+ model = self.models[task]
66
+ model.eval()
67
+ model = model.to(self.device)
68
+
69
+ if task == 'image_captioning':
70
+ with torch.no_grad():
71
+ caption = model.generate(im, sample=False, num_beams=3, max_length=20, min_length=5)
72
+ return 'Caption: ' + caption[0]
73
+
74
+ if task == 'visual_question_answering':
75
+ with torch.no_grad():
76
+ answer = model(im, question, train=False, inference='generate')
77
+ return 'Answer: ' + answer[0]
78
+
79
+ # image_text_matching
80
+ itm_output = model(im, caption, match_head='itm')
81
+ itm_score = torch.nn.functional.softmax(itm_output, dim=1)[:, 1]
82
+ itc_score = model(im, caption, match_head='itc')
83
+ return f'The image and text is matched with a probability of {itm_score.item():.4f}.\n' \
84
+ f'The image feature and text feature has a cosine similarity of {itc_score.item():.4f}.'
85
+
86
+
87
+ def load_image(image, image_size, device):
88
+ raw_image = Image.open(str(image)).convert('RGB')
89
+
90
+ w, h = raw_image.size
91
+
92
+ transform = transforms.Compose([
93
+ transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
94
+ transforms.ToTensor(),
95
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
96
+ ])
97
+ image = transform(raw_image).unsqueeze(0).to(device)
98
+ return image
pretrain.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ '''
8
+ import argparse
9
+ import os
10
+ import ruamel_yaml as yaml
11
+ import numpy as np
12
+ import random
13
+ import time
14
+ import datetime
15
+ import json
16
+ from pathlib import Path
17
+
18
+ import torch
19
+ import torch.backends.cudnn as cudnn
20
+ import torch.distributed as dist
21
+
22
+ from models.blip_pretrain import blip_pretrain
23
+ import utils
24
+ from utils import warmup_lr_schedule, step_lr_schedule
25
+ from data import create_dataset, create_sampler, create_loader
26
+
27
+ def train(model, data_loader, optimizer, epoch, device, config):
28
+ # train
29
+ model.train()
30
+
31
+ metric_logger = utils.MetricLogger(delimiter=" ")
32
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=50, fmt='{value:.6f}'))
33
+ metric_logger.add_meter('loss_ita', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
34
+ metric_logger.add_meter('loss_itm', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
35
+ metric_logger.add_meter('loss_lm', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
36
+
37
+ header = 'Train Epoch: [{}]'.format(epoch)
38
+ print_freq = 50
39
+
40
+ if config['laion_path']:
41
+ data_loader.dataset.reload_laion(epoch)
42
+
43
+ data_loader.sampler.set_epoch(epoch)
44
+
45
+ for i, (image, caption) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
46
+
47
+ if epoch==0:
48
+ warmup_lr_schedule(optimizer, i, config['warmup_steps'], config['warmup_lr'], config['init_lr'])
49
+
50
+ optimizer.zero_grad()
51
+
52
+ image = image.to(device,non_blocking=True)
53
+
54
+ # ramp up alpha in the first 2 epochs
55
+ alpha = config['alpha']*min(1,(epoch*len(data_loader)+i)/(2*len(data_loader)))
56
+
57
+ loss_ita, loss_itm, loss_lm = model(image, caption, alpha = alpha)
58
+ loss = loss_ita + loss_itm + loss_lm
59
+
60
+ loss.backward()
61
+ optimizer.step()
62
+
63
+ metric_logger.update(loss_ita=loss_ita.item())
64
+ metric_logger.update(loss_itm=loss_itm.item())
65
+ metric_logger.update(loss_lm=loss_lm.item())
66
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
67
+
68
+
69
+ # gather the stats from all processes
70
+ metric_logger.synchronize_between_processes()
71
+ print("Averaged stats:", metric_logger.global_avg())
72
+ return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
73
+
74
+
75
+ def main(args, config):
76
+ utils.init_distributed_mode(args)
77
+
78
+ device = torch.device(args.device)
79
+
80
+ # fix the seed for reproducibility
81
+ seed = args.seed + utils.get_rank()
82
+ torch.manual_seed(seed)
83
+ np.random.seed(seed)
84
+ random.seed(seed)
85
+ cudnn.benchmark = True
86
+
87
+ #### Dataset ####
88
+ print("Creating dataset")
89
+ datasets = [create_dataset('pretrain', config, min_scale=0.2)]
90
+ print('number of training samples: %d'%len(datasets[0]))
91
+
92
+ num_tasks = utils.get_world_size()
93
+ global_rank = utils.get_rank()
94
+ samplers = create_sampler(datasets, [True], num_tasks, global_rank)
95
+
96
+ data_loader = create_loader(datasets,samplers,batch_size=[config['batch_size']], num_workers=[4], is_trains=[True], collate_fns=[None])[0]
97
+
98
+ #### Model ####
99
+ print("Creating model")
100
+ model = blip_pretrain(image_size=config['image_size'], vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'],
101
+ vit_ckpt_layer=config['vit_ckpt_layer'], queue_size=config['queue_size'])
102
+
103
+ model = model.to(device)
104
+
105
+ optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay'])
106
+
107
+ start_epoch = 0
108
+ if args.checkpoint:
109
+ checkpoint = torch.load(args.checkpoint, map_location='cpu')
110
+ state_dict = checkpoint['model']
111
+ model.load_state_dict(state_dict)
112
+
113
+ optimizer.load_state_dict(checkpoint['optimizer'])
114
+ start_epoch = checkpoint['epoch']+1
115
+ print('resume checkpoint from %s'%args.checkpoint)
116
+
117
+ model_without_ddp = model
118
+ if args.distributed:
119
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
120
+ model_without_ddp = model.module
121
+
122
+ print("Start training")
123
+ start_time = time.time()
124
+ for epoch in range(start_epoch, config['max_epoch']):
125
+
126
+ step_lr_schedule(optimizer, epoch, config['init_lr'], config['min_lr'], config['lr_decay_rate'])
127
+
128
+ train_stats = train(model, data_loader, optimizer, epoch, device, config)
129
+ if utils.is_main_process():
130
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
131
+ 'epoch': epoch,
132
+ }
133
+ save_obj = {
134
+ 'model': model_without_ddp.state_dict(),
135
+ 'optimizer': optimizer.state_dict(),
136
+ 'config': config,
137
+ 'epoch': epoch,
138
+ }
139
+ torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_%02d.pth'%epoch))
140
+
141
+ with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
142
+ f.write(json.dumps(log_stats) + "\n")
143
+
144
+ dist.barrier()
145
+
146
+ total_time = time.time() - start_time
147
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
148
+ print('Training time {}'.format(total_time_str))
149
+
150
+
151
+ if __name__ == '__main__':
152
+ parser = argparse.ArgumentParser()
153
+ parser.add_argument('--config', default='./configs/pretrain.yaml')
154
+ parser.add_argument('--output_dir', default='output/Pretrain')
155
+ parser.add_argument('--checkpoint', default='')
156
+ parser.add_argument('--evaluate', action='store_true')
157
+ parser.add_argument('--device', default='cuda')
158
+ parser.add_argument('--seed', default=42, type=int)
159
+ parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
160
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
161
+ parser.add_argument('--distributed', default=True, type=bool)
162
+ args = parser.parse_args()
163
+
164
+ config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
165
+
166
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
167
+
168
+ yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
169
+
170
+ main(args, config)
read_json.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ import pandas as pd
2
+ df=pd.read_json('/workspace/BLIP/BLIP_main/objects.json')
3
+ print(df.head())
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ timm==0.4.12
2
+ transformers==4.15.0
3
+ fairscale==0.4.4
4
+ pycocoevalcap
5
+ ruamel_yaml
sukhoi_su57.jpg ADDED
train_caption.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ '''
8
+ import argparse
9
+ import os
10
+ import ruamel_yaml as yaml
11
+ import numpy as np
12
+ import random
13
+ import time
14
+ import datetime
15
+ import json
16
+ from pathlib import Path
17
+
18
+ import torch
19
+ import torch.backends.cudnn as cudnn
20
+ import torch.distributed as dist
21
+
22
+ from models.blip import blip_decoder
23
+ import utils
24
+ from utils import cosine_lr_schedule
25
+ from data import create_dataset, create_sampler, create_loader
26
+ from data.utils import save_result, coco_caption_eval
27
+
28
+ def train(model, data_loader, optimizer, epoch, device):
29
+ # train
30
+ model.train()
31
+
32
+ metric_logger = utils.MetricLogger(delimiter=" ")
33
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
34
+ metric_logger.add_meter('loss', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
35
+ header = 'Train Caption Epoch: [{}]'.format(epoch)
36
+ print_freq = 50
37
+
38
+ for i, (image, caption, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
39
+ image = image.to(device)
40
+
41
+ loss = model(image, caption)
42
+
43
+ optimizer.zero_grad()
44
+ loss.backward()
45
+ optimizer.step()
46
+
47
+ metric_logger.update(loss=loss.item())
48
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
49
+
50
+ # gather the stats from all processes
51
+ metric_logger.synchronize_between_processes()
52
+ print("Averaged stats:", metric_logger.global_avg())
53
+ return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
54
+
55
+
56
+ @torch.no_grad()
57
+ def evaluate(model, data_loader, device, config):
58
+ # evaluate
59
+ model.eval()
60
+
61
+ metric_logger = utils.MetricLogger(delimiter=" ")
62
+ header = 'Caption generation:'
63
+ print_freq = 10
64
+
65
+ result = []
66
+ for image, image_id in metric_logger.log_every(data_loader, print_freq, header):
67
+
68
+ image = image.to(device)
69
+
70
+ captions = model.generate(image, sample=False, num_beams=config['num_beams'], max_length=config['max_length'],
71
+ min_length=config['min_length'])
72
+
73
+ for caption, img_id in zip(captions, image_id):
74
+ result.append({"image_id": img_id.item(), "caption": caption})
75
+
76
+ return result
77
+
78
+
79
+ def main(args, config):
80
+ utils.init_distributed_mode(args)
81
+
82
+ device = torch.device(args.device)
83
+
84
+ # fix the seed for reproducibility
85
+ seed = args.seed + utils.get_rank()
86
+ torch.manual_seed(seed)
87
+ np.random.seed(seed)
88
+ random.seed(seed)
89
+ cudnn.benchmark = True
90
+
91
+ #### Dataset ####
92
+ print("Creating captioning dataset")
93
+ train_dataset, val_dataset, test_dataset = create_dataset('caption_coco', config)
94
+
95
+ if args.distributed:
96
+ num_tasks = utils.get_world_size()
97
+ global_rank = utils.get_rank()
98
+ samplers = create_sampler([train_dataset,val_dataset,test_dataset], [True,False,False], num_tasks, global_rank)
99
+ else:
100
+ samplers = [None, None, None]
101
+
102
+ train_loader, val_loader, test_loader = create_loader([train_dataset, val_dataset, test_dataset],samplers,
103
+ batch_size=[config['batch_size']]*3,num_workers=[4,4,4],
104
+ is_trains=[True, False, False], collate_fns=[None,None,None])
105
+
106
+ #### Model ####
107
+ print("Creating model")
108
+ model = blip_decoder(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'],
109
+ vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'],
110
+ prompt=config['prompt'])
111
+
112
+ model = model.to(device)
113
+
114
+ model_without_ddp = model
115
+ if args.distributed:
116
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
117
+ model_without_ddp = model.module
118
+
119
+ optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay'])
120
+
121
+ best = 0
122
+ best_epoch = 0
123
+
124
+ print("Start training")
125
+ start_time = time.time()
126
+ for epoch in range(0, config['max_epoch']):
127
+ if not args.evaluate:
128
+ if args.distributed:
129
+ train_loader.sampler.set_epoch(epoch)
130
+
131
+ cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr'])
132
+
133
+ train_stats = train(model, train_loader, optimizer, epoch, device)
134
+
135
+ val_result = evaluate(model_without_ddp, val_loader, device, config)
136
+ val_result_file = save_result(val_result, args.result_dir, 'val_epoch%d'%epoch, remove_duplicate='image_id')
137
+
138
+ test_result = evaluate(model_without_ddp, test_loader, device, config)
139
+ test_result_file = save_result(test_result, args.result_dir, 'test_epoch%d'%epoch, remove_duplicate='image_id')
140
+
141
+ if utils.is_main_process():
142
+ coco_val = coco_caption_eval(config['coco_gt_root'],val_result_file,'val')
143
+ coco_test = coco_caption_eval(config['coco_gt_root'],test_result_file,'test')
144
+
145
+ if args.evaluate:
146
+ log_stats = {**{f'val_{k}': v for k, v in coco_val.eval.items()},
147
+ **{f'test_{k}': v for k, v in coco_test.eval.items()},
148
+ }
149
+ with open(os.path.join(args.output_dir, "evaluate.txt"),"a") as f:
150
+ f.write(json.dumps(log_stats) + "\n")
151
+ else:
152
+ save_obj = {
153
+ 'model': model_without_ddp.state_dict(),
154
+ 'optimizer': optimizer.state_dict(),
155
+ 'config': config,
156
+ 'epoch': epoch,
157
+ }
158
+
159
+ if coco_val.eval['CIDEr'] + coco_val.eval['Bleu_4'] > best:
160
+ best = coco_val.eval['CIDEr'] + coco_val.eval['Bleu_4']
161
+ best_epoch = epoch
162
+ torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth'))
163
+
164
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
165
+ **{f'val_{k}': v for k, v in coco_val.eval.items()},
166
+ **{f'test_{k}': v for k, v in coco_test.eval.items()},
167
+ 'epoch': epoch,
168
+ 'best_epoch': best_epoch,
169
+ }
170
+ with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
171
+ f.write(json.dumps(log_stats) + "\n")
172
+
173
+ if args.evaluate:
174
+ break
175
+ dist.barrier()
176
+
177
+ total_time = time.time() - start_time
178
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
179
+ print('Training time {}'.format(total_time_str))
180
+
181
+
182
+ if __name__ == '__main__':
183
+ parser = argparse.ArgumentParser()
184
+ parser.add_argument('--config', default='./configs/caption_coco.yaml')
185
+ parser.add_argument('--output_dir', default='output/Caption_coco')
186
+ parser.add_argument('--evaluate', action='store_true')
187
+ parser.add_argument('--device', default='cuda')
188
+ parser.add_argument('--seed', default=42, type=int)
189
+ parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
190
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
191
+ parser.add_argument('--distributed', default=True, type=bool)
192
+ args = parser.parse_args()
193
+
194
+ config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
195
+
196
+ args.result_dir = os.path.join(args.output_dir, 'result')
197
+
198
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
199
+ Path(args.result_dir).mkdir(parents=True, exist_ok=True)
200
+
201
+ yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
202
+
203
+ main(args, config)
train_nlvr.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ '''
8
+ import argparse
9
+ import os
10
+ import ruamel_yaml as yaml
11
+ import numpy as np
12
+ import random
13
+ import time
14
+ import datetime
15
+ from pathlib import Path
16
+ import json
17
+
18
+ import torch
19
+ import torch.backends.cudnn as cudnn
20
+ import torch.distributed as dist
21
+
22
+ from models.blip_nlvr import blip_nlvr
23
+
24
+ import utils
25
+ from utils import cosine_lr_schedule
26
+ from data import create_dataset, create_sampler, create_loader
27
+
28
+ def train(model, data_loader, optimizer, epoch, device, config):
29
+ # train
30
+ model.train()
31
+
32
+ metric_logger = utils.MetricLogger(delimiter=" ")
33
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=50, fmt='{value:.6f}'))
34
+ metric_logger.add_meter('loss', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
35
+
36
+ header = 'Train Epoch: [{}]'.format(epoch)
37
+ print_freq = 50
38
+ step_size = 10
39
+
40
+ for i,(image0, image1, text, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
41
+
42
+ images = torch.cat([image0, image1], dim=0)
43
+ images, targets = images.to(device), targets.to(device)
44
+
45
+ loss = model(images, text, targets=targets, train=True)
46
+
47
+ optimizer.zero_grad()
48
+ loss.backward()
49
+ optimizer.step()
50
+
51
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
52
+ metric_logger.update(loss=loss.item())
53
+
54
+ # gather the stats from all processes
55
+ metric_logger.synchronize_between_processes()
56
+ print("Averaged stats:", metric_logger.global_avg())
57
+ return {k: "{:.4f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
58
+
59
+
60
+ @torch.no_grad()
61
+ def evaluate(model, data_loader, device, config):
62
+ # test
63
+ model.eval()
64
+
65
+ metric_logger = utils.MetricLogger(delimiter=" ")
66
+
67
+ header = 'Evaluation:'
68
+ print_freq = 50
69
+
70
+ for image0, image1, text, targets in metric_logger.log_every(data_loader, print_freq, header):
71
+ images = torch.cat([image0, image1], dim=0)
72
+ images, targets = images.to(device), targets.to(device)
73
+
74
+ prediction = model(images, text, targets=targets, train=False)
75
+
76
+ _, pred_class = prediction.max(1)
77
+ accuracy = (targets==pred_class).sum() / targets.size(0)
78
+
79
+ metric_logger.meters['acc'].update(accuracy.item(), n=image0.size(0))
80
+
81
+ # gather the stats from all processes
82
+ metric_logger.synchronize_between_processes()
83
+
84
+ print("Averaged stats:", metric_logger.global_avg())
85
+ return {k: "{:.4f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
86
+
87
+
88
+
89
+ def main(args, config):
90
+ utils.init_distributed_mode(args)
91
+
92
+ device = torch.device(args.device)
93
+
94
+ # fix the seed for reproducibility
95
+ seed = args.seed + utils.get_rank()
96
+ torch.manual_seed(seed)
97
+ np.random.seed(seed)
98
+ random.seed(seed)
99
+ cudnn.benchmark = True
100
+
101
+ #### Dataset ####
102
+ print("Creating dataset")
103
+ datasets = create_dataset('nlvr', config)
104
+
105
+ if args.distributed:
106
+ num_tasks = utils.get_world_size()
107
+ global_rank = utils.get_rank()
108
+ samplers = create_sampler(datasets, [True,False,False], num_tasks, global_rank)
109
+ else:
110
+ samplers = [None, None, None]
111
+
112
+ batch_size=[config['batch_size_train'],config['batch_size_test'],config['batch_size_test']]
113
+ train_loader, val_loader, test_loader = create_loader(datasets,samplers,batch_size=batch_size,
114
+ num_workers=[4,4,4],is_trains=[True,False,False],
115
+ collate_fns=[None,None,None])
116
+
117
+ #### Model ####
118
+ print("Creating model")
119
+ model = blip_nlvr(pretrained=config['pretrained'], image_size=config['image_size'],
120
+ vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'])
121
+
122
+ model = model.to(device)
123
+
124
+ model_without_ddp = model
125
+ if args.distributed:
126
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
127
+ model_without_ddp = model.module
128
+
129
+ optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay'])
130
+
131
+ print("Start training")
132
+ start_time = time.time()
133
+ best = 0
134
+ best_epoch = 0
135
+
136
+ for epoch in range(0, config['max_epoch']):
137
+ if not args.evaluate:
138
+ if args.distributed:
139
+ train_loader.sampler.set_epoch(epoch)
140
+
141
+ cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr'])
142
+
143
+ train_stats = train(model, train_loader, optimizer, epoch, device, config)
144
+
145
+ val_stats = evaluate(model, val_loader, device, config)
146
+ test_stats = evaluate(model, test_loader, device, config)
147
+
148
+ if utils.is_main_process():
149
+ if args.evaluate:
150
+ log_stats = {**{f'val_{k}': v for k, v in val_stats.items()},
151
+ **{f'test_{k}': v for k, v in test_stats.items()},
152
+ }
153
+ with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
154
+ f.write(json.dumps(log_stats) + "\n")
155
+
156
+ else:
157
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
158
+ **{f'val_{k}': v for k, v in val_stats.items()},
159
+ **{f'test_{k}': v for k, v in test_stats.items()},
160
+ 'epoch': epoch,
161
+ }
162
+
163
+ if float(val_stats['acc'])>best:
164
+ save_obj = {
165
+ 'model': model_without_ddp.state_dict(),
166
+ 'optimizer': optimizer.state_dict(),
167
+ 'config': config,
168
+ 'epoch': epoch,
169
+ }
170
+ torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth'))
171
+ best = float(val_stats['acc'])
172
+ best_epoch = epoch
173
+
174
+ with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
175
+ f.write(json.dumps(log_stats) + "\n")
176
+ if args.evaluate:
177
+ break
178
+
179
+ dist.barrier()
180
+
181
+ if utils.is_main_process():
182
+ with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
183
+ f.write("best epoch: %d"%best_epoch)
184
+
185
+ total_time = time.time() - start_time
186
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
187
+ print('Training time {}'.format(total_time_str))
188
+
189
+
190
+ if __name__ == '__main__':
191
+ parser = argparse.ArgumentParser()
192
+ parser.add_argument('--config', default='./configs/nlvr.yaml')
193
+ parser.add_argument('--output_dir', default='output/NLVR')
194
+ parser.add_argument('--evaluate', action='store_true')
195
+ parser.add_argument('--device', default='cuda')
196
+ parser.add_argument('--seed', default=42, type=int)
197
+ parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
198
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
199
+ parser.add_argument('--distributed', default=True, type=bool)
200
+ args = parser.parse_args()
201
+
202
+ config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
203
+
204
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
205
+
206
+ yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
207
+
208
+ main(args, config)
train_retrieval.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ '''
8
+ import argparse
9
+ import os
10
+ import ruamel_yaml as yaml
11
+ import numpy as np
12
+ import random
13
+ import time
14
+ import datetime
15
+ import json
16
+ from pathlib import Path
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ import torch.backends.cudnn as cudnn
21
+ import torch.distributed as dist
22
+
23
+ from models.blip_retrieval import blip_retrieval
24
+ import utils
25
+ from utils import cosine_lr_schedule
26
+ from data import create_dataset, create_sampler, create_loader
27
+
28
+
29
+ def train(model, data_loader, optimizer, epoch, device, config):
30
+ # train
31
+ model.train()
32
+
33
+ metric_logger = utils.MetricLogger(delimiter=" ")
34
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
35
+ metric_logger.add_meter('loss_itm', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
36
+ metric_logger.add_meter('loss_ita', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
37
+ header = 'Train Epoch: [{}]'.format(epoch)
38
+ print_freq = 50
39
+
40
+ for i,(image, caption, idx) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
41
+ image = image.to(device,non_blocking=True)
42
+ idx = idx.to(device,non_blocking=True)
43
+
44
+ if epoch>0:
45
+ alpha = config['alpha']
46
+ else:
47
+ alpha = config['alpha']*min(1,i/len(data_loader))
48
+
49
+ loss_ita, loss_itm = model(image, caption, alpha=alpha, idx=idx)
50
+ loss = loss_ita + loss_itm
51
+
52
+ optimizer.zero_grad()
53
+ loss.backward()
54
+ optimizer.step()
55
+
56
+ metric_logger.update(loss_itm=loss_itm.item())
57
+ metric_logger.update(loss_ita=loss_ita.item())
58
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
59
+
60
+ # gather the stats from all processes
61
+ metric_logger.synchronize_between_processes()
62
+ print("Averaged stats:", metric_logger.global_avg())
63
+ return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
64
+
65
+
66
+ @torch.no_grad()
67
+ def evaluation(model, data_loader, device, config):
68
+ # test
69
+ model.eval()
70
+
71
+ metric_logger = utils.MetricLogger(delimiter=" ")
72
+ header = 'Evaluation:'
73
+
74
+ print('Computing features for evaluation...')
75
+ start_time = time.time()
76
+
77
+ texts = data_loader.dataset.text
78
+ num_text = len(texts)
79
+ text_bs = 256
80
+ text_ids = []
81
+ text_embeds = []
82
+ text_atts = []
83
+ for i in range(0, num_text, text_bs):
84
+ text = texts[i: min(num_text, i+text_bs)]
85
+ text_input = model.tokenizer(text, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(device)
86
+ text_output = model.text_encoder(text_input.input_ids, attention_mask = text_input.attention_mask, mode='text')
87
+ text_embed = F.normalize(model.text_proj(text_output.last_hidden_state[:,0,:]))
88
+ text_embeds.append(text_embed)
89
+ text_ids.append(text_input.input_ids)
90
+ text_atts.append(text_input.attention_mask)
91
+
92
+ text_embeds = torch.cat(text_embeds,dim=0)
93
+ text_ids = torch.cat(text_ids,dim=0)
94
+ text_atts = torch.cat(text_atts,dim=0)
95
+ text_ids[:,0] = model.tokenizer.enc_token_id
96
+
97
+ image_feats = []
98
+ image_embeds = []
99
+ for image, img_id in data_loader:
100
+ image = image.to(device)
101
+ image_feat = model.visual_encoder(image)
102
+ image_embed = model.vision_proj(image_feat[:,0,:])
103
+ image_embed = F.normalize(image_embed,dim=-1)
104
+
105
+ image_feats.append(image_feat.cpu())
106
+ image_embeds.append(image_embed)
107
+
108
+ image_feats = torch.cat(image_feats,dim=0)
109
+ image_embeds = torch.cat(image_embeds,dim=0)
110
+
111
+ sims_matrix = image_embeds @ text_embeds.t()
112
+ score_matrix_i2t = torch.full((len(data_loader.dataset.image),len(texts)),-100.0).to(device)
113
+
114
+ num_tasks = utils.get_world_size()
115
+ rank = utils.get_rank()
116
+ step = sims_matrix.size(0)//num_tasks + 1
117
+ start = rank*step
118
+ end = min(sims_matrix.size(0),start+step)
119
+
120
+ for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)):
121
+ topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)
122
+
123
+ encoder_output = image_feats[start+i].repeat(config['k_test'],1,1).to(device)
124
+ encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device)
125
+ output = model.text_encoder(text_ids[topk_idx],
126
+ attention_mask = text_atts[topk_idx],
127
+ encoder_hidden_states = encoder_output,
128
+ encoder_attention_mask = encoder_att,
129
+ return_dict = True,
130
+ )
131
+ score = model.itm_head(output.last_hidden_state[:,0,:])[:,1]
132
+ score_matrix_i2t[start+i,topk_idx] = score + topk_sim
133
+
134
+ sims_matrix = sims_matrix.t()
135
+ score_matrix_t2i = torch.full((len(texts),len(data_loader.dataset.image)),-100.0).to(device)
136
+
137
+ step = sims_matrix.size(0)//num_tasks + 1
138
+ start = rank*step
139
+ end = min(sims_matrix.size(0),start+step)
140
+
141
+ for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)):
142
+
143
+ topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)
144
+ encoder_output = image_feats[topk_idx].to(device)
145
+ encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device)
146
+ output = model.text_encoder(text_ids[start+i].repeat(config['k_test'],1),
147
+ attention_mask = text_atts[start+i].repeat(config['k_test'],1),
148
+ encoder_hidden_states = encoder_output,
149
+ encoder_attention_mask = encoder_att,
150
+ return_dict = True,
151
+ )
152
+ score = model.itm_head(output.last_hidden_state[:,0,:])[:,1]
153
+ score_matrix_t2i[start+i,topk_idx] = score + topk_sim
154
+
155
+ if args.distributed:
156
+ dist.barrier()
157
+ torch.distributed.all_reduce(score_matrix_i2t, op=torch.distributed.ReduceOp.SUM)
158
+ torch.distributed.all_reduce(score_matrix_t2i, op=torch.distributed.ReduceOp.SUM)
159
+
160
+ total_time = time.time() - start_time
161
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
162
+ print('Evaluation time {}'.format(total_time_str))
163
+
164
+ return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()
165
+
166
+
167
+
168
+ @torch.no_grad()
169
+ def itm_eval(scores_i2t, scores_t2i, txt2img, img2txt):
170
+
171
+ #Images->Text
172
+ ranks = np.zeros(scores_i2t.shape[0])
173
+ for index,score in enumerate(scores_i2t):
174
+ inds = np.argsort(score)[::-1]
175
+ # Score
176
+ rank = 1e20
177
+ for i in img2txt[index]:
178
+ tmp = np.where(inds == i)[0][0]
179
+ if tmp < rank:
180
+ rank = tmp
181
+ ranks[index] = rank
182
+
183
+ # Compute metrics
184
+ tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
185
+ tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
186
+ tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
187
+
188
+ #Text->Images
189
+ ranks = np.zeros(scores_t2i.shape[0])
190
+
191
+ for index,score in enumerate(scores_t2i):
192
+ inds = np.argsort(score)[::-1]
193
+ ranks[index] = np.where(inds == txt2img[index])[0][0]
194
+
195
+ # Compute metrics
196
+ ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
197
+ ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
198
+ ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
199
+
200
+ tr_mean = (tr1 + tr5 + tr10) / 3
201
+ ir_mean = (ir1 + ir5 + ir10) / 3
202
+ r_mean = (tr_mean + ir_mean) / 2
203
+
204
+ eval_result = {'txt_r1': tr1,
205
+ 'txt_r5': tr5,
206
+ 'txt_r10': tr10,
207
+ 'txt_r_mean': tr_mean,
208
+ 'img_r1': ir1,
209
+ 'img_r5': ir5,
210
+ 'img_r10': ir10,
211
+ 'img_r_mean': ir_mean,
212
+ 'r_mean': r_mean}
213
+ return eval_result
214
+
215
+
216
+ def main(args, config):
217
+ utils.init_distributed_mode(args)
218
+
219
+ device = torch.device(args.device)
220
+
221
+ # fix the seed for reproducibility
222
+ seed = args.seed + utils.get_rank()
223
+ torch.manual_seed(seed)
224
+ np.random.seed(seed)
225
+ random.seed(seed)
226
+ cudnn.benchmark = True
227
+
228
+ #### Dataset ####
229
+ print("Creating retrieval dataset")
230
+ train_dataset, val_dataset, test_dataset = create_dataset('retrieval_%s'%config['dataset'], config)
231
+
232
+ if args.distributed:
233
+ num_tasks = utils.get_world_size()
234
+ global_rank = utils.get_rank()
235
+ samplers = create_sampler([train_dataset], [True], num_tasks, global_rank) + [None, None]
236
+ else:
237
+ samplers = [None, None, None]
238
+
239
+ train_loader, val_loader, test_loader = create_loader([train_dataset, val_dataset, test_dataset],samplers,
240
+ batch_size=[config['batch_size_train']]+[config['batch_size_test']]*2,
241
+ num_workers=[4,4,4],
242
+ is_trains=[True, False, False],
243
+ collate_fns=[None,None,None])
244
+
245
+
246
+ #### Model ####
247
+ print("Creating model")
248
+ model = blip_retrieval(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'],
249
+ vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'],
250
+ queue_size=config['queue_size'], negative_all_rank=config['negative_all_rank'])
251
+
252
+ model = model.to(device)
253
+
254
+ model_without_ddp = model
255
+ if args.distributed:
256
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
257
+ model_without_ddp = model.module
258
+
259
+ optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay'])
260
+
261
+ best = 0
262
+ best_epoch = 0
263
+
264
+ print("Start training")
265
+ start_time = time.time()
266
+
267
+ for epoch in range(0, config['max_epoch']):
268
+ if not args.evaluate:
269
+ if args.distributed:
270
+ train_loader.sampler.set_epoch(epoch)
271
+
272
+ cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr'])
273
+
274
+ train_stats = train(model, train_loader, optimizer, epoch, device, config)
275
+
276
+ score_val_i2t, score_val_t2i, = evaluation(model_without_ddp, val_loader, device, config)
277
+ score_test_i2t, score_test_t2i = evaluation(model_without_ddp, test_loader, device, config)
278
+
279
+ if utils.is_main_process():
280
+
281
+ val_result = itm_eval(score_val_i2t, score_val_t2i, val_loader.dataset.txt2img, val_loader.dataset.img2txt)
282
+ print(val_result)
283
+
284
+ if val_result['r_mean']>best:
285
+ save_obj = {
286
+ 'model': model_without_ddp.state_dict(),
287
+ 'optimizer': optimizer.state_dict(),
288
+ 'config': config,
289
+ 'epoch': epoch,
290
+ }
291
+ torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth'))
292
+ best = val_result['r_mean']
293
+ best_epoch = epoch
294
+
295
+ test_result = itm_eval(score_test_i2t, score_test_t2i, test_loader.dataset.txt2img, test_loader.dataset.img2txt)
296
+ print(test_result)
297
+
298
+ if args.evaluate:
299
+ log_stats = {**{f'val_{k}': v for k, v in val_result.items()},
300
+ **{f'test_{k}': v for k, v in test_result.items()},
301
+ }
302
+ with open(os.path.join(args.output_dir, "evaluate.txt"),"a") as f:
303
+ f.write(json.dumps(log_stats) + "\n")
304
+ else:
305
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
306
+ **{f'val_{k}': v for k, v in val_result.items()},
307
+ **{f'test_{k}': v for k, v in test_result.items()},
308
+ 'epoch': epoch,
309
+ 'best_epoch': best_epoch,
310
+ }
311
+ with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
312
+ f.write(json.dumps(log_stats) + "\n")
313
+
314
+ if args.evaluate:
315
+ break
316
+
317
+ dist.barrier()
318
+ torch.cuda.empty_cache()
319
+
320
+ total_time = time.time() - start_time
321
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
322
+ print('Training time {}'.format(total_time_str))
323
+
324
+
325
+ if __name__ == '__main__':
326
+ parser = argparse.ArgumentParser()
327
+ parser.add_argument('--config', default='./configs/retrieval_flickr.yaml')
328
+ parser.add_argument('--output_dir', default='output/Retrieval_flickr')
329
+ parser.add_argument('--evaluate', action='store_true')
330
+ parser.add_argument('--device', default='cuda')
331
+ parser.add_argument('--seed', default=42, type=int)
332
+ parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
333
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
334
+ parser.add_argument('--distributed', default=True, type=bool)
335
+ args = parser.parse_args()
336
+
337
+ config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
338
+
339
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
340
+
341
+ yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
342
+
343
+ main(args, config)
train_vqa.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ '''
8
+ import argparse
9
+ import os
10
+ import ruamel_yaml as yaml
11
+ import numpy as np
12
+ import random
13
+ import time
14
+ import datetime
15
+ import json
16
+ from pathlib import Path
17
+
18
+ import torch
19
+ import torch.backends.cudnn as cudnn
20
+ import torch.distributed as dist
21
+
22
+ from models.blip_vqa import blip_vqa
23
+ import utils
24
+ from utils import cosine_lr_schedule
25
+ from data import create_dataset, create_sampler, create_loader
26
+ from data.vqa_dataset import vqa_collate_fn
27
+ from data.utils import save_result
28
+
29
+
30
+ def train(model, data_loader, optimizer, epoch, device):
31
+ # train
32
+ model.train()
33
+
34
+ metric_logger = utils.MetricLogger(delimiter=" ")
35
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
36
+ metric_logger.add_meter('loss', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
37
+
38
+ header = 'Train Epoch: [{}]'.format(epoch)
39
+ print_freq = 50
40
+
41
+ for i,(image, question, answer, weights, n) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
42
+ image, weights = image.to(device,non_blocking=True), weights.to(device,non_blocking=True)
43
+
44
+ loss = model(image, question, answer, train=True, n=n, weights=weights)
45
+
46
+ optimizer.zero_grad()
47
+ loss.backward()
48
+ optimizer.step()
49
+
50
+ metric_logger.update(loss=loss.item())
51
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
52
+
53
+ # gather the stats from all processes
54
+ metric_logger.synchronize_between_processes()
55
+ print("Averaged stats:", metric_logger.global_avg())
56
+ return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
57
+
58
+
59
+ @torch.no_grad()
60
+ def evaluation(model, data_loader, device, config) :
61
+ # test
62
+ model.eval()
63
+
64
+ metric_logger = utils.MetricLogger(delimiter=" ")
65
+ header = 'Generate VQA test result:'
66
+ print_freq = 50
67
+
68
+ result = []
69
+
70
+ if config['inference']=='rank':
71
+ answer_list = data_loader.dataset.answer_list
72
+ answer_candidates = model.tokenizer(answer_list, padding='longest', return_tensors='pt').to(device)
73
+ answer_candidates.input_ids[:,0] = model.tokenizer.bos_token_id
74
+
75
+ for n, (image, question, question_id) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
76
+ image = image.to(device,non_blocking=True)
77
+
78
+ if config['inference']=='generate':
79
+ answers = model(image, question, train=False, inference='generate')
80
+
81
+ for answer, ques_id in zip(answers, question_id):
82
+ ques_id = int(ques_id.item())
83
+ result.append({"question_id":ques_id, "answer":answer})
84
+
85
+ elif config['inference']=='rank':
86
+ answer_ids = model(image, question, answer_candidates, train=False, inference='rank', k_test=config['k_test'])
87
+
88
+ for ques_id, answer_id in zip(question_id, answer_ids):
89
+ result.append({"question_id":int(ques_id.item()), "answer":answer_list[answer_id]})
90
+
91
+ return result
92
+
93
+
94
+ def main(args, config):
95
+ utils.init_distributed_mode(args)
96
+
97
+ device = torch.device(args.device)
98
+
99
+ # fix the seed for reproducibility
100
+ seed = args.seed + utils.get_rank()
101
+ torch.manual_seed(seed)
102
+ np.random.seed(seed)
103
+ random.seed(seed)
104
+ cudnn.benchmark = True
105
+
106
+ #### Dataset ####
107
+ print("Creating vqa datasets")
108
+ datasets = create_dataset('vqa', config)
109
+
110
+ if args.distributed:
111
+ num_tasks = utils.get_world_size()
112
+ global_rank = utils.get_rank()
113
+ samplers = create_sampler(datasets, [True, False], num_tasks, global_rank)
114
+ else:
115
+ samplers = [None, None]
116
+
117
+ train_loader, test_loader = create_loader(datasets,samplers,
118
+ batch_size=[config['batch_size_train'],config['batch_size_test']],
119
+ num_workers=[4,4],is_trains=[True, False],
120
+ collate_fns=[vqa_collate_fn,None])
121
+ #### Model ####
122
+ print("Creating model")
123
+ model = blip_vqa(pretrained=config['pretrained'], image_size=config['image_size'],
124
+ vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'])
125
+
126
+ model = model.to(device)
127
+
128
+ model_without_ddp = model
129
+ if args.distributed:
130
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
131
+ model_without_ddp = model.module
132
+
133
+ optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay'])
134
+
135
+ best = 0
136
+ best_epoch = 0
137
+
138
+ print("Start training")
139
+ start_time = time.time()
140
+ for epoch in range(0, config['max_epoch']):
141
+ if not args.evaluate:
142
+ if args.distributed:
143
+ train_loader.sampler.set_epoch(epoch)
144
+
145
+ cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr'])
146
+
147
+ train_stats = train(model, train_loader, optimizer, epoch, device)
148
+
149
+ else:
150
+ break
151
+
152
+ if utils.is_main_process():
153
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
154
+ 'epoch': epoch,
155
+ }
156
+ with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
157
+ f.write(json.dumps(log_stats) + "\n")
158
+
159
+ save_obj = {
160
+ 'model': model_without_ddp.state_dict(),
161
+ 'optimizer': optimizer.state_dict(),
162
+ 'config': config,
163
+ 'epoch': epoch,
164
+ }
165
+ torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_%02d.pth'%epoch))
166
+
167
+ dist.barrier()
168
+
169
+ vqa_result = evaluation(model_without_ddp, test_loader, device, config)
170
+ result_file = save_result(vqa_result, args.result_dir, 'vqa_result')
171
+
172
+ total_time = time.time() - start_time
173
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
174
+ print('Training time {}'.format(total_time_str))
175
+
176
+
177
+
178
+ if __name__ == '__main__':
179
+ parser = argparse.ArgumentParser()
180
+ parser.add_argument('--config', default='./configs/vqa.yaml')
181
+ parser.add_argument('--output_dir', default='output/VQA')
182
+ parser.add_argument('--evaluate', action='store_true')
183
+ parser.add_argument('--device', default='cuda')
184
+ parser.add_argument('--seed', default=42, type=int)
185
+ parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
186
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
187
+ parser.add_argument('--distributed', default=True, type=bool)
188
+ args = parser.parse_args()
189
+
190
+ config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
191
+
192
+ args.result_dir = os.path.join(args.output_dir, 'result')
193
+
194
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
195
+ Path(args.result_dir).mkdir(parents=True, exist_ok=True)
196
+
197
+ yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
198
+
199
+ main(args, config)
transform/randaugment.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+
5
+ ## aug functions
6
+ def identity_func(img):
7
+ return img
8
+
9
+
10
+ def autocontrast_func(img, cutoff=0):
11
+ '''
12
+ same output as PIL.ImageOps.autocontrast
13
+ '''
14
+ n_bins = 256
15
+
16
+ def tune_channel(ch):
17
+ n = ch.size
18
+ cut = cutoff * n // 100
19
+ if cut == 0:
20
+ high, low = ch.max(), ch.min()
21
+ else:
22
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
23
+ low = np.argwhere(np.cumsum(hist) > cut)
24
+ low = 0 if low.shape[0] == 0 else low[0]
25
+ high = np.argwhere(np.cumsum(hist[::-1]) > cut)
26
+ high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
27
+ if high <= low:
28
+ table = np.arange(n_bins)
29
+ else:
30
+ scale = (n_bins - 1) / (high - low)
31
+ offset = -low * scale
32
+ table = np.arange(n_bins) * scale + offset
33
+ table[table < 0] = 0
34
+ table[table > n_bins - 1] = n_bins - 1
35
+ table = table.clip(0, 255).astype(np.uint8)
36
+ return table[ch]
37
+
38
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
39
+ out = cv2.merge(channels)
40
+ return out
41
+
42
+
43
+ def equalize_func(img):
44
+ '''
45
+ same output as PIL.ImageOps.equalize
46
+ PIL's implementation is different from cv2.equalize
47
+ '''
48
+ n_bins = 256
49
+
50
+ def tune_channel(ch):
51
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
52
+ non_zero_hist = hist[hist != 0].reshape(-1)
53
+ step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
54
+ if step == 0: return ch
55
+ n = np.empty_like(hist)
56
+ n[0] = step // 2
57
+ n[1:] = hist[:-1]
58
+ table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
59
+ return table[ch]
60
+
61
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
62
+ out = cv2.merge(channels)
63
+ return out
64
+
65
+
66
+ def rotate_func(img, degree, fill=(0, 0, 0)):
67
+ '''
68
+ like PIL, rotate by degree, not radians
69
+ '''
70
+ H, W = img.shape[0], img.shape[1]
71
+ center = W / 2, H / 2
72
+ M = cv2.getRotationMatrix2D(center, degree, 1)
73
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
74
+ return out
75
+
76
+
77
+ def solarize_func(img, thresh=128):
78
+ '''
79
+ same output as PIL.ImageOps.posterize
80
+ '''
81
+ table = np.array([el if el < thresh else 255 - el for el in range(256)])
82
+ table = table.clip(0, 255).astype(np.uint8)
83
+ out = table[img]
84
+ return out
85
+
86
+
87
+ def color_func(img, factor):
88
+ '''
89
+ same output as PIL.ImageEnhance.Color
90
+ '''
91
+ ## implementation according to PIL definition, quite slow
92
+ # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
93
+ # out = blend(degenerate, img, factor)
94
+ # M = (
95
+ # np.eye(3) * factor
96
+ # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
97
+ # )[np.newaxis, np.newaxis, :]
98
+ M = (
99
+ np.float32([
100
+ [0.886, -0.114, -0.114],
101
+ [-0.587, 0.413, -0.587],
102
+ [-0.299, -0.299, 0.701]]) * factor
103
+ + np.float32([[0.114], [0.587], [0.299]])
104
+ )
105
+ out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
106
+ return out
107
+
108
+
109
+ def contrast_func(img, factor):
110
+ """
111
+ same output as PIL.ImageEnhance.Contrast
112
+ """
113
+ mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
114
+ table = np.array([(
115
+ el - mean) * factor + mean
116
+ for el in range(256)
117
+ ]).clip(0, 255).astype(np.uint8)
118
+ out = table[img]
119
+ return out
120
+
121
+
122
+ def brightness_func(img, factor):
123
+ '''
124
+ same output as PIL.ImageEnhance.Contrast
125
+ '''
126
+ table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
127
+ out = table[img]
128
+ return out
129
+
130
+
131
+ def sharpness_func(img, factor):
132
+ '''
133
+ The differences the this result and PIL are all on the 4 boundaries, the center
134
+ areas are same
135
+ '''
136
+ kernel = np.ones((3, 3), dtype=np.float32)
137
+ kernel[1][1] = 5
138
+ kernel /= 13
139
+ degenerate = cv2.filter2D(img, -1, kernel)
140
+ if factor == 0.0:
141
+ out = degenerate
142
+ elif factor == 1.0:
143
+ out = img
144
+ else:
145
+ out = img.astype(np.float32)
146
+ degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
147
+ out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
148
+ out = out.astype(np.uint8)
149
+ return out
150
+
151
+
152
+ def shear_x_func(img, factor, fill=(0, 0, 0)):
153
+ H, W = img.shape[0], img.shape[1]
154
+ M = np.float32([[1, factor, 0], [0, 1, 0]])
155
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
156
+ return out
157
+
158
+
159
+ def translate_x_func(img, offset, fill=(0, 0, 0)):
160
+ '''
161
+ same output as PIL.Image.transform
162
+ '''
163
+ H, W = img.shape[0], img.shape[1]
164
+ M = np.float32([[1, 0, -offset], [0, 1, 0]])
165
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
166
+ return out
167
+
168
+
169
+ def translate_y_func(img, offset, fill=(0, 0, 0)):
170
+ '''
171
+ same output as PIL.Image.transform
172
+ '''
173
+ H, W = img.shape[0], img.shape[1]
174
+ M = np.float32([[1, 0, 0], [0, 1, -offset]])
175
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
176
+ return out
177
+
178
+
179
+ def posterize_func(img, bits):
180
+ '''
181
+ same output as PIL.ImageOps.posterize
182
+ '''
183
+ out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
184
+ return out
185
+
186
+
187
+ def shear_y_func(img, factor, fill=(0, 0, 0)):
188
+ H, W = img.shape[0], img.shape[1]
189
+ M = np.float32([[1, 0, 0], [factor, 1, 0]])
190
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
191
+ return out
192
+
193
+
194
+ def cutout_func(img, pad_size, replace=(0, 0, 0)):
195
+ replace = np.array(replace, dtype=np.uint8)
196
+ H, W = img.shape[0], img.shape[1]
197
+ rh, rw = np.random.random(2)
198
+ pad_size = pad_size // 2
199
+ ch, cw = int(rh * H), int(rw * W)
200
+ x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
201
+ y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
202
+ out = img.copy()
203
+ out[x1:x2, y1:y2, :] = replace
204
+ return out
205
+
206
+
207
+ ### level to args
208
+ def enhance_level_to_args(MAX_LEVEL):
209
+ def level_to_args(level):
210
+ return ((level / MAX_LEVEL) * 1.8 + 0.1,)
211
+ return level_to_args
212
+
213
+
214
+ def shear_level_to_args(MAX_LEVEL, replace_value):
215
+ def level_to_args(level):
216
+ level = (level / MAX_LEVEL) * 0.3
217
+ if np.random.random() > 0.5: level = -level
218
+ return (level, replace_value)
219
+
220
+ return level_to_args
221
+
222
+
223
+ def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
224
+ def level_to_args(level):
225
+ level = (level / MAX_LEVEL) * float(translate_const)
226
+ if np.random.random() > 0.5: level = -level
227
+ return (level, replace_value)
228
+
229
+ return level_to_args
230
+
231
+
232
+ def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
233
+ def level_to_args(level):
234
+ level = int((level / MAX_LEVEL) * cutout_const)
235
+ return (level, replace_value)
236
+
237
+ return level_to_args
238
+
239
+
240
+ def solarize_level_to_args(MAX_LEVEL):
241
+ def level_to_args(level):
242
+ level = int((level / MAX_LEVEL) * 256)
243
+ return (level, )
244
+ return level_to_args
245
+
246
+
247
+ def none_level_to_args(level):
248
+ return ()
249
+
250
+
251
+ def posterize_level_to_args(MAX_LEVEL):
252
+ def level_to_args(level):
253
+ level = int((level / MAX_LEVEL) * 4)
254
+ return (level, )
255
+ return level_to_args
256
+
257
+
258
+ def rotate_level_to_args(MAX_LEVEL, replace_value):
259
+ def level_to_args(level):
260
+ level = (level / MAX_LEVEL) * 30
261
+ if np.random.random() < 0.5:
262
+ level = -level
263
+ return (level, replace_value)
264
+
265
+ return level_to_args
266
+
267
+
268
+ func_dict = {
269
+ 'Identity': identity_func,
270
+ 'AutoContrast': autocontrast_func,
271
+ 'Equalize': equalize_func,
272
+ 'Rotate': rotate_func,
273
+ 'Solarize': solarize_func,
274
+ 'Color': color_func,
275
+ 'Contrast': contrast_func,
276
+ 'Brightness': brightness_func,
277
+ 'Sharpness': sharpness_func,
278
+ 'ShearX': shear_x_func,
279
+ 'TranslateX': translate_x_func,
280
+ 'TranslateY': translate_y_func,
281
+ 'Posterize': posterize_func,
282
+ 'ShearY': shear_y_func,
283
+ }
284
+
285
+ translate_const = 10
286
+ MAX_LEVEL = 10
287
+ replace_value = (128, 128, 128)
288
+ arg_dict = {
289
+ 'Identity': none_level_to_args,
290
+ 'AutoContrast': none_level_to_args,
291
+ 'Equalize': none_level_to_args,
292
+ 'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value),
293
+ 'Solarize': solarize_level_to_args(MAX_LEVEL),
294
+ 'Color': enhance_level_to_args(MAX_LEVEL),
295
+ 'Contrast': enhance_level_to_args(MAX_LEVEL),
296
+ 'Brightness': enhance_level_to_args(MAX_LEVEL),
297
+ 'Sharpness': enhance_level_to_args(MAX_LEVEL),
298
+ 'ShearX': shear_level_to_args(MAX_LEVEL, replace_value),
299
+ 'TranslateX': translate_level_to_args(
300
+ translate_const, MAX_LEVEL, replace_value
301
+ ),
302
+ 'TranslateY': translate_level_to_args(
303
+ translate_const, MAX_LEVEL, replace_value
304
+ ),
305
+ 'Posterize': posterize_level_to_args(MAX_LEVEL),
306
+ 'ShearY': shear_level_to_args(MAX_LEVEL, replace_value),
307
+ }
308
+
309
+
310
+ class RandomAugment(object):
311
+
312
+ def __init__(self, N=2, M=10, isPIL=False, augs=[]):
313
+ self.N = N
314
+ self.M = M
315
+ self.isPIL = isPIL
316
+ if augs:
317
+ self.augs = augs
318
+ else:
319
+ self.augs = list(arg_dict.keys())
320
+
321
+ def get_random_ops(self):
322
+ sampled_ops = np.random.choice(self.augs, self.N)
323
+ return [(op, 0.5, self.M) for op in sampled_ops]
324
+
325
+ def __call__(self, img):
326
+ if self.isPIL:
327
+ img = np.array(img)
328
+ ops = self.get_random_ops()
329
+ for name, prob, level in ops:
330
+ if np.random.random() > prob:
331
+ continue
332
+ args = arg_dict[name](level)
333
+ img = func_dict[name](img, *args)
334
+ return img
335
+
336
+
337
+ if __name__ == '__main__':
338
+ a = RandomAugment()
339
+ img = np.random.randn(32, 32, 3)
340
+ a(img)
utils.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
3
+ """Decay the learning rate"""
4
+ lr = (init_lr - min_lr) * 0.5 * (1. + math.cos(math.pi * epoch / max_epoch)) + min_lr
5
+ for param_group in optimizer.param_groups:
6
+ param_group['lr'] = lr
7
+
8
+ def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
9
+ """Warmup the learning rate"""
10
+ lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max_step)
11
+ for param_group in optimizer.param_groups:
12
+ param_group['lr'] = lr
13
+
14
+ def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
15
+ """Decay the learning rate"""
16
+ lr = max(min_lr, init_lr * (decay_rate**epoch))
17
+ for param_group in optimizer.param_groups:
18
+ param_group['lr'] = lr
19
+
20
+ import numpy as np
21
+ import io
22
+ import os
23
+ import time
24
+ from collections import defaultdict, deque
25
+ import datetime
26
+
27
+ import torch
28
+ import torch.distributed as dist
29
+
30
+ class SmoothedValue(object):
31
+ """Track a series of values and provide access to smoothed values over a
32
+ window or the global series average.
33
+ """
34
+
35
+ def __init__(self, window_size=20, fmt=None):
36
+ if fmt is None:
37
+ fmt = "{median:.4f} ({global_avg:.4f})"
38
+ self.deque = deque(maxlen=window_size)
39
+ self.total = 0.0
40
+ self.count = 0
41
+ self.fmt = fmt
42
+
43
+ def update(self, value, n=1):
44
+ self.deque.append(value)
45
+ self.count += n
46
+ self.total += value * n
47
+
48
+ def synchronize_between_processes(self):
49
+ """
50
+ Warning: does not synchronize the deque!
51
+ """
52
+ if not is_dist_avail_and_initialized():
53
+ return
54
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
55
+ dist.barrier()
56
+ dist.all_reduce(t)
57
+ t = t.tolist()
58
+ self.count = int(t[0])
59
+ self.total = t[1]
60
+
61
+ @property
62
+ def median(self):
63
+ d = torch.tensor(list(self.deque))
64
+ return d.median().item()
65
+
66
+ @property
67
+ def avg(self):
68
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
69
+ return d.mean().item()
70
+
71
+ @property
72
+ def global_avg(self):
73
+ return self.total / self.count
74
+
75
+ @property
76
+ def max(self):
77
+ return max(self.deque)
78
+
79
+ @property
80
+ def value(self):
81
+ return self.deque[-1]
82
+
83
+ def __str__(self):
84
+ return self.fmt.format(
85
+ median=self.median,
86
+ avg=self.avg,
87
+ global_avg=self.global_avg,
88
+ max=self.max,
89
+ value=self.value)
90
+
91
+
92
+ class MetricLogger(object):
93
+ def __init__(self, delimiter="\t"):
94
+ self.meters = defaultdict(SmoothedValue)
95
+ self.delimiter = delimiter
96
+
97
+ def update(self, **kwargs):
98
+ for k, v in kwargs.items():
99
+ if isinstance(v, torch.Tensor):
100
+ v = v.item()
101
+ assert isinstance(v, (float, int))
102
+ self.meters[k].update(v)
103
+
104
+ def __getattr__(self, attr):
105
+ if attr in self.meters:
106
+ return self.meters[attr]
107
+ if attr in self.__dict__:
108
+ return self.__dict__[attr]
109
+ raise AttributeError("'{}' object has no attribute '{}'".format(
110
+ type(self).__name__, attr))
111
+
112
+ def __str__(self):
113
+ loss_str = []
114
+ for name, meter in self.meters.items():
115
+ loss_str.append(
116
+ "{}: {}".format(name, str(meter))
117
+ )
118
+ return self.delimiter.join(loss_str)
119
+
120
+ def global_avg(self):
121
+ loss_str = []
122
+ for name, meter in self.meters.items():
123
+ loss_str.append(
124
+ "{}: {:.4f}".format(name, meter.global_avg)
125
+ )
126
+ return self.delimiter.join(loss_str)
127
+
128
+ def synchronize_between_processes(self):
129
+ for meter in self.meters.values():
130
+ meter.synchronize_between_processes()
131
+
132
+ def add_meter(self, name, meter):
133
+ self.meters[name] = meter
134
+
135
+ def log_every(self, iterable, print_freq, header=None):
136
+ i = 0
137
+ if not header:
138
+ header = ''
139
+ start_time = time.time()
140
+ end = time.time()
141
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
142
+ data_time = SmoothedValue(fmt='{avg:.4f}')
143
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
144
+ log_msg = [
145
+ header,
146
+ '[{0' + space_fmt + '}/{1}]',
147
+ 'eta: {eta}',
148
+ '{meters}',
149
+ 'time: {time}',
150
+ 'data: {data}'
151
+ ]
152
+ if torch.cuda.is_available():
153
+ log_msg.append('max mem: {memory:.0f}')
154
+ log_msg = self.delimiter.join(log_msg)
155
+ MB = 1024.0 * 1024.0
156
+ for obj in iterable:
157
+ data_time.update(time.time() - end)
158
+ yield obj
159
+ iter_time.update(time.time() - end)
160
+ if i % print_freq == 0 or i == len(iterable) - 1:
161
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
162
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
163
+ if torch.cuda.is_available():
164
+ print(log_msg.format(
165
+ i, len(iterable), eta=eta_string,
166
+ meters=str(self),
167
+ time=str(iter_time), data=str(data_time),
168
+ memory=torch.cuda.max_memory_allocated() / MB))
169
+ else:
170
+ print(log_msg.format(
171
+ i, len(iterable), eta=eta_string,
172
+ meters=str(self),
173
+ time=str(iter_time), data=str(data_time)))
174
+ i += 1
175
+ end = time.time()
176
+ total_time = time.time() - start_time
177
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
178
+ print('{} Total time: {} ({:.4f} s / it)'.format(
179
+ header, total_time_str, total_time / len(iterable)))
180
+
181
+
182
+ class AttrDict(dict):
183
+ def __init__(self, *args, **kwargs):
184
+ super(AttrDict, self).__init__(*args, **kwargs)
185
+ self.__dict__ = self
186
+
187
+
188
+ def compute_acc(logits, label, reduction='mean'):
189
+ ret = (torch.argmax(logits, dim=1) == label).float()
190
+ if reduction == 'none':
191
+ return ret.detach()
192
+ elif reduction == 'mean':
193
+ return ret.mean().item()
194
+
195
+ def compute_n_params(model, return_str=True):
196
+ tot = 0
197
+ for p in model.parameters():
198
+ w = 1
199
+ for x in p.shape:
200
+ w *= x
201
+ tot += w
202
+ if return_str:
203
+ if tot >= 1e6:
204
+ return '{:.1f}M'.format(tot / 1e6)
205
+ else:
206
+ return '{:.1f}K'.format(tot / 1e3)
207
+ else:
208
+ return tot
209
+
210
+ def setup_for_distributed(is_master):
211
+ """
212
+ This function disables printing when not in master process
213
+ """
214
+ import builtins as __builtin__
215
+ builtin_print = __builtin__.print
216
+
217
+ def print(*args, **kwargs):
218
+ force = kwargs.pop('force', False)
219
+ if is_master or force:
220
+ builtin_print(*args, **kwargs)
221
+
222
+ __builtin__.print = print
223
+
224
+
225
+ def is_dist_avail_and_initialized():
226
+ if not dist.is_available():
227
+ return False
228
+ if not dist.is_initialized():
229
+ return False
230
+ return True
231
+
232
+
233
+ def get_world_size():
234
+ if not is_dist_avail_and_initialized():
235
+ return 1
236
+ return dist.get_world_size()
237
+
238
+
239
+ def get_rank():
240
+ if not is_dist_avail_and_initialized():
241
+ return 0
242
+ return dist.get_rank()
243
+
244
+
245
+ def is_main_process():
246
+ return get_rank() == 0
247
+
248
+
249
+ def save_on_master(*args, **kwargs):
250
+ if is_main_process():
251
+ torch.save(*args, **kwargs)
252
+
253
+
254
+ def init_distributed_mode(args):
255
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
256
+ args.rank = int(os.environ["RANK"])
257
+ args.world_size = int(os.environ['WORLD_SIZE'])
258
+ args.gpu = int(os.environ['LOCAL_RANK'])
259
+ elif 'SLURM_PROCID' in os.environ:
260
+ args.rank = int(os.environ['SLURM_PROCID'])
261
+ args.gpu = args.rank % torch.cuda.device_count()
262
+ else:
263
+ print('Not using distributed mode')
264
+ args.distributed = False
265
+ return
266
+
267
+ args.distributed = True
268
+
269
+ torch.cuda.set_device(args.gpu)
270
+ args.dist_backend = 'nccl'
271
+ print('| distributed init (rank {}, word {}): {}'.format(
272
+ args.rank, args.world_size, args.dist_url), flush=True)
273
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
274
+ world_size=args.world_size, rank=args.rank)
275
+ torch.distributed.barrier()
276
+ setup_for_distributed(args.rank == 0)
277
+
278
+
vision/visual_genome/test.csv ADDED
The diff for this file is too large to render. See raw diff