OpenJMLA / LMdecoder.py
sino
Update LMdecoder.py
5fb0bcb
raw
history blame
No virus
6.93 kB
import copy
from doctest import ELLIPSIS_MARKER
from functools import partial
import json
from turtle import forward, shape
import einops
import torch
from torch import nn
from mmcls.models.backbones.vision_transformer import TransformerEncoderLayer
from transformers import GPT2Model, GPT2Config,GPT2LMHeadModel,GPTNeoForCausalLM,GPTNeoModel, \
BartModel, BartConfig, BartForCausalLM, BertForMaskedLM, AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer
from transformers import BitsAndBytesConfig
from peft import prepare_model_for_kbit_training
from peft import LoraConfig
from peft import get_peft_model
from mmcv.cnn import build_norm_layer
from mmcv.runner import BaseModule
import math
from ipdb import set_trace
class mixEmbed(nn.Module):
def __init__(self, lm_embed: nn.Embedding , audio_embeddings, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.lm_embed = lm_embed
self.audio_embeddings = audio_embeddings # ugly but works without modifying raw model codes
def forward(self, input_ids):
text_ids = torch.clamp(input_ids.clone(), 0).long()
au_ids = torch.clamp(-(input_ids.clone() + 1), 0).long()
text_embeds = self.lm_embed(text_ids)
au_embeds = self.audio_embeddings[au_ids]
with torch.no_grad():
embed_mask = (input_ids > 0)
mix_embeds = au_embeds.clone()
mix_embeds[embed_mask] = text_embeds[embed_mask]
return mix_embeds
class LMDecoder(nn.Module):
def __init__(self,
# num_patches=196,
img_size=(80,512),
patch_size:int=16,
in_chans:int=3,
embed_dim=1024, # encoder embed dim
decoder_embed_dim=512,
norm_cfg=dict(type='LN', eps=1e-6),
# patch_resolution=14,
decoder_type='gpt2',
freeze_decoder=True,
additional_layer:int=0,
):
super().__init__()
self.decoder_type = decoder_type
self.load_lm()
self.lm_embed = self.lm.get_input_embeddings()
try:
self.lm_pos_embed = self.lm.get_position_embeddings()
except NotImplementedError:
self.lm_pos_embed = None # rotrary embeds
if hasattr(self.lm,'embed_dim'):
self.embed_dim = self.lm.embed_dim
else:
self.embed_dim = decoder_embed_dim
# self.asLM = asLM # if generates tokens rather than hidden states
# if self.asLM: # TODO: 当年写这个是为啥?
# self.lm.set_output_embeddings(nn.Linear(self.embed_dim, self.self.LMconfig.vocab_size, bias=False))
self.freeze_decoder = False
if True:
for para in self.lm.parameters():
para.requires_grad = False
def load_lm(self):
## ---------------------LM setting----------------------
self.tokenizer = AutoTokenizer.from_pretrained(self.decoder_type)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.LMconfig = AutoConfig.from_pretrained(self.decoder_type, token='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')
self.lm = AutoModelForCausalLM.from_pretrained(self.decoder_type, token='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')
def forward(self, input_ids, flatten_embs, attention_mask, labels, **kwargs):
mix_embed = mixEmbed(self.lm_embed, flatten_embs)
self.lm.set_input_embeddings(mix_embed) # modification of the lm embed
output = self.lm(input_ids=input_ids, attention_mask=attention_mask, labels=labels, output_hidden_states=True, **kwargs)
self.lm.set_input_embeddings(self.lm_embed) # modification of the lm embed
return output
def generate(self, input_ids, flatten_embs):
mix_embed = mixEmbed(self.lm_embed, flatten_embs)
self.lm.set_input_embeddings(mix_embed) # modification of the lm embed
outputs = self.lm.generate(input_ids=input_ids, max_new_tokens=256, use_cache=False)
# outputs = self.lm.generate(input_ids=input_ids,
# max_new_tokens=1024,
# do_sample=True,
# temperature=1.5,
# num_beams=1,
# top_p=0.9,
# top_k=3,
# use_cache=False)
self.lm.set_input_embeddings(self.lm_embed) # modification of the lm embed
return outputs
'''
## infer params
max_input_tokens: 40
batch_size_test: 16
max_new_tokens: 64
min_length: 2
num_beams: 5
length_penalty: -2.0
top_p: 0.9
top_k: 3
no_repeat_ngram_size: 2
apply_lemmatizer: False
use_nucleus_sampling: True
'''
class LMDecoder_qlora(LMDecoder):
def __init__(self,
# num_patches=196,
img_size=(80,512),
patch_size:int=16,
in_chans:int=3,
embed_dim=1024, # encoder embed dim
decoder_embed_dim=512,
norm_cfg=dict(type='LN', eps=1e-6),
# patch_resolution=14,
decoder_type='gpt2',
freeze_decoder=True,
additional_layer:int=0,
):
super().__init__( img_size, patch_size, in_chans, embed_dim, decoder_embed_dim, norm_cfg, decoder_type, freeze_decoder, additional_layer)
def load_lm(self):
self.tokenizer = AutoTokenizer.from_pretrained(self.decoder_type)
self.LMconfig = AutoConfig.from_pretrained(self.decoder_type, trust_remote_code=True )
double_quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(self.decoder_type,
# device_map='auto', # if remove, can not add lora
# load_in_4bit=True,# if remove, can not add lora
# # torch_dtype=torch.bfloat16,
# quantization_config=double_quant_config, # if remove, can not add lora
trust_remote_code=True )
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)
lora_config = LoraConfig(
r=8,
lora_alpha=32,
target_modules=["query_key_value"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
self.lm = get_peft_model(model, lora_config)
self.lm.print_trainable_parameters()