florentgbelidji HF staff commited on
Commit
1ff436e
1 Parent(s): 60e5b4d

Delete models

Browse files
models/.ipynb_checkpoints/blip_decoder-checkpoint.py DELETED
@@ -1,175 +0,0 @@
1
- '''
2
- * Copyright (c) 2022, salesforce.com, inc.
3
- * All rights reserved.
4
- * SPDX-License-Identifier: BSD-3-Clause
5
- * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
- * By Junnan Li
7
- '''
8
- import warnings
9
- warnings.filterwarnings("ignore")
10
-
11
- from vit import VisionTransformer, interpolate_pos_embed
12
- from med import BertConfig, BertModel, BertLMHeadModel
13
- from transformers import BertTokenizer
14
-
15
- import torch
16
- from torch import nn
17
- import torch.nn.functional as F
18
-
19
- import os
20
- from urllib.parse import urlparse
21
- from timm.models.hub import download_cached_file
22
-
23
- class BLIP_Decoder(nn.Module):
24
- def __init__(self,
25
- med_config = 'configs/med_config.json',
26
- image_size = 384,
27
- vit = 'base',
28
- vit_grad_ckpt = False,
29
- vit_ckpt_layer = 0,
30
- prompt = 'a picture of ',
31
- ):
32
- """
33
- Args:
34
- med_config (str): path for the mixture of encoder-decoder model's configuration file
35
- image_size (int): input image size
36
- vit (str): model size of vision transformer
37
- """
38
- super().__init__()
39
-
40
- self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
41
- self.tokenizer = init_tokenizer()
42
- med_config = BertConfig.from_json_file(med_config)
43
- med_config.encoder_width = vision_width
44
- self.text_decoder = BertLMHeadModel(config=med_config)
45
-
46
- self.prompt = prompt
47
- self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1
48
-
49
-
50
- def forward(self, image, caption):
51
-
52
- image_embeds = self.visual_encoder(image)
53
- image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
54
-
55
- text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device)
56
-
57
- text.input_ids[:,0] = self.tokenizer.bos_token_id
58
-
59
- decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100)
60
- decoder_targets[:,:self.prompt_length] = -100
61
-
62
- decoder_output = self.text_decoder(text.input_ids,
63
- attention_mask = text.attention_mask,
64
- encoder_hidden_states = image_embeds,
65
- encoder_attention_mask = image_atts,
66
- labels = decoder_targets,
67
- return_dict = True,
68
- )
69
- loss_lm = decoder_output.loss
70
-
71
- return loss_lm
72
-
73
- def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0):
74
- image_embeds = self.visual_encoder(image)
75
-
76
- if not sample:
77
- image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
78
-
79
- image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
80
- model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts}
81
-
82
- prompt = [self.prompt] * image.size(0)
83
- input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device)
84
- input_ids[:,0] = self.tokenizer.bos_token_id
85
- input_ids = input_ids[:, :-1]
86
-
87
- if sample:
88
- #nucleus sampling
89
- outputs = self.text_decoder.generate(input_ids=input_ids,
90
- max_length=max_length,
91
- min_length=min_length,
92
- do_sample=True,
93
- top_p=top_p,
94
- num_return_sequences=1,
95
- eos_token_id=self.tokenizer.sep_token_id,
96
- pad_token_id=self.tokenizer.pad_token_id,
97
- repetition_penalty=1.1,
98
- **model_kwargs)
99
- else:
100
- #beam search
101
- outputs = self.text_decoder.generate(input_ids=input_ids,
102
- max_length=max_length,
103
- min_length=min_length,
104
- num_beams=num_beams,
105
- eos_token_id=self.tokenizer.sep_token_id,
106
- pad_token_id=self.tokenizer.pad_token_id,
107
- repetition_penalty=repetition_penalty,
108
- **model_kwargs)
109
-
110
- captions = []
111
- for output in outputs:
112
- caption = self.tokenizer.decode(output, skip_special_tokens=True)
113
- captions.append(caption[len(self.prompt):])
114
- return captions
115
-
116
-
117
- def blip_decoder(pretrained='',**kwargs):
118
- model = BLIP_Decoder(**kwargs)
119
- if pretrained:
120
- model,msg = load_checkpoint(model,pretrained)
121
- assert(len(msg.missing_keys)==0)
122
- return model
123
-
124
- def init_tokenizer():
125
- tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
126
- tokenizer.add_special_tokens({'bos_token':'[DEC]'})
127
- tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
128
- tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
129
- return tokenizer
130
-
131
-
132
- def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
133
-
134
- assert vit in ['base', 'large'], "vit parameter must be base or large"
135
- if vit=='base':
136
- vision_width = 768
137
- visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
138
- num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
139
- drop_path_rate=0 or drop_path_rate
140
- )
141
- elif vit=='large':
142
- vision_width = 1024
143
- visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
144
- num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
145
- drop_path_rate=0.1 or drop_path_rate
146
- )
147
- return visual_encoder, vision_width
148
-
149
- def is_url(url_or_filename):
150
- parsed = urlparse(url_or_filename)
151
- return parsed.scheme in ("http", "https")
152
-
153
- def load_checkpoint(model,url_or_filename):
154
- if is_url(url_or_filename):
155
- cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
156
- checkpoint = torch.load(cached_file, map_location='cpu')
157
- elif os.path.isfile(url_or_filename):
158
- checkpoint = torch.load(url_or_filename, map_location='cpu')
159
- else:
160
- raise RuntimeError('checkpoint url or path is invalid')
161
-
162
- state_dict = checkpoint['model']
163
-
164
- state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
165
- if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
166
- state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
167
- model.visual_encoder_m)
168
- for key in model.state_dict().keys():
169
- if key in state_dict.keys():
170
- if state_dict[key].shape!=model.state_dict()[key].shape:
171
- del state_dict[key]
172
-
173
- msg = model.load_state_dict(state_dict,strict=False)
174
- print('load checkpoint from %s'%url_or_filename)
175
- return model,msg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/blip_decoder.py DELETED
@@ -1,175 +0,0 @@
1
- '''
2
- * Copyright (c) 2022, salesforce.com, inc.
3
- * All rights reserved.
4
- * SPDX-License-Identifier: BSD-3-Clause
5
- * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
- * By Junnan Li
7
- '''
8
- import warnings
9
- warnings.filterwarnings("ignore")
10
-
11
- from models.vit import VisionTransformer, interpolate_pos_embed
12
- from models.med import BertConfig, BertModel, BertLMHeadModel
13
- from transformers import BertTokenizer
14
-
15
- import torch
16
- from torch import nn
17
- import torch.nn.functional as F
18
-
19
- import os
20
- from urllib.parse import urlparse
21
- from timm.models.hub import download_cached_file
22
-
23
- class BLIP_Decoder(nn.Module):
24
- def __init__(self,
25
- med_config = 'configs/med_config.json',
26
- image_size = 384,
27
- vit = 'base',
28
- vit_grad_ckpt = False,
29
- vit_ckpt_layer = 0,
30
- prompt = 'a picture of ',
31
- ):
32
- """
33
- Args:
34
- med_config (str): path for the mixture of encoder-decoder model's configuration file
35
- image_size (int): input image size
36
- vit (str): model size of vision transformer
37
- """
38
- super().__init__()
39
-
40
- self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
41
- self.tokenizer = init_tokenizer()
42
- med_config = BertConfig.from_json_file(med_config)
43
- med_config.encoder_width = vision_width
44
- self.text_decoder = BertLMHeadModel(config=med_config)
45
-
46
- self.prompt = prompt
47
- self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1
48
-
49
-
50
- def forward(self, image, caption):
51
-
52
- image_embeds = self.visual_encoder(image)
53
- image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
54
-
55
- text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device)
56
-
57
- text.input_ids[:,0] = self.tokenizer.bos_token_id
58
-
59
- decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100)
60
- decoder_targets[:,:self.prompt_length] = -100
61
-
62
- decoder_output = self.text_decoder(text.input_ids,
63
- attention_mask = text.attention_mask,
64
- encoder_hidden_states = image_embeds,
65
- encoder_attention_mask = image_atts,
66
- labels = decoder_targets,
67
- return_dict = True,
68
- )
69
- loss_lm = decoder_output.loss
70
-
71
- return loss_lm
72
-
73
- def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0):
74
- image_embeds = self.visual_encoder(image)
75
-
76
- if not sample:
77
- image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
78
-
79
- image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
80
- model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts}
81
-
82
- prompt = [self.prompt] * image.size(0)
83
- input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device)
84
- input_ids[:,0] = self.tokenizer.bos_token_id
85
- input_ids = input_ids[:, :-1]
86
-
87
- if sample:
88
- #nucleus sampling
89
- outputs = self.text_decoder.generate(input_ids=input_ids,
90
- max_length=max_length,
91
- min_length=min_length,
92
- do_sample=True,
93
- top_p=top_p,
94
- num_return_sequences=1,
95
- eos_token_id=self.tokenizer.sep_token_id,
96
- pad_token_id=self.tokenizer.pad_token_id,
97
- repetition_penalty=1.1,
98
- **model_kwargs)
99
- else:
100
- #beam search
101
- outputs = self.text_decoder.generate(input_ids=input_ids,
102
- max_length=max_length,
103
- min_length=min_length,
104
- num_beams=num_beams,
105
- eos_token_id=self.tokenizer.sep_token_id,
106
- pad_token_id=self.tokenizer.pad_token_id,
107
- repetition_penalty=repetition_penalty,
108
- **model_kwargs)
109
-
110
- captions = []
111
- for output in outputs:
112
- caption = self.tokenizer.decode(output, skip_special_tokens=True)
113
- captions.append(caption[len(self.prompt):])
114
- return captions
115
-
116
-
117
- def blip_decoder(pretrained='',**kwargs):
118
- model = BLIP_Decoder(**kwargs)
119
- if pretrained:
120
- model,msg = load_checkpoint(model,pretrained)
121
- assert(len(msg.missing_keys)==0)
122
- return model
123
-
124
- def init_tokenizer():
125
- tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
126
- tokenizer.add_special_tokens({'bos_token':'[DEC]'})
127
- tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
128
- tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
129
- return tokenizer
130
-
131
-
132
- def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
133
-
134
- assert vit in ['base', 'large'], "vit parameter must be base or large"
135
- if vit=='base':
136
- vision_width = 768
137
- visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
138
- num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
139
- drop_path_rate=0 or drop_path_rate
140
- )
141
- elif vit=='large':
142
- vision_width = 1024
143
- visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
144
- num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
145
- drop_path_rate=0.1 or drop_path_rate
146
- )
147
- return visual_encoder, vision_width
148
-
149
- def is_url(url_or_filename):
150
- parsed = urlparse(url_or_filename)
151
- return parsed.scheme in ("http", "https")
152
-
153
- def load_checkpoint(model,url_or_filename):
154
- if is_url(url_or_filename):
155
- cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
156
- checkpoint = torch.load(cached_file, map_location='cpu')
157
- elif os.path.isfile(url_or_filename):
158
- checkpoint = torch.load(url_or_filename, map_location='cpu')
159
- else:
160
- raise RuntimeError('checkpoint url or path is invalid')
161
-
162
- state_dict = checkpoint['model']
163
-
164
- state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
165
- if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
166
- state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
167
- model.visual_encoder_m)
168
- for key in model.state_dict().keys():
169
- if key in state_dict.keys():
170
- if state_dict[key].shape!=model.state_dict()[key].shape:
171
- del state_dict[key]
172
-
173
- msg = model.load_state_dict(state_dict,strict=False)
174
- print('load checkpoint from %s'%url_or_filename)
175
- return model,msg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/med.py DELETED
@@ -1,953 +0,0 @@
1
- '''
2
- * Copyright (c) 2022, salesforce.com, inc.
3
- * All rights reserved.
4
- * SPDX-License-Identifier: BSD-3-Clause
5
- * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
- * By Junnan Li
7
- * Based on huggingface code base
8
- * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
9
- '''
10
-
11
- import math
12
- import os
13
- import warnings
14
- from dataclasses import dataclass
15
- from typing import Optional, Tuple
16
-
17
- import torch
18
- from torch import Tensor, device, dtype, nn
19
- import torch.utils.checkpoint
20
- from torch import nn
21
- from torch.nn import CrossEntropyLoss
22
- import torch.nn.functional as F
23
-
24
- from transformers.activations import ACT2FN
25
- from transformers.file_utils import (
26
- ModelOutput,
27
- )
28
- from transformers.modeling_outputs import (
29
- BaseModelOutputWithPastAndCrossAttentions,
30
- BaseModelOutputWithPoolingAndCrossAttentions,
31
- CausalLMOutputWithCrossAttentions,
32
- MaskedLMOutput,
33
- MultipleChoiceModelOutput,
34
- NextSentencePredictorOutput,
35
- QuestionAnsweringModelOutput,
36
- SequenceClassifierOutput,
37
- TokenClassifierOutput,
38
- )
39
- from transformers.modeling_utils import (
40
- PreTrainedModel,
41
- apply_chunking_to_forward,
42
- find_pruneable_heads_and_indices,
43
- prune_linear_layer,
44
- )
45
- from transformers.utils import logging
46
- from transformers.models.bert.configuration_bert import BertConfig
47
-
48
-
49
- logger = logging.get_logger(__name__)
50
-
51
-
52
- class BertEmbeddings(nn.Module):
53
- """Construct the embeddings from word and position embeddings."""
54
-
55
- def __init__(self, config):
56
- super().__init__()
57
- self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
58
- self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
59
-
60
- # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
61
- # any TensorFlow checkpoint file
62
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
63
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
64
-
65
- # position_ids (1, len position emb) is contiguous in memory and exported when serialized
66
- self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
67
- self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
68
-
69
- self.config = config
70
-
71
- def forward(
72
- self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
73
- ):
74
- if input_ids is not None:
75
- input_shape = input_ids.size()
76
- else:
77
- input_shape = inputs_embeds.size()[:-1]
78
-
79
- seq_length = input_shape[1]
80
-
81
- if position_ids is None:
82
- position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
83
-
84
- if inputs_embeds is None:
85
- inputs_embeds = self.word_embeddings(input_ids)
86
-
87
- embeddings = inputs_embeds
88
-
89
- if self.position_embedding_type == "absolute":
90
- position_embeddings = self.position_embeddings(position_ids)
91
- embeddings += position_embeddings
92
- embeddings = self.LayerNorm(embeddings)
93
- embeddings = self.dropout(embeddings)
94
- return embeddings
95
-
96
-
97
- class BertSelfAttention(nn.Module):
98
- def __init__(self, config, is_cross_attention):
99
- super().__init__()
100
- self.config = config
101
- if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
102
- raise ValueError(
103
- "The hidden size (%d) is not a multiple of the number of attention "
104
- "heads (%d)" % (config.hidden_size, config.num_attention_heads)
105
- )
106
-
107
- self.num_attention_heads = config.num_attention_heads
108
- self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
109
- self.all_head_size = self.num_attention_heads * self.attention_head_size
110
-
111
- self.query = nn.Linear(config.hidden_size, self.all_head_size)
112
- if is_cross_attention:
113
- self.key = nn.Linear(config.encoder_width, self.all_head_size)
114
- self.value = nn.Linear(config.encoder_width, self.all_head_size)
115
- else:
116
- self.key = nn.Linear(config.hidden_size, self.all_head_size)
117
- self.value = nn.Linear(config.hidden_size, self.all_head_size)
118
-
119
- self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
120
- self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
121
- if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
122
- self.max_position_embeddings = config.max_position_embeddings
123
- self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
124
- self.save_attention = False
125
-
126
- def save_attn_gradients(self, attn_gradients):
127
- self.attn_gradients = attn_gradients
128
-
129
- def get_attn_gradients(self):
130
- return self.attn_gradients
131
-
132
- def save_attention_map(self, attention_map):
133
- self.attention_map = attention_map
134
-
135
- def get_attention_map(self):
136
- return self.attention_map
137
-
138
- def transpose_for_scores(self, x):
139
- new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
140
- x = x.view(*new_x_shape)
141
- return x.permute(0, 2, 1, 3)
142
-
143
- def forward(
144
- self,
145
- hidden_states,
146
- attention_mask=None,
147
- head_mask=None,
148
- encoder_hidden_states=None,
149
- encoder_attention_mask=None,
150
- past_key_value=None,
151
- output_attentions=False,
152
- ):
153
- mixed_query_layer = self.query(hidden_states)
154
-
155
- # If this is instantiated as a cross-attention module, the keys
156
- # and values come from an encoder; the attention mask needs to be
157
- # such that the encoder's padding tokens are not attended to.
158
- is_cross_attention = encoder_hidden_states is not None
159
-
160
- if is_cross_attention:
161
- key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
162
- value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
163
- attention_mask = encoder_attention_mask
164
- elif past_key_value is not None:
165
- key_layer = self.transpose_for_scores(self.key(hidden_states))
166
- value_layer = self.transpose_for_scores(self.value(hidden_states))
167
- key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
168
- value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
169
- else:
170
- key_layer = self.transpose_for_scores(self.key(hidden_states))
171
- value_layer = self.transpose_for_scores(self.value(hidden_states))
172
-
173
- query_layer = self.transpose_for_scores(mixed_query_layer)
174
-
175
- past_key_value = (key_layer, value_layer)
176
-
177
- # Take the dot product between "query" and "key" to get the raw attention scores.
178
- attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
179
-
180
- if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
181
- seq_length = hidden_states.size()[1]
182
- position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
183
- position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
184
- distance = position_ids_l - position_ids_r
185
- positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
186
- positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
187
-
188
- if self.position_embedding_type == "relative_key":
189
- relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
190
- attention_scores = attention_scores + relative_position_scores
191
- elif self.position_embedding_type == "relative_key_query":
192
- relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
193
- relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
194
- attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
195
-
196
- attention_scores = attention_scores / math.sqrt(self.attention_head_size)
197
- if attention_mask is not None:
198
- # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
199
- attention_scores = attention_scores + attention_mask
200
-
201
- # Normalize the attention scores to probabilities.
202
- attention_probs = nn.Softmax(dim=-1)(attention_scores)
203
-
204
- if is_cross_attention and self.save_attention:
205
- self.save_attention_map(attention_probs)
206
- attention_probs.register_hook(self.save_attn_gradients)
207
-
208
- # This is actually dropping out entire tokens to attend to, which might
209
- # seem a bit unusual, but is taken from the original Transformer paper.
210
- attention_probs_dropped = self.dropout(attention_probs)
211
-
212
- # Mask heads if we want to
213
- if head_mask is not None:
214
- attention_probs_dropped = attention_probs_dropped * head_mask
215
-
216
- context_layer = torch.matmul(attention_probs_dropped, value_layer)
217
-
218
- context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
219
- new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
220
- context_layer = context_layer.view(*new_context_layer_shape)
221
-
222
- outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
223
-
224
- outputs = outputs + (past_key_value,)
225
- return outputs
226
-
227
-
228
- class BertSelfOutput(nn.Module):
229
- def __init__(self, config):
230
- super().__init__()
231
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
232
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
233
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
234
-
235
- def forward(self, hidden_states, input_tensor):
236
- hidden_states = self.dense(hidden_states)
237
- hidden_states = self.dropout(hidden_states)
238
- hidden_states = self.LayerNorm(hidden_states + input_tensor)
239
- return hidden_states
240
-
241
-
242
- class BertAttention(nn.Module):
243
- def __init__(self, config, is_cross_attention=False):
244
- super().__init__()
245
- self.self = BertSelfAttention(config, is_cross_attention)
246
- self.output = BertSelfOutput(config)
247
- self.pruned_heads = set()
248
-
249
- def prune_heads(self, heads):
250
- if len(heads) == 0:
251
- return
252
- heads, index = find_pruneable_heads_and_indices(
253
- heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
254
- )
255
-
256
- # Prune linear layers
257
- self.self.query = prune_linear_layer(self.self.query, index)
258
- self.self.key = prune_linear_layer(self.self.key, index)
259
- self.self.value = prune_linear_layer(self.self.value, index)
260
- self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
261
-
262
- # Update hyper params and store pruned heads
263
- self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
264
- self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
265
- self.pruned_heads = self.pruned_heads.union(heads)
266
-
267
- def forward(
268
- self,
269
- hidden_states,
270
- attention_mask=None,
271
- head_mask=None,
272
- encoder_hidden_states=None,
273
- encoder_attention_mask=None,
274
- past_key_value=None,
275
- output_attentions=False,
276
- ):
277
- self_outputs = self.self(
278
- hidden_states,
279
- attention_mask,
280
- head_mask,
281
- encoder_hidden_states,
282
- encoder_attention_mask,
283
- past_key_value,
284
- output_attentions,
285
- )
286
- attention_output = self.output(self_outputs[0], hidden_states)
287
- outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
288
- return outputs
289
-
290
-
291
- class BertIntermediate(nn.Module):
292
- def __init__(self, config):
293
- super().__init__()
294
- self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
295
- if isinstance(config.hidden_act, str):
296
- self.intermediate_act_fn = ACT2FN[config.hidden_act]
297
- else:
298
- self.intermediate_act_fn = config.hidden_act
299
-
300
- def forward(self, hidden_states):
301
- hidden_states = self.dense(hidden_states)
302
- hidden_states = self.intermediate_act_fn(hidden_states)
303
- return hidden_states
304
-
305
-
306
- class BertOutput(nn.Module):
307
- def __init__(self, config):
308
- super().__init__()
309
- self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
310
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
311
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
312
-
313
- def forward(self, hidden_states, input_tensor):
314
- hidden_states = self.dense(hidden_states)
315
- hidden_states = self.dropout(hidden_states)
316
- hidden_states = self.LayerNorm(hidden_states + input_tensor)
317
- return hidden_states
318
-
319
-
320
- class BertLayer(nn.Module):
321
- def __init__(self, config, layer_num):
322
- super().__init__()
323
- self.config = config
324
- self.chunk_size_feed_forward = config.chunk_size_feed_forward
325
- self.seq_len_dim = 1
326
- self.attention = BertAttention(config)
327
- self.layer_num = layer_num
328
- if self.config.add_cross_attention:
329
- self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
330
- self.intermediate = BertIntermediate(config)
331
- self.output = BertOutput(config)
332
-
333
- def forward(
334
- self,
335
- hidden_states,
336
- attention_mask=None,
337
- head_mask=None,
338
- encoder_hidden_states=None,
339
- encoder_attention_mask=None,
340
- past_key_value=None,
341
- output_attentions=False,
342
- mode=None,
343
- ):
344
- # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
345
- self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
346
- self_attention_outputs = self.attention(
347
- hidden_states,
348
- attention_mask,
349
- head_mask,
350
- output_attentions=output_attentions,
351
- past_key_value=self_attn_past_key_value,
352
- )
353
- attention_output = self_attention_outputs[0]
354
-
355
- outputs = self_attention_outputs[1:-1]
356
- present_key_value = self_attention_outputs[-1]
357
-
358
- if mode=='multimodal':
359
- assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
360
-
361
- cross_attention_outputs = self.crossattention(
362
- attention_output,
363
- attention_mask,
364
- head_mask,
365
- encoder_hidden_states,
366
- encoder_attention_mask,
367
- output_attentions=output_attentions,
368
- )
369
- attention_output = cross_attention_outputs[0]
370
- outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
371
- layer_output = apply_chunking_to_forward(
372
- self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
373
- )
374
- outputs = (layer_output,) + outputs
375
-
376
- outputs = outputs + (present_key_value,)
377
-
378
- return outputs
379
-
380
- def feed_forward_chunk(self, attention_output):
381
- intermediate_output = self.intermediate(attention_output)
382
- layer_output = self.output(intermediate_output, attention_output)
383
- return layer_output
384
-
385
-
386
- class BertEncoder(nn.Module):
387
- def __init__(self, config):
388
- super().__init__()
389
- self.config = config
390
- self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
391
- self.gradient_checkpointing = False
392
-
393
- def forward(
394
- self,
395
- hidden_states,
396
- attention_mask=None,
397
- head_mask=None,
398
- encoder_hidden_states=None,
399
- encoder_attention_mask=None,
400
- past_key_values=None,
401
- use_cache=None,
402
- output_attentions=False,
403
- output_hidden_states=False,
404
- return_dict=True,
405
- mode='multimodal',
406
- ):
407
- all_hidden_states = () if output_hidden_states else None
408
- all_self_attentions = () if output_attentions else None
409
- all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
410
-
411
- next_decoder_cache = () if use_cache else None
412
-
413
- for i in range(self.config.num_hidden_layers):
414
- layer_module = self.layer[i]
415
- if output_hidden_states:
416
- all_hidden_states = all_hidden_states + (hidden_states,)
417
-
418
- layer_head_mask = head_mask[i] if head_mask is not None else None
419
- past_key_value = past_key_values[i] if past_key_values is not None else None
420
-
421
- if self.gradient_checkpointing and self.training:
422
-
423
- if use_cache:
424
- logger.warn(
425
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
426
- )
427
- use_cache = False
428
-
429
- def create_custom_forward(module):
430
- def custom_forward(*inputs):
431
- return module(*inputs, past_key_value, output_attentions)
432
-
433
- return custom_forward
434
-
435
- layer_outputs = torch.utils.checkpoint.checkpoint(
436
- create_custom_forward(layer_module),
437
- hidden_states,
438
- attention_mask,
439
- layer_head_mask,
440
- encoder_hidden_states,
441
- encoder_attention_mask,
442
- mode=mode,
443
- )
444
- else:
445
- layer_outputs = layer_module(
446
- hidden_states,
447
- attention_mask,
448
- layer_head_mask,
449
- encoder_hidden_states,
450
- encoder_attention_mask,
451
- past_key_value,
452
- output_attentions,
453
- mode=mode,
454
- )
455
-
456
- hidden_states = layer_outputs[0]
457
- if use_cache:
458
- next_decoder_cache += (layer_outputs[-1],)
459
- if output_attentions:
460
- all_self_attentions = all_self_attentions + (layer_outputs[1],)
461
-
462
- if output_hidden_states:
463
- all_hidden_states = all_hidden_states + (hidden_states,)
464
-
465
- if not return_dict:
466
- return tuple(
467
- v
468
- for v in [
469
- hidden_states,
470
- next_decoder_cache,
471
- all_hidden_states,
472
- all_self_attentions,
473
- all_cross_attentions,
474
- ]
475
- if v is not None
476
- )
477
- return BaseModelOutputWithPastAndCrossAttentions(
478
- last_hidden_state=hidden_states,
479
- past_key_values=next_decoder_cache,
480
- hidden_states=all_hidden_states,
481
- attentions=all_self_attentions,
482
- cross_attentions=all_cross_attentions,
483
- )
484
-
485
-
486
- class BertPooler(nn.Module):
487
- def __init__(self, config):
488
- super().__init__()
489
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
490
- self.activation = nn.Tanh()
491
-
492
- def forward(self, hidden_states):
493
- # We "pool" the model by simply taking the hidden state corresponding
494
- # to the first token.
495
- first_token_tensor = hidden_states[:, 0]
496
- pooled_output = self.dense(first_token_tensor)
497
- pooled_output = self.activation(pooled_output)
498
- return pooled_output
499
-
500
-
501
- class BertPredictionHeadTransform(nn.Module):
502
- def __init__(self, config):
503
- super().__init__()
504
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
505
- if isinstance(config.hidden_act, str):
506
- self.transform_act_fn = ACT2FN[config.hidden_act]
507
- else:
508
- self.transform_act_fn = config.hidden_act
509
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
510
-
511
- def forward(self, hidden_states):
512
- hidden_states = self.dense(hidden_states)
513
- hidden_states = self.transform_act_fn(hidden_states)
514
- hidden_states = self.LayerNorm(hidden_states)
515
- return hidden_states
516
-
517
-
518
- class BertLMPredictionHead(nn.Module):
519
- def __init__(self, config):
520
- super().__init__()
521
- self.transform = BertPredictionHeadTransform(config)
522
-
523
- # The output weights are the same as the input embeddings, but there is
524
- # an output-only bias for each token.
525
- self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
526
-
527
- self.bias = nn.Parameter(torch.zeros(config.vocab_size))
528
-
529
- # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
530
- self.decoder.bias = self.bias
531
-
532
- def forward(self, hidden_states):
533
- hidden_states = self.transform(hidden_states)
534
- hidden_states = self.decoder(hidden_states)
535
- return hidden_states
536
-
537
-
538
- class BertOnlyMLMHead(nn.Module):
539
- def __init__(self, config):
540
- super().__init__()
541
- self.predictions = BertLMPredictionHead(config)
542
-
543
- def forward(self, sequence_output):
544
- prediction_scores = self.predictions(sequence_output)
545
- return prediction_scores
546
-
547
-
548
- class BertPreTrainedModel(PreTrainedModel):
549
- """
550
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
551
- models.
552
- """
553
-
554
- config_class = BertConfig
555
- base_model_prefix = "bert"
556
- _keys_to_ignore_on_load_missing = [r"position_ids"]
557
-
558
- def _init_weights(self, module):
559
- """ Initialize the weights """
560
- if isinstance(module, (nn.Linear, nn.Embedding)):
561
- # Slightly different from the TF version which uses truncated_normal for initialization
562
- # cf https://github.com/pytorch/pytorch/pull/5617
563
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
564
- elif isinstance(module, nn.LayerNorm):
565
- module.bias.data.zero_()
566
- module.weight.data.fill_(1.0)
567
- if isinstance(module, nn.Linear) and module.bias is not None:
568
- module.bias.data.zero_()
569
-
570
-
571
- class BertModel(BertPreTrainedModel):
572
- """
573
- The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
574
- cross-attention is added between the self-attention layers, following the architecture described in `Attention is
575
- all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
576
- Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
577
- argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
578
- input to the forward pass.
579
- """
580
-
581
- def __init__(self, config, add_pooling_layer=True):
582
- super().__init__(config)
583
- self.config = config
584
-
585
- self.embeddings = BertEmbeddings(config)
586
-
587
- self.encoder = BertEncoder(config)
588
-
589
- self.pooler = BertPooler(config) if add_pooling_layer else None
590
-
591
- self.init_weights()
592
-
593
-
594
- def get_input_embeddings(self):
595
- return self.embeddings.word_embeddings
596
-
597
- def set_input_embeddings(self, value):
598
- self.embeddings.word_embeddings = value
599
-
600
- def _prune_heads(self, heads_to_prune):
601
- """
602
- Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
603
- class PreTrainedModel
604
- """
605
- for layer, heads in heads_to_prune.items():
606
- self.encoder.layer[layer].attention.prune_heads(heads)
607
-
608
-
609
- def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
610
- """
611
- Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
612
- Arguments:
613
- attention_mask (:obj:`torch.Tensor`):
614
- Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
615
- input_shape (:obj:`Tuple[int]`):
616
- The shape of the input to the model.
617
- device: (:obj:`torch.device`):
618
- The device of the input to the model.
619
- Returns:
620
- :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
621
- """
622
- # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
623
- # ourselves in which case we just need to make it broadcastable to all heads.
624
- if attention_mask.dim() == 3:
625
- extended_attention_mask = attention_mask[:, None, :, :]
626
- elif attention_mask.dim() == 2:
627
- # Provided a padding mask of dimensions [batch_size, seq_length]
628
- # - if the model is a decoder, apply a causal mask in addition to the padding mask
629
- # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
630
- if is_decoder:
631
- batch_size, seq_length = input_shape
632
-
633
- seq_ids = torch.arange(seq_length, device=device)
634
- causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
635
- # in case past_key_values are used we need to add a prefix ones mask to the causal mask
636
- # causal and attention masks must have same type with pytorch version < 1.3
637
- causal_mask = causal_mask.to(attention_mask.dtype)
638
-
639
- if causal_mask.shape[1] < attention_mask.shape[1]:
640
- prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
641
- causal_mask = torch.cat(
642
- [
643
- torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
644
- causal_mask,
645
- ],
646
- axis=-1,
647
- )
648
-
649
- extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
650
- else:
651
- extended_attention_mask = attention_mask[:, None, None, :]
652
- else:
653
- raise ValueError(
654
- "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
655
- input_shape, attention_mask.shape
656
- )
657
- )
658
-
659
- # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
660
- # masked positions, this operation will create a tensor which is 0.0 for
661
- # positions we want to attend and -10000.0 for masked positions.
662
- # Since we are adding it to the raw scores before the softmax, this is
663
- # effectively the same as removing these entirely.
664
- extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
665
- extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
666
- return extended_attention_mask
667
-
668
- def forward(
669
- self,
670
- input_ids=None,
671
- attention_mask=None,
672
- position_ids=None,
673
- head_mask=None,
674
- inputs_embeds=None,
675
- encoder_embeds=None,
676
- encoder_hidden_states=None,
677
- encoder_attention_mask=None,
678
- past_key_values=None,
679
- use_cache=None,
680
- output_attentions=None,
681
- output_hidden_states=None,
682
- return_dict=None,
683
- is_decoder=False,
684
- mode='multimodal',
685
- ):
686
- r"""
687
- encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
688
- Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
689
- the model is configured as a decoder.
690
- encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
691
- Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
692
- the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
693
- - 1 for tokens that are **not masked**,
694
- - 0 for tokens that are **masked**.
695
- past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
696
- Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
697
- If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
698
- (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
699
- instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
700
- use_cache (:obj:`bool`, `optional`):
701
- If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
702
- decoding (see :obj:`past_key_values`).
703
- """
704
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
705
- output_hidden_states = (
706
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
707
- )
708
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
709
-
710
- if is_decoder:
711
- use_cache = use_cache if use_cache is not None else self.config.use_cache
712
- else:
713
- use_cache = False
714
-
715
- if input_ids is not None and inputs_embeds is not None:
716
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
717
- elif input_ids is not None:
718
- input_shape = input_ids.size()
719
- batch_size, seq_length = input_shape
720
- device = input_ids.device
721
- elif inputs_embeds is not None:
722
- input_shape = inputs_embeds.size()[:-1]
723
- batch_size, seq_length = input_shape
724
- device = inputs_embeds.device
725
- elif encoder_embeds is not None:
726
- input_shape = encoder_embeds.size()[:-1]
727
- batch_size, seq_length = input_shape
728
- device = encoder_embeds.device
729
- else:
730
- raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
731
-
732
- # past_key_values_length
733
- past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
734
-
735
- if attention_mask is None:
736
- attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
737
-
738
- # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
739
- # ourselves in which case we just need to make it broadcastable to all heads.
740
- extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
741
- device, is_decoder)
742
-
743
- # If a 2D or 3D attention mask is provided for the cross-attention
744
- # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
745
- if encoder_hidden_states is not None:
746
- if type(encoder_hidden_states) == list:
747
- encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
748
- else:
749
- encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
750
- encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
751
-
752
- if type(encoder_attention_mask) == list:
753
- encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
754
- elif encoder_attention_mask is None:
755
- encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
756
- encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
757
- else:
758
- encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
759
- else:
760
- encoder_extended_attention_mask = None
761
-
762
- # Prepare head mask if needed
763
- # 1.0 in head_mask indicate we keep the head
764
- # attention_probs has shape bsz x n_heads x N x N
765
- # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
766
- # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
767
- head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
768
-
769
- if encoder_embeds is None:
770
- embedding_output = self.embeddings(
771
- input_ids=input_ids,
772
- position_ids=position_ids,
773
- inputs_embeds=inputs_embeds,
774
- past_key_values_length=past_key_values_length,
775
- )
776
- else:
777
- embedding_output = encoder_embeds
778
-
779
- encoder_outputs = self.encoder(
780
- embedding_output,
781
- attention_mask=extended_attention_mask,
782
- head_mask=head_mask,
783
- encoder_hidden_states=encoder_hidden_states,
784
- encoder_attention_mask=encoder_extended_attention_mask,
785
- past_key_values=past_key_values,
786
- use_cache=use_cache,
787
- output_attentions=output_attentions,
788
- output_hidden_states=output_hidden_states,
789
- return_dict=return_dict,
790
- mode=mode,
791
- )
792
- sequence_output = encoder_outputs[0]
793
- pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
794
-
795
- if not return_dict:
796
- return (sequence_output, pooled_output) + encoder_outputs[1:]
797
-
798
- return BaseModelOutputWithPoolingAndCrossAttentions(
799
- last_hidden_state=sequence_output,
800
- pooler_output=pooled_output,
801
- past_key_values=encoder_outputs.past_key_values,
802
- hidden_states=encoder_outputs.hidden_states,
803
- attentions=encoder_outputs.attentions,
804
- cross_attentions=encoder_outputs.cross_attentions,
805
- )
806
-
807
-
808
-
809
- class BertLMHeadModel(BertPreTrainedModel):
810
-
811
- _keys_to_ignore_on_load_unexpected = [r"pooler"]
812
- _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
813
-
814
- def __init__(self, config):
815
- super().__init__(config)
816
-
817
- self.bert = BertModel(config, add_pooling_layer=False)
818
- self.cls = BertOnlyMLMHead(config)
819
-
820
- self.init_weights()
821
-
822
- def get_output_embeddings(self):
823
- return self.cls.predictions.decoder
824
-
825
- def set_output_embeddings(self, new_embeddings):
826
- self.cls.predictions.decoder = new_embeddings
827
-
828
- def forward(
829
- self,
830
- input_ids=None,
831
- attention_mask=None,
832
- position_ids=None,
833
- head_mask=None,
834
- inputs_embeds=None,
835
- encoder_hidden_states=None,
836
- encoder_attention_mask=None,
837
- labels=None,
838
- past_key_values=None,
839
- use_cache=None,
840
- output_attentions=None,
841
- output_hidden_states=None,
842
- return_dict=None,
843
- return_logits=False,
844
- is_decoder=True,
845
- reduction='mean',
846
- mode='multimodal',
847
- ):
848
- r"""
849
- encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
850
- Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
851
- the model is configured as a decoder.
852
- encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
853
- Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
854
- the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
855
- - 1 for tokens that are **not masked**,
856
- - 0 for tokens that are **masked**.
857
- labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
858
- Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
859
- ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
860
- ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
861
- past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
862
- Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
863
- If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
864
- (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
865
- instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
866
- use_cache (:obj:`bool`, `optional`):
867
- If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
868
- decoding (see :obj:`past_key_values`).
869
- Returns:
870
- Example::
871
- >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
872
- >>> import torch
873
- >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
874
- >>> config = BertConfig.from_pretrained("bert-base-cased")
875
- >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
876
- >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
877
- >>> outputs = model(**inputs)
878
- >>> prediction_logits = outputs.logits
879
- """
880
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
881
- if labels is not None:
882
- use_cache = False
883
-
884
- outputs = self.bert(
885
- input_ids,
886
- attention_mask=attention_mask,
887
- position_ids=position_ids,
888
- head_mask=head_mask,
889
- inputs_embeds=inputs_embeds,
890
- encoder_hidden_states=encoder_hidden_states,
891
- encoder_attention_mask=encoder_attention_mask,
892
- past_key_values=past_key_values,
893
- use_cache=use_cache,
894
- output_attentions=output_attentions,
895
- output_hidden_states=output_hidden_states,
896
- return_dict=return_dict,
897
- is_decoder=is_decoder,
898
- mode=mode,
899
- )
900
-
901
- sequence_output = outputs[0]
902
- prediction_scores = self.cls(sequence_output)
903
-
904
- if return_logits:
905
- return prediction_scores[:, :-1, :].contiguous()
906
-
907
- lm_loss = None
908
- if labels is not None:
909
- # we are doing next-token prediction; shift prediction scores and input ids by one
910
- shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
911
- labels = labels[:, 1:].contiguous()
912
- loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
913
- lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
914
- if reduction=='none':
915
- lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
916
-
917
- if not return_dict:
918
- output = (prediction_scores,) + outputs[2:]
919
- return ((lm_loss,) + output) if lm_loss is not None else output
920
-
921
- return CausalLMOutputWithCrossAttentions(
922
- loss=lm_loss,
923
- logits=prediction_scores,
924
- past_key_values=outputs.past_key_values,
925
- hidden_states=outputs.hidden_states,
926
- attentions=outputs.attentions,
927
- cross_attentions=outputs.cross_attentions,
928
- )
929
-
930
- def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
931
- input_shape = input_ids.shape
932
- # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
933
- if attention_mask is None:
934
- attention_mask = input_ids.new_ones(input_shape)
935
-
936
- # cut decoder_input_ids if past is used
937
- if past is not None:
938
- input_ids = input_ids[:, -1:]
939
-
940
- return {
941
- "input_ids": input_ids,
942
- "attention_mask": attention_mask,
943
- "past_key_values": past,
944
- "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
945
- "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
946
- "is_decoder": True,
947
- }
948
-
949
- def _reorder_cache(self, past, beam_idx):
950
- reordered_past = ()
951
- for layer_past in past:
952
- reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
953
- return reordered_past
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/vit.py DELETED
@@ -1,305 +0,0 @@
1
- '''
2
- * Copyright (c) 2022, salesforce.com, inc.
3
- * All rights reserved.
4
- * SPDX-License-Identifier: BSD-3-Clause
5
- * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
- * By Junnan Li
7
- * Based on timm code base
8
- * https://github.com/rwightman/pytorch-image-models/tree/master/timm
9
- '''
10
-
11
- import torch
12
- import torch.nn as nn
13
- import torch.nn.functional as F
14
- from functools import partial
15
-
16
- from timm.models.vision_transformer import _cfg, PatchEmbed
17
- from timm.models.registry import register_model
18
- from timm.models.layers import trunc_normal_, DropPath
19
- from timm.models.helpers import named_apply, adapt_input_conv
20
-
21
- from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
22
-
23
- class Mlp(nn.Module):
24
- """ MLP as used in Vision Transformer, MLP-Mixer and related networks
25
- """
26
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
27
- super().__init__()
28
- out_features = out_features or in_features
29
- hidden_features = hidden_features or in_features
30
- self.fc1 = nn.Linear(in_features, hidden_features)
31
- self.act = act_layer()
32
- self.fc2 = nn.Linear(hidden_features, out_features)
33
- self.drop = nn.Dropout(drop)
34
-
35
- def forward(self, x):
36
- x = self.fc1(x)
37
- x = self.act(x)
38
- x = self.drop(x)
39
- x = self.fc2(x)
40
- x = self.drop(x)
41
- return x
42
-
43
-
44
- class Attention(nn.Module):
45
- def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
46
- super().__init__()
47
- self.num_heads = num_heads
48
- head_dim = dim // num_heads
49
- # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
50
- self.scale = qk_scale or head_dim ** -0.5
51
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
52
- self.attn_drop = nn.Dropout(attn_drop)
53
- self.proj = nn.Linear(dim, dim)
54
- self.proj_drop = nn.Dropout(proj_drop)
55
- self.attn_gradients = None
56
- self.attention_map = None
57
-
58
- def save_attn_gradients(self, attn_gradients):
59
- self.attn_gradients = attn_gradients
60
-
61
- def get_attn_gradients(self):
62
- return self.attn_gradients
63
-
64
- def save_attention_map(self, attention_map):
65
- self.attention_map = attention_map
66
-
67
- def get_attention_map(self):
68
- return self.attention_map
69
-
70
- def forward(self, x, register_hook=False):
71
- B, N, C = x.shape
72
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
73
- q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
74
-
75
- attn = (q @ k.transpose(-2, -1)) * self.scale
76
- attn = attn.softmax(dim=-1)
77
- attn = self.attn_drop(attn)
78
-
79
- if register_hook:
80
- self.save_attention_map(attn)
81
- attn.register_hook(self.save_attn_gradients)
82
-
83
- x = (attn @ v).transpose(1, 2).reshape(B, N, C)
84
- x = self.proj(x)
85
- x = self.proj_drop(x)
86
- return x
87
-
88
-
89
- class Block(nn.Module):
90
-
91
- def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
92
- drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
93
- super().__init__()
94
- self.norm1 = norm_layer(dim)
95
- self.attn = Attention(
96
- dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
97
- # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
98
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
99
- self.norm2 = norm_layer(dim)
100
- mlp_hidden_dim = int(dim * mlp_ratio)
101
- self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
102
-
103
- if use_grad_checkpointing:
104
- self.attn = checkpoint_wrapper(self.attn)
105
- self.mlp = checkpoint_wrapper(self.mlp)
106
-
107
- def forward(self, x, register_hook=False):
108
- x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
109
- x = x + self.drop_path(self.mlp(self.norm2(x)))
110
- return x
111
-
112
-
113
- class VisionTransformer(nn.Module):
114
- """ Vision Transformer
115
- A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
116
- https://arxiv.org/abs/2010.11929
117
- """
118
- def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
119
- num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
120
- drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
121
- use_grad_checkpointing=False, ckpt_layer=0):
122
- """
123
- Args:
124
- img_size (int, tuple): input image size
125
- patch_size (int, tuple): patch size
126
- in_chans (int): number of input channels
127
- num_classes (int): number of classes for classification head
128
- embed_dim (int): embedding dimension
129
- depth (int): depth of transformer
130
- num_heads (int): number of attention heads
131
- mlp_ratio (int): ratio of mlp hidden dim to embedding dim
132
- qkv_bias (bool): enable bias for qkv if True
133
- qk_scale (float): override default qk scale of head_dim ** -0.5 if set
134
- representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
135
- drop_rate (float): dropout rate
136
- attn_drop_rate (float): attention dropout rate
137
- drop_path_rate (float): stochastic depth rate
138
- norm_layer: (nn.Module): normalization layer
139
- """
140
- super().__init__()
141
- self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
142
- norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
143
-
144
- self.patch_embed = PatchEmbed(
145
- img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
146
-
147
- num_patches = self.patch_embed.num_patches
148
-
149
- self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
150
- self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
151
- self.pos_drop = nn.Dropout(p=drop_rate)
152
-
153
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
154
- self.blocks = nn.ModuleList([
155
- Block(
156
- dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
157
- drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
158
- use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer)
159
- )
160
- for i in range(depth)])
161
- self.norm = norm_layer(embed_dim)
162
-
163
- trunc_normal_(self.pos_embed, std=.02)
164
- trunc_normal_(self.cls_token, std=.02)
165
- self.apply(self._init_weights)
166
-
167
- def _init_weights(self, m):
168
- if isinstance(m, nn.Linear):
169
- trunc_normal_(m.weight, std=.02)
170
- if isinstance(m, nn.Linear) and m.bias is not None:
171
- nn.init.constant_(m.bias, 0)
172
- elif isinstance(m, nn.LayerNorm):
173
- nn.init.constant_(m.bias, 0)
174
- nn.init.constant_(m.weight, 1.0)
175
-
176
- @torch.jit.ignore
177
- def no_weight_decay(self):
178
- return {'pos_embed', 'cls_token'}
179
-
180
- def forward(self, x, register_blk=-1):
181
- B = x.shape[0]
182
- x = self.patch_embed(x)
183
-
184
- cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
185
- x = torch.cat((cls_tokens, x), dim=1)
186
-
187
- x = x + self.pos_embed[:,:x.size(1),:]
188
- x = self.pos_drop(x)
189
-
190
- for i,blk in enumerate(self.blocks):
191
- x = blk(x, register_blk==i)
192
- x = self.norm(x)
193
-
194
- return x
195
-
196
- @torch.jit.ignore()
197
- def load_pretrained(self, checkpoint_path, prefix=''):
198
- _load_weights(self, checkpoint_path, prefix)
199
-
200
-
201
- @torch.no_grad()
202
- def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
203
- """ Load weights from .npz checkpoints for official Google Brain Flax implementation
204
- """
205
- import numpy as np
206
-
207
- def _n2p(w, t=True):
208
- if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
209
- w = w.flatten()
210
- if t:
211
- if w.ndim == 4:
212
- w = w.transpose([3, 2, 0, 1])
213
- elif w.ndim == 3:
214
- w = w.transpose([2, 0, 1])
215
- elif w.ndim == 2:
216
- w = w.transpose([1, 0])
217
- return torch.from_numpy(w)
218
-
219
- w = np.load(checkpoint_path)
220
- if not prefix and 'opt/target/embedding/kernel' in w:
221
- prefix = 'opt/target/'
222
-
223
- if hasattr(model.patch_embed, 'backbone'):
224
- # hybrid
225
- backbone = model.patch_embed.backbone
226
- stem_only = not hasattr(backbone, 'stem')
227
- stem = backbone if stem_only else backbone.stem
228
- stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
229
- stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
230
- stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
231
- if not stem_only:
232
- for i, stage in enumerate(backbone.stages):
233
- for j, block in enumerate(stage.blocks):
234
- bp = f'{prefix}block{i + 1}/unit{j + 1}/'
235
- for r in range(3):
236
- getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
237
- getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
238
- getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
239
- if block.downsample is not None:
240
- block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
241
- block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
242
- block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
243
- embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
244
- else:
245
- embed_conv_w = adapt_input_conv(
246
- model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
247
- model.patch_embed.proj.weight.copy_(embed_conv_w)
248
- model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
249
- model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
250
- pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
251
- if pos_embed_w.shape != model.pos_embed.shape:
252
- pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
253
- pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
254
- model.pos_embed.copy_(pos_embed_w)
255
- model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
256
- model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
257
- # if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
258
- # model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
259
- # model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
260
- # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
261
- # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
262
- # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
263
- for i, block in enumerate(model.blocks.children()):
264
- block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
265
- mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
266
- block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
267
- block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
268
- block.attn.qkv.weight.copy_(torch.cat([
269
- _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
270
- block.attn.qkv.bias.copy_(torch.cat([
271
- _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
272
- block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
273
- block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
274
- for r in range(2):
275
- getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
276
- getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
277
- block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
278
- block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
279
-
280
-
281
- def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
282
- # interpolate position embedding
283
- embedding_size = pos_embed_checkpoint.shape[-1]
284
- num_patches = visual_encoder.patch_embed.num_patches
285
- num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
286
- # height (== width) for the checkpoint position embedding
287
- orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
288
- # height (== width) for the new position embedding
289
- new_size = int(num_patches ** 0.5)
290
-
291
- if orig_size!=new_size:
292
- # class_token and dist_token are kept unchanged
293
- extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
294
- # only the position tokens are interpolated
295
- pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
296
- pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
297
- pos_tokens = torch.nn.functional.interpolate(
298
- pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
299
- pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
300
- new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
301
- print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
302
-
303
- return new_pos_embed
304
- else:
305
- return pos_embed_checkpoint