Spaces:
Running
Running
File size: 5,155 Bytes
506a2b4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
from typing import List, Tuple
import numpy as np
import librosa
import torch
import torch.nn.functional as F
from s3tokenizer.utils import padding
from s3tokenizer.model_v2 import (
S3TokenizerV2,
ModelConfig,
)
# Sampling rate of the inputs to S3TokenizerV2
S3_SR = 16_000
S3_HOP = 160 # 100 frames/sec
S3_TOKEN_HOP = 640 # 25 tokens/sec
S3_TOKEN_RATE = 25
SPEECH_VOCAB_SIZE = 6561
class S3Tokenizer(S3TokenizerV2):
"""
s3tokenizer.S3TokenizerV2 with the following changes:
- a more integrated `forward`
- compute `log_mel_spectrogram` using `_mel_filters` and `window` in `register_buffers`
"""
ignore_state_dict_missing = ("_mel_filters", "window")
def __init__(
self,
name: str="speech_tokenizer_v2_25hz",
config: ModelConfig = ModelConfig()
):
super().__init__(name)
self.n_fft = 400
_mel_filters = librosa.filters.mel(
sr=S3_SR,
n_fft=self.n_fft,
n_mels=config.n_mels
)
self.register_buffer(
"_mel_filters",
torch.FloatTensor(_mel_filters),
)
self.register_buffer(
"window",
torch.hann_window(self.n_fft),
)
def pad(self, wavs, sr) -> List[torch.Tensor]:
"""
Given a list of wavs with the same `sample_rate`, pad them so that the length is multiple of 40ms (S3 runs at 25 token/sec).
"""
processed_wavs = []
for wav in wavs:
if isinstance(wav, np.ndarray):
wav = torch.from_numpy(wav)
if wav.dim() == 1:
wav = wav.unsqueeze(0)
n_tokens = (wav.shape[1] / sr) * S3_TOKEN_RATE
n_tokens = np.ceil(n_tokens)
intended_wav_len = n_tokens * (sr / S3_TOKEN_RATE)
intended_wav_len = int(intended_wav_len)
wav = torch.nn.functional.pad(
wav,
(0, intended_wav_len - wav.shape[-1]),
mode="constant",
value=0
)
processed_wavs.append(wav)
return processed_wavs
def _prepare_audio(self, wavs):
"""Prepare a list of audios for s3tokenizer processing."""
processed_wavs = []
for wav in wavs:
if isinstance(wav, np.ndarray):
wav = torch.from_numpy(wav)
if wav.dim() == 1:
wav = wav.unsqueeze(0)
processed_wavs.append(wav)
return processed_wavs
@torch.no_grad()
def forward(
self,
wavs: torch.Tensor,
accelerator: 'Accelerator'=None,
max_len: int=None,
) -> Tuple[torch.Tensor, torch.LongTensor]:
"""
NOTE: mel-spec has a hop size of 160 points (100 frame/sec).
FIXME: this class inherits `nn.Module` but doesn't accept `torch.Tensor` and handles a list of wavs one by one, which is unexpected.
Args
----
- `wavs`: 16 kHz speech audio
- `max_len` max length to truncate the output sequence to (25 token/sec).
NOTE: please pad the waveform if longer sequence is needed.
"""
processed_wavs = self._prepare_audio(wavs)
mels, mel_lens = [], []
for wav in processed_wavs:
wav = wav.to(self.device)
mel = self.log_mel_spectrogram(wav) # [B=1, F, T]
if max_len is not None:
mel = mel[..., :max_len * 4] # num_mel_frames = 4 * num_tokens
mels.append(mel.squeeze(0))
mels, mel_lens = padding(mels)
if accelerator is None:
tokenizer = self
else:
tokenizer = accelerator.unwrap_model(self)
speech_tokens, speech_token_lens = tokenizer.quantize(mels, mel_lens.to(self.device))
return (
speech_tokens.long().detach(),
speech_token_lens.long().detach(),
)
def log_mel_spectrogram(
self,
audio: torch.Tensor,
padding: int = 0,
):
"""
Compute the log-Mel spectrogram of
Parameters
----------
audio: torch.Tensor, shape = (*)
The path to audio or either a NumPy array or Tensor containing the
audio waveform in 16 kHz
padding: int
Number of zero samples to pad to the right
Returns
-------
torch.Tensor, shape = (128, n_frames)
A Tensor that contains the Mel spectrogram
"""
if not torch.is_tensor(audio):
audio = torch.from_numpy(audio)
audio = audio.to(self.device)
if padding > 0:
audio = F.pad(audio, (0, padding))
stft = torch.stft(
audio, self.n_fft, S3_HOP,
window=self.window.to(self.device),
return_complex=True
)
magnitudes = stft[..., :-1].abs()**2
mel_spec = self._mel_filters.to(self.device) @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec
|