|
import json |
|
import os |
|
import pdb |
|
from mmcv.cnn.bricks import padding |
|
import torch |
|
from torch import nn, einsum |
|
from typing import Optional, Dict, Tuple |
|
from .mae_vit import MAEViT |
|
from .htsat import HTSAT_Swin_Transformer, create_htsat_model |
|
from .LMdecoder import LMDecoder, LMDecoder_qlora |
|
from .vision_transformer import VisionTransformer |
|
from einops import rearrange, repeat |
|
from einops_exts import rearrange_many |
|
import inspect |
|
|
|
from transformers.modeling_utils import PreTrainedModel |
|
from .configuration_maelm import MAELMConfig |
|
|
|
class ArgsHandler: |
|
def __init__(self, module, funcname, fargs, fkargs): |
|
self.fargs = list(fargs) |
|
self.fkargs = fkargs |
|
func = getattr(module, funcname) |
|
fal_repr = f"{funcname}_argnames_list" |
|
if (argns_list:=getattr(module, fal_repr, None)) is None: |
|
self.func_sig = inspect.signature(func) |
|
self.argnames_list = list(self.func_sig.parameters.keys()) |
|
setattr(module, fal_repr, self.argnames_list) |
|
else: |
|
self.argnames_list = argns_list |
|
|
|
def get_arg(self, arg_name): |
|
if arg_name in self.fkargs: |
|
arg = self.fkargs[arg_name] |
|
else: |
|
arg = self.fargs[self.argnames_list.index(arg_name)] |
|
return arg |
|
|
|
def set_arg(self, arg_name, arg_value): |
|
if arg_name in self.fkargs: |
|
self.fkargs[arg_name] = arg_value |
|
else: |
|
self.fargs[self.argnames_list.index(arg_name)] = arg_value |
|
|
|
def return_all_args(self,): |
|
return tuple(self.fargs), self.fkargs |
|
|
|
class SquaredReLU(nn.Module): |
|
""" squared ReLU activation function""" |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, x): |
|
return torch.pow(torch.relu(x), 2) |
|
|
|
def FeedForward(dim, out_dim, mult=4, act='gelu'): |
|
""" |
|
lucidrains implementation, slightly modified with the act parameter. |
|
""" |
|
|
|
acts = dict( |
|
gelu=nn.GELU, |
|
sqrelu=SquaredReLU, |
|
relu=nn.ReLU |
|
) |
|
|
|
assert act in acts, f"act. can only be one of {acts.keys()}" |
|
|
|
inner_dim = int(dim * mult) |
|
return nn.Sequential( |
|
nn.LayerNorm(dim), |
|
nn.Linear(dim, inner_dim, bias=False), |
|
acts[act](), |
|
nn.Linear(inner_dim, out_dim, bias=False) |
|
) |
|
|
|
|
|
class PerceiverAttentionLayer(nn.Module): |
|
def __init__( |
|
self, |
|
*, |
|
feat_dim, |
|
latent_dim, |
|
dim_head=64, |
|
heads=8 |
|
): |
|
super().__init__() |
|
self.scale = dim_head ** -0.5 |
|
self.heads = heads |
|
self.dim_head = dim_head |
|
|
|
inner_dim = dim_head * heads |
|
|
|
|
|
self.norm_media = nn.LayerNorm(feat_dim) |
|
self.norm_latents = nn.LayerNorm(latent_dim) |
|
|
|
self.to_q = nn.Linear(latent_dim, inner_dim, bias=False) |
|
self.to_k = nn.Linear(feat_dim, inner_dim, bias=False) |
|
self.to_v = nn.Linear(feat_dim, inner_dim, bias=False) |
|
self.to_out = nn.Linear(inner_dim, latent_dim, bias=False) |
|
|
|
def forward(self, features, latents): |
|
""" |
|
Latent vectors are cross-attending to the visual features x. |
|
:param x: Tensor (n_batch, n_features, dim) |
|
visual features |
|
:param latents: Tensor (n_batch, n_latents, dim) |
|
latent learnt vectors from which the queries are computed. |
|
Actually the same, just replicated in n_batch and n_frames dimension. |
|
:return: Tensor (n_batch, n_latents, dim) |
|
""" |
|
assert features.ndim == 3 |
|
assert latents.ndim == 3 |
|
assert features.shape[0] == latents.shape[0] |
|
|
|
|
|
n_heads = self.heads |
|
n_batch, n_features, dim = features.shape |
|
n_queries = latents.shape[1] |
|
|
|
|
|
x = self.norm_media(features) |
|
latents = self.norm_latents(latents) |
|
|
|
|
|
|
|
q = self.to_q(latents) |
|
q = rearrange(q, 'b q (h d) -> b h q d', h=n_heads) |
|
assert q.shape == torch.Size([n_batch, n_heads, n_queries, self.dim_head]) |
|
|
|
|
|
|
|
''' |
|
kv_input = torch.cat((x, latents), dim=-2) |
|
n_features_latents = n_features + n_queries |
|
''' |
|
|
|
kv_input = x |
|
n_features_latents = n_features |
|
|
|
|
|
k = self.to_k(kv_input) |
|
v = self.to_v(kv_input) |
|
|
|
|
|
|
|
|
|
k, v = rearrange_many((k, v), 'b f (h d) -> b h f d', h=n_heads) |
|
assert v.shape == torch.Size([n_batch, n_heads, n_features_latents, self.dim_head]) |
|
|
|
|
|
q = q * self.scale |
|
|
|
|
|
|
|
|
|
|
|
sim = einsum('b h q d, b h f d -> b h q f', q, k) |
|
|
|
|
|
sim = sim - sim.amax(dim=-1, keepdim=True).detach() |
|
alphas = sim.softmax(dim=-1) |
|
|
|
|
|
out = einsum('b h q f, b h f v -> b h q v', alphas, v) |
|
|
|
|
|
out = rearrange(out, 'b h q v -> b q (h v)') |
|
return self.to_out(out) |
|
|
|
|
|
class MAEForCausalLM(PreTrainedModel): |
|
""" |
|
|
|
Args: |
|
backbone (dict): Config dict for encoder. Defaults to None. |
|
neck (dict): Config dict for encoder. Defaults to None. |
|
head (dict): Config dict for loss functions. Defaults to None. |
|
init_cfg (dict, optional): Config dict for weight initialization. |
|
Defaults to None. |
|
""" |
|
|
|
config_class = MAELMConfig |
|
|
|
def __init__(self, config: MAELMConfig) -> None: |
|
super().__init__(config) |
|
backbone = config.backbone |
|
assert backbone is not None |
|
bk_name = backbone.pop('name') |
|
self.bk_name = bk_name |
|
if bk_name == 'MAEViT': |
|
ckpt_path = backbone.pop('ckpt') if 'ckpt' in backbone else None |
|
self.backbone = MAEViT(**backbone) |
|
|
|
|
|
|
|
|
|
elif bk_name == 'HTSAT': |
|
ckpt_path = backbone.pop('ckpt') if 'ckpt' in backbone else None |
|
self.backbone = create_htsat_model(backbone) |
|
if ckpt_path is not None: |
|
ckpt = torch.load( ckpt_path,'cpu') |
|
self.backbone.load_state_dict(ckpt['state_dict']) |
|
elif bk_name == 'qformer': |
|
raise NotImplemented |
|
else: |
|
raise NotImplemented |
|
|
|
|
|
|
|
|
|
|
|
neck = config.neck |
|
assert neck is not None |
|
nk_name = neck.pop('name') |
|
if nk_name == 'LMDecoder': |
|
self.neck = LMDecoder(**neck) |
|
elif nk_name == 'LMDecoder_qlora': |
|
self.neck = LMDecoder_qlora(**neck) |
|
else: |
|
raise NotImplemented |
|
self.config = self.neck.LMconfig |
|
|
|
''' |
|
self.ae_proj = nn.Linear( |
|
768, self.config.hidden_size |
|
) |
|
''' |
|
|
|
|
|
|
|
|
|
self.neck.lm.model.gradient_checkpointing = False |
|
|
|
self.register_buffer('ones', torch.ones((1,4096), dtype=torch.long), persistent=False) |
|
self.graft_adapter() |
|
self.init_weights() |
|
|
|
for p in self.parameters(): |
|
p.data = p.data.to(torch.bfloat16) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if False: |
|
self.patch_llm() |
|
self.first_run = True |
|
|
|
def graft_adapter(self): |
|
adapter_latent_len = 32 |
|
self.adapter_latent_len = adapter_latent_len |
|
self.adapter_latent = nn.Parameter(torch.rand((1,adapter_latent_len, self.config.hidden_size), \ |
|
dtype=torch.float)) |
|
resampler_latent_len = 32 |
|
self.resampler_latent_len = resampler_latent_len |
|
self.resampler_latent = nn.Parameter(torch.rand((1,resampler_latent_len, self.config.hidden_size), \ |
|
dtype=torch.float)) |
|
|
|
|
|
|
|
self.adapter = nn.ModuleList([]) |
|
|
|
ff_mult = 4 |
|
heads=8 |
|
dim_head=512 |
|
act='gelu' |
|
|
|
lm_dim = self.config.hidden_size |
|
if self.bk_name == 'HTSAT': |
|
feat_dim = 1024 |
|
depth = len(self.backbone.layers[2].blocks) |
|
else: |
|
feat_dim = 768 |
|
depth = int(len(self.neck.lm.model.layers)/2) |
|
for idx in range(depth): |
|
self.adapter.append(nn.ModuleList([ |
|
Adapter(input_size=self.config.hidden_size), |
|
|
|
|
|
|
|
])) |
|
|
|
self.samplers = nn.ModuleList([]) |
|
for _ in range(3): |
|
self.samplers.append(nn.ModuleList([ |
|
PerceiverAttentionLayer(feat_dim=feat_dim, latent_dim=lm_dim, dim_head=64, heads=heads), |
|
FeedForward(dim=lm_dim, out_dim=lm_dim, mult=4), |
|
])) |
|
self.norm = nn.LayerNorm(lm_dim) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_weights(self): |
|
try: |
|
super().init_weights() |
|
except: |
|
pass |
|
|
|
|
|
if getattr(self, 'adapter_latent', None) is not None: |
|
self.adapter_latent.data.normal_(mean=0.0, std=0.02) |
|
if getattr(self, 'resampler_latent', None) is not None: |
|
self.adapter_latent.data.normal_(mean=0.0, std=0.02) |
|
|
|
def forward_resampler(self, x): |
|
|
|
latents = repeat(self.resampler_latent, 'b n d -> (bs b) n d', bs=x.shape[0]) |
|
for attn, ff in self.samplers: |
|
latents = attn(x, latents) + latents |
|
latents = ff(latents) + latents |
|
v2t_feats = self.norm(latents) |
|
|
|
return v2t_feats |
|
|
|
|
|
def hook_adapter(self, audio_embedding, lm, v2t_feats): |
|
|
|
class PHooker: |
|
|
|
|
|
adapter = self.adapter |
|
y = v2t_feats |
|
handles_list = list() |
|
cnter = 0 |
|
def layer_prehook(self, m, margs, mkargs): |
|
ahl = ArgsHandler(m, 'forward', margs, mkargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
adapt = self.adapter[self.cnter][0] |
|
|
|
hs = ahl.get_arg("hidden_states") |
|
adapter_residual = hs |
|
neo_hs = adapt(hs, adapter_residual) |
|
|
|
self.cnter+=1 |
|
ahl.set_arg("hidden_states", neo_hs) |
|
return ahl.return_all_args() |
|
def first_layer_prehook(self, m, margs, mkargs): |
|
ahl = ArgsHandler(m, 'forward', margs, mkargs) |
|
neo_lm_latents = self.y |
|
hs = ahl.get_arg("hidden_states") |
|
hs_msk = self.lm_ahl.get_arg("input_ids") < 0 |
|
|
|
neo_hs = hs.masked_scatter(hs_msk.unsqueeze(-1), neo_lm_latents) |
|
ahl.set_arg("hidden_states", neo_hs) |
|
return ahl.return_all_args() |
|
|
|
def lm_prehook(self, m, margs, mkargs): |
|
self.lm_ahl = ArgsHandler(m, 'forward', margs, mkargs) |
|
return None |
|
def last_layer_hook(self, m, margs, mkargs): |
|
|
|
self.cnter = 0 |
|
|
|
if getattr(lm,'phooker',False): |
|
for _ in lm.phooker.handles_list: |
|
_.remove() |
|
del lm.phooker |
|
lm.phooker = None |
|
phooker = PHooker() |
|
phooker.handles_list.append(lm.register_forward_pre_hook(phooker.lm_prehook, with_kwargs=True)) |
|
|
|
phooker.handles_list.append(lm.model.layers[0].register_forward_pre_hook(phooker.first_layer_prehook, with_kwargs=True)) |
|
|
|
for ii in range(1,len(lm.model.layers),2): |
|
l = lm.model.layers[ii] |
|
handle = l.register_forward_pre_hook(phooker.layer_prehook, with_kwargs=True) |
|
phooker.handles_list.append(handle) |
|
phooker.handles_list.append(lm.model.layers[-1].register_forward_pre_hook(phooker.last_layer_hook, with_kwargs=True)) |
|
lm.phooker = phooker |
|
return None |
|
|
|
|
|
|
|
def prepare_ids(self, batch, audio_ids): |
|
toker = self.neck.tokenizer |
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
input_ids = batch['input_ids'] |
|
att_msk = batch['attention_mask'] |
|
au_crds = batch['audio_crds'] |
|
ans_crds = batch['ans_crds'] |
|
bsz = input_ids.shape[0] |
|
|
|
|
|
merged_ids, merged_msk, label_ids = list(), list(), list() |
|
for i in range(bsz): |
|
|
|
cur_merged_ids = torch.cat([ -1 * audio_ids[i] -1, input_ids[i,au_crds[i]:]]) |
|
|
|
|
|
cur_au_msk = torch.ones(audio_ids.shape[1], device=audio_ids.device) |
|
|
|
cur_merged_msk = torch.cat([ cur_au_msk, att_msk[i,au_crds[i]:]]) |
|
cur_label_ids = cur_merged_ids.clone().detach() |
|
cur_label_ids[:audio_ids.shape[1]+ans_crds[i]] = -100 |
|
|
|
merged_ids.append(cur_merged_ids) |
|
merged_msk.append(cur_merged_msk) |
|
label_ids.append(cur_label_ids) |
|
|
|
merged_ids = torch.stack(merged_ids, dim=0) |
|
merged_msk = torch.stack(merged_msk, dim=0) |
|
label_ids = torch.stack(label_ids, dim=0) |
|
|
|
assert merged_ids.shape[0] == bsz |
|
assert merged_ids.shape == merged_msk.shape |
|
|
|
label_msk = merged_msk.clone() |
|
assert label_msk.shape == merged_msk.shape |
|
assert merged_msk[:,-1].max() == 1 |
|
|
|
for i in range(len(ans_crds)): |
|
label_ids[i,:audio_ids.shape[1]+ans_crds[i]].fill_(-100) |
|
|
|
|
|
merged_labels = label_ids |
|
merged_ids[merged_ids.eq(-100)] = toker.pad_token_id |
|
|
|
return merged_ids, merged_msk, merged_labels |
|
|
|
def forward(self, batch, **kwargs): |
|
"""Forward computation during training. |
|
|
|
Args: |
|
img (torch.Tensor): Input images of shape (N, C, H, W). |
|
kwargs: Any keyword arguments to be used to forward. |
|
Returns: |
|
Dict[str, torch.Tensor]: A dictionary of loss components. |
|
""" |
|
bsz = len(batch['input_ids']) |
|
device = batch['input_ids'].device |
|
float_type = next(self.parameters()).dtype |
|
spectrogram = batch['spectrogram'].type(float_type) |
|
audio_embedding = self.backbone(spectrogram).detach() |
|
resampler_feats = self.forward_resampler(audio_embedding) |
|
self.hook_adapter(audio_embedding, self.neck.lm, resampler_feats) |
|
|
|
|
|
|
|
audio_ids = torch.arange(self.adapter_latent.shape[1]).unsqueeze(0).repeat((bsz, 1)).long().to(device) |
|
assert audio_ids.max() < 100 |
|
merged_ids, merged_msk, merged_labels = self.prepare_ids(batch, audio_ids) |
|
|
|
try: |
|
assert merged_ids.shape == merged_labels.shape |
|
outs = self.neck(input_ids=merged_ids.contiguous().long(), |
|
flatten_embs=self.adapter_latent.flatten(0,1), |
|
|
|
attention_mask=merged_msk.contiguous().long(), |
|
labels=merged_labels.contiguous().long(), use_cache=False) |
|
except Exception as e: |
|
import traceback |
|
traceback.print_exc() |
|
__import__('remote_pdb').set_trace() |
|
|
|
|
|
|
|
if eval(os.environ.get("doing_eval", 'False')): |
|
outs.merged_ids = merged_ids.cpu() |
|
outs.merged_labels = merged_labels.cpu() |
|
|
|
return outs |
|
|
|
|
|
def forward_test(self, batch, **kwargs): |
|
"""Forward computation during training. |
|
|
|
Args: |
|
img (torch.Tensor): Input images of shape (N, C, H, W). |
|
kwargs: Any keyword arguments to be used to forward. |
|
Returns: |
|
Dict[str, torch.Tensor]: A dictionary of loss components. |
|
""" |
|
|
|
|
|
bsz = len(batch['input_ids']) |
|
device = batch['input_ids'].device |
|
float_type = next(self.parameters()).dtype |
|
spectrogram = batch['spectrogram'].type(float_type) |
|
audio_embedding = self.backbone(spectrogram).detach() |
|
resampler_feats = self.forward_resampler(audio_embedding) |
|
self.hook_adapter(audio_embedding, self.neck.lm, resampler_feats) |
|
|
|
audio_ids = torch.arange(self.adapter_latent.shape[1]).unsqueeze(0).repeat((bsz, 1)).long().to(device) |
|
assert audio_ids.max() < 100 |
|
|
|
merged_ids, merged_msk, merged_labels = self.prepare_ids(batch, audio_ids) |
|
au_crds = batch['audio_crds'] |
|
ans_crds = batch['ans_crds'] |
|
|
|
aid_len = audio_ids.shape[-1] |
|
|
|
|
|
toker = self.neck.tokenizer |
|
with torch.no_grad(): |
|
|
|
|
|
pad_token = toker.encode(self.neck.tokenizer.eos_token)[0] |
|
padded_merged_ids = self.ones[:, :aid_len+max(ans_crds)].repeat(bsz, 1).clone().detach() * pad_token |
|
for i in range(bsz): |
|
|
|
assert au_crds[i] <= ans_crds[i] |
|
cur_ids = merged_ids[i][:aid_len+ans_crds[i]] |
|
padded_merged_ids[i][max(ans_crds)-ans_crds[i]:] = cur_ids |
|
|
|
outs = self.neck.generate(padded_merged_ids, self.adapter_latent.flatten(0,1)) |
|
|
|
|
|
return outs |
|
|
|
|
|
|
|
import torch |
|
from torch import nn |
|
|
|
from transformers.activations import ACT2FN |
|
|
|
class Adapter(nn.Module): |
|
""" |
|
Implementation of a sequential bottleneck adapter block. |
|
""" |
|
def __init__( |
|
self, |
|
input_size, |
|
down_sample=None, |
|
): |
|
super().__init__() |
|
|
|
self.input_size = input_size |
|
|
|
|
|
self.down_sample = down_sample |
|
if down_sample is None: |
|
self.down_sample = self.input_size // 2 |
|
|
|
self.adapter_norm_before = nn.LayerNorm(self.input_size) |
|
self.adapter_down = nn.Linear(self.input_size, self.down_sample) |
|
self.non_linearity = ACT2FN["silu"] |
|
|
|
|
|
self.adapter_up = nn.Linear(self.down_sample, self.input_size) |
|
|
|
|
|
self.scaling = nn.Parameter(torch.ones(1)) |
|
|
|
self.adapter_down.apply(self._init_weights) |
|
self.adapter_up.apply(self._init_weights) |
|
|
|
def forward(self, x, residual_input): |
|
|
|
down = self.non_linearity(self.adapter_down(self.adapter_norm_before(x))) |
|
|
|
up = self.adapter_up(down) |
|
up = up * self.scaling |
|
output = up |
|
|
|
output = output + residual_input |
|
|
|
return output |
|
|
|
@staticmethod |
|
def _init_weights(module): |
|
"""Initialize the weights.""" |
|
if isinstance(module, (nn.Linear, nn.Embedding)): |
|
|
|
module.weight.data.normal_(mean=0.0, std=0.02) |
|
elif isinstance(module, nn.LayerNorm): |
|
module.bias.data.zero_() |
|
module.weight.data.fill_(1.0) |
|
if isinstance(module, nn.Linear) and module.bias is not None: |
|
module.bias.data.zero_() |
|
|
|
|