# cp from https://github.com/lifeiteng/vall-e/blob/main/valle/data/tokenizer.py # Copyright 2023 (authors: Feiteng Li) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import re from dataclasses import asdict, dataclass from typing import Any, Dict, List, Optional, Pattern, Union import numpy as np import torch import torchaudio # from lhotse.features import FeatureExtractor # from lhotse.utils import Seconds, compute_num_frames from phonemizer.backend import EspeakBackend from phonemizer.backend.espeak.language_switch import LanguageSwitch from phonemizer.backend.espeak.words_mismatch import WordMismatch from phonemizer.punctuation import Punctuation from phonemizer.separator import Separator class TextTokenizer: """Phonemize Text.""" def __init__( self, language="en-us", backend="espeak", separator=Separator(word="_", syllable="-", phone="|"), preserve_punctuation=True, punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(), with_stress: bool = False, tie: Union[bool, str] = False, language_switch: LanguageSwitch = "keep-flags", words_mismatch: WordMismatch = "ignore", ) -> None: phonemizer = EspeakBackend( language, punctuation_marks=punctuation_marks, preserve_punctuation=preserve_punctuation, with_stress=with_stress, tie=tie, language_switch=language_switch, words_mismatch=words_mismatch, ) self.backend = phonemizer self.separator = separator def to_list(self, phonemized: str) -> List[str]: fields = [] for word in phonemized.split(self.separator.word): # "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z. pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE) fields.extend( [p for p in pp if p != self.separator.phone] + [self.separator.word] ) assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count( self.separator.phone ) return fields[:-1] def __call__(self, text, strip=True) -> List[List[str]]: if isinstance(text, str): text = [text] phonemized = self.backend.phonemize( text, separator=self.separator, strip=strip, njobs=1 ) return [self.to_list(p) for p in phonemized] def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]: phonemes = tokenizer([text.strip()]) return phonemes[0] # k2symbols def convert_audio(wav: torch.Tensor, sr: int, target_sr: int, target_channels: int): assert wav.shape[0] in [1, 2], "Audio must be mono or stereo." if target_channels == 1: wav = wav.mean(0, keepdim=True) elif target_channels == 2: *shape, _, length = wav.shape wav = wav.expand(*shape, target_channels, length) elif wav.shape[0] == 1: wav = wav.expand(target_channels, -1) wav = torchaudio.transforms.Resample(sr, target_sr)(wav) return wav class AudioTokenizer: """EnCodec audio.""" def __init__( self, device: Any = None, signature = None ) -> None: from audiocraft.solvers import CompressionSolver model = CompressionSolver.model_from_checkpoint(signature) self.sample_rate = model.sample_rate self.channels = model.channels if not device: device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda:0") self._device = device self.codec = model.to(device) @property def device(self): return self._device def encode(self, wav: torch.Tensor) -> torch.Tensor: codes = self.codec.encode(wav.to(self.device)) return [(codes[0], None)] def decode(self, frames: torch.Tensor) -> torch.Tensor: frames = frames[0][0] # [1,4,T] return self.codec.decode(frames) def tokenize_audio(tokenizer: AudioTokenizer, audio_path: str, offset = -1, num_frames=-1): # Load and pre-process the audio waveform if offset != -1 and num_frames!=-1: wav, sr = torchaudio.load(audio_path, frame_offset=offset, num_frames=num_frames) else: wav, sr = torchaudio.load(audio_path) wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels) wav = wav.unsqueeze(0) # Extract discrete codes from EnCodec with torch.no_grad(): encoded_frames = tokenizer.encode(wav) return encoded_frames