hugo flores garcia
commited on
Commit
•
2b3cdf0
1
Parent(s):
e612fff
gitattributes
Browse files- .gitattributes +4 -0
- .gitignore +191 -0
- .pre-commit-config.yaml +15 -0
- LICENSE +21 -0
- app.py +704 -0
- assets/example.wav +0 -0
- conf/c2f.yml +14 -0
- conf/interface.yml +10 -0
- conf/lora/lora.yml +22 -0
- conf/salad_bowl.yml +0 -0
- conf/vampnet.yml +49 -0
- requirements.txt +10 -0
- scripts/exp/eval.py +110 -0
- scripts/exp/experiment.py +254 -0
- scripts/exp/fine_tune.py +81 -0
- scripts/exp/train.py +686 -0
- scripts/utils/README.md +28 -0
- scripts/utils/gtzan_embeddings.py +264 -0
- scripts/utils/plots.py +43 -0
- scripts/utils/remove_quiet_files.py +29 -0
- scripts/utils/split.py +64 -0
- scripts/utils/split_long_audio_file.py +34 -0
- scripts/utils/stage.py +30 -0
- scripts/utils/visualize_embeddings.py +265 -0
- scripts/utils/xeno-canto-dl.py +234 -0
- setup.py +43 -0
- vampnet/__init__.py +6 -0
- vampnet/beats.py +250 -0
- vampnet/interface.py +432 -0
- vampnet/mask.py +242 -0
- vampnet/modules/__init__.py +6 -0
- vampnet/modules/activations.py +55 -0
- vampnet/modules/layers.py +164 -0
- vampnet/modules/transformer.py +953 -0
- vampnet/scheduler.py +47 -0
- vampnet/util.py +46 -0
.gitattributes
CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
coarse.pth filter=lfs diff=lfs merge=lfs -text
|
37 |
+
c2f.pth filter=lfs diff=lfs merge=lfs -text
|
38 |
+
wavebeat.pth filter=lfs diff=lfs merge=lfs -text
|
39 |
+
codec.pth filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# IPython
|
81 |
+
profile_default/
|
82 |
+
ipython_config.py
|
83 |
+
|
84 |
+
# pyenv
|
85 |
+
.python-version
|
86 |
+
|
87 |
+
# pipenv
|
88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
91 |
+
# install all needed dependencies.
|
92 |
+
#Pipfile.lock
|
93 |
+
|
94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
95 |
+
__pypackages__/
|
96 |
+
|
97 |
+
# Celery stuff
|
98 |
+
celerybeat-schedule
|
99 |
+
celerybeat.pid
|
100 |
+
|
101 |
+
# SageMath parsed files
|
102 |
+
*.sage.py
|
103 |
+
|
104 |
+
# Environments
|
105 |
+
.env
|
106 |
+
.venv
|
107 |
+
env/env.sh
|
108 |
+
venv/
|
109 |
+
env.bak/
|
110 |
+
venv.bak/
|
111 |
+
|
112 |
+
# Spyder project settings
|
113 |
+
.spyderproject
|
114 |
+
.spyproject
|
115 |
+
|
116 |
+
# Rope project settings
|
117 |
+
.ropeproject
|
118 |
+
|
119 |
+
# mkdocs documentation
|
120 |
+
/site
|
121 |
+
|
122 |
+
# mypy
|
123 |
+
.mypy_cache/
|
124 |
+
.dmypy.json
|
125 |
+
dmypy.json
|
126 |
+
|
127 |
+
# Pyre type checker
|
128 |
+
.pyre/
|
129 |
+
|
130 |
+
# Files created by experiments
|
131 |
+
output/
|
132 |
+
snapshot/
|
133 |
+
*.m4a
|
134 |
+
notebooks/scratch.ipynb
|
135 |
+
notebooks/inspect.ipynb
|
136 |
+
notebooks/effects.ipynb
|
137 |
+
notebooks/*.ipynb
|
138 |
+
notebooks/*.gif
|
139 |
+
notebooks/*.wav
|
140 |
+
notebooks/*.mp4
|
141 |
+
*runs/
|
142 |
+
boards/
|
143 |
+
samples/
|
144 |
+
*.ipynb
|
145 |
+
|
146 |
+
results.json
|
147 |
+
metrics.csv
|
148 |
+
mprofile_*
|
149 |
+
mem.png
|
150 |
+
|
151 |
+
results/
|
152 |
+
mprofile*
|
153 |
+
*.png
|
154 |
+
# do not ignore the test wav file
|
155 |
+
!tests/audio/short_test_audio.wav
|
156 |
+
!tests/audio/output.wav
|
157 |
+
*/.DS_Store
|
158 |
+
.DS_Store
|
159 |
+
env.sh
|
160 |
+
_codebraid/
|
161 |
+
**/*.html
|
162 |
+
**/*.exec.md
|
163 |
+
flagged/
|
164 |
+
log.txt
|
165 |
+
ckpt/
|
166 |
+
.syncthing*
|
167 |
+
tests/assets/
|
168 |
+
archived/
|
169 |
+
|
170 |
+
scratch/
|
171 |
+
|
172 |
+
runs-archive
|
173 |
+
lyrebird-audiotools
|
174 |
+
lyrebird-audio-codec
|
175 |
+
samples-*/**
|
176 |
+
|
177 |
+
gradio-outputs/
|
178 |
+
samples*/
|
179 |
+
models-all/
|
180 |
+
models.zip
|
181 |
+
.git-old
|
182 |
+
conf/generated/*
|
183 |
+
runs*/
|
184 |
+
|
185 |
+
|
186 |
+
gtzan.zip
|
187 |
+
.gtzan_emb_cache
|
188 |
+
runs
|
189 |
+
|
190 |
+
data/
|
191 |
+
src/
|
.pre-commit-config.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
repos:
|
2 |
+
- repo: https://github.com/asottile/reorder_python_imports
|
3 |
+
rev: v2.5.0
|
4 |
+
hooks:
|
5 |
+
- id: reorder-python-imports
|
6 |
+
- repo: https://github.com/psf/black
|
7 |
+
rev: 23.1.0
|
8 |
+
hooks:
|
9 |
+
- id: black
|
10 |
+
language_version: python3
|
11 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
12 |
+
rev: v4.0.1
|
13 |
+
hooks:
|
14 |
+
- id: end-of-file-fixer
|
15 |
+
- id: trailing-whitespace
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 Hugo Flores García and Prem Seetharaman
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
app.py
ADDED
@@ -0,0 +1,704 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# huggingface space exclusive
|
2 |
+
import os
|
3 |
+
|
4 |
+
# print("installing pyharp")
|
5 |
+
# os.system('pip install "pyharp@git+https://github.com/audacitorch/pyharp.git"')
|
6 |
+
# print("installing madmom")
|
7 |
+
os.system('pip install cython')
|
8 |
+
os.system('pip install madmom')
|
9 |
+
|
10 |
+
from pathlib import Path
|
11 |
+
from typing import Tuple
|
12 |
+
import yaml
|
13 |
+
import tempfile
|
14 |
+
import uuid
|
15 |
+
import shutil
|
16 |
+
from dataclasses import dataclass, asdict
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import audiotools as at
|
20 |
+
import argbind
|
21 |
+
import torch
|
22 |
+
|
23 |
+
import gradio as gr
|
24 |
+
from vampnet.interface import Interface
|
25 |
+
from vampnet import mask as pmask
|
26 |
+
|
27 |
+
from pyharp import ModelCard, build_endpoint
|
28 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
29 |
+
|
30 |
+
interface = Interface(
|
31 |
+
device=device,
|
32 |
+
coarse_ckpt="models/vampvat/coarse.pth",
|
33 |
+
coarse2fine_ckpt="models/vampvat/c2f.pth",
|
34 |
+
codec_ckpt="models/vampvat/codec.pth",
|
35 |
+
)
|
36 |
+
|
37 |
+
# populate the model choices with any interface.yml files in the generated confs
|
38 |
+
MODEL_CHOICES = {
|
39 |
+
"default": {
|
40 |
+
"Interface.coarse_ckpt": str(interface.coarse_path),
|
41 |
+
"Interface.coarse2fine_ckpt": str(interface.c2f_path),
|
42 |
+
"Interface.codec_ckpt": str(interface.codec_path),
|
43 |
+
}
|
44 |
+
}
|
45 |
+
generated_confs = Path("conf/generated")
|
46 |
+
for conf_file in generated_confs.glob("*/interface.yml"):
|
47 |
+
with open(conf_file) as f:
|
48 |
+
_conf = yaml.safe_load(f)
|
49 |
+
MODEL_CHOICES[conf_file.parent.name] = _conf
|
50 |
+
|
51 |
+
|
52 |
+
|
53 |
+
OUT_DIR = Path("gradio-outputs")
|
54 |
+
OUT_DIR.mkdir(exist_ok=True, parents=True)
|
55 |
+
|
56 |
+
|
57 |
+
def load_audio(file):
|
58 |
+
print(file)
|
59 |
+
filepath = file.name
|
60 |
+
sig = at.AudioSignal.salient_excerpt(
|
61 |
+
filepath,
|
62 |
+
duration=interface.coarse.chunk_size_s
|
63 |
+
)
|
64 |
+
sig = interface.preprocess(sig)
|
65 |
+
|
66 |
+
out_dir = OUT_DIR / "tmp" / str(uuid.uuid4())
|
67 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
68 |
+
sig.write(out_dir / "input.wav")
|
69 |
+
return sig.path_to_file
|
70 |
+
|
71 |
+
|
72 |
+
def load_example_audio():
|
73 |
+
return "./assets/example.wav"
|
74 |
+
|
75 |
+
from torch_pitch_shift import pitch_shift, get_fast_shifts
|
76 |
+
def shift_pitch(signal, interval: int):
|
77 |
+
signal.samples = pitch_shift(
|
78 |
+
signal.samples,
|
79 |
+
shift=interval,
|
80 |
+
sample_rate=signal.sample_rate
|
81 |
+
)
|
82 |
+
return signal
|
83 |
+
|
84 |
+
def _vamp(data, return_mask=False):
|
85 |
+
# remove any old files in the output directory (from previous runs)
|
86 |
+
shutil.rmtree(OUT_DIR)
|
87 |
+
OUT_DIR.mkdir()
|
88 |
+
|
89 |
+
out_dir = OUT_DIR / str(uuid.uuid4())
|
90 |
+
out_dir.mkdir()
|
91 |
+
sig = at.AudioSignal(data[input_audio])
|
92 |
+
sig = interface.preprocess(sig)
|
93 |
+
|
94 |
+
# reload the model if necessary
|
95 |
+
interface.reload(
|
96 |
+
coarse_ckpt=MODEL_CHOICES[data[model_choice]]["Interface.coarse_ckpt"],
|
97 |
+
c2f_ckpt=MODEL_CHOICES[data[model_choice]]["Interface.coarse2fine_ckpt"],
|
98 |
+
)
|
99 |
+
|
100 |
+
loudness = sig.loudness()
|
101 |
+
print(f"input loudness is {loudness}")
|
102 |
+
|
103 |
+
if data[pitch_shift_amt] != 0:
|
104 |
+
sig = shift_pitch(sig, data[pitch_shift_amt])
|
105 |
+
|
106 |
+
z = interface.encode(sig)
|
107 |
+
|
108 |
+
ncc = data[n_conditioning_codebooks]
|
109 |
+
|
110 |
+
# build the mask
|
111 |
+
mask = pmask.linear_random(z, data[rand_mask_intensity])
|
112 |
+
mask = pmask.mask_and(
|
113 |
+
mask, pmask.inpaint(
|
114 |
+
z,
|
115 |
+
interface.s2t(data[prefix_s]),
|
116 |
+
interface.s2t(data[suffix_s])
|
117 |
+
)
|
118 |
+
)
|
119 |
+
mask = pmask.mask_and(
|
120 |
+
mask, pmask.periodic_mask(
|
121 |
+
z,
|
122 |
+
data[periodic_p],
|
123 |
+
data[periodic_w],
|
124 |
+
random_roll=True
|
125 |
+
)
|
126 |
+
)
|
127 |
+
if data[onset_mask_width] > 0:
|
128 |
+
mask = pmask.mask_or(
|
129 |
+
mask, pmask.onset_mask(sig, z, interface, width=data[onset_mask_width])
|
130 |
+
)
|
131 |
+
if data[beat_mask_width] > 0:
|
132 |
+
beat_mask = interface.make_beat_mask(
|
133 |
+
sig,
|
134 |
+
after_beat_s=(data[beat_mask_width]/1000),
|
135 |
+
mask_upbeats=not data[beat_mask_downbeats],
|
136 |
+
)
|
137 |
+
mask = pmask.mask_and(mask, beat_mask)
|
138 |
+
|
139 |
+
# these should be the last two mask ops
|
140 |
+
mask = pmask.dropout(mask, data[dropout])
|
141 |
+
mask = pmask.codebook_unmask(mask, ncc)
|
142 |
+
mask = pmask.codebook_mask(mask, int(data[n_mask_codebooks]))
|
143 |
+
|
144 |
+
print(f"dropout {data[dropout]}")
|
145 |
+
print(f"masktemp {data[masktemp]}")
|
146 |
+
print(f"sampletemp {data[sampletemp]}")
|
147 |
+
print(f"top_p {data[top_p]}")
|
148 |
+
print(f"prefix_s {data[prefix_s]}")
|
149 |
+
print(f"suffix_s {data[suffix_s]}")
|
150 |
+
print(f"rand_mask_intensity {data[rand_mask_intensity]}")
|
151 |
+
print(f"num_steps {data[num_steps]}")
|
152 |
+
print(f"periodic_p {data[periodic_p]}")
|
153 |
+
print(f"periodic_w {data[periodic_w]}")
|
154 |
+
print(f"n_conditioning_codebooks {data[n_conditioning_codebooks]}")
|
155 |
+
print(f"use_coarse2fine {data[use_coarse2fine]}")
|
156 |
+
print(f"onset_mask_width {data[onset_mask_width]}")
|
157 |
+
print(f"beat_mask_width {data[beat_mask_width]}")
|
158 |
+
print(f"beat_mask_downbeats {data[beat_mask_downbeats]}")
|
159 |
+
print(f"stretch_factor {data[stretch_factor]}")
|
160 |
+
print(f"seed {data[seed]}")
|
161 |
+
print(f"pitch_shift_amt {data[pitch_shift_amt]}")
|
162 |
+
print(f"sample_cutoff {data[sample_cutoff]}")
|
163 |
+
|
164 |
+
|
165 |
+
_top_p = data[top_p] if data[top_p] > 0 else None
|
166 |
+
# save the mask as a txt file
|
167 |
+
np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
|
168 |
+
|
169 |
+
_seed = data[seed] if data[seed] > 0 else None
|
170 |
+
print(f"processing coarse...")
|
171 |
+
zv, mask_z = interface.coarse_vamp(
|
172 |
+
z,
|
173 |
+
mask=mask,
|
174 |
+
sampling_steps=data[num_steps],
|
175 |
+
mask_temperature=data[masktemp]*10,
|
176 |
+
sampling_temperature=data[sampletemp],
|
177 |
+
return_mask=True,
|
178 |
+
typical_filtering=data[typical_filtering],
|
179 |
+
typical_mass=data[typical_mass],
|
180 |
+
typical_min_tokens=data[typical_min_tokens],
|
181 |
+
top_p=_top_p,
|
182 |
+
gen_fn=interface.coarse.generate,
|
183 |
+
seed=_seed,
|
184 |
+
sample_cutoff=data[sample_cutoff],
|
185 |
+
)
|
186 |
+
|
187 |
+
if use_coarse2fine:
|
188 |
+
print(f"processing coarse to fine...")
|
189 |
+
zv = interface.coarse_to_fine(
|
190 |
+
zv,
|
191 |
+
mask_temperature=data[masktemp]*10,
|
192 |
+
sampling_temperature=data[sampletemp],
|
193 |
+
mask=mask,
|
194 |
+
sampling_steps=data[num_steps] // 2,
|
195 |
+
sample_cutoff=data[sample_cutoff],
|
196 |
+
seed=_seed,
|
197 |
+
)
|
198 |
+
|
199 |
+
sig = interface.to_signal(zv).cpu()
|
200 |
+
print("done")
|
201 |
+
|
202 |
+
|
203 |
+
|
204 |
+
sig.write(out_dir / "output.wav")
|
205 |
+
|
206 |
+
if return_mask:
|
207 |
+
mask = interface.to_signal(mask_z).cpu()
|
208 |
+
mask.write(out_dir / "mask.wav")
|
209 |
+
return sig.path_to_file, mask.path_to_file
|
210 |
+
else:
|
211 |
+
return sig.path_to_file
|
212 |
+
|
213 |
+
def vamp(data):
|
214 |
+
return _vamp(data, return_mask=True)
|
215 |
+
|
216 |
+
def api_vamp(data):
|
217 |
+
return _vamp(data, return_mask=False)
|
218 |
+
|
219 |
+
def save_vamp(data):
|
220 |
+
out_dir = OUT_DIR / "saved" / str(uuid.uuid4())
|
221 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
222 |
+
|
223 |
+
sig_in = at.AudioSignal(data[input_audio])
|
224 |
+
sig_out = at.AudioSignal(data[output_audio])
|
225 |
+
|
226 |
+
sig_in.write(out_dir / "input.wav")
|
227 |
+
sig_out.write(out_dir / "output.wav")
|
228 |
+
|
229 |
+
_data = {
|
230 |
+
"masktemp": data[masktemp],
|
231 |
+
"sampletemp": data[sampletemp],
|
232 |
+
"top_p": data[top_p],
|
233 |
+
"prefix_s": data[prefix_s],
|
234 |
+
"suffix_s": data[suffix_s],
|
235 |
+
"rand_mask_intensity": data[rand_mask_intensity],
|
236 |
+
"num_steps": data[num_steps],
|
237 |
+
"notes": data[notes_text],
|
238 |
+
"periodic_period": data[periodic_p],
|
239 |
+
"periodic_width": data[periodic_w],
|
240 |
+
"n_conditioning_codebooks": data[n_conditioning_codebooks],
|
241 |
+
"use_coarse2fine": data[use_coarse2fine],
|
242 |
+
"stretch_factor": data[stretch_factor],
|
243 |
+
"seed": data[seed],
|
244 |
+
"samplecutoff": data[sample_cutoff],
|
245 |
+
}
|
246 |
+
|
247 |
+
# save with yaml
|
248 |
+
with open(out_dir / "data.yaml", "w") as f:
|
249 |
+
yaml.dump(_data, f)
|
250 |
+
|
251 |
+
import zipfile
|
252 |
+
zip_path = out_dir.with_suffix(".zip")
|
253 |
+
with zipfile.ZipFile(zip_path, "w") as zf:
|
254 |
+
for file in out_dir.iterdir():
|
255 |
+
zf.write(file, file.name)
|
256 |
+
|
257 |
+
return f"saved! your save code is {out_dir.stem}", zip_path
|
258 |
+
|
259 |
+
|
260 |
+
def harp_vamp(_input_audio, _beat_mask_width, _sampletemp):
|
261 |
+
|
262 |
+
out_dir = OUT_DIR / str(uuid.uuid4())
|
263 |
+
out_dir.mkdir()
|
264 |
+
sig = at.AudioSignal(_input_audio)
|
265 |
+
sig = interface.preprocess(sig)
|
266 |
+
|
267 |
+
z = interface.encode(sig)
|
268 |
+
|
269 |
+
# build the mask
|
270 |
+
mask = pmask.linear_random(z, 1.0)
|
271 |
+
if _beat_mask_width > 0:
|
272 |
+
beat_mask = interface.make_beat_mask(
|
273 |
+
sig,
|
274 |
+
after_beat_s=(_beat_mask_width/1000),
|
275 |
+
)
|
276 |
+
mask = pmask.mask_and(mask, beat_mask)
|
277 |
+
|
278 |
+
# save the mask as a txt file
|
279 |
+
zv, mask_z = interface.coarse_vamp(
|
280 |
+
z,
|
281 |
+
mask=mask,
|
282 |
+
sampling_temperature=_sampletemp,
|
283 |
+
return_mask=True,
|
284 |
+
gen_fn=interface.coarse.generate,
|
285 |
+
)
|
286 |
+
|
287 |
+
|
288 |
+
zv = interface.coarse_to_fine(
|
289 |
+
zv,
|
290 |
+
sampling_temperature=_sampletemp,
|
291 |
+
mask=mask,
|
292 |
+
)
|
293 |
+
|
294 |
+
sig = interface.to_signal(zv).cpu()
|
295 |
+
print("done")
|
296 |
+
|
297 |
+
sig.write(out_dir / "output.wav")
|
298 |
+
|
299 |
+
return sig.path_to_file
|
300 |
+
|
301 |
+
with gr.Blocks() as demo:
|
302 |
+
|
303 |
+
with gr.Row():
|
304 |
+
with gr.Column():
|
305 |
+
gr.Markdown("# VampNet Audio Vamping")
|
306 |
+
gr.Markdown("""## Description:
|
307 |
+
This is a demo of the VampNet, a generative audio model that transforms the input audio based on the chosen settings.
|
308 |
+
You can control the extent and nature of variation with a set of manual controls and presets.
|
309 |
+
Use this interface to experiment with different mask settings and explore the audio outputs.
|
310 |
+
""")
|
311 |
+
|
312 |
+
gr.Markdown("""
|
313 |
+
## Instructions:
|
314 |
+
1. You can start by uploading some audio, or by loading the example audio.
|
315 |
+
2. Choose a preset for the vamp operation, or manually adjust the controls to customize the mask settings.
|
316 |
+
3. Click the "generate (vamp)!!!" button to apply the vamp operation. Listen to the output audio.
|
317 |
+
4. Optionally, you can add some notes and save the result.
|
318 |
+
5. You can also use the output as the new input and continue experimenting!
|
319 |
+
""")
|
320 |
+
with gr.Row():
|
321 |
+
with gr.Column():
|
322 |
+
|
323 |
+
|
324 |
+
manual_audio_upload = gr.File(
|
325 |
+
label=f"upload some audio (will be randomly trimmed to max of {interface.coarse.chunk_size_s:.2f}s)",
|
326 |
+
file_types=["audio"]
|
327 |
+
)
|
328 |
+
load_example_audio_button = gr.Button("or load example audio")
|
329 |
+
|
330 |
+
input_audio = gr.Audio(
|
331 |
+
label="input audio",
|
332 |
+
interactive=False,
|
333 |
+
type="filepath",
|
334 |
+
)
|
335 |
+
|
336 |
+
audio_mask = gr.Audio(
|
337 |
+
label="audio mask (listen to this to hear the mask hints)",
|
338 |
+
interactive=False,
|
339 |
+
type="filepath",
|
340 |
+
)
|
341 |
+
|
342 |
+
# connect widgets
|
343 |
+
load_example_audio_button.click(
|
344 |
+
fn=load_example_audio,
|
345 |
+
inputs=[],
|
346 |
+
outputs=[ input_audio]
|
347 |
+
)
|
348 |
+
|
349 |
+
manual_audio_upload.change(
|
350 |
+
fn=load_audio,
|
351 |
+
inputs=[manual_audio_upload],
|
352 |
+
outputs=[ input_audio]
|
353 |
+
)
|
354 |
+
|
355 |
+
# mask settings
|
356 |
+
with gr.Column():
|
357 |
+
|
358 |
+
|
359 |
+
presets = {
|
360 |
+
"unconditional": {
|
361 |
+
"periodic_p": 0,
|
362 |
+
"onset_mask_width": 0,
|
363 |
+
"beat_mask_width": 0,
|
364 |
+
"beat_mask_downbeats": False,
|
365 |
+
},
|
366 |
+
"slight periodic variation": {
|
367 |
+
"periodic_p": 5,
|
368 |
+
"onset_mask_width": 5,
|
369 |
+
"beat_mask_width": 0,
|
370 |
+
"beat_mask_downbeats": False,
|
371 |
+
},
|
372 |
+
"moderate periodic variation": {
|
373 |
+
"periodic_p": 13,
|
374 |
+
"onset_mask_width": 5,
|
375 |
+
"beat_mask_width": 0,
|
376 |
+
"beat_mask_downbeats": False,
|
377 |
+
},
|
378 |
+
"strong periodic variation": {
|
379 |
+
"periodic_p": 17,
|
380 |
+
"onset_mask_width": 5,
|
381 |
+
"beat_mask_width": 0,
|
382 |
+
"beat_mask_downbeats": False,
|
383 |
+
},
|
384 |
+
"very strong periodic variation": {
|
385 |
+
"periodic_p": 21,
|
386 |
+
"onset_mask_width": 5,
|
387 |
+
"beat_mask_width": 0,
|
388 |
+
"beat_mask_downbeats": False,
|
389 |
+
},
|
390 |
+
"beat-driven variation": {
|
391 |
+
"periodic_p": 0,
|
392 |
+
"onset_mask_width": 0,
|
393 |
+
"beat_mask_width": 50,
|
394 |
+
"beat_mask_downbeats": False,
|
395 |
+
},
|
396 |
+
"beat-driven variation (downbeats only)": {
|
397 |
+
"periodic_p": 0,
|
398 |
+
"onset_mask_width": 0,
|
399 |
+
"beat_mask_width": 50,
|
400 |
+
"beat_mask_downbeats": True,
|
401 |
+
},
|
402 |
+
"beat-driven variation (downbeats only, strong)": {
|
403 |
+
"periodic_p": 0,
|
404 |
+
"onset_mask_width": 0,
|
405 |
+
"beat_mask_width": 20,
|
406 |
+
"beat_mask_downbeats": True,
|
407 |
+
},
|
408 |
+
}
|
409 |
+
|
410 |
+
preset = gr.Dropdown(
|
411 |
+
label="preset",
|
412 |
+
choices=list(presets.keys()),
|
413 |
+
value="strong periodic variation",
|
414 |
+
)
|
415 |
+
load_preset_button = gr.Button("load_preset")
|
416 |
+
|
417 |
+
with gr.Accordion("manual controls", open=True):
|
418 |
+
periodic_p = gr.Slider(
|
419 |
+
label="periodic prompt (0 - unconditional, 2 - lots of hints, 8 - a couple of hints, 16 - occasional hint, 32 - very occasional hint, etc)",
|
420 |
+
minimum=0,
|
421 |
+
maximum=128,
|
422 |
+
step=1,
|
423 |
+
value=3,
|
424 |
+
)
|
425 |
+
|
426 |
+
|
427 |
+
onset_mask_width = gr.Slider(
|
428 |
+
label="onset mask width (multiplies with the periodic mask, 1 step ~= 10milliseconds) ",
|
429 |
+
minimum=0,
|
430 |
+
maximum=100,
|
431 |
+
step=1,
|
432 |
+
value=5,
|
433 |
+
)
|
434 |
+
|
435 |
+
beat_mask_width = gr.Slider(
|
436 |
+
label="beat prompt (ms)",
|
437 |
+
minimum=0,
|
438 |
+
maximum=200,
|
439 |
+
value=0,
|
440 |
+
)
|
441 |
+
beat_mask_downbeats = gr.Checkbox(
|
442 |
+
label="beat mask downbeats only?",
|
443 |
+
value=False
|
444 |
+
)
|
445 |
+
|
446 |
+
n_mask_codebooks = gr.Number(
|
447 |
+
label="first upper codebook level to mask",
|
448 |
+
value=9,
|
449 |
+
)
|
450 |
+
|
451 |
+
|
452 |
+
with gr.Accordion("extras ", open=False):
|
453 |
+
pitch_shift_amt = gr.Slider(
|
454 |
+
label="pitch shift amount (semitones)",
|
455 |
+
minimum=-12,
|
456 |
+
maximum=12,
|
457 |
+
step=1,
|
458 |
+
value=0,
|
459 |
+
)
|
460 |
+
|
461 |
+
rand_mask_intensity = gr.Slider(
|
462 |
+
label="random mask intensity. (If this is less than 1, scatters prompts throughout the audio, should be between 0.9 and 1.0)",
|
463 |
+
minimum=0.0,
|
464 |
+
maximum=1.0,
|
465 |
+
value=1.0
|
466 |
+
)
|
467 |
+
|
468 |
+
periodic_w = gr.Slider(
|
469 |
+
label="periodic prompt width (steps, 1 step ~= 10milliseconds)",
|
470 |
+
minimum=1,
|
471 |
+
maximum=20,
|
472 |
+
step=1,
|
473 |
+
value=1,
|
474 |
+
)
|
475 |
+
n_conditioning_codebooks = gr.Number(
|
476 |
+
label="number of conditioning codebooks. probably 0",
|
477 |
+
value=0,
|
478 |
+
precision=0,
|
479 |
+
)
|
480 |
+
|
481 |
+
stretch_factor = gr.Slider(
|
482 |
+
label="time stretch factor",
|
483 |
+
minimum=0,
|
484 |
+
maximum=64,
|
485 |
+
step=1,
|
486 |
+
value=1,
|
487 |
+
)
|
488 |
+
|
489 |
+
preset_outputs = {
|
490 |
+
periodic_p,
|
491 |
+
onset_mask_width,
|
492 |
+
beat_mask_width,
|
493 |
+
beat_mask_downbeats,
|
494 |
+
}
|
495 |
+
|
496 |
+
def load_preset(_preset):
|
497 |
+
return tuple(presets[_preset].values())
|
498 |
+
|
499 |
+
load_preset_button.click(
|
500 |
+
fn=load_preset,
|
501 |
+
inputs=[preset],
|
502 |
+
outputs=preset_outputs
|
503 |
+
)
|
504 |
+
|
505 |
+
|
506 |
+
with gr.Accordion("prefix/suffix prompts", open=False):
|
507 |
+
prefix_s = gr.Slider(
|
508 |
+
label="prefix hint length (seconds)",
|
509 |
+
minimum=0.0,
|
510 |
+
maximum=10.0,
|
511 |
+
value=0.0
|
512 |
+
)
|
513 |
+
suffix_s = gr.Slider(
|
514 |
+
label="suffix hint length (seconds)",
|
515 |
+
minimum=0.0,
|
516 |
+
maximum=10.0,
|
517 |
+
value=0.0
|
518 |
+
)
|
519 |
+
|
520 |
+
masktemp = gr.Slider(
|
521 |
+
label="mask temperature",
|
522 |
+
minimum=0.0,
|
523 |
+
maximum=100.0,
|
524 |
+
value=1.5
|
525 |
+
)
|
526 |
+
sampletemp = gr.Slider(
|
527 |
+
label="sample temperature",
|
528 |
+
minimum=0.1,
|
529 |
+
maximum=10.0,
|
530 |
+
value=1.0,
|
531 |
+
step=0.001
|
532 |
+
)
|
533 |
+
|
534 |
+
|
535 |
+
|
536 |
+
with gr.Accordion("sampling settings", open=False):
|
537 |
+
top_p = gr.Slider(
|
538 |
+
label="top p (0.0 = off)",
|
539 |
+
minimum=0.0,
|
540 |
+
maximum=1.0,
|
541 |
+
value=0.9
|
542 |
+
)
|
543 |
+
typical_filtering = gr.Checkbox(
|
544 |
+
label="typical filtering ",
|
545 |
+
value=False
|
546 |
+
)
|
547 |
+
typical_mass = gr.Slider(
|
548 |
+
label="typical mass (should probably stay between 0.1 and 0.5)",
|
549 |
+
minimum=0.01,
|
550 |
+
maximum=0.99,
|
551 |
+
value=0.15
|
552 |
+
)
|
553 |
+
typical_min_tokens = gr.Slider(
|
554 |
+
label="typical min tokens (should probably stay between 1 and 256)",
|
555 |
+
minimum=1,
|
556 |
+
maximum=256,
|
557 |
+
step=1,
|
558 |
+
value=64
|
559 |
+
)
|
560 |
+
sample_cutoff = gr.Slider(
|
561 |
+
label="sample cutoff",
|
562 |
+
minimum=0.0,
|
563 |
+
maximum=1.0,
|
564 |
+
value=0.5,
|
565 |
+
step=0.01
|
566 |
+
)
|
567 |
+
|
568 |
+
use_coarse2fine = gr.Checkbox(
|
569 |
+
label="use coarse2fine",
|
570 |
+
value=True,
|
571 |
+
visible=False
|
572 |
+
)
|
573 |
+
|
574 |
+
num_steps = gr.Slider(
|
575 |
+
label="number of steps (should normally be between 12 and 36)",
|
576 |
+
minimum=1,
|
577 |
+
maximum=128,
|
578 |
+
step=1,
|
579 |
+
value=36
|
580 |
+
)
|
581 |
+
|
582 |
+
dropout = gr.Slider(
|
583 |
+
label="mask dropout",
|
584 |
+
minimum=0.0,
|
585 |
+
maximum=1.0,
|
586 |
+
step=0.01,
|
587 |
+
value=0.0
|
588 |
+
)
|
589 |
+
|
590 |
+
|
591 |
+
seed = gr.Number(
|
592 |
+
label="seed (0 for random)",
|
593 |
+
value=0,
|
594 |
+
precision=0,
|
595 |
+
)
|
596 |
+
|
597 |
+
|
598 |
+
|
599 |
+
# mask settings
|
600 |
+
with gr.Column():
|
601 |
+
|
602 |
+
model_choice = gr.Dropdown(
|
603 |
+
label="model choice",
|
604 |
+
choices=list(MODEL_CHOICES.keys()),
|
605 |
+
value="default",
|
606 |
+
visible=True
|
607 |
+
)
|
608 |
+
|
609 |
+
vamp_button = gr.Button("generate (vamp)!!!")
|
610 |
+
output_audio = gr.Audio(
|
611 |
+
label="output audio",
|
612 |
+
interactive=False,
|
613 |
+
type="filepath"
|
614 |
+
)
|
615 |
+
|
616 |
+
notes_text = gr.Textbox(
|
617 |
+
label="type any notes about the generated audio here",
|
618 |
+
value="",
|
619 |
+
interactive=True
|
620 |
+
)
|
621 |
+
save_button = gr.Button("save vamp")
|
622 |
+
download_file = gr.File(
|
623 |
+
label="vamp to download will appear here",
|
624 |
+
interactive=False
|
625 |
+
)
|
626 |
+
use_as_input_button = gr.Button("use output as input")
|
627 |
+
|
628 |
+
thank_you = gr.Markdown("")
|
629 |
+
|
630 |
+
|
631 |
+
_inputs = {
|
632 |
+
input_audio,
|
633 |
+
num_steps,
|
634 |
+
masktemp,
|
635 |
+
sampletemp,
|
636 |
+
top_p,
|
637 |
+
prefix_s, suffix_s,
|
638 |
+
rand_mask_intensity,
|
639 |
+
periodic_p, periodic_w,
|
640 |
+
n_conditioning_codebooks,
|
641 |
+
dropout,
|
642 |
+
use_coarse2fine,
|
643 |
+
stretch_factor,
|
644 |
+
onset_mask_width,
|
645 |
+
typical_filtering,
|
646 |
+
typical_mass,
|
647 |
+
typical_min_tokens,
|
648 |
+
beat_mask_width,
|
649 |
+
beat_mask_downbeats,
|
650 |
+
seed,
|
651 |
+
model_choice,
|
652 |
+
n_mask_codebooks,
|
653 |
+
pitch_shift_amt,
|
654 |
+
sample_cutoff
|
655 |
+
}
|
656 |
+
|
657 |
+
# connect widgets
|
658 |
+
vamp_button.click(
|
659 |
+
fn=vamp,
|
660 |
+
inputs=_inputs,
|
661 |
+
outputs=[output_audio, audio_mask],
|
662 |
+
)
|
663 |
+
|
664 |
+
api_vamp_button = gr.Button("api vamp", visible=False)
|
665 |
+
api_vamp_button.click(
|
666 |
+
fn=api_vamp,
|
667 |
+
inputs=_inputs,
|
668 |
+
outputs=[output_audio],
|
669 |
+
api_name="vamp"
|
670 |
+
)
|
671 |
+
|
672 |
+
use_as_input_button.click(
|
673 |
+
fn=lambda x: x,
|
674 |
+
inputs=[output_audio],
|
675 |
+
outputs=[input_audio]
|
676 |
+
)
|
677 |
+
|
678 |
+
save_button.click(
|
679 |
+
fn=save_vamp,
|
680 |
+
inputs=_inputs | {notes_text, output_audio},
|
681 |
+
outputs=[thank_you, download_file]
|
682 |
+
)
|
683 |
+
|
684 |
+
# harp stuff
|
685 |
+
harp_inputs = [
|
686 |
+
input_audio,
|
687 |
+
beat_mask_width,
|
688 |
+
sampletemp,
|
689 |
+
]
|
690 |
+
|
691 |
+
build_endpoint(
|
692 |
+
inputs=harp_inputs,
|
693 |
+
output=output_audio,
|
694 |
+
process_fn=harp_vamp,
|
695 |
+
card=ModelCard(
|
696 |
+
name="vampnet",
|
697 |
+
description="Generate variations on music input, based on small prompts around the beat. NOTE: vampnet's has a maximum context length of 10 seconds. Please split all audio clips into 10 second chunks, or processing will result in an error. ",
|
698 |
+
author="Hugo Flores García",
|
699 |
+
tags=["music", "generative"]
|
700 |
+
),
|
701 |
+
visible=False
|
702 |
+
)
|
703 |
+
|
704 |
+
demo.queue().launch()
|
assets/example.wav
ADDED
Binary file (883 kB). View file
|
|
conf/c2f.yml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/vampnet.yml
|
3 |
+
|
4 |
+
VampNet.n_codebooks: 14
|
5 |
+
VampNet.n_conditioning_codebooks: 4
|
6 |
+
|
7 |
+
VampNet.embedding_dim: 1280
|
8 |
+
VampNet.n_layers: 16
|
9 |
+
VampNet.n_heads: 20
|
10 |
+
|
11 |
+
AudioDataset.duration: 3.0
|
12 |
+
|
13 |
+
|
14 |
+
AudioDataset.loudness_cutoff: -40.0
|
conf/interface.yml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Interface.coarse_ckpt: ./models/vampnet/coarse.pth
|
2 |
+
Interface.coarse2fine_ckpt: ./models/vampnet/c2f.pth
|
3 |
+
Interface.codec_ckpt: ./models/vampnet/codec.pth
|
4 |
+
Interface.coarse_chunk_size_s: 10
|
5 |
+
Interface.coarse2fine_chunk_size_s: 3
|
6 |
+
Interface.wavebeat_ckpt: ./models/wavebeat.pth
|
7 |
+
|
8 |
+
# AudioLoader.sources:
|
9 |
+
# - /media/CHONK/null
|
10 |
+
|
conf/lora/lora.yml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/vampnet.yml
|
3 |
+
|
4 |
+
fine_tune: True
|
5 |
+
|
6 |
+
train/AudioDataset.n_examples: 100000000
|
7 |
+
val/AudioDataset.n_examples: 500
|
8 |
+
|
9 |
+
|
10 |
+
NoamScheduler.warmup: 500
|
11 |
+
|
12 |
+
batch_size: 6
|
13 |
+
num_workers: 7
|
14 |
+
save_iters: [2000, 4000, 10000,20000, 40000, 100000]
|
15 |
+
sample_freq: 2000
|
16 |
+
val_freq: 1000
|
17 |
+
|
18 |
+
AdamW.lr: 0.0001
|
19 |
+
|
20 |
+
# let's us organize sound classes into folders and choose from those sound classes uniformly
|
21 |
+
AudioDataset.without_replacement: False
|
22 |
+
num_iters: 500000
|
conf/salad_bowl.yml
ADDED
File without changes
|
conf/vampnet.yml
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
codec_ckpt: ./models/vampnet/codec.pth
|
3 |
+
save_path: ckpt
|
4 |
+
|
5 |
+
num_iters: 1000000000
|
6 |
+
save_iters: [10000, 50000, 100000, 300000, 500000]
|
7 |
+
val_idx: [0,1,2,3,4,5,6,7,8,9]
|
8 |
+
sample_freq: 10000
|
9 |
+
val_freq: 1000
|
10 |
+
|
11 |
+
batch_size: 8
|
12 |
+
num_workers: 10
|
13 |
+
|
14 |
+
# Optimization
|
15 |
+
amp: false
|
16 |
+
|
17 |
+
CrossEntropyLoss.label_smoothing: 0.1
|
18 |
+
|
19 |
+
AdamW.lr: 0.001
|
20 |
+
|
21 |
+
NoamScheduler.factor: 2.0
|
22 |
+
NoamScheduler.warmup: 10000
|
23 |
+
|
24 |
+
VampNet.vocab_size: 1024
|
25 |
+
VampNet.n_codebooks: 4
|
26 |
+
VampNet.n_conditioning_codebooks: 0
|
27 |
+
VampNet.r_cond_dim: 0
|
28 |
+
VampNet.noise_mode: mask
|
29 |
+
VampNet.embedding_dim: 1280
|
30 |
+
VampNet.n_layers: 20
|
31 |
+
VampNet.n_heads: 20
|
32 |
+
VampNet.flash_attn: false
|
33 |
+
VampNet.dropout: 0.1
|
34 |
+
|
35 |
+
AudioLoader.relative_path: ""
|
36 |
+
AudioDataset.loudness_cutoff: -30.0
|
37 |
+
AudioDataset.without_replacement: true
|
38 |
+
AudioLoader.shuffle: true
|
39 |
+
|
40 |
+
AudioDataset.duration: 10.0
|
41 |
+
|
42 |
+
train/AudioDataset.n_examples: 10000000
|
43 |
+
train/AudioLoader.sources:
|
44 |
+
- /media/CHONK/hugo/spotdl/audio-train
|
45 |
+
|
46 |
+
val/AudioDataset.n_examples: 2000
|
47 |
+
val/AudioLoader.sources:
|
48 |
+
- /media/CHONK/hugo/spotdl/audio-val
|
49 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
argbind>=0.3.2
|
3 |
+
numpy==1.23
|
4 |
+
gradio
|
5 |
+
loralib
|
6 |
+
wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat
|
7 |
+
lac @ git+https://github.com/hugofloresgarcia/lac.git
|
8 |
+
descript-audiotools @ git+https://github.com/descriptinc/audiotools.git@0.7.2
|
9 |
+
-e git+https://github.com/audacitorch/pyharp.git#egg=pyharp
|
10 |
+
torch_pitch_shift
|
scripts/exp/eval.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import os
|
3 |
+
from functools import partial
|
4 |
+
|
5 |
+
from frechet_audio_distance import FrechetAudioDistance
|
6 |
+
import pandas
|
7 |
+
import argbind
|
8 |
+
import torch
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
import audiotools
|
12 |
+
from audiotools import AudioSignal
|
13 |
+
|
14 |
+
@argbind.bind(without_prefix=True)
|
15 |
+
def eval(
|
16 |
+
exp_dir: str = None,
|
17 |
+
baseline_key: str = "baseline",
|
18 |
+
audio_ext: str = ".wav",
|
19 |
+
):
|
20 |
+
assert exp_dir is not None
|
21 |
+
exp_dir = Path(exp_dir)
|
22 |
+
assert exp_dir.exists(), f"exp_dir {exp_dir} does not exist"
|
23 |
+
|
24 |
+
# set up our metrics
|
25 |
+
# sisdr_loss = audiotools.metrics.distance.SISDRLoss()
|
26 |
+
# stft_loss = audiotools.metrics.spectral.MultiScaleSTFTLoss()
|
27 |
+
mel_loss = audiotools.metrics.spectral.MelSpectrogramLoss()
|
28 |
+
frechet = FrechetAudioDistance(
|
29 |
+
use_pca=False,
|
30 |
+
use_activation=False,
|
31 |
+
verbose=True,
|
32 |
+
audio_load_worker=4,
|
33 |
+
)
|
34 |
+
frechet.model.to("cuda" if torch.cuda.is_available() else "cpu")
|
35 |
+
|
36 |
+
# figure out what conditions we have
|
37 |
+
conditions = [d.name for d in exp_dir.iterdir() if d.is_dir()]
|
38 |
+
|
39 |
+
assert baseline_key in conditions, f"baseline_key {baseline_key} not found in {exp_dir}"
|
40 |
+
conditions.remove(baseline_key)
|
41 |
+
|
42 |
+
print(f"Found {len(conditions)} conditions in {exp_dir}")
|
43 |
+
print(f"conditions: {conditions}")
|
44 |
+
|
45 |
+
baseline_dir = exp_dir / baseline_key
|
46 |
+
baseline_files = sorted(list(baseline_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem))
|
47 |
+
|
48 |
+
metrics = []
|
49 |
+
for condition in tqdm(conditions):
|
50 |
+
cond_dir = exp_dir / condition
|
51 |
+
cond_files = sorted(list(cond_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem))
|
52 |
+
|
53 |
+
print(f"computing fad for {baseline_dir} and {cond_dir}")
|
54 |
+
frechet_score = frechet.score(baseline_dir, cond_dir)
|
55 |
+
|
56 |
+
# make sure we have the same number of files
|
57 |
+
num_files = min(len(baseline_files), len(cond_files))
|
58 |
+
baseline_files = baseline_files[:num_files]
|
59 |
+
cond_files = cond_files[:num_files]
|
60 |
+
assert len(list(baseline_files)) == len(list(cond_files)), f"number of files in {baseline_dir} and {cond_dir} do not match. {len(list(baseline_files))} vs {len(list(cond_files))}"
|
61 |
+
|
62 |
+
def process(baseline_file, cond_file):
|
63 |
+
# make sure the files match (same name)
|
64 |
+
assert baseline_file.stem == cond_file.stem, f"baseline file {baseline_file} and cond file {cond_file} do not match"
|
65 |
+
|
66 |
+
# load the files
|
67 |
+
baseline_sig = AudioSignal(str(baseline_file))
|
68 |
+
cond_sig = AudioSignal(str(cond_file))
|
69 |
+
|
70 |
+
cond_sig.resample(baseline_sig.sample_rate)
|
71 |
+
cond_sig.truncate_samples(baseline_sig.length)
|
72 |
+
|
73 |
+
# if our condition is inpainting, we need to trim the conditioning off
|
74 |
+
if "inpaint" in condition:
|
75 |
+
ctx_amt = float(condition.split("_")[-1])
|
76 |
+
ctx_samples = int(ctx_amt * baseline_sig.sample_rate)
|
77 |
+
print(f"found inpainting condition. trimming off {ctx_samples} samples from {cond_file} and {baseline_file}")
|
78 |
+
cond_sig.trim(ctx_samples, ctx_samples)
|
79 |
+
baseline_sig.trim(ctx_samples, ctx_samples)
|
80 |
+
|
81 |
+
return {
|
82 |
+
# "sisdr": -sisdr_loss(baseline_sig, cond_sig).item(),
|
83 |
+
# "stft": stft_loss(baseline_sig, cond_sig).item(),
|
84 |
+
"mel": mel_loss(baseline_sig, cond_sig).item(),
|
85 |
+
"frechet": frechet_score,
|
86 |
+
# "visqol": vsq,
|
87 |
+
"condition": condition,
|
88 |
+
"file": baseline_file.stem,
|
89 |
+
}
|
90 |
+
|
91 |
+
print(f"processing {len(baseline_files)} files in {baseline_dir} and {cond_dir}")
|
92 |
+
metrics.extend(tqdm(map(process, baseline_files, cond_files), total=len(baseline_files)))
|
93 |
+
|
94 |
+
metric_keys = [k for k in metrics[0].keys() if k not in ("condition", "file")]
|
95 |
+
|
96 |
+
|
97 |
+
for mk in metric_keys:
|
98 |
+
stat = pandas.DataFrame(metrics)
|
99 |
+
stat = stat.groupby(['condition'])[mk].agg(['mean', 'count', 'std'])
|
100 |
+
stat.to_csv(exp_dir / f"stats-{mk}.csv")
|
101 |
+
|
102 |
+
df = pandas.DataFrame(metrics)
|
103 |
+
df.to_csv(exp_dir / "metrics-all.csv", index=False)
|
104 |
+
|
105 |
+
|
106 |
+
if __name__ == "__main__":
|
107 |
+
args = argbind.parse_args()
|
108 |
+
|
109 |
+
with argbind.scope(args):
|
110 |
+
eval()
|
scripts/exp/experiment.py
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import random
|
3 |
+
from typing import List
|
4 |
+
import tempfile
|
5 |
+
import subprocess
|
6 |
+
|
7 |
+
import argbind
|
8 |
+
from tqdm import tqdm
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from vampnet.interface import Interface
|
12 |
+
from vampnet import mask as pmask
|
13 |
+
import audiotools as at
|
14 |
+
|
15 |
+
Interface: Interface = argbind.bind(Interface)
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
def calculate_bitrate(
|
20 |
+
interface, num_codebooks,
|
21 |
+
downsample_factor
|
22 |
+
):
|
23 |
+
bit_width = 10
|
24 |
+
sr = interface.codec.sample_rate
|
25 |
+
hop = interface.codec.hop_size
|
26 |
+
rate = (sr / hop) * ((bit_width * num_codebooks) / downsample_factor)
|
27 |
+
return rate
|
28 |
+
|
29 |
+
def baseline(sig, interface):
|
30 |
+
return interface.preprocess(sig)
|
31 |
+
|
32 |
+
def reconstructed(sig, interface):
|
33 |
+
return interface.to_signal(
|
34 |
+
interface.encode(sig)
|
35 |
+
)
|
36 |
+
|
37 |
+
def coarse2fine(sig, interface):
|
38 |
+
z = interface.encode(sig)
|
39 |
+
z = z[:, :interface.c2f.n_conditioning_codebooks, :]
|
40 |
+
|
41 |
+
z = interface.coarse_to_fine(z)
|
42 |
+
return interface.to_signal(z)
|
43 |
+
|
44 |
+
class CoarseCond:
|
45 |
+
|
46 |
+
def __init__(self, num_conditioning_codebooks, downsample_factor):
|
47 |
+
self.num_conditioning_codebooks = num_conditioning_codebooks
|
48 |
+
self.downsample_factor = downsample_factor
|
49 |
+
|
50 |
+
def __call__(self, sig, interface):
|
51 |
+
z = interface.encode(sig)
|
52 |
+
mask = pmask.full_mask(z)
|
53 |
+
mask = pmask.codebook_unmask(mask, self.num_conditioning_codebooks)
|
54 |
+
mask = pmask.periodic_mask(mask, self.downsample_factor)
|
55 |
+
|
56 |
+
zv = interface.coarse_vamp(z, mask)
|
57 |
+
zv = interface.coarse_to_fine(zv)
|
58 |
+
return interface.to_signal(zv)
|
59 |
+
|
60 |
+
def opus(sig, interface, bitrate=128):
|
61 |
+
sig = interface.preprocess(sig)
|
62 |
+
|
63 |
+
with tempfile.NamedTemporaryFile(suffix=".wav") as f:
|
64 |
+
sig.write(f.name)
|
65 |
+
|
66 |
+
opus_name = Path(f.name).with_suffix(".opus")
|
67 |
+
# convert to opus
|
68 |
+
cmd = [
|
69 |
+
"ffmpeg", "-y", "-i", f.name,
|
70 |
+
"-c:a", "libopus",
|
71 |
+
"-b:a", f"{bitrate}",
|
72 |
+
opus_name
|
73 |
+
]
|
74 |
+
subprocess.run(cmd, check=True)
|
75 |
+
|
76 |
+
# convert back to wav
|
77 |
+
output_name = Path(f"{f.name}-opus").with_suffix(".wav")
|
78 |
+
cmd = [
|
79 |
+
"ffmpeg", "-y", "-i", opus_name,
|
80 |
+
output_name
|
81 |
+
]
|
82 |
+
|
83 |
+
subprocess.run(cmd, check=True)
|
84 |
+
|
85 |
+
sig = at.AudioSignal(
|
86 |
+
output_name,
|
87 |
+
sample_rate=sig.sample_rate
|
88 |
+
)
|
89 |
+
return sig
|
90 |
+
|
91 |
+
def mask_ratio_1_step(ratio=1.0):
|
92 |
+
def wrapper(sig, interface):
|
93 |
+
z = interface.encode(sig)
|
94 |
+
mask = pmask.linear_random(z, ratio)
|
95 |
+
zv = interface.coarse_vamp(
|
96 |
+
z,
|
97 |
+
mask,
|
98 |
+
sampling_steps=1,
|
99 |
+
)
|
100 |
+
|
101 |
+
return interface.to_signal(zv)
|
102 |
+
return wrapper
|
103 |
+
|
104 |
+
def num_sampling_steps(num_steps=1):
|
105 |
+
def wrapper(sig, interface: Interface):
|
106 |
+
z = interface.encode(sig)
|
107 |
+
mask = pmask.periodic_mask(z, 16)
|
108 |
+
zv = interface.coarse_vamp(
|
109 |
+
z,
|
110 |
+
mask,
|
111 |
+
sampling_steps=num_steps,
|
112 |
+
)
|
113 |
+
|
114 |
+
zv = interface.coarse_to_fine(zv)
|
115 |
+
return interface.to_signal(zv)
|
116 |
+
return wrapper
|
117 |
+
|
118 |
+
def beat_mask(ctx_time):
|
119 |
+
def wrapper(sig, interface):
|
120 |
+
beat_mask = interface.make_beat_mask(
|
121 |
+
sig,
|
122 |
+
before_beat_s=ctx_time/2,
|
123 |
+
after_beat_s=ctx_time/2,
|
124 |
+
invert=True
|
125 |
+
)
|
126 |
+
|
127 |
+
z = interface.encode(sig)
|
128 |
+
|
129 |
+
zv = interface.coarse_vamp(
|
130 |
+
z, beat_mask
|
131 |
+
)
|
132 |
+
|
133 |
+
zv = interface.coarse_to_fine(zv)
|
134 |
+
return interface.to_signal(zv)
|
135 |
+
return wrapper
|
136 |
+
|
137 |
+
def inpaint(ctx_time):
|
138 |
+
def wrapper(sig, interface: Interface):
|
139 |
+
z = interface.encode(sig)
|
140 |
+
mask = pmask.inpaint(z, interface.s2t(ctx_time), interface.s2t(ctx_time))
|
141 |
+
|
142 |
+
zv = interface.coarse_vamp(z, mask)
|
143 |
+
zv = interface.coarse_to_fine(zv)
|
144 |
+
|
145 |
+
return interface.to_signal(zv)
|
146 |
+
return wrapper
|
147 |
+
|
148 |
+
def token_noise(noise_amt):
|
149 |
+
def wrapper(sig, interface: Interface):
|
150 |
+
z = interface.encode(sig)
|
151 |
+
mask = pmask.random(z, noise_amt)
|
152 |
+
z = torch.where(
|
153 |
+
mask,
|
154 |
+
torch.randint_like(z, 0, interface.coarse.vocab_size),
|
155 |
+
z
|
156 |
+
)
|
157 |
+
return interface.to_signal(z)
|
158 |
+
return wrapper
|
159 |
+
|
160 |
+
EXP_REGISTRY = {}
|
161 |
+
|
162 |
+
EXP_REGISTRY["gen-compression"] = {
|
163 |
+
"baseline": baseline,
|
164 |
+
"reconstructed": reconstructed,
|
165 |
+
"coarse2fine": coarse2fine,
|
166 |
+
**{
|
167 |
+
f"{n}_codebooks_downsampled_{x}x": CoarseCond(num_conditioning_codebooks=n, downsample_factor=x)
|
168 |
+
for (n, x) in (
|
169 |
+
(1, 1), # 1 codebook, no downsampling
|
170 |
+
(4, 4), # 4 codebooks, downsampled 4x
|
171 |
+
(4, 16), # 4 codebooks, downsampled 16x
|
172 |
+
(4, 32), # 4 codebooks, downsampled 16x
|
173 |
+
)
|
174 |
+
},
|
175 |
+
**{
|
176 |
+
f"token_noise_{x}": mask_ratio_1_step(ratio=x)
|
177 |
+
for x in [0.25, 0.5, 0.75]
|
178 |
+
},
|
179 |
+
|
180 |
+
}
|
181 |
+
|
182 |
+
|
183 |
+
EXP_REGISTRY["sampling-steps"] = {
|
184 |
+
# "codec": reconstructed,
|
185 |
+
**{f"steps_{n}": num_sampling_steps(n) for n in [1, 4, 12, 36, 64, 72]},
|
186 |
+
}
|
187 |
+
|
188 |
+
|
189 |
+
EXP_REGISTRY["musical-sampling"] = {
|
190 |
+
**{f"beat_mask_{t}": beat_mask(t) for t in [0.075]},
|
191 |
+
**{f"inpaint_{t}": inpaint(t) for t in [0.5, 1.0,]}, # multiply these by 2 (they go left and right)
|
192 |
+
}
|
193 |
+
|
194 |
+
@argbind.bind(without_prefix=True)
|
195 |
+
def main(
|
196 |
+
sources=[
|
197 |
+
"/media/CHONK/hugo/spotdl/val",
|
198 |
+
],
|
199 |
+
output_dir: str = "./samples",
|
200 |
+
max_excerpts: int = 2000,
|
201 |
+
exp_type: str = "gen-compression",
|
202 |
+
seed: int = 0,
|
203 |
+
ext: str = [".mp3"],
|
204 |
+
):
|
205 |
+
at.util.seed(seed)
|
206 |
+
interface = Interface()
|
207 |
+
|
208 |
+
output_dir = Path(output_dir)
|
209 |
+
output_dir.mkdir(exist_ok=True, parents=True)
|
210 |
+
|
211 |
+
from audiotools.data.datasets import AudioLoader, AudioDataset
|
212 |
+
|
213 |
+
loader = AudioLoader(sources=sources, shuffle_state=seed, ext=ext)
|
214 |
+
dataset = AudioDataset(loader,
|
215 |
+
sample_rate=interface.codec.sample_rate,
|
216 |
+
duration=interface.coarse.chunk_size_s,
|
217 |
+
n_examples=max_excerpts,
|
218 |
+
without_replacement=True,
|
219 |
+
)
|
220 |
+
|
221 |
+
if exp_type in EXP_REGISTRY:
|
222 |
+
SAMPLE_CONDS = EXP_REGISTRY[exp_type]
|
223 |
+
else:
|
224 |
+
raise ValueError(f"Unknown exp_type {exp_type}")
|
225 |
+
|
226 |
+
|
227 |
+
indices = list(range(max_excerpts))
|
228 |
+
random.shuffle(indices)
|
229 |
+
for i in tqdm(indices):
|
230 |
+
# if all our files are already there, skip
|
231 |
+
done = []
|
232 |
+
for name in SAMPLE_CONDS:
|
233 |
+
o_dir = Path(output_dir) / name
|
234 |
+
done.append((o_dir / f"{i}.wav").exists())
|
235 |
+
if all(done):
|
236 |
+
continue
|
237 |
+
|
238 |
+
sig = dataset[i]["signal"]
|
239 |
+
results = {
|
240 |
+
name: cond(sig, interface).cpu()
|
241 |
+
for name, cond in SAMPLE_CONDS.items()
|
242 |
+
}
|
243 |
+
|
244 |
+
for name, sig in results.items():
|
245 |
+
o_dir = Path(output_dir) / name
|
246 |
+
o_dir.mkdir(exist_ok=True, parents=True)
|
247 |
+
|
248 |
+
sig.write(o_dir / f"{i}.wav")
|
249 |
+
|
250 |
+
if __name__ == "__main__":
|
251 |
+
args = argbind.parse_args()
|
252 |
+
|
253 |
+
with argbind.scope(args):
|
254 |
+
main()
|
scripts/exp/fine_tune.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argbind
|
2 |
+
from pathlib import Path
|
3 |
+
import yaml
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
"""example output: (yaml)
|
10 |
+
|
11 |
+
"""
|
12 |
+
|
13 |
+
@argbind.bind(without_prefix=True, positional=True)
|
14 |
+
def fine_tune(audio_files_or_folders: List[str], name: str):
|
15 |
+
|
16 |
+
conf_dir = Path("conf")
|
17 |
+
assert conf_dir.exists(), "conf directory not found. are you in the vampnet directory?"
|
18 |
+
|
19 |
+
conf_dir = conf_dir / "generated"
|
20 |
+
conf_dir.mkdir(exist_ok=True)
|
21 |
+
|
22 |
+
finetune_dir = conf_dir / name
|
23 |
+
finetune_dir.mkdir(exist_ok=True)
|
24 |
+
|
25 |
+
finetune_c2f_conf = {
|
26 |
+
"$include": ["conf/lora/lora.yml"],
|
27 |
+
"fine_tune": True,
|
28 |
+
"train/AudioLoader.sources": audio_files_or_folders,
|
29 |
+
"val/AudioLoader.sources": audio_files_or_folders,
|
30 |
+
"VampNet.n_codebooks": 14,
|
31 |
+
"VampNet.n_conditioning_codebooks": 4,
|
32 |
+
"VampNet.embedding_dim": 1280,
|
33 |
+
"VampNet.n_layers": 16,
|
34 |
+
"VampNet.n_heads": 20,
|
35 |
+
"AudioDataset.duration": 3.0,
|
36 |
+
"AudioDataset.loudness_cutoff": -40.0,
|
37 |
+
"save_path": str(finetune_dir / "ckpt/c2f"),
|
38 |
+
"fine_tune_checkpoint": "./models/vampnet/c2f.pth"
|
39 |
+
}
|
40 |
+
|
41 |
+
finetune_coarse_conf = {
|
42 |
+
"$include": ["conf/lora/lora.yml"],
|
43 |
+
"fine_tune": True,
|
44 |
+
"train/AudioLoader.sources": audio_files_or_folders,
|
45 |
+
"val/AudioLoader.sources": audio_files_or_folders,
|
46 |
+
"save_path": str(finetune_dir / "ckpt/coarse"),
|
47 |
+
"fine_tune_checkpoint": "./models/vampnet/coarse.pth"
|
48 |
+
}
|
49 |
+
|
50 |
+
interface_conf = {
|
51 |
+
"Interface.coarse_ckpt": f"{finetune_dir}/ckpt/coarse/latest/vampnet/weights.pth",
|
52 |
+
|
53 |
+
"Interface.coarse2fine_ckpt": f"{finetune_dir}/ckpt/c2f/latest/vampnet/weights.pth",
|
54 |
+
"Interface.wavebeat_ckpt": "./models/wavebeat.pth",
|
55 |
+
|
56 |
+
"Interface.codec_ckpt": "./models/vampnet/codec.pth",
|
57 |
+
"AudioLoader.sources": [audio_files_or_folders],
|
58 |
+
}
|
59 |
+
|
60 |
+
# save the confs
|
61 |
+
with open(finetune_dir / "c2f.yml", "w") as f:
|
62 |
+
yaml.dump(finetune_c2f_conf, f)
|
63 |
+
|
64 |
+
with open(finetune_dir / "coarse.yml", "w") as f:
|
65 |
+
yaml.dump(finetune_coarse_conf, f)
|
66 |
+
|
67 |
+
with open(finetune_dir / "interface.yml", "w") as f:
|
68 |
+
yaml.dump(interface_conf, f)
|
69 |
+
|
70 |
+
|
71 |
+
print(f"generated confs in {finetune_dir}. run training jobs with `python scripts/exp/train.py --args.load {finetune_dir}/<c2f/coarse>.yml` ")
|
72 |
+
|
73 |
+
if __name__ == "__main__":
|
74 |
+
args = argbind.parse_args()
|
75 |
+
|
76 |
+
with argbind.scope(args):
|
77 |
+
fine_tune()
|
78 |
+
|
79 |
+
|
80 |
+
|
81 |
+
|
scripts/exp/train.py
ADDED
@@ -0,0 +1,686 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import warnings
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Optional
|
6 |
+
from dataclasses import dataclass
|
7 |
+
|
8 |
+
import argbind
|
9 |
+
import audiotools as at
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
from audiotools import AudioSignal
|
13 |
+
from audiotools.data import transforms as tfm
|
14 |
+
from einops import rearrange
|
15 |
+
from rich import pretty
|
16 |
+
from rich.traceback import install
|
17 |
+
from torch.utils.tensorboard import SummaryWriter
|
18 |
+
|
19 |
+
import vampnet
|
20 |
+
from vampnet.modules.transformer import VampNet
|
21 |
+
from vampnet.util import codebook_unflatten, codebook_flatten
|
22 |
+
from vampnet import mask as pmask
|
23 |
+
# from dac.model.dac import DAC
|
24 |
+
from lac.model.lac import LAC as DAC
|
25 |
+
|
26 |
+
from audiotools.ml.decorators import (
|
27 |
+
timer, Tracker, when
|
28 |
+
)
|
29 |
+
|
30 |
+
import loralib as lora
|
31 |
+
|
32 |
+
import torch._dynamo
|
33 |
+
torch._dynamo.config.verbose=True
|
34 |
+
|
35 |
+
|
36 |
+
# Enable cudnn autotuner to speed up training
|
37 |
+
# (can be altered by the funcs.seed function)
|
38 |
+
torch.backends.cudnn.benchmark = bool(int(os.getenv("CUDNN_BENCHMARK", 1)))
|
39 |
+
# Uncomment to trade memory for speed.
|
40 |
+
|
41 |
+
# Install to make things look nice
|
42 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
43 |
+
pretty.install()
|
44 |
+
install()
|
45 |
+
|
46 |
+
# optim
|
47 |
+
Accelerator = argbind.bind(at.ml.Accelerator, without_prefix=True)
|
48 |
+
CrossEntropyLoss = argbind.bind(nn.CrossEntropyLoss)
|
49 |
+
AdamW = argbind.bind(torch.optim.AdamW)
|
50 |
+
NoamScheduler = argbind.bind(vampnet.scheduler.NoamScheduler)
|
51 |
+
|
52 |
+
# transforms
|
53 |
+
filter_fn = lambda fn: hasattr(fn, "transform") and fn.__qualname__ not in [
|
54 |
+
"BaseTransform",
|
55 |
+
"Compose",
|
56 |
+
"Choose",
|
57 |
+
]
|
58 |
+
|
59 |
+
# model
|
60 |
+
VampNet = argbind.bind(VampNet)
|
61 |
+
|
62 |
+
|
63 |
+
# data
|
64 |
+
AudioLoader = argbind.bind(at.datasets.AudioLoader)
|
65 |
+
AudioDataset = argbind.bind(at.datasets.AudioDataset, "train", "val")
|
66 |
+
|
67 |
+
IGNORE_INDEX = -100
|
68 |
+
|
69 |
+
|
70 |
+
@argbind.bind("train", "val", without_prefix=True)
|
71 |
+
def build_transform():
|
72 |
+
transform = tfm.Compose(
|
73 |
+
tfm.VolumeNorm(("const", -24)),
|
74 |
+
# tfm.PitchShift(),
|
75 |
+
tfm.RescaleAudio(),
|
76 |
+
)
|
77 |
+
return transform
|
78 |
+
|
79 |
+
|
80 |
+
@torch.no_grad()
|
81 |
+
def apply_transform(transform_fn, batch):
|
82 |
+
sig: AudioSignal = batch["signal"]
|
83 |
+
kwargs = batch["transform_args"]
|
84 |
+
|
85 |
+
sig: AudioSignal = transform_fn(sig.clone(), **kwargs)
|
86 |
+
return sig
|
87 |
+
|
88 |
+
|
89 |
+
def build_datasets(args, sample_rate: int):
|
90 |
+
with argbind.scope(args, "train"):
|
91 |
+
train_data = AudioDataset(
|
92 |
+
AudioLoader(), sample_rate, transform=build_transform()
|
93 |
+
)
|
94 |
+
with argbind.scope(args, "val"):
|
95 |
+
val_data = AudioDataset(AudioLoader(), sample_rate, transform=build_transform())
|
96 |
+
return train_data, val_data
|
97 |
+
|
98 |
+
|
99 |
+
def rand_float(shape, low, high, rng):
|
100 |
+
return rng.draw(shape)[:, 0] * (high - low) + low
|
101 |
+
|
102 |
+
|
103 |
+
def flip_coin(shape, p, rng):
|
104 |
+
return rng.draw(shape)[:, 0] < p
|
105 |
+
|
106 |
+
|
107 |
+
def num_params_hook(o, p):
|
108 |
+
return o + f" {p/1e6:<.3f}M params."
|
109 |
+
|
110 |
+
|
111 |
+
def add_num_params_repr_hook(model):
|
112 |
+
import numpy as np
|
113 |
+
from functools import partial
|
114 |
+
|
115 |
+
for n, m in model.named_modules():
|
116 |
+
o = m.extra_repr()
|
117 |
+
p = sum([np.prod(p.size()) for p in m.parameters()])
|
118 |
+
|
119 |
+
setattr(m, "extra_repr", partial(num_params_hook, o=o, p=p))
|
120 |
+
|
121 |
+
|
122 |
+
def accuracy(
|
123 |
+
preds: torch.Tensor,
|
124 |
+
target: torch.Tensor,
|
125 |
+
top_k: int = 1,
|
126 |
+
ignore_index: Optional[int] = None,
|
127 |
+
) -> torch.Tensor:
|
128 |
+
# Flatten the predictions and targets to be of shape (batch_size * sequence_length, n_class)
|
129 |
+
preds = rearrange(preds, "b p s -> (b s) p")
|
130 |
+
target = rearrange(target, "b s -> (b s)")
|
131 |
+
|
132 |
+
# return torchmetrics.functional.accuracy(preds, target, task='multiclass', top_k=topk, num_classes=preds.shape[-1], ignore_index=ignore_index)
|
133 |
+
if ignore_index is not None:
|
134 |
+
# Create a mask for the ignored index
|
135 |
+
mask = target != ignore_index
|
136 |
+
# Apply the mask to the target and predictions
|
137 |
+
preds = preds[mask]
|
138 |
+
target = target[mask]
|
139 |
+
|
140 |
+
# Get the top-k predicted classes and their indices
|
141 |
+
_, pred_indices = torch.topk(preds, k=top_k, dim=-1)
|
142 |
+
|
143 |
+
# Determine if the true target is in the top-k predicted classes
|
144 |
+
correct = torch.sum(torch.eq(pred_indices, target.unsqueeze(1)), dim=1)
|
145 |
+
|
146 |
+
# Calculate the accuracy
|
147 |
+
accuracy = torch.mean(correct.float())
|
148 |
+
|
149 |
+
return accuracy
|
150 |
+
|
151 |
+
def _metrics(z_hat, r, target, flat_mask, output):
|
152 |
+
for r_range in [(0, 0.5), (0.5, 1.0)]:
|
153 |
+
unmasked_target = target.masked_fill(flat_mask.bool(), IGNORE_INDEX)
|
154 |
+
masked_target = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
|
155 |
+
|
156 |
+
assert target.shape[0] == r.shape[0]
|
157 |
+
# grab the indices of the r values that are in the range
|
158 |
+
r_idx = (r >= r_range[0]) & (r < r_range[1])
|
159 |
+
|
160 |
+
# grab the target and z_hat values that are in the range
|
161 |
+
r_unmasked_target = unmasked_target[r_idx]
|
162 |
+
r_masked_target = masked_target[r_idx]
|
163 |
+
r_z_hat = z_hat[r_idx]
|
164 |
+
|
165 |
+
for topk in (1, 25):
|
166 |
+
s, e = r_range
|
167 |
+
tag = f"accuracy-{s}-{e}/top{topk}"
|
168 |
+
|
169 |
+
output[f"{tag}/unmasked"] = accuracy(
|
170 |
+
preds=r_z_hat,
|
171 |
+
target=r_unmasked_target,
|
172 |
+
ignore_index=IGNORE_INDEX,
|
173 |
+
top_k=topk,
|
174 |
+
)
|
175 |
+
output[f"{tag}/masked"] = accuracy(
|
176 |
+
preds=r_z_hat,
|
177 |
+
target=r_masked_target,
|
178 |
+
ignore_index=IGNORE_INDEX,
|
179 |
+
top_k=topk,
|
180 |
+
)
|
181 |
+
|
182 |
+
|
183 |
+
@dataclass
|
184 |
+
class State:
|
185 |
+
model: VampNet
|
186 |
+
codec: DAC
|
187 |
+
|
188 |
+
optimizer: AdamW
|
189 |
+
scheduler: NoamScheduler
|
190 |
+
criterion: CrossEntropyLoss
|
191 |
+
grad_clip_val: float
|
192 |
+
|
193 |
+
rng: torch.quasirandom.SobolEngine
|
194 |
+
|
195 |
+
train_data: AudioDataset
|
196 |
+
val_data: AudioDataset
|
197 |
+
|
198 |
+
tracker: Tracker
|
199 |
+
|
200 |
+
|
201 |
+
@timer()
|
202 |
+
def train_loop(state: State, batch: dict, accel: Accelerator):
|
203 |
+
state.model.train()
|
204 |
+
batch = at.util.prepare_batch(batch, accel.device)
|
205 |
+
signal = apply_transform(state.train_data.transform, batch)
|
206 |
+
|
207 |
+
output = {}
|
208 |
+
vn = accel.unwrap(state.model)
|
209 |
+
with accel.autocast():
|
210 |
+
with torch.inference_mode():
|
211 |
+
state.codec.to(accel.device)
|
212 |
+
z = state.codec.encode(signal.samples, signal.sample_rate)["codes"]
|
213 |
+
z = z[:, : vn.n_codebooks, :]
|
214 |
+
|
215 |
+
n_batch = z.shape[0]
|
216 |
+
r = state.rng.draw(n_batch)[:, 0].to(accel.device)
|
217 |
+
|
218 |
+
mask = pmask.random(z, r)
|
219 |
+
mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
|
220 |
+
z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
|
221 |
+
|
222 |
+
z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
|
223 |
+
|
224 |
+
dtype = torch.bfloat16 if accel.amp else None
|
225 |
+
with accel.autocast(dtype=dtype):
|
226 |
+
z_hat = state.model(z_mask_latent)
|
227 |
+
|
228 |
+
target = codebook_flatten(
|
229 |
+
z[:, vn.n_conditioning_codebooks :, :],
|
230 |
+
)
|
231 |
+
|
232 |
+
flat_mask = codebook_flatten(
|
233 |
+
mask[:, vn.n_conditioning_codebooks :, :],
|
234 |
+
)
|
235 |
+
|
236 |
+
# replace target with ignore index for masked tokens
|
237 |
+
t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
|
238 |
+
output["loss"] = state.criterion(z_hat, t_masked)
|
239 |
+
|
240 |
+
_metrics(
|
241 |
+
r=r,
|
242 |
+
z_hat=z_hat,
|
243 |
+
target=target,
|
244 |
+
flat_mask=flat_mask,
|
245 |
+
output=output,
|
246 |
+
)
|
247 |
+
|
248 |
+
|
249 |
+
accel.backward(output["loss"])
|
250 |
+
|
251 |
+
output["other/learning_rate"] = state.optimizer.param_groups[0]["lr"]
|
252 |
+
output["other/batch_size"] = z.shape[0]
|
253 |
+
|
254 |
+
|
255 |
+
accel.scaler.unscale_(state.optimizer)
|
256 |
+
output["other/grad_norm"] = torch.nn.utils.clip_grad_norm_(
|
257 |
+
state.model.parameters(), state.grad_clip_val
|
258 |
+
)
|
259 |
+
|
260 |
+
accel.step(state.optimizer)
|
261 |
+
state.optimizer.zero_grad()
|
262 |
+
|
263 |
+
state.scheduler.step()
|
264 |
+
accel.update()
|
265 |
+
|
266 |
+
|
267 |
+
return {k: v for k, v in sorted(output.items())}
|
268 |
+
|
269 |
+
|
270 |
+
@timer()
|
271 |
+
@torch.no_grad()
|
272 |
+
def val_loop(state: State, batch: dict, accel: Accelerator):
|
273 |
+
state.model.eval()
|
274 |
+
state.codec.eval()
|
275 |
+
batch = at.util.prepare_batch(batch, accel.device)
|
276 |
+
signal = apply_transform(state.val_data.transform, batch)
|
277 |
+
|
278 |
+
vn = accel.unwrap(state.model)
|
279 |
+
z = state.codec.encode(signal.samples, signal.sample_rate)["codes"]
|
280 |
+
z = z[:, : vn.n_codebooks, :]
|
281 |
+
|
282 |
+
n_batch = z.shape[0]
|
283 |
+
r = state.rng.draw(n_batch)[:, 0].to(accel.device)
|
284 |
+
|
285 |
+
mask = pmask.random(z, r)
|
286 |
+
mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
|
287 |
+
z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
|
288 |
+
|
289 |
+
z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
|
290 |
+
|
291 |
+
z_hat = state.model(z_mask_latent)
|
292 |
+
|
293 |
+
target = codebook_flatten(
|
294 |
+
z[:, vn.n_conditioning_codebooks :, :],
|
295 |
+
)
|
296 |
+
|
297 |
+
flat_mask = codebook_flatten(
|
298 |
+
mask[:, vn.n_conditioning_codebooks :, :]
|
299 |
+
)
|
300 |
+
|
301 |
+
output = {}
|
302 |
+
# replace target with ignore index for masked tokens
|
303 |
+
t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
|
304 |
+
output["loss"] = state.criterion(z_hat, t_masked)
|
305 |
+
|
306 |
+
_metrics(
|
307 |
+
r=r,
|
308 |
+
z_hat=z_hat,
|
309 |
+
target=target,
|
310 |
+
flat_mask=flat_mask,
|
311 |
+
output=output,
|
312 |
+
)
|
313 |
+
|
314 |
+
return output
|
315 |
+
|
316 |
+
|
317 |
+
def validate(state, val_dataloader, accel):
|
318 |
+
for batch in val_dataloader:
|
319 |
+
output = val_loop(state, batch, accel)
|
320 |
+
# Consolidate state dicts if using ZeroRedundancyOptimizer
|
321 |
+
if hasattr(state.optimizer, "consolidate_state_dict"):
|
322 |
+
state.optimizer.consolidate_state_dict()
|
323 |
+
return output
|
324 |
+
|
325 |
+
|
326 |
+
def checkpoint(state, save_iters, save_path, fine_tune):
|
327 |
+
if accel.local_rank != 0:
|
328 |
+
state.tracker.print(f"ERROR:Skipping checkpoint on rank {accel.local_rank}")
|
329 |
+
return
|
330 |
+
|
331 |
+
metadata = {"logs": dict(state.tracker.history)}
|
332 |
+
|
333 |
+
tags = ["latest"]
|
334 |
+
state.tracker.print(f"Saving to {str(Path('.').absolute())}")
|
335 |
+
|
336 |
+
if state.tracker.step in save_iters:
|
337 |
+
tags.append(f"{state.tracker.step // 1000}k")
|
338 |
+
|
339 |
+
if state.tracker.is_best("val", "loss"):
|
340 |
+
state.tracker.print(f"Best model so far")
|
341 |
+
tags.append("best")
|
342 |
+
|
343 |
+
if fine_tune:
|
344 |
+
for tag in tags:
|
345 |
+
# save the lora model
|
346 |
+
(Path(save_path) / tag).mkdir(parents=True, exist_ok=True)
|
347 |
+
torch.save(
|
348 |
+
lora.lora_state_dict(accel.unwrap(state.model)),
|
349 |
+
f"{save_path}/{tag}/lora.pth"
|
350 |
+
)
|
351 |
+
|
352 |
+
for tag in tags:
|
353 |
+
model_extra = {
|
354 |
+
"optimizer.pth": state.optimizer.state_dict(),
|
355 |
+
"scheduler.pth": state.scheduler.state_dict(),
|
356 |
+
"tracker.pth": state.tracker.state_dict(),
|
357 |
+
"metadata.pth": metadata,
|
358 |
+
}
|
359 |
+
|
360 |
+
accel.unwrap(state.model).metadata = metadata
|
361 |
+
accel.unwrap(state.model).save_to_folder(
|
362 |
+
f"{save_path}/{tag}", model_extra, package=False
|
363 |
+
)
|
364 |
+
|
365 |
+
|
366 |
+
def save_sampled(state, z, writer):
|
367 |
+
num_samples = z.shape[0]
|
368 |
+
|
369 |
+
for i in range(num_samples):
|
370 |
+
sampled = accel.unwrap(state.model).generate(
|
371 |
+
codec=state.codec,
|
372 |
+
time_steps=z.shape[-1],
|
373 |
+
start_tokens=z[i : i + 1],
|
374 |
+
)
|
375 |
+
sampled.cpu().write_audio_to_tb(
|
376 |
+
f"sampled/{i}",
|
377 |
+
writer,
|
378 |
+
step=state.tracker.step,
|
379 |
+
plot_fn=None,
|
380 |
+
)
|
381 |
+
|
382 |
+
|
383 |
+
def save_imputation(state, z, val_idx, writer):
|
384 |
+
n_prefix = int(z.shape[-1] * 0.25)
|
385 |
+
n_suffix = int(z.shape[-1] * 0.25)
|
386 |
+
|
387 |
+
vn = accel.unwrap(state.model)
|
388 |
+
|
389 |
+
mask = pmask.inpaint(z, n_prefix, n_suffix)
|
390 |
+
mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
|
391 |
+
z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
|
392 |
+
|
393 |
+
imputed_noisy = vn.to_signal(z_mask, state.codec)
|
394 |
+
imputed_true = vn.to_signal(z, state.codec)
|
395 |
+
|
396 |
+
imputed = []
|
397 |
+
for i in range(len(z)):
|
398 |
+
imputed.append(
|
399 |
+
vn.generate(
|
400 |
+
codec=state.codec,
|
401 |
+
time_steps=z.shape[-1],
|
402 |
+
start_tokens=z[i][None, ...],
|
403 |
+
mask=mask[i][None, ...],
|
404 |
+
)
|
405 |
+
)
|
406 |
+
imputed = AudioSignal.batch(imputed)
|
407 |
+
|
408 |
+
for i in range(len(val_idx)):
|
409 |
+
imputed_noisy[i].cpu().write_audio_to_tb(
|
410 |
+
f"inpainted_prompt/{i}",
|
411 |
+
writer,
|
412 |
+
step=state.tracker.step,
|
413 |
+
plot_fn=None,
|
414 |
+
)
|
415 |
+
imputed[i].cpu().write_audio_to_tb(
|
416 |
+
f"inpainted_middle/{i}",
|
417 |
+
writer,
|
418 |
+
step=state.tracker.step,
|
419 |
+
plot_fn=None,
|
420 |
+
)
|
421 |
+
imputed_true[i].cpu().write_audio_to_tb(
|
422 |
+
f"reconstructed/{i}",
|
423 |
+
writer,
|
424 |
+
step=state.tracker.step,
|
425 |
+
plot_fn=None,
|
426 |
+
)
|
427 |
+
|
428 |
+
|
429 |
+
@torch.no_grad()
|
430 |
+
def save_samples(state: State, val_idx: int, writer: SummaryWriter):
|
431 |
+
state.model.eval()
|
432 |
+
state.codec.eval()
|
433 |
+
vn = accel.unwrap(state.model)
|
434 |
+
|
435 |
+
batch = [state.val_data[i] for i in val_idx]
|
436 |
+
batch = at.util.prepare_batch(state.val_data.collate(batch), accel.device)
|
437 |
+
|
438 |
+
signal = apply_transform(state.val_data.transform, batch)
|
439 |
+
|
440 |
+
z = state.codec.encode(signal.samples, signal.sample_rate)["codes"]
|
441 |
+
z = z[:, : vn.n_codebooks, :]
|
442 |
+
|
443 |
+
r = torch.linspace(0.1, 0.95, len(val_idx)).to(accel.device)
|
444 |
+
|
445 |
+
|
446 |
+
mask = pmask.random(z, r)
|
447 |
+
mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
|
448 |
+
z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
|
449 |
+
|
450 |
+
z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
|
451 |
+
|
452 |
+
z_hat = state.model(z_mask_latent)
|
453 |
+
|
454 |
+
z_pred = torch.softmax(z_hat, dim=1).argmax(dim=1)
|
455 |
+
z_pred = codebook_unflatten(z_pred, n_c=vn.n_predict_codebooks)
|
456 |
+
z_pred = torch.cat([z[:, : vn.n_conditioning_codebooks, :], z_pred], dim=1)
|
457 |
+
|
458 |
+
generated = vn.to_signal(z_pred, state.codec)
|
459 |
+
reconstructed = vn.to_signal(z, state.codec)
|
460 |
+
masked = vn.to_signal(z_mask.squeeze(1), state.codec)
|
461 |
+
|
462 |
+
for i in range(generated.batch_size):
|
463 |
+
audio_dict = {
|
464 |
+
"original": signal[i],
|
465 |
+
"masked": masked[i],
|
466 |
+
"generated": generated[i],
|
467 |
+
"reconstructed": reconstructed[i],
|
468 |
+
}
|
469 |
+
for k, v in audio_dict.items():
|
470 |
+
v.cpu().write_audio_to_tb(
|
471 |
+
f"onestep/_{i}.r={r[i]:0.2f}/{k}",
|
472 |
+
writer,
|
473 |
+
step=state.tracker.step,
|
474 |
+
plot_fn=None,
|
475 |
+
)
|
476 |
+
|
477 |
+
save_sampled(state=state, z=z, writer=writer)
|
478 |
+
save_imputation(state=state, z=z, val_idx=val_idx, writer=writer)
|
479 |
+
|
480 |
+
|
481 |
+
|
482 |
+
@argbind.bind(without_prefix=True)
|
483 |
+
def load(
|
484 |
+
args,
|
485 |
+
accel: at.ml.Accelerator,
|
486 |
+
tracker: Tracker,
|
487 |
+
save_path: str,
|
488 |
+
resume: bool = False,
|
489 |
+
tag: str = "latest",
|
490 |
+
fine_tune_checkpoint: Optional[str] = None,
|
491 |
+
grad_clip_val: float = 5.0,
|
492 |
+
) -> State:
|
493 |
+
codec = DAC.load(args["codec_ckpt"], map_location="cpu")
|
494 |
+
codec.eval()
|
495 |
+
|
496 |
+
model, v_extra = None, {}
|
497 |
+
|
498 |
+
if args["fine_tune"]:
|
499 |
+
assert fine_tune_checkpoint is not None, "Must provide a fine-tune checkpoint"
|
500 |
+
model = torch.compile(
|
501 |
+
VampNet.load(location=Path(fine_tune_checkpoint),
|
502 |
+
map_location="cpu",
|
503 |
+
)
|
504 |
+
)
|
505 |
+
|
506 |
+
if resume:
|
507 |
+
kwargs = {
|
508 |
+
"folder": f"{save_path}/{tag}",
|
509 |
+
"map_location": "cpu",
|
510 |
+
"package": False,
|
511 |
+
}
|
512 |
+
tracker.print(f"Loading checkpoint from {kwargs['folder']}")
|
513 |
+
if (Path(kwargs["folder"]) / "vampnet").exists():
|
514 |
+
model, v_extra = VampNet.load_from_folder(**kwargs)
|
515 |
+
else:
|
516 |
+
raise ValueError(
|
517 |
+
f"Could not find a VampNet checkpoint in {kwargs['folder']}"
|
518 |
+
)
|
519 |
+
|
520 |
+
|
521 |
+
|
522 |
+
|
523 |
+
model = torch.compile(VampNet()) if model is None else model
|
524 |
+
model = accel.prepare_model(model)
|
525 |
+
|
526 |
+
# assert accel.unwrap(model).n_codebooks == codec.quantizer.n_codebooks
|
527 |
+
assert (
|
528 |
+
accel.unwrap(model).vocab_size == codec.quantizer.quantizers[0].codebook_size
|
529 |
+
)
|
530 |
+
|
531 |
+
|
532 |
+
if accel.world_size > 1:
|
533 |
+
from torch.distributed.optim import ZeroRedundancyOptimizer
|
534 |
+
optimizer = ZeroRedundancyOptimizer(model.parameters(), AdamW)
|
535 |
+
print(f"OPTIMIZER LR is {optimizer.param_groups[0]['lr']}")
|
536 |
+
else:
|
537 |
+
optimizer = AdamW(model.parameters())
|
538 |
+
|
539 |
+
scheduler = NoamScheduler(optimizer, d_model=accel.unwrap(model).embedding_dim)
|
540 |
+
scheduler.step()
|
541 |
+
|
542 |
+
if "optimizer.pth" in v_extra:
|
543 |
+
optimizer.load_state_dict(v_extra["optimizer.pth"])
|
544 |
+
scheduler.load_state_dict(v_extra["scheduler.pth"])
|
545 |
+
if "tracker.pth" in v_extra:
|
546 |
+
tracker.load_state_dict(v_extra["tracker.pth"])
|
547 |
+
|
548 |
+
criterion = CrossEntropyLoss()
|
549 |
+
|
550 |
+
sample_rate = codec.sample_rate
|
551 |
+
|
552 |
+
# a better rng for sampling from our schedule
|
553 |
+
rng = torch.quasirandom.SobolEngine(1, scramble=True, seed=args["seed"])
|
554 |
+
|
555 |
+
# log a model summary w/ num params
|
556 |
+
if accel.local_rank == 0:
|
557 |
+
add_num_params_repr_hook(accel.unwrap(model))
|
558 |
+
with open(f"{save_path}/model.txt", "w") as f:
|
559 |
+
f.write(repr(accel.unwrap(model)))
|
560 |
+
|
561 |
+
# load the datasets
|
562 |
+
train_data, val_data = build_datasets(args, sample_rate)
|
563 |
+
|
564 |
+
return State(
|
565 |
+
tracker=tracker,
|
566 |
+
model=model,
|
567 |
+
codec=codec,
|
568 |
+
optimizer=optimizer,
|
569 |
+
scheduler=scheduler,
|
570 |
+
criterion=criterion,
|
571 |
+
rng=rng,
|
572 |
+
train_data=train_data,
|
573 |
+
val_data=val_data,
|
574 |
+
grad_clip_val=grad_clip_val,
|
575 |
+
)
|
576 |
+
|
577 |
+
|
578 |
+
@argbind.bind(without_prefix=True)
|
579 |
+
def train(
|
580 |
+
args,
|
581 |
+
accel: at.ml.Accelerator,
|
582 |
+
seed: int = 0,
|
583 |
+
codec_ckpt: str = None,
|
584 |
+
save_path: str = "ckpt",
|
585 |
+
num_iters: int = int(1000e6),
|
586 |
+
save_iters: list = [10000, 50000, 100000, 300000, 500000,],
|
587 |
+
sample_freq: int = 10000,
|
588 |
+
val_freq: int = 1000,
|
589 |
+
batch_size: int = 12,
|
590 |
+
val_idx: list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
|
591 |
+
num_workers: int = 10,
|
592 |
+
fine_tune: bool = False,
|
593 |
+
):
|
594 |
+
assert codec_ckpt is not None, "codec_ckpt is required"
|
595 |
+
|
596 |
+
seed = seed + accel.local_rank
|
597 |
+
at.util.seed(seed)
|
598 |
+
writer = None
|
599 |
+
|
600 |
+
if accel.local_rank == 0:
|
601 |
+
writer = SummaryWriter(log_dir=f"{save_path}/logs/")
|
602 |
+
argbind.dump_args(args, f"{save_path}/args.yml")
|
603 |
+
|
604 |
+
tracker = Tracker(
|
605 |
+
writer=writer, log_file=f"{save_path}/log.txt", rank=accel.local_rank
|
606 |
+
)
|
607 |
+
|
608 |
+
# load the codec model
|
609 |
+
state: State = load(
|
610 |
+
args=args,
|
611 |
+
accel=accel,
|
612 |
+
tracker=tracker,
|
613 |
+
save_path=save_path)
|
614 |
+
print("initialized state.")
|
615 |
+
|
616 |
+
train_dataloader = accel.prepare_dataloader(
|
617 |
+
state.train_data,
|
618 |
+
start_idx=state.tracker.step * batch_size,
|
619 |
+
num_workers=num_workers,
|
620 |
+
batch_size=batch_size,
|
621 |
+
collate_fn=state.train_data.collate,
|
622 |
+
)
|
623 |
+
val_dataloader = accel.prepare_dataloader(
|
624 |
+
state.val_data,
|
625 |
+
start_idx=0,
|
626 |
+
num_workers=num_workers,
|
627 |
+
batch_size=batch_size,
|
628 |
+
collate_fn=state.val_data.collate,
|
629 |
+
persistent_workers=num_workers > 0,
|
630 |
+
)
|
631 |
+
print("initialized dataloader.")
|
632 |
+
|
633 |
+
|
634 |
+
|
635 |
+
if fine_tune:
|
636 |
+
lora.mark_only_lora_as_trainable(state.model)
|
637 |
+
print("marked only lora as trainable.")
|
638 |
+
|
639 |
+
# Wrap the functions so that they neatly track in TensorBoard + progress bars
|
640 |
+
# and only run when specific conditions are met.
|
641 |
+
global train_loop, val_loop, validate, save_samples, checkpoint
|
642 |
+
|
643 |
+
train_loop = tracker.log("train", "value", history=False)(
|
644 |
+
tracker.track("train", num_iters, completed=state.tracker.step)(train_loop)
|
645 |
+
)
|
646 |
+
val_loop = tracker.track("val", len(val_dataloader))(val_loop)
|
647 |
+
validate = tracker.log("val", "mean")(validate)
|
648 |
+
|
649 |
+
save_samples = when(lambda: accel.local_rank == 0)(save_samples)
|
650 |
+
checkpoint = when(lambda: accel.local_rank == 0)(checkpoint)
|
651 |
+
|
652 |
+
print("starting training loop.")
|
653 |
+
with tracker.live:
|
654 |
+
for tracker.step, batch in enumerate(train_dataloader, start=tracker.step):
|
655 |
+
train_loop(state, batch, accel)
|
656 |
+
|
657 |
+
last_iter = (
|
658 |
+
tracker.step == num_iters - 1 if num_iters is not None else False
|
659 |
+
)
|
660 |
+
|
661 |
+
if tracker.step % sample_freq == 0 or last_iter:
|
662 |
+
save_samples(state, val_idx, writer)
|
663 |
+
|
664 |
+
if tracker.step % val_freq == 0 or last_iter:
|
665 |
+
validate(state, val_dataloader, accel)
|
666 |
+
checkpoint(
|
667 |
+
state=state,
|
668 |
+
save_iters=save_iters,
|
669 |
+
save_path=save_path,
|
670 |
+
fine_tune=fine_tune)
|
671 |
+
|
672 |
+
# Reset validation progress bar, print summary since last validation.
|
673 |
+
tracker.done("val", f"Iteration {tracker.step}")
|
674 |
+
|
675 |
+
if last_iter:
|
676 |
+
break
|
677 |
+
|
678 |
+
|
679 |
+
if __name__ == "__main__":
|
680 |
+
args = argbind.parse_args()
|
681 |
+
args["args.debug"] = int(os.getenv("LOCAL_RANK", 0)) == 0
|
682 |
+
with argbind.scope(args):
|
683 |
+
with Accelerator() as accel:
|
684 |
+
if accel.local_rank != 0:
|
685 |
+
sys.tracebacklimit = 0
|
686 |
+
train(args, accel)
|
scripts/utils/README.md
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Scripts
|
2 |
+
|
3 |
+
## process_zip.py
|
4 |
+
|
5 |
+
Some requirements that may not be installed in the docker image:
|
6 |
+
* argbind
|
7 |
+
* wav2wav (pip install git+https://github.com/descriptinc/lyrebird-wav2wav.git or `pip install git+https://github.com/descriptinc/lyrebird-wav2wav.git@<branchname>`)
|
8 |
+
|
9 |
+
### zip folder structure
|
10 |
+
|
11 |
+
The zip folder should have the following internal structure:
|
12 |
+
|
13 |
+
```
|
14 |
+
base_folder/
|
15 |
+
test_case_1/
|
16 |
+
before.wav
|
17 |
+
test_case_2/
|
18 |
+
before.wav
|
19 |
+
...
|
20 |
+
test_case_n/
|
21 |
+
before.wav
|
22 |
+
```
|
23 |
+
|
24 |
+
Note: There can be issues with the output zip if the input zip folder structure is too deep or too shallow. IF you want/need to use a zip file with a different folder structure, adjust this:
|
25 |
+
https://github.com/descriptinc/lyrebird-wav2wav/blob/136c923ce19df03876a515ca0ed83854710cfa30/scripts/utils/process_zip.py#L28
|
26 |
+
|
27 |
+
### Execution
|
28 |
+
`python process_zip.py <path/to/zip> -tag <string>`
|
scripts/utils/gtzan_embeddings.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
TODO: train a linear probe
|
3 |
+
usage:
|
4 |
+
python gtzan_embeddings.py --args.load conf/interface.yml --Interface.device cuda --path_to_gtzan /path/to/gtzan/genres_original --output_dir /path/to/output
|
5 |
+
"""
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import List
|
8 |
+
|
9 |
+
import audiotools as at
|
10 |
+
from audiotools import AudioSignal
|
11 |
+
import argbind
|
12 |
+
import torch
|
13 |
+
import numpy as np
|
14 |
+
import zipfile
|
15 |
+
import json
|
16 |
+
|
17 |
+
from vampnet.interface import Interface
|
18 |
+
import tqdm
|
19 |
+
|
20 |
+
# bind the Interface to argbind
|
21 |
+
Interface = argbind.bind(Interface)
|
22 |
+
|
23 |
+
DEBUG = False
|
24 |
+
|
25 |
+
def smart_plotly_export(fig, save_path):
|
26 |
+
img_format = save_path.split('.')[-1]
|
27 |
+
if img_format == 'html':
|
28 |
+
fig.write_html(save_path)
|
29 |
+
elif img_format == 'bytes':
|
30 |
+
return fig.to_image(format='png')
|
31 |
+
#TODO: come back and make this prettier
|
32 |
+
elif img_format == 'numpy':
|
33 |
+
import io
|
34 |
+
from PIL import Image
|
35 |
+
|
36 |
+
def plotly_fig2array(fig):
|
37 |
+
#convert Plotly fig to an array
|
38 |
+
fig_bytes = fig.to_image(format="png", width=1200, height=700)
|
39 |
+
buf = io.BytesIO(fig_bytes)
|
40 |
+
img = Image.open(buf)
|
41 |
+
return np.asarray(img)
|
42 |
+
|
43 |
+
return plotly_fig2array(fig)
|
44 |
+
elif img_format == 'jpeg' or 'png' or 'webp':
|
45 |
+
fig.write_image(save_path)
|
46 |
+
else:
|
47 |
+
raise ValueError("invalid image format")
|
48 |
+
|
49 |
+
def dim_reduce(emb, labels, save_path, n_components=3, method='tsne', title=''):
|
50 |
+
"""
|
51 |
+
dimensionality reduction for visualization!
|
52 |
+
saves an html plotly figure to save_path
|
53 |
+
parameters:
|
54 |
+
emb (np.ndarray): the samples to be reduces with shape (samples, features)
|
55 |
+
labels (list): list of labels for embedding
|
56 |
+
save_path (str): path where u wanna save ur figure
|
57 |
+
method (str): umap, tsne, or pca
|
58 |
+
title (str): title for ur figure
|
59 |
+
returns:
|
60 |
+
proj (np.ndarray): projection vector with shape (samples, dimensions)
|
61 |
+
"""
|
62 |
+
import pandas as pd
|
63 |
+
import plotly.express as px
|
64 |
+
if method == 'umap':
|
65 |
+
from umap import UMAP
|
66 |
+
reducer = umap.UMAP(n_components=n_components)
|
67 |
+
elif method == 'tsne':
|
68 |
+
from sklearn.manifold import TSNE
|
69 |
+
reducer = TSNE(n_components=n_components)
|
70 |
+
elif method == 'pca':
|
71 |
+
from sklearn.decomposition import PCA
|
72 |
+
reducer = PCA(n_components=n_components)
|
73 |
+
else:
|
74 |
+
raise ValueError
|
75 |
+
|
76 |
+
proj = reducer.fit_transform(emb)
|
77 |
+
|
78 |
+
if n_components == 2:
|
79 |
+
df = pd.DataFrame(dict(
|
80 |
+
x=proj[:, 0],
|
81 |
+
y=proj[:, 1],
|
82 |
+
instrument=labels
|
83 |
+
))
|
84 |
+
fig = px.scatter(df, x='x', y='y', color='instrument',
|
85 |
+
title=title+f"_{method}")
|
86 |
+
|
87 |
+
elif n_components == 3:
|
88 |
+
df = pd.DataFrame(dict(
|
89 |
+
x=proj[:, 0],
|
90 |
+
y=proj[:, 1],
|
91 |
+
z=proj[:, 2],
|
92 |
+
instrument=labels
|
93 |
+
))
|
94 |
+
fig = px.scatter_3d(df, x='x', y='y', z='z',
|
95 |
+
color='instrument',
|
96 |
+
title=title)
|
97 |
+
else:
|
98 |
+
raise ValueError("cant plot more than 3 components")
|
99 |
+
|
100 |
+
fig.update_traces(marker=dict(size=6,
|
101 |
+
line=dict(width=1,
|
102 |
+
color='DarkSlateGrey')),
|
103 |
+
selector=dict(mode='markers'))
|
104 |
+
|
105 |
+
return smart_plotly_export(fig, save_path)
|
106 |
+
|
107 |
+
|
108 |
+
|
109 |
+
# per JukeMIR, we want the emebddings from the middle layer?
|
110 |
+
def vampnet_embed(sig: AudioSignal, interface: Interface, layer=10):
|
111 |
+
with torch.inference_mode():
|
112 |
+
# preprocess the signal
|
113 |
+
sig = interface.preprocess(sig)
|
114 |
+
|
115 |
+
# get the coarse vampnet model
|
116 |
+
vampnet = interface.coarse
|
117 |
+
|
118 |
+
# get the tokens
|
119 |
+
z = interface.encode(sig)[:, :vampnet.n_codebooks, :]
|
120 |
+
z_latents = vampnet.embedding.from_codes(z, interface.codec)
|
121 |
+
|
122 |
+
# do a forward pass through the model, get the embeddings
|
123 |
+
_z, embeddings = vampnet(z_latents, return_activations=True)
|
124 |
+
# print(f"got embeddings with shape {embeddings.shape}")
|
125 |
+
# [layer, batch, time, n_dims]
|
126 |
+
# [20, 1, 600ish, 768]
|
127 |
+
|
128 |
+
|
129 |
+
# squeeze batch dim (1 bc layer should be dim 0)
|
130 |
+
assert embeddings.shape[1] == 1, f"expected batch dim to be 1, got {embeddings.shape[0]}"
|
131 |
+
embeddings = embeddings.squeeze(1)
|
132 |
+
|
133 |
+
num_layers = embeddings.shape[0]
|
134 |
+
assert layer < num_layers, f"layer {layer} is out of bounds for model with {num_layers} layers"
|
135 |
+
|
136 |
+
# do meanpooling over the time dimension
|
137 |
+
embeddings = embeddings.mean(dim=-2)
|
138 |
+
# [20, 768]
|
139 |
+
|
140 |
+
# return the embeddings
|
141 |
+
return embeddings
|
142 |
+
|
143 |
+
from dataclasses import dataclass, fields
|
144 |
+
@dataclass
|
145 |
+
class Embedding:
|
146 |
+
genre: str
|
147 |
+
filename: str
|
148 |
+
embedding: np.ndarray
|
149 |
+
|
150 |
+
def save(self, path):
|
151 |
+
"""Save the Embedding object to a given path as a zip file."""
|
152 |
+
with zipfile.ZipFile(path, 'w') as archive:
|
153 |
+
|
154 |
+
# Save numpy array
|
155 |
+
with archive.open('embedding.npy', 'w') as f:
|
156 |
+
np.save(f, self.embedding)
|
157 |
+
|
158 |
+
# Save non-numpy data as json
|
159 |
+
non_numpy_data = {f.name: getattr(self, f.name) for f in fields(self) if f.name != 'embedding'}
|
160 |
+
with archive.open('data.json', 'w') as f:
|
161 |
+
f.write(json.dumps(non_numpy_data).encode('utf-8'))
|
162 |
+
|
163 |
+
@classmethod
|
164 |
+
def load(cls, path):
|
165 |
+
"""Load the Embedding object from a given zip path."""
|
166 |
+
with zipfile.ZipFile(path, 'r') as archive:
|
167 |
+
|
168 |
+
# Load numpy array
|
169 |
+
with archive.open('embedding.npy') as f:
|
170 |
+
embedding = np.load(f)
|
171 |
+
|
172 |
+
# Load non-numpy data from json
|
173 |
+
with archive.open('data.json') as f:
|
174 |
+
data = json.loads(f.read().decode('utf-8'))
|
175 |
+
|
176 |
+
return cls(embedding=embedding, **data)
|
177 |
+
|
178 |
+
|
179 |
+
@argbind.bind(without_prefix=True)
|
180 |
+
def main(
|
181 |
+
path_to_gtzan: str = None,
|
182 |
+
cache_dir: str = "./.gtzan_emb_cache",
|
183 |
+
output_dir: str = "./gtzan_vampnet_embeddings",
|
184 |
+
layers: List[int] = [1, 3, 5, 7, 9, 11, 13, 15, 17, 19]
|
185 |
+
):
|
186 |
+
path_to_gtzan = Path(path_to_gtzan)
|
187 |
+
assert path_to_gtzan.exists(), f"{path_to_gtzan} does not exist"
|
188 |
+
|
189 |
+
cache_dir = Path(cache_dir)
|
190 |
+
output_dir = Path(output_dir)
|
191 |
+
output_dir.mkdir(exist_ok=True, parents=True)
|
192 |
+
|
193 |
+
# load our interface
|
194 |
+
# argbind will automatically load the default config,
|
195 |
+
interface = Interface()
|
196 |
+
|
197 |
+
# gtzan should have a folder for each genre, so let's get the list of genres
|
198 |
+
genres = [Path(x).name for x in path_to_gtzan.iterdir() if x.is_dir()]
|
199 |
+
print(f"Found {len(genres)} genres")
|
200 |
+
print(f"genres: {genres}")
|
201 |
+
|
202 |
+
# collect audio files, genres, and embeddings
|
203 |
+
data = []
|
204 |
+
for genre in genres:
|
205 |
+
audio_files = list(at.util.find_audio(path_to_gtzan / genre))
|
206 |
+
print(f"Found {len(audio_files)} audio files for genre {genre}")
|
207 |
+
|
208 |
+
for audio_file in tqdm.tqdm(audio_files, desc=f"embedding genre {genre}"):
|
209 |
+
# check if we have a cached embedding for this file
|
210 |
+
cached_path = (cache_dir / f"{genre}_{audio_file.stem}.emb")
|
211 |
+
if cached_path.exists():
|
212 |
+
# if so, load it
|
213 |
+
if DEBUG:
|
214 |
+
print(f"loading cached embedding for {cached_path.stem}")
|
215 |
+
embedding = Embedding.load(cached_path)
|
216 |
+
else:
|
217 |
+
try:
|
218 |
+
sig = AudioSignal(audio_file)
|
219 |
+
except Exception as e:
|
220 |
+
print(f"failed to load {audio_file.name} with error {e}")
|
221 |
+
print(f"skipping {audio_file.name}")
|
222 |
+
continue
|
223 |
+
|
224 |
+
# gets the embedding
|
225 |
+
emb = vampnet_embed(sig, interface).cpu().numpy()
|
226 |
+
|
227 |
+
# create an embedding we can save/load
|
228 |
+
embedding = Embedding(
|
229 |
+
genre=genre,
|
230 |
+
filename=audio_file.name,
|
231 |
+
embedding=emb
|
232 |
+
)
|
233 |
+
|
234 |
+
# cache the embeddings
|
235 |
+
cached_path.parent.mkdir(exist_ok=True, parents=True)
|
236 |
+
embedding.save(cached_path)
|
237 |
+
data.append(embedding)
|
238 |
+
|
239 |
+
# now, let's do a dim reduction on the embeddings
|
240 |
+
# and visualize them.
|
241 |
+
|
242 |
+
# collect a list of embeddings and labels
|
243 |
+
embeddings = [d.embedding for d in data]
|
244 |
+
labels = [d.genre for d in data]
|
245 |
+
|
246 |
+
# convert the embeddings to a numpy array
|
247 |
+
embeddings = np.stack(embeddings)
|
248 |
+
|
249 |
+
# do dimensionality reduction for each layer we're given
|
250 |
+
for layer in tqdm.tqdm(layers, desc="dim reduction"):
|
251 |
+
dim_reduce(
|
252 |
+
embeddings[:, layer, :], labels,
|
253 |
+
save_path=str(output_dir / f'vampnet-gtzan-layer={layer}.html'),
|
254 |
+
n_components=2, method='tsne',
|
255 |
+
title=f'vampnet-gtzan-layer={layer}'
|
256 |
+
)
|
257 |
+
|
258 |
+
|
259 |
+
|
260 |
+
|
261 |
+
if __name__ == "__main__":
|
262 |
+
args = argbind.parse_args()
|
263 |
+
with argbind.scope(args):
|
264 |
+
main()
|
scripts/utils/plots.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
import seaborn as sns
|
3 |
+
from pandas.api.types import CategoricalDtype
|
4 |
+
|
5 |
+
def plot_metrics(metrics, condition_to_latex, title, color_palette):
|
6 |
+
# Add a new column to your dataframe with the latex representation
|
7 |
+
metrics['condition_latex'] = metrics['condition'].map(condition_to_latex)
|
8 |
+
|
9 |
+
# Order condition_latex as per the condition_to_latex dictionary
|
10 |
+
cat_type = CategoricalDtype(categories=condition_to_latex.values(), ordered=True)
|
11 |
+
metrics['condition_latex'] = metrics['condition_latex'].astype(cat_type)
|
12 |
+
|
13 |
+
# Compute mean and std for each condition for each metric
|
14 |
+
grouped = metrics.groupby('condition_latex')[['mel', 'frechet']].agg(['mean', 'std'])
|
15 |
+
|
16 |
+
fig, axs = plt.subplots(2, 1, figsize=(7, 5.25))
|
17 |
+
|
18 |
+
# Set the main title for the figure
|
19 |
+
fig.suptitle(title, fontsize=16)
|
20 |
+
|
21 |
+
# Get color for each bar in the plot
|
22 |
+
bar_colors = [color_palette[condition] for condition in grouped.index]
|
23 |
+
|
24 |
+
# Plot mel
|
25 |
+
sns.boxplot(x='condition_latex', y='mel', data=metrics, ax=axs[0], palette=color_palette, showfliers=False)
|
26 |
+
axs[0].set_ylabel('Mel Spectrogram Loss \u2190')
|
27 |
+
axs[0].set_xlabel('') # Remove x-axis label
|
28 |
+
axs[0].set_xticklabels(grouped.index, rotation=0, ha='center')
|
29 |
+
|
30 |
+
# Plot frechet
|
31 |
+
axs[1].bar(grouped.index, grouped['frechet']['mean'], yerr=grouped['frechet']['std'], color=bar_colors)
|
32 |
+
axs[1].set_ylabel('FAD \u2190')
|
33 |
+
axs[1].set_xlabel('') # Remove x-axis label
|
34 |
+
axs[1].set_xticklabels(grouped.index, rotation=0, ha='center')
|
35 |
+
|
36 |
+
# Adjust the space between plots
|
37 |
+
plt.subplots_adjust(hspace=0.1)
|
38 |
+
|
39 |
+
# Remove any unnecessary space around the plot
|
40 |
+
plt.tight_layout(rect=[0, 0, 1, 0.96])
|
41 |
+
|
42 |
+
# Reduce the space between suptitle and the plot
|
43 |
+
plt.subplots_adjust(top=0.92)
|
scripts/utils/remove_quiet_files.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# removes files with loudness below 24db
|
2 |
+
|
3 |
+
from pathlib import Path
|
4 |
+
import shutil
|
5 |
+
import audiotools as at
|
6 |
+
import argbind
|
7 |
+
|
8 |
+
@argbind.bind(without_prefix=True)
|
9 |
+
def remove_quiet_files(
|
10 |
+
src_dir: Path = None,
|
11 |
+
dest_dir: Path = None,
|
12 |
+
min_loudness: float = -30,
|
13 |
+
):
|
14 |
+
# copy src to dest
|
15 |
+
dest_dir.mkdir(parents=True, exist_ok=True)
|
16 |
+
shutil.copytree(src_dir, dest_dir, dirs_exist_ok=True)
|
17 |
+
|
18 |
+
audio_files = at.util.find_audio(dest_dir)
|
19 |
+
for audio_file in audio_files:
|
20 |
+
sig = at.AudioSignal(audio_file)
|
21 |
+
if sig.loudness() < min_loudness:
|
22 |
+
audio_file.unlink()
|
23 |
+
print(f"removed {audio_file}")
|
24 |
+
|
25 |
+
if __name__ == "__main__":
|
26 |
+
args = argbind.parse_args()
|
27 |
+
|
28 |
+
with argbind.scope(args):
|
29 |
+
remove_quiet_files()
|
scripts/utils/split.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import random
|
3 |
+
import shutil
|
4 |
+
import os
|
5 |
+
import json
|
6 |
+
|
7 |
+
import argbind
|
8 |
+
from tqdm import tqdm
|
9 |
+
from tqdm.contrib.concurrent import thread_map
|
10 |
+
|
11 |
+
from audiotools.core import util
|
12 |
+
|
13 |
+
|
14 |
+
@argbind.bind(without_prefix=True)
|
15 |
+
def train_test_split(
|
16 |
+
audio_folder: str = ".",
|
17 |
+
test_size: float = 0.2,
|
18 |
+
seed: int = 42,
|
19 |
+
pattern: str = "**/*.mp3",
|
20 |
+
):
|
21 |
+
print(f"finding audio")
|
22 |
+
|
23 |
+
audio_folder = Path(audio_folder)
|
24 |
+
audio_files = list(tqdm(audio_folder.glob(pattern)))
|
25 |
+
print(f"found {len(audio_files)} audio files")
|
26 |
+
|
27 |
+
# split according to test_size
|
28 |
+
n_test = int(len(audio_files) * test_size)
|
29 |
+
n_train = len(audio_files) - n_test
|
30 |
+
|
31 |
+
# shuffle
|
32 |
+
random.seed(seed)
|
33 |
+
random.shuffle(audio_files)
|
34 |
+
|
35 |
+
train_files = audio_files[:n_train]
|
36 |
+
test_files = audio_files[n_train:]
|
37 |
+
|
38 |
+
|
39 |
+
print(f"Train files: {len(train_files)}")
|
40 |
+
print(f"Test files: {len(test_files)}")
|
41 |
+
continue_ = input("Continue [yn]? ") or "n"
|
42 |
+
|
43 |
+
if continue_ != "y":
|
44 |
+
return
|
45 |
+
|
46 |
+
for split, files in (
|
47 |
+
("train", train_files), ("test", test_files)
|
48 |
+
):
|
49 |
+
for file in tqdm(files):
|
50 |
+
out_file = audio_folder.parent / f"{audio_folder.name}-{split}" / Path(file).name
|
51 |
+
out_file.parent.mkdir(exist_ok=True, parents=True)
|
52 |
+
os.symlink(file, out_file)
|
53 |
+
|
54 |
+
# save split as json
|
55 |
+
with open(Path(audio_folder) / f"{split}.json", "w") as f:
|
56 |
+
json.dump([str(f) for f in files], f)
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
if __name__ == "__main__":
|
61 |
+
args = argbind.parse_args()
|
62 |
+
|
63 |
+
with argbind.scope(args):
|
64 |
+
train_test_split()
|
scripts/utils/split_long_audio_file.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import argbind
|
3 |
+
|
4 |
+
import audiotools as at
|
5 |
+
import tqdm
|
6 |
+
|
7 |
+
|
8 |
+
@argbind.bind(without_prefix=True)
|
9 |
+
def split_long_audio_file(
|
10 |
+
file: str = None,
|
11 |
+
max_chunk_size_s: int = 60*10
|
12 |
+
):
|
13 |
+
file = Path(file)
|
14 |
+
output_dir = file.parent / file.stem
|
15 |
+
output_dir.mkdir()
|
16 |
+
|
17 |
+
sig = at.AudioSignal(file)
|
18 |
+
|
19 |
+
# split into chunks
|
20 |
+
for i, sig in tqdm.tqdm(enumerate(sig.windows(
|
21 |
+
window_duration=max_chunk_size_s, hop_duration=max_chunk_size_s/2,
|
22 |
+
preprocess=True))
|
23 |
+
):
|
24 |
+
sig.write(output_dir / f"{i}.wav")
|
25 |
+
|
26 |
+
print(f"wrote {len(list(output_dir.glob('*.wav')))} files to {output_dir}")
|
27 |
+
|
28 |
+
return output_dir
|
29 |
+
|
30 |
+
if __name__ == "__main__":
|
31 |
+
args = argbind.parse_args()
|
32 |
+
|
33 |
+
with argbind.scope(args):
|
34 |
+
split_long_audio_file()
|
scripts/utils/stage.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import subprocess
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import argbind
|
6 |
+
import rich
|
7 |
+
from audiotools.ml import Experiment
|
8 |
+
|
9 |
+
|
10 |
+
@argbind.bind(without_prefix=True)
|
11 |
+
def run(
|
12 |
+
run_dir: str = os.getenv("PATH_TO_RUNS", "runs"),
|
13 |
+
name: str = None,
|
14 |
+
recent: bool = False,
|
15 |
+
):
|
16 |
+
if recent:
|
17 |
+
paths = sorted(Path(run_dir).iterdir(), key=os.path.getmtime)
|
18 |
+
paths = [p.name for p in paths if p.is_dir()]
|
19 |
+
if paths:
|
20 |
+
name = paths[-1]
|
21 |
+
|
22 |
+
with Experiment(run_dir, name) as exp:
|
23 |
+
exp.snapshot()
|
24 |
+
rich.print(f"Created a snapshot of {exp.parent_directory} at {exp.exp_dir}")
|
25 |
+
|
26 |
+
|
27 |
+
if __name__ == "__main__":
|
28 |
+
args = argbind.parse_args()
|
29 |
+
with argbind.scope(args):
|
30 |
+
run()
|
scripts/utils/visualize_embeddings.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
TODO: train a linear probe
|
3 |
+
usage:
|
4 |
+
python gtzan_embeddings.py --args.load conf/interface.yml --Interface.device cuda --path_to_audio /path/to/audio/labels --output_dir /path/to/output
|
5 |
+
"""
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import List
|
8 |
+
|
9 |
+
import audiotools as at
|
10 |
+
from audiotools import AudioSignal
|
11 |
+
import argbind
|
12 |
+
import torch
|
13 |
+
import numpy as np
|
14 |
+
import zipfile
|
15 |
+
import json
|
16 |
+
|
17 |
+
from vampnet.interface import Interface
|
18 |
+
import tqdm
|
19 |
+
|
20 |
+
# bind the Interface to argbind
|
21 |
+
Interface = argbind.bind(Interface)
|
22 |
+
|
23 |
+
DEBUG = False
|
24 |
+
|
25 |
+
|
26 |
+
def smart_plotly_export(fig, save_path: Path):
|
27 |
+
img_format = save_path.suffix[1:]
|
28 |
+
if img_format == "html":
|
29 |
+
fig.write_html(save_path)
|
30 |
+
elif img_format == 'bytes':
|
31 |
+
return fig.to_image(format='png')
|
32 |
+
#TODO: come back and make this prettier
|
33 |
+
elif img_format == 'numpy':
|
34 |
+
import io
|
35 |
+
from PIL import Image
|
36 |
+
|
37 |
+
def plotly_fig2array(fig):
|
38 |
+
#convert Plotly fig to an array
|
39 |
+
fig_bytes = fig.to_image(format="png", width=1200, height=700)
|
40 |
+
buf = io.BytesIO(fig_bytes)
|
41 |
+
img = Image.open(buf)
|
42 |
+
return np.asarray(img)
|
43 |
+
|
44 |
+
return plotly_fig2array(fig)
|
45 |
+
elif img_format == 'jpeg' or 'png' or 'webp':
|
46 |
+
fig.write_image(save_path)
|
47 |
+
else:
|
48 |
+
raise ValueError("invalid image format")
|
49 |
+
|
50 |
+
|
51 |
+
def dim_reduce(annotated_embeddings, layer, output_dir, n_components=3, method="tsne"):
|
52 |
+
"""
|
53 |
+
dimensionality reduction for visualization!
|
54 |
+
saves an html plotly figure to save_path
|
55 |
+
parameters:
|
56 |
+
annotated_embeddings (list): the annotated enmbeddings to be reduced; embeddings have shape (samples, features)
|
57 |
+
labels (list): list of labels for embedding
|
58 |
+
save_path (str): path where u wanna save ur figure
|
59 |
+
method (str): umap, tsne, or pca
|
60 |
+
title (str): title for ur figure
|
61 |
+
returns:
|
62 |
+
proj (np.ndarray): projection vector with shape (samples, dimensions)
|
63 |
+
"""
|
64 |
+
import pandas as pd
|
65 |
+
import plotly.express as px
|
66 |
+
|
67 |
+
fig_name = f"vampnet-embeddings-layer={layer}"
|
68 |
+
fig_title = f"{fig_name}_{method}"
|
69 |
+
save_path = (output_dir / fig_name).with_suffix(".html")
|
70 |
+
|
71 |
+
if method == "umap":
|
72 |
+
from umap import UMAP
|
73 |
+
reducer = umap.UMAP(n_components=n_components)
|
74 |
+
elif method == "tsne":
|
75 |
+
from sklearn.manifold import TSNE
|
76 |
+
|
77 |
+
reducer = TSNE(n_components=n_components)
|
78 |
+
elif method == "pca":
|
79 |
+
from sklearn.decomposition import PCA
|
80 |
+
|
81 |
+
reducer = PCA(n_components=n_components)
|
82 |
+
else:
|
83 |
+
raise ValueError(f"invalid method: {method}")
|
84 |
+
|
85 |
+
labels = [emb.label for emb in annotated_embeddings]
|
86 |
+
names = [emb.filename for emb in annotated_embeddings]
|
87 |
+
embs = [emb.embedding for emb in annotated_embeddings]
|
88 |
+
embs_at_layer = np.stack(embs)[:, layer, :]
|
89 |
+
projs = reducer.fit_transform(embs_at_layer)
|
90 |
+
|
91 |
+
df = pd.DataFrame(
|
92 |
+
{
|
93 |
+
"label": labels,
|
94 |
+
"name": names,
|
95 |
+
"x": projs[:, 0],
|
96 |
+
"y": projs[:, 1],
|
97 |
+
}
|
98 |
+
)
|
99 |
+
if n_components == 2:
|
100 |
+
fig = px.scatter(
|
101 |
+
df, x="x", y="y", color="label", hover_name="name", title=fig_title,
|
102 |
+
)
|
103 |
+
|
104 |
+
elif n_components == 3:
|
105 |
+
df['z'] = projs[:, 2]
|
106 |
+
fig = px.scatter_3d(
|
107 |
+
df, x="x", y="y", z="z", color="label", hover_name="name", title=fig_title
|
108 |
+
)
|
109 |
+
else:
|
110 |
+
raise ValueError(f"can't plot {n_components} components")
|
111 |
+
|
112 |
+
fig.update_traces(
|
113 |
+
marker=dict(size=6, line=dict(width=1, color="DarkSlateGrey")),
|
114 |
+
selector=dict(mode="markers"),
|
115 |
+
)
|
116 |
+
|
117 |
+
return smart_plotly_export(fig, save_path)
|
118 |
+
|
119 |
+
|
120 |
+
|
121 |
+
# per JukeMIR, we want the emebddings from the middle layer?
|
122 |
+
def vampnet_embed(sig: AudioSignal, interface: Interface, layer=10):
|
123 |
+
with torch.inference_mode():
|
124 |
+
# preprocess the signal
|
125 |
+
sig = interface.preprocess(sig)
|
126 |
+
|
127 |
+
# get the coarse vampnet model
|
128 |
+
vampnet = interface.coarse
|
129 |
+
|
130 |
+
# get the tokens
|
131 |
+
z = interface.encode(sig)[:, :vampnet.n_codebooks, :]
|
132 |
+
z_latents = vampnet.embedding.from_codes(z, interface.codec)
|
133 |
+
|
134 |
+
# do a forward pass through the model, get the embeddings
|
135 |
+
_z, embeddings = vampnet(z_latents, return_activations=True)
|
136 |
+
# print(f"got embeddings with shape {embeddings.shape}")
|
137 |
+
# [layer, batch, time, n_dims]
|
138 |
+
# [20, 1, 600ish, 768]
|
139 |
+
|
140 |
+
|
141 |
+
# squeeze batch dim (1 bc layer should be dim 0)
|
142 |
+
assert embeddings.shape[1] == 1, f"expected batch dim to be 1, got {embeddings.shape[0]}"
|
143 |
+
embeddings = embeddings.squeeze(1)
|
144 |
+
|
145 |
+
num_layers = embeddings.shape[0]
|
146 |
+
assert layer < num_layers, f"layer {layer} is out of bounds for model with {num_layers} layers"
|
147 |
+
|
148 |
+
# do meanpooling over the time dimension
|
149 |
+
embeddings = embeddings.mean(dim=-2)
|
150 |
+
# [20, 768]
|
151 |
+
|
152 |
+
# return the embeddings
|
153 |
+
return embeddings
|
154 |
+
|
155 |
+
from dataclasses import dataclass, fields
|
156 |
+
@dataclass
|
157 |
+
class AnnotatedEmbedding:
|
158 |
+
label: str
|
159 |
+
filename: str
|
160 |
+
embedding: np.ndarray
|
161 |
+
|
162 |
+
def save(self, path):
|
163 |
+
"""Save the Embedding object to a given path as a zip file."""
|
164 |
+
with zipfile.ZipFile(path, 'w') as archive:
|
165 |
+
|
166 |
+
# Save numpy array
|
167 |
+
with archive.open('embedding.npy', 'w') as f:
|
168 |
+
np.save(f, self.embedding)
|
169 |
+
|
170 |
+
# Save non-numpy data as json
|
171 |
+
non_numpy_data = {f.name: getattr(self, f.name) for f in fields(self) if f.name != 'embedding'}
|
172 |
+
with archive.open('data.json', 'w') as f:
|
173 |
+
f.write(json.dumps(non_numpy_data).encode('utf-8'))
|
174 |
+
|
175 |
+
@classmethod
|
176 |
+
def load(cls, path):
|
177 |
+
"""Load the Embedding object from a given zip path."""
|
178 |
+
with zipfile.ZipFile(path, 'r') as archive:
|
179 |
+
|
180 |
+
# Load numpy array
|
181 |
+
with archive.open('embedding.npy') as f:
|
182 |
+
embedding = np.load(f)
|
183 |
+
|
184 |
+
# Load non-numpy data from json
|
185 |
+
with archive.open('data.json') as f:
|
186 |
+
data = json.loads(f.read().decode('utf-8'))
|
187 |
+
|
188 |
+
return cls(embedding=embedding, **data)
|
189 |
+
|
190 |
+
|
191 |
+
@argbind.bind(without_prefix=True)
|
192 |
+
def main(
|
193 |
+
path_to_audio: str = None,
|
194 |
+
cache_dir: str = "./.emb_cache",
|
195 |
+
output_dir: str = "./vampnet_embeddings",
|
196 |
+
layers: List[int] = [1, 3, 5, 7, 9, 11, 13, 15, 17, 19],
|
197 |
+
method: str = "tsne",
|
198 |
+
n_components: int = 2,
|
199 |
+
):
|
200 |
+
path_to_audio = Path(path_to_audio)
|
201 |
+
assert path_to_audio.exists(), f"{path_to_audio} does not exist"
|
202 |
+
|
203 |
+
cache_dir = Path(cache_dir)
|
204 |
+
output_dir = Path(output_dir)
|
205 |
+
output_dir.mkdir(exist_ok=True, parents=True)
|
206 |
+
|
207 |
+
# load our interface
|
208 |
+
# argbind will automatically load the default config,
|
209 |
+
interface = Interface()
|
210 |
+
|
211 |
+
# we expect path_to_audio to consist of a folder for each label, so let's get the list of labels
|
212 |
+
labels = [Path(x).name for x in path_to_audio.iterdir() if x.is_dir()]
|
213 |
+
print(f"Found {len(labels)} labels")
|
214 |
+
print(f"labels: {labels}")
|
215 |
+
|
216 |
+
# collect audio files, labels, and embeddings
|
217 |
+
annotated_embeddings = []
|
218 |
+
for label in labels:
|
219 |
+
audio_files = list(at.util.find_audio(path_to_audio / label))
|
220 |
+
print(f"Found {len(audio_files)} audio files for label {label}")
|
221 |
+
|
222 |
+
for audio_file in tqdm.tqdm(audio_files, desc=f"embedding label {label}"):
|
223 |
+
# check if we have a cached embedding for this file
|
224 |
+
cached_path = cache_dir / f"{label}_{audio_file.stem}.emb"
|
225 |
+
if cached_path.exists():
|
226 |
+
# if so, load it
|
227 |
+
if DEBUG:
|
228 |
+
print(f"loading cached embedding for {cached_path.stem}")
|
229 |
+
embedding = AnnotatedEmbedding.load(cached_path)
|
230 |
+
else:
|
231 |
+
try:
|
232 |
+
sig = AudioSignal(audio_file)
|
233 |
+
except Exception as e:
|
234 |
+
print(f"failed to load {audio_file.name} with error {e}")
|
235 |
+
print(f"skipping {audio_file.name}")
|
236 |
+
continue
|
237 |
+
|
238 |
+
# gets the embedding
|
239 |
+
emb = vampnet_embed(sig, interface).cpu().numpy()
|
240 |
+
|
241 |
+
# create an embedding we can save/load
|
242 |
+
embedding = AnnotatedEmbedding(
|
243 |
+
label=label, filename=audio_file.name, embedding=emb
|
244 |
+
)
|
245 |
+
|
246 |
+
# cache the embeddings
|
247 |
+
cached_path.parent.mkdir(exist_ok=True, parents=True)
|
248 |
+
embedding.save(cached_path)
|
249 |
+
annotated_embeddings.append(embedding)
|
250 |
+
|
251 |
+
# now, let's do a dim reduction on the embeddings and visualize them.
|
252 |
+
for layer in tqdm.tqdm(layers, desc="dim reduction"):
|
253 |
+
dim_reduce(
|
254 |
+
annotated_embeddings,
|
255 |
+
layer,
|
256 |
+
output_dir=output_dir,
|
257 |
+
n_components=n_components,
|
258 |
+
method=method,
|
259 |
+
)
|
260 |
+
|
261 |
+
|
262 |
+
if __name__ == "__main__":
|
263 |
+
args = argbind.parse_args()
|
264 |
+
with argbind.scope(args):
|
265 |
+
main()
|
scripts/utils/xeno-canto-dl.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from xenopy import Query
|
2 |
+
|
3 |
+
|
4 |
+
SPECIES = [
|
5 |
+
"American Robin",
|
6 |
+
"Northern Cardinal",
|
7 |
+
"Mourning Dove",
|
8 |
+
"American Crow",
|
9 |
+
"Baltimore Oriole",
|
10 |
+
"Blue Jay",
|
11 |
+
"Eastern Bluebird",
|
12 |
+
"House Finch",
|
13 |
+
"American Goldfinch",
|
14 |
+
"House Sparrow",
|
15 |
+
"Song Sparrow",
|
16 |
+
"Tufted Titmouse",
|
17 |
+
"White-breasted Nuthatch",
|
18 |
+
"European Starling",
|
19 |
+
"American Redstart",
|
20 |
+
"Red-winged Blackbird",
|
21 |
+
"Brown-headed Cowbird",
|
22 |
+
"Common Grackle",
|
23 |
+
"Boat-tailed Grackle",
|
24 |
+
"Common Yellowthroat",
|
25 |
+
"Northern Mockingbird",
|
26 |
+
"Carolina Wren",
|
27 |
+
"Eastern Meadowlark",
|
28 |
+
"Chipping Sparrow",
|
29 |
+
"Tree Swallow",
|
30 |
+
"Barn Swallow",
|
31 |
+
"Cliff Swallow",
|
32 |
+
"Pine Siskin",
|
33 |
+
"Indigo Bunting",
|
34 |
+
"Eastern Towhee",
|
35 |
+
"Carolina Chickadee",
|
36 |
+
"Great Crested Flycatcher",
|
37 |
+
"Eastern Wood-Pewee",
|
38 |
+
"Ovenbird",
|
39 |
+
"Northern Flicker",
|
40 |
+
"Red-eyed Vireo",
|
41 |
+
"American Woodcock",
|
42 |
+
"Eastern Phoebe",
|
43 |
+
"Downy Woodpecker",
|
44 |
+
"Scarlet Tanager",
|
45 |
+
"Yellow Warbler",
|
46 |
+
"White-eyed Vireo",
|
47 |
+
"Common Loon",
|
48 |
+
"White-throated Sparrow",
|
49 |
+
"Yellow-throated Vireo",
|
50 |
+
"Great Blue Heron",
|
51 |
+
"Belted Kingfisher",
|
52 |
+
"Pied-billed Grebe",
|
53 |
+
"Wild Turkey",
|
54 |
+
"Wood Thrush",
|
55 |
+
"Rose-breasted Grosbeak",
|
56 |
+
"Field Sparrow",
|
57 |
+
"Hooded Warbler",
|
58 |
+
"Northern Parula",
|
59 |
+
"Chestnut-sided Warbler",
|
60 |
+
"Blue-winged Warbler",
|
61 |
+
"Red-bellied Woodpecker",
|
62 |
+
"Yellow-billed Cuckoo",
|
63 |
+
"Gray Catbird",
|
64 |
+
"Northern Saw-whet Owl",
|
65 |
+
"Osprey",
|
66 |
+
"Common Nighthawk",
|
67 |
+
"Broad-winged Hawk",
|
68 |
+
"Black-throated Green Warbler",
|
69 |
+
"Great Horned Owl",
|
70 |
+
"Common Raven",
|
71 |
+
"Barred Owl",
|
72 |
+
"Canada Warbler",
|
73 |
+
"Magnolia Warbler",
|
74 |
+
"Black-and-white Warbler",
|
75 |
+
"Eastern Kingbird",
|
76 |
+
"Swainson's Thrush",
|
77 |
+
"Worm-eating Warbler",
|
78 |
+
"Prairie Warbler",
|
79 |
+
"Baltimore Oriole",
|
80 |
+
"Black-throated Blue Warbler",
|
81 |
+
"Louisiana Waterthrush",
|
82 |
+
"Blackburnian Warbler",
|
83 |
+
"Black-capped Chickadee",
|
84 |
+
"Cerulean Warbler",
|
85 |
+
"Red-shouldered Hawk",
|
86 |
+
"Cooper's Hawk",
|
87 |
+
"Yellow-throated Warbler",
|
88 |
+
"Blue-headed Vireo",
|
89 |
+
"Blackpoll Warbler",
|
90 |
+
"Ruffed Grouse",
|
91 |
+
"Kentucky Warbler",
|
92 |
+
"Hermit Thrush",
|
93 |
+
"Cedar Waxwing",
|
94 |
+
"Eastern Screech-Owl",
|
95 |
+
"Northern Goshawk",
|
96 |
+
"Green Heron",
|
97 |
+
"Red-tailed Hawk",
|
98 |
+
"Black Vulture",
|
99 |
+
"Hairy Woodpecker",
|
100 |
+
"Golden-crowned Kinglet",
|
101 |
+
"Ruby-crowned Kinglet",
|
102 |
+
"Bicknell's Thrush",
|
103 |
+
"Blue-gray Gnatcatcher",
|
104 |
+
"Veery",
|
105 |
+
"Pileated Woodpecker",
|
106 |
+
"Purple Finch",
|
107 |
+
"White-crowned Sparrow",
|
108 |
+
"Snow Bunting",
|
109 |
+
"Pine Grosbeak",
|
110 |
+
"American Tree Sparrow",
|
111 |
+
"Dark-eyed Junco",
|
112 |
+
"Snowy Owl",
|
113 |
+
"White-winged Crossbill",
|
114 |
+
"Red Crossbill",
|
115 |
+
"Common Redpoll",
|
116 |
+
"Northern Shrike",
|
117 |
+
"Northern Harrier",
|
118 |
+
"Rough-legged Hawk",
|
119 |
+
"Long-eared Owl",
|
120 |
+
"Evening Grosbeak",
|
121 |
+
"Northern Pintail",
|
122 |
+
"American Black Duck",
|
123 |
+
"Mallard",
|
124 |
+
"Canvasback",
|
125 |
+
"Redhead",
|
126 |
+
"Ring-necked Duck",
|
127 |
+
"Greater Scaup",
|
128 |
+
"Lesser Scaup",
|
129 |
+
"Bufflehead",
|
130 |
+
"Common Goldeneye",
|
131 |
+
"Hooded Merganser",
|
132 |
+
"Common Merganser",
|
133 |
+
"Red-breasted Merganser",
|
134 |
+
"Ruddy Duck",
|
135 |
+
"Wood Duck",
|
136 |
+
"Gadwall",
|
137 |
+
"American Wigeon",
|
138 |
+
"Northern Shoveler",
|
139 |
+
"Green-winged Teal",
|
140 |
+
"Blue-winged Teal",
|
141 |
+
"Cinnamon Teal",
|
142 |
+
"Ringed Teal",
|
143 |
+
"Cape Teal",
|
144 |
+
"Northern Fulmar",
|
145 |
+
"Yellow-billed Loon",
|
146 |
+
"Red-throated Loon",
|
147 |
+
"Arctic Loon",
|
148 |
+
"Pacific Loon",
|
149 |
+
"Horned Grebe",
|
150 |
+
"Red-necked Grebe",
|
151 |
+
"Eared Grebe",
|
152 |
+
"Western Grebe",
|
153 |
+
"Clark's Grebe",
|
154 |
+
"Double-crested Cormorant",
|
155 |
+
"Pelagic Cormorant",
|
156 |
+
"Great Cormorant",
|
157 |
+
"American White Pelican",
|
158 |
+
"Brown Pelican",
|
159 |
+
"Brandt's Cormorant",
|
160 |
+
"Least Bittern",
|
161 |
+
"Great Egret",
|
162 |
+
"Snowy Egret",
|
163 |
+
"Little Blue Heron",
|
164 |
+
"Tricolored Heron",
|
165 |
+
"Reddish Egret",
|
166 |
+
"Black-crowned Night-Heron",
|
167 |
+
"Yellow-crowned Night-Heron",
|
168 |
+
"White Ibis",
|
169 |
+
"Glossy Ibis",
|
170 |
+
"Roseate Spoonbill",
|
171 |
+
"Wood Stork",
|
172 |
+
"Black-bellied Whistling-Duck",
|
173 |
+
"Fulvous Whistling-Duck",
|
174 |
+
"Greater White-fronted Goose",
|
175 |
+
"Snow Goose",
|
176 |
+
"Ross's Goose",
|
177 |
+
"Canada Goose",
|
178 |
+
"Brant",
|
179 |
+
"Mute Swan",
|
180 |
+
"Tundra Swan",
|
181 |
+
"Whooper Swan",
|
182 |
+
"Sandhill Crane",
|
183 |
+
"Black-necked Stilt",
|
184 |
+
"American Avocet",
|
185 |
+
"Northern Jacana",
|
186 |
+
"Greater Yellowlegs",
|
187 |
+
"Lesser Yellowlegs",
|
188 |
+
"Willet",
|
189 |
+
"Spotted Sandpiper",
|
190 |
+
"Upland Sandpiper",
|
191 |
+
"Whimbrel",
|
192 |
+
"Long-billed Curlew",
|
193 |
+
"Marbled Godwit",
|
194 |
+
"Ruddy Turnstone",
|
195 |
+
"Red Knot",
|
196 |
+
"Sanderling",
|
197 |
+
"Semipalmated Sandpiper",
|
198 |
+
"Western Sandpiper",
|
199 |
+
"Least Sandpiper",
|
200 |
+
"White-rumped Sandpiper",
|
201 |
+
"Baird's Sandpiper",
|
202 |
+
"Pectoral Sandpiper",
|
203 |
+
"Dunlin",
|
204 |
+
"Buff-breasted Sandpiper",
|
205 |
+
"Short-billed Dowitcher",
|
206 |
+
"Long-billed Dowitcher",
|
207 |
+
"Common Snipe",
|
208 |
+
"American Woodcock",
|
209 |
+
"Wilson's Phalarope",
|
210 |
+
"Red-necked Phalarope",
|
211 |
+
"Red Phalarope"
|
212 |
+
]
|
213 |
+
|
214 |
+
from pathlib import Path
|
215 |
+
|
216 |
+
def remove_spaces(s):
|
217 |
+
return s.replace(" ", "")
|
218 |
+
|
219 |
+
for species in SPECIES:
|
220 |
+
if Path("/media/CHONK/hugo/xeno-canto-full/" + remove_spaces(species)).exists():
|
221 |
+
continue
|
222 |
+
try:
|
223 |
+
q = Query(
|
224 |
+
name=species, q="A", length="10-30",
|
225 |
+
)
|
226 |
+
|
227 |
+
# retrieve metadata
|
228 |
+
metafiles = q.retrieve_meta(verbose=True)
|
229 |
+
# retrieve recordings
|
230 |
+
q.retrieve_recordings(multiprocess=True, nproc=10, attempts=10, outdir="/media/CHONK/hugo/xeno-canto-full/")
|
231 |
+
|
232 |
+
except:
|
233 |
+
print("Failed to download " + species)
|
234 |
+
continue
|
setup.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import find_packages
|
2 |
+
from setuptools import setup
|
3 |
+
|
4 |
+
with open("README.md") as f:
|
5 |
+
long_description = f.read()
|
6 |
+
|
7 |
+
setup(
|
8 |
+
name="vampnet",
|
9 |
+
version="0.0.1",
|
10 |
+
classifiers=[
|
11 |
+
"Intended Audience :: Developers",
|
12 |
+
"Natural Language :: English",
|
13 |
+
"Programming Language :: Python :: 3.7",
|
14 |
+
"Topic :: Artistic Software",
|
15 |
+
"Topic :: Multimedia",
|
16 |
+
"Topic :: Multimedia :: Sound/Audio",
|
17 |
+
"Topic :: Multimedia :: Sound/Audio :: Editors",
|
18 |
+
"Topic :: Software Development :: Libraries",
|
19 |
+
],
|
20 |
+
description="Generative Music Modeling.",
|
21 |
+
long_description=long_description,
|
22 |
+
long_description_content_type="text/markdown",
|
23 |
+
author="Hugo Flores García, Prem Seetharaman",
|
24 |
+
author_email="hfgacrcia@descript.com",
|
25 |
+
url="https://github.com/hugofloresgarcia/vampnet",
|
26 |
+
license="MIT",
|
27 |
+
packages=find_packages(),
|
28 |
+
install_requires=[
|
29 |
+
"torch",
|
30 |
+
"argbind>=0.3.2",
|
31 |
+
"numpy==1.23",
|
32 |
+
"wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat",
|
33 |
+
"lac @ git+https://github.com/hugofloresgarcia/lac.git",
|
34 |
+
"descript-audiotools @ git+https://github.com/descriptinc/audiotools.git@0.7.2",
|
35 |
+
"gradio",
|
36 |
+
"loralib",
|
37 |
+
"torch_pitch_shift",
|
38 |
+
"madmom",
|
39 |
+
"pyharp @ git+https://github.com/audacitorch/pyharp.git",
|
40 |
+
"plotly",
|
41 |
+
"umap_learn",
|
42 |
+
],
|
43 |
+
)
|
vampnet/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from . import modules
|
3 |
+
from . import scheduler
|
4 |
+
from .interface import Interface
|
5 |
+
|
6 |
+
__version__ = "0.0.1"
|
vampnet/beats.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import warnings
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import Any
|
7 |
+
from typing import List
|
8 |
+
from typing import Tuple
|
9 |
+
from typing import Union
|
10 |
+
|
11 |
+
import librosa
|
12 |
+
import torch
|
13 |
+
import numpy as np
|
14 |
+
from audiotools import AudioSignal
|
15 |
+
|
16 |
+
|
17 |
+
logging.basicConfig(level=logging.INFO)
|
18 |
+
|
19 |
+
###################
|
20 |
+
# beat sync utils #
|
21 |
+
###################
|
22 |
+
|
23 |
+
AGGREGATOR_REGISTRY = {
|
24 |
+
"mean": np.mean,
|
25 |
+
"median": np.median,
|
26 |
+
"max": np.max,
|
27 |
+
"min": np.min,
|
28 |
+
}
|
29 |
+
|
30 |
+
|
31 |
+
def list_aggregators() -> list:
|
32 |
+
return list(AGGREGATOR_REGISTRY.keys())
|
33 |
+
|
34 |
+
|
35 |
+
@dataclass
|
36 |
+
class TimeSegment:
|
37 |
+
start: float
|
38 |
+
end: float
|
39 |
+
|
40 |
+
@property
|
41 |
+
def duration(self):
|
42 |
+
return self.end - self.start
|
43 |
+
|
44 |
+
def __str__(self) -> str:
|
45 |
+
return f"{self.start} - {self.end}"
|
46 |
+
|
47 |
+
def find_overlapping_segment(
|
48 |
+
self, segments: List["TimeSegment"]
|
49 |
+
) -> Union["TimeSegment", None]:
|
50 |
+
"""Find the first segment that overlaps with this segment, or None if no segment overlaps"""
|
51 |
+
for s in segments:
|
52 |
+
if s.start <= self.start and s.end >= self.end:
|
53 |
+
return s
|
54 |
+
return None
|
55 |
+
|
56 |
+
|
57 |
+
def mkdir(path: Union[Path, str]) -> Path:
|
58 |
+
p = Path(path)
|
59 |
+
p.mkdir(parents=True, exist_ok=True)
|
60 |
+
return p
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
###################
|
65 |
+
# beat data #
|
66 |
+
###################
|
67 |
+
@dataclass
|
68 |
+
class BeatSegment(TimeSegment):
|
69 |
+
downbeat: bool = False # if there's a downbeat on the start_time
|
70 |
+
|
71 |
+
|
72 |
+
class Beats:
|
73 |
+
def __init__(self, beat_times, downbeat_times):
|
74 |
+
if isinstance(beat_times, np.ndarray):
|
75 |
+
beat_times = beat_times.tolist()
|
76 |
+
if isinstance(downbeat_times, np.ndarray):
|
77 |
+
downbeat_times = downbeat_times.tolist()
|
78 |
+
self._beat_times = beat_times
|
79 |
+
self._downbeat_times = downbeat_times
|
80 |
+
self._use_downbeats = False
|
81 |
+
|
82 |
+
def use_downbeats(self, use_downbeats: bool = True):
|
83 |
+
"""use downbeats instead of beats when calling beat_times"""
|
84 |
+
self._use_downbeats = use_downbeats
|
85 |
+
|
86 |
+
def beat_segments(self, signal: AudioSignal) -> List[BeatSegment]:
|
87 |
+
"""
|
88 |
+
segments a song into time segments corresponding to beats.
|
89 |
+
the first segment starts at 0 and ends at the first beat time.
|
90 |
+
the last segment starts at the last beat time and ends at the end of the song.
|
91 |
+
"""
|
92 |
+
beat_times = self._beat_times.copy()
|
93 |
+
downbeat_times = self._downbeat_times
|
94 |
+
beat_times.insert(0, 0)
|
95 |
+
beat_times.append(signal.signal_duration)
|
96 |
+
|
97 |
+
downbeat_ids = np.intersect1d(beat_times, downbeat_times, return_indices=True)[
|
98 |
+
1
|
99 |
+
]
|
100 |
+
is_downbeat = [
|
101 |
+
True if i in downbeat_ids else False for i in range(len(beat_times))
|
102 |
+
]
|
103 |
+
segments = [
|
104 |
+
BeatSegment(start_time, end_time, downbeat)
|
105 |
+
for start_time, end_time, downbeat in zip(
|
106 |
+
beat_times[:-1], beat_times[1:], is_downbeat
|
107 |
+
)
|
108 |
+
]
|
109 |
+
return segments
|
110 |
+
|
111 |
+
def get_beats(self) -> np.ndarray:
|
112 |
+
"""returns an array of beat times, in seconds
|
113 |
+
if downbeats is True, returns an array of downbeat times, in seconds
|
114 |
+
"""
|
115 |
+
return np.array(
|
116 |
+
self._downbeat_times if self._use_downbeats else self._beat_times
|
117 |
+
)
|
118 |
+
|
119 |
+
@property
|
120 |
+
def beat_times(self) -> np.ndarray:
|
121 |
+
"""return beat times"""
|
122 |
+
return np.array(self._beat_times)
|
123 |
+
|
124 |
+
@property
|
125 |
+
def downbeat_times(self) -> np.ndarray:
|
126 |
+
"""return downbeat times"""
|
127 |
+
return np.array(self._downbeat_times)
|
128 |
+
|
129 |
+
def beat_times_to_feature_frames(
|
130 |
+
self, signal: AudioSignal, features: np.ndarray
|
131 |
+
) -> np.ndarray:
|
132 |
+
"""convert beat times to frames, given an array of time-varying features"""
|
133 |
+
beat_times = self.get_beats()
|
134 |
+
beat_frames = (
|
135 |
+
beat_times * signal.sample_rate / signal.signal_length * features.shape[-1]
|
136 |
+
).astype(np.int64)
|
137 |
+
return beat_frames
|
138 |
+
|
139 |
+
def sync_features(
|
140 |
+
self, feature_frames: np.ndarray, features: np.ndarray, aggregate="median"
|
141 |
+
) -> np.ndarray:
|
142 |
+
"""sync features to beats"""
|
143 |
+
if aggregate not in AGGREGATOR_REGISTRY:
|
144 |
+
raise ValueError(f"unknown aggregation method {aggregate}")
|
145 |
+
|
146 |
+
return librosa.util.sync(
|
147 |
+
features, feature_frames, aggregate=AGGREGATOR_REGISTRY[aggregate]
|
148 |
+
)
|
149 |
+
|
150 |
+
def to_json(self) -> dict:
|
151 |
+
"""return beats and downbeats as json"""
|
152 |
+
return {
|
153 |
+
"beats": self._beat_times,
|
154 |
+
"downbeats": self._downbeat_times,
|
155 |
+
"use_downbeats": self._use_downbeats,
|
156 |
+
}
|
157 |
+
|
158 |
+
@classmethod
|
159 |
+
def from_dict(cls, data: dict):
|
160 |
+
"""load beats and downbeats from json"""
|
161 |
+
inst = cls(data["beats"], data["downbeats"])
|
162 |
+
inst.use_downbeats(data["use_downbeats"])
|
163 |
+
return inst
|
164 |
+
|
165 |
+
def save(self, output_dir: Path):
|
166 |
+
"""save beats and downbeats to json"""
|
167 |
+
mkdir(output_dir)
|
168 |
+
with open(output_dir / "beats.json", "w") as f:
|
169 |
+
json.dump(self.to_json(), f)
|
170 |
+
|
171 |
+
@classmethod
|
172 |
+
def load(cls, input_dir: Path):
|
173 |
+
"""load beats and downbeats from json"""
|
174 |
+
beats_file = Path(input_dir) / "beats.json"
|
175 |
+
with open(beats_file, "r") as f:
|
176 |
+
data = json.load(f)
|
177 |
+
return cls.from_dict(data)
|
178 |
+
|
179 |
+
|
180 |
+
###################
|
181 |
+
# beat tracking #
|
182 |
+
###################
|
183 |
+
|
184 |
+
|
185 |
+
class BeatTracker:
|
186 |
+
def extract_beats(self, signal: AudioSignal) -> Tuple[np.ndarray, np.ndarray]:
|
187 |
+
"""extract beats from an audio signal"""
|
188 |
+
raise NotImplementedError
|
189 |
+
|
190 |
+
def __call__(self, signal: AudioSignal) -> Beats:
|
191 |
+
"""extract beats from an audio signal
|
192 |
+
NOTE: if the first beat (and/or downbeat) is detected within the first 100ms of the audio,
|
193 |
+
it is discarded. This is to avoid empty bins with no beat synced features in the first beat.
|
194 |
+
Args:
|
195 |
+
signal (AudioSignal): signal to beat track
|
196 |
+
Returns:
|
197 |
+
Tuple[np.ndarray, np.ndarray]: beats and downbeats
|
198 |
+
"""
|
199 |
+
beats, downbeats = self.extract_beats(signal)
|
200 |
+
return Beats(beats, downbeats)
|
201 |
+
|
202 |
+
|
203 |
+
class WaveBeat(BeatTracker):
|
204 |
+
def __init__(self, ckpt_path: str = "checkpoints/wavebeat", device: str = "cpu"):
|
205 |
+
from wavebeat.dstcn import dsTCNModel
|
206 |
+
|
207 |
+
model = dsTCNModel.load_from_checkpoint(ckpt_path, map_location=torch.device(device))
|
208 |
+
model.eval()
|
209 |
+
|
210 |
+
self.device = device
|
211 |
+
self.model = model
|
212 |
+
|
213 |
+
def extract_beats(self, signal: AudioSignal) -> Tuple[np.ndarray, np.ndarray]:
|
214 |
+
"""returns beat and downbeat times, in seconds"""
|
215 |
+
# extract beats
|
216 |
+
beats, downbeats = self.model.predict_beats_from_array(
|
217 |
+
audio=signal.audio_data.squeeze(0),
|
218 |
+
sr=signal.sample_rate,
|
219 |
+
use_gpu=self.device != "cpu",
|
220 |
+
)
|
221 |
+
|
222 |
+
return beats, downbeats
|
223 |
+
|
224 |
+
|
225 |
+
class MadmomBeats(BeatTracker):
|
226 |
+
def __init__(self):
|
227 |
+
raise NotImplementedError
|
228 |
+
|
229 |
+
def extract_beats(self, signal: AudioSignal) -> Tuple[np.ndarray, np.ndarray]:
|
230 |
+
"""returns beat and downbeat times, in seconds"""
|
231 |
+
pass
|
232 |
+
|
233 |
+
|
234 |
+
BEAT_TRACKER_REGISTRY = {
|
235 |
+
"wavebeat": WaveBeat,
|
236 |
+
"madmom": MadmomBeats,
|
237 |
+
}
|
238 |
+
|
239 |
+
|
240 |
+
def list_beat_trackers() -> list:
|
241 |
+
return list(BEAT_TRACKER_REGISTRY.keys())
|
242 |
+
|
243 |
+
|
244 |
+
def load_beat_tracker(beat_tracker: str, **kwargs) -> BeatTracker:
|
245 |
+
if beat_tracker not in BEAT_TRACKER_REGISTRY:
|
246 |
+
raise ValueError(
|
247 |
+
f"Unknown beat tracker {beat_tracker}. Available: {list_beat_trackers()}"
|
248 |
+
)
|
249 |
+
|
250 |
+
return BEAT_TRACKER_REGISTRY[beat_tracker](**kwargs)
|
vampnet/interface.py
ADDED
@@ -0,0 +1,432 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
import math
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
from audiotools import AudioSignal
|
8 |
+
import tqdm
|
9 |
+
|
10 |
+
from .modules.transformer import VampNet
|
11 |
+
from .beats import WaveBeat
|
12 |
+
from .mask import *
|
13 |
+
|
14 |
+
# from dac.model.dac import DAC
|
15 |
+
from lac.model.lac import LAC as DAC
|
16 |
+
|
17 |
+
|
18 |
+
def signal_concat(
|
19 |
+
audio_signals: list,
|
20 |
+
):
|
21 |
+
audio_data = torch.cat([x.audio_data for x in audio_signals], dim=-1)
|
22 |
+
|
23 |
+
return AudioSignal(audio_data, sample_rate=audio_signals[0].sample_rate)
|
24 |
+
|
25 |
+
|
26 |
+
def _load_model(
|
27 |
+
ckpt: str,
|
28 |
+
lora_ckpt: str = None,
|
29 |
+
device: str = "cpu",
|
30 |
+
chunk_size_s: int = 10,
|
31 |
+
):
|
32 |
+
# we need to set strict to False if the model has lora weights to add later
|
33 |
+
model = VampNet.load(location=Path(ckpt), map_location="cpu", strict=False)
|
34 |
+
|
35 |
+
# load lora weights if needed
|
36 |
+
if lora_ckpt is not None:
|
37 |
+
if not Path(lora_ckpt).exists():
|
38 |
+
should_cont = input(
|
39 |
+
f"lora checkpoint {lora_ckpt} does not exist. continue? (y/n) "
|
40 |
+
)
|
41 |
+
if should_cont != "y":
|
42 |
+
raise Exception("aborting")
|
43 |
+
else:
|
44 |
+
model.load_state_dict(torch.load(lora_ckpt, map_location="cpu"), strict=False)
|
45 |
+
|
46 |
+
model.to(device)
|
47 |
+
model.eval()
|
48 |
+
model.chunk_size_s = chunk_size_s
|
49 |
+
return model
|
50 |
+
|
51 |
+
|
52 |
+
|
53 |
+
class Interface(torch.nn.Module):
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
coarse_ckpt: str = None,
|
57 |
+
coarse_lora_ckpt: str = None,
|
58 |
+
coarse2fine_ckpt: str = None,
|
59 |
+
coarse2fine_lora_ckpt: str = None,
|
60 |
+
codec_ckpt: str = None,
|
61 |
+
wavebeat_ckpt: str = None,
|
62 |
+
device: str = "cpu",
|
63 |
+
coarse_chunk_size_s: int = 10,
|
64 |
+
coarse2fine_chunk_size_s: int = 3,
|
65 |
+
):
|
66 |
+
super().__init__()
|
67 |
+
assert codec_ckpt is not None, "must provide a codec checkpoint"
|
68 |
+
self.codec = DAC.load(Path(codec_ckpt))
|
69 |
+
self.codec.eval()
|
70 |
+
self.codec.to(device)
|
71 |
+
self.codec_path = Path(codec_ckpt)
|
72 |
+
|
73 |
+
assert coarse_ckpt is not None, "must provide a coarse checkpoint"
|
74 |
+
self.coarse = _load_model(
|
75 |
+
ckpt=coarse_ckpt,
|
76 |
+
lora_ckpt=coarse_lora_ckpt,
|
77 |
+
device=device,
|
78 |
+
chunk_size_s=coarse_chunk_size_s,
|
79 |
+
)
|
80 |
+
self.coarse_path = Path(coarse_ckpt)
|
81 |
+
|
82 |
+
# check if we have a coarse2fine ckpt
|
83 |
+
if coarse2fine_ckpt is not None:
|
84 |
+
self.c2f_path = Path(coarse2fine_ckpt)
|
85 |
+
self.c2f = _load_model(
|
86 |
+
ckpt=coarse2fine_ckpt,
|
87 |
+
lora_ckpt=coarse2fine_lora_ckpt,
|
88 |
+
device=device,
|
89 |
+
chunk_size_s=coarse2fine_chunk_size_s,
|
90 |
+
)
|
91 |
+
else:
|
92 |
+
self.c2f_path = None
|
93 |
+
self.c2f = None
|
94 |
+
|
95 |
+
if wavebeat_ckpt is not None:
|
96 |
+
print(f"loading wavebeat from {wavebeat_ckpt}")
|
97 |
+
self.beat_tracker = WaveBeat(wavebeat_ckpt)
|
98 |
+
self.beat_tracker.model.to(device)
|
99 |
+
else:
|
100 |
+
self.beat_tracker = None
|
101 |
+
|
102 |
+
self.device = device
|
103 |
+
|
104 |
+
def reload(
|
105 |
+
self,
|
106 |
+
coarse_ckpt: str = None,
|
107 |
+
c2f_ckpt: str = None,
|
108 |
+
):
|
109 |
+
if coarse_ckpt is not None:
|
110 |
+
# check if we already loaded, if so, don't reload
|
111 |
+
if self.coarse_path == Path(coarse_ckpt):
|
112 |
+
print(f"already loaded {coarse_ckpt}")
|
113 |
+
return
|
114 |
+
self.coarse = _load_model(
|
115 |
+
ckpt=coarse_ckpt,
|
116 |
+
device=self.device,
|
117 |
+
chunk_size_s=self.coarse.chunk_size_s,
|
118 |
+
)
|
119 |
+
self.coarse_path = Path(coarse_ckpt)
|
120 |
+
print(f"loaded {coarse_ckpt}")
|
121 |
+
|
122 |
+
if c2f_ckpt is not None:
|
123 |
+
if self.c2f_path == Path(c2f_ckpt):
|
124 |
+
print(f"already loaded {c2f_ckpt}")
|
125 |
+
return
|
126 |
+
self.c2f = _load_model(
|
127 |
+
ckpt=c2f_ckpt,
|
128 |
+
device=self.device,
|
129 |
+
chunk_size_s=self.c2f.chunk_size_s,
|
130 |
+
)
|
131 |
+
self.c2f_path = Path(c2f_ckpt)
|
132 |
+
print(f"loaded {c2f_ckpt}")
|
133 |
+
|
134 |
+
def s2t(self, seconds: float):
|
135 |
+
"""seconds to tokens"""
|
136 |
+
if isinstance(seconds, np.ndarray):
|
137 |
+
return np.ceil(seconds * self.codec.sample_rate / self.codec.hop_length)
|
138 |
+
else:
|
139 |
+
return math.ceil(seconds * self.codec.sample_rate / self.codec.hop_length)
|
140 |
+
|
141 |
+
def s2t2s(self, seconds: float):
|
142 |
+
"""seconds to tokens to seconds"""
|
143 |
+
return self.t2s(self.s2t(seconds))
|
144 |
+
|
145 |
+
def t2s(self, tokens: int):
|
146 |
+
"""tokens to seconds"""
|
147 |
+
return tokens * self.codec.hop_length / self.codec.sample_rate
|
148 |
+
|
149 |
+
def to(self, device):
|
150 |
+
self.device = device
|
151 |
+
self.coarse.to(device)
|
152 |
+
self.codec.to(device)
|
153 |
+
|
154 |
+
if self.c2f is not None:
|
155 |
+
self.c2f.to(device)
|
156 |
+
|
157 |
+
if self.beat_tracker is not None:
|
158 |
+
self.beat_tracker.model.to(device)
|
159 |
+
return self
|
160 |
+
|
161 |
+
def to_signal(self, z: torch.Tensor):
|
162 |
+
return self.coarse.to_signal(z, self.codec)
|
163 |
+
|
164 |
+
def preprocess(self, signal: AudioSignal):
|
165 |
+
signal = (
|
166 |
+
signal.clone()
|
167 |
+
.resample(self.codec.sample_rate)
|
168 |
+
.to_mono()
|
169 |
+
.normalize(-24)
|
170 |
+
.ensure_max_of_audio(1.0)
|
171 |
+
)
|
172 |
+
return signal
|
173 |
+
|
174 |
+
@torch.inference_mode()
|
175 |
+
def encode(self, signal: AudioSignal):
|
176 |
+
signal = self.preprocess(signal).to(self.device)
|
177 |
+
z = self.codec.encode(signal.samples, signal.sample_rate)["codes"]
|
178 |
+
return z
|
179 |
+
|
180 |
+
def snap_to_beats(
|
181 |
+
self,
|
182 |
+
signal: AudioSignal
|
183 |
+
):
|
184 |
+
assert hasattr(self, "beat_tracker"), "No beat tracker loaded"
|
185 |
+
beats, downbeats = self.beat_tracker.extract_beats(signal)
|
186 |
+
|
187 |
+
# trim the signa around the first beat time
|
188 |
+
samples_begin = int(beats[0] * signal.sample_rate )
|
189 |
+
samples_end = int(beats[-1] * signal.sample_rate)
|
190 |
+
print(beats[0])
|
191 |
+
signal = signal.clone().trim(samples_begin, signal.length - samples_end)
|
192 |
+
|
193 |
+
return signal
|
194 |
+
|
195 |
+
def make_beat_mask(self,
|
196 |
+
signal: AudioSignal,
|
197 |
+
before_beat_s: float = 0.0,
|
198 |
+
after_beat_s: float = 0.02,
|
199 |
+
mask_downbeats: bool = True,
|
200 |
+
mask_upbeats: bool = True,
|
201 |
+
downbeat_downsample_factor: int = None,
|
202 |
+
beat_downsample_factor: int = None,
|
203 |
+
dropout: float = 0.0,
|
204 |
+
invert: bool = True,
|
205 |
+
):
|
206 |
+
"""make a beat synced mask. that is, make a mask that
|
207 |
+
places 1s at and around the beat, and 0s everywhere else.
|
208 |
+
"""
|
209 |
+
assert self.beat_tracker is not None, "No beat tracker loaded"
|
210 |
+
|
211 |
+
# get the beat times
|
212 |
+
beats, downbeats = self.beat_tracker.extract_beats(signal)
|
213 |
+
|
214 |
+
# get the beat indices in z
|
215 |
+
beats_z, downbeats_z = self.s2t(beats), self.s2t(downbeats)
|
216 |
+
|
217 |
+
# remove downbeats from beats
|
218 |
+
beats_z = torch.tensor(beats_z)[~torch.isin(torch.tensor(beats_z), torch.tensor(downbeats_z))]
|
219 |
+
beats_z = beats_z.tolist()
|
220 |
+
downbeats_z = downbeats_z.tolist()
|
221 |
+
|
222 |
+
# make the mask
|
223 |
+
seq_len = self.s2t(signal.duration)
|
224 |
+
mask = torch.zeros(seq_len, device=self.device)
|
225 |
+
|
226 |
+
mask_b4 = self.s2t(before_beat_s)
|
227 |
+
mask_after = self.s2t(after_beat_s)
|
228 |
+
|
229 |
+
if beat_downsample_factor is not None:
|
230 |
+
if beat_downsample_factor < 1:
|
231 |
+
raise ValueError("mask_beat_downsample_factor must be >= 1 or None")
|
232 |
+
else:
|
233 |
+
beat_downsample_factor = 1
|
234 |
+
|
235 |
+
if downbeat_downsample_factor is not None:
|
236 |
+
if downbeat_downsample_factor < 1:
|
237 |
+
raise ValueError("mask_beat_downsample_factor must be >= 1 or None")
|
238 |
+
else:
|
239 |
+
downbeat_downsample_factor = 1
|
240 |
+
|
241 |
+
beats_z = beats_z[::beat_downsample_factor]
|
242 |
+
downbeats_z = downbeats_z[::downbeat_downsample_factor]
|
243 |
+
print(f"beats_z: {len(beats_z)}")
|
244 |
+
print(f"downbeats_z: {len(downbeats_z)}")
|
245 |
+
|
246 |
+
if mask_upbeats:
|
247 |
+
for beat_idx in beats_z:
|
248 |
+
_slice = int(beat_idx - mask_b4), int(beat_idx + mask_after)
|
249 |
+
num_steps = mask[_slice[0]:_slice[1]].shape[0]
|
250 |
+
_m = torch.ones(num_steps, device=self.device)
|
251 |
+
_m_mask = torch.bernoulli(_m * (1 - dropout))
|
252 |
+
_m = _m * _m_mask.long()
|
253 |
+
|
254 |
+
mask[_slice[0]:_slice[1]] = _m
|
255 |
+
|
256 |
+
if mask_downbeats:
|
257 |
+
for downbeat_idx in downbeats_z:
|
258 |
+
_slice = int(downbeat_idx - mask_b4), int(downbeat_idx + mask_after)
|
259 |
+
num_steps = mask[_slice[0]:_slice[1]].shape[0]
|
260 |
+
_m = torch.ones(num_steps, device=self.device)
|
261 |
+
_m_mask = torch.bernoulli(_m * (1 - dropout))
|
262 |
+
_m = _m * _m_mask.long()
|
263 |
+
|
264 |
+
mask[_slice[0]:_slice[1]] = _m
|
265 |
+
|
266 |
+
mask = mask.clamp(0, 1)
|
267 |
+
if invert:
|
268 |
+
mask = 1 - mask
|
269 |
+
|
270 |
+
mask = mask[None, None, :].bool().long()
|
271 |
+
if self.c2f is not None:
|
272 |
+
mask = mask.repeat(1, self.c2f.n_codebooks, 1)
|
273 |
+
else:
|
274 |
+
mask = mask.repeat(1, self.coarse.n_codebooks, 1)
|
275 |
+
return mask
|
276 |
+
|
277 |
+
def coarse_to_fine(
|
278 |
+
self,
|
279 |
+
z: torch.Tensor,
|
280 |
+
mask: torch.Tensor = None,
|
281 |
+
**kwargs
|
282 |
+
):
|
283 |
+
assert self.c2f is not None, "No coarse2fine model loaded"
|
284 |
+
length = z.shape[-1]
|
285 |
+
chunk_len = self.s2t(self.c2f.chunk_size_s)
|
286 |
+
n_chunks = math.ceil(z.shape[-1] / chunk_len)
|
287 |
+
|
288 |
+
# zero pad to chunk_len
|
289 |
+
if length % chunk_len != 0:
|
290 |
+
pad_len = chunk_len - (length % chunk_len)
|
291 |
+
z = torch.nn.functional.pad(z, (0, pad_len))
|
292 |
+
mask = torch.nn.functional.pad(mask, (0, pad_len)) if mask is not None else None
|
293 |
+
|
294 |
+
n_codebooks_to_append = self.c2f.n_codebooks - z.shape[1]
|
295 |
+
if n_codebooks_to_append > 0:
|
296 |
+
z = torch.cat([
|
297 |
+
z,
|
298 |
+
torch.zeros(z.shape[0], n_codebooks_to_append, z.shape[-1]).long().to(self.device)
|
299 |
+
], dim=1)
|
300 |
+
|
301 |
+
# set the mask to 0 for all conditioning codebooks
|
302 |
+
if mask is not None:
|
303 |
+
mask = mask.clone()
|
304 |
+
mask[:, :self.c2f.n_conditioning_codebooks, :] = 0
|
305 |
+
|
306 |
+
fine_z = []
|
307 |
+
for i in range(n_chunks):
|
308 |
+
chunk = z[:, :, i * chunk_len : (i + 1) * chunk_len]
|
309 |
+
mask_chunk = mask[:, :, i * chunk_len : (i + 1) * chunk_len] if mask is not None else None
|
310 |
+
|
311 |
+
chunk = self.c2f.generate(
|
312 |
+
codec=self.codec,
|
313 |
+
time_steps=chunk_len,
|
314 |
+
start_tokens=chunk,
|
315 |
+
return_signal=False,
|
316 |
+
mask=mask_chunk,
|
317 |
+
**kwargs
|
318 |
+
)
|
319 |
+
fine_z.append(chunk)
|
320 |
+
|
321 |
+
fine_z = torch.cat(fine_z, dim=-1)
|
322 |
+
return fine_z[:, :, :length].clone()
|
323 |
+
|
324 |
+
def coarse_vamp(
|
325 |
+
self,
|
326 |
+
z,
|
327 |
+
mask,
|
328 |
+
return_mask=False,
|
329 |
+
gen_fn=None,
|
330 |
+
**kwargs
|
331 |
+
):
|
332 |
+
# coarse z
|
333 |
+
cz = z[:, : self.coarse.n_codebooks, :].clone()
|
334 |
+
assert cz.shape[-1] <= self.s2t(self.coarse.chunk_size_s), f"the sequence of tokens provided must match the one specified in the coarse chunk size, but got {cz.shape[-1]} and {self.s2t(self.coarse.chunk_size_s)}"
|
335 |
+
|
336 |
+
mask = mask[:, : self.coarse.n_codebooks, :]
|
337 |
+
|
338 |
+
cz_masked, mask = apply_mask(cz, mask, self.coarse.mask_token)
|
339 |
+
cz_masked = cz_masked[:, : self.coarse.n_codebooks, :]
|
340 |
+
|
341 |
+
gen_fn = gen_fn or self.coarse.generate
|
342 |
+
c_vamp = gen_fn(
|
343 |
+
codec=self.codec,
|
344 |
+
time_steps=cz.shape[-1],
|
345 |
+
start_tokens=cz,
|
346 |
+
mask=mask,
|
347 |
+
return_signal=False,
|
348 |
+
**kwargs
|
349 |
+
)
|
350 |
+
|
351 |
+
# add the fine codes back in
|
352 |
+
c_vamp = torch.cat(
|
353 |
+
[c_vamp, z[:, self.coarse.n_codebooks :, :]],
|
354 |
+
dim=1
|
355 |
+
)
|
356 |
+
|
357 |
+
if return_mask:
|
358 |
+
return c_vamp, cz_masked
|
359 |
+
|
360 |
+
return c_vamp
|
361 |
+
|
362 |
+
# def chunked_coarse_vamp(
|
363 |
+
# self,
|
364 |
+
# z,
|
365 |
+
# mask,
|
366 |
+
# return_mask=False,
|
367 |
+
# gen_fn=None,
|
368 |
+
# **kwargs
|
369 |
+
# )
|
370 |
+
|
371 |
+
|
372 |
+
if __name__ == "__main__":
|
373 |
+
import audiotools as at
|
374 |
+
import logging
|
375 |
+
logger = logging.getLogger()
|
376 |
+
logger.setLevel(logging.INFO)
|
377 |
+
torch.set_printoptions(threshold=10000)
|
378 |
+
at.util.seed(42)
|
379 |
+
|
380 |
+
interface = Interface(
|
381 |
+
coarse_ckpt="./models/vampnet/coarse.pth",
|
382 |
+
coarse2fine_ckpt="./models/vampnet/c2f.pth",
|
383 |
+
codec_ckpt="./models/vampnet/codec.pth",
|
384 |
+
device="cuda",
|
385 |
+
wavebeat_ckpt="./models/wavebeat.pth"
|
386 |
+
)
|
387 |
+
|
388 |
+
|
389 |
+
sig = at.AudioSignal('assets/example.wav')
|
390 |
+
|
391 |
+
z = interface.encode(sig)
|
392 |
+
breakpoint()
|
393 |
+
|
394 |
+
# mask = linear_random(z, 1.0)
|
395 |
+
# mask = mask_and(
|
396 |
+
# mask, periodic_mask(
|
397 |
+
# z,
|
398 |
+
# 32,
|
399 |
+
# 1,
|
400 |
+
# random_roll=True
|
401 |
+
# )
|
402 |
+
# )
|
403 |
+
|
404 |
+
# mask = interface.make_beat_mask(
|
405 |
+
# sig, 0.0, 0.075
|
406 |
+
# )
|
407 |
+
# mask = dropout(mask, 0.0)
|
408 |
+
# mask = codebook_unmask(mask, 0)
|
409 |
+
|
410 |
+
mask = inpaint(z, n_prefix=100, n_suffix=100)
|
411 |
+
|
412 |
+
zv, mask_z = interface.coarse_vamp(
|
413 |
+
z,
|
414 |
+
mask=mask,
|
415 |
+
sampling_steps=36,
|
416 |
+
temperature=8.0,
|
417 |
+
return_mask=True,
|
418 |
+
gen_fn=interface.coarse.generate
|
419 |
+
)
|
420 |
+
|
421 |
+
|
422 |
+
use_coarse2fine = True
|
423 |
+
if use_coarse2fine:
|
424 |
+
zv = interface.coarse_to_fine(zv, temperature=0.8, mask=mask)
|
425 |
+
breakpoint()
|
426 |
+
|
427 |
+
mask = interface.to_signal(mask_z).cpu()
|
428 |
+
|
429 |
+
sig = interface.to_signal(zv).cpu()
|
430 |
+
print("done")
|
431 |
+
|
432 |
+
|
vampnet/mask.py
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from audiotools import AudioSignal
|
5 |
+
|
6 |
+
from .util import scalar_to_batch_tensor
|
7 |
+
|
8 |
+
def _gamma(r):
|
9 |
+
return (r * torch.pi / 2).cos().clamp(1e-10, 1.0)
|
10 |
+
|
11 |
+
def _invgamma(y):
|
12 |
+
if not torch.is_tensor(y):
|
13 |
+
y = torch.tensor(y)[None]
|
14 |
+
return 2 * y.acos() / torch.pi
|
15 |
+
|
16 |
+
def full_mask(x: torch.Tensor):
|
17 |
+
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
|
18 |
+
return torch.ones_like(x).long()
|
19 |
+
|
20 |
+
def empty_mask(x: torch.Tensor):
|
21 |
+
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
|
22 |
+
return torch.zeros_like(x).long()
|
23 |
+
|
24 |
+
def apply_mask(
|
25 |
+
x: torch.Tensor,
|
26 |
+
mask: torch.Tensor,
|
27 |
+
mask_token: int
|
28 |
+
):
|
29 |
+
assert mask.ndim == 3, "mask must be (batch, n_codebooks, seq), but got {mask.ndim}"
|
30 |
+
assert mask.shape == x.shape, f"mask must be same shape as x, but got {mask.shape} and {x.shape}"
|
31 |
+
assert mask.dtype == torch.long, "mask must be long dtype, but got {mask.dtype}"
|
32 |
+
assert ~torch.any(mask > 1), "mask must be binary"
|
33 |
+
assert ~torch.any(mask < 0), "mask must be binary"
|
34 |
+
|
35 |
+
fill_x = torch.full_like(x, mask_token)
|
36 |
+
x = x * (1 - mask) + fill_x * mask
|
37 |
+
|
38 |
+
return x, mask
|
39 |
+
|
40 |
+
def random(
|
41 |
+
x: torch.Tensor,
|
42 |
+
r: torch.Tensor
|
43 |
+
):
|
44 |
+
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
|
45 |
+
if not isinstance(r, torch.Tensor):
|
46 |
+
r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device)
|
47 |
+
|
48 |
+
r = _gamma(r)[:, None, None]
|
49 |
+
probs = torch.ones_like(x) * r
|
50 |
+
|
51 |
+
mask = torch.bernoulli(probs)
|
52 |
+
mask = mask.round().long()
|
53 |
+
|
54 |
+
return mask
|
55 |
+
|
56 |
+
def linear_random(
|
57 |
+
x: torch.Tensor,
|
58 |
+
r: torch.Tensor,
|
59 |
+
):
|
60 |
+
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
|
61 |
+
if not isinstance(r, torch.Tensor):
|
62 |
+
r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device).float()
|
63 |
+
|
64 |
+
probs = torch.ones_like(x).to(x.device).float()
|
65 |
+
# expand to batch and codebook dims
|
66 |
+
probs = probs.expand(x.shape[0], x.shape[1], -1)
|
67 |
+
probs = probs * r
|
68 |
+
|
69 |
+
mask = torch.bernoulli(probs)
|
70 |
+
mask = mask.round().long()
|
71 |
+
|
72 |
+
return mask
|
73 |
+
|
74 |
+
def inpaint(x: torch.Tensor,
|
75 |
+
n_prefix,
|
76 |
+
n_suffix,
|
77 |
+
):
|
78 |
+
assert n_prefix is not None
|
79 |
+
assert n_suffix is not None
|
80 |
+
|
81 |
+
mask = full_mask(x)
|
82 |
+
|
83 |
+
# if we have a prefix or suffix, set their mask prob to 0
|
84 |
+
if n_prefix > 0:
|
85 |
+
if not isinstance(n_prefix, torch.Tensor):
|
86 |
+
n_prefix = scalar_to_batch_tensor(n_prefix, x.shape[0]).to(x.device)
|
87 |
+
for i, n in enumerate(n_prefix):
|
88 |
+
if n > 0:
|
89 |
+
mask[i, :, :n] = 0.0
|
90 |
+
if n_suffix > 0:
|
91 |
+
if not isinstance(n_suffix, torch.Tensor):
|
92 |
+
n_suffix = scalar_to_batch_tensor(n_suffix, x.shape[0]).to(x.device)
|
93 |
+
for i, n in enumerate(n_suffix):
|
94 |
+
if n > 0:
|
95 |
+
mask[i, :, -n:] = 0.0
|
96 |
+
|
97 |
+
|
98 |
+
return mask
|
99 |
+
|
100 |
+
def periodic_mask(x: torch.Tensor,
|
101 |
+
period: int, width: int = 1,
|
102 |
+
random_roll=False,
|
103 |
+
):
|
104 |
+
mask = full_mask(x)
|
105 |
+
if period == 0:
|
106 |
+
return mask
|
107 |
+
|
108 |
+
if not isinstance(period, torch.Tensor):
|
109 |
+
period = scalar_to_batch_tensor(period, x.shape[0])
|
110 |
+
for i, factor in enumerate(period):
|
111 |
+
if factor == 0:
|
112 |
+
continue
|
113 |
+
for j in range(mask.shape[-1]):
|
114 |
+
if j % factor == 0:
|
115 |
+
# figure out how wide the mask should be
|
116 |
+
j_start = max(0, j - width // 2 )
|
117 |
+
j_end = min(mask.shape[-1] - 1, j + width // 2 ) + 1
|
118 |
+
# flip a coin for each position in the mask
|
119 |
+
j_mask = torch.bernoulli(torch.ones(j_end - j_start))
|
120 |
+
assert torch.all(j_mask == 1)
|
121 |
+
j_fill = torch.ones_like(j_mask) * (1 - j_mask)
|
122 |
+
assert torch.all(j_fill == 0)
|
123 |
+
# fill
|
124 |
+
mask[i, :, j_start:j_end] = j_fill
|
125 |
+
if random_roll:
|
126 |
+
# add a random offset to the mask
|
127 |
+
offset = torch.randint(0, period[0], (1,))
|
128 |
+
mask = torch.roll(mask, offset.item(), dims=-1)
|
129 |
+
|
130 |
+
return mask
|
131 |
+
|
132 |
+
def codebook_unmask(
|
133 |
+
mask: torch.Tensor,
|
134 |
+
n_conditioning_codebooks: int
|
135 |
+
):
|
136 |
+
if n_conditioning_codebooks == None:
|
137 |
+
return mask
|
138 |
+
# if we have any conditioning codebooks, set their mask to 0
|
139 |
+
mask = mask.clone()
|
140 |
+
mask[:, :n_conditioning_codebooks, :] = 0
|
141 |
+
return mask
|
142 |
+
|
143 |
+
def codebook_mask(mask: torch.Tensor, start: int):
|
144 |
+
mask = mask.clone()
|
145 |
+
mask[:, start:, :] = 1
|
146 |
+
return mask
|
147 |
+
|
148 |
+
def mask_and(
|
149 |
+
mask1: torch.Tensor,
|
150 |
+
mask2: torch.Tensor
|
151 |
+
):
|
152 |
+
assert mask1.shape == mask2.shape, "masks must be same shape"
|
153 |
+
return torch.min(mask1, mask2)
|
154 |
+
|
155 |
+
def dropout(
|
156 |
+
mask: torch.Tensor,
|
157 |
+
p: float,
|
158 |
+
):
|
159 |
+
assert 0 <= p <= 1, "p must be between 0 and 1"
|
160 |
+
assert mask.max() <= 1, "mask must be binary"
|
161 |
+
assert mask.min() >= 0, "mask must be binary"
|
162 |
+
mask = (~mask.bool()).float()
|
163 |
+
mask = torch.bernoulli(mask * (1 - p))
|
164 |
+
mask = ~mask.round().bool()
|
165 |
+
return mask.long()
|
166 |
+
|
167 |
+
def mask_or(
|
168 |
+
mask1: torch.Tensor,
|
169 |
+
mask2: torch.Tensor
|
170 |
+
):
|
171 |
+
assert mask1.shape == mask2.shape, f"masks must be same shape, but got {mask1.shape} and {mask2.shape}"
|
172 |
+
assert mask1.max() <= 1, "mask1 must be binary"
|
173 |
+
assert mask2.max() <= 1, "mask2 must be binary"
|
174 |
+
assert mask1.min() >= 0, "mask1 must be binary"
|
175 |
+
assert mask2.min() >= 0, "mask2 must be binary"
|
176 |
+
return (mask1 + mask2).clamp(0, 1)
|
177 |
+
|
178 |
+
def time_stretch_mask(
|
179 |
+
x: torch.Tensor,
|
180 |
+
stretch_factor: int,
|
181 |
+
):
|
182 |
+
assert stretch_factor >= 1, "stretch factor must be >= 1"
|
183 |
+
c_seq_len = x.shape[-1]
|
184 |
+
x = x.repeat_interleave(stretch_factor, dim=-1)
|
185 |
+
|
186 |
+
# trim cz to the original length
|
187 |
+
x = x[:, :, :c_seq_len]
|
188 |
+
|
189 |
+
mask = periodic_mask(x, stretch_factor, width=1)
|
190 |
+
return mask
|
191 |
+
|
192 |
+
def onset_mask(
|
193 |
+
sig: AudioSignal,
|
194 |
+
z: torch.Tensor,
|
195 |
+
interface,
|
196 |
+
width: int = 1
|
197 |
+
):
|
198 |
+
import librosa
|
199 |
+
import madmom
|
200 |
+
from madmom.features.onsets import RNNOnsetProcessor, OnsetPeakPickingProcessor
|
201 |
+
import tempfile
|
202 |
+
import numpy as np
|
203 |
+
|
204 |
+
with tempfile.NamedTemporaryFile(suffix='.wav') as f:
|
205 |
+
sig = sig.clone()
|
206 |
+
sig.write(f.name)
|
207 |
+
|
208 |
+
proc = RNNOnsetProcessor(online=False)
|
209 |
+
onsetproc = OnsetPeakPickingProcessor(threshold=0.3,
|
210 |
+
fps=sig.sample_rate/interface.codec.hop_length)
|
211 |
+
|
212 |
+
act = proc(f.name)
|
213 |
+
onset_times = onsetproc(act)
|
214 |
+
|
215 |
+
# convert to indices for z array
|
216 |
+
onset_indices = librosa.time_to_frames(onset_times, sr=sig.sample_rate, hop_length=interface.codec.hop_length)
|
217 |
+
|
218 |
+
if onset_indices.shape[0] == 0:
|
219 |
+
mask = empty_mask(z)
|
220 |
+
print(f"no onsets found, returning empty mask")
|
221 |
+
else:
|
222 |
+
torch.set_printoptions(threshold=1000)
|
223 |
+
print("onset indices: ", onset_indices)
|
224 |
+
print("onset times: ", onset_times)
|
225 |
+
|
226 |
+
# create a mask, set onset
|
227 |
+
mask = torch.ones_like(z)
|
228 |
+
n_timesteps = z.shape[-1]
|
229 |
+
|
230 |
+
for onset_index in onset_indices:
|
231 |
+
onset_index = min(onset_index, n_timesteps - 1)
|
232 |
+
onset_index = max(onset_index, 0)
|
233 |
+
mask[:, :, onset_index - width:onset_index + width] = 0.0
|
234 |
+
|
235 |
+
print(mask)
|
236 |
+
|
237 |
+
return mask
|
238 |
+
|
239 |
+
|
240 |
+
|
241 |
+
if __name__ == "__main__":
|
242 |
+
pass
|
vampnet/modules/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import audiotools
|
2 |
+
|
3 |
+
audiotools.ml.BaseModel.INTERN += ["vampnet.modules.**"]
|
4 |
+
audiotools.ml.BaseModel.EXTERN += ["einops", "flash_attn.flash_attention", "loralib"]
|
5 |
+
|
6 |
+
from .transformer import VampNet
|
vampnet/modules/activations.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from einops import rearrange
|
7 |
+
|
8 |
+
|
9 |
+
class NewGELU(nn.Module):
|
10 |
+
"""
|
11 |
+
Implementation of the GELU activation function currently in Google BERT repo
|
12 |
+
(identical to OpenAI GPT). Also see the Gaussian Error Linear Units
|
13 |
+
paper: https://arxiv.org/abs/1606.08415
|
14 |
+
"""
|
15 |
+
|
16 |
+
def forward(self, x):
|
17 |
+
return (
|
18 |
+
0.5
|
19 |
+
* x
|
20 |
+
* (
|
21 |
+
1.0
|
22 |
+
+ torch.tanh(
|
23 |
+
math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))
|
24 |
+
)
|
25 |
+
)
|
26 |
+
)
|
27 |
+
|
28 |
+
class GatedGELU(nn.Module):
|
29 |
+
def __init__(self):
|
30 |
+
super().__init__()
|
31 |
+
self.gelu = NewGELU()
|
32 |
+
|
33 |
+
def forward(self, x, dim: int = -1):
|
34 |
+
p1, p2 = x.chunk(2, dim=dim)
|
35 |
+
return p1 * self.gelu(p2)
|
36 |
+
|
37 |
+
class Snake1d(nn.Module):
|
38 |
+
def __init__(self, channels):
|
39 |
+
super().__init__()
|
40 |
+
self.alpha = nn.Parameter(torch.ones(channels))
|
41 |
+
|
42 |
+
def forward(self, x):
|
43 |
+
return x + (self.alpha + 1e-9).reciprocal() * torch.sin(self.alpha * x).pow(2)
|
44 |
+
|
45 |
+
def get_activation(name: str = "relu"):
|
46 |
+
if name == "relu":
|
47 |
+
return nn.ReLU
|
48 |
+
elif name == "gelu":
|
49 |
+
return NewGELU
|
50 |
+
elif name == "geglu":
|
51 |
+
return GatedGELU
|
52 |
+
elif name == "snake":
|
53 |
+
return Snake1d
|
54 |
+
else:
|
55 |
+
raise ValueError(f"Unrecognized activation {name}")
|
vampnet/modules/layers.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from typing import Optional
|
3 |
+
from typing import Tuple
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from einops import rearrange
|
9 |
+
from torch.nn.utils import weight_norm
|
10 |
+
|
11 |
+
# Scripting this brings model speed up 1.4x
|
12 |
+
@torch.jit.script
|
13 |
+
def snake(x, alpha):
|
14 |
+
shape = x.shape
|
15 |
+
x = x.reshape(shape[0], shape[1], -1)
|
16 |
+
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
|
17 |
+
x = x.reshape(shape)
|
18 |
+
return x
|
19 |
+
|
20 |
+
|
21 |
+
class Snake1d(nn.Module):
|
22 |
+
def __init__(self, channels):
|
23 |
+
super().__init__()
|
24 |
+
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
return snake(x, self.alpha)
|
28 |
+
|
29 |
+
|
30 |
+
def num_params(model):
|
31 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
32 |
+
|
33 |
+
|
34 |
+
def recurse_children(module, fn):
|
35 |
+
for child in module.children():
|
36 |
+
if isinstance(child, nn.ModuleList):
|
37 |
+
for c in child:
|
38 |
+
yield recurse_children(c, fn)
|
39 |
+
if isinstance(child, nn.ModuleDict):
|
40 |
+
for c in child.values():
|
41 |
+
yield recurse_children(c, fn)
|
42 |
+
|
43 |
+
yield recurse_children(child, fn)
|
44 |
+
yield fn(child)
|
45 |
+
|
46 |
+
|
47 |
+
def WNConv1d(*args, **kwargs):
|
48 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
49 |
+
|
50 |
+
|
51 |
+
def WNConvTranspose1d(*args, **kwargs):
|
52 |
+
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
53 |
+
|
54 |
+
|
55 |
+
class SequentialWithFiLM(nn.Module):
|
56 |
+
"""
|
57 |
+
handy wrapper for nn.Sequential that allows FiLM layers to be
|
58 |
+
inserted in between other layers.
|
59 |
+
"""
|
60 |
+
|
61 |
+
def __init__(self, *layers):
|
62 |
+
super().__init__()
|
63 |
+
self.layers = nn.ModuleList(layers)
|
64 |
+
|
65 |
+
@staticmethod
|
66 |
+
def has_film(module):
|
67 |
+
mod_has_film = any(
|
68 |
+
[res for res in recurse_children(module, lambda c: isinstance(c, FiLM))]
|
69 |
+
)
|
70 |
+
return mod_has_film
|
71 |
+
|
72 |
+
def forward(self, x, cond):
|
73 |
+
for layer in self.layers:
|
74 |
+
if self.has_film(layer):
|
75 |
+
x = layer(x, cond)
|
76 |
+
else:
|
77 |
+
x = layer(x)
|
78 |
+
return x
|
79 |
+
|
80 |
+
|
81 |
+
class FiLM(nn.Module):
|
82 |
+
def __init__(self, input_dim: int, output_dim: int):
|
83 |
+
super().__init__()
|
84 |
+
|
85 |
+
self.input_dim = input_dim
|
86 |
+
self.output_dim = output_dim
|
87 |
+
|
88 |
+
if input_dim > 0:
|
89 |
+
self.beta = nn.Linear(input_dim, output_dim)
|
90 |
+
self.gamma = nn.Linear(input_dim, output_dim)
|
91 |
+
|
92 |
+
def forward(self, x, r):
|
93 |
+
if self.input_dim == 0:
|
94 |
+
return x
|
95 |
+
else:
|
96 |
+
beta, gamma = self.beta(r), self.gamma(r)
|
97 |
+
beta, gamma = (
|
98 |
+
beta.view(x.size(0), self.output_dim, 1),
|
99 |
+
gamma.view(x.size(0), self.output_dim, 1),
|
100 |
+
)
|
101 |
+
x = x * (gamma + 1) + beta
|
102 |
+
return x
|
103 |
+
|
104 |
+
|
105 |
+
class CodebookEmbedding(nn.Module):
|
106 |
+
def __init__(
|
107 |
+
self,
|
108 |
+
vocab_size: int,
|
109 |
+
latent_dim: int,
|
110 |
+
n_codebooks: int,
|
111 |
+
emb_dim: int,
|
112 |
+
special_tokens: Optional[Tuple[str]] = None,
|
113 |
+
):
|
114 |
+
super().__init__()
|
115 |
+
self.n_codebooks = n_codebooks
|
116 |
+
self.emb_dim = emb_dim
|
117 |
+
self.latent_dim = latent_dim
|
118 |
+
self.vocab_size = vocab_size
|
119 |
+
|
120 |
+
if special_tokens is not None:
|
121 |
+
for tkn in special_tokens:
|
122 |
+
self.special = nn.ParameterDict(
|
123 |
+
{
|
124 |
+
tkn: nn.Parameter(torch.randn(n_codebooks, self.latent_dim))
|
125 |
+
for tkn in special_tokens
|
126 |
+
}
|
127 |
+
)
|
128 |
+
self.special_idxs = {
|
129 |
+
tkn: i + vocab_size for i, tkn in enumerate(special_tokens)
|
130 |
+
}
|
131 |
+
|
132 |
+
self.out_proj = nn.Conv1d(n_codebooks * self.latent_dim, self.emb_dim, 1)
|
133 |
+
|
134 |
+
def from_codes(self, codes: torch.Tensor, codec):
|
135 |
+
"""
|
136 |
+
get a sequence of continuous embeddings from a sequence of discrete codes.
|
137 |
+
unlike it's counterpart in the original VQ-VAE, this function adds for any special tokens
|
138 |
+
necessary for the language model, like <MASK>.
|
139 |
+
"""
|
140 |
+
n_codebooks = codes.shape[1]
|
141 |
+
latent = []
|
142 |
+
for i in range(n_codebooks):
|
143 |
+
c = codes[:, i, :]
|
144 |
+
|
145 |
+
lookup_table = codec.quantizer.quantizers[i].codebook.weight
|
146 |
+
if hasattr(self, "special"):
|
147 |
+
special_lookup = torch.cat(
|
148 |
+
[self.special[tkn][i : i + 1] for tkn in self.special], dim=0
|
149 |
+
)
|
150 |
+
lookup_table = torch.cat([lookup_table, special_lookup], dim=0)
|
151 |
+
|
152 |
+
l = F.embedding(c, lookup_table).transpose(1, 2)
|
153 |
+
latent.append(l)
|
154 |
+
|
155 |
+
latent = torch.cat(latent, dim=1)
|
156 |
+
return latent
|
157 |
+
|
158 |
+
def forward(self, latents: torch.Tensor):
|
159 |
+
"""
|
160 |
+
project a sequence of latents to a sequence of embeddings
|
161 |
+
"""
|
162 |
+
x = self.out_proj(latents)
|
163 |
+
return x
|
164 |
+
|
vampnet/modules/transformer.py
ADDED
@@ -0,0 +1,953 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import logging
|
3 |
+
from typing import Optional, Tuple, Union
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from einops import rearrange
|
10 |
+
import loralib as lora
|
11 |
+
import audiotools as at
|
12 |
+
|
13 |
+
from .activations import get_activation
|
14 |
+
from .layers import CodebookEmbedding
|
15 |
+
from .layers import FiLM
|
16 |
+
from .layers import SequentialWithFiLM
|
17 |
+
from .layers import WNConv1d
|
18 |
+
from ..util import scalar_to_batch_tensor, codebook_flatten, codebook_unflatten
|
19 |
+
from ..mask import _gamma
|
20 |
+
|
21 |
+
LORA_R = 8
|
22 |
+
|
23 |
+
# def log(t, eps=1e-20):
|
24 |
+
# return torch.log(t + eps)
|
25 |
+
|
26 |
+
|
27 |
+
def gumbel_noise_like(t):
|
28 |
+
noise = torch.zeros_like(t).uniform_(1e-20, 1)
|
29 |
+
return -torch.log(-torch.log(noise))
|
30 |
+
|
31 |
+
|
32 |
+
def gumbel_sample(t, temperature=1.0, dim=-1):
|
33 |
+
return ((t / max(temperature, 1e-10)) + gumbel_noise_like(t)).argmax(dim=dim)
|
34 |
+
|
35 |
+
|
36 |
+
class RMSNorm(nn.Module):
|
37 |
+
def __init__(self, hidden_size: int, eps=1e-6):
|
38 |
+
super().__init__()
|
39 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
40 |
+
self.var_eps = eps
|
41 |
+
|
42 |
+
def forward(self, x):
|
43 |
+
"""Returns root mean square normalized version of input `x`
|
44 |
+
# T5 uses a layer_norm which only scales and doesn't shift, which is also known
|
45 |
+
# as Root Mean Square Layer Normalization https://arxiv.org/abs/1910.07467
|
46 |
+
# thus varience is calculated w/o mean and there is no bias
|
47 |
+
Parameters
|
48 |
+
----------
|
49 |
+
x : Tensor[B x T x D]
|
50 |
+
Returns
|
51 |
+
-------
|
52 |
+
Tensor[B x T x D]
|
53 |
+
"""
|
54 |
+
var = x.pow(2).mean(-1, keepdim=True)
|
55 |
+
x = x * torch.rsqrt(var + self.var_eps)
|
56 |
+
|
57 |
+
return self.weight * x
|
58 |
+
|
59 |
+
|
60 |
+
class FeedForward(nn.Module):
|
61 |
+
def __init__(
|
62 |
+
self, d_model: int = 512, dropout: float = 0.1, activation: str = "geglu"
|
63 |
+
):
|
64 |
+
super().__init__()
|
65 |
+
factor = 2 if activation == "geglu" else 1
|
66 |
+
self.w_1 = lora.Linear(d_model, d_model * 4, bias=False, r=LORA_R)
|
67 |
+
self.w_2 = lora.Linear(d_model * 4 // factor, d_model, bias=False, r=LORA_R)
|
68 |
+
self.drop = nn.Dropout(dropout)
|
69 |
+
self.act = get_activation(activation)()
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
"""Computes position-wise feed-forward layer
|
73 |
+
Parameters
|
74 |
+
----------
|
75 |
+
x : Tensor[B x T x D]
|
76 |
+
Returns
|
77 |
+
-------
|
78 |
+
Tensor[B x T x D]
|
79 |
+
"""
|
80 |
+
x = self.w_1(x)
|
81 |
+
x = self.act(x)
|
82 |
+
x = self.drop(x)
|
83 |
+
x = self.w_2(x)
|
84 |
+
return x
|
85 |
+
|
86 |
+
|
87 |
+
class MultiHeadRelativeAttention(nn.Module):
|
88 |
+
def __init__(
|
89 |
+
self,
|
90 |
+
n_head: int = 8,
|
91 |
+
d_model: int = 512,
|
92 |
+
dropout: float = 0.1,
|
93 |
+
bidirectional: bool = True,
|
94 |
+
has_relative_attention_bias: bool = True,
|
95 |
+
attention_num_buckets: int = 32,
|
96 |
+
attention_max_distance: int = 128,
|
97 |
+
):
|
98 |
+
super().__init__()
|
99 |
+
d_head = d_model // n_head
|
100 |
+
self.n_head = n_head
|
101 |
+
self.d_head = d_head
|
102 |
+
self.bidirectional = bidirectional
|
103 |
+
self.has_relative_attention_bias = has_relative_attention_bias
|
104 |
+
self.attention_num_buckets = attention_num_buckets
|
105 |
+
self.attention_max_distance = attention_max_distance
|
106 |
+
|
107 |
+
# Create linear query, key, value projections
|
108 |
+
self.w_qs = lora.Linear(d_model, d_model, bias=False, r=LORA_R)
|
109 |
+
self.w_ks = nn.Linear(d_model, d_model, bias=False)
|
110 |
+
self.w_vs = lora.Linear(d_model, d_model, bias=False, r=LORA_R)
|
111 |
+
|
112 |
+
# Create linear final output projection
|
113 |
+
self.fc = lora.Linear(d_model, d_model, bias=False, r=LORA_R)
|
114 |
+
|
115 |
+
# Dropout for attention output weights
|
116 |
+
self.dropout = nn.Dropout(dropout)
|
117 |
+
|
118 |
+
# Create relative positional embeddings (if turned on)
|
119 |
+
if has_relative_attention_bias:
|
120 |
+
self.relative_attention_bias = nn.Embedding(attention_num_buckets, n_head)
|
121 |
+
|
122 |
+
def _relative_position_bucket(self, relative_position):
|
123 |
+
"""Converts unbounded relative position into bounded set of buckets
|
124 |
+
with half "exact" buckets (1 position = 1 bucket) and half "log-spaced"
|
125 |
+
buckets
|
126 |
+
Parameters
|
127 |
+
----------
|
128 |
+
relative_position : Tensor[T_q x T_kv]
|
129 |
+
Relative positions between queries and key_value items
|
130 |
+
Returns
|
131 |
+
-------
|
132 |
+
Tensor[T_q x T_kv]
|
133 |
+
Input relative positions converted into buckets
|
134 |
+
"""
|
135 |
+
relative_buckets = 0
|
136 |
+
num_buckets = self.attention_num_buckets
|
137 |
+
max_distance = self.attention_max_distance
|
138 |
+
|
139 |
+
# Convert relative position for (-inf, inf) to [0, inf]
|
140 |
+
# Negative relative positions correspond to past
|
141 |
+
# Positive relative positions correspond to future
|
142 |
+
if self.bidirectional:
|
143 |
+
# use half buckets for each side (past / future)
|
144 |
+
num_buckets //= 2
|
145 |
+
|
146 |
+
# Shift the position positions by `num_buckets` to wrap around
|
147 |
+
# negative positions
|
148 |
+
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
|
149 |
+
relative_position = torch.abs(relative_position)
|
150 |
+
else:
|
151 |
+
# If not bidirectional, ignore positive positions and wrap
|
152 |
+
# negative positions to positive
|
153 |
+
relative_position = -torch.min(
|
154 |
+
relative_position, torch.zeros_like(relative_position)
|
155 |
+
)
|
156 |
+
|
157 |
+
# Allocate half of the buckets are for exact increments in positions
|
158 |
+
max_exact = num_buckets // 2
|
159 |
+
is_small = relative_position < max_exact
|
160 |
+
|
161 |
+
# The other half of the buckets are for logarithmically bigger bins in
|
162 |
+
# positions up to `max_distance`
|
163 |
+
relative_postion_if_large = max_exact + (
|
164 |
+
torch.log(relative_position.float() / max_exact)
|
165 |
+
/ math.log(max_distance / max_exact)
|
166 |
+
* (num_buckets - max_exact)
|
167 |
+
).to(torch.long)
|
168 |
+
|
169 |
+
# Clip the max relative position to `num_buckets - 1`
|
170 |
+
relative_postion_if_large = torch.min(
|
171 |
+
relative_postion_if_large,
|
172 |
+
torch.full_like(relative_postion_if_large, num_buckets - 1),
|
173 |
+
)
|
174 |
+
|
175 |
+
# Choose relative buckets based on small or large positions
|
176 |
+
relative_buckets += torch.where(
|
177 |
+
is_small, relative_position, relative_postion_if_large
|
178 |
+
)
|
179 |
+
|
180 |
+
return relative_buckets
|
181 |
+
|
182 |
+
def compute_bias(self, query_length, key_length):
|
183 |
+
"""Computes a position bias scalar for each index in query_length x key_length
|
184 |
+
Parameters
|
185 |
+
----------
|
186 |
+
query_length : int
|
187 |
+
key_length : int
|
188 |
+
Returns
|
189 |
+
-------
|
190 |
+
Tensor[heads x 1 x T_q x T_kv]
|
191 |
+
Position bias to be applied on attention logits
|
192 |
+
"""
|
193 |
+
|
194 |
+
query_position = torch.arange(query_length, dtype=torch.long)[:, None]
|
195 |
+
key_position = torch.arange(key_length, dtype=torch.long)[None, :]
|
196 |
+
relative_position = key_position - query_position
|
197 |
+
|
198 |
+
# Convert relative position to buckets
|
199 |
+
relative_position_bucket = self._relative_position_bucket(relative_position)
|
200 |
+
relative_position_bucket = relative_position_bucket.to(
|
201 |
+
self.relative_attention_bias.weight.device
|
202 |
+
)
|
203 |
+
|
204 |
+
# Index attention bias values
|
205 |
+
values = self.relative_attention_bias(relative_position_bucket)
|
206 |
+
values = rearrange(values, "q k h -> h 1 q k")
|
207 |
+
|
208 |
+
return values
|
209 |
+
|
210 |
+
def forward(self, q, k, v, mask=None, position_bias=None):
|
211 |
+
"""Computes attention over (keys, values) for every timestep in query
|
212 |
+
Parameters
|
213 |
+
----------
|
214 |
+
q : Tensor[B x T_q x d_model]
|
215 |
+
Query vectors
|
216 |
+
k : Tensor[B x T_kv x d_model]
|
217 |
+
Key vectors to compute attention over
|
218 |
+
v : Tensor[B x T_kv x d_model]
|
219 |
+
Value vectors corresponding to the keys
|
220 |
+
mask : Tensor[B x T_q x T_kv], optional
|
221 |
+
position_bias: Tensor[head x 1 x T_q x T_kv]
|
222 |
+
Returns
|
223 |
+
-------
|
224 |
+
Tensor[B x T_q x d_model]
|
225 |
+
Outputs after attending (key, value) using queries
|
226 |
+
"""
|
227 |
+
# Compute query, key, value projections
|
228 |
+
q = rearrange(self.w_qs(q), "b l (head k) -> head b l k", head=self.n_head)
|
229 |
+
k = rearrange(self.w_ks(k), "b t (head k) -> head b t k", head=self.n_head)
|
230 |
+
v = rearrange(self.w_vs(v), "b t (head k) -> head b t k", head=self.n_head)
|
231 |
+
|
232 |
+
# Compute attention matrix
|
233 |
+
attn = torch.einsum("hblk,hbtk->hblt", [q, k]) / np.sqrt(q.shape[-1])
|
234 |
+
|
235 |
+
# Add relative position bias to attention scores
|
236 |
+
if position_bias is None:
|
237 |
+
if self.has_relative_attention_bias:
|
238 |
+
position_bias = self.compute_bias(q.size(-2), k.size(-2))
|
239 |
+
else:
|
240 |
+
position_bias = torch.zeros_like(attn)
|
241 |
+
attn += position_bias
|
242 |
+
|
243 |
+
# Apply mask to attention scores to prevent looking up invalid locations
|
244 |
+
if mask is not None:
|
245 |
+
attn = attn.masked_fill(mask[None] == 0, -1e9)
|
246 |
+
|
247 |
+
# Normalize attention scores and add dropout
|
248 |
+
attn = torch.softmax(attn, dim=3)
|
249 |
+
attn = self.dropout(attn)
|
250 |
+
|
251 |
+
# Compute attended outputs (product of attention matrix and values)
|
252 |
+
output = torch.einsum("hblt,hbtv->hblv", [attn, v])
|
253 |
+
output = rearrange(output, "head b l v -> b l (head v)")
|
254 |
+
output = self.fc(output)
|
255 |
+
|
256 |
+
return output, position_bias
|
257 |
+
|
258 |
+
|
259 |
+
class TransformerLayer(nn.Module):
|
260 |
+
def __init__(
|
261 |
+
self,
|
262 |
+
d_model: int = 512,
|
263 |
+
d_cond: int = 64,
|
264 |
+
n_heads: int = 8,
|
265 |
+
bidirectional: bool = True,
|
266 |
+
is_decoder: bool = False,
|
267 |
+
has_relative_attention_bias: bool = False,
|
268 |
+
flash_attn: bool = False,
|
269 |
+
dropout: float = 0.1,
|
270 |
+
):
|
271 |
+
super().__init__()
|
272 |
+
# Store args
|
273 |
+
self.is_decoder = is_decoder
|
274 |
+
|
275 |
+
# Create self-attention layer
|
276 |
+
self.norm_1 = RMSNorm(d_model)
|
277 |
+
self.film_1 = FiLM(d_cond, d_model)
|
278 |
+
self.flash_attn = flash_attn
|
279 |
+
|
280 |
+
if flash_attn:
|
281 |
+
from flash_attn.flash_attention import FlashMHA
|
282 |
+
self.self_attn = FlashMHA(
|
283 |
+
embed_dim=d_model,
|
284 |
+
num_heads=n_heads,
|
285 |
+
attention_dropout=dropout,
|
286 |
+
causal=False,
|
287 |
+
)
|
288 |
+
else:
|
289 |
+
self.self_attn = MultiHeadRelativeAttention(
|
290 |
+
n_heads, d_model, dropout, bidirectional, has_relative_attention_bias
|
291 |
+
)
|
292 |
+
|
293 |
+
# (Optional) Create cross-attention layer
|
294 |
+
if is_decoder:
|
295 |
+
self.norm_2 = RMSNorm(d_model)
|
296 |
+
self.film_2 = FiLM(d_cond, d_model)
|
297 |
+
self.cross_attn = MultiHeadRelativeAttention(
|
298 |
+
n_heads,
|
299 |
+
d_model,
|
300 |
+
dropout,
|
301 |
+
bidirectional=True,
|
302 |
+
has_relative_attention_bias=False,
|
303 |
+
)
|
304 |
+
|
305 |
+
# Create last feed-forward layer
|
306 |
+
self.norm_3 = RMSNorm(d_model)
|
307 |
+
self.film_3 = FiLM(d_cond, d_model)
|
308 |
+
self.feed_forward = FeedForward(d_model=d_model, dropout=dropout)
|
309 |
+
|
310 |
+
# Create dropout
|
311 |
+
self.dropout = nn.Dropout(dropout)
|
312 |
+
|
313 |
+
def forward(
|
314 |
+
self,
|
315 |
+
x,
|
316 |
+
x_mask,
|
317 |
+
cond,
|
318 |
+
src=None,
|
319 |
+
src_mask=None,
|
320 |
+
position_bias=None,
|
321 |
+
encoder_decoder_position_bias=None,
|
322 |
+
):
|
323 |
+
"""Computes one transformer layer consisting of self attention, (op) cross attention
|
324 |
+
and feedforward layer
|
325 |
+
Parameters
|
326 |
+
----------
|
327 |
+
x : Tensor[B x T_q x D]
|
328 |
+
x_mask : Tensor[B x T_q]
|
329 |
+
src : Tensor[B x T_kv x D], optional
|
330 |
+
src_mask : Tensor[B x T_kv x D], optional
|
331 |
+
position_bias : Tensor[heads x B x T_q x T_q], optional
|
332 |
+
Relative position bias for self attention layer
|
333 |
+
encoder_decoder_position_bias : Tensor[heads x B x T_q x T_kv], optional
|
334 |
+
Relative position bias for cross attention layer
|
335 |
+
Returns
|
336 |
+
-------
|
337 |
+
Tensor[B x T_q x D]
|
338 |
+
"""
|
339 |
+
y = self.norm_1(x)
|
340 |
+
y = self.film_1(y.permute(0, 2, 1), cond).permute(0, 2, 1)
|
341 |
+
if self.flash_attn:
|
342 |
+
with torch.autocast(y.device.type, dtype=torch.bfloat16):
|
343 |
+
y = self.self_attn(y)[0]
|
344 |
+
else:
|
345 |
+
y, position_bias = self.self_attn(y, y, y, x_mask, position_bias)
|
346 |
+
x = x + self.dropout(y)
|
347 |
+
|
348 |
+
if self.is_decoder:
|
349 |
+
y = self.norm_2(x)
|
350 |
+
y = self.film_2(y.permute(0, 2, 1), cond).permute(0, 2, 1)
|
351 |
+
y, encoder_decoder_position_bias = self.cross_attn(
|
352 |
+
y, src, src, src_mask, encoder_decoder_position_bias
|
353 |
+
)
|
354 |
+
x = x + self.dropout(y)
|
355 |
+
|
356 |
+
y = self.norm_3(x)
|
357 |
+
y = self.film_3(
|
358 |
+
y.permute(
|
359 |
+
0,
|
360 |
+
2,
|
361 |
+
1,
|
362 |
+
),
|
363 |
+
cond,
|
364 |
+
).permute(0, 2, 1)
|
365 |
+
y = self.feed_forward(y)
|
366 |
+
x = x + self.dropout(y)
|
367 |
+
|
368 |
+
return x, position_bias, encoder_decoder_position_bias
|
369 |
+
|
370 |
+
|
371 |
+
class TransformerStack(nn.Module):
|
372 |
+
def __init__(
|
373 |
+
self,
|
374 |
+
d_model: int = 512,
|
375 |
+
d_cond: int = 64,
|
376 |
+
n_heads: int = 8,
|
377 |
+
n_layers: int = 8,
|
378 |
+
last_layer: bool = True,
|
379 |
+
bidirectional: bool = True,
|
380 |
+
flash_attn: bool = False,
|
381 |
+
is_decoder: bool = False,
|
382 |
+
dropout: float = 0.1,
|
383 |
+
):
|
384 |
+
super().__init__()
|
385 |
+
# Store args
|
386 |
+
self.bidirectional = bidirectional
|
387 |
+
self.is_decoder = is_decoder
|
388 |
+
|
389 |
+
# Create transformer layers
|
390 |
+
# In T5, relative attention bias is shared by all layers in the stack
|
391 |
+
self.layers = nn.ModuleList(
|
392 |
+
[
|
393 |
+
TransformerLayer(
|
394 |
+
d_model,
|
395 |
+
d_cond,
|
396 |
+
n_heads,
|
397 |
+
bidirectional,
|
398 |
+
is_decoder,
|
399 |
+
has_relative_attention_bias=True if (i == 0) else False,
|
400 |
+
flash_attn=flash_attn,
|
401 |
+
dropout=dropout,
|
402 |
+
)
|
403 |
+
for i in range(n_layers)
|
404 |
+
]
|
405 |
+
)
|
406 |
+
|
407 |
+
# Perform last normalization
|
408 |
+
self.norm = RMSNorm(d_model) if last_layer else None
|
409 |
+
|
410 |
+
def subsequent_mask(self, size):
|
411 |
+
return torch.ones(1, size, size).tril().bool()
|
412 |
+
|
413 |
+
def forward(self, x, x_mask, cond=None, src=None, src_mask=None,
|
414 |
+
return_activations: bool = False
|
415 |
+
):
|
416 |
+
"""Computes a full transformer stack
|
417 |
+
Parameters
|
418 |
+
----------
|
419 |
+
x : Tensor[B x T_q x D]
|
420 |
+
x_mask : Tensor[B x T_q]
|
421 |
+
src : Tensor[B x T_kv x D], optional
|
422 |
+
src_mask : Tensor[B x T_kv], optional
|
423 |
+
Returns
|
424 |
+
-------
|
425 |
+
Tensor[B x T_q x D]
|
426 |
+
"""
|
427 |
+
|
428 |
+
# Convert `src_mask` to (B x T_q x T_kv) shape for cross attention masking
|
429 |
+
if self.is_decoder:
|
430 |
+
src_mask = x_mask.unsqueeze(-1) * src_mask.unsqueeze(-2)
|
431 |
+
|
432 |
+
# Convert `x_mask` to (B x T_q x T_q) shape for self attention masking
|
433 |
+
x_mask = x_mask.unsqueeze(-2)
|
434 |
+
if not self.bidirectional:
|
435 |
+
x_mask = x_mask * self.subsequent_mask(x.size(1)).to(x_mask.device)
|
436 |
+
|
437 |
+
# Initialize position biases
|
438 |
+
position_bias = None
|
439 |
+
encoder_decoder_position_bias = None
|
440 |
+
|
441 |
+
# Compute transformer layers
|
442 |
+
if return_activations:
|
443 |
+
activations = []
|
444 |
+
for layer in self.layers:
|
445 |
+
x, position_bias, encoder_decoder_position_bias = layer(
|
446 |
+
x=x,
|
447 |
+
x_mask=x_mask,
|
448 |
+
cond=cond,
|
449 |
+
src=src,
|
450 |
+
src_mask=src_mask,
|
451 |
+
position_bias=position_bias,
|
452 |
+
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
453 |
+
)
|
454 |
+
if return_activations:
|
455 |
+
activations.append(x.detach())
|
456 |
+
|
457 |
+
|
458 |
+
out = self.norm(x) if self.norm is not None else x
|
459 |
+
if return_activations:
|
460 |
+
return out, torch.stack(activations)
|
461 |
+
else:
|
462 |
+
return out
|
463 |
+
|
464 |
+
|
465 |
+
class VampNet(at.ml.BaseModel):
|
466 |
+
def __init__(
|
467 |
+
self,
|
468 |
+
n_heads: int = 20,
|
469 |
+
n_layers: int = 16,
|
470 |
+
r_cond_dim: int = 0,
|
471 |
+
n_codebooks: int = 9,
|
472 |
+
n_conditioning_codebooks: int = 0,
|
473 |
+
latent_dim: int = 8,
|
474 |
+
embedding_dim: int = 1280,
|
475 |
+
vocab_size: int = 1024,
|
476 |
+
flash_attn: bool = True,
|
477 |
+
noise_mode: str = "mask",
|
478 |
+
dropout: float = 0.1
|
479 |
+
):
|
480 |
+
super().__init__()
|
481 |
+
assert r_cond_dim == 0, f"r_cond_dim must be 0 (not supported), but got {r_cond_dim}"
|
482 |
+
self.n_heads = n_heads
|
483 |
+
self.n_layers = n_layers
|
484 |
+
self.r_cond_dim = r_cond_dim
|
485 |
+
self.n_codebooks = n_codebooks
|
486 |
+
self.n_conditioning_codebooks = n_conditioning_codebooks
|
487 |
+
self.embedding_dim = embedding_dim
|
488 |
+
self.vocab_size = vocab_size
|
489 |
+
self.latent_dim = latent_dim
|
490 |
+
self.flash_attn = flash_attn
|
491 |
+
self.noise_mode = noise_mode
|
492 |
+
|
493 |
+
assert self.noise_mode == "mask", "deprecated"
|
494 |
+
|
495 |
+
self.embedding = CodebookEmbedding(
|
496 |
+
latent_dim=latent_dim,
|
497 |
+
n_codebooks=n_codebooks,
|
498 |
+
vocab_size=vocab_size,
|
499 |
+
emb_dim=embedding_dim,
|
500 |
+
special_tokens=["MASK"],
|
501 |
+
)
|
502 |
+
self.mask_token = self.embedding.special_idxs["MASK"]
|
503 |
+
|
504 |
+
self.transformer = TransformerStack(
|
505 |
+
d_model=embedding_dim,
|
506 |
+
d_cond=r_cond_dim,
|
507 |
+
n_heads=n_heads,
|
508 |
+
n_layers=n_layers,
|
509 |
+
last_layer=True,
|
510 |
+
bidirectional=True,
|
511 |
+
flash_attn=flash_attn,
|
512 |
+
is_decoder=False,
|
513 |
+
dropout=dropout,
|
514 |
+
)
|
515 |
+
|
516 |
+
# Add final conv layer
|
517 |
+
self.n_predict_codebooks = n_codebooks - n_conditioning_codebooks
|
518 |
+
self.classifier = SequentialWithFiLM(
|
519 |
+
WNConv1d(
|
520 |
+
embedding_dim,
|
521 |
+
vocab_size * self.n_predict_codebooks,
|
522 |
+
kernel_size=1,
|
523 |
+
padding="same",
|
524 |
+
# groups=self.n_predict_codebooks,
|
525 |
+
),
|
526 |
+
)
|
527 |
+
|
528 |
+
def forward(self, x, return_activations: bool = False):
|
529 |
+
x = self.embedding(x)
|
530 |
+
x_mask = torch.ones_like(x, dtype=torch.bool)[:, :1, :].squeeze(1)
|
531 |
+
|
532 |
+
x = rearrange(x, "b d n -> b n d")
|
533 |
+
out = self.transformer(x=x, x_mask=x_mask, return_activations=return_activations)
|
534 |
+
if return_activations:
|
535 |
+
out, activations = out
|
536 |
+
|
537 |
+
out = rearrange(out, "b n d -> b d n")
|
538 |
+
|
539 |
+
out = self.classifier(out, None) # no cond here!
|
540 |
+
|
541 |
+
out = rearrange(out, "b (p c) t -> b p (t c)", c=self.n_predict_codebooks)
|
542 |
+
|
543 |
+
if return_activations:
|
544 |
+
return out, activations
|
545 |
+
else:
|
546 |
+
return out
|
547 |
+
|
548 |
+
def r_embed(self, r, max_positions=10000):
|
549 |
+
if self.r_cond_dim > 0:
|
550 |
+
dtype = r.dtype
|
551 |
+
|
552 |
+
r = _gamma(r) * max_positions
|
553 |
+
half_dim = self.r_cond_dim // 2
|
554 |
+
|
555 |
+
emb = math.log(max_positions) / (half_dim - 1)
|
556 |
+
emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
|
557 |
+
|
558 |
+
emb = r[:, None] * emb[None, :]
|
559 |
+
emb = torch.cat([emb.sin(), emb.cos()], dim=1)
|
560 |
+
|
561 |
+
if self.r_cond_dim % 2 == 1: # zero pad
|
562 |
+
emb = nn.functional.pad(emb, (0, 1), mode="constant")
|
563 |
+
|
564 |
+
return emb.to(dtype)
|
565 |
+
else:
|
566 |
+
return r
|
567 |
+
|
568 |
+
@torch.no_grad()
|
569 |
+
def to_signal(self, z, codec):
|
570 |
+
"""
|
571 |
+
convert a sequence of latents to a signal.
|
572 |
+
"""
|
573 |
+
assert z.ndim == 3
|
574 |
+
|
575 |
+
signal = at.AudioSignal(
|
576 |
+
codec.decode(
|
577 |
+
codec.quantizer.from_latents(self.embedding.from_codes(z, codec))[0]
|
578 |
+
)["audio"],
|
579 |
+
codec.sample_rate,
|
580 |
+
)
|
581 |
+
|
582 |
+
# find where the mask token is and replace it with silence in the audio
|
583 |
+
for tstep in range(z.shape[-1]):
|
584 |
+
if torch.any(z[:, :, tstep] == self.mask_token):
|
585 |
+
sample_idx_0 = tstep * codec.hop_length
|
586 |
+
sample_idx_1 = sample_idx_0 + codec.hop_length
|
587 |
+
signal.samples[:, :, sample_idx_0:sample_idx_1] = 0.0
|
588 |
+
|
589 |
+
return signal
|
590 |
+
|
591 |
+
|
592 |
+
@torch.no_grad()
|
593 |
+
def generate(
|
594 |
+
self,
|
595 |
+
codec,
|
596 |
+
time_steps: int = 300,
|
597 |
+
sampling_steps: int = 36,
|
598 |
+
start_tokens: Optional[torch.Tensor] = None,
|
599 |
+
sampling_temperature: float = 1.0,
|
600 |
+
mask: Optional[torch.Tensor] = None,
|
601 |
+
mask_temperature: float = 10.5,
|
602 |
+
typical_filtering=False,
|
603 |
+
typical_mass=0.2,
|
604 |
+
typical_min_tokens=1,
|
605 |
+
top_p=None,
|
606 |
+
return_signal=True,
|
607 |
+
seed: int = None,
|
608 |
+
sample_cutoff: float = 1.0,
|
609 |
+
):
|
610 |
+
if seed is not None:
|
611 |
+
at.util.seed(seed)
|
612 |
+
logging.debug(f"beginning generation with {sampling_steps} steps")
|
613 |
+
|
614 |
+
|
615 |
+
|
616 |
+
#####################
|
617 |
+
# resolve initial z #
|
618 |
+
#####################
|
619 |
+
z = start_tokens
|
620 |
+
|
621 |
+
if z is None:
|
622 |
+
z = torch.full((1, self.n_codebooks, time_steps), self.mask_token).to(
|
623 |
+
self.device
|
624 |
+
)
|
625 |
+
|
626 |
+
logging.debug(f"created z with shape {z.shape}")
|
627 |
+
|
628 |
+
|
629 |
+
#################
|
630 |
+
# resolve mask #
|
631 |
+
#################
|
632 |
+
|
633 |
+
if mask is None:
|
634 |
+
mask = torch.ones_like(z).to(self.device).int()
|
635 |
+
mask[:, : self.n_conditioning_codebooks, :] = 0.0
|
636 |
+
if mask.ndim == 2:
|
637 |
+
mask = mask[:, None, :].repeat(1, z.shape[1], 1)
|
638 |
+
# init_mask = mask.clone()
|
639 |
+
|
640 |
+
logging.debug(f"created mask with shape {mask.shape}")
|
641 |
+
|
642 |
+
|
643 |
+
###########
|
644 |
+
# set up #
|
645 |
+
##########
|
646 |
+
# apply the mask to z
|
647 |
+
z_masked = z.masked_fill(mask.bool(), self.mask_token)
|
648 |
+
# logging.debug(f"z_masked: {z_masked}")
|
649 |
+
|
650 |
+
# how many mask tokens to begin with?
|
651 |
+
num_mask_tokens_at_start = (z_masked == self.mask_token).sum()
|
652 |
+
logging.debug(f"num mask tokens at start: {num_mask_tokens_at_start}")
|
653 |
+
|
654 |
+
# how many codebooks are we inferring vs conditioning on?
|
655 |
+
n_infer_codebooks = self.n_codebooks - self.n_conditioning_codebooks
|
656 |
+
logging.debug(f"n infer codebooks: {n_infer_codebooks}")
|
657 |
+
|
658 |
+
#################
|
659 |
+
# begin sampling #
|
660 |
+
#################
|
661 |
+
|
662 |
+
for i in range(sampling_steps):
|
663 |
+
logging.debug(f"step {i} of {sampling_steps}")
|
664 |
+
|
665 |
+
# our current schedule step
|
666 |
+
r = scalar_to_batch_tensor(
|
667 |
+
(i + 1) / sampling_steps,
|
668 |
+
z.shape[0]
|
669 |
+
).to(z.device)
|
670 |
+
logging.debug(f"r: {r}")
|
671 |
+
|
672 |
+
# get latents
|
673 |
+
latents = self.embedding.from_codes(z_masked, codec)
|
674 |
+
logging.debug(f"computed latents with shape: {latents.shape}")
|
675 |
+
|
676 |
+
|
677 |
+
# infer from latents
|
678 |
+
# NOTE: this collapses the codebook dimension into the sequence dimension
|
679 |
+
logits = self.forward(latents) # b, prob, seq
|
680 |
+
logits = logits.permute(0, 2, 1) # b, seq, prob
|
681 |
+
b = logits.shape[0]
|
682 |
+
|
683 |
+
logging.debug(f"permuted logits with shape: {logits.shape}")
|
684 |
+
|
685 |
+
sampled_z, selected_probs = sample_from_logits(
|
686 |
+
logits, sample=(
|
687 |
+
(i / sampling_steps) <= sample_cutoff
|
688 |
+
),
|
689 |
+
temperature=sampling_temperature,
|
690 |
+
typical_filtering=typical_filtering, typical_mass=typical_mass,
|
691 |
+
typical_min_tokens=typical_min_tokens,
|
692 |
+
top_k=None, top_p=top_p, return_probs=True,
|
693 |
+
)
|
694 |
+
|
695 |
+
logging.debug(f"sampled z with shape: {sampled_z.shape}")
|
696 |
+
|
697 |
+
# flatten z_masked and mask, so we can deal with the sampling logic
|
698 |
+
# we'll unflatten them at the end of the loop for the next forward pass
|
699 |
+
# remove conditioning codebooks, we'll add them back at the end
|
700 |
+
z_masked = codebook_flatten(z_masked[:, self.n_conditioning_codebooks:, :])
|
701 |
+
|
702 |
+
mask = (z_masked == self.mask_token).int()
|
703 |
+
|
704 |
+
# update the mask, remove conditioning codebooks from the mask
|
705 |
+
logging.debug(f"updated mask with shape: {mask.shape}")
|
706 |
+
# add z back into sampled z where the mask was false
|
707 |
+
sampled_z = torch.where(
|
708 |
+
mask.bool(), sampled_z, z_masked
|
709 |
+
)
|
710 |
+
logging.debug(f"added z back into sampled z with shape: {sampled_z.shape}")
|
711 |
+
|
712 |
+
# ignore any tokens that weren't masked
|
713 |
+
selected_probs = torch.where(
|
714 |
+
mask.bool(), selected_probs, torch.inf
|
715 |
+
)
|
716 |
+
|
717 |
+
# get the num tokens to mask, according to the schedule
|
718 |
+
num_to_mask = torch.floor(_gamma(r) * num_mask_tokens_at_start).unsqueeze(1).long()
|
719 |
+
logging.debug(f"num to mask: {num_to_mask}")
|
720 |
+
|
721 |
+
if i != (sampling_steps - 1):
|
722 |
+
num_to_mask = torch.maximum(
|
723 |
+
torch.tensor(1),
|
724 |
+
torch.minimum(
|
725 |
+
mask.sum(dim=-1, keepdim=True) - 1,
|
726 |
+
num_to_mask
|
727 |
+
)
|
728 |
+
)
|
729 |
+
|
730 |
+
|
731 |
+
# get our new mask
|
732 |
+
mask = mask_by_random_topk(
|
733 |
+
num_to_mask, selected_probs, mask_temperature * (1-r)
|
734 |
+
)
|
735 |
+
|
736 |
+
# update the mask
|
737 |
+
z_masked = torch.where(
|
738 |
+
mask.bool(), self.mask_token, sampled_z
|
739 |
+
)
|
740 |
+
logging.debug(f"updated z_masked with shape: {z_masked.shape}")
|
741 |
+
|
742 |
+
z_masked = codebook_unflatten(z_masked, n_infer_codebooks)
|
743 |
+
mask = codebook_unflatten(mask, n_infer_codebooks)
|
744 |
+
logging.debug(f"unflattened z_masked with shape: {z_masked.shape}")
|
745 |
+
|
746 |
+
# add conditioning codebooks back to z_masked
|
747 |
+
z_masked = torch.cat(
|
748 |
+
(z[:, :self.n_conditioning_codebooks, :], z_masked), dim=1
|
749 |
+
)
|
750 |
+
logging.debug(f"added conditioning codebooks back to z_masked with shape: {z_masked.shape}")
|
751 |
+
|
752 |
+
|
753 |
+
# add conditioning codebooks back to sampled_z
|
754 |
+
sampled_z = codebook_unflatten(sampled_z, n_infer_codebooks)
|
755 |
+
sampled_z = torch.cat(
|
756 |
+
(z[:, :self.n_conditioning_codebooks, :], sampled_z), dim=1
|
757 |
+
)
|
758 |
+
|
759 |
+
logging.debug(f"finished sampling")
|
760 |
+
|
761 |
+
if return_signal:
|
762 |
+
return self.to_signal(sampled_z, codec)
|
763 |
+
else:
|
764 |
+
return sampled_z
|
765 |
+
|
766 |
+
def sample_from_logits(
|
767 |
+
logits,
|
768 |
+
sample: bool = True,
|
769 |
+
temperature: float = 1.0,
|
770 |
+
top_k: int = None,
|
771 |
+
top_p: float = None,
|
772 |
+
typical_filtering: bool = False,
|
773 |
+
typical_mass: float = 0.2,
|
774 |
+
typical_min_tokens: int = 1,
|
775 |
+
return_probs: bool = False
|
776 |
+
):
|
777 |
+
"""Convenience function to sample from a categorial distribution with input as
|
778 |
+
unnormalized logits.
|
779 |
+
|
780 |
+
Parameters
|
781 |
+
----------
|
782 |
+
logits : Tensor[..., vocab_size]
|
783 |
+
config: SamplingConfig
|
784 |
+
The set of hyperparameters to be used for sampling
|
785 |
+
sample : bool, optional
|
786 |
+
Whether to perform multinomial sampling, by default True
|
787 |
+
temperature : float, optional
|
788 |
+
Scaling parameter when multinomial samping, by default 1.0
|
789 |
+
top_k : int, optional
|
790 |
+
Restricts sampling to only `top_k` values acc. to probability,
|
791 |
+
by default None
|
792 |
+
top_p : float, optional
|
793 |
+
Restricts sampling to only those values with cumulative
|
794 |
+
probability = `top_p`, by default None
|
795 |
+
|
796 |
+
Returns
|
797 |
+
-------
|
798 |
+
Tensor[...]
|
799 |
+
Sampled tokens
|
800 |
+
"""
|
801 |
+
shp = logits.shape[:-1]
|
802 |
+
|
803 |
+
if typical_filtering:
|
804 |
+
typical_filter(logits,
|
805 |
+
typical_mass=typical_mass,
|
806 |
+
typical_min_tokens=typical_min_tokens
|
807 |
+
)
|
808 |
+
|
809 |
+
# Apply top_k sampling
|
810 |
+
if top_k is not None:
|
811 |
+
v, _ = logits.topk(top_k)
|
812 |
+
logits[logits < v[..., [-1]]] = -float("inf")
|
813 |
+
|
814 |
+
# Apply top_p (nucleus) sampling
|
815 |
+
if top_p is not None and top_p < 1.0:
|
816 |
+
v, sorted_indices = logits.sort(descending=True)
|
817 |
+
cumulative_probs = v.softmax(dim=-1).cumsum(dim=-1)
|
818 |
+
|
819 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
820 |
+
# Right shift indices_to_remove to keep 1st token over threshold
|
821 |
+
sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, 0), value=False)[
|
822 |
+
..., :-1
|
823 |
+
]
|
824 |
+
|
825 |
+
# Compute indices_to_remove in unsorted array
|
826 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
827 |
+
-1, sorted_indices, sorted_indices_to_remove
|
828 |
+
)
|
829 |
+
|
830 |
+
logits[indices_to_remove] = -float("inf")
|
831 |
+
|
832 |
+
# Perform multinomial sampling after normalizing logits
|
833 |
+
probs = (
|
834 |
+
F.softmax(logits / temperature, dim=-1)
|
835 |
+
if temperature > 0
|
836 |
+
else logits.softmax(dim=-1)
|
837 |
+
)
|
838 |
+
token = (
|
839 |
+
probs.view(-1, probs.size(-1)).multinomial(1).squeeze(1).view(*shp)
|
840 |
+
if sample
|
841 |
+
else logits.argmax(-1)
|
842 |
+
)
|
843 |
+
|
844 |
+
if return_probs:
|
845 |
+
token_probs = probs.take_along_dim(token.unsqueeze(-1), dim=-1).squeeze(-1)
|
846 |
+
return token, token_probs
|
847 |
+
else:
|
848 |
+
return token
|
849 |
+
|
850 |
+
|
851 |
+
|
852 |
+
def mask_by_random_topk(
|
853 |
+
num_to_mask: int,
|
854 |
+
probs: torch.Tensor,
|
855 |
+
temperature: float = 1.0,
|
856 |
+
):
|
857 |
+
"""
|
858 |
+
Args:
|
859 |
+
num_to_mask (int): number of tokens to mask
|
860 |
+
probs (torch.Tensor): probabilities for each sampled event, shape (batch, seq)
|
861 |
+
temperature (float, optional): temperature. Defaults to 1.0.
|
862 |
+
"""
|
863 |
+
logging.debug(f"masking by random topk")
|
864 |
+
logging.debug(f"num to mask: {num_to_mask}")
|
865 |
+
logging.debug(f"probs shape: {probs.shape}")
|
866 |
+
logging.debug(f"temperature: {temperature}")
|
867 |
+
logging.debug("")
|
868 |
+
|
869 |
+
noise = gumbel_noise_like(probs)
|
870 |
+
confidence = torch.log(probs) + temperature * noise
|
871 |
+
logging.debug(f"confidence shape: {confidence.shape}")
|
872 |
+
|
873 |
+
sorted_confidence, sorted_idx = confidence.sort(dim=-1)
|
874 |
+
logging.debug(f"sorted confidence shape: {sorted_confidence.shape}")
|
875 |
+
logging.debug(f"sorted idx shape: {sorted_idx.shape}")
|
876 |
+
|
877 |
+
# get the cut off threshold, given the mask length
|
878 |
+
cut_off = torch.take_along_dim(
|
879 |
+
sorted_confidence, num_to_mask, axis=-1
|
880 |
+
)
|
881 |
+
logging.debug(f"cut off shape: {cut_off.shape}")
|
882 |
+
|
883 |
+
# mask out the tokens
|
884 |
+
mask = confidence < cut_off
|
885 |
+
logging.debug(f"mask shape: {mask.shape}")
|
886 |
+
|
887 |
+
return mask
|
888 |
+
|
889 |
+
def typical_filter(
|
890 |
+
logits,
|
891 |
+
typical_mass: float = 0.95,
|
892 |
+
typical_min_tokens: int = 1,):
|
893 |
+
nb, nt, _ = logits.shape
|
894 |
+
x_flat = rearrange(logits, "b t l -> (b t ) l")
|
895 |
+
x_flat_norm = torch.nn.functional.log_softmax(x_flat, dim=-1)
|
896 |
+
x_flat_norm_p = torch.exp(x_flat_norm)
|
897 |
+
entropy = -(x_flat_norm * x_flat_norm_p).nansum(-1, keepdim=True)
|
898 |
+
|
899 |
+
c_flat_shifted = torch.abs((-x_flat_norm) - entropy)
|
900 |
+
c_flat_sorted, x_flat_indices = torch.sort(c_flat_shifted, descending=False)
|
901 |
+
x_flat_cumsum = (
|
902 |
+
x_flat.gather(-1, x_flat_indices).softmax(dim=-1).cumsum(dim=-1)
|
903 |
+
)
|
904 |
+
|
905 |
+
last_ind = (x_flat_cumsum < typical_mass).sum(dim=-1)
|
906 |
+
sorted_indices_to_remove = c_flat_sorted > c_flat_sorted.gather(
|
907 |
+
1, last_ind.view(-1, 1)
|
908 |
+
)
|
909 |
+
if typical_min_tokens > 1:
|
910 |
+
sorted_indices_to_remove[..., :typical_min_tokens] = 0
|
911 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
912 |
+
1, x_flat_indices, sorted_indices_to_remove
|
913 |
+
)
|
914 |
+
x_flat = x_flat.masked_fill(indices_to_remove, -float("Inf"))
|
915 |
+
logits = rearrange(x_flat, "(b t) l -> b t l", t=nt)
|
916 |
+
return logits
|
917 |
+
|
918 |
+
|
919 |
+
if __name__ == "__main__":
|
920 |
+
# import argbind
|
921 |
+
from .layers import num_params
|
922 |
+
|
923 |
+
VampNet = argbind.bind(VampNet)
|
924 |
+
|
925 |
+
@argbind.bind(without_prefix=True)
|
926 |
+
def try_model(device: str = "cuda", batch_size: int = 2, seq_len_s: float = 10.0):
|
927 |
+
seq_len = int(32000 / 512 * seq_len_s)
|
928 |
+
|
929 |
+
model = VampNet().to(device)
|
930 |
+
|
931 |
+
z = torch.randint(
|
932 |
+
0, model.vocab_size, size=(batch_size, model.n_codebooks, seq_len)
|
933 |
+
).to(device)
|
934 |
+
|
935 |
+
r = torch.zeros(batch_size).to(device)
|
936 |
+
|
937 |
+
z_mask_latent = torch.rand(
|
938 |
+
batch_size, model.latent_dim * model.n_codebooks, seq_len
|
939 |
+
).to(device)
|
940 |
+
z_hat = model(z_mask_latent)
|
941 |
+
|
942 |
+
pred = z_hat.argmax(dim=1)
|
943 |
+
pred = model.embedding.unflatten(pred, n_codebooks=model.n_predict_codebooks)
|
944 |
+
|
945 |
+
print(f"model has {num_params(model)/1e6:<.3f}M parameters")
|
946 |
+
print(f"prediction has shape {pred.shape}")
|
947 |
+
breakpoint()
|
948 |
+
|
949 |
+
args = argbind.parse_args()
|
950 |
+
with argbind.scope(args):
|
951 |
+
try_model()
|
952 |
+
|
953 |
+
|
vampnet/scheduler.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
class NoamScheduler:
|
7 |
+
"""OG scheduler from transformer paper: https://arxiv.org/pdf/1706.03762.pdf
|
8 |
+
Implementation from Annotated Transformer: https://nlp.seas.harvard.edu/2018/04/03/attention.html
|
9 |
+
"""
|
10 |
+
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
optimizer: torch.optim.Optimizer,
|
14 |
+
d_model: int = 512,
|
15 |
+
factor: float = 1.0,
|
16 |
+
warmup: int = 4000,
|
17 |
+
):
|
18 |
+
# Store hparams
|
19 |
+
self.warmup = warmup
|
20 |
+
self.factor = factor
|
21 |
+
self.d_model = d_model
|
22 |
+
|
23 |
+
# Initialize variables `lr` and `steps`
|
24 |
+
self.lr = None
|
25 |
+
self.steps = 0
|
26 |
+
|
27 |
+
# Store the optimizer
|
28 |
+
self.optimizer = optimizer
|
29 |
+
|
30 |
+
def state_dict(self):
|
31 |
+
return {
|
32 |
+
key: value for key, value in self.__dict__.items() if key != "optimizer"
|
33 |
+
}
|
34 |
+
|
35 |
+
def load_state_dict(self, state_dict):
|
36 |
+
self.__dict__.update(state_dict)
|
37 |
+
|
38 |
+
def step(self):
|
39 |
+
self.steps += 1
|
40 |
+
self.lr = self.factor * (
|
41 |
+
self.d_model ** (-0.5)
|
42 |
+
* min(self.steps ** (-0.5), self.steps * self.warmup ** (-1.5))
|
43 |
+
)
|
44 |
+
|
45 |
+
for p in self.optimizer.param_groups:
|
46 |
+
p["lr"] = self.lr
|
47 |
+
|
vampnet/util.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tqdm
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from einops import rearrange
|
5 |
+
|
6 |
+
def scalar_to_batch_tensor(x, batch_size):
|
7 |
+
return torch.tensor(x).repeat(batch_size)
|
8 |
+
|
9 |
+
|
10 |
+
def parallelize(
|
11 |
+
fn,
|
12 |
+
*iterables,
|
13 |
+
parallel: str = "thread_map",
|
14 |
+
**kwargs
|
15 |
+
):
|
16 |
+
if parallel == "thread_map":
|
17 |
+
from tqdm.contrib.concurrent import thread_map
|
18 |
+
return thread_map(
|
19 |
+
fn,
|
20 |
+
*iterables,
|
21 |
+
**kwargs
|
22 |
+
)
|
23 |
+
elif parallel == "process_map":
|
24 |
+
from tqdm.contrib.concurrent import process_map
|
25 |
+
return process_map(
|
26 |
+
fn,
|
27 |
+
*iterables,
|
28 |
+
**kwargs
|
29 |
+
)
|
30 |
+
elif parallel == "single":
|
31 |
+
return [fn(x) for x in tqdm.tqdm(*iterables)]
|
32 |
+
else:
|
33 |
+
raise ValueError(f"parallel must be one of 'thread_map', 'process_map', 'single', but got {parallel}")
|
34 |
+
|
35 |
+
def codebook_flatten(tokens: torch.Tensor):
|
36 |
+
"""
|
37 |
+
flatten a sequence of tokens from (batch, codebook, time) to (batch, codebook * time)
|
38 |
+
"""
|
39 |
+
return rearrange(tokens, "b c t -> b (t c)")
|
40 |
+
|
41 |
+
def codebook_unflatten(flat_tokens: torch.Tensor, n_c: int = None):
|
42 |
+
"""
|
43 |
+
unflatten a sequence of tokens from (batch, codebook * time) to (batch, codebook, time)
|
44 |
+
"""
|
45 |
+
tokens = rearrange(flat_tokens, "b (t c) -> b c t", c=n_c)
|
46 |
+
return tokens
|