|
def extend_instance(obj, mixin): |
|
"""Apply mixins to a class instance after creation""" |
|
base_cls = obj.__class__ |
|
base_cls_name = obj.__class__.__name__ |
|
obj.__class__ = type( |
|
base_cls_name, (mixin, base_cls), {} |
|
) |
|
|
|
|
|
def getattr_recursive(obj, att): |
|
""" |
|
Return nested attribute of obj |
|
Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c |
|
""" |
|
if att == "": |
|
return obj |
|
i = att.find(".") |
|
if i < 0: |
|
return getattr(obj, att) |
|
else: |
|
return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :]) |
|
|
|
|
|
def setattr_recursive(obj, att, val): |
|
""" |
|
Set nested attribute of obj |
|
Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val |
|
""" |
|
if "." in att: |
|
obj = getattr_recursive(obj, ".".join(att.split(".")[:-1])) |
|
setattr(obj, att.split(".")[-1], val) |
|
|
|
|
|
def _infer_decoder_layers_attr_name(model): |
|
for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES: |
|
if k.lower() in model.__class__.__name__.lower(): |
|
return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k] |
|
|
|
raise ValueError( |
|
"We require the attribute name for the nn.ModuleList in the decoder storing" |
|
" the transformer block layers. Please supply this string manually." |
|
) |
|
|
|
|
|
__KNOWN_DECODER_LAYERS_ATTR_NAMES = { |
|
"llama": "model.layers", |
|
"mistral": "model.layers", |
|
} |
|
|
|
def resize_eva_pos_embed(state_dict, model, interpolation: str = "bicubic", seq_dim=1): |
|
|
|
if "pos_embed" in state_dict: |
|
pos_embed_checkpoint = state_dict["pos_embed"] |
|
embedding_size = pos_embed_checkpoint.shape[-1] |
|
num_patches = model.patch_embed.num_patches |
|
num_extra_tokens = model.pos_embed.shape[-2] - num_patches |
|
|
|
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) |
|
|
|
new_size = int(num_patches**0.5) |
|
|
|
if orig_size != new_size: |
|
print( |
|
"Position interpolate from %dx%d to %dx%d" |
|
% (orig_size, orig_size, new_size, new_size) |
|
) |
|
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] |
|
|
|
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] |
|
pos_tokens = pos_tokens.reshape( |
|
-1, orig_size, orig_size, embedding_size |
|
).permute(0, 3, 1, 2) |
|
|
|
pos_tokens = pos_tokens.float() |
|
|
|
pos_tokens = torch.nn.functional.interpolate( |
|
pos_tokens, |
|
size=(new_size, new_size), |
|
mode="bicubic", |
|
align_corners=False, |
|
) |
|
|
|
pos_tokens = pos_tokens.half() |
|
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) |
|
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) |
|
state_dict["pos_embed"] = new_pos_embed |
|
|
|
patch_embed_proj = state_dict["patch_embed.proj.weight"] |
|
patch_size = model.patch_embed.patch_size |
|
|
|
patch_embed_proj = patch_embed_proj.float() |
|
state_dict["patch_embed.proj.weight"] = torch.nn.functional.interpolate( |
|
patch_embed_proj.float(), |
|
size=patch_size, |
|
mode="bicubic", |
|
align_corners=False, |
|
) |
|
state_dict["patch_embed.proj.weight"] = state_dict["patch_embed.proj.weight"].half() |
|
|