| import io |
| from hashlib import sha256 |
| from pathlib import Path |
| from typing import Callable, Literal, Tuple |
|
|
| import torch |
| import torchaudio |
| from loguru import logger |
|
|
| from fish_speech.models.dac.modded_dac import DAC |
| from fish_speech.utils.file import ( |
| AUDIO_EXTENSIONS, |
| audio_to_bytes, |
| list_files, |
| read_ref_text, |
| ) |
| from fish_speech.utils.schema import ServeReferenceAudio |
|
|
|
|
| class ReferenceLoader: |
|
|
| def __init__(self) -> None: |
| """ |
| Component of the TTSInferenceEngine class. |
| Loads and manages the cache for the reference audio and text. |
| """ |
| self.ref_by_id: dict = {} |
| self.ref_by_hash: dict = {} |
|
|
| |
| self.decoder_model: DAC |
| self.encode_reference: Callable |
|
|
| |
| backends = torchaudio.list_audio_backends() |
| if "ffmpeg" in backends: |
| self.backend = "ffmpeg" |
| else: |
| self.backend = "soundfile" |
|
|
| def load_by_id( |
| self, |
| id: str, |
| use_cache: Literal["on", "off"], |
| ) -> Tuple: |
|
|
| |
| ref_folder = Path("references") / id |
| ref_folder.mkdir(parents=True, exist_ok=True) |
| ref_audios = list_files( |
| ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False |
| ) |
|
|
| if use_cache == "off" or id not in self.ref_by_id: |
| |
| prompt_tokens = [ |
| self.encode_reference( |
| |
| reference_audio=audio_to_bytes(str(ref_audio)), |
| enable_reference_audio=True, |
| ) |
| for ref_audio in ref_audios |
| ] |
| prompt_texts = [ |
| read_ref_text(str(ref_audio.with_suffix(".lab"))) |
| for ref_audio in ref_audios |
| ] |
| self.ref_by_id[id] = (prompt_tokens, prompt_texts) |
|
|
| else: |
| |
| logger.info("Use same references") |
| prompt_tokens, prompt_texts = self.ref_by_id[id] |
|
|
| return prompt_tokens, prompt_texts |
|
|
| def load_by_hash( |
| self, |
| references: list[ServeReferenceAudio], |
| use_cache: Literal["on", "off"], |
| ) -> Tuple: |
|
|
| |
| audio_hashes = [sha256(ref.audio).hexdigest() for ref in references] |
|
|
| cache_used = False |
| prompt_tokens, prompt_texts = [], [] |
| for i, ref in enumerate(references): |
| if use_cache == "off" or audio_hashes[i] not in self.ref_by_hash: |
| |
| prompt_tokens.append( |
| self.encode_reference( |
| reference_audio=ref.audio, |
| enable_reference_audio=True, |
| ) |
| ) |
| prompt_texts.append(ref.text) |
| self.ref_by_hash[audio_hashes[i]] = (prompt_tokens[-1], ref.text) |
|
|
| else: |
| |
| cached_token, cached_text = self.ref_by_hash[audio_hashes[i]] |
| prompt_tokens.append(cached_token) |
| prompt_texts.append(cached_text) |
| cache_used = True |
|
|
| if cache_used: |
| logger.info("Use same references") |
|
|
| return prompt_tokens, prompt_texts |
|
|
| def load_audio(self, reference_audio, sr): |
| """ |
| Load the audio data from a file or bytes. |
| """ |
| if len(reference_audio) > 255 or not Path(reference_audio).exists(): |
| audio_data = reference_audio |
| reference_audio = io.BytesIO(audio_data) |
|
|
| waveform, original_sr = torchaudio.load(reference_audio, backend=self.backend) |
|
|
| if waveform.shape[0] > 1: |
| waveform = torch.mean(waveform, dim=0, keepdim=True) |
|
|
| if original_sr != sr: |
| resampler = torchaudio.transforms.Resample( |
| orig_freq=original_sr, new_freq=sr |
| ) |
| waveform = resampler(waveform) |
|
|
| audio = waveform.squeeze().numpy() |
| return audio |
|
|