File size: 1,026 Bytes
9f73bcc
3b724b2
9f73bcc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from transformers import PreTrainedModel
from package.audio_encoders_pytorch import AutoEncoder1d as AE1d, TanhBottleneck
from .autoencoder_config import AutoEncoder1dConfig

bottleneck = { 'tanh': TanhBottleneck }

class AutoEncoder1d(PreTrainedModel):

    config_class = AutoEncoder1dConfig

    def __init__(self, config: AutoEncoder1dConfig):
        super().__init__(config)

        self.autoencoder = AE1d(
            in_channels = config.in_channels,
            patch_size = config.patch_size,
            channels = config.channels, 
            multipliers = config.multipliers,
            factors = config.factors, 
            num_blocks = config.num_blocks, 
            bottleneck = bottleneck[config.bottleneck]() 
        )

    def forward(self, *args, **kwargs):
        return self.autoencoder(*args, **kwargs)

    def encode(self, *args, **kwargs):
        return self.autoencoder.encode(*args, **kwargs)

    def decode(self, *args, **kwargs):
        return self.autoencoder.decode(*args, **kwargs)