MassivelyMultilingualTTS / Preprocessing /EnCodecAudioPreprocessor.py
Flux9665's picture
use explicit code instead of relying on release download
9e275b8
raw
history blame
5.48 kB
from collections import OrderedDict
import torch
from torchaudio.transforms import Resample
from Preprocessing.Codec.encodec import EnCodec
class CodecAudioPreprocessor:
def __init__(self, input_sr, output_sr=16000, device="cpu", path_to_model="Preprocessing/Codec/encodec_16k_320d.pt"):
self.device = device
self.input_sr = input_sr
self.output_sr = output_sr
self.resample = Resample(orig_freq=input_sr, new_freq=output_sr).to(self.device)
self.model = EnCodec(n_filters=32, D=512)
parameter_dict = torch.load(path_to_model, map_location="cpu")
new_state_dict = OrderedDict()
for k, v in parameter_dict.items():
name = k[7:]
new_state_dict[name] = v
self.model.load_state_dict(new_state_dict)
remove_encodec_weight_norm(self.model)
self.model.eval()
self.model.to(device)
def resample_audio(self, audio, current_sampling_rate):
if current_sampling_rate != self.input_sr:
print("warning, change in sampling rate detected. If this happens too often, consider re-ordering the audios so that the sampling rate stays constant for multiple samples")
self.resample = Resample(orig_freq=current_sampling_rate, new_freq=self.output_sr).to(self.device)
self.input_sr = current_sampling_rate
if type(audio) != torch.tensor and type(audio) != torch.Tensor:
audio = torch.tensor(audio, device=self.device, dtype=torch.float32)
audio = self.resample(audio.float().to(self.device))
return audio
@torch.inference_mode()
def audio_to_codebook_indexes(self, audio, current_sampling_rate):
if current_sampling_rate != self.output_sr:
audio = self.resample_audio(audio, current_sampling_rate)
elif type(audio) != torch.tensor and type(audio) != torch.Tensor:
audio = torch.tensor(audio, device=self.device, dtype=torch.float32)
return self.model.encode(audio.float().unsqueeze(0).unsqueeze(0).to(self.device)).squeeze()
@torch.inference_mode()
def indexes_to_audio(self, codebook_indexes):
return self.model.decode(codebook_indexes).squeeze()
def remove_encodec_weight_norm(model):
from Preprocessing.Codec.seanet import SConv1d
from Preprocessing.Codec.seanet import SConvTranspose1d
from Preprocessing.Codec.seanet import SEANetResnetBlock
from torch.nn.utils import remove_weight_norm
encoder = model.encoder.model
for key in encoder._modules:
if isinstance(encoder._modules[key], SEANetResnetBlock):
remove_weight_norm(encoder._modules[key].shortcut.conv.conv)
block_modules = encoder._modules[key].block._modules
for skey in block_modules:
if isinstance(block_modules[skey], SConv1d):
remove_weight_norm(block_modules[skey].conv.conv)
elif isinstance(encoder._modules[key], SConv1d):
remove_weight_norm(encoder._modules[key].conv.conv)
decoder = model.decoder.model
for key in decoder._modules:
if isinstance(decoder._modules[key], SEANetResnetBlock):
remove_weight_norm(decoder._modules[key].shortcut.conv.conv)
block_modules = decoder._modules[key].block._modules
for skey in block_modules:
if isinstance(block_modules[skey], SConv1d):
remove_weight_norm(block_modules[skey].conv.conv)
elif isinstance(decoder._modules[key], SConvTranspose1d):
remove_weight_norm(decoder._modules[key].convtr.convtr)
elif isinstance(decoder._modules[key], SConv1d):
remove_weight_norm(decoder._modules[key].conv.conv)
if __name__ == '__main__':
import soundfile
import time
with torch.inference_mode():
test_audio1 = "../audios/ad01_0000.wav"
test_audio2 = "../audios/angry.wav"
test_audio3 = "../audios/ry.wav"
test_audio4 = "../audios/test.wav"
ap = CodecAudioPreprocessor(input_sr=1, path_to_model="Codec/encodec_16k_320d.pt")
wav, sr = soundfile.read(test_audio1)
indexes_1 = ap.audio_to_codebook_indexes(wav, current_sampling_rate=sr)
wav, sr = soundfile.read(test_audio2)
indexes_2 = ap.audio_to_codebook_indexes(wav, current_sampling_rate=sr)
wav, sr = soundfile.read(test_audio3)
indexes_3 = ap.audio_to_codebook_indexes(wav, current_sampling_rate=sr)
wav, sr = soundfile.read(test_audio4)
indexes_4 = ap.audio_to_codebook_indexes(wav, current_sampling_rate=sr)
print(indexes_4)
t0 = time.time()
audio1 = ap.indexes_to_audio(indexes_1)
audio2 = ap.indexes_to_audio(indexes_2)
audio3 = ap.indexes_to_audio(indexes_3)
audio4 = ap.indexes_to_audio(indexes_4)
t1 = time.time()
print(audio1.shape)
print(audio2.shape)
print(audio3.shape)
print(audio4.shape)
print(t1 - t0)
soundfile.write(file=f"../audios/1_reconstructed_in_{t1 - t0}_encodec.wav", data=audio1, samplerate=16000)
soundfile.write(file=f"../audios/2_reconstructed_in_{t1 - t0}_encodec.wav", data=audio2, samplerate=16000)
soundfile.write(file=f"../audios/3_reconstructed_in_{t1 - t0}_encodec.wav", data=audio3, samplerate=16000)
soundfile.write(file=f"../audios/4_reconstructed_in_{t1 - t0}_encodec.wav", data=audio4, samplerate=16000)