xinyu1205 commited on
Commit
8e6dc9f
1 Parent(s): 8c25077

Create utils.py

Browse files
Files changed (1) hide show
  1. models/utils.py +278 -0
models/utils.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import math
5
+
6
+ from torch import nn
7
+ from typing import List
8
+ from transformers import BertTokenizer
9
+ from urllib.parse import urlparse
10
+ from timm.models.hub import download_cached_file
11
+ from models.vit import interpolate_pos_embed
12
+ from models.swin_transformer import interpolate_relative_pos_embed
13
+ from pathlib import Path
14
+ CONFIG_PATH=(Path(__file__).resolve().parents[1])
15
+
16
+ def read_json(rpath):
17
+ with open(rpath, 'r') as f:
18
+ return json.load(f)
19
+
20
+
21
+ def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module,
22
+ base_model_prefix: str, skip_key: str):
23
+ uninitialized_encoder_weights: List[str] = []
24
+ if decoder.__class__ != encoder.__class__:
25
+ logger.info(
26
+ f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder weights are correctly initialized."
27
+ )
28
+
29
+ def tie_encoder_to_decoder_recursively(
30
+ decoder_pointer: nn.Module,
31
+ encoder_pointer: nn.Module,
32
+ module_name: str,
33
+ uninitialized_encoder_weights: List[str],
34
+ skip_key: str,
35
+ depth=0,
36
+ ):
37
+ assert isinstance(decoder_pointer, nn.Module) and isinstance(
38
+ encoder_pointer, nn.Module
39
+ ), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module"
40
+ if hasattr(decoder_pointer, "weight") and skip_key not in module_name:
41
+ assert hasattr(encoder_pointer, "weight")
42
+ encoder_pointer.weight = decoder_pointer.weight
43
+ if hasattr(decoder_pointer, "bias"):
44
+ assert hasattr(encoder_pointer, "bias")
45
+ encoder_pointer.bias = decoder_pointer.bias
46
+ print(module_name + ' is tied')
47
+ return
48
+
49
+ encoder_modules = encoder_pointer._modules
50
+ decoder_modules = decoder_pointer._modules
51
+ if len(decoder_modules) > 0:
52
+ assert (
53
+ len(encoder_modules) > 0
54
+ ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
55
+
56
+ all_encoder_weights = set([
57
+ module_name + "/" + sub_name
58
+ for sub_name in encoder_modules.keys()
59
+ ])
60
+ encoder_layer_pos = 0
61
+ for name, module in decoder_modules.items():
62
+ if name.isdigit():
63
+ encoder_name = str(int(name) + encoder_layer_pos)
64
+ decoder_name = name
65
+ if not isinstance(
66
+ decoder_modules[decoder_name],
67
+ type(encoder_modules[encoder_name])) and len(
68
+ encoder_modules) != len(decoder_modules):
69
+ # this can happen if the name corresponds to the position in a list module list of layers
70
+ # in this case the decoder has added a cross-attention that the encoder does not have
71
+ # thus skip this step and subtract one layer pos from encoder
72
+ encoder_layer_pos -= 1
73
+ continue
74
+ elif name not in encoder_modules:
75
+ continue
76
+ elif depth > 500:
77
+ raise ValueError(
78
+ "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model."
79
+ )
80
+ else:
81
+ decoder_name = encoder_name = name
82
+ tie_encoder_to_decoder_recursively(
83
+ decoder_modules[decoder_name],
84
+ encoder_modules[encoder_name],
85
+ module_name + "/" + name,
86
+ uninitialized_encoder_weights,
87
+ skip_key,
88
+ depth=depth + 1,
89
+ )
90
+ all_encoder_weights.remove(module_name + "/" + encoder_name)
91
+
92
+ uninitialized_encoder_weights += list(all_encoder_weights)
93
+
94
+ # tie weights recursively
95
+ tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix,
96
+ uninitialized_encoder_weights, skip_key)
97
+
98
+
99
+ class GroupWiseLinear(nn.Module):
100
+ # could be changed to:
101
+ # output = torch.einsum('ijk,zjk->ij', x, self.W)
102
+ # or output = torch.einsum('ijk,jk->ij', x, self.W[0])
103
+ def __init__(self, num_class, hidden_dim, bias=True):
104
+ super().__init__()
105
+ self.num_class = num_class
106
+ self.hidden_dim = hidden_dim
107
+ self.bias = bias
108
+
109
+ self.W = nn.Parameter(torch.Tensor(1, num_class, hidden_dim))
110
+ if bias:
111
+ self.b = nn.Parameter(torch.Tensor(1, num_class))
112
+ self.reset_parameters()
113
+
114
+ def reset_parameters(self):
115
+ stdv = 1. / math.sqrt(self.W.size(2))
116
+ for i in range(self.num_class):
117
+ self.W[0][i].data.uniform_(-stdv, stdv)
118
+ if self.bias:
119
+ for i in range(self.num_class):
120
+ self.b[0][i].data.uniform_(-stdv, stdv)
121
+
122
+ def forward(self, x):
123
+ # x: B,K,d
124
+ x = (self.W * x).sum(-1)
125
+ if self.bias:
126
+ x = x + self.b
127
+ return x
128
+
129
+
130
+ def init_tokenizer():
131
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
132
+ tokenizer.add_special_tokens({'bos_token': '[DEC]'})
133
+ tokenizer.add_special_tokens({'additional_special_tokens': ['[ENC]']})
134
+ tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
135
+ return tokenizer
136
+
137
+
138
+ def create_vit(vit,
139
+ image_size,
140
+ use_grad_checkpointing=False,
141
+ ckpt_layer=0,
142
+ drop_path_rate=0):
143
+
144
+ assert vit in ['base', 'large'], "vit parameter must be base or large"
145
+ if vit == 'base':
146
+ vision_width = 768
147
+ visual_encoder = VisionTransformer(
148
+ img_size=image_size,
149
+ patch_size=16,
150
+ embed_dim=vision_width,
151
+ depth=12,
152
+ num_heads=12,
153
+ use_grad_checkpointing=use_grad_checkpointing,
154
+ ckpt_layer=ckpt_layer,
155
+ drop_path_rate=0 or drop_path_rate)
156
+ elif vit == 'large':
157
+ vision_width = 1024
158
+ visual_encoder = VisionTransformer(
159
+ img_size=image_size,
160
+ patch_size=16,
161
+ embed_dim=vision_width,
162
+ depth=24,
163
+ num_heads=16,
164
+ use_grad_checkpointing=use_grad_checkpointing,
165
+ ckpt_layer=ckpt_layer,
166
+ drop_path_rate=0.1 or drop_path_rate)
167
+ return visual_encoder, vision_width
168
+
169
+
170
+ def is_url(url_or_filename):
171
+ parsed = urlparse(url_or_filename)
172
+ return parsed.scheme in ("http", "https")
173
+
174
+
175
+ def load_checkpoint(model, url_or_filename):
176
+ if is_url(url_or_filename):
177
+ cached_file = download_cached_file(url_or_filename,
178
+ check_hash=False,
179
+ progress=True)
180
+ checkpoint = torch.load(cached_file, map_location='cpu')
181
+ elif os.path.isfile(url_or_filename):
182
+ checkpoint = torch.load(url_or_filename, map_location='cpu')
183
+ else:
184
+ raise RuntimeError('checkpoint url or path is invalid')
185
+
186
+ state_dict = checkpoint['model']
187
+
188
+ state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(
189
+ state_dict['visual_encoder.pos_embed'], model.visual_encoder)
190
+ if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
191
+ state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(
192
+ state_dict['visual_encoder_m.pos_embed'], model.visual_encoder_m)
193
+ for key in model.state_dict().keys():
194
+ if key in state_dict.keys():
195
+ if state_dict[key].shape != model.state_dict()[key].shape:
196
+ del state_dict[key]
197
+
198
+ msg = model.load_state_dict(state_dict, strict=False)
199
+ print('load checkpoint from %s' % url_or_filename)
200
+ return model, msg
201
+
202
+
203
+ def load_checkpoint_swinbase(model, url_or_filename, kwargs):
204
+ if kwargs['image_size'] == 224:
205
+ vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_224.json'
206
+ elif kwargs['image_size'] == 384:
207
+ vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_384.json'
208
+ window_size = read_json(vision_config_path)['window_size']
209
+ print('--------------')
210
+ print(url_or_filename)
211
+ print('--------------')
212
+ if is_url(url_or_filename):
213
+ cached_file = download_cached_file(url_or_filename,
214
+ check_hash=False,
215
+ progress=True)
216
+ checkpoint = torch.load(cached_file, map_location='cpu')
217
+ elif os.path.isfile(url_or_filename):
218
+ checkpoint = torch.load(url_or_filename, map_location='cpu')
219
+ else:
220
+ raise RuntimeError('checkpoint url or path is invalid')
221
+
222
+ state_dict = checkpoint['model']
223
+
224
+ for k in list(state_dict.keys()):
225
+ if 'relative_position_bias_table' in k:
226
+ dst_num_pos = (2 * window_size - 1)**2
227
+ state_dict[k] = interpolate_relative_pos_embed(state_dict[k],
228
+ dst_num_pos,
229
+ param_name=k)
230
+ elif ('relative_position_index' in k) or ('attn_mask' in k):
231
+ del state_dict[k]
232
+ elif "vision_multi" in k:
233
+ state_dict[k.replace("vision_multi",
234
+ "tagging_head")] = state_dict.pop(k)
235
+
236
+ msg = model.load_state_dict(state_dict, strict=False)
237
+ print('load checkpoint from %s' % url_or_filename)
238
+ return model, msg
239
+
240
+
241
+ def load_checkpoint_swinlarge(model, url_or_filename, kwargs):
242
+ if kwargs['image_size'] == 224:
243
+ vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_224.json'
244
+ elif kwargs['image_size'] == 384:
245
+ vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_384.json'
246
+ window_size = read_json(vision_config_path)['window_size']
247
+ print('--------------')
248
+ print(url_or_filename)
249
+ print('--------------')
250
+ if is_url(url_or_filename):
251
+ cached_file = download_cached_file(url_or_filename,
252
+ check_hash=False,
253
+ progress=True)
254
+ checkpoint = torch.load(cached_file, map_location='cpu')
255
+ elif os.path.isfile(url_or_filename):
256
+ checkpoint = torch.load(url_or_filename, map_location='cpu')
257
+ else:
258
+ raise RuntimeError('checkpoint url or path is invalid')
259
+
260
+ state_dict = checkpoint['model']
261
+
262
+ for k in list(state_dict.keys()):
263
+ if 'relative_position_bias_table' in k:
264
+ dst_num_pos = (2 * window_size - 1)**2
265
+ state_dict[k] = interpolate_relative_pos_embed(state_dict[k],
266
+ dst_num_pos,
267
+ param_name=k)
268
+ elif ('relative_position_index' in k) or ('attn_mask' in k):
269
+ del state_dict[k]
270
+ elif "vision_multi" in k:
271
+ state_dict[k.replace("vision_multi",
272
+ "tagging_head")] = state_dict.pop(k)
273
+
274
+ msg = model.load_state_dict(state_dict, strict=False)
275
+ print('load checkpoint from %s' % url_or_filename)
276
+ return model, msg
277
+
278
+