Spaces:
Runtime error
Runtime error
import os | |
import pathlib | |
import uuid | |
from abc import ABC, abstractmethod | |
from typing import Callable, Optional, Union | |
import julius | |
import torch | |
from audiocraft.data.audio import audio_read, audio_write | |
from audiocraft.models import MultiBandDiffusion # type: ignore | |
mbd = MultiBandDiffusion.get_mbd_24khz(bw=6) # 1.5 | |
class Decoder(ABC): | |
def decode(self, tokens: list[int], ref_audio_path: Optional[str] = None, causal: Optional[bool] = None): | |
raise NotImplementedError | |
class EncodecDecoder(Decoder): | |
def __init__( | |
self, | |
tokeniser_decode_fn: Callable[[list[int]], str], | |
data_adapter_fn: Callable[[list[list[int]]], tuple[list[int], list[list[int]]]], | |
output_dir: str, | |
): | |
self._mbd_sample_rate = 24_000 | |
self._end_of_audio_token = 1024 | |
self._num_codebooks = 8 | |
self.mbd = mbd | |
self.tokeniser_decode_fn = tokeniser_decode_fn | |
self._data_adapter_fn = data_adapter_fn | |
self.output_dir = pathlib.Path(output_dir).resolve() | |
os.makedirs(self.output_dir, exist_ok=True) | |
def _save_audio(self, name: str, wav: torch.Tensor): | |
audio_write( | |
name, | |
wav.squeeze(0).cpu(), | |
self._mbd_sample_rate, | |
strategy="loudness", | |
loudness_compressor=True, | |
) | |
def get_tokens(self, audio_path: str) -> list[list[int]]: | |
""" | |
Utility method to get tokens from audio. Useful when you want to test reconstruction in some form (e.g. | |
limited codebook reconstruction or sampling from second stage model only). | |
""" | |
pass | |
wav, sr = audio_read(audio_path) | |
if sr != self._mbd_sample_rate: | |
wav = julius.resample_frac(wav, sr, self._mbd_sample_rate) | |
if wav.ndim == 2: | |
wav = wav.unsqueeze(1) | |
wav = wav.to("cuda") | |
tokens = self.mbd.codec_model.encode(wav) | |
tokens = tokens[0][0] | |
return tokens.tolist() | |
def decode( | |
self, tokens: list[list[int]], causal: bool = True, ref_audio_path: Optional[str] = None | |
) -> Union[str, torch.Tensor]: | |
# TODO: this has strange behaviour -- if causal is True, it returns tokens. if causal is False, it SAVES the audio file. | |
text_ids, extracted_audio_ids = self._data_adapter_fn(tokens) | |
text = self.tokeniser_decode_fn(text_ids) | |
# print(f"Text: {text}") | |
tokens = torch.tensor(extracted_audio_ids, device="cuda").unsqueeze(0) | |
if tokens.shape[1] < self._num_codebooks: | |
tokens = torch.cat( | |
[tokens, *[torch.ones_like(tokens[0:1, 0:1]) * 0] * (self._num_codebooks - tokens.shape[1])], dim=1 | |
) | |
if causal: | |
return tokens | |
else: | |
with torch.amp.autocast(device_type="cuda", dtype=torch.float32): | |
wav = self.mbd.tokens_to_wav(tokens) | |
# NOTE: we couldn't just return wav here as it goes through loudness compression etc :) | |
if wav.shape[-1] < 9600: | |
# this causes problem for the code below, and is also odd :) | |
# first happened for tokens (1, 8, 28) -> wav (1, 1, 8960) (~320x factor in time dimension!) | |
raise Exception("wav predicted is shorter than 400ms!") | |
try: | |
wav_file_name = self.output_dir / f"synth_{text.replace(' ', '_')[:25]}_{uuid.uuid4()}" | |
self._save_audio(wav_file_name, wav) | |
return wav_file_name | |
except Exception as e: | |
print(f"Failed to save audio! Reason: {e}") | |
wav_file_name = self.output_dir / f"synth_{uuid.uuid4()}" | |
self._save_audio(wav_file_name, wav) | |
return wav_file_name | |