File size: 5,285 Bytes
5488167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71922e7
5488167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71922e7
5488167
 
71922e7
 
 
 
 
 
 
 
 
 
5488167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import os
import torch
from diffusers import AutoencoderDC
import torchaudio
import torchvision.transforms as transforms
from diffusers.models.modeling_utils import ModelMixin
from diffusers.loaders import FromOriginalModelMixin
from diffusers.configuration_utils import ConfigMixin, register_to_config


try:
    from .music_vocoder import ADaMoSHiFiGANV1
except ImportError:
    from music_vocoder import ADaMoSHiFiGANV1


root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
DEFAULT_PRETRAINED_PATH = os.path.join(root_dir, "checkpoints", "music_dcae_f8c8")
VOCODER_PRETRAINED_PATH = os.path.join(root_dir, "checkpoints", "music_vocoder")


class MusicDCAE(ModelMixin, ConfigMixin, FromOriginalModelMixin):
    @register_to_config
    def __init__(self, source_sample_rate=None, dcae_checkpoint_path=DEFAULT_PRETRAINED_PATH, vocoder_checkpoint_path=VOCODER_PRETRAINED_PATH):
        super(MusicDCAE, self).__init__()

        self.dcae = AutoencoderDC.from_pretrained(dcae_checkpoint_path)
        self.vocoder = ADaMoSHiFiGANV1.from_pretrained(vocoder_checkpoint_path)

        if source_sample_rate is None:
            source_sample_rate = 48000

        self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100)

        self.transform = transforms.Compose([
            transforms.Normalize(0.5, 0.5),
        ])
        self.min_mel_value = -11.0
        self.max_mel_value = 3.0
        self.audio_chunk_size = int(round((1024 * 512 / 44100 * 48000)))
        self.mel_chunk_size = 1024
        self.time_dimention_multiple = 8
        self.latent_chunk_size = self.mel_chunk_size // self.time_dimention_multiple
        self.scale_factor = 0.1786
        self.shift_factor = -1.9091

    def load_audio(self, audio_path):
        audio, sr = torchaudio.load(audio_path)
        return audio, sr

    def forward_mel(self, audios):
        mels = []
        for i in range(len(audios)):
            image = self.vocoder.mel_transform(audios[i])
            mels.append(image)
        mels = torch.stack(mels)
        return mels

    @torch.no_grad()
    def encode(self, audios, audio_lengths=None, sr=None):
        if audio_lengths is None:
            audio_lengths = torch.tensor([audios.shape[2]] * audios.shape[0])
            audio_lengths = audio_lengths.to(audios.device)

        # audios: N x 2 x T, 48kHz
        device = audios.device
        dtype = audios.dtype

        if sr is None:
            sr = 48000
            resampler = self.resampler
        else:
            resampler = torchaudio.transforms.Resample(sr, 44100).to(device).to(dtype)

        audio = resampler(audios)

        max_audio_len = audio.shape[-1]
        if max_audio_len % (8 * 512) != 0:
            audio = torch.nn.functional.pad(audio, (0, 8 * 512 - max_audio_len % (8 * 512)))

        mels = self.forward_mel(audio)
        mels = (mels - self.min_mel_value) / (self.max_mel_value - self.min_mel_value)
        mels = self.transform(mels)
        latents = []
        for mel in mels:
            latent = self.dcae.encoder(mel.unsqueeze(0))
            latents.append(latent)
        latents = torch.cat(latents, dim=0)
        latent_lengths = (audio_lengths / sr * 44100 / 512 / self.time_dimention_multiple).long()
        latents = (latents - self.shift_factor) * self.scale_factor
        return latents, latent_lengths

    @torch.no_grad()
    def decode(self, latents, audio_lengths=None, sr=None):
        latents = latents / self.scale_factor + self.shift_factor

        pred_wavs = []

        for latent in latents:
            mels = self.dcae.decoder(latent.unsqueeze(0))
            mels = mels * 0.5 + 0.5
            mels = mels * (self.max_mel_value - self.min_mel_value) + self.min_mel_value
            wav = self.vocoder.decode(mels[0]).squeeze(1)

            if sr is not None:
                resampler = torchaudio.transforms.Resample(44100, sr).to(latents.device).to(latents.dtype)
                wav = resampler(wav)
            else:
                sr = 44100
            pred_wavs.append(wav)

        if audio_lengths is not None:
            pred_wavs = [wav[:, :length].cpu() for wav, length in zip(pred_wavs, audio_lengths)]
        return sr, pred_wavs

    def forward(self, audios, audio_lengths=None, sr=None):
        latents, latent_lengths = self.encode(audios=audios, audio_lengths=audio_lengths, sr=sr)
        sr, pred_wavs = self.decode(latents=latents, audio_lengths=audio_lengths, sr=sr)
        return sr, pred_wavs, latents, latent_lengths


if __name__ == "__main__":

    audio, sr = torchaudio.load("test.wav")
    audio_lengths = torch.tensor([audio.shape[1]])
    audios = audio.unsqueeze(0)
    
    # test encode only
    model = MusicDCAE()
    # latents, latent_lengths = model.encode(audios, audio_lengths)
    # print("latents shape: ", latents.shape)
    # print("latent_lengths: ", latent_lengths)

    # test encode and decode
    sr, pred_wavs, latents, latent_lengths = model(audios, audio_lengths, sr)
    print("reconstructed wavs: ", pred_wavs[0].shape)
    print("latents shape: ", latents.shape)
    print("latent_lengths: ", latent_lengths)
    print("sr: ", sr)
    torchaudio.save("test_reconstructed.flac", pred_wavs[0], sr)
    print("test_reconstructed.flac")