WEBing commited on
Commit
4cbb2ad
·
1 Parent(s): 388bd8b

merge eva_clip to vision_tower_builder

Browse files
Files changed (45) hide show
  1. eva_clip/model_configs/EVA02-CLIP-L-14-448.json → EVA02-CLIP-L-14-448.json +0 -0
  2. eva_clip/__init__.py +0 -11
  3. eva_clip/__pycache__/__init__.cpython-39.pyc +0 -0
  4. eva_clip/__pycache__/constants.cpython-39.pyc +0 -0
  5. eva_clip/__pycache__/eva_vit_model.cpython-39.pyc +0 -0
  6. eva_clip/__pycache__/factory.cpython-39.pyc +0 -0
  7. eva_clip/__pycache__/hf_configs.cpython-39.pyc +0 -0
  8. eva_clip/__pycache__/hf_model.cpython-39.pyc +0 -0
  9. eva_clip/__pycache__/loss.cpython-39.pyc +0 -0
  10. eva_clip/__pycache__/model.cpython-39.pyc +0 -0
  11. eva_clip/__pycache__/modified_resnet.cpython-39.pyc +0 -0
  12. eva_clip/__pycache__/openai.cpython-39.pyc +0 -0
  13. eva_clip/__pycache__/pretrained.cpython-39.pyc +0 -0
  14. eva_clip/__pycache__/rope.cpython-39.pyc +0 -0
  15. eva_clip/__pycache__/timm_model.cpython-39.pyc +0 -0
  16. eva_clip/__pycache__/tokenizer.cpython-39.pyc +0 -0
  17. eva_clip/__pycache__/transform.cpython-39.pyc +0 -0
  18. eva_clip/__pycache__/transformer.cpython-39.pyc +0 -0
  19. eva_clip/__pycache__/utils.cpython-39.pyc +0 -0
  20. eva_clip/bpe_simple_vocab_16e6.txt.gz +0 -3
  21. eva_clip/constants.py +0 -2
  22. eva_clip/factory.py +0 -459
  23. eva_clip/hf_configs.py +0 -57
  24. eva_clip/hf_model.py +0 -248
  25. eva_clip/loss.py +0 -138
  26. eva_clip/model.py +0 -439
  27. eva_clip/model_configs/EVA01-CLIP-B-16.json +0 -19
  28. eva_clip/model_configs/EVA01-CLIP-g-14-plus.json +0 -24
  29. eva_clip/model_configs/EVA01-CLIP-g-14.json +0 -24
  30. eva_clip/model_configs/EVA02-CLIP-B-16.json +0 -29
  31. eva_clip/model_configs/EVA02-CLIP-L-14-336.json +0 -29
  32. eva_clip/model_configs/EVA02-CLIP-L-14.json +0 -29
  33. eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json +0 -25
  34. eva_clip/model_configs/EVA02-CLIP-bigE-14.json +0 -25
  35. eva_clip/modified_resnet.py +0 -181
  36. eva_clip/openai.py +0 -144
  37. eva_clip/pretrained.py +0 -332
  38. eva_clip/rope.py +0 -137
  39. eva_clip/timm_model.py +0 -122
  40. eva_clip/tokenizer.py +0 -201
  41. eva_clip/transform.py +0 -103
  42. eva_clip/transformer.py +0 -737
  43. eva_clip/utils.py +0 -326
  44. modeling_kangaroo.py +10 -61
  45. eva_clip/eva_vit_model.py → vision_tower_builder.py +242 -6
eva_clip/model_configs/EVA02-CLIP-L-14-448.json → EVA02-CLIP-L-14-448.json RENAMED
File without changes
eva_clip/__init__.py DELETED
@@ -1,11 +0,0 @@
1
- from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
2
- from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer
3
- from .factory import list_models, add_model_config, get_model_config, load_checkpoint
4
- from .loss import ClipLoss
5
- from .model import CLIP, CustomCLIP, CLIPTextCfg, CLIPVisionCfg,\
6
- convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
7
- from .openai import load_openai_model, list_openai_models
8
- from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model,\
9
- get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
10
- from .tokenizer import SimpleTokenizer, tokenize
11
- from .transform import image_transform
 
 
 
 
 
 
 
 
 
 
 
 
eva_clip/__pycache__/__init__.cpython-39.pyc DELETED
Binary file (1.28 kB)
 
eva_clip/__pycache__/constants.cpython-39.pyc DELETED
Binary file (313 Bytes)
 
eva_clip/__pycache__/eva_vit_model.cpython-39.pyc DELETED
Binary file (15.8 kB)
 
eva_clip/__pycache__/factory.cpython-39.pyc DELETED
Binary file (11.2 kB)
 
eva_clip/__pycache__/hf_configs.cpython-39.pyc DELETED
Binary file (714 Bytes)
 
eva_clip/__pycache__/hf_model.cpython-39.pyc DELETED
Binary file (7.38 kB)
 
eva_clip/__pycache__/loss.cpython-39.pyc DELETED
Binary file (3.32 kB)
 
eva_clip/__pycache__/model.cpython-39.pyc DELETED
Binary file (13.2 kB)
 
eva_clip/__pycache__/modified_resnet.cpython-39.pyc DELETED
Binary file (6.33 kB)
 
eva_clip/__pycache__/openai.cpython-39.pyc DELETED
Binary file (4.79 kB)
 
eva_clip/__pycache__/pretrained.cpython-39.pyc DELETED
Binary file (9.01 kB)
 
eva_clip/__pycache__/rope.cpython-39.pyc DELETED
Binary file (5.25 kB)
 
eva_clip/__pycache__/timm_model.cpython-39.pyc DELETED
Binary file (3.94 kB)
 
eva_clip/__pycache__/tokenizer.cpython-39.pyc DELETED
Binary file (8.42 kB)
 
eva_clip/__pycache__/transform.cpython-39.pyc DELETED
Binary file (2.78 kB)
 
eva_clip/__pycache__/transformer.cpython-39.pyc DELETED
Binary file (20.7 kB)
 
eva_clip/__pycache__/utils.cpython-39.pyc DELETED
Binary file (9.61 kB)
 
eva_clip/bpe_simple_vocab_16e6.txt.gz DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
- size 1356917
 
 
 
 
eva_clip/constants.py DELETED
@@ -1,2 +0,0 @@
1
- OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
2
- OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
 
 
 
