FlexSED / src /models /.ipynb_checkpoints /dasheng-checkpoint.py
OpenSound's picture
Upload 544 files
3b6a091 verified
raw
history blame
25 kB
import torch
import torch.nn as nn
from einops import rearrange
import torch
from torch.cuda.amp import autocast
from functools import partial
from typing import Optional, Tuple, Union
import torchaudio.transforms as audio_transforms
from einops import rearrange
from einops.layers.torch import Rearrange
from itertools import repeat
import collections
import torch.nn.functional as F
import einops
if hasattr(nn.functional, 'scaled_dot_product_attention'):
ATTENTION_MODE = 'flash'
else:
ATTENTION_MODE = 'math'
print(f'attention mode is {ATTENTION_MODE}')
def _ntuple(n):
def parse(x) -> Tuple:
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
return tuple(x)
return tuple(repeat(x, n))
return parse
to_2tuple = _ntuple(2)
class MAELoss(torch.nn.Module):
def __init__(self, norm_pix_loss: bool = True):
super().__init__()
self.norm_pix_loss = norm_pix_loss
@autocast(enabled=False)
def forward(self, pred: torch.Tensor, target: torch.Tensor,
mask: torch.Tensor) -> torch.Tensor:
if self.norm_pix_loss is True:
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
target = (target - mean) / (var + 1.e-6)**.5
elif self.norm_pix_loss == 'global':
mean = target.mean()
var = target.var()
target = (target - mean) / (var + 1.e-6)**.5
loss = (pred - target)**2
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
return loss
class AudioPatchEmbed(nn.Module):
def __init__(self,
input_size: Union[int, Tuple[int, int]] = (64, 100),
patch_size: Tuple[int, int] = (64, 4),
patch_stride: Tuple[int, int] = (64, 4),
in_chans=1,
embed_dim=768,
norm_layer=None,
flatten=False):
super().__init__()
patch_size = to_2tuple(patch_size)
patch_stride = to_2tuple(patch_stride)
self.input_size: Tuple[int, int] = to_2tuple(input_size)
self.patch_size: Tuple[int, int] = to_2tuple(patch_size)
self.patch_stride: Tuple[int, int] = to_2tuple(patch_stride)
self.grid_size = (self.input_size[0] // self.patch_stride[0],
self.input_size[1] // self.patch_stride[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.proj = nn.Conv2d(in_chans,
embed_dim,
kernel_size=patch_size,
stride=patch_stride)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
if self.flatten:
x = rearrange(x, 'b c f t -> b (f t) c')
x = self.norm(x)
return x
class LayerScale(nn.Module):
def __init__(self, dim: int, init_values=1e-5, inplace=False):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x):
return x.mul_(self.gamma) if self.inplace else x * self.gamma
class Attention(nn.Module):
def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
attn_drop=0.,
proj_drop=0.):
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.attn_drop_p = attn_drop
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(
0) # make torchscript happy (cannot use tensor as tuple)
if ATTENTION_MODE == 'flash':
x = F.scaled_dot_product_attention(q, k, v,
dropout_p=self.attn_drop_p,
scale=self.scale,
)
x = einops.rearrange(x, 'B H L D -> B L (H D)')
elif ATTENTION_MODE == 'math':
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Mlp(nn.Module):
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Block(nn.Module):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
drop=0.,
attn_drop=0.,
init_values=None,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
attention_type='Attention',
):
super().__init__()
self.norm1 = norm_layer(dim)
attn_type = globals()[attention_type]
self.attn = attn_type(dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop)
self.ls1 = LayerScale(
dim, init_values=init_values) if init_values else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = Mlp(in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=drop)
self.ls2 = LayerScale(
dim, init_values=init_values) if init_values else nn.Identity()
def forward(self, x):
x = x + self.ls1(self.attn(self.norm1(x)))
x = x + self.ls2(self.mlp(self.norm2(x)))
return x
class AudioTransformerMAE_Encoder(nn.Module):
def __init__(self,
patch_size: Tuple[int, int] = (64, 4),
patch_stride: Tuple[int, int] = (64, 4),
embed_dim: int = 768,
depth: int = 12,
num_heads=8,
mlp_ratio=4.,
qkv_bias=True,
drop_rate=0.,
attn_drop_rate=0.,
norm_layer=None,
act_layer=None,
init_values=None,
target_length=1008,
pooling='mean',
time_patch_out: Optional[float] = None,
freq_patch_out: Optional[float] = None,
block_type='Block',
attention_type='Attention',
eval_avg='cat',
n_fft: int = 512,
n_mels: int = 64,
hop_size: int = 160,
win_size: int = 512,
f_min: int = 0,
f_max: int = 8000,
center: bool = True,
**kwargs):
super().__init__()
self.pooling = pooling
self.embed_dim = embed_dim
self.patch_stride = patch_stride
self.patch_size = patch_size
self.n_mels = n_mels
self.eval_avg = eval_avg
self.time_patch_out = time_patch_out
self.freq_patch_out = freq_patch_out
self.front_end = nn.Sequential(
audio_transforms.MelSpectrogram(f_min=f_min,
sample_rate=16000,
win_length=win_size,
center=center,
n_fft=n_fft,
f_max=f_max,
hop_length=hop_size,
n_mels=self.n_mels),
audio_transforms.AmplitudeToDB(top_db=kwargs.get('top_db', 120)))
self.init_bn = nn.Sequential(
Rearrange('b c f t -> b f c t'),
nn.BatchNorm2d(self.n_mels, momentum=0.01),
Rearrange('b f c t -> b c f t'))
self.target_length = target_length
self.patch_embed = AudioPatchEmbed(input_size=(self.n_mels,
target_length),
embed_dim=self.embed_dim,
patch_size=self.patch_size,
flatten=False,
patch_stride=self.patch_stride)
self.num_patches = self.patch_embed.num_patches
if pooling == 'token':
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.token_pos_embed = nn.Parameter(
torch.randn(1, embed_dim) * .02)
self.time_pos_embed = nn.Parameter(
torch.randn(1, embed_dim, 1, self.patch_embed.grid_size[1]) * .02)
self.freq_pos_embed = nn.Parameter(
torch.randn(1, embed_dim, self.patch_embed.grid_size[0], 1) * .02)
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
self.pos_drop = nn.Dropout(p=drop_rate)
block_function = globals()[block_type]
self.blocks = nn.Sequential(*[
block_function(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
init_values=init_values,
drop=drop_rate,
attn_drop=attn_drop_rate,
norm_layer=norm_layer,
act_layer=act_layer,
attention_type=attention_type,
) for _ in range(depth)
])
self.norm = norm_layer(embed_dim)
self.apply(self.init_weights)
if hasattr(self, 'cls_token') and self.cls_token is not None:
nn.init.normal_(self.cls_token, std=1e-6)
group_masking = kwargs.get('group_masking', False)
if isinstance(group_masking, bool):
if group_masking is True:
self.masking_func = self.random_masking_group
else:
self.masking_func = self.random_masking
elif isinstance(group_masking, int):
self.masking_func = partial(self.random_masking_group,
group_factor=group_masking)
@torch.jit.ignore
def no_weight_decay(self):
return {
'time_pos_embed', 'cls_token', 'freq_pos_embed', 'token_pos_embed'
}
def init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
nn.init.constant_(module.bias, 0)
nn.init.constant_(module.weight, 1.0)
def random_masking_group(self, x, mask_ratio, group_factor: int = 2):
"""
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
x: [N, L, D], sequence
"""
N, L, D = x.shape # batch, length, dim
len_keep = int(L * (1 - mask_ratio))
noise = torch.rand(N, L // group_factor,
device=x.device) # noise in [0, 1]
# indices = torch.arange(L).view(1, 5, 4).repeat(N, 1, 1)
indices = torch.arange(L, device=x.device).view(-1, group_factor)
# sort noise for each sample
ids_shuffle = torch.argsort(
noise, dim=1) # ascend: small is keep, large is remove
ids_shuffle = indices[ids_shuffle].flatten(-2)
ids_restore = torch.argsort(ids_shuffle, dim=1)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(x,
dim=1,
index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore
def random_masking(self, x, mask_ratio):
"""
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
x: [N, L, D], sequence
"""
N, L, D = x.shape # batch, length, dim
len_keep = int(L * (1 - mask_ratio))
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
# sort noise for each sample
ids_shuffle = torch.argsort(
noise, dim=1) # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(x,
dim=1,
index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore
def forward_features(self, x, mask_ratio):
x = self.patch_embed(x)
b, c, f, t = x.shape
x = x + self.time_pos_embed[:, :, :, :t]
x = x + self.freq_pos_embed[:, :, :, :] # Just for sin pos embed
x = rearrange(x, 'b c f t -> b (f t) c')
# x, mask, ids_restore = self.random_masking(x, mask_ratio)
x, mask, ids_restore = self.masking_func(x, mask_ratio)
if self.pooling == 'token':
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
cls_token = cls_token + self.token_pos_embed[:, :]
x = torch.cat((cls_token, x), dim=1)
x = self.pos_drop(x)
x = self.blocks(x)
x = self.norm(x)
return x, mask, ids_restore
def load_state_dict(self, state_dict, strict=True, **kwargs):
if 'time_pos_embed' in state_dict and self.time_pos_embed.shape != state_dict[
'time_pos_embed'].shape:
print(
"Positional Embedding shape not the same with model, resizing!"
)
self.change_pos_embedding(state_dict)
super().load_state_dict(state_dict, strict=strict, **kwargs)
def change_pos_embedding(self, state_dict):
target_time_pos_embed_length = self.time_pos_embed.shape[-1]
target_freq_pos_embed_length = self.freq_pos_embed.shape[-2]
pretrained_time_pos_embed = state_dict['time_pos_embed']
pretrained_freq_pos_embed = state_dict['freq_pos_embed']
if target_freq_pos_embed_length <= pretrained_time_pos_embed.shape[-1]:
state_dict['time_pos_embed'] = pretrained_time_pos_embed[
..., :target_time_pos_embed_length]
else:
state_dict['time_pos_embed'] = torch.nn.functional.interpolate(
pretrained_time_pos_embed,
size=(1, target_time_pos_embed_length),
align_corners=False,
mode='bilinear')
if target_freq_pos_embed_length <= pretrained_freq_pos_embed.shape[-2]:
state_dict[
'freq_pos_embed'] = pretrained_freq_pos_embed[:, :, :
target_freq_pos_embed_length, :]
else:
state_dict['freq_pos_embed'] = torch.nn.functional.interpolate(
pretrained_freq_pos_embed,
size=(target_freq_pos_embed_length, 1),
align_corners=False,
mode='bilinear')
def forward_to_spec(self, x):
# Do not use fp16 for feature extraction, that is likely to get nan
with autocast(enabled=False):
X = self.front_end(x)
X = rearrange(X, 'b f t -> b 1 f t')
X = self.init_bn(X)
return X
def forward(self, x, mask_ratio: float = 0.75):
x = self.forward_to_spec(x)
x, mask, restore_idxs = self.forward_features(x, mask_ratio=mask_ratio)
return x, mask, restore_idxs
class AudioTransformerMAE_Decoder(nn.Module):
def __init__(self,
input_dim: int,
outputdim: int,
patch_size: int = 16,
patch_stride: int = 16,
embed_dim: int = 768,
num_patches: int = 100,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.,
qkv_bias: bool = True,
drop_rate: float = 0.,
attn_drop_rate: float = 0.,
norm_layer: Optional[torch.nn.Module] = None,
act_layer: Optional[torch.nn.Module] = None,
cls_token: bool = False,
attention_type='Attention',
init_values=None,
**kwargs):
super().__init__()
self.embed_dim = embed_dim
self.patch_stride = patch_stride
self.patch_size = patch_size
self.input_dim = input_dim
self.input_proj = nn.Linear(input_dim, embed_dim)
self.mask_token = nn.Parameter(torch.randn(1, 1, embed_dim) * .02)
_norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
_act_layer = act_layer or nn.GELU
self.use_cls = cls_token
num_patches_total = num_patches + 1 if not cls_token else num_patches
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches_total, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
self.blocks = nn.Sequential(*[
Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
init_values=init_values,
drop=drop_rate,
attn_drop=attn_drop_rate,
norm_layer=_norm_layer,
act_layer=_act_layer,
attention_type=attention_type,
) for i in range(depth)
])
self.norm = _norm_layer(embed_dim)
self.outputlayer = nn.Linear(self.embed_dim, outputdim)
self.apply(self.init_weights)
torch.nn.init.normal_(self.mask_token, std=.02)
@torch.jit.ignore
def no_weight_decay(self):
return {
'time_pos_embed', 'cls_token', 'freq_pos_embed', 'token_pos_embed'
}
def init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.trunc_normal_(module.weight, std=.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
nn.init.constant_(module.bias, 0)
nn.init.constant_(module.weight, 1.0)
def forward_features(self, x, ids_restore):
x = self.input_proj(x)
mask_tokens = self.mask_token.repeat(
x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
if self.use_cls:
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
else:
x_ = torch.cat([x[:, :, :], mask_tokens], dim=1)
x_ = torch.gather(x_,
dim=1,
index=ids_restore.unsqueeze(-1).repeat(
1, 1, x.shape[2])) # unshuffle
if self.use_cls:
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
else:
x = x_
t = x.shape[1]
x = x + self.pos_embed[:, :t, :]
x = self.pos_drop(x)
x = self.blocks(x)
x = self.norm(x)
return x
def forward(self, x, restore_idxs):
x = self.forward_features(x, restore_idxs)
x = self.outputlayer(x)
return x
class AudioTransformerMAE(nn.Module):
def __init__(self,
encoder: AudioTransformerMAE_Encoder,
decoder: AudioTransformerMAE_Decoder,
loss_fn: Optional[torch.nn.Module] = None):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.unfold = nn.Unfold(
kernel_size=self.encoder.patch_embed.patch_size,
stride=self.encoder.patch_embed.patch_size)
self.loss_fn = MAELoss() if loss_fn is None else loss_fn
def forward(self,
x: torch.Tensor,
mask_ratio: float = 0.75,
return_loss: bool = False):
latent, mask, restore_ids = self.encoder(x, mask_ratio=mask_ratio)
pred = self.decoder(latent, restore_ids)
with autocast(enabled=False):
targets = self.encoder.front_end(x)
targets = self.patchify(targets)
if return_loss:
return self.loss_fn(pred, targets, mask)
return pred, targets, mask
def patchify(self, x):
return self.unfold(x.unsqueeze(1)).transpose(-2, -1)
def dasheng_base(**kwargs):
encoder_kwargs = dict(embed_dim=768,
depth=12,
num_heads=12,
target_length=1008,
patch_size=[64, 4],
patch_stride=[64, 4])
encoder_kwargs.update(
(k, kwargs[k]) for k in set(kwargs).intersection(encoder_kwargs))
encoder_kwargs = {**encoder_kwargs, **kwargs}
encoder = AudioTransformerMAE_Encoder(**encoder_kwargs)
decoder_kwargs = dict(embed_dim=512,
depth=8,
num_heads=16,
input_dim=encoder_kwargs['embed_dim'],
outputdim=encoder.patch_embed.patch_size[0] *
encoder.patch_embed.patch_size[1],
num_patches=encoder.patch_embed.num_patches)
decoder = AudioTransformerMAE_Decoder(**decoder_kwargs)
return AudioTransformerMAE(encoder, decoder)
def dasheng_06B(**kwargs):
encoder_kwargs = dict(
patch_size=[64, 4],
patch_stride=[64, 4],
embed_dim=1536,
depth=24,
num_heads=24,
mlp_ratio=4,
)
encoder_kwargs.update(
(k, kwargs[k]) for k in set(kwargs).intersection(encoder_kwargs))
encoder_kwargs = {**encoder_kwargs, **kwargs}
encoder = AudioTransformerMAE_Encoder(**encoder_kwargs)
decoder_kwargs = dict(embed_dim=512,
depth=8,
num_heads=16,
input_dim=encoder_kwargs['embed_dim'],
outputdim=encoder.patch_embed.patch_size[0] *
encoder.patch_embed.patch_size[1],
num_patches=encoder.patch_embed.num_patches)
decoder = AudioTransformerMAE_Decoder(**decoder_kwargs)
return AudioTransformerMAE(encoder, decoder)
def dasheng_12B(**kwargs):
encoder_kwargs = dict(
patch_size=[64, 4],
patch_stride=[64, 4],
embed_dim=1536,
depth=40,
num_heads=24,
mlp_ratio=4,
)
encoder_kwargs.update(
(k, kwargs[k]) for k in set(kwargs).intersection(encoder_kwargs))
encoder_kwargs = {**encoder_kwargs, **kwargs}
encoder = AudioTransformerMAE_Encoder(**encoder_kwargs)
decoder_kwargs = dict(embed_dim=768,
depth=8,
num_heads=24,
input_dim=encoder_kwargs['embed_dim'],
outputdim=encoder.patch_embed.patch_size[0] *
encoder.patch_embed.patch_size[1],
num_patches=encoder.patch_embed.num_patches)
decoder = AudioTransformerMAE_Decoder(**decoder_kwargs)
return AudioTransformerMAE(encoder, decoder)