emo-knob / fam /llm /decoders.py
tonychenxyz's picture
init
9e34a62
raw
history blame
3.73 kB
import os
import pathlib
import uuid
from abc import ABC, abstractmethod
from typing import Callable, Optional, Union
import julius
import torch
from audiocraft.data.audio import audio_read, audio_write
from audiocraft.models import MultiBandDiffusion # type: ignore
mbd = MultiBandDiffusion.get_mbd_24khz(bw=6) # 1.5
class Decoder(ABC):
@abstractmethod
def decode(self, tokens: list[int], ref_audio_path: Optional[str] = None, causal: Optional[bool] = None):
raise NotImplementedError
class EncodecDecoder(Decoder):
def __init__(
self,
tokeniser_decode_fn: Callable[[list[int]], str],
data_adapter_fn: Callable[[list[list[int]]], tuple[list[int], list[list[int]]]],
output_dir: str,
):
self._mbd_sample_rate = 24_000
self._end_of_audio_token = 1024
self._num_codebooks = 8
self.mbd = mbd
self.tokeniser_decode_fn = tokeniser_decode_fn
self._data_adapter_fn = data_adapter_fn
self.output_dir = pathlib.Path(output_dir).resolve()
os.makedirs(self.output_dir, exist_ok=True)
def _save_audio(self, name: str, wav: torch.Tensor):
audio_write(
name,
wav.squeeze(0).cpu(),
self._mbd_sample_rate,
strategy="loudness",
loudness_compressor=True,
)
def get_tokens(self, audio_path: str) -> list[list[int]]:
"""
Utility method to get tokens from audio. Useful when you want to test reconstruction in some form (e.g.
limited codebook reconstruction or sampling from second stage model only).
"""
pass
wav, sr = audio_read(audio_path)
if sr != self._mbd_sample_rate:
wav = julius.resample_frac(wav, sr, self._mbd_sample_rate)
if wav.ndim == 2:
wav = wav.unsqueeze(1)
wav = wav.to("cuda")
tokens = self.mbd.codec_model.encode(wav)
tokens = tokens[0][0]
return tokens.tolist()
def decode(
self, tokens: list[list[int]], causal: bool = True, ref_audio_path: Optional[str] = None
) -> Union[str, torch.Tensor]:
# TODO: this has strange behaviour -- if causal is True, it returns tokens. if causal is False, it SAVES the audio file.
text_ids, extracted_audio_ids = self._data_adapter_fn(tokens)
text = self.tokeniser_decode_fn(text_ids)
# print(f"Text: {text}")
tokens = torch.tensor(extracted_audio_ids, device="cuda").unsqueeze(0)
if tokens.shape[1] < self._num_codebooks:
tokens = torch.cat(
[tokens, *[torch.ones_like(tokens[0:1, 0:1]) * 0] * (self._num_codebooks - tokens.shape[1])], dim=1
)
if causal:
return tokens
else:
with torch.amp.autocast(device_type="cuda", dtype=torch.float32):
wav = self.mbd.tokens_to_wav(tokens)
# NOTE: we couldn't just return wav here as it goes through loudness compression etc :)
if wav.shape[-1] < 9600:
# this causes problem for the code below, and is also odd :)
# first happened for tokens (1, 8, 28) -> wav (1, 1, 8960) (~320x factor in time dimension!)
raise Exception("wav predicted is shorter than 400ms!")
try:
wav_file_name = self.output_dir / f"synth_{text.replace(' ', '_')[:25]}_{uuid.uuid4()}"
self._save_audio(wav_file_name, wav)
return wav_file_name
except Exception as e:
print(f"Failed to save audio! Reason: {e}")
wav_file_name = self.output_dir / f"synth_{uuid.uuid4()}"
self._save_audio(wav_file_name, wav)
return wav_file_name