Rex Cheng commited on
Commit
dbac20f
·
1 Parent(s): f2786fb

initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +21 -0
  2. README.md +152 -14
  3. app.py +149 -0
  4. demo.py +135 -0
  5. docs/images/icon.png +0 -0
  6. docs/index.html +147 -0
  7. docs/style.css +78 -0
  8. docs/style_videos.css +52 -0
  9. docs/video_gen.html +254 -0
  10. docs/video_main.html +98 -0
  11. docs/video_vgg.html +452 -0
  12. mmaudio/__init__.py +0 -0
  13. mmaudio/eval_utils.py +245 -0
  14. mmaudio/ext/__init__.py +1 -0
  15. mmaudio/ext/autoencoder/__init__.py +1 -0
  16. mmaudio/ext/autoencoder/autoencoder.py +48 -0
  17. mmaudio/ext/autoencoder/edm2_utils.py +168 -0
  18. mmaudio/ext/autoencoder/vae.py +369 -0
  19. mmaudio/ext/autoencoder/vae_modules.py +117 -0
  20. mmaudio/ext/bigvgan/LICENSE +21 -0
  21. mmaudio/ext/bigvgan/__init__.py +1 -0
  22. mmaudio/ext/bigvgan/activations.py +120 -0
  23. mmaudio/ext/bigvgan/alias_free_torch/__init__.py +6 -0
  24. mmaudio/ext/bigvgan/alias_free_torch/act.py +28 -0
  25. mmaudio/ext/bigvgan/alias_free_torch/filter.py +95 -0
  26. mmaudio/ext/bigvgan/alias_free_torch/resample.py +49 -0
  27. mmaudio/ext/bigvgan/bigvgan.py +32 -0
  28. mmaudio/ext/bigvgan/bigvgan_vocoder.yml +63 -0
  29. mmaudio/ext/bigvgan/env.py +18 -0
  30. mmaudio/ext/bigvgan/incl_licenses/LICENSE_1 +21 -0
  31. mmaudio/ext/bigvgan/incl_licenses/LICENSE_2 +21 -0
  32. mmaudio/ext/bigvgan/incl_licenses/LICENSE_3 +201 -0
  33. mmaudio/ext/bigvgan/incl_licenses/LICENSE_4 +29 -0
  34. mmaudio/ext/bigvgan/incl_licenses/LICENSE_5 +16 -0
  35. mmaudio/ext/bigvgan/models.py +255 -0
  36. mmaudio/ext/bigvgan/utils.py +31 -0
  37. mmaudio/ext/bigvgan_v2/LICENSE +21 -0
  38. mmaudio/ext/bigvgan_v2/__init__.py +0 -0
  39. mmaudio/ext/bigvgan_v2/activations.py +126 -0
  40. mmaudio/ext/bigvgan_v2/alias_free_activation/cuda/__init__.py +0 -0
  41. mmaudio/ext/bigvgan_v2/alias_free_activation/cuda/activation1d.py +77 -0
  42. mmaudio/ext/bigvgan_v2/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  43. mmaudio/ext/bigvgan_v2/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
  44. mmaudio/ext/bigvgan_v2/alias_free_activation/cuda/compat.h +29 -0
  45. mmaudio/ext/bigvgan_v2/alias_free_activation/cuda/load.py +86 -0
  46. mmaudio/ext/bigvgan_v2/alias_free_activation/cuda/type_shim.h +92 -0
  47. mmaudio/ext/bigvgan_v2/alias_free_activation/torch/__init__.py +6 -0
  48. mmaudio/ext/bigvgan_v2/alias_free_activation/torch/act.py +32 -0
  49. mmaudio/ext/bigvgan_v2/alias_free_activation/torch/filter.py +101 -0
  50. mmaudio/ext/bigvgan_v2/alias_free_activation/torch/resample.py +54 -0
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Ho Kei Cheng
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,14 +1,152 @@
1
- ---
2
- title: MMAudio
3
- emoji: 🐨
4
- colorFrom: yellow
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 5.8.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: Generating synchronizated audio given video/text inputs.
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # [Taming Multimodal Joint Training for High-Quality Video-to-Audio Synthesis](https://hkchengrex.github.io/MMAudio)
2
+
3
+ [Ho Kei Cheng](https://hkchengrex.github.io/), [Masato Ishii](https://scholar.google.co.jp/citations?user=RRIO1CcAAAAJ), [Akio Hayakawa](https://scholar.google.com/citations?user=sXAjHFIAAAAJ), [Takashi Shibuya](https://scholar.google.com/citations?user=XCRO260AAAAJ), [Alexander Schwing](https://www.alexander-schwing.de/), [Yuki Mitsufuji](https://www.yukimitsufuji.com/)
4
+
5
+ University of Illinois Urbana-Champaign, Sony AI, and Sony Group Corporation
6
+
7
+
8
+ [[Paper (being prepared)]](https://hkchengrex.github.io/MMAudio) [[Project Page]](https://hkchengrex.github.io/MMAudio)
9
+
10
+
11
+ **Note: This repository is still under construction. Single-example inference should work as expected. The training code will be added. Code is subject to non-backward-compatible changes.**
12
+
13
+ ## Highlight
14
+
15
+ MMAudio generates synchronized audio given video and/or text inputs.
16
+ Our key innovation is multimodal joint training which allows training on a wide range of audio-visual and audio-text datasets.
17
+ Moreover, a synchronization module aligns the generated audio with the video frames.
18
+
19
+
20
+ ## Results
21
+
22
+ (All audio from our algorithm MMAudio)
23
+
24
+ Videos from Sora:
25
+
26
+ https://github.com/user-attachments/assets/82afd192-0cee-48a1-86ca-bd39b8c8f330
27
+
28
+
29
+ Videos from MovieGen/Hunyuan Video/VGGSound:
30
+
31
+ https://github.com/user-attachments/assets/29230d4e-21c1-4cf8-a221-c28f2af6d0ca
32
+
33
+ For more results, visit https://hkchengrex.com/MMAudio/video_main.html.
34
+
35
+ ## Installation
36
+
37
+ We have only tested this on Ubuntu.
38
+
39
+ ### Prerequisites
40
+
41
+ We recommend using a [miniforge](https://github.com/conda-forge/miniforge) environment.
42
+
43
+ - Python 3.8+
44
+ - PyTorch **2.5.1+** and corresponding torchvision/torchaudio (pick your CUDA version https://pytorch.org/)
45
+ - ffmpeg<7 ([this is required by torchaudio](https://pytorch.org/audio/master/installation.html#optional-dependencies), you can install it in a miniforge environment with `conda install -c conda-forge 'ffmpeg<7'`)
46
+
47
+ **Clone our repository:**
48
+
49
+ ```bash
50
+ git clone https://github.com/hkchengrex/MMAudio.git
51
+ ```
52
+
53
+ **Install with pip:**
54
+
55
+ ```bash
56
+ cd MMAudio
57
+ pip install -e .
58
+ ```
59
+
60
+ (If you encounter the File "setup.py" not found error, upgrade your pip with pip install --upgrade pip)
61
+
62
+ **Pretrained models:**
63
+
64
+ The models will be downloaded automatically when you run the demo script. MD5 checksums are provided in `mmaudio/utils/download_utils.py`
65
+
66
+ | Model | Download link | File size |
67
+ | -------- | ------- | ------- |
68
+ | Flow prediction network, small 16kHz | <a href="https://databank.illinois.edu/datafiles/k6jve/download" download="mmaudio_small_16k.pth">mmaudio_small_16k.pth</a> | 601M |
69
+ | Flow prediction network, small 44.1kHz | <a href="https://databank.illinois.edu/datafiles/864ya/download" download="mmaudio_small_44k.pth">mmaudio_small_44k.pth</a> | 601M |
70
+ | Flow prediction network, medium 44.1kHz | <a href="https://databank.illinois.edu/datafiles/pa94t/download" download="mmaudio_medium_44k.pth">mmaudio_medium_44k.pth</a> | 2.4G |
71
+ | Flow prediction network, large 44.1kHz **(recommended)** | <a href="https://databank.illinois.edu/datafiles/4jx76/download" download="mmaudio_large_44k.pth">mmaudio_large_44k.pth</a> | 3.9G |
72
+ | 16kHz VAE | <a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/v1-16.pth">v1-16.pth</a> | 655M |
73
+ | 16kHz BigVGAN vocoder |<a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/best_netG.pt">best_netG.pt</a> | 429M |
74
+ | 44.1kHz VAE |<a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/v1-44.pth">v1-44.pth</a> | 1.2G |
75
+ | Synchformer visual encoder |<a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/synchformer_state_dict.pth">synchformer_state_dict.pth</a> | 907M |
76
+
77
+ The 44.1kHz vocoder will be downloaded automatically.
78
+
79
+ The expected directory structure (full):
80
+
81
+ ```bash
82
+ MMAudio
83
+ ├── ext_weights
84
+ │ ├── best_netG.pt
85
+ │ ├── synchformer_state_dict.pth
86
+ │ ├── v1-16.pth
87
+ │ └── v1-44.pth
88
+ ├── weights
89
+ │ ├── mmaudio_small_16k.pth
90
+ │ ├── mmaudio_small_44k.pth
91
+ │ ├── mmaudio_medium_44k.pth
92
+ │ └── mmaudio_large_44k.pth
93
+ └── ...
94
+ ```
95
+
96
+ The expected directory structure (minimal, for the recommended model only):
97
+
98
+ ```bash
99
+ MMAudio
100
+ ├── ext_weights
101
+ │ ├── synchformer_state_dict.pth
102
+ │ └── v1-44.pth
103
+ ├── weights
104
+ │ └── mmaudio_large_44k.pth
105
+ └── ...
106
+ ```
107
+
108
+ ## Demo
109
+
110
+ By default, these scripts use the `large_44k` model.
111
+ In our experiments, inference only takes around 6GB of GPU memory (in 16-bit mode) which should fit in most modern GPUs.
112
+
113
+ ### Command-line interface
114
+
115
+ With `demo.py`
116
+ ```bash
117
+ python demo.py --duration=8 --video=<path to video> --prompt "your prompt"
118
+ ```
119
+ The output (audio in `.flac` format, and video in `.mp4` format) will be saved in `./output`.
120
+ See the file for more options.
121
+ Simply omit the `--video` option for text-to-audio synthesis.
122
+ The default output (and training) duration is 8 seconds. Longer/shorter durations could also work, but a large deviation from the training duration may result in a lower quality.
123
+
124
+
125
+ ### Gradio interface
126
+
127
+ Supports video-to-audio and text-to-audio synthesis.
128
+
129
+ ```
130
+ python gradio_demo.py
131
+ ```
132
+
133
+ ### Known limitations
134
+
135
+ 1. The model sometimes generates undesired unintelligible human speech-like sounds
136
+ 2. The model sometimes generates undesired background music
137
+ 3. The model struggles with unfamiliar concepts, e.g., it can generate "gunfires" but not "RPG firing".
138
+
139
+ We believe all of these three limitations can be addressed with more high-quality training data.
140
+
141
+ ## Training
142
+ Work in progress.
143
+
144
+ ## Evaluation
145
+ Work in progress.
146
+
147
+ ## Acknowledgement
148
+ Many thanks to:
149
+ - [Make-An-Audio 2](https://github.com/bytedance/Make-An-Audio-2) for the 16kHz BigVGAN pretrained model
150
+ - [BigVGAN](https://github.com/NVIDIA/BigVGAN)
151
+ - [Synchformer](https://github.com/v-iashin/Synchformer)
152
+
app.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from datetime import datetime
3
+ from pathlib import Path
4
+
5
+ import gradio as gr
6
+ import torch
7
+ import torchaudio
8
+
9
+ from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video,
10
+ setup_eval_logging)
11
+ from mmaudio.model.flow_matching import FlowMatching
12
+ from mmaudio.model.networks import MMAudio, get_my_mmaudio
13
+ from mmaudio.model.sequence_config import SequenceConfig
14
+ from mmaudio.model.utils.features_utils import FeaturesUtils
15
+
16
+ torch.backends.cuda.matmul.allow_tf32 = True
17
+ torch.backends.cudnn.allow_tf32 = True
18
+
19
+ log = logging.getLogger()
20
+
21
+ device = 'cuda'
22
+ dtype = torch.bfloat16
23
+
24
+ model: ModelConfig = all_model_cfg['large_44k_v2']
25
+ model.download_if_needed()
26
+ output_dir = Path('./output/gradio')
27
+
28
+ setup_eval_logging()
29
+
30
+
31
+ def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
32
+ seq_cfg = model.seq_cfg
33
+
34
+ net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval()
35
+ net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
36
+ log.info(f'Loaded weights from {model.model_path}')
37
+
38
+ feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
39
+ synchformer_ckpt=model.synchformer_ckpt,
40
+ enable_conditions=True,
41
+ mode=model.mode,
42
+ bigvgan_vocoder_ckpt=model.bigvgan_16k_path)
43
+ feature_utils = feature_utils.to(device, dtype).eval()
44
+
45
+ return net, feature_utils, seq_cfg
46
+
47
+
48
+ net, feature_utils, seq_cfg = get_model()
49
+
50
+
51
+ @torch.inference_mode()
52
+ def video_to_audio(video: gr.Video, prompt: str, negative_prompt: str, seed: int, num_steps: int,
53
+ cfg_strength: float, duration: float):
54
+
55
+ rng = torch.Generator(device=device)
56
+ rng.manual_seed(seed)
57
+ fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
58
+
59
+ clip_frames, sync_frames, duration = load_video(video, duration)
60
+ clip_frames = clip_frames.unsqueeze(0)
61
+ sync_frames = sync_frames.unsqueeze(0)
62
+ seq_cfg.duration = duration
63
+ net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
64
+
65
+ audios = generate(clip_frames,
66
+ sync_frames, [prompt],
67
+ negative_text=[negative_prompt],
68
+ feature_utils=feature_utils,
69
+ net=net,
70
+ fm=fm,
71
+ rng=rng,
72
+ cfg_strength=cfg_strength)
73
+ audio = audios.float().cpu()[0]
74
+
75
+ current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
76
+ output_dir.mkdir(exist_ok=True, parents=True)
77
+ video_save_path = output_dir / f'{current_time_string}.mp4'
78
+ make_video(video,
79
+ video_save_path,
80
+ audio,
81
+ sampling_rate=seq_cfg.sampling_rate,
82
+ duration_sec=seq_cfg.duration)
83
+ return video_save_path
84
+
85
+
86
+ @torch.inference_mode()
87
+ def text_to_audio(prompt: str, negative_prompt: str, seed: int, num_steps: int, cfg_strength: float,
88
+ duration: float):
89
+
90
+ rng = torch.Generator(device=device)
91
+ rng.manual_seed(seed)
92
+ fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
93
+
94
+ clip_frames = sync_frames = None
95
+ seq_cfg.duration = duration
96
+ net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
97
+
98
+ audios = generate(clip_frames,
99
+ sync_frames, [prompt],
100
+ negative_text=[negative_prompt],
101
+ feature_utils=feature_utils,
102
+ net=net,
103
+ fm=fm,
104
+ rng=rng,
105
+ cfg_strength=cfg_strength)
106
+ audio = audios.float().cpu()[0]
107
+
108
+ current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
109
+ output_dir.mkdir(exist_ok=True, parents=True)
110
+ audio_save_path = output_dir / f'{current_time_string}.flac'
111
+ torchaudio.save(audio_save_path, audio, seq_cfg.sampling_rate)
112
+ return audio_save_path
113
+
114
+
115
+ video_to_audio_tab = gr.Interface(
116
+ fn=video_to_audio,
117
+ inputs=[
118
+ gr.Video(),
119
+ gr.Text(label='Prompt'),
120
+ gr.Text(label='Negative prompt', value='music'),
121
+ gr.Number(label='Seed', value=0, precision=0, minimum=0),
122
+ gr.Number(label='Num steps', value=25, precision=0, minimum=1),
123
+ gr.Number(label='Guidance Strength', value=4.5, minimum=1),
124
+ gr.Number(label='Duration (sec)', value=8, minimum=1),
125
+ ],
126
+ outputs='playable_video',
127
+ cache_examples=False,
128
+ title='MMAudio — Video-to-Audio Synthesis',
129
+ )
130
+
131
+ text_to_audio_tab = gr.Interface(
132
+ fn=text_to_audio,
133
+ inputs=[
134
+ gr.Text(label='Prompt'),
135
+ gr.Text(label='Negative prompt'),
136
+ gr.Number(label='Seed', value=0, precision=0, minimum=0),
137
+ gr.Number(label='Num steps', value=25, precision=0, minimum=1),
138
+ gr.Number(label='Guidance Strength', value=4.5, minimum=1),
139
+ gr.Number(label='Duration (sec)', value=8, minimum=1),
140
+ ],
141
+ outputs='audio',
142
+ cache_examples=False,
143
+ title='MMAudio — Text-to-Audio Synthesis',
144
+ )
145
+
146
+ if __name__ == "__main__":
147
+ gr.TabbedInterface([video_to_audio_tab, text_to_audio_tab],
148
+ ['Video-to-Audio', 'Text-to-Audio']).launch(server_port=17888,
149
+ allowed_paths=[output_dir])
demo.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from argparse import ArgumentParser
3
+ from pathlib import Path
4
+
5
+ import torch
6
+ import torchaudio
7
+
8
+ from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate,
9
+ load_video, make_video, setup_eval_logging)
10
+ from mmaudio.model.flow_matching import FlowMatching
11
+ from mmaudio.model.networks import MMAudio, get_my_mmaudio
12
+ from mmaudio.model.utils.features_utils import FeaturesUtils
13
+
14
+ torch.backends.cuda.matmul.allow_tf32 = True
15
+ torch.backends.cudnn.allow_tf32 = True
16
+
17
+ log = logging.getLogger()
18
+
19
+
20
+ @torch.inference_mode()
21
+ def main():
22
+ setup_eval_logging()
23
+
24
+ parser = ArgumentParser()
25
+ parser.add_argument('--variant',
26
+ type=str,
27
+ default='large_44k_v2',
28
+ help='small_16k, small_44k, medium_44k, large_44k, large_44k_v2')
29
+ parser.add_argument('--video', type=Path, help='Path to the video file')
30
+ parser.add_argument('--prompt', type=str, help='Input prompt', default='')
31
+ parser.add_argument('--negative_prompt', type=str, help='Negative prompt', default='')
32
+ parser.add_argument('--duration', type=float, default=8.0)
33
+ parser.add_argument('--cfg_strength', type=float, default=4.5)
34
+ parser.add_argument('--num_steps', type=int, default=25)
35
+
36
+ parser.add_argument('--mask_away_clip', action='store_true')
37
+
38
+ parser.add_argument('--output', type=Path, help='Output directory', default='./output')
39
+ parser.add_argument('--seed', type=int, help='Random seed', default=42)
40
+ parser.add_argument('--skip_video_composite', action='store_true')
41
+ parser.add_argument('--full_precision', action='store_true')
42
+
43
+ args = parser.parse_args()
44
+
45
+ if args.variant not in all_model_cfg:
46
+ raise ValueError(f'Unknown model variant: {args.variant}')
47
+ model: ModelConfig = all_model_cfg[args.variant]
48
+ model.download_if_needed()
49
+ seq_cfg = model.seq_cfg
50
+
51
+ if args.video:
52
+ video_path: Path = Path(args.video).expanduser()
53
+ else:
54
+ video_path = None
55
+ prompt: str = args.prompt
56
+ negative_prompt: str = args.negative_prompt
57
+ output_dir: str = args.output.expanduser()
58
+ seed: int = args.seed
59
+ num_steps: int = args.num_steps
60
+ duration: float = args.duration
61
+ cfg_strength: float = args.cfg_strength
62
+ skip_video_composite: bool = args.skip_video_composite
63
+ mask_away_clip: bool = args.mask_away_clip
64
+
65
+ device = 'cuda'
66
+ dtype = torch.float32 if args.full_precision else torch.bfloat16
67
+
68
+ output_dir.mkdir(parents=True, exist_ok=True)
69
+
70
+ # load a pretrained model
71
+ net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval()
72
+ net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
73
+ log.info(f'Loaded weights from {model.model_path}')
74
+
75
+ # misc setup
76
+ rng = torch.Generator(device=device)
77
+ rng.manual_seed(seed)
78
+ fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
79
+
80
+ feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
81
+ synchformer_ckpt=model.synchformer_ckpt,
82
+ enable_conditions=True,
83
+ mode=model.mode,
84
+ bigvgan_vocoder_ckpt=model.bigvgan_16k_path)
85
+ feature_utils = feature_utils.to(device, dtype).eval()
86
+
87
+ if video_path is not None:
88
+ log.info(f'Using video {video_path}')
89
+ clip_frames, sync_frames, duration = load_video(video_path, duration)
90
+ if mask_away_clip:
91
+ clip_frames = None
92
+ else:
93
+ clip_frames = clip_frames.unsqueeze(0)
94
+ sync_frames = sync_frames.unsqueeze(0)
95
+ else:
96
+ log.info('No video provided -- text-to-audio mode')
97
+ clip_frames = sync_frames = None
98
+
99
+ seq_cfg.duration = duration
100
+ net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
101
+
102
+ log.info(f'Prompt: {prompt}')
103
+ log.info(f'Negative prompt: {negative_prompt}')
104
+
105
+ audios = generate(clip_frames,
106
+ sync_frames, [prompt],
107
+ negative_text=[negative_prompt],
108
+ feature_utils=feature_utils,
109
+ net=net,
110
+ fm=fm,
111
+ rng=rng,
112
+ cfg_strength=cfg_strength)
113
+ audio = audios.float().cpu()[0]
114
+ if video_path is not None:
115
+ save_path = output_dir / f'{video_path.stem}.flac'
116
+ else:
117
+ safe_filename = prompt.replace(' ', '_').replace('/', '_').replace('.', '')
118
+ save_path = output_dir / f'{safe_filename}.flac'
119
+ torchaudio.save(save_path, audio, seq_cfg.sampling_rate)
120
+
121
+ log.info(f'Audio saved to {save_path}')
122
+ if video_path is not None and not skip_video_composite:
123
+ video_save_path = output_dir / f'{video_path.stem}.mp4'
124
+ make_video(video_path,
125
+ video_save_path,
126
+ audio,
127
+ sampling_rate=seq_cfg.sampling_rate,
128
+ duration_sec=seq_cfg.duration)
129
+ log.info(f'Video saved to {output_dir / video_save_path}')
130
+
131
+ log.info('Memory usage: %.2f GB', torch.cuda.max_memory_allocated() / (2**30))
132
+
133
+
134
+ if __name__ == '__main__':
135
+ main()
docs/images/icon.png ADDED
docs/index.html ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <!-- Google tag (gtag.js) -->
5
+ <script async src="https://www.googletagmanager.com/gtag/js?id=G-0JKBJ3WRJZ"></script>
6
+ <script>
7
+ window.dataLayer = window.dataLayer || [];
8
+ function gtag(){dataLayer.push(arguments);}
9
+ gtag('js', new Date());
10
+ gtag('config', 'G-0JKBJ3WRJZ');
11
+ </script>
12
+
13
+ <link rel="preconnect" href="https://fonts.googleapis.com">
14
+ <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
15
+ <link href="https://fonts.googleapis.com/css2?family=Source+Sans+3&display=swap" rel="stylesheet">
16
+ <meta charset="UTF-8">
17
+ <title>MMAudio</title>
18
+
19
+ <link rel="icon" type="image/png" href="images/icon.png">
20
+
21
+ <meta name="viewport" content="width=device-width, initial-scale=1">
22
+ <!-- CSS only -->
23
+ <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.0.1/dist/css/bootstrap.min.css" rel="stylesheet"
24
+ integrity="sha384-+0n0xVW2eSR5OomGNYDnhzAbDsOXxcvSN1TPprVMTNDbiYZCxYbOOl7+AMvyTG2x" crossorigin="anonymous">
25
+ <script src="https://ajax.googleapis.com/ajax/libs/jquery/3.5.1/jquery.min.js"></script>
26
+
27
+ <link rel="stylesheet" href="style.css">
28
+ </head>
29
+ <body>
30
+
31
+ <body>
32
+ <br><br><br><br>
33
+ <div class="container">
34
+ <div class="row text-center" style="font-size:38px">
35
+ <div class="col strong">
36
+ Taming Multimodal Joint Training for High-Quality <br>Video-to-Audio Synthesis
37
+ </div>
38
+ </div>
39
+
40
+ <br>
41
+ <div class="row text-center" style="font-size:28px">
42
+ <div class="col">
43
+ arXiv 2024
44
+ </div>
45
+ </div>
46
+ <br>
47
+
48
+ <div class="h-100 row text-center heavy justify-content-md-center" style="font-size:22px;">
49
+ <div class="col-sm-auto px-lg-2">
50
+ <a href="https://hkchengrex.github.io/">Ho Kei Cheng<sup>1</sup></a>
51
+ </div>
52
+ <div class="col-sm-auto px-lg-2">
53
+ <nobr><a href="https://scholar.google.co.jp/citations?user=RRIO1CcAAAAJ">Masato Ishii<sup>2</sup></a></nobr>
54
+ </div>
55
+ <div class="col-sm-auto px-lg-2">
56
+ <nobr><a href="https://scholar.google.com/citations?user=sXAjHFIAAAAJ">Akio Hayakawa<sup>2</sup></a></nobr>
57
+ </div>
58
+ <div class="col-sm-auto px-lg-2">
59
+ <nobr><a href="https://scholar.google.com/citations?user=XCRO260AAAAJ">Takashi Shibuya<sup>2</sup></a></nobr>
60
+ </div>
61
+ <div class="col-sm-auto px-lg-2">
62
+ <nobr><a href="https://www.alexander-schwing.de/">Alexander Schwing<sup>1</sup></a></nobr>
63
+ </div>
64
+ <div class="col-sm-auto px-lg-2" >
65
+ <nobr><a href="https://www.yukimitsufuji.com/">Yuki Mitsufuji<sup>2,3</sup></a></nobr>
66
+ </div>
67
+ </div>
68
+
69
+ <div class="h-100 row text-center heavy justify-content-md-center" style="font-size:22px;">
70
+ <div class="col-sm-auto px-lg-2">
71
+ <sup>1</sup>University of Illinois Urbana-Champaign
72
+ </div>
73
+ <div class="col-sm-auto px-lg-2">
74
+ <sup>2</sup>Sony AI
75
+ </div>
76
+ <div class="col-sm-auto px-lg-2">
77
+ <sup>3</sup>Sony Group Corporation
78
+ </div>
79
+ </div>
80
+
81
+ <br>
82
+
83
+ <br>
84
+
85
+ <div class="h-100 row text-center justify-content-md-center" style="font-size:20px;">
86
+ <!-- <div class="col-sm-2">
87
+ <a href="https://arxiv.org/abs/2310.12982">[arXiv]</a>
88
+ </div> -->
89
+ <div class="col-sm-3">
90
+ <a href="">[Paper (being prepared)]</a>
91
+ </div>
92
+ <div class="col-sm-3">
93
+ <a href="https://github.com/hkchengrex/MMAudio">[Code]</a>
94
+ </div>
95
+ <!-- <div class="col-sm-2">
96
+ <a
97
+ href="https://colab.research.google.com/drive/1yo43XTbjxuWA7XgCUO9qxAi7wBI6HzvP?usp=sharing">[Colab]</a>
98
+ </div> -->
99
+ </div>
100
+
101
+ <br>
102
+
103
+ <hr>
104
+
105
+ <div class="row" style="font-size:32px">
106
+ <div class="col strong">
107
+ TL;DR
108
+ </div>
109
+ </div>
110
+ <br>
111
+ <div class="row">
112
+ <div class="col">
113
+ <p class="light" style="text-align: left;">
114
+ MMAudio generates synchronized audio given video and/or text inputs.
115
+ </p>
116
+ </div>
117
+ </div>
118
+
119
+ <br>
120
+ <hr>
121
+ <br>
122
+
123
+ <div class="row" style="font-size:32px">
124
+ <div class="col strong">
125
+ Demo
126
+ </div>
127
+ </div>
128
+ <br>
129
+ <div class="row" style="font-size:48px">
130
+ <div class="col strong text-center">
131
+ <a href="video_main.html" style="text-decoration: underline;">&lt;More results&gt;</a>
132
+ </div>
133
+ </div>
134
+ <br>
135
+ <div class="video-container" style="text-align: center;">
136
+ <iframe src="https://youtube.com/embed/YElewUT2M4M"></iframe>
137
+ </div>
138
+
139
+ <br>
140
+
141
+ <br><br>
142
+ <br><br>
143
+
144
+ </div>
145
+
146
+ </body>
147
+ </html>
docs/style.css ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ body {
2
+ font-family: 'Source Sans 3', sans-serif;
3
+ font-size: 18px;
4
+ margin-left: auto;
5
+ margin-right: auto;
6
+ font-weight: 400;
7
+ height: 100%;
8
+ max-width: 1000px;
9
+ }
10
+
11
+ table {
12
+ width: 100%;
13
+ border-collapse: collapse;
14
+ }
15
+ th, td {
16
+ border: 1px solid #ddd;
17
+ padding: 8px;
18
+ text-align: center;
19
+ }
20
+ th {
21
+ background-color: #f2f2f2;
22
+ }
23
+ video {
24
+ width: 100%;
25
+ height: auto;
26
+ }
27
+ p {
28
+ font-size: 28px;
29
+ }
30
+ h2 {
31
+ font-size: 36px;
32
+ }
33
+
34
+ .strong {
35
+ font-weight: 700;
36
+ }
37
+
38
+ .light {
39
+ font-weight: 100;
40
+ }
41
+
42
+ .heavy {
43
+ font-weight: 900;
44
+ }
45
+
46
+ .column {
47
+ float: left;
48
+ }
49
+
50
+ a:link,
51
+ a:visited {
52
+ color: #05538f;
53
+ text-decoration: none;
54
+ }
55
+
56
+ a:hover {
57
+ color: #63cbdd;
58
+ }
59
+
60
+ hr {
61
+ border: 0;
62
+ height: 1px;
63
+ background-image: linear-gradient(to right, rgba(0, 0, 0, 0), rgba(0, 0, 0, 0.75), rgba(0, 0, 0, 0));
64
+ }
65
+
66
+ .video-container {
67
+ position: relative;
68
+ padding-bottom: 56.25%; /* 16:9 */
69
+ height: 0;
70
+ }
71
+
72
+ .video-container iframe {
73
+ position: absolute;
74
+ top: 0;
75
+ left: 0;
76
+ width: 100%;
77
+ height: 100%;
78
+ }
docs/style_videos.css ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ body {
2
+ font-family: 'Source Sans 3', sans-serif;
3
+ font-size: 1.5vh;
4
+ font-weight: 400;
5
+ }
6
+
7
+ table {
8
+ width: 100%;
9
+ border-collapse: collapse;
10
+ }
11
+ th, td {
12
+ border: 1px solid #ddd;
13
+ padding: 8px;
14
+ text-align: center;
15
+ }
16
+ th {
17
+ background-color: #f2f2f2;
18
+ }
19
+ video {
20
+ width: 100%;
21
+ height: auto;
22
+ }
23
+ p {
24
+ font-size: 1.5vh;
25
+ font-weight: bold;
26
+ }
27
+ h2 {
28
+ font-size: 2vh;
29
+ font-weight: bold;
30
+ }
31
+
32
+ .video-container {
33
+ position: relative;
34
+ padding-bottom: 56.25%; /* 16:9 */
35
+ height: 0;
36
+ }
37
+
38
+ .video-container iframe {
39
+ position: absolute;
40
+ top: 0;
41
+ left: 0;
42
+ width: 100%;
43
+ height: 100%;
44
+ }
45
+
46
+ .video-header {
47
+ background-color: #f2f2f2;
48
+ text-align: center;
49
+ font-size: 1.5vh;
50
+ font-weight: bold;
51
+ padding: 8px;
52
+ }
docs/video_gen.html ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <!-- Google tag (gtag.js) -->
5
+ <script async src="https://www.googletagmanager.com/gtag/js?id=G-0JKBJ3WRJZ"></script>
6
+ <script>
7
+ window.dataLayer = window.dataLayer || [];
8
+ function gtag(){dataLayer.push(arguments);}
9
+ gtag('js', new Date());
10
+ gtag('config', 'G-0JKBJ3WRJZ');
11
+ </script>
12
+
13
+ <link href='https://fonts.googleapis.com/css?family=Source+Sans+Pro' rel='stylesheet' type='text/css'>
14
+ <meta charset="UTF-8">
15
+ <title>MMAudio</title>
16
+
17
+ <link rel="icon" type="image/png" href="images/icon.png">
18
+
19
+ <meta name="viewport" content="width=device-width, initial-scale=1">
20
+ <!-- CSS only -->
21
+ <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.0.1/dist/css/bootstrap.min.css" rel="stylesheet"
22
+ integrity="sha384-+0n0xVW2eSR5OomGNYDnhzAbDsOXxcvSN1TPprVMTNDbiYZCxYbOOl7+AMvyTG2x" crossorigin="anonymous">
23
+ <script src="https://ajax.googleapis.com/ajax/libs/jquery/3.7.1/jquery.min.js"></script>
24
+
25
+ <link rel="stylesheet" href="style_videos.css">
26
+ </head>
27
+ <body>
28
+
29
+ <div id="moviegen_all">
30
+ <h2 id="moviegen" style="text-align: center;">Comparisons with Movie Gen Audio on Videos Generated by MovieGen</h2>
31
+ <p id="moviegen1" style="overflow: hidden;">
32
+ Example 1: Ice cracking with sharp snapping sound, and metal tool scraping against the ice surface.
33
+ <span style="float: right;"><a href="#index">Back to index</a></span>
34
+ </p>
35
+
36
+ <div class="row g-1">
37
+ <div class="col-sm-6">
38
+ <div class="video-header">Movie Gen Audio</div>
39
+ <div class="video-container">
40
+ <iframe src="https://youtube.com/embed/d7Lb0ihtGcE"></iframe>
41
+ </div>
42
+ </div>
43
+ <div class="col-sm-6">
44
+ <div class="video-header">Ours</div>
45
+ <div class="video-container">
46
+ <iframe src="https://youtube.com/embed/F4JoJ2r2m8U"></iframe>
47
+ </div>
48
+ </div>
49
+ </div>
50
+ <br>
51
+
52
+ <!-- <p id="moviegen2">Example 2: Rhythmic splashing and lapping of water. <span style="float:right;"><a href="#index">Back to index</a></span> </p>
53
+
54
+ <table>
55
+ <thead>
56
+ <tr>
57
+ <th>Movie Gen Audio</th>
58
+ <th>Ours</th>
59
+ </tr>
60
+ </thead>
61
+ <tbody>
62
+ <tr>
63
+ <td width="50%">
64
+ <div class="video-container">
65
+ <iframe src="https://youtube.com/embed/5gQNPK99CIk"></iframe>
66
+ </div>
67
+ </td>
68
+ <td width="50%">
69
+ <div class="video-container">
70
+ <iframe src="https://youtube.com/embed/AbwnTzG-BpA"></iframe>
71
+ </div>
72
+ </td>
73
+ </tr>
74
+ </tbody>
75
+ </table> -->
76
+
77
+ <p id="moviegen2" style="overflow: hidden;">
78
+ Example 2: Rhythmic splashing and lapping of water.
79
+ <span style="float:right;"><a href="#index">Back to index</a></span>
80
+ </p>
81
+ <div class="row g-1">
82
+ <div class="col-sm-6">
83
+ <div class="video-header">Movie Gen Audio</div>
84
+ <div class="video-container">
85
+ <iframe src="https://youtube.com/embed/5gQNPK99CIk"></iframe>
86
+ </div>
87
+ </div>
88
+ <div class="col-sm-6">
89
+ <div class="video-header">Ours</div>
90
+ <div class="video-container">
91
+ <iframe src="https://youtube.com/embed/AbwnTzG-BpA"></iframe>
92
+ </div>
93
+ </div>
94
+ </div>
95
+ <br>
96
+
97
+ <p id="moviegen3" style="overflow: hidden;">
98
+ Example 3: Shovel scrapes against dry earth.
99
+ <span style="float:right;"><a href="#index">Back to index</a></span>
100
+ </p>
101
+ <div class="row g-1">
102
+ <div class="col-sm-6">
103
+ <div class="video-header">Movie Gen Audio</div>
104
+ <div class="video-container">
105
+ <iframe src="https://youtube.com/embed/PUKGyEve7XQ"></iframe>
106
+ </div>
107
+ </div>
108
+ <div class="col-sm-6">
109
+ <div class="video-header">Ours</div>
110
+ <div class="video-container">
111
+ <iframe src="https://youtube.com/embed/CNn7i8VNkdc"></iframe>
112
+ </div>
113
+ </div>
114
+ </div>
115
+ <br>
116
+
117
+
118
+ <p id="moviegen4" style="overflow: hidden;">
119
+ (Failure case) Example 4: Creamy sound of mashed potatoes being scooped.
120
+ <span style="float:right;"><a href="#index">Back to index</a></span>
121
+ </p>
122
+ <div class="row g-1">
123
+ <div class="col-sm-6">
124
+ <div class="video-header">Movie Gen Audio</div>
125
+ <div class="video-container">
126
+ <iframe src="https://youtube.com/embed/PJv1zxR9JjQ"></iframe>
127
+ </div>
128
+ </div>
129
+ <div class="col-sm-6">
130
+ <div class="video-header">Ours</div>
131
+ <div class="video-container">
132
+ <iframe src="https://youtube.com/embed/c3-LJ1lNsPQ"></iframe>
133
+ </div>
134
+ </div>
135
+ </div>
136
+ <br>
137
+
138
+ </div>
139
+
140
+ <div id="hunyuan_sora_all">
141
+
142
+ <h2 id="hunyuan" style="text-align: center;">Results on Videos Generated by Hunyuan</h2>
143
+ <p style="overflow: hidden;">
144
+ <span style="float:right;"><a href="#index">Back to index</a></span>
145
+ </p>
146
+ <div class="row g-1">
147
+ <div class="col-sm-6">
148
+ <div class="video-header">Typing</div>
149
+ <div class="video-container">
150
+ <iframe src="https://youtube.com/embed/8ln_9hhH_nk"></iframe>
151
+ </div>
152
+ </div>
153
+ <div class="col-sm-6">
154
+ <div class="video-header">Water is rushing down a stream and pouring</div>
155
+ <div class="video-container">
156
+ <iframe src="https://youtube.com/embed/5df1FZFQj30"></iframe>
157
+ </div>
158
+ </div>
159
+ </div>
160
+ <div class="row g-1">
161
+ <div class="col-sm-6">
162
+ <div class="video-header">Waves on beach</div>
163
+ <div class="video-container">
164
+ <iframe src="https://youtube.com/embed/7wQ9D5WgpFc"></iframe>
165
+ </div>
166
+ </div>
167
+ <div class="col-sm-6">
168
+ <div class="video-header">Water droplet</div>
169
+ <div class="video-container">
170
+ <iframe src="https://youtube.com/embed/q7M2nsalGjM"></iframe>
171
+ </div>
172
+ </div>
173
+ </div>
174
+ <br>
175
+
176
+ <h2 id="sora" style="text-align: center;">Results on Videos Generated by Sora</h2>
177
+ <p style="overflow: hidden;">
178
+ <span style="float:right;"><a href="#index">Back to index</a></span>
179
+ </p>
180
+ <div class="row g-1">
181
+ <div class="col-sm-6">
182
+ <div class="video-header">Ships riding waves</div>
183
+ <div class="video-container">
184
+ <iframe src="https://youtube.com/embed/JbgQzHHytk8"></iframe>
185
+ </div>
186
+ </div>
187
+ <div class="col-sm-6">
188
+ <div class="video-header">Train (no text prompt given)</div>
189
+ <div class="video-container">
190
+ <iframe src="https://youtube.com/embed/xOW7zrjpWC8"></iframe>
191
+ </div>
192
+ </div>
193
+ </div>
194
+ <div class="row g-1">
195
+ <div class="col-sm-6">
196
+ <div class="video-header">Seashore (no text prompt given)</div>
197
+ <div class="video-container">
198
+ <iframe src="https://youtube.com/embed/fIuw5Y8ZZ9E"></iframe>
199
+ </div>
200
+ </div>
201
+ <div class="col-sm-6">
202
+ <div class="video-header">Surfing (failure: unprompted music)</div>
203
+ <div class="video-container">
204
+ <iframe src="https://youtube.com/embed/UcSTk-v0M_s"></iframe>
205
+ </div>
206
+ </div>
207
+ </div>
208
+ <br>
209
+
210
+ <div id="mochi_ltx_all">
211
+ <h2 id="mochi" style="text-align: center;">Results on Videos Generated by Mochi 1</h2>
212
+ <p style="overflow: hidden;">
213
+ <span style="float:right;"><a href="#index">Back to index</a></span>
214
+ </p>
215
+ <div class="row g-1">
216
+ <div class="col-sm-6">
217
+ <div class="video-header">Magical fire and lightning (no text prompt given)</div>
218
+ <div class="video-container">
219
+ <iframe src="https://youtube.com/embed/tTlRZaSMNwY"></iframe>
220
+ </div>
221
+ </div>
222
+ <div class="col-sm-6">
223
+ <div class="video-header">Storm (no text prompt given)</div>
224
+ <div class="video-container">
225
+ <iframe src="https://youtube.com/embed/4hrZTMJUy3w"></iframe>
226
+ </div>
227
+ </div>
228
+ </div>
229
+ <br>
230
+
231
+ <h2 id="ltx" style="text-align: center;">Results on Videos Generated by LTX-Video</h2>
232
+ <p style="overflow: hidden;">
233
+ <span style="float:right;"><a href="#index">Back to index</a></span>
234
+ </p>
235
+ <div class="row g-1">
236
+ <div class="col-sm-6">
237
+ <div class="video-header">Firewood burning and cracking</div>
238
+ <div class="video-container">
239
+ <iframe src="https://youtube.com/embed/P7_DDpgev0g"></iframe>
240
+ </div>
241
+ </div>
242
+ <div class="col-sm-6">
243
+ <div class="video-header">Waterfall, water splashing</div>
244
+ <div class="video-container">
245
+ <iframe src="https://youtube.com/embed/4MvjceYnIO0"></iframe>
246
+ </div>
247
+ </div>
248
+ </div>
249
+ <br>
250
+
251
+ </div>
252
+
253
+ </body>
254
+ </html>
docs/video_main.html ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <!-- Google tag (gtag.js) -->
5
+ <script async src="https://www.googletagmanager.com/gtag/js?id=G-0JKBJ3WRJZ"></script>
6
+ <script>
7
+ window.dataLayer = window.dataLayer || [];
8
+ function gtag(){dataLayer.push(arguments);}
9
+ gtag('js', new Date());
10
+ gtag('config', 'G-0JKBJ3WRJZ');
11
+ </script>
12
+
13
+ <link href='https://fonts.googleapis.com/css?family=Source+Sans+Pro' rel='stylesheet' type='text/css'>
14
+ <meta charset="UTF-8">
15
+ <title>MMAudio</title>
16
+
17
+ <link rel="icon" type="image/png" href="images/icon.png">
18
+
19
+ <meta name="viewport" content="width=device-width, initial-scale=1, maximum-scale=1, user-scalable=no">
20
+ <!-- CSS only -->
21
+ <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.0.1/dist/css/bootstrap.min.css" rel="stylesheet"
22
+ integrity="sha384-+0n0xVW2eSR5OomGNYDnhzAbDsOXxcvSN1TPprVMTNDbiYZCxYbOOl7+AMvyTG2x" crossorigin="anonymous">
23
+ <script src="https://ajax.googleapis.com/ajax/libs/jquery/3.7.1/jquery.min.js"></script>
24
+
25
+ <link rel="stylesheet" href="style_videos.css">
26
+
27
+ <script type="text/javascript">
28
+ $(document).ready(function(){
29
+ $("#content").load("video_gen.html #moviegen_all");
30
+ $("#load_moveigen").click(function(){
31
+ $("#content").load("video_gen.html #moviegen_all");
32
+ });
33
+ $("#load_hunyuan_sora").click(function(){
34
+ $("#content").load("video_gen.html #hunyuan_sora_all");
35
+ });
36
+ $("#load_mochi_ltx").click(function(){
37
+ $("#content").load("video_gen.html #mochi_ltx_all");
38
+ });
39
+ $("#load_vgg1").click(function(){
40
+ $("#content").load("video_vgg.html #vgg1");
41
+ });
42
+ $("#load_vgg2").click(function(){
43
+ $("#content").load("video_vgg.html #vgg2");
44
+ });
45
+ $("#load_vgg3").click(function(){
46
+ $("#content").load("video_vgg.html #vgg3");
47
+ });
48
+ $("#load_vgg4").click(function(){
49
+ $("#content").load("video_vgg.html #vgg4");
50
+ });
51
+ $("#load_vgg5").click(function(){
52
+ $("#content").load("video_vgg.html #vgg5");
53
+ });
54
+ $("#load_vgg6").click(function(){
55
+ $("#content").load("video_vgg.html #vgg6");
56
+ });
57
+ $("#load_vgg_extra").click(function(){
58
+ $("#content").load("video_vgg.html #vgg_extra");
59
+ });
60
+ });
61
+ </script>
62
+ </head>
63
+ <body>
64
+ <h1 id="index" style="text-align: center;">Index</h1>
65
+ <p><b>(Click on the links to load the corresponding videos)</b> <span style="float:right;"><a href="index.html">Back to project page</a></span></p>
66
+
67
+ <ol>
68
+ <li>
69
+ <a href="#" id="load_moveigen">Comparisons with Movie Gen Audio on Videos Generated by MovieGen</a>
70
+ </li>
71
+ <li>
72
+ <a href="#" id="load_hunyuan_sora">Results on Videos Generated by Hunyuan and Sora</a>
73
+ </li>
74
+ <li>
75
+ <a href="#" id="load_mochi_ltx">Results on Videos Generated by Mochi 1 and LTX-Video</a>
76
+ </li>
77
+ <li>
78
+ On VGGSound
79
+ <ol>
80
+ <li><a id='load_vgg1' href="#">Example 1: Wolf howling</a></li>
81
+ <li><a id='load_vgg2' href="#">Example 2: Striking a golf ball</a></li>
82
+ <li><a id='load_vgg3' href="#">Example 3: Hitting a drum</a></li>
83
+ <li><a id='load_vgg4' href="#">Example 4: Dog barking</a></li>
84
+ <li><a id='load_vgg5' href="#">Example 5: Playing a string instrument</a></li>
85
+ <li><a id='load_vgg6' href="#">Example 6: A group of people playing tambourines</a></li>
86
+ <li><a id='load_vgg_extra' href="#">Extra results & failure cases</a></li>
87
+ </ol>
88
+ </li>
89
+ </ol>
90
+
91
+ <div id="content" class="container-fluid">
92
+
93
+ </div>
94
+ <br>
95
+ <br>
96
+
97
+ </body>
98
+ </html>
docs/video_vgg.html ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <!-- Google tag (gtag.js) -->
5
+ <script async src="https://www.googletagmanager.com/gtag/js?id=G-0JKBJ3WRJZ"></script>
6
+ <script>
7
+ window.dataLayer = window.dataLayer || [];
8
+ function gtag(){dataLayer.push(arguments);}
9
+ gtag('js', new Date());
10
+ gtag('config', 'G-0JKBJ3WRJZ');
11
+ </script>
12
+
13
+ <link href='https://fonts.googleapis.com/css?family=Source+Sans+Pro' rel='stylesheet' type='text/css'>
14
+ <meta charset="UTF-8">
15
+ <title>MMAudio</title>
16
+
17
+ <meta name="viewport" content="width=device-width, initial-scale=1">
18
+ <!-- CSS only -->
19
+ <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.0.1/dist/css/bootstrap.min.css" rel="stylesheet"
20
+ integrity="sha384-+0n0xVW2eSR5OomGNYDnhzAbDsOXxcvSN1TPprVMTNDbiYZCxYbOOl7+AMvyTG2x" crossorigin="anonymous">
21
+ <script src="https://ajax.googleapis.com/ajax/libs/jquery/3.5.1/jquery.min.js"></script>
22
+
23
+ <link rel="stylesheet" href="style_videos.css">
24
+ </head>
25
+ <body>
26
+
27
+ <div id="vgg1">
28
+ <h2 style="text-align: center;">Comparisons with state-of-the-art methods in VGGSound</h2>
29
+ <p style="overflow: hidden;">
30
+ Example 1: Wolf howling.
31
+ <span style="float:right;"><a href="#index">Back to index</a></span>
32
+ </p>
33
+ <div class="row g-1">
34
+ <div class="col-sm-3">
35
+ <div class="video-header">Ground-truth</div>
36
+ <div class="video-container">
37
+ <iframe src="https://youtube.com/embed/9J_V74gqMUA"></iframe>
38
+ </div>
39
+ </div>
40
+ <div class="col-sm-3">
41
+ <div class="video-header">Ours</div>
42
+ <div class="video-container">
43
+ <iframe src="https://youtube.com/embed/P6O8IpjErPc"></iframe>
44
+ </div>
45
+ </div>
46
+ <div class="col-sm-3">
47
+ <div class="video-header">V2A-Mapper</div>
48
+ <div class="video-container">
49
+ <iframe src="https://youtube.com/embed/w-5eyqepvTk"></iframe>
50
+ </div>
51
+ </div>
52
+ <div class="col-sm-3">
53
+ <div class="video-header">FoleyCrafter</div>
54
+ <div class="video-container">
55
+ <iframe src="https://youtube.com/embed/VOLfoZlRkzo"></iframe>
56
+ </div>
57
+ </div>
58
+ </div>
59
+ <div class="row g-1">
60
+ <div class="col-sm-3">
61
+ <div class="video-header">Frieren</div>
62
+ <div class="video-container">
63
+ <iframe src="https://youtube.com/embed/49owKyA5Pa8"></iframe>
64
+ </div>
65
+ </div>
66
+ <div class="col-sm-3">
67
+ <div class="video-header">VATT</div>
68
+ <div class="video-container">
69
+ <iframe src="https://youtube.com/embed/QVtrFgbeGDM"></iframe>
70
+ </div>
71
+ </div>
72
+ <div class="col-sm-3">
73
+ <div class="video-header">V-AURA</div>
74
+ <div class="video-container">
75
+ <iframe src="https://youtube.com/embed/8r0uEfSNjvI"></iframe>
76
+ </div>
77
+ </div>
78
+ <div class="col-sm-3">
79
+ <div class="video-header">Seeing and Hearing</div>
80
+ <div class="video-container">
81
+ <iframe src="https://youtube.com/embed/bn-sLg2qulk"></iframe>
82
+ </div>
83
+ </div>
84
+ </div>
85
+ </div>
86
+
87
+ <div id="vgg2">
88
+ <h2 style="text-align: center;">Comparisons with state-of-the-art methods in VGGSound</h2>
89
+ <p style="overflow: hidden;">
90
+ Example 2: Striking a golf ball.
91
+ <span style="float:right;"><a href="#index">Back to index</a></span>
92
+ </p>
93
+
94
+ <div class="row g-1">
95
+ <div class="col-sm-3">
96
+ <div class="video-header">Ground-truth</div>
97
+ <div class="video-container">
98
+ <iframe src="https://youtube.com/embed/1hwSu42kkho"></iframe>
99
+ </div>
100
+ </div>
101
+ <div class="col-sm-3">
102
+ <div class="video-header">Ours</div>
103
+ <div class="video-container">
104
+ <iframe src="https://youtube.com/embed/kZibDoDCNxI"></iframe>
105
+ </div>
106
+ </div>
107
+ <div class="col-sm-3">
108
+ <div class="video-header">V2A-Mapper</div>
109
+ <div class="video-container">
110
+ <iframe src="https://youtube.com/embed/jgKfLBLhh7Y"></iframe>
111
+ </div>
112
+ </div>
113
+ <div class="col-sm-3">
114
+ <div class="video-header">FoleyCrafter</div>
115
+ <div class="video-container">
116
+ <iframe src="https://youtube.com/embed/Lfsx8mOPcJo"></iframe>
117
+ </div>
118
+ </div>
119
+ </div>
120
+ <div class="row g-1">
121
+ <div class="col-sm-3">
122
+ <div class="video-header">Frieren</div>
123
+ <div class="video-container">
124
+ <iframe src="https://youtube.com/embed/tz-LpbB0MBc"></iframe>
125
+ </div>
126
+ </div>
127
+ <div class="col-sm-3">
128
+ <div class="video-header">VATT</div>
129
+ <div class="video-container">
130
+ <iframe src="https://youtube.com/embed/RTDUHMi08n4"></iframe>
131
+ </div>
132
+ </div>
133
+ <div class="col-sm-3">
134
+ <div class="video-header">V-AURA</div>
135
+ <div class="video-container">
136
+ <iframe src="https://youtube.com/embed/N-3TDOsPnZQ"></iframe>
137
+ </div>
138
+ </div>
139
+ <div class="col-sm-3">
140
+ <div class="video-header">Seeing and Hearing</div>
141
+ <div class="video-container">
142
+ <iframe src="https://youtube.com/embed/QnsHnLn4gB0"></iframe>
143
+ </div>
144
+ </div>
145
+ </div>
146
+ </div>
147
+
148
+ <div id="vgg3">
149
+ <h2 style="text-align: center;">Comparisons with state-of-the-art methods in VGGSound</h2>
150
+ <p style="overflow: hidden;">
151
+ Example 3: Hitting a drum.
152
+ <span style="float:right;"><a href="#index">Back to index</a></span>
153
+ </p>
154
+
155
+ <div class="row g-1">
156
+ <div class="col-sm-3">
157
+ <div class="video-header">Ground-truth</div>
158
+ <div class="video-container">
159
+ <iframe src="https://youtube.com/embed/0oeIwq77w0Q"></iframe>
160
+ </div>
161
+ </div>
162
+ <div class="col-sm-3">
163
+ <div class="video-header">Ours</div>
164
+ <div class="video-container">
165
+ <iframe src="https://youtube.com/embed/-UtPV9ohuIM"></iframe>
166
+ </div>
167
+ </div>
168
+ <div class="col-sm-3">
169
+ <div class="video-header">V2A-Mapper</div>
170
+ <div class="video-container">
171
+ <iframe src="https://youtube.com/embed/9yivkgN-zwc"></iframe>
172
+ </div>
173
+ </div>
174
+ <div class="col-sm-3">
175
+ <div class="video-header">FoleyCrafter</div>
176
+ <div class="video-container">
177
+ <iframe src="https://youtube.com/embed/kkCsXPOlBvY"></iframe>
178
+ </div>
179
+ </div>
180
+ </div>
181
+ <div class="row g-1">
182
+ <div class="col-sm-3">
183
+ <div class="video-header">Frieren</div>
184
+ <div class="video-container">
185
+ <iframe src="https://youtube.com/embed/MbNKsVsuvig"></iframe>
186
+ </div>
187
+ </div>
188
+ <div class="col-sm-3">
189
+ <div class="video-header">VATT</div>
190
+ <div class="video-container">
191
+ <iframe src="https://youtube.com/embed/2yYviBjrpBw"></iframe>
192
+ </div>
193
+ </div>
194
+ <div class="col-sm-3">
195
+ <div class="video-header">V-AURA</div>
196
+ <div class="video-container">
197
+ <iframe src="https://youtube.com/embed/9yivkgN-zwc"></iframe>
198
+ </div>
199
+ </div>
200
+ <div class="col-sm-3">
201
+ <div class="video-header">Seeing and Hearing</div>
202
+ <div class="video-container">
203
+ <iframe src="https://youtube.com/embed/6dnyQt4Fuhs"></iframe>
204
+ </div>
205
+ </div>
206
+ </div>
207
+ </div>
208
+ </div>
209
+
210
+ <div id="vgg4">
211
+ <h2 style="text-align: center;">Comparisons with state-of-the-art methods in VGGSound</h2>
212
+ <p style="overflow: hidden;">
213
+ Example 4: Dog barking.
214
+ <span style="float:right;"><a href="#index">Back to index</a></span>
215
+ </p>
216
+
217
+ <div class="row g-1">
218
+ <div class="col-sm-3">
219
+ <div class="video-header">Ground-truth</div>
220
+ <div class="video-container">
221
+ <iframe src="https://youtube.com/embed/ckaqvTyMYAw"></iframe>
222
+ </div>
223
+ </div>
224
+ <div class="col-sm-3">
225
+ <div class="video-header">Ours</div>
226
+ <div class="video-container">
227
+ <iframe src="https://youtube.com/embed/_aRndFZzZ-I"></iframe>
228
+ </div>
229
+ </div>
230
+ <div class="col-sm-3">
231
+ <div class="video-header">V2A-Mapper</div>
232
+ <div class="video-container">
233
+ <iframe src="https://youtube.com/embed/mNCISP3LBl0"></iframe>
234
+ </div>
235
+ </div>
236
+ <div class="col-sm-3">
237
+ <div class="video-header">FoleyCrafter</div>
238
+ <div class="video-container">
239
+ <iframe src="https://youtube.com/embed/phZBQ3L7foE"></iframe>
240
+ </div>
241
+ </div>
242
+ </div>
243
+ <div class="row g-1">
244
+ <div class="col-sm-3">
245
+ <div class="video-header">Frieren</div>
246
+ <div class="video-container">
247
+ <iframe src="https://youtube.com/embed/Sb5Mg1-ORao"></iframe>
248
+ </div>
249
+ </div>
250
+ <div class="col-sm-3">
251
+ <div class="video-header">VATT</div>
252
+ <div class="video-container">
253
+ <iframe src="https://youtube.com/embed/eHmAGOmtDDg"></iframe>
254
+ </div>
255
+ </div>
256
+ <div class="col-sm-3">
257
+ <div class="video-header">V-AURA</div>
258
+ <div class="video-container">
259
+ <iframe src="https://youtube.com/embed/NEGa3krBrm0"></iframe>
260
+ </div>
261
+ </div>
262
+ <div class="col-sm-3">
263
+ <div class="video-header">Seeing and Hearing</div>
264
+ <div class="video-container">
265
+ <iframe src="https://youtube.com/embed/aO0EAXlwE7A"></iframe>
266
+ </div>
267
+ </div>
268
+ </div>
269
+ </div>
270
+
271
+ <div id="vgg5">
272
+ <h2 style="text-align: center;">Comparisons with state-of-the-art methods in VGGSound</h2>
273
+ <p style="overflow: hidden;">
274
+ Example 5: Playing a string instrument.
275
+ <span style="float:right;"><a href="#index">Back to index</a></span>
276
+ </p>
277
+
278
+ <div class="row g-1">
279
+ <div class="col-sm-3">
280
+ <div class="video-header">Ground-truth</div>
281
+ <div class="video-container">
282
+ <iframe src="https://youtube.com/embed/KP1QhWauIOc"></iframe>
283
+ </div>
284
+ </div>
285
+ <div class="col-sm-3">
286
+ <div class="video-header">Ours</div>
287
+ <div class="video-container">
288
+ <iframe src="https://youtube.com/embed/ovaJhWSquYE"></iframe>
289
+ </div>
290
+ </div>
291
+ <div class="col-sm-3">
292
+ <div class="video-header">V2A-Mapper</div>
293
+ <div class="video-container">
294
+ <iframe src="https://youtube.com/embed/N723FS9lcy8"></iframe>
295
+ </div>
296
+ </div>
297
+ <div class="col-sm-3">
298
+ <div class="video-header">FoleyCrafter</div>
299
+ <div class="video-container">
300
+ <iframe src="https://youtube.com/embed/t0N4ZAAXo58"></iframe>
301
+ </div>
302
+ </div>
303
+ </div>
304
+ <div class="row g-1">
305
+ <div class="col-sm-3">
306
+ <div class="video-header">Frieren</div>
307
+ <div class="video-container">
308
+ <iframe src="https://youtube.com/embed/8YSRs03QNNA"></iframe>
309
+ </div>
310
+ </div>
311
+ <div class="col-sm-3">
312
+ <div class="video-header">VATT</div>
313
+ <div class="video-container">
314
+ <iframe src="https://youtube.com/embed/vOpMz55J1kY"></iframe>
315
+ </div>
316
+ </div>
317
+ <div class="col-sm-3">
318
+ <div class="video-header">V-AURA</div>
319
+ <div class="video-container">
320
+ <iframe src="https://youtube.com/embed/9JHC75vr9h0"></iframe>
321
+ </div>
322
+ </div>
323
+ <div class="col-sm-3">
324
+ <div class="video-header">Seeing and Hearing</div>
325
+ <div class="video-container">
326
+ <iframe src="https://youtube.com/embed/9w0JckNzXmY"></iframe>
327
+ </div>
328
+ </div>
329
+ </div>
330
+ </div>
331
+
332
+ <div id="vgg6">
333
+ <h2 style="text-align: center;">Comparisons with state-of-the-art methods in VGGSound</h2>
334
+ <p style="overflow: hidden;">
335
+ Example 6: A group of people playing tambourines.
336
+ <span style="float:right;"><a href="#index">Back to index</a></span>
337
+ </p>
338
+
339
+ <div class="row g-1">
340
+ <div class="col-sm-3">
341
+ <div class="video-header">Ground-truth</div>
342
+ <div class="video-container">
343
+ <iframe src="https://youtube.com/embed/mx6JLxzUkRc"></iframe>
344
+ </div>
345
+ </div>
346
+ <div class="col-sm-3">
347
+ <div class="video-header">Ours</div>
348
+ <div class="video-container">
349
+ <iframe src="https://youtube.com/embed/oLirHhP9Su8"></iframe>
350
+ </div>
351
+ </div>
352
+ <div class="col-sm-3">
353
+ <div class="video-header">V2A-Mapper</div>
354
+ <div class="video-container">
355
+ <iframe src="https://youtube.com/embed/HkLkHMqptv0"></iframe>
356
+ </div>
357
+ </div>
358
+ <div class="col-sm-3">
359
+ <div class="video-header">FoleyCrafter</div>
360
+ <div class="video-container">
361
+ <iframe src="https://youtube.com/embed/rpHiiODjmNU"></iframe>
362
+ </div>
363
+ </div>
364
+ </div>
365
+ <div class="row g-1">
366
+ <div class="col-sm-3">
367
+ <div class="video-header">Frieren</div>
368
+ <div class="video-container">
369
+ <iframe src="https://youtube.com/embed/1mVD3fJ0LpM"></iframe>
370
+ </div>
371
+ </div>
372
+ <div class="col-sm-3">
373
+ <div class="video-header">VATT</div>
374
+ <div class="video-container">
375
+ <iframe src="https://youtube.com/embed/yjVFnJiEJlw"></iframe>
376
+ </div>
377
+ </div>
378
+ <div class="col-sm-3">
379
+ <div class="video-header">V-AURA</div>
380
+ <div class="video-container">
381
+ <iframe src="https://youtube.com/embed/neVeMSWtRkU"></iframe>
382
+ </div>
383
+ </div>
384
+ <div class="col-sm-3">
385
+ <div class="video-header">Seeing and Hearing</div>
386
+ <div class="video-container">
387
+ <iframe src="https://youtube.com/embed/EUE7YwyVWz8"></iframe>
388
+ </div>
389
+ </div>
390
+ </div>
391
+ </div>
392
+
393
+ <div id="vgg_extra">
394
+ <h2 style="text-align: center;">Comparisons with state-of-the-art methods in VGGSound</h2>
395
+ <p style="overflow: hidden;">
396
+ <span style="float:right;"><a href="#index">Back to index</a></span>
397
+ </p>
398
+
399
+ <div class="row g-1">
400
+ <div class="col-sm-3">
401
+ <div class="video-header">Moving train</div>
402
+ <div class="video-container">
403
+ <iframe src="https://youtube.com/embed/Ta6H45rBzJc"></iframe>
404
+ </div>
405
+ </div>
406
+ <div class="col-sm-3">
407
+ <div class="video-header">Water splashing</div>
408
+ <div class="video-container">
409
+ <iframe src="https://youtube.com/embed/hl6AtgHXpb4"></iframe>
410
+ </div>
411
+ </div>
412
+ <div class="col-sm-3">
413
+ <div class="video-header">Skateboarding</div>
414
+ <div class="video-container">
415
+ <iframe src="https://youtube.com/embed/n4sCNi_9buI"></iframe>
416
+ </div>
417
+ </div>
418
+ <div class="col-sm-3">
419
+ <div class="video-header">Synchronized clapping</div>
420
+ <div class="video-container">
421
+ <iframe src="https://youtube.com/embed/oxexfpLn7FE"></iframe>
422
+ </div>
423
+ </div>
424
+ </div>
425
+
426
+ <br><br>
427
+
428
+ <div id="extra-failure">
429
+ <h2 style="text-align: center;">Failure cases</h2>
430
+ <p style="overflow: hidden;">
431
+ <span style="float:right;"><a href="#index">Back to index</a></span>
432
+ </p>
433
+
434
+ <div class="row g-1">
435
+ <div class="col-sm-6">
436
+ <div class="video-header">Human speech</div>
437
+ <div class="video-container">
438
+ <iframe src="https://youtube.com/embed/nx0CyrDu70Y"></iframe>
439
+ </div>
440
+ </div>
441
+ <div class="col-sm-6">
442
+ <div class="video-header">Unfamiliar vision input</div>
443
+ <div class="video-container">
444
+ <iframe src="https://youtube.com/embed/hfnAqmK3X7w"></iframe>
445
+ </div>
446
+ </div>
447
+ </div>
448
+ </div>
449
+ </div>
450
+
451
+ </body>
452
+ </html>
mmaudio/__init__.py ADDED
File without changes
mmaudio/eval_utils.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import logging
3
+ from pathlib import Path
4
+ from typing import Optional
5
+
6
+ import torch
7
+ from colorlog import ColoredFormatter
8
+ from torchvision.transforms import v2
9
+ from torio.io import StreamingMediaDecoder, StreamingMediaEncoder
10
+
11
+ from mmaudio.model.flow_matching import FlowMatching
12
+ from mmaudio.model.networks import MMAudio
13
+ from mmaudio.model.sequence_config import (CONFIG_16K, CONFIG_44K, SequenceConfig)
14
+ from mmaudio.model.utils.features_utils import FeaturesUtils
15
+ from mmaudio.utils.download_utils import download_model_if_needed
16
+
17
+ log = logging.getLogger()
18
+
19
+
20
+ @dataclasses.dataclass
21
+ class ModelConfig:
22
+ model_name: str
23
+ model_path: Path
24
+ vae_path: Path
25
+ bigvgan_16k_path: Optional[Path]
26
+ mode: str
27
+ synchformer_ckpt: Path = Path('./ext_weights/synchformer_state_dict.pth')
28
+
29
+ @property
30
+ def seq_cfg(self) -> SequenceConfig:
31
+ if self.mode == '16k':
32
+ return CONFIG_16K
33
+ elif self.mode == '44k':
34
+ return CONFIG_44K
35
+
36
+ def download_if_needed(self):
37
+ download_model_if_needed(self.model_path)
38
+ download_model_if_needed(self.vae_path)
39
+ if self.bigvgan_16k_path is not None:
40
+ download_model_if_needed(self.bigvgan_16k_path)
41
+ download_model_if_needed(self.synchformer_ckpt)
42
+
43
+
44
+ small_16k = ModelConfig(model_name='small_16k',
45
+ model_path=Path('./weights/mmaudio_small_16k.pth'),
46
+ vae_path=Path('./ext_weights/v1-16.pth'),
47
+ bigvgan_16k_path=Path('./ext_weights/best_netG.pt'),
48
+ mode='16k')
49
+ small_44k = ModelConfig(model_name='small_44k',
50
+ model_path=Path('./weights/mmaudio_small_44k.pth'),
51
+ vae_path=Path('./ext_weights/v1-44.pth'),
52
+ bigvgan_16k_path=None,
53
+ mode='44k')
54
+ medium_44k = ModelConfig(model_name='medium_44k',
55
+ model_path=Path('./weights/mmaudio_medium_44k.pth'),
56
+ vae_path=Path('./ext_weights/v1-44.pth'),
57
+ bigvgan_16k_path=None,
58
+ mode='44k')
59
+ large_44k = ModelConfig(model_name='large_44k',
60
+ model_path=Path('./weights/mmaudio_large_44k.pth'),
61
+ vae_path=Path('./ext_weights/v1-44.pth'),
62
+ bigvgan_16k_path=None,
63
+ mode='44k')
64
+ large_44k_v2 = ModelConfig(model_name='large_44k_v2',
65
+ model_path=Path('./weights/mmaudio_large_44k_v2.pth'),
66
+ vae_path=Path('./ext_weights/v1-44.pth'),
67
+ bigvgan_16k_path=None,
68
+ mode='44k')
69
+ all_model_cfg: dict[str, ModelConfig] = {
70
+ 'small_16k': small_16k,
71
+ 'small_44k': small_44k,
72
+ 'medium_44k': medium_44k,
73
+ 'large_44k': large_44k,
74
+ 'large_44k_v2': large_44k_v2,
75
+ }
76
+
77
+
78
+ def generate(clip_video: Optional[torch.Tensor],
79
+ sync_video: Optional[torch.Tensor],
80
+ text: Optional[list[str]],
81
+ *,
82
+ negative_text: Optional[list[str]] = None,
83
+ feature_utils: FeaturesUtils,
84
+ net: MMAudio,
85
+ fm: FlowMatching,
86
+ rng: torch.Generator,
87
+ cfg_strength: float):
88
+ device = feature_utils.device
89
+ dtype = feature_utils.dtype
90
+
91
+ bs = len(text)
92
+ if clip_video is not None:
93
+ clip_video = clip_video.to(device, dtype, non_blocking=True)
94
+ clip_features = feature_utils.encode_video_with_clip(clip_video, batch_size=bs)
95
+ else:
96
+ clip_features = net.get_empty_clip_sequence(bs)
97
+
98
+ if sync_video is not None:
99
+ sync_video = sync_video.to(device, dtype, non_blocking=True)
100
+ sync_features = feature_utils.encode_video_with_sync(sync_video, batch_size=bs)
101
+ else:
102
+ sync_features = net.get_empty_sync_sequence(bs)
103
+
104
+ if text is not None:
105
+ text_features = feature_utils.encode_text(text)
106
+ else:
107
+ text_features = net.get_empty_string_sequence(bs)
108
+
109
+ if negative_text is not None:
110
+ assert len(negative_text) == bs
111
+ negative_text_features = feature_utils.encode_text(negative_text)
112
+ else:
113
+ negative_text_features = net.get_empty_string_sequence(bs)
114
+
115
+ x0 = torch.randn(bs,
116
+ net.latent_seq_len,
117
+ net.latent_dim,
118
+ device=device,
119
+ dtype=dtype,
120
+ generator=rng)
121
+ preprocessed_conditions = net.preprocess_conditions(clip_features, sync_features, text_features)
122
+ empty_conditions = net.get_empty_conditions(
123
+ bs, negative_text_features=negative_text_features if negative_text is not None else None)
124
+
125
+ cfg_ode_wrapper = lambda t, x: net.ode_wrapper(t, x, preprocessed_conditions, empty_conditions,
126
+ cfg_strength)
127
+ x1 = fm.to_data(cfg_ode_wrapper, x0)
128
+ x1 = net.unnormalize(x1)
129
+ spec = feature_utils.decode(x1)
130
+ audio = feature_utils.vocode(spec)
131
+ return audio
132
+
133
+
134
+ LOGFORMAT = " %(log_color)s%(levelname)-8s%(reset)s | %(log_color)s%(message)s%(reset)s"
135
+
136
+
137
+ def setup_eval_logging(log_level: int = logging.INFO):
138
+ logging.root.setLevel(log_level)
139
+ formatter = ColoredFormatter(LOGFORMAT)
140
+ stream = logging.StreamHandler()
141
+ stream.setLevel(log_level)
142
+ stream.setFormatter(formatter)
143
+ log = logging.getLogger()
144
+ log.setLevel(log_level)
145
+ log.addHandler(stream)
146
+
147
+
148
+ def load_video(video_path: Path, duration_sec: float) -> tuple[torch.Tensor, torch.Tensor, float]:
149
+ _CLIP_SIZE = 384
150
+ _CLIP_FPS = 8.0
151
+
152
+ _SYNC_SIZE = 224
153
+ _SYNC_FPS = 25.0
154
+
155
+ clip_transform = v2.Compose([
156
+ v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
157
+ v2.ToImage(),
158
+ v2.ToDtype(torch.float32, scale=True),
159
+ ])
160
+
161
+ sync_transform = v2.Compose([
162
+ v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
163
+ v2.CenterCrop(_SYNC_SIZE),
164
+ v2.ToImage(),
165
+ v2.ToDtype(torch.float32, scale=True),
166
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
167
+ ])
168
+
169
+ reader = StreamingMediaDecoder(video_path)
170
+ reader.add_basic_video_stream(
171
+ frames_per_chunk=int(_CLIP_FPS * duration_sec),
172
+ frame_rate=_CLIP_FPS,
173
+ format='rgb24',
174
+ )
175
+ reader.add_basic_video_stream(
176
+ frames_per_chunk=int(_SYNC_FPS * duration_sec),
177
+ frame_rate=_SYNC_FPS,
178
+ format='rgb24',
179
+ )
180
+
181
+ reader.fill_buffer()
182
+ data_chunk = reader.pop_chunks()
183
+ clip_chunk = data_chunk[0]
184
+ sync_chunk = data_chunk[1]
185
+ assert clip_chunk is not None
186
+ assert sync_chunk is not None
187
+
188
+ clip_frames = clip_transform(clip_chunk)
189
+ sync_frames = sync_transform(sync_chunk)
190
+
191
+ clip_length_sec = clip_frames.shape[0] / _CLIP_FPS
192
+ sync_length_sec = sync_frames.shape[0] / _SYNC_FPS
193
+
194
+ if clip_length_sec < duration_sec:
195
+ log.warning(f'Clip video is too short: {clip_length_sec:.2f} < {duration_sec:.2f}')
196
+ log.warning(f'Truncating to {clip_length_sec:.2f} sec')
197
+ duration_sec = clip_length_sec
198
+
199
+ if sync_length_sec < duration_sec:
200
+ log.warning(f'Sync video is too short: {sync_length_sec:.2f} < {duration_sec:.2f}')
201
+ log.warning(f'Truncating to {sync_length_sec:.2f} sec')
202
+ duration_sec = sync_length_sec
203
+
204
+ clip_frames = clip_frames[:int(_CLIP_FPS * duration_sec)]
205
+ sync_frames = sync_frames[:int(_SYNC_FPS * duration_sec)]
206
+
207
+ return clip_frames, sync_frames, duration_sec
208
+
209
+
210
+ def make_video(video_path: Path, output_path: Path, audio: torch.Tensor, sampling_rate: int,
211
+ duration_sec: float):
212
+
213
+ approx_max_length = int(duration_sec * 60)
214
+ reader = StreamingMediaDecoder(video_path)
215
+ reader.add_basic_video_stream(
216
+ frames_per_chunk=approx_max_length,
217
+ format='rgb24',
218
+ )
219
+ reader.fill_buffer()
220
+ video_chunk = reader.pop_chunks()[0]
221
+ assert video_chunk is not None
222
+
223
+ fps = int(reader.get_out_stream_info(0).frame_rate)
224
+ if fps > 60:
225
+ log.warning(f'This code supports only up to 60 fps, but the video has {fps} fps')
226
+ log.warning(f'Just change the *60 above me')
227
+
228
+ h, w = video_chunk.shape[-2:]
229
+ video_chunk = video_chunk[:int(fps * duration_sec)]
230
+
231
+ writer = StreamingMediaEncoder(output_path)
232
+ writer.add_audio_stream(
233
+ sample_rate=sampling_rate,
234
+ num_channels=audio.shape[0],
235
+ encoder='aac', # 'flac' does not work for some reason?
236
+ )
237
+ writer.add_video_stream(frame_rate=fps,
238
+ width=w,
239
+ height=h,
240
+ format='rgb24',
241
+ encoder='libx264',
242
+ encoder_format='yuv420p')
243
+ with writer.open():
244
+ writer.write_audio_chunk(0, audio.float().transpose(0, 1))
245
+ writer.write_video_chunk(1, video_chunk)
mmaudio/ext/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
mmaudio/ext/autoencoder/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .autoencoder import AutoEncoderModule
mmaudio/ext/autoencoder/autoencoder.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from mmaudio.ext.autoencoder.vae import VAE, get_my_vae
7
+ from mmaudio.ext.bigvgan import BigVGAN
8
+ from mmaudio.ext.bigvgan_v2.bigvgan import BigVGAN as BigVGANv2
9
+ from mmaudio.model.utils.distributions import DiagonalGaussianDistribution
10
+
11
+
12
+ class AutoEncoderModule(nn.Module):
13
+
14
+ def __init__(self,
15
+ *,
16
+ vae_ckpt_path,
17
+ vocoder_ckpt_path: Optional[str] = None,
18
+ mode: Literal['16k', '44k']):
19
+ super().__init__()
20
+ self.vae: VAE = get_my_vae(mode).eval()
21
+ vae_state_dict = torch.load(vae_ckpt_path, weights_only=True, map_location='cpu')
22
+ self.vae.load_state_dict(vae_state_dict)
23
+ self.vae.remove_weight_norm()
24
+
25
+ if mode == '16k':
26
+ assert vocoder_ckpt_path is not None
27
+ self.vocoder = BigVGAN(vocoder_ckpt_path).eval()
28
+ elif mode == '44k':
29
+ self.vocoder = BigVGANv2.from_pretrained('nvidia/bigvgan_v2_44khz_128band_512x',
30
+ use_cuda_kernel=False)
31
+ self.vocoder.remove_weight_norm()
32
+ else:
33
+ raise ValueError(f'Unknown mode: {mode}')
34
+
35
+ for param in self.parameters():
36
+ param.requires_grad = False
37
+
38
+ @torch.inference_mode()
39
+ def encode(self, x: torch.Tensor) -> DiagonalGaussianDistribution:
40
+ return self.vae.encode(x)
41
+
42
+ @torch.inference_mode()
43
+ def decode(self, z: torch.Tensor) -> torch.Tensor:
44
+ return self.vae.decode(z)
45
+
46
+ @torch.inference_mode()
47
+ def vocode(self, spec: torch.Tensor) -> torch.Tensor:
48
+ return self.vocoder(spec)
mmaudio/ext/autoencoder/edm2_utils.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+ """Improved diffusion model architecture proposed in the paper
8
+ "Analyzing and Improving the Training Dynamics of Diffusion Models"."""
9
+
10
+ import numpy as np
11
+ import torch
12
+
13
+ #----------------------------------------------------------------------------
14
+ # Variant of constant() that inherits dtype and device from the given
15
+ # reference tensor by default.
16
+
17
+ _constant_cache = dict()
18
+
19
+
20
+ def constant(value, shape=None, dtype=None, device=None, memory_format=None):
21
+ value = np.asarray(value)
22
+ if shape is not None:
23
+ shape = tuple(shape)
24
+ if dtype is None:
25
+ dtype = torch.get_default_dtype()
26
+ if device is None:
27
+ device = torch.device('cpu')
28
+ if memory_format is None:
29
+ memory_format = torch.contiguous_format
30
+
31
+ key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
32
+ tensor = _constant_cache.get(key, None)
33
+ if tensor is None:
34
+ tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
35
+ if shape is not None:
36
+ tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
37
+ tensor = tensor.contiguous(memory_format=memory_format)
38
+ _constant_cache[key] = tensor
39
+ return tensor
40
+
41
+
42
+ def const_like(ref, value, shape=None, dtype=None, device=None, memory_format=None):
43
+ if dtype is None:
44
+ dtype = ref.dtype
45
+ if device is None:
46
+ device = ref.device
47
+ return constant(value, shape=shape, dtype=dtype, device=device, memory_format=memory_format)
48
+
49
+
50
+ #----------------------------------------------------------------------------
51
+ # Normalize given tensor to unit magnitude with respect to the given
52
+ # dimensions. Default = all dimensions except the first.
53
+
54
+
55
+ def normalize(x, dim=None, eps=1e-4):
56
+ if dim is None:
57
+ dim = list(range(1, x.ndim))
58
+ norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
59
+ norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel()))
60
+ return x / norm.to(x.dtype)
61
+
62
+
63
+ class Normalize(torch.nn.Module):
64
+
65
+ def __init__(self, dim=None, eps=1e-4):
66
+ super().__init__()
67
+ self.dim = dim
68
+ self.eps = eps
69
+
70
+ def forward(self, x):
71
+ return normalize(x, dim=self.dim, eps=self.eps)
72
+
73
+
74
+ #----------------------------------------------------------------------------
75
+ # Upsample or downsample the given tensor with the given filter,
76
+ # or keep it as is.
77
+
78
+
79
+ def resample(x, f=[1, 1], mode='keep'):
80
+ if mode == 'keep':
81
+ return x
82
+ f = np.float32(f)
83
+ assert f.ndim == 1 and len(f) % 2 == 0
84
+ pad = (len(f) - 1) // 2
85
+ f = f / f.sum()
86
+ f = np.outer(f, f)[np.newaxis, np.newaxis, :, :]
87
+ f = const_like(x, f)
88
+ c = x.shape[1]
89
+ if mode == 'down':
90
+ return torch.nn.functional.conv2d(x,
91
+ f.tile([c, 1, 1, 1]),
92
+ groups=c,
93
+ stride=2,
94
+ padding=(pad, ))
95
+ assert mode == 'up'
96
+ return torch.nn.functional.conv_transpose2d(x, (f * 4).tile([c, 1, 1, 1]),
97
+ groups=c,
98
+ stride=2,
99
+ padding=(pad, ))
100
+
101
+
102
+ #----------------------------------------------------------------------------
103
+ # Magnitude-preserving SiLU (Equation 81).
104
+
105
+
106
+ def mp_silu(x):
107
+ return torch.nn.functional.silu(x) / 0.596
108
+
109
+
110
+ class MPSiLU(torch.nn.Module):
111
+
112
+ def forward(self, x):
113
+ return mp_silu(x)
114
+
115
+
116
+ #----------------------------------------------------------------------------
117
+ # Magnitude-preserving sum (Equation 88).
118
+
119
+
120
+ def mp_sum(a, b, t=0.5):
121
+ return a.lerp(b, t) / np.sqrt((1 - t)**2 + t**2)
122
+
123
+
124
+ #----------------------------------------------------------------------------
125
+ # Magnitude-preserving concatenation (Equation 103).
126
+
127
+
128
+ def mp_cat(a, b, dim=1, t=0.5):
129
+ Na = a.shape[dim]
130
+ Nb = b.shape[dim]
131
+ C = np.sqrt((Na + Nb) / ((1 - t)**2 + t**2))
132
+ wa = C / np.sqrt(Na) * (1 - t)
133
+ wb = C / np.sqrt(Nb) * t
134
+ return torch.cat([wa * a, wb * b], dim=dim)
135
+
136
+
137
+ #----------------------------------------------------------------------------
138
+ # Magnitude-preserving convolution or fully-connected layer (Equation 47)
139
+ # with force weight normalization (Equation 66).
140
+
141
+
142
+ class MPConv1D(torch.nn.Module):
143
+
144
+ def __init__(self, in_channels, out_channels, kernel_size):
145
+ super().__init__()
146
+ self.out_channels = out_channels
147
+ self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, kernel_size))
148
+
149
+ self.weight_norm_removed = False
150
+
151
+ def forward(self, x, gain=1):
152
+ assert self.weight_norm_removed, 'call remove_weight_norm() before inference'
153
+
154
+ w = self.weight * gain
155
+ if w.ndim == 2:
156
+ return x @ w.t()
157
+ assert w.ndim == 3
158
+ return torch.nn.functional.conv1d(x, w, padding=(w.shape[-1] // 2, ))
159
+
160
+ def remove_weight_norm(self):
161
+ w = self.weight.to(torch.float32)
162
+ w = normalize(w) # traditional weight normalization
163
+ w = w / np.sqrt(w[0].numel())
164
+ w = w.to(self.weight.dtype)
165
+ self.weight.data.copy_(w)
166
+
167
+ self.weight_norm_removed = True
168
+ return self
mmaudio/ext/autoencoder/vae.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from mmaudio.ext.autoencoder.edm2_utils import MPConv1D
8
+ from mmaudio.ext.autoencoder.vae_modules import (AttnBlock1D, Downsample1D, ResnetBlock1D,
9
+ Upsample1D, nonlinearity)
10
+ from mmaudio.model.utils.distributions import DiagonalGaussianDistribution
11
+
12
+ log = logging.getLogger()
13
+
14
+ DATA_MEAN_80D = [
15
+ -1.6058, -1.3676, -1.2520, -1.2453, -1.2078, -1.2224, -1.2419, -1.2439, -1.2922, -1.2927,
16
+ -1.3170, -1.3543, -1.3401, -1.3836, -1.3907, -1.3912, -1.4313, -1.4152, -1.4527, -1.4728,
17
+ -1.4568, -1.5101, -1.5051, -1.5172, -1.5623, -1.5373, -1.5746, -1.5687, -1.6032, -1.6131,
18
+ -1.6081, -1.6331, -1.6489, -1.6489, -1.6700, -1.6738, -1.6953, -1.6969, -1.7048, -1.7280,
19
+ -1.7361, -1.7495, -1.7658, -1.7814, -1.7889, -1.8064, -1.8221, -1.8377, -1.8417, -1.8643,
20
+ -1.8857, -1.8929, -1.9173, -1.9379, -1.9531, -1.9673, -1.9824, -2.0042, -2.0215, -2.0436,
21
+ -2.0766, -2.1064, -2.1418, -2.1855, -2.2319, -2.2767, -2.3161, -2.3572, -2.3954, -2.4282,
22
+ -2.4659, -2.5072, -2.5552, -2.6074, -2.6584, -2.7107, -2.7634, -2.8266, -2.8981, -2.9673
23
+ ]
24
+
25
+ DATA_STD_80D = [
26
+ 1.0291, 1.0411, 1.0043, 0.9820, 0.9677, 0.9543, 0.9450, 0.9392, 0.9343, 0.9297, 0.9276, 0.9263,
27
+ 0.9242, 0.9254, 0.9232, 0.9281, 0.9263, 0.9315, 0.9274, 0.9247, 0.9277, 0.9199, 0.9188, 0.9194,
28
+ 0.9160, 0.9161, 0.9146, 0.9161, 0.9100, 0.9095, 0.9145, 0.9076, 0.9066, 0.9095, 0.9032, 0.9043,
29
+ 0.9038, 0.9011, 0.9019, 0.9010, 0.8984, 0.8983, 0.8986, 0.8961, 0.8962, 0.8978, 0.8962, 0.8973,
30
+ 0.8993, 0.8976, 0.8995, 0.9016, 0.8982, 0.8972, 0.8974, 0.8949, 0.8940, 0.8947, 0.8936, 0.8939,
31
+ 0.8951, 0.8956, 0.9017, 0.9167, 0.9436, 0.9690, 1.0003, 1.0225, 1.0381, 1.0491, 1.0545, 1.0604,
32
+ 1.0761, 1.0929, 1.1089, 1.1196, 1.1176, 1.1156, 1.1117, 1.1070
33
+ ]
34
+
35
+ DATA_MEAN_128D = [
36
+ -3.3462, -2.6723, -2.4893, -2.3143, -2.2664, -2.3317, -2.1802, -2.4006, -2.2357, -2.4597,
37
+ -2.3717, -2.4690, -2.5142, -2.4919, -2.6610, -2.5047, -2.7483, -2.5926, -2.7462, -2.7033,
38
+ -2.7386, -2.8112, -2.7502, -2.9594, -2.7473, -3.0035, -2.8891, -2.9922, -2.9856, -3.0157,
39
+ -3.1191, -2.9893, -3.1718, -3.0745, -3.1879, -3.2310, -3.1424, -3.2296, -3.2791, -3.2782,
40
+ -3.2756, -3.3134, -3.3509, -3.3750, -3.3951, -3.3698, -3.4505, -3.4509, -3.5089, -3.4647,
41
+ -3.5536, -3.5788, -3.5867, -3.6036, -3.6400, -3.6747, -3.7072, -3.7279, -3.7283, -3.7795,
42
+ -3.8259, -3.8447, -3.8663, -3.9182, -3.9605, -3.9861, -4.0105, -4.0373, -4.0762, -4.1121,
43
+ -4.1488, -4.1874, -4.2461, -4.3170, -4.3639, -4.4452, -4.5282, -4.6297, -4.7019, -4.7960,
44
+ -4.8700, -4.9507, -5.0303, -5.0866, -5.1634, -5.2342, -5.3242, -5.4053, -5.4927, -5.5712,
45
+ -5.6464, -5.7052, -5.7619, -5.8410, -5.9188, -6.0103, -6.0955, -6.1673, -6.2362, -6.3120,
46
+ -6.3926, -6.4797, -6.5565, -6.6511, -6.8130, -6.9961, -7.1275, -7.2457, -7.3576, -7.4663,
47
+ -7.6136, -7.7469, -7.8815, -8.0132, -8.1515, -8.3071, -8.4722, -8.7418, -9.3975, -9.6628,
48
+ -9.7671, -9.8863, -9.9992, -10.0860, -10.1709, -10.5418, -11.2795, -11.3861
49
+ ]
50
+
51
+ DATA_STD_128D = [
52
+ 2.3804, 2.4368, 2.3772, 2.3145, 2.2803, 2.2510, 2.2316, 2.2083, 2.1996, 2.1835, 2.1769, 2.1659,
53
+ 2.1631, 2.1618, 2.1540, 2.1606, 2.1571, 2.1567, 2.1612, 2.1579, 2.1679, 2.1683, 2.1634, 2.1557,
54
+ 2.1668, 2.1518, 2.1415, 2.1449, 2.1406, 2.1350, 2.1313, 2.1415, 2.1281, 2.1352, 2.1219, 2.1182,
55
+ 2.1327, 2.1195, 2.1137, 2.1080, 2.1179, 2.1036, 2.1087, 2.1036, 2.1015, 2.1068, 2.0975, 2.0991,
56
+ 2.0902, 2.1015, 2.0857, 2.0920, 2.0893, 2.0897, 2.0910, 2.0881, 2.0925, 2.0873, 2.0960, 2.0900,
57
+ 2.0957, 2.0958, 2.0978, 2.0936, 2.0886, 2.0905, 2.0845, 2.0855, 2.0796, 2.0840, 2.0813, 2.0817,
58
+ 2.0838, 2.0840, 2.0917, 2.1061, 2.1431, 2.1976, 2.2482, 2.3055, 2.3700, 2.4088, 2.4372, 2.4609,
59
+ 2.4731, 2.4847, 2.5072, 2.5451, 2.5772, 2.6147, 2.6529, 2.6596, 2.6645, 2.6726, 2.6803, 2.6812,
60
+ 2.6899, 2.6916, 2.6931, 2.6998, 2.7062, 2.7262, 2.7222, 2.7158, 2.7041, 2.7485, 2.7491, 2.7451,
61
+ 2.7485, 2.7233, 2.7297, 2.7233, 2.7145, 2.6958, 2.6788, 2.6439, 2.6007, 2.4786, 2.2469, 2.1877,
62
+ 2.1392, 2.0717, 2.0107, 1.9676, 1.9140, 1.7102, 0.9101, 0.7164
63
+ ]
64
+
65
+
66
+ class VAE(nn.Module):
67
+
68
+ def __init__(
69
+ self,
70
+ *,
71
+ data_dim: int,
72
+ embed_dim: int,
73
+ hidden_dim: int,
74
+ ):
75
+ super().__init__()
76
+
77
+ if data_dim == 80:
78
+ self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_80D, dtype=torch.float32).cuda())
79
+ self.data_std = nn.Buffer(torch.tensor(DATA_STD_80D, dtype=torch.float32).cuda())
80
+ elif data_dim == 128:
81
+ self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_128D, dtype=torch.float32).cuda())
82
+ self.data_std = nn.Buffer(torch.tensor(DATA_STD_128D, dtype=torch.float32).cuda())
83
+
84
+ self.data_mean = self.data_mean.view(1, -1, 1)
85
+ self.data_std = self.data_std.view(1, -1, 1)
86
+
87
+ self.encoder = Encoder1D(
88
+ dim=hidden_dim,
89
+ ch_mult=(1, 2, 4),
90
+ num_res_blocks=2,
91
+ attn_layers=[3],
92
+ down_layers=[0],
93
+ in_dim=data_dim,
94
+ embed_dim=embed_dim,
95
+ )
96
+ self.decoder = Decoder1D(
97
+ dim=hidden_dim,
98
+ ch_mult=(1, 2, 4),
99
+ num_res_blocks=2,
100
+ attn_layers=[3],
101
+ down_layers=[0],
102
+ in_dim=data_dim,
103
+ out_dim=data_dim,
104
+ embed_dim=embed_dim,
105
+ )
106
+
107
+ self.embed_dim = embed_dim
108
+ # self.quant_conv = nn.Conv1d(2 * embed_dim, 2 * embed_dim, 1)
109
+ # self.post_quant_conv = nn.Conv1d(embed_dim, embed_dim, 1)
110
+
111
+ self.initialize_weights()
112
+
113
+ def initialize_weights(self):
114
+ pass
115
+
116
+ def encode(self, x: torch.Tensor, normalize: bool = True) -> DiagonalGaussianDistribution:
117
+ if normalize:
118
+ x = self.normalize(x)
119
+ moments = self.encoder(x)
120
+ posterior = DiagonalGaussianDistribution(moments)
121
+ return posterior
122
+
123
+ def decode(self, z: torch.Tensor, unnormalize: bool = True) -> torch.Tensor:
124
+ dec = self.decoder(z)
125
+ if unnormalize:
126
+ dec = self.unnormalize(dec)
127
+ return dec
128
+
129
+ def normalize(self, x: torch.Tensor) -> torch.Tensor:
130
+ return (x - self.data_mean) / self.data_std
131
+
132
+ def unnormalize(self, x: torch.Tensor) -> torch.Tensor:
133
+ return x * self.data_std + self.data_mean
134
+
135
+ def forward(
136
+ self,
137
+ x: torch.Tensor,
138
+ sample_posterior: bool = True,
139
+ rng: Optional[torch.Generator] = None,
140
+ normalize: bool = True,
141
+ unnormalize: bool = True,
142
+ ) -> tuple[torch.Tensor, DiagonalGaussianDistribution]:
143
+
144
+ posterior = self.encode(x, normalize=normalize)
145
+ if sample_posterior:
146
+ z = posterior.sample(rng)
147
+ else:
148
+ z = posterior.mode()
149
+ dec = self.decode(z, unnormalize=unnormalize)
150
+ return dec, posterior
151
+
152
+ def load_weights(self, src_dict) -> None:
153
+ self.load_state_dict(src_dict, strict=True)
154
+
155
+ @property
156
+ def device(self) -> torch.device:
157
+ return next(self.parameters()).device
158
+
159
+ def get_last_layer(self):
160
+ return self.decoder.conv_out.weight
161
+
162
+ def remove_weight_norm(self):
163
+ for name, m in self.named_modules():
164
+ if isinstance(m, MPConv1D):
165
+ m.remove_weight_norm()
166
+ log.debug(f"Removed weight norm from {name}")
167
+ return self
168
+
169
+
170
+ class Encoder1D(nn.Module):
171
+
172
+ def __init__(self,
173
+ *,
174
+ dim: int,
175
+ ch_mult: tuple[int] = (1, 2, 4, 8),
176
+ num_res_blocks: int,
177
+ attn_layers: list[int] = [],
178
+ down_layers: list[int] = [],
179
+ resamp_with_conv: bool = True,
180
+ in_dim: int,
181
+ embed_dim: int,
182
+ double_z: bool = True,
183
+ kernel_size: int = 3,
184
+ clip_act: float = 256.0):
185
+ super().__init__()
186
+ self.dim = dim
187
+ self.num_layers = len(ch_mult)
188
+ self.num_res_blocks = num_res_blocks
189
+ self.in_channels = in_dim
190
+ self.clip_act = clip_act
191
+ self.down_layers = down_layers
192
+ self.attn_layers = attn_layers
193
+ self.conv_in = MPConv1D(in_dim, self.dim, kernel_size=kernel_size)
194
+
195
+ in_ch_mult = (1, ) + tuple(ch_mult)
196
+ self.in_ch_mult = in_ch_mult
197
+ # downsampling
198
+ self.down = nn.ModuleList()
199
+ for i_level in range(self.num_layers):
200
+ block = nn.ModuleList()
201
+ attn = nn.ModuleList()
202
+ block_in = dim * in_ch_mult[i_level]
203
+ block_out = dim * ch_mult[i_level]
204
+ for i_block in range(self.num_res_blocks):
205
+ block.append(
206
+ ResnetBlock1D(in_dim=block_in,
207
+ out_dim=block_out,
208
+ kernel_size=kernel_size,
209
+ use_norm=True))
210
+ block_in = block_out
211
+ if i_level in attn_layers:
212
+ attn.append(AttnBlock1D(block_in))
213
+ down = nn.Module()
214
+ down.block = block
215
+ down.attn = attn
216
+ if i_level in down_layers:
217
+ down.downsample = Downsample1D(block_in, resamp_with_conv)
218
+ self.down.append(down)
219
+
220
+ # middle
221
+ self.mid = nn.Module()
222
+ self.mid.block_1 = ResnetBlock1D(in_dim=block_in,
223
+ out_dim=block_in,
224
+ kernel_size=kernel_size,
225
+ use_norm=True)
226
+ self.mid.attn_1 = AttnBlock1D(block_in)
227
+ self.mid.block_2 = ResnetBlock1D(in_dim=block_in,
228
+ out_dim=block_in,
229
+ kernel_size=kernel_size,
230
+ use_norm=True)
231
+
232
+ # end
233
+ self.conv_out = MPConv1D(block_in,
234
+ 2 * embed_dim if double_z else embed_dim,
235
+ kernel_size=kernel_size)
236
+
237
+ self.learnable_gain = nn.Parameter(torch.zeros([]))
238
+
239
+ def forward(self, x):
240
+
241
+ # downsampling
242
+ hs = [self.conv_in(x)]
243
+ for i_level in range(self.num_layers):
244
+ for i_block in range(self.num_res_blocks):
245
+ h = self.down[i_level].block[i_block](hs[-1])
246
+ if len(self.down[i_level].attn) > 0:
247
+ h = self.down[i_level].attn[i_block](h)
248
+ h = h.clamp(-self.clip_act, self.clip_act)
249
+ hs.append(h)
250
+ if i_level in self.down_layers:
251
+ hs.append(self.down[i_level].downsample(hs[-1]))
252
+
253
+ # middle
254
+ h = hs[-1]
255
+ h = self.mid.block_1(h)
256
+ h = self.mid.attn_1(h)
257
+ h = self.mid.block_2(h)
258
+ h = h.clamp(-self.clip_act, self.clip_act)
259
+
260
+ # end
261
+ h = nonlinearity(h)
262
+ h = self.conv_out(h, gain=(self.learnable_gain + 1))
263
+ return h
264
+
265
+
266
+ class Decoder1D(nn.Module):
267
+
268
+ def __init__(self,
269
+ *,
270
+ dim: int,
271
+ out_dim: int,
272
+ ch_mult: tuple[int] = (1, 2, 4, 8),
273
+ num_res_blocks: int,
274
+ attn_layers: list[int] = [],
275
+ down_layers: list[int] = [],
276
+ kernel_size: int = 3,
277
+ resamp_with_conv: bool = True,
278
+ in_dim: int,
279
+ embed_dim: int,
280
+ clip_act: float = 256.0):
281
+ super().__init__()
282
+ self.ch = dim
283
+ self.num_layers = len(ch_mult)
284
+ self.num_res_blocks = num_res_blocks
285
+ self.in_channels = in_dim
286
+ self.clip_act = clip_act
287
+ self.down_layers = [i + 1 for i in down_layers] # each downlayer add one
288
+
289
+ # compute in_ch_mult, block_in and curr_res at lowest res
290
+ block_in = dim * ch_mult[self.num_layers - 1]
291
+
292
+ # z to block_in
293
+ self.conv_in = MPConv1D(embed_dim, block_in, kernel_size=kernel_size)
294
+
295
+ # middle
296
+ self.mid = nn.Module()
297
+ self.mid.block_1 = ResnetBlock1D(in_dim=block_in, out_dim=block_in, use_norm=True)
298
+ self.mid.attn_1 = AttnBlock1D(block_in)
299
+ self.mid.block_2 = ResnetBlock1D(in_dim=block_in, out_dim=block_in, use_norm=True)
300
+
301
+ # upsampling
302
+ self.up = nn.ModuleList()
303
+ for i_level in reversed(range(self.num_layers)):
304
+ block = nn.ModuleList()
305
+ attn = nn.ModuleList()
306
+ block_out = dim * ch_mult[i_level]
307
+ for i_block in range(self.num_res_blocks + 1):
308
+ block.append(ResnetBlock1D(in_dim=block_in, out_dim=block_out, use_norm=True))
309
+ block_in = block_out
310
+ if i_level in attn_layers:
311
+ attn.append(AttnBlock1D(block_in))
312
+ up = nn.Module()
313
+ up.block = block
314
+ up.attn = attn
315
+ if i_level in self.down_layers:
316
+ up.upsample = Upsample1D(block_in, resamp_with_conv)
317
+ self.up.insert(0, up) # prepend to get consistent order
318
+
319
+ # end
320
+ self.conv_out = MPConv1D(block_in, out_dim, kernel_size=kernel_size)
321
+ self.learnable_gain = nn.Parameter(torch.zeros([]))
322
+
323
+ def forward(self, z):
324
+ # z to block_in
325
+ h = self.conv_in(z)
326
+
327
+ # middle
328
+ h = self.mid.block_1(h)
329
+ h = self.mid.attn_1(h)
330
+ h = self.mid.block_2(h)
331
+ h = h.clamp(-self.clip_act, self.clip_act)
332
+
333
+ # upsampling
334
+ for i_level in reversed(range(self.num_layers)):
335
+ for i_block in range(self.num_res_blocks + 1):
336
+ h = self.up[i_level].block[i_block](h)
337
+ if len(self.up[i_level].attn) > 0:
338
+ h = self.up[i_level].attn[i_block](h)
339
+ h = h.clamp(-self.clip_act, self.clip_act)
340
+ if i_level in self.down_layers:
341
+ h = self.up[i_level].upsample(h)
342
+
343
+ h = nonlinearity(h)
344
+ h = self.conv_out(h, gain=(self.learnable_gain + 1))
345
+ return h
346
+
347
+
348
+ def VAE_16k(**kwargs) -> VAE:
349
+ return VAE(data_dim=80, embed_dim=20, hidden_dim=384, **kwargs)
350
+
351
+
352
+ def VAE_44k(**kwargs) -> VAE:
353
+ return VAE(data_dim=128, embed_dim=40, hidden_dim=512, **kwargs)
354
+
355
+
356
+ def get_my_vae(name: str, **kwargs) -> VAE:
357
+ if name == '16k':
358
+ return VAE_16k(**kwargs)
359
+ if name == '44k':
360
+ return VAE_44k(**kwargs)
361
+ raise ValueError(f'Unknown model: {name}')
362
+
363
+
364
+ if __name__ == '__main__':
365
+ network = get_my_vae('standard')
366
+
367
+ # print the number of parameters in terms of millions
368
+ num_params = sum(p.numel() for p in network.parameters()) / 1e6
369
+ print(f'Number of parameters: {num_params:.2f}M')
mmaudio/ext/autoencoder/vae_modules.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from einops import rearrange
5
+
6
+ from mmaudio.ext.autoencoder.edm2_utils import (MPConv1D, mp_silu, mp_sum, normalize)
7
+
8
+
9
+ def nonlinearity(x):
10
+ # swish
11
+ return mp_silu(x)
12
+
13
+
14
+ class ResnetBlock1D(nn.Module):
15
+
16
+ def __init__(self, *, in_dim, out_dim=None, conv_shortcut=False, kernel_size=3, use_norm=True):
17
+ super().__init__()
18
+ self.in_dim = in_dim
19
+ out_dim = in_dim if out_dim is None else out_dim
20
+ self.out_dim = out_dim
21
+ self.use_conv_shortcut = conv_shortcut
22
+ self.use_norm = use_norm
23
+
24
+ self.conv1 = MPConv1D(in_dim, out_dim, kernel_size=kernel_size)
25
+ self.conv2 = MPConv1D(out_dim, out_dim, kernel_size=kernel_size)
26
+ if self.in_dim != self.out_dim:
27
+ if self.use_conv_shortcut:
28
+ self.conv_shortcut = MPConv1D(in_dim, out_dim, kernel_size=kernel_size)
29
+ else:
30
+ self.nin_shortcut = MPConv1D(in_dim, out_dim, kernel_size=1)
31
+
32
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
33
+
34
+ # pixel norm
35
+ if self.use_norm:
36
+ x = normalize(x, dim=1)
37
+
38
+ h = x
39
+ h = nonlinearity(h)
40
+ h = self.conv1(h)
41
+
42
+ h = nonlinearity(h)
43
+ h = self.conv2(h)
44
+
45
+ if self.in_dim != self.out_dim:
46
+ if self.use_conv_shortcut:
47
+ x = self.conv_shortcut(x)
48
+ else:
49
+ x = self.nin_shortcut(x)
50
+
51
+ return mp_sum(x, h, t=0.3)
52
+
53
+
54
+ class AttnBlock1D(nn.Module):
55
+
56
+ def __init__(self, in_channels, num_heads=1):
57
+ super().__init__()
58
+ self.in_channels = in_channels
59
+
60
+ self.num_heads = num_heads
61
+ self.qkv = MPConv1D(in_channels, in_channels * 3, kernel_size=1)
62
+ self.proj_out = MPConv1D(in_channels, in_channels, kernel_size=1)
63
+
64
+ def forward(self, x):
65
+ h = x
66
+ y = self.qkv(h)
67
+ y = y.reshape(y.shape[0], self.num_heads, -1, 3, y.shape[-1])
68
+ q, k, v = normalize(y, dim=2).unbind(3)
69
+
70
+ q = rearrange(q, 'b h c l -> b h l c')
71
+ k = rearrange(k, 'b h c l -> b h l c')
72
+ v = rearrange(v, 'b h c l -> b h l c')
73
+
74
+ h = F.scaled_dot_product_attention(q, k, v)
75
+ h = rearrange(h, 'b h l c -> b (h c) l')
76
+
77
+ h = self.proj_out(h)
78
+
79
+ return mp_sum(x, h, t=0.3)
80
+
81
+
82
+ class Upsample1D(nn.Module):
83
+
84
+ def __init__(self, in_channels, with_conv):
85
+ super().__init__()
86
+ self.with_conv = with_conv
87
+ if self.with_conv:
88
+ self.conv = MPConv1D(in_channels, in_channels, kernel_size=3)
89
+
90
+ def forward(self, x):
91
+ x = F.interpolate(x, scale_factor=2.0, mode='nearest-exact') # support 3D tensor(B,C,T)
92
+ if self.with_conv:
93
+ x = self.conv(x)
94
+ return x
95
+
96
+
97
+ class Downsample1D(nn.Module):
98
+
99
+ def __init__(self, in_channels, with_conv):
100
+ super().__init__()
101
+ self.with_conv = with_conv
102
+ if self.with_conv:
103
+ # no asymmetric padding in torch conv, must do it ourselves
104
+ self.conv1 = MPConv1D(in_channels, in_channels, kernel_size=1)
105
+ self.conv2 = MPConv1D(in_channels, in_channels, kernel_size=1)
106
+
107
+ def forward(self, x):
108
+
109
+ if self.with_conv:
110
+ x = self.conv1(x)
111
+
112
+ x = F.avg_pool1d(x, kernel_size=2, stride=2)
113
+
114
+ if self.with_conv:
115
+ x = self.conv2(x)
116
+
117
+ return x
mmaudio/ext/bigvgan/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 NVIDIA CORPORATION.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
mmaudio/ext/bigvgan/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .bigvgan import BigVGAN
mmaudio/ext/bigvgan/activations.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch
5
+ from torch import nn, sin, pow
6
+ from torch.nn import Parameter
7
+
8
+
9
+ class Snake(nn.Module):
10
+ '''
11
+ Implementation of a sine-based periodic activation function
12
+ Shape:
13
+ - Input: (B, C, T)
14
+ - Output: (B, C, T), same shape as the input
15
+ Parameters:
16
+ - alpha - trainable parameter
17
+ References:
18
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
19
+ https://arxiv.org/abs/2006.08195
20
+ Examples:
21
+ >>> a1 = snake(256)
22
+ >>> x = torch.randn(256)
23
+ >>> x = a1(x)
24
+ '''
25
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
26
+ '''
27
+ Initialization.
28
+ INPUT:
29
+ - in_features: shape of the input
30
+ - alpha: trainable parameter
31
+ alpha is initialized to 1 by default, higher values = higher-frequency.
32
+ alpha will be trained along with the rest of your model.
33
+ '''
34
+ super(Snake, self).__init__()
35
+ self.in_features = in_features
36
+
37
+ # initialize alpha
38
+ self.alpha_logscale = alpha_logscale
39
+ if self.alpha_logscale: # log scale alphas initialized to zeros
40
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
41
+ else: # linear scale alphas initialized to ones
42
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
43
+
44
+ self.alpha.requires_grad = alpha_trainable
45
+
46
+ self.no_div_by_zero = 0.000000001
47
+
48
+ def forward(self, x):
49
+ '''
50
+ Forward pass of the function.
51
+ Applies the function to the input elementwise.
52
+ Snake ∶= x + 1/a * sin^2 (xa)
53
+ '''
54
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
55
+ if self.alpha_logscale:
56
+ alpha = torch.exp(alpha)
57
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
58
+
59
+ return x
60
+
61
+
62
+ class SnakeBeta(nn.Module):
63
+ '''
64
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
65
+ Shape:
66
+ - Input: (B, C, T)
67
+ - Output: (B, C, T), same shape as the input
68
+ Parameters:
69
+ - alpha - trainable parameter that controls frequency
70
+ - beta - trainable parameter that controls magnitude
71
+ References:
72
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
73
+ https://arxiv.org/abs/2006.08195
74
+ Examples:
75
+ >>> a1 = snakebeta(256)
76
+ >>> x = torch.randn(256)
77
+ >>> x = a1(x)
78
+ '''
79
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
80
+ '''
81
+ Initialization.
82
+ INPUT:
83
+ - in_features: shape of the input
84
+ - alpha - trainable parameter that controls frequency
85
+ - beta - trainable parameter that controls magnitude
86
+ alpha is initialized to 1 by default, higher values = higher-frequency.
87
+ beta is initialized to 1 by default, higher values = higher-magnitude.
88
+ alpha will be trained along with the rest of your model.
89
+ '''
90
+ super(SnakeBeta, self).__init__()
91
+ self.in_features = in_features
92
+
93
+ # initialize alpha
94
+ self.alpha_logscale = alpha_logscale
95
+ if self.alpha_logscale: # log scale alphas initialized to zeros
96
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
97
+ self.beta = Parameter(torch.zeros(in_features) * alpha)
98
+ else: # linear scale alphas initialized to ones
99
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
100
+ self.beta = Parameter(torch.ones(in_features) * alpha)
101
+
102
+ self.alpha.requires_grad = alpha_trainable
103
+ self.beta.requires_grad = alpha_trainable
104
+
105
+ self.no_div_by_zero = 0.000000001
106
+
107
+ def forward(self, x):
108
+ '''
109
+ Forward pass of the function.
110
+ Applies the function to the input elementwise.
111
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
112
+ '''
113
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
114
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
115
+ if self.alpha_logscale:
116
+ alpha = torch.exp(alpha)
117
+ beta = torch.exp(beta)
118
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
119
+
120
+ return x
mmaudio/ext/bigvgan/alias_free_torch/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ from .filter import *
5
+ from .resample import *
6
+ from .act import *
mmaudio/ext/bigvgan/alias_free_torch/act.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch.nn as nn
5
+ from .resample import UpSample1d, DownSample1d
6
+
7
+
8
+ class Activation1d(nn.Module):
9
+ def __init__(self,
10
+ activation,
11
+ up_ratio: int = 2,
12
+ down_ratio: int = 2,
13
+ up_kernel_size: int = 12,
14
+ down_kernel_size: int = 12):
15
+ super().__init__()
16
+ self.up_ratio = up_ratio
17
+ self.down_ratio = down_ratio
18
+ self.act = activation
19
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
20
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
21
+
22
+ # x: [B,C,T]
23
+ def forward(self, x):
24
+ x = self.upsample(x)
25
+ x = self.act(x)
26
+ x = self.downsample(x)
27
+
28
+ return x
mmaudio/ext/bigvgan/alias_free_torch/filter.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import math
8
+
9
+ if 'sinc' in dir(torch):
10
+ sinc = torch.sinc
11
+ else:
12
+ # This code is adopted from adefossez's julius.core.sinc under the MIT License
13
+ # https://adefossez.github.io/julius/julius/core.html
14
+ # LICENSE is in incl_licenses directory.
15
+ def sinc(x: torch.Tensor):
16
+ """
17
+ Implementation of sinc, i.e. sin(pi * x) / (pi * x)
18
+ __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
19
+ """
20
+ return torch.where(x == 0,
21
+ torch.tensor(1., device=x.device, dtype=x.dtype),
22
+ torch.sin(math.pi * x) / math.pi / x)
23
+
24
+
25
+ # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
26
+ # https://adefossez.github.io/julius/julius/lowpass.html
27
+ # LICENSE is in incl_licenses directory.
28
+ def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
29
+ even = (kernel_size % 2 == 0)
30
+ half_size = kernel_size // 2
31
+
32
+ #For kaiser window
33
+ delta_f = 4 * half_width
34
+ A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
35
+ if A > 50.:
36
+ beta = 0.1102 * (A - 8.7)
37
+ elif A >= 21.:
38
+ beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
39
+ else:
40
+ beta = 0.
41
+ window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
42
+
43
+ # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
44
+ if even:
45
+ time = (torch.arange(-half_size, half_size) + 0.5)
46
+ else:
47
+ time = torch.arange(kernel_size) - half_size
48
+ if cutoff == 0:
49
+ filter_ = torch.zeros_like(time)
50
+ else:
51
+ filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
52
+ # Normalize filter to have sum = 1, otherwise we will have a small leakage
53
+ # of the constant component in the input signal.
54
+ filter_ /= filter_.sum()
55
+ filter = filter_.view(1, 1, kernel_size)
56
+
57
+ return filter
58
+
59
+
60
+ class LowPassFilter1d(nn.Module):
61
+ def __init__(self,
62
+ cutoff=0.5,
63
+ half_width=0.6,
64
+ stride: int = 1,
65
+ padding: bool = True,
66
+ padding_mode: str = 'replicate',
67
+ kernel_size: int = 12):
68
+ # kernel_size should be even number for stylegan3 setup,
69
+ # in this implementation, odd number is also possible.
70
+ super().__init__()
71
+ if cutoff < -0.:
72
+ raise ValueError("Minimum cutoff must be larger than zero.")
73
+ if cutoff > 0.5:
74
+ raise ValueError("A cutoff above 0.5 does not make sense.")
75
+ self.kernel_size = kernel_size
76
+ self.even = (kernel_size % 2 == 0)
77
+ self.pad_left = kernel_size // 2 - int(self.even)
78
+ self.pad_right = kernel_size // 2
79
+ self.stride = stride
80
+ self.padding = padding
81
+ self.padding_mode = padding_mode
82
+ filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
83
+ self.register_buffer("filter", filter)
84
+
85
+ #input [B, C, T]
86
+ def forward(self, x):
87
+ _, C, _ = x.shape
88
+
89
+ if self.padding:
90
+ x = F.pad(x, (self.pad_left, self.pad_right),
91
+ mode=self.padding_mode)
92
+ out = F.conv1d(x, self.filter.expand(C, -1, -1),
93
+ stride=self.stride, groups=C)
94
+
95
+ return out
mmaudio/ext/bigvgan/alias_free_torch/resample.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+ from .filter import LowPassFilter1d
7
+ from .filter import kaiser_sinc_filter1d
8
+
9
+
10
+ class UpSample1d(nn.Module):
11
+ def __init__(self, ratio=2, kernel_size=None):
12
+ super().__init__()
13
+ self.ratio = ratio
14
+ self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
15
+ self.stride = ratio
16
+ self.pad = self.kernel_size // ratio - 1
17
+ self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
18
+ self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
19
+ filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
20
+ half_width=0.6 / ratio,
21
+ kernel_size=self.kernel_size)
22
+ self.register_buffer("filter", filter)
23
+
24
+ # x: [B, C, T]
25
+ def forward(self, x):
26
+ _, C, _ = x.shape
27
+
28
+ x = F.pad(x, (self.pad, self.pad), mode='replicate')
29
+ x = self.ratio * F.conv_transpose1d(
30
+ x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
31
+ x = x[..., self.pad_left:-self.pad_right]
32
+
33
+ return x
34
+
35
+
36
+ class DownSample1d(nn.Module):
37
+ def __init__(self, ratio=2, kernel_size=None):
38
+ super().__init__()
39
+ self.ratio = ratio
40
+ self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
41
+ self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
42
+ half_width=0.6 / ratio,
43
+ stride=ratio,
44
+ kernel_size=self.kernel_size)
45
+
46
+ def forward(self, x):
47
+ xx = self.lowpass(x)
48
+
49
+ return xx
mmaudio/ext/bigvgan/bigvgan.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from omegaconf import OmegaConf
6
+
7
+ from mmaudio.ext.bigvgan.models import BigVGANVocoder
8
+
9
+ _bigvgan_vocoder_path = Path(__file__).parent / 'bigvgan_vocoder.yml'
10
+
11
+
12
+ class BigVGAN(nn.Module):
13
+
14
+ def __init__(self, ckpt_path, config_path=_bigvgan_vocoder_path):
15
+ super().__init__()
16
+ vocoder_cfg = OmegaConf.load(config_path)
17
+ self.vocoder = BigVGANVocoder(vocoder_cfg).eval()
18
+ vocoder_ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=True)['generator']
19
+ self.vocoder.load_state_dict(vocoder_ckpt)
20
+
21
+ self.weight_norm_removed = False
22
+ self.remove_weight_norm()
23
+
24
+ @torch.inference_mode()
25
+ def forward(self, x):
26
+ assert self.weight_norm_removed, 'call remove_weight_norm() before inference'
27
+ return self.vocoder(x)
28
+
29
+ def remove_weight_norm(self):
30
+ self.vocoder.remove_weight_norm()
31
+ self.weight_norm_removed = True
32
+ return self
mmaudio/ext/bigvgan/bigvgan_vocoder.yml ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ resblock: '1'
2
+ num_gpus: 0
3
+ batch_size: 64
4
+ num_mels: 80
5
+ learning_rate: 0.0001
6
+ adam_b1: 0.8
7
+ adam_b2: 0.99
8
+ lr_decay: 0.999
9
+ seed: 1234
10
+ upsample_rates:
11
+ - 4
12
+ - 4
13
+ - 2
14
+ - 2
15
+ - 2
16
+ - 2
17
+ upsample_kernel_sizes:
18
+ - 8
19
+ - 8
20
+ - 4
21
+ - 4
22
+ - 4
23
+ - 4
24
+ upsample_initial_channel: 1536
25
+ resblock_kernel_sizes:
26
+ - 3
27
+ - 7
28
+ - 11
29
+ resblock_dilation_sizes:
30
+ - - 1
31
+ - 3
32
+ - 5
33
+ - - 1
34
+ - 3
35
+ - 5
36
+ - - 1
37
+ - 3
38
+ - 5
39
+ activation: snakebeta
40
+ snake_logscale: true
41
+ resolutions:
42
+ - - 1024
43
+ - 120
44
+ - 600
45
+ - - 2048
46
+ - 240
47
+ - 1200
48
+ - - 512
49
+ - 50
50
+ - 240
51
+ mpd_reshapes:
52
+ - 2
53
+ - 3
54
+ - 5
55
+ - 7
56
+ - 11
57
+ use_spectral_norm: false
58
+ discriminator_channel_mult: 1
59
+ num_workers: 4
60
+ dist_config:
61
+ dist_backend: nccl
62
+ dist_url: tcp://localhost:54341
63
+ world_size: 1
mmaudio/ext/bigvgan/env.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import os
5
+ import shutil
6
+
7
+
8
+ class AttrDict(dict):
9
+ def __init__(self, *args, **kwargs):
10
+ super(AttrDict, self).__init__(*args, **kwargs)
11
+ self.__dict__ = self
12
+
13
+
14
+ def build_env(config, config_name, path):
15
+ t_path = os.path.join(path, config_name)
16
+ if config != t_path:
17
+ os.makedirs(path, exist_ok=True)
18
+ shutil.copyfile(config, os.path.join(path, config_name))
mmaudio/ext/bigvgan/incl_licenses/LICENSE_1 ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2020 Jungil Kong
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
mmaudio/ext/bigvgan/incl_licenses/LICENSE_2 ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2020 Edward Dixon
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
mmaudio/ext/bigvgan/incl_licenses/LICENSE_3 ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
mmaudio/ext/bigvgan/incl_licenses/LICENSE_4 ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2019, Seungwon Park 박승원
4
+ All rights reserved.
5
+
6
+ Redistribution and use in source and binary forms, with or without
7
+ modification, are permitted provided that the following conditions are met:
8
+
9
+ 1. Redistributions of source code must retain the above copyright notice, this
10
+ list of conditions and the following disclaimer.
11
+
12
+ 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ this list of conditions and the following disclaimer in the documentation
14
+ and/or other materials provided with the distribution.
15
+
16
+ 3. Neither the name of the copyright holder nor the names of its
17
+ contributors may be used to endorse or promote products derived from
18
+ this software without specific prior written permission.
19
+
20
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
mmaudio/ext/bigvgan/incl_licenses/LICENSE_5 ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright 2020 Alexandre Défossez
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
4
+ associated documentation files (the "Software"), to deal in the Software without restriction,
5
+ including without limitation the rights to use, copy, modify, merge, publish, distribute,
6
+ sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
7
+ furnished to do so, subject to the following conditions:
8
+
9
+ The above copyright notice and this permission notice shall be included in all copies or
10
+ substantial portions of the Software.
11
+
12
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
13
+ NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
14
+ NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
15
+ DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
16
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
mmaudio/ext/bigvgan/models.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.nn import Conv1d, ConvTranspose1d
10
+ from torch.nn.utils.parametrizations import weight_norm
11
+ from torch.nn.utils.parametrize import remove_parametrizations
12
+
13
+ from mmaudio.ext.bigvgan import activations
14
+ from mmaudio.ext.bigvgan.alias_free_torch import *
15
+ from mmaudio.ext.bigvgan.utils import get_padding, init_weights
16
+
17
+ LRELU_SLOPE = 0.1
18
+
19
+
20
+ class AMPBlock1(torch.nn.Module):
21
+
22
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None):
23
+ super(AMPBlock1, self).__init__()
24
+ self.h = h
25
+
26
+ self.convs1 = nn.ModuleList([
27
+ weight_norm(
28
+ Conv1d(channels,
29
+ channels,
30
+ kernel_size,
31
+ 1,
32
+ dilation=dilation[0],
33
+ padding=get_padding(kernel_size, dilation[0]))),
34
+ weight_norm(
35
+ Conv1d(channels,
36
+ channels,
37
+ kernel_size,
38
+ 1,
39
+ dilation=dilation[1],
40
+ padding=get_padding(kernel_size, dilation[1]))),
41
+ weight_norm(
42
+ Conv1d(channels,
43
+ channels,
44
+ kernel_size,
45
+ 1,
46
+ dilation=dilation[2],
47
+ padding=get_padding(kernel_size, dilation[2])))
48
+ ])
49
+ self.convs1.apply(init_weights)
50
+
51
+ self.convs2 = nn.ModuleList([
52
+ weight_norm(
53
+ Conv1d(channels,
54
+ channels,
55
+ kernel_size,
56
+ 1,
57
+ dilation=1,
58
+ padding=get_padding(kernel_size, 1))),
59
+ weight_norm(
60
+ Conv1d(channels,
61
+ channels,
62
+ kernel_size,
63
+ 1,
64
+ dilation=1,
65
+ padding=get_padding(kernel_size, 1))),
66
+ weight_norm(
67
+ Conv1d(channels,
68
+ channels,
69
+ kernel_size,
70
+ 1,
71
+ dilation=1,
72
+ padding=get_padding(kernel_size, 1)))
73
+ ])
74
+ self.convs2.apply(init_weights)
75
+
76
+ self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
77
+
78
+ if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
79
+ self.activations = nn.ModuleList([
80
+ Activation1d(
81
+ activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
82
+ for _ in range(self.num_layers)
83
+ ])
84
+ elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
85
+ self.activations = nn.ModuleList([
86
+ Activation1d(
87
+ activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
88
+ for _ in range(self.num_layers)
89
+ ])
90
+ else:
91
+ raise NotImplementedError(
92
+ "activation incorrectly specified. check the config file and look for 'activation'."
93
+ )
94
+
95
+ def forward(self, x):
96
+ acts1, acts2 = self.activations[::2], self.activations[1::2]
97
+ for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
98
+ xt = a1(x)
99
+ xt = c1(xt)
100
+ xt = a2(xt)
101
+ xt = c2(xt)
102
+ x = xt + x
103
+
104
+ return x
105
+
106
+ def remove_weight_norm(self):
107
+ for l in self.convs1:
108
+ remove_parametrizations(l, 'weight')
109
+ for l in self.convs2:
110
+ remove_parametrizations(l, 'weight')
111
+
112
+
113
+ class AMPBlock2(torch.nn.Module):
114
+
115
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), activation=None):
116
+ super(AMPBlock2, self).__init__()
117
+ self.h = h
118
+
119
+ self.convs = nn.ModuleList([
120
+ weight_norm(
121
+ Conv1d(channels,
122
+ channels,
123
+ kernel_size,
124
+ 1,
125
+ dilation=dilation[0],
126
+ padding=get_padding(kernel_size, dilation[0]))),
127
+ weight_norm(
128
+ Conv1d(channels,
129
+ channels,
130
+ kernel_size,
131
+ 1,
132
+ dilation=dilation[1],
133
+ padding=get_padding(kernel_size, dilation[1])))
134
+ ])
135
+ self.convs.apply(init_weights)
136
+
137
+ self.num_layers = len(self.convs) # total number of conv layers
138
+
139
+ if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
140
+ self.activations = nn.ModuleList([
141
+ Activation1d(
142
+ activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
143
+ for _ in range(self.num_layers)
144
+ ])
145
+ elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
146
+ self.activations = nn.ModuleList([
147
+ Activation1d(
148
+ activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
149
+ for _ in range(self.num_layers)
150
+ ])
151
+ else:
152
+ raise NotImplementedError(
153
+ "activation incorrectly specified. check the config file and look for 'activation'."
154
+ )
155
+
156
+ def forward(self, x):
157
+ for c, a in zip(self.convs, self.activations):
158
+ xt = a(x)
159
+ xt = c(xt)
160
+ x = xt + x
161
+
162
+ return x
163
+
164
+ def remove_weight_norm(self):
165
+ for l in self.convs:
166
+ remove_parametrizations(l, 'weight')
167
+
168
+
169
+ class BigVGANVocoder(torch.nn.Module):
170
+ # this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks.
171
+ def __init__(self, h):
172
+ super().__init__()
173
+ self.h = h
174
+
175
+ self.num_kernels = len(h.resblock_kernel_sizes)
176
+ self.num_upsamples = len(h.upsample_rates)
177
+
178
+ # pre conv
179
+ self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
180
+
181
+ # define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
182
+ resblock = AMPBlock1 if h.resblock == '1' else AMPBlock2
183
+
184
+ # transposed conv-based upsamplers. does not apply anti-aliasing
185
+ self.ups = nn.ModuleList()
186
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
187
+ self.ups.append(
188
+ nn.ModuleList([
189
+ weight_norm(
190
+ ConvTranspose1d(h.upsample_initial_channel // (2**i),
191
+ h.upsample_initial_channel // (2**(i + 1)),
192
+ k,
193
+ u,
194
+ padding=(k - u) // 2))
195
+ ]))
196
+
197
+ # residual blocks using anti-aliased multi-periodicity composition modules (AMP)
198
+ self.resblocks = nn.ModuleList()
199
+ for i in range(len(self.ups)):
200
+ ch = h.upsample_initial_channel // (2**(i + 1))
201
+ for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
202
+ self.resblocks.append(resblock(h, ch, k, d, activation=h.activation))
203
+
204
+ # post conv
205
+ if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing
206
+ activation_post = activations.Snake(ch, alpha_logscale=h.snake_logscale)
207
+ self.activation_post = Activation1d(activation=activation_post)
208
+ elif h.activation == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing
209
+ activation_post = activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
210
+ self.activation_post = Activation1d(activation=activation_post)
211
+ else:
212
+ raise NotImplementedError(
213
+ "activation incorrectly specified. check the config file and look for 'activation'."
214
+ )
215
+
216
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
217
+
218
+ # weight initialization
219
+ for i in range(len(self.ups)):
220
+ self.ups[i].apply(init_weights)
221
+ self.conv_post.apply(init_weights)
222
+
223
+ def forward(self, x):
224
+ # pre conv
225
+ x = self.conv_pre(x)
226
+
227
+ for i in range(self.num_upsamples):
228
+ # upsampling
229
+ for i_up in range(len(self.ups[i])):
230
+ x = self.ups[i][i_up](x)
231
+ # AMP blocks
232
+ xs = None
233
+ for j in range(self.num_kernels):
234
+ if xs is None:
235
+ xs = self.resblocks[i * self.num_kernels + j](x)
236
+ else:
237
+ xs += self.resblocks[i * self.num_kernels + j](x)
238
+ x = xs / self.num_kernels
239
+
240
+ # post conv
241
+ x = self.activation_post(x)
242
+ x = self.conv_post(x)
243
+ x = torch.tanh(x)
244
+
245
+ return x
246
+
247
+ def remove_weight_norm(self):
248
+ print('Removing weight norm...')
249
+ for l in self.ups:
250
+ for l_i in l:
251
+ remove_parametrizations(l_i, 'weight')
252
+ for l in self.resblocks:
253
+ l.remove_weight_norm()
254
+ remove_parametrizations(self.conv_pre, 'weight')
255
+ remove_parametrizations(self.conv_post, 'weight')
mmaudio/ext/bigvgan/utils.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import os
5
+
6
+ import torch
7
+ from torch.nn.utils.parametrizations import weight_norm
8
+
9
+
10
+ def init_weights(m, mean=0.0, std=0.01):
11
+ classname = m.__class__.__name__
12
+ if classname.find("Conv") != -1:
13
+ m.weight.data.normal_(mean, std)
14
+
15
+
16
+ def apply_weight_norm(m):
17
+ classname = m.__class__.__name__
18
+ if classname.find("Conv") != -1:
19
+ weight_norm(m)
20
+
21
+
22
+ def get_padding(kernel_size, dilation=1):
23
+ return int((kernel_size * dilation - dilation) / 2)
24
+
25
+
26
+ def load_checkpoint(filepath, device):
27
+ assert os.path.isfile(filepath)
28
+ print("Loading '{}'".format(filepath))
29
+ checkpoint_dict = torch.load(filepath, map_location=device)
30
+ print("Complete.")
31
+ return checkpoint_dict
mmaudio/ext/bigvgan_v2/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 NVIDIA CORPORATION.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
mmaudio/ext/bigvgan_v2/__init__.py ADDED
File without changes
mmaudio/ext/bigvgan_v2/activations.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch
5
+ from torch import nn, sin, pow
6
+ from torch.nn import Parameter
7
+
8
+
9
+ class Snake(nn.Module):
10
+ """
11
+ Implementation of a sine-based periodic activation function
12
+ Shape:
13
+ - Input: (B, C, T)
14
+ - Output: (B, C, T), same shape as the input
15
+ Parameters:
16
+ - alpha - trainable parameter
17
+ References:
18
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
19
+ https://arxiv.org/abs/2006.08195
20
+ Examples:
21
+ >>> a1 = snake(256)
22
+ >>> x = torch.randn(256)
23
+ >>> x = a1(x)
24
+ """
25
+
26
+ def __init__(
27
+ self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
28
+ ):
29
+ """
30
+ Initialization.
31
+ INPUT:
32
+ - in_features: shape of the input
33
+ - alpha: trainable parameter
34
+ alpha is initialized to 1 by default, higher values = higher-frequency.
35
+ alpha will be trained along with the rest of your model.
36
+ """
37
+ super(Snake, self).__init__()
38
+ self.in_features = in_features
39
+
40
+ # Initialize alpha
41
+ self.alpha_logscale = alpha_logscale
42
+ if self.alpha_logscale: # Log scale alphas initialized to zeros
43
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
44
+ else: # Linear scale alphas initialized to ones
45
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
46
+
47
+ self.alpha.requires_grad = alpha_trainable
48
+
49
+ self.no_div_by_zero = 0.000000001
50
+
51
+ def forward(self, x):
52
+ """
53
+ Forward pass of the function.
54
+ Applies the function to the input elementwise.
55
+ Snake ∶= x + 1/a * sin^2 (xa)
56
+ """
57
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # Line up with x to [B, C, T]
58
+ if self.alpha_logscale:
59
+ alpha = torch.exp(alpha)
60
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
61
+
62
+ return x
63
+
64
+
65
+ class SnakeBeta(nn.Module):
66
+ """
67
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
68
+ Shape:
69
+ - Input: (B, C, T)
70
+ - Output: (B, C, T), same shape as the input
71
+ Parameters:
72
+ - alpha - trainable parameter that controls frequency
73
+ - beta - trainable parameter that controls magnitude
74
+ References:
75
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
76
+ https://arxiv.org/abs/2006.08195
77
+ Examples:
78
+ >>> a1 = snakebeta(256)
79
+ >>> x = torch.randn(256)
80
+ >>> x = a1(x)
81
+ """
82
+
83
+ def __init__(
84
+ self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
85
+ ):
86
+ """
87
+ Initialization.
88
+ INPUT:
89
+ - in_features: shape of the input
90
+ - alpha - trainable parameter that controls frequency
91
+ - beta - trainable parameter that controls magnitude
92
+ alpha is initialized to 1 by default, higher values = higher-frequency.
93
+ beta is initialized to 1 by default, higher values = higher-magnitude.
94
+ alpha will be trained along with the rest of your model.
95
+ """
96
+ super(SnakeBeta, self).__init__()
97
+ self.in_features = in_features
98
+
99
+ # Initialize alpha
100
+ self.alpha_logscale = alpha_logscale
101
+ if self.alpha_logscale: # Log scale alphas initialized to zeros
102
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
103
+ self.beta = Parameter(torch.zeros(in_features) * alpha)
104
+ else: # Linear scale alphas initialized to ones
105
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
106
+ self.beta = Parameter(torch.ones(in_features) * alpha)
107
+
108
+ self.alpha.requires_grad = alpha_trainable
109
+ self.beta.requires_grad = alpha_trainable
110
+
111
+ self.no_div_by_zero = 0.000000001
112
+
113
+ def forward(self, x):
114
+ """
115
+ Forward pass of the function.
116
+ Applies the function to the input elementwise.
117
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
118
+ """
119
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # Line up with x to [B, C, T]
120
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
121
+ if self.alpha_logscale:
122
+ alpha = torch.exp(alpha)
123
+ beta = torch.exp(beta)
124
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
125
+
126
+ return x
mmaudio/ext/bigvgan_v2/alias_free_activation/cuda/__init__.py ADDED
File without changes
mmaudio/ext/bigvgan_v2/alias_free_activation/cuda/activation1d.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from alias_free_activation.torch.resample import UpSample1d, DownSample1d
7
+
8
+ # load fused CUDA kernel: this enables importing anti_alias_activation_cuda
9
+ from alias_free_activation.cuda import load
10
+
11
+ anti_alias_activation_cuda = load.load()
12
+
13
+
14
+ class FusedAntiAliasActivation(torch.autograd.Function):
15
+ """
16
+ Assumes filter size 12, replication padding on upsampling/downsampling, and logscale alpha/beta parameters as inputs.
17
+ The hyperparameters are hard-coded in the kernel to maximize speed.
18
+ NOTE: The fused kenrel is incorrect for Activation1d with different hyperparameters.
19
+ """
20
+
21
+ @staticmethod
22
+ def forward(ctx, inputs, up_ftr, down_ftr, alpha, beta):
23
+ activation_results = anti_alias_activation_cuda.forward(
24
+ inputs, up_ftr, down_ftr, alpha, beta
25
+ )
26
+
27
+ return activation_results
28
+
29
+ @staticmethod
30
+ def backward(ctx, output_grads):
31
+ raise NotImplementedError
32
+ return output_grads, None, None
33
+
34
+
35
+ class Activation1d(nn.Module):
36
+ def __init__(
37
+ self,
38
+ activation,
39
+ up_ratio: int = 2,
40
+ down_ratio: int = 2,
41
+ up_kernel_size: int = 12,
42
+ down_kernel_size: int = 12,
43
+ fused: bool = True,
44
+ ):
45
+ super().__init__()
46
+ self.up_ratio = up_ratio
47
+ self.down_ratio = down_ratio
48
+ self.act = activation
49
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
50
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
51
+
52
+ self.fused = fused # Whether to use fused CUDA kernel or not
53
+
54
+ def forward(self, x):
55
+ if not self.fused:
56
+ x = self.upsample(x)
57
+ x = self.act(x)
58
+ x = self.downsample(x)
59
+ return x
60
+ else:
61
+ if self.act.__class__.__name__ == "Snake":
62
+ beta = self.act.alpha.data # Snake uses same params for alpha and beta
63
+ else:
64
+ beta = (
65
+ self.act.beta.data
66
+ ) # Snakebeta uses different params for alpha and beta
67
+ alpha = self.act.alpha.data
68
+ if (
69
+ not self.act.alpha_logscale
70
+ ): # Exp baked into cuda kernel, cancel it out with a log
71
+ alpha = torch.log(alpha)
72
+ beta = torch.log(beta)
73
+
74
+ x = FusedAntiAliasActivation.apply(
75
+ x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta
76
+ )
77
+ return x
mmaudio/ext/bigvgan_v2/alias_free_activation/cuda/anti_alias_activation.cpp ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* coding=utf-8
2
+ * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #include <torch/extension.h>
18
+
19
+ extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta);
20
+
21
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
22
+ m.def("forward", &fwd_cuda, "Anti-Alias Activation forward (CUDA)");
23
+ }
mmaudio/ext/bigvgan_v2/alias_free_activation/cuda/anti_alias_activation_cuda.cu ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* coding=utf-8
2
+ * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #include <ATen/ATen.h>
18
+ #include <cuda.h>
19
+ #include <cuda_runtime.h>
20
+ #include <cuda_fp16.h>
21
+ #include <cuda_profiler_api.h>
22
+ #include <ATen/cuda/CUDAContext.h>
23
+ #include <torch/extension.h>
24
+ #include "type_shim.h"
25
+ #include <assert.h>
26
+ #include <cfloat>
27
+ #include <limits>
28
+ #include <stdint.h>
29
+ #include <c10/macros/Macros.h>
30
+
31
+ namespace
32
+ {
33
+ // Hard-coded hyperparameters
34
+ // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
35
+ constexpr int ELEMENTS_PER_LDG_STG = 1; //(WARP_ITERATIONS < 4) ? 1 : 4;
36
+ constexpr int BUFFER_SIZE = 32;
37
+ constexpr int FILTER_SIZE = 12;
38
+ constexpr int HALF_FILTER_SIZE = 6;
39
+ constexpr int UPSAMPLE_REPLICATION_PAD = 5; // 5 on each side, matching torch impl
40
+ constexpr int DOWNSAMPLE_REPLICATION_PAD_LEFT = 5; // matching torch impl
41
+ constexpr int DOWNSAMPLE_REPLICATION_PAD_RIGHT = 6; // matching torch impl
42
+
43
+ template <typename input_t, typename output_t, typename acc_t>
44
+ __global__ void anti_alias_activation_forward(
45
+ output_t *dst,
46
+ const input_t *src,
47
+ const input_t *up_ftr,
48
+ const input_t *down_ftr,
49
+ const input_t *alpha,
50
+ const input_t *beta,
51
+ int batch_size,
52
+ int channels,
53
+ int seq_len)
54
+ {
55
+ // Up and downsample filters
56
+ input_t up_filter[FILTER_SIZE];
57
+ input_t down_filter[FILTER_SIZE];
58
+
59
+ // Load data from global memory including extra indices reserved for replication paddings
60
+ input_t elements[2 * FILTER_SIZE + 2 * BUFFER_SIZE + 2 * UPSAMPLE_REPLICATION_PAD] = {0};
61
+ input_t intermediates[2 * FILTER_SIZE + 2 * BUFFER_SIZE + DOWNSAMPLE_REPLICATION_PAD_LEFT + DOWNSAMPLE_REPLICATION_PAD_RIGHT] = {0};
62
+
63
+ // Output stores downsampled output before writing to dst
64
+ output_t output[BUFFER_SIZE];
65
+
66
+ // blockDim/threadIdx = (128, 1, 1)
67
+ // gridDim/blockIdx = (seq_blocks, channels, batches)
68
+ int block_offset = (blockIdx.x * 128 * BUFFER_SIZE + seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
69
+ int local_offset = threadIdx.x * BUFFER_SIZE;
70
+ int seq_offset = blockIdx.x * 128 * BUFFER_SIZE + local_offset;
71
+
72
+ // intermediate have double the seq_len
73
+ int intermediate_local_offset = threadIdx.x * BUFFER_SIZE * 2;
74
+ int intermediate_seq_offset = blockIdx.x * 128 * BUFFER_SIZE * 2 + intermediate_local_offset;
75
+
76
+ // Get values needed for replication padding before moving pointer
77
+ const input_t *right_most_pntr = src + (seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
78
+ input_t seq_left_most_value = right_most_pntr[0];
79
+ input_t seq_right_most_value = right_most_pntr[seq_len - 1];
80
+
81
+ // Move src and dst pointers
82
+ src += block_offset + local_offset;
83
+ dst += block_offset + local_offset;
84
+
85
+ // Alpha and beta values for snake activatons. Applies exp by default
86
+ alpha = alpha + blockIdx.y;
87
+ input_t alpha_val = expf(alpha[0]);
88
+ beta = beta + blockIdx.y;
89
+ input_t beta_val = expf(beta[0]);
90
+
91
+ #pragma unroll
92
+ for (int it = 0; it < FILTER_SIZE; it += 1)
93
+ {
94
+ up_filter[it] = up_ftr[it];
95
+ down_filter[it] = down_ftr[it];
96
+ }
97
+
98
+ // Apply replication padding for upsampling, matching torch impl
99
+ #pragma unroll
100
+ for (int it = -HALF_FILTER_SIZE; it < BUFFER_SIZE + HALF_FILTER_SIZE; it += 1)
101
+ {
102
+ int element_index = seq_offset + it; // index for element
103
+ if ((element_index < 0) && (element_index >= -UPSAMPLE_REPLICATION_PAD))
104
+ {
105
+ elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_left_most_value;
106
+ }
107
+ if ((element_index >= seq_len) && (element_index < seq_len + UPSAMPLE_REPLICATION_PAD))
108
+ {
109
+ elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_right_most_value;
110
+ }
111
+ if ((element_index >= 0) && (element_index < seq_len))
112
+ {
113
+ elements[2 * (HALF_FILTER_SIZE + it)] = 2 * src[it];
114
+ }
115
+ }
116
+
117
+ // Apply upsampling strided convolution and write to intermediates. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT for replication padding of the downsampilng conv later
118
+ #pragma unroll
119
+ for (int it = 0; it < (2 * BUFFER_SIZE + 2 * FILTER_SIZE); it += 1)
120
+ {
121
+ input_t acc = 0.0;
122
+ int element_index = intermediate_seq_offset + it; // index for intermediate
123
+ #pragma unroll
124
+ for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
125
+ {
126
+ if ((element_index + f_idx) >= 0)
127
+ {
128
+ acc += up_filter[f_idx] * elements[it + f_idx];
129
+ }
130
+ }
131
+ intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] = acc;
132
+ }
133
+
134
+ // Apply activation function. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT and DOWNSAMPLE_REPLICATION_PAD_RIGHT for replication padding of the downsampilng conv later
135
+ double no_div_by_zero = 0.000000001;
136
+ #pragma unroll
137
+ for (int it = 0; it < 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it += 1)
138
+ {
139
+ intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] += (1.0 / (beta_val + no_div_by_zero)) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val);
140
+ }
141
+
142
+ // Apply replication padding before downsampling conv from intermediates
143
+ #pragma unroll
144
+ for (int it = 0; it < DOWNSAMPLE_REPLICATION_PAD_LEFT; it += 1)
145
+ {
146
+ intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT];
147
+ }
148
+ #pragma unroll
149
+ for (int it = DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it < DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE + DOWNSAMPLE_REPLICATION_PAD_RIGHT; it += 1)
150
+ {
151
+ intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE - 1];
152
+ }
153
+
154
+ // Apply downsample strided convolution (assuming stride=2) from intermediates
155
+ #pragma unroll
156
+ for (int it = 0; it < BUFFER_SIZE; it += 1)
157
+ {
158
+ input_t acc = 0.0;
159
+ #pragma unroll
160
+ for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
161
+ {
162
+ // Add constant DOWNSAMPLE_REPLICATION_PAD_RIGHT to match torch implementation
163
+ acc += down_filter[f_idx] * intermediates[it * 2 + f_idx + DOWNSAMPLE_REPLICATION_PAD_RIGHT];
164
+ }
165
+ output[it] = acc;
166
+ }
167
+
168
+ // Write output to dst
169
+ #pragma unroll
170
+ for (int it = 0; it < BUFFER_SIZE; it += ELEMENTS_PER_LDG_STG)
171
+ {
172
+ int element_index = seq_offset + it;
173
+ if (element_index < seq_len)
174
+ {
175
+ dst[it] = output[it];
176
+ }
177
+ }
178
+
179
+ }
180
+
181
+ template <typename input_t, typename output_t, typename acc_t>
182
+ void dispatch_anti_alias_activation_forward(
183
+ output_t *dst,
184
+ const input_t *src,
185
+ const input_t *up_ftr,
186
+ const input_t *down_ftr,
187
+ const input_t *alpha,
188
+ const input_t *beta,
189
+ int batch_size,
190
+ int channels,
191
+ int seq_len)
192
+ {
193
+ if (seq_len == 0)
194
+ {
195
+ return;
196
+ }
197
+ else
198
+ {
199
+ // Use 128 threads per block to maximimize gpu utilization
200
+ constexpr int threads_per_block = 128;
201
+ constexpr int seq_len_per_block = 4096;
202
+ int blocks_per_seq_len = (seq_len + seq_len_per_block - 1) / seq_len_per_block;
203
+ dim3 blocks(blocks_per_seq_len, channels, batch_size);
204
+ dim3 threads(threads_per_block, 1, 1);
205
+
206
+ anti_alias_activation_forward<input_t, output_t, acc_t>
207
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, up_ftr, down_ftr, alpha, beta, batch_size, channels, seq_len);
208
+ }
209
+ }
210
+ }
211
+
212
+ extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta)
213
+ {
214
+ // Input is a 3d tensor with dimensions [batches, channels, seq_len]
215
+ const int batches = input.size(0);
216
+ const int channels = input.size(1);
217
+ const int seq_len = input.size(2);
218
+
219
+ // Output
220
+ auto act_options = input.options().requires_grad(false);
221
+
222
+ torch::Tensor anti_alias_activation_results =
223
+ torch::empty({batches, channels, seq_len}, act_options);
224
+
225
+ void *input_ptr = static_cast<void *>(input.data_ptr());
226
+ void *up_filter_ptr = static_cast<void *>(up_filter.data_ptr());
227
+ void *down_filter_ptr = static_cast<void *>(down_filter.data_ptr());
228
+ void *alpha_ptr = static_cast<void *>(alpha.data_ptr());
229
+ void *beta_ptr = static_cast<void *>(beta.data_ptr());
230
+ void *anti_alias_activation_results_ptr = static_cast<void *>(anti_alias_activation_results.data_ptr());
231
+
232
+ DISPATCH_FLOAT_HALF_AND_BFLOAT(
233
+ input.scalar_type(),
234
+ "dispatch anti alias activation_forward",
235
+ dispatch_anti_alias_activation_forward<scalar_t, scalar_t, float>(
236
+ reinterpret_cast<scalar_t *>(anti_alias_activation_results_ptr),
237
+ reinterpret_cast<const scalar_t *>(input_ptr),
238
+ reinterpret_cast<const scalar_t *>(up_filter_ptr),
239
+ reinterpret_cast<const scalar_t *>(down_filter_ptr),
240
+ reinterpret_cast<const scalar_t *>(alpha_ptr),
241
+ reinterpret_cast<const scalar_t *>(beta_ptr),
242
+ batches,
243
+ channels,
244
+ seq_len););
245
+ return anti_alias_activation_results;
246
+ }
mmaudio/ext/bigvgan_v2/alias_free_activation/cuda/compat.h ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* coding=utf-8
2
+ * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ /*This code is copied fron NVIDIA apex:
18
+ * https://github.com/NVIDIA/apex
19
+ * with minor changes. */
20
+
21
+ #ifndef TORCH_CHECK
22
+ #define TORCH_CHECK AT_CHECK
23
+ #endif
24
+
25
+ #ifdef VERSION_GE_1_3
26
+ #define DATA_PTR data_ptr
27
+ #else
28
+ #define DATA_PTR data
29
+ #endif
mmaudio/ext/bigvgan_v2/alias_free_activation/cuda/load.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ import os
5
+ import pathlib
6
+ import subprocess
7
+
8
+ from torch.utils import cpp_extension
9
+
10
+ """
11
+ Setting this param to a list has a problem of generating different compilation commands (with diferent order of architectures) and leading to recompilation of fused kernels.
12
+ Set it to empty stringo avoid recompilation and assign arch flags explicity in extra_cuda_cflags below
13
+ """
14
+ os.environ["TORCH_CUDA_ARCH_LIST"] = ""
15
+
16
+
17
+ def load():
18
+ # Check if cuda 11 is installed for compute capability 8.0
19
+ cc_flag = []
20
+ _, bare_metal_major, _ = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
21
+ if int(bare_metal_major) >= 11:
22
+ cc_flag.append("-gencode")
23
+ cc_flag.append("arch=compute_80,code=sm_80")
24
+
25
+ # Build path
26
+ srcpath = pathlib.Path(__file__).parent.absolute()
27
+ buildpath = srcpath / "build"
28
+ _create_build_dir(buildpath)
29
+
30
+ # Helper function to build the kernels.
31
+ def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
32
+ return cpp_extension.load(
33
+ name=name,
34
+ sources=sources,
35
+ build_directory=buildpath,
36
+ extra_cflags=[
37
+ "-O3",
38
+ ],
39
+ extra_cuda_cflags=[
40
+ "-O3",
41
+ "-gencode",
42
+ "arch=compute_70,code=sm_70",
43
+ "--use_fast_math",
44
+ ]
45
+ + extra_cuda_flags
46
+ + cc_flag,
47
+ verbose=True,
48
+ )
49
+
50
+ extra_cuda_flags = [
51
+ "-U__CUDA_NO_HALF_OPERATORS__",
52
+ "-U__CUDA_NO_HALF_CONVERSIONS__",
53
+ "--expt-relaxed-constexpr",
54
+ "--expt-extended-lambda",
55
+ ]
56
+
57
+ sources = [
58
+ srcpath / "anti_alias_activation.cpp",
59
+ srcpath / "anti_alias_activation_cuda.cu",
60
+ ]
61
+ anti_alias_activation_cuda = _cpp_extention_load_helper(
62
+ "anti_alias_activation_cuda", sources, extra_cuda_flags
63
+ )
64
+
65
+ return anti_alias_activation_cuda
66
+
67
+
68
+ def _get_cuda_bare_metal_version(cuda_dir):
69
+ raw_output = subprocess.check_output(
70
+ [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
71
+ )
72
+ output = raw_output.split()
73
+ release_idx = output.index("release") + 1
74
+ release = output[release_idx].split(".")
75
+ bare_metal_major = release[0]
76
+ bare_metal_minor = release[1][0]
77
+
78
+ return raw_output, bare_metal_major, bare_metal_minor
79
+
80
+
81
+ def _create_build_dir(buildpath):
82
+ try:
83
+ os.mkdir(buildpath)
84
+ except OSError:
85
+ if not os.path.isdir(buildpath):
86
+ print(f"Creation of the build directory {buildpath} failed")
mmaudio/ext/bigvgan_v2/alias_free_activation/cuda/type_shim.h ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* coding=utf-8
2
+ * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #include <ATen/ATen.h>
18
+ #include "compat.h"
19
+
20
+ #define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, ...) \
21
+ switch (TYPE) \
22
+ { \
23
+ case at::ScalarType::Float: \
24
+ { \
25
+ using scalar_t = float; \
26
+ __VA_ARGS__; \
27
+ break; \
28
+ } \
29
+ case at::ScalarType::Half: \
30
+ { \
31
+ using scalar_t = at::Half; \
32
+ __VA_ARGS__; \
33
+ break; \
34
+ } \
35
+ case at::ScalarType::BFloat16: \
36
+ { \
37
+ using scalar_t = at::BFloat16; \
38
+ __VA_ARGS__; \
39
+ break; \
40
+ } \
41
+ default: \
42
+ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
43
+ }
44
+
45
+ #define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
46
+ switch (TYPEIN) \
47
+ { \
48
+ case at::ScalarType::Float: \
49
+ { \
50
+ using scalar_t_in = float; \
51
+ switch (TYPEOUT) \
52
+ { \
53
+ case at::ScalarType::Float: \
54
+ { \
55
+ using scalar_t_out = float; \
56
+ __VA_ARGS__; \
57
+ break; \
58
+ } \
59
+ case at::ScalarType::Half: \
60
+ { \
61
+ using scalar_t_out = at::Half; \
62
+ __VA_ARGS__; \
63
+ break; \
64
+ } \
65
+ case at::ScalarType::BFloat16: \
66
+ { \
67
+ using scalar_t_out = at::BFloat16; \
68
+ __VA_ARGS__; \
69
+ break; \
70
+ } \
71
+ default: \
72
+ AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
73
+ } \
74
+ break; \
75
+ } \
76
+ case at::ScalarType::Half: \
77
+ { \
78
+ using scalar_t_in = at::Half; \
79
+ using scalar_t_out = at::Half; \
80
+ __VA_ARGS__; \
81
+ break; \
82
+ } \
83
+ case at::ScalarType::BFloat16: \
84
+ { \
85
+ using scalar_t_in = at::BFloat16; \
86
+ using scalar_t_out = at::BFloat16; \
87
+ __VA_ARGS__; \
88
+ break; \
89
+ } \
90
+ default: \
91
+ AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
92
+ }
mmaudio/ext/bigvgan_v2/alias_free_activation/torch/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ from .filter import *
5
+ from .resample import *
6
+ from .act import *
mmaudio/ext/bigvgan_v2/alias_free_activation/torch/act.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch.nn as nn
5
+
6
+ from mmaudio.ext.bigvgan_v2.alias_free_activation.torch.resample import (DownSample1d, UpSample1d)
7
+
8
+
9
+ class Activation1d(nn.Module):
10
+
11
+ def __init__(
12
+ self,
13
+ activation,
14
+ up_ratio: int = 2,
15
+ down_ratio: int = 2,
16
+ up_kernel_size: int = 12,
17
+ down_kernel_size: int = 12,
18
+ ):
19
+ super().__init__()
20
+ self.up_ratio = up_ratio
21
+ self.down_ratio = down_ratio
22
+ self.act = activation
23
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
24
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
25
+
26
+ # x: [B,C,T]
27
+ def forward(self, x):
28
+ x = self.upsample(x)
29
+ x = self.act(x)
30
+ x = self.downsample(x)
31
+
32
+ return x
mmaudio/ext/bigvgan_v2/alias_free_activation/torch/filter.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import math
8
+
9
+ if "sinc" in dir(torch):
10
+ sinc = torch.sinc
11
+ else:
12
+ # This code is adopted from adefossez's julius.core.sinc under the MIT License
13
+ # https://adefossez.github.io/julius/julius/core.html
14
+ # LICENSE is in incl_licenses directory.
15
+ def sinc(x: torch.Tensor):
16
+ """
17
+ Implementation of sinc, i.e. sin(pi * x) / (pi * x)
18
+ __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
19
+ """
20
+ return torch.where(
21
+ x == 0,
22
+ torch.tensor(1.0, device=x.device, dtype=x.dtype),
23
+ torch.sin(math.pi * x) / math.pi / x,
24
+ )
25
+
26
+
27
+ # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
28
+ # https://adefossez.github.io/julius/julius/lowpass.html
29
+ # LICENSE is in incl_licenses directory.
30
+ def kaiser_sinc_filter1d(
31
+ cutoff, half_width, kernel_size
32
+ ): # return filter [1,1,kernel_size]
33
+ even = kernel_size % 2 == 0
34
+ half_size = kernel_size // 2
35
+
36
+ # For kaiser window
37
+ delta_f = 4 * half_width
38
+ A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
39
+ if A > 50.0:
40
+ beta = 0.1102 * (A - 8.7)
41
+ elif A >= 21.0:
42
+ beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
43
+ else:
44
+ beta = 0.0
45
+ window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
46
+
47
+ # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
48
+ if even:
49
+ time = torch.arange(-half_size, half_size) + 0.5
50
+ else:
51
+ time = torch.arange(kernel_size) - half_size
52
+ if cutoff == 0:
53
+ filter_ = torch.zeros_like(time)
54
+ else:
55
+ filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
56
+ """
57
+ Normalize filter to have sum = 1, otherwise we will have a small leakage of the constant component in the input signal.
58
+ """
59
+ filter_ /= filter_.sum()
60
+ filter = filter_.view(1, 1, kernel_size)
61
+
62
+ return filter
63
+
64
+
65
+ class LowPassFilter1d(nn.Module):
66
+ def __init__(
67
+ self,
68
+ cutoff=0.5,
69
+ half_width=0.6,
70
+ stride: int = 1,
71
+ padding: bool = True,
72
+ padding_mode: str = "replicate",
73
+ kernel_size: int = 12,
74
+ ):
75
+ """
76
+ kernel_size should be even number for stylegan3 setup, in this implementation, odd number is also possible.
77
+ """
78
+ super().__init__()
79
+ if cutoff < -0.0:
80
+ raise ValueError("Minimum cutoff must be larger than zero.")
81
+ if cutoff > 0.5:
82
+ raise ValueError("A cutoff above 0.5 does not make sense.")
83
+ self.kernel_size = kernel_size
84
+ self.even = kernel_size % 2 == 0
85
+ self.pad_left = kernel_size // 2 - int(self.even)
86
+ self.pad_right = kernel_size // 2
87
+ self.stride = stride
88
+ self.padding = padding
89
+ self.padding_mode = padding_mode
90
+ filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
91
+ self.register_buffer("filter", filter)
92
+
93
+ # Input [B, C, T]
94
+ def forward(self, x):
95
+ _, C, _ = x.shape
96
+
97
+ if self.padding:
98
+ x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
99
+ out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
100
+
101
+ return out
mmaudio/ext/bigvgan_v2/alias_free_activation/torch/resample.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+
7
+ from mmaudio.ext.bigvgan_v2.alias_free_activation.torch.filter import (LowPassFilter1d,
8
+ kaiser_sinc_filter1d)
9
+
10
+
11
+ class UpSample1d(nn.Module):
12
+
13
+ def __init__(self, ratio=2, kernel_size=None):
14
+ super().__init__()
15
+ self.ratio = ratio
16
+ self.kernel_size = (int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size)
17
+ self.stride = ratio
18
+ self.pad = self.kernel_size // ratio - 1
19
+ self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
20
+ self.pad_right = (self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2)
21
+ filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
22
+ half_width=0.6 / ratio,
23
+ kernel_size=self.kernel_size)
24
+ self.register_buffer("filter", filter)
25
+
26
+ # x: [B, C, T]
27
+ def forward(self, x):
28
+ _, C, _ = x.shape
29
+
30
+ x = F.pad(x, (self.pad, self.pad), mode="replicate")
31
+ x = self.ratio * F.conv_transpose1d(
32
+ x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
33
+ x = x[..., self.pad_left:-self.pad_right]
34
+
35
+ return x
36
+
37
+
38
+ class DownSample1d(nn.Module):
39
+
40
+ def __init__(self, ratio=2, kernel_size=None):
41
+ super().__init__()
42
+ self.ratio = ratio
43
+ self.kernel_size = (int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size)
44
+ self.lowpass = LowPassFilter1d(
45
+ cutoff=0.5 / ratio,
46
+ half_width=0.6 / ratio,
47
+ stride=ratio,
48
+ kernel_size=self.kernel_size,
49
+ )
50
+
51
+ def forward(self, x):
52
+ xx = self.lowpass(x)
53
+
54
+ return xx