Upload 19 files
Browse files- .gitattributes +36 -35
- .gitignore +154 -0
- Dockerfile +43 -0
- MIDI.py +1735 -0
- README.md +13 -12
- app.py +534 -0
- app_onnx.py +626 -0
- example/Bach--Fugue-in-D-Minor.mid +3 -0
- example/Beethoven--Symphony-No5-in-C-Minor-Fate-Opus-67.mid +3 -0
- example/Chopin--Nocturne No. 9 in B Major, Opus 32 No.1, Andante Sostenuto.mid +3 -0
- example/Mozart--Requiem, No.1..mid +3 -0
- example/castle_in_the_sky.mid +3 -0
- example/eva-残酷な天使のテーゼ.mid +3 -0
- javascript/app.js +732 -0
- midi_model.py +250 -0
- midi_synthesizer.py +81 -0
- midi_tokenizer.py +1196 -0
- packages.txt +1 -0
- requirements.txt +11 -0
.gitattributes
CHANGED
@@ -1,35 +1,36 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.mid filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
105 |
+
__pypackages__/
|
106 |
+
|
107 |
+
# Celery stuff
|
108 |
+
celerybeat-schedule
|
109 |
+
celerybeat.pid
|
110 |
+
|
111 |
+
# SageMath parsed files
|
112 |
+
*.sage.py
|
113 |
+
|
114 |
+
# Environments
|
115 |
+
.env
|
116 |
+
.venv
|
117 |
+
env/
|
118 |
+
venv/
|
119 |
+
ENV/
|
120 |
+
env.bak/
|
121 |
+
venv.bak/
|
122 |
+
|
123 |
+
# Spyder project settings
|
124 |
+
.spyderproject
|
125 |
+
.spyproject
|
126 |
+
|
127 |
+
# Rope project settings
|
128 |
+
.ropeproject
|
129 |
+
|
130 |
+
# mkdocs documentation
|
131 |
+
/site
|
132 |
+
|
133 |
+
# mypy
|
134 |
+
.mypy_cache/
|
135 |
+
.dmypy.json
|
136 |
+
dmypy.json
|
137 |
+
|
138 |
+
# Pyre type checker
|
139 |
+
.pyre/
|
140 |
+
|
141 |
+
# pytype static type analyzer
|
142 |
+
.pytype/
|
143 |
+
|
144 |
+
# Cython debug symbols
|
145 |
+
cython_debug/
|
146 |
+
|
147 |
+
# PyCharm
|
148 |
+
# JetBrains specific template is maintainted in a separate JetBrains.gitignore that can
|
149 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
150 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
151 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
152 |
+
.idea/
|
153 |
+
output.mid
|
154 |
+
/outputs/
|
Dockerfile
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM nvidia/cuda:11.6.1-cudnn8-devel-ubuntu20.04
|
2 |
+
|
3 |
+
ARG DEBIAN_FRONTEND=noninteractive
|
4 |
+
|
5 |
+
ENV PYTHONUNBUFFERED=1
|
6 |
+
|
7 |
+
RUN apt-get update && apt-get install --no-install-recommends -y \
|
8 |
+
build-essential \
|
9 |
+
python3.9 \
|
10 |
+
python3-pip \
|
11 |
+
git \
|
12 |
+
ffmpeg \
|
13 |
+
fluidsynth \
|
14 |
+
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
15 |
+
|
16 |
+
WORKDIR /code
|
17 |
+
|
18 |
+
COPY ./requirements.txt /code/requirements.txt
|
19 |
+
|
20 |
+
# Set up a new user named "user" with user ID 1000
|
21 |
+
RUN useradd -m -u 1000 user
|
22 |
+
# Switch to the "user" user
|
23 |
+
USER user
|
24 |
+
# Set home to the user's home directory
|
25 |
+
ENV HOME=/home/user \
|
26 |
+
PATH=/home/user/.local/bin:$PATH \
|
27 |
+
PYTHONPATH=$HOME/app \
|
28 |
+
PYTHONUNBUFFERED=1 \
|
29 |
+
GRADIO_ALLOW_FLAGGING=never \
|
30 |
+
GRADIO_NUM_PORTS=1 \
|
31 |
+
GRADIO_SERVER_NAME=0.0.0.0 \
|
32 |
+
GRADIO_THEME=huggingface \
|
33 |
+
SYSTEM=spaces
|
34 |
+
|
35 |
+
RUN pip3 install --no-cache-dir --upgrade -r /code/requirements.txt
|
36 |
+
|
37 |
+
# Set the working directory to the user's home directory
|
38 |
+
WORKDIR $HOME/app
|
39 |
+
|
40 |
+
# Copy the current directory contents into the container at $HOME/app setting the owner to the user
|
41 |
+
COPY --chown=user . $HOME/app
|
42 |
+
|
43 |
+
CMD ["python3", "app.py"]
|
MIDI.py
ADDED
@@ -0,0 +1,1735 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /usr/bin/python3
|
2 |
+
# unsupported 20091104 ...
|
3 |
+
# ['set_sequence_number', dtime, sequence]
|
4 |
+
# ['raw_data', dtime, raw]
|
5 |
+
|
6 |
+
# 20150914 jimbo1qaz MIDI.py str/bytes bug report
|
7 |
+
# I found a MIDI file which had Shift-JIS titles. When midi.py decodes it as
|
8 |
+
# latin-1, it produces a string which cannot even be accessed without raising
|
9 |
+
# a UnicodeDecodeError. Maybe, when converting raw byte strings from MIDI,
|
10 |
+
# you should keep them as bytes, not improperly decode them. However, this
|
11 |
+
# would change the API. (ie: text = a "string" ? of 0 or more bytes). It
|
12 |
+
# could break compatiblity, but there's not much else you can do to fix the bug
|
13 |
+
# https://en.wikipedia.org/wiki/Shift_JIS
|
14 |
+
|
15 |
+
r'''
|
16 |
+
This module offers functions: concatenate_scores(), grep(),
|
17 |
+
merge_scores(), mix_scores(), midi2opus(), midi2score(), opus2midi(),
|
18 |
+
opus2score(), play_score(), score2midi(), score2opus(), score2stats(),
|
19 |
+
score_type(), segment(), timeshift() and to_millisecs(),
|
20 |
+
where "midi" means the MIDI-file bytes (as can be put in a .mid file,
|
21 |
+
or piped into aplaymidi), and "opus" and "score" are list-structures
|
22 |
+
as inspired by Sean Burke's MIDI-Perl CPAN module.
|
23 |
+
|
24 |
+
Warning: Version 6.4 is not necessarily backward-compatible with
|
25 |
+
previous versions, in that text-data is now bytes, not strings.
|
26 |
+
This reflects the fact that many MIDI files have text data in
|
27 |
+
encodings other that ISO-8859-1, for example in Shift-JIS.
|
28 |
+
|
29 |
+
Download MIDI.py from http://www.pjb.com.au/midi/free/MIDI.py
|
30 |
+
and put it in your PYTHONPATH. MIDI.py depends on Python3.
|
31 |
+
|
32 |
+
There is also a call-compatible translation into Lua of this
|
33 |
+
module: see http://www.pjb.com.au/comp/lua/MIDI.html
|
34 |
+
|
35 |
+
The "opus" is a direct translation of the midi-file-events, where
|
36 |
+
the times are delta-times, in ticks, since the previous event.
|
37 |
+
|
38 |
+
The "score" is more human-centric; it uses absolute times, and
|
39 |
+
combines the separate note_on and note_off events into one "note"
|
40 |
+
event, with a duration:
|
41 |
+
['note', start_time, duration, channel, note, velocity] # in a "score"
|
42 |
+
|
43 |
+
EVENTS (in an "opus" structure)
|
44 |
+
['note_off', dtime, channel, note, velocity] # in an "opus"
|
45 |
+
['note_on', dtime, channel, note, velocity] # in an "opus"
|
46 |
+
['key_after_touch', dtime, channel, note, velocity]
|
47 |
+
['control_change', dtime, channel, controller(0-127), value(0-127)]
|
48 |
+
['patch_change', dtime, channel, patch]
|
49 |
+
['channel_after_touch', dtime, channel, velocity]
|
50 |
+
['pitch_wheel_change', dtime, channel, pitch_wheel]
|
51 |
+
['text_event', dtime, text]
|
52 |
+
['copyright_text_event', dtime, text]
|
53 |
+
['track_name', dtime, text]
|
54 |
+
['instrument_name', dtime, text]
|
55 |
+
['lyric', dtime, text]
|
56 |
+
['marker', dtime, text]
|
57 |
+
['cue_point', dtime, text]
|
58 |
+
['text_event_08', dtime, text]
|
59 |
+
['text_event_09', dtime, text]
|
60 |
+
['text_event_0a', dtime, text]
|
61 |
+
['text_event_0b', dtime, text]
|
62 |
+
['text_event_0c', dtime, text]
|
63 |
+
['text_event_0d', dtime, text]
|
64 |
+
['text_event_0e', dtime, text]
|
65 |
+
['text_event_0f', dtime, text]
|
66 |
+
['end_track', dtime]
|
67 |
+
['set_tempo', dtime, tempo]
|
68 |
+
['smpte_offset', dtime, hr, mn, se, fr, ff]
|
69 |
+
['time_signature', dtime, nn, dd, cc, bb]
|
70 |
+
['key_signature', dtime, sf, mi]
|
71 |
+
['sequencer_specific', dtime, raw]
|
72 |
+
['raw_meta_event', dtime, command(0-255), raw]
|
73 |
+
['sysex_f0', dtime, raw]
|
74 |
+
['sysex_f7', dtime, raw]
|
75 |
+
['song_position', dtime, song_pos]
|
76 |
+
['song_select', dtime, song_number]
|
77 |
+
['tune_request', dtime]
|
78 |
+
|
79 |
+
DATA TYPES
|
80 |
+
channel = a value 0 to 15
|
81 |
+
controller = 0 to 127 (see http://www.pjb.com.au/muscript/gm.html#cc )
|
82 |
+
dtime = time measured in "ticks", 0 to 268435455
|
83 |
+
velocity = a value 0 (soft) to 127 (loud)
|
84 |
+
note = a value 0 to 127 (middle-C is 60)
|
85 |
+
patch = 0 to 127 (see http://www.pjb.com.au/muscript/gm.html )
|
86 |
+
pitch_wheel = a value -8192 to 8191 (0x1FFF)
|
87 |
+
raw = bytes, of length 0 or more (for sysex events see below)
|
88 |
+
sequence_number = a value 0 to 65,535 (0xFFFF)
|
89 |
+
song_pos = a value 0 to 16,383 (0x3FFF)
|
90 |
+
song_number = a value 0 to 127
|
91 |
+
tempo = microseconds per crochet (quarter-note), 0 to 16777215
|
92 |
+
text = bytes, of length 0 or more
|
93 |
+
ticks = the number of ticks per crochet (quarter-note)
|
94 |
+
|
95 |
+
In sysex_f0 events, the raw data must not start with a \xF0 byte,
|
96 |
+
since this gets added automatically;
|
97 |
+
but it must end with an explicit \xF7 byte!
|
98 |
+
In the very unlikely case that you ever need to split sysex data
|
99 |
+
into one sysex_f0 followed by one or more sysex_f7s, then only the
|
100 |
+
last of those sysex_f7 events must end with the explicit \xF7 byte
|
101 |
+
(again, the raw data of individual sysex_f7 events must not start
|
102 |
+
with any \xF7 byte, since this gets added automatically).
|
103 |
+
|
104 |
+
Since version 6.4, text data is in bytes, not in a ISO-8859-1 string.
|
105 |
+
|
106 |
+
|
107 |
+
GOING THROUGH A SCORE WITHIN A PYTHON PROGRAM
|
108 |
+
channels = {2,3,5,8,13}
|
109 |
+
itrack = 1 # skip 1st element which is ticks
|
110 |
+
while itrack < len(score):
|
111 |
+
for event in score[itrack]:
|
112 |
+
if event[0] == 'note': # for example,
|
113 |
+
pass # do something to all notes
|
114 |
+
# or, to work on events in only particular channels...
|
115 |
+
channel_index = MIDI.Event2channelindex.get(event[0], False)
|
116 |
+
if channel_index and (event[channel_index] in channels):
|
117 |
+
pass # do something to channels 2,3,5,8 and 13
|
118 |
+
itrack += 1
|
119 |
+
|
120 |
+
'''
|
121 |
+
|
122 |
+
import sys, struct, copy
|
123 |
+
# sys.stdout = os.fdopen(sys.stdout.fileno(), 'wb')
|
124 |
+
Version = '6.7'
|
125 |
+
VersionDate = '20201120'
|
126 |
+
# 20201120 6.7 call to bytest() removed, and protect _unshift_ber_int
|
127 |
+
# 20160702 6.6 to_millisecs() now handles set_tempo across multiple Tracks
|
128 |
+
# 20150921 6.5 segment restores controllers as well as patch and tempo
|
129 |
+
# 20150914 6.4 text data is bytes or bytearray, not ISO-8859-1 strings
|
130 |
+
# 20150628 6.3 absent any set_tempo, default is 120bpm (see MIDI file spec 1.1)
|
131 |
+
# 20150101 6.2 all text events can be 8-bit; let user get the right encoding
|
132 |
+
# 20141231 6.1 fix _some_text_event; sequencer_specific data can be 8-bit
|
133 |
+
# 20141230 6.0 synth_specific data can be 8-bit
|
134 |
+
# 20120504 5.9 add the contents of mid_opus_tracks()
|
135 |
+
# 20120208 5.8 fix num_notes_by_channel() ; should be a dict
|
136 |
+
# 20120129 5.7 _encode handles empty tracks; score2stats num_notes_by_channel
|
137 |
+
# 20111111 5.6 fix patch 45 and 46 in Number2patch, should be Harp
|
138 |
+
# 20110129 5.5 add mix_opus_tracks() and event2alsaseq()
|
139 |
+
# 20110126 5.4 "previous message repeated N times" to save space on stderr
|
140 |
+
# 20110125 5.2 opus2score terminates unended notes at the end of the track
|
141 |
+
# 20110124 5.1 the warnings in midi2opus display track_num
|
142 |
+
# 21110122 5.0 if garbage, midi2opus returns the opus so far
|
143 |
+
# 21110119 4.9 non-ascii chars stripped out of the text_events
|
144 |
+
# 21110110 4.8 note_on with velocity=0 treated as a note-off
|
145 |
+
# 21110108 4.6 unknown F-series event correctly eats just one byte
|
146 |
+
# 21011010 4.2 segment() uses start_time, end_time named params
|
147 |
+
# 21011005 4.1 timeshift() must not pad the set_tempo command
|
148 |
+
# 21011003 4.0 pitch2note_event must be chapitch2note_event
|
149 |
+
# 21010918 3.9 set_sequence_number supported, FWIW
|
150 |
+
# 20100913 3.7 many small bugfixes; passes all tests
|
151 |
+
# 20100910 3.6 concatenate_scores enforce ticks=1000, just like merge_scores
|
152 |
+
# 20100908 3.5 minor bugs fixed in score2stats
|
153 |
+
# 20091104 3.4 tune_request now supported
|
154 |
+
# 20091104 3.3 fixed bug in decoding song_position and song_select
|
155 |
+
# 20091104 3.2 unsupported: set_sequence_number tune_request raw_data
|
156 |
+
# 20091101 3.1 document how to traverse a score within Python
|
157 |
+
# 20091021 3.0 fixed bug in score2stats detecting GM-mode = 0
|
158 |
+
# 20091020 2.9 score2stats reports GM-mode and bank msb,lsb events
|
159 |
+
# 20091019 2.8 in merge_scores, channel 9 must remain channel 9 (in GM)
|
160 |
+
# 20091018 2.7 handles empty tracks gracefully
|
161 |
+
# 20091015 2.6 grep() selects channels
|
162 |
+
# 20091010 2.5 merge_scores reassigns channels to avoid conflicts
|
163 |
+
# 20091010 2.4 fixed bug in to_millisecs which now only does opusses
|
164 |
+
# 20091010 2.3 score2stats returns channels & patch_changes, by_track & total
|
165 |
+
# 20091010 2.2 score2stats() returns also pitches and percussion dicts
|
166 |
+
# 20091010 2.1 bugs: >= not > in segment, to notice patch_change at time 0
|
167 |
+
# 20091010 2.0 bugs: spurious pop(0) ( in _decode sysex
|
168 |
+
# 20091008 1.9 bugs: ISO decoding in sysex; str( not int( in note-off warning
|
169 |
+
# 20091008 1.8 add concatenate_scores()
|
170 |
+
# 20091006 1.7 score2stats() measures nticks and ticks_per_quarter
|
171 |
+
# 20091004 1.6 first mix_scores() and merge_scores()
|
172 |
+
# 20090424 1.5 timeshift() bugfix: earliest only sees events after from_time
|
173 |
+
# 20090330 1.4 timeshift() has also a from_time argument
|
174 |
+
# 20090322 1.3 timeshift() has also a start_time argument
|
175 |
+
# 20090319 1.2 add segment() and timeshift()
|
176 |
+
# 20090301 1.1 add to_millisecs()
|
177 |
+
|
178 |
+
_previous_warning = '' # 5.4
|
179 |
+
_previous_times = 0 # 5.4
|
180 |
+
_no_warning = True
|
181 |
+
#------------------------------- Encoding stuff --------------------------
|
182 |
+
|
183 |
+
def opus2midi(opus=[]):
|
184 |
+
r'''The argument is a list: the first item in the list is the "ticks"
|
185 |
+
parameter, the others are the tracks. Each track is a list
|
186 |
+
of midi-events, and each event is itself a list; see above.
|
187 |
+
opus2midi() returns a bytestring of the MIDI, which can then be
|
188 |
+
written either to a file opened in binary mode (mode='wb'),
|
189 |
+
or to stdout by means of: sys.stdout.buffer.write()
|
190 |
+
|
191 |
+
my_opus = [
|
192 |
+
96,
|
193 |
+
[ # track 0:
|
194 |
+
['patch_change', 0, 1, 8], # and these are the events...
|
195 |
+
['note_on', 5, 1, 25, 96],
|
196 |
+
['note_off', 96, 1, 25, 0],
|
197 |
+
['note_on', 0, 1, 29, 96],
|
198 |
+
['note_off', 96, 1, 29, 0],
|
199 |
+
], # end of track 0
|
200 |
+
]
|
201 |
+
my_midi = opus2midi(my_opus)
|
202 |
+
sys.stdout.buffer.write(my_midi)
|
203 |
+
'''
|
204 |
+
if len(opus) < 2:
|
205 |
+
opus=[1000, [],]
|
206 |
+
tracks = copy.deepcopy(opus)
|
207 |
+
ticks = int(tracks.pop(0))
|
208 |
+
ntracks = len(tracks)
|
209 |
+
if ntracks == 1:
|
210 |
+
format = 0
|
211 |
+
else:
|
212 |
+
format = 1
|
213 |
+
|
214 |
+
my_midi = b"MThd\x00\x00\x00\x06"+struct.pack('>HHH',format,ntracks,ticks)
|
215 |
+
for track in tracks:
|
216 |
+
events = _encode(track)
|
217 |
+
my_midi += b'MTrk' + struct.pack('>I',len(events)) + events
|
218 |
+
_clean_up_warnings()
|
219 |
+
return my_midi
|
220 |
+
|
221 |
+
|
222 |
+
def score2opus(score=None):
|
223 |
+
r'''
|
224 |
+
The argument is a list: the first item in the list is the "ticks"
|
225 |
+
parameter, the others are the tracks. Each track is a list
|
226 |
+
of score-events, and each event is itself a list. A score-event
|
227 |
+
is similar to an opus-event (see above), except that in a score:
|
228 |
+
1) the times are expressed as an absolute number of ticks
|
229 |
+
from the track's start time
|
230 |
+
2) the pairs of 'note_on' and 'note_off' events in an "opus"
|
231 |
+
are abstracted into a single 'note' event in a "score":
|
232 |
+
['note', start_time, duration, channel, pitch, velocity]
|
233 |
+
score2opus() returns a list specifying the equivalent "opus".
|
234 |
+
|
235 |
+
my_score = [
|
236 |
+
96,
|
237 |
+
[ # track 0:
|
238 |
+
['patch_change', 0, 1, 8],
|
239 |
+
['note', 5, 96, 1, 25, 96],
|
240 |
+
['note', 101, 96, 1, 29, 96]
|
241 |
+
], # end of track 0
|
242 |
+
]
|
243 |
+
my_opus = score2opus(my_score)
|
244 |
+
'''
|
245 |
+
if len(score) < 2:
|
246 |
+
score=[1000, [],]
|
247 |
+
tracks = copy.deepcopy(score)
|
248 |
+
ticks = int(tracks.pop(0))
|
249 |
+
opus_tracks = []
|
250 |
+
for scoretrack in tracks:
|
251 |
+
time2events = dict([])
|
252 |
+
for scoreevent in scoretrack:
|
253 |
+
if scoreevent[0] == 'note':
|
254 |
+
note_on_event = ['note_on',scoreevent[1],
|
255 |
+
scoreevent[3],scoreevent[4],scoreevent[5]]
|
256 |
+
note_off_event = ['note_off',scoreevent[1]+scoreevent[2],
|
257 |
+
scoreevent[3],scoreevent[4],scoreevent[5]]
|
258 |
+
if time2events.get(note_on_event[1]):
|
259 |
+
time2events[note_on_event[1]].append(note_on_event)
|
260 |
+
else:
|
261 |
+
time2events[note_on_event[1]] = [note_on_event,]
|
262 |
+
if time2events.get(note_off_event[1]):
|
263 |
+
time2events[note_off_event[1]].append(note_off_event)
|
264 |
+
else:
|
265 |
+
time2events[note_off_event[1]] = [note_off_event,]
|
266 |
+
continue
|
267 |
+
if time2events.get(scoreevent[1]):
|
268 |
+
time2events[scoreevent[1]].append(scoreevent)
|
269 |
+
else:
|
270 |
+
time2events[scoreevent[1]] = [scoreevent,]
|
271 |
+
|
272 |
+
sorted_times = [] # list of keys
|
273 |
+
for k in time2events.keys():
|
274 |
+
sorted_times.append(k)
|
275 |
+
sorted_times.sort()
|
276 |
+
|
277 |
+
sorted_events = [] # once-flattened list of values sorted by key
|
278 |
+
for time in sorted_times:
|
279 |
+
sorted_events.extend(time2events[time])
|
280 |
+
|
281 |
+
abs_time = 0
|
282 |
+
for event in sorted_events: # convert abs times => delta times
|
283 |
+
delta_time = event[1] - abs_time
|
284 |
+
abs_time = event[1]
|
285 |
+
event[1] = delta_time
|
286 |
+
opus_tracks.append(sorted_events)
|
287 |
+
opus_tracks.insert(0,ticks)
|
288 |
+
_clean_up_warnings()
|
289 |
+
return opus_tracks
|
290 |
+
|
291 |
+
def score2midi(score=None):
|
292 |
+
r'''
|
293 |
+
Translates a "score" into MIDI, using score2opus() then opus2midi()
|
294 |
+
'''
|
295 |
+
return opus2midi(score2opus(score))
|
296 |
+
|
297 |
+
#--------------------------- Decoding stuff ------------------------
|
298 |
+
|
299 |
+
def midi2opus(midi=b''):
|
300 |
+
r'''Translates MIDI into a "opus". For a description of the
|
301 |
+
"opus" format, see opus2midi()
|
302 |
+
'''
|
303 |
+
my_midi=bytearray(midi)
|
304 |
+
if len(my_midi) < 4:
|
305 |
+
_clean_up_warnings()
|
306 |
+
return [1000,[],]
|
307 |
+
id = bytes(my_midi[0:4])
|
308 |
+
if id != b'MThd':
|
309 |
+
_warn("midi2opus: midi starts with "+str(id)+" instead of 'MThd'")
|
310 |
+
_clean_up_warnings()
|
311 |
+
return [1000,[],]
|
312 |
+
[length, format, tracks_expected, ticks] = struct.unpack(
|
313 |
+
'>IHHH', bytes(my_midi[4:14]))
|
314 |
+
if length != 6:
|
315 |
+
_warn("midi2opus: midi header length was "+str(length)+" instead of 6")
|
316 |
+
_clean_up_warnings()
|
317 |
+
return [1000,[],]
|
318 |
+
my_opus = [ticks,]
|
319 |
+
my_midi = my_midi[14:]
|
320 |
+
track_num = 1 # 5.1
|
321 |
+
while len(my_midi) >= 8:
|
322 |
+
track_type = bytes(my_midi[0:4])
|
323 |
+
if track_type != b'MTrk':
|
324 |
+
_warn('midi2opus: Warning: track #'+str(track_num)+' type is '+str(track_type)+" instead of b'MTrk'")
|
325 |
+
[track_length] = struct.unpack('>I', my_midi[4:8])
|
326 |
+
my_midi = my_midi[8:]
|
327 |
+
if track_length > len(my_midi):
|
328 |
+
_warn('midi2opus: track #'+str(track_num)+' length '+str(track_length)+' is too large')
|
329 |
+
_clean_up_warnings()
|
330 |
+
return my_opus # 5.0
|
331 |
+
my_midi_track = my_midi[0:track_length]
|
332 |
+
my_track = _decode(my_midi_track)
|
333 |
+
my_opus.append(my_track)
|
334 |
+
my_midi = my_midi[track_length:]
|
335 |
+
track_num += 1 # 5.1
|
336 |
+
_clean_up_warnings()
|
337 |
+
return my_opus
|
338 |
+
|
339 |
+
def opus2score(opus=[]):
|
340 |
+
r'''For a description of the "opus" and "score" formats,
|
341 |
+
see opus2midi() and score2opus().
|
342 |
+
'''
|
343 |
+
if len(opus) < 2:
|
344 |
+
_clean_up_warnings()
|
345 |
+
return [1000,[],]
|
346 |
+
tracks = copy.deepcopy(opus) # couple of slices probably quicker...
|
347 |
+
ticks = int(tracks.pop(0))
|
348 |
+
score = [ticks,]
|
349 |
+
for opus_track in tracks:
|
350 |
+
ticks_so_far = 0
|
351 |
+
score_track = []
|
352 |
+
chapitch2note_on_events = dict([]) # 4.0
|
353 |
+
for opus_event in opus_track:
|
354 |
+
ticks_so_far += opus_event[1]
|
355 |
+
if opus_event[0] == 'note_off' or (opus_event[0] == 'note_on' and opus_event[4] == 0): # 4.8
|
356 |
+
cha = opus_event[2]
|
357 |
+
pitch = opus_event[3]
|
358 |
+
key = cha*128 + pitch
|
359 |
+
if chapitch2note_on_events.get(key):
|
360 |
+
new_event = chapitch2note_on_events[key].pop(0)
|
361 |
+
new_event[2] = ticks_so_far - new_event[1]
|
362 |
+
score_track.append(new_event)
|
363 |
+
elif pitch > 127:
|
364 |
+
pass #_warn('opus2score: note_off with no note_on, bad pitch='+str(pitch))
|
365 |
+
else:
|
366 |
+
pass #_warn('opus2score: note_off with no note_on cha='+str(cha)+' pitch='+str(pitch))
|
367 |
+
elif opus_event[0] == 'note_on':
|
368 |
+
cha = opus_event[2]
|
369 |
+
pitch = opus_event[3]
|
370 |
+
key = cha*128 + pitch
|
371 |
+
new_event = ['note',ticks_so_far,0,cha,pitch, opus_event[4]]
|
372 |
+
if chapitch2note_on_events.get(key):
|
373 |
+
chapitch2note_on_events[key].append(new_event)
|
374 |
+
else:
|
375 |
+
chapitch2note_on_events[key] = [new_event,]
|
376 |
+
else:
|
377 |
+
opus_event[1] = ticks_so_far
|
378 |
+
score_track.append(opus_event)
|
379 |
+
# check for unterminated notes (Oisín) -- 5.2
|
380 |
+
for chapitch in chapitch2note_on_events:
|
381 |
+
note_on_events = chapitch2note_on_events[chapitch]
|
382 |
+
for new_e in note_on_events:
|
383 |
+
new_e[2] = ticks_so_far - new_e[1]
|
384 |
+
score_track.append(new_e)
|
385 |
+
pass #_warn("opus2score: note_on with no note_off cha="+str(new_e[3])+' pitch='+str(new_e[4])+'; adding note_off at end')
|
386 |
+
score.append(score_track)
|
387 |
+
_clean_up_warnings()
|
388 |
+
return score
|
389 |
+
|
390 |
+
def midi2score(midi=b''):
|
391 |
+
r'''
|
392 |
+
Translates MIDI into a "score", using midi2opus() then opus2score()
|
393 |
+
'''
|
394 |
+
return opus2score(midi2opus(midi))
|
395 |
+
|
396 |
+
def midi2ms_score(midi=b''):
|
397 |
+
r'''
|
398 |
+
Translates MIDI into a "score" with one beat per second and one
|
399 |
+
tick per millisecond, using midi2opus() then to_millisecs()
|
400 |
+
then opus2score()
|
401 |
+
'''
|
402 |
+
return opus2score(to_millisecs(midi2opus(midi)))
|
403 |
+
|
404 |
+
#------------------------ Other Transformations ---------------------
|
405 |
+
|
406 |
+
def to_millisecs(old_opus=None):
|
407 |
+
r'''Recallibrates all the times in an "opus" to use one beat
|
408 |
+
per second and one tick per millisecond. This makes it
|
409 |
+
hard to retrieve any information about beats or barlines,
|
410 |
+
but it does make it easy to mix different scores together.
|
411 |
+
'''
|
412 |
+
if old_opus == None:
|
413 |
+
return [1000,[],]
|
414 |
+
try:
|
415 |
+
old_tpq = int(old_opus[0])
|
416 |
+
except IndexError: # 5.0
|
417 |
+
_warn('to_millisecs: the opus '+str(type(old_opus))+' has no elements')
|
418 |
+
return [1000,[],]
|
419 |
+
new_opus = [1000,]
|
420 |
+
# 6.7 first go through building a table of set_tempos by absolute-tick
|
421 |
+
ticks2tempo = {}
|
422 |
+
itrack = 1
|
423 |
+
while itrack < len(old_opus):
|
424 |
+
ticks_so_far = 0
|
425 |
+
for old_event in old_opus[itrack]:
|
426 |
+
if old_event[0] == 'note':
|
427 |
+
raise TypeError('to_millisecs needs an opus, not a score')
|
428 |
+
ticks_so_far += old_event[1]
|
429 |
+
if old_event[0] == 'set_tempo':
|
430 |
+
ticks2tempo[ticks_so_far] = old_event[2]
|
431 |
+
itrack += 1
|
432 |
+
# then get the sorted-array of their keys
|
433 |
+
tempo_ticks = [] # list of keys
|
434 |
+
for k in ticks2tempo.keys():
|
435 |
+
tempo_ticks.append(k)
|
436 |
+
tempo_ticks.sort()
|
437 |
+
# then go through converting to millisec, testing if the next
|
438 |
+
# set_tempo lies before the next track-event, and using it if so.
|
439 |
+
itrack = 1
|
440 |
+
while itrack < len(old_opus):
|
441 |
+
ms_per_old_tick = 500.0 / old_tpq # float: will round later 6.3
|
442 |
+
i_tempo_ticks = 0
|
443 |
+
ticks_so_far = 0
|
444 |
+
ms_so_far = 0.0
|
445 |
+
previous_ms_so_far = 0.0
|
446 |
+
new_track = [['set_tempo',0,1000000],] # new "crochet" is 1 sec
|
447 |
+
for old_event in old_opus[itrack]:
|
448 |
+
# detect if ticks2tempo has something before this event
|
449 |
+
# 20160702 if ticks2tempo is at the same time, leave it
|
450 |
+
event_delta_ticks = old_event[1]
|
451 |
+
if (i_tempo_ticks < len(tempo_ticks) and
|
452 |
+
tempo_ticks[i_tempo_ticks] < (ticks_so_far + old_event[1])):
|
453 |
+
delta_ticks = tempo_ticks[i_tempo_ticks] - ticks_so_far
|
454 |
+
ms_so_far += (ms_per_old_tick * delta_ticks)
|
455 |
+
ticks_so_far = tempo_ticks[i_tempo_ticks]
|
456 |
+
ms_per_old_tick = ticks2tempo[ticks_so_far] / (1000.0*old_tpq)
|
457 |
+
i_tempo_ticks += 1
|
458 |
+
event_delta_ticks -= delta_ticks
|
459 |
+
new_event = copy.deepcopy(old_event) # now handle the new event
|
460 |
+
ms_so_far += (ms_per_old_tick * old_event[1])
|
461 |
+
new_event[1] = round(ms_so_far - previous_ms_so_far)
|
462 |
+
if old_event[0] != 'set_tempo':
|
463 |
+
previous_ms_so_far = ms_so_far
|
464 |
+
new_track.append(new_event)
|
465 |
+
ticks_so_far += event_delta_ticks
|
466 |
+
new_opus.append(new_track)
|
467 |
+
itrack += 1
|
468 |
+
_clean_up_warnings()
|
469 |
+
return new_opus
|
470 |
+
|
471 |
+
def event2alsaseq(event=None): # 5.5
|
472 |
+
r'''Converts an event into the format needed by the alsaseq module,
|
473 |
+
http://pp.com.mx/python/alsaseq
|
474 |
+
The type of track (opus or score) is autodetected.
|
475 |
+
'''
|
476 |
+
pass
|
477 |
+
|
478 |
+
def grep(score=None, channels=None):
|
479 |
+
r'''Returns a "score" containing only the channels specified
|
480 |
+
'''
|
481 |
+
if score == None:
|
482 |
+
return [1000,[],]
|
483 |
+
ticks = score[0]
|
484 |
+
new_score = [ticks,]
|
485 |
+
if channels == None:
|
486 |
+
return new_score
|
487 |
+
channels = set(channels)
|
488 |
+
global Event2channelindex
|
489 |
+
itrack = 1
|
490 |
+
while itrack < len(score):
|
491 |
+
new_score.append([])
|
492 |
+
for event in score[itrack]:
|
493 |
+
channel_index = Event2channelindex.get(event[0], False)
|
494 |
+
if channel_index:
|
495 |
+
if event[channel_index] in channels:
|
496 |
+
new_score[itrack].append(event)
|
497 |
+
else:
|
498 |
+
new_score[itrack].append(event)
|
499 |
+
itrack += 1
|
500 |
+
return new_score
|
501 |
+
|
502 |
+
def play_score(score=None):
|
503 |
+
r'''Converts the "score" to midi, and feeds it into 'aplaymidi -'
|
504 |
+
'''
|
505 |
+
if score == None:
|
506 |
+
return
|
507 |
+
import subprocess
|
508 |
+
pipe = subprocess.Popen(['aplaymidi','-'], stdin=subprocess.PIPE)
|
509 |
+
if score_type(score) == 'opus':
|
510 |
+
pipe.stdin.write(opus2midi(score))
|
511 |
+
else:
|
512 |
+
pipe.stdin.write(score2midi(score))
|
513 |
+
pipe.stdin.close()
|
514 |
+
|
515 |
+
def timeshift(score=None, shift=None, start_time=None, from_time=0, tracks={0,1,2,3,4,5,6,7,8,10,12,13,14,15}):
|
516 |
+
r'''Returns a "score" shifted in time by "shift" ticks, or shifted
|
517 |
+
so that the first event starts at "start_time" ticks.
|
518 |
+
|
519 |
+
If "from_time" is specified, only those events in the score
|
520 |
+
that begin after it are shifted. If "start_time" is less than
|
521 |
+
"from_time" (or "shift" is negative), then the intermediate
|
522 |
+
notes are deleted, though patch-change events are preserved.
|
523 |
+
|
524 |
+
If "tracks" are specified, then only those tracks get shifted.
|
525 |
+
"tracks" can be a list, tuple or set; it gets converted to set
|
526 |
+
internally.
|
527 |
+
|
528 |
+
It is deprecated to specify both "shift" and "start_time".
|
529 |
+
If this does happen, timeshift() will print a warning to
|
530 |
+
stderr and ignore the "shift" argument.
|
531 |
+
|
532 |
+
If "shift" is negative and sufficiently large that it would
|
533 |
+
leave some event with a negative tick-value, then the score
|
534 |
+
is shifted so that the first event occurs at time 0. This
|
535 |
+
also occurs if "start_time" is negative, and is also the
|
536 |
+
default if neither "shift" nor "start_time" are specified.
|
537 |
+
'''
|
538 |
+
#_warn('tracks='+str(tracks))
|
539 |
+
if score == None or len(score) < 2:
|
540 |
+
return [1000, [],]
|
541 |
+
new_score = [score[0],]
|
542 |
+
my_type = score_type(score)
|
543 |
+
if my_type == '':
|
544 |
+
return new_score
|
545 |
+
if my_type == 'opus':
|
546 |
+
_warn("timeshift: opus format is not supported\n")
|
547 |
+
# _clean_up_scores() 6.2; doesn't exist! what was it supposed to do?
|
548 |
+
return new_score
|
549 |
+
if not (shift == None) and not (start_time == None):
|
550 |
+
_warn("timeshift: shift and start_time specified: ignoring shift\n")
|
551 |
+
shift = None
|
552 |
+
if shift == None:
|
553 |
+
if (start_time == None) or (start_time < 0):
|
554 |
+
start_time = 0
|
555 |
+
# shift = start_time - from_time
|
556 |
+
|
557 |
+
i = 1 # ignore first element (ticks)
|
558 |
+
tracks = set(tracks) # defend against tuples and lists
|
559 |
+
earliest = 1000000000
|
560 |
+
if not (start_time == None) or shift < 0: # first find the earliest event
|
561 |
+
while i < len(score):
|
562 |
+
if len(tracks) and not ((i-1) in tracks):
|
563 |
+
i += 1
|
564 |
+
continue
|
565 |
+
for event in score[i]:
|
566 |
+
if event[1] < from_time:
|
567 |
+
continue # just inspect the to_be_shifted events
|
568 |
+
if event[1] < earliest:
|
569 |
+
earliest = event[1]
|
570 |
+
i += 1
|
571 |
+
if earliest > 999999999:
|
572 |
+
earliest = 0
|
573 |
+
if shift == None:
|
574 |
+
shift = start_time - earliest
|
575 |
+
elif (earliest + shift) < 0:
|
576 |
+
start_time = 0
|
577 |
+
shift = 0 - earliest
|
578 |
+
|
579 |
+
i = 1 # ignore first element (ticks)
|
580 |
+
while i < len(score):
|
581 |
+
if len(tracks) == 0 or not ((i-1) in tracks): # 3.8
|
582 |
+
new_score.append(score[i])
|
583 |
+
i += 1
|
584 |
+
continue
|
585 |
+
new_track = []
|
586 |
+
for event in score[i]:
|
587 |
+
new_event = list(event)
|
588 |
+
#if new_event[1] == 0 and shift > 0 and new_event[0] != 'note':
|
589 |
+
# pass
|
590 |
+
#elif new_event[1] >= from_time:
|
591 |
+
if new_event[1] >= from_time:
|
592 |
+
# 4.1 must not rightshift set_tempo
|
593 |
+
if new_event[0] != 'set_tempo' or shift<0:
|
594 |
+
new_event[1] += shift
|
595 |
+
elif (shift < 0) and (new_event[1] >= (from_time+shift)):
|
596 |
+
continue
|
597 |
+
new_track.append(new_event)
|
598 |
+
if len(new_track) > 0:
|
599 |
+
new_score.append(new_track)
|
600 |
+
i += 1
|
601 |
+
_clean_up_warnings()
|
602 |
+
return new_score
|
603 |
+
|
604 |
+
def segment(score=None, start_time=None, end_time=None, start=0, end=100000000,
|
605 |
+
tracks={0,1,2,3,4,5,6,7,8,10,11,12,13,14,15}):
|
606 |
+
r'''Returns a "score" which is a segment of the one supplied
|
607 |
+
as the argument, beginning at "start_time" ticks and ending
|
608 |
+
at "end_time" ticks (or at the end if "end_time" is not supplied).
|
609 |
+
If the set "tracks" is specified, only those tracks will
|
610 |
+
be returned.
|
611 |
+
'''
|
612 |
+
if score == None or len(score) < 2:
|
613 |
+
return [1000, [],]
|
614 |
+
if start_time == None: # as of 4.2 start_time is recommended
|
615 |
+
start_time = start # start is legacy usage
|
616 |
+
if end_time == None: # likewise
|
617 |
+
end_time = end
|
618 |
+
new_score = [score[0],]
|
619 |
+
my_type = score_type(score)
|
620 |
+
if my_type == '':
|
621 |
+
return new_score
|
622 |
+
if my_type == 'opus':
|
623 |
+
# more difficult (disconnecting note_on's from their note_off's)...
|
624 |
+
_warn("segment: opus format is not supported\n")
|
625 |
+
_clean_up_warnings()
|
626 |
+
return new_score
|
627 |
+
i = 1 # ignore first element (ticks); we count in ticks anyway
|
628 |
+
tracks = set(tracks) # defend against tuples and lists
|
629 |
+
while i < len(score):
|
630 |
+
if len(tracks) and not ((i-1) in tracks):
|
631 |
+
i += 1
|
632 |
+
continue
|
633 |
+
new_track = []
|
634 |
+
channel2cc_num = {} # most recent controller change before start
|
635 |
+
channel2cc_val = {}
|
636 |
+
channel2cc_time = {}
|
637 |
+
channel2patch_num = {} # keep most recent patch change before start
|
638 |
+
channel2patch_time = {}
|
639 |
+
set_tempo_num = 500000 # most recent tempo change before start 6.3
|
640 |
+
set_tempo_time = 0
|
641 |
+
earliest_note_time = end_time
|
642 |
+
for event in score[i]:
|
643 |
+
if event[0] == 'control_change': # 6.5
|
644 |
+
cc_time = channel2cc_time.get(event[2]) or 0
|
645 |
+
if (event[1] <= start_time) and (event[1] >= cc_time):
|
646 |
+
channel2cc_num[event[2]] = event[3]
|
647 |
+
channel2cc_val[event[2]] = event[4]
|
648 |
+
channel2cc_time[event[2]] = event[1]
|
649 |
+
elif event[0] == 'patch_change':
|
650 |
+
patch_time = channel2patch_time.get(event[2]) or 0
|
651 |
+
if (event[1]<=start_time) and (event[1] >= patch_time): # 2.0
|
652 |
+
channel2patch_num[event[2]] = event[3]
|
653 |
+
channel2patch_time[event[2]] = event[1]
|
654 |
+
elif event[0] == 'set_tempo':
|
655 |
+
if (event[1]<=start_time) and (event[1]>=set_tempo_time): #6.4
|
656 |
+
set_tempo_num = event[2]
|
657 |
+
set_tempo_time = event[1]
|
658 |
+
if (event[1] >= start_time) and (event[1] <= end_time):
|
659 |
+
new_track.append(event)
|
660 |
+
if (event[0] == 'note') and (event[1] < earliest_note_time):
|
661 |
+
earliest_note_time = event[1]
|
662 |
+
if len(new_track) > 0:
|
663 |
+
new_track.append(['set_tempo', start_time, set_tempo_num])
|
664 |
+
for c in channel2patch_num:
|
665 |
+
new_track.append(['patch_change',start_time,c,channel2patch_num[c]],)
|
666 |
+
for c in channel2cc_num: # 6.5
|
667 |
+
new_track.append(['control_change',start_time,c,channel2cc_num[c],channel2cc_val[c]])
|
668 |
+
new_score.append(new_track)
|
669 |
+
i += 1
|
670 |
+
_clean_up_warnings()
|
671 |
+
return new_score
|
672 |
+
|
673 |
+
def score_type(opus_or_score=None):
|
674 |
+
r'''Returns a string, either 'opus' or 'score' or ''
|
675 |
+
'''
|
676 |
+
if opus_or_score == None or str(type(opus_or_score)).find('list')<0 or len(opus_or_score) < 2:
|
677 |
+
return ''
|
678 |
+
i = 1 # ignore first element
|
679 |
+
while i < len(opus_or_score):
|
680 |
+
for event in opus_or_score[i]:
|
681 |
+
if event[0] == 'note':
|
682 |
+
return 'score'
|
683 |
+
elif event[0] == 'note_on':
|
684 |
+
return 'opus'
|
685 |
+
i += 1
|
686 |
+
return ''
|
687 |
+
|
688 |
+
def concatenate_scores(scores):
|
689 |
+
r'''Concatenates a list of scores into one score.
|
690 |
+
If the scores differ in their "ticks" parameter,
|
691 |
+
they will all get converted to millisecond-tick format.
|
692 |
+
'''
|
693 |
+
# the deepcopys are needed if the input_score's are refs to the same obj
|
694 |
+
# e.g. if invoked by midisox's repeat()
|
695 |
+
input_scores = _consistentise_ticks(scores) # 3.7
|
696 |
+
output_score = copy.deepcopy(input_scores[0])
|
697 |
+
for input_score in input_scores[1:]:
|
698 |
+
output_stats = score2stats(output_score)
|
699 |
+
delta_ticks = output_stats['nticks']
|
700 |
+
itrack = 1
|
701 |
+
while itrack < len(input_score):
|
702 |
+
if itrack >= len(output_score): # new output track if doesn't exist
|
703 |
+
output_score.append([])
|
704 |
+
for event in input_score[itrack]:
|
705 |
+
output_score[itrack].append(copy.deepcopy(event))
|
706 |
+
output_score[itrack][-1][1] += delta_ticks
|
707 |
+
itrack += 1
|
708 |
+
return output_score
|
709 |
+
|
710 |
+
def merge_scores(scores):
|
711 |
+
r'''Merges a list of scores into one score. A merged score comprises
|
712 |
+
all of the tracks from all of the input scores; un-merging is possible
|
713 |
+
by selecting just some of the tracks. If the scores differ in their
|
714 |
+
"ticks" parameter, they will all get converted to millisecond-tick
|
715 |
+
format. merge_scores attempts to resolve channel-conflicts,
|
716 |
+
but there are of course only 15 available channels...
|
717 |
+
'''
|
718 |
+
input_scores = _consistentise_ticks(scores) # 3.6
|
719 |
+
output_score = [1000]
|
720 |
+
channels_so_far = set()
|
721 |
+
all_channels = {0,1,2,3,4,5,6,7,8,10,11,12,13,14,15}
|
722 |
+
global Event2channelindex
|
723 |
+
for input_score in input_scores:
|
724 |
+
new_channels = set(score2stats(input_score).get('channels_total', []))
|
725 |
+
new_channels.discard(9) # 2.8 cha9 must remain cha9 (in GM)
|
726 |
+
for channel in channels_so_far & new_channels:
|
727 |
+
# consistently choose lowest avaiable, to ease testing
|
728 |
+
free_channels = list(all_channels - (channels_so_far|new_channels))
|
729 |
+
if len(free_channels) > 0:
|
730 |
+
free_channels.sort()
|
731 |
+
free_channel = free_channels[0]
|
732 |
+
else:
|
733 |
+
free_channel = None
|
734 |
+
break
|
735 |
+
itrack = 1
|
736 |
+
while itrack < len(input_score):
|
737 |
+
for input_event in input_score[itrack]:
|
738 |
+
channel_index=Event2channelindex.get(input_event[0],False)
|
739 |
+
if channel_index and input_event[channel_index]==channel:
|
740 |
+
input_event[channel_index] = free_channel
|
741 |
+
itrack += 1
|
742 |
+
channels_so_far.add(free_channel)
|
743 |
+
|
744 |
+
channels_so_far |= new_channels
|
745 |
+
output_score.extend(input_score[1:])
|
746 |
+
return output_score
|
747 |
+
|
748 |
+
def _ticks(event):
|
749 |
+
return event[1]
|
750 |
+
def mix_opus_tracks(input_tracks): # 5.5
|
751 |
+
r'''Mixes an array of tracks into one track. A mixed track
|
752 |
+
cannot be un-mixed. It is assumed that the tracks share the same
|
753 |
+
ticks parameter and the same tempo.
|
754 |
+
Mixing score-tracks is trivial (just insert all events into one array).
|
755 |
+
Mixing opus-tracks is only slightly harder, but it's common enough
|
756 |
+
that a dedicated function is useful.
|
757 |
+
'''
|
758 |
+
output_score = [1000, []]
|
759 |
+
for input_track in input_tracks: # 5.8
|
760 |
+
input_score = opus2score([1000, input_track])
|
761 |
+
for event in input_score[1]:
|
762 |
+
output_score[1].append(event)
|
763 |
+
output_score[1].sort(key=_ticks)
|
764 |
+
output_opus = score2opus(output_score)
|
765 |
+
return output_opus[1]
|
766 |
+
|
767 |
+
def mix_scores(scores):
|
768 |
+
r'''Mixes a list of scores into one one-track score.
|
769 |
+
A mixed score cannot be un-mixed. Hopefully the scores
|
770 |
+
have no undesirable channel-conflicts between them.
|
771 |
+
If the scores differ in their "ticks" parameter,
|
772 |
+
they will all get converted to millisecond-tick format.
|
773 |
+
'''
|
774 |
+
input_scores = _consistentise_ticks(scores) # 3.6
|
775 |
+
output_score = [1000, []]
|
776 |
+
for input_score in input_scores:
|
777 |
+
for input_track in input_score[1:]:
|
778 |
+
output_score[1].extend(input_track)
|
779 |
+
return output_score
|
780 |
+
|
781 |
+
def score2stats(opus_or_score=None):
|
782 |
+
r'''Returns a dict of some basic stats about the score, like
|
783 |
+
bank_select (list of tuples (msb,lsb)),
|
784 |
+
channels_by_track (list of lists), channels_total (set),
|
785 |
+
general_midi_mode (list),
|
786 |
+
ntracks, nticks, patch_changes_by_track (list of dicts),
|
787 |
+
num_notes_by_channel (list of numbers),
|
788 |
+
patch_changes_total (set),
|
789 |
+
percussion (dict histogram of channel 9 events),
|
790 |
+
pitches (dict histogram of pitches on channels other than 9),
|
791 |
+
pitch_range_by_track (list, by track, of two-member-tuples),
|
792 |
+
pitch_range_sum (sum over tracks of the pitch_ranges),
|
793 |
+
'''
|
794 |
+
bank_select_msb = -1
|
795 |
+
bank_select_lsb = -1
|
796 |
+
bank_select = []
|
797 |
+
channels_by_track = []
|
798 |
+
channels_total = set([])
|
799 |
+
general_midi_mode = []
|
800 |
+
num_notes_by_channel = dict([])
|
801 |
+
patches_used_by_track = []
|
802 |
+
patches_used_total = set([])
|
803 |
+
patch_changes_by_track = []
|
804 |
+
patch_changes_total = set([])
|
805 |
+
percussion = dict([]) # histogram of channel 9 "pitches"
|
806 |
+
pitches = dict([]) # histogram of pitch-occurrences channels 0-8,10-15
|
807 |
+
pitch_range_sum = 0 # u pitch-ranges of each track
|
808 |
+
pitch_range_by_track = []
|
809 |
+
is_a_score = True
|
810 |
+
if opus_or_score == None:
|
811 |
+
return {'bank_select':[], 'channels_by_track':[], 'channels_total':[],
|
812 |
+
'general_midi_mode':[], 'ntracks':0, 'nticks':0,
|
813 |
+
'num_notes_by_channel':dict([]),
|
814 |
+
'patch_changes_by_track':[], 'patch_changes_total':[],
|
815 |
+
'percussion':{}, 'pitches':{}, 'pitch_range_by_track':[],
|
816 |
+
'ticks_per_quarter':0, 'pitch_range_sum':0}
|
817 |
+
ticks_per_quarter = opus_or_score[0]
|
818 |
+
i = 1 # ignore first element, which is ticks
|
819 |
+
nticks = 0
|
820 |
+
while i < len(opus_or_score):
|
821 |
+
highest_pitch = 0
|
822 |
+
lowest_pitch = 128
|
823 |
+
channels_this_track = set([])
|
824 |
+
patch_changes_this_track = dict({})
|
825 |
+
for event in opus_or_score[i]:
|
826 |
+
if event[0] == 'note':
|
827 |
+
num_notes_by_channel[event[3]] = num_notes_by_channel.get(event[3],0) + 1
|
828 |
+
if event[3] == 9:
|
829 |
+
percussion[event[4]] = percussion.get(event[4],0) + 1
|
830 |
+
else:
|
831 |
+
pitches[event[4]] = pitches.get(event[4],0) + 1
|
832 |
+
if event[4] > highest_pitch:
|
833 |
+
highest_pitch = event[4]
|
834 |
+
if event[4] < lowest_pitch:
|
835 |
+
lowest_pitch = event[4]
|
836 |
+
channels_this_track.add(event[3])
|
837 |
+
channels_total.add(event[3])
|
838 |
+
finish_time = event[1] + event[2]
|
839 |
+
if finish_time > nticks:
|
840 |
+
nticks = finish_time
|
841 |
+
elif event[0] == 'note_off' or (event[0] == 'note_on' and event[4] == 0): # 4.8
|
842 |
+
finish_time = event[1]
|
843 |
+
if finish_time > nticks:
|
844 |
+
nticks = finish_time
|
845 |
+
elif event[0] == 'note_on':
|
846 |
+
is_a_score = False
|
847 |
+
num_notes_by_channel[event[2]] = num_notes_by_channel.get(event[2],0) + 1
|
848 |
+
if event[2] == 9:
|
849 |
+
percussion[event[3]] = percussion.get(event[3],0) + 1
|
850 |
+
else:
|
851 |
+
pitches[event[3]] = pitches.get(event[3],0) + 1
|
852 |
+
if event[3] > highest_pitch:
|
853 |
+
highest_pitch = event[3]
|
854 |
+
if event[3] < lowest_pitch:
|
855 |
+
lowest_pitch = event[3]
|
856 |
+
channels_this_track.add(event[2])
|
857 |
+
channels_total.add(event[2])
|
858 |
+
elif event[0] == 'patch_change':
|
859 |
+
patch_changes_this_track[event[2]] = event[3]
|
860 |
+
patch_changes_total.add(event[3])
|
861 |
+
elif event[0] == 'control_change':
|
862 |
+
if event[3] == 0: # bank select MSB
|
863 |
+
bank_select_msb = event[4]
|
864 |
+
elif event[3] == 32: # bank select LSB
|
865 |
+
bank_select_lsb = event[4]
|
866 |
+
if bank_select_msb >= 0 and bank_select_lsb >= 0:
|
867 |
+
bank_select.append((bank_select_msb,bank_select_lsb))
|
868 |
+
bank_select_msb = -1
|
869 |
+
bank_select_lsb = -1
|
870 |
+
elif event[0] == 'sysex_f0':
|
871 |
+
if _sysex2midimode.get(event[2], -1) >= 0:
|
872 |
+
general_midi_mode.append(_sysex2midimode.get(event[2]))
|
873 |
+
if is_a_score:
|
874 |
+
if event[1] > nticks:
|
875 |
+
nticks = event[1]
|
876 |
+
else:
|
877 |
+
nticks += event[1]
|
878 |
+
if lowest_pitch == 128:
|
879 |
+
lowest_pitch = 0
|
880 |
+
channels_by_track.append(channels_this_track)
|
881 |
+
patch_changes_by_track.append(patch_changes_this_track)
|
882 |
+
pitch_range_by_track.append((lowest_pitch,highest_pitch))
|
883 |
+
pitch_range_sum += (highest_pitch-lowest_pitch)
|
884 |
+
i += 1
|
885 |
+
|
886 |
+
return {'bank_select':bank_select,
|
887 |
+
'channels_by_track':channels_by_track,
|
888 |
+
'channels_total':channels_total,
|
889 |
+
'general_midi_mode':general_midi_mode,
|
890 |
+
'ntracks':len(opus_or_score)-1,
|
891 |
+
'nticks':nticks,
|
892 |
+
'num_notes_by_channel':num_notes_by_channel,
|
893 |
+
'patch_changes_by_track':patch_changes_by_track,
|
894 |
+
'patch_changes_total':patch_changes_total,
|
895 |
+
'percussion':percussion,
|
896 |
+
'pitches':pitches,
|
897 |
+
'pitch_range_by_track':pitch_range_by_track,
|
898 |
+
'pitch_range_sum':pitch_range_sum,
|
899 |
+
'ticks_per_quarter':ticks_per_quarter}
|
900 |
+
|
901 |
+
#----------------------------- Event stuff --------------------------
|
902 |
+
|
903 |
+
_sysex2midimode = {
|
904 |
+
"\x7E\x7F\x09\x01\xF7": 1,
|
905 |
+
"\x7E\x7F\x09\x02\xF7": 0,
|
906 |
+
"\x7E\x7F\x09\x03\xF7": 2,
|
907 |
+
}
|
908 |
+
|
909 |
+
# Some public-access tuples:
|
910 |
+
MIDI_events = tuple('''note_off note_on key_after_touch
|
911 |
+
control_change patch_change channel_after_touch
|
912 |
+
pitch_wheel_change'''.split())
|
913 |
+
|
914 |
+
Text_events = tuple('''text_event copyright_text_event
|
915 |
+
track_name instrument_name lyric marker cue_point text_event_08
|
916 |
+
text_event_09 text_event_0a text_event_0b text_event_0c
|
917 |
+
text_event_0d text_event_0e text_event_0f'''.split())
|
918 |
+
|
919 |
+
Nontext_meta_events = tuple('''end_track set_tempo
|
920 |
+
smpte_offset time_signature key_signature sequencer_specific
|
921 |
+
raw_meta_event sysex_f0 sysex_f7 song_position song_select
|
922 |
+
tune_request'''.split())
|
923 |
+
# unsupported: raw_data
|
924 |
+
|
925 |
+
# Actually, 'tune_request' is is F-series event, not strictly a meta-event...
|
926 |
+
Meta_events = Text_events + Nontext_meta_events
|
927 |
+
All_events = MIDI_events + Meta_events
|
928 |
+
|
929 |
+
# And three dictionaries:
|
930 |
+
Number2patch = { # General MIDI patch numbers:
|
931 |
+
0:'Acoustic Grand',
|
932 |
+
1:'Bright Acoustic',
|
933 |
+
2:'Electric Grand',
|
934 |
+
3:'Honky-Tonk',
|
935 |
+
4:'Electric Piano 1',
|
936 |
+
5:'Electric Piano 2',
|
937 |
+
6:'Harpsichord',
|
938 |
+
7:'Clav',
|
939 |
+
8:'Celesta',
|
940 |
+
9:'Glockenspiel',
|
941 |
+
10:'Music Box',
|
942 |
+
11:'Vibraphone',
|
943 |
+
12:'Marimba',
|
944 |
+
13:'Xylophone',
|
945 |
+
14:'Tubular Bells',
|
946 |
+
15:'Dulcimer',
|
947 |
+
16:'Drawbar Organ',
|
948 |
+
17:'Percussive Organ',
|
949 |
+
18:'Rock Organ',
|
950 |
+
19:'Church Organ',
|
951 |
+
20:'Reed Organ',
|
952 |
+
21:'Accordion',
|
953 |
+
22:'Harmonica',
|
954 |
+
23:'Tango Accordion',
|
955 |
+
24:'Acoustic Guitar(nylon)',
|
956 |
+
25:'Acoustic Guitar(steel)',
|
957 |
+
26:'Electric Guitar(jazz)',
|
958 |
+
27:'Electric Guitar(clean)',
|
959 |
+
28:'Electric Guitar(muted)',
|
960 |
+
29:'Overdriven Guitar',
|
961 |
+
30:'Distortion Guitar',
|
962 |
+
31:'Guitar Harmonics',
|
963 |
+
32:'Acoustic Bass',
|
964 |
+
33:'Electric Bass(finger)',
|
965 |
+
34:'Electric Bass(pick)',
|
966 |
+
35:'Fretless Bass',
|
967 |
+
36:'Slap Bass 1',
|
968 |
+
37:'Slap Bass 2',
|
969 |
+
38:'Synth Bass 1',
|
970 |
+
39:'Synth Bass 2',
|
971 |
+
40:'Violin',
|
972 |
+
41:'Viola',
|
973 |
+
42:'Cello',
|
974 |
+
43:'Contrabass',
|
975 |
+
44:'Tremolo Strings',
|
976 |
+
45:'Pizzicato Strings',
|
977 |
+
46:'Orchestral Harp',
|
978 |
+
47:'Timpani',
|
979 |
+
48:'String Ensemble 1',
|
980 |
+
49:'String Ensemble 2',
|
981 |
+
50:'SynthStrings 1',
|
982 |
+
51:'SynthStrings 2',
|
983 |
+
52:'Choir Aahs',
|
984 |
+
53:'Voice Oohs',
|
985 |
+
54:'Synth Voice',
|
986 |
+
55:'Orchestra Hit',
|
987 |
+
56:'Trumpet',
|
988 |
+
57:'Trombone',
|
989 |
+
58:'Tuba',
|
990 |
+
59:'Muted Trumpet',
|
991 |
+
60:'French Horn',
|
992 |
+
61:'Brass Section',
|
993 |
+
62:'SynthBrass 1',
|
994 |
+
63:'SynthBrass 2',
|
995 |
+
64:'Soprano Sax',
|
996 |
+
65:'Alto Sax',
|
997 |
+
66:'Tenor Sax',
|
998 |
+
67:'Baritone Sax',
|
999 |
+
68:'Oboe',
|
1000 |
+
69:'English Horn',
|
1001 |
+
70:'Bassoon',
|
1002 |
+
71:'Clarinet',
|
1003 |
+
72:'Piccolo',
|
1004 |
+
73:'Flute',
|
1005 |
+
74:'Recorder',
|
1006 |
+
75:'Pan Flute',
|
1007 |
+
76:'Blown Bottle',
|
1008 |
+
77:'Skakuhachi',
|
1009 |
+
78:'Whistle',
|
1010 |
+
79:'Ocarina',
|
1011 |
+
80:'Lead 1 (square)',
|
1012 |
+
81:'Lead 2 (sawtooth)',
|
1013 |
+
82:'Lead 3 (calliope)',
|
1014 |
+
83:'Lead 4 (chiff)',
|
1015 |
+
84:'Lead 5 (charang)',
|
1016 |
+
85:'Lead 6 (voice)',
|
1017 |
+
86:'Lead 7 (fifths)',
|
1018 |
+
87:'Lead 8 (bass+lead)',
|
1019 |
+
88:'Pad 1 (new age)',
|
1020 |
+
89:'Pad 2 (warm)',
|
1021 |
+
90:'Pad 3 (polysynth)',
|
1022 |
+
91:'Pad 4 (choir)',
|
1023 |
+
92:'Pad 5 (bowed)',
|
1024 |
+
93:'Pad 6 (metallic)',
|
1025 |
+
94:'Pad 7 (halo)',
|
1026 |
+
95:'Pad 8 (sweep)',
|
1027 |
+
96:'FX 1 (rain)',
|
1028 |
+
97:'FX 2 (soundtrack)',
|
1029 |
+
98:'FX 3 (crystal)',
|
1030 |
+
99:'FX 4 (atmosphere)',
|
1031 |
+
100:'FX 5 (brightness)',
|
1032 |
+
101:'FX 6 (goblins)',
|
1033 |
+
102:'FX 7 (echoes)',
|
1034 |
+
103:'FX 8 (sci-fi)',
|
1035 |
+
104:'Sitar',
|
1036 |
+
105:'Banjo',
|
1037 |
+
106:'Shamisen',
|
1038 |
+
107:'Koto',
|
1039 |
+
108:'Kalimba',
|
1040 |
+
109:'Bagpipe',
|
1041 |
+
110:'Fiddle',
|
1042 |
+
111:'Shanai',
|
1043 |
+
112:'Tinkle Bell',
|
1044 |
+
113:'Agogo',
|
1045 |
+
114:'Steel Drums',
|
1046 |
+
115:'Woodblock',
|
1047 |
+
116:'Taiko Drum',
|
1048 |
+
117:'Melodic Tom',
|
1049 |
+
118:'Synth Drum',
|
1050 |
+
119:'Reverse Cymbal',
|
1051 |
+
120:'Guitar Fret Noise',
|
1052 |
+
121:'Breath Noise',
|
1053 |
+
122:'Seashore',
|
1054 |
+
123:'Bird Tweet',
|
1055 |
+
124:'Telephone Ring',
|
1056 |
+
125:'Helicopter',
|
1057 |
+
126:'Applause',
|
1058 |
+
127:'Gunshot',
|
1059 |
+
}
|
1060 |
+
Notenum2percussion = { # General MIDI Percussion (on Channel 9):
|
1061 |
+
35:'Acoustic Bass Drum',
|
1062 |
+
36:'Bass Drum 1',
|
1063 |
+
37:'Side Stick',
|
1064 |
+
38:'Acoustic Snare',
|
1065 |
+
39:'Hand Clap',
|
1066 |
+
40:'Electric Snare',
|
1067 |
+
41:'Low Floor Tom',
|
1068 |
+
42:'Closed Hi-Hat',
|
1069 |
+
43:'High Floor Tom',
|
1070 |
+
44:'Pedal Hi-Hat',
|
1071 |
+
45:'Low Tom',
|
1072 |
+
46:'Open Hi-Hat',
|
1073 |
+
47:'Low-Mid Tom',
|
1074 |
+
48:'Hi-Mid Tom',
|
1075 |
+
49:'Crash Cymbal 1',
|
1076 |
+
50:'High Tom',
|
1077 |
+
51:'Ride Cymbal 1',
|
1078 |
+
52:'Chinese Cymbal',
|
1079 |
+
53:'Ride Bell',
|
1080 |
+
54:'Tambourine',
|
1081 |
+
55:'Splash Cymbal',
|
1082 |
+
56:'Cowbell',
|
1083 |
+
57:'Crash Cymbal 2',
|
1084 |
+
58:'Vibraslap',
|
1085 |
+
59:'Ride Cymbal 2',
|
1086 |
+
60:'Hi Bongo',
|
1087 |
+
61:'Low Bongo',
|
1088 |
+
62:'Mute Hi Conga',
|
1089 |
+
63:'Open Hi Conga',
|
1090 |
+
64:'Low Conga',
|
1091 |
+
65:'High Timbale',
|
1092 |
+
66:'Low Timbale',
|
1093 |
+
67:'High Agogo',
|
1094 |
+
68:'Low Agogo',
|
1095 |
+
69:'Cabasa',
|
1096 |
+
70:'Maracas',
|
1097 |
+
71:'Short Whistle',
|
1098 |
+
72:'Long Whistle',
|
1099 |
+
73:'Short Guiro',
|
1100 |
+
74:'Long Guiro',
|
1101 |
+
75:'Claves',
|
1102 |
+
76:'Hi Wood Block',
|
1103 |
+
77:'Low Wood Block',
|
1104 |
+
78:'Mute Cuica',
|
1105 |
+
79:'Open Cuica',
|
1106 |
+
80:'Mute Triangle',
|
1107 |
+
81:'Open Triangle',
|
1108 |
+
}
|
1109 |
+
|
1110 |
+
Event2channelindex = { 'note':3, 'note_off':2, 'note_on':2,
|
1111 |
+
'key_after_touch':2, 'control_change':2, 'patch_change':2,
|
1112 |
+
'channel_after_touch':2, 'pitch_wheel_change':2
|
1113 |
+
}
|
1114 |
+
|
1115 |
+
################################################################
|
1116 |
+
# The code below this line is full of frightening things, all to
|
1117 |
+
# do with the actual encoding and decoding of binary MIDI data.
|
1118 |
+
|
1119 |
+
def _twobytes2int(byte_a):
|
1120 |
+
r'''decode a 16 bit quantity from two bytes,'''
|
1121 |
+
return (byte_a[1] | (byte_a[0] << 8))
|
1122 |
+
|
1123 |
+
def _int2twobytes(int_16bit):
|
1124 |
+
r'''encode a 16 bit quantity into two bytes,'''
|
1125 |
+
return bytes([(int_16bit>>8) & 0xFF, int_16bit & 0xFF])
|
1126 |
+
|
1127 |
+
def _read_14_bit(byte_a):
|
1128 |
+
r'''decode a 14 bit quantity from two bytes,'''
|
1129 |
+
return (byte_a[0] | (byte_a[1] << 7))
|
1130 |
+
|
1131 |
+
def _write_14_bit(int_14bit):
|
1132 |
+
r'''encode a 14 bit quantity into two bytes,'''
|
1133 |
+
return bytes([int_14bit & 0x7F, (int_14bit>>7) & 0x7F])
|
1134 |
+
|
1135 |
+
def _ber_compressed_int(integer):
|
1136 |
+
r'''BER compressed integer (not an ASN.1 BER, see perlpacktut for
|
1137 |
+
details). Its bytes represent an unsigned integer in base 128,
|
1138 |
+
most significant digit first, with as few digits as possible.
|
1139 |
+
Bit eight (the high bit) is set on each byte except the last.
|
1140 |
+
'''
|
1141 |
+
ber = bytearray(b'')
|
1142 |
+
seven_bits = 0x7F & integer
|
1143 |
+
ber.insert(0, seven_bits) # XXX surely should convert to a char ?
|
1144 |
+
integer >>= 7
|
1145 |
+
while integer > 0:
|
1146 |
+
seven_bits = 0x7F & integer
|
1147 |
+
ber.insert(0, 0x80|seven_bits) # XXX surely should convert to a char ?
|
1148 |
+
integer >>= 7
|
1149 |
+
return ber
|
1150 |
+
|
1151 |
+
def _unshift_ber_int(ba):
|
1152 |
+
r'''Given a bytearray, returns a tuple of (the ber-integer at the
|
1153 |
+
start, and the remainder of the bytearray).
|
1154 |
+
'''
|
1155 |
+
if not len(ba): # 6.7
|
1156 |
+
_warn('_unshift_ber_int: no integer found')
|
1157 |
+
return ((0, b""))
|
1158 |
+
byte = ba.pop(0)
|
1159 |
+
integer = 0
|
1160 |
+
while True:
|
1161 |
+
integer += (byte & 0x7F)
|
1162 |
+
if not (byte & 0x80):
|
1163 |
+
return ((integer, ba))
|
1164 |
+
if not len(ba):
|
1165 |
+
_warn('_unshift_ber_int: no end-of-integer found')
|
1166 |
+
return ((0, ba))
|
1167 |
+
byte = ba.pop(0)
|
1168 |
+
integer <<= 7
|
1169 |
+
|
1170 |
+
def _clean_up_warnings(): # 5.4
|
1171 |
+
# Call this before returning from any publicly callable function
|
1172 |
+
# whenever there's a possibility that a warning might have been printed
|
1173 |
+
# by the function, or by any private functions it might have called.
|
1174 |
+
if _no_warning:
|
1175 |
+
return
|
1176 |
+
global _previous_times
|
1177 |
+
global _previous_warning
|
1178 |
+
if _previous_times > 1:
|
1179 |
+
# E:1176, 0: invalid syntax (<string>, line 1176) (syntax-error) ???
|
1180 |
+
# print(' previous message repeated '+str(_previous_times)+' times', file=sys.stderr)
|
1181 |
+
# 6.7
|
1182 |
+
sys.stderr.write(' previous message repeated {0} times\n'.format(_previous_times))
|
1183 |
+
elif _previous_times > 0:
|
1184 |
+
sys.stderr.write(' previous message repeated\n')
|
1185 |
+
_previous_times = 0
|
1186 |
+
_previous_warning = ''
|
1187 |
+
|
1188 |
+
def _warn(s=''):
|
1189 |
+
if _no_warning:
|
1190 |
+
return
|
1191 |
+
global _previous_times
|
1192 |
+
global _previous_warning
|
1193 |
+
if s == _previous_warning: # 5.4
|
1194 |
+
_previous_times = _previous_times + 1
|
1195 |
+
else:
|
1196 |
+
_clean_up_warnings()
|
1197 |
+
sys.stderr.write(str(s)+"\n")
|
1198 |
+
_previous_warning = s
|
1199 |
+
|
1200 |
+
def _some_text_event(which_kind=0x01, text=b'some_text'):
|
1201 |
+
if str(type(text)).find("'str'") >= 0: # 6.4 test for back-compatibility
|
1202 |
+
data = bytes(text, encoding='ISO-8859-1')
|
1203 |
+
else:
|
1204 |
+
data = bytes(text)
|
1205 |
+
return b'\xFF'+bytes((which_kind,))+_ber_compressed_int(len(data))+data
|
1206 |
+
|
1207 |
+
def _consistentise_ticks(scores): # 3.6
|
1208 |
+
# used by mix_scores, merge_scores, concatenate_scores
|
1209 |
+
if len(scores) == 1:
|
1210 |
+
return copy.deepcopy(scores)
|
1211 |
+
are_consistent = True
|
1212 |
+
ticks = scores[0][0]
|
1213 |
+
iscore = 1
|
1214 |
+
while iscore < len(scores):
|
1215 |
+
if scores[iscore][0] != ticks:
|
1216 |
+
are_consistent = False
|
1217 |
+
break
|
1218 |
+
iscore += 1
|
1219 |
+
if are_consistent:
|
1220 |
+
return copy.deepcopy(scores)
|
1221 |
+
new_scores = []
|
1222 |
+
iscore = 0
|
1223 |
+
while iscore < len(scores):
|
1224 |
+
score = scores[iscore]
|
1225 |
+
new_scores.append(opus2score(to_millisecs(score2opus(score))))
|
1226 |
+
iscore += 1
|
1227 |
+
return new_scores
|
1228 |
+
|
1229 |
+
|
1230 |
+
###########################################################################
|
1231 |
+
|
1232 |
+
def _decode(trackdata=b'', exclude=None, include=None,
|
1233 |
+
event_callback=None, exclusive_event_callback=None, no_eot_magic=False):
|
1234 |
+
r'''Decodes MIDI track data into an opus-style list of events.
|
1235 |
+
The options:
|
1236 |
+
'exclude' is a list of event types which will be ignored SHOULD BE A SET
|
1237 |
+
'include' (and no exclude), makes exclude a list
|
1238 |
+
of all possible events, /minus/ what include specifies
|
1239 |
+
'event_callback' is a coderef
|
1240 |
+
'exclusive_event_callback' is a coderef
|
1241 |
+
'''
|
1242 |
+
trackdata = bytearray(trackdata)
|
1243 |
+
if exclude == None:
|
1244 |
+
exclude = []
|
1245 |
+
if include == None:
|
1246 |
+
include = []
|
1247 |
+
if include and not exclude:
|
1248 |
+
exclude = All_events
|
1249 |
+
include = set(include)
|
1250 |
+
exclude = set(exclude)
|
1251 |
+
|
1252 |
+
# Pointer = 0; not used here; we eat through the bytearray instead.
|
1253 |
+
event_code = -1; # used for running status
|
1254 |
+
event_count = 0;
|
1255 |
+
events = []
|
1256 |
+
|
1257 |
+
while(len(trackdata)):
|
1258 |
+
# loop while there's anything to analyze ...
|
1259 |
+
eot = False # When True, the event registrar aborts this loop
|
1260 |
+
event_count += 1
|
1261 |
+
|
1262 |
+
E = []
|
1263 |
+
# E for events - we'll feed it to the event registrar at the end.
|
1264 |
+
|
1265 |
+
# Slice off the delta time code, and analyze it
|
1266 |
+
[time, remainder] = _unshift_ber_int(trackdata)
|
1267 |
+
|
1268 |
+
# Now let's see what we can make of the command
|
1269 |
+
first_byte = trackdata.pop(0) & 0xFF
|
1270 |
+
|
1271 |
+
if (first_byte < 0xF0): # It's a MIDI event
|
1272 |
+
if (first_byte & 0x80):
|
1273 |
+
event_code = first_byte
|
1274 |
+
else:
|
1275 |
+
# It wants running status; use last event_code value
|
1276 |
+
trackdata.insert(0, first_byte)
|
1277 |
+
if (event_code == -1):
|
1278 |
+
_warn("Running status not set; Aborting track.")
|
1279 |
+
return []
|
1280 |
+
|
1281 |
+
command = event_code & 0xF0
|
1282 |
+
channel = event_code & 0x0F
|
1283 |
+
|
1284 |
+
if (command == 0xF6): # 0-byte argument
|
1285 |
+
pass
|
1286 |
+
elif (command == 0xC0 or command == 0xD0): # 1-byte argument
|
1287 |
+
parameter = trackdata.pop(0) # could be B
|
1288 |
+
else: # 2-byte argument could be BB or 14-bit
|
1289 |
+
parameter = (trackdata.pop(0), trackdata.pop(0))
|
1290 |
+
|
1291 |
+
#################################################################
|
1292 |
+
# MIDI events
|
1293 |
+
|
1294 |
+
if (command == 0x80):
|
1295 |
+
if 'note_off' in exclude:
|
1296 |
+
continue
|
1297 |
+
E = ['note_off', time, channel, parameter[0], parameter[1]]
|
1298 |
+
elif (command == 0x90):
|
1299 |
+
if 'note_on' in exclude:
|
1300 |
+
continue
|
1301 |
+
E = ['note_on', time, channel, parameter[0], parameter[1]]
|
1302 |
+
elif (command == 0xA0):
|
1303 |
+
if 'key_after_touch' in exclude:
|
1304 |
+
continue
|
1305 |
+
E = ['key_after_touch',time,channel,parameter[0],parameter[1]]
|
1306 |
+
elif (command == 0xB0):
|
1307 |
+
if 'control_change' in exclude:
|
1308 |
+
continue
|
1309 |
+
E = ['control_change',time,channel,parameter[0],parameter[1]]
|
1310 |
+
elif (command == 0xC0):
|
1311 |
+
if 'patch_change' in exclude:
|
1312 |
+
continue
|
1313 |
+
E = ['patch_change', time, channel, parameter]
|
1314 |
+
elif (command == 0xD0):
|
1315 |
+
if 'channel_after_touch' in exclude:
|
1316 |
+
continue
|
1317 |
+
E = ['channel_after_touch', time, channel, parameter]
|
1318 |
+
elif (command == 0xE0):
|
1319 |
+
if 'pitch_wheel_change' in exclude:
|
1320 |
+
continue
|
1321 |
+
E = ['pitch_wheel_change', time, channel,
|
1322 |
+
_read_14_bit(parameter)-0x2000]
|
1323 |
+
else:
|
1324 |
+
_warn("Shouldn't get here; command="+hex(command))
|
1325 |
+
|
1326 |
+
elif (first_byte == 0xFF): # It's a Meta-Event! ##################
|
1327 |
+
#[command, length, remainder] =
|
1328 |
+
# unpack("xCwa*", substr(trackdata, $Pointer, 6));
|
1329 |
+
#Pointer += 6 - len(remainder);
|
1330 |
+
# # Move past JUST the length-encoded.
|
1331 |
+
command = trackdata.pop(0) & 0xFF
|
1332 |
+
[length, trackdata] = _unshift_ber_int(trackdata)
|
1333 |
+
if (command == 0x00):
|
1334 |
+
if (length == 2):
|
1335 |
+
E = ['set_sequence_number',time,_twobytes2int(trackdata)]
|
1336 |
+
else:
|
1337 |
+
_warn('set_sequence_number: length must be 2, not '+str(length))
|
1338 |
+
E = ['set_sequence_number', time, 0]
|
1339 |
+
|
1340 |
+
elif command >= 0x01 and command <= 0x0f: # Text events
|
1341 |
+
# 6.2 take it in bytes; let the user get the right encoding.
|
1342 |
+
# text_str = trackdata[0:length].decode('ascii','ignore')
|
1343 |
+
# text_str = trackdata[0:length].decode('ISO-8859-1')
|
1344 |
+
# 6.4 take it in bytes; let the user get the right encoding.
|
1345 |
+
text_data = bytes(trackdata[0:length]) # 6.4
|
1346 |
+
# Defined text events
|
1347 |
+
if (command == 0x01):
|
1348 |
+
E = ['text_event', time, text_data]
|
1349 |
+
elif (command == 0x02):
|
1350 |
+
E = ['copyright_text_event', time, text_data]
|
1351 |
+
elif (command == 0x03):
|
1352 |
+
E = ['track_name', time, text_data]
|
1353 |
+
elif (command == 0x04):
|
1354 |
+
E = ['instrument_name', time, text_data]
|
1355 |
+
elif (command == 0x05):
|
1356 |
+
E = ['lyric', time, text_data]
|
1357 |
+
elif (command == 0x06):
|
1358 |
+
E = ['marker', time, text_data]
|
1359 |
+
elif (command == 0x07):
|
1360 |
+
E = ['cue_point', time, text_data]
|
1361 |
+
# Reserved but apparently unassigned text events
|
1362 |
+
elif (command == 0x08):
|
1363 |
+
E = ['text_event_08', time, text_data]
|
1364 |
+
elif (command == 0x09):
|
1365 |
+
E = ['text_event_09', time, text_data]
|
1366 |
+
elif (command == 0x0a):
|
1367 |
+
E = ['text_event_0a', time, text_data]
|
1368 |
+
elif (command == 0x0b):
|
1369 |
+
E = ['text_event_0b', time, text_data]
|
1370 |
+
elif (command == 0x0c):
|
1371 |
+
E = ['text_event_0c', time, text_data]
|
1372 |
+
elif (command == 0x0d):
|
1373 |
+
E = ['text_event_0d', time, text_data]
|
1374 |
+
elif (command == 0x0e):
|
1375 |
+
E = ['text_event_0e', time, text_data]
|
1376 |
+
elif (command == 0x0f):
|
1377 |
+
E = ['text_event_0f', time, text_data]
|
1378 |
+
|
1379 |
+
# Now the sticky events -------------------------------------
|
1380 |
+
elif (command == 0x2F):
|
1381 |
+
E = ['end_track', time]
|
1382 |
+
# The code for handling this, oddly, comes LATER,
|
1383 |
+
# in the event registrar.
|
1384 |
+
elif (command == 0x51): # DTime, Microseconds/Crochet
|
1385 |
+
if length != 3:
|
1386 |
+
_warn('set_tempo event, but length='+str(length))
|
1387 |
+
E = ['set_tempo', time,
|
1388 |
+
struct.unpack(">I", b'\x00'+trackdata[0:3])[0]]
|
1389 |
+
elif (command == 0x54):
|
1390 |
+
if length != 5: # DTime, HR, MN, SE, FR, FF
|
1391 |
+
_warn('smpte_offset event, but length='+str(length))
|
1392 |
+
E = ['smpte_offset',time] + list(struct.unpack(">BBBBB",trackdata[0:5]))
|
1393 |
+
elif (command == 0x58):
|
1394 |
+
if length != 4: # DTime, NN, DD, CC, BB
|
1395 |
+
_warn('time_signature event, but length='+str(length))
|
1396 |
+
E = ['time_signature', time]+list(trackdata[0:4])
|
1397 |
+
elif (command == 0x59):
|
1398 |
+
if length != 2: # DTime, SF(signed), MI
|
1399 |
+
_warn('key_signature event, but length='+str(length))
|
1400 |
+
E = ['key_signature',time] + list(struct.unpack(">bB",trackdata[0:2]))
|
1401 |
+
elif (command == 0x7F): # 6.4
|
1402 |
+
E = ['sequencer_specific',time, bytes(trackdata[0:length])]
|
1403 |
+
else:
|
1404 |
+
E = ['raw_meta_event', time, command,
|
1405 |
+
bytes(trackdata[0:length])] # 6.0
|
1406 |
+
#"[uninterpretable meta-event command of length length]"
|
1407 |
+
# DTime, Command, Binary Data
|
1408 |
+
# It's uninterpretable; record it as raw_data.
|
1409 |
+
|
1410 |
+
# Pointer += length; # Now move Pointer
|
1411 |
+
trackdata = trackdata[length:]
|
1412 |
+
|
1413 |
+
######################################################################
|
1414 |
+
elif (first_byte == 0xF0 or first_byte == 0xF7):
|
1415 |
+
# Note that sysexes in MIDI /files/ are different than sysexes
|
1416 |
+
# in MIDI transmissions!! The vast majority of system exclusive
|
1417 |
+
# messages will just use the F0 format. For instance, the
|
1418 |
+
# transmitted message F0 43 12 00 07 F7 would be stored in a
|
1419 |
+
# MIDI file as F0 05 43 12 00 07 F7. As mentioned above, it is
|
1420 |
+
# required to include the F7 at the end so that the reader of the
|
1421 |
+
# MIDI file knows that it has read the entire message. (But the F7
|
1422 |
+
# is omitted if this is a non-final block in a multiblock sysex;
|
1423 |
+
# but the F7 (if there) is counted in the message's declared
|
1424 |
+
# length, so we don't have to think about it anyway.)
|
1425 |
+
#command = trackdata.pop(0)
|
1426 |
+
[length, trackdata] = _unshift_ber_int(trackdata)
|
1427 |
+
if first_byte == 0xF0:
|
1428 |
+
# 20091008 added ISO-8859-1 to get an 8-bit str
|
1429 |
+
# 6.4 return bytes instead
|
1430 |
+
E = ['sysex_f0', time, bytes(trackdata[0:length])]
|
1431 |
+
else:
|
1432 |
+
E = ['sysex_f7', time, bytes(trackdata[0:length])]
|
1433 |
+
trackdata = trackdata[length:]
|
1434 |
+
|
1435 |
+
######################################################################
|
1436 |
+
# Now, the MIDI file spec says:
|
1437 |
+
# <track data> = <MTrk event>+
|
1438 |
+
# <MTrk event> = <delta-time> <event>
|
1439 |
+
# <event> = <MIDI event> | <sysex event> | <meta-event>
|
1440 |
+
# I know that, on the wire, <MIDI event> can include note_on,
|
1441 |
+
# note_off, and all the other 8x to Ex events, AND Fx events
|
1442 |
+
# other than F0, F7, and FF -- namely, <song position msg>,
|
1443 |
+
# <song select msg>, and <tune request>.
|
1444 |
+
#
|
1445 |
+
# Whether these can occur in MIDI files is not clear specified
|
1446 |
+
# from the MIDI file spec. So, I'm going to assume that
|
1447 |
+
# they CAN, in practice, occur. I don't know whether it's
|
1448 |
+
# proper for you to actually emit these into a MIDI file.
|
1449 |
+
|
1450 |
+
elif (first_byte == 0xF2): # DTime, Beats
|
1451 |
+
# <song position msg> ::= F2 <data pair>
|
1452 |
+
E = ['song_position', time, _read_14_bit(trackdata[:2])]
|
1453 |
+
trackdata = trackdata[2:]
|
1454 |
+
|
1455 |
+
elif (first_byte == 0xF3): # <song select msg> ::= F3 <data singlet>
|
1456 |
+
# E = ['song_select', time, struct.unpack('>B',trackdata.pop(0))[0]]
|
1457 |
+
E = ['song_select', time, trackdata[0]]
|
1458 |
+
trackdata = trackdata[1:]
|
1459 |
+
# DTime, Thing (what?! song number? whatever ...)
|
1460 |
+
|
1461 |
+
elif (first_byte == 0xF6): # DTime
|
1462 |
+
E = ['tune_request', time]
|
1463 |
+
# What would a tune request be doing in a MIDI /file/?
|
1464 |
+
|
1465 |
+
#########################################################
|
1466 |
+
# ADD MORE META-EVENTS HERE. TODO:
|
1467 |
+
# f1 -- MTC Quarter Frame Message. One data byte follows
|
1468 |
+
# the Status; it's the time code value, from 0 to 127.
|
1469 |
+
# f8 -- MIDI clock. no data.
|
1470 |
+
# fa -- MIDI start. no data.
|
1471 |
+
# fb -- MIDI continue. no data.
|
1472 |
+
# fc -- MIDI stop. no data.
|
1473 |
+
# fe -- Active sense. no data.
|
1474 |
+
# f4 f5 f9 fd -- unallocated
|
1475 |
+
|
1476 |
+
r'''
|
1477 |
+
elif (first_byte > 0xF0) { # Some unknown kinda F-series event ####
|
1478 |
+
# Here we only produce a one-byte piece of raw data.
|
1479 |
+
# But the encoder for 'raw_data' accepts any length of it.
|
1480 |
+
E = [ 'raw_data',
|
1481 |
+
time, substr(trackdata,Pointer,1) ]
|
1482 |
+
# DTime and the Data (in this case, the one Event-byte)
|
1483 |
+
++Pointer; # itself
|
1484 |
+
|
1485 |
+
'''
|
1486 |
+
elif first_byte > 0xF0: # Some unknown F-series event
|
1487 |
+
# Here we only produce a one-byte piece of raw data.
|
1488 |
+
# E = ['raw_data', time, bytest(trackdata[0])] # 6.4
|
1489 |
+
E = ['raw_data', time, trackdata[0]] # 6.4 6.7
|
1490 |
+
trackdata = trackdata[1:]
|
1491 |
+
else: # Fallthru.
|
1492 |
+
_warn("Aborting track. Command-byte first_byte="+hex(first_byte))
|
1493 |
+
break
|
1494 |
+
# End of the big if-group
|
1495 |
+
|
1496 |
+
|
1497 |
+
######################################################################
|
1498 |
+
# THE EVENT REGISTRAR...
|
1499 |
+
if E and (E[0] == 'end_track'):
|
1500 |
+
# This is the code for exceptional handling of the EOT event.
|
1501 |
+
eot = True
|
1502 |
+
if not no_eot_magic:
|
1503 |
+
if E[1] > 0: # a null text-event to carry the delta-time
|
1504 |
+
E = ['text_event', E[1], '']
|
1505 |
+
else:
|
1506 |
+
E = [] # EOT with a delta-time of 0; ignore it.
|
1507 |
+
|
1508 |
+
if E and not (E[0] in exclude):
|
1509 |
+
#if ( $exclusive_event_callback ):
|
1510 |
+
# &{ $exclusive_event_callback }( @E );
|
1511 |
+
#else:
|
1512 |
+
# &{ $event_callback }( @E ) if $event_callback;
|
1513 |
+
events.append(E)
|
1514 |
+
if eot:
|
1515 |
+
break
|
1516 |
+
|
1517 |
+
# End of the big "Event" while-block
|
1518 |
+
|
1519 |
+
return events
|
1520 |
+
|
1521 |
+
|
1522 |
+
###########################################################################
|
1523 |
+
def _encode(events_lol, unknown_callback=None, never_add_eot=False,
|
1524 |
+
no_eot_magic=False, no_running_status=False):
|
1525 |
+
# encode an event structure, presumably for writing to a file
|
1526 |
+
# Calling format:
|
1527 |
+
# $data_r = MIDI::Event::encode( \@event_lol, { options } );
|
1528 |
+
# Takes a REFERENCE to an event structure (a LoL)
|
1529 |
+
# Returns an (unblessed) REFERENCE to track data.
|
1530 |
+
|
1531 |
+
# If you want to use this to encode a /single/ event,
|
1532 |
+
# you still have to do it as a reference to an event structure (a LoL)
|
1533 |
+
# that just happens to have just one event. I.e.,
|
1534 |
+
# encode( [ $event ] ) or encode( [ [ 'note_on', 100, 5, 42, 64] ] )
|
1535 |
+
# If you're doing this, consider the never_add_eot track option, as in
|
1536 |
+
# print MIDI ${ encode( [ $event], { 'never_add_eot' => 1} ) };
|
1537 |
+
|
1538 |
+
data = [] # what I'll store the chunks of byte-data in
|
1539 |
+
|
1540 |
+
# This is so my end_track magic won't corrupt the original
|
1541 |
+
events = copy.deepcopy(events_lol)
|
1542 |
+
|
1543 |
+
if not never_add_eot:
|
1544 |
+
# One way or another, tack on an 'end_track'
|
1545 |
+
if events:
|
1546 |
+
last = events[-1]
|
1547 |
+
if not (last[0] == 'end_track'): # no end_track already
|
1548 |
+
if (last[0] == 'text_event' and len(last[2]) == 0):
|
1549 |
+
# 0-length text event at track-end.
|
1550 |
+
if no_eot_magic:
|
1551 |
+
# Exceptional case: don't mess with track-final
|
1552 |
+
# 0-length text_events; just peg on an end_track
|
1553 |
+
events.append(['end_track', 0])
|
1554 |
+
else:
|
1555 |
+
# NORMAL CASE: replace with an end_track, leaving DTime
|
1556 |
+
last[0] = 'end_track'
|
1557 |
+
else:
|
1558 |
+
# last event was neither 0-length text_event nor end_track
|
1559 |
+
events.append(['end_track', 0])
|
1560 |
+
else: # an eventless track!
|
1561 |
+
events = [['end_track', 0],]
|
1562 |
+
|
1563 |
+
# maybe_running_status = not no_running_status # unused? 4.7
|
1564 |
+
last_status = -1
|
1565 |
+
|
1566 |
+
for event_r in (events):
|
1567 |
+
E = copy.deepcopy(event_r)
|
1568 |
+
# otherwise the shifting'd corrupt the original
|
1569 |
+
if not E:
|
1570 |
+
continue
|
1571 |
+
|
1572 |
+
event = E.pop(0)
|
1573 |
+
if not len(event):
|
1574 |
+
continue
|
1575 |
+
|
1576 |
+
dtime = int(E.pop(0))
|
1577 |
+
# print('event='+str(event)+' dtime='+str(dtime))
|
1578 |
+
|
1579 |
+
event_data = ''
|
1580 |
+
|
1581 |
+
if ( # MIDI events -- eligible for running status
|
1582 |
+
event == 'note_on'
|
1583 |
+
or event == 'note_off'
|
1584 |
+
or event == 'control_change'
|
1585 |
+
or event == 'key_after_touch'
|
1586 |
+
or event == 'patch_change'
|
1587 |
+
or event == 'channel_after_touch'
|
1588 |
+
or event == 'pitch_wheel_change' ):
|
1589 |
+
|
1590 |
+
# This block is where we spend most of the time. Gotta be tight.
|
1591 |
+
if (event == 'note_off'):
|
1592 |
+
status = 0x80 | (int(E[0]) & 0x0F)
|
1593 |
+
parameters = struct.pack('>BB', int(E[1])&0x7F, int(E[2])&0x7F)
|
1594 |
+
elif (event == 'note_on'):
|
1595 |
+
status = 0x90 | (int(E[0]) & 0x0F)
|
1596 |
+
parameters = struct.pack('>BB', int(E[1])&0x7F, int(E[2])&0x7F)
|
1597 |
+
elif (event == 'key_after_touch'):
|
1598 |
+
status = 0xA0 | (int(E[0]) & 0x0F)
|
1599 |
+
parameters = struct.pack('>BB', int(E[1])&0x7F, int(E[2])&0x7F)
|
1600 |
+
elif (event == 'control_change'):
|
1601 |
+
status = 0xB0 | (int(E[0]) & 0x0F)
|
1602 |
+
parameters = struct.pack('>BB', int(E[1])&0xFF, int(E[2])&0xFF)
|
1603 |
+
elif (event == 'patch_change'):
|
1604 |
+
status = 0xC0 | (int(E[0]) & 0x0F)
|
1605 |
+
parameters = struct.pack('>B', int(E[1]) & 0xFF)
|
1606 |
+
elif (event == 'channel_after_touch'):
|
1607 |
+
status = 0xD0 | (int(E[0]) & 0x0F)
|
1608 |
+
parameters = struct.pack('>B', int(E[1]) & 0xFF)
|
1609 |
+
elif (event == 'pitch_wheel_change'):
|
1610 |
+
status = 0xE0 | (int(E[0]) & 0x0F)
|
1611 |
+
parameters = _write_14_bit(int(E[1]) + 0x2000)
|
1612 |
+
else:
|
1613 |
+
_warn("BADASS FREAKOUT ERROR 31415!")
|
1614 |
+
|
1615 |
+
# And now the encoding
|
1616 |
+
# w = BER compressed integer (not ASN.1 BER, see perlpacktut for
|
1617 |
+
# details). Its bytes represent an unsigned integer in base 128,
|
1618 |
+
# most significant digit first, with as few digits as possible.
|
1619 |
+
# Bit eight (the high bit) is set on each byte except the last.
|
1620 |
+
|
1621 |
+
data.append(_ber_compressed_int(dtime))
|
1622 |
+
if (status != last_status) or no_running_status:
|
1623 |
+
data.append(struct.pack('>B', status))
|
1624 |
+
data.append(parameters)
|
1625 |
+
|
1626 |
+
last_status = status
|
1627 |
+
continue
|
1628 |
+
else:
|
1629 |
+
# Not a MIDI event.
|
1630 |
+
# All the code in this block could be more efficient,
|
1631 |
+
# but this is not where the code needs to be tight.
|
1632 |
+
# print "zaz $event\n";
|
1633 |
+
last_status = -1
|
1634 |
+
|
1635 |
+
if event == 'raw_meta_event':
|
1636 |
+
event_data = _some_text_event(int(E[0]), E[1])
|
1637 |
+
elif (event == 'set_sequence_number'): # 3.9
|
1638 |
+
event_data = b'\xFF\x00\x02'+_int2twobytes(E[0])
|
1639 |
+
|
1640 |
+
# Text meta-events...
|
1641 |
+
# a case for a dict, I think (pjb) ...
|
1642 |
+
elif (event == 'text_event'):
|
1643 |
+
event_data = _some_text_event(0x01, E[0])
|
1644 |
+
elif (event == 'copyright_text_event'):
|
1645 |
+
event_data = _some_text_event(0x02, E[0])
|
1646 |
+
elif (event == 'track_name'):
|
1647 |
+
event_data = _some_text_event(0x03, E[0])
|
1648 |
+
elif (event == 'instrument_name'):
|
1649 |
+
event_data = _some_text_event(0x04, E[0])
|
1650 |
+
elif (event == 'lyric'):
|
1651 |
+
event_data = _some_text_event(0x05, E[0])
|
1652 |
+
elif (event == 'marker'):
|
1653 |
+
event_data = _some_text_event(0x06, E[0])
|
1654 |
+
elif (event == 'cue_point'):
|
1655 |
+
event_data = _some_text_event(0x07, E[0])
|
1656 |
+
elif (event == 'text_event_08'):
|
1657 |
+
event_data = _some_text_event(0x08, E[0])
|
1658 |
+
elif (event == 'text_event_09'):
|
1659 |
+
event_data = _some_text_event(0x09, E[0])
|
1660 |
+
elif (event == 'text_event_0a'):
|
1661 |
+
event_data = _some_text_event(0x0A, E[0])
|
1662 |
+
elif (event == 'text_event_0b'):
|
1663 |
+
event_data = _some_text_event(0x0B, E[0])
|
1664 |
+
elif (event == 'text_event_0c'):
|
1665 |
+
event_data = _some_text_event(0x0C, E[0])
|
1666 |
+
elif (event == 'text_event_0d'):
|
1667 |
+
event_data = _some_text_event(0x0D, E[0])
|
1668 |
+
elif (event == 'text_event_0e'):
|
1669 |
+
event_data = _some_text_event(0x0E, E[0])
|
1670 |
+
elif (event == 'text_event_0f'):
|
1671 |
+
event_data = _some_text_event(0x0F, E[0])
|
1672 |
+
# End of text meta-events
|
1673 |
+
|
1674 |
+
elif (event == 'end_track'):
|
1675 |
+
event_data = b"\xFF\x2F\x00"
|
1676 |
+
|
1677 |
+
elif (event == 'set_tempo'):
|
1678 |
+
#event_data = struct.pack(">BBwa*", 0xFF, 0x51, 3,
|
1679 |
+
# substr( struct.pack('>I', E[0]), 1, 3))
|
1680 |
+
event_data = b'\xFF\x51\x03'+struct.pack('>I',E[0])[1:]
|
1681 |
+
elif (event == 'smpte_offset'):
|
1682 |
+
# event_data = struct.pack(">BBwBBBBB", 0xFF, 0x54, 5, E[0:5] )
|
1683 |
+
event_data = struct.pack(">BBBbBBBB", 0xFF,0x54,0x05,E[0],E[1],E[2],E[3],E[4])
|
1684 |
+
elif (event == 'time_signature'):
|
1685 |
+
# event_data = struct.pack(">BBwBBBB", 0xFF, 0x58, 4, E[0:4] )
|
1686 |
+
event_data = struct.pack(">BBBbBBB", 0xFF, 0x58, 0x04, E[0],E[1],E[2],E[3])
|
1687 |
+
elif (event == 'key_signature'):
|
1688 |
+
event_data = struct.pack(">BBBbB", 0xFF, 0x59, 0x02, E[0],E[1])
|
1689 |
+
elif (event == 'sequencer_specific'):
|
1690 |
+
# event_data = struct.pack(">BBwa*", 0xFF,0x7F, len(E[0]), E[0])
|
1691 |
+
event_data = _some_text_event(0x7F, E[0])
|
1692 |
+
# End of Meta-events
|
1693 |
+
|
1694 |
+
# Other Things...
|
1695 |
+
elif (event == 'sysex_f0'):
|
1696 |
+
#event_data = struct.pack(">Bwa*", 0xF0, len(E[0]), E[0])
|
1697 |
+
#B=bitstring w=BER-compressed-integer a=null-padded-ascii-str
|
1698 |
+
event_data = bytearray(b'\xF0')+_ber_compressed_int(len(E[0]))+bytearray(E[0])
|
1699 |
+
elif (event == 'sysex_f7'):
|
1700 |
+
#event_data = struct.pack(">Bwa*", 0xF7, len(E[0]), E[0])
|
1701 |
+
event_data = bytearray(b'\xF7')+_ber_compressed_int(len(E[0]))+bytearray(E[0])
|
1702 |
+
|
1703 |
+
elif (event == 'song_position'):
|
1704 |
+
event_data = b"\xF2" + _write_14_bit( E[0] )
|
1705 |
+
elif (event == 'song_select'):
|
1706 |
+
event_data = struct.pack('>BB', 0xF3, E[0] )
|
1707 |
+
elif (event == 'tune_request'):
|
1708 |
+
event_data = b"\xF6"
|
1709 |
+
elif (event == 'raw_data'):
|
1710 |
+
_warn("_encode: raw_data event not supported")
|
1711 |
+
# event_data = E[0]
|
1712 |
+
continue
|
1713 |
+
# End of Other Stuff
|
1714 |
+
|
1715 |
+
else:
|
1716 |
+
# The Big Fallthru
|
1717 |
+
if unknown_callback:
|
1718 |
+
# push(@data, &{ $unknown_callback }( @$event_r ))
|
1719 |
+
pass
|
1720 |
+
else:
|
1721 |
+
_warn("Unknown event: "+str(event))
|
1722 |
+
# To surpress complaint here, just set
|
1723 |
+
# 'unknown_callback' => sub { return () }
|
1724 |
+
continue
|
1725 |
+
|
1726 |
+
#print "Event $event encoded part 2\n"
|
1727 |
+
if str(type(event_data)).find("'str'") >= 0:
|
1728 |
+
event_data = bytearray(event_data.encode('Latin1', 'ignore'))
|
1729 |
+
if len(event_data): # how could $event_data be empty
|
1730 |
+
# data.append(struct.pack('>wa*', dtime, event_data))
|
1731 |
+
# print(' event_data='+str(event_data))
|
1732 |
+
data.append(_ber_compressed_int(dtime)+event_data)
|
1733 |
+
|
1734 |
+
return b''.join(data)
|
1735 |
+
|
README.md
CHANGED
@@ -1,12 +1,13 @@
|
|
1 |
-
---
|
2 |
-
title: Music
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 5.0.
|
8 |
-
app_file:
|
9 |
-
pinned:
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
1 |
+
---
|
2 |
+
title: Midi Music Generator
|
3 |
+
emoji: 🎼🎶
|
4 |
+
colorFrom: red
|
5 |
+
colorTo: indigo
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 5.0.1
|
8 |
+
app_file: app_onnx.py
|
9 |
+
pinned: true
|
10 |
+
license: apache-2.0
|
11 |
+
---
|
12 |
+
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,534 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spaces
|
2 |
+
import random
|
3 |
+
import argparse
|
4 |
+
import glob
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
import time
|
8 |
+
from concurrent.futures import ThreadPoolExecutor
|
9 |
+
|
10 |
+
import gradio as gr
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import torch.nn.functional as F
|
14 |
+
import tqdm
|
15 |
+
from huggingface_hub import hf_hub_download
|
16 |
+
from transformers import DynamicCache
|
17 |
+
|
18 |
+
import MIDI
|
19 |
+
from midi_model import MIDIModel, MIDIModelConfig
|
20 |
+
from midi_synthesizer import MidiSynthesizer
|
21 |
+
|
22 |
+
MAX_SEED = np.iinfo(np.int32).max
|
23 |
+
in_space = os.getenv("SYSTEM") == "spaces"
|
24 |
+
|
25 |
+
|
26 |
+
@torch.inference_mode()
|
27 |
+
def generate(model: MIDIModel, prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98, top_k=20,
|
28 |
+
disable_patch_change=False, disable_control_change=False, disable_channels=None, generator=None):
|
29 |
+
tokenizer = model.tokenizer
|
30 |
+
if disable_channels is not None:
|
31 |
+
disable_channels = [tokenizer.parameter_ids["channel"][c] for c in disable_channels]
|
32 |
+
else:
|
33 |
+
disable_channels = []
|
34 |
+
max_token_seq = tokenizer.max_token_seq
|
35 |
+
if prompt is None:
|
36 |
+
input_tensor = torch.full((1, max_token_seq), tokenizer.pad_id, dtype=torch.long, device=model.device)
|
37 |
+
input_tensor[0, 0] = tokenizer.bos_id # bos
|
38 |
+
input_tensor = input_tensor.unsqueeze(0)
|
39 |
+
input_tensor = torch.cat([input_tensor] * batch_size, dim=0)
|
40 |
+
else:
|
41 |
+
if len(prompt.shape) == 2:
|
42 |
+
prompt = prompt[None, :]
|
43 |
+
prompt = np.repeat(prompt, repeats=batch_size, axis=0)
|
44 |
+
elif prompt.shape[0] == 1:
|
45 |
+
prompt = np.repeat(prompt, repeats=batch_size, axis=0)
|
46 |
+
elif len(prompt.shape) != 3 or prompt.shape[0] != batch_size:
|
47 |
+
raise ValueError(f"invalid shape for prompt, {prompt.shape}")
|
48 |
+
prompt = prompt[..., :max_token_seq]
|
49 |
+
if prompt.shape[-1] < max_token_seq:
|
50 |
+
prompt = np.pad(prompt, ((0, 0), (0, 0), (0, max_token_seq - prompt.shape[-1])),
|
51 |
+
mode="constant", constant_values=tokenizer.pad_id)
|
52 |
+
input_tensor = torch.from_numpy(prompt).to(dtype=torch.long, device=model.device)
|
53 |
+
cur_len = input_tensor.shape[1]
|
54 |
+
bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
|
55 |
+
cache1 = DynamicCache()
|
56 |
+
past_len = 0
|
57 |
+
with bar:
|
58 |
+
while cur_len < max_len:
|
59 |
+
end = [False] * batch_size
|
60 |
+
hidden = model.forward(input_tensor[:, past_len:], cache=cache1)[:, -1]
|
61 |
+
next_token_seq = None
|
62 |
+
event_names = [""] * batch_size
|
63 |
+
cache2 = DynamicCache()
|
64 |
+
for i in range(max_token_seq):
|
65 |
+
mask = torch.zeros((batch_size, tokenizer.vocab_size), dtype=torch.int64, device=model.device)
|
66 |
+
for b in range(batch_size):
|
67 |
+
if end[b]:
|
68 |
+
mask[b, tokenizer.pad_id] = 1
|
69 |
+
continue
|
70 |
+
if i == 0:
|
71 |
+
mask_ids = list(tokenizer.event_ids.values()) + [tokenizer.eos_id]
|
72 |
+
if disable_patch_change:
|
73 |
+
mask_ids.remove(tokenizer.event_ids["patch_change"])
|
74 |
+
if disable_control_change:
|
75 |
+
mask_ids.remove(tokenizer.event_ids["control_change"])
|
76 |
+
mask[b, mask_ids] = 1
|
77 |
+
else:
|
78 |
+
param_names = tokenizer.events[event_names[b]]
|
79 |
+
if i > len(param_names):
|
80 |
+
mask[b, tokenizer.pad_id] = 1
|
81 |
+
continue
|
82 |
+
param_name = param_names[i - 1]
|
83 |
+
mask_ids = tokenizer.parameter_ids[param_name]
|
84 |
+
if param_name == "channel":
|
85 |
+
mask_ids = [i for i in mask_ids if i not in disable_channels]
|
86 |
+
mask[b, mask_ids] = 1
|
87 |
+
mask = mask.unsqueeze(1)
|
88 |
+
x = next_token_seq
|
89 |
+
if i != 0:
|
90 |
+
hidden = None
|
91 |
+
x = x[:, -1:]
|
92 |
+
logits = model.forward_token(hidden, x, cache=cache2)[:, -1:]
|
93 |
+
scores = torch.softmax(logits / temp, dim=-1) * mask
|
94 |
+
samples = model.sample_top_p_k(scores, top_p, top_k, generator=generator)
|
95 |
+
if i == 0:
|
96 |
+
next_token_seq = samples
|
97 |
+
for b in range(batch_size):
|
98 |
+
if end[b]:
|
99 |
+
continue
|
100 |
+
eid = samples[b].item()
|
101 |
+
if eid == tokenizer.eos_id:
|
102 |
+
end[b] = True
|
103 |
+
else:
|
104 |
+
event_names[b] = tokenizer.id_events[eid]
|
105 |
+
else:
|
106 |
+
next_token_seq = torch.cat([next_token_seq, samples], dim=1)
|
107 |
+
if all([len(tokenizer.events[event_names[b]]) == i for b in range(batch_size) if not end[b]]):
|
108 |
+
break
|
109 |
+
if next_token_seq.shape[1] < max_token_seq:
|
110 |
+
next_token_seq = F.pad(next_token_seq, (0, max_token_seq - next_token_seq.shape[1]),
|
111 |
+
"constant", value=tokenizer.pad_id)
|
112 |
+
next_token_seq = next_token_seq.unsqueeze(1)
|
113 |
+
input_tensor = torch.cat([input_tensor, next_token_seq], dim=1)
|
114 |
+
past_len = cur_len
|
115 |
+
cur_len += 1
|
116 |
+
bar.update(1)
|
117 |
+
yield next_token_seq[:, 0].cpu().numpy()
|
118 |
+
if all(end):
|
119 |
+
break
|
120 |
+
|
121 |
+
|
122 |
+
def create_msg(name, data):
|
123 |
+
return {"name": name, "data": data}
|
124 |
+
|
125 |
+
|
126 |
+
def send_msgs(msgs):
|
127 |
+
return json.dumps(msgs)
|
128 |
+
|
129 |
+
|
130 |
+
def get_duration(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm,
|
131 |
+
time_sig, key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr,
|
132 |
+
remove_empty_channels, seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
|
133 |
+
t = gen_events // 23
|
134 |
+
if "large" in model_name:
|
135 |
+
t = gen_events // 14
|
136 |
+
return t + 5
|
137 |
+
|
138 |
+
|
139 |
+
@spaces.GPU(duration=get_duration)
|
140 |
+
def run(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm, time_sig,
|
141 |
+
key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels,
|
142 |
+
seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
|
143 |
+
model = models[model_name]
|
144 |
+
model.to(device=opt.device)
|
145 |
+
tokenizer = model.tokenizer
|
146 |
+
bpm = int(bpm)
|
147 |
+
if time_sig == "auto":
|
148 |
+
time_sig = None
|
149 |
+
time_sig_nn = 4
|
150 |
+
time_sig_dd = 2
|
151 |
+
else:
|
152 |
+
time_sig_nn, time_sig_dd = time_sig.split('/')
|
153 |
+
time_sig_nn = int(time_sig_nn)
|
154 |
+
time_sig_dd = {2: 1, 4: 2, 8: 3}[int(time_sig_dd)]
|
155 |
+
if key_sig == 0:
|
156 |
+
key_sig = None
|
157 |
+
key_sig_sf = 0
|
158 |
+
key_sig_mi = 0
|
159 |
+
else:
|
160 |
+
key_sig = (key_sig - 1)
|
161 |
+
key_sig_sf = key_sig // 2 - 7
|
162 |
+
key_sig_mi = key_sig % 2
|
163 |
+
gen_events = int(gen_events)
|
164 |
+
max_len = gen_events
|
165 |
+
if seed_rand:
|
166 |
+
seed = random.randint(0, MAX_SEED)
|
167 |
+
generator = torch.Generator(opt.device).manual_seed(seed)
|
168 |
+
disable_patch_change = False
|
169 |
+
disable_channels = None
|
170 |
+
if tab == 0:
|
171 |
+
i = 0
|
172 |
+
mid = [[tokenizer.bos_id] + [tokenizer.pad_id] * (tokenizer.max_token_seq - 1)]
|
173 |
+
if tokenizer.version == "v2":
|
174 |
+
if time_sig is not None:
|
175 |
+
mid.append(tokenizer.event2tokens(["time_signature", 0, 0, 0, time_sig_nn - 1, time_sig_dd - 1]))
|
176 |
+
if key_sig is not None:
|
177 |
+
mid.append(tokenizer.event2tokens(["key_signature", 0, 0, 0, key_sig_sf + 7, key_sig_mi]))
|
178 |
+
if bpm != 0:
|
179 |
+
mid.append(tokenizer.event2tokens(["set_tempo", 0, 0, 0, bpm]))
|
180 |
+
patches = {}
|
181 |
+
if instruments is None:
|
182 |
+
instruments = []
|
183 |
+
for instr in instruments:
|
184 |
+
patches[i] = patch2number[instr]
|
185 |
+
i = (i + 1) if i != 8 else 10
|
186 |
+
if drum_kit != "None":
|
187 |
+
patches[9] = drum_kits2number[drum_kit]
|
188 |
+
for i, (c, p) in enumerate(patches.items()):
|
189 |
+
mid.append(tokenizer.event2tokens(["patch_change", 0, 0, i + 1, c, p]))
|
190 |
+
mid = np.asarray([mid] * OUTPUT_BATCH_SIZE, dtype=np.int64)
|
191 |
+
mid_seq = mid.tolist()
|
192 |
+
if len(instruments) > 0:
|
193 |
+
disable_patch_change = True
|
194 |
+
disable_channels = [i for i in range(16) if i not in patches]
|
195 |
+
elif tab == 1 and mid is not None:
|
196 |
+
eps = 4 if reduce_cc_st else 0
|
197 |
+
mid = tokenizer.tokenize(MIDI.midi2score(mid), cc_eps=eps, tempo_eps=eps,
|
198 |
+
remap_track_channel=remap_track_channel,
|
199 |
+
add_default_instr=add_default_instr,
|
200 |
+
remove_empty_channels=remove_empty_channels)
|
201 |
+
mid = mid[:int(midi_events)]
|
202 |
+
mid = np.asarray([mid] * OUTPUT_BATCH_SIZE, dtype=np.int64)
|
203 |
+
mid_seq = mid.tolist()
|
204 |
+
elif tab == 2 and mid_seq is not None:
|
205 |
+
mid = np.asarray(mid_seq, dtype=np.int64)
|
206 |
+
if continuation_select > 0:
|
207 |
+
continuation_state.append(mid_seq)
|
208 |
+
mid = np.repeat(mid[continuation_select - 1:continuation_select], repeats=OUTPUT_BATCH_SIZE, axis=0)
|
209 |
+
mid_seq = mid.tolist()
|
210 |
+
else:
|
211 |
+
continuation_state.append(mid.shape[1])
|
212 |
+
else:
|
213 |
+
continuation_state = [0]
|
214 |
+
mid = [[tokenizer.bos_id] + [tokenizer.pad_id] * (tokenizer.max_token_seq - 1)]
|
215 |
+
mid = np.asarray([mid] * OUTPUT_BATCH_SIZE, dtype=np.int64)
|
216 |
+
mid_seq = mid.tolist()
|
217 |
+
|
218 |
+
if mid is not None:
|
219 |
+
max_len += mid.shape[1]
|
220 |
+
|
221 |
+
init_msgs = [create_msg("progress", [0, gen_events])]
|
222 |
+
if not (tab == 2 and continuation_select == 0):
|
223 |
+
for i in range(OUTPUT_BATCH_SIZE):
|
224 |
+
events = [tokenizer.tokens2event(tokens) for tokens in mid_seq[i]]
|
225 |
+
init_msgs += [create_msg("visualizer_clear", [i, tokenizer.version]),
|
226 |
+
create_msg("visualizer_append", [i, events])]
|
227 |
+
yield mid_seq, continuation_state, seed, send_msgs(init_msgs)
|
228 |
+
midi_generator = generate(model, mid, batch_size=OUTPUT_BATCH_SIZE, max_len=max_len, temp=temp,
|
229 |
+
top_p=top_p, top_k=top_k, disable_patch_change=disable_patch_change,
|
230 |
+
disable_control_change=not allow_cc, disable_channels=disable_channels,
|
231 |
+
generator=generator)
|
232 |
+
events = [list() for i in range(OUTPUT_BATCH_SIZE)]
|
233 |
+
t = time.time() + 1
|
234 |
+
for i, token_seqs in enumerate(midi_generator):
|
235 |
+
token_seqs = token_seqs.tolist()
|
236 |
+
for j in range(OUTPUT_BATCH_SIZE):
|
237 |
+
token_seq = token_seqs[j]
|
238 |
+
mid_seq[j].append(token_seq)
|
239 |
+
events[j].append(tokenizer.tokens2event(token_seq))
|
240 |
+
if time.time() - t > 0.5:
|
241 |
+
msgs = [create_msg("progress", [i + 1, gen_events])]
|
242 |
+
for j in range(OUTPUT_BATCH_SIZE):
|
243 |
+
msgs += [create_msg("visualizer_append", [j, events[j]])]
|
244 |
+
events[j] = list()
|
245 |
+
yield mid_seq, continuation_state, seed, send_msgs(msgs)
|
246 |
+
t = time.time()
|
247 |
+
yield mid_seq, continuation_state, seed, send_msgs([])
|
248 |
+
|
249 |
+
|
250 |
+
def finish_run(model_name, mid_seq):
|
251 |
+
if mid_seq is None:
|
252 |
+
outputs = [None] * OUTPUT_BATCH_SIZE
|
253 |
+
return *outputs, []
|
254 |
+
tokenizer = models[model_name].tokenizer
|
255 |
+
outputs = []
|
256 |
+
end_msgs = [create_msg("progress", [0, 0])]
|
257 |
+
if not os.path.exists("outputs"):
|
258 |
+
os.mkdir("outputs")
|
259 |
+
for i in range(OUTPUT_BATCH_SIZE):
|
260 |
+
events = [tokenizer.tokens2event(tokens) for tokens in mid_seq[i]]
|
261 |
+
mid = tokenizer.detokenize(mid_seq[i])
|
262 |
+
with open(f"outputs/output{i + 1}.mid", 'wb') as f:
|
263 |
+
f.write(MIDI.score2midi(mid))
|
264 |
+
outputs.append(f"outputs/output{i + 1}.mid")
|
265 |
+
end_msgs += [create_msg("visualizer_clear", [i, tokenizer.version]),
|
266 |
+
create_msg("visualizer_append", [i, events]),
|
267 |
+
create_msg("visualizer_end", i)]
|
268 |
+
return *outputs, send_msgs(end_msgs)
|
269 |
+
|
270 |
+
|
271 |
+
def synthesis_task(mid):
|
272 |
+
return synthesizer.synthesis(MIDI.score2opus(mid))
|
273 |
+
|
274 |
+
def render_audio(model_name, mid_seq, should_render_audio):
|
275 |
+
if (not should_render_audio) or mid_seq is None:
|
276 |
+
outputs = [None] * OUTPUT_BATCH_SIZE
|
277 |
+
return tuple(outputs)
|
278 |
+
tokenizer = models[model_name].tokenizer
|
279 |
+
outputs = []
|
280 |
+
if not os.path.exists("outputs"):
|
281 |
+
os.mkdir("outputs")
|
282 |
+
audio_futures = []
|
283 |
+
for i in range(OUTPUT_BATCH_SIZE):
|
284 |
+
mid = tokenizer.detokenize(mid_seq[i])
|
285 |
+
audio_future = thread_pool.submit(synthesis_task, mid)
|
286 |
+
audio_futures.append(audio_future)
|
287 |
+
for future in audio_futures:
|
288 |
+
outputs.append((44100, future.result()))
|
289 |
+
if OUTPUT_BATCH_SIZE == 1:
|
290 |
+
return outputs[0]
|
291 |
+
return tuple(outputs)
|
292 |
+
|
293 |
+
|
294 |
+
def undo_continuation(model_name, mid_seq, continuation_state):
|
295 |
+
if mid_seq is None or len(continuation_state) < 2:
|
296 |
+
return mid_seq, continuation_state, send_msgs([])
|
297 |
+
tokenizer = models[model_name].tokenizer
|
298 |
+
if isinstance(continuation_state[-1], list):
|
299 |
+
mid_seq = continuation_state[-1]
|
300 |
+
else:
|
301 |
+
mid_seq = [ms[:continuation_state[-1]] for ms in mid_seq]
|
302 |
+
continuation_state = continuation_state[:-1]
|
303 |
+
end_msgs = [create_msg("progress", [0, 0])]
|
304 |
+
for i in range(OUTPUT_BATCH_SIZE):
|
305 |
+
events = [tokenizer.tokens2event(tokens) for tokens in mid_seq[i]]
|
306 |
+
end_msgs += [create_msg("visualizer_clear", [i, tokenizer.version]),
|
307 |
+
create_msg("visualizer_append", [i, events]),
|
308 |
+
create_msg("visualizer_end", i)]
|
309 |
+
return mid_seq, continuation_state, send_msgs(end_msgs)
|
310 |
+
|
311 |
+
|
312 |
+
def load_javascript(dir="javascript"):
|
313 |
+
scripts_list = glob.glob(f"{dir}/*.js")
|
314 |
+
javascript = ""
|
315 |
+
for path in scripts_list:
|
316 |
+
with open(path, "r", encoding="utf8") as jsfile:
|
317 |
+
js_content = jsfile.read()
|
318 |
+
js_content = js_content.replace("const MIDI_OUTPUT_BATCH_SIZE=4;",
|
319 |
+
f"const MIDI_OUTPUT_BATCH_SIZE={OUTPUT_BATCH_SIZE};")
|
320 |
+
javascript += f"\n<!-- {path} --><script>{js_content}</script>"
|
321 |
+
template_response_ori = gr.routes.templates.TemplateResponse
|
322 |
+
|
323 |
+
def template_response(*args, **kwargs):
|
324 |
+
res = template_response_ori(*args, **kwargs)
|
325 |
+
res.body = res.body.replace(
|
326 |
+
b'</head>', f'{javascript}</head>'.encode("utf8"))
|
327 |
+
res.init_headers()
|
328 |
+
return res
|
329 |
+
|
330 |
+
gr.routes.templates.TemplateResponse = template_response
|
331 |
+
|
332 |
+
|
333 |
+
def hf_hub_download_retry(repo_id, filename):
|
334 |
+
print(f"downloading {repo_id} {filename}")
|
335 |
+
retry = 0
|
336 |
+
err = None
|
337 |
+
while retry < 30:
|
338 |
+
try:
|
339 |
+
return hf_hub_download(repo_id=repo_id, filename=filename)
|
340 |
+
except Exception as e:
|
341 |
+
err = e
|
342 |
+
retry += 1
|
343 |
+
if err:
|
344 |
+
raise err
|
345 |
+
|
346 |
+
|
347 |
+
number2drum_kits = {-1: "None", 0: "Standard", 8: "Room", 16: "Power", 24: "Electric", 25: "TR-808", 32: "Jazz",
|
348 |
+
40: "Blush", 48: "Orchestra"}
|
349 |
+
patch2number = {v: k for k, v in MIDI.Number2patch.items()}
|
350 |
+
drum_kits2number = {v: k for k, v in number2drum_kits.items()}
|
351 |
+
key_signatures = ['C♭', 'A♭m', 'G♭', 'E♭m', 'D♭', 'B♭m', 'A♭', 'Fm', 'E♭', 'Cm', 'B♭', 'Gm', 'F', 'Dm',
|
352 |
+
'C', 'Am', 'G', 'Em', 'D', 'Bm', 'A', 'F♯m', 'E', 'C♯m', 'B', 'G♯m', 'F♯', 'D♯m', 'C♯', 'A♯m']
|
353 |
+
|
354 |
+
if __name__ == "__main__":
|
355 |
+
parser = argparse.ArgumentParser()
|
356 |
+
parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
|
357 |
+
parser.add_argument("--port", type=int, default=7860, help="gradio server port")
|
358 |
+
parser.add_argument("--device", type=str, default="cuda", help="device to run model")
|
359 |
+
parser.add_argument("--batch", type=int, default=8, help="batch size")
|
360 |
+
parser.add_argument("--max-gen", type=int, default=1024, help="max")
|
361 |
+
opt = parser.parse_args()
|
362 |
+
OUTPUT_BATCH_SIZE = opt.batch
|
363 |
+
soundfont_path = hf_hub_download_retry(repo_id="skytnt/midi-model", filename="soundfont.sf2")
|
364 |
+
thread_pool = ThreadPoolExecutor(max_workers=OUTPUT_BATCH_SIZE)
|
365 |
+
synthesizer = MidiSynthesizer(soundfont_path)
|
366 |
+
models_info = {
|
367 |
+
"generic pretrain model (tv2o-medium) by skytnt": [
|
368 |
+
"skytnt/midi-model-tv2o-medium", {
|
369 |
+
"jpop": "skytnt/midi-model-tv2om-jpop-lora",
|
370 |
+
"touhou": "skytnt/midi-model-tv2om-touhou-lora"
|
371 |
+
}
|
372 |
+
],
|
373 |
+
"generic pretrain model (tv2o-large) by asigalov61": [
|
374 |
+
"asigalov61/Music-Llama", {}
|
375 |
+
],
|
376 |
+
"generic pretrain model (tv2o-medium) by asigalov61": [
|
377 |
+
"asigalov61/Music-Llama-Medium", {}
|
378 |
+
],
|
379 |
+
"generic pretrain model (tv1-medium) by skytnt": [
|
380 |
+
"skytnt/midi-model", {}
|
381 |
+
]
|
382 |
+
}
|
383 |
+
models = {}
|
384 |
+
if opt.device == "cuda":
|
385 |
+
torch.backends.cudnn.deterministic = True
|
386 |
+
torch.backends.cudnn.benchmark = False
|
387 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
388 |
+
torch.backends.cudnn.allow_tf32 = True
|
389 |
+
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
390 |
+
torch.backends.cuda.enable_flash_sdp(True)
|
391 |
+
for name, (repo_id, loras) in models_info.items():
|
392 |
+
model = MIDIModel.from_pretrained(repo_id)
|
393 |
+
model.to(device="cpu", dtype=torch.float32)
|
394 |
+
models[name] = model
|
395 |
+
for lora_name, lora_repo in loras.items():
|
396 |
+
model = MIDIModel.from_pretrained(repo_id)
|
397 |
+
print(f"loading lora {lora_repo} for {name}")
|
398 |
+
model = model.load_merge_lora(lora_repo)
|
399 |
+
model.to(device="cpu", dtype=torch.float32)
|
400 |
+
models[f"{name} with {lora_name} lora"] = model
|
401 |
+
|
402 |
+
load_javascript()
|
403 |
+
app = gr.Blocks(theme=gr.themes.Soft())
|
404 |
+
with app:
|
405 |
+
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Midi Composer</h1>")
|
406 |
+
gr.Markdown("![Visitors](https://api.visitorbadge.io/api/visitors?path=skytnt.midi-composer&style=flat)\n\n"
|
407 |
+
"Midi event transformer for symbolic music generation\n\n"
|
408 |
+
"Demo for [SkyTNT/midi-model](https://github.com/SkyTNT/midi-model)\n\n"
|
409 |
+
"[Open In Colab]"
|
410 |
+
"(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
|
411 |
+
" or [download windows app](https://github.com/SkyTNT/midi-model/releases)"
|
412 |
+
" for unlimited generation\n\n"
|
413 |
+
"**Update v1.3**: MIDITokenizerV2 and new MidiVisualizer\n\n"
|
414 |
+
"The current **best** model: generic pretrain model (tv2o-medium) by skytnt"
|
415 |
+
)
|
416 |
+
js_msg = gr.Textbox(elem_id="msg_receiver", visible=False)
|
417 |
+
js_msg.change(None, [js_msg], [], js="""
|
418 |
+
(msg_json) =>{
|
419 |
+
let msgs = JSON.parse(msg_json);
|
420 |
+
executeCallbacks(msgReceiveCallbacks, msgs);
|
421 |
+
return [];
|
422 |
+
}
|
423 |
+
""")
|
424 |
+
input_model = gr.Dropdown(label="select model", choices=list(models.keys()),
|
425 |
+
type="value", value=list(models.keys())[0])
|
426 |
+
tab_select = gr.State(value=0)
|
427 |
+
with gr.Tabs():
|
428 |
+
with gr.TabItem("custom prompt") as tab1:
|
429 |
+
input_instruments = gr.Dropdown(label="🪗instruments (auto if empty)", choices=list(patch2number.keys()),
|
430 |
+
multiselect=True, max_choices=15, type="value")
|
431 |
+
input_drum_kit = gr.Dropdown(label="🥁drum kit", choices=list(drum_kits2number.keys()), type="value",
|
432 |
+
value="None")
|
433 |
+
input_bpm = gr.Slider(label="BPM (beats per minute, auto if 0)", minimum=0, maximum=255,
|
434 |
+
step=1,
|
435 |
+
value=0)
|
436 |
+
input_time_sig = gr.Radio(label="time signature (only for tv2 models)",
|
437 |
+
value="auto",
|
438 |
+
choices=["auto", "4/4", "2/4", "3/4", "6/4", "7/4",
|
439 |
+
"2/2", "3/2", "4/2", "3/8", "5/8", "6/8", "7/8", "9/8", "12/8"]
|
440 |
+
)
|
441 |
+
input_key_sig = gr.Radio(label="key signature (only for tv2 models)",
|
442 |
+
value="auto",
|
443 |
+
choices=["auto"] + key_signatures,
|
444 |
+
type="index"
|
445 |
+
)
|
446 |
+
example1 = gr.Examples([
|
447 |
+
[[], "None"],
|
448 |
+
[["Acoustic Grand"], "None"],
|
449 |
+
[['Acoustic Grand', 'SynthStrings 2', 'SynthStrings 1', 'Pizzicato Strings',
|
450 |
+
'Pad 2 (warm)', 'Tremolo Strings', 'String Ensemble 1'], "Orchestra"],
|
451 |
+
[['Trumpet', 'Oboe', 'Trombone', 'String Ensemble 1', 'Clarinet',
|
452 |
+
'French Horn', 'Pad 4 (choir)', 'Bassoon', 'Flute'], "None"],
|
453 |
+
[['Flute', 'French Horn', 'Clarinet', 'String Ensemble 2', 'English Horn', 'Bassoon',
|
454 |
+
'Oboe', 'Pizzicato Strings'], "Orchestra"],
|
455 |
+
[['Electric Piano 2', 'Lead 5 (charang)', 'Electric Bass(pick)', 'Lead 2 (sawtooth)',
|
456 |
+
'Pad 1 (new age)', 'Orchestra Hit', 'Cello', 'Electric Guitar(clean)'], "Standard"],
|
457 |
+
[["Electric Guitar(clean)", "Electric Guitar(muted)", "Overdriven Guitar", "Distortion Guitar",
|
458 |
+
"Electric Bass(finger)"], "Standard"]
|
459 |
+
], [input_instruments, input_drum_kit])
|
460 |
+
with gr.TabItem("midi prompt") as tab2:
|
461 |
+
input_midi = gr.File(label="input midi", file_types=[".midi", ".mid"], type="binary")
|
462 |
+
input_midi_events = gr.Slider(label="use first n midi events as prompt", minimum=1, maximum=512,
|
463 |
+
step=1,
|
464 |
+
value=128)
|
465 |
+
input_reduce_cc_st = gr.Checkbox(label="reduce control_change and set_tempo events", value=True)
|
466 |
+
input_remap_track_channel = gr.Checkbox(
|
467 |
+
label="remap tracks and channels so each track has only one channel and in order", value=True)
|
468 |
+
input_add_default_instr = gr.Checkbox(
|
469 |
+
label="add a default instrument to channels that don't have an instrument", value=True)
|
470 |
+
input_remove_empty_channels = gr.Checkbox(label="remove channels without notes", value=False)
|
471 |
+
example2 = gr.Examples([[file, 128] for file in glob.glob("example/*.mid")],
|
472 |
+
[input_midi, input_midi_events])
|
473 |
+
with gr.TabItem("last output prompt") as tab3:
|
474 |
+
gr.Markdown("Continue generating on the last output.")
|
475 |
+
input_continuation_select = gr.Radio(label="select output to continue generating", value="all",
|
476 |
+
choices=["all"] + [f"output{i + 1}" for i in
|
477 |
+
range(OUTPUT_BATCH_SIZE)],
|
478 |
+
type="index"
|
479 |
+
)
|
480 |
+
undo_btn = gr.Button("undo the last continuation")
|
481 |
+
|
482 |
+
tab1.select(lambda: 0, None, tab_select, queue=False)
|
483 |
+
tab2.select(lambda: 1, None, tab_select, queue=False)
|
484 |
+
tab3.select(lambda: 2, None, tab_select, queue=False)
|
485 |
+
input_seed = gr.Slider(label="seed", minimum=0, maximum=2 ** 31 - 1,
|
486 |
+
step=1, value=0)
|
487 |
+
input_seed_rand = gr.Checkbox(label="random seed", value=True)
|
488 |
+
input_gen_events = gr.Slider(label="generate max n midi events", minimum=1, maximum=opt.max_gen,
|
489 |
+
step=1, value=opt.max_gen // 2)
|
490 |
+
with gr.Accordion("options", open=False):
|
491 |
+
input_temp = gr.Slider(label="temperature", minimum=0.1, maximum=1.2, step=0.01, value=1)
|
492 |
+
input_top_p = gr.Slider(label="top p", minimum=0.1, maximum=1, step=0.01, value=0.95)
|
493 |
+
input_top_k = gr.Slider(label="top k", minimum=1, maximum=128, step=1, value=20)
|
494 |
+
input_allow_cc = gr.Checkbox(label="allow midi cc event", value=True)
|
495 |
+
input_render_audio = gr.Checkbox(label="render audio after generation", value=True)
|
496 |
+
example3 = gr.Examples([[1, 0.94, 128], [1, 0.98, 20], [1, 0.98, 12]],
|
497 |
+
[input_temp, input_top_p, input_top_k])
|
498 |
+
run_btn = gr.Button("generate", variant="primary")
|
499 |
+
# stop_btn = gr.Button("stop and output")
|
500 |
+
output_midi_seq = gr.State()
|
501 |
+
output_continuation_state = gr.State([0])
|
502 |
+
midi_outputs = []
|
503 |
+
audio_outputs = []
|
504 |
+
with gr.Tabs(elem_id="output_tabs"):
|
505 |
+
for i in range(OUTPUT_BATCH_SIZE):
|
506 |
+
with gr.TabItem(f"output {i + 1}") as tab1:
|
507 |
+
output_midi_visualizer = gr.HTML(elem_id=f"midi_visualizer_container_{i}")
|
508 |
+
output_audio = gr.Audio(label="output audio", format="mp3", elem_id=f"midi_audio_{i}")
|
509 |
+
output_midi = gr.File(label="output midi", file_types=[".mid"])
|
510 |
+
midi_outputs.append(output_midi)
|
511 |
+
audio_outputs.append(output_audio)
|
512 |
+
run_event = run_btn.click(run, [input_model, tab_select, output_midi_seq, output_continuation_state,
|
513 |
+
input_continuation_select, input_instruments, input_drum_kit, input_bpm,
|
514 |
+
input_time_sig, input_key_sig, input_midi, input_midi_events,
|
515 |
+
input_reduce_cc_st, input_remap_track_channel,
|
516 |
+
input_add_default_instr, input_remove_empty_channels,
|
517 |
+
input_seed, input_seed_rand, input_gen_events, input_temp, input_top_p,
|
518 |
+
input_top_k, input_allow_cc],
|
519 |
+
[output_midi_seq, output_continuation_state, input_seed, js_msg],
|
520 |
+
concurrency_limit=10, queue=True)
|
521 |
+
finish_run_event = run_event.then(fn=finish_run,
|
522 |
+
inputs=[input_model, output_midi_seq],
|
523 |
+
outputs=midi_outputs + [js_msg],
|
524 |
+
queue=False)
|
525 |
+
finish_run_event.then(fn=render_audio,
|
526 |
+
inputs=[input_model, output_midi_seq, input_render_audio],
|
527 |
+
outputs=audio_outputs,
|
528 |
+
queue=False)
|
529 |
+
# stop_btn.click(None, [], [], cancels=run_event,
|
530 |
+
# queue=False)
|
531 |
+
undo_btn.click(undo_continuation, [input_model, output_midi_seq, output_continuation_state],
|
532 |
+
[output_midi_seq, output_continuation_state, js_msg], queue=False)
|
533 |
+
app.queue().launch(server_port=opt.port, share=opt.share, inbrowser=True, ssr_mode=False)
|
534 |
+
thread_pool.shutdown()
|
app_onnx.py
ADDED
@@ -0,0 +1,626 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spaces
|
2 |
+
import random
|
3 |
+
import argparse
|
4 |
+
import glob
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
import time
|
8 |
+
from concurrent.futures import ThreadPoolExecutor
|
9 |
+
|
10 |
+
import gradio as gr
|
11 |
+
import numpy as np
|
12 |
+
import onnxruntime as rt
|
13 |
+
import tqdm
|
14 |
+
from huggingface_hub import hf_hub_download
|
15 |
+
|
16 |
+
import MIDI
|
17 |
+
from midi_synthesizer import MidiSynthesizer
|
18 |
+
from midi_tokenizer import MIDITokenizer
|
19 |
+
|
20 |
+
MAX_SEED = np.iinfo(np.int32).max
|
21 |
+
in_space = os.getenv("SYSTEM") == "spaces"
|
22 |
+
|
23 |
+
|
24 |
+
def softmax(x, axis):
|
25 |
+
x_max = np.amax(x, axis=axis, keepdims=True)
|
26 |
+
exp_x_shifted = np.exp(x - x_max)
|
27 |
+
return exp_x_shifted / np.sum(exp_x_shifted, axis=axis, keepdims=True)
|
28 |
+
|
29 |
+
|
30 |
+
def sample_top_p_k(probs, p, k, generator=None):
|
31 |
+
if generator is None:
|
32 |
+
generator = np.random
|
33 |
+
probs_idx = np.argsort(-probs, axis=-1)
|
34 |
+
probs_sort = np.take_along_axis(probs, probs_idx, -1)
|
35 |
+
probs_sum = np.cumsum(probs_sort, axis=-1)
|
36 |
+
mask = probs_sum - probs_sort > p
|
37 |
+
probs_sort[mask] = 0.0
|
38 |
+
mask = np.zeros(probs_sort.shape[-1])
|
39 |
+
mask[:k] = 1
|
40 |
+
probs_sort = probs_sort * mask
|
41 |
+
probs_sort /= np.sum(probs_sort, axis=-1, keepdims=True)
|
42 |
+
shape = probs_sort.shape
|
43 |
+
probs_sort_flat = probs_sort.reshape(-1, shape[-1])
|
44 |
+
probs_idx_flat = probs_idx.reshape(-1, shape[-1])
|
45 |
+
next_token = np.stack([generator.choice(idxs, p=pvals) for pvals, idxs in zip(probs_sort_flat, probs_idx_flat)])
|
46 |
+
next_token = next_token.reshape(*shape[:-1])
|
47 |
+
return next_token
|
48 |
+
|
49 |
+
|
50 |
+
def apply_io_binding(model: rt.InferenceSession, inputs, outputs, batch_size, past_len, cur_len):
|
51 |
+
io_binding = model.io_binding()
|
52 |
+
for input_ in model.get_inputs():
|
53 |
+
name = input_.name
|
54 |
+
if name.startswith("past_key_values"):
|
55 |
+
present_name = name.replace("past_key_values", "present")
|
56 |
+
if present_name in outputs:
|
57 |
+
v = outputs[present_name]
|
58 |
+
else:
|
59 |
+
v = rt.OrtValue.ortvalue_from_shape_and_type(
|
60 |
+
(batch_size, input_.shape[1], past_len, input_.shape[3]),
|
61 |
+
element_type=np.float32,
|
62 |
+
device_type=device)
|
63 |
+
inputs[name] = v
|
64 |
+
else:
|
65 |
+
v = inputs[name]
|
66 |
+
io_binding.bind_ortvalue_input(name, v)
|
67 |
+
|
68 |
+
for output in model.get_outputs():
|
69 |
+
name = output.name
|
70 |
+
if name.startswith("present"):
|
71 |
+
v = rt.OrtValue.ortvalue_from_shape_and_type(
|
72 |
+
(batch_size, output.shape[1], cur_len, output.shape[3]),
|
73 |
+
element_type=np.float32,
|
74 |
+
device_type=device)
|
75 |
+
outputs[name] = v
|
76 |
+
else:
|
77 |
+
v = outputs[name]
|
78 |
+
io_binding.bind_ortvalue_output(name, v)
|
79 |
+
return io_binding
|
80 |
+
|
81 |
+
def generate(model, prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98, top_k=20,
|
82 |
+
disable_patch_change=False, disable_control_change=False, disable_channels=None, generator=None):
|
83 |
+
tokenizer = model[2]
|
84 |
+
if disable_channels is not None:
|
85 |
+
disable_channels = [tokenizer.parameter_ids["channel"][c] for c in disable_channels]
|
86 |
+
else:
|
87 |
+
disable_channels = []
|
88 |
+
if generator is None:
|
89 |
+
generator = np.random
|
90 |
+
max_token_seq = tokenizer.max_token_seq
|
91 |
+
if prompt is None:
|
92 |
+
input_tensor = np.full((1, max_token_seq), tokenizer.pad_id, dtype=np.int64)
|
93 |
+
input_tensor[0, 0] = tokenizer.bos_id # bos
|
94 |
+
input_tensor = input_tensor[None, :, :]
|
95 |
+
input_tensor = np.repeat(input_tensor, repeats=batch_size, axis=0)
|
96 |
+
else:
|
97 |
+
if len(prompt.shape) == 2:
|
98 |
+
prompt = prompt[None, :]
|
99 |
+
prompt = np.repeat(prompt, repeats=batch_size, axis=0)
|
100 |
+
elif prompt.shape[0] == 1:
|
101 |
+
prompt = np.repeat(prompt, repeats=batch_size, axis=0)
|
102 |
+
elif len(prompt.shape) != 3 or prompt.shape[0] != batch_size:
|
103 |
+
raise ValueError(f"invalid shape for prompt, {prompt.shape}")
|
104 |
+
prompt = prompt[..., :max_token_seq]
|
105 |
+
if prompt.shape[-1] < max_token_seq:
|
106 |
+
prompt = np.pad(prompt, ((0, 0), (0, 0), (0, max_token_seq - prompt.shape[-1])),
|
107 |
+
mode="constant", constant_values=tokenizer.pad_id)
|
108 |
+
input_tensor = prompt
|
109 |
+
cur_len = input_tensor.shape[1]
|
110 |
+
bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
|
111 |
+
model0_inputs = {}
|
112 |
+
model0_outputs = {}
|
113 |
+
emb_size = 1024
|
114 |
+
for output in model[0].get_outputs():
|
115 |
+
if output.name == "hidden":
|
116 |
+
emb_size = output.shape[2]
|
117 |
+
past_len = 0
|
118 |
+
with bar:
|
119 |
+
while cur_len < max_len:
|
120 |
+
end = [False] * batch_size
|
121 |
+
model0_inputs["x"] = rt.OrtValue.ortvalue_from_numpy(input_tensor[:, past_len:], device_type=device)
|
122 |
+
model0_outputs["hidden"] = rt.OrtValue.ortvalue_from_shape_and_type(
|
123 |
+
(batch_size, cur_len - past_len, emb_size),
|
124 |
+
element_type=np.float32,
|
125 |
+
device_type=device)
|
126 |
+
io_binding = apply_io_binding(model[0], model0_inputs, model0_outputs, batch_size, past_len, cur_len)
|
127 |
+
io_binding.synchronize_inputs()
|
128 |
+
model[0].run_with_iobinding(io_binding)
|
129 |
+
io_binding.synchronize_outputs()
|
130 |
+
|
131 |
+
hidden = model0_outputs["hidden"].numpy()[:, -1:]
|
132 |
+
next_token_seq = np.zeros((batch_size, 0), dtype=np.int64)
|
133 |
+
event_names = [""] * batch_size
|
134 |
+
model1_inputs = {"hidden": rt.OrtValue.ortvalue_from_numpy(hidden, device_type=device)}
|
135 |
+
model1_outputs = {}
|
136 |
+
for i in range(max_token_seq):
|
137 |
+
mask = np.zeros((batch_size, tokenizer.vocab_size), dtype=np.int64)
|
138 |
+
for b in range(batch_size):
|
139 |
+
if end[b]:
|
140 |
+
mask[b, tokenizer.pad_id] = 1
|
141 |
+
continue
|
142 |
+
if i == 0:
|
143 |
+
mask_ids = list(tokenizer.event_ids.values()) + [tokenizer.eos_id]
|
144 |
+
if disable_patch_change:
|
145 |
+
mask_ids.remove(tokenizer.event_ids["patch_change"])
|
146 |
+
if disable_control_change:
|
147 |
+
mask_ids.remove(tokenizer.event_ids["control_change"])
|
148 |
+
mask[b, mask_ids] = 1
|
149 |
+
else:
|
150 |
+
param_names = tokenizer.events[event_names[b]]
|
151 |
+
if i > len(param_names):
|
152 |
+
mask[b, tokenizer.pad_id] = 1
|
153 |
+
continue
|
154 |
+
param_name = param_names[i - 1]
|
155 |
+
mask_ids = tokenizer.parameter_ids[param_name]
|
156 |
+
if param_name == "channel":
|
157 |
+
mask_ids = [i for i in mask_ids if i not in disable_channels]
|
158 |
+
mask[b, mask_ids] = 1
|
159 |
+
mask = mask[:, None, :]
|
160 |
+
x = next_token_seq
|
161 |
+
if i != 0:
|
162 |
+
# cached
|
163 |
+
if i == 1:
|
164 |
+
hidden = np.zeros((batch_size, 0, emb_size), dtype=np.float32)
|
165 |
+
model1_inputs["hidden"] = rt.OrtValue.ortvalue_from_numpy(hidden, device_type=device)
|
166 |
+
x = x[:, -1:]
|
167 |
+
model1_inputs["x"] = rt.OrtValue.ortvalue_from_numpy(x, device_type=device)
|
168 |
+
model1_outputs["y"] = rt.OrtValue.ortvalue_from_shape_and_type(
|
169 |
+
(batch_size, 1, tokenizer.vocab_size),
|
170 |
+
element_type=np.float32,
|
171 |
+
device_type=device
|
172 |
+
)
|
173 |
+
io_binding = apply_io_binding(model[1], model1_inputs, model1_outputs, batch_size, i, i+1)
|
174 |
+
io_binding.synchronize_inputs()
|
175 |
+
model[1].run_with_iobinding(io_binding)
|
176 |
+
io_binding.synchronize_outputs()
|
177 |
+
logits = model1_outputs["y"].numpy()
|
178 |
+
scores = softmax(logits / temp, -1) * mask
|
179 |
+
samples = sample_top_p_k(scores, top_p, top_k, generator)
|
180 |
+
if i == 0:
|
181 |
+
next_token_seq = samples
|
182 |
+
for b in range(batch_size):
|
183 |
+
if end[b]:
|
184 |
+
continue
|
185 |
+
eid = samples[b].item()
|
186 |
+
if eid == tokenizer.eos_id:
|
187 |
+
end[b] = True
|
188 |
+
else:
|
189 |
+
event_names[b] = tokenizer.id_events[eid]
|
190 |
+
else:
|
191 |
+
next_token_seq = np.concatenate([next_token_seq, samples], axis=1)
|
192 |
+
if all([len(tokenizer.events[event_names[b]]) == i for b in range(batch_size) if not end[b]]):
|
193 |
+
break
|
194 |
+
if next_token_seq.shape[1] < max_token_seq:
|
195 |
+
next_token_seq = np.pad(next_token_seq,
|
196 |
+
((0, 0), (0, max_token_seq - next_token_seq.shape[-1])),
|
197 |
+
mode="constant", constant_values=tokenizer.pad_id)
|
198 |
+
next_token_seq = next_token_seq[:, None, :]
|
199 |
+
input_tensor = np.concatenate([input_tensor, next_token_seq], axis=1)
|
200 |
+
past_len = cur_len
|
201 |
+
cur_len += 1
|
202 |
+
bar.update(1)
|
203 |
+
yield next_token_seq[:, 0]
|
204 |
+
if all(end):
|
205 |
+
break
|
206 |
+
|
207 |
+
|
208 |
+
def create_msg(name, data):
|
209 |
+
return {"name": name, "data": data}
|
210 |
+
|
211 |
+
|
212 |
+
def send_msgs(msgs):
|
213 |
+
return json.dumps(msgs)
|
214 |
+
|
215 |
+
|
216 |
+
def get_duration(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm,
|
217 |
+
time_sig, key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr,
|
218 |
+
remove_empty_channels, seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
|
219 |
+
t = gen_events // 30
|
220 |
+
if "large" in model_name:
|
221 |
+
t = gen_events // 23
|
222 |
+
return t + 5
|
223 |
+
|
224 |
+
|
225 |
+
@spaces.GPU(duration=get_duration)
|
226 |
+
def run(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm, time_sig,
|
227 |
+
key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels,
|
228 |
+
seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
|
229 |
+
model = models[model_name]
|
230 |
+
model_base = rt.InferenceSession(model[0], providers=providers)
|
231 |
+
model_token = rt.InferenceSession(model[1], providers=providers)
|
232 |
+
tokenizer = model[2]
|
233 |
+
model = [model_base, model_token, tokenizer]
|
234 |
+
bpm = int(bpm)
|
235 |
+
if time_sig == "auto":
|
236 |
+
time_sig = None
|
237 |
+
time_sig_nn = 4
|
238 |
+
time_sig_dd = 2
|
239 |
+
else:
|
240 |
+
time_sig_nn, time_sig_dd = time_sig.split('/')
|
241 |
+
time_sig_nn = int(time_sig_nn)
|
242 |
+
time_sig_dd = {2: 1, 4: 2, 8: 3}[int(time_sig_dd)]
|
243 |
+
if key_sig == 0:
|
244 |
+
key_sig = None
|
245 |
+
key_sig_sf = 0
|
246 |
+
key_sig_mi = 0
|
247 |
+
else:
|
248 |
+
key_sig = (key_sig - 1)
|
249 |
+
key_sig_sf = key_sig // 2 - 7
|
250 |
+
key_sig_mi = key_sig % 2
|
251 |
+
gen_events = int(gen_events)
|
252 |
+
max_len = gen_events
|
253 |
+
if seed_rand:
|
254 |
+
seed = random.randint(0, MAX_SEED)
|
255 |
+
generator = np.random.RandomState(seed)
|
256 |
+
disable_patch_change = False
|
257 |
+
disable_channels = None
|
258 |
+
if tab == 0:
|
259 |
+
i = 0
|
260 |
+
mid = [[tokenizer.bos_id] + [tokenizer.pad_id] * (tokenizer.max_token_seq - 1)]
|
261 |
+
if tokenizer.version == "v2":
|
262 |
+
if time_sig is not None:
|
263 |
+
mid.append(tokenizer.event2tokens(["time_signature", 0, 0, 0, time_sig_nn - 1, time_sig_dd - 1]))
|
264 |
+
if key_sig is not None:
|
265 |
+
mid.append(tokenizer.event2tokens(["key_signature", 0, 0, 0, key_sig_sf + 7, key_sig_mi]))
|
266 |
+
if bpm != 0:
|
267 |
+
mid.append(tokenizer.event2tokens(["set_tempo", 0, 0, 0, bpm]))
|
268 |
+
patches = {}
|
269 |
+
if instruments is None:
|
270 |
+
instruments = []
|
271 |
+
for instr in instruments:
|
272 |
+
patches[i] = patch2number[instr]
|
273 |
+
i = (i + 1) if i != 8 else 10
|
274 |
+
if drum_kit != "None":
|
275 |
+
patches[9] = drum_kits2number[drum_kit]
|
276 |
+
for i, (c, p) in enumerate(patches.items()):
|
277 |
+
mid.append(tokenizer.event2tokens(["patch_change", 0, 0, i + 1, c, p]))
|
278 |
+
mid = np.asarray([mid] * OUTPUT_BATCH_SIZE, dtype=np.int64)
|
279 |
+
mid_seq = mid.tolist()
|
280 |
+
if len(instruments) > 0:
|
281 |
+
disable_patch_change = True
|
282 |
+
disable_channels = [i for i in range(16) if i not in patches]
|
283 |
+
elif tab == 1 and mid is not None:
|
284 |
+
eps = 4 if reduce_cc_st else 0
|
285 |
+
mid = tokenizer.tokenize(MIDI.midi2score(mid), cc_eps=eps, tempo_eps=eps,
|
286 |
+
remap_track_channel=remap_track_channel,
|
287 |
+
add_default_instr=add_default_instr,
|
288 |
+
remove_empty_channels=remove_empty_channels)
|
289 |
+
mid = mid[:int(midi_events)]
|
290 |
+
mid = np.asarray([mid] * OUTPUT_BATCH_SIZE, dtype=np.int64)
|
291 |
+
mid_seq = mid.tolist()
|
292 |
+
elif tab == 2 and mid_seq is not None:
|
293 |
+
mid = np.asarray(mid_seq, dtype=np.int64)
|
294 |
+
if continuation_select > 0:
|
295 |
+
continuation_state.append(mid_seq)
|
296 |
+
mid = np.repeat(mid[continuation_select - 1:continuation_select], repeats=OUTPUT_BATCH_SIZE, axis=0)
|
297 |
+
mid_seq = mid.tolist()
|
298 |
+
else:
|
299 |
+
continuation_state.append(mid.shape[1])
|
300 |
+
else:
|
301 |
+
continuation_state = [0]
|
302 |
+
mid = [[tokenizer.bos_id] + [tokenizer.pad_id] * (tokenizer.max_token_seq - 1)]
|
303 |
+
mid = np.asarray([mid] * OUTPUT_BATCH_SIZE, dtype=np.int64)
|
304 |
+
mid_seq = mid.tolist()
|
305 |
+
|
306 |
+
if mid is not None:
|
307 |
+
max_len += mid.shape[1]
|
308 |
+
|
309 |
+
init_msgs = [create_msg("progress", [0, gen_events])]
|
310 |
+
if not (tab == 2 and continuation_select == 0):
|
311 |
+
for i in range(OUTPUT_BATCH_SIZE):
|
312 |
+
events = [tokenizer.tokens2event(tokens) for tokens in mid_seq[i]]
|
313 |
+
init_msgs += [create_msg("visualizer_clear", [i, tokenizer.version]),
|
314 |
+
create_msg("visualizer_append", [i, events])]
|
315 |
+
yield mid_seq, continuation_state, seed, send_msgs(init_msgs)
|
316 |
+
midi_generator = generate(model, mid, batch_size=OUTPUT_BATCH_SIZE, max_len=max_len, temp=temp,
|
317 |
+
top_p=top_p, top_k=top_k, disable_patch_change=disable_patch_change,
|
318 |
+
disable_control_change=not allow_cc, disable_channels=disable_channels,
|
319 |
+
generator=generator)
|
320 |
+
events = [list() for i in range(OUTPUT_BATCH_SIZE)]
|
321 |
+
t = time.time() + 1
|
322 |
+
for i, token_seqs in enumerate(midi_generator):
|
323 |
+
token_seqs = token_seqs.tolist()
|
324 |
+
for j in range(OUTPUT_BATCH_SIZE):
|
325 |
+
token_seq = token_seqs[j]
|
326 |
+
mid_seq[j].append(token_seq)
|
327 |
+
events[j].append(tokenizer.tokens2event(token_seq))
|
328 |
+
if time.time() - t > 0.5:
|
329 |
+
msgs = [create_msg("progress", [i + 1, gen_events])]
|
330 |
+
for j in range(OUTPUT_BATCH_SIZE):
|
331 |
+
msgs += [create_msg("visualizer_append", [j, events[j]])]
|
332 |
+
events[j] = list()
|
333 |
+
yield mid_seq, continuation_state, seed, send_msgs(msgs)
|
334 |
+
t = time.time()
|
335 |
+
yield mid_seq, continuation_state, seed, send_msgs([])
|
336 |
+
|
337 |
+
|
338 |
+
def finish_run(model_name, mid_seq):
|
339 |
+
if mid_seq is None:
|
340 |
+
outputs = [None] * OUTPUT_BATCH_SIZE
|
341 |
+
return *outputs, []
|
342 |
+
tokenizer = models[model_name][2]
|
343 |
+
outputs = []
|
344 |
+
end_msgs = [create_msg("progress", [0, 0])]
|
345 |
+
if not os.path.exists("outputs"):
|
346 |
+
os.mkdir("outputs")
|
347 |
+
for i in range(OUTPUT_BATCH_SIZE):
|
348 |
+
events = [tokenizer.tokens2event(tokens) for tokens in mid_seq[i]]
|
349 |
+
mid = tokenizer.detokenize(mid_seq[i])
|
350 |
+
with open(f"outputs/output{i + 1}.mid", 'wb') as f:
|
351 |
+
f.write(MIDI.score2midi(mid))
|
352 |
+
outputs.append(f"outputs/output{i + 1}.mid")
|
353 |
+
end_msgs += [create_msg("visualizer_clear", [i, tokenizer.version]),
|
354 |
+
create_msg("visualizer_append", [i, events]),
|
355 |
+
create_msg("visualizer_end", i)]
|
356 |
+
return *outputs, send_msgs(end_msgs)
|
357 |
+
|
358 |
+
|
359 |
+
def synthesis_task(mid):
|
360 |
+
return synthesizer.synthesis(MIDI.score2opus(mid))
|
361 |
+
|
362 |
+
def render_audio(model_name, mid_seq, should_render_audio):
|
363 |
+
if (not should_render_audio) or mid_seq is None:
|
364 |
+
outputs = [None] * OUTPUT_BATCH_SIZE
|
365 |
+
return tuple(outputs)
|
366 |
+
tokenizer = models[model_name][2]
|
367 |
+
outputs = []
|
368 |
+
if not os.path.exists("outputs"):
|
369 |
+
os.mkdir("outputs")
|
370 |
+
audio_futures = []
|
371 |
+
for i in range(OUTPUT_BATCH_SIZE):
|
372 |
+
mid = tokenizer.detokenize(mid_seq[i])
|
373 |
+
audio_future = thread_pool.submit(synthesis_task, mid)
|
374 |
+
audio_futures.append(audio_future)
|
375 |
+
for future in audio_futures:
|
376 |
+
outputs.append((44100, future.result()))
|
377 |
+
if OUTPUT_BATCH_SIZE == 1:
|
378 |
+
return outputs[0]
|
379 |
+
return tuple(outputs)
|
380 |
+
|
381 |
+
|
382 |
+
def undo_continuation(model_name, mid_seq, continuation_state):
|
383 |
+
if mid_seq is None or len(continuation_state) < 2:
|
384 |
+
return mid_seq, continuation_state, send_msgs([])
|
385 |
+
tokenizer = models[model_name][2]
|
386 |
+
if isinstance(continuation_state[-1], list):
|
387 |
+
mid_seq = continuation_state[-1]
|
388 |
+
else:
|
389 |
+
mid_seq = [ms[:continuation_state[-1]] for ms in mid_seq]
|
390 |
+
continuation_state = continuation_state[:-1]
|
391 |
+
end_msgs = [create_msg("progress", [0, 0])]
|
392 |
+
for i in range(OUTPUT_BATCH_SIZE):
|
393 |
+
events = [tokenizer.tokens2event(tokens) for tokens in mid_seq[i]]
|
394 |
+
end_msgs += [create_msg("visualizer_clear", [i, tokenizer.version]),
|
395 |
+
create_msg("visualizer_append", [i, events]),
|
396 |
+
create_msg("visualizer_end", i)]
|
397 |
+
return mid_seq, continuation_state, send_msgs(end_msgs)
|
398 |
+
|
399 |
+
|
400 |
+
def load_javascript(dir="javascript"):
|
401 |
+
scripts_list = glob.glob(f"{dir}/*.js")
|
402 |
+
javascript = ""
|
403 |
+
for path in scripts_list:
|
404 |
+
with open(path, "r", encoding="utf8") as jsfile:
|
405 |
+
js_content = jsfile.read()
|
406 |
+
js_content = js_content.replace("const MIDI_OUTPUT_BATCH_SIZE=4;",
|
407 |
+
f"const MIDI_OUTPUT_BATCH_SIZE={OUTPUT_BATCH_SIZE};")
|
408 |
+
javascript += f"\n<!-- {path} --><script>{js_content}</script>"
|
409 |
+
template_response_ori = gr.routes.templates.TemplateResponse
|
410 |
+
|
411 |
+
def template_response(*args, **kwargs):
|
412 |
+
res = template_response_ori(*args, **kwargs)
|
413 |
+
res.body = res.body.replace(
|
414 |
+
b'</head>', f'{javascript}</head>'.encode("utf8"))
|
415 |
+
res.init_headers()
|
416 |
+
return res
|
417 |
+
|
418 |
+
gr.routes.templates.TemplateResponse = template_response
|
419 |
+
|
420 |
+
|
421 |
+
def hf_hub_download_retry(repo_id, filename):
|
422 |
+
print(f"downloading {repo_id} {filename}")
|
423 |
+
retry = 0
|
424 |
+
err = None
|
425 |
+
while retry < 30:
|
426 |
+
try:
|
427 |
+
return hf_hub_download(repo_id=repo_id, filename=filename)
|
428 |
+
except Exception as e:
|
429 |
+
err = e
|
430 |
+
retry += 1
|
431 |
+
if err:
|
432 |
+
raise err
|
433 |
+
|
434 |
+
|
435 |
+
def get_tokenizer(repo_id):
|
436 |
+
config_path = hf_hub_download_retry(repo_id=repo_id, filename=f"config.json")
|
437 |
+
with open(config_path, "r") as f:
|
438 |
+
config = json.load(f)
|
439 |
+
tokenizer = MIDITokenizer(config["tokenizer"]["version"])
|
440 |
+
tokenizer.set_optimise_midi(config["tokenizer"]["optimise_midi"])
|
441 |
+
return tokenizer
|
442 |
+
|
443 |
+
|
444 |
+
number2drum_kits = {-1: "None", 0: "Standard", 8: "Room", 16: "Power", 24: "Electric", 25: "TR-808", 32: "Jazz",
|
445 |
+
40: "Blush", 48: "Orchestra"}
|
446 |
+
patch2number = {v: k for k, v in MIDI.Number2patch.items()}
|
447 |
+
drum_kits2number = {v: k for k, v in number2drum_kits.items()}
|
448 |
+
key_signatures = ['C♭', 'A♭m', 'G♭', 'E♭m', 'D♭', 'B♭m', 'A♭', 'Fm', 'E♭', 'Cm', 'B♭', 'Gm', 'F', 'Dm',
|
449 |
+
'C', 'Am', 'G', 'Em', 'D', 'Bm', 'A', 'F♯m', 'E', 'C♯m', 'B', 'G♯m', 'F♯', 'D♯m', 'C♯', 'A♯m']
|
450 |
+
|
451 |
+
if __name__ == "__main__":
|
452 |
+
parser = argparse.ArgumentParser()
|
453 |
+
parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
|
454 |
+
parser.add_argument("--port", type=int, default=7860, help="gradio server port")
|
455 |
+
parser.add_argument("--device", type=str, default="cuda", help="device to run model")
|
456 |
+
parser.add_argument("--batch", type=int, default=8, help="batch size")
|
457 |
+
parser.add_argument("--max-gen", type=int, default=1024, help="max")
|
458 |
+
opt = parser.parse_args()
|
459 |
+
OUTPUT_BATCH_SIZE = opt.batch
|
460 |
+
soundfont_path = hf_hub_download_retry(repo_id="skytnt/midi-model", filename="soundfont.sf2")
|
461 |
+
thread_pool = ThreadPoolExecutor(max_workers=OUTPUT_BATCH_SIZE)
|
462 |
+
synthesizer = MidiSynthesizer(soundfont_path)
|
463 |
+
models_info = {
|
464 |
+
"generic pretrain model (tv2o-medium) by skytnt": [
|
465 |
+
"skytnt/midi-model-tv2o-medium", "", {
|
466 |
+
"jpop": "skytnt/midi-model-tv2om-jpop-lora",
|
467 |
+
"touhou": "skytnt/midi-model-tv2om-touhou-lora"
|
468 |
+
}
|
469 |
+
],
|
470 |
+
"generic pretrain model (tv2o-large) by asigalov61": [
|
471 |
+
"asigalov61/Music-Llama", "", {}
|
472 |
+
],
|
473 |
+
"generic pretrain model (tv2o-medium) by asigalov61": [
|
474 |
+
"asigalov61/Music-Llama-Medium", "", {}
|
475 |
+
],
|
476 |
+
"generic pretrain model (tv1-medium) by skytnt": [
|
477 |
+
"skytnt/midi-model", "", {}
|
478 |
+
]
|
479 |
+
}
|
480 |
+
models = {}
|
481 |
+
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
482 |
+
device = "cuda"
|
483 |
+
|
484 |
+
for name, (repo_id, path, loras) in models_info.items():
|
485 |
+
model_base_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}onnx/model_base.onnx")
|
486 |
+
model_token_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}onnx/model_token.onnx")
|
487 |
+
tokenizer = get_tokenizer(repo_id)
|
488 |
+
models[name] = [model_base_path, model_token_path, tokenizer]
|
489 |
+
for lora_name, lora_repo in loras.items():
|
490 |
+
model_base_path = hf_hub_download_retry(repo_id=lora_repo, filename=f"onnx/model_base.onnx")
|
491 |
+
model_token_path = hf_hub_download_retry(repo_id=lora_repo, filename=f"onnx/model_token.onnx")
|
492 |
+
models[f"{name} with {lora_name} lora"] = [model_base_path, model_token_path, tokenizer]
|
493 |
+
|
494 |
+
load_javascript()
|
495 |
+
app = gr.Blocks(theme=gr.themes.Soft())
|
496 |
+
with app:
|
497 |
+
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Midi Composer</h1>")
|
498 |
+
gr.Markdown("![Visitors](https://api.visitorbadge.io/api/visitors?path=skytnt.midi-composer&style=flat)\n\n"
|
499 |
+
"Midi event transformer for symbolic music generation\n\n"
|
500 |
+
"Demo for [SkyTNT/midi-model](https://github.com/SkyTNT/midi-model)\n\n"
|
501 |
+
"[Open In Colab]"
|
502 |
+
"(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
|
503 |
+
" or [download windows app](https://github.com/SkyTNT/midi-model/releases)"
|
504 |
+
" for unlimited generation\n\n"
|
505 |
+
"**Update v1.3**: MIDITokenizerV2 and new MidiVisualizer\n\n"
|
506 |
+
"The current **best** model: generic pretrain model (tv2o-medium) by skytnt"
|
507 |
+
)
|
508 |
+
js_msg = gr.Textbox(elem_id="msg_receiver", visible=False)
|
509 |
+
js_msg.change(None, [js_msg], [], js="""
|
510 |
+
(msg_json) =>{
|
511 |
+
let msgs = JSON.parse(msg_json);
|
512 |
+
executeCallbacks(msgReceiveCallbacks, msgs);
|
513 |
+
return [];
|
514 |
+
}
|
515 |
+
""")
|
516 |
+
input_model = gr.Dropdown(label="select model", choices=list(models.keys()),
|
517 |
+
type="value", value=list(models.keys())[0])
|
518 |
+
tab_select = gr.State(value=0)
|
519 |
+
with gr.Tabs():
|
520 |
+
with gr.TabItem("custom prompt") as tab1:
|
521 |
+
input_instruments = gr.Dropdown(label="🪗instruments (auto if empty)", choices=list(patch2number.keys()),
|
522 |
+
multiselect=True, max_choices=15, type="value")
|
523 |
+
input_drum_kit = gr.Dropdown(label="🥁drum kit", choices=list(drum_kits2number.keys()), type="value",
|
524 |
+
value="None")
|
525 |
+
input_bpm = gr.Slider(label="BPM (beats per minute, auto if 0)", minimum=0, maximum=255,
|
526 |
+
step=1,
|
527 |
+
value=0)
|
528 |
+
input_time_sig = gr.Radio(label="time signature (only for tv2 models)",
|
529 |
+
value="auto",
|
530 |
+
choices=["auto", "4/4", "2/4", "3/4", "6/4", "7/4",
|
531 |
+
"2/2", "3/2", "4/2", "3/8", "5/8", "6/8", "7/8", "9/8", "12/8"]
|
532 |
+
)
|
533 |
+
input_key_sig = gr.Radio(label="key signature (only for tv2 models)",
|
534 |
+
value="auto",
|
535 |
+
choices=["auto"] + key_signatures,
|
536 |
+
type="index"
|
537 |
+
)
|
538 |
+
example1 = gr.Examples([
|
539 |
+
[[], "None"],
|
540 |
+
[["Acoustic Grand"], "None"],
|
541 |
+
[['Acoustic Grand', 'SynthStrings 2', 'SynthStrings 1', 'Pizzicato Strings',
|
542 |
+
'Pad 2 (warm)', 'Tremolo Strings', 'String Ensemble 1'], "Orchestra"],
|
543 |
+
[['Trumpet', 'Oboe', 'Trombone', 'String Ensemble 1', 'Clarinet',
|
544 |
+
'French Horn', 'Pad 4 (choir)', 'Bassoon', 'Flute'], "None"],
|
545 |
+
[['Flute', 'French Horn', 'Clarinet', 'String Ensemble 2', 'English Horn', 'Bassoon',
|
546 |
+
'Oboe', 'Pizzicato Strings'], "Orchestra"],
|
547 |
+
[['Electric Piano 2', 'Lead 5 (charang)', 'Electric Bass(pick)', 'Lead 2 (sawtooth)',
|
548 |
+
'Pad 1 (new age)', 'Orchestra Hit', 'Cello', 'Electric Guitar(clean)'], "Standard"],
|
549 |
+
[["Electric Guitar(clean)", "Electric Guitar(muted)", "Overdriven Guitar", "Distortion Guitar",
|
550 |
+
"Electric Bass(finger)"], "Standard"]
|
551 |
+
], [input_instruments, input_drum_kit])
|
552 |
+
with gr.TabItem("midi prompt") as tab2:
|
553 |
+
input_midi = gr.File(label="input midi", file_types=[".midi", ".mid"], type="binary")
|
554 |
+
input_midi_events = gr.Slider(label="use first n midi events as prompt", minimum=1, maximum=512,
|
555 |
+
step=1,
|
556 |
+
value=128)
|
557 |
+
input_reduce_cc_st = gr.Checkbox(label="reduce control_change and set_tempo events", value=True)
|
558 |
+
input_remap_track_channel = gr.Checkbox(
|
559 |
+
label="remap tracks and channels so each track has only one channel and in order", value=True)
|
560 |
+
input_add_default_instr = gr.Checkbox(
|
561 |
+
label="add a default instrument to channels that don't have an instrument", value=True)
|
562 |
+
input_remove_empty_channels = gr.Checkbox(label="remove channels without notes", value=False)
|
563 |
+
example2 = gr.Examples([[file, 128] for file in glob.glob("example/*.mid")],
|
564 |
+
[input_midi, input_midi_events])
|
565 |
+
with gr.TabItem("last output prompt") as tab3:
|
566 |
+
gr.Markdown("Continue generating on the last output.")
|
567 |
+
input_continuation_select = gr.Radio(label="select output to continue generating", value="all",
|
568 |
+
choices=["all"] + [f"output{i + 1}" for i in
|
569 |
+
range(OUTPUT_BATCH_SIZE)],
|
570 |
+
type="index"
|
571 |
+
)
|
572 |
+
undo_btn = gr.Button("undo the last continuation")
|
573 |
+
|
574 |
+
tab1.select(lambda: 0, None, tab_select, queue=False)
|
575 |
+
tab2.select(lambda: 1, None, tab_select, queue=False)
|
576 |
+
tab3.select(lambda: 2, None, tab_select, queue=False)
|
577 |
+
input_seed = gr.Slider(label="seed", minimum=0, maximum=2 ** 31 - 1,
|
578 |
+
step=1, value=0)
|
579 |
+
input_seed_rand = gr.Checkbox(label="random seed", value=True)
|
580 |
+
input_gen_events = gr.Slider(label="generate max n midi events", minimum=1, maximum=opt.max_gen,
|
581 |
+
step=1, value=opt.max_gen // 2)
|
582 |
+
with gr.Accordion("options", open=False):
|
583 |
+
input_temp = gr.Slider(label="temperature", minimum=0.1, maximum=1.2, step=0.01, value=1)
|
584 |
+
input_top_p = gr.Slider(label="top p", minimum=0.1, maximum=1, step=0.01, value=0.95)
|
585 |
+
input_top_k = gr.Slider(label="top k", minimum=1, maximum=128, step=1, value=20)
|
586 |
+
input_allow_cc = gr.Checkbox(label="allow midi cc event", value=True)
|
587 |
+
input_render_audio = gr.Checkbox(label="render audio after generation", value=True)
|
588 |
+
example3 = gr.Examples([[1, 0.94, 128], [1, 0.98, 20], [1, 0.98, 12]],
|
589 |
+
[input_temp, input_top_p, input_top_k])
|
590 |
+
run_btn = gr.Button("generate", variant="primary")
|
591 |
+
# stop_btn = gr.Button("stop and output")
|
592 |
+
output_midi_seq = gr.State()
|
593 |
+
output_continuation_state = gr.State([0])
|
594 |
+
midi_outputs = []
|
595 |
+
audio_outputs = []
|
596 |
+
with gr.Tabs(elem_id="output_tabs"):
|
597 |
+
for i in range(OUTPUT_BATCH_SIZE):
|
598 |
+
with gr.TabItem(f"output {i + 1}") as tab1:
|
599 |
+
output_midi_visualizer = gr.HTML(elem_id=f"midi_visualizer_container_{i}")
|
600 |
+
output_audio = gr.Audio(label="output audio", format="mp3", elem_id=f"midi_audio_{i}")
|
601 |
+
output_midi = gr.File(label="output midi", file_types=[".mid"])
|
602 |
+
midi_outputs.append(output_midi)
|
603 |
+
audio_outputs.append(output_audio)
|
604 |
+
run_event = run_btn.click(run, [input_model, tab_select, output_midi_seq, output_continuation_state,
|
605 |
+
input_continuation_select, input_instruments, input_drum_kit, input_bpm,
|
606 |
+
input_time_sig, input_key_sig, input_midi, input_midi_events,
|
607 |
+
input_reduce_cc_st, input_remap_track_channel,
|
608 |
+
input_add_default_instr, input_remove_empty_channels,
|
609 |
+
input_seed, input_seed_rand, input_gen_events, input_temp, input_top_p,
|
610 |
+
input_top_k, input_allow_cc],
|
611 |
+
[output_midi_seq, output_continuation_state, input_seed, js_msg],
|
612 |
+
concurrency_limit=10, queue=True)
|
613 |
+
finish_run_event = run_event.then(fn=finish_run,
|
614 |
+
inputs=[input_model, output_midi_seq],
|
615 |
+
outputs=midi_outputs + [js_msg],
|
616 |
+
queue=False)
|
617 |
+
finish_run_event.then(fn=render_audio,
|
618 |
+
inputs=[input_model, output_midi_seq, input_render_audio],
|
619 |
+
outputs=audio_outputs,
|
620 |
+
queue=False)
|
621 |
+
# stop_btn.click(None, [], [], cancels=run_event,
|
622 |
+
# queue=False)
|
623 |
+
undo_btn.click(undo_continuation, [input_model, output_midi_seq, output_continuation_state],
|
624 |
+
[output_midi_seq, output_continuation_state, js_msg], queue=False)
|
625 |
+
app.queue().launch(server_port=opt.port, share=opt.share, inbrowser=True, ssr_mode=False)
|
626 |
+
thread_pool.shutdown()
|
example/Bach--Fugue-in-D-Minor.mid
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1398121eb86a33e73f90ec84be71dac6abc0ddf11372ea7cdd9e01586938a56b
|
3 |
+
size 7720
|
example/Beethoven--Symphony-No5-in-C-Minor-Fate-Opus-67.mid
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:28ff6fdcd644e781d36411bf40ab7a1f4849adddbcd1040eaec22751c5ca99d2
|
3 |
+
size 87090
|
example/Chopin--Nocturne No. 9 in B Major, Opus 32 No.1, Andante Sostenuto.mid
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3a236e647ad9f5d0af680d3ca19d3b60f334c4bde6b4f86310f63405245c476e
|
3 |
+
size 13484
|
example/Mozart--Requiem, No.1..mid
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:aa49bf4633401e16777fe47f6f53a494c2166f5101af6dafc60114932a59b9bd
|
3 |
+
size 14695
|
example/castle_in_the_sky.mid
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fa14aec6f1be15c4fddd0decc6d9152204f160d4e07e05d8d1dc9f209c309ff7
|
3 |
+
size 7957
|
example/eva-残酷な天使のテーゼ.mid
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e513487543d7e27ec5dc30f027302d2a3b5a3aaf9af554def1e5cd6a7a8d355a
|
3 |
+
size 17671
|
javascript/app.js
ADDED
@@ -0,0 +1,732 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
const MIDI_OUTPUT_BATCH_SIZE=4;
|
2 |
+
//Do not change MIDI_OUTPUT_BATCH_SIZE. It will be automatically replaced.
|
3 |
+
|
4 |
+
/**
|
5 |
+
* 自动绕过 shadowRoot 的 querySelector
|
6 |
+
* @param {string} selector - 要查询的 CSS 选择器
|
7 |
+
* @returns {Element|null} - 匹配的元素或 null 如果未找到
|
8 |
+
*/
|
9 |
+
function deepQuerySelector(selector) {
|
10 |
+
/**
|
11 |
+
* 在指定的根元素或文档对象下深度查询元素
|
12 |
+
* @param {Element|Document} root - 要开始搜索的根元素或文档对象
|
13 |
+
* @param {string} selector - 要查询的 CSS 选择器
|
14 |
+
* @returns {Element|null} - 匹配的元素或 null 如果未找到
|
15 |
+
*/
|
16 |
+
function deepSearch(root, selector) {
|
17 |
+
// 在当前根元素下查找
|
18 |
+
let element = root.querySelector(selector);
|
19 |
+
if (element) {
|
20 |
+
return element;
|
21 |
+
}
|
22 |
+
|
23 |
+
// 如果未找到,递归检查 shadow DOM
|
24 |
+
const shadowHosts = root.querySelectorAll('*');
|
25 |
+
|
26 |
+
for (let i = 0; i < shadowHosts.length; i++) {
|
27 |
+
const host = shadowHosts[i];
|
28 |
+
|
29 |
+
// 检查当前元素是否有 shadowRoot
|
30 |
+
if (host.shadowRoot) {
|
31 |
+
element = deepSearch(host.shadowRoot, selector);
|
32 |
+
if (element) {
|
33 |
+
return element;
|
34 |
+
}
|
35 |
+
}
|
36 |
+
}
|
37 |
+
// 未找到元素
|
38 |
+
return null;
|
39 |
+
}
|
40 |
+
|
41 |
+
return deepSearch(this, selector);
|
42 |
+
}
|
43 |
+
|
44 |
+
Element.prototype.deepQuerySelector = deepQuerySelector;
|
45 |
+
Document.prototype.deepQuerySelector = deepQuerySelector;
|
46 |
+
|
47 |
+
function gradioApp() {
|
48 |
+
const elems = document.getElementsByTagName('gradio-app')
|
49 |
+
const gradioShadowRoot = elems.length == 0 ? null : elems[0].shadowRoot
|
50 |
+
return !!gradioShadowRoot ? gradioShadowRoot : document;
|
51 |
+
}
|
52 |
+
|
53 |
+
uiUpdateCallbacks = []
|
54 |
+
msgReceiveCallbacks = []
|
55 |
+
|
56 |
+
function onUiUpdate(callback){
|
57 |
+
uiUpdateCallbacks.push(callback)
|
58 |
+
}
|
59 |
+
|
60 |
+
function onMsgReceive(callback){
|
61 |
+
msgReceiveCallbacks.push(callback)
|
62 |
+
}
|
63 |
+
|
64 |
+
function runCallback(x, m){
|
65 |
+
try {
|
66 |
+
x(m)
|
67 |
+
} catch (e) {
|
68 |
+
(console.error || console.log).call(console, e.message, e);
|
69 |
+
}
|
70 |
+
}
|
71 |
+
function executeCallbacks(queue, m) {
|
72 |
+
queue.forEach(function(x){runCallback(x, m)})
|
73 |
+
}
|
74 |
+
|
75 |
+
document.addEventListener("DOMContentLoaded", function() {
|
76 |
+
var mutationObserver = new MutationObserver(function(m){
|
77 |
+
executeCallbacks(uiUpdateCallbacks, m);
|
78 |
+
});
|
79 |
+
mutationObserver.observe( gradioApp(), { childList:true, subtree:true })
|
80 |
+
});
|
81 |
+
|
82 |
+
function HSVtoRGB(h, s, v) {
|
83 |
+
let r, g, b, i, f, p, q, t;
|
84 |
+
i = Math.floor(h * 6);
|
85 |
+
f = h * 6 - i;
|
86 |
+
p = v * (1 - s);
|
87 |
+
q = v * (1 - f * s);
|
88 |
+
t = v * (1 - (1 - f) * s);
|
89 |
+
switch (i % 6) {
|
90 |
+
case 0: r = v; g = t; b = p; break;
|
91 |
+
case 1: r = q; g = v; b = p; break;
|
92 |
+
case 2: r = p; g = v; b = t; break;
|
93 |
+
case 3: r = p; g = q; b = v; break;
|
94 |
+
case 4: r = t; g = p; b = v; break;
|
95 |
+
case 5: r = v; g = p; b = q; break;
|
96 |
+
}
|
97 |
+
return {
|
98 |
+
r: Math.round(r * 255),
|
99 |
+
g: Math.round(g * 255),
|
100 |
+
b: Math.round(b * 255)
|
101 |
+
};
|
102 |
+
}
|
103 |
+
|
104 |
+
function isMobile(){
|
105 |
+
return /(iPhone|iPad|iPod|iOS|Android|Windows Phone)/i.test(navigator.userAgent);
|
106 |
+
}
|
107 |
+
|
108 |
+
const number2patch = ['Acoustic Grand', 'Bright Acoustic', 'Electric Grand', 'Honky-Tonk', 'Electric Piano 1', 'Electric Piano 2', 'Harpsichord', 'Clav', 'Celesta', 'Glockenspiel', 'Music Box', 'Vibraphone', 'Marimba', 'Xylophone', 'Tubular Bells', 'Dulcimer', 'Drawbar Organ', 'Percussive Organ', 'Rock Organ', 'Church Organ', 'Reed Organ', 'Accordion', 'Harmonica', 'Tango Accordion', 'Acoustic Guitar(nylon)', 'Acoustic Guitar(steel)', 'Electric Guitar(jazz)', 'Electric Guitar(clean)', 'Electric Guitar(muted)', 'Overdriven Guitar', 'Distortion Guitar', 'Guitar Harmonics', 'Acoustic Bass', 'Electric Bass(finger)', 'Electric Bass(pick)', 'Fretless Bass', 'Slap Bass 1', 'Slap Bass 2', 'Synth Bass 1', 'Synth Bass 2', 'Violin', 'Viola', 'Cello', 'Contrabass', 'Tremolo Strings', 'Pizzicato Strings', 'Orchestral Harp', 'Timpani', 'String Ensemble 1', 'String Ensemble 2', 'SynthStrings 1', 'SynthStrings 2', 'Choir Aahs', 'Voice Oohs', 'Synth Voice', 'Orchestra Hit', 'Trumpet', 'Trombone', 'Tuba', 'Muted Trumpet', 'French Horn', 'Brass Section', 'SynthBrass 1', 'SynthBrass 2', 'Soprano Sax', 'Alto Sax', 'Tenor Sax', 'Baritone Sax', 'Oboe', 'English Horn', 'Bassoon', 'Clarinet', 'Piccolo', 'Flute', 'Recorder', 'Pan Flute', 'Blown Bottle', 'Skakuhachi', 'Whistle', 'Ocarina', 'Lead 1 (square)', 'Lead 2 (sawtooth)', 'Lead 3 (calliope)', 'Lead 4 (chiff)', 'Lead 5 (charang)', 'Lead 6 (voice)', 'Lead 7 (fifths)', 'Lead 8 (bass+lead)', 'Pad 1 (new age)', 'Pad 2 (warm)', 'Pad 3 (polysynth)', 'Pad 4 (choir)', 'Pad 5 (bowed)', 'Pad 6 (metallic)', 'Pad 7 (halo)', 'Pad 8 (sweep)', 'FX 1 (rain)', 'FX 2 (soundtrack)', 'FX 3 (crystal)', 'FX 4 (atmosphere)', 'FX 5 (brightness)', 'FX 6 (goblins)', 'FX 7 (echoes)', 'FX 8 (sci-fi)', 'Sitar', 'Banjo', 'Shamisen', 'Koto', 'Kalimba', 'Bagpipe', 'Fiddle', 'Shanai', 'Tinkle Bell', 'Agogo', 'Steel Drums', 'Woodblock', 'Taiko Drum', 'Melodic Tom', 'Synth Drum', 'Reverse Cymbal', 'Guitar Fret Noise', 'Breath Noise', 'Seashore', 'Bird Tweet', 'Telephone Ring', 'Helicopter', 'Applause', 'Gunshot']
|
109 |
+
const number2drum_kits = {0: "Standard", 8: "Room", 16: "Power", 24: "Electric", 25: "TR-808", 32: "Jazz", 40: "Blush", 48: "Orchestra"}
|
110 |
+
|
111 |
+
class MidiVisualizer extends HTMLElement{
|
112 |
+
constructor() {
|
113 |
+
super();
|
114 |
+
this.midiEvents = [];
|
115 |
+
this.activeNotes = [];
|
116 |
+
this.midiTimes = [];
|
117 |
+
this.trackMap = new Map()
|
118 |
+
this.patches = [];
|
119 |
+
for (let i=0;i<16;i++){
|
120 |
+
this.patches.push([[0,0]])
|
121 |
+
}
|
122 |
+
this.container = null;
|
123 |
+
this.trackList = null
|
124 |
+
this.pianoRoll = null;
|
125 |
+
this.svg = null;
|
126 |
+
this.timeLine = null;
|
127 |
+
this.config = {
|
128 |
+
noteHeight : 4,
|
129 |
+
beatWidth: 32
|
130 |
+
}
|
131 |
+
if (isMobile()){
|
132 |
+
this.config.noteHeight = 1;
|
133 |
+
this.config.beatWidth = 16;
|
134 |
+
}
|
135 |
+
this.timePreBeat = 16
|
136 |
+
this.svgWidth = 0;
|
137 |
+
this.t1 = 0;
|
138 |
+
this.totalTimeMs = 0
|
139 |
+
this.playTime = 0
|
140 |
+
this.playTimeMs = 0
|
141 |
+
this.lastUpdateTime = 0
|
142 |
+
this.colorMap = new Map();
|
143 |
+
this.playing = false;
|
144 |
+
this.timer = null;
|
145 |
+
this.version = "v2"
|
146 |
+
this.init();
|
147 |
+
}
|
148 |
+
|
149 |
+
init(){
|
150 |
+
this.innerHTML=''
|
151 |
+
const shadow = this.attachShadow({mode: 'open'});
|
152 |
+
const style = document.createElement("style");
|
153 |
+
style.textContent = ".note.active {stroke: black;stroke-width: 0.75;stroke-opacity: 0.75;}";
|
154 |
+
const container = document.createElement('div');
|
155 |
+
container.style.display="flex";
|
156 |
+
container.style.height=`${this.config.noteHeight*128 + 25}px`;
|
157 |
+
const trackListContainer = document.createElement('div');
|
158 |
+
trackListContainer.style.width = "260px";
|
159 |
+
trackListContainer.style.minWidth = "260px";
|
160 |
+
trackListContainer.style.height = "100%";
|
161 |
+
trackListContainer.style.display="flex";
|
162 |
+
trackListContainer.style.flexDirection="column";
|
163 |
+
const trackList = document.createElement('div');
|
164 |
+
trackList.style.width = "100%";
|
165 |
+
trackList.style.height = "100%";
|
166 |
+
trackList.style.overflowY= "scroll";
|
167 |
+
trackList.style.display="flex";
|
168 |
+
trackList.style.flexDirection="column";
|
169 |
+
trackList.style.flexGrow="1";
|
170 |
+
const trackControls = document.createElement('div');
|
171 |
+
trackControls.style.display="flex";
|
172 |
+
trackControls.style.flexDirection="row";
|
173 |
+
trackControls.style.width = "100%";
|
174 |
+
trackControls.style.height = "50px";
|
175 |
+
trackControls.style.minHeight = "50px";
|
176 |
+
const allTrackBtn = document.createElement('button');
|
177 |
+
allTrackBtn.textContent = "All";
|
178 |
+
allTrackBtn.style.width = "50%";
|
179 |
+
allTrackBtn.style.height = "100%";
|
180 |
+
allTrackBtn.style.backgroundColor = "rgba(200, 200, 200, 0.3)";
|
181 |
+
allTrackBtn.style.color = 'inherit';
|
182 |
+
allTrackBtn.style.border = "none";
|
183 |
+
allTrackBtn.style.cursor = 'pointer';
|
184 |
+
let self = this;
|
185 |
+
allTrackBtn.onclick = function (){
|
186 |
+
self.trackMap.forEach((track, id) => {
|
187 |
+
track.setChecked(true);
|
188 |
+
})
|
189 |
+
};
|
190 |
+
const noneTrackBtn = document.createElement('button');
|
191 |
+
noneTrackBtn.textContent = "None";
|
192 |
+
noneTrackBtn.style.width = "50%";
|
193 |
+
noneTrackBtn.style.height = "100%";
|
194 |
+
noneTrackBtn.style.backgroundColor = "rgba(200, 200, 200, 0.3)";
|
195 |
+
noneTrackBtn.style.color = 'inherit';
|
196 |
+
noneTrackBtn.style.border = "none";
|
197 |
+
noneTrackBtn.style.cursor = 'pointer';
|
198 |
+
noneTrackBtn.onclick = function (){
|
199 |
+
self.trackMap.forEach((track, id) => {
|
200 |
+
track.setChecked(false);
|
201 |
+
});
|
202 |
+
};
|
203 |
+
const pianoRoll = document.createElement('div');
|
204 |
+
pianoRoll.style.overflowX= "scroll";
|
205 |
+
pianoRoll.style.flexGrow="1";
|
206 |
+
const svg = document.createElementNS('http://www.w3.org/2000/svg', 'svg');
|
207 |
+
svg.style.height = `${this.config.noteHeight*128}px`;
|
208 |
+
svg.style.width = `${this.svgWidth}px`;
|
209 |
+
const timeLine = document.createElementNS('http://www.w3.org/2000/svg', 'line');
|
210 |
+
timeLine.style.stroke = "green"
|
211 |
+
timeLine.style.strokeWidth = "2";
|
212 |
+
|
213 |
+
if (isMobile()){
|
214 |
+
trackListContainer.style.display = "none";
|
215 |
+
timeLine.style.strokeWidth = "1";
|
216 |
+
}
|
217 |
+
shadow.appendChild(style)
|
218 |
+
shadow.appendChild(container);
|
219 |
+
container.appendChild(trackListContainer);
|
220 |
+
trackListContainer.appendChild(trackList);
|
221 |
+
trackListContainer.appendChild(trackControls);
|
222 |
+
trackControls.appendChild(allTrackBtn);
|
223 |
+
trackControls.appendChild(noneTrackBtn);
|
224 |
+
container.appendChild(pianoRoll);
|
225 |
+
pianoRoll.appendChild(svg);
|
226 |
+
svg.appendChild(timeLine)
|
227 |
+
this.container = container;
|
228 |
+
this.trackList = trackList;
|
229 |
+
this.pianoRoll = pianoRoll;
|
230 |
+
this.svg = svg;
|
231 |
+
this.timeLine= timeLine;
|
232 |
+
for(let i = 0; i < 128 ; i++){
|
233 |
+
this.colorMap.set(i, HSVtoRGB(i / 128, 1, 1))
|
234 |
+
}
|
235 |
+
this.setPlayTime(0);
|
236 |
+
}
|
237 |
+
|
238 |
+
addTrack(id, tr, cl, name, color){
|
239 |
+
const track = {id, tr, cl, name, color, empty: true,
|
240 |
+
lastCC: new Map(),
|
241 |
+
instrument: cl===9?"Standard Drum":"Acoustic Grand",
|
242 |
+
svg: document.createElementNS('http://www.w3.org/2000/svg', 'g'),
|
243 |
+
ccPaths: new Map()
|
244 |
+
}
|
245 |
+
this.svg.appendChild(track.svg)
|
246 |
+
const trackItem = this.createTrackItem(track);
|
247 |
+
this.trackList.appendChild(trackItem);
|
248 |
+
this.trackMap.set(id, track);
|
249 |
+
return track;
|
250 |
+
}
|
251 |
+
|
252 |
+
getTrack(tr, cl){
|
253 |
+
const id = tr * 16 + cl
|
254 |
+
let track = this.trackMap.get(id)
|
255 |
+
if (!!track){
|
256 |
+
return track
|
257 |
+
}
|
258 |
+
let color = this.colorMap.get((this.trackMap.size*53)%128)
|
259 |
+
return this.addTrack(id, tr, cl, `Track ${tr}, Channel ${cl}`, color)
|
260 |
+
}
|
261 |
+
|
262 |
+
createTrackItem(track) {
|
263 |
+
const trackItem = document.createElement('div');
|
264 |
+
trackItem.style.display = 'flex';
|
265 |
+
trackItem.style.alignItems = 'center';
|
266 |
+
trackItem.style.width = '100%';
|
267 |
+
trackItem.style.position = 'relative';
|
268 |
+
|
269 |
+
const colorBar = document.createElement('div');
|
270 |
+
colorBar.style.width = '5%';
|
271 |
+
colorBar.style.height = '100%';
|
272 |
+
colorBar.style.position = 'absolute';
|
273 |
+
colorBar.style.left = '0';
|
274 |
+
colorBar.style.top = '0';
|
275 |
+
let color = track.color;
|
276 |
+
colorBar.style.backgroundColor = `rgb(${color.r}, ${color.g}, ${color.b})`;
|
277 |
+
trackItem.appendChild(colorBar);
|
278 |
+
|
279 |
+
const content = document.createElement('div');
|
280 |
+
content.style.paddingLeft = '30px';
|
281 |
+
content.style.flexGrow = '1';
|
282 |
+
content.style.color = "grey"
|
283 |
+
content.innerHTML = `<p>${track.name}<br>${track.instrument}</p>`;
|
284 |
+
trackItem.appendChild(content);
|
285 |
+
track.updateInstrument = function (instrument){
|
286 |
+
track.instrument = instrument;
|
287 |
+
content.innerHTML = `<p>${track.name}<br>${track.instrument}</p>`;
|
288 |
+
}
|
289 |
+
track.setEmpty = function (empty){
|
290 |
+
if (empty!==track.empty){
|
291 |
+
content.style.color = empty?"grey":"inherit";
|
292 |
+
}
|
293 |
+
}
|
294 |
+
|
295 |
+
const toggleSwitch = document.createElement('input');
|
296 |
+
toggleSwitch.type = 'checkbox';
|
297 |
+
toggleSwitch.checked = true;
|
298 |
+
toggleSwitch.style.marginLeft = 'auto';
|
299 |
+
toggleSwitch.style.marginRight = '10px';
|
300 |
+
toggleSwitch.style.width = '20px';
|
301 |
+
toggleSwitch.style.height = '20px';
|
302 |
+
toggleSwitch.style.cursor = 'pointer';
|
303 |
+
|
304 |
+
toggleSwitch.onchange = function () {
|
305 |
+
track.svg.setAttribute('visibility',toggleSwitch.checked? "visible" : "hidden")
|
306 |
+
};
|
307 |
+
track.setChecked = function (checked){
|
308 |
+
toggleSwitch.checked = checked;
|
309 |
+
track.svg.setAttribute('visibility',toggleSwitch.checked? "visible" : "hidden")
|
310 |
+
}
|
311 |
+
trackItem.appendChild(toggleSwitch);
|
312 |
+
return trackItem;
|
313 |
+
}
|
314 |
+
|
315 |
+
clearMidiEvents(){
|
316 |
+
this.pause()
|
317 |
+
this.midiEvents = [];
|
318 |
+
this.activeNotes = [];
|
319 |
+
this.midiTimes = [];
|
320 |
+
this.trackMap = new Map()
|
321 |
+
this.patches = [];
|
322 |
+
for (let i=0;i<16;i++){
|
323 |
+
this.patches.push([[0,0]])
|
324 |
+
}
|
325 |
+
this.t1 = 0
|
326 |
+
this.setPlayTime(0);
|
327 |
+
this.totalTimeMs = 0;
|
328 |
+
this.playTimeMs = 0
|
329 |
+
this.lastUpdateTime = 0
|
330 |
+
this.trackList.innerHTML = ''
|
331 |
+
this.svgWidth = 0
|
332 |
+
this.svg.innerHTML = ''
|
333 |
+
this.svg.style.width = `${this.svgWidth}px`;
|
334 |
+
this.svg.appendChild(this.timeLine)
|
335 |
+
}
|
336 |
+
|
337 |
+
appendMidiEvent(midiEvent){
|
338 |
+
if(midiEvent instanceof Array && midiEvent.length > 0){
|
339 |
+
|
340 |
+
this.t1 += midiEvent[1]
|
341 |
+
let t = this.t1*this.timePreBeat + midiEvent[2]
|
342 |
+
midiEvent = [midiEvent[0], t].concat(midiEvent.slice(3))
|
343 |
+
if(midiEvent[0] === "note"){
|
344 |
+
let track = midiEvent[2]
|
345 |
+
let duration = 0
|
346 |
+
let channel = 0
|
347 |
+
let pitch = 0
|
348 |
+
let velocity = 0
|
349 |
+
if(this.version === "v1"){
|
350 |
+
duration = midiEvent[3]
|
351 |
+
channel = midiEvent[4]
|
352 |
+
pitch = midiEvent[5]
|
353 |
+
velocity = midiEvent[6]
|
354 |
+
}else if (this.version === "v2"){
|
355 |
+
channel = midiEvent[3]
|
356 |
+
pitch = midiEvent[4]
|
357 |
+
velocity = midiEvent[5]
|
358 |
+
duration = midiEvent[6]
|
359 |
+
}
|
360 |
+
let vis_track = this.getTrack(track, channel);
|
361 |
+
vis_track.setEmpty(false);
|
362 |
+
let x = (t/this.timePreBeat)*this.config.beatWidth
|
363 |
+
let y = (127 - pitch)*this.config.noteHeight
|
364 |
+
let w = (duration/this.timePreBeat)*this.config.beatWidth
|
365 |
+
let h = this.config.noteHeight
|
366 |
+
this.svgWidth = Math.ceil(Math.max(x + w, this.svgWidth))
|
367 |
+
let opacity = Math.min(1, velocity/127 + 0.1).toFixed(2)
|
368 |
+
let rect = this.drawNote(vis_track, x,y,w,h, opacity)
|
369 |
+
midiEvent.push(rect);
|
370 |
+
this.setPlayTime(t);
|
371 |
+
this.pianoRoll.scrollTo(this.svgWidth - this.pianoRoll.offsetWidth, this.pianoRoll.scrollTop)
|
372 |
+
}else if(midiEvent[0] === "patch_change"){
|
373 |
+
let track = midiEvent[2];
|
374 |
+
let channel = midiEvent[3];
|
375 |
+
this.patches[channel].push([t, midiEvent[4]]);
|
376 |
+
this.patches[channel].sort((a, b) => a[0] - b[0]);
|
377 |
+
this.getTrack(track, channel);
|
378 |
+
}else if(midiEvent[0] === "control_change"){
|
379 |
+
let track = midiEvent[2];
|
380 |
+
let channel = midiEvent[3];
|
381 |
+
let controller = midiEvent[4];
|
382 |
+
let value = midiEvent[5];
|
383 |
+
let vis_track = this.getTrack(track, channel);
|
384 |
+
this.drawCC(vis_track, t, controller, value);
|
385 |
+
this.setPlayTime(t);
|
386 |
+
}
|
387 |
+
this.midiEvents.push(midiEvent);
|
388 |
+
this.svg.style.width = `${this.svgWidth}px`;
|
389 |
+
}
|
390 |
+
|
391 |
+
}
|
392 |
+
|
393 |
+
drawNote(track, x, y, w, h, opacity) {
|
394 |
+
if (!track.svg) {
|
395 |
+
return null;
|
396 |
+
}
|
397 |
+
const rect = document.createElementNS('http://www.w3.org/2000/svg', 'rect');
|
398 |
+
rect.classList.add('note');
|
399 |
+
const color = track.color;
|
400 |
+
rect.setAttribute('fill', `rgba(${color.r}, ${color.g}, ${color.b}, ${opacity})`);
|
401 |
+
// Round values to the nearest integer to avoid partially filled pixels.
|
402 |
+
rect.setAttribute('x', `${Math.round(x)}`);
|
403 |
+
rect.setAttribute('y', `${Math.round(y)}`);
|
404 |
+
rect.setAttribute('width', `${Math.round(w)}`);
|
405 |
+
rect.setAttribute('height', `${Math.round(h)}`);
|
406 |
+
track.svg.appendChild(rect);
|
407 |
+
return rect
|
408 |
+
}
|
409 |
+
|
410 |
+
drawCC(track, t, controller, value){
|
411 |
+
if (!track.svg) {
|
412 |
+
return null;
|
413 |
+
}
|
414 |
+
let path = track.ccPaths.get(controller);
|
415 |
+
let x = (t/this.timePreBeat)*this.config.beatWidth
|
416 |
+
let y = (127 - value)*this.config.noteHeight
|
417 |
+
if (!path){
|
418 |
+
path = document.createElementNS('http://www.w3.org/2000/svg', 'path');
|
419 |
+
path.setAttribute('visibility',"hidden");
|
420 |
+
path.setAttribute('fill', "transparent");
|
421 |
+
const color = track.color;
|
422 |
+
path.setAttribute('stroke', `rgba(${color.r}, ${color.g}, ${color.b}, 0.6)`);
|
423 |
+
path.setAttribute('stroke-width', "1");
|
424 |
+
path.setAttribute('d',
|
425 |
+
t===0?`M ${x} ${y}`:`M 0 ${127*this.config.noteHeight} H ${x} V ${y}`);
|
426 |
+
track.svg.appendChild(path);
|
427 |
+
track.ccPaths.set(controller, path);
|
428 |
+
track.lastCC.set(controller, value);
|
429 |
+
return path;
|
430 |
+
}
|
431 |
+
let lastVal = track.lastCC.get(controller);
|
432 |
+
if(lastVal !== value){
|
433 |
+
path.removeAttribute('visibility');
|
434 |
+
}
|
435 |
+
let d = path.getAttribute("d");
|
436 |
+
d += `H ${x} V ${y}`
|
437 |
+
path.setAttribute('d', d);
|
438 |
+
return path
|
439 |
+
}
|
440 |
+
|
441 |
+
finishAppendMidiEvent(){
|
442 |
+
this.pause()
|
443 |
+
let midiEvents = this.midiEvents.sort((a, b)=>a[1]-b[1])
|
444 |
+
let tempo = (60 / 120) * 10 ** 3
|
445 |
+
let ms = 0
|
446 |
+
let lastT = 0
|
447 |
+
this.midiTimes.push({ms:ms, t: 0, tempo: tempo})
|
448 |
+
midiEvents.forEach((midiEvent)=>{
|
449 |
+
let t = midiEvent[1]
|
450 |
+
ms += ((t- lastT) / this.timePreBeat) * tempo
|
451 |
+
if(midiEvent[0]==="set_tempo"){
|
452 |
+
tempo = (60 / midiEvent[3]) * 10 ** 3
|
453 |
+
this.midiTimes.push({ms:ms, t: t, tempo: tempo})
|
454 |
+
}
|
455 |
+
if(midiEvent[0]==="note"){
|
456 |
+
this.totalTimeMs = Math.max(this.totalTimeMs, ms + (midiEvent[3]/ this.timePreBeat)*tempo)
|
457 |
+
}else{
|
458 |
+
this.totalTimeMs = Math.max(this.totalTimeMs, ms);
|
459 |
+
}
|
460 |
+
lastT = t;
|
461 |
+
})
|
462 |
+
let x = (lastT/this.timePreBeat)*this.config.beatWidth;
|
463 |
+
this.trackMap.forEach((track, id)=>{
|
464 |
+
track.ccPaths.forEach((path, controller)=>{
|
465 |
+
let d = path.getAttribute("d");
|
466 |
+
d += `H ${x}`
|
467 |
+
path.setAttribute('d', d);
|
468 |
+
})
|
469 |
+
})
|
470 |
+
}
|
471 |
+
|
472 |
+
setPlayTime(t){
|
473 |
+
this.playTime = t
|
474 |
+
let x = Math.round((t/this.timePreBeat)*this.config.beatWidth)
|
475 |
+
this.timeLine.setAttribute('x1', `${x}`);
|
476 |
+
this.timeLine.setAttribute('y1', '0');
|
477 |
+
this.timeLine.setAttribute('x2', `${x}`);
|
478 |
+
this.timeLine.setAttribute('y2', `${this.config.noteHeight*128}`);
|
479 |
+
|
480 |
+
this.pianoRoll.scrollTo(Math.max(0, x - this.pianoRoll.offsetWidth/2), this.pianoRoll.scrollTop)
|
481 |
+
|
482 |
+
this.trackMap.forEach((track, id)=>{
|
483 |
+
let instrument = track.instrument
|
484 |
+
let cl = track.cl;
|
485 |
+
let patches = this.patches[cl]
|
486 |
+
let p = 0
|
487 |
+
for (let i = 0; i < patches.length ; i++){
|
488 |
+
let tp = patches[i]
|
489 |
+
if (t < tp[0])
|
490 |
+
break
|
491 |
+
p = tp[1]
|
492 |
+
}
|
493 |
+
if (cl === 9){
|
494 |
+
let drumKit = number2drum_kits[`${p}`];
|
495 |
+
if (!!drumKit)
|
496 |
+
instrument = drumKit + " Drum";
|
497 |
+
}else{
|
498 |
+
instrument = number2patch[p]
|
499 |
+
}
|
500 |
+
if (instrument !== track.instrument)
|
501 |
+
track.updateInstrument(instrument)
|
502 |
+
});
|
503 |
+
|
504 |
+
let dt = Date.now() - this.lastUpdateTime; // limit the update rate of ActiveNotes
|
505 |
+
if(this.playing && dt > 50){
|
506 |
+
let activeNotes = []
|
507 |
+
this.removeActiveNotes(this.activeNotes)
|
508 |
+
this.midiEvents.forEach((midiEvent)=>{
|
509 |
+
if(midiEvent[0] === "note"){
|
510 |
+
let time = midiEvent[1]
|
511 |
+
let duration = this.version==="v1"? midiEvent[3]:midiEvent[6]
|
512 |
+
let note = midiEvent[midiEvent.length - 1]
|
513 |
+
if(time <=this.playTime && time+duration>= this.playTime){
|
514 |
+
activeNotes.push(note)
|
515 |
+
}
|
516 |
+
}
|
517 |
+
});
|
518 |
+
this.addActiveNotes(activeNotes)
|
519 |
+
this.lastUpdateTime = Date.now();
|
520 |
+
}
|
521 |
+
|
522 |
+
}
|
523 |
+
|
524 |
+
setPlayTimeMs(ms){
|
525 |
+
this.playTimeMs = ms
|
526 |
+
let playTime = 0
|
527 |
+
for(let i =0;i<this.midiTimes.length;i++){
|
528 |
+
let midiTime = this.midiTimes[i]
|
529 |
+
if(midiTime.ms>=ms){
|
530 |
+
break;
|
531 |
+
}
|
532 |
+
playTime = midiTime.t + (ms-midiTime.ms) * this.timePreBeat / midiTime.tempo
|
533 |
+
}
|
534 |
+
this.setPlayTime(playTime)
|
535 |
+
}
|
536 |
+
|
537 |
+
addActiveNotes(notes){
|
538 |
+
notes.forEach((note)=>{
|
539 |
+
this.activeNotes.push(note)
|
540 |
+
note.classList.add('active');
|
541 |
+
});
|
542 |
+
}
|
543 |
+
|
544 |
+
removeActiveNotes(notes){
|
545 |
+
notes.forEach((note)=>{
|
546 |
+
let idx = this.activeNotes.indexOf(note)
|
547 |
+
if(idx>-1)
|
548 |
+
this.activeNotes.splice(idx, 1);
|
549 |
+
note.classList.remove('active');
|
550 |
+
});
|
551 |
+
}
|
552 |
+
|
553 |
+
play(){
|
554 |
+
this.playing = true;
|
555 |
+
}
|
556 |
+
|
557 |
+
pause(){
|
558 |
+
this.removeActiveNotes(this.activeNotes)
|
559 |
+
this.playing = false;
|
560 |
+
}
|
561 |
+
|
562 |
+
|
563 |
+
bindAudioPlayer(audio){
|
564 |
+
this.pause()
|
565 |
+
audio.addEventListener("play", (event)=>{
|
566 |
+
this.play()
|
567 |
+
})
|
568 |
+
audio.addEventListener("pause", (event)=>{
|
569 |
+
this.pause()
|
570 |
+
})
|
571 |
+
audio.addEventListener("loadedmetadata", (event)=>{
|
572 |
+
//I don't know why the calculated totalTimeMs is different from audio.duration*10**3
|
573 |
+
this.totalTimeMs = audio.duration*10**3;
|
574 |
+
})
|
575 |
+
}
|
576 |
+
|
577 |
+
bindWaveformCursor(cursor){
|
578 |
+
let self = this;
|
579 |
+
const callback = function(mutationsList, observer) {
|
580 |
+
for(let mutation of mutationsList) {
|
581 |
+
if (mutation.type === 'attributes' && mutation.attributeName === 'style') {
|
582 |
+
let progress = parseFloat(mutation.target.style.left.slice(0,-1))*0.01;
|
583 |
+
if(!isNaN(progress)){
|
584 |
+
self.setPlayTimeMs(progress*self.totalTimeMs);
|
585 |
+
}
|
586 |
+
}
|
587 |
+
}
|
588 |
+
};
|
589 |
+
const observer = new MutationObserver(callback);
|
590 |
+
observer.observe(cursor, {
|
591 |
+
attributes: true,
|
592 |
+
attributeFilter: ['style']
|
593 |
+
});
|
594 |
+
}
|
595 |
+
}
|
596 |
+
|
597 |
+
customElements.define('midi-visualizer', MidiVisualizer);
|
598 |
+
|
599 |
+
(()=>{
|
600 |
+
function midi_visualizer_setup(idx, midi_visualizer){
|
601 |
+
let midi_visualizer_container_inited = null
|
602 |
+
let midi_audio_audio_inited = null;
|
603 |
+
let midi_audio_cursor_inited = null;
|
604 |
+
onUiUpdate((m)=>{
|
605 |
+
let app = gradioApp()
|
606 |
+
let midi_visualizer_container = app.querySelector(`#midi_visualizer_container_${idx}`);
|
607 |
+
if(!!midi_visualizer_container && midi_visualizer_container_inited!== midi_visualizer_container){
|
608 |
+
midi_visualizer_container.appendChild(midi_visualizer)
|
609 |
+
midi_visualizer_container_inited = midi_visualizer_container;
|
610 |
+
}
|
611 |
+
let midi_audio = app.querySelector(`#midi_audio_${idx}`);
|
612 |
+
if (!!midi_audio){
|
613 |
+
let midi_audio_cursor = midi_audio.deepQuerySelector(".cursor");
|
614 |
+
if(!!midi_audio_cursor && midi_audio_cursor_inited!==midi_audio_cursor){
|
615 |
+
midi_visualizer.bindWaveformCursor(midi_audio_cursor)
|
616 |
+
midi_audio_cursor_inited = midi_audio_cursor
|
617 |
+
}
|
618 |
+
let midi_audio_waveform = midi_audio.deepQuerySelector("#waveform");
|
619 |
+
if(!!midi_audio_waveform){
|
620 |
+
let midi_audio_audio = midi_audio_waveform.deepQuerySelector("audio");
|
621 |
+
if(!!midi_audio_audio && midi_audio_audio_inited!==midi_audio_audio){
|
622 |
+
midi_visualizer.bindAudioPlayer(midi_audio_audio)
|
623 |
+
midi_audio_audio_inited = midi_audio_audio
|
624 |
+
}
|
625 |
+
}
|
626 |
+
}
|
627 |
+
});
|
628 |
+
}
|
629 |
+
|
630 |
+
let midi_visualizers = []
|
631 |
+
for (let i = 0; i < MIDI_OUTPUT_BATCH_SIZE ; i++){
|
632 |
+
let midi_visualizer = document.createElement('midi-visualizer');
|
633 |
+
midi_visualizers.push(midi_visualizer);
|
634 |
+
midi_visualizer_setup(i, midi_visualizer)
|
635 |
+
}
|
636 |
+
|
637 |
+
let hasProgressBar = false;
|
638 |
+
let output_tabs_inited = null;
|
639 |
+
onUiUpdate((m)=>{
|
640 |
+
let app = gradioApp()
|
641 |
+
let output_tabs = app.querySelector("#output_tabs");
|
642 |
+
if(!!output_tabs && output_tabs_inited!== output_tabs){
|
643 |
+
output_tabs_inited = output_tabs;
|
644 |
+
}
|
645 |
+
});
|
646 |
+
|
647 |
+
function createProgressBar(progressbarContainer){
|
648 |
+
let parentProgressbar = progressbarContainer.parentNode;
|
649 |
+
let divProgress = document.createElement('div');
|
650 |
+
divProgress.className='progressDiv';
|
651 |
+
let rect = progressbarContainer.getBoundingClientRect();
|
652 |
+
divProgress.style.width = rect.width + "px";
|
653 |
+
divProgress.style.background = "#b4c0cc";
|
654 |
+
divProgress.style.borderRadius = "8px";
|
655 |
+
let divInner = document.createElement('div');
|
656 |
+
divInner.className='progress';
|
657 |
+
divInner.style.color = "white";
|
658 |
+
divInner.style.background = "#0060df";
|
659 |
+
divInner.style.textAlign = "right";
|
660 |
+
divInner.style.fontWeight = "bold";
|
661 |
+
divInner.style.borderRadius = "8px";
|
662 |
+
divInner.style.height = "20px";
|
663 |
+
divInner.style.lineHeight = "20px";
|
664 |
+
divInner.style.paddingRight = "8px"
|
665 |
+
divInner.style.width = "0%";
|
666 |
+
divProgress.appendChild(divInner);
|
667 |
+
parentProgressbar.insertBefore(divProgress, progressbarContainer);
|
668 |
+
hasProgressBar = true;
|
669 |
+
}
|
670 |
+
|
671 |
+
function removeProgressBar(progressbarContainer){
|
672 |
+
let parentProgressbar = progressbarContainer.parentNode;
|
673 |
+
let divProgress = parentProgressbar.querySelector(".progressDiv");
|
674 |
+
parentProgressbar.removeChild(divProgress);
|
675 |
+
hasProgressBar = false;
|
676 |
+
}
|
677 |
+
|
678 |
+
function setProgressBar(progress, total){
|
679 |
+
if (!hasProgressBar)
|
680 |
+
createProgressBar(output_tabs_inited)
|
681 |
+
if (hasProgressBar && total === 0){
|
682 |
+
removeProgressBar(output_tabs_inited)
|
683 |
+
return
|
684 |
+
}
|
685 |
+
let parentProgressbar = output_tabs_inited.parentNode;
|
686 |
+
// let divProgress = parentProgressbar.querySelector(".progressDiv");
|
687 |
+
let divInner = parentProgressbar.querySelector(".progress");
|
688 |
+
if(total===0)
|
689 |
+
total = 1;
|
690 |
+
divInner.style.width = `${(progress/total)*100}%`;
|
691 |
+
divInner.textContent = `${progress}/${total}`;
|
692 |
+
}
|
693 |
+
|
694 |
+
onMsgReceive((msgs)=>{
|
695 |
+
for(let msg of msgs){
|
696 |
+
if(msg instanceof Array){
|
697 |
+
msg.forEach((o)=>{handleMsg(o)});
|
698 |
+
}else{
|
699 |
+
handleMsg(msg);
|
700 |
+
}
|
701 |
+
}
|
702 |
+
})
|
703 |
+
function handleMsg(msg){
|
704 |
+
let idx;
|
705 |
+
switch (msg.name) {
|
706 |
+
case "visualizer_clear":
|
707 |
+
idx = msg.data[0];
|
708 |
+
let ver = msg.data[1];
|
709 |
+
midi_visualizers[idx].clearMidiEvents(false);
|
710 |
+
midi_visualizers[idx].version = ver;
|
711 |
+
break;
|
712 |
+
case "visualizer_append":
|
713 |
+
idx = msg.data[0];
|
714 |
+
let events = msg.data[1];
|
715 |
+
events.forEach( value => {
|
716 |
+
midi_visualizers[idx].appendMidiEvent(value);
|
717 |
+
})
|
718 |
+
break;
|
719 |
+
case "visualizer_end":
|
720 |
+
idx = msg.data;
|
721 |
+
midi_visualizers[idx].finishAppendMidiEvent()
|
722 |
+
midi_visualizers[idx].setPlayTime(0);
|
723 |
+
break;
|
724 |
+
case "progress":
|
725 |
+
let progress = msg.data[0]
|
726 |
+
let total = msg.data[1]
|
727 |
+
setProgressBar(progress, total)
|
728 |
+
break;
|
729 |
+
default:
|
730 |
+
}
|
731 |
+
}
|
732 |
+
})();
|
midi_model.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from typing import Union, Dict, Any
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import tqdm
|
9 |
+
from peft import PeftConfig, LoraModel, load_peft_weights, set_peft_model_state_dict
|
10 |
+
from transformers import LlamaModel, LlamaConfig, DynamicCache, PretrainedConfig, PreTrainedModel
|
11 |
+
|
12 |
+
from midi_tokenizer import MIDITokenizerV1, MIDITokenizerV2, MIDITokenizer
|
13 |
+
|
14 |
+
config_name_list = ["tv1-medium", "tv2-medium", "tv2o-medium", "tv2-large", "tv2o-large"]
|
15 |
+
|
16 |
+
|
17 |
+
class MIDIModelConfig(PretrainedConfig):
|
18 |
+
model_type = "midi_model"
|
19 |
+
|
20 |
+
def __init__(self,
|
21 |
+
tokenizer: Union[MIDITokenizerV1, MIDITokenizerV2, Dict]=None,
|
22 |
+
net_config: Union[LlamaConfig, Dict]=None,
|
23 |
+
net_token_config: Union[LlamaConfig, Dict]=None,
|
24 |
+
**kwargs):
|
25 |
+
super().__init__(**kwargs)
|
26 |
+
if tokenizer:
|
27 |
+
if isinstance(tokenizer, dict):
|
28 |
+
self.tokenizer = MIDITokenizer(tokenizer["version"])
|
29 |
+
self.tokenizer.set_optimise_midi(tokenizer["optimise_midi"])
|
30 |
+
else:
|
31 |
+
self.tokenizer = tokenizer
|
32 |
+
else:
|
33 |
+
self.tokenizer = MIDITokenizer()
|
34 |
+
if net_config:
|
35 |
+
if isinstance(net_config, dict):
|
36 |
+
self.net_config = LlamaConfig(**net_config)
|
37 |
+
else:
|
38 |
+
self.net_config = net_config
|
39 |
+
else:
|
40 |
+
self.net_config = LlamaConfig()
|
41 |
+
if net_token_config:
|
42 |
+
if isinstance(net_token_config, dict):
|
43 |
+
self.net_token_config = LlamaConfig(**net_token_config)
|
44 |
+
else:
|
45 |
+
self.net_token_config = net_token_config
|
46 |
+
else:
|
47 |
+
self.net_token_config = LlamaConfig()
|
48 |
+
self.n_embd = self.net_token_config.hidden_size
|
49 |
+
|
50 |
+
def to_dict(self) -> Dict[str, Any]:
|
51 |
+
d = super().to_dict()
|
52 |
+
d["tokenizer"] = self.tokenizer.to_dict()
|
53 |
+
return d
|
54 |
+
|
55 |
+
def __str__(self):
|
56 |
+
d = {
|
57 |
+
"net": self.net_config.to_json_string(use_diff=False),
|
58 |
+
"net_token": self.net_token_config.to_json_string(use_diff=False)
|
59 |
+
}
|
60 |
+
return json.dumps(d, indent=4)
|
61 |
+
|
62 |
+
@staticmethod
|
63 |
+
def get_config(tokenizer_ver="v2", optimise_midi=True, n_layer=12, n_head=16, n_embd=1024, n_inner=4096):
|
64 |
+
tokenizer = MIDITokenizer(tokenizer_ver)
|
65 |
+
tokenizer.set_optimise_midi(optimise_midi)
|
66 |
+
net_config = LlamaConfig(vocab_size=tokenizer.vocab_size,
|
67 |
+
hidden_size=n_embd, num_attention_heads=n_head,
|
68 |
+
num_hidden_layers=n_layer, intermediate_size=n_inner,
|
69 |
+
pad_token_id=tokenizer.pad_id, max_position_embeddings=4096,
|
70 |
+
use_cache=False)
|
71 |
+
net_token_config = LlamaConfig(vocab_size=tokenizer.vocab_size,
|
72 |
+
hidden_size=n_embd, num_attention_heads=n_head // 4,
|
73 |
+
num_hidden_layers=n_layer // 4, intermediate_size=n_inner // 4,
|
74 |
+
pad_token_id=tokenizer.pad_id, max_position_embeddings=4096,
|
75 |
+
use_cache=False)
|
76 |
+
return MIDIModelConfig(tokenizer, net_config, net_token_config)
|
77 |
+
|
78 |
+
@staticmethod
|
79 |
+
def from_name(name="tv2o-medium"):
|
80 |
+
tv, size = name.split("-")
|
81 |
+
tv = tv[1:]
|
82 |
+
if tv[-1] == "o":
|
83 |
+
o = True
|
84 |
+
tv = tv[:-1]
|
85 |
+
else:
|
86 |
+
o = False
|
87 |
+
if tv not in ["v1", "v2"]:
|
88 |
+
raise ValueError(f"Unknown tokenizer version {tv}")
|
89 |
+
if size == "medium":
|
90 |
+
return MIDIModelConfig.get_config(tokenizer_ver=tv, optimise_midi=o,
|
91 |
+
n_layer=12, n_head=16, n_embd=1024, n_inner=4096)
|
92 |
+
elif size == "large":
|
93 |
+
return MIDIModelConfig.get_config(tokenizer_ver=tv, optimise_midi=o,
|
94 |
+
n_layer=24, n_head=16, n_embd=1024, n_inner=4096)
|
95 |
+
else:
|
96 |
+
raise ValueError(f"Unknown model size {size}")
|
97 |
+
|
98 |
+
|
99 |
+
class MIDIModel(PreTrainedModel):
|
100 |
+
config_class = MIDIModelConfig
|
101 |
+
|
102 |
+
def __init__(self, config: MIDIModelConfig, *args, **kwargs):
|
103 |
+
super(MIDIModel, self).__init__(config, *args, **kwargs)
|
104 |
+
self.tokenizer = config.tokenizer
|
105 |
+
self.net = LlamaModel(config.net_config)
|
106 |
+
self.net_token = LlamaModel(config.net_token_config)
|
107 |
+
self.lm_head = nn.Linear(config.n_embd, self.tokenizer.vocab_size, bias=False)
|
108 |
+
|
109 |
+
def load_merge_lora(self, model_id):
|
110 |
+
peft_config = PeftConfig.from_pretrained(model_id)
|
111 |
+
model = LoraModel(self, peft_config, adapter_name="default")
|
112 |
+
adapter_state_dict = load_peft_weights(model_id, device=str(self.device))
|
113 |
+
set_peft_model_state_dict(self, adapter_state_dict, "default")
|
114 |
+
return model.merge_and_unload()
|
115 |
+
|
116 |
+
def forward_token(self, hidden_state=None, x=None, cache=None):
|
117 |
+
"""
|
118 |
+
|
119 |
+
:param hidden_state: (batch_size, n_embd)
|
120 |
+
:param x: (batch_size, token_sequence_length)
|
121 |
+
:param cache: Cache
|
122 |
+
:return: (batch_size, 1 + token_sequence_length, vocab_size)
|
123 |
+
"""
|
124 |
+
if hidden_state is not None:
|
125 |
+
#if you use cache, you don't need to pass in hidden_state
|
126 |
+
hidden_state = hidden_state.unsqueeze(1) # (batch_size, 1, n_embd)
|
127 |
+
if x is not None:
|
128 |
+
x = self.net_token.embed_tokens(x)
|
129 |
+
if hidden_state is not None:
|
130 |
+
x = torch.cat([hidden_state, x], dim=1)
|
131 |
+
hidden_state = x
|
132 |
+
hidden_state = self.net_token.forward(inputs_embeds=hidden_state,
|
133 |
+
past_key_values=cache,
|
134 |
+
use_cache=cache is not None).last_hidden_state
|
135 |
+
return self.lm_head(hidden_state)
|
136 |
+
|
137 |
+
def forward(self, x, cache = None):
|
138 |
+
"""
|
139 |
+
:param x: (batch_size, midi_sequence_length, token_sequence_length)
|
140 |
+
:param cache: Cache
|
141 |
+
:return: hidden (batch_size, midi_sequence_length, n_embd)
|
142 |
+
"""
|
143 |
+
|
144 |
+
# merge token sequence
|
145 |
+
x = self.net.embed_tokens(x)
|
146 |
+
x = x.sum(dim=-2)
|
147 |
+
x = self.net.forward(inputs_embeds=x,
|
148 |
+
past_key_values=cache,
|
149 |
+
use_cache=cache is not None)
|
150 |
+
return x.last_hidden_state
|
151 |
+
|
152 |
+
def sample_top_p_k(self, probs, p, k, generator=None):
|
153 |
+
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
154 |
+
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
155 |
+
mask = probs_sum - probs_sort > p
|
156 |
+
probs_sort[mask] = 0.0
|
157 |
+
mask = torch.zeros(probs_sort.shape[-1], device=probs_sort.device)
|
158 |
+
mask[:k] = 1
|
159 |
+
probs_sort = probs_sort * mask
|
160 |
+
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
161 |
+
shape = probs_sort.shape
|
162 |
+
next_token = torch.multinomial(probs_sort.reshape(-1, shape[-1]),
|
163 |
+
num_samples=1, generator=generator).reshape(*shape[:-1], 1)
|
164 |
+
next_token = torch.gather(probs_idx, -1, next_token).reshape(*shape[:-1])
|
165 |
+
return next_token
|
166 |
+
|
167 |
+
@torch.inference_mode()
|
168 |
+
def generate(self, prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98, top_k=20, generator=None):
|
169 |
+
tokenizer = self.tokenizer
|
170 |
+
max_token_seq = tokenizer.max_token_seq
|
171 |
+
if prompt is None:
|
172 |
+
input_tensor = torch.full((1, max_token_seq), tokenizer.pad_id, dtype=torch.long, device=self.device)
|
173 |
+
input_tensor[0, 0] = tokenizer.bos_id # bos
|
174 |
+
input_tensor = input_tensor.unsqueeze(0)
|
175 |
+
input_tensor = torch.cat([input_tensor] * batch_size, dim=0)
|
176 |
+
else:
|
177 |
+
if len(prompt.shape) == 2:
|
178 |
+
prompt = prompt[None, :]
|
179 |
+
prompt = np.repeat(prompt, repeats=batch_size, axis=0)
|
180 |
+
elif prompt.shape[0] == 1:
|
181 |
+
prompt = np.repeat(prompt, repeats=batch_size, axis=0)
|
182 |
+
elif len(prompt.shape) != 3 or prompt.shape[0] != batch_size:
|
183 |
+
raise ValueError(f"invalid shape for prompt, {prompt.shape}")
|
184 |
+
prompt = prompt[..., :max_token_seq]
|
185 |
+
if prompt.shape[-1] < max_token_seq:
|
186 |
+
prompt = np.pad(prompt, ((0, 0), (0, 0), (0, max_token_seq - prompt.shape[-1])),
|
187 |
+
mode="constant", constant_values=tokenizer.pad_id)
|
188 |
+
input_tensor = torch.from_numpy(prompt).to(dtype=torch.long, device=self.device)
|
189 |
+
|
190 |
+
cur_len = input_tensor.shape[1]
|
191 |
+
bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
|
192 |
+
cache1 = DynamicCache()
|
193 |
+
past_len = 0
|
194 |
+
with bar:
|
195 |
+
while cur_len < max_len:
|
196 |
+
end = [False] * batch_size
|
197 |
+
hidden = self.forward(input_tensor[:, past_len:], cache=cache1)[:, -1]
|
198 |
+
next_token_seq = None
|
199 |
+
event_names = [""] * batch_size
|
200 |
+
cache2 = DynamicCache()
|
201 |
+
for i in range(max_token_seq):
|
202 |
+
mask = torch.zeros((batch_size, tokenizer.vocab_size), dtype=torch.int64, device=self.device)
|
203 |
+
for b in range(batch_size):
|
204 |
+
if end[b]:
|
205 |
+
mask[b, tokenizer.pad_id] = 1
|
206 |
+
continue
|
207 |
+
if i == 0:
|
208 |
+
mask[b, list(tokenizer.event_ids.values()) + [tokenizer.eos_id]] = 1
|
209 |
+
else:
|
210 |
+
param_names = tokenizer.events[event_names[b]]
|
211 |
+
if i > len(param_names):
|
212 |
+
mask[b, tokenizer.pad_id] = 1
|
213 |
+
continue
|
214 |
+
mask[b, tokenizer.parameter_ids[param_names[i - 1]]] = 1
|
215 |
+
mask = mask.unsqueeze(1)
|
216 |
+
x = next_token_seq
|
217 |
+
if i != 0:
|
218 |
+
# cached
|
219 |
+
hidden = None
|
220 |
+
x = x[:, -1:]
|
221 |
+
logits = self.forward_token(hidden, x, cache=cache2)[:, -1:]
|
222 |
+
scores = torch.softmax(logits / temp, dim=-1) * mask
|
223 |
+
samples = self.sample_top_p_k(scores, top_p, top_k, generator=generator)
|
224 |
+
if i == 0:
|
225 |
+
next_token_seq = samples
|
226 |
+
for b in range(batch_size):
|
227 |
+
if end[b]:
|
228 |
+
continue
|
229 |
+
eid = samples[b].item()
|
230 |
+
if eid == tokenizer.eos_id:
|
231 |
+
end[b] = True
|
232 |
+
else:
|
233 |
+
event_names[b] = tokenizer.id_events[eid]
|
234 |
+
else:
|
235 |
+
next_token_seq = torch.cat([next_token_seq, samples], dim=1)
|
236 |
+
if all([len(tokenizer.events[event_names[b]]) == i for b in range(batch_size) if not end[b]]):
|
237 |
+
break
|
238 |
+
|
239 |
+
if next_token_seq.shape[1] < max_token_seq:
|
240 |
+
next_token_seq = F.pad(next_token_seq, (0, max_token_seq - next_token_seq.shape[1]),
|
241 |
+
"constant", value=tokenizer.pad_id)
|
242 |
+
next_token_seq = next_token_seq.unsqueeze(1)
|
243 |
+
input_tensor = torch.cat([input_tensor, next_token_seq], dim=1)
|
244 |
+
past_len = cur_len
|
245 |
+
cur_len += 1
|
246 |
+
bar.update(1)
|
247 |
+
|
248 |
+
if all(end):
|
249 |
+
break
|
250 |
+
return input_tensor.cpu().numpy()
|
midi_synthesizer.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from threading import Lock
|
2 |
+
|
3 |
+
import fluidsynth
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
class MidiSynthesizer:
|
8 |
+
def __init__(self, soundfont_path, sample_rate=44100):
|
9 |
+
self.soundfont_path = soundfont_path
|
10 |
+
self.sample_rate = sample_rate
|
11 |
+
fl = fluidsynth.Synth(samplerate=float(sample_rate))
|
12 |
+
sfid = fl.sfload(soundfont_path)
|
13 |
+
self.devices = [[fl, sfid, False]]
|
14 |
+
self.file_lock = Lock()
|
15 |
+
|
16 |
+
def get_fluidsynth(self):
|
17 |
+
for device in self.devices:
|
18 |
+
if not device[2]:
|
19 |
+
device[2] = True
|
20 |
+
return device
|
21 |
+
with self.file_lock:
|
22 |
+
fl = fluidsynth.Synth(samplerate=float(self.sample_rate))
|
23 |
+
sfid = fl.sfload(self.soundfont_path)
|
24 |
+
device = [fl, sfid, True]
|
25 |
+
self.devices.append(device)
|
26 |
+
return device
|
27 |
+
|
28 |
+
def release_fluidsynth(self, device):
|
29 |
+
device[0].system_reset()
|
30 |
+
device[0].get_samples(self.sample_rate*5) # wait for silence
|
31 |
+
device[2] = False
|
32 |
+
|
33 |
+
def synthesis(self, midi_opus):
|
34 |
+
ticks_per_beat = midi_opus[0]
|
35 |
+
event_list = []
|
36 |
+
for track_idx, track in enumerate(midi_opus[1:]):
|
37 |
+
abs_t = 0
|
38 |
+
for event in track:
|
39 |
+
abs_t += event[1]
|
40 |
+
event_new = [*event]
|
41 |
+
event_new[1] = abs_t
|
42 |
+
event_list.append(event_new)
|
43 |
+
event_list = sorted(event_list, key=lambda e: e[1])
|
44 |
+
|
45 |
+
tempo = int((60 / 120) * 10 ** 6) # default 120 bpm
|
46 |
+
ss = np.empty((0, 2), dtype=np.int16)
|
47 |
+
device = self.get_fluidsynth()
|
48 |
+
fl, sfid = device[:-1]
|
49 |
+
last_t = 0
|
50 |
+
for c in range(16):
|
51 |
+
fl.program_select(c, sfid, 128 if c == 9 else 0, 0)
|
52 |
+
for event in event_list:
|
53 |
+
name = event[0]
|
54 |
+
sample_len = int(((event[1] / ticks_per_beat) * tempo / (10 ** 6)) * self.sample_rate)
|
55 |
+
sample_len -= int(((last_t / ticks_per_beat) * tempo / (10 ** 6)) * self.sample_rate)
|
56 |
+
last_t = event[1]
|
57 |
+
if sample_len > 0:
|
58 |
+
sample = fl.get_samples(sample_len).reshape(sample_len, 2)
|
59 |
+
ss = np.concatenate([ss, sample])
|
60 |
+
if name == "set_tempo":
|
61 |
+
tempo = event[2]
|
62 |
+
elif name == "patch_change":
|
63 |
+
c, p = event[2:4]
|
64 |
+
fl.program_select(c, sfid, 128 if c == 9 else 0, p)
|
65 |
+
elif name == "control_change":
|
66 |
+
c, cc, v = event[2:5]
|
67 |
+
fl.cc(c, cc, v)
|
68 |
+
elif name == "note_on" and event[3] > 0:
|
69 |
+
c, p, v = event[2:5]
|
70 |
+
fl.noteon(c, p, v)
|
71 |
+
elif name == "note_off" or (name == "note_on" and event[3] == 0):
|
72 |
+
c, p = event[2:4]
|
73 |
+
fl.noteoff(c, p)
|
74 |
+
|
75 |
+
self.release_fluidsynth(device)
|
76 |
+
if ss.shape[0] > 0:
|
77 |
+
max_val = np.abs(ss).max()
|
78 |
+
if max_val != 0:
|
79 |
+
ss = (ss / max_val) * np.iinfo(np.int16).max
|
80 |
+
ss = ss.astype(np.int16)
|
81 |
+
return ss
|
midi_tokenizer.py
ADDED
@@ -0,0 +1,1196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from typing import Dict, Any
|
3 |
+
|
4 |
+
import PIL.Image
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
class MIDITokenizerV1:
|
9 |
+
def __init__(self):
|
10 |
+
self.version = "v1"
|
11 |
+
self.optimise_midi = False
|
12 |
+
self.vocab_size = 0
|
13 |
+
|
14 |
+
def allocate_ids(size):
|
15 |
+
ids = [self.vocab_size + i for i in range(size)]
|
16 |
+
self.vocab_size += size
|
17 |
+
return ids
|
18 |
+
|
19 |
+
self.pad_id = allocate_ids(1)[0]
|
20 |
+
self.bos_id = allocate_ids(1)[0]
|
21 |
+
self.eos_id = allocate_ids(1)[0]
|
22 |
+
self.events = {
|
23 |
+
"note": ["time1", "time2", "track", "duration", "channel", "pitch", "velocity"],
|
24 |
+
"patch_change": ["time1", "time2", "track", "channel", "patch"],
|
25 |
+
"control_change": ["time1", "time2", "track", "channel", "controller", "value"],
|
26 |
+
"set_tempo": ["time1", "time2", "track", "bpm"],
|
27 |
+
}
|
28 |
+
self.event_parameters = {
|
29 |
+
"time1": 128, "time2": 16, "duration": 2048, "track": 128, "channel": 16, "pitch": 128, "velocity": 128,
|
30 |
+
"patch": 128, "controller": 128, "value": 128, "bpm": 256
|
31 |
+
}
|
32 |
+
self.event_ids = {e: allocate_ids(1)[0] for e in self.events.keys()}
|
33 |
+
self.id_events = {i: e for e, i in self.event_ids.items()}
|
34 |
+
self.parameter_ids = {p: allocate_ids(s) for p, s in self.event_parameters.items()}
|
35 |
+
self.max_token_seq = max([len(ps) for ps in self.events.values()]) + 1
|
36 |
+
|
37 |
+
def to_dict(self) -> Dict[str, Any]:
|
38 |
+
d = {
|
39 |
+
"version":self.version,
|
40 |
+
"optimise_midi":self.optimise_midi,
|
41 |
+
"vocab_size": self.vocab_size,
|
42 |
+
"events": self.events,
|
43 |
+
"event_parameters": self.event_parameters,
|
44 |
+
"max_token_seq": self.max_token_seq,
|
45 |
+
"pad_id": self.pad_id,
|
46 |
+
"bos_id": self.bos_id,
|
47 |
+
"eos_id": self.eos_id,
|
48 |
+
}
|
49 |
+
return d
|
50 |
+
|
51 |
+
def set_optimise_midi(self, optimise_midi=True):
|
52 |
+
self.optimise_midi = optimise_midi
|
53 |
+
|
54 |
+
@staticmethod
|
55 |
+
def tempo2bpm(tempo):
|
56 |
+
tempo = tempo / 10 ** 6 # us to s
|
57 |
+
bpm = 60 / tempo
|
58 |
+
return bpm
|
59 |
+
|
60 |
+
@staticmethod
|
61 |
+
def bpm2tempo(bpm):
|
62 |
+
if bpm == 0:
|
63 |
+
bpm = 1
|
64 |
+
tempo = int((60 / bpm) * 10 ** 6)
|
65 |
+
return tempo
|
66 |
+
|
67 |
+
def tokenize(self, midi_score, add_bos_eos=True, cc_eps=4, tempo_eps=4,
|
68 |
+
remap_track_channel=None, add_default_instr=None, remove_empty_channels=None):
|
69 |
+
if remap_track_channel is None: # set default value
|
70 |
+
remap_track_channel = self.optimise_midi
|
71 |
+
if add_default_instr is None:
|
72 |
+
add_default_instr = self.optimise_midi
|
73 |
+
if remove_empty_channels is None:
|
74 |
+
remove_empty_channels = self.optimise_midi
|
75 |
+
|
76 |
+
ticks_per_beat = midi_score[0]
|
77 |
+
event_list = {}
|
78 |
+
track_idx_map = {i: dict() for i in range(16)}
|
79 |
+
track_idx_dict = {}
|
80 |
+
channels = []
|
81 |
+
patch_channels = []
|
82 |
+
empty_channels = [True] * 16
|
83 |
+
channel_note_tracks = {i: list() for i in range(16)}
|
84 |
+
for track_idx, track in enumerate(midi_score[1:129]):
|
85 |
+
last_notes = {}
|
86 |
+
patch_dict = {}
|
87 |
+
control_dict = {}
|
88 |
+
last_tempo = 0
|
89 |
+
for event in track:
|
90 |
+
if event[0] not in self.events:
|
91 |
+
continue
|
92 |
+
c = -1
|
93 |
+
t = round(16 * event[1] / ticks_per_beat) # quantization
|
94 |
+
new_event = [event[0], t // 16, t % 16, track_idx] + event[2:]
|
95 |
+
if event[0] == "note":
|
96 |
+
c = event[3]
|
97 |
+
if c > 15 or c < 0:
|
98 |
+
continue
|
99 |
+
empty_channels[c] = False
|
100 |
+
track_idx_dict.setdefault(c, track_idx)
|
101 |
+
note_tracks = channel_note_tracks[c]
|
102 |
+
if track_idx not in note_tracks:
|
103 |
+
note_tracks.append(track_idx)
|
104 |
+
new_event[4] = max(1, round(16 * new_event[4] / ticks_per_beat))
|
105 |
+
elif event[0] == "set_tempo":
|
106 |
+
if new_event[4] == 0: # invalid tempo
|
107 |
+
continue
|
108 |
+
bpm = int(self.tempo2bpm(new_event[4]))
|
109 |
+
new_event[4] = min(bpm, 255)
|
110 |
+
if event[0] == "note":
|
111 |
+
key = tuple(new_event[:4] + new_event[5:-1])
|
112 |
+
else:
|
113 |
+
key = tuple(new_event[:-1])
|
114 |
+
if event[0] == "patch_change":
|
115 |
+
c, p = event[2:]
|
116 |
+
if c > 15 or c < 0:
|
117 |
+
continue
|
118 |
+
last_p = patch_dict.setdefault(c, None)
|
119 |
+
if last_p == p:
|
120 |
+
continue
|
121 |
+
patch_dict[c] = p
|
122 |
+
if c not in patch_channels:
|
123 |
+
patch_channels.append(c)
|
124 |
+
elif event[0] == "control_change":
|
125 |
+
c, cc, v = event[2:]
|
126 |
+
if c > 15 or c < 0:
|
127 |
+
continue
|
128 |
+
last_v = control_dict.setdefault((c, cc), 0)
|
129 |
+
if abs(last_v - v) < cc_eps:
|
130 |
+
continue
|
131 |
+
control_dict[(c, cc)] = v
|
132 |
+
elif event[0] == "set_tempo":
|
133 |
+
tempo = new_event[-1]
|
134 |
+
if abs(last_tempo - tempo) < tempo_eps:
|
135 |
+
continue
|
136 |
+
last_tempo = tempo
|
137 |
+
|
138 |
+
if c != -1:
|
139 |
+
if c not in channels:
|
140 |
+
channels.append(c)
|
141 |
+
tr_map = track_idx_map[c]
|
142 |
+
if track_idx not in tr_map:
|
143 |
+
tr_map[track_idx] = 0
|
144 |
+
|
145 |
+
if event[0] == "note": # to eliminate note overlap due to quantization
|
146 |
+
cp = tuple(new_event[5:7])
|
147 |
+
if cp in last_notes:
|
148 |
+
last_note_key, last_note = last_notes[cp]
|
149 |
+
last_t = last_note[1] * 16 + last_note[2]
|
150 |
+
last_note[4] = max(0, min(last_note[4], t - last_t))
|
151 |
+
if last_note[4] == 0:
|
152 |
+
event_list.pop(last_note_key)
|
153 |
+
last_notes[cp] = (key, new_event)
|
154 |
+
event_list[key] = new_event
|
155 |
+
event_list = list(event_list.values())
|
156 |
+
|
157 |
+
empty_channels = [c for c in channels if empty_channels[c]]
|
158 |
+
|
159 |
+
if remap_track_channel:
|
160 |
+
patch_channels = []
|
161 |
+
channels_count = 0
|
162 |
+
channels_map = {9: 9} if 9 in channels else {}
|
163 |
+
if remove_empty_channels:
|
164 |
+
channels = sorted(channels, key=lambda x: 1 if x in empty_channels else 0)
|
165 |
+
for c in channels:
|
166 |
+
if c == 9:
|
167 |
+
continue
|
168 |
+
channels_map[c] = channels_count
|
169 |
+
channels_count += 1
|
170 |
+
if channels_count == 9:
|
171 |
+
channels_count = 10
|
172 |
+
channels = list(channels_map.values())
|
173 |
+
|
174 |
+
track_count = 0
|
175 |
+
track_idx_map_order = [k for k, v in sorted(list(channels_map.items()), key=lambda x: x[1])]
|
176 |
+
for c in track_idx_map_order: # tracks not to remove
|
177 |
+
if remove_empty_channels and c in empty_channels:
|
178 |
+
continue
|
179 |
+
tr_map = track_idx_map[c]
|
180 |
+
for track_idx in tr_map:
|
181 |
+
note_tracks = channel_note_tracks[c]
|
182 |
+
if len(note_tracks) != 0 and track_idx not in note_tracks:
|
183 |
+
continue
|
184 |
+
track_count += 1
|
185 |
+
tr_map[track_idx] = track_count
|
186 |
+
for c in track_idx_map_order: # tracks to remove
|
187 |
+
if not (remove_empty_channels and c in empty_channels):
|
188 |
+
continue
|
189 |
+
tr_map = track_idx_map[c]
|
190 |
+
for track_idx in tr_map:
|
191 |
+
note_tracks = channel_note_tracks[c]
|
192 |
+
if not (len(note_tracks) != 0 and track_idx not in note_tracks):
|
193 |
+
continue
|
194 |
+
track_count += 1
|
195 |
+
tr_map[track_idx] = track_count
|
196 |
+
|
197 |
+
empty_channels = [channels_map[c] for c in empty_channels]
|
198 |
+
track_idx_dict = {}
|
199 |
+
for event in event_list:
|
200 |
+
name = event[0]
|
201 |
+
track_idx = event[3]
|
202 |
+
if name == "note":
|
203 |
+
c = event[5]
|
204 |
+
event[5] = channels_map[c]
|
205 |
+
event[3] = track_idx_map[c][track_idx]
|
206 |
+
track_idx_dict.setdefault(event[5], event[3])
|
207 |
+
# setdefault, so the track_idx is first of the channel
|
208 |
+
elif name == "set_tempo":
|
209 |
+
event[3] = 0
|
210 |
+
elif name == "control_change" or name == "patch_change":
|
211 |
+
c = event[4]
|
212 |
+
event[4] = channels_map[c]
|
213 |
+
tr_map = track_idx_map[c]
|
214 |
+
# move the event to first track of the channel if it's original track is empty
|
215 |
+
note_tracks = channel_note_tracks[c]
|
216 |
+
if len(note_tracks) != 0 and track_idx not in note_tracks:
|
217 |
+
track_idx = channel_note_tracks[c][0]
|
218 |
+
new_track_idx = tr_map[track_idx]
|
219 |
+
event[3] = new_track_idx
|
220 |
+
if name == "patch_change" and event[4] not in patch_channels:
|
221 |
+
patch_channels.append(event[4])
|
222 |
+
|
223 |
+
if add_default_instr:
|
224 |
+
for c in channels:
|
225 |
+
if c not in patch_channels and c in track_idx_dict:
|
226 |
+
event_list.append(["patch_change", 0, 0, track_idx_dict[c], c, 0])
|
227 |
+
|
228 |
+
events_name_order = {"set_tempo": 0, "patch_change": 1, "control_change": 2, "note": 3}
|
229 |
+
events_order = lambda e: e[1:4] + [events_name_order[e[0]]]
|
230 |
+
event_list = sorted(event_list, key=events_order)
|
231 |
+
|
232 |
+
setup_events = {}
|
233 |
+
notes_in_setup = False
|
234 |
+
for i, event in enumerate(event_list): # optimise setup
|
235 |
+
new_event = [*event]
|
236 |
+
if event[0] != "note":
|
237 |
+
new_event[1] = 0
|
238 |
+
new_event[2] = 0
|
239 |
+
has_next = False
|
240 |
+
has_pre = False
|
241 |
+
if i < len(event_list) - 1:
|
242 |
+
next_event = event_list[i + 1]
|
243 |
+
has_next = event[1] + event[2] == next_event[1] + next_event[2]
|
244 |
+
if notes_in_setup and i > 0:
|
245 |
+
pre_event = event_list[i - 1]
|
246 |
+
has_pre = event[1] + event[2] == pre_event[1] + pre_event[2]
|
247 |
+
if (event[0] == "note" and not has_next) or (notes_in_setup and not has_pre):
|
248 |
+
event_list = sorted(setup_events.values(), key=events_order) + event_list[i:]
|
249 |
+
break
|
250 |
+
else:
|
251 |
+
if event[0] == "note":
|
252 |
+
notes_in_setup = True
|
253 |
+
key = tuple([event[0]] + event[3:-2])
|
254 |
+
else:
|
255 |
+
key = tuple([event[0]] + event[3:-1])
|
256 |
+
setup_events[key] = new_event
|
257 |
+
|
258 |
+
last_t1 = 0
|
259 |
+
midi_seq = []
|
260 |
+
for event in event_list:
|
261 |
+
if remove_empty_channels and event[0] in ["control_change", "patch_change"] and event[4] in empty_channels:
|
262 |
+
continue
|
263 |
+
cur_t1 = event[1]
|
264 |
+
event[1] = event[1] - last_t1
|
265 |
+
tokens = self.event2tokens(event)
|
266 |
+
if not tokens:
|
267 |
+
continue
|
268 |
+
midi_seq.append(tokens)
|
269 |
+
last_t1 = cur_t1
|
270 |
+
|
271 |
+
if add_bos_eos:
|
272 |
+
bos = [self.bos_id] + [self.pad_id] * (self.max_token_seq - 1)
|
273 |
+
eos = [self.eos_id] + [self.pad_id] * (self.max_token_seq - 1)
|
274 |
+
midi_seq = [bos] + midi_seq + [eos]
|
275 |
+
return midi_seq
|
276 |
+
|
277 |
+
def event2tokens(self, event):
|
278 |
+
name = event[0]
|
279 |
+
params = event[1:]
|
280 |
+
if not all([0 <= params[i] < self.event_parameters[p] for i, p in enumerate(self.events[name])]):
|
281 |
+
return []
|
282 |
+
tokens = [self.event_ids[name]] + [self.parameter_ids[p][params[i]]
|
283 |
+
for i, p in enumerate(self.events[name])]
|
284 |
+
tokens += [self.pad_id] * (self.max_token_seq - len(tokens))
|
285 |
+
return tokens
|
286 |
+
|
287 |
+
def tokens2event(self, tokens):
|
288 |
+
if tokens[0] not in self.id_events:
|
289 |
+
return []
|
290 |
+
name = self.id_events[tokens[0]]
|
291 |
+
if len(tokens) <= len(self.events[name]):
|
292 |
+
return []
|
293 |
+
params = tokens[1:]
|
294 |
+
params = [params[i] - self.parameter_ids[p][0] for i, p in enumerate(self.events[name])]
|
295 |
+
if not all([0 <= params[i] < self.event_parameters[p] for i, p in enumerate(self.events[name])]):
|
296 |
+
return []
|
297 |
+
event = [name] + params
|
298 |
+
return event
|
299 |
+
|
300 |
+
def detokenize(self, midi_seq):
|
301 |
+
ticks_per_beat = 480
|
302 |
+
tracks_dict = {}
|
303 |
+
t1 = 0
|
304 |
+
for tokens in midi_seq:
|
305 |
+
if tokens[0] in self.id_events:
|
306 |
+
event = self.tokens2event(tokens)
|
307 |
+
if not event:
|
308 |
+
continue
|
309 |
+
name = event[0]
|
310 |
+
if name == "set_tempo":
|
311 |
+
event[4] = self.bpm2tempo(event[4])
|
312 |
+
if event[0] == "note":
|
313 |
+
event[4] = int(event[4] * ticks_per_beat / 16)
|
314 |
+
t1 += event[1]
|
315 |
+
t = t1 * 16 + event[2]
|
316 |
+
t = int(t * ticks_per_beat / 16)
|
317 |
+
track_idx = event[3]
|
318 |
+
if track_idx not in tracks_dict:
|
319 |
+
tracks_dict[track_idx] = []
|
320 |
+
tracks_dict[track_idx].append([event[0], t] + event[4:])
|
321 |
+
tracks = [tr for idx, tr in sorted(list(tracks_dict.items()), key=lambda it: it[0])]
|
322 |
+
|
323 |
+
for i in range(len(tracks)): # to eliminate note overlap
|
324 |
+
track = tracks[i]
|
325 |
+
track = sorted(track, key=lambda e: e[1])
|
326 |
+
last_note_t = {}
|
327 |
+
zero_len_notes = []
|
328 |
+
for e in reversed(track):
|
329 |
+
if e[0] == "note":
|
330 |
+
t, d, c, p = e[1:5]
|
331 |
+
key = (c, p)
|
332 |
+
if key in last_note_t:
|
333 |
+
d = min(d, max(last_note_t[key] - t, 0))
|
334 |
+
last_note_t[key] = t
|
335 |
+
e[2] = d
|
336 |
+
if d == 0:
|
337 |
+
zero_len_notes.append(e)
|
338 |
+
for e in zero_len_notes:
|
339 |
+
track.remove(e)
|
340 |
+
tracks[i] = track
|
341 |
+
return [ticks_per_beat, *tracks]
|
342 |
+
|
343 |
+
def midi2img(self, midi_score):
|
344 |
+
ticks_per_beat = midi_score[0]
|
345 |
+
notes = []
|
346 |
+
max_time = 1
|
347 |
+
track_num = len(midi_score[1:])
|
348 |
+
for track_idx, track in enumerate(midi_score[1:]):
|
349 |
+
for event in track:
|
350 |
+
t = round(16 * event[1] / ticks_per_beat)
|
351 |
+
if event[0] == "note":
|
352 |
+
d = max(1, round(16 * event[2] / ticks_per_beat))
|
353 |
+
c, p = event[3:5]
|
354 |
+
max_time = max(max_time, t + d + 1)
|
355 |
+
notes.append((track_idx, c, p, t, d))
|
356 |
+
img = np.zeros((128, max_time, 3), dtype=np.uint8)
|
357 |
+
colors = {(i, j): np.random.randint(50, 256, 3) for i in range(track_num) for j in range(16)}
|
358 |
+
for note in notes:
|
359 |
+
tr, c, p, t, d = note
|
360 |
+
img[p, t: t + d] = colors[(tr, c)]
|
361 |
+
img = PIL.Image.fromarray(np.flip(img, 0))
|
362 |
+
return img
|
363 |
+
|
364 |
+
def augment(self, midi_seq, max_pitch_shift=4, max_vel_shift=10, max_cc_val_shift=10, max_bpm_shift=10,
|
365 |
+
max_track_shift=0, max_channel_shift=16):
|
366 |
+
pitch_shift = random.randint(-max_pitch_shift, max_pitch_shift)
|
367 |
+
vel_shift = random.randint(-max_vel_shift, max_vel_shift)
|
368 |
+
cc_val_shift = random.randint(-max_cc_val_shift, max_cc_val_shift)
|
369 |
+
bpm_shift = random.randint(-max_bpm_shift, max_bpm_shift)
|
370 |
+
track_shift = random.randint(0, max_track_shift)
|
371 |
+
channel_shift = random.randint(0, max_channel_shift)
|
372 |
+
midi_seq_new = []
|
373 |
+
for tokens in midi_seq:
|
374 |
+
tokens_new = [*tokens]
|
375 |
+
if tokens[0] in self.id_events:
|
376 |
+
name = self.id_events[tokens[0]]
|
377 |
+
for i, pn in enumerate(self.events[name]):
|
378 |
+
if pn == "track":
|
379 |
+
tr = tokens[1 + i] - self.parameter_ids[pn][0]
|
380 |
+
tr += track_shift
|
381 |
+
tr = tr % self.event_parameters[pn]
|
382 |
+
tokens_new[1 + i] = self.parameter_ids[pn][tr]
|
383 |
+
elif pn == "channel":
|
384 |
+
c = tokens[1 + i] - self.parameter_ids[pn][0]
|
385 |
+
c0 = c
|
386 |
+
c += channel_shift
|
387 |
+
c = c % self.event_parameters[pn]
|
388 |
+
if c0 == 9:
|
389 |
+
c = 9
|
390 |
+
elif c == 9:
|
391 |
+
c = (9 + channel_shift) % self.event_parameters[pn]
|
392 |
+
tokens_new[1 + i] = self.parameter_ids[pn][c]
|
393 |
+
|
394 |
+
if name == "note":
|
395 |
+
c = tokens[5] - self.parameter_ids["channel"][0]
|
396 |
+
p = tokens[6] - self.parameter_ids["pitch"][0]
|
397 |
+
v = tokens[7] - self.parameter_ids["velocity"][0]
|
398 |
+
if c != 9: # no shift for drums
|
399 |
+
p += pitch_shift
|
400 |
+
if not 0 <= p < 128:
|
401 |
+
return midi_seq
|
402 |
+
v += vel_shift
|
403 |
+
v = max(1, min(127, v))
|
404 |
+
tokens_new[6] = self.parameter_ids["pitch"][p]
|
405 |
+
tokens_new[7] = self.parameter_ids["velocity"][v]
|
406 |
+
elif name == "control_change":
|
407 |
+
cc = tokens[5] - self.parameter_ids["controller"][0]
|
408 |
+
val = tokens[6] - self.parameter_ids["value"][0]
|
409 |
+
if cc in [1, 2, 7, 11]:
|
410 |
+
val += cc_val_shift
|
411 |
+
val = max(1, min(127, val))
|
412 |
+
tokens_new[6] = self.parameter_ids["value"][val]
|
413 |
+
elif name == "set_tempo":
|
414 |
+
bpm = tokens[4] - self.parameter_ids["bpm"][0]
|
415 |
+
bpm += bpm_shift
|
416 |
+
bpm = max(1, min(255, bpm))
|
417 |
+
tokens_new[4] = self.parameter_ids["bpm"][bpm]
|
418 |
+
midi_seq_new.append(tokens_new)
|
419 |
+
return midi_seq_new
|
420 |
+
|
421 |
+
def check_quality(self, midi_seq, alignment_min=0.3, tonality_min=0.8, piano_max=0.7, notes_bandwidth_min=3,
|
422 |
+
notes_density_max=50, notes_density_min=2.5, total_notes_max=20000, total_notes_min=256,
|
423 |
+
note_window_size=16):
|
424 |
+
total_notes = 0
|
425 |
+
channels = []
|
426 |
+
time_hist = [0] * 16
|
427 |
+
note_windows = {}
|
428 |
+
notes_sametime = []
|
429 |
+
notes_density_list = []
|
430 |
+
tonality_list = []
|
431 |
+
notes_bandwidth_list = []
|
432 |
+
instruments = {}
|
433 |
+
piano_channels = []
|
434 |
+
abs_t1 = 0
|
435 |
+
last_t = 0
|
436 |
+
for tsi, tokens in enumerate(midi_seq):
|
437 |
+
event = self.tokens2event(tokens)
|
438 |
+
if not event:
|
439 |
+
continue
|
440 |
+
t1, t2, tr = event[1:4]
|
441 |
+
abs_t1 += t1
|
442 |
+
t = abs_t1 * 16 + t2
|
443 |
+
c = None
|
444 |
+
if event[0] == "note":
|
445 |
+
d, c, p, v = event[4:]
|
446 |
+
total_notes += 1
|
447 |
+
time_hist[t2] += 1
|
448 |
+
if c != 9: # ignore drum channel
|
449 |
+
if c not in instruments:
|
450 |
+
instruments[c] = 0
|
451 |
+
if c not in piano_channels:
|
452 |
+
piano_channels.append(c)
|
453 |
+
note_windows.setdefault(abs_t1 // note_window_size, []).append(p)
|
454 |
+
if last_t != t:
|
455 |
+
notes_sametime = [(et, p_) for et, p_ in notes_sametime if et > last_t]
|
456 |
+
notes_sametime_p = [p_ for _, p_ in notes_sametime]
|
457 |
+
if len(notes_sametime) > 0:
|
458 |
+
notes_bandwidth_list.append(max(notes_sametime_p) - min(notes_sametime_p))
|
459 |
+
notes_sametime.append((t + d - 1, p))
|
460 |
+
elif event[0] == "patch_change":
|
461 |
+
c, p = event[4:]
|
462 |
+
instruments[c] = p
|
463 |
+
if p == 0 and c not in piano_channels:
|
464 |
+
piano_channels.append(c)
|
465 |
+
if c is not None and c not in channels:
|
466 |
+
channels.append(c)
|
467 |
+
last_t = t
|
468 |
+
reasons = []
|
469 |
+
if total_notes < total_notes_min:
|
470 |
+
reasons.append("total_min")
|
471 |
+
if total_notes > total_notes_max:
|
472 |
+
reasons.append("total_max")
|
473 |
+
if len(note_windows) == 0 and total_notes > 0:
|
474 |
+
reasons.append("drum_only")
|
475 |
+
if reasons:
|
476 |
+
return False, reasons
|
477 |
+
time_hist = sorted(time_hist, reverse=True)
|
478 |
+
alignment = sum(time_hist[:2]) / total_notes
|
479 |
+
for notes in note_windows.values():
|
480 |
+
key_hist = [0] * 12
|
481 |
+
for p in notes:
|
482 |
+
key_hist[p % 12] += 1
|
483 |
+
key_hist = sorted(key_hist, reverse=True)
|
484 |
+
tonality_list.append(sum(key_hist[:7]) / len(notes))
|
485 |
+
notes_density_list.append(len(notes) / note_window_size)
|
486 |
+
tonality_list = sorted(tonality_list)
|
487 |
+
tonality = sum(tonality_list) / len(tonality_list)
|
488 |
+
notes_bandwidth = sum(notes_bandwidth_list) / len(notes_bandwidth_list) if notes_bandwidth_list else 0
|
489 |
+
notes_density = max(notes_density_list) if notes_density_list else 0
|
490 |
+
piano_ratio = len(piano_channels) / len(channels)
|
491 |
+
if len(channels) <= 3: # ignore piano threshold if it is a piano solo midi
|
492 |
+
piano_max = 1
|
493 |
+
if alignment < alignment_min: # check weather the notes align to the bars (because some midi files are recorded)
|
494 |
+
reasons.append("alignment")
|
495 |
+
if tonality < tonality_min: # check whether the music is tonal
|
496 |
+
reasons.append("tonality")
|
497 |
+
if notes_bandwidth < notes_bandwidth_min: # check whether music is melodic line only
|
498 |
+
reasons.append("bandwidth")
|
499 |
+
if not notes_density_min < notes_density < notes_density_max:
|
500 |
+
reasons.append("density")
|
501 |
+
if piano_ratio > piano_max: # check whether most instruments is piano (because some midi files don't have instruments assigned correctly)
|
502 |
+
reasons.append("piano")
|
503 |
+
return not reasons, reasons
|
504 |
+
|
505 |
+
|
506 |
+
class MIDITokenizerV2:
|
507 |
+
def __init__(self):
|
508 |
+
self.version = "v2"
|
509 |
+
self.optimise_midi = False
|
510 |
+
self.vocab_size = 0
|
511 |
+
|
512 |
+
def allocate_ids(size):
|
513 |
+
ids = [self.vocab_size + i for i in range(size)]
|
514 |
+
self.vocab_size += size
|
515 |
+
return ids
|
516 |
+
|
517 |
+
self.pad_id = allocate_ids(1)[0]
|
518 |
+
self.bos_id = allocate_ids(1)[0]
|
519 |
+
self.eos_id = allocate_ids(1)[0]
|
520 |
+
self.events = {
|
521 |
+
"note": ["time1", "time2", "track", "channel", "pitch", "velocity", "duration"],
|
522 |
+
"patch_change": ["time1", "time2", "track", "channel", "patch"],
|
523 |
+
"control_change": ["time1", "time2", "track", "channel", "controller", "value"],
|
524 |
+
"set_tempo": ["time1", "time2", "track", "bpm"],
|
525 |
+
"time_signature": ["time1", "time2", "track", "nn", "dd"],
|
526 |
+
"key_signature": ["time1", "time2", "track", "sf", "mi"],
|
527 |
+
}
|
528 |
+
self.event_parameters = {
|
529 |
+
"time1": 128, "time2": 16, "duration": 2048, "track": 128, "channel": 16, "pitch": 128, "velocity": 128,
|
530 |
+
"patch": 128, "controller": 128, "value": 128, "bpm": 384, "nn": 16, "dd": 4, "sf": 15, "mi": 2
|
531 |
+
}
|
532 |
+
self.event_ids = {e: allocate_ids(1)[0] for e in self.events.keys()}
|
533 |
+
self.id_events = {i: e for e, i in self.event_ids.items()}
|
534 |
+
self.parameter_ids = {p: allocate_ids(s) for p, s in self.event_parameters.items()}
|
535 |
+
self.max_token_seq = max([len(ps) for ps in self.events.values()]) + 1
|
536 |
+
|
537 |
+
def to_dict(self) -> Dict[str, Any]:
|
538 |
+
d = {
|
539 |
+
"version":self.version,
|
540 |
+
"optimise_midi":self.optimise_midi,
|
541 |
+
"vocab_size": self.vocab_size,
|
542 |
+
"events": self.events,
|
543 |
+
"event_parameters": self.event_parameters,
|
544 |
+
"max_token_seq": self.max_token_seq,
|
545 |
+
"pad_id": self.pad_id,
|
546 |
+
"bos_id": self.bos_id,
|
547 |
+
"eos_id": self.eos_id,
|
548 |
+
}
|
549 |
+
return d
|
550 |
+
|
551 |
+
def set_optimise_midi(self, optimise_midi=True):
|
552 |
+
self.optimise_midi = optimise_midi
|
553 |
+
|
554 |
+
@staticmethod
|
555 |
+
def tempo2bpm(tempo):
|
556 |
+
tempo = tempo / 10 ** 6 # us to s
|
557 |
+
bpm = 60 / tempo
|
558 |
+
return bpm
|
559 |
+
|
560 |
+
@staticmethod
|
561 |
+
def bpm2tempo(bpm):
|
562 |
+
if bpm == 0:
|
563 |
+
bpm = 1
|
564 |
+
tempo = int((60 / bpm) * 10 ** 6)
|
565 |
+
return tempo
|
566 |
+
|
567 |
+
@staticmethod
|
568 |
+
def sf2key(sf):
|
569 |
+
# sf in key_signature to key.
|
570 |
+
# key represents the sequence from C note to B note (12 in total)
|
571 |
+
return (sf * 7) % 12
|
572 |
+
|
573 |
+
@staticmethod
|
574 |
+
def key2sf(k, mi):
|
575 |
+
# key to sf
|
576 |
+
sf = (k * 7) % 12
|
577 |
+
if sf > 6 or (mi == 1 and sf >= 5):
|
578 |
+
sf -= 12
|
579 |
+
return sf
|
580 |
+
|
581 |
+
@staticmethod
|
582 |
+
def detect_key_signature(key_hist, threshold=0.7):
|
583 |
+
if len(key_hist) != 12:
|
584 |
+
return None
|
585 |
+
if sum(key_hist) == 0:
|
586 |
+
return None
|
587 |
+
p = sum(sorted(key_hist, reverse=True)[:7]) / sum(key_hist)
|
588 |
+
if p < threshold:
|
589 |
+
return None
|
590 |
+
keys = [x[1] for x in sorted(zip(key_hist, range(len(key_hist))), reverse=True, key=lambda x: x[0])[:7]]
|
591 |
+
keys = sorted(keys)
|
592 |
+
semitones = []
|
593 |
+
for i in range(len(keys)):
|
594 |
+
dis = keys[i] - keys[i - 1]
|
595 |
+
if dis == 1 or dis == -11:
|
596 |
+
semitones.append(keys[i])
|
597 |
+
if len(semitones) != 2:
|
598 |
+
return None
|
599 |
+
semitones_dis = semitones[1] - semitones[0]
|
600 |
+
if semitones_dis == 5:
|
601 |
+
root_key = semitones[0]
|
602 |
+
elif semitones_dis == 7:
|
603 |
+
root_key = semitones[1]
|
604 |
+
else:
|
605 |
+
return None
|
606 |
+
return root_key
|
607 |
+
|
608 |
+
def tokenize(self, midi_score, add_bos_eos=True, cc_eps=4, tempo_eps=4,
|
609 |
+
remap_track_channel=None, add_default_instr=None, remove_empty_channels=None):
|
610 |
+
if remap_track_channel is None: # set default value
|
611 |
+
remap_track_channel = self.optimise_midi
|
612 |
+
if add_default_instr is None:
|
613 |
+
add_default_instr = self.optimise_midi
|
614 |
+
if remove_empty_channels is None:
|
615 |
+
remove_empty_channels = self.optimise_midi
|
616 |
+
|
617 |
+
ticks_per_beat = midi_score[0]
|
618 |
+
event_list = {}
|
619 |
+
track_idx_map = {i: dict() for i in range(16)}
|
620 |
+
track_idx_dict = {}
|
621 |
+
channels = []
|
622 |
+
patch_channels = []
|
623 |
+
empty_channels = [True] * 16
|
624 |
+
channel_note_tracks = {i: list() for i in range(16)}
|
625 |
+
note_key_hist = [0]*12
|
626 |
+
key_sigs = []
|
627 |
+
track_to_channels = {}
|
628 |
+
for track_idx, track in enumerate(midi_score[1:129]):
|
629 |
+
last_notes = {}
|
630 |
+
patch_dict = {}
|
631 |
+
control_dict = {}
|
632 |
+
last_bpm = 0
|
633 |
+
track_channels = []
|
634 |
+
track_to_channels.setdefault(track_idx, track_channels)
|
635 |
+
for event in track:
|
636 |
+
if event[0] not in self.events:
|
637 |
+
continue
|
638 |
+
name = event[0]
|
639 |
+
c = -1
|
640 |
+
t = round(16 * event[1] / ticks_per_beat) # quantization
|
641 |
+
new_event = [name, t // 16, t % 16, track_idx]
|
642 |
+
if name == "note":
|
643 |
+
d, c, p, v = event[2:]
|
644 |
+
if not (0 <= c <= 15):
|
645 |
+
continue
|
646 |
+
d = max(1, round(16 * d / ticks_per_beat))
|
647 |
+
new_event += [c, p, v, d]
|
648 |
+
empty_channels[c] = False
|
649 |
+
track_idx_dict.setdefault(c, track_idx)
|
650 |
+
note_tracks = channel_note_tracks[c]
|
651 |
+
if track_idx not in note_tracks:
|
652 |
+
note_tracks.append(track_idx)
|
653 |
+
if c != 9:
|
654 |
+
note_key_hist[p%12] += 1
|
655 |
+
if c not in track_channels:
|
656 |
+
track_channels.append(c)
|
657 |
+
elif name == "patch_change":
|
658 |
+
c, p = event[2:]
|
659 |
+
if not (0 <= c <= 15):
|
660 |
+
continue
|
661 |
+
new_event += [c, p]
|
662 |
+
last_p = patch_dict.setdefault(c, None)
|
663 |
+
if last_p == p:
|
664 |
+
continue
|
665 |
+
patch_dict[c] = p
|
666 |
+
if c not in patch_channels:
|
667 |
+
patch_channels.append(c)
|
668 |
+
elif name == "control_change":
|
669 |
+
c, cc, v = event[2:]
|
670 |
+
if not (0 <= c <= 15):
|
671 |
+
continue
|
672 |
+
new_event += [c, cc, v]
|
673 |
+
last_v = control_dict.setdefault((c, cc), 0)
|
674 |
+
if abs(last_v - v) < cc_eps:
|
675 |
+
continue
|
676 |
+
control_dict[(c, cc)] = v
|
677 |
+
elif name == "set_tempo":
|
678 |
+
tempo = event[2]
|
679 |
+
if tempo == 0: # invalid tempo
|
680 |
+
continue
|
681 |
+
bpm = min(int(self.tempo2bpm(tempo)), 383)
|
682 |
+
new_event += [bpm]
|
683 |
+
if abs(last_bpm - bpm) < tempo_eps:
|
684 |
+
continue
|
685 |
+
last_bpm = bpm
|
686 |
+
elif name == "time_signature":
|
687 |
+
nn, dd = event[2:4]
|
688 |
+
if not (1 <= nn <= 16 and 1 <= dd <= 4): # invalid
|
689 |
+
continue
|
690 |
+
nn -= 1 # make it start from 0
|
691 |
+
dd -= 1
|
692 |
+
new_event += [nn, dd]
|
693 |
+
elif name == "key_signature":
|
694 |
+
sf, mi = event[2:]
|
695 |
+
if not (-7 <= sf <= 7 and 0 <= mi <= 1): # invalid
|
696 |
+
continue
|
697 |
+
sf += 7
|
698 |
+
new_event += [sf, mi]
|
699 |
+
key_sigs.append(new_event)
|
700 |
+
|
701 |
+
if name in ["note", "time_signature", "key_signature"]:
|
702 |
+
key = tuple(new_event[:-2])
|
703 |
+
else:
|
704 |
+
key = tuple(new_event[:-1])
|
705 |
+
|
706 |
+
if c != -1:
|
707 |
+
if c not in channels:
|
708 |
+
channels.append(c)
|
709 |
+
tr_map = track_idx_map[c]
|
710 |
+
if track_idx not in tr_map:
|
711 |
+
tr_map[track_idx] = 0
|
712 |
+
|
713 |
+
if event[0] == "note": # to eliminate note overlap due to quantization
|
714 |
+
cp = tuple(new_event[4:6]) # channel pitch
|
715 |
+
if cp in last_notes:
|
716 |
+
last_note_key, last_note = last_notes[cp]
|
717 |
+
last_t = last_note[1] * 16 + last_note[2]
|
718 |
+
last_note[-1] = max(0, min(last_note[-1], t - last_t)) # modify duration
|
719 |
+
if last_note[-1] == 0:
|
720 |
+
event_list.pop(last_note_key)
|
721 |
+
last_notes[cp] = (key, new_event)
|
722 |
+
event_list[key] = new_event
|
723 |
+
event_list = list(event_list.values())
|
724 |
+
|
725 |
+
empty_channels = [c for c in channels if empty_channels[c]]
|
726 |
+
|
727 |
+
if remap_track_channel:
|
728 |
+
patch_channels = []
|
729 |
+
channels_count = 0
|
730 |
+
channels_map = {9: 9} if 9 in channels else {}
|
731 |
+
if remove_empty_channels:
|
732 |
+
channels = sorted(channels, key=lambda x: 1 if x in empty_channels else 0)
|
733 |
+
for c in channels:
|
734 |
+
if c == 9:
|
735 |
+
continue
|
736 |
+
channels_map[c] = channels_count
|
737 |
+
channels_count += 1
|
738 |
+
if channels_count == 9:
|
739 |
+
channels_count = 10
|
740 |
+
channels = list(channels_map.values())
|
741 |
+
|
742 |
+
track_count = 0
|
743 |
+
track_idx_map_order = [k for k, v in sorted(list(channels_map.items()), key=lambda x: x[1])]
|
744 |
+
for c in track_idx_map_order: # tracks not to remove
|
745 |
+
if remove_empty_channels and c in empty_channels:
|
746 |
+
continue
|
747 |
+
tr_map = track_idx_map[c]
|
748 |
+
for track_idx in tr_map:
|
749 |
+
note_tracks = channel_note_tracks[c]
|
750 |
+
if len(note_tracks) != 0 and track_idx not in note_tracks:
|
751 |
+
continue
|
752 |
+
track_count += 1
|
753 |
+
tr_map[track_idx] = track_count
|
754 |
+
for c in track_idx_map_order: # tracks to remove
|
755 |
+
if not (remove_empty_channels and c in empty_channels):
|
756 |
+
continue
|
757 |
+
tr_map = track_idx_map[c]
|
758 |
+
for track_idx in tr_map:
|
759 |
+
note_tracks = channel_note_tracks[c]
|
760 |
+
if not (len(note_tracks) != 0 and track_idx not in note_tracks):
|
761 |
+
continue
|
762 |
+
track_count += 1
|
763 |
+
tr_map[track_idx] = track_count
|
764 |
+
|
765 |
+
empty_channels = [channels_map[c] for c in empty_channels]
|
766 |
+
track_idx_dict = {}
|
767 |
+
key_sigs = []
|
768 |
+
key_signature_to_add = []
|
769 |
+
key_signature_to_remove = []
|
770 |
+
for event in event_list:
|
771 |
+
name = event[0]
|
772 |
+
track_idx = event[3]
|
773 |
+
if name == "note":
|
774 |
+
c = event[4]
|
775 |
+
event[4] = channels_map[c] # channel
|
776 |
+
event[3] = track_idx_map[c][track_idx] # track
|
777 |
+
track_idx_dict.setdefault(event[4], event[3])
|
778 |
+
# setdefault, so the track_idx is first of the channel
|
779 |
+
elif name in ["set_tempo", "time_signature"]:
|
780 |
+
event[3] = 0 # set track 0 for meta events
|
781 |
+
elif name == "key_signature":
|
782 |
+
new_channel_track_idxs = []
|
783 |
+
for c, tr_map in track_idx_map.items():
|
784 |
+
if track_idx in tr_map:
|
785 |
+
new_track_idx = tr_map[track_idx]
|
786 |
+
c = channels_map[c]
|
787 |
+
new_channel_track_idx = (c, new_track_idx)
|
788 |
+
if new_track_idx == 0:
|
789 |
+
continue
|
790 |
+
if new_channel_track_idx not in new_channel_track_idxs:
|
791 |
+
new_channel_track_idxs.append(new_channel_track_idx)
|
792 |
+
|
793 |
+
if len(new_channel_track_idxs) == 0:
|
794 |
+
if event[3] == 0: # keep key_signature on track 0 (meta)
|
795 |
+
key_sigs.append(event)
|
796 |
+
continue
|
797 |
+
event[3] = -1 # avoid remove same event
|
798 |
+
key_signature_to_remove.append(event) # empty track
|
799 |
+
continue
|
800 |
+
c, nt = new_channel_track_idxs[0]
|
801 |
+
event[3] = nt
|
802 |
+
key_sigs.append(event)
|
803 |
+
if c == 9:
|
804 |
+
event[4] = 7 # sf=0
|
805 |
+
for c, nt in new_channel_track_idxs[1:]:
|
806 |
+
new_event = [*event]
|
807 |
+
new_event[3] = nt
|
808 |
+
if c == 9:
|
809 |
+
new_event[4] = 7 # sf=0
|
810 |
+
key_sigs.append(new_event)
|
811 |
+
key_signature_to_add.append(new_event)
|
812 |
+
elif name == "control_change" or name == "patch_change":
|
813 |
+
c = event[4]
|
814 |
+
event[4] = channels_map[c] # channel
|
815 |
+
tr_map = track_idx_map[c]
|
816 |
+
# move the event to first track of the channel if it's original track is empty
|
817 |
+
note_tracks = channel_note_tracks[c]
|
818 |
+
if len(note_tracks) != 0 and track_idx not in note_tracks:
|
819 |
+
track_idx = channel_note_tracks[c][0]
|
820 |
+
new_track_idx = tr_map[track_idx]
|
821 |
+
event[3] = new_track_idx
|
822 |
+
if name == "patch_change" and event[4] not in patch_channels:
|
823 |
+
patch_channels.append(event[4])
|
824 |
+
for key_sig in key_signature_to_remove:
|
825 |
+
event_list.remove(key_sig)
|
826 |
+
event_list += key_signature_to_add
|
827 |
+
track_to_channels ={}
|
828 |
+
for c, tr_map in track_idx_map.items():
|
829 |
+
if c not in channels_map:
|
830 |
+
continue
|
831 |
+
c = channels_map[c]
|
832 |
+
for _, track_idx in tr_map.items():
|
833 |
+
track_to_channels.setdefault(track_idx, [])
|
834 |
+
cs = track_to_channels[track_idx]
|
835 |
+
if c not in cs:
|
836 |
+
cs.append(c)
|
837 |
+
|
838 |
+
if add_default_instr:
|
839 |
+
for c in channels:
|
840 |
+
if c not in patch_channels and c in track_idx_dict:
|
841 |
+
event_list.append(["patch_change", 0, 0, track_idx_dict[c], c, 0])
|
842 |
+
|
843 |
+
if len(key_sigs) == 0 or all([key_sig[4]==7 for key_sig in key_sigs]):
|
844 |
+
# detect key signature or fix the default key signature
|
845 |
+
root_key = self.detect_key_signature(note_key_hist)
|
846 |
+
if root_key is not None:
|
847 |
+
sf = self.key2sf(root_key, 0)
|
848 |
+
# print("detect_key_signature",sf)
|
849 |
+
if len(key_sigs) == 0:
|
850 |
+
for tr, cs in track_to_channels.items():
|
851 |
+
if remap_track_channel and tr == 0:
|
852 |
+
continue
|
853 |
+
new_event = ["key_signature", 0, 0, tr, (0 if (len(cs) == 1 and cs[0] == 9) else sf) + 7, 0]
|
854 |
+
event_list.append(new_event)
|
855 |
+
else:
|
856 |
+
for key_sig in key_sigs:
|
857 |
+
tr = key_sig[3]
|
858 |
+
if tr in track_to_channels:
|
859 |
+
cs = track_to_channels[tr]
|
860 |
+
if len(cs) == 1 and cs[0] == 9:
|
861 |
+
continue
|
862 |
+
key_sig[4] = sf + 7
|
863 |
+
key_sig[5] = 0
|
864 |
+
else:
|
865 |
+
# remove default key signature
|
866 |
+
for key_sig in key_sigs:
|
867 |
+
event_list.remove(key_sig)
|
868 |
+
|
869 |
+
events_name_order = ["time_signature", "key_signature", "set_tempo", "patch_change", "control_change", "note"]
|
870 |
+
events_name_order = {name: i for i, name in enumerate(events_name_order)}
|
871 |
+
events_order = lambda e: e[1:4] + [events_name_order[e[0]]]
|
872 |
+
event_list = sorted(event_list, key=events_order)
|
873 |
+
|
874 |
+
setup_events = {}
|
875 |
+
notes_in_setup = False
|
876 |
+
for i, event in enumerate(event_list): # optimise setup
|
877 |
+
new_event = [*event] # make copy of event
|
878 |
+
if event[0] not in ["note", "time_signature"]:
|
879 |
+
new_event[1] = 0
|
880 |
+
new_event[2] = 0
|
881 |
+
has_next = False
|
882 |
+
has_pre = False
|
883 |
+
if i < len(event_list) - 1:
|
884 |
+
next_event = event_list[i + 1]
|
885 |
+
has_next = event[1] + event[2] == next_event[1] + next_event[2]
|
886 |
+
if notes_in_setup and i > 0:
|
887 |
+
pre_event = event_list[i - 1]
|
888 |
+
has_pre = event[1] + event[2] == pre_event[1] + pre_event[2]
|
889 |
+
if (event[0] == "note" and not has_next) or (notes_in_setup and not has_pre):
|
890 |
+
event_list = sorted(setup_events.values(), key=events_order) + event_list[i:]
|
891 |
+
break
|
892 |
+
else:
|
893 |
+
if event[0] == "note":
|
894 |
+
notes_in_setup = True
|
895 |
+
if event[0] in ["note", "time_signature", "key_signature"]:
|
896 |
+
key = tuple([event[0]]+event[3:-2])
|
897 |
+
else:
|
898 |
+
key = tuple([event[0]]+event[3:-1])
|
899 |
+
setup_events[key] = new_event
|
900 |
+
|
901 |
+
last_t1 = 0
|
902 |
+
midi_seq = []
|
903 |
+
for event in event_list:
|
904 |
+
if remove_empty_channels and event[0] in ["control_change", "patch_change"] and event[4] in empty_channels:
|
905 |
+
continue
|
906 |
+
cur_t1 = event[1]
|
907 |
+
event[1] = event[1] - last_t1
|
908 |
+
tokens = self.event2tokens(event)
|
909 |
+
if not tokens:
|
910 |
+
continue
|
911 |
+
midi_seq.append(tokens)
|
912 |
+
last_t1 = cur_t1
|
913 |
+
|
914 |
+
if add_bos_eos:
|
915 |
+
bos = [self.bos_id] + [self.pad_id] * (self.max_token_seq - 1)
|
916 |
+
eos = [self.eos_id] + [self.pad_id] * (self.max_token_seq - 1)
|
917 |
+
midi_seq = [bos] + midi_seq + [eos]
|
918 |
+
return midi_seq
|
919 |
+
|
920 |
+
def event2tokens(self, event):
|
921 |
+
name = event[0]
|
922 |
+
params = event[1:]
|
923 |
+
if not all([0 <= params[i] < self.event_parameters[p] for i, p in enumerate(self.events[name])]):
|
924 |
+
return []
|
925 |
+
tokens = [self.event_ids[name]] + [self.parameter_ids[p][params[i]]
|
926 |
+
for i, p in enumerate(self.events[name])]
|
927 |
+
tokens += [self.pad_id] * (self.max_token_seq - len(tokens))
|
928 |
+
return tokens
|
929 |
+
|
930 |
+
def tokens2event(self, tokens):
|
931 |
+
if tokens[0] not in self.id_events:
|
932 |
+
return []
|
933 |
+
name = self.id_events[tokens[0]]
|
934 |
+
if len(tokens) <= len(self.events[name]):
|
935 |
+
return []
|
936 |
+
params = tokens[1:]
|
937 |
+
params = [params[i] - self.parameter_ids[p][0] for i, p in enumerate(self.events[name])]
|
938 |
+
if not all([0 <= params[i] < self.event_parameters[p] for i, p in enumerate(self.events[name])]):
|
939 |
+
return []
|
940 |
+
event = [name] + params
|
941 |
+
return event
|
942 |
+
|
943 |
+
def detokenize(self, midi_seq):
|
944 |
+
ticks_per_beat = 480
|
945 |
+
tracks_dict = {}
|
946 |
+
t1 = 0
|
947 |
+
for tokens in midi_seq:
|
948 |
+
if tokens[0] in self.id_events:
|
949 |
+
event = self.tokens2event(tokens)
|
950 |
+
if not event:
|
951 |
+
continue
|
952 |
+
name = event[0]
|
953 |
+
t1 += event[1]
|
954 |
+
t = t1 * 16 + event[2]
|
955 |
+
t = int(t * ticks_per_beat / 16)
|
956 |
+
track_idx = event[3]
|
957 |
+
event_new = [name, t]
|
958 |
+
if name == "note":
|
959 |
+
c, p, v, d = event[4:]
|
960 |
+
d = int(d * ticks_per_beat / 16)
|
961 |
+
event_new += [d, c, p, v]
|
962 |
+
elif name == "control_change" or name == "patch_change":
|
963 |
+
event_new += event[4:]
|
964 |
+
elif name == "set_tempo":
|
965 |
+
event_new += [self.bpm2tempo(event[4])]
|
966 |
+
elif name == "time_signature":
|
967 |
+
nn, dd = event[4:]
|
968 |
+
nn += 1
|
969 |
+
dd += 1
|
970 |
+
event_new += [nn, dd, 24, 8] # usually cc, bb = 24, 8
|
971 |
+
elif name == "key_signature":
|
972 |
+
sf, mi = event[4:]
|
973 |
+
sf -= 7
|
974 |
+
event_new += [sf, mi]
|
975 |
+
else: # should not go here
|
976 |
+
continue
|
977 |
+
if track_idx not in tracks_dict:
|
978 |
+
tracks_dict[track_idx] = []
|
979 |
+
tracks_dict[track_idx].append(event_new)
|
980 |
+
tracks = [tr for idx, tr in sorted(list(tracks_dict.items()), key=lambda it: it[0])]
|
981 |
+
|
982 |
+
for i in range(len(tracks)): # to eliminate note overlap
|
983 |
+
track = tracks[i]
|
984 |
+
track = sorted(track, key=lambda e: e[1])
|
985 |
+
last_note_t = {}
|
986 |
+
zero_len_notes = []
|
987 |
+
for e in reversed(track):
|
988 |
+
if e[0] == "note":
|
989 |
+
t, d, c, p = e[1:5]
|
990 |
+
key = (c, p)
|
991 |
+
if key in last_note_t:
|
992 |
+
d = min(d, max(last_note_t[key] - t, 0))
|
993 |
+
last_note_t[key] = t
|
994 |
+
e[2] = d
|
995 |
+
if d == 0:
|
996 |
+
zero_len_notes.append(e)
|
997 |
+
for e in zero_len_notes:
|
998 |
+
track.remove(e)
|
999 |
+
tracks[i] = track
|
1000 |
+
return [ticks_per_beat, *tracks]
|
1001 |
+
|
1002 |
+
def midi2img(self, midi_score):
|
1003 |
+
ticks_per_beat = midi_score[0]
|
1004 |
+
notes = []
|
1005 |
+
max_time = 1
|
1006 |
+
track_num = len(midi_score[1:])
|
1007 |
+
for track_idx, track in enumerate(midi_score[1:]):
|
1008 |
+
for event in track:
|
1009 |
+
t = round(16 * event[1] / ticks_per_beat)
|
1010 |
+
if event[0] == "note":
|
1011 |
+
d = max(1, round(16 * event[2] / ticks_per_beat))
|
1012 |
+
c, p = event[3:5]
|
1013 |
+
max_time = max(max_time, t + d + 1)
|
1014 |
+
notes.append((track_idx, c, p, t, d))
|
1015 |
+
img = np.zeros((128, max_time, 3), dtype=np.uint8)
|
1016 |
+
colors = {(i, j): np.random.randint(50, 256, 3) for i in range(track_num) for j in range(16)}
|
1017 |
+
for note in notes:
|
1018 |
+
tr, c, p, t, d = note
|
1019 |
+
img[p, t: t + d] = colors[(tr, c)]
|
1020 |
+
img = PIL.Image.fromarray(np.flip(img, 0))
|
1021 |
+
return img
|
1022 |
+
|
1023 |
+
def augment(self, midi_seq, max_pitch_shift=4, max_vel_shift=10, max_cc_val_shift=10, max_bpm_shift=10,
|
1024 |
+
max_track_shift=0, max_channel_shift=16):
|
1025 |
+
pitch_shift = random.randint(-max_pitch_shift, max_pitch_shift)
|
1026 |
+
vel_shift = random.randint(-max_vel_shift, max_vel_shift)
|
1027 |
+
cc_val_shift = random.randint(-max_cc_val_shift, max_cc_val_shift)
|
1028 |
+
bpm_shift = random.randint(-max_bpm_shift, max_bpm_shift)
|
1029 |
+
track_shift = random.randint(0, max_track_shift)
|
1030 |
+
channel_shift = random.randint(0, max_channel_shift)
|
1031 |
+
midi_seq_new = []
|
1032 |
+
key_signature_tokens = []
|
1033 |
+
track_to_channels = {}
|
1034 |
+
for tokens in midi_seq:
|
1035 |
+
tokens_new = [*tokens]
|
1036 |
+
if tokens[0] in self.id_events:
|
1037 |
+
name = self.id_events[tokens[0]]
|
1038 |
+
for i, pn in enumerate(self.events[name]):
|
1039 |
+
if pn == "track":
|
1040 |
+
tr = tokens[1 + i] - self.parameter_ids[pn][0]
|
1041 |
+
tr += track_shift
|
1042 |
+
tr = tr % self.event_parameters[pn]
|
1043 |
+
tokens_new[1 + i] = self.parameter_ids[pn][tr]
|
1044 |
+
elif pn == "channel":
|
1045 |
+
c = tokens[1 + i] - self.parameter_ids[pn][0]
|
1046 |
+
c0 = c
|
1047 |
+
c += channel_shift
|
1048 |
+
c = c % self.event_parameters[pn]
|
1049 |
+
if c0 == 9:
|
1050 |
+
c = 9
|
1051 |
+
elif c == 9:
|
1052 |
+
c = (9 + channel_shift) % self.event_parameters[pn]
|
1053 |
+
tokens_new[1 + i] = self.parameter_ids[pn][c]
|
1054 |
+
|
1055 |
+
if name == "note":
|
1056 |
+
tr = tokens[3] - self.parameter_ids["track"][0]
|
1057 |
+
c = tokens[4] - self.parameter_ids["channel"][0]
|
1058 |
+
p = tokens[5] - self.parameter_ids["pitch"][0]
|
1059 |
+
v = tokens[6] - self.parameter_ids["velocity"][0]
|
1060 |
+
if c != 9: # no shift for drums
|
1061 |
+
p += pitch_shift
|
1062 |
+
if not 0 <= p < 128:
|
1063 |
+
return midi_seq
|
1064 |
+
v += vel_shift
|
1065 |
+
v = max(1, min(127, v))
|
1066 |
+
tokens_new[5] = self.parameter_ids["pitch"][p]
|
1067 |
+
tokens_new[6] = self.parameter_ids["velocity"][v]
|
1068 |
+
track_to_channels.setdefault(tr, [])
|
1069 |
+
cs = track_to_channels[tr]
|
1070 |
+
if c not in cs:
|
1071 |
+
cs.append(c)
|
1072 |
+
elif name == "control_change":
|
1073 |
+
cc = tokens[5] - self.parameter_ids["controller"][0]
|
1074 |
+
val = tokens[6] - self.parameter_ids["value"][0]
|
1075 |
+
if cc in [1, 2, 7, 11]:
|
1076 |
+
val += cc_val_shift
|
1077 |
+
val = max(1, min(127, val))
|
1078 |
+
tokens_new[6] = self.parameter_ids["value"][val]
|
1079 |
+
elif name == "set_tempo":
|
1080 |
+
bpm = tokens[4] - self.parameter_ids["bpm"][0]
|
1081 |
+
bpm += bpm_shift
|
1082 |
+
bpm = max(1, min(383, bpm))
|
1083 |
+
tokens_new[4] = self.parameter_ids["bpm"][bpm]
|
1084 |
+
elif name == "key_signature":
|
1085 |
+
sf = tokens[4] - self.parameter_ids["sf"][0]
|
1086 |
+
mi = tokens[5] - self.parameter_ids["mi"][0]
|
1087 |
+
sf -= 7
|
1088 |
+
k = self.sf2key(sf)
|
1089 |
+
k = (k + pitch_shift) % 12
|
1090 |
+
sf = self.key2sf(k, mi)
|
1091 |
+
sf += 7
|
1092 |
+
tokens_new[4] = self.parameter_ids["sf"][sf]
|
1093 |
+
tokens_new[5] = self.parameter_ids["mi"][mi]
|
1094 |
+
key_signature_tokens.append(tokens_new)
|
1095 |
+
midi_seq_new.append(tokens_new)
|
1096 |
+
for tokens in key_signature_tokens:
|
1097 |
+
tr = tokens[3] - self.parameter_ids["track"][0]
|
1098 |
+
if tr in track_to_channels:
|
1099 |
+
cs = track_to_channels[tr]
|
1100 |
+
if len(cs) == 1 and cs[0] == 9:
|
1101 |
+
tokens[4] = self.parameter_ids["sf"][7] # sf=0
|
1102 |
+
return midi_seq_new
|
1103 |
+
|
1104 |
+
def check_quality(self, midi_seq, alignment_min=0.3, tonality_min=0.8, piano_max=0.7, notes_bandwidth_min=3,
|
1105 |
+
notes_density_max=50, notes_density_min=2.5, total_notes_max=20000, total_notes_min=256,
|
1106 |
+
note_window_size=16):
|
1107 |
+
total_notes = 0
|
1108 |
+
channels = []
|
1109 |
+
time_hist = [0] * 16
|
1110 |
+
note_windows = {}
|
1111 |
+
notes_sametime = []
|
1112 |
+
notes_density_list = []
|
1113 |
+
tonality_list = []
|
1114 |
+
notes_bandwidth_list = []
|
1115 |
+
instruments = {}
|
1116 |
+
piano_channels = []
|
1117 |
+
abs_t1 = 0
|
1118 |
+
last_t = 0
|
1119 |
+
for tsi, tokens in enumerate(midi_seq):
|
1120 |
+
event = self.tokens2event(tokens)
|
1121 |
+
if not event:
|
1122 |
+
continue
|
1123 |
+
t1, t2, tr = event[1:4]
|
1124 |
+
abs_t1 += t1
|
1125 |
+
t = abs_t1 * 16 + t2
|
1126 |
+
c = None
|
1127 |
+
if event[0] == "note":
|
1128 |
+
c, p, v, d = event[4:]
|
1129 |
+
total_notes += 1
|
1130 |
+
time_hist[t2] += 1
|
1131 |
+
if c != 9: # ignore drum channel
|
1132 |
+
if c not in instruments:
|
1133 |
+
instruments[c] = 0
|
1134 |
+
if c not in piano_channels:
|
1135 |
+
piano_channels.append(c)
|
1136 |
+
note_windows.setdefault(abs_t1 // note_window_size, []).append(p)
|
1137 |
+
if last_t != t:
|
1138 |
+
notes_sametime = [(et, p_) for et, p_ in notes_sametime if et > last_t]
|
1139 |
+
notes_sametime_p = [p_ for _, p_ in notes_sametime]
|
1140 |
+
if len(notes_sametime) > 0:
|
1141 |
+
notes_bandwidth_list.append(max(notes_sametime_p) - min(notes_sametime_p))
|
1142 |
+
notes_sametime.append((t + d - 1, p))
|
1143 |
+
elif event[0] == "patch_change":
|
1144 |
+
c, p = event[4:]
|
1145 |
+
instruments[c] = p
|
1146 |
+
if p == 0 and c not in piano_channels:
|
1147 |
+
piano_channels.append(c)
|
1148 |
+
if c is not None and c not in channels:
|
1149 |
+
channels.append(c)
|
1150 |
+
last_t = t
|
1151 |
+
reasons = []
|
1152 |
+
if total_notes < total_notes_min:
|
1153 |
+
reasons.append("total_min")
|
1154 |
+
if total_notes > total_notes_max:
|
1155 |
+
reasons.append("total_max")
|
1156 |
+
if len(note_windows) == 0 and total_notes > 0:
|
1157 |
+
reasons.append("drum_only")
|
1158 |
+
if reasons:
|
1159 |
+
return False, reasons
|
1160 |
+
time_hist = sorted(time_hist, reverse=True)
|
1161 |
+
alignment = sum(time_hist[:2]) / total_notes
|
1162 |
+
for notes in note_windows.values():
|
1163 |
+
key_hist = [0] * 12
|
1164 |
+
for p in notes:
|
1165 |
+
key_hist[p % 12] += 1
|
1166 |
+
key_hist = sorted(key_hist, reverse=True)
|
1167 |
+
tonality_list.append(sum(key_hist[:7]) / len(notes))
|
1168 |
+
notes_density_list.append(len(notes) / note_window_size)
|
1169 |
+
tonality_list = sorted(tonality_list)
|
1170 |
+
tonality = sum(tonality_list) / len(tonality_list)
|
1171 |
+
notes_bandwidth = sum(notes_bandwidth_list) / len(notes_bandwidth_list) if notes_bandwidth_list else 0
|
1172 |
+
notes_density = max(notes_density_list) if notes_density_list else 0
|
1173 |
+
piano_ratio = len(piano_channels) / len(channels)
|
1174 |
+
if len(channels) <= 3: # ignore piano threshold if it is a piano solo midi
|
1175 |
+
piano_max = 1
|
1176 |
+
if alignment < alignment_min: # check weather the notes align to the bars (because some midi files are recorded)
|
1177 |
+
reasons.append("alignment")
|
1178 |
+
if tonality < tonality_min: # check whether the music is tonal
|
1179 |
+
reasons.append("tonality")
|
1180 |
+
if notes_bandwidth < notes_bandwidth_min: # check whether music is melodic line only
|
1181 |
+
reasons.append("bandwidth")
|
1182 |
+
if not notes_density_min < notes_density < notes_density_max:
|
1183 |
+
reasons.append("density")
|
1184 |
+
if piano_ratio > piano_max: # check whether most instruments is piano (because some midi files don't have instruments assigned correctly)
|
1185 |
+
reasons.append("piano")
|
1186 |
+
return not reasons, reasons
|
1187 |
+
|
1188 |
+
|
1189 |
+
class MIDITokenizer:
|
1190 |
+
def __new__(cls, version="v2"):
|
1191 |
+
if version == "v1":
|
1192 |
+
return MIDITokenizerV1()
|
1193 |
+
elif version == "v2":
|
1194 |
+
return MIDITokenizerV2()
|
1195 |
+
else:
|
1196 |
+
raise ValueError(f"Unsupported version: {version}")
|
packages.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
fluidsynth
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--extra-index-url https://download.pytorch.org/whl/cu124
|
2 |
+
Pillow
|
3 |
+
numpy
|
4 |
+
torch
|
5 |
+
onnxruntime-gpu
|
6 |
+
peft>=0.13.0
|
7 |
+
transformers>=4.36
|
8 |
+
gradio==5.0.1
|
9 |
+
pyfluidsynth
|
10 |
+
tqdm
|
11 |
+
huggingface_hub
|