File size: 1,026 Bytes
9f73bcc 3b724b2 9f73bcc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
from transformers import PreTrainedModel
from package.audio_encoders_pytorch import AutoEncoder1d as AE1d, TanhBottleneck
from .autoencoder_config import AutoEncoder1dConfig
bottleneck = { 'tanh': TanhBottleneck }
class AutoEncoder1d(PreTrainedModel):
config_class = AutoEncoder1dConfig
def __init__(self, config: AutoEncoder1dConfig):
super().__init__(config)
self.autoencoder = AE1d(
in_channels = config.in_channels,
patch_size = config.patch_size,
channels = config.channels,
multipliers = config.multipliers,
factors = config.factors,
num_blocks = config.num_blocks,
bottleneck = bottleneck[config.bottleneck]()
)
def forward(self, *args, **kwargs):
return self.autoencoder(*args, **kwargs)
def encode(self, *args, **kwargs):
return self.autoencoder.encode(*args, **kwargs)
def decode(self, *args, **kwargs):
return self.autoencoder.decode(*args, **kwargs)
|