MGM / minigemini /model /multimodal_encoder /openclip_encoder.py
wcy1122's picture
update code
35153f6
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import json
import logging
import deepspeed
from pathlib import Path
from open_clip.factory import load_state_dict, get_model_config
from open_clip.model import CLIPVisionCfg, CLIPTextCfg, _build_vision_tower, convert_to_custom_text_state_dict, resize_pos_embed
from typing import Dict, Optional
from transformers.deepspeed import deepspeed_config, is_deepspeed_zero3_enabled
open_clip_config = {
"model_cfg": {
"embed_dim": 768,
"vision_cfg": {
"timm_model_name": "convnext_large",
"timm_model_pretrained": False,
"timm_pool": "",
"timm_proj": "mlp",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 320
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 16
}
},
"preprocess_cfg": {
"mean": [
0.48145466,
0.4578275,
0.40821073
],
"std": [
0.26862954,
0.26130258,
0.27577711
]
}
}
# xxx
class OpenCLIPVisionTower(nn.Module):
def __init__(self, vision_tower, args, delay_load=False):
super().__init__()
self.is_loaded = False
self.vision_tower_name = vision_tower
self.vision_config = open_clip_config
# json.load(open(os.path.join(vision_tower,'open_clip_config.json'), 'r'))
self.is_optimize = getattr(args, 'optimize_vision_tower_aux', False)
if not delay_load:
self.load_model()
def load_model(self):
# print(self.vision_tower_name)
ckpt_path = os.path.join(self.vision_tower_name, 'open_clip_pytorch_model.bin')
if 'convnext' in self.vision_tower_name:
if 'large' in self.vision_tower_name and 'd_320' in self.vision_tower_name:
self.model_type = 'convnext_large_d_320'
self.model_channel = [192, 384, 768, 1536] # stage 0-3
elif 'base' in self.vision_tower_name and 'w_320' in self.vision_tower_name:
self.model_type = 'convnext_base_w_320'
self.model_channel = [128, 256, 512, 1024]
elif 'xxlarge' in self.vision_tower_name:
self.model_type = 'convnext_xxlarge'
self.model_channel = [384, 768, 1536, 3072]
clip_model = CLIP(**get_model_config(self.model_type))
clip_model.visual.trunk.norm_pre = None
clip_model.visual.trunk.head = None
clip_model.visual.head = None
print(f'Loading pretrained weights ({self.model_type}).')
load_checkpoint(clip_model, ckpt_path, strict=False)
self.is_loaded = True
# decompose stem and stages blocks in vision tower
self.vision_stem = clip_model.visual.trunk.stem
self.vision_stages = clip_model.visual.trunk.stages
self.vision_stem.requires_grad_(False)
self.vision_stages.requires_grad_(False)
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_feature = self.backbone(image.to(device=self.device, dtype=self.dtype).unsqueeze(0))
image_features.append(image_feature)
else:
image_features = self.backbone(images.to(device=self.device, dtype=self.dtype))
return image_features
def backbone(self, images):
if not self.is_optimize:
with torch.no_grad():
results = self.basic_forward(images)
else:
results = self.basic_forward(images)
target_size = (results['stage_0'].shape[-2], results['stage_0'].shape[-1])
result_cat = []
for _stage in results:
if _stage == 'stage_0':
result_cat.append(results[_stage].contiguous())
else:
result_cat.append(F.interpolate(results[_stage].float().contiguous() ,
size=target_size,
mode='bilinear',
align_corners=False).to(dtype=results[_stage].dtype))
result_cat = torch.cat(result_cat, dim=1)
return result_cat.contiguous()
def basic_forward(self, images):
results = {}
x = self.vision_stem(images)
for _idx in range(len(self.vision_stages)):
x = self.vision_stages[_idx](x)
results[f'stage_{_idx}'] = x
return results
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return self.vision_stem[0].weight.dtype
@property
def device(self):
return self.vision_stem[0].weight.device
@property
def config(self):
return self.vision_config
@property
def hidden_size(self):
return sum(self.model_channel)
# modified function from open_clip to support zero3 stage
def load_checkpoint(model, checkpoint_path, strict=True):
if Path(checkpoint_path).suffix in ('.npz', '.npy'):
from open_clip.big_vision import load_big_vision_weights
load_big_vision_weights(model, checkpoint_path)
return {}
state_dict = load_state_dict(checkpoint_path)
# detect old format and make compatible with new format
if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
state_dict = convert_to_custom_text_state_dict(state_dict)
# If loading a non-SigLIP model for SigLIP training. See https://github.com/mlfoundations/open_clip/issues/712
# if 'logit_bias' not in state_dict and model.logit_bias is not None:
# state_dict["logit_bias"] = torch.zeros_like(state_dict["logit_scale"])
# Certain text transformers no longer expect position_ids after transformers==4.31
position_id_key = 'text.transformer.embeddings.position_ids'
if position_id_key in state_dict and not hasattr(model, position_id_key):
del state_dict[position_id_key]
resize_pos_embed(state_dict, model)
# resize_text_pos_embed(state_dict, model)
#incompatible_keys = model.load_state_dict(state_dict, strict=strict)
if is_deepspeed_zero3_enabled():
error_msgs = []
def load(module: nn.Module, state_dict, prefix=""):
metadata = None
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
# Parameters of module and children will start with prefix. We can exit early if there are none in this
# state_dict
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
if is_deepspeed_zero3_enabled():
# In sharded models, each shard has only part of the full state_dict, so only gather
# parameters that are in the current state_dict.
named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False))
params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters]
if len(params_to_gather) > 0:
# because zero3 puts placeholders in model params, this context
# manager gathers (unpartitions) the params of the current layer, then loads from
# the state dict and then re-partitions them again
with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0):
if torch.distributed.get_rank() == 0:
module._load_from_state_dict(*args)
else:
module._load_from_state_dict(*args)
for name, child in module._modules.items():
if child is not None:
load(child, state_dict, prefix + name + ".")
load(model, state_dict)
incompatible_keys = []
else:
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
logging.info(f"incompatible_keys.missing_keys: {incompatible_keys.missing_keys}")
return incompatible_keys
class CLIP(nn.Module):
output_dict: torch.jit.Final[bool]
def __init__(
self,
embed_dim: int,
vision_cfg: CLIPVisionCfg,
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
output_dict: bool = False,
):
super().__init__()
self.output_dict = output_dict
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)