import json import os import pdb from mmcv.cnn.bricks import padding import torch from torch import nn, einsum from typing import Optional, Dict, Tuple from .mae_vit import MAEViT from .htsat import HTSAT_Swin_Transformer, create_htsat_model from .LMdecoder import LMDecoder, LMDecoder_qlora from .vision_transformer import VisionTransformer from einops import rearrange, repeat from einops_exts import rearrange_many import inspect from transformers.modeling_utils import PreTrainedModel from .configuration_maelm import MAELMConfig class ArgsHandler: def __init__(self, module, funcname, fargs, fkargs): self.fargs = list(fargs) self.fkargs = fkargs func = getattr(module, funcname) fal_repr = f"{funcname}_argnames_list" if (argns_list:=getattr(module, fal_repr, None)) is None: self.func_sig = inspect.signature(func) self.argnames_list = list(self.func_sig.parameters.keys()) setattr(module, fal_repr, self.argnames_list) else: self.argnames_list = argns_list def get_arg(self, arg_name): if arg_name in self.fkargs: arg = self.fkargs[arg_name] else: arg = self.fargs[self.argnames_list.index(arg_name)] return arg def set_arg(self, arg_name, arg_value): if arg_name in self.fkargs: self.fkargs[arg_name] = arg_value else: self.fargs[self.argnames_list.index(arg_name)] = arg_value def return_all_args(self,): return tuple(self.fargs), self.fkargs class SquaredReLU(nn.Module): """ squared ReLU activation function""" def __init__(self): super().__init__() def forward(self, x): return torch.pow(torch.relu(x), 2) def FeedForward(dim, out_dim, mult=4, act='gelu'): """ lucidrains implementation, slightly modified with the act parameter. """ acts = dict( gelu=nn.GELU, sqrelu=SquaredReLU, relu=nn.ReLU ) assert act in acts, f"act. can only be one of {acts.keys()}" inner_dim = int(dim * mult) return nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, inner_dim, bias=False), acts[act](), nn.Linear(inner_dim, out_dim, bias=False) ) class PerceiverAttentionLayer(nn.Module): def __init__( self, *, feat_dim, latent_dim, dim_head=64, heads=8 ): super().__init__() self.scale = dim_head ** -0.5 self.heads = heads self.dim_head = dim_head inner_dim = dim_head * heads # trainable components of PerceiverAttentionLayer self.norm_media = nn.LayerNorm(feat_dim) self.norm_latents = nn.LayerNorm(latent_dim) self.to_q = nn.Linear(latent_dim, inner_dim, bias=False) self.to_k = nn.Linear(feat_dim, inner_dim, bias=False) self.to_v = nn.Linear(feat_dim, inner_dim, bias=False) self.to_out = nn.Linear(inner_dim, latent_dim, bias=False) def forward(self, features, latents): """ Latent vectors are cross-attending to the visual features x. :param x: Tensor (n_batch, n_features, dim) visual features :param latents: Tensor (n_batch, n_latents, dim) latent learnt vectors from which the queries are computed. Actually the same, just replicated in n_batch and n_frames dimension. :return: Tensor (n_batch, n_latents, dim) """ assert features.ndim == 3 assert latents.ndim == 3 assert features.shape[0] == latents.shape[0] #assert features.shape[2] == latents.shape[2] n_heads = self.heads n_batch, n_features, dim = features.shape n_queries = latents.shape[1] # layer normalization, as usual x = self.norm_media(features) latents = self.norm_latents(latents) # queries # compute the queries from the latents, for all attention heads simultaneously. q = self.to_q(latents) q = rearrange(q, 'b q (h d) -> b h q d', h=n_heads) assert q.shape == torch.Size([n_batch, n_heads, n_queries, self.dim_head]) # keys and values for all attention heads ''' kv_input = torch.cat((x, latents), dim=-2) n_features_latents = n_features + n_queries ''' kv_input = x n_features_latents = n_features # keys, values k = self.to_k(kv_input) v = self.to_v(kv_input) # batch, features, (heads, dim) # split so we have an extra dimension for the heads # q, k, v = rearrange_many((q, k, v), 'b t n (h d) -> b h t n d', h=h) k, v = rearrange_many((k, v), 'b f (h d) -> b h f d', h=n_heads) assert v.shape == torch.Size([n_batch, n_heads, n_features_latents, self.dim_head]) # scale queries? q = q * self.scale # attention # attention scores # sim = einsum('... i d, ... j d -> ... i j', q, k) sim = einsum('b h q d, b h f d -> b h q f', q, k) # Is this for numerical stability? Does not affect the result of the softmax operation sim = sim - sim.amax(dim=-1, keepdim=True).detach() alphas = sim.softmax(dim=-1) # out = einsum('... i j, ... j d -> ... i d', alphas, v) out = einsum('b h q f, b h f v -> b h q v', alphas, v) # out = rearrange(out, 'b h t n d -> b t n (h d)', h=h) out = rearrange(out, 'b h q v -> b q (h v)') return self.to_out(out) class MAEForCausalLM(PreTrainedModel): """ Args: backbone (dict): Config dict for encoder. Defaults to None. neck (dict): Config dict for encoder. Defaults to None. head (dict): Config dict for loss functions. Defaults to None. init_cfg (dict, optional): Config dict for weight initialization. Defaults to None. """ config_class = MAELMConfig def __init__(self, config: MAELMConfig) -> None: super().__init__(config) backbone = config.backbone assert backbone is not None bk_name = backbone.pop('name') self.bk_name = bk_name if bk_name == 'MAEViT': ckpt_path = backbone.pop('ckpt') if 'ckpt' in backbone else None self.backbone = MAEViT(**backbone) #if ckpt_path is not None: # ckpt = torch.load( ckpt_path,'cpu') # self.backbone.load_state_dict(ckpt['state_dict']) elif bk_name == 'HTSAT': ckpt_path = backbone.pop('ckpt') if 'ckpt' in backbone else None self.backbone = create_htsat_model(backbone) if ckpt_path is not None: ckpt = torch.load( ckpt_path,'cpu') self.backbone.load_state_dict(ckpt['state_dict']) elif bk_name == 'qformer': raise NotImplemented else: raise NotImplemented # neck["num_patches"] = self.backbone.num_patches # neck["patch_resolution"] = self.backbone.patch_resolution neck = config.neck assert neck is not None nk_name = neck.pop('name') if nk_name == 'LMDecoder': self.neck = LMDecoder(**neck) elif nk_name == 'LMDecoder_qlora': self.neck = LMDecoder_qlora(**neck) else: raise NotImplemented self.config = self.neck.LMconfig # TODO ''' self.ae_proj = nn.Linear( 768, self.config.hidden_size ) ''' ## TODO #self.neck.lm.apply(lambda m:m.gradient_checkpointing=True) self.neck.lm.model.gradient_checkpointing = False self.register_buffer('ones', torch.ones((1,4096), dtype=torch.long), persistent=False) self.graft_adapter() self.init_weights() # float32 --> bfloat16 for p in self.parameters(): p.data = p.data.to(torch.bfloat16) #if config.resume_from_checkpoint is not None: # drain_loader = True # accelerator.load_state(config.resume_from_checkpoint, load_module_strict=False) # # start_epoch, start_step, all_step = [int(_.split('_')[1]) for _ in args.resume_from_checkpoint.split('/')[-2].split('-')] #elif config.resume_from_pth is not None: # print(f'###########loading##########{config.resume_from_pth}###########loading##########') # ckpt = torch.load(config.resume_from_pth, map_location='cpu') # ckpt_copy = {k[7:]: v for k, v in ckpt.items()} # self.load_state_dict(ckpt_copy, strict=False) # print(f'###########loaded##########{config.resume_from_pth}###########loaded##########') if False: self.patch_llm() self.first_run = True def graft_adapter(self): adapter_latent_len = 32 self.adapter_latent_len = adapter_latent_len self.adapter_latent = nn.Parameter(torch.rand((1,adapter_latent_len, self.config.hidden_size), \ dtype=torch.float)) resampler_latent_len = 32 self.resampler_latent_len = resampler_latent_len self.resampler_latent = nn.Parameter(torch.rand((1,resampler_latent_len, self.config.hidden_size), \ dtype=torch.float)) ## TODO # self.adapter.pre_bn = torch.nn.BatchNorm1d(4096, affine=True) self.adapter = nn.ModuleList([]) ff_mult = 4 heads=8 dim_head=512 act='gelu' lm_dim = self.config.hidden_size if self.bk_name == 'HTSAT': feat_dim = 1024 depth = len(self.backbone.layers[2].blocks) else: feat_dim = 768 depth = int(len(self.neck.lm.model.layers)/2) # 16 for idx in range(depth): self.adapter.append(nn.ModuleList([ Adapter(input_size=self.config.hidden_size), # PerceiverAttentionLayer(feat_dim=feat_dim, latent_dim=lm_dim, dim_head=dim_head, heads=heads), # FeedForward(dim=lm_dim, out_dim=lm_dim, mult=1, act=act), #FeedForward(dim=self.dim, out_dim=768, mult=ff_mult, act=act) if idx != depth-1 else nn.Identity() ])) self.samplers = nn.ModuleList([]) # add for _ in range(3): self.samplers.append(nn.ModuleList([ PerceiverAttentionLayer(feat_dim=feat_dim, latent_dim=lm_dim, dim_head=64, heads=heads), FeedForward(dim=lm_dim, out_dim=lm_dim, mult=4), ])) self.norm = nn.LayerNorm(lm_dim) # self.agate_list = nn.ParameterList([]) # for i in range(len(self.neck.lm.model.layers)): # self.agate_list.append(nn.Parameter(torch.zeros(lm_dim))) def init_weights(self): try: super().init_weights() except: pass # import traceback # traceback.print_exc() if getattr(self, 'adapter_latent', None) is not None: self.adapter_latent.data.normal_(mean=0.0, std=0.02) if getattr(self, 'resampler_latent', None) is not None: self.adapter_latent.data.normal_(mean=0.0, std=0.02) def forward_resampler(self, x): # b, 768, 512 latents = repeat(self.resampler_latent, 'b n d -> (bs b) n d', bs=x.shape[0]) for attn, ff in self.samplers: latents = attn(x, latents) + latents latents = ff(latents) + latents v2t_feats = self.norm(latents) # # v2t_atts = torch.ones(v2t_feats.shape[:2], dtype=torch.long, device=v2t_feats.device) return v2t_feats # bs, 32, dim_llm def hook_adapter(self, audio_embedding, lm, v2t_feats): class PHooker: # model = self.backbone # mgtr = self.backbone.forward_generator(spectrogram) adapter = self.adapter y = v2t_feats handles_list = list() cnter = 0 def layer_prehook(self, m, margs, mkargs): ahl = ArgsHandler(m, 'forward', margs, mkargs) # print(self.cnter) # if self.cnter>=16: # self.cnter+=1 # return None adapt = self.adapter[self.cnter][0] hs = ahl.get_arg("hidden_states") adapter_residual = hs neo_hs = adapt(hs, adapter_residual) self.cnter+=1 ahl.set_arg("hidden_states", neo_hs) return ahl.return_all_args() def first_layer_prehook(self, m, margs, mkargs): ahl = ArgsHandler(m, 'forward', margs, mkargs) neo_lm_latents = self.y # torch.Size([128, 32, 4096]) hs = ahl.get_arg("hidden_states") # torch.Size([128, 87, 4096]) hs_msk = self.lm_ahl.get_arg("input_ids") < 0 # torch.Size([128, 87]) [False,, True*32, False,,] # __import__('pdb').set_trace() neo_hs = hs.masked_scatter(hs_msk.unsqueeze(-1), neo_lm_latents) # resampler hooker直接替换 ahl.set_arg("hidden_states", neo_hs) return ahl.return_all_args() def lm_prehook(self, m, margs, mkargs): self.lm_ahl = ArgsHandler(m, 'forward', margs, mkargs) return None def last_layer_hook(self, m, margs, mkargs): # __import__('pdb').set_trace() self.cnter = 0 if getattr(lm,'phooker',False): for _ in lm.phooker.handles_list: _.remove() del lm.phooker lm.phooker = None phooker = PHooker() phooker.handles_list.append(lm.register_forward_pre_hook(phooker.lm_prehook, with_kwargs=True)) # 第一层插入 phooker.handles_list.append(lm.model.layers[0].register_forward_pre_hook(phooker.first_layer_prehook, with_kwargs=True)) for ii in range(1,len(lm.model.layers),2): l = lm.model.layers[ii] handle = l.register_forward_pre_hook(phooker.layer_prehook, with_kwargs=True) phooker.handles_list.append(handle) phooker.handles_list.append(lm.model.layers[-1].register_forward_pre_hook(phooker.last_layer_hook, with_kwargs=True)) lm.phooker = phooker return None def prepare_ids(self, batch, audio_ids): toker = self.neck.tokenizer # for idx, l in enumerate(self.neck.lm.model.layers): # l.agate = self.agate_list[idx].clone() ## should clone the parameter with torch.no_grad(): input_ids = batch['input_ids'] att_msk = batch['attention_mask'] au_crds = batch['audio_crds'] ans_crds = batch['ans_crds'] bsz = input_ids.shape[0] # __import__('pdb').set_trace() ## TODO merged_ids, merged_msk, label_ids = list(), list(), list() for i in range(bsz): # cur_merged_ids = torch.cat([input_ids[i,:au_crds[i]], -1 * audio_ids[i] -1, input_ids[i,au_crds[i]:]]) cur_merged_ids = torch.cat([ -1 * audio_ids[i] -1, input_ids[i,au_crds[i]:]]) # cur_au_msk = self.ones[:,:audio_ids.shape[1]][0].clone().type_as(att_msk).detach() cur_au_msk = torch.ones(audio_ids.shape[1], device=audio_ids.device) # cur_merged_msk = torch.cat([att_msk[i,:au_crds[i]], cur_au_msk, att_msk[i,au_crds[i]:]]) cur_merged_msk = torch.cat([ cur_au_msk, att_msk[i,au_crds[i]:]]) cur_label_ids = cur_merged_ids.clone().detach() cur_label_ids[:audio_ids.shape[1]+ans_crds[i]] = -100 merged_ids.append(cur_merged_ids) merged_msk.append(cur_merged_msk) label_ids.append(cur_label_ids) merged_ids = torch.stack(merged_ids, dim=0) merged_msk = torch.stack(merged_msk, dim=0) label_ids = torch.stack(label_ids, dim=0) assert merged_ids.shape[0] == bsz assert merged_ids.shape == merged_msk.shape label_msk = merged_msk.clone() assert label_msk.shape == merged_msk.shape assert merged_msk[:,-1].max() == 1 for i in range(len(ans_crds)): label_ids[i,:audio_ids.shape[1]+ans_crds[i]].fill_(-100) merged_labels = label_ids merged_ids[merged_ids.eq(-100)] = toker.pad_token_id return merged_ids, merged_msk, merged_labels def forward(self, batch, **kwargs): """Forward computation during training. Args: img (torch.Tensor): Input images of shape (N, C, H, W). kwargs: Any keyword arguments to be used to forward. Returns: Dict[str, torch.Tensor]: A dictionary of loss components. """ bsz = len(batch['input_ids']) device = batch['input_ids'].device float_type = next(self.parameters()).dtype spectrogram = batch['spectrogram'].type(float_type) audio_embedding = self.backbone(spectrogram).detach() # b, 768, 512 resampler_feats = self.forward_resampler(audio_embedding) self.hook_adapter(audio_embedding, self.neck.lm, resampler_feats) # add hook # self.hook_resapmler(resampler_feats, self.neck.lm) audio_ids = torch.arange(self.adapter_latent.shape[1]).unsqueeze(0).repeat((bsz, 1)).long().to(device) assert audio_ids.max() < 100 merged_ids, merged_msk, merged_labels = self.prepare_ids(batch, audio_ids) try: assert merged_ids.shape == merged_labels.shape outs = self.neck(input_ids=merged_ids.contiguous().long(), flatten_embs=self.adapter_latent.flatten(0,1), # 32, 4096 # flatten_embs = resampler_feats.flatten(0,1), # b, 32, 4096 attention_mask=merged_msk.contiguous().long(), labels=merged_labels.contiguous().long(), use_cache=False) except Exception as e: import traceback traceback.print_exc() __import__('remote_pdb').set_trace() #outs.hidden_logits = self.hidden_logits ## TODO if eval(os.environ.get("doing_eval", 'False')): outs.merged_ids = merged_ids.cpu() outs.merged_labels = merged_labels.cpu() return outs def forward_test(self, batch, **kwargs): """Forward computation during training. Args: img (torch.Tensor): Input images of shape (N, C, H, W). kwargs: Any keyword arguments to be used to forward. Returns: Dict[str, torch.Tensor]: A dictionary of loss components. """ bsz = len(batch['input_ids']) device = batch['input_ids'].device float_type = next(self.parameters()).dtype spectrogram = batch['spectrogram'].type(float_type) audio_embedding = self.backbone(spectrogram).detach() # b, 768, 512 resampler_feats = self.forward_resampler(audio_embedding) self.hook_adapter(audio_embedding, self.neck.lm, resampler_feats) # add hook # self.extract_features(batch, self.neck.lm) audio_ids = torch.arange(self.adapter_latent.shape[1]).unsqueeze(0).repeat((bsz, 1)).long().to(device) assert audio_ids.max() < 100 merged_ids, merged_msk, merged_labels = self.prepare_ids(batch, audio_ids) au_crds = batch['audio_crds'] ans_crds = batch['ans_crds'] aid_len = audio_ids.shape[-1] toker = self.neck.tokenizer with torch.no_grad(): ## TODO pad_token = toker.encode(self.neck.tokenizer.eos_token)[0] padded_merged_ids = self.ones[:, :aid_len+max(ans_crds)].repeat(bsz, 1).clone().detach() * pad_token for i in range(bsz): # for i in range(1): assert au_crds[i] <= ans_crds[i] cur_ids = merged_ids[i][:aid_len+ans_crds[i]] padded_merged_ids[i][max(ans_crds)-ans_crds[i]:] = cur_ids # __import__('pdb').set_trace() outs = self.neck.generate(padded_merged_ids, self.adapter_latent.flatten(0,1)) #outs.hidden_logits = self.hidden_logits return outs import torch from torch import nn from transformers.activations import ACT2FN class Adapter(nn.Module): """ Implementation of a sequential bottleneck adapter block. """ def __init__( self, input_size, down_sample=None, ): super().__init__() self.input_size = input_size # if a downsample size is not passed, we just half the size of the original input self.down_sample = down_sample if down_sample is None: self.down_sample = self.input_size // 2 self.adapter_norm_before = nn.LayerNorm(self.input_size) self.adapter_down = nn.Linear(self.input_size, self.down_sample) self.non_linearity = ACT2FN["silu"] # Up projection to input size self.adapter_up = nn.Linear(self.down_sample, self.input_size) # Additional scaling factor (from He et al. (2021)) self.scaling = nn.Parameter(torch.ones(1)) self.adapter_down.apply(self._init_weights) self.adapter_up.apply(self._init_weights) def forward(self, x, residual_input): # , residual_input=None): down = self.non_linearity(self.adapter_down(self.adapter_norm_before(x))) up = self.adapter_up(down) up = up * self.scaling output = up output = output + residual_input return output @staticmethod def _init_weights(module): """Initialize the weights.""" if isinstance(module, (nn.Linear, nn.Embedding)): # std defaults to 0.02, this might need to be changed module.weight.data.normal_(mean=0.0, std=0.02) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_()