File size: 1,480 Bytes
e661388
 
 
 
3ffeea6
e661388
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch 
from transformers import PreTrainedModel
from .dmae_config import DMAE1dConfig
from audio_encoders_pytorch import ME1d, TanhBottleneck # pip install audio_encoders_pytorch==0.0.20
from audio_diffusion_pytorch import UNetV0, LTPlugin, DiffusionAE

class DMAE1d(PreTrainedModel):

    config_class = DMAE1dConfig

    def __init__(self, config: DMAE1dConfig):
        super().__init__(config)

        UNet = LTPlugin(
            UNetV0,
            num_filters=128,
            window_length=64,
            stride=64,
        )

        self.model = DiffusionAE(
            net_t=UNet, 
            dim=1, 
            in_channels=2,
            channels=[256, 512, 512, 512, 1024, 1024, 1024],
            factors=[1, 2, 2, 2, 2, 2, 2],
            items=[1, 2, 2, 2, 2, 2, 2],
            encoder=ME1d(
                in_channels=2,
                channels=512,
                multipliers=[1, 1, 1],
                factors=[2, 2],
                num_blocks=[4, 8],
                stft_num_fft=1023,
                stft_hop_length=256, 
                out_channels=32,
                bottleneck=TanhBottleneck()
            ),
            inject_depth=4
        )

    def forward(self, *args, **kwargs):
        return self.model(*args, **kwargs)

    def encode(self, *args, **kwargs):
        return self.model.encode(*args, **kwargs)
    
    @torch.no_grad()
    def decode(self, *args, **kwargs):
        return self.model.decode(*args, **kwargs)