File size: 9,396 Bytes
2e9807d fb2212a 2e9807d fb2212a 2e9807d fb2212a 2e9807d fb2212a 2e9807d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 |
from typing import Union
import numpy as np
import torch
import torchaudio
import torch.nn as nn
import torchaudio.transforms as transforms
from transformers import PretrainedConfig, PreTrainedModel
import dac
from audiotools import AudioSignal
from utils import freeze
class DACConfig(PretrainedConfig):
model_type = 'dac'
def __init__(self,
model_type_by_sampling_freq:str='44khz',
encoding_chunk_size_in_sec:int=1,
decoding_chunk_rate:float=0.1,
decoding_overlap_rate:float=0.1,
**kwargs):
super().__init__(**kwargs)
"""
Initializes the model object.
Args:
model_type_by_sampling_freq (str, optional): The model type based on the sampling frequency. Defaults to '44khz'. Choose among ['44khz', '24khz', '16khz']
encoding_chunk_size_in_sec (int, optional): The size of the encoding chunk in seconds. Defaults to 1.
decoding_chunk_rate (float, optional): The decoding chunk rate. Must be between 0 and 1. Defaults to 0.1.
decoding_overlap_rate (float, optional): The decoding overlap rate. Must be between 0 and 1. Defaults to 0.1.
**kwargs: Additional keyword arguments.
Raises:
AssertionError: If the model_type_by_sampling_freq is not one of ['44khz', '24khz', '16khz'].
AssertionError: If the decoding_chunk_rate is not between 0 and 1.
AssertionError: If the decoding_overlap_rate is not between 0 and 1.
"""
self.model_type_by_sampling_freq = model_type_by_sampling_freq
self.encoding_chunk_size_in_sec = encoding_chunk_size_in_sec
self.decoding_chunk_rate = decoding_chunk_rate
self.decoding_overlap_rate = decoding_overlap_rate
assert model_type_by_sampling_freq.lower() in ['44khz', '24khz', '16khz']
assert decoding_chunk_rate > 0 and decoding_chunk_rate <= 1.0, '`decoding_chunk_rate` must be bewteen 0 and 1.'
assert decoding_overlap_rate >= 0 and decoding_overlap_rate < 1.0, '`decoding_overlap_rate` must be bewteen 0 and 1.'
class DAC(PreTrainedModel):
config_class = DACConfig
def __init__(self, config):
super().__init__(config)
self.model_type_by_sampling_freq = config.model_type_by_sampling_freq.lower()
self.model_type_by_sampling_freq_int = {'44khz':44100, '24khz':24000, '16khz':16000}[self.model_type_by_sampling_freq]
self.encoding_chunk_size_in_sec = config.encoding_chunk_size_in_sec
self.decoding_chunk_rate = config.decoding_chunk_rate
self.decoding_overlap_rate = config.decoding_overlap_rate
dac_path = dac.utils.download(model_type=self.model_type_by_sampling_freq)
self.dac = dac.DAC.load(dac_path)
self.dac.eval()
freeze(self.dac)
self.downsampling_rate = int(np.prod(self.dac.encoder_rates)) # 512
def load_audio(self, filename:str):
waveform, sample_rate = torchaudio.load(filename) # waveform: (n_channels, length); sample_rate: const.
return waveform, sample_rate
def resample_audio(self, waveform:torch.FloatTensor, orig_sr:int, target_sr:int):
"""
- sr: sampling rate
- waveform: (n_channels, length)
"""
if orig_sr == target_sr:
return waveform
converter = transforms.Resample(orig_freq=orig_sr, new_freq=target_sr)
waveform = converter(waveform) # (n_channels, new_length)
return waveform # (n_channels, new_length)
def to_mono_channel(self, waveform:torch.FloatTensor):
"""
- waveform: (n_channels, length)
"""
n_channels = waveform.shape[0]
if n_channels > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True) # (1, length)
return waveform # (1, length)
@torch.no_grad()
def encode(self, audio_fname:str):
self.eval()
waveform, sr = self.load_audio(audio_fname)
waveform = self.resample_audio(waveform, orig_sr=sr, target_sr=self.model_type_by_sampling_freq_int)
sr = self.model_type_by_sampling_freq_int
waveform = self.to_mono_channel(waveform) # DAC accepts a mono channel only.
zq, s = self._chunk_encoding(waveform, sr)
return zq, s
def _chunk_encoding(self, waveform:torch.FloatTensor, sr:int):
# TODO: can I make it parallel?
"""
waveform: (c l)
"""
x = waveform # brief varname
x = x.unsqueeze(1) # (b 1 l); add a null batch dim
chunk_size = int(self.encoding_chunk_size_in_sec * sr)
# adjust `chunk_size` to prevent any padding in `dac.preprocess`, which causes a gap between the mini-batches in the resulting music.
remainer = chunk_size % self.dac.hop_length
chunk_size = chunk_size-remainer
# process
zq_list, s_list = [], []
audio_length = x.shape[-1]
for start in range(0, audio_length, chunk_size):
end = start + chunk_size
chunk = x[:, :, start:end]
chunk = self.dac.preprocess(chunk, sr)
zq, s, _, _, _ = self.dac.encode(chunk.to(self.device))
zq = zq.cpu()
s = s.cpu()
"""
"zq" : Tensor[B x D x T]
Quantized continuous representation of input
= summation of all the residual quantized vectors across every rvq level
= E(x) = z = \sum_n^N{zq_n} where N is the number of codebooks
"s" : Tensor[B x N x T]
Codebook indices for each codebook
(quantized discrete representation of input)
*first element in the N dimension = first RVQ level
"""
zq_list.append(zq)
s_list.append(s)
torch.cuda.empty_cache()
zq = torch.cat(zq_list, dim=2).float() # (1, d, length)
s = torch.cat(s_list, dim=2).long() # (1, n_rvq, length)
return zq, s
@torch.no_grad()
def decode(self, *, zq:Union[torch.FloatTensor,None]=None, s:Union[torch.IntTensor,None]=None):
"""
zq: (b, d, length)
"""
if isinstance(zq,type(None)) and isinstance(s,type(None)):
assert False, 'one of them must be valid.'
self.eval()
if not isinstance(zq,type(None)):
waveform = self._chunk_decoding(zq) # (b, 1, length); output always has a mono-channel.
if not isinstance(s,type(None)):
zq = self.code_to_zq(s)
waveform = self._chunk_decoding(zq) # (b, 1, length); output always has a mono-channel.
return waveform
def _chunk_decoding(self, zq:torch.FloatTensor):
"""
zq: (b, d, length)
"""
length = zq.shape[-1]
chunk_size = round(int(self.decoding_chunk_rate * length))
overlap_size = round(self.decoding_overlap_rate * chunk_size) # overlap size in terms of token length
overlap_size_in_data_space = round(overlap_size * self.downsampling_rate)
waveform_concat = None
for start in range(0, length, chunk_size-overlap_size):
end = start + chunk_size
chunk = zq[:,:, start:end] # (b, d, chunk_size)
waveform = self.dac.decode(chunk.to(self.device)) # (b, 1, chunk_size*self.downsampling_rate)
waveform = waveform.cpu()
if isinstance(waveform_concat, type(None)):
waveform_concat = waveform.clone()
else:
if self.decoding_overlap_rate != 0.:
prev_x = waveform_concat[:,:,:-overlap_size_in_data_space]
rest_of_new_x = waveform[:,:,overlap_size_in_data_space:]
overlap_x_from_prev_x = waveform_concat[:,:,-overlap_size_in_data_space:] # (b, 1, overlap_size_in_data_space)
overlap_x_from_new_x = waveform[:,:,:overlap_size_in_data_space] # (b, 1, overlap_size_in_data_space)
overlap = (overlap_x_from_prev_x + overlap_x_from_new_x) / 2 # take mean; maybe there's a better strategy but it seems to work fine.
waveform_concat = torch.cat((prev_x, overlap, rest_of_new_x), dim=-1) # (b, 1, ..)
else:
prev_x = waveform_concat
rest_of_new_x = waveform
waveform_concat = torch.cat((prev_x, rest_of_new_x), dim=-1) # (b, 1, ..)
return waveform_concat # (b, 1, length)
def code_to_zq(self, s:torch.IntTensor):
"""
s: (b, n_rvq, length)
"""
zq, _, _ = self.dac.quantizer.from_codes(s.to(self.device)) # zq: (b, d, length)
zq = zq.cpu()
return zq
def save_tensor(self, tensor:torch.Tensor, fname:str) -> None:
torch.save(tensor.cpu(), fname)
def load_tensor(self, fname:str):
return torch.load(fname)
def waveform_to_audiofile(self, waveform:torch.FloatTensor, fname:str) -> None:
AudioSignal(waveform, sample_rate=self.model_type_by_sampling_freq_int).write(fname)
|