import logging import argparse import os.path import numpy as np import torch from torch import nn from transformers import AutoConfig from model.base_model import CLIPModel from model.process_clip import add_time_attn_block, convert_model_to_lora, set_global_value, resize_pos from open_clip import convert_weights_to_lp from open_clip.transformer import PatchDropout from training.distributed import is_master def SET_GLOBAL_VALUE(k, v): set_global_value(k, v) def create_vat_model(args): config = AutoConfig.from_pretrained(args.model, cache_dir=args.cache_dir) model = CLIPModel(config, args.num_frames, args.add_time_attn) model.vision_model.patch_dropout = PatchDropout(args.force_patch_dropout) device = args.device precision = args.precision if precision in ("fp16", "bf16"): dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 model.to(device=device) convert_weights_to_lp(model, dtype=dtype) elif precision in ("pure_fp16", "pure_bf16"): dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 model.to(device=device, dtype=dtype) else: model.to(device=device) if args.pretrained: try: args.pretrained = os.path.join(args.cache_dir, args.pretrained) if is_master(args): logging.info(f'Loading pretrained {args.model} weights ({args.pretrained}).') # incompatible_keys = load_checkpoint(model, pretrained, strict=False) ckpt = torch.load(args.pretrained, map_location='cpu') incompatible_keys = model.load_state_dict(ckpt, strict=False if args.add_time_attn else True) if is_master(args): logging.info(incompatible_keys) except Exception as e: if is_master(args): logging.info(f"Failed loading pretrained model with {e}") else: if is_master(args): logging.info(f"No pretrained model to load in \'{args.pretrained}\'") if args.add_time_attn: add_time_attn_block(model.vision_model.encoder, device=device) if is_master(args): logging.info(f'Convert spatial attention to time attention pretrained.') if args.clip_type == 'al': resize_pos(model.vision_model.embeddings, args) if is_master(args): logging.info(f'Resize to position embedding successfully.') if args.init_temp != 0: with torch.no_grad(): model.logit_scale.fill_(np.log(1 / float(args.init_temp))) if is_master(args): logging.info(f'Reset logit scale to {args.init_temp} (log-scale) and trainable {args.learn_temp}.') if args.convert_to_lora: convert_model_to_lora(args, model) if is_master(args): logging.info(f"Successfuly convert model to lora style.") # if output_dict and hasattr(model, "output_dict"): # model.output_dict = True return model if __name__ == '__main__': MODEL_DICT = {"ViT-L-14": "laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K", "ViT-H-14": "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"} CHECKPOINT_DICT = {"ViT-L-14": "models--laion--CLIP-ViT-L-14-DataComp.XL-s13B-b90K/snapshots/84c9828e63dc9a9351d1fe637c346d4c1c4db341/pytorch_model.bin", "ViT-H-14": "models--laion--CLIP-ViT-H-14-laion2B-s32B-b79K/snapshots/94a64189c3535c1cb44acfcccd7b0908c1c8eb23/pytorch_model.bin"} parser = argparse.ArgumentParser() args = parser.parse_args() args.pretrained = True args.model = MODEL_DICT["ViT-L-14"] args.pretrained = CHECKPOINT_DICT["ViT-L-14"] args.cache_dir = 'D:\Omni-modal-valdt-1kw' args.device = 'cpu' args.precision = None args.lock_text = True args.lock_image = True args.init_temp = 0 args.force_patch_dropout = 0.5 args.add_time_attn = True args.convert_to_lora = True args.lora_r = 16 args.lora_alpha = 16 args.lora_dropout = 0.0 # 0.1? args.num_frames = 8 args.clip_type = 'vl' args.num_mel_bins = 128 args.target_length = 1024 args.audio_sample_rate = 16000 args.audio_mean = 1 args.audio_std = 1 args.rank = 0 SET_GLOBAL_VALUE('PATCH_DROPOUT', args.force_patch_dropout) SET_GLOBAL_VALUE('NUM_FRAMES', args.num_frames) model = create_vat_model(args) '''方法1,自定义函数 参考自 https://blog.csdn.net/qq_33757398/article/details/109210240''' def model_structure(model): blank = ' ' print('-' * 150) print('|' + ' ' * 44 + 'weight name' + ' ' * 45 + '|' \ + ' ' * 10 + 'weight shape' + ' ' * 10 + '|' \ + ' ' * 3 + 'number' + ' ' * 3 + '|') print('-' * 150) num_para = 0 type_size = 1 # 如果是浮点数就是4 for index, (key, w_variable) in enumerate(model.named_parameters()): if len(key) <= 100: key = key + (100 - len(key)) * blank shape = str(w_variable.shape) if len(shape) <= 30: shape = shape + (30 - len(shape)) * blank each_para = 1 for k in w_variable.shape: each_para *= k num_para += each_para str_num = str(each_para) if len(str_num) <= 10: str_num = str_num + (10 - len(str_num)) * blank print('| {} | {} | {} |'.format(key, shape, str_num)) print('-' * 150) print('The total number of parameters: ' + str(num_para)) print('The parameters of Model {}: {:4f}M'.format(model._get_name(), num_para * type_size / 1000 / 1000)) print('-' * 150) model_structure(model) # model_structure(model.vision_model) # model_structure(model.text_model) # model.lock_image_tower(unlocked_groups=1) # model.lock_text_tower(unlocked_layers=0) # model.unlock_time_attn() if args.lock_image: # if args.clip_type == 'al' or args.clip_type == 'dl': # for param in model.vision_model.embeddings.parameters(): # param.requires_grad = True # for param in model.vision_model.pre_layrnorm.parameters(): # param.requires_grad = True # else: for param in model.vision_model.embeddings.parameters(): param.requires_grad = False for param in model.vision_model.pre_layrnorm.parameters(): param.requires_grad = False for param in model.vision_model.embeddings.position_embedding.parameters(): param.requires_grad = False model.vision_model.embeddings.class_embedding.requires_grad = True if args.lock_text: for param in model.text_model.parameters(): param.requires_grad = False for param in model.text_projection.parameters(): param.requires_grad = False for n, p in model.named_parameters(): # if p.requires_grad: print(n, '--->', p.requires_grad) b, c, t, h, w = 2, 3, args.num_frames, 224, 224 x = torch.randn(b, c, t, h, w) y = model(image=x) print()