import torch from transformers import PreTrainedModel from .dmae_config import DMAE1dConfig from audio_encoders_pytorch import ME1d, TanhBottleneck # pip install audio_encoders_pytorch==0.0.20 from audio_diffusion_pytorch import UNetV0, LTPlugin, DiffusionAE class DMAE1d(PreTrainedModel): config_class = DMAE1dConfig def __init__(self, config: DMAE1dConfig): super().__init__(config) UNet = LTPlugin( UNetV0, num_filters=128, window_length=64, stride=64, ) self.model = DiffusionAE( net_t=UNet, dim=1, in_channels=2, channels=[256, 512, 512, 512, 1024, 1024, 1024], factors=[1, 2, 2, 2, 2, 2, 2], items=[1, 2, 2, 2, 2, 2, 2], encoder=ME1d( in_channels=2, channels=512, multipliers=[1, 1, 1], factors=[2, 2], num_blocks=[4, 8], stft_num_fft=1023, stft_hop_length=256, out_channels=32, bottleneck=TanhBottleneck() ), inject_depth=4 ) def forward(self, *args, **kwargs): return self.model(*args, **kwargs) def encode(self, *args, **kwargs): return self.model.encode(*args, **kwargs) @torch.no_grad() def decode(self, *args, **kwargs): return self.model.decode(*args, **kwargs)