File size: 6,930 Bytes
ff4fdee 5fb0bcb ff4fdee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
import copy
from doctest import ELLIPSIS_MARKER
from functools import partial
import json
from turtle import forward, shape
import einops
import torch
from torch import nn
from mmcls.models.backbones.vision_transformer import TransformerEncoderLayer
from transformers import GPT2Model, GPT2Config,GPT2LMHeadModel,GPTNeoForCausalLM,GPTNeoModel, \
BartModel, BartConfig, BartForCausalLM, BertForMaskedLM, AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer
from transformers import BitsAndBytesConfig
from peft import prepare_model_for_kbit_training
from peft import LoraConfig
from peft import get_peft_model
from mmcv.cnn import build_norm_layer
from mmcv.runner import BaseModule
import math
from ipdb import set_trace
class mixEmbed(nn.Module):
def __init__(self, lm_embed: nn.Embedding , audio_embeddings, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.lm_embed = lm_embed
self.audio_embeddings = audio_embeddings # ugly but works without modifying raw model codes
def forward(self, input_ids):
text_ids = torch.clamp(input_ids.clone(), 0).long()
au_ids = torch.clamp(-(input_ids.clone() + 1), 0).long()
text_embeds = self.lm_embed(text_ids)
au_embeds = self.audio_embeddings[au_ids]
with torch.no_grad():
embed_mask = (input_ids > 0)
mix_embeds = au_embeds.clone()
mix_embeds[embed_mask] = text_embeds[embed_mask]
return mix_embeds
class LMDecoder(nn.Module):
def __init__(self,
# num_patches=196,
img_size=(80,512),
patch_size:int=16,
in_chans:int=3,
embed_dim=1024, # encoder embed dim
decoder_embed_dim=512,
norm_cfg=dict(type='LN', eps=1e-6),
# patch_resolution=14,
decoder_type='gpt2',
freeze_decoder=True,
additional_layer:int=0,
):
super().__init__()
self.decoder_type = decoder_type
self.load_lm()
self.lm_embed = self.lm.get_input_embeddings()
try:
self.lm_pos_embed = self.lm.get_position_embeddings()
except NotImplementedError:
self.lm_pos_embed = None # rotrary embeds
if hasattr(self.lm,'embed_dim'):
self.embed_dim = self.lm.embed_dim
else:
self.embed_dim = decoder_embed_dim
# self.asLM = asLM # if generates tokens rather than hidden states
# if self.asLM: # TODO: 当年写这个是为啥?
# self.lm.set_output_embeddings(nn.Linear(self.embed_dim, self.self.LMconfig.vocab_size, bias=False))
self.freeze_decoder = False
if True:
for para in self.lm.parameters():
para.requires_grad = False
def load_lm(self):
## ---------------------LM setting----------------------
self.tokenizer = AutoTokenizer.from_pretrained(self.decoder_type)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.LMconfig = AutoConfig.from_pretrained(self.decoder_type, token='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')
self.lm = AutoModelForCausalLM.from_pretrained(self.decoder_type, token='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')
def forward(self, input_ids, flatten_embs, attention_mask, labels, **kwargs):
mix_embed = mixEmbed(self.lm_embed, flatten_embs)
self.lm.set_input_embeddings(mix_embed) # modification of the lm embed
output = self.lm(input_ids=input_ids, attention_mask=attention_mask, labels=labels, output_hidden_states=True, **kwargs)
self.lm.set_input_embeddings(self.lm_embed) # modification of the lm embed
return output
def generate(self, input_ids, flatten_embs):
mix_embed = mixEmbed(self.lm_embed, flatten_embs)
self.lm.set_input_embeddings(mix_embed) # modification of the lm embed
outputs = self.lm.generate(input_ids=input_ids, max_new_tokens=256, use_cache=False)
# outputs = self.lm.generate(input_ids=input_ids,
# max_new_tokens=1024,
# do_sample=True,
# temperature=1.5,
# num_beams=1,
# top_p=0.9,
# top_k=3,
# use_cache=False)
self.lm.set_input_embeddings(self.lm_embed) # modification of the lm embed
return outputs
'''
## infer params
max_input_tokens: 40
batch_size_test: 16
max_new_tokens: 64
min_length: 2
num_beams: 5
length_penalty: -2.0
top_p: 0.9
top_k: 3
no_repeat_ngram_size: 2
apply_lemmatizer: False
use_nucleus_sampling: True
'''
class LMDecoder_qlora(LMDecoder):
def __init__(self,
# num_patches=196,
img_size=(80,512),
patch_size:int=16,
in_chans:int=3,
embed_dim=1024, # encoder embed dim
decoder_embed_dim=512,
norm_cfg=dict(type='LN', eps=1e-6),
# patch_resolution=14,
decoder_type='gpt2',
freeze_decoder=True,
additional_layer:int=0,
):
super().__init__( img_size, patch_size, in_chans, embed_dim, decoder_embed_dim, norm_cfg, decoder_type, freeze_decoder, additional_layer)
def load_lm(self):
self.tokenizer = AutoTokenizer.from_pretrained(self.decoder_type)
self.LMconfig = AutoConfig.from_pretrained(self.decoder_type, trust_remote_code=True )
double_quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(self.decoder_type,
# device_map='auto', # if remove, can not add lora
# load_in_4bit=True,# if remove, can not add lora
# # torch_dtype=torch.bfloat16,
# quantization_config=double_quant_config, # if remove, can not add lora
trust_remote_code=True )
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)
lora_config = LoraConfig(
r=8,
lora_alpha=32,
target_modules=["query_key_value"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
self.lm = get_peft_model(model, lora_config)
self.lm.print_trainable_parameters()
|