fffiloni commited on
Commit
c197497
1 Parent(s): afe4c7a

Upload 3 files

Browse files
xdecoder/language/LangEncoder/build.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from transformers import CLIPTokenizer, CLIPTokenizerFast
4
+ from transformers import AutoTokenizer
5
+
6
+ from .registry import lang_encoders
7
+ from .registry import is_lang_encoder
8
+
9
+
10
+ def build_lang_encoder(config_encoder, tokenizer, verbose, **kwargs):
11
+ model_name = config_encoder['NAME']
12
+
13
+ if not is_lang_encoder(model_name):
14
+ raise ValueError(f'Unkown model: {model_name}')
15
+
16
+ return lang_encoders(model_name)(config_encoder, tokenizer, verbose, **kwargs)
17
+
18
+
19
+ def build_tokenizer(config_encoder):
20
+ tokenizer = None
21
+ os.environ['TOKENIZERS_PARALLELISM'] = 'true'
22
+ if config_encoder['TOKENIZER'] == 'clip':
23
+ pretrained_tokenizer = config_encoder.get(
24
+ 'PRETRAINED_TOKENIZER', 'openai/clip-vit-base-patch32'
25
+ )
26
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_tokenizer)
27
+ tokenizer.add_special_tokens({'cls_token': tokenizer.eos_token})
28
+ elif config_encoder['TOKENIZER'] == 'clip-fast':
29
+ pretrained_tokenizer = config_encoder.get(
30
+ 'PRETRAINED_TOKENIZER', 'openai/clip-vit-base-patch32'
31
+ )
32
+ tokenizer = CLIPTokenizerFast.from_pretrained(pretrained_tokenizer, from_slow=True)
33
+ else:
34
+ tokenizer = AutoTokenizer.from_pretrained(config_encoder['TOKENIZER'])
35
+
36
+ return tokenizer
xdecoder/language/LangEncoder/registry.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _lang_encoders = {}
2
+
3
+
4
+ def register_lang_encoder(fn):
5
+ module_name_split = fn.__module__.split('.')
6
+ model_name = module_name_split[-1]
7
+
8
+ _lang_encoders[model_name] = fn
9
+
10
+ return fn
11
+
12
+
13
+ def lang_encoders(model_name):
14
+ return _lang_encoders[model_name]
15
+
16
+
17
+ def is_lang_encoder(model_name):
18
+ return model_name in _lang_encoders
xdecoder/language/LangEncoder/transformer.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Tuple, Union
3
+ import logging
4
+ import os
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
+
11
+ from timm.models.layers import DropPath, trunc_normal_
12
+
13
+ from .registry import register_lang_encoder
14
+ from utils.distributed import is_main_process
15
+ from utils.model import register_norm_module
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ @register_norm_module
21
+ class LayerNorm(nn.Module):
22
+ def __init__(self, hidden_size, eps=1e-12):
23
+ """Construct a layernorm module in the TF style (epsilon inside the square root).
24
+ """
25
+ super(LayerNorm, self).__init__()
26
+ self.weight = nn.Parameter(torch.ones(hidden_size))
27
+ self.bias = nn.Parameter(torch.zeros(hidden_size))
28
+ self.variance_epsilon = eps
29
+
30
+ def forward(self, x):
31
+ pdtype = x.dtype
32
+ x = x.float()
33
+ u = x.mean(-1, keepdim=True)
34
+ s = (x - u).pow(2).mean(-1, keepdim=True)
35
+ x = (x - u) / torch.sqrt(s + self.variance_epsilon)
36
+ return self.weight * x.to(pdtype) + self.bias
37
+
38
+
39
+ class QuickGELU(nn.Module):
40
+ def forward(self, x: torch.Tensor):
41
+ return x * torch.sigmoid(1.702 * x)
42
+
43
+
44
+ class ResidualAttentionBlock(nn.Module):
45
+ def __init__(self,
46
+ d_model: int,
47
+ n_head: int,
48
+ attn_mask: torch.Tensor = None,
49
+ drop_path: float = 0.0):
50
+ super().__init__()
51
+
52
+ self.attn = nn.MultiheadAttention(d_model, n_head)
53
+ self.ln_1 = LayerNorm(d_model)
54
+ self.mlp = nn.Sequential(OrderedDict([
55
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
56
+ ("gelu", QuickGELU()),
57
+ ("c_proj", nn.Linear(d_model * 4, d_model))
58
+ ]))
59
+ self.ln_2 = LayerNorm(d_model)
60
+ self.attn_mask = attn_mask
61
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
62
+
63
+ def attention(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None):
64
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) \
65
+ if self.attn_mask is not None else None
66
+
67
+
68
+ return self.attn(
69
+ x, x, x,
70
+ key_padding_mask=key_padding_mask,
71
+ need_weights=False,
72
+ attn_mask=self.attn_mask
73
+ )[0]
74
+
75
+ def forward(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None):
76
+ x = x + self.drop_path(self.attention(self.ln_1(x), key_padding_mask=key_padding_mask))
77
+ x = x + self.drop_path(self.mlp(self.ln_2(x)))
78
+ return x
79
+
80
+
81
+ class Transformer(nn.Module):
82
+ def __init__(self,
83
+ context_length: int,
84
+ vocab_size: int,
85
+ width: int,
86
+ layers: int,
87
+ heads: int,
88
+ drop_path: float = 0.0,
89
+ autogressive: bool =True):
90
+ super().__init__()
91
+
92
+ self.token_embedding = nn.Embedding(vocab_size, width)
93
+
94
+ self.context_length = context_length
95
+ self.positional_embedding = nn.Parameter(
96
+ torch.empty(self.context_length, width)
97
+ )
98
+
99
+ self.width = width
100
+ self.layers = layers
101
+ self.autogressive = autogressive
102
+ attn_mask = self.build_attention_mask() if autogressive else None
103
+ dpr = [x.item() for x in torch.linspace(0, drop_path, layers)] # stochastic depth decay rule
104
+ self.resblocks = nn.ModuleList(
105
+ [
106
+ ResidualAttentionBlock(width, heads, attn_mask, dpr[i])
107
+ for i in range(layers)
108
+ ]
109
+ )
110
+
111
+ self.ln_final = LayerNorm(width)
112
+
113
+ trunc_normal_(self.positional_embedding, std=.02)
114
+ # nn.init.normal_(self.token_embedding, std=.02)
115
+ trunc_normal_(self.token_embedding.weight, std=.02)
116
+ self.apply(self._init_weights)
117
+
118
+ @property
119
+ def dim_out(self):
120
+ return self.width
121
+
122
+ def build_attention_mask(self):
123
+ # lazily create causal attention mask, with full attention between the vision tokens
124
+ # pytorch uses additive attention mask; fill with -inf
125
+ mask = torch.empty(self.context_length, self.context_length)
126
+ mask.fill_(float("-inf"))
127
+ mask.triu_(1) # zero out the lower diagonal
128
+ return mask
129
+
130
+ def _init_weights(self, m):
131
+ if isinstance(m, (nn.Linear, nn.Conv2d)):
132
+ if is_main_process():
133
+ logger.info('=> init weight of Linear/Conv2d from trunc norm')
134
+ trunc_normal_(m.weight, std=0.02)
135
+ if m.bias is not None:
136
+ if is_main_process():
137
+ logger.info('=> init bias of Linear/Conv2d to zeros')
138
+ nn.init.constant_(m.bias, 0)
139
+ elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
140
+ nn.init.constant_(m.bias, 0)
141
+
142
+ def load_pretrained(self, pretrained='', pretrained_layers=[], verbose=True):
143
+ if os.path.isfile(pretrained):
144
+ pretrained_dict = torch.load(pretrained, map_location='cpu')
145
+ logging.info(f'=> loading pretrained model {pretrained}')
146
+ model_dict = self.state_dict()
147
+ stripped_key = lambda x: x[13:] if x.startswith('lang_encoder.') else x
148
+ pretrained_dict = {
149
+ stripped_key(k): v for k, v in pretrained_dict.items()
150
+ if stripped_key(k) in model_dict.keys()
151
+ }
152
+ need_init_state_dict = {}
153
+ for k, v in pretrained_dict.items():
154
+ need_init = (
155
+ k.split('.')[0] in pretrained_layers
156
+ or pretrained_layers[0] == '*'
157
+ )
158
+ if need_init:
159
+ if verbose:
160
+ logger.info(f'=> init {k} from {pretrained}')
161
+
162
+ if 'positional_embedding' in k and v.size() != model_dict[k].size():
163
+ positional_embedding_pretrained = v
164
+ positional_embedding_current = model_dict[k]
165
+ L1, nH1 = positional_embedding_pretrained.size()
166
+ L2, nH2 = positional_embedding_current.size()
167
+ if nH1 != nH2:
168
+ logger.info(f"Error in loading {k}, passing")
169
+ else:
170
+ if L1 != L2:
171
+ logger.info(
172
+ '=> load_pretrained: resized variant: {} to {}'
173
+ .format((L1, nH1), (L2, nH2))
174
+ )
175
+
176
+ posemb = positional_embedding_pretrained.float()
177
+ posemb_grid = posemb.unsqueeze(dim=0).permute(0, 2, 1)
178
+ posemb_grid = torch.nn.functional.interpolate(posemb_grid, size=L2, mode='linear')
179
+ posemb_grid = posemb_grid.permute(0, 2, 1).squeeze(dim=0)
180
+ v = posemb_grid
181
+
182
+ need_init_state_dict[k] = v
183
+
184
+ self.load_state_dict(need_init_state_dict, strict=False)
185
+
186
+
187
+ @torch.jit.ignore
188
+ def no_weight_decay(self):
189
+ return {
190
+ 'positional_embedding',
191
+ 'token_embedding',
192
+ }
193
+
194
+ def forward(self, input_ids, attention_mask=None):
195
+ key_padding_mask = (attention_mask == 0) if (not self.autogressive and attention_mask is not None) else None
196
+ # key_padding_mask = (input_ids == 0) if not self.autogressive else None
197
+ x = self.token_embedding(input_ids) # [batch_size, n_ctx, d_model]
198
+ x = x + self.positional_embedding
199
+ x = x.permute(1, 0, 2) # NLD -> LND
200
+ for block in self.resblocks:
201
+ x = block(x, key_padding_mask)
202
+ x = x.permute(1, 0, 2) # LND -> NLD
203
+
204
+ x = self.ln_final(x)
205
+
206
+ return {'last_hidden_state': x}
207
+
208
+
209
+ @register_lang_encoder
210
+ def lang_encoder(config_encoder, tokenizer, verbose, **kwargs):
211
+ transformer = Transformer(
212
+ context_length=config_encoder['CONTEXT_LENGTH'],
213
+ vocab_size=tokenizer.vocab_size,
214
+ width=config_encoder['WIDTH'],
215
+ layers=config_encoder['LAYERS'],
216
+ heads=config_encoder['HEADS'],
217
+ autogressive=config_encoder.get('AUTOGRESSIVE', True)
218
+ )
219
+
220
+ if config_encoder.get('LOAD_PRETRAINED', False):
221
+ transformer.load_pretrained(config_encoder['PRETRAINED'], config_encoder.get('PRETRAINED_LAYERS', ['*']))
222
+ return transformer