dmae1d-ATC64-v2 / dmae.py
flavioschneider's picture
Upload DMAE1d
3ffeea6
raw
history blame contribute delete
No virus
1.48 kB
import torch
from transformers import PreTrainedModel
from .dmae_config import DMAE1dConfig
from audio_encoders_pytorch import ME1d, TanhBottleneck # pip install audio_encoders_pytorch==0.0.20
from audio_diffusion_pytorch import UNetV0, LTPlugin, DiffusionAE
class DMAE1d(PreTrainedModel):
config_class = DMAE1dConfig
def __init__(self, config: DMAE1dConfig):
super().__init__(config)
UNet = LTPlugin(
UNetV0,
num_filters=128,
window_length=64,
stride=64,
)
self.model = DiffusionAE(
net_t=UNet,
dim=1,
in_channels=2,
channels=[256, 512, 512, 512, 1024, 1024, 1024],
factors=[1, 2, 2, 2, 2, 2, 2],
items=[1, 2, 2, 2, 2, 2, 2],
encoder=ME1d(
in_channels=2,
channels=512,
multipliers=[1, 1, 1],
factors=[2, 2],
num_blocks=[4, 8],
stft_num_fft=1023,
stft_hop_length=256,
out_channels=32,
bottleneck=TanhBottleneck()
),
inject_depth=4
)
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
def encode(self, *args, **kwargs):
return self.model.encode(*args, **kwargs)
@torch.no_grad()
def decode(self, *args, **kwargs):
return self.model.decode(*args, **kwargs)