|
|
import torch |
|
|
import torch.nn as nn |
|
|
from einops import rearrange |
|
|
import torch |
|
|
from torch.cuda.amp import autocast |
|
|
from functools import partial |
|
|
from typing import Optional, Tuple, Union |
|
|
import torchaudio.transforms as audio_transforms |
|
|
from einops import rearrange |
|
|
from einops.layers.torch import Rearrange |
|
|
from itertools import repeat |
|
|
import collections |
|
|
|
|
|
import torch.nn.functional as F |
|
|
import einops |
|
|
|
|
|
|
|
|
if hasattr(nn.functional, 'scaled_dot_product_attention'): |
|
|
ATTENTION_MODE = 'flash' |
|
|
else: |
|
|
ATTENTION_MODE = 'math' |
|
|
print(f'attention mode is {ATTENTION_MODE}') |
|
|
|
|
|
|
|
|
def _ntuple(n): |
|
|
|
|
|
def parse(x) -> Tuple: |
|
|
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): |
|
|
return tuple(x) |
|
|
return tuple(repeat(x, n)) |
|
|
|
|
|
return parse |
|
|
|
|
|
|
|
|
to_2tuple = _ntuple(2) |
|
|
|
|
|
|
|
|
class MAELoss(torch.nn.Module): |
|
|
|
|
|
def __init__(self, norm_pix_loss: bool = True): |
|
|
super().__init__() |
|
|
self.norm_pix_loss = norm_pix_loss |
|
|
|
|
|
@autocast(enabled=False) |
|
|
def forward(self, pred: torch.Tensor, target: torch.Tensor, |
|
|
mask: torch.Tensor) -> torch.Tensor: |
|
|
if self.norm_pix_loss is True: |
|
|
mean = target.mean(dim=-1, keepdim=True) |
|
|
var = target.var(dim=-1, keepdim=True) |
|
|
target = (target - mean) / (var + 1.e-6)**.5 |
|
|
elif self.norm_pix_loss == 'global': |
|
|
mean = target.mean() |
|
|
var = target.var() |
|
|
target = (target - mean) / (var + 1.e-6)**.5 |
|
|
loss = (pred - target)**2 |
|
|
loss = loss.mean(dim=-1) |
|
|
loss = (loss * mask).sum() / mask.sum() |
|
|
return loss |
|
|
|
|
|
|
|
|
class AudioPatchEmbed(nn.Module): |
|
|
|
|
|
def __init__(self, |
|
|
input_size: Union[int, Tuple[int, int]] = (64, 100), |
|
|
patch_size: Tuple[int, int] = (64, 4), |
|
|
patch_stride: Tuple[int, int] = (64, 4), |
|
|
in_chans=1, |
|
|
embed_dim=768, |
|
|
norm_layer=None, |
|
|
flatten=False): |
|
|
super().__init__() |
|
|
patch_size = to_2tuple(patch_size) |
|
|
patch_stride = to_2tuple(patch_stride) |
|
|
self.input_size: Tuple[int, int] = to_2tuple(input_size) |
|
|
self.patch_size: Tuple[int, int] = to_2tuple(patch_size) |
|
|
self.patch_stride: Tuple[int, int] = to_2tuple(patch_stride) |
|
|
self.grid_size = (self.input_size[0] // self.patch_stride[0], |
|
|
self.input_size[1] // self.patch_stride[1]) |
|
|
self.num_patches = self.grid_size[0] * self.grid_size[1] |
|
|
self.flatten = flatten |
|
|
|
|
|
self.proj = nn.Conv2d(in_chans, |
|
|
embed_dim, |
|
|
kernel_size=patch_size, |
|
|
stride=patch_stride) |
|
|
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.proj(x) |
|
|
if self.flatten: |
|
|
x = rearrange(x, 'b c f t -> b (f t) c') |
|
|
x = self.norm(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class LayerScale(nn.Module): |
|
|
|
|
|
def __init__(self, dim: int, init_values=1e-5, inplace=False): |
|
|
super().__init__() |
|
|
self.inplace = inplace |
|
|
self.gamma = nn.Parameter(init_values * torch.ones(dim)) |
|
|
|
|
|
def forward(self, x): |
|
|
return x.mul_(self.gamma) if self.inplace else x * self.gamma |
|
|
|
|
|
|
|
|
class Attention(nn.Module): |
|
|
|
|
|
def __init__(self, |
|
|
dim, |
|
|
num_heads=8, |
|
|
qkv_bias=False, |
|
|
attn_drop=0., |
|
|
proj_drop=0.): |
|
|
super().__init__() |
|
|
assert dim % num_heads == 0, 'dim should be divisible by num_heads' |
|
|
self.num_heads = num_heads |
|
|
head_dim = dim // num_heads |
|
|
self.scale = head_dim**-0.5 |
|
|
|
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
|
|
self.attn_drop = nn.Dropout(attn_drop) |
|
|
self.attn_drop_p = attn_drop |
|
|
self.proj = nn.Linear(dim, dim) |
|
|
self.proj_drop = nn.Dropout(proj_drop) |
|
|
|
|
|
def forward(self, x): |
|
|
B, N, C = x.shape |
|
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, |
|
|
C // self.num_heads).permute(2, 0, 3, 1, 4) |
|
|
q, k, v = qkv.unbind( |
|
|
0) |
|
|
|
|
|
if ATTENTION_MODE == 'flash': |
|
|
x = F.scaled_dot_product_attention(q, k, v, |
|
|
dropout_p=self.attn_drop_p, |
|
|
scale=self.scale, |
|
|
) |
|
|
x = einops.rearrange(x, 'B H L D -> B L (H D)') |
|
|
elif ATTENTION_MODE == 'math': |
|
|
attn = (q @ k.transpose(-2, -1)) * self.scale |
|
|
attn = attn.softmax(dim=-1) |
|
|
attn = self.attn_drop(attn) |
|
|
x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
|
|
|
|
|
x = self.proj(x) |
|
|
x = self.proj_drop(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class Mlp(nn.Module): |
|
|
|
|
|
def __init__(self, |
|
|
in_features, |
|
|
hidden_features=None, |
|
|
out_features=None, |
|
|
act_layer=nn.GELU, |
|
|
drop=0.): |
|
|
super().__init__() |
|
|
out_features = out_features or in_features |
|
|
hidden_features = hidden_features or in_features |
|
|
self.fc1 = nn.Linear(in_features, hidden_features) |
|
|
self.act = act_layer() |
|
|
self.fc2 = nn.Linear(hidden_features, out_features) |
|
|
self.drop = nn.Dropout(drop) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.fc1(x) |
|
|
x = self.act(x) |
|
|
x = self.drop(x) |
|
|
x = self.fc2(x) |
|
|
x = self.drop(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class Block(nn.Module): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dim, |
|
|
num_heads, |
|
|
mlp_ratio=4., |
|
|
qkv_bias=False, |
|
|
drop=0., |
|
|
attn_drop=0., |
|
|
init_values=None, |
|
|
act_layer=nn.GELU, |
|
|
norm_layer=nn.LayerNorm, |
|
|
attention_type='Attention', |
|
|
): |
|
|
super().__init__() |
|
|
self.norm1 = norm_layer(dim) |
|
|
attn_type = globals()[attention_type] |
|
|
self.attn = attn_type(dim, |
|
|
num_heads=num_heads, |
|
|
qkv_bias=qkv_bias, |
|
|
attn_drop=attn_drop, |
|
|
proj_drop=drop) |
|
|
self.ls1 = LayerScale( |
|
|
dim, init_values=init_values) if init_values else nn.Identity() |
|
|
|
|
|
self.norm2 = norm_layer(dim) |
|
|
self.mlp = Mlp(in_features=dim, |
|
|
hidden_features=int(dim * mlp_ratio), |
|
|
act_layer=act_layer, |
|
|
drop=drop) |
|
|
self.ls2 = LayerScale( |
|
|
dim, init_values=init_values) if init_values else nn.Identity() |
|
|
|
|
|
def forward(self, x): |
|
|
x = x + self.ls1(self.attn(self.norm1(x))) |
|
|
x = x + self.ls2(self.mlp(self.norm2(x))) |
|
|
return x |
|
|
|
|
|
|
|
|
class AudioTransformerMAE_Encoder(nn.Module): |
|
|
|
|
|
def __init__(self, |
|
|
patch_size: Tuple[int, int] = (64, 4), |
|
|
patch_stride: Tuple[int, int] = (64, 4), |
|
|
embed_dim: int = 768, |
|
|
depth: int = 12, |
|
|
num_heads=8, |
|
|
mlp_ratio=4., |
|
|
qkv_bias=True, |
|
|
drop_rate=0., |
|
|
attn_drop_rate=0., |
|
|
norm_layer=None, |
|
|
act_layer=None, |
|
|
init_values=None, |
|
|
target_length=1008, |
|
|
pooling='mean', |
|
|
time_patch_out: Optional[float] = None, |
|
|
freq_patch_out: Optional[float] = None, |
|
|
block_type='Block', |
|
|
attention_type='Attention', |
|
|
eval_avg='cat', |
|
|
n_fft: int = 512, |
|
|
n_mels: int = 64, |
|
|
hop_size: int = 160, |
|
|
win_size: int = 512, |
|
|
f_min: int = 0, |
|
|
f_max: int = 8000, |
|
|
center: bool = True, |
|
|
**kwargs): |
|
|
super().__init__() |
|
|
self.pooling = pooling |
|
|
self.embed_dim = embed_dim |
|
|
self.patch_stride = patch_stride |
|
|
self.patch_size = patch_size |
|
|
self.n_mels = n_mels |
|
|
self.eval_avg = eval_avg |
|
|
self.time_patch_out = time_patch_out |
|
|
self.freq_patch_out = freq_patch_out |
|
|
|
|
|
self.front_end = nn.Sequential( |
|
|
audio_transforms.MelSpectrogram(f_min=f_min, |
|
|
sample_rate=16000, |
|
|
win_length=win_size, |
|
|
center=center, |
|
|
n_fft=n_fft, |
|
|
f_max=f_max, |
|
|
hop_length=hop_size, |
|
|
n_mels=self.n_mels), |
|
|
audio_transforms.AmplitudeToDB(top_db=kwargs.get('top_db', 120))) |
|
|
|
|
|
self.init_bn = nn.Sequential( |
|
|
Rearrange('b c f t -> b f c t'), |
|
|
nn.BatchNorm2d(self.n_mels, momentum=0.01), |
|
|
Rearrange('b f c t -> b c f t')) |
|
|
|
|
|
self.target_length = target_length |
|
|
self.patch_embed = AudioPatchEmbed(input_size=(self.n_mels, |
|
|
target_length), |
|
|
embed_dim=self.embed_dim, |
|
|
patch_size=self.patch_size, |
|
|
flatten=False, |
|
|
patch_stride=self.patch_stride) |
|
|
self.num_patches = self.patch_embed.num_patches |
|
|
|
|
|
if pooling == 'token': |
|
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
|
|
self.token_pos_embed = nn.Parameter( |
|
|
torch.randn(1, embed_dim) * .02) |
|
|
|
|
|
self.time_pos_embed = nn.Parameter( |
|
|
torch.randn(1, embed_dim, 1, self.patch_embed.grid_size[1]) * .02) |
|
|
self.freq_pos_embed = nn.Parameter( |
|
|
torch.randn(1, embed_dim, self.patch_embed.grid_size[0], 1) * .02) |
|
|
|
|
|
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) |
|
|
act_layer = act_layer or nn.GELU |
|
|
self.pos_drop = nn.Dropout(p=drop_rate) |
|
|
block_function = globals()[block_type] |
|
|
self.blocks = nn.Sequential(*[ |
|
|
block_function( |
|
|
dim=embed_dim, |
|
|
num_heads=num_heads, |
|
|
mlp_ratio=mlp_ratio, |
|
|
qkv_bias=qkv_bias, |
|
|
init_values=init_values, |
|
|
drop=drop_rate, |
|
|
attn_drop=attn_drop_rate, |
|
|
norm_layer=norm_layer, |
|
|
act_layer=act_layer, |
|
|
attention_type=attention_type, |
|
|
) for _ in range(depth) |
|
|
]) |
|
|
self.norm = norm_layer(embed_dim) |
|
|
self.apply(self.init_weights) |
|
|
if hasattr(self, 'cls_token') and self.cls_token is not None: |
|
|
nn.init.normal_(self.cls_token, std=1e-6) |
|
|
group_masking = kwargs.get('group_masking', False) |
|
|
if isinstance(group_masking, bool): |
|
|
if group_masking is True: |
|
|
self.masking_func = self.random_masking_group |
|
|
else: |
|
|
self.masking_func = self.random_masking |
|
|
elif isinstance(group_masking, int): |
|
|
self.masking_func = partial(self.random_masking_group, |
|
|
group_factor=group_masking) |
|
|
|
|
|
@torch.jit.ignore |
|
|
def no_weight_decay(self): |
|
|
return { |
|
|
'time_pos_embed', 'cls_token', 'freq_pos_embed', 'token_pos_embed' |
|
|
} |
|
|
|
|
|
def init_weights(self, module): |
|
|
if isinstance(module, nn.Linear): |
|
|
torch.nn.init.xavier_uniform_(module.weight) |
|
|
if module.bias is not None: |
|
|
nn.init.zeros_(module.bias) |
|
|
elif isinstance(module, nn.LayerNorm): |
|
|
nn.init.constant_(module.bias, 0) |
|
|
nn.init.constant_(module.weight, 1.0) |
|
|
|
|
|
def random_masking_group(self, x, mask_ratio, group_factor: int = 2): |
|
|
""" |
|
|
Perform per-sample random masking by per-sample shuffling. |
|
|
Per-sample shuffling is done by argsort random noise. |
|
|
x: [N, L, D], sequence |
|
|
""" |
|
|
N, L, D = x.shape |
|
|
len_keep = int(L * (1 - mask_ratio)) |
|
|
|
|
|
noise = torch.rand(N, L // group_factor, |
|
|
device=x.device) |
|
|
|
|
|
indices = torch.arange(L, device=x.device).view(-1, group_factor) |
|
|
|
|
|
|
|
|
ids_shuffle = torch.argsort( |
|
|
noise, dim=1) |
|
|
ids_shuffle = indices[ids_shuffle].flatten(-2) |
|
|
ids_restore = torch.argsort(ids_shuffle, dim=1) |
|
|
|
|
|
|
|
|
ids_keep = ids_shuffle[:, :len_keep] |
|
|
x_masked = torch.gather(x, |
|
|
dim=1, |
|
|
index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) |
|
|
|
|
|
|
|
|
mask = torch.ones([N, L], device=x.device) |
|
|
mask[:, :len_keep] = 0 |
|
|
|
|
|
mask = torch.gather(mask, dim=1, index=ids_restore) |
|
|
|
|
|
return x_masked, mask, ids_restore |
|
|
|
|
|
def random_masking(self, x, mask_ratio): |
|
|
""" |
|
|
Perform per-sample random masking by per-sample shuffling. |
|
|
Per-sample shuffling is done by argsort random noise. |
|
|
x: [N, L, D], sequence |
|
|
""" |
|
|
N, L, D = x.shape |
|
|
len_keep = int(L * (1 - mask_ratio)) |
|
|
|
|
|
noise = torch.rand(N, L, device=x.device) |
|
|
|
|
|
|
|
|
ids_shuffle = torch.argsort( |
|
|
noise, dim=1) |
|
|
ids_restore = torch.argsort(ids_shuffle, dim=1) |
|
|
|
|
|
|
|
|
ids_keep = ids_shuffle[:, :len_keep] |
|
|
x_masked = torch.gather(x, |
|
|
dim=1, |
|
|
index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) |
|
|
|
|
|
|
|
|
mask = torch.ones([N, L], device=x.device) |
|
|
mask[:, :len_keep] = 0 |
|
|
|
|
|
mask = torch.gather(mask, dim=1, index=ids_restore) |
|
|
|
|
|
return x_masked, mask, ids_restore |
|
|
|
|
|
def forward_features(self, x, mask_ratio): |
|
|
x = self.patch_embed(x) |
|
|
b, c, f, t = x.shape |
|
|
x = x + self.time_pos_embed[:, :, :, :t] |
|
|
x = x + self.freq_pos_embed[:, :, :, :] |
|
|
x = rearrange(x, 'b c f t -> b (f t) c') |
|
|
|
|
|
x, mask, ids_restore = self.masking_func(x, mask_ratio) |
|
|
if self.pooling == 'token': |
|
|
cls_token = self.cls_token.expand(x.shape[0], -1, -1) |
|
|
cls_token = cls_token + self.token_pos_embed[:, :] |
|
|
x = torch.cat((cls_token, x), dim=1) |
|
|
x = self.pos_drop(x) |
|
|
x = self.blocks(x) |
|
|
x = self.norm(x) |
|
|
return x, mask, ids_restore |
|
|
|
|
|
def load_state_dict(self, state_dict, strict=True, **kwargs): |
|
|
if 'time_pos_embed' in state_dict and self.time_pos_embed.shape != state_dict[ |
|
|
'time_pos_embed'].shape: |
|
|
print( |
|
|
"Positional Embedding shape not the same with model, resizing!" |
|
|
) |
|
|
self.change_pos_embedding(state_dict) |
|
|
super().load_state_dict(state_dict, strict=strict, **kwargs) |
|
|
|
|
|
def change_pos_embedding(self, state_dict): |
|
|
target_time_pos_embed_length = self.time_pos_embed.shape[-1] |
|
|
target_freq_pos_embed_length = self.freq_pos_embed.shape[-2] |
|
|
|
|
|
pretrained_time_pos_embed = state_dict['time_pos_embed'] |
|
|
pretrained_freq_pos_embed = state_dict['freq_pos_embed'] |
|
|
|
|
|
if target_freq_pos_embed_length <= pretrained_time_pos_embed.shape[-1]: |
|
|
state_dict['time_pos_embed'] = pretrained_time_pos_embed[ |
|
|
..., :target_time_pos_embed_length] |
|
|
else: |
|
|
state_dict['time_pos_embed'] = torch.nn.functional.interpolate( |
|
|
pretrained_time_pos_embed, |
|
|
size=(1, target_time_pos_embed_length), |
|
|
align_corners=False, |
|
|
mode='bilinear') |
|
|
if target_freq_pos_embed_length <= pretrained_freq_pos_embed.shape[-2]: |
|
|
state_dict[ |
|
|
'freq_pos_embed'] = pretrained_freq_pos_embed[:, :, : |
|
|
target_freq_pos_embed_length, :] |
|
|
else: |
|
|
state_dict['freq_pos_embed'] = torch.nn.functional.interpolate( |
|
|
pretrained_freq_pos_embed, |
|
|
size=(target_freq_pos_embed_length, 1), |
|
|
align_corners=False, |
|
|
mode='bilinear') |
|
|
|
|
|
def forward_to_spec(self, x): |
|
|
|
|
|
with autocast(enabled=False): |
|
|
X = self.front_end(x) |
|
|
X = rearrange(X, 'b f t -> b 1 f t') |
|
|
X = self.init_bn(X) |
|
|
return X |
|
|
|
|
|
def forward(self, x, mask_ratio: float = 0.75): |
|
|
x = self.forward_to_spec(x) |
|
|
x, mask, restore_idxs = self.forward_features(x, mask_ratio=mask_ratio) |
|
|
return x, mask, restore_idxs |
|
|
|
|
|
|
|
|
class AudioTransformerMAE_Decoder(nn.Module): |
|
|
|
|
|
def __init__(self, |
|
|
input_dim: int, |
|
|
outputdim: int, |
|
|
patch_size: int = 16, |
|
|
patch_stride: int = 16, |
|
|
embed_dim: int = 768, |
|
|
num_patches: int = 100, |
|
|
depth: int = 12, |
|
|
num_heads: int = 12, |
|
|
mlp_ratio: float = 4., |
|
|
qkv_bias: bool = True, |
|
|
drop_rate: float = 0., |
|
|
attn_drop_rate: float = 0., |
|
|
norm_layer: Optional[torch.nn.Module] = None, |
|
|
act_layer: Optional[torch.nn.Module] = None, |
|
|
cls_token: bool = False, |
|
|
attention_type='Attention', |
|
|
init_values=None, |
|
|
**kwargs): |
|
|
super().__init__() |
|
|
self.embed_dim = embed_dim |
|
|
self.patch_stride = patch_stride |
|
|
self.patch_size = patch_size |
|
|
self.input_dim = input_dim |
|
|
|
|
|
self.input_proj = nn.Linear(input_dim, embed_dim) |
|
|
|
|
|
self.mask_token = nn.Parameter(torch.randn(1, 1, embed_dim) * .02) |
|
|
_norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) |
|
|
_act_layer = act_layer or nn.GELU |
|
|
self.use_cls = cls_token |
|
|
num_patches_total = num_patches + 1 if not cls_token else num_patches |
|
|
self.pos_embed = nn.Parameter( |
|
|
torch.zeros(1, num_patches_total, embed_dim)) |
|
|
self.pos_drop = nn.Dropout(p=drop_rate) |
|
|
self.blocks = nn.Sequential(*[ |
|
|
Block( |
|
|
dim=embed_dim, |
|
|
num_heads=num_heads, |
|
|
mlp_ratio=mlp_ratio, |
|
|
qkv_bias=qkv_bias, |
|
|
init_values=init_values, |
|
|
drop=drop_rate, |
|
|
attn_drop=attn_drop_rate, |
|
|
norm_layer=_norm_layer, |
|
|
act_layer=_act_layer, |
|
|
attention_type=attention_type, |
|
|
) for i in range(depth) |
|
|
]) |
|
|
self.norm = _norm_layer(embed_dim) |
|
|
self.outputlayer = nn.Linear(self.embed_dim, outputdim) |
|
|
self.apply(self.init_weights) |
|
|
torch.nn.init.normal_(self.mask_token, std=.02) |
|
|
|
|
|
@torch.jit.ignore |
|
|
def no_weight_decay(self): |
|
|
return { |
|
|
'time_pos_embed', 'cls_token', 'freq_pos_embed', 'token_pos_embed' |
|
|
} |
|
|
|
|
|
def init_weights(self, module): |
|
|
if isinstance(module, nn.Linear): |
|
|
nn.init.trunc_normal_(module.weight, std=.02) |
|
|
if module.bias is not None: |
|
|
nn.init.zeros_(module.bias) |
|
|
elif isinstance(module, nn.LayerNorm): |
|
|
nn.init.constant_(module.bias, 0) |
|
|
nn.init.constant_(module.weight, 1.0) |
|
|
|
|
|
def forward_features(self, x, ids_restore): |
|
|
x = self.input_proj(x) |
|
|
mask_tokens = self.mask_token.repeat( |
|
|
x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) |
|
|
if self.use_cls: |
|
|
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) |
|
|
else: |
|
|
x_ = torch.cat([x[:, :, :], mask_tokens], dim=1) |
|
|
x_ = torch.gather(x_, |
|
|
dim=1, |
|
|
index=ids_restore.unsqueeze(-1).repeat( |
|
|
1, 1, x.shape[2])) |
|
|
if self.use_cls: |
|
|
x = torch.cat([x[:, :1, :], x_], dim=1) |
|
|
else: |
|
|
x = x_ |
|
|
t = x.shape[1] |
|
|
|
|
|
x = x + self.pos_embed[:, :t, :] |
|
|
x = self.pos_drop(x) |
|
|
x = self.blocks(x) |
|
|
x = self.norm(x) |
|
|
return x |
|
|
|
|
|
def forward(self, x, restore_idxs): |
|
|
x = self.forward_features(x, restore_idxs) |
|
|
x = self.outputlayer(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class AudioTransformerMAE(nn.Module): |
|
|
|
|
|
def __init__(self, |
|
|
encoder: AudioTransformerMAE_Encoder, |
|
|
decoder: AudioTransformerMAE_Decoder, |
|
|
loss_fn: Optional[torch.nn.Module] = None): |
|
|
super().__init__() |
|
|
self.encoder = encoder |
|
|
self.decoder = decoder |
|
|
self.unfold = nn.Unfold( |
|
|
kernel_size=self.encoder.patch_embed.patch_size, |
|
|
stride=self.encoder.patch_embed.patch_size) |
|
|
self.loss_fn = MAELoss() if loss_fn is None else loss_fn |
|
|
|
|
|
def forward(self, |
|
|
x: torch.Tensor, |
|
|
mask_ratio: float = 0.75, |
|
|
return_loss: bool = False): |
|
|
latent, mask, restore_ids = self.encoder(x, mask_ratio=mask_ratio) |
|
|
pred = self.decoder(latent, restore_ids) |
|
|
with autocast(enabled=False): |
|
|
targets = self.encoder.front_end(x) |
|
|
targets = self.patchify(targets) |
|
|
if return_loss: |
|
|
return self.loss_fn(pred, targets, mask) |
|
|
return pred, targets, mask |
|
|
|
|
|
def patchify(self, x): |
|
|
return self.unfold(x.unsqueeze(1)).transpose(-2, -1) |
|
|
|
|
|
|
|
|
def dasheng_base(**kwargs): |
|
|
encoder_kwargs = dict(embed_dim=768, |
|
|
depth=12, |
|
|
num_heads=12, |
|
|
target_length=1008, |
|
|
patch_size=[64, 4], |
|
|
patch_stride=[64, 4]) |
|
|
encoder_kwargs.update( |
|
|
(k, kwargs[k]) for k in set(kwargs).intersection(encoder_kwargs)) |
|
|
encoder_kwargs = {**encoder_kwargs, **kwargs} |
|
|
encoder = AudioTransformerMAE_Encoder(**encoder_kwargs) |
|
|
|
|
|
decoder_kwargs = dict(embed_dim=512, |
|
|
depth=8, |
|
|
num_heads=16, |
|
|
input_dim=encoder_kwargs['embed_dim'], |
|
|
outputdim=encoder.patch_embed.patch_size[0] * |
|
|
encoder.patch_embed.patch_size[1], |
|
|
num_patches=encoder.patch_embed.num_patches) |
|
|
decoder = AudioTransformerMAE_Decoder(**decoder_kwargs) |
|
|
return AudioTransformerMAE(encoder, decoder) |
|
|
|
|
|
|
|
|
def dasheng_06B(**kwargs): |
|
|
encoder_kwargs = dict( |
|
|
patch_size=[64, 4], |
|
|
patch_stride=[64, 4], |
|
|
embed_dim=1536, |
|
|
depth=24, |
|
|
num_heads=24, |
|
|
mlp_ratio=4, |
|
|
) |
|
|
encoder_kwargs.update( |
|
|
(k, kwargs[k]) for k in set(kwargs).intersection(encoder_kwargs)) |
|
|
encoder_kwargs = {**encoder_kwargs, **kwargs} |
|
|
encoder = AudioTransformerMAE_Encoder(**encoder_kwargs) |
|
|
|
|
|
decoder_kwargs = dict(embed_dim=512, |
|
|
depth=8, |
|
|
num_heads=16, |
|
|
input_dim=encoder_kwargs['embed_dim'], |
|
|
outputdim=encoder.patch_embed.patch_size[0] * |
|
|
encoder.patch_embed.patch_size[1], |
|
|
num_patches=encoder.patch_embed.num_patches) |
|
|
decoder = AudioTransformerMAE_Decoder(**decoder_kwargs) |
|
|
return AudioTransformerMAE(encoder, decoder) |
|
|
|
|
|
|
|
|
def dasheng_12B(**kwargs): |
|
|
encoder_kwargs = dict( |
|
|
patch_size=[64, 4], |
|
|
patch_stride=[64, 4], |
|
|
embed_dim=1536, |
|
|
depth=40, |
|
|
num_heads=24, |
|
|
mlp_ratio=4, |
|
|
) |
|
|
encoder_kwargs.update( |
|
|
(k, kwargs[k]) for k in set(kwargs).intersection(encoder_kwargs)) |
|
|
encoder_kwargs = {**encoder_kwargs, **kwargs} |
|
|
encoder = AudioTransformerMAE_Encoder(**encoder_kwargs) |
|
|
|
|
|
decoder_kwargs = dict(embed_dim=768, |
|
|
depth=8, |
|
|
num_heads=24, |
|
|
input_dim=encoder_kwargs['embed_dim'], |
|
|
outputdim=encoder.patch_embed.patch_size[0] * |
|
|
encoder.patch_embed.patch_size[1], |
|
|
num_patches=encoder.patch_embed.num_patches) |
|
|
decoder = AudioTransformerMAE_Decoder(**decoder_kwargs) |
|
|
return AudioTransformerMAE(encoder, decoder) |
|
|
|