# -------------------------------------------------------- # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058) # Github source: https://github.com/microsoft/unilm/tree/master/beats # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Based on fairseq code bases # https://github.com/pytorch/fairseq # -------------------------------------------------------- import torch import torch.nn as nn from torch.nn import LayerNorm import torchaudio.compliance.kaldi as ta_kaldi from beats.backbone import ( TransformerEncoder, ) import logging from typing import Optional logger = logging.getLogger(__name__) class BEATsConfig: def __init__(self, cfg=None): self.input_patch_size: int = -1 # path size of patch embedding self.embed_dim: int = 512 # patch embedding dimension self.conv_bias: bool = False # include bias in conv encoder self.encoder_layers: int = 12 # num encoder layers in the transformer self.encoder_embed_dim: int = 768 # encoder embedding dimension self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN self.encoder_attention_heads: int = 12 # num encoder attention heads self.activation_fn: str = "gelu" # activation function to use self.layer_wise_gradient_decay_ratio: float = 1.0 # ratio for layer-wise gradient decay self.layer_norm_first: bool = False # apply layernorm first in the transformer self.deep_norm: bool = False # apply deep_norm first in the transformer # dropouts self.dropout: float = 0.1 # dropout probability for the transformer self.attention_dropout: float = 0.1 # dropout probability for attention weights self.activation_dropout: float = 0.0 # dropout probability after activation in FFN self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr) # positional embeddings self.conv_pos: int = 128 # number of filters for convolutional positional embeddings self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding # relative position embedding self.relative_position_embedding: bool = False # apply relative position embedding self.num_buckets: int = 320 # number of buckets for relative position embedding self.max_distance: int = 1280 # maximum distance for relative position embedding self.gru_rel_pos: bool = False # apply gated relative position embedding # label predictor self.finetuned_model: bool = False # whether the model is a fine-tuned model. self.predictor_dropout: float = 0.1 # dropout probability for the predictor self.predictor_class: int = 527 # target class number for the predictor if cfg is not None: self.update(cfg) def update(self, cfg: dict): self.__dict__.update(cfg) class BEATs(nn.Module): def __init__( self, cfg: BEATsConfig, ) -> None: super().__init__() logger.info(f"BEATs Config: {cfg.__dict__}") self.cfg = cfg self.embed = cfg.embed_dim self.post_extract_proj = ( nn.Linear(self.embed, cfg.encoder_embed_dim) if self.embed != cfg.encoder_embed_dim else None ) self.input_patch_size = cfg.input_patch_size self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size, bias=cfg.conv_bias) self.dropout_input = nn.Dropout(cfg.dropout_input) assert not cfg.deep_norm or not cfg.layer_norm_first self.encoder = TransformerEncoder(cfg) self.layer_norm = LayerNorm(self.embed) if cfg.finetuned_model: self.predictor_dropout = nn.Dropout(cfg.predictor_dropout) self.predictor = nn.Linear(cfg.encoder_embed_dim, cfg.predictor_class) else: self.predictor = None def forward_padding_mask( self, features: torch.Tensor, padding_mask: torch.Tensor, ) -> torch.Tensor: extra = padding_mask.size(1) % features.size(1) if extra > 0: padding_mask = padding_mask[:, :-extra] padding_mask = padding_mask.view( padding_mask.size(0), features.size(1), -1 ) padding_mask = padding_mask.all(-1) return padding_mask def preprocess( self, source: torch.Tensor, fbank_mean: float = 15.41663, fbank_std: float = 6.55582, ) -> torch.Tensor: fbanks = [] for waveform in source: waveform = waveform.unsqueeze(0) * 2 ** 15 fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10) fbanks.append(fbank) fbank = torch.stack(fbanks, dim=0) fbank = (fbank - fbank_mean) / (2 * fbank_std) return fbank def extract_features( self, source: torch.Tensor, padding_mask: Optional[torch.Tensor] = None, fbank_mean: float = 15.41663, fbank_std: float = 6.55582, feature_only=False, ): fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std).to(torch.float32) if padding_mask is not None: padding_mask = self.forward_padding_mask(fbank, padding_mask) fbank = fbank.unsqueeze(1) features = self.patch_embedding(fbank) features = features.reshape(features.shape[0], features.shape[1], -1) features = features.transpose(1, 2) features = self.layer_norm(features) if padding_mask is not None: padding_mask = self.forward_padding_mask(features, padding_mask) if self.post_extract_proj is not None: features = self.post_extract_proj(features) x = self.dropout_input(features) x, layer_results = self.encoder( x, padding_mask=padding_mask, ) if not feature_only and self.predictor is not None: x = self.predictor_dropout(x) logits = self.predictor(x) if padding_mask is not None and padding_mask.any(): logits[padding_mask] = 0 logits = logits.sum(dim=1) logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(logits) else: logits = logits.mean(dim=1) lprobs = torch.sigmoid(logits) return lprobs, padding_mask else: return x, padding_mask