|
from typing import Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torchaudio
|
|
import torch.nn as nn
|
|
import torchaudio.transforms as transforms
|
|
from transformers import PretrainedConfig, PreTrainedModel
|
|
|
|
import dac
|
|
from audiotools import AudioSignal
|
|
|
|
from utils import freeze
|
|
|
|
|
|
class DACConfig(PretrainedConfig):
|
|
model_type = 'dac'
|
|
|
|
def __init__(self,
|
|
model_type_by_sampling_freq:str='44khz',
|
|
encoding_chunk_size_in_sec:int=1,
|
|
decoding_chunk_rate:float=0.1,
|
|
decoding_overlap_rate:float=0.1,
|
|
**kwargs):
|
|
super().__init__(**kwargs)
|
|
"""
|
|
Initializes the model object.
|
|
Args:
|
|
model_type_by_sampling_freq (str, optional): The model type based on the sampling frequency. Defaults to '44khz'. Choose among ['44khz', '24khz', '16khz']
|
|
encoding_chunk_size_in_sec (int, optional): The size of the encoding chunk in seconds. Defaults to 1.
|
|
decoding_chunk_rate (float, optional): The decoding chunk rate. Must be between 0 and 1. Defaults to 0.1.
|
|
decoding_overlap_rate (float, optional): The decoding overlap rate. Must be between 0 and 1. Defaults to 0.1.
|
|
**kwargs: Additional keyword arguments.
|
|
Raises:
|
|
AssertionError: If the model_type_by_sampling_freq is not one of ['44khz', '24khz', '16khz'].
|
|
AssertionError: If the decoding_chunk_rate is not between 0 and 1.
|
|
AssertionError: If the decoding_overlap_rate is not between 0 and 1.
|
|
"""
|
|
self.model_type_by_sampling_freq = model_type_by_sampling_freq
|
|
self.encoding_chunk_size_in_sec = encoding_chunk_size_in_sec
|
|
self.decoding_chunk_rate = decoding_chunk_rate
|
|
self.decoding_overlap_rate = decoding_overlap_rate
|
|
|
|
assert model_type_by_sampling_freq.lower() in ['44khz', '24khz', '16khz']
|
|
assert decoding_chunk_rate > 0 and decoding_chunk_rate <= 1.0, '`decoding_chunk_rate` must be bewteen 0 and 1.'
|
|
assert decoding_overlap_rate >= 0 and decoding_overlap_rate < 1.0, '`decoding_overlap_rate` must be bewteen 0 and 1.'
|
|
|
|
|
|
|
|
class DAC(PreTrainedModel):
|
|
config_class = DACConfig
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
|
|
self.model_type_by_sampling_freq = config.model_type_by_sampling_freq.lower()
|
|
self.model_type_by_sampling_freq_int = {'44khz':44100, '24khz':24000, '16khz':16000}[self.model_type_by_sampling_freq]
|
|
self.encoding_chunk_size_in_sec = config.encoding_chunk_size_in_sec
|
|
self.decoding_chunk_rate = config.decoding_chunk_rate
|
|
self.decoding_overlap_rate = config.decoding_overlap_rate
|
|
|
|
|
|
dac_path = dac.utils.download(model_type=self.model_type_by_sampling_freq)
|
|
self.dac = dac.DAC.load(dac_path)
|
|
self.dac.eval()
|
|
freeze(self.dac)
|
|
|
|
self.downsampling_rate = int(np.prod(self.dac.encoder_rates))
|
|
|
|
def load_audio(self, filename:str):
|
|
waveform, sample_rate = torchaudio.load(filename)
|
|
return waveform, sample_rate
|
|
|
|
def resample_audio(self, waveform:torch.FloatTensor, orig_sr:int, target_sr:int):
|
|
"""
|
|
- sr: sampling rate
|
|
- waveform: (n_channels, length)
|
|
"""
|
|
if orig_sr == target_sr:
|
|
return waveform
|
|
|
|
converter = transforms.Resample(orig_freq=orig_sr, new_freq=target_sr)
|
|
waveform = converter(waveform)
|
|
return waveform
|
|
|
|
def to_mono_channel(self, waveform:torch.FloatTensor):
|
|
"""
|
|
- waveform: (n_channels, length)
|
|
"""
|
|
n_channels = waveform.shape[0]
|
|
if n_channels > 1:
|
|
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
|
return waveform
|
|
|
|
@torch.no_grad()
|
|
def encode(self, audio_fname:str):
|
|
self.eval()
|
|
|
|
waveform, sr = self.load_audio(audio_fname)
|
|
waveform = self.resample_audio(waveform, orig_sr=sr, target_sr=self.model_type_by_sampling_freq_int)
|
|
sr = self.model_type_by_sampling_freq_int
|
|
waveform = self.to_mono_channel(waveform)
|
|
|
|
zq, s = self._chunk_encoding(waveform, sr)
|
|
return zq, s
|
|
|
|
def _chunk_encoding(self, waveform:torch.FloatTensor, sr:int):
|
|
|
|
"""
|
|
waveform: (c l)
|
|
"""
|
|
x = waveform
|
|
x = x.unsqueeze(1)
|
|
chunk_size = int(self.encoding_chunk_size_in_sec * sr)
|
|
|
|
|
|
remainer = chunk_size % self.dac.hop_length
|
|
chunk_size = chunk_size-remainer
|
|
|
|
|
|
zq_list, s_list = [], []
|
|
audio_length = x.shape[-1]
|
|
for start in range(0, audio_length, chunk_size):
|
|
end = start + chunk_size
|
|
chunk = x[:, :, start:end]
|
|
chunk = self.dac.preprocess(chunk, sr)
|
|
zq, s, _, _, _ = self.dac.encode(chunk.to(self.device))
|
|
zq = zq.cpu()
|
|
s = s.cpu()
|
|
"""
|
|
"zq" : Tensor[B x D x T]
|
|
Quantized continuous representation of input
|
|
= summation of all the residual quantized vectors across every rvq level
|
|
= E(x) = z = \sum_n^N{zq_n} where N is the number of codebooks
|
|
"s" : Tensor[B x N x T]
|
|
Codebook indices for each codebook
|
|
(quantized discrete representation of input)
|
|
*first element in the N dimension = first RVQ level
|
|
"""
|
|
zq_list.append(zq)
|
|
s_list.append(s)
|
|
torch.cuda.empty_cache()
|
|
|
|
zq = torch.cat(zq_list, dim=2).float()
|
|
s = torch.cat(s_list, dim=2).long()
|
|
|
|
return zq, s
|
|
|
|
@torch.no_grad()
|
|
def decode(self, *, zq:Union[torch.FloatTensor,None]=None, s:Union[torch.IntTensor,None]=None):
|
|
"""
|
|
zq: (b, d, length)
|
|
"""
|
|
if isinstance(zq,type(None)) and isinstance(s,type(None)):
|
|
assert False, 'one of them must be valid.'
|
|
self.eval()
|
|
|
|
if not isinstance(zq,type(None)):
|
|
waveform = self._chunk_decoding(zq)
|
|
if not isinstance(s,type(None)):
|
|
zq = self.code_to_zq(s)
|
|
waveform = self._chunk_decoding(zq)
|
|
|
|
return waveform
|
|
|
|
def _chunk_decoding(self, zq:torch.FloatTensor):
|
|
"""
|
|
zq: (b, d, length)
|
|
"""
|
|
length = zq.shape[-1]
|
|
chunk_size = round(int(self.decoding_chunk_rate * length))
|
|
overlap_size = round(self.decoding_overlap_rate * chunk_size)
|
|
overlap_size_in_data_space = round(overlap_size * self.downsampling_rate)
|
|
waveform_concat = None
|
|
for start in range(0, length, chunk_size-overlap_size):
|
|
end = start + chunk_size
|
|
chunk = zq[:,:, start:end]
|
|
waveform = self.dac.decode(chunk.to(self.device))
|
|
waveform = waveform.cpu()
|
|
|
|
if isinstance(waveform_concat, type(None)):
|
|
waveform_concat = waveform.clone()
|
|
else:
|
|
if self.decoding_overlap_rate != 0.:
|
|
prev_x = waveform_concat[:,:,:-overlap_size_in_data_space]
|
|
rest_of_new_x = waveform[:,:,overlap_size_in_data_space:]
|
|
overlap_x_from_prev_x = waveform_concat[:,:,-overlap_size_in_data_space:]
|
|
overlap_x_from_new_x = waveform[:,:,:overlap_size_in_data_space]
|
|
overlap = (overlap_x_from_prev_x + overlap_x_from_new_x) / 2
|
|
waveform_concat = torch.cat((prev_x, overlap, rest_of_new_x), dim=-1)
|
|
else:
|
|
prev_x = waveform_concat
|
|
rest_of_new_x = waveform
|
|
waveform_concat = torch.cat((prev_x, rest_of_new_x), dim=-1)
|
|
return waveform_concat
|
|
|
|
def code_to_zq(self, s:torch.IntTensor):
|
|
"""
|
|
s: (b, n_rvq, length)
|
|
"""
|
|
zq, _, _ = self.dac.quantizer.from_codes(s.to(self.device))
|
|
zq = zq.cpu()
|
|
return zq
|
|
|
|
def save_tensor(self, tensor:torch.Tensor, fname:str) -> None:
|
|
torch.save(tensor.cpu(), fname)
|
|
|
|
def load_tensor(self, fname:str):
|
|
return torch.load(fname)
|
|
|
|
def waveform_to_audiofile(self, waveform:torch.FloatTensor, fname:str) -> None:
|
|
AudioSignal(waveform, sample_rate=self.model_type_by_sampling_freq_int).write(fname)
|
|
|