import copy from doctest import ELLIPSIS_MARKER from functools import partial import json from turtle import forward, shape import einops import torch from torch import nn from mmcls.models.backbones.vision_transformer import TransformerEncoderLayer from transformers import GPT2Model, GPT2Config,GPT2LMHeadModel,GPTNeoForCausalLM,GPTNeoModel, \ BartModel, BartConfig, BartForCausalLM, BertForMaskedLM, AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer from transformers import BitsAndBytesConfig from peft import prepare_model_for_kbit_training from peft import LoraConfig from peft import get_peft_model from mmcv.cnn import build_norm_layer from mmcv.runner import BaseModule import math from ipdb import set_trace class mixEmbed(nn.Module): def __init__(self, lm_embed: nn.Embedding , audio_embeddings, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.lm_embed = lm_embed self.audio_embeddings = audio_embeddings # ugly but works without modifying raw model codes def forward(self, input_ids): text_ids = torch.clamp(input_ids.clone(), 0).long() au_ids = torch.clamp(-(input_ids.clone() + 1), 0).long() text_embeds = self.lm_embed(text_ids) au_embeds = self.audio_embeddings[au_ids] with torch.no_grad(): embed_mask = (input_ids > 0) mix_embeds = au_embeds.clone() mix_embeds[embed_mask] = text_embeds[embed_mask] return mix_embeds class LMDecoder(nn.Module): def __init__(self, # num_patches=196, img_size=(80,512), patch_size:int=16, in_chans:int=3, embed_dim=1024, # encoder embed dim decoder_embed_dim=512, norm_cfg=dict(type='LN', eps=1e-6), # patch_resolution=14, decoder_type='gpt2', freeze_decoder=True, additional_layer:int=0, ): super().__init__() self.decoder_type = decoder_type self.load_lm() self.lm_embed = self.lm.get_input_embeddings() try: self.lm_pos_embed = self.lm.get_position_embeddings() except NotImplementedError: self.lm_pos_embed = None # rotrary embeds if hasattr(self.lm,'embed_dim'): self.embed_dim = self.lm.embed_dim else: self.embed_dim = decoder_embed_dim # self.asLM = asLM # if generates tokens rather than hidden states # if self.asLM: # TODO: 当年写这个是为啥? # self.lm.set_output_embeddings(nn.Linear(self.embed_dim, self.self.LMconfig.vocab_size, bias=False)) self.freeze_decoder = False if True: for para in self.lm.parameters(): para.requires_grad = False def load_lm(self): ## ---------------------LM setting---------------------- self.tokenizer = AutoTokenizer.from_pretrained(self.decoder_type) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.LMconfig = AutoConfig.from_pretrained(self.decoder_type, token='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') self.lm = AutoModelForCausalLM.from_pretrained(self.decoder_type, token='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') def forward(self, input_ids, flatten_embs, attention_mask, labels, **kwargs): mix_embed = mixEmbed(self.lm_embed, flatten_embs) self.lm.set_input_embeddings(mix_embed) # modification of the lm embed output = self.lm(input_ids=input_ids, attention_mask=attention_mask, labels=labels, output_hidden_states=True, **kwargs) self.lm.set_input_embeddings(self.lm_embed) # modification of the lm embed return output def generate(self, input_ids, flatten_embs): mix_embed = mixEmbed(self.lm_embed, flatten_embs) self.lm.set_input_embeddings(mix_embed) # modification of the lm embed outputs = self.lm.generate(input_ids=input_ids, max_new_tokens=256, use_cache=False) # outputs = self.lm.generate(input_ids=input_ids, # max_new_tokens=1024, # do_sample=True, # temperature=1.5, # num_beams=1, # top_p=0.9, # top_k=3, # use_cache=False) self.lm.set_input_embeddings(self.lm_embed) # modification of the lm embed return outputs ''' ## infer params max_input_tokens: 40 batch_size_test: 16 max_new_tokens: 64 min_length: 2 num_beams: 5 length_penalty: -2.0 top_p: 0.9 top_k: 3 no_repeat_ngram_size: 2 apply_lemmatizer: False use_nucleus_sampling: True ''' class LMDecoder_qlora(LMDecoder): def __init__(self, # num_patches=196, img_size=(80,512), patch_size:int=16, in_chans:int=3, embed_dim=1024, # encoder embed dim decoder_embed_dim=512, norm_cfg=dict(type='LN', eps=1e-6), # patch_resolution=14, decoder_type='gpt2', freeze_decoder=True, additional_layer:int=0, ): super().__init__( img_size, patch_size, in_chans, embed_dim, decoder_embed_dim, norm_cfg, decoder_type, freeze_decoder, additional_layer) def load_lm(self): self.tokenizer = AutoTokenizer.from_pretrained(self.decoder_type) self.LMconfig = AutoConfig.from_pretrained(self.decoder_type, trust_remote_code=True ) double_quant_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, ) model = AutoModelForCausalLM.from_pretrained(self.decoder_type, # device_map='auto', # if remove, can not add lora # load_in_4bit=True,# if remove, can not add lora # # torch_dtype=torch.bfloat16, # quantization_config=double_quant_config, # if remove, can not add lora trust_remote_code=True ) model.gradient_checkpointing_enable() model = prepare_model_for_kbit_training(model) lora_config = LoraConfig( r=8, lora_alpha=32, target_modules=["query_key_value"], lora_dropout=0.05, bias="none", task_type="CAUSAL_LM" ) self.lm = get_peft_model(model, lora_config) self.lm.print_trainable_parameters()