File size: 1,480 Bytes
e661388 3ffeea6 e661388 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
import torch
from transformers import PreTrainedModel
from .dmae_config import DMAE1dConfig
from audio_encoders_pytorch import ME1d, TanhBottleneck # pip install audio_encoders_pytorch==0.0.20
from audio_diffusion_pytorch import UNetV0, LTPlugin, DiffusionAE
class DMAE1d(PreTrainedModel):
config_class = DMAE1dConfig
def __init__(self, config: DMAE1dConfig):
super().__init__(config)
UNet = LTPlugin(
UNetV0,
num_filters=128,
window_length=64,
stride=64,
)
self.model = DiffusionAE(
net_t=UNet,
dim=1,
in_channels=2,
channels=[256, 512, 512, 512, 1024, 1024, 1024],
factors=[1, 2, 2, 2, 2, 2, 2],
items=[1, 2, 2, 2, 2, 2, 2],
encoder=ME1d(
in_channels=2,
channels=512,
multipliers=[1, 1, 1],
factors=[2, 2],
num_blocks=[4, 8],
stft_num_fft=1023,
stft_hop_length=256,
out_channels=32,
bottleneck=TanhBottleneck()
),
inject_depth=4
)
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
def encode(self, *args, **kwargs):
return self.model.encode(*args, **kwargs)
@torch.no_grad()
def decode(self, *args, **kwargs):
return self.model.decode(*args, **kwargs)
|