flavioschneider commited on
Commit
e661388
1 Parent(s): 3548637

Upload DMAE1d

Browse files
Files changed (3) hide show
  1. config.json +6 -1
  2. dmae.py +52 -0
  3. pytorch_model.bin +3 -0
config.json CHANGED
@@ -1,7 +1,12 @@
1
  {
 
 
 
2
  "auto_map": {
3
- "AutoConfig": "dmae_config.DMAE1dConfig"
 
4
  },
5
  "model_type": "archinetai/dmae1d-ATC64-v2",
 
6
  "transformers_version": "4.24.0"
7
  }
 
1
  {
2
+ "architectures": [
3
+ "DMAE1d"
4
+ ],
5
  "auto_map": {
6
+ "AutoConfig": "dmae_config.DMAE1dConfig",
7
+ "AutoModel": "dmae.DMAE1d"
8
  },
9
  "model_type": "archinetai/dmae1d-ATC64-v2",
10
+ "torch_dtype": "float32",
11
  "transformers_version": "4.24.0"
12
  }
dmae.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import PreTrainedModel
3
+ from .dmae_config import DMAE1dConfig
4
+ from audio_encoders_pytorch import ME1d, TanhBottleneck # pip install audio_encoders_pytorch==0.0.20
5
+ from audio_diffusion_pytorch.unets import UNetV0, LTPlugin # pip install -U git+https://github.com/archinetai/audio-diffusion-pytorch.git@nightly # v0.0.2
6
+ from audio_diffusion_pytorch.models import DiffusionAE
7
+
8
+ class DMAE1d(PreTrainedModel):
9
+
10
+ config_class = DMAE1dConfig
11
+
12
+ def __init__(self, config: DMAE1dConfig):
13
+ super().__init__(config)
14
+
15
+ UNet = LTPlugin(
16
+ UNetV0,
17
+ num_filters=128,
18
+ window_length=64,
19
+ stride=64,
20
+ )
21
+
22
+ self.model = DiffusionAE(
23
+ net_t=UNet,
24
+ dim=1,
25
+ in_channels=2,
26
+ channels=[256, 512, 512, 512, 1024, 1024, 1024],
27
+ factors=[1, 2, 2, 2, 2, 2, 2],
28
+ items=[1, 2, 2, 2, 2, 2, 2],
29
+ encoder=ME1d(
30
+ in_channels=2,
31
+ channels=512,
32
+ multipliers=[1, 1, 1],
33
+ factors=[2, 2],
34
+ num_blocks=[4, 8],
35
+ stft_num_fft=1023,
36
+ stft_hop_length=256,
37
+ out_channels=32,
38
+ bottleneck=TanhBottleneck()
39
+ ),
40
+ inject_depth=4
41
+ )
42
+
43
+ def forward(self, *args, **kwargs):
44
+ return self.model(*args, **kwargs)
45
+
46
+ def encode(self, *args, **kwargs):
47
+ return self.model.encode(*args, **kwargs)
48
+
49
+ @torch.no_grad()
50
+ def decode(self, *args, **kwargs):
51
+ return self.model.decode(*args, **kwargs)
52
+
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c9ad49fb4b5ba60c7db2774eebee21590731c9b2d423efff47d8e57119982f20
3
+ size 740732261