Spaces:
Running
on
Zero
Running
on
Zero
SunderAli17
commited on
Create model.py
Browse files- eva_clip/model.py +439 -0
eva_clip/model.py
ADDED
@@ -0,0 +1,439 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" CLIP Model
|
2 |
+
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
|
3 |
+
"""
|
4 |
+
import os
|
5 |
+
from dataclasses import dataclass
|
6 |
+
from typing import Optional, Tuple, Union
|
7 |
+
from functools import partial
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from torch import nn
|
13 |
+
|
14 |
+
try:
|
15 |
+
from .hf_model import HFTextEncoder
|
16 |
+
except:
|
17 |
+
HFTextEncoder = None
|
18 |
+
from .modified_resnet import ModifiedResNet
|
19 |
+
# from .timm_model import TimmModel
|
20 |
+
from .eva_vit_model import EVAVisionTransformer
|
21 |
+
from .transformer import LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
|
22 |
+
|
23 |
+
try:
|
24 |
+
from apex.normalization import FusedLayerNorm
|
25 |
+
except:
|
26 |
+
FusedLayerNorm = LayerNorm
|
27 |
+
print("Please 'pip install apex'")
|
28 |
+
|
29 |
+
try:
|
30 |
+
import xformers.ops as xops
|
31 |
+
except ImportError:
|
32 |
+
xops = None
|
33 |
+
print("Please 'pip install xformers'")
|
34 |
+
|
35 |
+
@dataclass
|
36 |
+
class CLIPVisionCfg:
|
37 |
+
layers: Union[Tuple[int, int, int, int], int] = 12
|
38 |
+
width: int = 768
|
39 |
+
head_width: int = 64
|
40 |
+
mlp_ratio: float = 4.0
|
41 |
+
patch_size: int = 16
|
42 |
+
image_size: Union[Tuple[int, int], int] = 224
|
43 |
+
ls_init_value: Optional[float] = None # layer scale initial value
|
44 |
+
patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
|
45 |
+
global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
|
46 |
+
drop_path_rate: Optional[float] = None # drop path rate
|
47 |
+
timm_model_name: str = None # a valid model name overrides layers, width, patch_size
|
48 |
+
timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
|
49 |
+
timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
|
50 |
+
timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
|
51 |
+
timm_proj_bias: bool = False # enable bias final projection
|
52 |
+
eva_model_name: str = None # a valid eva model name overrides layers, width, patch_size
|
53 |
+
qkv_bias: bool = True
|
54 |
+
fusedLN: bool = False
|
55 |
+
xattn: bool = False
|
56 |
+
postnorm: bool = False
|
57 |
+
rope: bool = False
|
58 |
+
pt_hw_seq_len: int = 16 # 224/14
|
59 |
+
intp_freq: bool = False
|
60 |
+
naiveswiglu: bool = False
|
61 |
+
subln: bool = False
|
62 |
+
|
63 |
+
|
64 |
+
@dataclass
|
65 |
+
class CLIPTextCfg:
|
66 |
+
context_length: int = 77
|
67 |
+
vocab_size: int = 49408
|
68 |
+
width: int = 512
|
69 |
+
heads: int = 8
|
70 |
+
layers: int = 12
|
71 |
+
ls_init_value: Optional[float] = None # layer scale initial value
|
72 |
+
hf_model_name: str = None
|
73 |
+
hf_tokenizer_name: str = None
|
74 |
+
hf_model_pretrained: bool = True
|
75 |
+
proj: str = 'mlp'
|
76 |
+
pooler_type: str = 'mean_pooler'
|
77 |
+
masked_language_modeling: bool = False
|
78 |
+
fusedLN: bool = False
|
79 |
+
xattn: bool = False
|
80 |
+
attn_mask: bool = True
|
81 |
+
|
82 |
+
def get_cast_dtype(precision: str):
|
83 |
+
cast_dtype = None
|
84 |
+
if precision == 'bf16':
|
85 |
+
cast_dtype = torch.bfloat16
|
86 |
+
elif precision == 'fp16':
|
87 |
+
cast_dtype = torch.float16
|
88 |
+
return cast_dtype
|
89 |
+
|
90 |
+
|
91 |
+
def _build_vision_tower(
|
92 |
+
embed_dim: int,
|
93 |
+
vision_cfg: CLIPVisionCfg,
|
94 |
+
quick_gelu: bool = False,
|
95 |
+
cast_dtype: Optional[torch.dtype] = None
|
96 |
+
):
|
97 |
+
if isinstance(vision_cfg, dict):
|
98 |
+
vision_cfg = CLIPVisionCfg(**vision_cfg)
|
99 |
+
|
100 |
+
# OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
|
101 |
+
# memory efficient in recent PyTorch releases (>= 1.10).
|
102 |
+
# NOTE: timm models always use native GELU regardless of quick_gelu flag.
|
103 |
+
act_layer = QuickGELU if quick_gelu else nn.GELU
|
104 |
+
|
105 |
+
if vision_cfg.eva_model_name:
|
106 |
+
vision_heads = vision_cfg.width // vision_cfg.head_width
|
107 |
+
norm_layer = LayerNorm
|
108 |
+
|
109 |
+
visual = EVAVisionTransformer(
|
110 |
+
img_size=vision_cfg.image_size,
|
111 |
+
patch_size=vision_cfg.patch_size,
|
112 |
+
num_classes=embed_dim,
|
113 |
+
use_mean_pooling=vision_cfg.global_average_pool, #False
|
114 |
+
init_values=vision_cfg.ls_init_value,
|
115 |
+
patch_dropout=vision_cfg.patch_dropout,
|
116 |
+
embed_dim=vision_cfg.width,
|
117 |
+
depth=vision_cfg.layers,
|
118 |
+
num_heads=vision_heads,
|
119 |
+
mlp_ratio=vision_cfg.mlp_ratio,
|
120 |
+
qkv_bias=vision_cfg.qkv_bias,
|
121 |
+
drop_path_rate=vision_cfg.drop_path_rate,
|
122 |
+
norm_layer= partial(FusedLayerNorm, eps=1e-6) if vision_cfg.fusedLN else partial(norm_layer, eps=1e-6),
|
123 |
+
xattn=vision_cfg.xattn,
|
124 |
+
rope=vision_cfg.rope,
|
125 |
+
postnorm=vision_cfg.postnorm,
|
126 |
+
pt_hw_seq_len= vision_cfg.pt_hw_seq_len, # 224/14
|
127 |
+
intp_freq= vision_cfg.intp_freq,
|
128 |
+
naiveswiglu= vision_cfg.naiveswiglu,
|
129 |
+
subln= vision_cfg.subln
|
130 |
+
)
|
131 |
+
elif vision_cfg.timm_model_name:
|
132 |
+
# visual = TimmModel(
|
133 |
+
# vision_cfg.timm_model_name,
|
134 |
+
# pretrained=vision_cfg.timm_model_pretrained,
|
135 |
+
# pool=vision_cfg.timm_pool,
|
136 |
+
# proj=vision_cfg.timm_proj,
|
137 |
+
# proj_bias=vision_cfg.timm_proj_bias,
|
138 |
+
# embed_dim=embed_dim,
|
139 |
+
# image_size=vision_cfg.image_size
|
140 |
+
# )
|
141 |
+
# act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models
|
142 |
+
raise ValueError
|
143 |
+
elif isinstance(vision_cfg.layers, (tuple, list)):
|
144 |
+
vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
|
145 |
+
visual = ModifiedResNet(
|
146 |
+
layers=vision_cfg.layers,
|
147 |
+
output_dim=embed_dim,
|
148 |
+
heads=vision_heads,
|
149 |
+
image_size=vision_cfg.image_size,
|
150 |
+
width=vision_cfg.width
|
151 |
+
)
|
152 |
+
else:
|
153 |
+
vision_heads = vision_cfg.width // vision_cfg.head_width
|
154 |
+
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
|
155 |
+
visual = VisionTransformer(
|
156 |
+
image_size=vision_cfg.image_size,
|
157 |
+
patch_size=vision_cfg.patch_size,
|
158 |
+
width=vision_cfg.width,
|
159 |
+
layers=vision_cfg.layers,
|
160 |
+
heads=vision_heads,
|
161 |
+
mlp_ratio=vision_cfg.mlp_ratio,
|
162 |
+
ls_init_value=vision_cfg.ls_init_value,
|
163 |
+
patch_dropout=vision_cfg.patch_dropout,
|
164 |
+
global_average_pool=vision_cfg.global_average_pool,
|
165 |
+
output_dim=embed_dim,
|
166 |
+
act_layer=act_layer,
|
167 |
+
norm_layer=norm_layer,
|
168 |
+
)
|
169 |
+
|
170 |
+
return visual
|
171 |
+
|
172 |
+
|
173 |
+
def _build_text_tower(
|
174 |
+
embed_dim: int,
|
175 |
+
text_cfg: CLIPTextCfg,
|
176 |
+
quick_gelu: bool = False,
|
177 |
+
cast_dtype: Optional[torch.dtype] = None,
|
178 |
+
):
|
179 |
+
if isinstance(text_cfg, dict):
|
180 |
+
text_cfg = CLIPTextCfg(**text_cfg)
|
181 |
+
|
182 |
+
if text_cfg.hf_model_name:
|
183 |
+
text = HFTextEncoder(
|
184 |
+
text_cfg.hf_model_name,
|
185 |
+
output_dim=embed_dim,
|
186 |
+
tokenizer_name=text_cfg.hf_tokenizer_name,
|
187 |
+
proj=text_cfg.proj,
|
188 |
+
pooler_type=text_cfg.pooler_type,
|
189 |
+
masked_language_modeling=text_cfg.masked_language_modeling
|
190 |
+
)
|
191 |
+
else:
|
192 |
+
act_layer = QuickGELU if quick_gelu else nn.GELU
|
193 |
+
norm_layer = LayerNorm
|
194 |
+
|
195 |
+
text = TextTransformer(
|
196 |
+
context_length=text_cfg.context_length,
|
197 |
+
vocab_size=text_cfg.vocab_size,
|
198 |
+
width=text_cfg.width,
|
199 |
+
heads=text_cfg.heads,
|
200 |
+
layers=text_cfg.layers,
|
201 |
+
ls_init_value=text_cfg.ls_init_value,
|
202 |
+
output_dim=embed_dim,
|
203 |
+
act_layer=act_layer,
|
204 |
+
norm_layer= FusedLayerNorm if text_cfg.fusedLN else norm_layer,
|
205 |
+
xattn=text_cfg.xattn,
|
206 |
+
attn_mask=text_cfg.attn_mask,
|
207 |
+
)
|
208 |
+
return text
|
209 |
+
|
210 |
+
class CLIP(nn.Module):
|
211 |
+
def __init__(
|
212 |
+
self,
|
213 |
+
embed_dim: int,
|
214 |
+
vision_cfg: CLIPVisionCfg,
|
215 |
+
text_cfg: CLIPTextCfg,
|
216 |
+
quick_gelu: bool = False,
|
217 |
+
cast_dtype: Optional[torch.dtype] = None,
|
218 |
+
):
|
219 |
+
super().__init__()
|
220 |
+
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
|
221 |
+
|
222 |
+
text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
|
223 |
+
self.transformer = text.transformer
|
224 |
+
self.vocab_size = text.vocab_size
|
225 |
+
self.token_embedding = text.token_embedding
|
226 |
+
self.positional_embedding = text.positional_embedding
|
227 |
+
self.ln_final = text.ln_final
|
228 |
+
self.text_projection = text.text_projection
|
229 |
+
self.register_buffer('attn_mask', text.attn_mask, persistent=False)
|
230 |
+
|
231 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
232 |
+
|
233 |
+
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
|
234 |
+
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
|
235 |
+
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
|
236 |
+
|
237 |
+
@torch.jit.ignore
|
238 |
+
def set_grad_checkpointing(self, enable=True):
|
239 |
+
self.visual.set_grad_checkpointing(enable)
|
240 |
+
self.transformer.grad_checkpointing = enable
|
241 |
+
|
242 |
+
@torch.jit.ignore
|
243 |
+
def no_weight_decay(self):
|
244 |
+
return {'logit_scale'}
|
245 |
+
|
246 |
+
def encode_image(self, image, normalize: bool = False):
|
247 |
+
features = self.visual(image)
|
248 |
+
return F.normalize(features, dim=-1) if normalize else features
|
249 |
+
|
250 |
+
def encode_text(self, text, normalize: bool = False):
|
251 |
+
cast_dtype = self.transformer.get_cast_dtype()
|
252 |
+
|
253 |
+
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
|
254 |
+
|
255 |
+
x = x + self.positional_embedding.to(cast_dtype)
|
256 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
257 |
+
x = self.transformer(x, attn_mask=self.attn_mask)
|
258 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
259 |
+
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
|
260 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
261 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
262 |
+
return F.normalize(x, dim=-1) if normalize else x
|
263 |
+
|
264 |
+
def forward(self, image, text):
|
265 |
+
image_features = self.encode_image(image, normalize=True)
|
266 |
+
text_features = self.encode_text(text, normalize=True)
|
267 |
+
return image_features, text_features, self.logit_scale.exp()
|
268 |
+
|
269 |
+
|
270 |
+
class CustomCLIP(nn.Module):
|
271 |
+
def __init__(
|
272 |
+
self,
|
273 |
+
embed_dim: int,
|
274 |
+
vision_cfg: CLIPVisionCfg,
|
275 |
+
text_cfg: CLIPTextCfg,
|
276 |
+
quick_gelu: bool = False,
|
277 |
+
cast_dtype: Optional[torch.dtype] = None,
|
278 |
+
itm_task: bool = False,
|
279 |
+
):
|
280 |
+
super().__init__()
|
281 |
+
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
|
282 |
+
self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
|
283 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
284 |
+
|
285 |
+
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
|
286 |
+
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
|
287 |
+
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
|
288 |
+
|
289 |
+
def lock_text_tower(self, unlocked_layers:int=0, freeze_layer_norm:bool=True):
|
290 |
+
self.text.lock(unlocked_layers, freeze_layer_norm)
|
291 |
+
|
292 |
+
@torch.jit.ignore
|
293 |
+
def set_grad_checkpointing(self, enable=True):
|
294 |
+
self.visual.set_grad_checkpointing(enable)
|
295 |
+
self.text.set_grad_checkpointing(enable)
|
296 |
+
|
297 |
+
@torch.jit.ignore
|
298 |
+
def no_weight_decay(self):
|
299 |
+
return {'logit_scale'}
|
300 |
+
|
301 |
+
def encode_image(self, image, normalize: bool = False):
|
302 |
+
features = self.visual(image)
|
303 |
+
return F.normalize(features, dim=-1) if normalize else features
|
304 |
+
|
305 |
+
def encode_text(self, text, normalize: bool = False):
|
306 |
+
features = self.text(text)
|
307 |
+
return F.normalize(features, dim=-1) if normalize else features
|
308 |
+
|
309 |
+
def forward(self, image, text):
|
310 |
+
image_features = self.encode_image(image, normalize=True)
|
311 |
+
text_features = self.encode_text(text, normalize=True)
|
312 |
+
return image_features, text_features, self.logit_scale.exp()
|
313 |
+
|
314 |
+
|
315 |
+
def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
|
316 |
+
"""Convert applicable model parameters to low-precision (bf16 or fp16)"""
|
317 |
+
|
318 |
+
def _convert_weights(l):
|
319 |
+
|
320 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
321 |
+
l.weight.data = l.weight.data.to(dtype)
|
322 |
+
if l.bias is not None:
|
323 |
+
l.bias.data = l.bias.data.to(dtype)
|
324 |
+
|
325 |
+
if isinstance(l, (nn.MultiheadAttention, Attention)):
|
326 |
+
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
327 |
+
tensor = getattr(l, attr, None)
|
328 |
+
if tensor is not None:
|
329 |
+
tensor.data = tensor.data.to(dtype)
|
330 |
+
|
331 |
+
if isinstance(l, nn.Parameter):
|
332 |
+
l.data = l.data.to(dtype)
|
333 |
+
|
334 |
+
for name in ["text_projection", "proj"]:
|
335 |
+
if hasattr(l, name) and isinstance(l, nn.Parameter):
|
336 |
+
attr = getattr(l, name, None)
|
337 |
+
if attr is not None:
|
338 |
+
attr.data = attr.data.to(dtype)
|
339 |
+
|
340 |
+
model.apply(_convert_weights)
|
341 |
+
|
342 |
+
|
343 |
+
convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
|
344 |
+
|
345 |
+
|
346 |
+
# used to maintain checkpoint compatibility
|
347 |
+
def convert_to_custom_text_state_dict(state_dict: dict):
|
348 |
+
if 'text_projection' in state_dict:
|
349 |
+
# old format state_dict, move text tower -> .text
|
350 |
+
new_state_dict = {}
|
351 |
+
for k, v in state_dict.items():
|
352 |
+
if any(k.startswith(p) for p in (
|
353 |
+
'text_projection',
|
354 |
+
'positional_embedding',
|
355 |
+
'token_embedding',
|
356 |
+
'transformer',
|
357 |
+
'ln_final',
|
358 |
+
'logit_scale'
|
359 |
+
)):
|
360 |
+
k = 'text.' + k
|
361 |
+
new_state_dict[k] = v
|
362 |
+
return new_state_dict
|
363 |
+
return state_dict
|
364 |
+
|
365 |
+
|
366 |
+
def build_model_from_openai_state_dict(
|
367 |
+
state_dict: dict,
|
368 |
+
quick_gelu=True,
|
369 |
+
cast_dtype=torch.float16,
|
370 |
+
):
|
371 |
+
vit = "visual.proj" in state_dict
|
372 |
+
|
373 |
+
if vit:
|
374 |
+
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
375 |
+
vision_layers = len(
|
376 |
+
[k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
377 |
+
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
378 |
+
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
379 |
+
image_size = vision_patch_size * grid_size
|
380 |
+
else:
|
381 |
+
counts: list = [
|
382 |
+
len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
383 |
+
vision_layers = tuple(counts)
|
384 |
+
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
385 |
+
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
386 |
+
vision_patch_size = None
|
387 |
+
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
388 |
+
image_size = output_width * 32
|
389 |
+
|
390 |
+
embed_dim = state_dict["text_projection"].shape[1]
|
391 |
+
context_length = state_dict["positional_embedding"].shape[0]
|
392 |
+
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
393 |
+
transformer_width = state_dict["ln_final.weight"].shape[0]
|
394 |
+
transformer_heads = transformer_width // 64
|
395 |
+
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
|
396 |
+
|
397 |
+
vision_cfg = CLIPVisionCfg(
|
398 |
+
layers=vision_layers,
|
399 |
+
width=vision_width,
|
400 |
+
patch_size=vision_patch_size,
|
401 |
+
image_size=image_size,
|
402 |
+
)
|
403 |
+
text_cfg = CLIPTextCfg(
|
404 |
+
context_length=context_length,
|
405 |
+
vocab_size=vocab_size,
|
406 |
+
width=transformer_width,
|
407 |
+
heads=transformer_heads,
|
408 |
+
layers=transformer_layers
|
409 |
+
)
|
410 |
+
model = CLIP(
|
411 |
+
embed_dim,
|
412 |
+
vision_cfg=vision_cfg,
|
413 |
+
text_cfg=text_cfg,
|
414 |
+
quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
|
415 |
+
cast_dtype=cast_dtype,
|
416 |
+
)
|
417 |
+
|
418 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
419 |
+
state_dict.pop(key, None)
|
420 |
+
|
421 |
+
convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
|
422 |
+
model.load_state_dict(state_dict)
|
423 |
+
return model.eval()
|
424 |
+
|
425 |
+
|
426 |
+
def trace_model(model, batch_size=256, device=torch.device('cpu')):
|
427 |
+
model.eval()
|
428 |
+
image_size = model.visual.image_size
|
429 |
+
example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
|
430 |
+
example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
|
431 |
+
model = torch.jit.trace_module(
|
432 |
+
model,
|
433 |
+
inputs=dict(
|
434 |
+
forward=(example_images, example_text),
|
435 |
+
encode_text=(example_text,),
|
436 |
+
encode_image=(example_images,)
|
437 |
+
))
|
438 |
+
model.visual.image_size = image_size
|
439 |
+
return model
|