|
import copy |
|
from doctest import ELLIPSIS_MARKER |
|
from functools import partial |
|
import json |
|
from turtle import forward, shape |
|
import einops |
|
import torch |
|
from torch import nn |
|
|
|
from mmcls.models.backbones.vision_transformer import TransformerEncoderLayer |
|
from transformers import GPT2Model, GPT2Config,GPT2LMHeadModel,GPTNeoForCausalLM,GPTNeoModel, \ |
|
BartModel, BartConfig, BartForCausalLM, BertForMaskedLM, AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer |
|
from transformers import BitsAndBytesConfig |
|
|
|
from peft import prepare_model_for_kbit_training |
|
from peft import LoraConfig |
|
from peft import get_peft_model |
|
|
|
|
|
from mmcv.cnn import build_norm_layer |
|
from mmcv.runner import BaseModule |
|
import math |
|
from ipdb import set_trace |
|
|
|
class mixEmbed(nn.Module): |
|
def __init__(self, lm_embed: nn.Embedding , audio_embeddings, *args, **kwargs) -> None: |
|
super().__init__(*args, **kwargs) |
|
self.lm_embed = lm_embed |
|
self.audio_embeddings = audio_embeddings |
|
|
|
def forward(self, input_ids): |
|
text_ids = torch.clamp(input_ids.clone(), 0).long() |
|
|
|
au_ids = torch.clamp(-(input_ids.clone() + 1), 0).long() |
|
text_embeds = self.lm_embed(text_ids) |
|
au_embeds = self.audio_embeddings[au_ids] |
|
with torch.no_grad(): |
|
embed_mask = (input_ids > 0) |
|
mix_embeds = au_embeds.clone() |
|
mix_embeds[embed_mask] = text_embeds[embed_mask] |
|
return mix_embeds |
|
|
|
|
|
class LMDecoder(nn.Module): |
|
def __init__(self, |
|
|
|
img_size=(80,512), |
|
patch_size:int=16, |
|
in_chans:int=3, |
|
embed_dim=1024, |
|
decoder_embed_dim=512, |
|
norm_cfg=dict(type='LN', eps=1e-6), |
|
|
|
decoder_type='gpt2', |
|
freeze_decoder=True, |
|
additional_layer:int=0, |
|
): |
|
super().__init__() |
|
self.decoder_type = decoder_type |
|
self.load_lm() |
|
|
|
self.lm_embed = self.lm.get_input_embeddings() |
|
try: |
|
self.lm_pos_embed = self.lm.get_position_embeddings() |
|
except NotImplementedError: |
|
self.lm_pos_embed = None |
|
|
|
|
|
if hasattr(self.lm,'embed_dim'): |
|
self.embed_dim = self.lm.embed_dim |
|
else: |
|
self.embed_dim = decoder_embed_dim |
|
|
|
|
|
|
|
|
|
self.freeze_decoder = False |
|
if True: |
|
for para in self.lm.parameters(): |
|
para.requires_grad = False |
|
|
|
def load_lm(self): |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.decoder_type) |
|
if self.tokenizer.pad_token is None: |
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
self.LMconfig = AutoConfig.from_pretrained(self.decoder_type, token='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') |
|
self.lm = AutoModelForCausalLM.from_pretrained(self.decoder_type, token='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') |
|
|
|
|
|
def forward(self, input_ids, flatten_embs, attention_mask, labels, **kwargs): |
|
mix_embed = mixEmbed(self.lm_embed, flatten_embs) |
|
self.lm.set_input_embeddings(mix_embed) |
|
output = self.lm(input_ids=input_ids, attention_mask=attention_mask, labels=labels, output_hidden_states=True, **kwargs) |
|
self.lm.set_input_embeddings(self.lm_embed) |
|
return output |
|
|
|
def generate(self, input_ids, flatten_embs): |
|
mix_embed = mixEmbed(self.lm_embed, flatten_embs) |
|
self.lm.set_input_embeddings(mix_embed) |
|
outputs = self.lm.generate(input_ids=input_ids, max_new_tokens=256, use_cache=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.lm.set_input_embeddings(self.lm_embed) |
|
return outputs |
|
''' |
|
## infer params |
|
max_input_tokens: 40 |
|
batch_size_test: 16 |
|
max_new_tokens: 64 |
|
min_length: 2 |
|
num_beams: 5 |
|
length_penalty: -2.0 |
|
top_p: 0.9 |
|
top_k: 3 |
|
no_repeat_ngram_size: 2 |
|
apply_lemmatizer: False |
|
use_nucleus_sampling: True |
|
''' |
|
|
|
class LMDecoder_qlora(LMDecoder): |
|
def __init__(self, |
|
|
|
img_size=(80,512), |
|
patch_size:int=16, |
|
in_chans:int=3, |
|
embed_dim=1024, |
|
decoder_embed_dim=512, |
|
norm_cfg=dict(type='LN', eps=1e-6), |
|
|
|
decoder_type='gpt2', |
|
freeze_decoder=True, |
|
additional_layer:int=0, |
|
): |
|
super().__init__( img_size, patch_size, in_chans, embed_dim, decoder_embed_dim, norm_cfg, decoder_type, freeze_decoder, additional_layer) |
|
|
|
def load_lm(self): |
|
self.tokenizer = AutoTokenizer.from_pretrained(self.decoder_type) |
|
self.LMconfig = AutoConfig.from_pretrained(self.decoder_type, trust_remote_code=True ) |
|
double_quant_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_use_double_quant=True, |
|
) |
|
model = AutoModelForCausalLM.from_pretrained(self.decoder_type, |
|
|
|
|
|
|
|
|
|
trust_remote_code=True ) |
|
|
|
model.gradient_checkpointing_enable() |
|
model = prepare_model_for_kbit_training(model) |
|
lora_config = LoraConfig( |
|
r=8, |
|
lora_alpha=32, |
|
target_modules=["query_key_value"], |
|
lora_dropout=0.05, |
|
bias="none", |
|
task_type="CAUSAL_LM" |
|
) |
|
|
|
self.lm = get_peft_model(model, lora_config) |
|
self.lm.print_trainable_parameters() |
|
|