Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	File size: 5,475 Bytes
			
			| 9e275b8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 | from collections import OrderedDict
import torch
from torchaudio.transforms import Resample
from Preprocessing.Codec.encodec import EnCodec
class CodecAudioPreprocessor:
    def __init__(self, input_sr, output_sr=16000, device="cpu", path_to_model="Preprocessing/Codec/encodec_16k_320d.pt"):
        self.device = device
        self.input_sr = input_sr
        self.output_sr = output_sr
        self.resample = Resample(orig_freq=input_sr, new_freq=output_sr).to(self.device)
        self.model = EnCodec(n_filters=32, D=512)
        parameter_dict = torch.load(path_to_model, map_location="cpu")
        new_state_dict = OrderedDict()
        for k, v in parameter_dict.items():
            name = k[7:]
            new_state_dict[name] = v
        self.model.load_state_dict(new_state_dict)
        remove_encodec_weight_norm(self.model)
        self.model.eval()
        self.model.to(device)
    def resample_audio(self, audio, current_sampling_rate):
        if current_sampling_rate != self.input_sr:
            print("warning, change in sampling rate detected. If this happens too often, consider re-ordering the audios so that the sampling rate stays constant for multiple samples")
            self.resample = Resample(orig_freq=current_sampling_rate, new_freq=self.output_sr).to(self.device)
            self.input_sr = current_sampling_rate
        if type(audio) != torch.tensor and type(audio) != torch.Tensor:
            audio = torch.tensor(audio, device=self.device, dtype=torch.float32)
        audio = self.resample(audio.float().to(self.device))
        return audio
    @torch.inference_mode()
    def audio_to_codebook_indexes(self, audio, current_sampling_rate):
        if current_sampling_rate != self.output_sr:
            audio = self.resample_audio(audio, current_sampling_rate)
        elif type(audio) != torch.tensor and type(audio) != torch.Tensor:
            audio = torch.tensor(audio, device=self.device, dtype=torch.float32)
        return self.model.encode(audio.float().unsqueeze(0).unsqueeze(0).to(self.device)).squeeze()
    @torch.inference_mode()
    def indexes_to_audio(self, codebook_indexes):
        return self.model.decode(codebook_indexes).squeeze()
def remove_encodec_weight_norm(model):
    from Preprocessing.Codec.seanet import SConv1d
    from Preprocessing.Codec.seanet import SConvTranspose1d
    from Preprocessing.Codec.seanet import SEANetResnetBlock
    from torch.nn.utils import remove_weight_norm
    encoder = model.encoder.model
    for key in encoder._modules:
        if isinstance(encoder._modules[key], SEANetResnetBlock):
            remove_weight_norm(encoder._modules[key].shortcut.conv.conv)
            block_modules = encoder._modules[key].block._modules
            for skey in block_modules:
                if isinstance(block_modules[skey], SConv1d):
                    remove_weight_norm(block_modules[skey].conv.conv)
        elif isinstance(encoder._modules[key], SConv1d):
            remove_weight_norm(encoder._modules[key].conv.conv)
    decoder = model.decoder.model
    for key in decoder._modules:
        if isinstance(decoder._modules[key], SEANetResnetBlock):
            remove_weight_norm(decoder._modules[key].shortcut.conv.conv)
            block_modules = decoder._modules[key].block._modules
            for skey in block_modules:
                if isinstance(block_modules[skey], SConv1d):
                    remove_weight_norm(block_modules[skey].conv.conv)
        elif isinstance(decoder._modules[key], SConvTranspose1d):
            remove_weight_norm(decoder._modules[key].convtr.convtr)
        elif isinstance(decoder._modules[key], SConv1d):
            remove_weight_norm(decoder._modules[key].conv.conv)
if __name__ == '__main__':
    import soundfile
    import time
    with torch.inference_mode():
        test_audio1 = "../audios/ad01_0000.wav"
        test_audio2 = "../audios/angry.wav"
        test_audio3 = "../audios/ry.wav"
        test_audio4 = "../audios/test.wav"
        ap = CodecAudioPreprocessor(input_sr=1, path_to_model="Codec/encodec_16k_320d.pt")
        wav, sr = soundfile.read(test_audio1)
        indexes_1 = ap.audio_to_codebook_indexes(wav, current_sampling_rate=sr)
        wav, sr = soundfile.read(test_audio2)
        indexes_2 = ap.audio_to_codebook_indexes(wav, current_sampling_rate=sr)
        wav, sr = soundfile.read(test_audio3)
        indexes_3 = ap.audio_to_codebook_indexes(wav, current_sampling_rate=sr)
        wav, sr = soundfile.read(test_audio4)
        indexes_4 = ap.audio_to_codebook_indexes(wav, current_sampling_rate=sr)
        print(indexes_4)
        t0 = time.time()
        audio1 = ap.indexes_to_audio(indexes_1)
        audio2 = ap.indexes_to_audio(indexes_2)
        audio3 = ap.indexes_to_audio(indexes_3)
        audio4 = ap.indexes_to_audio(indexes_4)
        t1 = time.time()
        print(audio1.shape)
        print(audio2.shape)
        print(audio3.shape)
        print(audio4.shape)
        print(t1 - t0)
        soundfile.write(file=f"../audios/1_reconstructed_in_{t1 - t0}_encodec.wav", data=audio1, samplerate=16000)
        soundfile.write(file=f"../audios/2_reconstructed_in_{t1 - t0}_encodec.wav", data=audio2, samplerate=16000)
        soundfile.write(file=f"../audios/3_reconstructed_in_{t1 - t0}_encodec.wav", data=audio3, samplerate=16000)
        soundfile.write(file=f"../audios/4_reconstructed_in_{t1 - t0}_encodec.wav", data=audio4, samplerate=16000)
 | 
