# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import typing as tp from einops import rearrange from librosa import filters import torch from torch import nn import torch.nn.functional as F import torchaudio class ChromaExtractor(nn.Module): """Chroma extraction and quantization. Args: sample_rate (int): Sample rate for the chroma extraction. n_chroma (int): Number of chroma bins for the chroma extraction. radix2_exp (int): Size of stft window for the chroma extraction (power of 2, e.g. 12 -> 2^12). nfft (int, optional): Number of FFT. winlen (int, optional): Window length. winhop (int, optional): Window hop size. argmax (bool, optional): Whether to use argmax. Defaults to False. norm (float, optional): Norm for chroma normalization. Defaults to inf. """ def __init__(self, sample_rate: int, n_chroma: int = 12, radix2_exp: int = 12, nfft: tp.Optional[int] = None, winlen: tp.Optional[int] = None, winhop: tp.Optional[int] = None, argmax: bool = False, norm: float = torch.inf): super().__init__() self.winlen = winlen or 2 ** radix2_exp self.nfft = nfft or self.winlen self.winhop = winhop or (self.winlen // 4) self.sample_rate = sample_rate self.n_chroma = n_chroma self.norm = norm self.argmax = argmax self.register_buffer('fbanks', torch.from_numpy(filters.chroma(sr=sample_rate, n_fft=self.nfft, tuning=0, n_chroma=self.n_chroma)), persistent=False) self.spec = torchaudio.transforms.Spectrogram(n_fft=self.nfft, win_length=self.winlen, hop_length=self.winhop, power=2, center=True, pad=0, normalized=True) def forward(self, wav: torch.Tensor) -> torch.Tensor: T = wav.shape[-1] # in case we are getting a wav that was dropped out (nullified) # from the conditioner, make sure wav length is no less that nfft if T < self.nfft: pad = self.nfft - T r = 0 if pad % 2 == 0 else 1 wav = F.pad(wav, (pad // 2, pad // 2 + r), 'constant', 0) assert wav.shape[-1] == self.nfft, f"expected len {self.nfft} but got {wav.shape[-1]}" spec = self.spec(wav).squeeze(1) raw_chroma = torch.einsum('cf,...ft->...ct', self.fbanks, spec) norm_chroma = torch.nn.functional.normalize(raw_chroma, p=self.norm, dim=-2, eps=1e-6) norm_chroma = rearrange(norm_chroma, 'b d t -> b t d') if self.argmax: idx = norm_chroma.argmax(-1, keepdim=True) norm_chroma[:] = 0 norm_chroma.scatter_(dim=-1, index=idx, value=1) return norm_chroma