eva_clip/factory.py DELETED
@@ -1,459 +0,0 @@
1
- import json
2
- import logging
3
- import os
4
- import pathlib
5
- import re
6
- from copy import deepcopy
7
- from pathlib import Path
8
- from typing import Optional, Tuple, Union, Dict, Any
9
- import torch
10
-
11
- from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
12
- from .model import CLIP, CustomCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
13
- get_cast_dtype
14
- from .openai import load_openai_model
15
- from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model
16
- from .transform import image_transform
17
- from .tokenizer import HFTokenizer, tokenize
18
- from .utils import resize_clip_pos_embed, resize_evaclip_pos_embed, resize_visual_pos_embed, resize_eva_pos_embed
19
-
20
-
21
- _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
22
- _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
23
-
24
-
25
- def _natural_key(string_):
26
- return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
27
-
28
-
29
- def _rescan_model_configs():
30
- global _MODEL_CONFIGS
31
-
32
- config_ext = ('.json',)
33
- config_files = []
34
- for config_path in _MODEL_CONFIG_PATHS:
35
- if config_path.is_file() and config_path.suffix in config_ext:
36
- config_files.append(config_path)
37
- elif config_path.is_dir():
38
- for ext in config_ext:
39
- config_files.extend(config_path.glob(f'*{ext}'))
40
-
41
- for cf in config_files:
42
- with open(cf, "r", encoding="utf8") as f:
43
- model_cfg = json.load(f)
44
- if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
45
- _MODEL_CONFIGS[cf.stem] = model_cfg
46
-
47
- _MODEL_CONFIGS = dict(sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0])))
48
-
49
-
50
- _rescan_model_configs() # initial populate of model config registry
51
-
52
-
53
- def list_models():
54
- """ enumerate available model architectures based on config files """
55
- return list(_MODEL_CONFIGS.keys())
56
-
57
-
58
- def add_model_config(path):
59
- """ add model config path or file and update registry """
60
- if not isinstance(path, Path):
61
- path = Path(path)
62
- _MODEL_CONFIG_PATHS.append(path)
63
- _rescan_model_configs()
64
-
65
-
66
- def get_model_config(model_name):
67
- if model_name in _MODEL_CONFIGS:
68
- return deepcopy(_MODEL_CONFIGS[model_name])
69
- else:
70
- return None
71
-
72
-
73
- def get_tokenizer(model_name):
74
- config = get_model_config(model_name)
75
- tokenizer = HFTokenizer(config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize
76
- return tokenizer
77
-
78
-
79
- # loading openai CLIP weights when is_openai=True for training
80
- def load_state_dict(checkpoint_path: str, map_location: str='cpu', model_key: str='model|module|state_dict', is_openai: bool=False, skip_list: list=[]):
81
- if is_openai:
82
- model = torch.jit.load(checkpoint_path, map_location="cpu").eval()
83
- state_dict = model.state_dict()
84
- for key in ["input_resolution", "context_length", "vocab_size"]:
85
- state_dict.pop(key, None)
86
- else:
87
- checkpoint = torch.load(checkpoint_path, map_location=map_location)
88
- for mk in model_key.split('|'):
89
- if isinstance(checkpoint, dict) and mk in checkpoint:
90
- state_dict = checkpoint[mk]
91
- break
92
- else:
93
- state_dict = checkpoint
94
- if next(iter(state_dict.items()))[0].startswith('module'):
95
- state_dict = {k[7:]: v for k, v in state_dict.items()}
96
-
97
- for k in skip_list:
98
- if k in list(state_dict.keys()):
99
- logging.info(f"Removing key {k} from pretrained checkpoint")
100
- del state_dict[k]
101
-
102
- if os.getenv('RoPE') == '1':
103
- for k in list(state_dict.keys()):
104
- if 'freqs_cos' in k or 'freqs_sin' in k:
105
- del state_dict[k]
106
- return state_dict
107
-
108
-
109
-
110
- def load_checkpoint(model, checkpoint_path, model_key="model|module|state_dict", strict=True):
111
- state_dict = load_state_dict(checkpoint_path, model_key=model_key, is_openai=False)
112
- # detect old format and make compatible with new format
113
- if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
114
- state_dict = convert_to_custom_text_state_dict(state_dict)
115
- if 'text.logit_scale' in state_dict and hasattr(model, 'logit_scale'):
116
- state_dict['logit_scale'] = state_dict['text.logit_scale']
117
- del state_dict['text.logit_scale']
118
-
119
- # resize_clip_pos_embed for CLIP and open CLIP
120
- if 'visual.positional_embedding' in state_dict:
121
- resize_clip_pos_embed(state_dict, model)
122
- # specified to eva_vit_model
123
- elif 'visual.pos_embed' in state_dict:
124
- resize_evaclip_pos_embed(state_dict, model)
125
-
126
- # resize_clip_pos_embed(state_dict, model)
127
- incompatible_keys = model.load_state_dict(state_dict, strict=strict)
128
- logging.info(f"incompatible_keys.missing_keys: {incompatible_keys.missing_keys}")
129
- return incompatible_keys
130
-
131
- def load_clip_visual_state_dict(checkpoint_path: str, map_location: str='cpu', is_openai: bool=False, skip_list:list=[]):
132
- state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list)
133
-
134
- for k in list(state_dict.keys()):
135
- if not k.startswith('visual.'):
136
- del state_dict[k]
137
- for k in list(state_dict.keys()):
138
- if k.startswith('visual.'):
139
- new_k = k[7:]
140
- state_dict[new_k] = state_dict[k]
141
- del state_dict[k]
142
- return state_dict
143
-
144
- def load_clip_text_state_dict(checkpoint_path: str, map_location: str='cpu', is_openai: bool=False, skip_list:list=[]):
145
- state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list)
146
-
147
- for k in list(state_dict.keys()):
148
- if k.startswith('visual.'):
149
- del state_dict[k]
150
- return state_dict
151
-
152
- def get_pretrained_tag(pretrained_model):
153
- pretrained_model = pretrained_model.lower()
154
- if "laion" in pretrained_model or "open_clip" in pretrained_model:
155
- return "open_clip"
156
- elif "openai" in pretrained_model:
157
- return "clip"
158
- elif "eva" in pretrained_model and "clip" in pretrained_model:
159
- return "eva_clip"
160
- else:
161
- return "other"
162
-
163
- def load_pretrained_checkpoint(
164
- model,
165
- visual_checkpoint_path,
166
- text_checkpoint_path,
167
- strict=True,
168
- visual_model=None,
169
- text_model=None,
170
- model_key="model|module|state_dict",
171
- skip_list=[]):
172
- visual_tag = get_pretrained_tag(visual_model)
173
- text_tag = get_pretrained_tag(text_model)
174
-
175
- logging.info(f"num of model state_dict keys: {len(model.state_dict().keys())}")
176
- visual_incompatible_keys, text_incompatible_keys = None, None
177
- if visual_checkpoint_path:
178
- if visual_tag == "eva_clip" or visual_tag == "open_clip":
179
- visual_state_dict = load_clip_visual_state_dict(visual_checkpoint_path, is_openai=False, skip_list=skip_list)
180
- elif visual_tag == "clip":
181
- visual_state_dict = load_clip_visual_state_dict(visual_checkpoint_path, is_openai=True, skip_list=skip_list)
182
- else:
183
- visual_state_dict = load_state_dict(visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list)
184
-
185
- # resize_clip_pos_embed for CLIP and open CLIP
186
- if 'positional_embedding' in visual_state_dict:
187
- resize_visual_pos_embed(visual_state_dict, model)
188
- # specified to EVA model
189
- elif 'pos_embed' in visual_state_dict:
190
- resize_eva_pos_embed(visual_state_dict, model)
191
-
192
- visual_incompatible_keys = model.visual.load_state_dict(visual_state_dict, strict=strict)
193
- logging.info(f"num of loaded visual_state_dict keys: {len(visual_state_dict.keys())}")
194
- logging.info(f"visual_incompatible_keys.missing_keys: {visual_incompatible_keys.missing_keys}")
195
-
196
- if text_checkpoint_path:
197
- if text_tag == "eva_clip" or text_tag == "open_clip":
198
- text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=False, skip_list=skip_list)
199
- elif text_tag == "clip":
200
- text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=True, skip_list=skip_list)
201
- else:
202
- text_state_dict = load_state_dict(visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list)
203
-
204
- text_incompatible_keys = model.text.load_state_dict(text_state_dict, strict=strict)
205
-
206
- logging.info(f"num of loaded text_state_dict keys: {len(text_state_dict.keys())}")
207
- logging.info(f"text_incompatible_keys.missing_keys: {text_incompatible_keys.missing_keys}")
208
-
209
- return visual_incompatible_keys, text_incompatible_keys
210
-
211
- def create_model(
212
- model_name: str,
213
- pretrained: Optional[str] = None,
214
- precision: str = 'fp32',
215
- device: Union[str, torch.device] = 'cpu',
216
- jit: bool = False,
217
- force_quick_gelu: bool = False,
218
- force_custom_clip: bool = False,
219
- force_patch_dropout: Optional[float] = None,
220
- pretrained_image: str = '',
221
- pretrained_text: str = '',
222
- pretrained_hf: bool = True,
223
- pretrained_visual_model: str = None,
224
- pretrained_text_model: str = None,
225
- cache_dir: Optional[str] = None,
226
- skip_list: list = [],
227
- ):
228
- model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
229
- if isinstance(device, str):
230
- device = torch.device(device)
231
-
232
- if pretrained and pretrained.lower() == 'openai':
233
- logging.info(f'Loading pretrained {model_name} from OpenAI.')
234
- model = load_openai_model(
235
- model_name,
236
- precision=precision,
237
- device=device,
238
- jit=jit,
239
- cache_dir=cache_dir,
240
- )
241
- else:
242
- model_cfg = get_model_config(model_name)
243
- if model_cfg is not None:
244
- logging.info(f'Loaded {model_name} model config.')
245
- else:
246
- logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
247
- raise RuntimeError(f'Model config for {model_name} not found.')
248
-
249
- if 'rope' in model_cfg.get('vision_cfg', {}):
250
- if model_cfg['vision_cfg']['rope']:
251
- os.environ['RoPE'] = "1"
252
- else:
253
- os.environ['RoPE'] = "0"
254
-
255
- if force_quick_gelu:
256
- # override for use of QuickGELU on non-OpenAI transformer models
257
- model_cfg["quick_gelu"] = True
258
-
259
- if force_patch_dropout is not None:
260
- # override the default patch dropout value
261
- model_cfg['vision_cfg']["patch_dropout"] = force_patch_dropout
262
-
263
- cast_dtype = get_cast_dtype(precision)
264
- custom_clip = model_cfg.pop('custom_text', False) or force_custom_clip or ('hf_model_name' in model_cfg['text_cfg'])
265
-
266
- if custom_clip:
267
- if 'hf_model_name' in model_cfg.get('text_cfg', {}):
268
- model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
269
- model = CustomCLIP(**model_cfg, cast_dtype=cast_dtype)
270
- else:
271
- model = CLIP(**model_cfg, cast_dtype=cast_dtype)
272
-
273
- pretrained_cfg = {}
274
- if pretrained:
275
- checkpoint_path = ''
276
- pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
277
- if pretrained_cfg:
278
- checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
279
- elif os.path.exists(pretrained):
280
- checkpoint_path = pretrained
281
-
282
- if checkpoint_path:
283
- logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
284
- load_checkpoint(model,
285
- checkpoint_path,
286
- model_key="model|module|state_dict",
287
- strict=False
288
- )
289
- else:
290
- error_str = (
291
- f'Pretrained weights ({pretrained}) not found for model {model_name}.'
292
- f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
293
- logging.warning(error_str)
294
- raise RuntimeError(error_str)
295
- else:
296
- visual_checkpoint_path = ''
297
- text_checkpoint_path = ''
298
-
299
- if pretrained_image:
300
- pretrained_visual_model = pretrained_visual_model.replace('/', '-') # for callers using old naming with / in ViT names
301
- pretrained_image_cfg = get_pretrained_cfg(pretrained_visual_model, pretrained_image)
302
- if 'timm_model_name' in model_cfg.get('vision_cfg', {}):
303
- # pretrained weight loading for timm models set via vision_cfg
304
- model_cfg['vision_cfg']['timm_model_pretrained'] = True
305
- elif pretrained_image_cfg:
306
- visual_checkpoint_path = download_pretrained(pretrained_image_cfg, cache_dir=cache_dir)
307
- elif os.path.exists(pretrained_image):
308
- visual_checkpoint_path = pretrained_image
309
- else:
310
- logging.warning(f'Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual.')
311
- raise RuntimeError(f'Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual.')
312
-
313
- if pretrained_text:
314
- pretrained_text_model = pretrained_text_model.replace('/', '-') # for callers using old naming with / in ViT names
315
- pretrained_text_cfg = get_pretrained_cfg(pretrained_text_model, pretrained_text)
316
- if pretrained_image_cfg:
317
- text_checkpoint_path = download_pretrained(pretrained_text_cfg, cache_dir=cache_dir)
318
- elif os.path.exists(pretrained_text):
319
- text_checkpoint_path = pretrained_text
320
- else:
321
- logging.warning(f'Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text.')
322
- raise RuntimeError(f'Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text.')
323
-
324
- if visual_checkpoint_path:
325
- logging.info(f'Loading pretrained {model_name}.visual weights ({visual_checkpoint_path}).')
326
- if text_checkpoint_path:
327
- logging.info(f'Loading pretrained {model_name}.text weights ({text_checkpoint_path}).')
328
-
329
- if visual_checkpoint_path or text_checkpoint_path:
330
- load_pretrained_checkpoint(
331
- model,
332
- visual_checkpoint_path,
333
- text_checkpoint_path,
334
- strict=False,
335
- visual_model=pretrained_visual_model,
336
- text_model=pretrained_text_model,
337
- model_key="model|module|state_dict",
338
- skip_list=skip_list
339
- )
340
-
341
- if "fp16" in precision or "bf16" in precision:
342
- logging.info(f'convert precision to {precision}')
343
- model = model.to(torch.bfloat16) if 'bf16' in precision else model.to(torch.float16)
344
-
345
- model.to(device=device)
346
-
347
- # set image / mean metadata from pretrained_cfg if available, or use default
348
- model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
349
- model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD
350
-
351
- if jit:
352
- model = torch.jit.script(model)
353
-
354
- return model
355
-
356
-
357
- def create_model_and_transforms(
358
- model_name: str,
359
- pretrained: Optional[str] = None,
360
- precision: str = 'fp32',
361
- device: Union[str, torch.device] = 'cpu',
362
- jit: bool = False,
363
- force_quick_gelu: bool = False,
364
- force_custom_clip: bool = False,
365
- force_patch_dropout: Optional[float] = None,
366
- pretrained_image: str = '',
367
- pretrained_text: str = '',
368
- pretrained_hf: bool = True,
369
- pretrained_visual_model: str = None,
370
- pretrained_text_model: str = None,
371
- image_mean: Optional[Tuple[float, ...]] = None,
372
- image_std: Optional[Tuple[float, ...]] = None,
373
- cache_dir: Optional[str] = None,
374
- skip_list: list = [],
375
- ):
376
- model = create_model(
377
- model_name,
378
- pretrained,
379
- precision=precision,
380
- device=device,
381
- jit=jit,
382
- force_quick_gelu=force_quick_gelu,
383
- force_custom_clip=force_custom_clip,
384
- force_patch_dropout=force_patch_dropout,
385
- pretrained_image=pretrained_image,
386
- pretrained_text=pretrained_text,
387
- pretrained_hf=pretrained_hf,
388
- pretrained_visual_model=pretrained_visual_model,
389
- pretrained_text_model=pretrained_text_model,
390
- cache_dir=cache_dir,
391
- skip_list=skip_list,
392
- )
393
-
394
- image_mean = image_mean or getattr(model.visual, 'image_mean', None)
395
- image_std = image_std or getattr(model.visual, 'image_std', None)
396
- preprocess_train = image_transform(
397
- model.visual.image_size,
398
- is_train=True,
399
- mean=image_mean,
400
- std=image_std
401
- )
402
- preprocess_val = image_transform(
403
- model.visual.image_size,
404
- is_train=False,
405
- mean=image_mean,
406
- std=image_std
407
- )
408
-
409
- return model, preprocess_train, preprocess_val
410
-
411
- def create_model_from_pretrained(
412
- model_name: str,
413
- pretrained: str,
414
- precision: str = 'fp32',
415
- device: Union[str, torch.device] = 'cpu',
416
- jit: bool = False,
417
- force_quick_gelu: bool = False,
418
- force_custom_clip: bool = False,
419
- force_patch_dropout: Optional[float] = None,
420
- return_transform: bool = True,
421
- image_mean: Optional[Tuple[float, ...]] = None,
422
- image_std: Optional[Tuple[float, ...]] = None,
423
- cache_dir: Optional[str] = None,
424
- is_frozen: bool = False,
425
- ):
426
- if not is_pretrained_cfg(model_name, pretrained) and not os.path.exists(pretrained):
427
- raise RuntimeError(
428
- f'{pretrained} is not a valid pretrained cfg or checkpoint for {model_name}.'
429
- f' Use open_clip.list_pretrained() to find one.')
430
-
431
- model = create_model(
432
- model_name,
433
- pretrained,
434
- precision=precision,
435
- device=device,
436
- jit=jit,
437
- force_quick_gelu=force_quick_gelu,
438
- force_custom_clip=force_custom_clip,
439
- force_patch_dropout=force_patch_dropout,
440
- cache_dir=cache_dir,
441
- )
442
-
443
- if is_frozen:
444
- for param in model.parameters():
445
- param.requires_grad = False
446
-
447
- if not return_transform:
448
- return model
449
-
450
- image_mean = image_mean or getattr(model.visual, 'image_mean', None)
451
- image_std = image_std or getattr(model.visual, 'image_std', None)
452
- preprocess = image_transform(
453
- model.visual.image_size,
454
- is_train=False,
455
- mean=image_mean,
456
- std=image_std
457
- )
458
-
459
- return model, preprocess
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eva_clip/hf_configs.py DELETED
@@ -1,57 +0,0 @@
1
- # HF architecture dict:
2
- arch_dict = {
3
- # https://huggingface.co/docs/transformers/model_doc/roberta#roberta
4
- "roberta": {
5
- "config_names": {
6
- "context_length": "max_position_embeddings",
7
- "vocab_size": "vocab_size",
8
- "width": "hidden_size",
9
- "heads": "num_attention_heads",
10
- "layers": "num_hidden_layers",
11
- "layer_attr": "layer",
12
- "token_embeddings_attr": "embeddings"
13
- },
14
- "pooler": "mean_pooler",
15
- },
16
- # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
17
- "xlm-roberta": {
18
- "config_names": {
19
- "context_length": "max_position_embeddings",
20
- "vocab_size": "vocab_size",
21
- "width": "hidden_size",
22
- "heads": "num_attention_heads",
23
- "layers": "num_hidden_layers",
24
- "layer_attr": "layer",
25
- "token_embeddings_attr": "embeddings"
26
- },
27
- "pooler": "mean_pooler",
28
- },
29
- # https://huggingface.co/docs/transformers/model_doc/mt5#mt5
30
- "mt5": {
31
- "config_names": {
32
- # unlimited seqlen
33
- # https://github.com/google-research/text-to-text-transfer-transformer/issues/273
34
- # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
35
- "context_length": "",
36
- "vocab_size": "vocab_size",
37
- "width": "d_model",
38
- "heads": "num_heads",
39
- "layers": "num_layers",
40
- "layer_attr": "block",
41
- "token_embeddings_attr": "embed_tokens"
42
- },
43
- "pooler": "mean_pooler",
44
- },
45
- "bert": {
46
- "config_names": {
47
- "context_length": "max_position_embeddings",
48
- "vocab_size": "vocab_size",
49
- "width": "hidden_size",
50
- "heads": "num_attention_heads",
51
- "layers": "num_hidden_layers",
52
- "layer_attr": "layer",
53
- "token_embeddings_attr": "embeddings"
54
- },
55
- "pooler": "mean_pooler",
56
- }
57
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eva_clip/hf_model.py DELETED
@@ -1,248 +0,0 @@
1
- """ huggingface model adapter
2
-
3
- Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
4
- """
5
-
6
- import re
7
-
8
- import torch
9
- import torch.nn as nn
10
- from torch.nn import functional as F
11
- from torch import TensorType
12
- try:
13
- import transformers
14
- from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer, AutoConfig, PretrainedConfig
15
- from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
16
- BaseModelOutputWithPoolingAndCrossAttentions
17
- except ImportError as e:
18
- transformers = None
19
-
20
-
21
- class BaseModelOutput:
22
- pass
23
-
24
-
25
- class PretrainedConfig:
26
- pass
27
-
28
- from .hf_configs import arch_dict
29
-
30
- # utils
31
- def _camel2snake(s):
32
- return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()
33
-
34
- # TODO: ?last - for gpt-like models
35
- _POOLERS = {}
36
-
37
- def register_pooler(cls):
38
- """Decorator registering pooler class"""
39
- _POOLERS[_camel2snake(cls.__name__)] = cls
40
- return cls
41
-
42
-
43
- @register_pooler
44
- class MeanPooler(nn.Module):
45
- """Mean pooling"""
46
- def forward(self, x:BaseModelOutput, attention_mask:TensorType):
47
- masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
48
- return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
49
-
50
- @register_pooler
51
- class MaxPooler(nn.Module):
52
- """Max pooling"""
53
- def forward(self, x:BaseModelOutput, attention_mask:TensorType):
54
- masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)
55
- return masked_output.max(1).values
56
-
57
- @register_pooler
58
- class ClsPooler(nn.Module):
59
- """CLS token pooling"""
60
- def __init__(self, use_pooler_output=True):
61
- super().__init__()
62
- self.cls_token_position = 0
63
- self.use_pooler_output = use_pooler_output
64
-
65
- def forward(self, x:BaseModelOutput, attention_mask:TensorType):
66
-
67
- if (self.use_pooler_output and
68
- isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and
69
- (x.pooler_output is not None)
70
- ):
71
- return x.pooler_output
72
-
73
- return x.last_hidden_state[:, self.cls_token_position, :]
74
-
75
- class HFTextEncoder(nn.Module):
76
- """HuggingFace model adapter"""
77
- def __init__(
78
- self,
79
- model_name_or_path: str,
80
- output_dim: int,
81
- tokenizer_name: str = None,
82
- config: PretrainedConfig = None,
83
- pooler_type: str = None,
84
- proj: str = None,
85
- pretrained: bool = True,
86
- masked_language_modeling: bool = False):
87
- super().__init__()
88
-
89
- self.output_dim = output_dim
90
-
91
- # TODO: find better way to get this information
92
- uses_transformer_pooler = (pooler_type == "cls_pooler")
93
-
94
- if transformers is None:
95
- raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")
96
- if config is None:
97
- self.config = AutoConfig.from_pretrained(model_name_or_path)
98
- if masked_language_modeling:
99
- create_func, model_args = (AutoModelForMaskedLM.from_pretrained, model_name_or_path) if pretrained else (
100
- AutoModelForMaskedLM.from_config, self.config)
101
- else:
102
- create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (
103
- AutoModel.from_config, self.config)
104
- # TODO: do all model configs have this attribute? PretrainedConfig does so yes??
105
- if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
106
- self.transformer = create_func(model_args)
107
- self.transformer = self.transformer.encoder
108
- else:
109
- self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)
110
- else:
111
- self.config = config
112
- if masked_language_modeling:
113
- self.transformer = AutoModelForMaskedLM.from_config(config)
114
- else:
115
- self.transformer = AutoModel.from_config(config)
116
-
117
- if pooler_type is None: # get default arch pooler
118
- self.pooler = _POOLERS[(arch_dict[self.config.model_type]["pooler"])]()
119
- else:
120
- self.pooler = _POOLERS[pooler_type]()
121
-
122
- d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
123
- if (d_model == output_dim) and (proj is None): # do we always need a proj?
124
- self.proj = nn.Identity()
125
- elif proj == 'linear':
126
- self.proj = nn.Linear(d_model, output_dim, bias=False)
127
- elif proj == 'mlp':
128
- hidden_size = (d_model + output_dim) // 2
129
- self.proj = nn.Sequential(
130
- nn.Linear(d_model, hidden_size, bias=False),
131
- nn.GELU(),
132
- nn.Linear(hidden_size, output_dim, bias=False),
133
- )
134
-
135
- # self.itm_proj = nn.Linear(d_model, 2, bias=False)
136
- # self.mlm_proj = nn.Linear(d_model, self.config.vocab_size), bias=False)
137
- self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
138
-
139
- # def forward_itm(self, x:TensorType, image_embeds:TensorType) -> TensorType:
140
- # image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(x.device)
141
- # attn_mask = (x != self.config.pad_token_id).long()
142
- # out = self.transformer(
143
- # input_ids=x,
144
- # attention_mask=attn_mask,
145
- # encoder_hidden_states = image_embeds,
146
- # encoder_attention_mask = image_atts,
147
- # )
148
- # pooled_out = self.pooler(out, attn_mask)
149
-
150
- # return self.itm_proj(pooled_out)
151
-
152
- def mask(self, input_ids, vocab_size, device, targets=None, masked_indices=None, probability_matrix=None):
153
- if masked_indices is None:
154
- masked_indices = torch.bernoulli(probability_matrix).bool()
155
-
156
- masked_indices[input_ids == self.tokenizer.pad_token_id] = False
157
- masked_indices[input_ids == self.tokenizer.cls_token_id] = False
158
-
159
- if targets is not None:
160
- targets[~masked_indices] = -100 # We only compute loss on masked tokens
161
-
162
- # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
163
- indices_replaced = torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices
164
- input_ids[indices_replaced] = self.tokenizer.mask_token_id
165
-
166
- # 10% of the time, we replace masked input tokens with random word
167
- indices_random = torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~indices_replaced
168
- random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to(device)
169
- input_ids[indices_random] = random_words[indices_random]
170
- # The rest of the time (10% of the time) we keep the masked input tokens unchanged
171
-
172
- if targets is not None:
173
- return input_ids, targets
174
- else:
175
- return input_ids
176
-
177
- def forward_mlm(self, input_ids, image_embeds, mlm_probability=0.25):
178
- labels = input_ids.clone()
179
- attn_mask = (input_ids != self.config.pad_token_id).long()
180
- image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(input_ids.device)
181
- vocab_size = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["vocab_size"])
182
- probability_matrix = torch.full(labels.shape, mlm_probability)
183
- input_ids, labels = self.mask(input_ids, vocab_size, input_ids.device, targets=labels,
184
- probability_matrix = probability_matrix)
185
- mlm_output = self.transformer(input_ids,
186
- attention_mask = attn_mask,
187
- encoder_hidden_states = image_embeds,
188
- encoder_attention_mask = image_atts,
189
- return_dict = True,
190
- labels = labels,
191
- )
192
- return mlm_output.loss
193
- # mlm_output = self.transformer(input_ids,
194
- # attention_mask = attn_mask,
195
- # encoder_hidden_states = image_embeds,
196
- # encoder_attention_mask = image_atts,
197
- # return_dict = True,
198
- # ).last_hidden_state
199
- # logits = self.mlm_proj(mlm_output)
200
-
201
- # # logits = logits[:, :-1, :].contiguous().view(-1, vocab_size)
202
- # logits = logits[:, 1:, :].contiguous().view(-1, vocab_size)
203
- # labels = labels[:, 1:].contiguous().view(-1)
204
-
205
- # mlm_loss = F.cross_entropy(
206
- # logits,
207
- # labels,
208
- # # label_smoothing=0.1,
209
- # )
210
- # return mlm_loss
211
-
212
-
213
- def forward(self, x:TensorType) -> TensorType:
214
- attn_mask = (x != self.config.pad_token_id).long()
215
- out = self.transformer(input_ids=x, attention_mask=attn_mask)
216
- pooled_out = self.pooler(out, attn_mask)
217
-
218
- return self.proj(pooled_out)
219
-
220
- def lock(self, unlocked_layers:int=0, freeze_layer_norm:bool=True):
221
- if not unlocked_layers: # full freezing
222
- for n, p in self.transformer.named_parameters():
223
- p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
224
- return
225
-
226
- encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
227
- layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
228
- print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
229
- embeddings = getattr(
230
- self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
231
- modules = [embeddings, *layer_list][:-unlocked_layers]
232
- # freeze layers
233
- for module in modules:
234
- for n, p in module.named_parameters():
235
- p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
236
-
237
-
238
- @torch.jit.ignore
239
- def set_grad_checkpointing(self, enable=True):
240
- self.transformer.gradient_checkpointing_enable()
241
-
242
- def get_num_layers(self):
243
- encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
244
- layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
245
- return len(layer_list)
246
-
247
- def init_parameters(self):
248
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eva_clip/loss.py DELETED
@@ -1,138 +0,0 @@
1
- import math
2
- import torch
3
- import torch.nn as nn
4
- from torch.nn import functional as F
5
-
6
- try:
7
- import torch.distributed.nn
8
- from torch import distributed as dist
9
- has_distributed = True
10
- except ImportError:
11
- has_distributed = False
12
-
13
- try:
14
- import horovod.torch as hvd
15
- except ImportError:
16
- hvd = None
17
-
18
- from timm.loss import LabelSmoothingCrossEntropy
19
-
20
-
21
- def gather_features(
22
- image_features,
23
- text_features,
24
- local_loss=False,
25
- gather_with_grad=False,
26
- rank=0,
27
- world_size=1,
28
- use_horovod=False
29
- ):
30
- assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
31
- if use_horovod:
32
- assert hvd is not None, 'Please install horovod'
33
- if gather_with_grad:
34
- all_image_features = hvd.allgather(image_features)
35
- all_text_features = hvd.allgather(text_features)
36
- else:
37
- with torch.no_grad():
38
- all_image_features = hvd.allgather(image_features)
39
- all_text_features = hvd.allgather(text_features)
40
- if not local_loss:
41
- # ensure grads for local rank when all_* features don't have a gradient
42
- gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
43
- gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
44
- gathered_image_features[rank] = image_features
45
- gathered_text_features[rank] = text_features
46
- all_image_features = torch.cat(gathered_image_features, dim=0)
47
- all_text_features = torch.cat(gathered_text_features, dim=0)
48
- else:
49
- # We gather tensors from all gpus
50
- if gather_with_grad:
51
- all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
52
- all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
53
- # all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features, async_op=True), dim=0)
54
- # all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features, async_op=True), dim=0)
55
- else:
56
- gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
57
- gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
58
- dist.all_gather(gathered_image_features, image_features)
59
- dist.all_gather(gathered_text_features, text_features)
60
- if not local_loss:
61
- # ensure grads for local rank when all_* features don't have a gradient
62
- gathered_image_features[rank] = image_features
63
- gathered_text_features[rank] = text_features
64
- all_image_features = torch.cat(gathered_image_features, dim=0)
65
- all_text_features = torch.cat(gathered_text_features, dim=0)
66
-
67
- return all_image_features, all_text_features
68
-
69
-
70
- class ClipLoss(nn.Module):
71
-
72
- def __init__(
73
- self,
74
- local_loss=False,
75
- gather_with_grad=False,
76
- cache_labels=False,
77
- rank=0,
78
- world_size=1,
79
- use_horovod=False,
80
- smoothing=0.,
81
- ):
82
- super().__init__()
83
- self.local_loss = local_loss
84
- self.gather_with_grad = gather_with_grad
85
- self.cache_labels = cache_labels
86
- self.rank = rank
87
- self.world_size = world_size
88
- self.use_horovod = use_horovod
89
- self.label_smoothing_cross_entropy = LabelSmoothingCrossEntropy(smoothing=smoothing) if smoothing > 0 else None
90
-
91
- # cache state
92
- self.prev_num_logits = 0
93
- self.labels = {}
94
-
95
- def forward(self, image_features, text_features, logit_scale=1.):
96
- device = image_features.device
97
- if self.world_size > 1:
98
- all_image_features, all_text_features = gather_features(
99
- image_features, text_features,
100
- self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
101
-
102
- if self.local_loss:
103
- logits_per_image = logit_scale * image_features @ all_text_features.T
104
- logits_per_text = logit_scale * text_features @ all_image_features.T
105
- else:
106
- logits_per_image = logit_scale * all_image_features @ all_text_features.T
107
- logits_per_text = logits_per_image.T
108
- else:
109
- logits_per_image = logit_scale * image_features @ text_features.T
110
- logits_per_text = logit_scale * text_features @ image_features.T
111
- # calculated ground-truth and cache if enabled
112
- num_logits = logits_per_image.shape[0]
113
- if self.prev_num_logits != num_logits or device not in self.labels:
114
- labels = torch.arange(num_logits, device=device, dtype=torch.long)
115
- if self.world_size > 1 and self.local_loss:
116
- labels = labels + num_logits * self.rank
117
- if self.cache_labels:
118
- self.labels[device] = labels
119
- self.prev_num_logits = num_logits
120
- else:
121
- labels = self.labels[device]
122
-
123
- if self.label_smoothing_cross_entropy:
124
- total_loss = (
125
- self.label_smoothing_cross_entropy(logits_per_image, labels) +
126
- self.label_smoothing_cross_entropy(logits_per_text, labels)
127
- ) / 2
128
- else:
129
- total_loss = (
130
- F.cross_entropy(logits_per_image, labels) +
131
- F.cross_entropy(logits_per_text, labels)
132
- ) / 2
133
-
134
- acc = None
135
- i2t_acc = (logits_per_image.argmax(-1) == labels).sum() / len(logits_per_image)
136
- t2i_acc = (logits_per_text.argmax(-1) == labels).sum() / len(logits_per_text)
137
- acc = {"i2t": i2t_acc, "t2i": t2i_acc}
138
- return total_loss, acc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eva_clip/model.py DELETED
@@ -1,439 +0,0 @@
1
- """ CLIP Model
2
-
3
- Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
- """
5
- import os
6
- from dataclasses import dataclass
7
- from typing import Optional, Tuple, Union
8
- from functools import partial
9
-
10
- import numpy as np
11
- import torch
12
- import torch.nn.functional as F
13
- from torch import nn
14
-
15
- try:
16
- from .hf_model import HFTextEncoder
17
- except:
18
- HFTextEncoder = None
19
- from .modified_resnet import ModifiedResNet
20
- from .timm_model import TimmModel
21
- from .eva_vit_model import EVAVisionTransformer
22
- from .transformer import LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
23
-
24
- try:
25
- from apex.normalization import FusedLayerNorm
26
- except:
27
- FusedLayerNorm = LayerNorm
28
- print("Please 'pip install apex'")
29
-
30
- try:
31
- import xformers.ops as xops
32
- except ImportError:
33
- xops = None
34
- print("Please 'pip install xformers'")
35
-
36
- @dataclass
37
- class CLIPVisionCfg:
38
- layers: Union[Tuple[int, int, int, int], int] = 12
39
- width: int = 768
40
- head_width: int = 64
41
- mlp_ratio: float = 4.0
42
- patch_size: int = 16
43
- image_size: Union[Tuple[int, int], int] = 224
44
- ls_init_value: Optional[float] = None # layer scale initial value
45
- patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
46
- global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
47
- drop_path_rate: Optional[float] = None # drop path rate
48
- timm_model_name: str = None # a valid model name overrides layers, width, patch_size
49
- timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
50
- timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
51
- timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
52
- timm_proj_bias: bool = False # enable bias final projection
53
- eva_model_name: str = None # a valid eva model name overrides layers, width, patch_size
54
- qkv_bias: bool = True
55
- fusedLN: bool = False
56
- xattn: bool = False
57
- postnorm: bool = False
58
- rope: bool = False
59
- pt_hw_seq_len: int = 16 # 224/14
60
- intp_freq: bool = False
61
- naiveswiglu: bool = False
62
- subln: bool = False
63
-
64
-
65
- @dataclass
66
- class CLIPTextCfg:
67
- context_length: int = 77
68
- vocab_size: int = 49408
69
- width: int = 512
70
- heads: int = 8
71
- layers: int = 12
72
- ls_init_value: Optional[float] = None # layer scale initial value
73
- hf_model_name: str = None
74
- hf_tokenizer_name: str = None
75
- hf_model_pretrained: bool = True
76
- proj: str = 'mlp'
77
- pooler_type: str = 'mean_pooler'
78
- masked_language_modeling: bool = False
79
- fusedLN: bool = False
80
- xattn: bool = False
81
- attn_mask: bool = True
82
-
83
- def get_cast_dtype(precision: str):
84
- cast_dtype = None
85
- if precision == 'bf16':
86
- cast_dtype = torch.bfloat16
87
- elif precision == 'fp16':
88
- cast_dtype = torch.float16
89
- return cast_dtype
90
-
91
-
92
- def _build_vision_tower(
93
- embed_dim: int,
94
- vision_cfg: CLIPVisionCfg,
95
- quick_gelu: bool = False,
96
- cast_dtype: Optional[torch.dtype] = None
97
- ):
98
- if isinstance(vision_cfg, dict):
99
- vision_cfg = CLIPVisionCfg(**vision_cfg)
100
-
101
- # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
102
- # memory efficient in recent PyTorch releases (>= 1.10).
103
- # NOTE: timm models always use native GELU regardless of quick_gelu flag.
104
- act_layer = QuickGELU if quick_gelu else nn.GELU
105
-
106
- if vision_cfg.eva_model_name:
107
- vision_heads = vision_cfg.width // vision_cfg.head_width
108
- norm_layer = LayerNorm
109
-
110
- visual = EVAVisionTransformer(
111
- img_size=vision_cfg.image_size,
112
- patch_size=vision_cfg.patch_size,
113
- num_classes=embed_dim,
114
- use_mean_pooling=vision_cfg.global_average_pool, #False
115
- init_values=vision_cfg.ls_init_value,
116
- patch_dropout=vision_cfg.patch_dropout,
117
- embed_dim=vision_cfg.width,
118
- depth=vision_cfg.layers,
119
- num_heads=vision_heads,
120
- mlp_ratio=vision_cfg.mlp_ratio,
121
- qkv_bias=vision_cfg.qkv_bias,
122
- drop_path_rate=vision_cfg.drop_path_rate,
123
- norm_layer= partial(FusedLayerNorm, eps=1e-6) if vision_cfg.fusedLN else partial(norm_layer, eps=1e-6),
124
- xattn=vision_cfg.xattn,
125
- rope=vision_cfg.rope,
126
- postnorm=vision_cfg.postnorm,
127
- pt_hw_seq_len= vision_cfg.pt_hw_seq_len, # 224/14
128
- intp_freq= vision_cfg.intp_freq,
129
- naiveswiglu= vision_cfg.naiveswiglu,
130
- subln= vision_cfg.subln
131
- )
132
- elif vision_cfg.timm_model_name:
133
- visual = TimmModel(
134
- vision_cfg.timm_model_name,
135
- pretrained=vision_cfg.timm_model_pretrained,
136
- pool=vision_cfg.timm_pool,
137
- proj=vision_cfg.timm_proj,
138
- proj_bias=vision_cfg.timm_proj_bias,
139
- embed_dim=embed_dim,
140
- image_size=vision_cfg.image_size
141
- )
142
- act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models
143
- elif isinstance(vision_cfg.layers, (tuple, list)):
144
- vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
145
- visual = ModifiedResNet(
146
- layers=vision_cfg.layers,
147
- output_dim=embed_dim,
148
- heads=vision_heads,
149
- image_size=vision_cfg.image_size,
150
- width=vision_cfg.width
151
- )
152
- else:
153
- vision_heads = vision_cfg.width // vision_cfg.head_width
154
- norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
155
- visual = VisionTransformer(
156
- image_size=vision_cfg.image_size,
157
- patch_size=vision_cfg.patch_size,
158
- width=vision_cfg.width,
159
- layers=vision_cfg.layers,
160
- heads=vision_heads,
161
- mlp_ratio=vision_cfg.mlp_ratio,
162
- ls_init_value=vision_cfg.ls_init_value,
163
- patch_dropout=vision_cfg.patch_dropout,
164
- global_average_pool=vision_cfg.global_average_pool,
165
- output_dim=embed_dim,
166
- act_layer=act_layer,
167
- norm_layer=norm_layer,
168
- )
169
-
170
- return visual
171
-
172
-
173
- def _build_text_tower(
174
- embed_dim: int,
175
- text_cfg: CLIPTextCfg,
176
- quick_gelu: bool = False,
177
- cast_dtype: Optional[torch.dtype] = None,
178
- ):
179
- if isinstance(text_cfg, dict):
180
- text_cfg = CLIPTextCfg(**text_cfg)
181
-
182
- if text_cfg.hf_model_name:
183
- text = HFTextEncoder(
184
- text_cfg.hf_model_name,
185
- output_dim=embed_dim,
186
- tokenizer_name=text_cfg.hf_tokenizer_name,
187
- proj=text_cfg.proj,
188
- pooler_type=text_cfg.pooler_type,
189
- masked_language_modeling=text_cfg.masked_language_modeling
190
- )
191
- else:
192
- act_layer = QuickGELU if quick_gelu else nn.GELU
193
- norm_layer = LayerNorm
194
-
195
- text = TextTransformer(
196
- context_length=text_cfg.context_length,
197
- vocab_size=text_cfg.vocab_size,
198
- width=text_cfg.width,
199
- heads=text_cfg.heads,
200
- layers=text_cfg.layers,
201
- ls_init_value=text_cfg.ls_init_value,
202
- output_dim=embed_dim,
203
- act_layer=act_layer,
204
- norm_layer= FusedLayerNorm if text_cfg.fusedLN else norm_layer,
205
- xattn=text_cfg.xattn,
206
- attn_mask=text_cfg.attn_mask,
207
- )
208
- return text
209
-
210
- class CLIP(nn.Module):
211
- def __init__(
212
- self,
213
- embed_dim: int,
214
- vision_cfg: CLIPVisionCfg,
215
- text_cfg: CLIPTextCfg,
216
- quick_gelu: bool = False,
217
- cast_dtype: Optional[torch.dtype] = None,
218
- ):
219
- super().__init__()
220
- self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
221
-
222
- text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
223
- self.transformer = text.transformer
224
- self.vocab_size = text.vocab_size
225
- self.token_embedding = text.token_embedding
226
- self.positional_embedding = text.positional_embedding
227
- self.ln_final = text.ln_final
228
- self.text_projection = text.text_projection
229
- self.register_buffer('attn_mask', text.attn_mask, persistent=False)
230
-
231
- self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
232
-
233
- def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
234
- # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
235
- self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
236
-
237
- @torch.jit.ignore
238
- def set_grad_checkpointing(self, enable=True):
239
- self.visual.set_grad_checkpointing(enable)
240
- self.transformer.grad_checkpointing = enable
241
-
242
- @torch.jit.ignore
243
- def no_weight_decay(self):
244
- return {'logit_scale'}
245
-
246
- def encode_image(self, image, normalize: bool = False):
247
- features = self.visual(image)
248
- return F.normalize(features, dim=-1) if normalize else features
249
-
250
- def encode_text(self, text, normalize: bool = False):
251
- cast_dtype = self.transformer.get_cast_dtype()
252
-
253
- x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
254
-
255
- x = x + self.positional_embedding.to(cast_dtype)
256
- x = x.permute(1, 0, 2) # NLD -> LND
257
- x = self.transformer(x, attn_mask=self.attn_mask)
258
- x = x.permute(1, 0, 2) # LND -> NLD
259
- x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
260
- # take features from the eot embedding (eot_token is the highest number in each sequence)
261
- x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
262
- return F.normalize(x, dim=-1) if normalize else x
263
-
264
- def forward(self, image, text):
265
- image_features = self.encode_image(image, normalize=True)
266
- text_features = self.encode_text(text, normalize=True)
267
- return image_features, text_features, self.logit_scale.exp()
268
-
269
-
270
- class CustomCLIP(nn.Module):
271
- def __init__(
272
- self,
273
- embed_dim: int,
274
- vision_cfg: CLIPVisionCfg,
275
- text_cfg: CLIPTextCfg,
276
- quick_gelu: bool = False,
277
- cast_dtype: Optional[torch.dtype] = None,
278
- itm_task: bool = False,
279
- ):
280
- super().__init__()
281
- self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
282
- self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
283
- self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
284
-
285
- def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
286
- # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
287
- self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
288
-
289
- def lock_text_tower(self, unlocked_layers:int=0, freeze_layer_norm:bool=True):
290
- self.text.lock(unlocked_layers, freeze_layer_norm)
291
-
292
- @torch.jit.ignore
293
- def set_grad_checkpointing(self, enable=True):
294
- self.visual.set_grad_checkpointing(enable)
295
- self.text.set_grad_checkpointing(enable)
296
-
297
- @torch.jit.ignore
298
- def no_weight_decay(self):
299
- return {'logit_scale'}
300
-
301
- def encode_image(self, image, normalize: bool = False):
302
- features = self.visual(image)
303
- return F.normalize(features, dim=-1) if normalize else features
304
-
305
- def encode_text(self, text, normalize: bool = False):
306
- features = self.text(text)
307
- return F.normalize(features, dim=-1) if normalize else features
308
-
309
- def forward(self, image, text):
310
- image_features = self.encode_image(image, normalize=True)
311
- text_features = self.encode_text(text, normalize=True)
312
- return image_features, text_features, self.logit_scale.exp()
313
-
314
-
315
- def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
316
- """Convert applicable model parameters to low-precision (bf16 or fp16)"""
317
-
318
- def _convert_weights(l):
319
-
320
- if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
321
- l.weight.data = l.weight.data.to(dtype)
322
- if l.bias is not None:
323
- l.bias.data = l.bias.data.to(dtype)
324
-
325
- if isinstance(l, (nn.MultiheadAttention, Attention)):
326
- for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
327
- tensor = getattr(l, attr, None)
328
- if tensor is not None:
329
- tensor.data = tensor.data.to(dtype)
330
-
331
- if isinstance(l, nn.Parameter):
332
- l.data = l.data.to(dtype)
333
-
334
- for name in ["text_projection", "proj"]:
335
- if hasattr(l, name) and isinstance(l, nn.Parameter):
336
- attr = getattr(l, name, None)
337
- if attr is not None:
338
- attr.data = attr.data.to(dtype)
339
-
340
- model.apply(_convert_weights)
341
-
342
-
343
- convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
344
-
345
-
346
- # used to maintain checkpoint compatibility
347
- def convert_to_custom_text_state_dict(state_dict: dict):
348
- if 'text_projection' in state_dict:
349
- # old format state_dict, move text tower -> .text
350
- new_state_dict = {}
351
- for k, v in state_dict.items():
352
- if any(k.startswith(p) for p in (
353
- 'text_projection',
354
- 'positional_embedding',
355
- 'token_embedding',
356
- 'transformer',
357
- 'ln_final',
358
- 'logit_scale'
359
- )):
360
- k = 'text.' + k
361
- new_state_dict[k] = v
362
- return new_state_dict
363
- return state_dict
364
-
365
-
366
- def build_model_from_openai_state_dict(
367
- state_dict: dict,
368
- quick_gelu=True,
369
- cast_dtype=torch.float16,
370
- ):
371
- vit = "visual.proj" in state_dict
372
-
373
- if vit:
374
- vision_width = state_dict["visual.conv1.weight"].shape[0]
375
- vision_layers = len(
376
- [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
377
- vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
378
- grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
379
- image_size = vision_patch_size * grid_size
380
- else:
381
- counts: list = [
382
- len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
383
- vision_layers = tuple(counts)
384
- vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
385
- output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
386
- vision_patch_size = None
387
- assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
388
- image_size = output_width * 32
389
-
390
- embed_dim = state_dict["text_projection"].shape[1]
391
- context_length = state_dict["positional_embedding"].shape[0]
392
- vocab_size = state_dict["token_embedding.weight"].shape[0]
393
- transformer_width = state_dict["ln_final.weight"].shape[0]
394
- transformer_heads = transformer_width // 64
395
- transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
396
-
397
- vision_cfg = CLIPVisionCfg(
398
- layers=vision_layers,
399
- width=vision_width,
400
- patch_size=vision_patch_size,
401
- image_size=image_size,
402
- )
403
- text_cfg = CLIPTextCfg(
404
- context_length=context_length,
405
- vocab_size=vocab_size,
406
- width=transformer_width,
407
- heads=transformer_heads,
408
- layers=transformer_layers
409
- )
410
- model = CLIP(
411
- embed_dim,
412
- vision_cfg=vision_cfg,
413
- text_cfg=text_cfg,
414
- quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
415
- cast_dtype=cast_dtype,
416
- )
417
-
418
- for key in ["input_resolution", "context_length", "vocab_size"]:
419
- state_dict.pop(key, None)
420
-
421
- convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
422
- model.load_state_dict(state_dict)
423
- return model.eval()
424
-
425
-
426
- def trace_model(model, batch_size=256, device=torch.device('cpu')):
427
- model.eval()
428
- image_size = model.visual.image_size
429
- example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
430
- example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
431
- model = torch.jit.trace_module(
432
- model,
433
- inputs=dict(
434
- forward=(example_images, example_text),
435
- encode_text=(example_text,),
436
- encode_image=(example_images,)
437
- ))
438
- model.visual.image_size = image_size
439
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eva_clip/model_configs/EVA01-CLIP-B-16.json DELETED
@@ -1,19 +0,0 @@
1
- {
2
- "embed_dim": 512,
3
- "vision_cfg": {
4
- "image_size": 224,
5
- "layers": 12,
6
- "width": 768,
7
- "patch_size": 16,
8
- "eva_model_name": "eva-clip-b-16",
9
- "ls_init_value": 0.1,
10
- "drop_path_rate": 0.0
11
- },
12
- "text_cfg": {
13
- "context_length": 77,
14
- "vocab_size": 49408,
15
- "width": 512,
16
- "heads": 8,
17
- "layers": 12
18
- }
19
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eva_clip/model_configs/EVA01-CLIP-g-14-plus.json DELETED
@@ -1,24 +0,0 @@
1
- {
2
- "embed_dim": 1024,
3
- "vision_cfg": {
4
- "image_size": 224,
5
- "layers": 40,
6
- "width": 1408,
7
- "head_width": 88,
8
- "mlp_ratio": 4.3637,
9
- "patch_size": 14,
10
- "eva_model_name": "eva-clip-g-14-x",
11
- "drop_path_rate": 0,
12
- "xattn": true,
13
- "fusedLN": true
14
- },
15
- "text_cfg": {
16
- "context_length": 77,
17
- "vocab_size": 49408,
18
- "width": 1024,
19
- "heads": 16,
20
- "layers": 24,
21
- "xattn": false,
22
- "fusedLN": true
23
- }
24
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eva_clip/model_configs/EVA01-CLIP-g-14.json DELETED
@@ -1,24 +0,0 @@
1
- {
2
- "embed_dim": 1024,
3
- "vision_cfg": {
4
- "image_size": 224,
5
- "layers": 40,
6
- "width": 1408,
7
- "head_width": 88,
8
- "mlp_ratio": 4.3637,
9
- "patch_size": 14,
10
- "eva_model_name": "eva-clip-g-14-x",
11
- "drop_path_rate": 0.4,
12
- "xattn": true,
13
- "fusedLN": true
14
- },
15
- "text_cfg": {
16
- "context_length": 77,
17
- "vocab_size": 49408,
18
- "width": 768,
19
- "heads": 12,
20
- "layers": 12,
21
- "xattn": false,
22
- "fusedLN": true
23
- }
24
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eva_clip/model_configs/EVA02-CLIP-B-16.json DELETED
@@ -1,29 +0,0 @@
1
- {
2
- "embed_dim": 512,
3
- "vision_cfg": {
4
- "image_size": 224,
5
- "layers": 12,
6
- "width": 768,
7
- "head_width": 64,
8
- "patch_size": 16,
9
- "mlp_ratio": 2.6667,
10
- "eva_model_name": "eva-clip-b-16-X",
11
- "drop_path_rate": 0.0,
12
- "xattn": true,
13
- "fusedLN": true,
14
- "rope": true,
15
- "pt_hw_seq_len": 16,
16
- "intp_freq": true,
17
- "naiveswiglu": true,
18
- "subln": true
19
- },
20
- "text_cfg": {
21
- "context_length": 77,
22
- "vocab_size": 49408,
23
- "width": 512,
24
- "heads": 8,
25
- "layers": 12,
26
- "xattn": true,
27
- "fusedLN": true
28
- }
29
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eva_clip/model_configs/EVA02-CLIP-L-14-336.json DELETED
@@ -1,29 +0,0 @@
1
- {
2
- "embed_dim": 768,
3
- "vision_cfg": {
4
- "image_size": 336,
5
- "layers": 24,
6
- "width": 1024,
7
- "drop_path_rate": 0,
8
- "head_width": 64,
9
- "mlp_ratio": 2.6667,
10
- "patch_size": 14,
11
- "eva_model_name": "eva-clip-l-14-336",
12
- "xattn": true,
13
- "fusedLN": true,
14
- "rope": true,
15
- "pt_hw_seq_len": 16,
16
- "intp_freq": true,
17
- "naiveswiglu": true,
18
- "subln": true
19
- },
20
- "text_cfg": {
21
- "context_length": 77,
22
- "vocab_size": 49408,
23
- "width": 768,
24
- "heads": 12,
25
- "layers": 12,
26
- "xattn": false,
27
- "fusedLN": true
28
- }
29
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eva_clip/model_configs/EVA02-CLIP-L-14.json DELETED
@@ -1,29 +0,0 @@
1
- {
2
- "embed_dim": 768,
3
- "vision_cfg": {
4
- "image_size": 224,
5
- "layers": 24,
6
- "width": 1024,
7
- "drop_path_rate": 0,
8
- "head_width": 64,
9
- "mlp_ratio": 2.6667,
10
- "patch_size": 14,
11
- "eva_model_name": "eva-clip-l-14",
12
- "xattn": true,
13
- "fusedLN": true,
14
- "rope": true,
15
- "pt_hw_seq_len": 16,
16
- "intp_freq": true,
17
- "naiveswiglu": true,
18
- "subln": true
19
- },
20
- "text_cfg": {
21
- "context_length": 77,
22
- "vocab_size": 49408,
23
- "width": 768,
24
- "heads": 12,
25
- "layers": 12,
26
- "xattn": false,
27
- "fusedLN": true
28
- }
29
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json DELETED
@@ -1,25 +0,0 @@
1
- {
2
- "embed_dim": 1024,
3
- "vision_cfg": {
4
- "image_size": 224,
5
- "layers": 64,
6
- "width": 1792,
7
- "head_width": 112,
8
- "mlp_ratio": 8.571428571428571,
9
- "patch_size": 14,
10
- "eva_model_name": "eva-clip-4b-14-x",
11
- "drop_path_rate": 0,
12
- "xattn": true,
13
- "postnorm": true,
14
- "fusedLN": true
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 1280,
20
- "heads": 20,
21
- "layers": 32,
22
- "xattn": false,
23
- "fusedLN": true
24
- }
25
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eva_clip/model_configs/EVA02-CLIP-bigE-14.json DELETED
@@ -1,25 +0,0 @@
1
- {
2
- "embed_dim": 1024,
3
- "vision_cfg": {
4
- "image_size": 224,
5
- "layers": 64,
6
- "width": 1792,
7
- "head_width": 112,
8
- "mlp_ratio": 8.571428571428571,
9
- "patch_size": 14,
10
- "eva_model_name": "eva-clip-4b-14-x",
11
- "drop_path_rate": 0,
12
- "xattn": true,
13
- "postnorm": true,
14
- "fusedLN": true
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 1024,
20
- "heads": 16,
21
- "layers": 24,
22
- "xattn": false,
23
- "fusedLN": true
24
- }
25
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eva_clip/modified_resnet.py DELETED
@@ -1,181 +0,0 @@
1
- from collections import OrderedDict
2
-
3
- import torch
4
- from torch import nn
5
- from torch.nn import functional as F
6
-
7
- from .utils import freeze_batch_norm_2d
8
-
9
-
10
- class Bottleneck(nn.Module):
11
- expansion = 4
12
-
13
- def __init__(self, inplanes, planes, stride=1):
14
- super().__init__()
15
-
16
- # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17
- self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18
- self.bn1 = nn.BatchNorm2d(planes)
19
- self.act1 = nn.ReLU(inplace=True)
20
-
21
- self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
22
- self.bn2 = nn.BatchNorm2d(planes)
23
- self.act2 = nn.ReLU(inplace=True)
24
-
25
- self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
26
-
27
- self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
28
- self.bn3 = nn.BatchNorm2d(planes * self.expansion)
29
- self.act3 = nn.ReLU(inplace=True)
30
-
31
- self.downsample = None
32
- self.stride = stride
33
-
34
- if stride > 1 or inplanes != planes * Bottleneck.expansion:
35
- # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
36
- self.downsample = nn.Sequential(OrderedDict([
37
- ("-1", nn.AvgPool2d(stride)),
38
- ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
39
- ("1", nn.BatchNorm2d(planes * self.expansion))
40
- ]))
41
-
42
- def forward(self, x: torch.Tensor):
43
- identity = x
44
-
45
- out = self.act1(self.bn1(self.conv1(x)))
46
- out = self.act2(self.bn2(self.conv2(out)))
47
- out = self.avgpool(out)
48
- out = self.bn3(self.conv3(out))
49
-
50
- if self.downsample is not None:
51
- identity = self.downsample(x)
52
-
53
- out += identity
54
- out = self.act3(out)
55
- return out
56
-
57
-
58
- class AttentionPool2d(nn.Module):
59
- def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
60
- super().__init__()
61
- self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
62
- self.k_proj = nn.Linear(embed_dim, embed_dim)
63
- self.q_proj = nn.Linear(embed_dim, embed_dim)
64
- self.v_proj = nn.Linear(embed_dim, embed_dim)
65
- self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
66
- self.num_heads = num_heads
67
-
68
- def forward(self, x):
69
- x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
70
- x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
71
- x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
72
- x, _ = F.multi_head_attention_forward(
73
- query=x, key=x, value=x,
74
- embed_dim_to_check=x.shape[-1],
75
- num_heads=self.num_heads,
76
- q_proj_weight=self.q_proj.weight,
77
- k_proj_weight=self.k_proj.weight,
78
- v_proj_weight=self.v_proj.weight,
79
- in_proj_weight=None,
80
- in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
81
- bias_k=None,
82
- bias_v=None,
83
- add_zero_attn=False,
84
- dropout_p=0.,
85
- out_proj_weight=self.c_proj.weight,
86
- out_proj_bias=self.c_proj.bias,
87
- use_separate_proj_weight=True,
88
- training=self.training,
89
- need_weights=False
90
- )
91
-
92
- return x[0]
93
-
94
-
95
- class ModifiedResNet(nn.Module):
96
- """
97
- A ResNet class that is similar to torchvision's but contains the following changes:
98
- - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
99
- - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
100
- - The final pooling layer is a QKV attention instead of an average pool
101
- """
102
-
103
- def __init__(self, layers, output_dim, heads, image_size=224, width=64):
104
- super().__init__()
105
- self.output_dim = output_dim
106
- self.image_size = image_size
107
-
108
- # the 3-layer stem
109
- self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
110
- self.bn1 = nn.BatchNorm2d(width // 2)
111
- self.act1 = nn.ReLU(inplace=True)
112
- self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
113
- self.bn2 = nn.BatchNorm2d(width // 2)
114
- self.act2 = nn.ReLU(inplace=True)
115
- self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
116
- self.bn3 = nn.BatchNorm2d(width)
117
- self.act3 = nn.ReLU(inplace=True)
118
- self.avgpool = nn.AvgPool2d(2)
119
-
120
- # residual layers
121
- self._inplanes = width # this is a *mutable* variable used during construction
122
- self.layer1 = self._make_layer(width, layers[0])
123
- self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
124
- self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
125
- self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
126
-
127
- embed_dim = width * 32 # the ResNet feature dimension
128
- self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
129
-
130
- self.init_parameters()
131
-
132
- def _make_layer(self, planes, blocks, stride=1):
133
- layers = [Bottleneck(self._inplanes, planes, stride)]
134
-
135
- self._inplanes = planes * Bottleneck.expansion
136
- for _ in range(1, blocks):
137
- layers.append(Bottleneck(self._inplanes, planes))
138
-
139
- return nn.Sequential(*layers)
140
-
141
- def init_parameters(self):
142
- if self.attnpool is not None:
143
- std = self.attnpool.c_proj.in_features ** -0.5
144
- nn.init.normal_(self.attnpool.q_proj.weight, std=std)
145
- nn.init.normal_(self.attnpool.k_proj.weight, std=std)
146
- nn.init.normal_(self.attnpool.v_proj.weight, std=std)
147
- nn.init.normal_(self.attnpool.c_proj.weight, std=std)
148
-
149
- for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
150
- for name, param in resnet_block.named_parameters():
151
- if name.endswith("bn3.weight"):
152
- nn.init.zeros_(param)
153
-
154
- def lock(self, unlocked_groups=0, freeze_bn_stats=False):
155
- assert unlocked_groups == 0, 'partial locking not currently supported for this model'
156
- for param in self.parameters():
157
- param.requires_grad = False
158
- if freeze_bn_stats:
159
- freeze_batch_norm_2d(self)
160
-
161
- @torch.jit.ignore
162
- def set_grad_checkpointing(self, enable=True):
163
- # FIXME support for non-transformer
164
- pass
165
-
166
- def stem(self, x):
167
- x = self.act1(self.bn1(self.conv1(x)))
168
- x = self.act2(self.bn2(self.conv2(x)))
169
- x = self.act3(self.bn3(self.conv3(x)))
170
- x = self.avgpool(x)
171
- return x
172
-
173
- def forward(self, x):
174
- x = self.stem(x)
175
- x = self.layer1(x)
176
- x = self.layer2(x)
177
- x = self.layer3(x)
178
- x = self.layer4(x)
179
- x = self.attnpool(x)
180
-
181
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eva_clip/openai.py DELETED
@@ -1,144 +0,0 @@
1
- """ OpenAI pretrained model functions
2
-
3
- Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
- """
5
-
6
- import os
7
- import warnings
8
- from typing import List, Optional, Union
9
-
10
- import torch
11
-
12
- from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype
13
- from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url
14
-
15
- __all__ = ["list_openai_models", "load_openai_model"]
16
-
17
-
18
- def list_openai_models() -> List[str]:
19
- """Returns the names of available CLIP models"""
20
- return list_pretrained_models_by_tag('openai')
21
-
22
-
23
- def load_openai_model(
24
- name: str,
25
- precision: Optional[str] = None,
26
- device: Optional[Union[str, torch.device]] = None,
27
- jit: bool = True,
28
- cache_dir: Optional[str] = None,
29
- ):
30
- """Load a CLIP model
31
-
32
- Parameters
33
- ----------
34
- name : str
35
- A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
36
- precision: str
37
- Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
38
- device : Union[str, torch.device]
39
- The device to put the loaded model
40
- jit : bool
41
- Whether to load the optimized JIT model (default) or more hackable non-JIT model.
42
- cache_dir : Optional[str]
43
- The directory to cache the downloaded model weights
44
-
45
- Returns
46
- -------
47
- model : torch.nn.Module
48
- The CLIP model
49
- preprocess : Callable[[PIL.Image], torch.Tensor]
50
- A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
51
- """
52
- if device is None:
53
- device = "cuda" if torch.cuda.is_available() else "cpu"
54
- if precision is None:
55
- precision = 'fp32' if device == 'cpu' else 'fp16'
56
-
57
- if get_pretrained_url(name, 'openai'):
58
- model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir)
59
- elif os.path.isfile(name):
60
- model_path = name
61
- else:
62
- raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
63
-
64
- try:
65
- # loading JIT archive
66
- model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
67
- state_dict = None
68
- except RuntimeError:
69
- # loading saved state dict
70
- if jit:
71
- warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
72
- jit = False
73
- state_dict = torch.load(model_path, map_location="cpu")
74
-
75
- if not jit:
76
- # Build a non-jit model from the OpenAI jitted model state dict
77
- cast_dtype = get_cast_dtype(precision)
78
- try:
79
- model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype)
80
- except KeyError:
81
- sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
82
- model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype)
83
-
84
- # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use
85
- model = model.to(device)
86
- if precision.startswith('amp') or precision == 'fp32':
87
- model.float()
88
- elif precision == 'bf16':
89
- convert_weights_to_lp(model, dtype=torch.bfloat16)
90
-
91
- return model
92
-
93
- # patch the device names
94
- device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
95
- device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
96
-
97
- def patch_device(module):
98
- try:
99
- graphs = [module.graph] if hasattr(module, "graph") else []
100
- except RuntimeError:
101
- graphs = []
102
-
103
- if hasattr(module, "forward1"):
104
- graphs.append(module.forward1.graph)
105
-
106
- for graph in graphs:
107
- for node in graph.findAllNodes("prim::Constant"):
108
- if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
109
- node.copyAttributes(device_node)
110
-
111
- model.apply(patch_device)
112
- patch_device(model.encode_image)
113
- patch_device(model.encode_text)
114
-
115
- # patch dtype to float32 (typically for CPU)
116
- if precision == 'fp32':
117
- float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
118
- float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
119
- float_node = float_input.node()
120
-
121
- def patch_float(module):
122
- try:
123
- graphs = [module.graph] if hasattr(module, "graph") else []
124
- except RuntimeError:
125
- graphs = []
126
-
127
- if hasattr(module, "forward1"):
128
- graphs.append(module.forward1.graph)
129
-
130
- for graph in graphs:
131
- for node in graph.findAllNodes("aten::to"):
132
- inputs = list(node.inputs())
133
- for i in [1, 2]: # dtype can be the second or third argument to aten::to()
134
- if inputs[i].node()["value"] == 5:
135
- inputs[i].node().copyAttributes(float_node)
136
-
137
- model.apply(patch_float)
138
- patch_float(model.encode_image)
139
- patch_float(model.encode_text)
140
- model.float()
141
-
142
- # ensure image_size attr available at consistent location for both jit and non-jit
143
- model.visual.image_size = model.input_resolution.item()
144
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eva_clip/pretrained.py DELETED
@@ -1,332 +0,0 @@
1
- import hashlib
2
- import os
3
- import urllib
4
- import warnings
5
- from functools import partial
6
- from typing import Dict, Union
7
-
8
- from tqdm import tqdm
9
-
10
- try:
11
- from huggingface_hub import hf_hub_download
12
- _has_hf_hub = True
13
- except ImportError:
14
- hf_hub_download = None
15
- _has_hf_hub = False
16
-
17
-
18
- def _pcfg(url='', hf_hub='', filename='', mean=None, std=None):
19
- return dict(
20
- url=url,
21
- hf_hub=hf_hub,
22
- mean=mean,
23
- std=std,
24
- )
25
-
26
- _VITB32 = dict(
27
- openai=_pcfg(
28
- "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
29
- laion400m_e31=_pcfg(
30
- "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
31
- laion400m_e32=_pcfg(
32
- "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
33
- laion2b_e16=_pcfg(
34
- "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"),
35
- laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/')
36
- )
37
-
38
- _VITB32_quickgelu = dict(
39
- openai=_pcfg(
40
- "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
41
- laion400m_e31=_pcfg(
42
- "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
43
- laion400m_e32=_pcfg(
44
- "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
45
- )
46
-
47
- _VITB16 = dict(
48
- openai=_pcfg(
49
- "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"),
50
- laion400m_e31=_pcfg(
51
- "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"),
52
- laion400m_e32=_pcfg(
53
- "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"),
54
- laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'),
55
- )
56
-
57
- _EVAB16 = dict(
58
- eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt'),
59
- eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt'),
60
- eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt'),
61
- eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt'),
62
- )
63
-
64
- _VITB16_PLUS_240 = dict(
65
- laion400m_e31=_pcfg(
66
- "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"),
67
- laion400m_e32=_pcfg(
68
- "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"),
69
- )
70
-
71
- _VITL14 = dict(
72
- openai=_pcfg(
73
- "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"),
74
- laion400m_e31=_pcfg(
75
- "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"),
76
- laion400m_e32=_pcfg(
77
- "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"),
78
- laion2b_s32b_b82k=_pcfg(
79
- hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/',
80
- mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
81
- )
82
-
83
- _EVAL14 = dict(
84
- eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_L_psz14.pt'),
85
- eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_L_psz14.pt'),
86
- eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt'),
87
- eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt'),
88
- )
89
-
90
- _VITL14_336 = dict(
91
- openai=_pcfg(
92
- "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"),
93
- )
94
-
95
- _EVAL14_336 = dict(
96
- eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt'),
97
- eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt'),
98
- eva_clip_224to336=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt'),
99
- eva02_clip_224to336=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt'),
100
- )
101
-
102
- _VITH14 = dict(
103
- laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'),
104
- )
105
-
106
- _VITg14 = dict(
107
- laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'),
108
- laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'),
109
- )
110
-
111
- _EVAg14 = dict(
112
- eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/'),
113
- eva01=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_g_psz14.pt'),
114
- eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt'),
115
- eva01_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt'),
116
- )
117
-
118
- _EVAg14_PLUS = dict(
119
- eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/'),
120
- eva01=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_g_psz14.pt'),
121
- eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt'),
122
- eva01_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt'),
123
- )
124
-
125
- _VITbigG14 = dict(
126
- laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'),
127
- )
128
-
129
- _EVAbigE14 = dict(
130
- eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'),
131
- eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'),
132
- eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt'),
133
- eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt'),
134
- )
135
-
136
- _EVAbigE14_PLUS = dict(
137
- eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'),
138
- eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'),
139
- eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt'),
140
- eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt'),
141
- )
142
-
143
-
144
- _PRETRAINED = {
145
- # "ViT-B-32": _VITB32,
146
- "OpenaiCLIP-B-32": _VITB32,
147
- "OpenCLIP-B-32": _VITB32,
148
-
149
- # "ViT-B-32-quickgelu": _VITB32_quickgelu,
150
- "OpenaiCLIP-B-32-quickgelu": _VITB32_quickgelu,
151
- "OpenCLIP-B-32-quickgelu": _VITB32_quickgelu,
152
-
153
- # "ViT-B-16": _VITB16,
154
- "OpenaiCLIP-B-16": _VITB16,
155
- "OpenCLIP-B-16": _VITB16,
156
-
157
- "EVA02-B-16": _EVAB16,
158
- "EVA02-CLIP-B-16": _EVAB16,
159
-
160
- # "ViT-B-16-plus-240": _VITB16_PLUS_240,
161
- "OpenCLIP-B-16-plus-240": _VITB16_PLUS_240,
162
-
163
- # "ViT-L-14": _VITL14,
164
- "OpenaiCLIP-L-14": _VITL14,
165
- "OpenCLIP-L-14": _VITL14,
166
-
167
- "EVA02-L-14": _EVAL14,
168
- "EVA02-CLIP-L-14": _EVAL14,
169
-
170
- # "ViT-L-14-336": _VITL14_336,
171
- "OpenaiCLIP-L-14-336": _VITL14_336,
172
-
173
- "EVA02-CLIP-L-14-336": _EVAL14_336,
174
-
175
- # "ViT-H-14": _VITH14,
176
- # "ViT-g-14": _VITg14,
177
- "OpenCLIP-H-14": _VITH14,
178
- "OpenCLIP-g-14": _VITg14,
179
-
180
- "EVA01-CLIP-g-14": _EVAg14,
181
- "EVA01-CLIP-g-14-plus": _EVAg14_PLUS,
182
-
183
- # "ViT-bigG-14": _VITbigG14,
184
- "OpenCLIP-bigG-14": _VITbigG14,
185
-
186
- "EVA02-CLIP-bigE-14": _EVAbigE14,
187
- "EVA02-CLIP-bigE-14-plus": _EVAbigE14_PLUS,
188
- }
189
-
190
-
191
- def _clean_tag(tag: str):
192
- # normalize pretrained tags
193
- return tag.lower().replace('-', '_')
194
-
195
-
196
- def list_pretrained(as_str: bool = False):
197
- """ returns list of pretrained models
198
- Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
199
- """
200
- return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()]
201
-
202
-
203
- def list_pretrained_models_by_tag(tag: str):
204
- """ return all models having the specified pretrain tag """
205
- models = []
206
- tag = _clean_tag(tag)
207
- for k in _PRETRAINED.keys():
208
- if tag in _PRETRAINED[k]:
209
- models.append(k)
210
- return models
211
-
212
-
213
- def list_pretrained_tags_by_model(model: str):
214
- """ return all pretrain tags for the specified model architecture """
215
- tags = []
216
- if model in _PRETRAINED:
217
- tags.extend(_PRETRAINED[model].keys())
218
- return tags
219
-
220
-
221
- def is_pretrained_cfg(model: str, tag: str):
222
- if model not in _PRETRAINED:
223
- return False
224
- return _clean_tag(tag) in _PRETRAINED[model]
225
-
226
-
227
- def get_pretrained_cfg(model: str, tag: str):
228
- if model not in _PRETRAINED:
229
- return {}
230
- model_pretrained = _PRETRAINED[model]
231
- return model_pretrained.get(_clean_tag(tag), {})
232
-
233
-
234
- def get_pretrained_url(model: str, tag: str):
235
- cfg = get_pretrained_cfg(model, _clean_tag(tag))
236
- return cfg.get('url', '')
237
-
238
-
239
- def download_pretrained_from_url(
240
- url: str,
241
- cache_dir: Union[str, None] = None,
242
- ):
243
- if not cache_dir:
244
- cache_dir = os.path.expanduser("~/.cache/clip")
245
- os.makedirs(cache_dir, exist_ok=True)
246
- filename = os.path.basename(url)
247
-
248
- if 'openaipublic' in url:
249
- expected_sha256 = url.split("/")[-2]
250
- elif 'mlfoundations' in url:
251
- expected_sha256 = os.path.splitext(filename)[0].split("-")[-1]
252
- else:
253
- expected_sha256 = ''
254
-
255
- download_target = os.path.join(cache_dir, filename)
256
-
257
- if os.path.exists(download_target) and not os.path.isfile(download_target):
258
- raise RuntimeError(f"{download_target} exists and is not a regular file")
259
-
260
- if os.path.isfile(download_target):
261
- if expected_sha256:
262
- if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
263
- return download_target
264
- else:
265
- warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
266
- else:
267
- return download_target
268
-
269
- with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
270
- with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
271
- while True:
272
- buffer = source.read(8192)
273
- if not buffer:
274
- break
275
-
276
- output.write(buffer)
277
- loop.update(len(buffer))
278
-
279
- if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
280
- raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
281
-
282
- return download_target
283
-
284
-
285
- def has_hf_hub(necessary=False):
286
- if not _has_hf_hub and necessary:
287
- # if no HF Hub module installed, and it is necessary to continue, raise error
288
- raise RuntimeError(
289
- 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
290
- return _has_hf_hub
291
-
292
-
293
- def download_pretrained_from_hf(
294
- model_id: str,
295
- filename: str = 'open_clip_pytorch_model.bin',
296
- revision=None,
297
- cache_dir: Union[str, None] = None,
298
- ):
299
- has_hf_hub(True)
300
- cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir)
301
- return cached_file
302
-
303
-
304
- def download_pretrained(
305
- cfg: Dict,
306
- force_hf_hub: bool = False,
307
- cache_dir: Union[str, None] = None,
308
- ):
309
- target = ''
310
- if not cfg:
311
- return target
312
-
313
- download_url = cfg.get('url', '')
314
- download_hf_hub = cfg.get('hf_hub', '')
315
- if download_hf_hub and force_hf_hub:
316
- # use HF hub even if url exists
317
- download_url = ''
318
-
319
- if download_url:
320
- target = download_pretrained_from_url(download_url, cache_dir=cache_dir)
321
- elif download_hf_hub:
322
- has_hf_hub(True)
323
- # we assume the hf_hub entries in pretrained config combine model_id + filename in
324
- # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and
325
- # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'.
326
- model_id, filename = os.path.split(download_hf_hub)
327
- if filename:
328
- target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir)
329
- else:
330
- target = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
331
-
332
- return target
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eva_clip/rope.py DELETED
@@ -1,137 +0,0 @@
1
- from math import pi
2
- import torch
3
- from torch import nn
4
- from einops import rearrange, repeat
5
- import logging
6
-
7
- def broadcat(tensors, dim = -1):
8
- num_tensors = len(tensors)
9
- shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
10
- assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
11
- shape_len = list(shape_lens)[0]
12
- dim = (dim + shape_len) if dim < 0 else dim
13
- dims = list(zip(*map(lambda t: list(t.shape), tensors)))
14
- expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
15
- assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation'
16
- max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
17
- expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
18
- expanded_dims.insert(dim, (dim, dims[dim]))
19
- expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
20
- tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
21
- return torch.cat(tensors, dim = dim)
22
-
23
- def rotate_half(x):
24
- x = rearrange(x, '... (d r) -> ... d r', r = 2)
25
- x1, x2 = x.unbind(dim = -1)
26
- x = torch.stack((-x2, x1), dim = -1)
27
- return rearrange(x, '... d r -> ... (d r)')
28
-
29
-
30
- class VisionRotaryEmbedding(nn.Module):
31
- def __init__(
32
- self,
33
- dim,
34
- pt_seq_len,
35
- ft_seq_len=None,
36
- custom_freqs = None,
37
- freqs_for = 'lang',
38
- theta = 10000,
39
- max_freq = 10,
40
- num_freqs = 1,
41
- ):
42
- super().__init__()
43
- if custom_freqs:
44
- freqs = custom_freqs
45
- elif freqs_for == 'lang':
46
- freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
47
- elif freqs_for == 'pixel':
48
- freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
49
- elif freqs_for == 'constant':
50
- freqs = torch.ones(num_freqs).float()
51
- else:
52
- raise ValueError(f'unknown modality {freqs_for}')
53
-
54
- if ft_seq_len is None: ft_seq_len = pt_seq_len
55
- t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
56
-
57
- freqs_h = torch.einsum('..., f -> ... f', t, freqs)
58
- freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2)
59
-
60
- freqs_w = torch.einsum('..., f -> ... f', t, freqs)
61
- freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2)
62
-
63
- freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1)
64
-
65
- self.register_buffer("freqs_cos", freqs.cos())
66
- self.register_buffer("freqs_sin", freqs.sin())
67
-
68
- logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
69
-
70
- def forward(self, t, start_index = 0):
71
- rot_dim = self.freqs_cos.shape[-1]
72
- end_index = start_index + rot_dim
73
- assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
74
- t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
75
- t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
76
-
77
- return torch.cat((t_left, t, t_right), dim = -1)
78
-
79
- class VisionRotaryEmbeddingFast(nn.Module):
80
- def __init__(
81
- self,
82
- dim,
83
- pt_seq_len,
84
- ft_seq_len=None,
85
- custom_freqs = None,
86
- freqs_for = 'lang',
87
- theta = 10000,
88
- max_freq = 10,
89
- num_freqs = 1,
90
- patch_dropout = 0.
91
- ):
92
- super().__init__()
93
- if custom_freqs:
94
- freqs = custom_freqs
95
- elif freqs_for == 'lang':
96
- freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
97
- elif freqs_for == 'pixel':
98
- freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
99
- elif freqs_for == 'constant':
100
- freqs = torch.ones(num_freqs).float()
101
- else:
102
- raise ValueError(f'unknown modality {freqs_for}')
103
-
104
- if ft_seq_len is None: ft_seq_len = pt_seq_len
105
- t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
106
-
107
- freqs = torch.einsum('..., f -> ... f', t, freqs)
108
- freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
109
- freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1)
110
-
111
- freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
112
- freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
113
-
114
- self.patch_dropout = patch_dropout
115
-
116
- self.register_buffer("freqs_cos", freqs_cos)
117
- self.register_buffer("freqs_sin", freqs_sin)
118
-
119
- logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
120
-
121
- def forward(self, t, patch_indices_keep=None):
122
- if patch_indices_keep is not None:
123
- batch = t.size()[0]
124
- batch_indices = torch.arange(batch)
125
- batch_indices = batch_indices[..., None]
126
-
127
- freqs_cos = repeat(self.freqs_cos, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])
128
- freqs_sin = repeat(self.freqs_sin, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])
129
-
130
- freqs_cos = freqs_cos[batch_indices, patch_indices_keep]
131
- freqs_cos = rearrange(freqs_cos, 'n i m j -> n m i j')
132
- freqs_sin = freqs_sin[batch_indices, patch_indices_keep]
133
- freqs_sin = rearrange(freqs_sin, 'n i m j -> n m i j')
134
-
135
- return t * freqs_cos + rotate_half(t) * freqs_sin
136
-
137
- return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eva_clip/timm_model.py DELETED
@@ -1,122 +0,0 @@
1
- """ timm model adapter
2
-
3
- Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
4
- """
5
- import logging
6
- from collections import OrderedDict
7
-
8
- import torch
9
- import torch.nn as nn
10
-
11
- try:
12
- import timm
13
- from timm.models.layers import Mlp, to_2tuple
14
- try:
15
- # old timm imports < 0.8.1
16
- from timm.models.layers.attention_pool2d import RotAttentionPool2d
17
- from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d
18
- except ImportError:
19
- # new timm imports >= 0.8.1
20
- from timm.layers import RotAttentionPool2d
21
- from timm.layers import AttentionPool2d as AbsAttentionPool2d
22
- except ImportError:
23
- timm = None
24
-
25
- from .utils import freeze_batch_norm_2d
26
-
27
-
28
- class TimmModel(nn.Module):
29
- """ timm model adapter
30
- # FIXME this adapter is a work in progress, may change in ways that break weight compat
31
- """
32
-
33
- def __init__(
34
- self,
35
- model_name,
36
- embed_dim,
37
- image_size=224,
38
- pool='avg',
39
- proj='linear',
40
- proj_bias=False,
41
- drop=0.,
42
- pretrained=False):
43
- super().__init__()
44
- if timm is None:
45
- raise RuntimeError("Please `pip install timm` to use timm models.")
46
-
47
- self.image_size = to_2tuple(image_size)
48
- self.trunk = timm.create_model(model_name, pretrained=pretrained)
49
- feat_size = self.trunk.default_cfg.get('pool_size', None)
50
- feature_ndim = 1 if not feat_size else 2
51
- if pool in ('abs_attn', 'rot_attn'):
52
- assert feature_ndim == 2
53
- # if attn pooling used, remove both classifier and default pool
54
- self.trunk.reset_classifier(0, global_pool='')
55
- else:
56
- # reset global pool if pool config set, otherwise leave as network default
57
- reset_kwargs = dict(global_pool=pool) if pool else {}
58
- self.trunk.reset_classifier(0, **reset_kwargs)
59
- prev_chs = self.trunk.num_features
60
-
61
- head_layers = OrderedDict()
62
- if pool == 'abs_attn':
63
- head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
64
- prev_chs = embed_dim
65
- elif pool == 'rot_attn':
66
- head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
67
- prev_chs = embed_dim
68
- else:
69
- assert proj, 'projection layer needed if non-attention pooling is used.'
70
-
71
- # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
72
- if proj == 'linear':
73
- head_layers['drop'] = nn.Dropout(drop)
74
- head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)
75
- elif proj == 'mlp':
76
- head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop, bias=(True, proj_bias))
77
-
78
- self.head = nn.Sequential(head_layers)
79
-
80
- def lock(self, unlocked_groups=0, freeze_bn_stats=False):
81
- """ lock modules
82
- Args:
83
- unlocked_groups (int): leave last n layer groups unlocked (default: 0)
84
- """
85
- if not unlocked_groups:
86
- # lock full model
87
- for param in self.trunk.parameters():
88
- param.requires_grad = False
89
- if freeze_bn_stats:
90
- freeze_batch_norm_2d(self.trunk)
91
- else:
92
- # NOTE: partial freeze requires latest timm (master) branch and is subject to change
93
- try:
94
- # FIXME import here until API stable and in an official release
95
- from timm.models.helpers import group_parameters, group_modules
96
- except ImportError:
97
- raise RuntimeError(
98
- 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`')
99
- matcher = self.trunk.group_matcher()
100
- gparams = group_parameters(self.trunk, matcher)
101
- max_layer_id = max(gparams.keys())
102
- max_layer_id = max_layer_id - unlocked_groups
103
- for group_idx in range(max_layer_id + 1):
104
- group = gparams[group_idx]
105
- for param in group:
106
- self.trunk.get_parameter(param).requires_grad = False
107
- if freeze_bn_stats:
108
- gmodules = group_modules(self.trunk, matcher, reverse=True)
109
- gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
110
- freeze_batch_norm_2d(self.trunk, gmodules)
111
-
112
- @torch.jit.ignore
113
- def set_grad_checkpointing(self, enable=True):
114
- try:
115
- self.trunk.set_grad_checkpointing(enable)
116
- except Exception as e:
117
- logging.warning('grad checkpointing not supported for this timm image tower, continuing without...')
118
-
119
- def forward(self, x):
120
- x = self.trunk(x)
121
- x = self.head(x)
122
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eva_clip/tokenizer.py DELETED
@@ -1,201 +0,0 @@
1
- """ CLIP tokenizer
2
-
3
- Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
- """
5
- import gzip
6
- import html
7
- import os
8
- from functools import lru_cache
9
- from typing import Union, List
10
-
11
- import ftfy
12
- import regex as re
13
- import torch
14
-
15
- # https://stackoverflow.com/q/62691279
16
- import os
17
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
18
-
19
-
20
- @lru_cache()
21
- def default_bpe():
22
- return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
23
-
24
-
25
- @lru_cache()
26
- def bytes_to_unicode():
27
- """
28
- Returns list of utf-8 byte and a corresponding list of unicode strings.
29
- The reversible bpe codes work on unicode strings.
30
- This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
31
- When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
32
- This is a signficant percentage of your normal, say, 32K bpe vocab.
33
- To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
34
- And avoids mapping to whitespace/control characters the bpe code barfs on.
35
- """
36
- bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
37
- cs = bs[:]
38
- n = 0
39
- for b in range(2**8):
40
- if b not in bs:
41
- bs.append(b)
42
- cs.append(2**8+n)
43
- n += 1
44
- cs = [chr(n) for n in cs]
45
- return dict(zip(bs, cs))
46
-
47
-
48
- def get_pairs(word):
49
- """Return set of symbol pairs in a word.
50
- Word is represented as tuple of symbols (symbols being variable-length strings).
51
- """
52
- pairs = set()
53
- prev_char = word[0]
54
- for char in word[1:]:
55
- pairs.add((prev_char, char))
56
- prev_char = char
57
- return pairs
58
-
59
-
60
- def basic_clean(text):
61
- text = ftfy.fix_text(text)
62
- text = html.unescape(html.unescape(text))
63
- return text.strip()
64
-
65
-
66
- def whitespace_clean(text):
67
- text = re.sub(r'\s+', ' ', text)
68
- text = text.strip()
69
- return text
70
-
71
-
72
- class SimpleTokenizer(object):
73
- def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
74
- self.byte_encoder = bytes_to_unicode()
75
- self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
76
- merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
77
- merges = merges[1:49152-256-2+1]
78
- merges = [tuple(merge.split()) for merge in merges]
79
- vocab = list(bytes_to_unicode().values())
80
- vocab = vocab + [v+'</w>' for v in vocab]
81
- for merge in merges:
82
- vocab.append(''.join(merge))
83
- if not special_tokens:
84
- special_tokens = ['<start_of_text>', '<end_of_text>']
85
- else:
86
- special_tokens = ['<start_of_text>', '<end_of_text>'] + special_tokens
87
- vocab.extend(special_tokens)
88
- self.encoder = dict(zip(vocab, range(len(vocab))))
89
- self.decoder = {v: k for k, v in self.encoder.items()}
90
- self.bpe_ranks = dict(zip(merges, range(len(merges))))
91
- self.cache = {t:t for t in special_tokens}
92
- special = "|".join(special_tokens)
93
- self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
94
-
95
- self.vocab_size = len(self.encoder)
96
- self.all_special_ids = [self.encoder[t] for t in special_tokens]
97
-
98
- def bpe(self, token):
99
- if token in self.cache:
100
- return self.cache[token]
101
- word = tuple(token[:-1]) + ( token[-1] + '</w>',)
102
- pairs = get_pairs(word)
103
-
104
- if not pairs:
105
- return token+'</w>'
106
-
107
- while True:
108
- bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
109
- if bigram not in self.bpe_ranks:
110
- break
111
- first, second = bigram
112
- new_word = []
113
- i = 0
114
- while i < len(word):
115
- try:
116
- j = word.index(first, i)
117
- new_word.extend(word[i:j])
118
- i = j
119
- except:
120
- new_word.extend(word[i:])
121
- break
122
-
123
- if word[i] == first and i < len(word)-1 and word[i+1] == second:
124
- new_word.append(first+second)
125
- i += 2
126
- else:
127
- new_word.append(word[i])
128
- i += 1
129
- new_word = tuple(new_word)
130
- word = new_word
131
- if len(word) == 1:
132
- break
133
- else:
134
- pairs = get_pairs(word)
135
- word = ' '.join(word)
136
- self.cache[token] = word
137
- return word
138
-
139
- def encode(self, text):
140
- bpe_tokens = []
141
- text = whitespace_clean(basic_clean(text)).lower()
142
- for token in re.findall(self.pat, text):
143
- token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
144
- bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
145
- return bpe_tokens
146
-
147
- def decode(self, tokens):
148
- text = ''.join([self.decoder[token] for token in tokens])
149
- text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
150
- return text
151
-
152
-
153
- _tokenizer = SimpleTokenizer()
154
-
155
-
156
- def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
157
- """
158
- Returns the tokenized representation of given input string(s)
159
-
160
- Parameters
161
- ----------
162
- texts : Union[str, List[str]]
163
- An input string or a list of input strings to tokenize
164
- context_length : int
165
- The context length to use; all CLIP models use 77 as the context length
166
-
167
- Returns
168
- -------
169
- A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
170
- """
171
- if isinstance(texts, str):
172
- texts = [texts]
173
-
174
- sot_token = _tokenizer.encoder["<start_of_text>"]
175
- eot_token = _tokenizer.encoder["<end_of_text>"]
176
- all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
177
- result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
178
-
179
- for i, tokens in enumerate(all_tokens):
180
- if len(tokens) > context_length:
181
- tokens = tokens[:context_length] # Truncate
182
- tokens[-1] = eot_token
183
- result[i, :len(tokens)] = torch.tensor(tokens)
184
-
185
- return result
186
-
187
-
188
- class HFTokenizer:
189
- "HuggingFace tokenizer wrapper"
190
- def __init__(self, tokenizer_name:str):
191
- from transformers import AutoTokenizer
192
- self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
193
-
194
- def __call__(self, texts:Union[str, List[str]], context_length:int=77) -> torch.Tensor:
195
- # same cleaning as for default tokenizer, except lowercasing
196
- # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
197
- if isinstance(texts, str):
198
- texts = [texts]
199
- texts = [whitespace_clean(basic_clean(text)) for text in texts]
200
- input_ids = self.tokenizer(texts, return_tensors='pt', max_length=context_length, padding='max_length', truncation=True).input_ids
201
- return input_ids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eva_clip/transform.py DELETED
@@ -1,103 +0,0 @@
1
- from typing import Optional, Sequence, Tuple
2
-
3
- import torch
4
- import torch.nn as nn
5
- import torchvision.transforms.functional as F
6
-
7
- from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
8
- CenterCrop
9
-
10
- from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
11
-
12
-
13
- class ResizeMaxSize(nn.Module):
14
-
15
- def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0):
16
- super().__init__()
17
- if not isinstance(max_size, int):
18
- raise TypeError(f"Size should be int. Got {type(max_size)}")
19
- self.max_size = max_size
20
- self.interpolation = interpolation
21
- self.fn = min if fn == 'min' else min
22
- self.fill = fill
23
-
24
- def forward(self, img):
25
- if isinstance(img, torch.Tensor):
26
- height, width = img.shape[:2]
27
- else:
28
- width, height = img.size
29
- scale = self.max_size / float(max(height, width))
30
- if scale != 1.0:
31
- new_size = tuple(round(dim * scale) for dim in (height, width))
32
- img = F.resize(img, new_size, self.interpolation)
33
- pad_h = self.max_size - new_size[0]
34
- pad_w = self.max_size - new_size[1]
35
- img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill)
36
- return img
37
-
38
-
39
- def _convert_to_rgb(image):
40
- return image.convert('RGB')
41
-
42
-
43
- # class CatGen(nn.Module):
44
- # def __init__(self, num=4):
45
- # self.num = num
46
- # def mixgen_batch(image, text):
47
- # batch_size = image.shape[0]
48
- # index = np.random.permutation(batch_size)
49
-
50
- # cat_images = []
51
- # for i in range(batch_size):
52
- # # image mixup
53
- # image[i,:] = lam * image[i,:] + (1 - lam) * image[index[i],:]
54
- # # text concat
55
- # text[i] = tokenizer((str(text[i]) + " " + str(text[index[i]])))[0]
56
- # text = torch.stack(text)
57
- # return image, text
58
-
59
-
60
- def image_transform(
61
- image_size: int,
62
- is_train: bool,
63
- mean: Optional[Tuple[float, ...]] = None,
64
- std: Optional[Tuple[float, ...]] = None,
65
- resize_longest_max: bool = False,
66
- fill_color: int = 0,
67
- ):
68
- mean = mean or OPENAI_DATASET_MEAN
69
- if not isinstance(mean, (list, tuple)):
70
- mean = (mean,) * 3
71
-
72
- std = std or OPENAI_DATASET_STD
73
- if not isinstance(std, (list, tuple)):
74
- std = (std,) * 3
75
-
76
- if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
77
- # for square size, pass size as int so that Resize() uses aspect preserving shortest edge
78
- image_size = image_size[0]
79
-
80
- normalize = Normalize(mean=mean, std=std)
81
- if is_train:
82
- return Compose([
83
- RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),
84
- _convert_to_rgb,
85
- ToTensor(),
86
- normalize,
87
- ])
88
- else:
89
- if resize_longest_max:
90
- transforms = [
91
- ResizeMaxSize(image_size, fill=fill_color)
92
- ]
93
- else:
94
- transforms = [
95
- Resize(image_size, interpolation=InterpolationMode.BICUBIC),
96
- CenterCrop(image_size),
97
- ]
98
- transforms.extend([
99
- _convert_to_rgb,
100
- ToTensor(),
101
- normalize,
102
- ])
103
- return Compose(transforms)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eva_clip/transformer.py DELETED
@@ -1,737 +0,0 @@
1
- import os
2
- import logging
3
- from collections import OrderedDict
4
- import math
5
- from typing import Callable, Optional, Sequence
6
- import numpy as np
7
- import torch
8
- from torch import nn
9
- from torch.nn import functional as F
10
-
11
- try:
12
- from timm.models.layers import trunc_normal_
13
- except:
14
- from timm.layers import trunc_normal_
15
-
16
- from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast
17
- from .utils import to_2tuple
18
-
19
- if os.getenv('ENV_TYPE') == 'deepspeed':
20
- try:
21
- import deepspeed
22
- from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
23
- except:
24
- print("Please 'pip install deepspeed'")
25
- deepspeed = None
26
- from torch.utils.checkpoint import checkpoint
27
- else:
28
- from torch.utils.checkpoint import checkpoint
29
-
30
- try:
31
- import xformers.ops as xops
32
- except ImportError:
33
- xops = None
34
- print("Please 'pip install xformers'")
35
-
36
- class LayerNormFp32(nn.LayerNorm):
37
- """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
38
- def __init__(self, *args, **kwargs):
39
- super().__init__(*args, **kwargs)
40
-
41
- def forward(self, x: torch.Tensor):
42
- output = F.layer_norm(
43
- x.float(),
44
- self.normalized_shape,
45
- self.weight.float() if self.weight is not None else None,
46
- self.bias.float() if self.bias is not None else None,
47
- self.eps,
48
- )
49
- return output.type_as(x)
50
-
51
-
52
- class LayerNorm(nn.LayerNorm):
53
- """Subclass torch's LayerNorm (with cast back to input dtype)."""
54
-
55
- def forward(self, x: torch.Tensor):
56
- orig_type = x.dtype
57
- x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
58
- return x.to(orig_type)
59
-
60
- class QuickGELU(nn.Module):
61
- # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
62
- def forward(self, x: torch.Tensor):
63
- return x * torch.sigmoid(1.702 * x)
64
-
65
-
66
- class LayerScale(nn.Module):
67
- def __init__(self, dim, init_values=1e-5, inplace=False):
68
- super().__init__()
69
- self.inplace = inplace
70
- self.gamma = nn.Parameter(init_values * torch.ones(dim))
71
-
72
- def forward(self, x):
73
- return x.mul_(self.gamma) if self.inplace else x * self.gamma
74
-
75
- class PatchDropout(nn.Module):
76
- """
77
- https://arxiv.org/abs/2212.00794
78
- """
79
-
80
- def __init__(self, prob, exclude_first_token=True):
81
- super().__init__()
82
- assert 0 <= prob < 1.
83
- self.prob = prob
84
- self.exclude_first_token = exclude_first_token # exclude CLS token
85
- logging.info(f"os.getenv('RoPE')={os.getenv('RoPE')}")
86
-
87
- def forward(self, x):
88
- if not self.training or self.prob == 0.:
89
- return x
90
-
91
- if self.exclude_first_token:
92
- cls_tokens, x = x[:, :1], x[:, 1:]
93
- else:
94
- cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
95
-
96
- batch = x.size()[0]
97
- num_tokens = x.size()[1]
98
-
99
- batch_indices = torch.arange(batch)
100
- batch_indices = batch_indices[..., None]
101
-
102
- keep_prob = 1 - self.prob
103
- num_patches_keep = max(1, int(num_tokens * keep_prob))
104
-
105
- rand = torch.randn(batch, num_tokens)
106
- patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
107
-
108
- x = x[batch_indices, patch_indices_keep]
109
-
110
- if self.exclude_first_token:
111
- x = torch.cat((cls_tokens, x), dim=1)
112
-
113
- if self.training and os.getenv('RoPE') == '1':
114
- return x, patch_indices_keep
115
-
116
- return x
117
-
118
-
119
- def _in_projection_packed(
120
- q: torch.Tensor,
121
- k: torch.Tensor,
122
- v: torch.Tensor,
123
- w: torch.Tensor,
124
- b: Optional[torch.Tensor] = None,
125
- ):
126
- """
127
- https://github.com/pytorch/pytorch/blob/db2a237763eb8693a20788be94f8c192e762baa8/torch/nn/functional.py#L4726
128
- """
129
- E = q.size(-1)
130
- if k is v:
131
- if q is k:
132
- # self-attention
133
- return F.linear(q, w, b).chunk(3, dim=-1)
134
- else:
135
- # encoder-decoder attention
136
- w_q, w_kv = w.split([E, E * 2])
137
- if b is None:
138
- b_q = b_kv = None
139
- else:
140
- b_q, b_kv = b.split([E, E * 2])
141
- return (F.linear(q, w_q, b_q),) + F.linear(k, w_kv, b_kv).chunk(2, dim=-1)
142
- else:
143
- w_q, w_k, w_v = w.chunk(3)
144
- if b is None:
145
- b_q = b_k = b_v = None
146
- else:
147
- b_q, b_k, b_v = b.chunk(3)
148
- return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)
149
-
150
- class Attention(nn.Module):
151
- def __init__(
152
- self,
153
- dim,
154
- num_heads=8,
155
- qkv_bias=True,
156
- scaled_cosine=False,
157
- scale_heads=False,
158
- logit_scale_max=math.log(1. / 0.01),
159
- attn_drop=0.,
160
- proj_drop=0.,
161
- xattn=False,
162
- rope=False
163
- ):
164
- super().__init__()
165
- self.scaled_cosine = scaled_cosine
166
- self.scale_heads = scale_heads
167
- assert dim % num_heads == 0, 'dim should be divisible by num_heads'
168
- self.num_heads = num_heads
169
- self.head_dim = dim // num_heads
170
- self.scale = self.head_dim ** -0.5
171
- self.logit_scale_max = logit_scale_max
172
-
173
- # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
174
- self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
175
- if qkv_bias:
176
- self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
177
- else:
178
- self.in_proj_bias = None
179
-
180
- if self.scaled_cosine:
181
- self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
182
- else:
183
- self.logit_scale = None
184
- self.attn_drop = nn.Dropout(attn_drop)
185
- if self.scale_heads:
186
- self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
187
- else:
188
- self.head_scale = None
189
- self.out_proj = nn.Linear(dim, dim)
190
- self.out_drop = nn.Dropout(proj_drop)
191
- self.xattn = xattn
192
- self.xattn_drop = attn_drop
193
- self.rope = rope
194
-
195
- def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
196
- L, N, C = x.shape
197
- q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
198
- if self.xattn:
199
- q = q.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
200
- k = k.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
201
- v = v.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
202
-
203
- x = xops.memory_efficient_attention(
204
- q, k, v,
205
- p=self.xattn_drop,
206
- scale=self.scale if self.logit_scale is None else None,
207
- attn_bias=xops.LowerTriangularMask() if attn_mask is not None else None,
208
- )
209
- else:
210
- q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
211
- k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
212
- v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
213
-
214
- if self.logit_scale is not None:
215
- attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
216
- logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
217
- attn = attn.view(N, self.num_heads, L, L) * logit_scale
218
- attn = attn.view(-1, L, L)
219
- else:
220
- q = q * self.scale
221
- attn = torch.bmm(q, k.transpose(-1, -2))
222
-
223
- if attn_mask is not None:
224
- if attn_mask.dtype == torch.bool:
225
- new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
226
- new_attn_mask.masked_fill_(attn_mask, float("-inf"))
227
- attn_mask = new_attn_mask
228
- attn += attn_mask
229
-
230
- attn = attn.softmax(dim=-1)
231
- attn = self.attn_drop(attn)
232
-
233
- x = torch.bmm(attn, v)
234
-
235
- if self.head_scale is not None:
236
- x = x.view(N, self.num_heads, L, C) * self.head_scale
237
- x = x.view(-1, L, C)
238
- x = x.transpose(0, 1).reshape(L, N, C)
239
- x = self.out_proj(x)
240
- x = self.out_drop(x)
241
- return x
242
-
243
- class CustomAttention(nn.Module):
244
- def __init__(
245
- self,
246
- dim,
247
- num_heads=8,
248
- qkv_bias=True,
249
- scaled_cosine=True,
250
- scale_heads=False,
251
- logit_scale_max=math.log(1. / 0.01),
252
- attn_drop=0.,
253
- proj_drop=0.,
254
- xattn=False
255
- ):
256
- super().__init__()
257
- self.scaled_cosine = scaled_cosine
258
- self.scale_heads = scale_heads
259
- assert dim % num_heads == 0, 'dim should be divisible by num_heads'
260
- self.num_heads = num_heads
261
- self.head_dim = dim // num_heads
262
- self.scale = self.head_dim ** -0.5
263
- self.logit_scale_max = logit_scale_max
264
-
265
- # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
266
- self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
267
- if qkv_bias:
268
- self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
269
- else:
270
- self.in_proj_bias = None
271
-
272
- if self.scaled_cosine:
273
- self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
274
- else:
275
- self.logit_scale = None
276
- self.attn_drop = nn.Dropout(attn_drop)
277
- if self.scale_heads:
278
- self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
279
- else:
280
- self.head_scale = None
281
- self.out_proj = nn.Linear(dim, dim)
282
- self.out_drop = nn.Dropout(proj_drop)
283
- self.xattn = xattn
284
- self.xattn_drop = attn_drop
285
-
286
- def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
287
- q, k, v = _in_projection_packed(query, key, value, self.in_proj_weight, self.in_proj_bias)
288
- N_q, B_q, C_q = q.shape
289
- N_k, B_k, C_k = k.shape
290
- N_v, B_v, C_v = v.shape
291
- if self.xattn:
292
- # B, N, C -> B, N, num_heads, C
293
- q = q.permute(1, 0, 2).reshape(B_q, N_q, self.num_heads, -1)
294
- k = k.permute(1, 0, 2).reshape(B_k, N_k, self.num_heads, -1)
295
- v = v.permute(1, 0, 2).reshape(B_v, N_v, self.num_heads, -1)
296
-
297
- x = xops.memory_efficient_attention(
298
- q, k, v,
299
- p=self.xattn_drop,
300
- scale=self.scale if self.logit_scale is None else None,
301
- attn_bias=xops.LowerTriangularMask() if attn_mask is not None else None
302
- )
303
- else:
304
- # B*H, L, C
305
- q = q.contiguous().view(N_q, B_q * self.num_heads, -1).transpose(0, 1)
306
- k = k.contiguous().view(N_k, B_k * self.num_heads, -1).transpose(0, 1)
307
- v = v.contiguous().view(N_v, B_v * self.num_heads, -1).transpose(0, 1)
308
-
309
- if self.logit_scale is not None:
310
- # B*H, N_q, N_k
311
- attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
312
- logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
313
- attn = attn.view(B_q, self.num_heads, N_q, N_k) * logit_scale
314
- attn = attn.view(-1, N_q, N_k)
315
- else:
316
- q = q * self.scale
317
- attn = torch.bmm(q, k.transpose(-1, -2))
318
-
319
- if attn_mask is not None:
320
- if attn_mask.dtype == torch.bool:
321
- new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
322
- new_attn_mask.masked_fill_(attn_mask, float("-inf"))
323
- attn_mask = new_attn_mask
324
- attn += attn_mask
325
-
326
- attn = attn.softmax(dim=-1)
327
- attn = self.attn_drop(attn)
328
-
329
- x = torch.bmm(attn, v)
330
-
331
- if self.head_scale is not None:
332
- x = x.view(B_q, self.num_heads, N_q, C_q) * self.head_scale
333
- x = x.view(-1, N_q, C_q)
334
- x = x.transpose(0, 1).reshape(N_q, B_q, C_q)
335
- x = self.out_proj(x)
336
- x = self.out_drop(x)
337
- return x
338
-
339
- class CustomResidualAttentionBlock(nn.Module):
340
- def __init__(
341
- self,
342
- d_model: int,
343
- n_head: int,
344
- mlp_ratio: float = 4.0,
345
- ls_init_value: float = None,
346
- act_layer: Callable = nn.GELU,
347
- norm_layer: Callable = LayerNorm,
348
- scale_cosine_attn: bool = False,
349
- scale_heads: bool = False,
350
- scale_attn: bool = False,
351
- scale_fc: bool = False,
352
- cross_attn: bool = False,
353
- xattn: bool = False,
354
- ):
355
- super().__init__()
356
-
357
- self.ln_1 = norm_layer(d_model)
358
- self.ln_1_k = norm_layer(d_model) if cross_attn else self.ln_1
359
- self.ln_1_v = norm_layer(d_model) if cross_attn else self.ln_1
360
- self.attn = CustomAttention(
361
- d_model, n_head,
362
- qkv_bias=True,
363
- attn_drop=0.,
364
- proj_drop=0.,
365
- scaled_cosine=scale_cosine_attn,
366
- scale_heads=scale_heads,
367
- xattn=xattn
368
- )
369
-
370
- self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity()
371
- self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
372
-
373
- self.ln_2 = norm_layer(d_model)
374
- mlp_width = int(d_model * mlp_ratio)
375
- self.mlp = nn.Sequential(OrderedDict([
376
- ("c_fc", nn.Linear(d_model, mlp_width)),
377
- ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()),
378
- ("gelu", act_layer()),
379
- ("c_proj", nn.Linear(mlp_width, d_model))
380
- ]))
381
-
382
- self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
383
-
384
- def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
385
- q = q + self.ls_1(self.ln_attn(self.attn(self.ln_1(q), self.ln_1_k(k), self.ln_1_v(v), attn_mask=attn_mask)))
386
- q = q + self.ls_2(self.mlp(self.ln_2(q)))
387
- return q
388
-
389
- class CustomTransformer(nn.Module):
390
- def __init__(
391
- self,
392
- width: int,
393
- layers: int,
394
- heads: int,
395
- mlp_ratio: float = 4.0,
396
- ls_init_value: float = None,
397
- act_layer: Callable = nn.GELU,
398
- norm_layer: Callable = LayerNorm,
399
- scale_cosine_attn: bool = True,
400
- scale_heads: bool = False,
401
- scale_attn: bool = False,
402
- scale_fc: bool = False,
403
- cross_attn: bool = False,
404
- xattn: bool = False,
405
- ):
406
- super().__init__()
407
- self.width = width
408
- self.layers = layers
409
- self.grad_checkpointing = False
410
- self.xattn = xattn
411
-
412
- self.resblocks = nn.ModuleList([
413
- CustomResidualAttentionBlock(
414
- width,
415
- heads,
416
- mlp_ratio,
417
- ls_init_value=ls_init_value,
418
- act_layer=act_layer,
419
- norm_layer=norm_layer,
420
- scale_cosine_attn=scale_cosine_attn,
421
- scale_heads=scale_heads,
422
- scale_attn=scale_attn,
423
- scale_fc=scale_fc,
424
- cross_attn=cross_attn,
425
- xattn=xattn)
426
- for _ in range(layers)
427
- ])
428
-
429
- def get_cast_dtype(self) -> torch.dtype:
430
- return self.resblocks[0].mlp.c_fc.weight.dtype
431
-
432
- def forward(self, q: torch.Tensor, k: torch.Tensor = None, v: torch.Tensor = None, attn_mask: Optional[torch.Tensor] = None):
433
- if k is None and v is None:
434
- k = v = q
435
- for r in self.resblocks:
436
- if self.grad_checkpointing and not torch.jit.is_scripting():
437
- q = checkpoint(r, q, k, v, attn_mask)
438
- else:
439
- q = r(q, k, v, attn_mask=attn_mask)
440
- return q
441
-
442
-
443
- class ResidualAttentionBlock(nn.Module):
444
- def __init__(
445
- self,
446
- d_model: int,
447
- n_head: int,
448
- mlp_ratio: float = 4.0,
449
- ls_init_value: float = None,
450
- act_layer: Callable = nn.GELU,
451
- norm_layer: Callable = LayerNorm,
452
- xattn: bool = False,
453
- ):
454
- super().__init__()
455
-
456
- self.ln_1 = norm_layer(d_model)
457
- if xattn:
458
- self.attn = Attention(d_model, n_head, xattn=True)
459
- else:
460
- self.attn = nn.MultiheadAttention(d_model, n_head)
461
- self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
462
-
463
- self.ln_2 = norm_layer(d_model)
464
- mlp_width = int(d_model * mlp_ratio)
465
- self.mlp = nn.Sequential(OrderedDict([
466
- ("c_fc", nn.Linear(d_model, mlp_width)),
467
- ("gelu", act_layer()),
468
- ("c_proj", nn.Linear(mlp_width, d_model))
469
- ]))
470
-
471
- self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
472
- self.xattn = xattn
473
-
474
- def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
475
- attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None
476
- if self.xattn:
477
- return self.attn(x, attn_mask=attn_mask)
478
- return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
479
-
480
- def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
481
- x = x + self.ls_1(self.attention(self.ln_1(x), attn_mask=attn_mask))
482
- x = x + self.ls_2(self.mlp(self.ln_2(x)))
483
- return x
484
-
485
- class Transformer(nn.Module):
486
- def __init__(
487
- self,
488
- width: int,
489
- layers: int,
490
- heads: int,
491
- mlp_ratio: float = 4.0,
492
- ls_init_value: float = None,
493
- act_layer: Callable = nn.GELU,
494
- norm_layer: Callable = LayerNorm,
495
- xattn: bool = False,
496
- ):
497
- super().__init__()
498
- self.width = width
499
- self.layers = layers
500
- self.grad_checkpointing = False
501
-
502
- self.resblocks = nn.ModuleList([
503
- ResidualAttentionBlock(
504
- width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, xattn=xattn)
505
- for _ in range(layers)
506
- ])
507
-
508
- def get_cast_dtype(self) -> torch.dtype:
509
- return self.resblocks[0].mlp.c_fc.weight.dtype
510
-
511
- def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
512
- for r in self.resblocks:
513
- if self.grad_checkpointing and not torch.jit.is_scripting():
514
- x = checkpoint(r, x, attn_mask)
515
- else:
516
- x = r(x, attn_mask=attn_mask)
517
- return x
518
-
519
-
520
- class VisionTransformer(nn.Module):
521
- def __init__(
522
- self,
523
- image_size: int,
524
- patch_size: int,
525
- width: int,
526
- layers: int,
527
- heads: int,
528
- mlp_ratio: float,
529
- ls_init_value: float = None,
530
- patch_dropout: float = 0.,
531
- global_average_pool: bool = False,
532
- output_dim: int = 512,
533
- act_layer: Callable = nn.GELU,
534
- norm_layer: Callable = LayerNorm,
535
- xattn: bool = False,
536
- ):
537
- super().__init__()
538
- self.image_size = to_2tuple(image_size)
539
- self.patch_size = to_2tuple(patch_size)
540
- self.grid_size = (self.image_size[0] // self.patch_size[0], self.image_size[1] // self.patch_size[1])
541
- self.output_dim = output_dim
542
- self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
543
-
544
- scale = width ** -0.5
545
- self.class_embedding = nn.Parameter(scale * torch.randn(width))
546
- self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))
547
-
548
- # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
549
- self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
550
- self.ln_pre = norm_layer(width)
551
-
552
- self.transformer = Transformer(
553
- width,
554
- layers,
555
- heads,
556
- mlp_ratio,
557
- ls_init_value=ls_init_value,
558
- act_layer=act_layer,
559
- norm_layer=norm_layer,
560
- xattn=xattn
561
- )
562
-
563
- self.global_average_pool = global_average_pool
564
- self.ln_post = norm_layer(width)
565
- self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
566
-
567
- def lock(self, unlocked_groups=0, freeze_bn_stats=False):
568
- for param in self.parameters():
569
- param.requires_grad = False
570
-
571
- if unlocked_groups != 0:
572
- groups = [
573
- [
574
- self.conv1,
575
- self.class_embedding,
576
- self.positional_embedding,
577
- self.ln_pre,
578
- ],
579
- *self.transformer.resblocks[:-1],
580
- [
581
- self.transformer.resblocks[-1],
582
- self.ln_post,
583
- ],
584
- self.proj,
585
- ]
586
-
587
- def _unlock(x):
588
- if isinstance(x, Sequence):
589
- for g in x:
590
- _unlock(g)
591
- else:
592
- if isinstance(x, torch.nn.Parameter):
593
- x.requires_grad = True
594
- else:
595
- for p in x.parameters():
596
- p.requires_grad = True
597
-
598
- _unlock(groups[-unlocked_groups:])
599
-
600
- def get_num_layers(self):
601
- return self.transformer.layers
602
-
603
- @torch.jit.ignore
604
- def set_grad_checkpointing(self, enable=True):
605
- self.transformer.grad_checkpointing = enable
606
-
607
- @torch.jit.ignore
608
- def no_weight_decay(self):
609
- return {'positional_embedding', 'class_embedding'}
610
-
611
- def forward(self, x: torch.Tensor, return_all_features: bool=False):
612
- x = self.conv1(x) # shape = [*, width, grid, grid]
613
- x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
614
- x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
615
- x = torch.cat(
616
- [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
617
- x], dim=1) # shape = [*, grid ** 2 + 1, width]
618
- x = x + self.positional_embedding.to(x.dtype)
619
-
620
- # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
621
- x = self.patch_dropout(x)
622
- x = self.ln_pre(x)
623
-
624
- x = x.permute(1, 0, 2) # NLD -> LND
625
- x = self.transformer(x)
626
- x = x.permute(1, 0, 2) # LND -> NLD
627
-
628
- if not return_all_features:
629
- if self.global_average_pool:
630
- x = x.mean(dim=1) #x = x[:,1:,:].mean(dim=1)
631
- else:
632
- x = x[:, 0]
633
-
634
- x = self.ln_post(x)
635
-
636
- if self.proj is not None:
637
- x = x @ self.proj
638
-
639
- return x
640
-
641
-
642
- class TextTransformer(nn.Module):
643
- def __init__(
644
- self,
645
- context_length: int = 77,
646
- vocab_size: int = 49408,
647
- width: int = 512,
648
- heads: int = 8,
649
- layers: int = 12,
650
- ls_init_value: float = None,
651
- output_dim: int = 512,
652
- act_layer: Callable = nn.GELU,
653
- norm_layer: Callable = LayerNorm,
654
- xattn: bool= False,
655
- attn_mask: bool = True
656
- ):
657
- super().__init__()
658
- self.context_length = context_length
659
- self.vocab_size = vocab_size
660
- self.width = width
661
- self.output_dim = output_dim
662
-
663
- self.token_embedding = nn.Embedding(vocab_size, width)
664
- self.positional_embedding = nn.Parameter(torch.empty(self.context_length, width))
665
- self.transformer = Transformer(
666
- width=width,
667
- layers=layers,
668
- heads=heads,
669
- ls_init_value=ls_init_value,
670
- act_layer=act_layer,
671
- norm_layer=norm_layer,
672
- xattn=xattn
673
- )
674
-
675
- self.xattn = xattn
676
- self.ln_final = norm_layer(width)
677
- self.text_projection = nn.Parameter(torch.empty(width, output_dim))
678
-
679
- if attn_mask:
680
- self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)
681
- else:
682
- self.attn_mask = None
683
-
684
- self.init_parameters()
685
-
686
- def init_parameters(self):
687
- nn.init.normal_(self.token_embedding.weight, std=0.02)
688
- nn.init.normal_(self.positional_embedding, std=0.01)
689
-
690
- proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
691
- attn_std = self.transformer.width ** -0.5
692
- fc_std = (2 * self.transformer.width) ** -0.5
693
- for block in self.transformer.resblocks:
694
- nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
695
- nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
696
- nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
697
- nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
698
-
699
- if self.text_projection is not None:
700
- nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
701
-
702
- @torch.jit.ignore
703
- def set_grad_checkpointing(self, enable=True):
704
- self.transformer.grad_checkpointing = enable
705
-
706
- @torch.jit.ignore
707
- def no_weight_decay(self):
708
- # return {'positional_embedding', 'token_embedding'}
709
- return {'positional_embedding'}
710
-
711
- def get_num_layers(self):
712
- return self.transformer.layers
713
-
714
- def build_attention_mask(self):
715
- # lazily create causal attention mask, with full attention between the vision tokens
716
- # pytorch uses additive attention mask; fill with -inf
717
- mask = torch.empty(self.context_length, self.context_length)
718
- mask.fill_(float("-inf"))
719
- mask.triu_(1) # zero out the lower diagonal
720
- return mask
721
-
722
- def forward(self, text, return_all_features: bool=False):
723
- cast_dtype = self.transformer.get_cast_dtype()
724
- x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
725
-
726
- x = x + self.positional_embedding.to(cast_dtype)
727
- x = x.permute(1, 0, 2) # NLD -> LND
728
- x = self.transformer(x, attn_mask=self.attn_mask)
729
- # x = self.transformer(x) # no attention mask is applied
730
- x = x.permute(1, 0, 2) # LND -> NLD
731
- x = self.ln_final(x)
732
-
733
- if not return_all_features:
734
- # x.shape = [batch_size, n_ctx, transformer.width]
735
- # take features from the eot embedding (eot_token is the highest number in each sequence)
736
- x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
737
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eva_clip/utils.py DELETED
@@ -1,326 +0,0 @@
1
- from itertools import repeat
2
- import collections.abc
3
- import logging
4
- import math
5
- import numpy as np
6
-
7
- import torch
8
- from torch import nn as nn
9
- from torchvision.ops.misc import FrozenBatchNorm2d
10
- import torch.nn.functional as F
11
-
12
- # open CLIP
13
- def resize_clip_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
14
- # Rescale the grid of position embeddings when loading from state_dict
15
- old_pos_embed = state_dict.get('visual.positional_embedding', None)
16
- if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
17
- return
18
- grid_size = to_2tuple(model.visual.grid_size)
19
- extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
20
- new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
21
- if new_seq_len == old_pos_embed.shape[0]:
22
- return
23
-
24
- if extra_tokens:
25
- pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
26
- else:
27
- pos_emb_tok, pos_emb_img = None, old_pos_embed
28
- old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
29
-
30
- logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
31
- pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
32
- pos_emb_img = F.interpolate(
33
- pos_emb_img,
34
- size=grid_size,
35
- mode=interpolation,
36
- align_corners=True,
37
- )
38
- pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
39
- if pos_emb_tok is not None:
40
- new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
41
- else:
42
- new_pos_embed = pos_emb_img
43
- state_dict['visual.positional_embedding'] = new_pos_embed
44
-
45
-
46
- def resize_visual_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
47
- # Rescale the grid of position embeddings when loading from state_dict
48
- old_pos_embed = state_dict.get('positional_embedding', None)
49
- if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
50
- return
51
- grid_size = to_2tuple(model.visual.grid_size)
52
- extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
53
- new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
54
- if new_seq_len == old_pos_embed.shape[0]:
55
- return
56
-
57
- if extra_tokens:
58
- pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
59
- else:
60
- pos_emb_tok, pos_emb_img = None, old_pos_embed
61
- old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
62
-
63
- logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
64
- pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
65
- pos_emb_img = F.interpolate(
66
- pos_emb_img,
67
- size=grid_size,
68
- mode=interpolation,
69
- align_corners=True,
70
- )
71
- pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
72
- if pos_emb_tok is not None:
73
- new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
74
- else:
75
- new_pos_embed = pos_emb_img
76
- state_dict['positional_embedding'] = new_pos_embed
77
-
78
- def resize_evaclip_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
79
- all_keys = list(state_dict.keys())
80
- # interpolate position embedding
81
- if 'visual.pos_embed' in state_dict:
82
- pos_embed_checkpoint = state_dict['visual.pos_embed']
83
- embedding_size = pos_embed_checkpoint.shape[-1]
84
- num_patches = model.visual.patch_embed.num_patches
85
- num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches
86
- # height (== width) for the checkpoint position embedding
87
- orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
88
- # height (== width) for the new position embedding
89
- new_size = int(num_patches ** 0.5)
90
- # class_token and dist_token are kept unchanged
91
- if orig_size != new_size:
92
- print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
93
- extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
94
- # only the position tokens are interpolated
95
- pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
96
- pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
97
- pos_tokens = torch.nn.functional.interpolate(
98
- pos_tokens.float(), size=(new_size, new_size), mode='bicubic', align_corners=False)
99
- pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
100
- new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
101
- state_dict['visual.pos_embed'] = new_pos_embed
102
-
103
- patch_embed_proj = state_dict['visual.patch_embed.proj.weight']
104
- patch_size = model.visual.patch_embed.patch_size
105
- state_dict['visual.patch_embed.proj.weight'] = torch.nn.functional.interpolate(
106
- patch_embed_proj.float(), size=patch_size, mode='bicubic', align_corners=False)
107
-
108
-
109
- def resize_eva_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
110
- all_keys = list(state_dict.keys())
111
- # interpolate position embedding
112
- if 'pos_embed' in state_dict:
113
- pos_embed_checkpoint = state_dict['pos_embed']
114
- embedding_size = pos_embed_checkpoint.shape[-1]
115
- num_patches = model.visual.patch_embed.num_patches
116
- num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches
117
- # height (== width) for the checkpoint position embedding
118
- orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
119
- # height (== width) for the new position embedding
120
- new_size = int(num_patches ** 0.5)
121
- # class_token and dist_token are kept unchanged
122
- if orig_size != new_size:
123
- print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
124
- extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
125
- # only the position tokens are interpolated
126
- pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
127
- pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
128
- pos_tokens = torch.nn.functional.interpolate(
129
- pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
130
- pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
131
- new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
132
- state_dict['pos_embed'] = new_pos_embed
133
-
134
- patch_embed_proj = state_dict['patch_embed.proj.weight']
135
- patch_size = model.visual.patch_embed.patch_size
136
- state_dict['patch_embed.proj.weight'] = torch.nn.functional.interpolate(
137
- patch_embed_proj.float(), size=patch_size, mode='bicubic', align_corners=False)
138
-
139
-
140
- def resize_rel_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
141
- all_keys = list(state_dict.keys())
142
- for key in all_keys:
143
- if "relative_position_index" in key:
144
- state_dict.pop(key)
145
-
146
- if "relative_position_bias_table" in key:
147
- rel_pos_bias = state_dict[key]
148
- src_num_pos, num_attn_heads = rel_pos_bias.size()
149
- dst_num_pos, _ = model.visual.state_dict()[key].size()
150
- dst_patch_shape = model.visual.patch_embed.patch_shape
151
- if dst_patch_shape[0] != dst_patch_shape[1]:
152
- raise NotImplementedError()
153
- num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1)
154
- src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
155
- dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)
156
- if src_size != dst_size:
157
- print("Position interpolate for %s from %dx%d to %dx%d" % (
158
- key, src_size, src_size, dst_size, dst_size))
159
- extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
160
- rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
161
-
162
- def geometric_progression(a, r, n):
163
- return a * (1.0 - r ** n) / (1.0 - r)
164
-
165
- left, right = 1.01, 1.5
166
- while right - left > 1e-6:
167
- q = (left + right) / 2.0
168
- gp = geometric_progression(1, q, src_size // 2)
169
- if gp > dst_size // 2:
170
- right = q
171
- else:
172
- left = q
173
-
174
- # if q > 1.090307:
175
- # q = 1.090307
176
-
177
- dis = []
178
- cur = 1
179
- for i in range(src_size // 2):
180
- dis.append(cur)
181
- cur += q ** (i + 1)
182
-
183
- r_ids = [-_ for _ in reversed(dis)]
184
-
185
- x = r_ids + [0] + dis
186
- y = r_ids + [0] + dis
187
-
188
- t = dst_size // 2.0
189
- dx = np.arange(-t, t + 0.1, 1.0)
190
- dy = np.arange(-t, t + 0.1, 1.0)
191
-
192
- print("Original positions = %s" % str(x))
193
- print("Target positions = %s" % str(dx))
194
-
195
- all_rel_pos_bias = []
196
-
197
- for i in range(num_attn_heads):
198
- z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
199
- f = F.interpolate.interp2d(x, y, z, kind='cubic')
200
- all_rel_pos_bias.append(
201
- torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))
202
-
203
- rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
204
-
205
- new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
206
- state_dict[key] = new_rel_pos_bias
207
-
208
- # interpolate position embedding
209
- if 'pos_embed' in state_dict:
210
- pos_embed_checkpoint = state_dict['pos_embed']
211
- embedding_size = pos_embed_checkpoint.shape[-1]
212
- num_patches = model.visual.patch_embed.num_patches
213
- num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches
214
- # height (== width) for the checkpoint position embedding
215
- orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
216
- # height (== width) for the new position embedding
217
- new_size = int(num_patches ** 0.5)
218
- # class_token and dist_token are kept unchanged
219
- if orig_size != new_size:
220
- print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
221
- extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
222
- # only the position tokens are interpolated
223
- pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
224
- pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
225
- pos_tokens = torch.nn.functional.interpolate(
226
- pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
227
- pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
228
- new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
229
- state_dict['pos_embed'] = new_pos_embed
230
-
231
- patch_embed_proj = state_dict['patch_embed.proj.weight']
232
- patch_size = model.visual.patch_embed.patch_size
233
- state_dict['patch_embed.proj.weight'] = torch.nn.functional.interpolate(
234
- patch_embed_proj.float(), size=patch_size, mode='bicubic', align_corners=False)
235
-
236
-
237
- def freeze_batch_norm_2d(module, module_match={}, name=''):
238
- """
239
- Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
240
- itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
241
- returned. Otherwise, the module is walked recursively and submodules are converted in place.
242
-
243
- Args:
244
- module (torch.nn.Module): Any PyTorch module.
245
- module_match (dict): Dictionary of full module names to freeze (all if empty)
246
- name (str): Full module name (prefix)
247
-
248
- Returns:
249
- torch.nn.Module: Resulting module
250
-
251
- Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
252
- """
253
- res = module
254
- is_match = True
255
- if module_match:
256
- is_match = name in module_match
257
- if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
258
- res = FrozenBatchNorm2d(module.num_features)
259
- res.num_features = module.num_features
260
- res.affine = module.affine
261
- if module.affine:
262
- res.weight.data = module.weight.data.clone().detach()
263
- res.bias.data = module.bias.data.clone().detach()
264
- res.running_mean.data = module.running_mean.data
265
- res.running_var.data = module.running_var.data
266
- res.eps = module.eps
267
- else:
268
- for child_name, child in module.named_children():
269
- full_child_name = '.'.join([name, child_name]) if name else child_name
270
- new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
271
- if new_child is not child:
272
- res.add_module(child_name, new_child)
273
- return res
274
-
275
-
276
- # From PyTorch internals
277
- def _ntuple(n):
278
- def parse(x):
279
- if isinstance(x, collections.abc.Iterable):
280
- return x
281
- return tuple(repeat(x, n))
282
- return parse
283
-
284
-
285
- to_1tuple = _ntuple(1)
286
- to_2tuple = _ntuple(2)
287
- to_3tuple = _ntuple(3)
288
- to_4tuple = _ntuple(4)
289
- to_ntuple = lambda n, x: _ntuple(n)(x)
290
-
291
-
292
- def is_logging(args):
293
- def is_global_master(args):
294
- return args.rank == 0
295
-
296
- def is_local_master(args):
297
- return args.local_rank == 0
298
-
299
- def is_master(args, local=False):
300
- return is_local_master(args) if local else is_global_master(args)
301
- return is_master
302
-
303
-
304
- class AllGather(torch.autograd.Function):
305
- """An autograd function that performs allgather on a tensor.
306
- Performs all_gather operation on the provided tensors.
307
- *** Warning ***: torch.distributed.all_gather has no gradient.
308
- """
309
-
310
- @staticmethod
311
- def forward(ctx, tensor, rank, world_size):
312
- tensors_gather = [torch.empty_like(tensor) for _ in range(world_size)]
313
- torch.distributed.all_gather(tensors_gather, tensor)
314
- ctx.rank = rank
315
- ctx.batch_size = tensor.shape[0]
316
- return torch.cat(tensors_gather, 0)
317
-
318
- @staticmethod
319
- def backward(ctx, grad_output):
320
- return (
321
- grad_output[ctx.batch_size * ctx.rank: ctx.batch_size * (ctx.rank + 1)],
322
- None,
323
- None
324
- )
325
-
326
- allgather = AllGather.apply
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modeling_kangaroo.py CHANGED
@@ -17,8 +17,6 @@
17
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
  # See the License for the specific language governing permissions and
19
  # limitations under the License.
20
- """PyTorch LLaMA model."""
21
-
22
  import math
23
  from typing import List, Optional, Tuple, Union
24
 
@@ -26,16 +24,15 @@ import torch
26
  import torch.nn.functional as F
27
  import torch.utils.checkpoint
28
  from torch import nn
29
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
30
 
 
31
  from transformers.activations import ACT2FN
32
  from transformers.cache_utils import Cache, DynamicCache, StaticCache
33
  from transformers.modeling_attn_mask_utils import AttentionMaskConverter
34
  from transformers.modeling_outputs import (
35
  BaseModelOutputWithPast,
36
  CausalLMOutputWithPast,
37
- QuestionAnsweringModelOutput,
38
- SequenceClassifierOutputWithPast,
39
  )
40
  from transformers.modeling_utils import PreTrainedModel
41
  from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
@@ -49,15 +46,14 @@ from transformers.utils import (
49
  )
50
  from transformers.models.llama.configuration_llama import LlamaConfig
51
 
52
- from eva_clip import create_model_and_transforms
53
  from .mm_projector_builder import build_vision_projector
 
54
 
55
  if is_flash_attn_2_available():
56
  from flash_attn import flash_attn_func, flash_attn_varlen_func
57
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
58
 
59
- from .data_utils import get_input, add_pred_to_history
60
- import transformers
61
 
62
  logger = logging.get_logger(__name__)
63
 
@@ -107,22 +103,6 @@ class LlamaRotaryEmbedding(nn.Module):
107
  self.register_buffer("inv_freq", inv_freq, persistent=False)
108
  # For BC we register cos and sin cached
109
  self.max_seq_len_cached = max_position_embeddings
110
-
111
- #@torch.no_grad()
112
- #def forward(self, x, position_ids):
113
- # # x: [bs, num_attention_heads, seq_len, head_size]
114
- # inv_freq_expanded = self.inv_freq[None, :, None].to(torch.bfloat16).expand(position_ids.shape[0], -1, 1)
115
- # position_ids_expanded = position_ids[:, None, :].to(torch.bfloat16)
116
- # # Force float32 since bfloat16 loses precision on long contexts
117
- # # See https://github.com/huggingface/transformers/pull/29285
118
- # device_type = x.device.type
119
- # device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
120
- # with torch.autocast(device_type=device_type, enabled=False):
121
- # freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)
122
- # emb = torch.cat((freqs, freqs), dim=-1)
123
- # cos = emb.cos()
124
- # sin = emb.sin()
125
- # return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
126
 
127
  @torch.no_grad()
128
  def forward(self, x, position_ids):
@@ -179,7 +159,6 @@ def rotate_half(x):
179
 
180
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
181
  """Applies Rotary Position Embedding to the query and key tensors.
182
-
183
  Args:
184
  q (`torch.Tensor`): The query tensor.
185
  k (`torch.Tensor`): The key tensor.
@@ -504,7 +483,6 @@ class LlamaFlashAttention2(LlamaAttention):
504
  """
505
  Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
506
  first unpad the input, then computes the attention scores and pad the final attention scores.
507
-
508
  Args:
509
  query_states (`torch.Tensor`):
510
  Input query states to be passed to Flash Attention API
@@ -759,11 +737,9 @@ LLAMA_START_DOCSTRING = r"""
759
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
760
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
761
  etc.)
762
-
763
  This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
764
  Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
765
  and behavior.
766
-
767
  Parameters:
768
  config ([`LlamaConfig`]):
769
  Model configuration class with all the parameters of the model. Initializing with a config file does not
@@ -804,50 +780,38 @@ LLAMA_INPUTS_DOCSTRING = r"""
804
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
805
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
806
  it.
807
-
808
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
809
  [`PreTrainedTokenizer.__call__`] for details.
810
-
811
  [What are input IDs?](../glossary#input-ids)
812
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
813
  Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
814
-
815
  - 1 for tokens that are **not masked**,
816
  - 0 for tokens that are **masked**.
817
-
818
  [What are attention masks?](../glossary#attention-mask)
819
-
820
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
821
  [`PreTrainedTokenizer.__call__`] for details.
822
-
823
  If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
824
  `past_key_values`).
825
-
826
  If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
827
  and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
828
  information on the default strategy.
829
-
830
  - 1 indicates the head is **not masked**,
831
  - 0 indicates the head is **masked**.
832
  position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
833
  Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
834
  config.n_positions - 1]`.
835
-
836
  [What are position IDs?](../glossary#position-ids)
837
  past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
838
  Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
839
  blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
840
  returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
841
-
842
  Two formats are allowed:
843
  - a [`~cache_utils.Cache`] instance;
844
  - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
845
  shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
846
  cache format.
847
-
848
  The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
849
  legacy cache format will be returned.
850
-
851
  If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
852
  have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
853
  of shape `(batch_size, sequence_length)`.
@@ -880,7 +844,6 @@ LLAMA_INPUTS_DOCSTRING = r"""
880
  class LlamaModel(LlamaPreTrainedModel):
881
  """
882
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
883
-
884
  Args:
885
  config: LlamaConfig
886
  """
@@ -1107,13 +1070,10 @@ class KangarooForCausalLM(LlamaPreTrainedModel):
1107
  super().__init__(config)
1108
  self.model = LlamaModel(config)
1109
  model_name = "EVA02-CLIP-L-14-448"
1110
- pretrained = "/mnt/dolphinfs/hdd_pool/docker/user/hadoop-mtcv/liujiajun18/models/models--QuanSun--EVA-CLIP/snapshots/11afd202f2ae80869d6cef18b1ec775e79bd8d12/EVA02_CLIP_L_psz14_s4B.pt"
1111
  self.vocab_size = config.vocab_size
1112
- model, _, preprocess = create_model_and_transforms(model_name, pretrained, force_custom_clip=True)
1113
- model.text = None
1114
- model.logit_scale = None
1115
- self.vision_tower = model.visual
1116
  self.mm_projector = build_vision_projector(mm_hidden_size=self.vision_tower.num_features, hidden_size=config.hidden_size, projector_type="mlp2x_gelu")
 
1117
  self.vocab_size = config.vocab_size
1118
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1119
 
@@ -1121,6 +1081,7 @@ class KangarooForCausalLM(LlamaPreTrainedModel):
1121
  self.angle = torch.stack([1 / torch.pow(torch.tensor(10000), torch.tensor(2 * (hid_j // 2) / hidden_dim)) for hid_j in range(hidden_dim)])
1122
 
1123
  self.patch_shape = self.vision_tower.patch_embed.patch_shape[0]
 
1124
  self.adaptive_pooling = torch.nn.Conv3d(in_channels=self.vision_tower.num_features,
1125
  out_channels=self.vision_tower.num_features,
1126
  kernel_size=(2, 2, 2),
@@ -1164,10 +1125,6 @@ class KangarooForCausalLM(LlamaPreTrainedModel):
1164
  image_features = image_features.permute(0, 4, 1, 2, 3)
1165
  image_features = self.adaptive_pooling(image_features)
1166
  image_features = image_features.permute(0, 2, 3, 4, 1)
1167
- #B, T, P, _, __ = image_features.shape
1168
- #image_features = image_features.reshape(B, T // 2, 2, P, _, __)
1169
- #image_features = image_features.mean(dim=2)
1170
- #image_features = image_features.reshape(B, T // 2, P, _, __)
1171
  image_features = image_features.reshape(-1, self.patch_shape*self.patch_shape // 4, image_features.shape[-1])
1172
 
1173
  image_features = self.mm_projector(image_features)
@@ -1195,20 +1152,14 @@ class KangarooForCausalLM(LlamaPreTrainedModel):
1195
  Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1196
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1197
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1198
-
1199
  Returns:
1200
-
1201
  Example:
1202
-
1203
  ```python
1204
  >>> from transformers import AutoTokenizer, LlamaForCausalLM
1205
-
1206
  >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
1207
  >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
1208
-
1209
  >>> prompt = "Hey, are you conscious? Can you talk to me?"
1210
  >>> inputs = tokenizer(prompt, return_tensors="pt")
1211
-
1212
  >>> # Generate
1213
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1214
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
@@ -1337,6 +1288,7 @@ class KangarooForCausalLM(LlamaPreTrainedModel):
1337
  T, C, H, W = video.shape
1338
  video = video.reshape(-1, C, H, W)
1339
  images_features = self.encode_images(video, durations, T)
 
1340
  input_embeds = self.model.embed_tokens.weight[inputs]
1341
  encoder_input = self.fuse_tokens_and_images(input_embeds, images_features, inputs)
1342
  encoder_input = encoder_input.permute(1, 0, 2)
@@ -1420,13 +1372,12 @@ class KangarooForCausalLM(LlamaPreTrainedModel):
1420
  )
1421
  return model_inputs
1422
 
1423
-
1424
  @torch.no_grad()
1425
  def chat(
1426
  self,
1427
  video_path : str,
1428
  query : str,
1429
- tokenizer : transformers.PreTrainedTokenizer,
1430
  num_segments : int = 64,
1431
  history : str = None,
1432
  system_prompt_id : int = 0,
@@ -1456,6 +1407,4 @@ class KangarooForCausalLM(LlamaPreTrainedModel):
1456
  reordered_past += (
1457
  tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1458
  )
1459
- return reordered_past
1460
-
1461
-
 
17
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
  # See the License for the specific language governing permissions and
19
  # limitations under the License.
 
 
20
  import math
21
  from typing import List, Optional, Tuple, Union
22
 
 
24
  import torch.nn.functional as F
25
  import torch.utils.checkpoint
26
  from torch import nn
27
+ from torch.nn import CrossEntropyLoss
28
 
29
+ from transformers import PreTrainedTokenizer
30
  from transformers.activations import ACT2FN
31
  from transformers.cache_utils import Cache, DynamicCache, StaticCache
32
  from transformers.modeling_attn_mask_utils import AttentionMaskConverter
33
  from transformers.modeling_outputs import (
34
  BaseModelOutputWithPast,
35
  CausalLMOutputWithPast,
 
 
36
  )
37
  from transformers.modeling_utils import PreTrainedModel
38
  from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
 
46
  )
47
  from transformers.models.llama.configuration_llama import LlamaConfig
48
 
49
+ from .vision_tower_builder import build_vision_tower
50
  from .mm_projector_builder import build_vision_projector
51
+ from .data_utils import get_input, add_pred_to_history
52
 
53
  if is_flash_attn_2_available():
54
  from flash_attn import flash_attn_func, flash_attn_varlen_func
55
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
56
 
 
 
57
 
58
  logger = logging.get_logger(__name__)
59
 
 
103
  self.register_buffer("inv_freq", inv_freq, persistent=False)
104
  # For BC we register cos and sin cached
105
  self.max_seq_len_cached = max_position_embeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  @torch.no_grad()
108
  def forward(self, x, position_ids):
 
159
 
160
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
161
  """Applies Rotary Position Embedding to the query and key tensors.
 
162
  Args:
163
  q (`torch.Tensor`): The query tensor.
164
  k (`torch.Tensor`): The key tensor.
 
483
  """
484
  Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
485
  first unpad the input, then computes the attention scores and pad the final attention scores.
 
486
  Args:
487
  query_states (`torch.Tensor`):
488
  Input query states to be passed to Flash Attention API
 
737
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
738
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
739
  etc.)
 
740
  This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
741
  Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
742
  and behavior.
 
743
  Parameters:
744
  config ([`LlamaConfig`]):
745
  Model configuration class with all the parameters of the model. Initializing with a config file does not
 
780
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
781
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
782
  it.
 
783
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
784
  [`PreTrainedTokenizer.__call__`] for details.
 
785
  [What are input IDs?](../glossary#input-ids)
786
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
787
  Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
 
788
  - 1 for tokens that are **not masked**,
789
  - 0 for tokens that are **masked**.
 
790
  [What are attention masks?](../glossary#attention-mask)
 
791
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
792
  [`PreTrainedTokenizer.__call__`] for details.
 
793
  If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
794
  `past_key_values`).
 
795
  If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
796
  and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
797
  information on the default strategy.
 
798
  - 1 indicates the head is **not masked**,
799
  - 0 indicates the head is **masked**.
800
  position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
801
  Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
802
  config.n_positions - 1]`.
 
803
  [What are position IDs?](../glossary#position-ids)
804
  past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
805
  Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
806
  blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
807
  returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
 
808
  Two formats are allowed:
809
  - a [`~cache_utils.Cache`] instance;
810
  - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
811
  shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
812
  cache format.
 
813
  The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
814
  legacy cache format will be returned.
 
815
  If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
816
  have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
817
  of shape `(batch_size, sequence_length)`.
 
844
  class LlamaModel(LlamaPreTrainedModel):
845
  """
846
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
 
847
  Args:
848
  config: LlamaConfig
849
  """
 
1070
  super().__init__(config)
1071
  self.model = LlamaModel(config)
1072
  model_name = "EVA02-CLIP-L-14-448"
 
1073
  self.vocab_size = config.vocab_size
1074
+ self.vision_tower = build_vision_tower(model_name)
 
 
 
1075
  self.mm_projector = build_vision_projector(mm_hidden_size=self.vision_tower.num_features, hidden_size=config.hidden_size, projector_type="mlp2x_gelu")
1076
+
1077
  self.vocab_size = config.vocab_size
1078
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1079
 
 
1081
  self.angle = torch.stack([1 / torch.pow(torch.tensor(10000), torch.tensor(2 * (hid_j // 2) / hidden_dim)) for hid_j in range(hidden_dim)])
1082
 
1083
  self.patch_shape = self.vision_tower.patch_embed.patch_shape[0]
1084
+ # patchify module
1085
  self.adaptive_pooling = torch.nn.Conv3d(in_channels=self.vision_tower.num_features,
1086
  out_channels=self.vision_tower.num_features,
1087
  kernel_size=(2, 2, 2),
 
1125
  image_features = image_features.permute(0, 4, 1, 2, 3)
1126
  image_features = self.adaptive_pooling(image_features)
1127
  image_features = image_features.permute(0, 2, 3, 4, 1)
 
 
 
 
1128
  image_features = image_features.reshape(-1, self.patch_shape*self.patch_shape // 4, image_features.shape[-1])
1129
 
1130
  image_features = self.mm_projector(image_features)
 
1152
  Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1153
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1154
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
 
1155
  Returns:
 
1156
  Example:
 
1157
  ```python
1158
  >>> from transformers import AutoTokenizer, LlamaForCausalLM
 
1159
  >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
1160
  >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
 
1161
  >>> prompt = "Hey, are you conscious? Can you talk to me?"
1162
  >>> inputs = tokenizer(prompt, return_tensors="pt")
 
1163
  >>> # Generate
1164
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1165
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
 
1288
  T, C, H, W = video.shape
1289
  video = video.reshape(-1, C, H, W)
1290
  images_features = self.encode_images(video, durations, T)
1291
+
1292
  input_embeds = self.model.embed_tokens.weight[inputs]
1293
  encoder_input = self.fuse_tokens_and_images(input_embeds, images_features, inputs)
1294
  encoder_input = encoder_input.permute(1, 0, 2)
 
1372
  )
1373
  return model_inputs
1374
 
 
1375
  @torch.no_grad()
1376
  def chat(
1377
  self,
1378
  video_path : str,
1379
  query : str,
1380
+ tokenizer : PreTrainedTokenizer,
1381
  num_segments : int = 64,
1382
  history : str = None,
1383
  system_prompt_id : int = 0,
 
1407
  reordered_past += (
1408
  tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1409
  )
1410
+ return reordered_past
 
 
eva_clip/eva_vit_model.py → vision_tower_builder.py RENAMED
@@ -1,20 +1,25 @@
1
  # --------------------------------------------------------
2
- # Adapted from https://github.com/microsoft/unilm/tree/master/beit
3
  # --------------------------------------------------------
4
  import math
5
  import os
6
- from functools import partial
 
 
7
  import torch
8
  import torch.nn as nn
9
  import torch.nn.functional as F
 
 
 
 
 
 
10
  try:
11
  from timm.models.layers import drop_path, to_2tuple, trunc_normal_
12
  except:
13
  from timm.layers import drop_path, to_2tuple, trunc_normal_
14
 
15
- from .transformer import PatchDropout
16
- from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast
17
-
18
  if os.getenv('ENV_TYPE') == 'deepspeed':
19
  try:
20
  from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
@@ -30,6 +35,59 @@ except ImportError:
30
  print("Please 'pip install xformers'")
31
 
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  class DropPath(nn.Module):
34
  """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
35
  """
@@ -78,6 +136,7 @@ class Mlp(nn.Module):
78
  x = self.drop(x)
79
  return x
80
 
 
81
  class SwiGLU(nn.Module):
82
  def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.,
83
  norm_layer=nn.LayerNorm, subln=False):
@@ -103,6 +162,7 @@ class SwiGLU(nn.Module):
103
  x = self.drop(x)
104
  return x
105
 
 
106
  class Attention(nn.Module):
107
  def __init__(
108
  self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
@@ -364,6 +424,91 @@ class RelativePositionBias(nn.Module):
364
  return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
365
 
366
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
  class EVAVisionTransformer(nn.Module):
368
  """ Vision Transformer with support for patch or hybrid CNN input stage
369
  """
@@ -383,7 +528,6 @@ class EVAVisionTransformer(nn.Module):
383
  num_patches = self.patch_embed.num_patches
384
 
385
  self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
386
- # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
387
  if use_abs_pos_emb:
388
  self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
389
  else:
@@ -530,3 +674,95 @@ class EVAVisionTransformer(nn.Module):
530
  x = self.forward_features(x)
531
  x = self.head(x)
532
  return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # --------------------------------------------------------
2
+ # Adapted from https://github.com/baaivision/EVA
3
  # --------------------------------------------------------
4
  import math
5
  import os
6
+ import json
7
+ import logging
8
+
9
  import torch
10
  import torch.nn as nn
11
  import torch.nn.functional as F
12
+ from einops import rearrange, repeat
13
+
14
+ from functools import partial
15
+ from typing import Optional, Tuple, Union
16
+ from dataclasses import dataclass
17
+
18
  try:
19
  from timm.models.layers import drop_path, to_2tuple, trunc_normal_
20
  except:
21
  from timm.layers import drop_path, to_2tuple, trunc_normal_
22
 
 
 
 
23
  if os.getenv('ENV_TYPE') == 'deepspeed':
24
  try:
25
  from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
 
35
  print("Please 'pip install xformers'")
36
 
37
 
38
+ class PatchDropout(nn.Module):
39
+ """
40
+ https://arxiv.org/abs/2212.00794
41
+ """
42
+
43
+ def __init__(self, prob, exclude_first_token=True):
44
+ super().__init__()
45
+ assert 0 <= prob < 1.
46
+ self.prob = prob
47
+ self.exclude_first_token = exclude_first_token # exclude CLS token
48
+ logging.info(f"os.getenv('RoPE')={os.getenv('RoPE')}")
49
+
50
+ def forward(self, x):
51
+ if not self.training or self.prob == 0.:
52
+ return x
53
+
54
+ if self.exclude_first_token:
55
+ cls_tokens, x = x[:, :1], x[:, 1:]
56
+ else:
57
+ cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
58
+
59
+ batch = x.size()[0]
60
+ num_tokens = x.size()[1]
61
+
62
+ batch_indices = torch.arange(batch)
63
+ batch_indices = batch_indices[..., None]
64
+
65
+ keep_prob = 1 - self.prob
66
+ num_patches_keep = max(1, int(num_tokens * keep_prob))
67
+
68
+ rand = torch.randn(batch, num_tokens)
69
+ patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
70
+
71
+ x = x[batch_indices, patch_indices_keep]
72
+
73
+ if self.exclude_first_token:
74
+ x = torch.cat((cls_tokens, x), dim=1)
75
+
76
+ if self.training and os.getenv('RoPE') == '1':
77
+ return x, patch_indices_keep
78
+
79
+ return x
80
+
81
+
82
+ class LayerNorm(nn.LayerNorm):
83
+ """Subclass torch's LayerNorm (with cast back to input dtype)."""
84
+
85
+ def forward(self, x: torch.Tensor):
86
+ orig_type = x.dtype
87
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
88
+ return x.to(orig_type)
89
+
90
+
91
  class DropPath(nn.Module):
92
  """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
93
  """
 
136
  x = self.drop(x)
137
  return x
138
 
139
+
140
  class SwiGLU(nn.Module):
141
  def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.,
142
  norm_layer=nn.LayerNorm, subln=False):
 
162
  x = self.drop(x)
163
  return x
164
 
165
+
166
  class Attention(nn.Module):
167
  def __init__(
168
  self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
 
424
  return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
425
 
426
 
427
+ def broadcat(tensors, dim = -1):
428
+ num_tensors = len(tensors)
429
+ shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
430
+ assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
431
+ shape_len = list(shape_lens)[0]
432
+ dim = (dim + shape_len) if dim < 0 else dim
433
+ dims = list(zip(*map(lambda t: list(t.shape), tensors)))
434
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
435
+ assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation'
436
+ max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
437
+ expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
438
+ expanded_dims.insert(dim, (dim, dims[dim]))
439
+ expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
440
+ tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
441
+ return torch.cat(tensors, dim = dim)
442
+
443
+
444
+ def rotate_half(x):
445
+ x = rearrange(x, '... (d r) -> ... d r', r = 2)
446
+ x1, x2 = x.unbind(dim = -1)
447
+ x = torch.stack((-x2, x1), dim = -1)
448
+ return rearrange(x, '... d r -> ... (d r)')
449
+
450
+
451
+ class VisionRotaryEmbeddingFast(nn.Module):
452
+ def __init__(
453
+ self,
454
+ dim,
455
+ pt_seq_len,
456
+ ft_seq_len=None,
457
+ custom_freqs = None,
458
+ freqs_for = 'lang',
459
+ theta = 10000,
460
+ max_freq = 10,
461
+ num_freqs = 1,
462
+ patch_dropout = 0.
463
+ ):
464
+ super().__init__()
465
+ if custom_freqs:
466
+ freqs = custom_freqs
467
+ elif freqs_for == 'lang':
468
+ freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
469
+ elif freqs_for == 'pixel':
470
+ freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
471
+ elif freqs_for == 'constant':
472
+ freqs = torch.ones(num_freqs).float()
473
+ else:
474
+ raise ValueError(f'unknown modality {freqs_for}')
475
+
476
+ if ft_seq_len is None: ft_seq_len = pt_seq_len
477
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
478
+
479
+ freqs = torch.einsum('..., f -> ... f', t, freqs)
480
+ freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
481
+ freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1)
482
+
483
+ freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
484
+ freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
485
+
486
+ self.patch_dropout = patch_dropout
487
+
488
+ self.register_buffer("freqs_cos", freqs_cos)
489
+ self.register_buffer("freqs_sin", freqs_sin)
490
+
491
+ logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
492
+
493
+ def forward(self, t, patch_indices_keep=None):
494
+ if patch_indices_keep is not None:
495
+ batch = t.size()[0]
496
+ batch_indices = torch.arange(batch)
497
+ batch_indices = batch_indices[..., None]
498
+
499
+ freqs_cos = repeat(self.freqs_cos, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])
500
+ freqs_sin = repeat(self.freqs_sin, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])
501
+
502
+ freqs_cos = freqs_cos[batch_indices, patch_indices_keep]
503
+ freqs_cos = rearrange(freqs_cos, 'n i m j -> n m i j')
504
+ freqs_sin = freqs_sin[batch_indices, patch_indices_keep]
505
+ freqs_sin = rearrange(freqs_sin, 'n i m j -> n m i j')
506
+
507
+ return t * freqs_cos + rotate_half(t) * freqs_sin
508
+
509
+ return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
510
+
511
+
512
  class EVAVisionTransformer(nn.Module):
513
  """ Vision Transformer with support for patch or hybrid CNN input stage
514
  """
 
528
  num_patches = self.patch_embed.num_patches
529
 
530
  self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
 
531
  if use_abs_pos_emb:
532
  self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
533
  else:
 
674
  x = self.forward_features(x)
675
  x = self.head(x)
676
  return x
677
+
678
+
679
+ @dataclass
680
+ class CLIPVisionCfg:
681
+ layers: Union[Tuple[int, int, int, int], int] = 12
682
+ width: int = 768
683
+ head_width: int = 64
684
+ mlp_ratio: float = 4.0
685
+ patch_size: int = 16
686
+ image_size: Union[Tuple[int, int], int] = 224
687
+ ls_init_value: Optional[float] = None # layer scale initial value
688
+ patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
689
+ global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
690
+ drop_path_rate: Optional[float] = None # drop path rate
691
+ timm_model_name: str = None # a valid model name overrides layers, width, patch_size
692
+ timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
693
+ timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
694
+ timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
695
+ timm_proj_bias: bool = False # enable bias final projection
696
+ eva_model_name: str = None # a valid eva model name overrides layers, width, patch_size
697
+ qkv_bias: bool = True
698
+ fusedLN: bool = False
699
+ xattn: bool = False
700
+ postnorm: bool = False
701
+ rope: bool = False
702
+ pt_hw_seq_len: int = 16 # 224/14
703
+ intp_freq: bool = False
704
+ naiveswiglu: bool = False
705
+ subln: bool = False
706
+
707
+
708
+ def build_vision_tower(
709
+ model_name: str,
710
+ precision: str = 'bf16',
711
+ device: Union[str, torch.device] = 'cpu',
712
+ ):
713
+ if isinstance(device, str):
714
+ device = torch.device(device)
715
+
716
+ model_cfg = json.load(open(model_name + '.json'))
717
+ if 'rope' in model_cfg.get('vision_cfg', {}):
718
+ if model_cfg['vision_cfg']['rope']:
719
+ os.environ['RoPE'] = "1"
720
+ else:
721
+ os.environ['RoPE'] = "0"
722
+
723
+ vision_cfg = CLIPVisionCfg(**model_cfg['vision_cfg'])
724
+
725
+ if vision_cfg.fusedLN:
726
+ try:
727
+ from apex.normalization import FusedLayerNorm
728
+ except:
729
+ FusedLayerNorm = LayerNorm
730
+ print("Please 'pip install apex'")
731
+ norm_layer = partial(FusedLayerNorm, eps=1e-6)
732
+ else:
733
+ norm_layer = partial(LayerNorm, eps=1e-6)
734
+
735
+ vision_tower = EVAVisionTransformer(
736
+ img_size = vision_cfg.image_size,
737
+ patch_size = vision_cfg.patch_size,
738
+ num_classes = model_cfg['embed_dim'],
739
+ use_mean_pooling = vision_cfg.global_average_pool, #False
740
+ init_values = vision_cfg.ls_init_value,
741
+ patch_dropout = vision_cfg.patch_dropout,
742
+ embed_dim = vision_cfg.width,
743
+ depth = vision_cfg.layers,
744
+ num_heads = vision_cfg.width // vision_cfg.head_width,
745
+ mlp_ratio = vision_cfg.mlp_ratio,
746
+ qkv_bias = vision_cfg.qkv_bias,
747
+ drop_path_rate = vision_cfg.drop_path_rate,
748
+ norm_layer = norm_layer,
749
+ xattn = vision_cfg.xattn,
750
+ rope = vision_cfg.rope,
751
+ postnorm = vision_cfg.postnorm,
752
+ pt_hw_seq_len = vision_cfg.pt_hw_seq_len, # 224/14
753
+ intp_freq = vision_cfg.intp_freq,
754
+ naiveswiglu = vision_cfg.naiveswiglu,
755
+ subln = vision_cfg.subln
756
+ )
757
+
758
+ if "fp16" in precision or "bf16" in precision:
759
+ logging.info(f'convert precision to {precision}')
760
+ vision_tower = vision_tower.to(torch.bfloat16) if 'bf16' in precision else vision_tower.to(torch.float16)
761
+
762
+ vision_tower.to(device=device)
763
+
764
+ # set image / mean metadata from pretrained_cfg if available, or use default
765
+ vision_tower.image_mean = (0.48145466, 0.4578275, 0.40821073)
766
+ vision_tower.image_std = (0.26862954, 0.26130258, 0.27577711)
767
+
768
+ return vision_tower