juancopi81 commited on
Commit
b100e1c
1 Parent(s): 8f8dcb6

Add t5x and mt3 models

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
LICENSE ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding those notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
176
+
177
+ END OF TERMS AND CONDITIONS
178
+
179
+ APPENDIX: How to apply the Apache License to your work.
180
+
181
+ To apply the Apache License to your work, attach the following
182
+ boilerplate notice, with the fields enclosed by brackets "[]"
183
+ replaced with your own identifying information. (Don't include
184
+ the brackets!) The text should be enclosed in the appropriate
185
+ comment syntax for the file format. We also recommend that a
186
+ file or class name and description of purpose be included on the
187
+ same "printed page" as the copyright notice for easier
188
+ identification within third-party archives.
189
+
190
+ Copyright [yyyy] [name of copyright owner]
191
+
192
+ Licensed under the Apache License, Version 2.0 (the "License");
193
+ you may not use this file except in compliance with the License.
194
+ You may obtain a copy of the License at
195
+
196
+ http://www.apache.org/licenses/LICENSE-2.0
197
+
198
+ Unless required by applicable law or agreed to in writing, software
199
+ distributed under the License is distributed on an "AS IS" BASIS,
200
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201
+ See the License for the specific language governing permissions and
202
+ limitations under the License.
README.md CHANGED
@@ -10,5 +10,3 @@ app_file: app.py
10
  pinned: false
11
  license: apache-2.0
12
  ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
10
  pinned: false
11
  license: apache-2.0
12
  ---
 
 
mt3/__init__.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The MT3 Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Base module for MT3."""
16
+
17
+ from mt3 import datasets
18
+ from mt3 import event_codec
19
+ from mt3 import inference
20
+ from mt3 import layers
21
+ from mt3 import metrics
22
+ from mt3 import metrics_utils
23
+ from mt3 import models
24
+ from mt3 import network
25
+ from mt3 import note_sequences
26
+ from mt3 import preprocessors
27
+ from mt3 import run_length_encoding
28
+ from mt3 import spectrograms
29
+ from mt3 import summaries
30
+ from mt3 import tasks
31
+ from mt3 import vocabularies
32
+
33
+ from mt3.version import __version__
mt3/datasets.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The MT3 Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Dataset configurations."""
16
+
17
+ import dataclasses
18
+ from typing import Mapping, Sequence, Union
19
+
20
+ from mt3 import note_sequences
21
+ import tensorflow as tf
22
+
23
+
24
+
25
+ @dataclasses.dataclass
26
+ class InferEvalSplit:
27
+ # key in dictionary containing all dataset splits
28
+ name: str
29
+ # task name suffix (each eval split is a separate task)
30
+ suffix: str
31
+ # whether or not to include in the mixture of all eval tasks
32
+ include_in_mixture: bool = True
33
+
34
+
35
+ @dataclasses.dataclass
36
+ class DatasetConfig:
37
+ """Configuration for a transcription dataset."""
38
+ # dataset name
39
+ name: str
40
+ # mapping from split name to path
41
+ paths: Mapping[str, str]
42
+ # mapping from feature name to feature
43
+ features: Mapping[str, Union[tf.io.FixedLenFeature,
44
+ tf.io.FixedLenSequenceFeature]]
45
+ # training split name
46
+ train_split: str
47
+ # training eval split name
48
+ train_eval_split: str
49
+ # list of infer eval split specs
50
+ infer_eval_splits: Sequence[InferEvalSplit]
51
+ # list of track specs to be used for metrics
52
+ track_specs: Sequence[note_sequences.TrackSpec] = dataclasses.field(
53
+ default_factory=list)
54
+
55
+ MAESTROV1_CONFIG = DatasetConfig(
56
+ name='maestrov1',
57
+ paths={
58
+ 'train':
59
+ 'gs://magentadata/datasets/maestro/v1.0.0/maestro-v1.0.0_ns_wav_train.tfrecord-?????-of-00010',
60
+ 'train_subset':
61
+ 'gs://magentadata/datasets/maestro/v1.0.0/maestro-v1.0.0_ns_wav_train.tfrecord-00002-of-00010',
62
+ 'validation':
63
+ 'gs://magentadata/datasets/maestro/v1.0.0/maestro-v1.0.0_ns_wav_validation.tfrecord-?????-of-00010',
64
+ 'validation_subset':
65
+ 'gs://magentadata/datasets/maestro/v1.0.0/maestro-v1.0.0_ns_wav_validation.tfrecord-0000[06]-of-00010',
66
+ 'test':
67
+ 'gs://magentadata/datasets/maestro/v1.0.0/maestro-v1.0.0_ns_wav_test.tfrecord-?????-of-00010'
68
+ },
69
+ features={
70
+ 'audio': tf.io.FixedLenFeature([], dtype=tf.string),
71
+ 'sequence': tf.io.FixedLenFeature([], dtype=tf.string),
72
+ 'id': tf.io.FixedLenFeature([], dtype=tf.string)
73
+ },
74
+ train_split='train',
75
+ train_eval_split='validation_subset',
76
+ infer_eval_splits=[
77
+ InferEvalSplit(name='train', suffix='eval_train_full',
78
+ include_in_mixture=False),
79
+ InferEvalSplit(name='train_subset', suffix='eval_train'),
80
+ InferEvalSplit(name='validation', suffix='validation_full',
81
+ include_in_mixture=False),
82
+ InferEvalSplit(name='validation_subset', suffix='validation'),
83
+ InferEvalSplit(name='test', suffix='test', include_in_mixture=False)
84
+ ])
85
+
86
+
87
+ MAESTROV3_CONFIG = DatasetConfig(
88
+ name='maestrov3',
89
+ paths={
90
+ 'train':
91
+ 'gs://magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0_ns_wav_train.tfrecord-?????-of-00025',
92
+ 'train_subset':
93
+ 'gs://magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0_ns_wav_train.tfrecord-00004-of-00025',
94
+ 'validation':
95
+ 'gs://magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0_ns_wav_validation.tfrecord-?????-of-00025',
96
+ 'validation_subset':
97
+ 'gs://magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0_ns_wav_validation.tfrecord-0002?-of-00025',
98
+ 'test':
99
+ 'gs://magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0_ns_wav_test.tfrecord-?????-of-00025'
100
+ },
101
+ features={
102
+ 'audio': tf.io.FixedLenFeature([], dtype=tf.string),
103
+ 'sequence': tf.io.FixedLenFeature([], dtype=tf.string),
104
+ 'id': tf.io.FixedLenFeature([], dtype=tf.string)
105
+ },
106
+ train_split='train',
107
+ train_eval_split='validation_subset',
108
+ infer_eval_splits=[
109
+ InferEvalSplit(name='train', suffix='eval_train_full',
110
+ include_in_mixture=False),
111
+ InferEvalSplit(name='train_subset', suffix='eval_train'),
112
+ InferEvalSplit(name='validation', suffix='validation_full',
113
+ include_in_mixture=False),
114
+ InferEvalSplit(name='validation_subset', suffix='validation'),
115
+ InferEvalSplit(name='test', suffix='test', include_in_mixture=False)
116
+ ])
117
+
118
+
119
+ GUITARSET_CONFIG = DatasetConfig(
120
+ name='guitarset',
121
+ paths={
122
+ 'train':
123
+ 'gs://mt3/data/datasets/guitarset/train.tfrecord-?????-of-00019',
124
+ 'validation':
125
+ 'gs://mt3/data/datasets/guitarset/validation.tfrecord-?????-of-00006',
126
+ },
127
+ features={
128
+ 'sequence': tf.io.FixedLenFeature([], dtype=tf.string),
129
+ 'audio': tf.io.FixedLenFeature([], dtype=tf.string),
130
+ 'velocity_range': tf.io.FixedLenFeature([], dtype=tf.string),
131
+ 'id': tf.io.FixedLenFeature([], dtype=tf.string),
132
+ },
133
+ train_split='train',
134
+ train_eval_split='validation',
135
+ infer_eval_splits=[
136
+ InferEvalSplit(name='train', suffix='eval_train'),
137
+ InferEvalSplit(name='validation', suffix='validation'),
138
+ ])
139
+
140
+
141
+ URMP_CONFIG = DatasetConfig(
142
+ name='urmp',
143
+ paths={
144
+ 'train': 'gs://mt3/data/datasets/urmp/train.tfrecord',
145
+ 'validation': 'gs://mt3/data/datasets/urmp/validation.tfrecord',
146
+ },
147
+ features={
148
+ 'id': tf.io.FixedLenFeature([], dtype=tf.string),
149
+ 'tracks': tf.io.FixedLenSequenceFeature(
150
+ [], dtype=tf.int64, allow_missing=True),
151
+ 'inst_names': tf.io.FixedLenSequenceFeature(
152
+ [], dtype=tf.string, allow_missing=True),
153
+ 'audio': tf.io.FixedLenFeature([], dtype=tf.string),
154
+ 'sequence': tf.io.FixedLenFeature([], dtype=tf.string),
155
+ 'instrument_sequences': tf.io.FixedLenSequenceFeature(
156
+ [], dtype=tf.string, allow_missing=True),
157
+ },
158
+ train_split='train',
159
+ train_eval_split='validation',
160
+ infer_eval_splits=[
161
+ InferEvalSplit(name='train', suffix='eval_train'),
162
+ InferEvalSplit(name='validation', suffix='validation')
163
+ ])
164
+
165
+
166
+ MUSICNET_CONFIG = DatasetConfig(
167
+ name='musicnet',
168
+ paths={
169
+ 'train':
170
+ 'gs://mt3/data/datasets/musicnet/musicnet-train.tfrecord-?????-of-00036',
171
+ 'validation':
172
+ 'gs://mt3/data/datasets/musicnet/musicnet-validation.tfrecord-?????-of-00005',
173
+ 'test':
174
+ 'gs://mt3/data/datasets/musicnet/musicnet-test.tfrecord-?????-of-00003'
175
+ },
176
+ features={
177
+ 'id': tf.io.FixedLenFeature([], dtype=tf.string),
178
+ 'sample_rate': tf.io.FixedLenFeature([], dtype=tf.float32),
179
+ 'audio': tf.io.FixedLenSequenceFeature(
180
+ [], dtype=tf.float32, allow_missing=True),
181
+ 'sequence': tf.io.FixedLenFeature([], dtype=tf.string)
182
+ },
183
+ train_split='train',
184
+ train_eval_split='validation',
185
+ infer_eval_splits=[
186
+ InferEvalSplit(name='train', suffix='eval_train'),
187
+ InferEvalSplit(name='validation', suffix='validation'),
188
+ InferEvalSplit(name='test', suffix='test', include_in_mixture=False)
189
+ ])
190
+
191
+
192
+ MUSICNET_EM_CONFIG = DatasetConfig(
193
+ name='musicnet_em',
194
+ paths={
195
+ 'train':
196
+ 'gs://mt3/data/datasets/musicnet_em/train.tfrecord-?????-of-00103',
197
+ 'validation':
198
+ 'gs://mt3/data/datasets/musicnet_em/validation.tfrecord-?????-of-00005',
199
+ 'test':
200
+ 'gs://mt3/data/datasets/musicnet_em/test.tfrecord-?????-of-00006'
201
+ },
202
+ features={
203
+ 'id': tf.io.FixedLenFeature([], dtype=tf.string),
204
+ 'sample_rate': tf.io.FixedLenFeature([], dtype=tf.float32),
205
+ 'audio': tf.io.FixedLenSequenceFeature(
206
+ [], dtype=tf.float32, allow_missing=True),
207
+ 'sequence': tf.io.FixedLenFeature([], dtype=tf.string)
208
+ },
209
+ train_split='train',
210
+ train_eval_split='validation',
211
+ infer_eval_splits=[
212
+ InferEvalSplit(name='train', suffix='eval_train'),
213
+ InferEvalSplit(name='validation', suffix='validation'),
214
+ InferEvalSplit(name='test', suffix='test', include_in_mixture=False)
215
+ ])
216
+
217
+
218
+ CERBERUS4_CONFIG = DatasetConfig(
219
+ name='cerberus4',
220
+ paths={
221
+ 'train':
222
+ 'gs://mt3/data/datasets/cerberus4/slakh_multi_cerberus_train_bass:drums:guitar:piano.tfrecord-?????-of-00286',
223
+ 'train_subset':
224
+ 'gs://mt3/data/datasets/cerberus4/slakh_multi_cerberus_train_bass:drums:guitar:piano.tfrecord-00000-of-00286',
225
+ 'validation':
226
+ 'gs://mt3/data/datasets/cerberus4/slakh_multi_cerberus_validation_bass:drums:guitar:piano.tfrecord-?????-of-00212',
227
+ 'validation_subset':
228
+ 'gs://mt3/data/datasets/cerberus4/slakh_multi_cerberus_validation_bass:drums:guitar:piano.tfrecord-0000?-of-00212',
229
+ 'test':
230
+ 'gs://mt3/data/datasets/cerberus4/slakh_multi_cerberus_test_bass:drums:guitar:piano.tfrecord-?????-of-00106'
231
+ },
232
+ features={
233
+ 'audio_sample_rate': tf.io.FixedLenFeature([], dtype=tf.int64),
234
+ 'inst_names': tf.io.FixedLenSequenceFeature(
235
+ [], dtype=tf.string, allow_missing=True),
236
+ 'midi_class': tf.io.FixedLenSequenceFeature(
237
+ [], dtype=tf.int64, allow_missing=True),
238
+ 'mix': tf.io.FixedLenSequenceFeature(
239
+ [], dtype=tf.float32, allow_missing=True),
240
+ 'note_sequences': tf.io.FixedLenSequenceFeature(
241
+ [], dtype=tf.string, allow_missing=True),
242
+ 'plugin_name': tf.io.FixedLenSequenceFeature(
243
+ [], dtype=tf.int64, allow_missing=True),
244
+ 'program_num': tf.io.FixedLenSequenceFeature(
245
+ [], dtype=tf.int64, allow_missing=True),
246
+ 'slakh_class': tf.io.FixedLenSequenceFeature(
247
+ [], dtype=tf.int64, allow_missing=True),
248
+ 'src_ids': tf.io.FixedLenSequenceFeature(
249
+ [], dtype=tf.string, allow_missing=True),
250
+ 'stems': tf.io.FixedLenSequenceFeature(
251
+ [], dtype=tf.float32, allow_missing=True),
252
+ 'stems_shape': tf.io.FixedLenFeature([2], dtype=tf.int64),
253
+ 'target_type': tf.io.FixedLenFeature([], dtype=tf.string),
254
+ 'track_id': tf.io.FixedLenFeature([], dtype=tf.string),
255
+ },
256
+ train_split='train',
257
+ train_eval_split='validation_subset',
258
+ infer_eval_splits=[
259
+ InferEvalSplit(name='train', suffix='eval_train_full',
260
+ include_in_mixture=False),
261
+ InferEvalSplit(name='train_subset', suffix='eval_train'),
262
+ InferEvalSplit(name='validation', suffix='validation_full',
263
+ include_in_mixture=False),
264
+ InferEvalSplit(name='validation_subset', suffix='validation'),
265
+ InferEvalSplit(name='test', suffix='test', include_in_mixture=False)
266
+ ],
267
+ track_specs=[
268
+ note_sequences.TrackSpec('bass', program=32),
269
+ note_sequences.TrackSpec('drums', is_drum=True),
270
+ note_sequences.TrackSpec('guitar', program=24),
271
+ note_sequences.TrackSpec('piano', program=0)
272
+ ])
273
+
274
+
275
+ SLAKH_CONFIG = DatasetConfig(
276
+ name='slakh',
277
+ paths={
278
+ 'train':
279
+ 'gs://mt3/data/datasets/slakh/slakh_multi_full_subsets_10_train_all_inst.tfrecord-?????-of-02307',
280
+ 'train_subset':
281
+ 'gs://mt3/data/datasets/slakh/slakh_multi_full_subsets_10_train_all_inst.tfrecord-00000-of-02307',
282
+ 'validation':
283
+ 'gs://mt3/data/datasets/slakh/slakh_multi_full_validation_all_inst.tfrecord-?????-of-00168',
284
+ 'validation_subset':
285
+ 'gs://mt3/data/datasets/slakh/slakh_multi_full_validation_all_inst.tfrecord-0000?-of-00168',
286
+ 'test':
287
+ 'gs://mt3/data/datasets/slakh/slakh_multi_full_test_all_inst.tfrecord-?????-of-00109'
288
+ },
289
+ features={
290
+ 'audio_sample_rate': tf.io.FixedLenFeature([], dtype=tf.int64),
291
+ 'inst_names': tf.io.FixedLenSequenceFeature([], dtype=tf.string,
292
+ allow_missing=True),
293
+ 'midi_class': tf.io.FixedLenSequenceFeature([], dtype=tf.int64,
294
+ allow_missing=True),
295
+ 'mix': tf.io.FixedLenSequenceFeature([], dtype=tf.float32,
296
+ allow_missing=True),
297
+ 'note_sequences': tf.io.FixedLenSequenceFeature([], dtype=tf.string,
298
+ allow_missing=True),
299
+ 'plugin_name': tf.io.FixedLenSequenceFeature([], dtype=tf.int64,
300
+ allow_missing=True),
301
+ 'program_num': tf.io.FixedLenSequenceFeature([], dtype=tf.int64,
302
+ allow_missing=True),
303
+ 'slakh_class': tf.io.FixedLenSequenceFeature([], dtype=tf.int64,
304
+ allow_missing=True),
305
+ 'src_ids': tf.io.FixedLenSequenceFeature([], dtype=tf.string,
306
+ allow_missing=True),
307
+ 'stems': tf.io.FixedLenSequenceFeature([], dtype=tf.float32,
308
+ allow_missing=True),
309
+ 'stems_shape': tf.io.FixedLenFeature([2], dtype=tf.int64),
310
+ 'target_type': tf.io.FixedLenFeature([], dtype=tf.string),
311
+ 'track_id': tf.io.FixedLenFeature([], dtype=tf.string),
312
+ },
313
+ train_split='train',
314
+ train_eval_split='validation_subset',
315
+ infer_eval_splits=[
316
+ InferEvalSplit(name='train', suffix='eval_train_full',
317
+ include_in_mixture=False),
318
+ InferEvalSplit(name='train_subset', suffix='eval_train'),
319
+ InferEvalSplit(name='validation', suffix='validation_full',
320
+ include_in_mixture=False),
321
+ InferEvalSplit(name='validation_subset', suffix='validation'),
322
+ InferEvalSplit(name='test', suffix='test', include_in_mixture=False)
323
+ ])
324
+
325
+
mt3/event_codec.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The MT3 Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Encode and decode events."""
16
+
17
+ import dataclasses
18
+ from typing import List, Tuple
19
+
20
+
21
+ @dataclasses.dataclass
22
+ class EventRange:
23
+ type: str
24
+ min_value: int
25
+ max_value: int
26
+
27
+
28
+ @dataclasses.dataclass
29
+ class Event:
30
+ type: str
31
+ value: int
32
+
33
+
34
+ class Codec:
35
+ """Encode and decode events.
36
+
37
+ Useful for declaring what certain ranges of a vocabulary should be used for.
38
+ This is intended to be used from Python before encoding or after decoding with
39
+ GenericTokenVocabulary. This class is more lightweight and does not include
40
+ things like EOS or UNK token handling.
41
+
42
+ To ensure that 'shift' events are always the first block of the vocab and
43
+ start at 0, that event type is required and specified separately.
44
+ """
45
+
46
+ def __init__(self, max_shift_steps: int, steps_per_second: float,
47
+ event_ranges: List[EventRange]):
48
+ """Define Codec.
49
+
50
+ Args:
51
+ max_shift_steps: Maximum number of shift steps that can be encoded.
52
+ steps_per_second: Shift steps will be interpreted as having a duration of
53
+ 1 / steps_per_second.
54
+ event_ranges: Other supported event types and their ranges.
55
+ """
56
+ self.steps_per_second = steps_per_second
57
+ self._shift_range = EventRange(
58
+ type='shift', min_value=0, max_value=max_shift_steps)
59
+ self._event_ranges = [self._shift_range] + event_ranges
60
+ # Ensure all event types have unique names.
61
+ assert len(self._event_ranges) == len(
62
+ set([er.type for er in self._event_ranges]))
63
+
64
+ @property
65
+ def num_classes(self) -> int:
66
+ return sum(er.max_value - er.min_value + 1 for er in self._event_ranges)
67
+
68
+ # The next couple methods are simplified special case methods just for shift
69
+ # events that are intended to be used from within autograph functions.
70
+
71
+ def is_shift_event_index(self, index: int) -> bool:
72
+ return (self._shift_range.min_value <= index) and (
73
+ index <= self._shift_range.max_value)
74
+
75
+ @property
76
+ def max_shift_steps(self) -> int:
77
+ return self._shift_range.max_value
78
+
79
+ def encode_event(self, event: Event) -> int:
80
+ """Encode an event to an index."""
81
+ offset = 0
82
+ for er in self._event_ranges:
83
+ if event.type == er.type:
84
+ if not er.min_value <= event.value <= er.max_value:
85
+ raise ValueError(
86
+ f'Event value {event.value} is not within valid range '
87
+ f'[{er.min_value}, {er.max_value}] for type {event.type}')
88
+ return offset + event.value - er.min_value
89
+ offset += er.max_value - er.min_value + 1
90
+
91
+ raise ValueError(f'Unknown event type: {event.type}')
92
+
93
+ def event_type_range(self, event_type: str) -> Tuple[int, int]:
94
+ """Return [min_id, max_id] for an event type."""
95
+ offset = 0
96
+ for er in self._event_ranges:
97
+ if event_type == er.type:
98
+ return offset, offset + (er.max_value - er.min_value)
99
+ offset += er.max_value - er.min_value + 1
100
+
101
+ raise ValueError(f'Unknown event type: {event_type}')
102
+
103
+ def decode_event_index(self, index: int) -> Event:
104
+ """Decode an event index to an Event."""
105
+ offset = 0
106
+ for er in self._event_ranges:
107
+ if offset <= index <= offset + er.max_value - er.min_value:
108
+ return Event(
109
+ type=er.type, value=er.min_value + index - offset)
110
+ offset += er.max_value - er.min_value + 1
111
+
112
+ raise ValueError(f'Unknown event index: {index}')
mt3/event_codec_test.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The MT3 Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tests for event_codec."""
16
+
17
+ from absl.testing import absltest
18
+ from mt3 import event_codec
19
+
20
+ Event = event_codec.Event
21
+ EventRange = event_codec.EventRange
22
+
23
+
24
+ class EventCodecTest(absltest.TestCase):
25
+
26
+ def test_encode_decode(self):
27
+ ec = event_codec.Codec(
28
+ max_shift_steps=100,
29
+ steps_per_second=100,
30
+ event_ranges=[EventRange('pitch', min_value=0, max_value=127)])
31
+ events = [
32
+ Event(type='pitch', value=60),
33
+ Event(type='shift', value=5),
34
+ Event(type='pitch', value=62),
35
+ ]
36
+ encoded = [ec.encode_event(e) for e in events]
37
+ self.assertSequenceEqual([161, 5, 163], encoded)
38
+
39
+ decoded = [ec.decode_event_index(idx) for idx in encoded]
40
+ self.assertSequenceEqual(events, decoded)
41
+
42
+ def test_shift_steps(self):
43
+ ec = event_codec.Codec(
44
+ max_shift_steps=100,
45
+ steps_per_second=100,
46
+ event_ranges=[EventRange('pitch', min_value=0, max_value=127)])
47
+
48
+ self.assertEqual(100, ec.max_shift_steps)
49
+ self.assertFalse(ec.is_shift_event_index(-1))
50
+ self.assertTrue(ec.is_shift_event_index(0))
51
+ self.assertTrue(ec.is_shift_event_index(100))
52
+ self.assertFalse(ec.is_shift_event_index(101))
53
+
54
+ if __name__ == '__main__':
55
+ absltest.main()
mt3/gin/eval.gin ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Defaults for eval.py.
2
+ #
3
+ # You must also include a binding for MODEL.
4
+ #
5
+ # Required to be set:
6
+ #
7
+ # - TASK_PREFIX
8
+ # - TASK_FEATURE_LENGTHS
9
+ # - CHECKPOINT_PATH
10
+ # - EVAL_OUTPUT_DIR
11
+ #
12
+ # Commonly overridden options:
13
+ #
14
+ # - DatasetConfig.split
15
+ # - DatasetConfig.batch_size
16
+ # - DatasetConfig.use_cached
17
+ # - RestoreCheckpointConfig.mode
18
+ # - PjitPartitioner.num_partitions
19
+
20
+ from __gin__ import dynamic_registration
21
+
22
+ import __main__ as eval_script
23
+ from mt3 import preprocessors
24
+ from mt3 import tasks
25
+ from mt3 import vocabularies
26
+ from t5x import partitioning
27
+ from t5x import utils
28
+
29
+ # Must be overridden
30
+ TASK_PREFIX = %gin.REQUIRED
31
+ TASK_FEATURE_LENGTHS = %gin.REQUIRED
32
+ CHECKPOINT_PATH = %gin.REQUIRED
33
+ EVAL_OUTPUT_DIR = %gin.REQUIRED
34
+
35
+ # Number of velocity bins: set to 1 (no velocity) or 127
36
+ NUM_VELOCITY_BINS = %gin.REQUIRED
37
+ VOCAB_CONFIG = @vocabularies.VocabularyConfig()
38
+ vocabularies.VocabularyConfig.num_velocity_bins = %NUM_VELOCITY_BINS
39
+
40
+ # Program granularity: set to 'flat', 'midi_class', or 'full'
41
+ PROGRAM_GRANULARITY = %gin.REQUIRED
42
+ preprocessors.map_midi_programs.granularity_type = %PROGRAM_GRANULARITY
43
+
44
+ TASK_SUFFIX = 'test'
45
+ tasks.construct_task_name:
46
+ task_prefix = %TASK_PREFIX
47
+ vocab_config = %VOCAB_CONFIG
48
+ task_suffix = %TASK_SUFFIX
49
+
50
+ eval_script.evaluate:
51
+ model = %MODEL # imported from separate gin file
52
+ dataset_cfg = @utils.DatasetConfig()
53
+ partitioner = @partitioning.PjitPartitioner()
54
+ restore_checkpoint_cfg = @utils.RestoreCheckpointConfig()
55
+ output_dir = %EVAL_OUTPUT_DIR
56
+
57
+ utils.DatasetConfig:
58
+ mixture_or_task_name = @tasks.construct_task_name()
59
+ task_feature_lengths = %TASK_FEATURE_LENGTHS
60
+ split = 'eval'
61
+ batch_size = 32
62
+ shuffle = False
63
+ seed = 42
64
+ use_cached = True
65
+ pack = False
66
+ use_custom_packing_ops = False
67
+
68
+ partitioning.PjitPartitioner.num_partitions = 1
69
+
70
+ utils.RestoreCheckpointConfig:
71
+ path = %CHECKPOINT_PATH
72
+ mode = 'specific'
mt3/gin/infer.gin ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Defaults for infer.py.
2
+ #
3
+ # You must also include a binding for MODEL.
4
+ #
5
+ # Required to be set:
6
+ #
7
+ # - TASK_PREFIX
8
+ # - TASK_FEATURE_LENGTHS
9
+ # - CHECKPOINT_PATH
10
+ # - INFER_OUTPUT_DIR
11
+ #
12
+ # Commonly overridden options:
13
+ #
14
+ # - infer.mode
15
+ # - infer.checkpoint_period
16
+ # - infer.shard_id
17
+ # - infer.num_shards
18
+ # - DatasetConfig.split
19
+ # - DatasetConfig.batch_size
20
+ # - DatasetConfig.use_cached
21
+ # - RestoreCheckpointConfig.is_tensorflow
22
+ # - RestoreCheckpointConfig.mode
23
+ # - PjitPartitioner.num_partitions
24
+
25
+ from __gin__ import dynamic_registration
26
+
27
+ import __main__ as infer_script
28
+ from mt3 import inference
29
+ from mt3 import preprocessors
30
+ from mt3 import tasks
31
+ from mt3 import vocabularies
32
+ from t5x import partitioning
33
+ from t5x import utils
34
+
35
+ # Must be overridden
36
+ TASK_PREFIX = %gin.REQUIRED
37
+ TASK_FEATURE_LENGTHS = %gin.REQUIRED
38
+ CHECKPOINT_PATH = %gin.REQUIRED
39
+ INFER_OUTPUT_DIR = %gin.REQUIRED
40
+
41
+ # Number of velocity bins: set to 1 (no velocity) or 127
42
+ NUM_VELOCITY_BINS = %gin.REQUIRED
43
+ VOCAB_CONFIG = @vocabularies.VocabularyConfig()
44
+ vocabularies.VocabularyConfig.num_velocity_bins = %NUM_VELOCITY_BINS
45
+
46
+ # Program granularity: set to 'flat', 'midi_class', or 'full'
47
+ PROGRAM_GRANULARITY = %gin.REQUIRED
48
+ preprocessors.map_midi_programs.granularity_type = %PROGRAM_GRANULARITY
49
+
50
+ TASK_SUFFIX = 'test'
51
+ tasks.construct_task_name:
52
+ task_prefix = %TASK_PREFIX
53
+ vocab_config = %VOCAB_CONFIG
54
+ task_suffix = %TASK_SUFFIX
55
+
56
+ ONSETS_ONLY = %gin.REQUIRED
57
+ USE_TIES = %gin.REQUIRED
58
+ inference.write_inferences_to_file:
59
+ vocab_config = %VOCAB_CONFIG
60
+ onsets_only = %ONSETS_ONLY
61
+ use_ties = %USE_TIES
62
+
63
+ infer_script.infer:
64
+ mode = 'predict'
65
+ model = %MODEL # imported from separate gin file
66
+ output_dir = %INFER_OUTPUT_DIR
67
+ dataset_cfg = @utils.DatasetConfig()
68
+ partitioner = @partitioning.PjitPartitioner()
69
+ restore_checkpoint_cfg = @utils.RestoreCheckpointConfig()
70
+ # This is a hack, but pass an extremely large value here to make sure the
71
+ # entire dataset fits in a single epoch. Otherwise, segments from a single
72
+ # example may end up in different epochs after splitting.
73
+ checkpoint_period = 1000000
74
+ shard_id = 0
75
+ num_shards = 1
76
+ write_fn = @inference.write_inferences_to_file
77
+
78
+ utils.DatasetConfig:
79
+ mixture_or_task_name = @tasks.construct_task_name()
80
+ task_feature_lengths = %TASK_FEATURE_LENGTHS
81
+ use_cached = True
82
+ split = 'eval'
83
+ batch_size = 32
84
+ shuffle = False
85
+ seed = 0
86
+ pack = False
87
+
88
+ partitioning.PjitPartitioner.num_partitions = 1
89
+
90
+ utils.RestoreCheckpointConfig:
91
+ path = %CHECKPOINT_PATH
92
+ mode = 'specific'
mt3/gin/ismir2021.gin ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration for ISMIR 2021 piano-only model.
2
+
3
+ TASK_PREFIX = 'maestrov3_notes'
4
+ TASK_FEATURE_LENGTHS = {'inputs': 512, 'targets': 1024}
5
+ TRAIN_STEPS = 400000
6
+ NUM_VELOCITY_BINS = 127
7
+ PROGRAM_GRANULARITY = 'flat'
8
+ ONSETS_ONLY = False
9
+ USE_TIES = False
mt3/gin/ismir2022/base.gin ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # T5.1.1 Base model.
2
+ include 'model.gin'
3
+
4
+ network.T5Config:
5
+ emb_dim = 768
6
+ num_heads = 12
7
+ num_encoder_layers = 12
8
+ num_decoder_layers = 12
9
+ head_dim = 64
10
+ mlp_dim = 2048
mt3/gin/ismir2022/finetune.gin ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __gin__ import dynamic_registration
2
+
3
+ from mt3 import network
4
+ from t5x import utils
5
+
6
+ include 'train.gin'
7
+
8
+ TASK_PREFIX = 'mega_notes_ties'
9
+ TASK_FEATURE_LENGTHS = {'inputs': 256, 'targets': 1024}
10
+ TRAIN_STEPS = 150000
11
+ BATCH_SIZE = 256
12
+ LABEL_SMOOTHING = 0.0
13
+ NUM_VELOCITY_BINS = 1
14
+ PROGRAM_GRANULARITY = 'full'
15
+ ONSETS_ONLY = False
16
+ USE_TIES = True
17
+ MAX_EXAMPLES_PER_MIX = None
18
+
19
+ network.T5Config.dropout_rate = 0.1
20
+
21
+ CHECKPOINT_PATH = %gin.REQUIRED
22
+ utils.CheckpointConfig.restore = @utils.RestoreCheckpointConfig()
23
+ utils.RestoreCheckpointConfig:
24
+ path = %CHECKPOINT_PATH
25
+ mode = 'specific'
mt3/gin/ismir2022/pretrain.gin ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ include 'train.gin'
2
+
3
+ TASK_FEATURE_LENGTHS = {'inputs': 256, 'targets': 1024}
4
+ TRAIN_STEPS = 500000
5
+ BATCH_SIZE = 1024
6
+ LABEL_SMOOTHING = 0.1
7
+ NUM_VELOCITY_BINS = 1
8
+ PROGRAM_GRANULARITY = 'full'
9
+ ONSETS_ONLY = False
10
+ USE_TIES = True
11
+ MAX_EXAMPLES_PER_MIX = 8
12
+
13
+ network.T5Config.dropout_rate = 0.0
mt3/gin/ismir2022/small.gin ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # T5.1.1 Small model.
2
+ include 'model.gin'
mt3/gin/local_tiny.gin ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # A gin file to make the Transformer models tiny for faster local testing.
2
+ #
3
+ # When testing locally with CPU, there are a few things that we need.
4
+ # - tiny model size
5
+ # - small enough batch size
6
+ # - small sequence length
7
+ # - determinstic dataset pipeline
8
+ #
9
+ # This gin file adds such configs. To use this gin file, add it on top of the
10
+ # existing full-scale gin files. The ordering of the gin file matters. So this
11
+ # should be added after all the other files are added to override the same
12
+ # configurables.
13
+
14
+ from __gin__ import dynamic_registration
15
+
16
+ from t5x import partitioning
17
+ from t5x import trainer
18
+ from t5x import utils
19
+ from t5x.examples.t5 import network
20
+
21
+ import __main__ as train_script
22
+
23
+ train_script.train.random_seed = 42 # dropout seed
24
+ train/utils.DatasetConfig.seed = 42 # dataset seed
25
+
26
+ TASK_FEATURE_LENGTHS = {"inputs": 8, "targets": 16}
27
+ LABEL_SMOOTHING = 0.0
28
+
29
+ # Network specification overrides
30
+ network.Transformer.config = @network.T5Config()
31
+ network.T5Config:
32
+ dtype = 'float32'
33
+ emb_dim = 8
34
+ num_heads = 4
35
+ num_encoder_layers = 2
36
+ num_decoder_layers = 2
37
+ head_dim = 3
38
+ mlp_dim = 16
39
+ mlp_activations = ('gelu', 'linear')
40
+ dropout_rate = 0.0
41
+ logits_via_embedding = False
42
+
43
+ TRAIN_STEPS = 3
44
+
45
+ train/utils.DatasetConfig:
46
+ batch_size = 8
47
+ shuffle = False
48
+
49
+ train_eval/utils.DatasetConfig.batch_size = 8
50
+
51
+ train_script.train:
52
+ eval_period = 3
53
+ eval_steps = 3
54
+
55
+ trainer.Trainer.num_microbatches = 0
56
+ partitioning.PjitPartitioner:
57
+ num_partitions = 1
58
+ model_parallel_submesh = None
59
+
60
+ utils.CheckpointConfig:
61
+ restore = None
62
+
63
+ infer_eval/utils.DatasetConfig.task_feature_lengths = %TASK_FEATURE_LENGTHS
mt3/gin/model.gin ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # T5.1.1 Small model.
2
+ from __gin__ import dynamic_registration
3
+
4
+ from mt3 import models
5
+ from mt3 import network
6
+ from mt3 import spectrograms
7
+ from mt3 import vocabularies
8
+ import seqio
9
+ from t5x import adafactor
10
+
11
+ # ------------------- Loss HParam ----------------------------------------------
12
+ Z_LOSS = 0.0001
13
+ LABEL_SMOOTHING = 0.0
14
+ LOSS_NORMALIZING_FACTOR = None
15
+ models.ContinuousInputsEncoderDecoderModel:
16
+ z_loss = %Z_LOSS
17
+ label_smoothing = %LABEL_SMOOTHING
18
+ loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR
19
+
20
+ # Output vocabulary
21
+ VOCAB_CONFIG = %gin.REQUIRED
22
+ OUTPUT_VOCABULARY = @vocabularies.vocabulary_from_codec()
23
+ vocabularies.vocabulary_from_codec.codec = @vocabularies.build_codec()
24
+ vocabularies.build_codec.vocab_config = %VOCAB_CONFIG
25
+
26
+ # ------------------- Optimizer ------------------------------------------------
27
+ # `learning_rate` is set by `Trainer.learning_rate_fn`.
28
+ OPTIMIZER = @adafactor.Adafactor()
29
+ adafactor.Adafactor:
30
+ decay_rate = 0.8
31
+ step_offset = 0
32
+ logical_factor_rules = @adafactor.standard_logical_factor_rules()
33
+
34
+ # ------------------- Model ----------------------------------------------------
35
+ SPECTROGRAM_CONFIG = @spectrograms.SpectrogramConfig()
36
+ MODEL = @models.ContinuousInputsEncoderDecoderModel()
37
+ models.ContinuousInputsEncoderDecoderModel:
38
+ module = @network.Transformer()
39
+ input_vocabulary = @seqio.vocabularies.PassThroughVocabulary()
40
+ output_vocabulary = %OUTPUT_VOCABULARY
41
+ optimizer_def = %OPTIMIZER
42
+ input_depth = @spectrograms.input_depth()
43
+ seqio.vocabularies.PassThroughVocabulary.size = 0
44
+ spectrograms.input_depth.spectrogram_config = %SPECTROGRAM_CONFIG
45
+
46
+ # ------------------- Network specification ------------------------------------
47
+ network.Transformer.config = @network.T5Config()
48
+ network.T5Config:
49
+ vocab_size = @vocabularies.num_embeddings()
50
+ dtype = 'float32'
51
+ emb_dim = 512
52
+ num_heads = 6
53
+ num_encoder_layers = 8
54
+ num_decoder_layers = 8
55
+ head_dim = 64
56
+ mlp_dim = 1024
57
+ mlp_activations = ('gelu', 'linear')
58
+ dropout_rate = 0.1
59
+ logits_via_embedding = False
60
+ vocabularies.num_embeddings.vocabulary = %OUTPUT_VOCABULARY
mt3/gin/mt3.gin ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration for MT3 multi-task multitrack model.
2
+
3
+ TASK_PREFIX = 'mega_notes_ties'
4
+ TASK_FEATURE_LENGTHS = {'inputs': 256, 'targets': 1024}
5
+ TRAIN_STEPS = 1000000
6
+ NUM_VELOCITY_BINS = 1
7
+ PROGRAM_GRANULARITY = 'full'
8
+ ONSETS_ONLY = False
9
+ USE_TIES = True
mt3/gin/train.gin ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Defaults for training with train.py.
2
+ #
3
+ # You must also include a binding for MODEL.
4
+ #
5
+ # Required to be set:
6
+ #
7
+ # - TASK_PREFIX
8
+ # - TASK_FEATURE_LENGTHS
9
+ # - TRAIN_STEPS
10
+ # - MODEL_DIR
11
+ #
12
+ # Commonly overridden options:
13
+ # - BATCH_SIZE
14
+ # - PjitPartitioner.num_partitions
15
+ # - Trainer.num_microbatches
16
+ # - USE_CACHED_TASKS: Whether to look for preprocessed SeqIO data, or preprocess
17
+ # on the fly.
18
+
19
+ from __gin__ import dynamic_registration
20
+
21
+ import __main__ as train_script
22
+ import seqio
23
+ from mt3 import mixing
24
+ from mt3 import preprocessors
25
+ from mt3 import tasks
26
+ from mt3 import vocabularies
27
+ from t5x import gin_utils
28
+ from t5x import partitioning
29
+ from t5x import utils
30
+ from t5x import trainer
31
+
32
+ # Must be overridden
33
+ TASK_PREFIX = %gin.REQUIRED
34
+ TASK_FEATURE_LENGTHS = %gin.REQUIRED
35
+ TRAIN_STEPS = %gin.REQUIRED
36
+ MODEL_DIR = %gin.REQUIRED
37
+
38
+ # Commonly overridden
39
+ TRAIN_TASK_SUFFIX = 'train'
40
+ EVAL_TASK_SUFFIX = 'eval'
41
+ USE_CACHED_TASKS = True
42
+ BATCH_SIZE = 256
43
+
44
+ # Sometimes overridden
45
+ EVAL_STEPS = 20
46
+
47
+ # Convenience overrides.
48
+ EVALUATOR_USE_MEMORY_CACHE = True
49
+ EVALUATOR_NUM_EXAMPLES = None # Use all examples in the infer_eval dataset.
50
+ JSON_WRITE_N_RESULTS = 0 # Don't write any inferences.
51
+
52
+ # Number of velocity bins: set to 1 (no velocity) or 127
53
+ NUM_VELOCITY_BINS = %gin.REQUIRED
54
+ VOCAB_CONFIG = @vocabularies.VocabularyConfig()
55
+ vocabularies.VocabularyConfig.num_velocity_bins = %NUM_VELOCITY_BINS
56
+
57
+ # Program granularity: set to 'flat', 'midi_class', or 'full'
58
+ PROGRAM_GRANULARITY = %gin.REQUIRED
59
+ preprocessors.map_midi_programs.granularity_type = %PROGRAM_GRANULARITY
60
+
61
+ # Maximum number of examples per mix, or None for no mixing
62
+ MAX_EXAMPLES_PER_MIX = None
63
+ mixing.mix_transcription_examples.max_examples_per_mix = %MAX_EXAMPLES_PER_MIX
64
+
65
+ train/tasks.construct_task_name:
66
+ task_prefix = %TASK_PREFIX
67
+ vocab_config = %VOCAB_CONFIG
68
+ task_suffix = %TRAIN_TASK_SUFFIX
69
+
70
+ eval/tasks.construct_task_name:
71
+ task_prefix = %TASK_PREFIX
72
+ vocab_config = %VOCAB_CONFIG
73
+ task_suffix = %EVAL_TASK_SUFFIX
74
+
75
+ train_script.train:
76
+ model = %MODEL # imported from separate gin file
77
+ model_dir = %MODEL_DIR
78
+ train_dataset_cfg = @train/utils.DatasetConfig()
79
+ train_eval_dataset_cfg = @train_eval/utils.DatasetConfig()
80
+ infer_eval_dataset_cfg = @infer_eval/utils.DatasetConfig()
81
+ checkpoint_cfg = @utils.CheckpointConfig()
82
+ partitioner = @partitioning.PjitPartitioner()
83
+ trainer_cls = @trainer.Trainer
84
+ total_steps = %TRAIN_STEPS
85
+ eval_steps = %EVAL_STEPS
86
+ eval_period = 5000
87
+ random_seed = None # use faster, hardware RNG
88
+ summarize_config_fn = @gin_utils.summarize_gin_config
89
+ inference_evaluator_cls = @seqio.Evaluator
90
+
91
+ seqio.Evaluator:
92
+ logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger]
93
+ num_examples = %EVALUATOR_NUM_EXAMPLES
94
+ use_memory_cache = %EVALUATOR_USE_MEMORY_CACHE
95
+
96
+ seqio.JSONLogger:
97
+ write_n_results = %JSON_WRITE_N_RESULTS
98
+
99
+ train/utils.DatasetConfig:
100
+ mixture_or_task_name = @train/tasks.construct_task_name()
101
+ task_feature_lengths = %TASK_FEATURE_LENGTHS
102
+ split = 'train'
103
+ batch_size = %BATCH_SIZE
104
+ shuffle = True
105
+ seed = None # use a new seed each run/restart
106
+ use_cached = %USE_CACHED_TASKS
107
+ pack = False
108
+
109
+ train_eval/utils.DatasetConfig:
110
+ mixture_or_task_name = @train/tasks.construct_task_name()
111
+ task_feature_lengths = %TASK_FEATURE_LENGTHS
112
+ split = 'eval'
113
+ batch_size = %BATCH_SIZE
114
+ shuffle = False
115
+ seed = 42
116
+ use_cached = %USE_CACHED_TASKS
117
+ pack = False
118
+
119
+ infer_eval/utils.DatasetConfig:
120
+ mixture_or_task_name = @eval/tasks.construct_task_name()
121
+ task_feature_lengths = %TASK_FEATURE_LENGTHS
122
+ split = 'eval'
123
+ batch_size = %BATCH_SIZE
124
+ shuffle = False
125
+ seed = 42
126
+ use_cached = %USE_CACHED_TASKS
127
+ pack = False
128
+
129
+ utils.CheckpointConfig:
130
+ restore = None
131
+ save = @utils.SaveCheckpointConfig()
132
+ utils.SaveCheckpointConfig:
133
+ period = 5000
134
+ dtype = 'float32'
135
+ keep = None # keep all checkpoints
136
+ save_dataset = False # don't checkpoint dataset state
137
+
138
+ partitioning.PjitPartitioner:
139
+ num_partitions = 1
140
+ model_parallel_submesh = None
141
+
142
+ trainer.Trainer:
143
+ num_microbatches = None
144
+ learning_rate_fn = @utils.create_learning_rate_scheduler()
145
+ utils.create_learning_rate_scheduler:
146
+ factors = 'constant'
147
+ base_learning_rate = 0.001
148
+ warmup_steps = 1000
mt3/inference.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The MT3 Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Functions for MT3 inference."""
16
+
17
+ import functools
18
+ import json
19
+
20
+ from typing import Any, Optional, Sequence
21
+
22
+ import gin
23
+
24
+ from mt3 import metrics_utils
25
+ from mt3 import note_sequences
26
+ from mt3 import tasks
27
+ from mt3 import vocabularies
28
+
29
+ import note_seq
30
+ import seqio
31
+ import tensorflow as tf
32
+
33
+
34
+ def write_inferences_to_file(
35
+ path: str,
36
+ inferences: Sequence[Any],
37
+ task_ds: tf.data.Dataset,
38
+ mode: str,
39
+ vocabulary: Optional[seqio.Vocabulary] = None,
40
+ vocab_config=gin.REQUIRED,
41
+ onsets_only=gin.REQUIRED,
42
+ use_ties=gin.REQUIRED) -> None:
43
+ """Writes model predictions, ground truth transcriptions, and input audio.
44
+
45
+ For now this only works for transcription tasks with ties.
46
+
47
+ Args:
48
+ path: File path to write to.
49
+ inferences: Model inferences, output of predict_batch.
50
+ task_ds: Original task dataset.
51
+ mode: Prediction mode; must be 'predict' as 'score' is not supported.
52
+ vocabulary: Task output vocabulary.
53
+ vocab_config: Vocabulary config object.
54
+ onsets_only: If True, only predict onsets.
55
+ use_ties: If True, use "tie" representation.
56
+ """
57
+ if mode == 'score':
58
+ raise ValueError('`score` mode currently not supported in MT3')
59
+ if not vocabulary:
60
+ raise ValueError('`vocabulary` parameter required in `predict` mode')
61
+
62
+ if onsets_only and use_ties:
63
+ raise ValueError('ties not compatible with onset-only transcription')
64
+ if onsets_only:
65
+ encoding_spec = note_sequences.NoteOnsetEncodingSpec
66
+ elif not use_ties:
67
+ encoding_spec = note_sequences.NoteEncodingSpec
68
+ else:
69
+ encoding_spec = note_sequences.NoteEncodingWithTiesSpec
70
+
71
+ codec = vocabularies.build_codec(vocab_config)
72
+
73
+ targets = []
74
+ predictions = []
75
+
76
+ for inp, output in zip(task_ds.as_numpy_iterator(), inferences):
77
+ tokens = tasks.trim_eos(vocabulary.decode_tf(output).numpy())
78
+
79
+ start_time = inp['input_times'][0]
80
+ # Round down to nearest symbolic token step.
81
+ start_time -= start_time % (1 / codec.steps_per_second)
82
+
83
+ targets.append({
84
+ 'unique_id': inp['unique_id'][0],
85
+ 'ref_ns': inp['sequence'][0] if inp['sequence'][0] else None,
86
+ })
87
+
88
+ predictions.append({
89
+ 'unique_id': inp['unique_id'][0],
90
+ 'est_tokens': tokens,
91
+ 'start_time': start_time,
92
+ # Input audio is not part of the "prediction" but the below call to
93
+ # metrics_utils.event_predictions_to_ns handles the concatenation.
94
+ 'raw_inputs': inp['raw_inputs']
95
+ })
96
+
97
+ # The first target for each full example contains the NoteSequence; just
98
+ # organize by ID.
99
+ full_targets = {}
100
+ for target in targets:
101
+ if target['ref_ns']:
102
+ full_targets[target['unique_id']] = {
103
+ 'ref_ns': note_seq.NoteSequence.FromString(target['ref_ns'])
104
+ }
105
+
106
+ full_predictions = metrics_utils.combine_predictions_by_id(
107
+ predictions=predictions,
108
+ combine_predictions_fn=functools.partial(
109
+ metrics_utils.event_predictions_to_ns,
110
+ codec=codec,
111
+ encoding_spec=encoding_spec))
112
+
113
+ assert sorted(full_targets.keys()) == sorted(full_predictions.keys())
114
+
115
+ full_target_prediction_pairs = [
116
+ (full_targets[id], full_predictions[id])
117
+ for id in sorted(full_targets.keys())
118
+ ]
119
+
120
+ def note_to_dict(note):
121
+ return {
122
+ 'start_time': note.start_time,
123
+ 'end_time': note.end_time,
124
+ 'pitch': note.pitch,
125
+ 'velocity': note.velocity,
126
+ 'program': note.program,
127
+ 'is_drum': note.is_drum
128
+ }
129
+
130
+ with tf.io.gfile.GFile(path, 'w') as f:
131
+ for target, prediction in full_target_prediction_pairs:
132
+ json_dict = {
133
+ 'id': target['ref_ns'].id,
134
+ 'est_notes':
135
+ [note_to_dict(note) for note in prediction['est_ns'].notes]
136
+ }
137
+ json_str = json.dumps(json_dict, cls=seqio.TensorAndNumpyEncoder)
138
+ f.write(json_str + '\n')
mt3/layers.py ADDED
@@ -0,0 +1,830 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The MT3 Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Dense attention classes and mask/weighting functions."""
16
+
17
+ # pylint: disable=attribute-defined-outside-init,g-bare-generic
18
+
19
+ import dataclasses
20
+ import functools
21
+ import operator
22
+ from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, Union
23
+
24
+ from flax import linen as nn
25
+ from flax.linen import partitioning as nn_partitioning
26
+ import jax
27
+ from jax import lax
28
+ from jax import random
29
+ import jax.numpy as jnp
30
+ import numpy as np
31
+
32
+
33
+ # from flax.linen.partitioning import param_with_axes, with_sharding_constraint
34
+ param_with_axes = nn_partitioning.param_with_axes
35
+ with_sharding_constraint = nn_partitioning.with_sharding_constraint
36
+
37
+
38
+ # Type annotations
39
+ Array = jnp.ndarray
40
+ DType = jnp.dtype
41
+ PRNGKey = jnp.ndarray
42
+ Shape = Iterable[int]
43
+ Activation = Callable[..., Array]
44
+ # Parameter initializers.
45
+ Initializer = Callable[[PRNGKey, Shape, DType], Array]
46
+
47
+ default_embed_init = nn.initializers.variance_scaling(
48
+ 1.0, 'fan_in', 'normal', out_axis=0)
49
+
50
+
51
+ def sinusoidal(min_scale: float = 1.0,
52
+ max_scale: float = 10000.0,
53
+ dtype: DType = jnp.float32) -> Initializer:
54
+ """Creates 1D Sinusoidal Position Embedding Initializer.
55
+
56
+ Args:
57
+ min_scale: Minimum frequency-scale in sine grating.
58
+ max_scale: Maximum frequency-scale in sine grating.
59
+ dtype: The DType of the returned values.
60
+
61
+ Returns:
62
+ The sinusoidal initialization function.
63
+ """
64
+
65
+ def init(key: PRNGKey, shape: Shape, dtype: DType = dtype) -> Array:
66
+ """Sinusoidal init."""
67
+ del key
68
+ if dtype != np.float32:
69
+ raise ValueError('The sinusoidal initializer only supports float32.')
70
+ if len(list(shape)) != 2:
71
+ raise ValueError(
72
+ f'Expected a 2D shape (max_len, features), but got {shape}.')
73
+ max_len, features = shape
74
+ pe = np.zeros((max_len, features), dtype=dtype)
75
+ position = np.arange(0, max_len)[:, np.newaxis]
76
+ scale_factor = -np.log(max_scale / min_scale) / (features // 2 - 1)
77
+ div_term = min_scale * np.exp(np.arange(0, features // 2) * scale_factor)
78
+ pe[:, :features // 2] = np.sin(position * div_term)
79
+ pe[:, features // 2:2 * (features // 2)] = np.cos(position * div_term)
80
+ return jnp.array(pe)
81
+
82
+ return init
83
+
84
+
85
+ def dot_product_attention(query: Array,
86
+ key: Array,
87
+ value: Array,
88
+ bias: Optional[Array] = None,
89
+ dropout_rng: Optional[PRNGKey] = None,
90
+ dropout_rate: float = 0.,
91
+ deterministic: bool = False,
92
+ dtype: DType = jnp.float32,
93
+ float32_logits: bool = False):
94
+ """Computes dot-product attention given query, key, and value.
95
+
96
+ This is the core function for applying attention based on
97
+ https://arxiv.org/abs/1706.03762. It calculates the attention weights given
98
+ query and key and combines the values using the attention weights.
99
+
100
+ Args:
101
+ query: queries for calculating attention with shape of `[batch, q_length,
102
+ num_heads, qk_depth_per_head]`.
103
+ key: keys for calculating attention with shape of `[batch, kv_length,
104
+ num_heads, qk_depth_per_head]`.
105
+ value: values to be used in attention with shape of `[batch, kv_length,
106
+ num_heads, v_depth_per_head]`.
107
+ bias: bias for the attention weights. This should be broadcastable to the
108
+ shape `[batch, num_heads, q_length, kv_length]` This can be used for
109
+ incorporating causal masks, padding masks, proximity bias, etc.
110
+ dropout_rng: JAX PRNGKey: to be used for dropout
111
+ dropout_rate: dropout rate
112
+ deterministic: bool, deterministic or not (to apply dropout)
113
+ dtype: the dtype of the computation (default: float32)
114
+ float32_logits: bool, if True then compute logits in float32 to avoid
115
+ numerical issues with bfloat16.
116
+
117
+ Returns:
118
+ Output of shape `[batch, length, num_heads, v_depth_per_head]`.
119
+ """
120
+ assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.'
121
+ assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], (
122
+ 'q, k, v batch dims must match.')
123
+ assert query.shape[-2] == key.shape[-2] == value.shape[-2], (
124
+ 'q, k, v num_heads must match.')
125
+ assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.'
126
+ assert query.shape[-1] == key.shape[-1], 'q, k depths must match.'
127
+
128
+ # Casting logits and softmax computation for float32 for model stability.
129
+ if float32_logits:
130
+ query = query.astype(jnp.float32)
131
+ key = key.astype(jnp.float32)
132
+
133
+ # `attn_weights`: [batch, num_heads, q_length, kv_length]
134
+ attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key)
135
+
136
+ # Apply attention bias: masking, dropout, proximity bias, etc.
137
+ if bias is not None:
138
+ attn_weights = attn_weights + bias.astype(attn_weights.dtype)
139
+
140
+ # Normalize the attention weights across `kv_length` dimension.
141
+ attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
142
+
143
+ # Apply attention dropout.
144
+ if not deterministic and dropout_rate > 0.:
145
+ keep_prob = 1.0 - dropout_rate
146
+ # T5 broadcasts along the "length" dim, but unclear which one that
147
+ # corresponds to in positional dimensions here, assuming query dim.
148
+ dropout_shape = list(attn_weights.shape)
149
+ dropout_shape[-2] = 1
150
+ keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape)
151
+ keep = jnp.broadcast_to(keep, attn_weights.shape)
152
+ multiplier = (
153
+ keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype))
154
+ attn_weights = attn_weights * multiplier
155
+
156
+ # Take the linear combination of `value`.
157
+ return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value)
158
+
159
+
160
+ dynamic_vector_slice_in_dim = jax.vmap(
161
+ lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None))
162
+
163
+
164
+ class MultiHeadDotProductAttention(nn.Module):
165
+ """Multi-head dot-product attention.
166
+
167
+ Attributes:
168
+ num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
169
+ should be divisible by the number of heads.
170
+ head_dim: dimension of each head.
171
+ dtype: the dtype of the computation.
172
+ dropout_rate: dropout rate
173
+ kernel_init: initializer for the kernel of the Dense layers.
174
+ float32_logits: bool, if True then compute logits in float32 to avoid
175
+ numerical issues with bfloat16.
176
+ """
177
+
178
+ num_heads: int
179
+ head_dim: int
180
+ dtype: DType = jnp.float32
181
+ dropout_rate: float = 0.
182
+ kernel_init: Initializer = nn.initializers.variance_scaling(
183
+ 1.0, 'fan_in', 'normal')
184
+ float32_logits: bool = False # computes logits in float32 for stability.
185
+
186
+ @nn.compact
187
+ def __call__(self,
188
+ inputs_q: Array,
189
+ inputs_kv: Array,
190
+ mask: Optional[Array] = None,
191
+ bias: Optional[Array] = None,
192
+ *,
193
+ decode: bool = False,
194
+ deterministic: bool = False) -> Array:
195
+ """Applies multi-head dot product attention on the input data.
196
+
197
+ Projects the inputs into multi-headed query, key, and value vectors,
198
+ applies dot-product attention and project the results to an output vector.
199
+
200
+ There are two modes: decoding and non-decoding (e.g., training). The mode is
201
+ determined by `decode` argument. For decoding, this method is called twice,
202
+ first to initialize the cache and then for an actual decoding process. The
203
+ two calls are differentiated by the presence of 'cached_key' in the variable
204
+ dict. In the cache initialization stage, the cache variables are initialized
205
+ as zeros and will be filled in the subsequent decoding process.
206
+
207
+ In the cache initialization call, `inputs_q` has a shape [batch, length,
208
+ q_features] and `inputs_kv`: [batch, length, kv_features]. During the
209
+ incremental decoding stage, query, key and value all have the shape [batch,
210
+ 1, qkv_features] corresponding to a single step.
211
+
212
+ Args:
213
+ inputs_q: input queries of shape `[batch, q_length, q_features]`.
214
+ inputs_kv: key/values of shape `[batch, kv_length, kv_features]`.
215
+ mask: attention mask of shape `[batch, num_heads, q_length, kv_length]`.
216
+ bias: attention bias of shape `[batch, num_heads, q_length, kv_length]`.
217
+ decode: Whether to prepare and use an autoregressive cache.
218
+ deterministic: Disables dropout if set to True.
219
+
220
+ Returns:
221
+ output of shape `[batch, length, q_features]`.
222
+ """
223
+ projection = functools.partial(
224
+ DenseGeneral,
225
+ axis=-1,
226
+ features=(self.num_heads, self.head_dim),
227
+ kernel_axes=('embed', 'joined_kv'),
228
+ dtype=self.dtype)
229
+
230
+ # NOTE: T5 does not explicitly rescale the attention logits by
231
+ # 1/sqrt(depth_kq)! This is folded into the initializers of the
232
+ # linear transformations, which is equivalent under Adafactor.
233
+ depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
234
+ query_init = lambda *args: self.kernel_init(*args) / depth_scaling
235
+
236
+ # Project inputs_q to multi-headed q/k/v
237
+ # dimensions are then [batch, length, num_heads, head_dim]
238
+ query = projection(kernel_init=query_init, name='query')(inputs_q)
239
+ key = projection(kernel_init=self.kernel_init, name='key')(inputs_kv)
240
+ value = projection(kernel_init=self.kernel_init, name='value')(inputs_kv)
241
+
242
+ query = with_sharding_constraint(query, ('batch', 'length', 'heads', 'kv'))
243
+ key = with_sharding_constraint(key, ('batch', 'length', 'heads', 'kv'))
244
+ value = with_sharding_constraint(value, ('batch', 'length', 'heads', 'kv'))
245
+
246
+ if decode:
247
+ # Detect if we're initializing by absence of existing cache data.
248
+ is_initialized = self.has_variable('cache', 'cached_key')
249
+ # The key and value have dimension [batch, length, num_heads, head_dim],
250
+ # but we cache them as [batch, num_heads, head_dim, length] as a TPU
251
+ # fusion optimization. This also enables the "scatter via one-hot
252
+ # broadcast" trick, which means we do a one-hot broadcast instead of a
253
+ # scatter/gather operations, resulting in a 3-4x speedup in practice.
254
+ swap_dims = lambda x: x[:-3] + tuple(x[i] for i in [-2, -1, -3])
255
+ cached_key = self.variable('cache', 'cached_key', jnp.zeros,
256
+ swap_dims(key.shape), key.dtype)
257
+ cached_value = self.variable('cache', 'cached_value', jnp.zeros,
258
+ swap_dims(value.shape), value.dtype)
259
+ cache_index = self.variable('cache', 'cache_index',
260
+ lambda: jnp.array(0, dtype=jnp.int32))
261
+ if is_initialized:
262
+ batch, num_heads, head_dim, length = (cached_key.value.shape)
263
+ # During fast autoregressive decoding, we feed one position at a time,
264
+ # and cache the keys and values step by step.
265
+ # Sanity shape check of cached key against input query.
266
+ expected_shape = (batch, 1, num_heads, head_dim)
267
+ if expected_shape != query.shape:
268
+ raise ValueError('Autoregressive cache shape error, '
269
+ 'expected query shape %s instead got %s.' %
270
+ (expected_shape, query.shape))
271
+
272
+ # Create a OHE of the current index. NOTE: the index is increased below.
273
+ cur_index = cache_index.value
274
+ one_hot_indices = jax.nn.one_hot(cur_index, length, dtype=key.dtype)
275
+ # In order to update the key, value caches with the current key and
276
+ # value, we move the length axis to the back, similar to what we did for
277
+ # the cached ones above.
278
+ # Note these are currently the key and value of a single position, since
279
+ # we feed one position at a time.
280
+ one_token_key = jnp.moveaxis(key, -3, -1)
281
+ one_token_value = jnp.moveaxis(value, -3, -1)
282
+ # Update key, value caches with our new 1d spatial slices.
283
+ # We implement an efficient scatter into the cache via one-hot
284
+ # broadcast and addition.
285
+ key = cached_key.value + one_token_key * one_hot_indices
286
+ value = cached_value.value + one_token_value * one_hot_indices
287
+ cached_key.value = key
288
+ cached_value.value = value
289
+ cache_index.value = cache_index.value + 1
290
+ # Move the keys and values back to their original shapes.
291
+ key = jnp.moveaxis(key, -1, -3)
292
+ value = jnp.moveaxis(value, -1, -3)
293
+
294
+ # Causal mask for cached decoder self-attention: our single query
295
+ # position should only attend to those key positions that have already
296
+ # been generated and cached, not the remaining zero elements.
297
+ mask = combine_masks(
298
+ mask,
299
+ jnp.broadcast_to(
300
+ jnp.arange(length) <= cur_index,
301
+ # (1, 1, length) represent (head dim, query length, key length)
302
+ # query length is 1 because during decoding we deal with one
303
+ # index.
304
+ # The same mask is applied to all batch elements and heads.
305
+ (batch, 1, 1, length)))
306
+
307
+ # Grab the correct relative attention bias during decoding. This is
308
+ # only required during single step decoding.
309
+ if bias is not None:
310
+ # The bias is a full attention matrix, but during decoding we only
311
+ # have to take a slice of it.
312
+ # This is equivalent to bias[..., cur_index:cur_index+1, :].
313
+ bias = dynamic_vector_slice_in_dim(
314
+ jnp.squeeze(bias, axis=0), jnp.reshape(cur_index, (-1)), 1, -2)
315
+
316
+ # Convert the boolean attention mask to an attention bias.
317
+ if mask is not None:
318
+ # attention mask in the form of attention bias
319
+ attention_bias = lax.select(
320
+ mask > 0,
321
+ jnp.full(mask.shape, 0.).astype(self.dtype),
322
+ jnp.full(mask.shape, -1e10).astype(self.dtype))
323
+ else:
324
+ attention_bias = None
325
+
326
+ # Add provided bias term (e.g. relative position embedding).
327
+ if bias is not None:
328
+ attention_bias = combine_biases(attention_bias, bias)
329
+
330
+ dropout_rng = None
331
+ if not deterministic and self.dropout_rate > 0.:
332
+ dropout_rng = self.make_rng('dropout')
333
+
334
+ # Apply attention.
335
+ x = dot_product_attention(
336
+ query,
337
+ key,
338
+ value,
339
+ bias=attention_bias,
340
+ dropout_rng=dropout_rng,
341
+ dropout_rate=self.dropout_rate,
342
+ deterministic=deterministic,
343
+ dtype=self.dtype,
344
+ float32_logits=self.float32_logits)
345
+
346
+ # Back to the original inputs dimensions.
347
+ out = DenseGeneral(
348
+ features=inputs_q.shape[-1], # output dim is set to the input dim.
349
+ axis=(-2, -1),
350
+ kernel_init=self.kernel_init,
351
+ kernel_axes=('joined_kv', 'embed'),
352
+ dtype=self.dtype,
353
+ name='out')(
354
+ x)
355
+ return out
356
+
357
+
358
+ def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]:
359
+ # A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
360
+ return tuple([ax if ax >= 0 else ndim + ax for ax in axes])
361
+
362
+
363
+ def _canonicalize_tuple(x):
364
+ if isinstance(x, Iterable):
365
+ return tuple(x)
366
+ else:
367
+ return (x,)
368
+
369
+
370
+ #------------------------------------------------------------------------------
371
+ # DenseGeneral for attention layers.
372
+ #------------------------------------------------------------------------------
373
+ class DenseGeneral(nn.Module):
374
+ """A linear transformation (without bias) with flexible axes.
375
+
376
+ Attributes:
377
+ features: tuple with numbers of output features.
378
+ axis: tuple with axes to apply the transformation on.
379
+ dtype: the dtype of the computation (default: float32).
380
+ kernel_init: initializer function for the weight matrix.
381
+ """
382
+ features: Union[Iterable[int], int]
383
+ axis: Union[Iterable[int], int] = -1
384
+ dtype: DType = jnp.float32
385
+ kernel_init: Initializer = nn.initializers.variance_scaling(
386
+ 1.0, 'fan_in', 'truncated_normal')
387
+ kernel_axes: Tuple[str, ...] = ()
388
+
389
+ @nn.compact
390
+ def __call__(self, inputs: Array) -> Array:
391
+ """Applies a linear transformation to the inputs along multiple dimensions.
392
+
393
+ Args:
394
+ inputs: The nd-array to be transformed.
395
+
396
+ Returns:
397
+ The transformed input.
398
+ """
399
+ features = _canonicalize_tuple(self.features)
400
+ axis = _canonicalize_tuple(self.axis)
401
+
402
+ inputs = jnp.asarray(inputs, self.dtype)
403
+ axis = _normalize_axes(axis, inputs.ndim)
404
+
405
+ kernel_shape = tuple([inputs.shape[ax] for ax in axis]) + features
406
+ kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),
407
+ np.prod(features))
408
+ kernel = param_with_axes(
409
+ 'kernel',
410
+ self.kernel_init,
411
+ kernel_param_shape,
412
+ jnp.float32,
413
+ axes=self.kernel_axes)
414
+ kernel = jnp.asarray(kernel, self.dtype)
415
+ kernel = jnp.reshape(kernel, kernel_shape)
416
+
417
+ contract_ind = tuple(range(0, len(axis)))
418
+ return lax.dot_general(inputs, kernel, ((axis, contract_ind), ((), ())))
419
+
420
+
421
+ def _convert_to_activation_function(
422
+ fn_or_string: Union[str, Callable]) -> Callable:
423
+ """Convert a string to an activation function."""
424
+ if fn_or_string == 'linear':
425
+ return lambda x: x
426
+ elif isinstance(fn_or_string, str):
427
+ return getattr(nn, fn_or_string)
428
+ elif callable(fn_or_string):
429
+ return fn_or_string
430
+ else:
431
+ raise ValueError("don't know how to convert %s to an activation function" %
432
+ (fn_or_string,))
433
+
434
+
435
+ class MlpBlock(nn.Module):
436
+ """Transformer MLP / feed-forward block.
437
+
438
+ Attributes:
439
+ intermediate_dim: Shared dimension of hidden layers.
440
+ activations: Type of activations for each layer. Each element is either
441
+ 'linear', a string function name in flax.linen, or a function.
442
+ kernel_init: Kernel function, passed to the dense layers.
443
+ deterministic: Whether the dropout layers should be deterministic.
444
+ intermediate_dropout_rate: Dropout rate used after the intermediate layers.
445
+ dtype: Type for the dense layer.
446
+ """
447
+ intermediate_dim: int = 2048
448
+ activations: Sequence[Union[str, Callable]] = ('relu',)
449
+ kernel_init: Initializer = nn.initializers.variance_scaling(
450
+ 1.0, 'fan_in', 'truncated_normal')
451
+ intermediate_dropout_rate: float = 0.1
452
+ dtype: Any = jnp.float32
453
+
454
+ @nn.compact
455
+ def __call__(self, inputs, decode: bool = False, deterministic: bool = False):
456
+ """Applies Transformer MlpBlock module."""
457
+ # Iterate over specified MLP input activation functions.
458
+ # e.g. ('relu',) or ('gelu', 'linear') for gated-gelu.
459
+ activations = []
460
+ for idx, act_fn in enumerate(self.activations):
461
+ dense_name = 'wi' if len(self.activations) == 1 else f'wi_{idx}'
462
+ x = DenseGeneral(
463
+ self.intermediate_dim,
464
+ dtype=self.dtype,
465
+ kernel_init=self.kernel_init,
466
+ kernel_axes=('embed', 'mlp'),
467
+ name=dense_name)(
468
+ inputs)
469
+ x = _convert_to_activation_function(act_fn)(x)
470
+ activations.append(x)
471
+
472
+ # Take elementwise product of above intermediate activations.
473
+ x = functools.reduce(operator.mul, activations)
474
+ # Apply dropout and final dense output projection.
475
+ x = nn.Dropout(
476
+ rate=self.intermediate_dropout_rate, broadcast_dims=(-2,))(
477
+ x, deterministic=deterministic) # Broadcast along length.
478
+ x = with_sharding_constraint(x, ('batch', 'length', 'mlp'))
479
+ output = DenseGeneral(
480
+ inputs.shape[-1],
481
+ dtype=self.dtype,
482
+ kernel_init=self.kernel_init,
483
+ kernel_axes=('mlp', 'embed'),
484
+ name='wo')(
485
+ x)
486
+ return output
487
+
488
+
489
+ class Embed(nn.Module):
490
+ """A parameterized function from integers [0, n) to d-dimensional vectors.
491
+
492
+ Attributes:
493
+ num_embeddings: number of embeddings.
494
+ features: number of feature dimensions for each embedding.
495
+ dtype: the dtype of the embedding vectors (default: float32).
496
+ embedding_init: embedding initializer.
497
+ one_hot: performs the gather with a one-hot contraction rather than a true
498
+ gather. This is currently needed for SPMD partitioning.
499
+ """
500
+ num_embeddings: int
501
+ features: int
502
+ cast_input_dtype: Optional[DType] = None
503
+ dtype: DType = jnp.float32
504
+ attend_dtype: Optional[DType] = None
505
+ embedding_init: Initializer = default_embed_init
506
+ one_hot: bool = False
507
+ embedding: Array = dataclasses.field(init=False)
508
+
509
+ def setup(self):
510
+ self.embedding = param_with_axes(
511
+ 'embedding',
512
+ self.embedding_init, (self.num_embeddings, self.features),
513
+ jnp.float32,
514
+ axes=('vocab', 'embed'))
515
+
516
+ def __call__(self, inputs: Array) -> Array:
517
+ """Embeds the inputs along the last dimension.
518
+
519
+ Args:
520
+ inputs: input data, all dimensions are considered batch dimensions.
521
+
522
+ Returns:
523
+ Output which is embedded input data. The output shape follows the input,
524
+ with an additional `features` dimension appended.
525
+ """
526
+ if self.cast_input_dtype:
527
+ inputs = inputs.astype(self.cast_input_dtype)
528
+ if not jnp.issubdtype(inputs.dtype, jnp.integer):
529
+ raise ValueError('Input type must be an integer or unsigned integer.')
530
+ if self.one_hot:
531
+ iota = lax.iota(jnp.int32, self.num_embeddings)
532
+ one_hot = jnp.array(inputs[..., jnp.newaxis] == iota, dtype=self.dtype)
533
+ output = jnp.dot(one_hot, jnp.asarray(self.embedding, self.dtype))
534
+ else:
535
+ output = jnp.asarray(self.embedding, self.dtype)[inputs]
536
+ output = with_sharding_constraint(output, ('batch', 'length', 'embed'))
537
+ return output
538
+
539
+ def attend(self, query: Array) -> Array:
540
+ """Attend over the embedding using a query array.
541
+
542
+ Args:
543
+ query: array with last dimension equal the feature depth `features` of the
544
+ embedding.
545
+
546
+ Returns:
547
+ An array with final dim `num_embeddings` corresponding to the batched
548
+ inner-product of the array of query vectors against each embedding.
549
+ Commonly used for weight-sharing between embeddings and logit transform
550
+ in NLP models.
551
+ """
552
+ dtype = self.attend_dtype if self.attend_dtype is not None else self.dtype
553
+ return jnp.dot(query, jnp.asarray(self.embedding, dtype).T)
554
+
555
+
556
+ class FixedEmbed(nn.Module):
557
+ """Fixed (not learnable) embeddings specified by the initializer function.
558
+
559
+ Attributes:
560
+ init_fn: The initializer function that defines the embeddings.
561
+ max_length: The maximum supported length.
562
+ dtype: The DType to use for the embeddings.
563
+ """
564
+ features: int
565
+ max_length: int = 2048
566
+ embedding_init: Initializer = sinusoidal()
567
+ dtype: jnp.dtype = jnp.float32
568
+
569
+ def setup(self):
570
+ # The key is set to None because sinusoid init is deterministic.
571
+ shape = (self.max_length, self.features)
572
+ self.embedding = self.embedding_init(None, shape, self.dtype) # pylint: disable=too-many-function-args
573
+
574
+ @nn.compact
575
+ def __call__(self,
576
+ inputs,
577
+ *,
578
+ decode: bool = False):
579
+ """Returns the fixed position embeddings specified by the initializer.
580
+
581
+ Args:
582
+ inputs: <int>[batch_size, seq_len] input position indices.
583
+ decode: True if running in single-position autoregressive decode mode.
584
+
585
+ Returns:
586
+ The fixed position embeddings <float32>[batch_size, seq_len, features].
587
+ """
588
+ # We use a cache position index for tracking decoding position.
589
+ if decode:
590
+ position_embedder_index = self.variable(
591
+ 'cache', 'position_embedder_index',
592
+ lambda: jnp.array(-1, dtype=jnp.uint32))
593
+ i = position_embedder_index.value
594
+ position_embedder_index.value = i + 1
595
+ return jax.lax.dynamic_slice(self.embedding, jnp.array((i, 0)),
596
+ np.array((1, self.features)))
597
+
598
+ return jnp.take(self.embedding, inputs, axis=0)
599
+
600
+
601
+ #------------------------------------------------------------------------------
602
+ # T5 Layernorm - no subtraction of mean or bias.
603
+ #------------------------------------------------------------------------------
604
+ class LayerNorm(nn.Module):
605
+ """T5 Layer normalization operating on the last axis of the input data."""
606
+ epsilon: float = 1e-6
607
+ dtype: Any = jnp.float32
608
+ scale_init: Initializer = nn.initializers.ones
609
+
610
+ @nn.compact
611
+ def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
612
+ """Applies layer normalization on the input."""
613
+ x = jnp.asarray(x, jnp.float32)
614
+ features = x.shape[-1]
615
+ mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True)
616
+ y = jnp.asarray(x * lax.rsqrt(mean2 + self.epsilon), self.dtype)
617
+ scale = param_with_axes(
618
+ 'scale', self.scale_init, (features,), jnp.float32, axes=('embed',))
619
+
620
+ scale = jnp.asarray(scale, self.dtype)
621
+ return y * scale
622
+
623
+
624
+ #------------------------------------------------------------------------------
625
+ # Mask-making utility functions.
626
+ #------------------------------------------------------------------------------
627
+ def make_attention_mask(query_input: Array,
628
+ key_input: Array,
629
+ pairwise_fn: Callable = jnp.multiply,
630
+ extra_batch_dims: int = 0,
631
+ dtype: DType = jnp.float32) -> Array:
632
+ """Mask-making helper for attention weights.
633
+
634
+ In case of 1d inputs (i.e., `[batch, len_q]`, `[batch, len_kv]`, the
635
+ attention weights will be `[batch, heads, len_q, len_kv]` and this
636
+ function will produce `[batch, 1, len_q, len_kv]`.
637
+
638
+ Args:
639
+ query_input: a batched, flat input of query_length size
640
+ key_input: a batched, flat input of key_length size
641
+ pairwise_fn: broadcasting elementwise comparison function
642
+ extra_batch_dims: number of extra batch dims to add singleton axes for, none
643
+ by default
644
+ dtype: mask return dtype
645
+
646
+ Returns:
647
+ A `[batch, 1, len_q, len_kv]` shaped mask for 1d attention.
648
+ """
649
+ # [batch, len_q, len_kv]
650
+ mask = pairwise_fn(
651
+ # [batch, len_q] -> [batch, len_q, 1]
652
+ jnp.expand_dims(query_input, axis=-1),
653
+ # [batch, len_q] -> [batch, 1, len_kv]
654
+ jnp.expand_dims(key_input, axis=-2))
655
+
656
+ # [batch, 1, len_q, len_kv]. This creates the head dim.
657
+ mask = jnp.expand_dims(mask, axis=-3)
658
+ mask = jnp.expand_dims(mask, axis=tuple(range(extra_batch_dims)))
659
+ return mask.astype(dtype)
660
+
661
+
662
+ def make_causal_mask(x: Array,
663
+ extra_batch_dims: int = 0,
664
+ dtype: DType = jnp.float32) -> Array:
665
+ """Make a causal mask for self-attention.
666
+
667
+ In case of 1d inputs (i.e., `[batch, len]`, the self-attention weights
668
+ will be `[batch, heads, len, len]` and this function will produce a
669
+ causal mask of shape `[batch, 1, len, len]`.
670
+
671
+ Note that a causal mask does not depend on the values of x; it only depends on
672
+ the shape. If x has padding elements, they will not be treated in a special
673
+ manner.
674
+
675
+ Args:
676
+ x: input array of shape `[batch, len]`
677
+ extra_batch_dims: number of batch dims to add singleton axes for, none by
678
+ default
679
+ dtype: mask return dtype
680
+
681
+ Returns:
682
+ A `[batch, 1, len, len]` shaped causal mask for 1d attention.
683
+ """
684
+ idxs = jnp.broadcast_to(jnp.arange(x.shape[-1], dtype=jnp.int32), x.shape)
685
+ return make_attention_mask(
686
+ idxs,
687
+ idxs,
688
+ jnp.greater_equal,
689
+ extra_batch_dims=extra_batch_dims,
690
+ dtype=dtype)
691
+
692
+
693
+ def combine_masks(*masks: Optional[Array], dtype: DType = jnp.float32):
694
+ """Combine attention masks.
695
+
696
+ Args:
697
+ *masks: set of attention mask arguments to combine, some can be None.
698
+ dtype: final mask dtype
699
+
700
+ Returns:
701
+ Combined mask, reduced by logical and, returns None if no masks given.
702
+ """
703
+ masks = [m for m in masks if m is not None]
704
+ if not masks:
705
+ return None
706
+ assert all(map(lambda x: x.ndim == masks[0].ndim, masks)), (
707
+ f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}')
708
+ mask, *other_masks = masks
709
+ for other_mask in other_masks:
710
+ mask = jnp.logical_and(mask, other_mask)
711
+ return mask.astype(dtype)
712
+
713
+
714
+ def combine_biases(*masks: Optional[Array]):
715
+ """Combine attention biases.
716
+
717
+ Args:
718
+ *masks: set of attention bias arguments to combine, some can be None.
719
+
720
+ Returns:
721
+ Combined mask, reduced by summation, returns None if no masks given.
722
+ """
723
+ masks = [m for m in masks if m is not None]
724
+ if not masks:
725
+ return None
726
+ assert all(map(lambda x: x.ndim == masks[0].ndim, masks)), (
727
+ f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}')
728
+ mask, *other_masks = masks
729
+ for other_mask in other_masks:
730
+ mask = mask + other_mask
731
+ return mask
732
+
733
+
734
+ def make_decoder_mask(decoder_target_tokens: Array,
735
+ dtype: DType,
736
+ decoder_causal_attention: Optional[Array] = None,
737
+ decoder_segment_ids: Optional[Array] = None) -> Array:
738
+ """Compute the self-attention mask for a decoder.
739
+
740
+ Decoder mask is formed by combining a causal mask, a padding mask and an
741
+ optional packing mask. If decoder_causal_attention is passed, it makes the
742
+ masking non-causal for positions that have value of 1.
743
+
744
+ A prefix LM is applied to a dataset which has a notion of "inputs" and
745
+ "targets", e.g., a machine translation task. The inputs and targets are
746
+ concatenated to form a new target. `decoder_target_tokens` is the concatenated
747
+ decoder output tokens.
748
+
749
+ The "inputs" portion of the concatenated sequence can attend to other "inputs"
750
+ tokens even for those at a later time steps. In order to control this
751
+ behavior, `decoder_causal_attention` is necessary. This is a binary mask with
752
+ a value of 1 indicating that the position belonged to "inputs" portion of the
753
+ original dataset.
754
+
755
+ Example:
756
+
757
+ Suppose we have a dataset with two examples.
758
+
759
+ ds = [{"inputs": [6, 7], "targets": [8]},
760
+ {"inputs": [3, 4], "targets": [5]}]
761
+
762
+ After the data preprocessing with packing, the two examples are packed into
763
+ one example with the following three fields (some fields are skipped for
764
+ simplicity).
765
+
766
+ decoder_target_tokens = [[6, 7, 8, 3, 4, 5, 0]]
767
+ decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]]
768
+ decoder_causal_attention = [[1, 1, 0, 1, 1, 0, 0]]
769
+
770
+ where each array has [batch, length] shape with batch size being 1. Then,
771
+ this function computes the following mask.
772
+
773
+ mask = [[[[1, 1, 0, 0, 0, 0, 0],
774
+ [1, 1, 0, 0, 0, 0, 0],
775
+ [1, 1, 1, 0, 0, 0, 0],
776
+ [0, 0, 0, 1, 1, 0, 0],
777
+ [0, 0, 0, 1, 1, 0, 0],
778
+ [0, 0, 0, 1, 1, 1, 0],
779
+ [0, 0, 0, 0, 0, 0, 0]]]]
780
+
781
+ mask[b, 1, :, :] represents the mask for the example `b` in the batch.
782
+ Because mask is for a self-attention layer, the mask's shape is a square of
783
+ shape [query length, key length].
784
+
785
+ mask[b, 1, i, j] = 1 means that the query token at position i can attend to
786
+ the key token at position j.
787
+
788
+ Args:
789
+ decoder_target_tokens: decoder output tokens. [batch, length]
790
+ dtype: dtype of the output mask.
791
+ decoder_causal_attention: a binary mask indicating which position should
792
+ only attend to earlier positions in the sequence. Others will attend
793
+ bidirectionally. [batch, length]
794
+ decoder_segment_ids: decoder segmentation info for packed examples. [batch,
795
+ length]
796
+
797
+ Returns:
798
+ the combined decoder mask.
799
+ """
800
+ masks = []
801
+ # The same mask is applied to all attention heads. So the head dimension is 1,
802
+ # i.e., the mask will be broadcast along the heads dim.
803
+ # [batch, 1, length, length]
804
+ causal_mask = make_causal_mask(decoder_target_tokens, dtype=dtype)
805
+
806
+ # Positions with value 1 in `decoder_causal_attneition` can attend
807
+ # bidirectionally.
808
+ if decoder_causal_attention is not None:
809
+ # [batch, 1, length, length]
810
+ inputs_mask = make_attention_mask(
811
+ decoder_causal_attention,
812
+ decoder_causal_attention,
813
+ jnp.logical_and,
814
+ dtype=dtype)
815
+ masks.append(jnp.logical_or(causal_mask, inputs_mask).astype(dtype))
816
+ else:
817
+ masks.append(causal_mask)
818
+
819
+ # Padding mask.
820
+ masks.append(
821
+ make_attention_mask(
822
+ decoder_target_tokens > 0, decoder_target_tokens > 0, dtype=dtype))
823
+
824
+ # Packing mask
825
+ if decoder_segment_ids is not None:
826
+ masks.append(
827
+ make_attention_mask(
828
+ decoder_segment_ids, decoder_segment_ids, jnp.equal, dtype=dtype))
829
+
830
+ return combine_masks(*masks, dtype=dtype)
mt3/layers_test.py ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The MT3 Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tests for attention classes."""
16
+
17
+ import dataclasses
18
+ from typing import Optional
19
+ from unittest import mock
20
+
21
+ from absl.testing import absltest
22
+ from absl.testing import parameterized
23
+ from flax import linen as nn
24
+ from flax.core import freeze
25
+ from flax.linen import partitioning as nn_partitioning
26
+ import jax
27
+ from jax import random
28
+ from jax.nn import initializers
29
+ import jax.numpy as jnp
30
+ from mt3 import layers
31
+ import numpy as np
32
+
33
+ # Parse absl flags test_srcdir and test_tmpdir.
34
+ jax.config.parse_flags_with_absl()
35
+
36
+ Array = jnp.ndarray
37
+ AxisMetadata = nn_partitioning.AxisMetadata # pylint: disable=invalid-name
38
+
39
+
40
+ class SelfAttention(layers.MultiHeadDotProductAttention):
41
+ """Self-attention special case of multi-head dot-product attention."""
42
+
43
+ @nn.compact
44
+ def __call__(self,
45
+ inputs_q: Array,
46
+ mask: Optional[Array] = None,
47
+ bias: Optional[Array] = None,
48
+ deterministic: bool = False):
49
+ return super().__call__(
50
+ inputs_q, inputs_q, mask, bias, deterministic=deterministic)
51
+
52
+
53
+ @dataclasses.dataclass(frozen=True)
54
+ class SelfAttentionArgs:
55
+ num_heads: int = 1
56
+ batch_size: int = 2
57
+ # qkv_features: int = 3
58
+ head_dim: int = 3
59
+ # out_features: int = 4
60
+ q_len: int = 5
61
+ features: int = 6
62
+ dropout_rate: float = 0.1
63
+ deterministic: bool = False
64
+ decode: bool = False
65
+ float32_logits: bool = False
66
+
67
+ def __post_init__(self):
68
+ # If we are doing decoding, the query length should be 1, because are doing
69
+ # autoregressive decoding where we feed one position at a time.
70
+ assert not self.decode or self.q_len == 1
71
+
72
+ def init_args(self):
73
+ return dict(
74
+ num_heads=self.num_heads,
75
+ head_dim=self.head_dim,
76
+ dropout_rate=self.dropout_rate,
77
+ float32_logits=self.float32_logits)
78
+
79
+ def apply_args(self):
80
+ inputs_q = jnp.ones((self.batch_size, self.q_len, self.features))
81
+ mask = jnp.ones((self.batch_size, self.num_heads, self.q_len, self.q_len))
82
+ bias = jnp.ones((self.batch_size, self.num_heads, self.q_len, self.q_len))
83
+ return {
84
+ 'inputs_q': inputs_q,
85
+ 'mask': mask,
86
+ 'bias': bias,
87
+ 'deterministic': self.deterministic
88
+ }
89
+
90
+
91
+ class AttentionTest(parameterized.TestCase):
92
+
93
+ def test_dot_product_attention_shape(self):
94
+ # This test only checks for shape but tries to make sure all code paths are
95
+ # reached.
96
+ dropout_rng = random.PRNGKey(0)
97
+ batch_size, num_heads, q_len, kv_len, qk_depth, v_depth = 1, 2, 3, 4, 5, 6
98
+
99
+ query = jnp.ones((batch_size, q_len, num_heads, qk_depth))
100
+ key = jnp.ones((batch_size, kv_len, num_heads, qk_depth))
101
+ value = jnp.ones((batch_size, kv_len, num_heads, v_depth))
102
+ bias = jnp.ones((batch_size, num_heads, q_len, kv_len))
103
+
104
+ args = dict(
105
+ query=query,
106
+ key=key,
107
+ value=value,
108
+ bias=bias,
109
+ dropout_rng=dropout_rng,
110
+ dropout_rate=0.5,
111
+ deterministic=False,
112
+ )
113
+
114
+ output = layers.dot_product_attention(**args)
115
+ self.assertEqual(output.shape, (batch_size, q_len, num_heads, v_depth))
116
+
117
+ def test_make_attention_mask_multiply_pairwise_fn(self):
118
+ decoder_target_tokens = jnp.array([[7, 0, 0], [8, 5, 0]])
119
+ attention_mask = layers.make_attention_mask(
120
+ decoder_target_tokens > 0, decoder_target_tokens > 0, dtype=jnp.int32)
121
+ expected0 = jnp.array([[1, 0, 0], [0, 0, 0], [0, 0, 0]])
122
+ expected1 = jnp.array([[1, 1, 0], [1, 1, 0], [0, 0, 0]])
123
+ self.assertEqual(attention_mask.shape, (2, 1, 3, 3))
124
+ np.testing.assert_array_equal(attention_mask[0, 0], expected0)
125
+ np.testing.assert_array_equal(attention_mask[1, 0], expected1)
126
+
127
+ def test_make_attention_mask_equal_pairwise_fn(self):
128
+ segment_ids = jnp.array([[1, 1, 2, 2, 2, 0], [1, 1, 1, 2, 0, 0]])
129
+ attention_mask = layers.make_attention_mask(
130
+ segment_ids, segment_ids, pairwise_fn=jnp.equal, dtype=jnp.int32)
131
+ # Padding is not treated in a special way. So they need to be zeroed out
132
+ # separately.
133
+ expected0 = jnp.array([[1, 1, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0],
134
+ [0, 0, 1, 1, 1, 0], [0, 0, 1, 1, 1, 0],
135
+ [0, 0, 1, 1, 1, 0], [0, 0, 0, 0, 0, 1]])
136
+ expected1 = jnp.array([[1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0],
137
+ [1, 1, 1, 0, 0, 0], [0, 0, 0, 1, 0, 0],
138
+ [0, 0, 0, 0, 1, 1], [0, 0, 0, 0, 1, 1]])
139
+ self.assertEqual(attention_mask.shape, (2, 1, 6, 6))
140
+ np.testing.assert_array_equal(attention_mask[0, 0], expected0)
141
+ np.testing.assert_array_equal(attention_mask[1, 0], expected1)
142
+
143
+ def test_make_causal_mask_with_padding(self):
144
+ x = jnp.array([[7, 0, 0], [8, 5, 0]])
145
+ y = layers.make_causal_mask(x)
146
+ self.assertEqual(y.shape, (2, 1, 3, 3))
147
+ # Padding is not treated in a special way. So they need to be zeroed out
148
+ # separately.
149
+ expected_y = jnp.array([[[1., 0., 0.], [1., 1., 0.], [1., 1., 1.]]],
150
+ jnp.float32)
151
+ np.testing.assert_allclose(y[0], expected_y)
152
+ np.testing.assert_allclose(y[1], expected_y)
153
+
154
+ def test_make_causal_mask_extra_batch_dims(self):
155
+ x = jnp.ones((3, 3, 5))
156
+ y = layers.make_causal_mask(x, extra_batch_dims=2)
157
+ self.assertEqual(y.shape, (1, 1, 3, 3, 1, 5, 5))
158
+
159
+ def test_make_causal_mask(self):
160
+ x = jnp.ones((1, 3))
161
+ y = layers.make_causal_mask(x)
162
+ self.assertEqual(y.shape, (1, 1, 3, 3))
163
+ expected_y = jnp.array([[[[1., 0., 0.], [1., 1., 0.], [1., 1., 1.]]]],
164
+ jnp.float32)
165
+ np.testing.assert_allclose(y, expected_y)
166
+
167
+ def test_combine_masks(self):
168
+ masks = [
169
+ jnp.array([0, 1, 0, 1], jnp.float32), None,
170
+ jnp.array([1, 1, 1, 1], jnp.float32),
171
+ jnp.array([1, 1, 1, 0], jnp.float32)
172
+ ]
173
+ y = layers.combine_masks(*masks)
174
+ np.testing.assert_allclose(y, jnp.array([0, 1, 0, 0], jnp.float32))
175
+
176
+ def test_combine_biases(self):
177
+ masks = [
178
+ jnp.array([0, 1, 0, 1], jnp.float32), None,
179
+ jnp.array([0, 1, 1, 1], jnp.float32),
180
+ jnp.array([0, 1, 1, 0], jnp.float32)
181
+ ]
182
+ y = layers.combine_biases(*masks)
183
+ np.testing.assert_allclose(y, jnp.array([0, 3, 2, 2], jnp.float32))
184
+
185
+ def test_make_decoder_mask_lm_unpacked(self):
186
+ decoder_target_tokens = jnp.array([6, 7, 3, 0])
187
+ mask = layers.make_decoder_mask(
188
+ decoder_target_tokens=decoder_target_tokens, dtype=jnp.float32)
189
+ expected_mask = jnp.array([[[1, 0, 0, 0], [1, 1, 0, 0], [1, 1, 1, 0],
190
+ [0, 0, 0, 0]]])
191
+ np.testing.assert_array_equal(mask, expected_mask)
192
+
193
+ def test_make_decoder_mask_lm_packed(self):
194
+ decoder_target_tokens = jnp.array([[6, 7, 3, 4, 5, 0]])
195
+ decoder_segment_ids = jnp.array([[1, 1, 1, 2, 2, 0]])
196
+ mask = layers.make_decoder_mask(
197
+ decoder_target_tokens=decoder_target_tokens,
198
+ dtype=jnp.float32,
199
+ decoder_segment_ids=decoder_segment_ids)
200
+ expected_mask = jnp.array([[[[1, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0],
201
+ [1, 1, 1, 0, 0, 0], [0, 0, 0, 1, 0, 0],
202
+ [0, 0, 0, 1, 1, 0], [0, 0, 0, 0, 0, 0]]]])
203
+ np.testing.assert_array_equal(mask, expected_mask)
204
+
205
+ def test_make_decoder_mask_prefix_lm_unpacked(self):
206
+ decoder_target_tokens = jnp.array([[5, 6, 7, 3, 4, 0]])
207
+ decoder_causal_attention = jnp.array([[1, 1, 1, 0, 0, 0]])
208
+ mask = layers.make_decoder_mask(
209
+ decoder_target_tokens=decoder_target_tokens,
210
+ dtype=jnp.float32,
211
+ decoder_causal_attention=decoder_causal_attention)
212
+ expected_mask = jnp.array(
213
+ [[[[1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0],
214
+ [1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 0], [0, 0, 0, 0, 0, 0]]]],
215
+ dtype=jnp.float32)
216
+ np.testing.assert_array_equal(mask, expected_mask)
217
+
218
+ def test_make_decoder_mask_prefix_lm_packed(self):
219
+ decoder_target_tokens = jnp.array([[5, 6, 7, 8, 3, 4, 0]])
220
+ decoder_segment_ids = jnp.array([[1, 1, 1, 2, 2, 2, 0]])
221
+ decoder_causal_attention = jnp.array([[1, 1, 0, 1, 1, 0, 0]])
222
+ mask = layers.make_decoder_mask(
223
+ decoder_target_tokens=decoder_target_tokens,
224
+ dtype=jnp.float32,
225
+ decoder_causal_attention=decoder_causal_attention,
226
+ decoder_segment_ids=decoder_segment_ids)
227
+ expected_mask = jnp.array([[[[1, 1, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0, 0],
228
+ [1, 1, 1, 0, 0, 0, 0], [0, 0, 0, 1, 1, 0, 0],
229
+ [0, 0, 0, 1, 1, 0, 0], [0, 0, 0, 1, 1, 1, 0],
230
+ [0, 0, 0, 0, 0, 0, 0]]]])
231
+ np.testing.assert_array_equal(mask, expected_mask)
232
+
233
+ def test_make_decoder_mask_prefix_lm_unpacked_multiple_elements(self):
234
+ decoder_target_tokens = jnp.array([[6, 7, 3, 0], [4, 5, 0, 0]])
235
+ decoder_causal_attention = jnp.array([[1, 1, 0, 0], [1, 0, 0, 0]])
236
+ mask = layers.make_decoder_mask(
237
+ decoder_target_tokens=decoder_target_tokens,
238
+ dtype=jnp.float32,
239
+ decoder_causal_attention=decoder_causal_attention)
240
+ expected_mask0 = jnp.array([[1, 1, 0, 0], [1, 1, 0, 0], [1, 1, 1, 0],
241
+ [0, 0, 0, 0]])
242
+ expected_mask1 = jnp.array([[1, 0, 0, 0], [1, 1, 0, 0], [0, 0, 0, 0],
243
+ [0, 0, 0, 0]])
244
+ self.assertEqual(mask.shape, (2, 1, 4, 4))
245
+ np.testing.assert_array_equal(mask[0, 0], expected_mask0)
246
+ np.testing.assert_array_equal(mask[1, 0], expected_mask1)
247
+
248
+ def test_make_decoder_mask_composite_causal_attention(self):
249
+ decoder_target_tokens = jnp.array([[6, 7, 3, 4, 8, 9, 0]])
250
+ decoder_causal_attention = jnp.array([[1, 1, 0, 0, 1, 1, 0]])
251
+ mask = layers.make_decoder_mask(
252
+ decoder_target_tokens=decoder_target_tokens,
253
+ dtype=jnp.float32,
254
+ decoder_causal_attention=decoder_causal_attention)
255
+ expected_mask0 = jnp.array([[1, 1, 0, 0, 1, 1, 0], [1, 1, 0, 0, 1, 1, 0],
256
+ [1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 0, 0, 0],
257
+ [1, 1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 1, 0],
258
+ [0, 0, 0, 0, 0, 0, 0]])
259
+
260
+ self.assertEqual(mask.shape, (1, 1, 7, 7))
261
+ np.testing.assert_array_equal(mask[0, 0], expected_mask0)
262
+
263
+ def test_make_decoder_mask_composite_causal_attention_packed(self):
264
+ decoder_target_tokens = jnp.array([[6, 7, 3, 4, 8, 9, 2, 3, 4]])
265
+ decoder_segment_ids = jnp.array([[1, 1, 1, 1, 1, 1, 2, 2, 2]])
266
+ decoder_causal_attention = jnp.array([[1, 1, 0, 0, 1, 1, 1, 1, 0]])
267
+ mask = layers.make_decoder_mask(
268
+ decoder_target_tokens=decoder_target_tokens,
269
+ dtype=jnp.float32,
270
+ decoder_causal_attention=decoder_causal_attention,
271
+ decoder_segment_ids=decoder_segment_ids)
272
+ expected_mask0 = jnp.array([[1, 1, 0, 0, 1, 1, 0, 0, 0],
273
+ [1, 1, 0, 0, 1, 1, 0, 0, 0],
274
+ [1, 1, 1, 0, 0, 0, 0, 0, 0],
275
+ [1, 1, 1, 1, 0, 0, 0, 0, 0],
276
+ [1, 1, 1, 1, 1, 1, 0, 0, 0],
277
+ [1, 1, 1, 1, 1, 1, 0, 0, 0],
278
+ [0, 0, 0, 0, 0, 0, 1, 1, 0],
279
+ [0, 0, 0, 0, 0, 0, 1, 1, 0],
280
+ [0, 0, 0, 0, 0, 0, 1, 1, 1]])
281
+
282
+ self.assertEqual(mask.shape, (1, 1, 9, 9))
283
+ np.testing.assert_array_equal(mask[0, 0], expected_mask0)
284
+
285
+ @parameterized.parameters({'f': 20}, {'f': 22})
286
+ def test_multihead_dot_product_attention(self, f):
287
+ # b: batch, f: emb_dim, q: q_len, k: kv_len, h: num_head, d: head_dim
288
+ b, q, h, d, k = 2, 3, 4, 5, 6
289
+
290
+ base_args = SelfAttentionArgs(num_heads=h, head_dim=d, dropout_rate=0)
291
+ args = base_args.init_args()
292
+
293
+ np.random.seed(0)
294
+ inputs_q = np.random.randn(b, q, f)
295
+ inputs_kv = np.random.randn(b, k, f)
296
+
297
+ # Projection: [b, q, f] -> [b, q, h, d]
298
+ # So the kernels have to be [f, h, d]
299
+ query_kernel = np.random.randn(f, h, d)
300
+ key_kernel = np.random.randn(f, h, d)
301
+ value_kernel = np.random.randn(f, h, d)
302
+ # `out` calculation: [b, q, h, d] -> [b, q, f]
303
+ # So kernel has to be [h, d, f]
304
+ out_kernel = np.random.randn(h, d, f)
305
+
306
+ params = {
307
+ 'query': {
308
+ 'kernel': query_kernel.reshape(f, -1)
309
+ },
310
+ 'key': {
311
+ 'kernel': key_kernel.reshape(f, -1)
312
+ },
313
+ 'value': {
314
+ 'kernel': value_kernel.reshape(f, -1)
315
+ },
316
+ 'out': {
317
+ 'kernel': out_kernel.reshape(-1, f)
318
+ }
319
+ }
320
+ y = layers.MultiHeadDotProductAttention(**args).apply(
321
+ {'params': freeze(params)}, inputs_q, inputs_kv)
322
+
323
+ query = np.einsum('bqf,fhd->bqhd', inputs_q, query_kernel)
324
+ key = np.einsum('bkf,fhd->bkhd', inputs_kv, key_kernel)
325
+ value = np.einsum('bkf,fhd->bkhd', inputs_kv, value_kernel)
326
+ logits = np.einsum('bqhd,bkhd->bhqk', query, key)
327
+ weights = nn.softmax(logits, axis=-1)
328
+ combined_value = np.einsum('bhqk,bkhd->bqhd', weights, value)
329
+ y_expected = np.einsum('bqhd,hdf->bqf', combined_value, out_kernel)
330
+ np.testing.assert_allclose(y, y_expected, rtol=1e-5, atol=1e-5)
331
+
332
+ def test_multihead_dot_product_attention_caching(self):
333
+ # b: batch, f: qkv_features, k: kv_len, h: num_head, d: head_dim
334
+ b, h, d, k = 2, 3, 4, 5
335
+ f = h * d
336
+
337
+ base_args = SelfAttentionArgs(num_heads=h, head_dim=d, dropout_rate=0)
338
+ args = base_args.init_args()
339
+
340
+ cache = {
341
+ 'cached_key': np.zeros((b, h, d, k)),
342
+ 'cached_value': np.zeros((b, h, d, k)),
343
+ 'cache_index': np.array(0)
344
+ }
345
+ inputs_q = np.random.randn(b, 1, f)
346
+ inputs_kv = np.random.randn(b, 1, f)
347
+
348
+ # Mock dense general such that q, k, v projections are replaced by simple
349
+ # reshaping.
350
+ def mock_dense_general(self, x, **kwargs): # pylint: disable=unused-argument
351
+ return x.reshape(b, -1, h, d)
352
+
353
+ with mock.patch.object(
354
+ layers.DenseGeneral, '__call__', new=mock_dense_general):
355
+ _, mutated = layers.MultiHeadDotProductAttention(**args).apply(
356
+ {'cache': freeze(cache)},
357
+ inputs_q,
358
+ inputs_kv,
359
+ decode=True,
360
+ mutable=['cache'])
361
+ updated_cache = mutated['cache']
362
+
363
+ # Perform the same mocked projection to generate the expected cache.
364
+ # (key|value): [b, 1, h, d]
365
+ key = mock_dense_general(None, inputs_kv)
366
+ value = mock_dense_general(None, inputs_kv)
367
+
368
+ # cached_(key|value): [b, h, d, k]
369
+ cache['cached_key'][:, :, :, 0] = key[:, 0, :, :]
370
+ cache['cached_value'][:, :, :, 0] = value[:, 0, :, :]
371
+ cache['cache_index'] = np.array(1)
372
+ for name, array in cache.items():
373
+ np.testing.assert_allclose(array, updated_cache[name])
374
+
375
+ def test_dot_product_attention(self):
376
+ # b: batch, f: emb_dim, q: q_len, k: kv_len, h: num_head, d: head_dim
377
+ b, q, h, d, k = 2, 3, 4, 5, 6
378
+ np.random.seed(0)
379
+ query = np.random.randn(b, q, h, d)
380
+ key = np.random.randn(b, k, h, d)
381
+ value = np.random.randn(b, k, h, d)
382
+ bias = np.random.randn(b, h, q, k)
383
+ attn_out = layers.dot_product_attention(query, key, value, bias=bias)
384
+ logits = np.einsum('bqhd,bkhd->bhqk', query, key)
385
+ weights = jax.nn.softmax(logits + bias, axis=-1)
386
+ expected = np.einsum('bhqk,bkhd->bqhd', weights, value)
387
+ np.testing.assert_allclose(attn_out, expected, atol=1e-6)
388
+
389
+
390
+ class EmbeddingTest(parameterized.TestCase):
391
+
392
+ def test_embedder_raises_exception_for_incorrect_input_type(self):
393
+ """Tests that inputs are integers and that an exception is raised if not."""
394
+ embed = layers.Embed(num_embeddings=10, features=5)
395
+ inputs = np.expand_dims(np.arange(5, dtype=np.int64), 1)
396
+ variables = embed.init(jax.random.PRNGKey(0), inputs)
397
+ bad_inputs = inputs.astype(np.float32)
398
+ with self.assertRaisesRegex(
399
+ ValueError, 'Input type must be an integer or unsigned integer.'):
400
+ _ = embed.apply(variables, bad_inputs)
401
+
402
+ @parameterized.named_parameters(
403
+ {
404
+ 'testcase_name': 'with_ones',
405
+ 'init_fn': jax.nn.initializers.ones,
406
+ 'num_embeddings': 10,
407
+ 'features': 5,
408
+ 'matrix_sum': 5 * 10,
409
+ }, {
410
+ 'testcase_name': 'with_zeros',
411
+ 'init_fn': jax.nn.initializers.zeros,
412
+ 'num_embeddings': 10,
413
+ 'features': 5,
414
+ 'matrix_sum': 0,
415
+ })
416
+ def test_embedding_initializes_correctly(self, init_fn, num_embeddings,
417
+ features, matrix_sum):
418
+ """Tests if the Embed class initializes with the requested initializer."""
419
+ embed = layers.Embed(
420
+ num_embeddings=num_embeddings,
421
+ features=features,
422
+ embedding_init=init_fn)
423
+ inputs = np.expand_dims(np.arange(5, dtype=np.int64), 1)
424
+ variables = embed.init(jax.random.PRNGKey(0), inputs)
425
+ embedding_matrix = variables['params']['embedding']
426
+ self.assertEqual(int(np.sum(embedding_matrix)), matrix_sum)
427
+
428
+ def test_embedding_matrix_shape(self):
429
+ """Tests that the embedding matrix has the right shape."""
430
+ num_embeddings = 10
431
+ features = 5
432
+ embed = layers.Embed(num_embeddings=num_embeddings, features=features)
433
+ inputs = np.expand_dims(np.arange(features, dtype=np.int64), 1)
434
+ variables = embed.init(jax.random.PRNGKey(0), inputs)
435
+ embedding_matrix = variables['params']['embedding']
436
+ self.assertEqual((num_embeddings, features), embedding_matrix.shape)
437
+
438
+ def test_embedding_attend(self):
439
+ """Tests that attending with ones returns sum of embedding vectors."""
440
+ features = 5
441
+ embed = layers.Embed(num_embeddings=10, features=features)
442
+ inputs = np.array([[1]], dtype=np.int64)
443
+ variables = embed.init(jax.random.PRNGKey(0), inputs)
444
+ query = np.ones(features, dtype=np.float32)
445
+ result = embed.apply(variables, query, method=embed.attend)
446
+ expected = np.sum(variables['params']['embedding'], -1)
447
+ np.testing.assert_array_almost_equal(result, expected)
448
+
449
+
450
+ class DenseTest(parameterized.TestCase):
451
+
452
+ def test_dense_general_no_bias(self):
453
+ rng = random.PRNGKey(0)
454
+ x = jnp.ones((1, 3))
455
+ model = layers.DenseGeneral(
456
+ features=4,
457
+ kernel_init=initializers.ones,
458
+ )
459
+ y, _ = model.init_with_output(rng, x)
460
+ self.assertEqual(y.shape, (1, 4))
461
+ np.testing.assert_allclose(y, np.full((1, 4), 3.))
462
+
463
+ def test_dense_general_two_features(self):
464
+ rng = random.PRNGKey(0)
465
+ x = jnp.ones((1, 3))
466
+ model = layers.DenseGeneral(
467
+ features=(2, 2),
468
+ kernel_init=initializers.ones,
469
+ )
470
+ y, _ = model.init_with_output(rng, x)
471
+ # We transform the last input dimension to two output dimensions (2, 2).
472
+ np.testing.assert_allclose(y, np.full((1, 2, 2), 3.))
473
+
474
+ def test_dense_general_two_axes(self):
475
+ rng = random.PRNGKey(0)
476
+ x = jnp.ones((1, 2, 2))
477
+ model = layers.DenseGeneral(
478
+ features=3,
479
+ axis=(-2, 2), # Note: this is the same as (1, 2).
480
+ kernel_init=initializers.ones,
481
+ )
482
+ y, _ = model.init_with_output(rng, x)
483
+ # We transform the last two input dimensions (2, 2) to one output dimension.
484
+ np.testing.assert_allclose(y, np.full((1, 3), 4.))
485
+
486
+ def test_mlp_same_out_dim(self):
487
+ module = layers.MlpBlock(
488
+ intermediate_dim=4,
489
+ activations=('relu',),
490
+ kernel_init=nn.initializers.xavier_uniform(),
491
+ dtype=jnp.float32,
492
+ )
493
+ inputs = np.array(
494
+ [
495
+ # Batch 1.
496
+ [[1, 1], [1, 1], [1, 2]],
497
+ # Batch 2.
498
+ [[2, 2], [3, 1], [2, 2]],
499
+ ],
500
+ dtype=np.float32)
501
+ params = module.init(random.PRNGKey(0), inputs, deterministic=True)
502
+ self.assertEqual(
503
+ jax.tree_map(lambda a: a.tolist(), params), {
504
+ 'params': {
505
+ 'wi': {
506
+ 'kernel': [[
507
+ -0.8675811290740967, 0.08417510986328125,
508
+ 0.022586345672607422, -0.9124102592468262
509
+ ],
510
+ [
511
+ -0.19464373588562012, 0.49809837341308594,
512
+ 0.7808468341827393, 0.9267289638519287
513
+ ]],
514
+ },
515
+ 'wo': {
516
+ 'kernel': [[0.01154780387878418, 0.1397249698638916],
517
+ [0.974980354309082, 0.5903260707855225],
518
+ [-0.05997943878173828, 0.616570234298706],
519
+ [0.2934272289276123, 0.8181164264678955]],
520
+ },
521
+ },
522
+ 'params_axes': {
523
+ 'wi': {
524
+ 'kernel_axes': AxisMetadata(names=('embed', 'mlp')),
525
+ },
526
+ 'wo': {
527
+ 'kernel_axes': AxisMetadata(names=('mlp', 'embed')),
528
+ },
529
+ },
530
+ })
531
+ result = module.apply(params, inputs, deterministic=True)
532
+ np.testing.assert_allclose(
533
+ result.tolist(),
534
+ [[[0.5237172245979309, 0.8508185744285583],
535
+ [0.5237172245979309, 0.8508185744285583],
536
+ [1.2344461679458618, 2.3844780921936035]],
537
+ [[1.0474344491958618, 1.7016371488571167],
538
+ [0.6809444427490234, 0.9663378596305847],
539
+ [1.0474344491958618, 1.7016371488571167]]],
540
+ rtol=1e-6,
541
+ )
542
+
543
+
544
+ if __name__ == '__main__':
545
+ absltest.main()
mt3/metrics.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The MT3 Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Transcription metrics."""
16
+
17
+ import collections
18
+ import copy
19
+ import functools
20
+ from typing import Any, Iterable, Mapping, Optional, Sequence
21
+
22
+ import mir_eval
23
+
24
+ from mt3 import event_codec
25
+ from mt3 import metrics_utils
26
+ from mt3 import note_sequences
27
+ from mt3 import spectrograms
28
+ from mt3 import summaries
29
+ from mt3 import vocabularies
30
+
31
+ import note_seq
32
+ import numpy as np
33
+ import seqio
34
+
35
+
36
+ def _program_aware_note_scores(
37
+ ref_ns: note_seq.NoteSequence,
38
+ est_ns: note_seq.NoteSequence,
39
+ granularity_type: str
40
+ ) -> Mapping[str, float]:
41
+ """Compute precision/recall/F1 for notes taking program into account.
42
+
43
+ For non-drum tracks, uses onsets and offsets. For drum tracks, uses onsets
44
+ only. Applies MIDI program map of specified granularity type.
45
+
46
+ Args:
47
+ ref_ns: Reference NoteSequence with ground truth labels.
48
+ est_ns: Estimated NoteSequence.
49
+ granularity_type: String key in vocabularies.PROGRAM_GRANULARITIES dict.
50
+
51
+ Returns:
52
+ A dictionary containing precision, recall, and F1 score.
53
+ """
54
+ program_map_fn = vocabularies.PROGRAM_GRANULARITIES[
55
+ granularity_type].program_map_fn
56
+
57
+ ref_ns = copy.deepcopy(ref_ns)
58
+ for note in ref_ns.notes:
59
+ if not note.is_drum:
60
+ note.program = program_map_fn(note.program)
61
+
62
+ est_ns = copy.deepcopy(est_ns)
63
+ for note in est_ns.notes:
64
+ if not note.is_drum:
65
+ note.program = program_map_fn(note.program)
66
+
67
+ program_and_is_drum_tuples = (
68
+ set((note.program, note.is_drum) for note in ref_ns.notes) |
69
+ set((note.program, note.is_drum) for note in est_ns.notes)
70
+ )
71
+
72
+ drum_precision_sum = 0.0
73
+ drum_precision_count = 0
74
+ drum_recall_sum = 0.0
75
+ drum_recall_count = 0
76
+
77
+ nondrum_precision_sum = 0.0
78
+ nondrum_precision_count = 0
79
+ nondrum_recall_sum = 0.0
80
+ nondrum_recall_count = 0
81
+
82
+ for program, is_drum in program_and_is_drum_tuples:
83
+ est_track = note_sequences.extract_track(est_ns, program, is_drum)
84
+ ref_track = note_sequences.extract_track(ref_ns, program, is_drum)
85
+
86
+ est_intervals, est_pitches, unused_est_velocities = (
87
+ note_seq.sequences_lib.sequence_to_valued_intervals(est_track))
88
+ ref_intervals, ref_pitches, unused_ref_velocities = (
89
+ note_seq.sequences_lib.sequence_to_valued_intervals(ref_track))
90
+
91
+ args = {
92
+ 'ref_intervals': ref_intervals, 'ref_pitches': ref_pitches,
93
+ 'est_intervals': est_intervals, 'est_pitches': est_pitches
94
+ }
95
+ if is_drum:
96
+ args['offset_ratio'] = None
97
+
98
+ precision, recall, unused_f_measure, unused_avg_overlap_ratio = (
99
+ mir_eval.transcription.precision_recall_f1_overlap(**args))
100
+
101
+ if is_drum:
102
+ drum_precision_sum += precision * len(est_intervals)
103
+ drum_precision_count += len(est_intervals)
104
+ drum_recall_sum += recall * len(ref_intervals)
105
+ drum_recall_count += len(ref_intervals)
106
+ else:
107
+ nondrum_precision_sum += precision * len(est_intervals)
108
+ nondrum_precision_count += len(est_intervals)
109
+ nondrum_recall_sum += recall * len(ref_intervals)
110
+ nondrum_recall_count += len(ref_intervals)
111
+
112
+ precision_sum = drum_precision_sum + nondrum_precision_sum
113
+ precision_count = drum_precision_count + nondrum_precision_count
114
+ recall_sum = drum_recall_sum + nondrum_recall_sum
115
+ recall_count = drum_recall_count + nondrum_recall_count
116
+
117
+ precision = (precision_sum / precision_count) if precision_count else 0
118
+ recall = (recall_sum / recall_count) if recall_count else 0
119
+ f_measure = mir_eval.util.f_measure(precision, recall)
120
+
121
+ drum_precision = ((drum_precision_sum / drum_precision_count)
122
+ if drum_precision_count else 0)
123
+ drum_recall = ((drum_recall_sum / drum_recall_count)
124
+ if drum_recall_count else 0)
125
+ drum_f_measure = mir_eval.util.f_measure(drum_precision, drum_recall)
126
+
127
+ nondrum_precision = ((nondrum_precision_sum / nondrum_precision_count)
128
+ if nondrum_precision_count else 0)
129
+ nondrum_recall = ((nondrum_recall_sum / nondrum_recall_count)
130
+ if nondrum_recall_count else 0)
131
+ nondrum_f_measure = mir_eval.util.f_measure(nondrum_precision, nondrum_recall)
132
+
133
+ return {
134
+ f'Onset + offset + program precision ({granularity_type})': precision,
135
+ f'Onset + offset + program recall ({granularity_type})': recall,
136
+ f'Onset + offset + program F1 ({granularity_type})': f_measure,
137
+ f'Drum onset precision ({granularity_type})': drum_precision,
138
+ f'Drum onset recall ({granularity_type})': drum_recall,
139
+ f'Drum onset F1 ({granularity_type})': drum_f_measure,
140
+ f'Nondrum onset + offset + program precision ({granularity_type})':
141
+ nondrum_precision,
142
+ f'Nondrum onset + offset + program recall ({granularity_type})':
143
+ nondrum_recall,
144
+ f'Nondrum onset + offset + program F1 ({granularity_type})':
145
+ nondrum_f_measure
146
+ }
147
+
148
+
149
+ def _note_onset_tolerance_sweep(
150
+ ref_ns: note_seq.NoteSequence, est_ns: note_seq.NoteSequence,
151
+ tolerances: Iterable[float] = (0.01, 0.02, 0.05, 0.1, 0.2, 0.5)
152
+ ) -> Mapping[str, float]:
153
+ """Compute note precision/recall/F1 across a range of tolerances."""
154
+ est_intervals, est_pitches, unused_est_velocities = (
155
+ note_seq.sequences_lib.sequence_to_valued_intervals(est_ns))
156
+ ref_intervals, ref_pitches, unused_ref_velocities = (
157
+ note_seq.sequences_lib.sequence_to_valued_intervals(ref_ns))
158
+
159
+ scores = {}
160
+
161
+ for tol in tolerances:
162
+ precision, recall, f_measure, _ = (
163
+ mir_eval.transcription.precision_recall_f1_overlap(
164
+ ref_intervals=ref_intervals, ref_pitches=ref_pitches,
165
+ est_intervals=est_intervals, est_pitches=est_pitches,
166
+ onset_tolerance=tol, offset_min_tolerance=tol))
167
+
168
+ scores[f'Onset + offset precision ({tol})'] = precision
169
+ scores[f'Onset + offset recall ({tol})'] = recall
170
+ scores[f'Onset + offset F1 ({tol})'] = f_measure
171
+
172
+ return scores
173
+
174
+
175
+ def transcription_metrics(
176
+ targets: Sequence[Mapping[str, Any]],
177
+ predictions: Sequence[Mapping[str, Any]],
178
+ codec: event_codec.Codec,
179
+ spectrogram_config: spectrograms.SpectrogramConfig,
180
+ onsets_only: bool,
181
+ use_ties: bool,
182
+ track_specs: Optional[Sequence[note_sequences.TrackSpec]] = None,
183
+ num_summary_examples: int = 5,
184
+ frame_fps: float = 62.5,
185
+ frame_velocity_threshold: int = 30,
186
+ ) -> Mapping[str, seqio.metrics.MetricValue]:
187
+ """Compute mir_eval transcription metrics."""
188
+ if onsets_only and use_ties:
189
+ raise ValueError('Ties not compatible with onset-only transcription.')
190
+ if onsets_only:
191
+ encoding_spec = note_sequences.NoteOnsetEncodingSpec
192
+ elif not use_ties:
193
+ encoding_spec = note_sequences.NoteEncodingSpec
194
+ else:
195
+ encoding_spec = note_sequences.NoteEncodingWithTiesSpec
196
+
197
+ # The first target for each full example contains the NoteSequence; just
198
+ # organize by ID.
199
+ full_targets = {}
200
+ for target in targets:
201
+ if target['ref_ns']:
202
+ full_targets[target['unique_id']] = {'ref_ns': target['ref_ns']}
203
+
204
+ # Gather all predictions for the same ID and concatenate them in time order,
205
+ # to construct full-length predictions.
206
+ full_predictions = metrics_utils.combine_predictions_by_id(
207
+ predictions=predictions,
208
+ combine_predictions_fn=functools.partial(
209
+ metrics_utils.event_predictions_to_ns,
210
+ codec=codec,
211
+ encoding_spec=encoding_spec))
212
+
213
+ assert sorted(full_targets.keys()) == sorted(full_predictions.keys())
214
+
215
+ full_target_prediction_pairs = [
216
+ (full_targets[id], full_predictions[id])
217
+ for id in sorted(full_targets.keys())
218
+ ]
219
+
220
+ scores = collections.defaultdict(list)
221
+ all_track_pianorolls = collections.defaultdict(list)
222
+ for target, prediction in full_target_prediction_pairs:
223
+ scores['Invalid events'].append(prediction['est_invalid_events'])
224
+ scores['Dropped events'].append(prediction['est_dropped_events'])
225
+
226
+ def remove_drums(ns):
227
+ ns_drumless = note_seq.NoteSequence()
228
+ ns_drumless.CopyFrom(ns)
229
+ del ns_drumless.notes[:]
230
+ ns_drumless.notes.extend([note for note in ns.notes if not note.is_drum])
231
+ return ns_drumless
232
+
233
+ est_ns_drumless = remove_drums(prediction['est_ns'])
234
+ ref_ns_drumless = remove_drums(target['ref_ns'])
235
+
236
+ # Whether or not there are separate tracks, compute metrics for the full
237
+ # NoteSequence minus drums.
238
+ est_tracks = [est_ns_drumless]
239
+ ref_tracks = [ref_ns_drumless]
240
+ use_track_offsets = [not onsets_only]
241
+ use_track_velocities = [not onsets_only]
242
+ track_instrument_names = ['']
243
+
244
+ if track_specs is not None:
245
+ # Compute transcription metrics separately for each track.
246
+ for spec in track_specs:
247
+ est_tracks.append(note_sequences.extract_track(
248
+ prediction['est_ns'], spec.program, spec.is_drum))
249
+ ref_tracks.append(note_sequences.extract_track(
250
+ target['ref_ns'], spec.program, spec.is_drum))
251
+ use_track_offsets.append(not onsets_only and not spec.is_drum)
252
+ use_track_velocities.append(not onsets_only)
253
+ track_instrument_names.append(spec.name)
254
+
255
+ for est_ns, ref_ns, use_offsets, use_velocities, instrument_name in zip(
256
+ est_tracks, ref_tracks, use_track_offsets, use_track_velocities,
257
+ track_instrument_names):
258
+ track_scores = {}
259
+
260
+ est_intervals, est_pitches, est_velocities = (
261
+ note_seq.sequences_lib.sequence_to_valued_intervals(est_ns))
262
+
263
+ ref_intervals, ref_pitches, ref_velocities = (
264
+ note_seq.sequences_lib.sequence_to_valued_intervals(ref_ns))
265
+
266
+ # Precision / recall / F1 using onsets (and pitches) only.
267
+ precision, recall, f_measure, avg_overlap_ratio = (
268
+ mir_eval.transcription.precision_recall_f1_overlap(
269
+ ref_intervals=ref_intervals,
270
+ ref_pitches=ref_pitches,
271
+ est_intervals=est_intervals,
272
+ est_pitches=est_pitches,
273
+ offset_ratio=None))
274
+ del avg_overlap_ratio
275
+ track_scores['Onset precision'] = precision
276
+ track_scores['Onset recall'] = recall
277
+ track_scores['Onset F1'] = f_measure
278
+
279
+ if use_offsets:
280
+ # Precision / recall / F1 using onsets and offsets.
281
+ precision, recall, f_measure, avg_overlap_ratio = (
282
+ mir_eval.transcription.precision_recall_f1_overlap(
283
+ ref_intervals=ref_intervals,
284
+ ref_pitches=ref_pitches,
285
+ est_intervals=est_intervals,
286
+ est_pitches=est_pitches))
287
+ del avg_overlap_ratio
288
+ track_scores['Onset + offset precision'] = precision
289
+ track_scores['Onset + offset recall'] = recall
290
+ track_scores['Onset + offset F1'] = f_measure
291
+
292
+ if use_velocities:
293
+ # Precision / recall / F1 using onsets and velocities (no offsets).
294
+ precision, recall, f_measure, avg_overlap_ratio = (
295
+ mir_eval.transcription_velocity.precision_recall_f1_overlap(
296
+ ref_intervals=ref_intervals,
297
+ ref_pitches=ref_pitches,
298
+ ref_velocities=ref_velocities,
299
+ est_intervals=est_intervals,
300
+ est_pitches=est_pitches,
301
+ est_velocities=est_velocities,
302
+ offset_ratio=None))
303
+ track_scores['Onset + velocity precision'] = precision
304
+ track_scores['Onset + velocity recall'] = recall
305
+ track_scores['Onset + velocity F1'] = f_measure
306
+
307
+ if use_offsets and use_velocities:
308
+ # Precision / recall / F1 using onsets, offsets, and velocities.
309
+ precision, recall, f_measure, avg_overlap_ratio = (
310
+ mir_eval.transcription_velocity.precision_recall_f1_overlap(
311
+ ref_intervals=ref_intervals,
312
+ ref_pitches=ref_pitches,
313
+ ref_velocities=ref_velocities,
314
+ est_intervals=est_intervals,
315
+ est_pitches=est_pitches,
316
+ est_velocities=est_velocities))
317
+ track_scores['Onset + offset + velocity precision'] = precision
318
+ track_scores['Onset + offset + velocity recall'] = recall
319
+ track_scores['Onset + offset + velocity F1'] = f_measure
320
+
321
+ # Calculate framewise metrics.
322
+ is_drum = all([n.is_drum for n in ref_ns.notes])
323
+ ref_pr = metrics_utils.get_prettymidi_pianoroll(
324
+ ref_ns, frame_fps, is_drum=is_drum)
325
+ est_pr = metrics_utils.get_prettymidi_pianoroll(
326
+ est_ns, frame_fps, is_drum=is_drum)
327
+ all_track_pianorolls[instrument_name].append((est_pr, ref_pr))
328
+ frame_precision, frame_recall, frame_f1 = metrics_utils.frame_metrics(
329
+ ref_pr, est_pr, velocity_threshold=frame_velocity_threshold)
330
+ track_scores['Frame Precision'] = frame_precision
331
+ track_scores['Frame Recall'] = frame_recall
332
+ track_scores['Frame F1'] = frame_f1
333
+
334
+ for metric_name, metric_value in track_scores.items():
335
+ if instrument_name:
336
+ scores[f'{instrument_name}/{metric_name}'].append(metric_value)
337
+ else:
338
+ scores[metric_name].append(metric_value)
339
+
340
+ # Add program-aware note metrics for all program granularities.
341
+ # Note that this interacts with the training program granularity; in
342
+ # particular granularities *higher* than the training granularity are likely
343
+ # to have poor metrics.
344
+ for granularity_type in vocabularies.PROGRAM_GRANULARITIES:
345
+ for name, score in _program_aware_note_scores(
346
+ target['ref_ns'], prediction['est_ns'],
347
+ granularity_type=granularity_type).items():
348
+ scores[name].append(score)
349
+
350
+ # Add (non-program-aware) note metrics across a range of onset/offset
351
+ # tolerances.
352
+ for name, score in _note_onset_tolerance_sweep(
353
+ ref_ns=ref_ns_drumless, est_ns=est_ns_drumless).items():
354
+ scores[name].append(score)
355
+
356
+ mean_scores = {k: np.mean(v) for k, v in scores.items()}
357
+
358
+ score_histograms = {'%s (hist)' % k: seqio.metrics.Histogram(np.array(v))
359
+ for k, v in scores.items()}
360
+
361
+ # Pick several examples to summarize.
362
+ targets_to_summarize, predictions_to_summarize = zip(
363
+ *full_target_prediction_pairs[:num_summary_examples])
364
+
365
+ # Compute audio summaries.
366
+ audio_summaries = summaries.audio_summaries(
367
+ targets=targets_to_summarize,
368
+ predictions=predictions_to_summarize,
369
+ spectrogram_config=spectrogram_config)
370
+
371
+ # Compute transcription summaries.
372
+ transcription_summaries = summaries.transcription_summaries(
373
+ targets=targets_to_summarize,
374
+ predictions=predictions_to_summarize,
375
+ spectrogram_config=spectrogram_config,
376
+ ns_feature_suffix='ns',
377
+ track_specs=track_specs)
378
+
379
+ pianorolls_to_summarize = {
380
+ k: v[:num_summary_examples] for k, v in all_track_pianorolls.items()
381
+ }
382
+
383
+ prettymidi_pianoroll_summaries = summaries.prettymidi_pianoroll(
384
+ pianorolls_to_summarize, fps=frame_fps)
385
+
386
+ return {
387
+ **mean_scores,
388
+ **score_histograms,
389
+ **audio_summaries,
390
+ **transcription_summaries,
391
+ **prettymidi_pianoroll_summaries,
392
+ }
mt3/metrics_utils.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The MT3 Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Utilities for transcription metrics."""
16
+
17
+ import collections
18
+ import functools
19
+
20
+ from typing import Any, Callable, Mapping, Optional, Sequence, Tuple, TypeVar
21
+
22
+ from mt3 import event_codec
23
+ from mt3 import note_sequences
24
+ from mt3 import run_length_encoding
25
+
26
+ import note_seq
27
+ import numpy as np
28
+ import pretty_midi
29
+ import sklearn
30
+
31
+ S = TypeVar('S')
32
+ T = TypeVar('T')
33
+
34
+ CombineExamplesFunctionType = Callable[[Sequence[Mapping[str, Any]]],
35
+ Mapping[str, Any]]
36
+
37
+
38
+ def _group_predictions_by_id(
39
+ predictions: Sequence[Mapping[str, T]]
40
+ ) -> Mapping[str, Sequence[T]]:
41
+ predictions_by_id = collections.defaultdict(list)
42
+ for pred in predictions:
43
+ predictions_by_id[pred['unique_id']].append(pred)
44
+ return predictions_by_id
45
+
46
+
47
+ def combine_predictions_by_id(
48
+ predictions: Sequence[Mapping[str, Any]],
49
+ combine_predictions_fn: CombineExamplesFunctionType
50
+ ) -> Mapping[str, Mapping[str, Any]]:
51
+ """Concatenate predicted examples, grouping by ID and sorting by time."""
52
+ predictions_by_id = _group_predictions_by_id(predictions)
53
+ return {
54
+ id: combine_predictions_fn(preds)
55
+ for id, preds in predictions_by_id.items()
56
+ }
57
+
58
+
59
+ def decode_and_combine_predictions(
60
+ predictions: Sequence[Mapping[str, Any]],
61
+ init_state_fn: Callable[[], S],
62
+ begin_segment_fn: Callable[[S], None],
63
+ decode_tokens_fn: Callable[[S, Sequence[int], int, Optional[int]],
64
+ Tuple[int, int]],
65
+ flush_state_fn: Callable[[S], T]
66
+ ) -> Tuple[T, int, int]:
67
+ """Decode and combine a sequence of predictions to a full result.
68
+
69
+ For time-based events, this usually means concatenation.
70
+
71
+ Args:
72
+ predictions: List of predictions, each of which is a dictionary containing
73
+ estimated tokens ('est_tokens') and start time ('start_time') fields.
74
+ init_state_fn: Function that takes no arguments and returns an initial
75
+ decoding state.
76
+ begin_segment_fn: Function that updates the decoding state at the beginning
77
+ of a segment.
78
+ decode_tokens_fn: Function that takes a decoding state, estimated tokens
79
+ (for a single segment), start time, and max time, and processes the
80
+ tokens, updating the decoding state in place. Also returns the number of
81
+ invalid and dropped events for the segment.
82
+ flush_state_fn: Function that flushes the final decoding state into the
83
+ result.
84
+
85
+ Returns:
86
+ result: The full combined decoding.
87
+ total_invalid_events: Total number of invalid event tokens across all
88
+ predictions.
89
+ total_dropped_events: Total number of dropped event tokens across all
90
+ predictions.
91
+ """
92
+ sorted_predictions = sorted(predictions, key=lambda pred: pred['start_time'])
93
+
94
+ state = init_state_fn()
95
+ total_invalid_events = 0
96
+ total_dropped_events = 0
97
+
98
+ for pred_idx, pred in enumerate(sorted_predictions):
99
+ begin_segment_fn(state)
100
+
101
+ # Depending on the audio token hop length, each symbolic token could be
102
+ # associated with multiple audio frames. Since we split up the audio frames
103
+ # into segments for prediction, this could lead to overlap. To prevent
104
+ # overlap issues, ensure that the current segment does not make any
105
+ # predictions for the time period covered by the subsequent segment.
106
+ max_decode_time = None
107
+ if pred_idx < len(sorted_predictions) - 1:
108
+ max_decode_time = sorted_predictions[pred_idx + 1]['start_time']
109
+
110
+ invalid_events, dropped_events = decode_tokens_fn(
111
+ state, pred['est_tokens'], pred['start_time'], max_decode_time)
112
+
113
+ total_invalid_events += invalid_events
114
+ total_dropped_events += dropped_events
115
+
116
+ return flush_state_fn(state), total_invalid_events, total_dropped_events
117
+
118
+
119
+ def event_predictions_to_ns(
120
+ predictions: Sequence[Mapping[str, Any]], codec: event_codec.Codec,
121
+ encoding_spec: note_sequences.NoteEncodingSpecType
122
+ ) -> Mapping[str, Any]:
123
+ """Convert a sequence of predictions to a combined NoteSequence."""
124
+ ns, total_invalid_events, total_dropped_events = decode_and_combine_predictions(
125
+ predictions=predictions,
126
+ init_state_fn=encoding_spec.init_decoding_state_fn,
127
+ begin_segment_fn=encoding_spec.begin_decoding_segment_fn,
128
+ decode_tokens_fn=functools.partial(
129
+ run_length_encoding.decode_events,
130
+ codec=codec,
131
+ decode_event_fn=encoding_spec.decode_event_fn),
132
+ flush_state_fn=encoding_spec.flush_decoding_state_fn)
133
+
134
+ # Also concatenate raw inputs from all predictions.
135
+ sorted_predictions = sorted(predictions, key=lambda pred: pred['start_time'])
136
+ raw_inputs = np.concatenate(
137
+ [pred['raw_inputs'] for pred in sorted_predictions], axis=0)
138
+ start_times = [pred['start_time'] for pred in sorted_predictions]
139
+
140
+ return {
141
+ 'raw_inputs': raw_inputs,
142
+ 'start_times': start_times,
143
+ 'est_ns': ns,
144
+ 'est_invalid_events': total_invalid_events,
145
+ 'est_dropped_events': total_dropped_events,
146
+ }
147
+
148
+
149
+ def get_prettymidi_pianoroll(ns: note_seq.NoteSequence, fps: float,
150
+ is_drum: bool):
151
+ """Convert NoteSequence to pianoroll through pretty_midi."""
152
+ for note in ns.notes:
153
+ if is_drum or note.end_time - note.start_time < 0.05:
154
+ # Give all drum notes a fixed length, and all others a min length
155
+ note.end_time = note.start_time + 0.05
156
+
157
+ pm = note_seq.note_sequence_to_pretty_midi(ns)
158
+ end_time = pm.get_end_time()
159
+ cc = [
160
+ # all sound off
161
+ pretty_midi.ControlChange(number=120, value=0, time=end_time),
162
+ # all notes off
163
+ pretty_midi.ControlChange(number=123, value=0, time=end_time)
164
+ ]
165
+ pm.instruments[0].control_changes = cc
166
+ if is_drum:
167
+ # If inst.is_drum is set, pretty_midi will return an all zero pianoroll.
168
+ for inst in pm.instruments:
169
+ inst.is_drum = False
170
+ pianoroll = pm.get_piano_roll(fs=fps)
171
+ return pianoroll
172
+
173
+
174
+ def frame_metrics(ref_pianoroll: np.ndarray,
175
+ est_pianoroll: np.ndarray,
176
+ velocity_threshold: int) -> Tuple[float, float, float]:
177
+ """Frame Precision, Recall, and F1."""
178
+ # Pad to same length
179
+ if ref_pianoroll.shape[1] > est_pianoroll.shape[1]:
180
+ diff = ref_pianoroll.shape[1] - est_pianoroll.shape[1]
181
+ est_pianoroll = np.pad(est_pianoroll, [(0, 0), (0, diff)], mode='constant')
182
+ elif est_pianoroll.shape[1] > ref_pianoroll.shape[1]:
183
+ diff = est_pianoroll.shape[1] - ref_pianoroll.shape[1]
184
+ ref_pianoroll = np.pad(ref_pianoroll, [(0, 0), (0, diff)], mode='constant')
185
+
186
+ # For ref, remove any notes that are too quiet (consistent with Cerberus.)
187
+ ref_frames_bool = ref_pianoroll > velocity_threshold
188
+ # For est, keep all predicted notes.
189
+ est_frames_bool = est_pianoroll > 0
190
+
191
+ precision, recall, f1, _ = sklearn.metrics.precision_recall_fscore_support(
192
+ ref_frames_bool.flatten(),
193
+ est_frames_bool.flatten(),
194
+ labels=[True, False])
195
+
196
+ return precision[0], recall[0], f1[0]
mt3/metrics_utils_test.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The MT3 Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tests for metrics_utils."""
16
+
17
+ from mt3 import event_codec
18
+ from mt3 import metrics_utils
19
+ from mt3 import note_sequences
20
+
21
+ import note_seq
22
+ import numpy as np
23
+ import tensorflow as tf
24
+
25
+
26
+ class MetricsUtilsTest(tf.test.TestCase):
27
+
28
+ def test_event_predictions_to_ns(self):
29
+ predictions = [
30
+ {
31
+ 'raw_inputs': [0, 0],
32
+ 'start_time': 0.0,
33
+ 'est_tokens': [20, 160],
34
+ },
35
+ {
36
+ 'raw_inputs': [1, 1],
37
+ 'start_time': 0.4,
38
+ # These last 2 events should be dropped.
39
+ 'est_tokens': [20, 161, 50, 162],
40
+ },
41
+ {
42
+ 'raw_inputs': [2, 2],
43
+ 'start_time': 0.8,
44
+ 'est_tokens': [163, 20, 164]
45
+ },
46
+ ]
47
+ expected_ns = note_seq.NoteSequence(ticks_per_quarter=220)
48
+ expected_ns.notes.add(
49
+ pitch=59,
50
+ velocity=100,
51
+ start_time=0.20,
52
+ end_time=0.21)
53
+ expected_ns.notes.add(
54
+ pitch=60,
55
+ velocity=100,
56
+ start_time=0.60,
57
+ end_time=0.61)
58
+ expected_ns.notes.add(
59
+ pitch=62,
60
+ velocity=100,
61
+ start_time=0.80,
62
+ end_time=0.81)
63
+ expected_ns.notes.add(
64
+ pitch=63,
65
+ velocity=100,
66
+ start_time=1.00,
67
+ end_time=1.01)
68
+ expected_ns.total_time = 1.01
69
+
70
+ codec = event_codec.Codec(
71
+ max_shift_steps=100,
72
+ steps_per_second=100,
73
+ event_ranges=[
74
+ event_codec.EventRange('pitch', note_seq.MIN_MIDI_PITCH,
75
+ note_seq.MAX_MIDI_PITCH)])
76
+ res = metrics_utils.event_predictions_to_ns(
77
+ predictions, codec=codec,
78
+ encoding_spec=note_sequences.NoteOnsetEncodingSpec)
79
+ self.assertProtoEquals(expected_ns, res['est_ns'])
80
+ self.assertEqual(0, res['est_invalid_events'])
81
+ self.assertEqual(2, res['est_dropped_events'])
82
+ np.testing.assert_array_equal([0, 0, 1, 1, 2, 2], res['raw_inputs'])
83
+
84
+ def test_event_predictions_to_ns_with_offsets(self):
85
+ predictions = [
86
+ {
87
+ 'raw_inputs': [0, 0],
88
+ 'start_time': 0.0,
89
+ 'est_tokens': [20, 356, 160],
90
+ },
91
+ {
92
+ 'raw_inputs': [1, 1],
93
+ 'start_time': 0.4,
94
+ 'est_tokens': [20, 292, 161],
95
+ },
96
+ {
97
+ 'raw_inputs': [2, 2],
98
+ 'start_time': 0.8,
99
+ 'est_tokens': [20, 229, 160, 161]
100
+ },
101
+ ]
102
+ expected_ns = note_seq.NoteSequence(ticks_per_quarter=220)
103
+ expected_ns.notes.add(
104
+ pitch=59,
105
+ velocity=127,
106
+ start_time=0.20,
107
+ end_time=1.00)
108
+ expected_ns.notes.add(
109
+ pitch=60,
110
+ velocity=63,
111
+ start_time=0.60,
112
+ end_time=1.00)
113
+ expected_ns.total_time = 1.00
114
+
115
+ codec = event_codec.Codec(
116
+ max_shift_steps=100,
117
+ steps_per_second=100,
118
+ event_ranges=[
119
+ event_codec.EventRange('pitch', note_seq.MIN_MIDI_PITCH,
120
+ note_seq.MAX_MIDI_PITCH),
121
+ event_codec.EventRange('velocity', 0, 127)
122
+ ])
123
+ res = metrics_utils.event_predictions_to_ns(
124
+ predictions, codec=codec, encoding_spec=note_sequences.NoteEncodingSpec)
125
+ self.assertProtoEquals(expected_ns, res['est_ns'])
126
+ self.assertEqual(0, res['est_invalid_events'])
127
+ self.assertEqual(0, res['est_dropped_events'])
128
+ np.testing.assert_array_equal([0, 0, 1, 1, 2, 2], res['raw_inputs'])
129
+
130
+ def test_event_predictions_to_ns_multitrack(self):
131
+ predictions = [
132
+ {
133
+ 'raw_inputs': [0, 0],
134
+ 'start_time': 0.0,
135
+ 'est_tokens': [20, 517, 356, 160],
136
+ },
137
+ {
138
+ 'raw_inputs': [1, 1],
139
+ 'start_time': 0.4,
140
+ 'est_tokens': [20, 356, 399],
141
+ },
142
+ {
143
+ 'raw_inputs': [2, 2],
144
+ 'start_time': 0.8,
145
+ 'est_tokens': [20, 517, 229, 160]
146
+ },
147
+ ]
148
+ expected_ns = note_seq.NoteSequence(ticks_per_quarter=220)
149
+ expected_ns.notes.add(
150
+ pitch=42,
151
+ velocity=127,
152
+ start_time=0.60,
153
+ end_time=0.61,
154
+ is_drum=True,
155
+ instrument=9)
156
+ expected_ns.notes.add(
157
+ pitch=59,
158
+ velocity=127,
159
+ start_time=0.20,
160
+ end_time=1.00,
161
+ program=32)
162
+ expected_ns.total_time = 1.00
163
+
164
+ codec = event_codec.Codec(
165
+ max_shift_steps=100,
166
+ steps_per_second=100,
167
+ event_ranges=[
168
+ event_codec.EventRange('pitch', note_seq.MIN_MIDI_PITCH,
169
+ note_seq.MAX_MIDI_PITCH),
170
+ event_codec.EventRange('velocity', 0, 127),
171
+ event_codec.EventRange('drum', note_seq.MIN_MIDI_PITCH,
172
+ note_seq.MAX_MIDI_PITCH),
173
+ event_codec.EventRange('program', note_seq.MIN_MIDI_PROGRAM,
174
+ note_seq.MAX_MIDI_PROGRAM)
175
+ ])
176
+ res = metrics_utils.event_predictions_to_ns(
177
+ predictions, codec=codec, encoding_spec=note_sequences.NoteEncodingSpec)
178
+ self.assertProtoEquals(expected_ns, res['est_ns'])
179
+ self.assertEqual(0, res['est_invalid_events'])
180
+ self.assertEqual(0, res['est_dropped_events'])
181
+ np.testing.assert_array_equal([0, 0, 1, 1, 2, 2], res['raw_inputs'])
182
+
183
+ def test_event_predictions_to_ns_multitrack_ties(self):
184
+ predictions = [
185
+ {
186
+ 'raw_inputs': [0, 0],
187
+ 'start_time': 0.0,
188
+ 'est_tokens': [613, # no tied notes
189
+ 20, 517, 356, 160],
190
+ },
191
+ {
192
+ 'raw_inputs': [1, 1],
193
+ 'start_time': 0.4,
194
+ 'est_tokens': [517, 160, 613, # tied note
195
+ 20, 356, 399],
196
+ },
197
+ {
198
+ 'raw_inputs': [2, 2],
199
+ 'start_time': 0.8,
200
+ 'est_tokens': [613] # no tied notes, causing active note to end
201
+ },
202
+ ]
203
+ expected_ns = note_seq.NoteSequence(ticks_per_quarter=220)
204
+ expected_ns.notes.add(
205
+ pitch=42,
206
+ velocity=127,
207
+ start_time=0.60,
208
+ end_time=0.61,
209
+ is_drum=True,
210
+ instrument=9)
211
+ expected_ns.notes.add(
212
+ pitch=59,
213
+ velocity=127,
214
+ start_time=0.20,
215
+ end_time=0.80,
216
+ program=32)
217
+ expected_ns.total_time = 0.80
218
+
219
+ codec = event_codec.Codec(
220
+ max_shift_steps=100,
221
+ steps_per_second=100,
222
+ event_ranges=[
223
+ event_codec.EventRange('pitch', note_seq.MIN_MIDI_PITCH,
224
+ note_seq.MAX_MIDI_PITCH),
225
+ event_codec.EventRange('velocity', 0, 127),
226
+ event_codec.EventRange('drum', note_seq.MIN_MIDI_PITCH,
227
+ note_seq.MAX_MIDI_PITCH),
228
+ event_codec.EventRange('program', note_seq.MIN_MIDI_PROGRAM,
229
+ note_seq.MAX_MIDI_PROGRAM),
230
+ event_codec.EventRange('tie', 0, 0)
231
+ ])
232
+ res = metrics_utils.event_predictions_to_ns(
233
+ predictions, codec=codec,
234
+ encoding_spec=note_sequences.NoteEncodingWithTiesSpec)
235
+ self.assertProtoEquals(expected_ns, res['est_ns'])
236
+ self.assertEqual(0, res['est_invalid_events'])
237
+ self.assertEqual(0, res['est_dropped_events'])
238
+ np.testing.assert_array_equal([0, 0, 1, 1, 2, 2], res['raw_inputs'])
239
+
240
+ def test_frame_metrics(self):
241
+ ref = np.zeros(shape=(128, 5))
242
+ est = np.zeros(shape=(128, 5))
243
+
244
+ # one overlapping note, two false positives, two false negatives
245
+ ref[10, 0] = 127
246
+ ref[10, 1] = 127
247
+ ref[10, 2] = 127
248
+
249
+ est[10, 2] = 127
250
+ est[10, 3] = 127
251
+ est[10, 4] = 127
252
+
253
+ prec, rec, _ = metrics_utils.frame_metrics(ref, est, velocity_threshold=1)
254
+ np.testing.assert_approx_equal(prec, 1/3)
255
+ np.testing.assert_approx_equal(rec, 1/3)
256
+
257
+
258
+ if __name__ == '__main__':
259
+ tf.test.main()
mt3/mixing.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The MT3 Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Functions for mixing (in the audio sense) multiple transcription examples."""
16
+
17
+ from typing import Callable, Optional, Sequence
18
+
19
+ import gin
20
+
21
+ from mt3 import event_codec
22
+ from mt3 import run_length_encoding
23
+
24
+ import numpy as np
25
+ import seqio
26
+ import tensorflow as tf
27
+
28
+
29
+ @gin.configurable
30
+ def mix_transcription_examples(
31
+ ds: tf.data.Dataset,
32
+ sequence_length: seqio.preprocessors.SequenceLengthType,
33
+ output_features: seqio.preprocessors.OutputFeaturesType,
34
+ codec: event_codec.Codec,
35
+ inputs_feature_key: str = 'inputs',
36
+ targets_feature_keys: Sequence[str] = ('targets',),
37
+ max_examples_per_mix: Optional[int] = None,
38
+ shuffle_buffer_size: int = seqio.SHUFFLE_BUFFER_SIZE
39
+ ) -> Callable[..., tf.data.Dataset]:
40
+ """Preprocessor that mixes together "batches" of transcription examples.
41
+
42
+ Args:
43
+ ds: Dataset of individual transcription examples, each of which should
44
+ have an 'inputs' field containing 1D audio samples (currently only
45
+ audio encoders that use raw samples as an intermediate representation
46
+ are supported), and a 'targets' field containing run-length encoded
47
+ note events.
48
+ sequence_length: Dictionary mapping feature key to length.
49
+ output_features: Dictionary mapping feature key to spec.
50
+ codec: An event_codec.Codec used to interpret the target events.
51
+ inputs_feature_key: Feature key for inputs which will be mixed as audio.
52
+ targets_feature_keys: List of feature keys for targets, each of which will
53
+ be merged (separately) as run-length encoded note events.
54
+ max_examples_per_mix: Maximum number of individual examples to mix together.
55
+ shuffle_buffer_size: Size of shuffle buffer to use for shuffle prior to
56
+ mixing.
57
+
58
+ Returns:
59
+ Dataset containing mixed examples.
60
+ """
61
+ if max_examples_per_mix is None:
62
+ return ds
63
+
64
+ # TODO(iansimon): is there a way to use seqio's seed?
65
+ ds = tf.data.Dataset.sample_from_datasets([
66
+ ds.shuffle(
67
+ buffer_size=shuffle_buffer_size // max_examples_per_mix
68
+ ).padded_batch(batch_size=i) for i in range(1, max_examples_per_mix + 1)
69
+ ])
70
+
71
+ def mix_inputs(ex):
72
+ samples = tf.reduce_sum(ex[inputs_feature_key], axis=0)
73
+ norm = tf.linalg.norm(samples, ord=np.inf)
74
+ ex[inputs_feature_key] = tf.math.divide_no_nan(samples, norm)
75
+ return ex
76
+ ds = ds.map(mix_inputs, num_parallel_calls=tf.data.experimental.AUTOTUNE)
77
+
78
+ max_tokens = sequence_length['targets']
79
+ if output_features['targets'].add_eos:
80
+ # Leave room to insert an EOS token.
81
+ max_tokens -= 1
82
+
83
+ def mix_targets(ex):
84
+ for k in targets_feature_keys:
85
+ ex[k] = run_length_encoding.merge_run_length_encoded_targets(
86
+ targets=ex[k],
87
+ codec=codec)
88
+ return ex
89
+ ds = ds.map(mix_targets, num_parallel_calls=tf.data.experimental.AUTOTUNE)
90
+
91
+ return ds
mt3/models.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The MT3 Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Feature converter and model for continuous inputs."""
16
+
17
+ from typing import Mapping
18
+ import seqio
19
+ from t5x import decoding
20
+ from t5x import models
21
+ import tensorflow as tf
22
+
23
+
24
+ class ContinuousInputsEncDecFeatureConverter(seqio.FeatureConverter):
25
+ """Feature converter for an encoder-decoder with continuous inputs."""
26
+
27
+ TASK_FEATURES = {
28
+ "inputs": seqio.FeatureConverter.FeatureSpec(dtype=tf.float32, rank=2),
29
+ "targets": seqio.FeatureConverter.FeatureSpec(dtype=tf.int32),
30
+ }
31
+ MODEL_FEATURES = {
32
+ "encoder_input_tokens":
33
+ seqio.FeatureConverter.FeatureSpec(dtype=tf.float32, rank=2),
34
+ "decoder_target_tokens":
35
+ seqio.FeatureConverter.FeatureSpec(dtype=tf.int32),
36
+ "decoder_input_tokens":
37
+ seqio.FeatureConverter.FeatureSpec(dtype=tf.int32),
38
+ "decoder_loss_weights":
39
+ seqio.FeatureConverter.FeatureSpec(dtype=tf.int32),
40
+ }
41
+ PACKING_FEATURE_DTYPES = {
42
+ "encoder_segment_ids": tf.int32,
43
+ "decoder_segment_ids": tf.int32,
44
+ "encoder_positions": tf.int32,
45
+ "decoder_positions": tf.int32
46
+ }
47
+
48
+ def _convert_features(
49
+ self, ds: tf.data.Dataset,
50
+ task_feature_lengths: Mapping[str, int]) -> tf.data.Dataset:
51
+ """Convert the dataset to be fed to the encoder-decoder model.
52
+
53
+ The conversion process involves three steps
54
+
55
+ 1. Each feature in the `task_feature_lengths` is trimmed/padded and
56
+ optionally packed depending on the value of self.pack.
57
+ 2. "inputs" fields are mapped to the encoder input and "targets" are mapped
58
+ to decoder input (after being shifted) and target.
59
+
60
+ All the keys in the `task_feature_lengths` should be present in the input
61
+ dataset, which may contain some extra features that are not in the
62
+ `task_feature_lengths`. They will not be included in the output dataset.
63
+ One common scenario is the "inputs_pretokenized" and "targets_pretokenized"
64
+ fields.
65
+
66
+ Args:
67
+ ds: an input tf.data.Dataset to be converted.
68
+ task_feature_lengths: a mapping from feature to its length.
69
+
70
+ Returns:
71
+ ds: the converted dataset.
72
+ """
73
+
74
+ def convert_example(
75
+ features: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]:
76
+ # targets_segment_id is present only for a packed dataset.
77
+ decoder_input_tokens = seqio.autoregressive_inputs(
78
+ features["targets"],
79
+ sequence_id=features.get("targets_segment_ids", None))
80
+
81
+ d = {"encoder_input_tokens": features["inputs"],
82
+ "decoder_target_tokens": features["targets"],
83
+ "decoder_input_tokens": decoder_input_tokens,
84
+ # Loss is computed for all but the padding positions.
85
+ "decoder_loss_weights":
86
+ seqio.non_padding_position(features["targets"])}
87
+
88
+ if self.pack:
89
+ d["encoder_segment_ids"] = features["inputs_segment_ids"]
90
+ d["decoder_segment_ids"] = features["targets_segment_ids"]
91
+ d["encoder_positions"] = features["inputs_positions"]
92
+ d["decoder_positions"] = features["targets_positions"]
93
+
94
+ return d
95
+
96
+ ds = self._pack_or_pad(ds, task_feature_lengths)
97
+ return ds.map(
98
+ convert_example, num_parallel_calls=tf.data.experimental.AUTOTUNE)
99
+
100
+ def get_model_feature_lengths(
101
+ self, task_feature_lengths: Mapping[str, int]) -> Mapping[str, int]:
102
+ """Define the length relationship between input and output features."""
103
+ encoder_length = task_feature_lengths["inputs"]
104
+ decoder_length = task_feature_lengths["targets"]
105
+
106
+ model_feature_lengths = {
107
+ "encoder_input_tokens": encoder_length,
108
+ "decoder_target_tokens": decoder_length,
109
+ "decoder_input_tokens": decoder_length,
110
+ "decoder_loss_weights": decoder_length
111
+ }
112
+ if self.pack:
113
+ model_feature_lengths["encoder_segment_ids"] = encoder_length
114
+ model_feature_lengths["decoder_segment_ids"] = decoder_length
115
+ model_feature_lengths["encoder_positions"] = encoder_length
116
+ model_feature_lengths["decoder_positions"] = decoder_length
117
+
118
+ return model_feature_lengths
119
+
120
+
121
+ class ContinuousInputsEncoderDecoderModel(models.EncoderDecoderModel):
122
+ """Encoder-decoder model with continuous inputs."""
123
+
124
+ FEATURE_CONVERTER_CLS = ContinuousInputsEncDecFeatureConverter
125
+
126
+ def __init__(self, module, input_vocabulary, output_vocabulary, optimizer_def,
127
+ input_depth, decode_fn=decoding.beam_search, label_smoothing=0.0,
128
+ z_loss=0.0, loss_normalizing_factor=None):
129
+ super().__init__(
130
+ module=module,
131
+ input_vocabulary=input_vocabulary,
132
+ output_vocabulary=output_vocabulary,
133
+ optimizer_def=optimizer_def,
134
+ decode_fn=decode_fn,
135
+ label_smoothing=label_smoothing,
136
+ z_loss=z_loss,
137
+ loss_normalizing_factor=loss_normalizing_factor)
138
+ self._input_depth = input_depth
139
+
140
+ def get_initial_variables(self, rng, input_shapes, input_types=None):
141
+ """Hacky override to bypass eval/infer inability to handle rank-3 inputs."""
142
+ encoder_shape = input_shapes["encoder_input_tokens"]
143
+ if len(encoder_shape) == 2:
144
+ input_shapes = {
145
+ "encoder_input_tokens": (*encoder_shape, self._input_depth),
146
+ **{k: v for k, v in input_shapes.items()
147
+ if k != "encoder_input_tokens"}
148
+ }
149
+ else:
150
+ assert encoder_shape[-1] == self._input_depth
151
+ return super().get_initial_variables(
152
+ rng=rng, input_shapes=input_shapes, input_types=input_types)
mt3/network.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The MT3 Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """T5.1.1 Transformer model."""
16
+
17
+ from typing import Any, Sequence
18
+
19
+ from flax import linen as nn
20
+ from flax import struct
21
+ import jax.numpy as jnp
22
+ from mt3 import layers
23
+
24
+
25
+ @struct.dataclass
26
+ class T5Config:
27
+ """Global hyperparameters used to minimize obnoxious kwarg plumbing."""
28
+ vocab_size: int
29
+ # Activation dtypes.
30
+ dtype: Any = jnp.float32
31
+ emb_dim: int = 512
32
+ num_heads: int = 8
33
+ num_encoder_layers: int = 6
34
+ num_decoder_layers: int = 6
35
+ head_dim: int = 64
36
+ mlp_dim: int = 2048
37
+ # Activation functions are retrieved from Flax.
38
+ mlp_activations: Sequence[str] = ('relu',)
39
+ dropout_rate: float = 0.1
40
+ # If `True`, the embedding weights are used in the decoder output layer.
41
+ logits_via_embedding: bool = False
42
+
43
+
44
+ class EncoderLayer(nn.Module):
45
+ """Transformer encoder layer."""
46
+ config: T5Config
47
+
48
+ @nn.compact
49
+ def __call__(self, inputs, encoder_mask=None, deterministic=False):
50
+ cfg = self.config
51
+
52
+ # Attention block.
53
+ assert inputs.ndim == 3
54
+ x = layers.LayerNorm(
55
+ dtype=cfg.dtype, name='pre_attention_layer_norm')(
56
+ inputs)
57
+ # [batch, length, emb_dim] -> [batch, length, emb_dim]
58
+ x = layers.MultiHeadDotProductAttention(
59
+ num_heads=cfg.num_heads,
60
+ dtype=cfg.dtype,
61
+ head_dim=cfg.head_dim,
62
+ dropout_rate=cfg.dropout_rate,
63
+ name='attention')(
64
+ x, x, encoder_mask, deterministic=deterministic)
65
+ x = nn.Dropout(
66
+ rate=cfg.dropout_rate, broadcast_dims=(-2,))(
67
+ x, deterministic=deterministic)
68
+ x = x + inputs
69
+
70
+ # MLP block.
71
+ y = layers.LayerNorm(dtype=cfg.dtype, name='pre_mlp_layer_norm')(x)
72
+ # [batch, length, emb_dim] -> [batch, length, emb_dim]
73
+ y = layers.MlpBlock(
74
+ intermediate_dim=cfg.mlp_dim,
75
+ activations=cfg.mlp_activations,
76
+ intermediate_dropout_rate=cfg.dropout_rate,
77
+ dtype=cfg.dtype,
78
+ name='mlp',
79
+ )(y, deterministic=deterministic)
80
+ y = nn.Dropout(
81
+ rate=cfg.dropout_rate, broadcast_dims=(-2,))(
82
+ y, deterministic=deterministic)
83
+ y = y + x
84
+
85
+ return y
86
+
87
+
88
+ class DecoderLayer(nn.Module):
89
+ """Transformer decoder layer that attends to the encoder."""
90
+ config: T5Config
91
+
92
+ @nn.compact
93
+ def __call__(self,
94
+ inputs,
95
+ encoded,
96
+ decoder_mask=None,
97
+ encoder_decoder_mask=None,
98
+ deterministic=False,
99
+ decode=False,
100
+ max_decode_length=None):
101
+ cfg = self.config
102
+
103
+ # inputs: embedded inputs to the decoder with shape [batch, length, emb_dim]
104
+ x = layers.LayerNorm(
105
+ dtype=cfg.dtype, name='pre_self_attention_layer_norm')(
106
+ inputs)
107
+
108
+ # Self-attention block
109
+ x = layers.MultiHeadDotProductAttention(
110
+ num_heads=cfg.num_heads,
111
+ dtype=cfg.dtype,
112
+ head_dim=cfg.head_dim,
113
+ dropout_rate=cfg.dropout_rate,
114
+ name='self_attention')(
115
+ x,
116
+ x,
117
+ decoder_mask,
118
+ deterministic=deterministic,
119
+ decode=decode)
120
+ x = nn.Dropout(
121
+ rate=cfg.dropout_rate, broadcast_dims=(-2,))(
122
+ x, deterministic=deterministic)
123
+ x = x + inputs
124
+
125
+ # Encoder-Decoder block.
126
+ y = layers.LayerNorm(
127
+ dtype=cfg.dtype, name='pre_cross_attention_layer_norm')(
128
+ x)
129
+ y = layers.MultiHeadDotProductAttention(
130
+ num_heads=cfg.num_heads,
131
+ dtype=cfg.dtype,
132
+ head_dim=cfg.head_dim,
133
+ dropout_rate=cfg.dropout_rate,
134
+ name='encoder_decoder_attention')(
135
+ y, encoded, encoder_decoder_mask, deterministic=deterministic)
136
+ y = nn.Dropout(
137
+ rate=cfg.dropout_rate, broadcast_dims=(-2,))(
138
+ y, deterministic=deterministic)
139
+ y = y + x
140
+
141
+ # MLP block.
142
+ z = layers.LayerNorm(dtype=cfg.dtype, name='pre_mlp_layer_norm')(y)
143
+ z = layers.MlpBlock(
144
+ intermediate_dim=cfg.mlp_dim,
145
+ activations=cfg.mlp_activations,
146
+ intermediate_dropout_rate=cfg.dropout_rate,
147
+ dtype=cfg.dtype,
148
+ name='mlp',
149
+ )(z, deterministic=deterministic)
150
+ z = nn.Dropout(
151
+ rate=cfg.dropout_rate, broadcast_dims=(-2,))(
152
+ z, deterministic=deterministic)
153
+ z = z + y
154
+
155
+ return z
156
+
157
+
158
+ class Encoder(nn.Module):
159
+ """A stack of encoder layers."""
160
+ config: T5Config
161
+
162
+ @nn.compact
163
+ def __call__(self,
164
+ encoder_input_tokens,
165
+ encoder_mask=None,
166
+ deterministic=False):
167
+ cfg = self.config
168
+ assert encoder_input_tokens.ndim == 3 # [batch, length, depth]
169
+
170
+ seq_length = encoder_input_tokens.shape[-2]
171
+ inputs_positions = jnp.arange(seq_length)[None, :]
172
+
173
+ # [batch, length, depth] -> [batch, length, emb_dim]
174
+ x = layers.DenseGeneral(
175
+ cfg.emb_dim,
176
+ dtype=cfg.dtype,
177
+ kernel_init=nn.linear.default_kernel_init,
178
+ kernel_axes=('vocab', 'embed'),
179
+ name='continuous_inputs_projection')(encoder_input_tokens)
180
+ x = x + layers.FixedEmbed(features=cfg.emb_dim)(inputs_positions)
181
+ x = nn.Dropout(
182
+ rate=cfg.dropout_rate, broadcast_dims=(-2,))(
183
+ x, deterministic=deterministic)
184
+ x = x.astype(cfg.dtype)
185
+
186
+ for lyr in range(cfg.num_encoder_layers):
187
+ # [batch, length, emb_dim] -> [batch, length, emb_dim]
188
+ x = EncoderLayer(
189
+ config=cfg,
190
+ name=f'layers_{lyr}')(x, encoder_mask, deterministic)
191
+
192
+ x = layers.LayerNorm(dtype=cfg.dtype, name='encoder_norm')(x)
193
+ return nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=deterministic)
194
+
195
+
196
+ class Decoder(nn.Module):
197
+ """A stack of decoder layers as a part of an encoder-decoder architecture."""
198
+ config: T5Config
199
+
200
+ @nn.compact
201
+ def __call__(self,
202
+ encoded,
203
+ decoder_input_tokens,
204
+ decoder_positions=None,
205
+ decoder_mask=None,
206
+ encoder_decoder_mask=None,
207
+ deterministic=False,
208
+ decode=False,
209
+ max_decode_length=None):
210
+ cfg = self.config
211
+ assert decoder_input_tokens.ndim == 2 # [batch, len]
212
+
213
+ seq_length = decoder_input_tokens.shape[-1]
214
+ decoder_positions = jnp.arange(seq_length)[None, :]
215
+
216
+ # [batch, length] -> [batch, length, emb_dim]
217
+ y = layers.Embed(
218
+ num_embeddings=cfg.vocab_size,
219
+ features=cfg.emb_dim,
220
+ dtype=cfg.dtype,
221
+ attend_dtype=jnp.float32, # for logit training stability
222
+ embedding_init=nn.initializers.normal(stddev=1.0),
223
+ one_hot=True,
224
+ name='token_embedder')(decoder_input_tokens.astype('int32'))
225
+ y = y + layers.FixedEmbed(features=cfg.emb_dim)(
226
+ decoder_positions, decode=decode)
227
+ y = nn.Dropout(
228
+ rate=cfg.dropout_rate, broadcast_dims=(-2,))(
229
+ y, deterministic=deterministic)
230
+ y = y.astype(cfg.dtype)
231
+
232
+ for lyr in range(cfg.num_decoder_layers):
233
+ # [batch, length, emb_dim] -> [batch, length, emb_dim]
234
+ y = DecoderLayer(
235
+ config=cfg, name=f'layers_{lyr}')(
236
+ y,
237
+ encoded,
238
+ decoder_mask=decoder_mask,
239
+ encoder_decoder_mask=encoder_decoder_mask,
240
+ deterministic=deterministic,
241
+ decode=decode,
242
+ max_decode_length=max_decode_length)
243
+
244
+ y = layers.LayerNorm(dtype=cfg.dtype, name='decoder_norm')(y)
245
+ y = nn.Dropout(
246
+ rate=cfg.dropout_rate, broadcast_dims=(-2,))(
247
+ y, deterministic=deterministic)
248
+
249
+ # [batch, length, emb_dim] -> [batch, length, vocab_size]
250
+ if cfg.logits_via_embedding:
251
+ # Use the transpose of embedding matrix for logit transform.
252
+ logits = self.shared_embedding.attend(y)
253
+ # Correctly normalize pre-softmax logits for this shared case.
254
+ logits = logits / jnp.sqrt(y.shape[-1])
255
+ else:
256
+ logits = layers.DenseGeneral(
257
+ cfg.vocab_size,
258
+ dtype=jnp.float32, # Use float32 for stabiliity.
259
+ kernel_axes=('embed', 'vocab'),
260
+ name='logits_dense')(
261
+ y)
262
+ return logits
263
+
264
+
265
+ class Transformer(nn.Module):
266
+ """An encoder-decoder Transformer model."""
267
+ config: T5Config
268
+
269
+ def setup(self):
270
+ cfg = self.config
271
+
272
+ self.encoder = Encoder(config=cfg)
273
+ self.decoder = Decoder(config=cfg)
274
+
275
+ def encode(self,
276
+ encoder_input_tokens,
277
+ encoder_segment_ids=None,
278
+ enable_dropout=True):
279
+ """Applies Transformer encoder-branch on the inputs."""
280
+ cfg = self.config
281
+ assert encoder_input_tokens.ndim == 3 # (batch, length, depth)
282
+
283
+ # Make padding attention mask; we don't actually mask out any input
284
+ # positions, letting the model potentially attend to the zero vector used as
285
+ # padding.
286
+ encoder_mask = layers.make_attention_mask(
287
+ jnp.ones(encoder_input_tokens.shape[:-1]),
288
+ jnp.ones(encoder_input_tokens.shape[:-1]),
289
+ dtype=cfg.dtype)
290
+ # Add segmentation block-diagonal attention mask if using segmented data.
291
+ if encoder_segment_ids is not None:
292
+ encoder_mask = layers.combine_masks(
293
+ encoder_mask,
294
+ layers.make_attention_mask(
295
+ encoder_segment_ids,
296
+ encoder_segment_ids,
297
+ jnp.equal,
298
+ dtype=cfg.dtype))
299
+
300
+ return self.encoder(
301
+ encoder_input_tokens, encoder_mask, deterministic=not enable_dropout)
302
+
303
+ def decode(
304
+ self,
305
+ encoded,
306
+ encoder_input_tokens, # only needed for masks
307
+ decoder_input_tokens,
308
+ decoder_target_tokens,
309
+ encoder_segment_ids=None,
310
+ decoder_segment_ids=None,
311
+ decoder_positions=None,
312
+ enable_dropout=True,
313
+ decode=False,
314
+ max_decode_length=None):
315
+ """Applies Transformer decoder-branch on encoded-input and target."""
316
+ cfg = self.config
317
+
318
+ # Make padding attention masks.
319
+ if decode:
320
+ # Do not mask decoder attention based on targets padding at
321
+ # decoding/inference time.
322
+ decoder_mask = None
323
+ encoder_decoder_mask = layers.make_attention_mask(
324
+ jnp.ones_like(decoder_target_tokens),
325
+ jnp.ones(encoder_input_tokens.shape[:-1]),
326
+ dtype=cfg.dtype)
327
+ else:
328
+ decoder_mask = layers.make_decoder_mask(
329
+ decoder_target_tokens=decoder_target_tokens,
330
+ dtype=cfg.dtype,
331
+ decoder_segment_ids=decoder_segment_ids)
332
+ encoder_decoder_mask = layers.make_attention_mask(
333
+ decoder_target_tokens > 0,
334
+ jnp.ones(encoder_input_tokens.shape[:-1]),
335
+ dtype=cfg.dtype)
336
+
337
+ # Add segmentation block-diagonal attention masks if using segmented data.
338
+ if encoder_segment_ids is not None:
339
+ if decode:
340
+ raise ValueError(
341
+ 'During decoding, packing should not be used but '
342
+ '`encoder_segment_ids` was passed to `Transformer.decode`.')
343
+
344
+ encoder_decoder_mask = layers.combine_masks(
345
+ encoder_decoder_mask,
346
+ layers.make_attention_mask(
347
+ decoder_segment_ids,
348
+ encoder_segment_ids,
349
+ jnp.equal,
350
+ dtype=cfg.dtype))
351
+
352
+ logits = self.decoder(
353
+ encoded,
354
+ decoder_input_tokens=decoder_input_tokens,
355
+ decoder_positions=decoder_positions,
356
+ decoder_mask=decoder_mask,
357
+ encoder_decoder_mask=encoder_decoder_mask,
358
+ deterministic=not enable_dropout,
359
+ decode=decode,
360
+ max_decode_length=max_decode_length)
361
+ return logits.astype(self.config.dtype)
362
+
363
+ def __call__(self,
364
+ encoder_input_tokens,
365
+ decoder_input_tokens,
366
+ decoder_target_tokens,
367
+ encoder_segment_ids=None,
368
+ decoder_segment_ids=None,
369
+ encoder_positions=None,
370
+ decoder_positions=None,
371
+ *,
372
+ enable_dropout: bool = True,
373
+ decode: bool = False):
374
+ """Applies Transformer model on the inputs.
375
+
376
+ This method requires both decoder_target_tokens and decoder_input_tokens,
377
+ which is a shifted version of the former. For a packed dataset, it usually
378
+ has additional processing applied. For example, the first element of each
379
+ sequence has id 0 instead of the shifted EOS id from the previous sequence.
380
+
381
+ Args:
382
+ encoder_input_tokens: input data to the encoder.
383
+ decoder_input_tokens: input token to the decoder.
384
+ decoder_target_tokens: target token to the decoder.
385
+ encoder_segment_ids: encoder segmentation info for packed examples.
386
+ decoder_segment_ids: decoder segmentation info for packed examples.
387
+ encoder_positions: encoder subsequence positions for packed examples.
388
+ decoder_positions: decoder subsequence positions for packed examples.
389
+ enable_dropout: Ensables dropout if set to True.
390
+ decode: Whether to prepare and use an autoregressive cache.
391
+
392
+ Returns:
393
+ logits array from full transformer.
394
+ """
395
+ encoded = self.encode(
396
+ encoder_input_tokens,
397
+ encoder_segment_ids=encoder_segment_ids,
398
+ enable_dropout=enable_dropout)
399
+
400
+ return self.decode(
401
+ encoded,
402
+ encoder_input_tokens, # only used for masks
403
+ decoder_input_tokens,
404
+ decoder_target_tokens,
405
+ encoder_segment_ids=encoder_segment_ids,
406
+ decoder_segment_ids=decoder_segment_ids,
407
+ decoder_positions=decoder_positions,
408
+ enable_dropout=enable_dropout,
409
+ decode=decode)
mt3/note_sequences.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The MT3 Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Helper functions that operate on NoteSequence protos."""
16
+
17
+ import dataclasses
18
+ import itertools
19
+
20
+ from typing import MutableMapping, MutableSet, Optional, Sequence, Tuple
21
+
22
+ from mt3 import event_codec
23
+ from mt3 import run_length_encoding
24
+ from mt3 import vocabularies
25
+
26
+ import note_seq
27
+
28
+ DEFAULT_VELOCITY = 100
29
+ DEFAULT_NOTE_DURATION = 0.01
30
+
31
+ # Quantization can result in zero-length notes; enforce a minimum duration.
32
+ MIN_NOTE_DURATION = 0.01
33
+
34
+
35
+ @dataclasses.dataclass
36
+ class TrackSpec:
37
+ name: str
38
+ program: int = 0
39
+ is_drum: bool = False
40
+
41
+
42
+ def extract_track(ns, program, is_drum):
43
+ track = note_seq.NoteSequence(ticks_per_quarter=220)
44
+ track_notes = [note for note in ns.notes
45
+ if note.program == program and note.is_drum == is_drum]
46
+ track.notes.extend(track_notes)
47
+ track.total_time = (max(note.end_time for note in track.notes)
48
+ if track.notes else 0.0)
49
+ return track
50
+
51
+
52
+ def trim_overlapping_notes(ns: note_seq.NoteSequence) -> note_seq.NoteSequence:
53
+ """Trim overlapping notes from a NoteSequence, dropping zero-length notes."""
54
+ ns_trimmed = note_seq.NoteSequence()
55
+ ns_trimmed.CopyFrom(ns)
56
+ channels = set((note.pitch, note.program, note.is_drum)
57
+ for note in ns_trimmed.notes)
58
+ for pitch, program, is_drum in channels:
59
+ notes = [note for note in ns_trimmed.notes if note.pitch == pitch
60
+ and note.program == program and note.is_drum == is_drum]
61
+ sorted_notes = sorted(notes, key=lambda note: note.start_time)
62
+ for i in range(1, len(sorted_notes)):
63
+ if sorted_notes[i - 1].end_time > sorted_notes[i].start_time:
64
+ sorted_notes[i - 1].end_time = sorted_notes[i].start_time
65
+ valid_notes = [note for note in ns_trimmed.notes
66
+ if note.start_time < note.end_time]
67
+ del ns_trimmed.notes[:]
68
+ ns_trimmed.notes.extend(valid_notes)
69
+ return ns_trimmed
70
+
71
+
72
+ def assign_instruments(ns: note_seq.NoteSequence) -> None:
73
+ """Assign instrument numbers to notes; modifies NoteSequence in place."""
74
+ program_instruments = {}
75
+ for note in ns.notes:
76
+ if note.program not in program_instruments and not note.is_drum:
77
+ num_instruments = len(program_instruments)
78
+ note.instrument = (num_instruments if num_instruments < 9
79
+ else num_instruments + 1)
80
+ program_instruments[note.program] = note.instrument
81
+ elif note.is_drum:
82
+ note.instrument = 9
83
+ else:
84
+ note.instrument = program_instruments[note.program]
85
+
86
+
87
+ def validate_note_sequence(ns: note_seq.NoteSequence) -> None:
88
+ """Raise ValueError if NoteSequence contains invalid notes."""
89
+ for note in ns.notes:
90
+ if note.start_time >= note.end_time:
91
+ raise ValueError('note has start time >= end time: %f >= %f' %
92
+ (note.start_time, note.end_time))
93
+ if note.velocity == 0:
94
+ raise ValueError('note has zero velocity')
95
+
96
+
97
+ def note_arrays_to_note_sequence(
98
+ onset_times: Sequence[float],
99
+ pitches: Sequence[int],
100
+ offset_times: Optional[Sequence[float]] = None,
101
+ velocities: Optional[Sequence[int]] = None,
102
+ programs: Optional[Sequence[int]] = None,
103
+ is_drums: Optional[Sequence[bool]] = None
104
+ ) -> note_seq.NoteSequence:
105
+ """Convert note onset / offset / pitch / velocity arrays to NoteSequence."""
106
+ ns = note_seq.NoteSequence(ticks_per_quarter=220)
107
+ for onset_time, offset_time, pitch, velocity, program, is_drum in itertools.zip_longest(
108
+ onset_times, [] if offset_times is None else offset_times,
109
+ pitches, [] if velocities is None else velocities,
110
+ [] if programs is None else programs,
111
+ [] if is_drums is None else is_drums):
112
+ if offset_time is None:
113
+ offset_time = onset_time + DEFAULT_NOTE_DURATION
114
+ if velocity is None:
115
+ velocity = DEFAULT_VELOCITY
116
+ if program is None:
117
+ program = 0
118
+ if is_drum is None:
119
+ is_drum = False
120
+ ns.notes.add(
121
+ start_time=onset_time,
122
+ end_time=offset_time,
123
+ pitch=pitch,
124
+ velocity=velocity,
125
+ program=program,
126
+ is_drum=is_drum)
127
+ ns.total_time = max(ns.total_time, offset_time)
128
+ assign_instruments(ns)
129
+ return ns
130
+
131
+
132
+ @dataclasses.dataclass
133
+ class NoteEventData:
134
+ pitch: int
135
+ velocity: Optional[int] = None
136
+ program: Optional[int] = None
137
+ is_drum: Optional[bool] = None
138
+ instrument: Optional[int] = None
139
+
140
+
141
+ def note_sequence_to_onsets(
142
+ ns: note_seq.NoteSequence
143
+ ) -> Tuple[Sequence[float], Sequence[NoteEventData]]:
144
+ """Extract note onsets and pitches from NoteSequence proto."""
145
+ # Sort by pitch to use as a tiebreaker for subsequent stable sort.
146
+ notes = sorted(ns.notes, key=lambda note: note.pitch)
147
+ return ([note.start_time for note in notes],
148
+ [NoteEventData(pitch=note.pitch) for note in notes])
149
+
150
+
151
+ def note_sequence_to_onsets_and_offsets(
152
+ ns: note_seq.NoteSequence,
153
+ ) -> Tuple[Sequence[float], Sequence[NoteEventData]]:
154
+ """Extract onset & offset times and pitches from a NoteSequence proto.
155
+
156
+ The onset & offset times will not necessarily be in sorted order.
157
+
158
+ Args:
159
+ ns: NoteSequence from which to extract onsets and offsets.
160
+
161
+ Returns:
162
+ times: A list of note onset and offset times.
163
+ values: A list of NoteEventData objects where velocity is zero for note
164
+ offsets.
165
+ """
166
+ # Sort by pitch and put offsets before onsets as a tiebreaker for subsequent
167
+ # stable sort.
168
+ notes = sorted(ns.notes, key=lambda note: note.pitch)
169
+ times = ([note.end_time for note in notes] +
170
+ [note.start_time for note in notes])
171
+ values = ([NoteEventData(pitch=note.pitch, velocity=0) for note in notes] +
172
+ [NoteEventData(pitch=note.pitch, velocity=note.velocity)
173
+ for note in notes])
174
+ return times, values
175
+
176
+
177
+ def note_sequence_to_onsets_and_offsets_and_programs(
178
+ ns: note_seq.NoteSequence,
179
+ ) -> Tuple[Sequence[float], Sequence[NoteEventData]]:
180
+ """Extract onset & offset times and pitches & programs from a NoteSequence.
181
+
182
+ The onset & offset times will not necessarily be in sorted order.
183
+
184
+ Args:
185
+ ns: NoteSequence from which to extract onsets and offsets.
186
+
187
+ Returns:
188
+ times: A list of note onset and offset times.
189
+ values: A list of NoteEventData objects where velocity is zero for note
190
+ offsets.
191
+ """
192
+ # Sort by program and pitch and put offsets before onsets as a tiebreaker for
193
+ # subsequent stable sort.
194
+ notes = sorted(ns.notes,
195
+ key=lambda note: (note.is_drum, note.program, note.pitch))
196
+ times = ([note.end_time for note in notes if not note.is_drum] +
197
+ [note.start_time for note in notes])
198
+ values = ([NoteEventData(pitch=note.pitch, velocity=0,
199
+ program=note.program, is_drum=False)
200
+ for note in notes if not note.is_drum] +
201
+ [NoteEventData(pitch=note.pitch, velocity=note.velocity,
202
+ program=note.program, is_drum=note.is_drum)
203
+ for note in notes])
204
+ return times, values
205
+
206
+
207
+ @dataclasses.dataclass
208
+ class NoteEncodingState:
209
+ """Encoding state for note transcription, keeping track of active pitches."""
210
+ # velocity bin for active pitches and programs
211
+ active_pitches: MutableMapping[Tuple[int, int], int] = dataclasses.field(
212
+ default_factory=dict)
213
+
214
+
215
+ def note_event_data_to_events(
216
+ state: Optional[NoteEncodingState],
217
+ value: NoteEventData,
218
+ codec: event_codec.Codec,
219
+ ) -> Sequence[event_codec.Event]:
220
+ """Convert note event data to a sequence of events."""
221
+ if value.velocity is None:
222
+ # onsets only, no program or velocity
223
+ return [event_codec.Event('pitch', value.pitch)]
224
+ else:
225
+ num_velocity_bins = vocabularies.num_velocity_bins_from_codec(codec)
226
+ velocity_bin = vocabularies.velocity_to_bin(
227
+ value.velocity, num_velocity_bins)
228
+ if value.program is None:
229
+ # onsets + offsets + velocities only, no programs
230
+ if state is not None:
231
+ state.active_pitches[(value.pitch, 0)] = velocity_bin
232
+ return [event_codec.Event('velocity', velocity_bin),
233
+ event_codec.Event('pitch', value.pitch)]
234
+ else:
235
+ if value.is_drum:
236
+ # drum events use a separate vocabulary
237
+ return [event_codec.Event('velocity', velocity_bin),
238
+ event_codec.Event('drum', value.pitch)]
239
+ else:
240
+ # program + velocity + pitch
241
+ if state is not None:
242
+ state.active_pitches[(value.pitch, value.program)] = velocity_bin
243
+ return [event_codec.Event('program', value.program),
244
+ event_codec.Event('velocity', velocity_bin),
245
+ event_codec.Event('pitch', value.pitch)]
246
+
247
+
248
+ def note_encoding_state_to_events(
249
+ state: NoteEncodingState
250
+ ) -> Sequence[event_codec.Event]:
251
+ """Output program and pitch events for active notes plus a final tie event."""
252
+ events = []
253
+ for pitch, program in sorted(
254
+ state.active_pitches.keys(), key=lambda k: k[::-1]):
255
+ if state.active_pitches[(pitch, program)]:
256
+ events += [event_codec.Event('program', program),
257
+ event_codec.Event('pitch', pitch)]
258
+ events.append(event_codec.Event('tie', 0))
259
+ return events
260
+
261
+
262
+ @dataclasses.dataclass
263
+ class NoteDecodingState:
264
+ """Decoding state for note transcription."""
265
+ current_time: float = 0.0
266
+ # velocity to apply to subsequent pitch events (zero for note-off)
267
+ current_velocity: int = DEFAULT_VELOCITY
268
+ # program to apply to subsequent pitch events
269
+ current_program: int = 0
270
+ # onset time and velocity for active pitches and programs
271
+ active_pitches: MutableMapping[Tuple[int, int],
272
+ Tuple[float, int]] = dataclasses.field(
273
+ default_factory=dict)
274
+ # pitches (with programs) to continue from previous segment
275
+ tied_pitches: MutableSet[Tuple[int, int]] = dataclasses.field(
276
+ default_factory=set)
277
+ # whether or not we are in the tie section at the beginning of a segment
278
+ is_tie_section: bool = False
279
+ # partially-decoded NoteSequence
280
+ note_sequence: note_seq.NoteSequence = dataclasses.field(
281
+ default_factory=lambda: note_seq.NoteSequence(ticks_per_quarter=220))
282
+
283
+
284
+ def decode_note_onset_event(
285
+ state: NoteDecodingState,
286
+ time: float,
287
+ event: event_codec.Event,
288
+ codec: event_codec.Codec,
289
+ ) -> None:
290
+ """Process note onset event and update decoding state."""
291
+ if event.type == 'pitch':
292
+ state.note_sequence.notes.add(
293
+ start_time=time, end_time=time + DEFAULT_NOTE_DURATION,
294
+ pitch=event.value, velocity=DEFAULT_VELOCITY)
295
+ state.note_sequence.total_time = max(state.note_sequence.total_time,
296
+ time + DEFAULT_NOTE_DURATION)
297
+ else:
298
+ raise ValueError('unexpected event type: %s' % event.type)
299
+
300
+
301
+ def _add_note_to_sequence(
302
+ ns: note_seq.NoteSequence,
303
+ start_time: float, end_time: float, pitch: int, velocity: int,
304
+ program: int = 0, is_drum: bool = False
305
+ ) -> None:
306
+ end_time = max(end_time, start_time + MIN_NOTE_DURATION)
307
+ ns.notes.add(
308
+ start_time=start_time, end_time=end_time,
309
+ pitch=pitch, velocity=velocity, program=program, is_drum=is_drum)
310
+ ns.total_time = max(ns.total_time, end_time)
311
+
312
+
313
+ def decode_note_event(
314
+ state: NoteDecodingState,
315
+ time: float,
316
+ event: event_codec.Event,
317
+ codec: event_codec.Codec
318
+ ) -> None:
319
+ """Process note event and update decoding state."""
320
+ if time < state.current_time:
321
+ raise ValueError('event time < current time, %f < %f' % (
322
+ time, state.current_time))
323
+ state.current_time = time
324
+ if event.type == 'pitch':
325
+ pitch = event.value
326
+ if state.is_tie_section:
327
+ # "tied" pitch
328
+ if (pitch, state.current_program) not in state.active_pitches:
329
+ raise ValueError('inactive pitch/program in tie section: %d/%d' %
330
+ (pitch, state.current_program))
331
+ if (pitch, state.current_program) in state.tied_pitches:
332
+ raise ValueError('pitch/program is already tied: %d/%d' %
333
+ (pitch, state.current_program))
334
+ state.tied_pitches.add((pitch, state.current_program))
335
+ elif state.current_velocity == 0:
336
+ # note offset
337
+ if (pitch, state.current_program) not in state.active_pitches:
338
+ raise ValueError('note-off for inactive pitch/program: %d/%d' %
339
+ (pitch, state.current_program))
340
+ onset_time, onset_velocity = state.active_pitches.pop(
341
+ (pitch, state.current_program))
342
+ _add_note_to_sequence(
343
+ state.note_sequence, start_time=onset_time, end_time=time,
344
+ pitch=pitch, velocity=onset_velocity, program=state.current_program)
345
+ else:
346
+ # note onset
347
+ if (pitch, state.current_program) in state.active_pitches:
348
+ # The pitch is already active; this shouldn't really happen but we'll
349
+ # try to handle it gracefully by ending the previous note and starting a
350
+ # new one.
351
+ onset_time, onset_velocity = state.active_pitches.pop(
352
+ (pitch, state.current_program))
353
+ _add_note_to_sequence(
354
+ state.note_sequence, start_time=onset_time, end_time=time,
355
+ pitch=pitch, velocity=onset_velocity, program=state.current_program)
356
+ state.active_pitches[(pitch, state.current_program)] = (
357
+ time, state.current_velocity)
358
+ elif event.type == 'drum':
359
+ # drum onset (drums have no offset)
360
+ if state.current_velocity == 0:
361
+ raise ValueError('velocity cannot be zero for drum event')
362
+ offset_time = time + DEFAULT_NOTE_DURATION
363
+ _add_note_to_sequence(
364
+ state.note_sequence, start_time=time, end_time=offset_time,
365
+ pitch=event.value, velocity=state.current_velocity, is_drum=True)
366
+ elif event.type == 'velocity':
367
+ # velocity change
368
+ num_velocity_bins = vocabularies.num_velocity_bins_from_codec(codec)
369
+ velocity = vocabularies.bin_to_velocity(event.value, num_velocity_bins)
370
+ state.current_velocity = velocity
371
+ elif event.type == 'program':
372
+ # program change
373
+ state.current_program = event.value
374
+ elif event.type == 'tie':
375
+ # end of tie section; end active notes that weren't declared tied
376
+ if not state.is_tie_section:
377
+ raise ValueError('tie section end event when not in tie section')
378
+ for (pitch, program) in list(state.active_pitches.keys()):
379
+ if (pitch, program) not in state.tied_pitches:
380
+ onset_time, onset_velocity = state.active_pitches.pop((pitch, program))
381
+ _add_note_to_sequence(
382
+ state.note_sequence,
383
+ start_time=onset_time, end_time=state.current_time,
384
+ pitch=pitch, velocity=onset_velocity, program=program)
385
+ state.is_tie_section = False
386
+ else:
387
+ raise ValueError('unexpected event type: %s' % event.type)
388
+
389
+
390
+ def begin_tied_pitches_section(state: NoteDecodingState) -> None:
391
+ """Begin the tied pitches section at the start of a segment."""
392
+ state.tied_pitches = set()
393
+ state.is_tie_section = True
394
+
395
+
396
+ def flush_note_decoding_state(
397
+ state: NoteDecodingState
398
+ ) -> note_seq.NoteSequence:
399
+ """End all active notes and return resulting NoteSequence."""
400
+ for onset_time, _ in state.active_pitches.values():
401
+ state.current_time = max(state.current_time, onset_time + MIN_NOTE_DURATION)
402
+ for (pitch, program) in list(state.active_pitches.keys()):
403
+ onset_time, onset_velocity = state.active_pitches.pop((pitch, program))
404
+ _add_note_to_sequence(
405
+ state.note_sequence, start_time=onset_time, end_time=state.current_time,
406
+ pitch=pitch, velocity=onset_velocity, program=program)
407
+ assign_instruments(state.note_sequence)
408
+ return state.note_sequence
409
+
410
+
411
+ class NoteEncodingSpecType(run_length_encoding.EventEncodingSpec):
412
+ pass
413
+
414
+
415
+ # encoding spec for modeling note onsets only
416
+ NoteOnsetEncodingSpec = NoteEncodingSpecType(
417
+ init_encoding_state_fn=lambda: None,
418
+ encode_event_fn=note_event_data_to_events,
419
+ encoding_state_to_events_fn=None,
420
+ init_decoding_state_fn=NoteDecodingState,
421
+ begin_decoding_segment_fn=lambda state: None,
422
+ decode_event_fn=decode_note_onset_event,
423
+ flush_decoding_state_fn=lambda state: state.note_sequence)
424
+
425
+
426
+ # encoding spec for modeling onsets and offsets
427
+ NoteEncodingSpec = NoteEncodingSpecType(
428
+ init_encoding_state_fn=lambda: None,
429
+ encode_event_fn=note_event_data_to_events,
430
+ encoding_state_to_events_fn=None,
431
+ init_decoding_state_fn=NoteDecodingState,
432
+ begin_decoding_segment_fn=lambda state: None,
433
+ decode_event_fn=decode_note_event,
434
+ flush_decoding_state_fn=flush_note_decoding_state)
435
+
436
+
437
+ # encoding spec for modeling onsets and offsets, with a "tie" section at the
438
+ # beginning of each segment listing already-active notes
439
+ NoteEncodingWithTiesSpec = NoteEncodingSpecType(
440
+ init_encoding_state_fn=NoteEncodingState,
441
+ encode_event_fn=note_event_data_to_events,
442
+ encoding_state_to_events_fn=note_encoding_state_to_events,
443
+ init_decoding_state_fn=NoteDecodingState,
444
+ begin_decoding_segment_fn=begin_tied_pitches_section,
445
+ decode_event_fn=decode_note_event,
446
+ flush_decoding_state_fn=flush_note_decoding_state)
mt3/note_sequences_test.py ADDED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The MT3 Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tests for note_sequences."""
16
+
17
+ from mt3 import event_codec
18
+ from mt3 import note_sequences
19
+ from mt3 import run_length_encoding
20
+
21
+ import note_seq
22
+ import numpy as np
23
+ import tensorflow as tf
24
+
25
+ codec = event_codec.Codec(
26
+ max_shift_steps=100,
27
+ steps_per_second=100,
28
+ event_ranges=[
29
+ event_codec.EventRange('pitch', note_seq.MIN_MIDI_PITCH,
30
+ note_seq.MAX_MIDI_PITCH),
31
+ event_codec.EventRange('velocity', 0, 127),
32
+ event_codec.EventRange('drum', note_seq.MIN_MIDI_PITCH,
33
+ note_seq.MAX_MIDI_PITCH),
34
+ event_codec.EventRange('program', note_seq.MIN_MIDI_PROGRAM,
35
+ note_seq.MAX_MIDI_PROGRAM),
36
+ event_codec.EventRange('tie', 0, 0)
37
+ ])
38
+
39
+
40
+ class RunLengthEncodingTest(tf.test.TestCase):
41
+
42
+ def test_encode_and_index_note_sequence(self):
43
+ ns = note_seq.NoteSequence()
44
+ ns.notes.add(start_time=1.0,
45
+ end_time=1.1,
46
+ pitch=61,
47
+ velocity=100)
48
+ ns.notes.add(start_time=2.0,
49
+ end_time=2.1,
50
+ pitch=62,
51
+ velocity=100)
52
+ ns.notes.add(start_time=3.0,
53
+ end_time=3.1,
54
+ pitch=63,
55
+ velocity=100)
56
+ ns.total_time = ns.notes[-1].end_time
57
+
58
+ frame_times = np.arange(0, 4, step=.001)
59
+
60
+ event_times, event_values = note_sequences.note_sequence_to_onsets(ns)
61
+ events, event_start_indices, event_end_indices, _, _ = run_length_encoding.encode_and_index_events(
62
+ state=None, event_times=event_times, event_values=event_values,
63
+ encode_event_fn=note_sequences.note_event_data_to_events,
64
+ codec=codec, frame_times=frame_times)
65
+
66
+ self.assertEqual(len(frame_times), len(event_start_indices))
67
+ self.assertEqual(len(frame_times), len(event_end_indices))
68
+ self.assertLen(events, 403)
69
+ expected_events = ([1] * 100 +
70
+ [162] +
71
+ [1] * 100 +
72
+ [163] +
73
+ [1] * 100 +
74
+ [164] +
75
+ [1] * 100)
76
+ np.testing.assert_array_equal(expected_events, events)
77
+
78
+ self.assertEqual(event_start_indices[0], 0)
79
+ self.assertEqual(event_end_indices[0], 0)
80
+
81
+ self.assertEqual(162, events[100])
82
+ self.assertEqual(1.0, frame_times[1000])
83
+ self.assertEqual(event_start_indices[1000], 100)
84
+ self.assertEqual(event_end_indices[1000], 100)
85
+
86
+ self.assertEqual(163, events[201])
87
+ self.assertEqual(2.0, frame_times[2000])
88
+ self.assertEqual(event_start_indices[2000], 201)
89
+ self.assertEqual(event_end_indices[2000], 201)
90
+
91
+ self.assertEqual(164, events[302])
92
+ self.assertEqual(3.0, frame_times[3000])
93
+ self.assertEqual(event_start_indices[3000], 302)
94
+ self.assertEqual(event_end_indices[3000], 302)
95
+
96
+ self.assertEqual(1, events[-1])
97
+ self.assertEqual(3.999, frame_times[-1])
98
+ self.assertEqual(event_start_indices[-1], 402)
99
+ self.assertEqual(event_end_indices[-1], len(expected_events))
100
+
101
+ def test_encode_and_index_note_sequence_velocity(self):
102
+ ns = note_seq.NoteSequence()
103
+ ns.notes.add(start_time=1.0,
104
+ end_time=3.0,
105
+ pitch=61,
106
+ velocity=1)
107
+ ns.notes.add(start_time=2.0,
108
+ end_time=4.0,
109
+ pitch=62,
110
+ velocity=127)
111
+ ns.total_time = ns.notes[-1].end_time
112
+
113
+ frame_times = np.arange(0, 4, step=.001)
114
+
115
+ event_times, event_values = (
116
+ note_sequences.note_sequence_to_onsets_and_offsets(ns))
117
+ events, event_start_indices, event_end_indices, _, _ = run_length_encoding.encode_and_index_events(
118
+ state=None, event_times=event_times, event_values=event_values,
119
+ encode_event_fn=note_sequences.note_event_data_to_events,
120
+ codec=codec, frame_times=frame_times)
121
+
122
+ self.assertEqual(len(frame_times), len(event_start_indices))
123
+ self.assertEqual(len(frame_times), len(event_end_indices))
124
+ self.assertLen(events, 408)
125
+ expected_events = ([1] * 100 +
126
+ [230, 162] +
127
+ [1] * 100 +
128
+ [356, 163] +
129
+ [1] * 100 +
130
+ [229, 162] +
131
+ [1] * 100 +
132
+ [229, 163])
133
+ np.testing.assert_array_equal(expected_events, events)
134
+
135
+ self.assertEqual(event_start_indices[0], 0)
136
+ self.assertEqual(event_end_indices[0], 0)
137
+
138
+ self.assertEqual(230, events[100])
139
+ self.assertEqual(162, events[101])
140
+ self.assertEqual(1.0, frame_times[1000])
141
+ self.assertEqual(event_start_indices[1000], 100)
142
+ self.assertEqual(event_end_indices[1000], 100)
143
+
144
+ self.assertEqual(356, events[202])
145
+ self.assertEqual(163, events[203])
146
+ self.assertEqual(2.0, frame_times[2000])
147
+ self.assertEqual(event_start_indices[2000], 202)
148
+ self.assertEqual(event_end_indices[2000], 202)
149
+
150
+ self.assertEqual(229, events[304])
151
+ self.assertEqual(162, events[305])
152
+ self.assertEqual(3.0, frame_times[3000])
153
+ self.assertEqual(event_start_indices[3000], 304)
154
+ self.assertEqual(event_end_indices[3000], 304)
155
+
156
+ self.assertEqual(229, events[406])
157
+ self.assertEqual(163, events[407])
158
+ self.assertEqual(3.999, frame_times[-1])
159
+ self.assertEqual(event_start_indices[-1], 405)
160
+ self.assertEqual(event_end_indices[-1], len(expected_events))
161
+
162
+ def test_encode_and_index_note_sequence_multitrack(self):
163
+ ns = note_seq.NoteSequence()
164
+ ns.notes.add(start_time=0.0,
165
+ end_time=1.0,
166
+ pitch=37,
167
+ velocity=127,
168
+ is_drum=True)
169
+ ns.notes.add(start_time=1.0,
170
+ end_time=3.0,
171
+ pitch=61,
172
+ velocity=127,
173
+ program=0)
174
+ ns.notes.add(start_time=2.0,
175
+ end_time=4.0,
176
+ pitch=62,
177
+ velocity=127,
178
+ program=40)
179
+ ns.total_time = ns.notes[-1].end_time
180
+
181
+ frame_times = np.arange(0, 4, step=.001)
182
+
183
+ event_times, event_values = (
184
+ note_sequences.note_sequence_to_onsets_and_offsets_and_programs(ns))
185
+ (tokens, event_start_indices, event_end_indices, state_tokens,
186
+ state_event_indices) = run_length_encoding.encode_and_index_events(
187
+ state=note_sequences.NoteEncodingState(),
188
+ event_times=event_times, event_values=event_values,
189
+ encode_event_fn=note_sequences.note_event_data_to_events,
190
+ codec=codec, frame_times=frame_times,
191
+ encoding_state_to_events_fn=(
192
+ note_sequences.note_encoding_state_to_events))
193
+
194
+ self.assertEqual(len(frame_times), len(event_start_indices))
195
+ self.assertEqual(len(frame_times), len(event_end_indices))
196
+ self.assertEqual(len(frame_times), len(state_event_indices))
197
+ self.assertLen(tokens, 414)
198
+
199
+ expected_events = (
200
+ [event_codec.Event('velocity', 127), event_codec.Event('drum', 37)] +
201
+ [event_codec.Event('shift', 1)] * 100 +
202
+ [event_codec.Event('program', 0),
203
+ event_codec.Event('velocity', 127), event_codec.Event('pitch', 61)] +
204
+ [event_codec.Event('shift', 1)] * 100 +
205
+ [event_codec.Event('program', 40),
206
+ event_codec.Event('velocity', 127), event_codec.Event('pitch', 62)] +
207
+ [event_codec.Event('shift', 1)] * 100 +
208
+ [event_codec.Event('program', 0),
209
+ event_codec.Event('velocity', 0), event_codec.Event('pitch', 61)] +
210
+ [event_codec.Event('shift', 1)] * 100 +
211
+ [event_codec.Event('program', 40),
212
+ event_codec.Event('velocity', 0), event_codec.Event('pitch', 62)])
213
+ expected_tokens = [codec.encode_event(e) for e in expected_events]
214
+ np.testing.assert_array_equal(expected_tokens, tokens)
215
+
216
+ expected_state_events = [
217
+ event_codec.Event('tie', 0), # state prior to first drum
218
+ event_codec.Event('tie', 0), # state prior to first onset
219
+ event_codec.Event('program', 0), # state prior to second onset
220
+ event_codec.Event('pitch', 61), # |
221
+ event_codec.Event('tie', 0), # |
222
+ event_codec.Event('program', 0), # state prior to first offset
223
+ event_codec.Event('pitch', 61), # |
224
+ event_codec.Event('program', 40), # |
225
+ event_codec.Event('pitch', 62), # |
226
+ event_codec.Event('tie', 0), # |
227
+ event_codec.Event('program', 40), # state prior to second offset
228
+ event_codec.Event('pitch', 62), # |
229
+ event_codec.Event('tie', 0) # |
230
+ ]
231
+ expected_state_tokens = [codec.encode_event(e)
232
+ for e in expected_state_events]
233
+ np.testing.assert_array_equal(expected_state_tokens, state_tokens)
234
+
235
+ self.assertEqual(event_start_indices[0], 0)
236
+ self.assertEqual(event_end_indices[0], 0)
237
+ self.assertEqual(state_event_indices[0], 0)
238
+
239
+ self.assertEqual(1.0, frame_times[1000])
240
+ self.assertEqual(event_start_indices[1000], 102)
241
+ self.assertEqual(event_end_indices[1000], 102)
242
+ self.assertEqual(state_event_indices[1000], 1)
243
+
244
+ self.assertEqual(2.0, frame_times[2000])
245
+ self.assertEqual(event_start_indices[2000], 205)
246
+ self.assertEqual(event_end_indices[2000], 205)
247
+ self.assertEqual(state_event_indices[2000], 2)
248
+
249
+ self.assertEqual(3.0, frame_times[3000])
250
+ self.assertEqual(event_start_indices[3000], 308)
251
+ self.assertEqual(event_end_indices[3000], 308)
252
+ self.assertEqual(state_event_indices[3000], 5)
253
+
254
+ self.assertEqual(3.999, frame_times[-1])
255
+ self.assertEqual(event_start_indices[-1], 410)
256
+ self.assertEqual(event_end_indices[-1], len(expected_events))
257
+ self.assertEqual(state_event_indices[-1], 10)
258
+
259
+ def test_encode_and_index_note_sequence_last_token_alignment(self):
260
+ ns = note_seq.NoteSequence()
261
+ ns.notes.add(start_time=0.0,
262
+ end_time=0.1,
263
+ pitch=60,
264
+ velocity=100)
265
+ ns.total_time = ns.notes[-1].end_time
266
+
267
+ frame_times = np.arange(0, 1.008, step=.008)
268
+
269
+ event_times, event_values = note_sequences.note_sequence_to_onsets(ns)
270
+ events, event_start_indices, event_end_indices, _, _ = run_length_encoding.encode_and_index_events(
271
+ state=None,
272
+ event_times=event_times,
273
+ event_values=event_values,
274
+ encode_event_fn=note_sequences.note_event_data_to_events,
275
+ codec=codec,
276
+ frame_times=frame_times)
277
+
278
+ self.assertEqual(len(frame_times), len(event_start_indices))
279
+ self.assertEqual(len(frame_times), len(event_end_indices))
280
+ self.assertLen(events, 102)
281
+ expected_events = [161] + [1] * 101
282
+
283
+ np.testing.assert_array_equal(expected_events, events)
284
+
285
+ self.assertEqual(event_start_indices[0], 0)
286
+ self.assertEqual(event_end_indices[0], 0)
287
+ self.assertEqual(event_start_indices[125], 101)
288
+ self.assertEqual(event_end_indices[125], 102)
289
+
290
+ def test_decode_note_sequence_events(self):
291
+ events = [25, 161, 50, 162]
292
+
293
+ decoding_state = note_sequences.NoteDecodingState()
294
+ invalid_ids, dropped_events = run_length_encoding.decode_events(
295
+ state=decoding_state, tokens=events, start_time=0, max_time=None,
296
+ codec=codec, decode_event_fn=note_sequences.decode_note_onset_event)
297
+ ns = note_sequences.flush_note_decoding_state(decoding_state)
298
+
299
+ self.assertEqual(0, invalid_ids)
300
+ self.assertEqual(0, dropped_events)
301
+ expected_ns = note_seq.NoteSequence(ticks_per_quarter=220)
302
+ expected_ns.notes.add(
303
+ pitch=60,
304
+ velocity=100,
305
+ start_time=0.25,
306
+ end_time=0.26)
307
+ expected_ns.notes.add(
308
+ pitch=61,
309
+ velocity=100,
310
+ start_time=0.50,
311
+ end_time=0.51)
312
+ expected_ns.total_time = 0.51
313
+ self.assertProtoEquals(expected_ns, ns)
314
+
315
+ def test_decode_note_sequence_events_onsets_only(self):
316
+ events = [5, 161, 25, 162]
317
+
318
+ decoding_state = note_sequences.NoteDecodingState()
319
+ invalid_ids, dropped_events = run_length_encoding.decode_events(
320
+ state=decoding_state, tokens=events, start_time=0, max_time=None,
321
+ codec=codec, decode_event_fn=note_sequences.decode_note_onset_event)
322
+ ns = note_sequences.flush_note_decoding_state(decoding_state)
323
+
324
+ self.assertEqual(0, invalid_ids)
325
+ self.assertEqual(0, dropped_events)
326
+ expected_ns = note_seq.NoteSequence(ticks_per_quarter=220)
327
+ expected_ns.notes.add(
328
+ pitch=60,
329
+ velocity=100,
330
+ start_time=0.05,
331
+ end_time=0.06)
332
+ expected_ns.notes.add(
333
+ pitch=61,
334
+ velocity=100,
335
+ start_time=0.25,
336
+ end_time=0.26)
337
+ expected_ns.total_time = 0.26
338
+ self.assertProtoEquals(expected_ns, ns)
339
+
340
+ def test_decode_note_sequence_events_velocity(self):
341
+ events = [5, 356, 161, 25, 229, 161]
342
+
343
+ decoding_state = note_sequences.NoteDecodingState()
344
+ invalid_ids, dropped_events = run_length_encoding.decode_events(
345
+ state=decoding_state, tokens=events, start_time=0, max_time=None,
346
+ codec=codec, decode_event_fn=note_sequences.decode_note_event)
347
+ ns = note_sequences.flush_note_decoding_state(decoding_state)
348
+
349
+ self.assertEqual(0, invalid_ids)
350
+ self.assertEqual(0, dropped_events)
351
+ expected_ns = note_seq.NoteSequence(ticks_per_quarter=220)
352
+ expected_ns.notes.add(
353
+ pitch=60,
354
+ velocity=127,
355
+ start_time=0.05,
356
+ end_time=0.25)
357
+ expected_ns.total_time = 0.25
358
+ self.assertProtoEquals(expected_ns, ns)
359
+
360
+ def test_decode_note_sequence_events_missing_offset(self):
361
+ events = [5, 356, 161, 10, 161, 25, 229, 161]
362
+
363
+ decoding_state = note_sequences.NoteDecodingState()
364
+ invalid_ids, dropped_events = run_length_encoding.decode_events(
365
+ state=decoding_state, tokens=events, start_time=0, max_time=None,
366
+ codec=codec, decode_event_fn=note_sequences.decode_note_event)
367
+ ns = note_sequences.flush_note_decoding_state(decoding_state)
368
+
369
+ self.assertEqual(0, invalid_ids)
370
+ self.assertEqual(0, dropped_events)
371
+ expected_ns = note_seq.NoteSequence(ticks_per_quarter=220)
372
+ expected_ns.notes.add(
373
+ pitch=60,
374
+ velocity=127,
375
+ start_time=0.05,
376
+ end_time=0.10)
377
+ expected_ns.notes.add(
378
+ pitch=60,
379
+ velocity=127,
380
+ start_time=0.10,
381
+ end_time=0.25)
382
+ expected_ns.total_time = 0.25
383
+ self.assertProtoEquals(expected_ns, ns)
384
+
385
+ def test_decode_note_sequence_events_multitrack(self):
386
+ events = [5, 525, 356, 161, 15, 356, 394, 25, 525, 229, 161]
387
+
388
+ decoding_state = note_sequences.NoteDecodingState()
389
+ invalid_ids, dropped_events = run_length_encoding.decode_events(
390
+ state=decoding_state, tokens=events, start_time=0, max_time=None,
391
+ codec=codec, decode_event_fn=note_sequences.decode_note_event)
392
+ ns = note_sequences.flush_note_decoding_state(decoding_state)
393
+
394
+ self.assertEqual(0, invalid_ids)
395
+ self.assertEqual(0, dropped_events)
396
+ expected_ns = note_seq.NoteSequence(ticks_per_quarter=220)
397
+ expected_ns.notes.add(
398
+ pitch=37,
399
+ velocity=127,
400
+ start_time=0.15,
401
+ end_time=0.16,
402
+ instrument=9,
403
+ is_drum=True)
404
+ expected_ns.notes.add(
405
+ pitch=60,
406
+ velocity=127,
407
+ start_time=0.05,
408
+ end_time=0.25,
409
+ program=40)
410
+ expected_ns.total_time = 0.25
411
+ self.assertProtoEquals(expected_ns, ns)
412
+
413
+ def test_decode_note_sequence_events_invalid_tokens(self):
414
+ events = [5, -1, 161, -2, 25, 162, 9999]
415
+
416
+ decoding_state = note_sequences.NoteDecodingState()
417
+ invalid_events, dropped_events = run_length_encoding.decode_events(
418
+ state=decoding_state, tokens=events, start_time=0, max_time=None,
419
+ codec=codec, decode_event_fn=note_sequences.decode_note_onset_event)
420
+ ns = note_sequences.flush_note_decoding_state(decoding_state)
421
+
422
+ self.assertEqual(3, invalid_events)
423
+ self.assertEqual(0, dropped_events)
424
+ expected_ns = note_seq.NoteSequence(ticks_per_quarter=220)
425
+ expected_ns.notes.add(
426
+ pitch=60,
427
+ velocity=100,
428
+ start_time=0.05,
429
+ end_time=0.06)
430
+ expected_ns.notes.add(
431
+ pitch=61,
432
+ velocity=100,
433
+ start_time=0.25,
434
+ end_time=0.26)
435
+ expected_ns.total_time = 0.26
436
+ self.assertProtoEquals(expected_ns, ns)
437
+
438
+ def test_decode_note_sequence_events_allow_event_at_exactly_max_time(self):
439
+ events = [161, 25, 162]
440
+
441
+ decoding_state = note_sequences.NoteDecodingState()
442
+ invalid_ids, dropped_events = run_length_encoding.decode_events(
443
+ state=decoding_state, tokens=events, start_time=1.0, max_time=1.25,
444
+ codec=codec, decode_event_fn=note_sequences.decode_note_onset_event)
445
+ ns = note_sequences.flush_note_decoding_state(decoding_state)
446
+
447
+ self.assertEqual(0, invalid_ids)
448
+ self.assertEqual(0, dropped_events)
449
+ expected_ns = note_seq.NoteSequence(ticks_per_quarter=220)
450
+ expected_ns.notes.add(
451
+ pitch=60,
452
+ velocity=100,
453
+ start_time=1.00,
454
+ end_time=1.01)
455
+ expected_ns.notes.add(
456
+ pitch=61,
457
+ velocity=100,
458
+ start_time=1.25,
459
+ end_time=1.26)
460
+ expected_ns.total_time = 1.26
461
+ self.assertProtoEquals(expected_ns, ns)
462
+
463
+ def test_decode_note_sequence_events_dropped_events(self):
464
+ events = [5, 161, 30, 162]
465
+
466
+ decoding_state = note_sequences.NoteDecodingState()
467
+ invalid_ids, dropped_events = run_length_encoding.decode_events(
468
+ state=decoding_state, tokens=events, start_time=1.0, max_time=1.25,
469
+ codec=codec, decode_event_fn=note_sequences.decode_note_onset_event)
470
+ ns = note_sequences.flush_note_decoding_state(decoding_state)
471
+
472
+ self.assertEqual(0, invalid_ids)
473
+ self.assertEqual(2, dropped_events)
474
+ expected_ns = note_seq.NoteSequence(ticks_per_quarter=220)
475
+ expected_ns.notes.add(
476
+ pitch=60,
477
+ velocity=100,
478
+ start_time=1.05,
479
+ end_time=1.06)
480
+ expected_ns.total_time = 1.06
481
+ self.assertProtoEquals(expected_ns, ns)
482
+
483
+ def test_decode_note_sequence_events_invalid_events(self):
484
+ events = [25, 230, 50, 161]
485
+
486
+ decoding_state = note_sequences.NoteDecodingState()
487
+ invalid_ids, dropped_events = run_length_encoding.decode_events(
488
+ state=decoding_state, tokens=events, start_time=0, max_time=None,
489
+ codec=codec, decode_event_fn=note_sequences.decode_note_onset_event)
490
+ ns = note_sequences.flush_note_decoding_state(decoding_state)
491
+
492
+ self.assertEqual(1, invalid_ids)
493
+ self.assertEqual(0, dropped_events)
494
+ expected_ns = note_seq.NoteSequence(ticks_per_quarter=220)
495
+ expected_ns.notes.add(
496
+ pitch=60,
497
+ velocity=100,
498
+ start_time=0.50,
499
+ end_time=0.51)
500
+ expected_ns.total_time = 0.51
501
+ self.assertProtoEquals(expected_ns, ns)
502
+
503
+
504
+ if __name__ == '__main__':
505
+ tf.test.main()
mt3/preprocessors.py ADDED
@@ -0,0 +1,669 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The MT3 Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Transcription preprocessors."""
16
+
17
+ from typing import Any, Callable, Mapping, Optional, Sequence, Tuple
18
+
19
+ from absl import logging
20
+ import gin
21
+ from immutabledict import immutabledict
22
+ import librosa
23
+
24
+ from mt3 import event_codec
25
+ from mt3 import note_sequences
26
+ from mt3 import run_length_encoding
27
+ from mt3 import spectrograms
28
+ from mt3 import vocabularies
29
+
30
+ import note_seq
31
+ import numpy as np
32
+ import seqio
33
+ import tensorflow as tf
34
+
35
+
36
+ def add_unique_id(ds: tf.data.Dataset) -> tf.data.Dataset:
37
+ """Add unique integer ID to each example in a dataset."""
38
+ def add_id_field(i, ex):
39
+ ex['unique_id'] = [i]
40
+ return ex
41
+ return ds.enumerate().map(
42
+ add_id_field, num_parallel_calls=tf.data.experimental.AUTOTUNE)
43
+
44
+
45
+ @seqio.map_over_dataset
46
+ def pad_notesequence_array(ex):
47
+ """Pad the NoteSequence array so that it can later be "split"."""
48
+ ex['sequence'] = tf.pad(tf.expand_dims(ex['sequence'], 0),
49
+ [[0, len(ex['input_times']) - 1]])
50
+ return ex
51
+
52
+
53
+ @seqio.map_over_dataset
54
+ def add_dummy_targets(ex):
55
+ """Add dummy targets; used in eval when targets are not actually used."""
56
+ ex['targets'] = np.array([], dtype=np.int32)
57
+ return ex
58
+
59
+
60
+ def _audio_to_frames(
61
+ samples: Sequence[float],
62
+ spectrogram_config: spectrograms.SpectrogramConfig,
63
+ ) -> Tuple[Sequence[Sequence[int]], np.ndarray]:
64
+ """Convert audio samples to non-overlapping frames and frame times."""
65
+ frame_size = spectrogram_config.hop_width
66
+ logging.info('Padding %d samples to multiple of %d', len(samples), frame_size)
67
+ samples = np.pad(samples,
68
+ [0, frame_size - len(samples) % frame_size],
69
+ mode='constant')
70
+
71
+ frames = spectrograms.split_audio(samples, spectrogram_config)
72
+
73
+ num_frames = len(samples) // frame_size
74
+ logging.info('Encoded %d samples to %d frames (%d samples each)',
75
+ len(samples), num_frames, frame_size)
76
+
77
+ times = np.arange(num_frames) / spectrogram_config.frames_per_second
78
+ return frames, times
79
+
80
+
81
+ def _include_inputs(ds, input_record, fields_to_omit=('audio',)):
82
+ """Include fields from input record (other than audio) in dataset records."""
83
+ def include_inputs_fn(output_record):
84
+ for key in set(input_record.keys()) - set(output_record.keys()):
85
+ output_record[key] = input_record[key]
86
+ for key in fields_to_omit:
87
+ del output_record[key]
88
+ return output_record
89
+ return ds.map(include_inputs_fn,
90
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
91
+
92
+
93
+ def tokenize_transcription_example(
94
+ ds: tf.data.Dataset, spectrogram_config: spectrograms.SpectrogramConfig,
95
+ codec: event_codec.Codec, is_training_data: bool,
96
+ onsets_only: bool, include_ties: bool, audio_is_samples: bool,
97
+ id_feature_key: Optional[str] = None
98
+ ) -> tf.data.Dataset:
99
+ """Tokenize a note transcription example for run-length encoding.
100
+
101
+ Outputs include:
102
+ inputs: audio sample frames, num_frames-by-frame_size
103
+ input_time: timestamp for each frame
104
+ targets: symbolic sequence of note-related events
105
+ input_event_start_indices: start target index for every input index
106
+ input_event_end_indices: end target index for every input index
107
+
108
+ Args:
109
+ ds: Input dataset.
110
+ spectrogram_config: Spectrogram configuration.
111
+ codec: Event vocabulary codec.
112
+ is_training_data: Unused.
113
+ onsets_only: If True, include only onset events (not offset, velocity, or
114
+ program).
115
+ include_ties: If True, also write state events containing active notes to
116
+ support a "tie" section after run-length encoding.
117
+ audio_is_samples: If True, audio is floating-point samples instead of
118
+ serialized WAV.
119
+ id_feature_key: If not None, replace sequence ID with specified key field
120
+ from the dataset.
121
+
122
+ Returns:
123
+ Dataset with the outputs described above.
124
+ """
125
+ del is_training_data
126
+
127
+ if onsets_only and include_ties:
128
+ raise ValueError('Ties not supported when only modeling onsets.')
129
+
130
+ def tokenize(sequence, audio, sample_rate, example_id=None):
131
+ ns = note_seq.NoteSequence.FromString(sequence)
132
+ note_sequences.validate_note_sequence(ns)
133
+
134
+ if example_id is not None:
135
+ ns.id = example_id
136
+
137
+ if audio_is_samples:
138
+ samples = audio
139
+ if sample_rate != spectrogram_config.sample_rate:
140
+ samples = librosa.resample(
141
+ samples, sample_rate, spectrogram_config.sample_rate)
142
+ else:
143
+ samples = note_seq.audio_io.wav_data_to_samples_librosa(
144
+ audio, sample_rate=spectrogram_config.sample_rate)
145
+
146
+ logging.info('Got samples for %s::%s with length %d',
147
+ ns.id, ns.filename, len(samples))
148
+
149
+ frames, frame_times = _audio_to_frames(samples, spectrogram_config)
150
+
151
+ if onsets_only:
152
+ times, values = note_sequences.note_sequence_to_onsets(ns)
153
+ else:
154
+ ns = note_seq.apply_sustain_control_changes(ns)
155
+ times, values = (
156
+ note_sequences.note_sequence_to_onsets_and_offsets_and_programs(ns))
157
+
158
+ # The original NoteSequence can have a lot of control changes we don't need;
159
+ # delete them.
160
+ del ns.control_changes[:]
161
+
162
+ (events, event_start_indices, event_end_indices,
163
+ state_events, state_event_indices) = (
164
+ run_length_encoding.encode_and_index_events(
165
+ state=note_sequences.NoteEncodingState() if include_ties else None,
166
+ event_times=times,
167
+ event_values=values,
168
+ encode_event_fn=note_sequences.note_event_data_to_events,
169
+ codec=codec,
170
+ frame_times=frame_times,
171
+ encoding_state_to_events_fn=(
172
+ note_sequences.note_encoding_state_to_events
173
+ if include_ties else None)))
174
+
175
+ yield {
176
+ 'inputs': frames,
177
+ 'input_times': frame_times,
178
+ 'targets': events,
179
+ 'input_event_start_indices': event_start_indices,
180
+ 'input_event_end_indices': event_end_indices,
181
+ 'state_events': state_events,
182
+ 'input_state_event_indices': state_event_indices,
183
+ 'sequence': ns.SerializeToString()
184
+ }
185
+
186
+ def process_record(input_record):
187
+ if audio_is_samples and 'sample_rate' not in input_record:
188
+ raise ValueError('Must provide sample rate when audio is samples.')
189
+
190
+ args = [
191
+ input_record['sequence'],
192
+ input_record['audio'],
193
+ input_record['sample_rate'] if 'sample_rate' in input_record else 0
194
+ ]
195
+ if id_feature_key is not None:
196
+ args.append(input_record[id_feature_key])
197
+
198
+ ds = tf.data.Dataset.from_generator(
199
+ tokenize,
200
+ output_signature={
201
+ 'inputs':
202
+ tf.TensorSpec(
203
+ shape=(None, spectrogram_config.hop_width),
204
+ dtype=tf.float32),
205
+ 'input_times':
206
+ tf.TensorSpec(shape=(None,), dtype=tf.float32),
207
+ 'targets':
208
+ tf.TensorSpec(shape=(None,), dtype=tf.int32),
209
+ 'input_event_start_indices':
210
+ tf.TensorSpec(shape=(None,), dtype=tf.int32),
211
+ 'input_event_end_indices':
212
+ tf.TensorSpec(shape=(None,), dtype=tf.int32),
213
+ 'state_events':
214
+ tf.TensorSpec(shape=(None,), dtype=tf.int32),
215
+ 'input_state_event_indices':
216
+ tf.TensorSpec(shape=(None,), dtype=tf.int32),
217
+ 'sequence':
218
+ tf.TensorSpec(shape=(), dtype=tf.string)
219
+ },
220
+ args=args)
221
+
222
+ ds = _include_inputs(ds, input_record)
223
+ return ds
224
+
225
+ tokenized_records = ds.flat_map(process_record)
226
+ return tokenized_records
227
+
228
+
229
+ def tokenize_guitarset_example(
230
+ ds: tf.data.Dataset, spectrogram_config: spectrograms.SpectrogramConfig,
231
+ codec: event_codec.Codec, is_training_data: bool,
232
+ onsets_only: bool, include_ties: bool
233
+ ) -> tf.data.Dataset:
234
+ """Tokenize a GuitarSet transcription example."""
235
+ def _preprocess_example(ex, name):
236
+ assert 'inst_names' not in ex, 'Key `inst_names` is already populated.'
237
+ ex['inst_names'] = [name]
238
+ ex['instrument_sequences'] = [ex.pop('sequence')]
239
+ return ex
240
+
241
+ ds = ds.map(
242
+ lambda x: _preprocess_example(x, 'Clean Guitar'),
243
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
244
+ ds = tokenize_example_with_program_lookup(
245
+ ds,
246
+ spectrogram_config=spectrogram_config,
247
+ codec=codec,
248
+ is_training_data=is_training_data,
249
+ inst_name_to_program_fn=guitarset_instrument_to_program,
250
+ onsets_only=onsets_only,
251
+ include_ties=include_ties,
252
+ id_feature_key='id')
253
+ return ds
254
+
255
+
256
+ def guitarset_instrument_to_program(instrument: str) -> int:
257
+ """GuitarSet is all guitar, return the first MIDI guitar program."""
258
+ if instrument == 'Clean Guitar':
259
+ return 24
260
+ else:
261
+ raise ValueError('Unknown GuitarSet instrument: %s' % instrument)
262
+
263
+
264
+ def tokenize_example_with_program_lookup(
265
+ ds: tf.data.Dataset,
266
+ spectrogram_config: spectrograms.SpectrogramConfig,
267
+ codec: event_codec.Codec,
268
+ is_training_data: bool,
269
+ onsets_only: bool,
270
+ include_ties: bool,
271
+ inst_name_to_program_fn: Callable[[str], int],
272
+ id_feature_key: Optional[str] = None
273
+ ) -> tf.data.Dataset:
274
+ """Tokenize an example, optionally looking up and assigning program numbers.
275
+
276
+ This can be used by any dataset where a mapping function can be used to
277
+ map from the inst_names feature to a set of program numbers.
278
+
279
+ Args:
280
+ ds: Input dataset.
281
+ spectrogram_config: Spectrogram configuration.
282
+ codec: Event vocabulary codec.
283
+ is_training_data: Unused.
284
+ onsets_only: If True, include only onset events (not offset & velocity).
285
+ include_ties: If True, include tie events.
286
+ inst_name_to_program_fn: A function used to map the instrument names
287
+ in the `inst_names` feature of each example to a MIDI program number.
288
+ id_feature_key: If not None, replace sequence ID with specified key field
289
+ from the dataset.
290
+
291
+ Returns:
292
+ Dataset with the outputs described above.
293
+ """
294
+ del is_training_data
295
+
296
+ def tokenize(sequences, inst_names, audio, example_id=None):
297
+ # Add all the notes from the tracks to a single NoteSequence.
298
+ ns = note_seq.NoteSequence(ticks_per_quarter=220)
299
+ tracks = [note_seq.NoteSequence.FromString(seq) for seq in sequences]
300
+ assert len(tracks) == len(inst_names)
301
+ for track, inst_name in zip(tracks, inst_names):
302
+ program = inst_name_to_program_fn(
303
+ inst_name.decode())
304
+
305
+ # Note that there are no pitch bends in URMP data; the below block will
306
+ # raise PitchBendError if one is encountered.
307
+ add_track_to_notesequence(ns, track, program=program, is_drum=False,
308
+ ignore_pitch_bends=False)
309
+
310
+ note_sequences.assign_instruments(ns)
311
+ note_sequences.validate_note_sequence(ns)
312
+
313
+ if example_id is not None:
314
+ ns.id = example_id
315
+
316
+ samples = note_seq.audio_io.wav_data_to_samples_librosa(
317
+ audio, sample_rate=spectrogram_config.sample_rate)
318
+
319
+ logging.info('Got samples for %s::%s with length %d',
320
+ ns.id, ns.filename, len(samples))
321
+
322
+ frames, frame_times = _audio_to_frames(samples, spectrogram_config)
323
+
324
+ if onsets_only:
325
+ times, values = note_sequences.note_sequence_to_onsets(ns)
326
+ else:
327
+ times, values = (
328
+ note_sequences.note_sequence_to_onsets_and_offsets_and_programs(ns))
329
+
330
+ # The original NoteSequence can have a lot of control changes we don't need;
331
+ # delete them.
332
+ del ns.control_changes[:]
333
+
334
+ (events, event_start_indices, event_end_indices,
335
+ state_events, state_event_indices) = (
336
+ run_length_encoding.encode_and_index_events(
337
+ state=note_sequences.NoteEncodingState() if include_ties else None,
338
+ event_times=times,
339
+ event_values=values,
340
+ encode_event_fn=note_sequences.note_event_data_to_events,
341
+ codec=codec,
342
+ frame_times=frame_times,
343
+ encoding_state_to_events_fn=(
344
+ note_sequences.note_encoding_state_to_events
345
+ if include_ties else None)))
346
+
347
+ yield {
348
+ 'inputs': frames,
349
+ 'input_times': frame_times,
350
+ 'targets': events,
351
+ 'input_event_start_indices': event_start_indices,
352
+ 'input_event_end_indices': event_end_indices,
353
+ 'state_events': state_events,
354
+ 'input_state_event_indices': state_event_indices,
355
+ 'sequence': ns.SerializeToString()
356
+ }
357
+
358
+ def process_record(input_record):
359
+ args = [
360
+ input_record['instrument_sequences'],
361
+ input_record['inst_names'],
362
+ input_record['audio'],
363
+ ]
364
+ if id_feature_key is not None:
365
+ args.append(input_record[id_feature_key])
366
+
367
+ ds = tf.data.Dataset.from_generator(
368
+ tokenize,
369
+ output_signature={
370
+ 'inputs':
371
+ tf.TensorSpec(
372
+ shape=(None, spectrogram_config.hop_width),
373
+ dtype=tf.float32),
374
+ 'input_times':
375
+ tf.TensorSpec(shape=(None,), dtype=tf.float32),
376
+ 'targets':
377
+ tf.TensorSpec(shape=(None,), dtype=tf.int32),
378
+ 'input_event_start_indices':
379
+ tf.TensorSpec(shape=(None,), dtype=tf.int32),
380
+ 'input_event_end_indices':
381
+ tf.TensorSpec(shape=(None,), dtype=tf.int32),
382
+ 'state_events':
383
+ tf.TensorSpec(shape=(None,), dtype=tf.int32),
384
+ 'input_state_event_indices':
385
+ tf.TensorSpec(shape=(None,), dtype=tf.int32),
386
+ 'sequence':
387
+ tf.TensorSpec(shape=(), dtype=tf.string)
388
+ },
389
+ args=args)
390
+
391
+ ds = _include_inputs(ds, input_record)
392
+ return ds
393
+
394
+ tokenized_records = ds.flat_map(process_record)
395
+ return tokenized_records
396
+
397
+
398
+ _URMP_INSTRUMENT_PROGRAMS = immutabledict({
399
+ 'vn': 40, # violin
400
+ 'va': 41, # viola
401
+ 'vc': 42, # cello
402
+ 'db': 43, # double bass
403
+ 'tpt': 56, # trumpet
404
+ 'tbn': 57, # trombone
405
+ 'tba': 58, # tuba
406
+ 'hn': 60, # French horn
407
+ 'sax': 64, # saxophone
408
+ 'ob': 68, # oboe
409
+ 'bn': 70, # bassoon
410
+ 'cl': 71, # clarinet
411
+ 'fl': 73 # flute
412
+ })
413
+
414
+
415
+ def urmp_instrument_to_program(urmp_instrument: str) -> int:
416
+ """Fetch the program number associated with a given URMP instrument code."""
417
+ if urmp_instrument not in _URMP_INSTRUMENT_PROGRAMS:
418
+ raise ValueError('unknown URMP instrument: %s' % urmp_instrument)
419
+ return _URMP_INSTRUMENT_PROGRAMS[urmp_instrument]
420
+
421
+
422
+ _SLAKH_CLASS_PROGRAMS = immutabledict({
423
+ 'Acoustic Piano': 0,
424
+ 'Electric Piano': 4,
425
+ 'Chromatic Percussion': 8,
426
+ 'Organ': 16,
427
+ 'Acoustic Guitar': 24,
428
+ 'Clean Electric Guitar': 26,
429
+ 'Distorted Electric Guitar': 29,
430
+ 'Acoustic Bass': 32,
431
+ 'Electric Bass': 33,
432
+ 'Violin': 40,
433
+ 'Viola': 41,
434
+ 'Cello': 42,
435
+ 'Contrabass': 43,
436
+ 'Orchestral Harp': 46,
437
+ 'Timpani': 47,
438
+ 'String Ensemble': 48,
439
+ 'Synth Strings': 50,
440
+ 'Choir and Voice': 52,
441
+ 'Orchestral Hit': 55,
442
+ 'Trumpet': 56,
443
+ 'Trombone': 57,
444
+ 'Tuba': 58,
445
+ 'French Horn': 60,
446
+ 'Brass Section': 61,
447
+ 'Soprano/Alto Sax': 64,
448
+ 'Tenor Sax': 66,
449
+ 'Baritone Sax': 67,
450
+ 'Oboe': 68,
451
+ 'English Horn': 69,
452
+ 'Bassoon': 70,
453
+ 'Clarinet': 71,
454
+ 'Pipe': 73,
455
+ 'Synth Lead': 80,
456
+ 'Synth Pad': 88
457
+ })
458
+
459
+
460
+ def slakh_class_to_program_and_is_drum(slakh_class: str) -> Tuple[int, bool]:
461
+ """Map Slakh class string to program number and boolean indicating drums."""
462
+ if slakh_class == 'Drums':
463
+ return 0, True
464
+ elif slakh_class not in _SLAKH_CLASS_PROGRAMS:
465
+ raise ValueError('unknown Slakh class: %s' % slakh_class)
466
+ else:
467
+ return _SLAKH_CLASS_PROGRAMS[slakh_class], False
468
+
469
+
470
+ class PitchBendError(Exception):
471
+ pass
472
+
473
+
474
+ def add_track_to_notesequence(ns: note_seq.NoteSequence,
475
+ track: note_seq.NoteSequence,
476
+ program: int, is_drum: bool,
477
+ ignore_pitch_bends: bool):
478
+ """Add a track to a NoteSequence."""
479
+ if track.pitch_bends and not ignore_pitch_bends:
480
+ raise PitchBendError
481
+ track_sus = note_seq.apply_sustain_control_changes(track)
482
+ for note in track_sus.notes:
483
+ note.program = program
484
+ note.is_drum = is_drum
485
+ ns.notes.extend([note])
486
+ ns.total_time = max(ns.total_time, note.end_time)
487
+
488
+
489
+ def tokenize_slakh_example(
490
+ ds: tf.data.Dataset,
491
+ spectrogram_config: spectrograms.SpectrogramConfig,
492
+ codec: event_codec.Codec,
493
+ is_training_data: bool,
494
+ onsets_only: bool,
495
+ include_ties: bool,
496
+ track_specs: Optional[Sequence[note_sequences.TrackSpec]],
497
+ ignore_pitch_bends: bool
498
+ ) -> tf.data.Dataset:
499
+ """Tokenize a Slakh multitrack note transcription example."""
500
+ def tokenize(sequences, samples, sample_rate, inst_names, example_id):
501
+ if sample_rate != spectrogram_config.sample_rate:
502
+ samples = librosa.resample(
503
+ samples, sample_rate, spectrogram_config.sample_rate)
504
+
505
+ frames, frame_times = _audio_to_frames(samples, spectrogram_config)
506
+
507
+ # Add all the notes from the tracks to a single NoteSequence.
508
+ ns = note_seq.NoteSequence(ticks_per_quarter=220)
509
+ tracks = [note_seq.NoteSequence.FromString(seq) for seq in sequences]
510
+ assert len(tracks) == len(inst_names)
511
+ if track_specs:
512
+ # Specific tracks expected.
513
+ assert len(tracks) == len(track_specs)
514
+ for track, spec, inst_name in zip(tracks, track_specs, inst_names):
515
+ # Make sure the instrument name matches what we expect.
516
+ assert inst_name.decode() == spec.name
517
+ try:
518
+ add_track_to_notesequence(ns, track,
519
+ program=spec.program, is_drum=spec.is_drum,
520
+ ignore_pitch_bends=ignore_pitch_bends)
521
+ except PitchBendError:
522
+ # TODO(iansimon): is there a way to count these?
523
+ return
524
+ else:
525
+ for track, inst_name in zip(tracks, inst_names):
526
+ # Instrument name should be Slakh class.
527
+ program, is_drum = slakh_class_to_program_and_is_drum(
528
+ inst_name.decode())
529
+ try:
530
+ add_track_to_notesequence(ns, track, program=program, is_drum=is_drum,
531
+ ignore_pitch_bends=ignore_pitch_bends)
532
+ except PitchBendError:
533
+ # TODO(iansimon): is there a way to count these?
534
+ return
535
+
536
+ note_sequences.assign_instruments(ns)
537
+ note_sequences.validate_note_sequence(ns)
538
+ if is_training_data:
539
+ # Trim overlapping notes in training (as our event vocabulary cannot
540
+ # represent them), but preserve original NoteSequence for eval.
541
+ ns = note_sequences.trim_overlapping_notes(ns)
542
+
543
+ ns.id = example_id
544
+
545
+ if onsets_only:
546
+ times, values = note_sequences.note_sequence_to_onsets(ns)
547
+ else:
548
+ times, values = (
549
+ note_sequences.note_sequence_to_onsets_and_offsets_and_programs(ns))
550
+
551
+ (events, event_start_indices, event_end_indices,
552
+ state_events, state_event_indices) = (
553
+ run_length_encoding.encode_and_index_events(
554
+ state=note_sequences.NoteEncodingState() if include_ties else None,
555
+ event_times=times,
556
+ event_values=values,
557
+ encode_event_fn=note_sequences.note_event_data_to_events,
558
+ codec=codec,
559
+ frame_times=frame_times,
560
+ encoding_state_to_events_fn=(
561
+ note_sequences.note_encoding_state_to_events
562
+ if include_ties else None)))
563
+
564
+ yield {
565
+ 'inputs': frames,
566
+ 'input_times': frame_times,
567
+ 'targets': events,
568
+ 'input_event_start_indices': event_start_indices,
569
+ 'input_event_end_indices': event_end_indices,
570
+ 'state_events': state_events,
571
+ 'input_state_event_indices': state_event_indices,
572
+ 'sequence': ns.SerializeToString()
573
+ }
574
+
575
+ def process_record(input_record):
576
+ ds = tf.data.Dataset.from_generator(
577
+ tokenize,
578
+ output_signature={
579
+ 'inputs':
580
+ tf.TensorSpec(
581
+ shape=(None, spectrogram_config.hop_width),
582
+ dtype=tf.float32),
583
+ 'input_times':
584
+ tf.TensorSpec(shape=(None,), dtype=tf.float32),
585
+ 'targets':
586
+ tf.TensorSpec(shape=(None,), dtype=tf.int32),
587
+ 'input_event_start_indices':
588
+ tf.TensorSpec(shape=(None,), dtype=tf.int32),
589
+ 'input_event_end_indices':
590
+ tf.TensorSpec(shape=(None,), dtype=tf.int32),
591
+ 'state_events':
592
+ tf.TensorSpec(shape=(None,), dtype=tf.int32),
593
+ 'input_state_event_indices':
594
+ tf.TensorSpec(shape=(None,), dtype=tf.int32),
595
+ 'sequence':
596
+ tf.TensorSpec(shape=(), dtype=tf.string)
597
+ },
598
+ args=[
599
+ input_record['note_sequences'], input_record['mix'],
600
+ input_record['audio_sample_rate'], input_record['inst_names'],
601
+ input_record['track_id']
602
+ ])
603
+
604
+ ds = _include_inputs(ds, input_record, fields_to_omit=['mix', 'stems'])
605
+ return ds
606
+
607
+ tokenized_records = ds.flat_map(process_record)
608
+ return tokenized_records
609
+
610
+
611
+
612
+
613
+ @seqio.map_over_dataset
614
+ def compute_spectrograms(ex, spectrogram_config):
615
+ samples = spectrograms.flatten_frames(ex['inputs'])
616
+ ex['inputs'] = spectrograms.compute_spectrogram(samples, spectrogram_config)
617
+ ex['raw_inputs'] = samples
618
+ return ex
619
+
620
+
621
+ def handle_too_long(dataset: tf.data.Dataset,
622
+ output_features: seqio.preprocessors.OutputFeaturesType,
623
+ sequence_length: seqio.preprocessors.SequenceLengthType,
624
+ skip: bool = False) -> tf.data.Dataset:
625
+ """Handle sequences that are too long, by either failing or skipping them."""
626
+ def max_length_for_key(key):
627
+ max_length = sequence_length[key]
628
+ if output_features[key].add_eos:
629
+ max_length -= 1
630
+ return max_length
631
+
632
+ if skip:
633
+ # Drop examples where one of the features is longer than its maximum
634
+ # sequence length.
635
+ def is_not_too_long(ex):
636
+ return not tf.reduce_any(
637
+ [k in output_features and len(v) > max_length_for_key(k)
638
+ for k, v in ex.items()])
639
+ dataset = dataset.filter(is_not_too_long)
640
+
641
+ def assert_not_too_long(key: str, value: tf.Tensor) -> tf.Tensor:
642
+ if key in output_features:
643
+ max_length = max_length_for_key(key)
644
+ tf.debugging.assert_less_equal(
645
+ tf.shape(value)[0], max_length,
646
+ f'Value for "{key}" field exceeds maximum length')
647
+ return value
648
+
649
+ # Assert that no examples have features longer than their maximum sequence
650
+ # length.
651
+ return dataset.map(
652
+ lambda ex: {k: assert_not_too_long(k, v) for k, v in ex.items()},
653
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
654
+
655
+
656
+ @gin.configurable
657
+ def map_midi_programs(
658
+ ds: tf.data.Dataset,
659
+ codec: event_codec.Codec,
660
+ granularity_type: str = 'full',
661
+ feature_key: str = 'targets'
662
+ ) -> Mapping[str, Any]:
663
+ """Apply MIDI program map to token sequences."""
664
+ granularity = vocabularies.PROGRAM_GRANULARITIES[granularity_type]
665
+ def _map_program_tokens(ex):
666
+ ex[feature_key] = granularity.tokens_map_fn(ex[feature_key], codec)
667
+ return ex
668
+ return ds.map(_map_program_tokens,
669
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
mt3/pytest.ini ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [pytest]
2
+ python_files = *_test.py
3
+ log_level = INFO
mt3/run_length_encoding.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The MT3 Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tools for run length encoding."""
16
+
17
+ import dataclasses
18
+ from typing import Any, Callable, Mapping, MutableMapping, Tuple, Optional, Sequence, TypeVar
19
+
20
+ from absl import logging
21
+ from mt3 import event_codec
22
+
23
+ import numpy as np
24
+ import seqio
25
+ import tensorflow as tf
26
+
27
+ Event = event_codec.Event
28
+
29
+ # These should be type variables, but unfortunately those are incompatible with
30
+ # dataclasses.
31
+ EventData = Any
32
+ EncodingState = Any
33
+ DecodingState = Any
34
+ DecodeResult = Any
35
+
36
+ T = TypeVar('T', bound=EventData)
37
+ ES = TypeVar('ES', bound=EncodingState)
38
+ DS = TypeVar('DS', bound=DecodingState)
39
+
40
+
41
+ @dataclasses.dataclass
42
+ class EventEncodingSpec:
43
+ """Spec for encoding events."""
44
+ # initialize encoding state
45
+ init_encoding_state_fn: Callable[[], EncodingState]
46
+ # convert EventData into zero or more events, updating encoding state
47
+ encode_event_fn: Callable[[EncodingState, EventData, event_codec.Codec],
48
+ Sequence[event_codec.Event]]
49
+ # convert encoding state (at beginning of segment) into events
50
+ encoding_state_to_events_fn: Optional[Callable[[EncodingState],
51
+ Sequence[event_codec.Event]]]
52
+ # create empty decoding state
53
+ init_decoding_state_fn: Callable[[], DecodingState]
54
+ # update decoding state when entering new segment
55
+ begin_decoding_segment_fn: Callable[[DecodingState], None]
56
+ # consume time and Event and update decoding state
57
+ decode_event_fn: Callable[
58
+ [DecodingState, float, event_codec.Event, event_codec.Codec], None]
59
+ # flush decoding state into result
60
+ flush_decoding_state_fn: Callable[[DecodingState], DecodeResult]
61
+
62
+
63
+ def encode_and_index_events(
64
+ state: ES,
65
+ event_times: Sequence[float],
66
+ event_values: Sequence[T],
67
+ encode_event_fn: Callable[[ES, T, event_codec.Codec],
68
+ Sequence[event_codec.Event]],
69
+ codec: event_codec.Codec,
70
+ frame_times: Sequence[float],
71
+ encoding_state_to_events_fn: Optional[
72
+ Callable[[ES], Sequence[event_codec.Event]]] = None,
73
+ ) -> Tuple[Sequence[int], Sequence[int], Sequence[int],
74
+ Sequence[int], Sequence[int]]:
75
+ """Encode a sequence of timed events and index to audio frame times.
76
+
77
+ Encodes time shifts as repeated single step shifts for later run length
78
+ encoding.
79
+
80
+ Optionally, also encodes a sequence of "state events", keeping track of the
81
+ current encoding state at each audio frame. This can be used e.g. to prepend
82
+ events representing the current state to a targets segment.
83
+
84
+ Args:
85
+ state: Initial event encoding state.
86
+ event_times: Sequence of event times.
87
+ event_values: Sequence of event values.
88
+ encode_event_fn: Function that transforms event value into a sequence of one
89
+ or more event_codec.Event objects.
90
+ codec: An event_codec.Codec object that maps Event objects to indices.
91
+ frame_times: Time for every audio frame.
92
+ encoding_state_to_events_fn: Function that transforms encoding state into a
93
+ sequence of one or more event_codec.Event objects.
94
+
95
+ Returns:
96
+ events: Encoded events and shifts.
97
+ event_start_indices: Corresponding start event index for every audio frame.
98
+ Note: one event can correspond to multiple audio indices due to sampling
99
+ rate differences. This makes splitting sequences tricky because the same
100
+ event can appear at the end of one sequence and the beginning of
101
+ another.
102
+ event_end_indices: Corresponding end event index for every audio frame. Used
103
+ to ensure when slicing that one chunk ends where the next begins. Should
104
+ always be true that event_end_indices[i] = event_start_indices[i + 1].
105
+ state_events: Encoded "state" events representing the encoding state before
106
+ each event.
107
+ state_event_indices: Corresponding state event index for every audio frame.
108
+ """
109
+ indices = np.argsort(event_times, kind='stable')
110
+ event_steps = [round(event_times[i] * codec.steps_per_second)
111
+ for i in indices]
112
+ event_values = [event_values[i] for i in indices]
113
+
114
+ events = []
115
+ state_events = []
116
+ event_start_indices = []
117
+ state_event_indices = []
118
+
119
+ cur_step = 0
120
+ cur_event_idx = 0
121
+ cur_state_event_idx = 0
122
+
123
+ def fill_event_start_indices_to_cur_step():
124
+ while(len(event_start_indices) < len(frame_times) and
125
+ frame_times[len(event_start_indices)] <
126
+ cur_step / codec.steps_per_second):
127
+ event_start_indices.append(cur_event_idx)
128
+ state_event_indices.append(cur_state_event_idx)
129
+
130
+ for event_step, event_value in zip(event_steps, event_values):
131
+ while event_step > cur_step:
132
+ events.append(codec.encode_event(Event(type='shift', value=1)))
133
+ cur_step += 1
134
+ fill_event_start_indices_to_cur_step()
135
+ cur_event_idx = len(events)
136
+ cur_state_event_idx = len(state_events)
137
+ if encoding_state_to_events_fn:
138
+ # Dump state to state events *before* processing the next event, because
139
+ # we want to capture the state prior to the occurrence of the event.
140
+ for e in encoding_state_to_events_fn(state):
141
+ state_events.append(codec.encode_event(e))
142
+ for e in encode_event_fn(state, event_value, codec):
143
+ events.append(codec.encode_event(e))
144
+
145
+ # After the last event, continue filling out the event_start_indices array.
146
+ # The inequality is not strict because if our current step lines up exactly
147
+ # with (the start of) an audio frame, we need to add an additional shift event
148
+ # to "cover" that frame.
149
+ while cur_step / codec.steps_per_second <= frame_times[-1]:
150
+ events.append(codec.encode_event(Event(type='shift', value=1)))
151
+ cur_step += 1
152
+ fill_event_start_indices_to_cur_step()
153
+ cur_event_idx = len(events)
154
+
155
+ # Now fill in event_end_indices. We need this extra array to make sure that
156
+ # when we slice events, each slice ends exactly where the subsequent slice
157
+ # begins.
158
+ event_end_indices = event_start_indices[1:] + [len(events)]
159
+
160
+ events = np.array(events)
161
+ state_events = np.array(state_events)
162
+ event_start_indices = np.array(event_start_indices)
163
+ event_end_indices = np.array(event_end_indices)
164
+ state_event_indices = np.array(state_event_indices)
165
+
166
+ return (events, event_start_indices, event_end_indices,
167
+ state_events, state_event_indices)
168
+
169
+
170
+ @seqio.map_over_dataset
171
+ def extract_target_sequence_with_indices(features, state_events_end_token=None):
172
+ """Extract target sequence corresponding to audio token segment."""
173
+ target_start_idx = features['input_event_start_indices'][0]
174
+ target_end_idx = features['input_event_end_indices'][-1]
175
+
176
+ features['targets'] = features['targets'][target_start_idx:target_end_idx]
177
+
178
+ if state_events_end_token is not None:
179
+ # Extract the state events corresponding to the audio start token, and
180
+ # prepend them to the targets array.
181
+ state_event_start_idx = features['input_state_event_indices'][0]
182
+ state_event_end_idx = state_event_start_idx + 1
183
+ while features['state_events'][
184
+ state_event_end_idx - 1] != state_events_end_token:
185
+ state_event_end_idx += 1
186
+ features['targets'] = tf.concat([
187
+ features['state_events'][state_event_start_idx:state_event_end_idx],
188
+ features['targets']
189
+ ], axis=0)
190
+
191
+ return features
192
+
193
+
194
+ def remove_redundant_state_changes_fn(
195
+ codec: event_codec.Codec,
196
+ feature_key: str = 'targets',
197
+ state_change_event_types: Sequence[str] = ()
198
+ ) -> Callable[[Mapping[str, Any]], Mapping[str, Any]]:
199
+ """Return preprocessing function that removes redundant state change events.
200
+
201
+ Args:
202
+ codec: The event_codec.Codec used to interpret the events.
203
+ feature_key: The feature key for which to remove redundant state changes.
204
+ state_change_event_types: A list of event types that represent state
205
+ changes; tokens corresponding to these event types will be interpreted
206
+ as state changes and redundant ones will be removed.
207
+
208
+ Returns:
209
+ A preprocessing function that removes redundant state change events.
210
+ """
211
+ state_change_event_ranges = [codec.event_type_range(event_type)
212
+ for event_type in state_change_event_types]
213
+
214
+ def remove_redundant_state_changes(
215
+ features: MutableMapping[str, Any],
216
+ ) -> Mapping[str, Any]:
217
+ """Remove redundant tokens e.g. duplicate velocity changes from sequence."""
218
+ current_state = tf.zeros(len(state_change_event_ranges), dtype=tf.int32)
219
+ output = tf.constant([], dtype=tf.int32)
220
+
221
+ for event in features[feature_key]:
222
+ # Let autograph know that the shape of 'output' will change during the
223
+ # loop.
224
+ tf.autograph.experimental.set_loop_options(
225
+ shape_invariants=[(output, tf.TensorShape([None]))])
226
+ is_redundant = False
227
+ for i, (min_index, max_index) in enumerate(state_change_event_ranges):
228
+ if (min_index <= event) and (event <= max_index):
229
+ if current_state[i] == event:
230
+ is_redundant = True
231
+ current_state = tf.tensor_scatter_nd_update(
232
+ current_state, indices=[[i]], updates=[event])
233
+ if not is_redundant:
234
+ output = tf.concat([output, [event]], axis=0)
235
+
236
+ features[feature_key] = output
237
+ return features
238
+
239
+ return seqio.map_over_dataset(remove_redundant_state_changes)
240
+
241
+
242
+ def run_length_encode_shifts_fn(
243
+ codec: event_codec.Codec,
244
+ feature_key: str = 'targets'
245
+ ) -> Callable[[Mapping[str, Any]], Mapping[str, Any]]:
246
+ """Return a function that run-length encodes shifts for a given codec.
247
+
248
+ Args:
249
+ codec: The Codec to use for shift events.
250
+ feature_key: The feature key for which to run-length encode shifts.
251
+
252
+ Returns:
253
+ A preprocessing function that run-length encodes single-step shifts.
254
+ """
255
+ def run_length_encode_shifts(
256
+ features: MutableMapping[str, Any]
257
+ ) -> Mapping[str, Any]:
258
+ """Combine leading/interior shifts, trim trailing shifts.
259
+
260
+ Args:
261
+ features: Dict of features to process.
262
+
263
+ Returns:
264
+ A dict of features.
265
+ """
266
+ events = features[feature_key]
267
+
268
+ shift_steps = 0
269
+ total_shift_steps = 0
270
+ output = tf.constant([], dtype=tf.int32)
271
+
272
+ for event in events:
273
+ # Let autograph know that the shape of 'output' will change during the
274
+ # loop.
275
+ tf.autograph.experimental.set_loop_options(
276
+ shape_invariants=[(output, tf.TensorShape([None]))])
277
+ if codec.is_shift_event_index(event):
278
+ shift_steps += 1
279
+ total_shift_steps += 1
280
+
281
+ else:
282
+ # Once we've reached a non-shift event, RLE all previous shift events
283
+ # before outputting the non-shift event.
284
+ if shift_steps > 0:
285
+ shift_steps = total_shift_steps
286
+ while shift_steps > 0:
287
+ output_steps = tf.minimum(codec.max_shift_steps, shift_steps)
288
+ output = tf.concat([output, [output_steps]], axis=0)
289
+ shift_steps -= output_steps
290
+ output = tf.concat([output, [event]], axis=0)
291
+
292
+ features[feature_key] = output
293
+ return features
294
+
295
+ return seqio.map_over_dataset(run_length_encode_shifts)
296
+
297
+
298
+ def merge_run_length_encoded_targets(
299
+ targets: np.ndarray,
300
+ codec: event_codec.Codec
301
+ ) -> Sequence[int]:
302
+ """Merge multiple tracks of target events into a single stream.
303
+
304
+ Args:
305
+ targets: A 2D array (# tracks by # events) of integer event values.
306
+ codec: The event_codec.Codec used to interpret the events.
307
+
308
+ Returns:
309
+ A 1D array of merged events.
310
+ """
311
+ num_tracks = tf.shape(targets)[0]
312
+ targets_length = tf.shape(targets)[1]
313
+
314
+ current_step = 0
315
+ current_offsets = tf.zeros(num_tracks, dtype=tf.int32)
316
+
317
+ output = tf.constant([], dtype=tf.int32)
318
+ done = tf.constant(False)
319
+
320
+ while not done:
321
+ # Let autograph know that the shape of 'output' will change during the loop.
322
+ tf.autograph.experimental.set_loop_options(
323
+ shape_invariants=[(output, tf.TensorShape([None]))])
324
+
325
+ # Determine which targets track has the earliest next step.
326
+ next_step = codec.max_shift_steps + 1
327
+ next_track = -1
328
+ for i in range(num_tracks):
329
+ if (current_offsets[i] == targets_length or
330
+ targets[i][current_offsets[i]] == 0):
331
+ # Already reached the end of this targets track.
332
+ # (Zero is technically a valid shift event but we never actually use it;
333
+ # it is always padding.)
334
+ continue
335
+ if not codec.is_shift_event_index(targets[i][current_offsets[i]]):
336
+ # The only way we would be at a non-shift event is if we have not yet
337
+ # reached the first shift event, which means we're at step zero.
338
+ next_step = 0
339
+ next_track = i
340
+ elif targets[i][current_offsets[i]] < next_step:
341
+ next_step = targets[i][current_offsets[i]]
342
+ next_track = i
343
+
344
+ if next_track == -1:
345
+ # We've already merged all of the target tracks in their entirety.
346
+ done = tf.constant(True)
347
+ break
348
+
349
+ if next_step == current_step and next_step > 0:
350
+ # We don't need to include the shift event itself as it's the same step as
351
+ # the previous shift.
352
+ start_offset = current_offsets[next_track] + 1
353
+ else:
354
+ start_offset = current_offsets[next_track]
355
+
356
+ # Merge in events up to but not including the next shift.
357
+ end_offset = start_offset + 1
358
+ while end_offset < targets_length and not codec.is_shift_event_index(
359
+ targets[next_track][end_offset]):
360
+ end_offset += 1
361
+ output = tf.concat(
362
+ [output, targets[next_track][start_offset:end_offset]], axis=0)
363
+
364
+ current_step = next_step
365
+ current_offsets = tf.tensor_scatter_nd_update(
366
+ current_offsets, indices=[[next_track]], updates=[end_offset])
367
+
368
+ return output
369
+
370
+
371
+ def decode_events(
372
+ state: DS,
373
+ tokens: np.ndarray,
374
+ start_time: int,
375
+ max_time: Optional[int],
376
+ codec: event_codec.Codec,
377
+ decode_event_fn: Callable[[DS, float, event_codec.Event, event_codec.Codec],
378
+ None],
379
+ ) -> Tuple[int, int]:
380
+ """Decode a series of tokens, maintaining a decoding state object.
381
+
382
+ Args:
383
+ state: Decoding state object; will be modified in-place.
384
+ tokens: event tokens to convert.
385
+ start_time: offset start time if decoding in the middle of a sequence.
386
+ max_time: Events at or beyond this time will be dropped.
387
+ codec: An event_codec.Codec object that maps indices to Event objects.
388
+ decode_event_fn: Function that consumes an Event (and the current time) and
389
+ updates the decoding state.
390
+
391
+ Returns:
392
+ invalid_events: number of events that could not be decoded.
393
+ dropped_events: number of events dropped due to max_time restriction.
394
+ """
395
+ invalid_events = 0
396
+ dropped_events = 0
397
+ cur_steps = 0
398
+ cur_time = start_time
399
+ token_idx = 0
400
+ for token_idx, token in enumerate(tokens):
401
+ try:
402
+ event = codec.decode_event_index(token)
403
+ except ValueError:
404
+ invalid_events += 1
405
+ continue
406
+ if event.type == 'shift':
407
+ cur_steps += event.value
408
+ cur_time = start_time + cur_steps / codec.steps_per_second
409
+ if max_time and cur_time > max_time:
410
+ dropped_events = len(tokens) - token_idx
411
+ break
412
+ else:
413
+ cur_steps = 0
414
+ try:
415
+ decode_event_fn(state, cur_time, event, codec)
416
+ except ValueError:
417
+ invalid_events += 1
418
+ logging.info(
419
+ 'Got invalid event when decoding event %s at time %f. '
420
+ 'Invalid event counter now at %d.',
421
+ event, cur_time, invalid_events, exc_info=True)
422
+ continue
423
+ return invalid_events, dropped_events
mt3/run_length_encoding_test.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The MT3 Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tests for run_length_encoding."""
16
+
17
+ from mt3 import event_codec
18
+ from mt3 import run_length_encoding
19
+
20
+ import note_seq
21
+ import numpy as np
22
+ import seqio
23
+ import tensorflow as tf
24
+
25
+ assert_dataset = seqio.test_utils.assert_dataset
26
+ codec = event_codec.Codec(
27
+ max_shift_steps=100,
28
+ steps_per_second=100,
29
+ event_ranges=[
30
+ event_codec.EventRange('pitch', note_seq.MIN_MIDI_PITCH,
31
+ note_seq.MAX_MIDI_PITCH),
32
+ event_codec.EventRange('velocity', 0, 127),
33
+ event_codec.EventRange('drum', note_seq.MIN_MIDI_PITCH,
34
+ note_seq.MAX_MIDI_PITCH),
35
+ event_codec.EventRange('program', note_seq.MIN_MIDI_PROGRAM,
36
+ note_seq.MAX_MIDI_PROGRAM),
37
+ event_codec.EventRange('tie', 0, 0)
38
+ ])
39
+ run_length_encode_shifts = run_length_encoding.run_length_encode_shifts_fn(
40
+ codec=codec)
41
+
42
+
43
+ class RunLengthEncodingTest(tf.test.TestCase):
44
+
45
+ def test_remove_redundant_state_changes(self):
46
+ og_dataset = tf.data.Dataset.from_tensors({
47
+ 'targets': [3, 525, 356, 161, 2, 525, 356, 161, 355, 394]
48
+ })
49
+
50
+ assert_dataset(
51
+ run_length_encoding.remove_redundant_state_changes_fn(
52
+ codec=codec,
53
+ state_change_event_types=['velocity', 'program'])(og_dataset),
54
+ {
55
+ 'targets': [3, 525, 356, 161, 2, 161, 355, 394],
56
+ })
57
+
58
+ def test_run_length_encode_shifts(self):
59
+ og_dataset = tf.data.Dataset.from_tensors({
60
+ 'targets': [1, 1, 1, 161, 1, 1, 1, 162, 1, 1, 1]
61
+ })
62
+
63
+ assert_dataset(
64
+ run_length_encode_shifts(og_dataset),
65
+ {
66
+ 'targets': [3, 161, 6, 162],
67
+ })
68
+
69
+ def test_run_length_encode_shifts_beyond_max_length(self):
70
+ og_dataset = tf.data.Dataset.from_tensors({
71
+ 'targets': [1] * 202 + [161, 1, 1, 1]
72
+ })
73
+
74
+ assert_dataset(
75
+ run_length_encode_shifts(og_dataset),
76
+ {
77
+ 'targets': [100, 100, 2, 161],
78
+ })
79
+
80
+ def test_run_length_encode_shifts_simultaneous(self):
81
+ og_dataset = tf.data.Dataset.from_tensors({
82
+ 'targets': [1, 1, 1, 161, 162, 1, 1, 1]
83
+ })
84
+
85
+ assert_dataset(
86
+ run_length_encode_shifts(og_dataset),
87
+ {
88
+ 'targets': [3, 161, 162],
89
+ })
90
+
91
+ def test_merge_run_length_encoded_targets(self):
92
+ # pylint: disable=bad-whitespace
93
+ targets = np.array([
94
+ [ 3, 161, 162, 5, 163],
95
+ [160, 164, 3, 165, 0]
96
+ ])
97
+ # pylint: enable=bad-whitespace
98
+ merged_targets = run_length_encoding.merge_run_length_encoded_targets(
99
+ targets=targets, codec=codec)
100
+ expected_merged_targets = [
101
+ 160, 164, 3, 161, 162, 165, 5, 163
102
+ ]
103
+ np.testing.assert_array_equal(expected_merged_targets, merged_targets)
104
+
105
+
106
+ if __name__ == '__main__':
107
+ tf.test.main()
mt3/scripts/dump_task.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The MT3 Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Simple debugging utility for printing out task contents."""
16
+
17
+ import re
18
+
19
+ from absl import app
20
+ from absl import flags
21
+
22
+ import mt3.tasks # pylint: disable=unused-import
23
+
24
+ import seqio
25
+ import tensorflow as tf
26
+
27
+
28
+ FLAGS = flags.FLAGS
29
+
30
+ flags.DEFINE_string("task", None, "A registered Task.")
31
+ flags.DEFINE_string("task_cache_dir", None, "Directory to use for task cache.")
32
+ flags.DEFINE_integer("max_examples", 10,
33
+ "Maximum number of examples (-1 for no limit).")
34
+ flags.DEFINE_string("format_string", "targets = {targets}",
35
+ "Format for printing examples.")
36
+ flags.DEFINE_string("split", "train",
37
+ "Which split of the dataset, e.g. train or validation.")
38
+ flags.DEFINE_integer("sequence_length_inputs", 256,
39
+ "Sequence length for inputs.")
40
+ flags.DEFINE_integer("sequence_length_targets", 1024,
41
+ "Sequence length for targets.")
42
+
43
+
44
+ def main(_):
45
+ if FLAGS.task_cache_dir:
46
+ seqio.add_global_cache_dirs([FLAGS.task_cache_dir])
47
+
48
+ task = seqio.get_mixture_or_task(FLAGS.task)
49
+
50
+ ds = task.get_dataset(
51
+ sequence_length={
52
+ "inputs": FLAGS.sequence_length_inputs,
53
+ "targets": FLAGS.sequence_length_targets,
54
+ },
55
+ split=FLAGS.split,
56
+ use_cached=bool(FLAGS.task_cache_dir),
57
+ shuffle=False)
58
+
59
+ keys = re.findall(r"{([\w+]+)}", FLAGS.format_string)
60
+ def _example_to_string(ex):
61
+ key_to_string = {}
62
+ for k in keys:
63
+ if k in ex:
64
+ v = ex[k].numpy().tolist()
65
+ key_to_string[k] = task.output_features[k].vocabulary.decode(v)
66
+ else:
67
+ key_to_string[k] = ""
68
+ return FLAGS.format_string.format(**key_to_string)
69
+
70
+ for ex in ds.take(FLAGS.max_examples):
71
+ for k, v in ex.items():
72
+ print(f"{k}: {tf.shape(v)}")
73
+ print(_example_to_string(ex))
74
+ print()
75
+
76
+
77
+ if __name__ == "__main__":
78
+ flags.mark_flags_as_required(["task"])
79
+
80
+ app.run(main)
mt3/scripts/extract_monophonic_examples.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The MT3 Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Detect monophonic tracks and extract notes."""
16
+
17
+ import collections
18
+ import os
19
+
20
+ from absl import app
21
+ from absl import flags
22
+ from absl import logging
23
+
24
+ import ddsp
25
+ import librosa
26
+ import note_seq
27
+ import numpy as np
28
+ import scipy
29
+ import tensorflow as tf
30
+
31
+
32
+ _INPUT_DIR = flags.DEFINE_string(
33
+ 'input_dir', None,
34
+ 'Input directory containing WAV files.')
35
+ _OUTPUT_TFRECORD_PATH = flags.DEFINE_string(
36
+ 'output_tfrecord_path', None,
37
+ 'Path to the output TFRecord containing tf.train.Example protos with '
38
+ 'monophonic tracks and inferred NoteSequence protos.')
39
+
40
+
41
+ CREPE_SAMPLE_RATE = 16000
42
+ CREPE_FRAME_RATE = 100
43
+
44
+ MONOPHONIC_CONFIDENCE_THRESHOLD = 0.95 # confidence must be greater than this
45
+ MONOPHONIC_CONFIDENCE_FRAC = 0.2 # for this fraction of frames
46
+
47
+ # split input audio into clips
48
+ CLIP_LENGTH_SECONDS = 5
49
+
50
+
51
+ def is_monophonic_heuristic(f0_confidence):
52
+ """Heuristic to check for monophonicity using f0 confidence."""
53
+ return (np.sum(f0_confidence >= MONOPHONIC_CONFIDENCE_THRESHOLD) /
54
+ len(f0_confidence) >= MONOPHONIC_CONFIDENCE_FRAC)
55
+
56
+
57
+ # HMM parameters for modeling notes and F0 tracks.
58
+ F0_MIDI_SIGMA = 0.2
59
+ OCTAVE_ERROR_PROB = 0.05
60
+ NOTES_PER_SECOND = 2
61
+ NOTE_CHANGE_PROB = NOTES_PER_SECOND / CREPE_FRAME_RATE
62
+ F0_CONFIDENCE_EXP = 7.5
63
+
64
+
65
+ def f0_hmm_matrices(f0_hz, f0_confidence):
66
+ """Observation and transition matrices for hidden Markov model of F0."""
67
+ f0_midi = librosa.hz_to_midi(f0_hz)
68
+ f0_midi_diff = f0_midi[:, np.newaxis] - np.arange(128)[np.newaxis, :]
69
+
70
+ # Compute the probability of each pitch at each frame, taking octave errors
71
+ # into account.
72
+ f0_midi_prob_octave_correct = scipy.stats.norm.pdf(
73
+ f0_midi_diff, scale=F0_MIDI_SIGMA)
74
+ f0_midi_prob_octave_low = scipy.stats.norm.pdf(
75
+ f0_midi_diff + 12, scale=F0_MIDI_SIGMA)
76
+ f0_midi_prob_octave_high = scipy.stats.norm.pdf(
77
+ f0_midi_diff - 12, scale=F0_MIDI_SIGMA)
78
+
79
+ # distribution of pitch values given note
80
+ f0_midi_loglik = ((1 - OCTAVE_ERROR_PROB) * f0_midi_prob_octave_correct +
81
+ 0.5 * OCTAVE_ERROR_PROB * f0_midi_prob_octave_low +
82
+ 0.5 * OCTAVE_ERROR_PROB * f0_midi_prob_octave_high)
83
+ # (uniform) distribution of pitch values given rest
84
+ f0_midi_rest_loglik = -np.log(128)
85
+
86
+ # Here we interpret confidence, after adjusting by exponent, as P(not rest).
87
+ f0_confidence_prob = np.power(f0_confidence, F0_CONFIDENCE_EXP)[:, np.newaxis]
88
+
89
+ obs_loglik = np.concatenate([
90
+ # probability of note (normalized by number of possible notes)
91
+ f0_midi_loglik + np.log(f0_confidence_prob) - np.log(128),
92
+ # probability of rest
93
+ f0_midi_rest_loglik + np.log(1.0 - f0_confidence_prob)
94
+ ], axis=1)
95
+
96
+ # Normalize to adjust P(confidence | note) by uniform P(note).
97
+ # TODO(iansimon): Not sure how correct this is but it doesn't affect the path.
98
+ obs_loglik += np.log(129)
99
+
100
+ trans_prob = ((NOTE_CHANGE_PROB / 128) * np.ones(129) +
101
+ (1 - NOTE_CHANGE_PROB - NOTE_CHANGE_PROB / 128) * np.eye(129))
102
+ trans_loglik = np.log(trans_prob)
103
+
104
+ return obs_loglik, trans_loglik
105
+
106
+
107
+ def hmm_forward(obs_loglik, trans_loglik):
108
+ """Forward algorithm for a hidden Markov model."""
109
+ n, k = obs_loglik.shape
110
+ trans = np.exp(trans_loglik)
111
+
112
+ loglik = 0.0
113
+
114
+ l = obs_loglik[0] - np.log(k)
115
+ c = scipy.special.logsumexp(l)
116
+ loglik += c
117
+
118
+ for i in range(1, n):
119
+ p = np.exp(l - c)
120
+ l = np.log(np.dot(p, trans)) + obs_loglik[i]
121
+ c = scipy.special.logsumexp(l)
122
+ loglik += c
123
+
124
+ return loglik
125
+
126
+
127
+ def hmm_viterbi(obs_loglik, trans_loglik):
128
+ """Viterbi algorithm for a hidden Markov model."""
129
+ n, k = obs_loglik.shape
130
+
131
+ loglik_matrix = np.zeros_like(obs_loglik)
132
+ path_matrix = np.zeros_like(obs_loglik, dtype=np.int32)
133
+
134
+ loglik_matrix[0, :] = obs_loglik[0, :] - np.log(k)
135
+
136
+ for i in range(1, n):
137
+ mat = np.tile(loglik_matrix[i - 1][:, np.newaxis], [1, 129]) + trans_loglik
138
+ path_matrix[i, :] = mat.argmax(axis=0)
139
+ loglik_matrix[i, :] = mat[path_matrix[i, :], range(129)] + obs_loglik[i]
140
+
141
+ path = [np.argmax(loglik_matrix[-1])]
142
+ for i in range(n, 1, -1):
143
+ path.append(path_matrix[i - 1, path[-1]])
144
+
145
+ return [(pitch if pitch < 128 else None) for pitch in path[::-1]]
146
+
147
+
148
+ def pitches_to_notesequence(pitches):
149
+ """Convert sequence of pitches output by Viterbi to NoteSequence proto."""
150
+ ns = note_seq.NoteSequence(ticks_per_quarter=220)
151
+ current_pitch = None
152
+ start_time = None
153
+ for frame, pitch in enumerate(pitches):
154
+ time = frame / CREPE_FRAME_RATE
155
+ if pitch != current_pitch:
156
+ if current_pitch is not None:
157
+ ns.notes.add(
158
+ pitch=current_pitch, velocity=100,
159
+ start_time=start_time, end_time=time)
160
+ current_pitch = pitch
161
+ start_time = time
162
+ if current_pitch is not None:
163
+ ns.notes.add(
164
+ pitch=current_pitch, velocity=100,
165
+ start_time=start_time, end_time=len(pitches) / CREPE_FRAME_RATE)
166
+ if ns.notes:
167
+ ns.total_time = ns.notes[-1].end_time
168
+ return ns
169
+
170
+
171
+ # Per-frame log likelihood threshold below which an F0 track will be discarded.
172
+ # Note that this is dependent on the HMM parameters specified above, so if those
173
+ # change then this threshold should also change.
174
+ PER_FRAME_LOGLIK_THRESHOLD = 0.3
175
+
176
+
177
+ def extract_note_sequence(crepe, samples, counters):
178
+ """Use CREPE to attempt to extract a monophonic NoteSequence from audio."""
179
+ f0_hz, f0_confidence = crepe.predict_f0_and_confidence(
180
+ samples[np.newaxis, :], viterbi=False)
181
+
182
+ f0_hz = f0_hz[0].numpy()
183
+ f0_confidence = f0_confidence[0].numpy()
184
+
185
+ if not is_monophonic_heuristic(f0_confidence):
186
+ counters['not_monophonic'] += 1
187
+ return None
188
+
189
+ obs_loglik, trans_loglik = f0_hmm_matrices(f0_hz, f0_confidence)
190
+
191
+ loglik = hmm_forward(obs_loglik, trans_loglik)
192
+ if loglik / len(obs_loglik) < PER_FRAME_LOGLIK_THRESHOLD:
193
+ counters['low_likelihood'] += 1
194
+ return None
195
+
196
+ pitches = hmm_viterbi(obs_loglik, trans_loglik)
197
+ ns = pitches_to_notesequence(pitches)
198
+
199
+ counters['extracted_monophonic_sequence'] += 1
200
+ return ns
201
+
202
+
203
+ def process_wav_file(wav_filename, crepe, counters):
204
+ """Extract monophonic transcription examples from a WAV file."""
205
+ wav_data = tf.io.gfile.GFile(wav_filename, 'rb').read()
206
+ samples = note_seq.audio_io.wav_data_to_samples_librosa(
207
+ wav_data, sample_rate=CREPE_SAMPLE_RATE)
208
+ clip_length_samples = int(CREPE_SAMPLE_RATE * CLIP_LENGTH_SECONDS)
209
+ for start_sample in range(0, len(samples), clip_length_samples):
210
+ clip_samples = samples[start_sample:start_sample + clip_length_samples]
211
+ if len(clip_samples) < clip_length_samples:
212
+ clip_samples = np.pad(
213
+ clip_samples, [(0, clip_length_samples - len(clip_samples))])
214
+ ns = extract_note_sequence(crepe, clip_samples, counters)
215
+ if ns:
216
+ feature = {
217
+ 'audio': tf.train.Feature(
218
+ float_list=tf.train.FloatList(value=clip_samples.tolist())),
219
+ 'filename': tf.train.Feature(
220
+ bytes_list=tf.train.BytesList(value=[wav_filename.encode()])),
221
+ 'offset': tf.train.Feature(
222
+ int64_list=tf.train.Int64List(value=[start_sample])),
223
+ 'sampling_rate': tf.train.Feature(
224
+ float_list=tf.train.FloatList(value=[CREPE_SAMPLE_RATE])),
225
+ 'sequence': tf.train.Feature(
226
+ bytes_list=tf.train.BytesList(value=[ns.SerializeToString()]))
227
+ }
228
+ yield tf.train.Example(features=tf.train.Features(feature=feature))
229
+
230
+
231
+ def main(unused_argv):
232
+ flags.mark_flags_as_required(['input_dir', 'output_tfrecord_path'])
233
+ crepe = ddsp.spectral_ops.PretrainedCREPE('full')
234
+ counters = collections.defaultdict(int)
235
+ with tf.io.TFRecordWriter(_OUTPUT_TFRECORD_PATH.value) as writer:
236
+ for filename in tf.io.gfile.listdir(_INPUT_DIR.value):
237
+ if not filename.endswith('.wav'):
238
+ logging.info('skipping %s...', filename)
239
+ counters['non_wav_files_skipped'] += 1
240
+ continue
241
+ logging.info('processing %s...', filename)
242
+ for ex in process_wav_file(
243
+ os.path.join(_INPUT_DIR.value, filename), crepe, counters):
244
+ writer.write(ex.SerializeToString())
245
+ counters['wav_files_processed'] += 1
246
+ for k, v in counters.items():
247
+ logging.info('COUNTER: %s = %d', k, v)
248
+
249
+
250
+ if __name__ == '__main__':
251
+ app.run(main)
mt3/spectrograms.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The MT3 Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Audio spectrogram functions."""
16
+
17
+ import dataclasses
18
+
19
+ from ddsp import spectral_ops
20
+ import tensorflow as tf
21
+
22
+ # defaults for spectrogram config
23
+ DEFAULT_SAMPLE_RATE = 16000
24
+ DEFAULT_HOP_WIDTH = 128
25
+ DEFAULT_NUM_MEL_BINS = 512
26
+
27
+ # fixed constants; add these to SpectrogramConfig before changing
28
+ FFT_SIZE = 2048
29
+ MEL_LO_HZ = 20.0
30
+
31
+
32
+ @dataclasses.dataclass
33
+ class SpectrogramConfig:
34
+ """Spectrogram configuration parameters."""
35
+ sample_rate: int = DEFAULT_SAMPLE_RATE
36
+ hop_width: int = DEFAULT_HOP_WIDTH
37
+ num_mel_bins: int = DEFAULT_NUM_MEL_BINS
38
+
39
+ @property
40
+ def abbrev_str(self):
41
+ s = ''
42
+ if self.sample_rate != DEFAULT_SAMPLE_RATE:
43
+ s += 'sr%d' % self.sample_rate
44
+ if self.hop_width != DEFAULT_HOP_WIDTH:
45
+ s += 'hw%d' % self.hop_width
46
+ if self.num_mel_bins != DEFAULT_NUM_MEL_BINS:
47
+ s += 'mb%d' % self.num_mel_bins
48
+ return s
49
+
50
+ @property
51
+ def frames_per_second(self):
52
+ return self.sample_rate / self.hop_width
53
+
54
+
55
+ def split_audio(samples, spectrogram_config):
56
+ """Split audio into frames."""
57
+ return tf.signal.frame(
58
+ samples,
59
+ frame_length=spectrogram_config.hop_width,
60
+ frame_step=spectrogram_config.hop_width,
61
+ pad_end=True)
62
+
63
+
64
+ def compute_spectrogram(samples, spectrogram_config):
65
+ """Compute a mel spectrogram."""
66
+ overlap = 1 - (spectrogram_config.hop_width / FFT_SIZE)
67
+ return spectral_ops.compute_logmel(
68
+ samples,
69
+ bins=spectrogram_config.num_mel_bins,
70
+ lo_hz=MEL_LO_HZ,
71
+ overlap=overlap,
72
+ fft_size=FFT_SIZE,
73
+ sample_rate=spectrogram_config.sample_rate)
74
+
75
+
76
+ def flatten_frames(frames):
77
+ """Convert frames back into a flat array of samples."""
78
+ return tf.reshape(frames, [-1])
79
+
80
+
81
+ def input_depth(spectrogram_config):
82
+ return spectrogram_config.num_mel_bins
mt3/summaries.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The MT3 Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """TensorBoard summaries and utilities."""
16
+
17
+ from typing import Any, Mapping, Optional, Sequence, Tuple
18
+
19
+ import librosa
20
+
21
+ from mt3 import note_sequences
22
+ from mt3 import spectrograms
23
+
24
+ import note_seq
25
+ from note_seq import midi_synth
26
+ from note_seq import sequences_lib
27
+ from note_seq.protobuf import music_pb2
28
+
29
+ import numpy as np
30
+ import seqio
31
+
32
+
33
+ _DEFAULT_AUDIO_SECONDS = 30.0
34
+ _DEFAULT_PIANOROLL_FRAMES_PER_SECOND = 15
35
+
36
+ # TODO(iansimon): pick a SoundFont; for some reason the default is all organ
37
+
38
+
39
+ def _extract_example_audio(
40
+ examples: Sequence[Mapping[str, Any]],
41
+ sample_rate: float,
42
+ num_seconds: float,
43
+ audio_key: str = 'raw_inputs'
44
+ ) -> np.ndarray:
45
+ """Extract audio from examples.
46
+
47
+ Args:
48
+ examples: List of examples containing raw audio.
49
+ sample_rate: Number of samples per second.
50
+ num_seconds: Number of seconds of audio to include.
51
+ audio_key: Dictionary key for the raw audio.
52
+
53
+ Returns:
54
+ An n-by-num_samples numpy array of samples.
55
+ """
56
+ n = len(examples)
57
+ num_samples = round(num_seconds * sample_rate)
58
+ all_samples = np.zeros([n, num_samples])
59
+ for i, ex in enumerate(examples):
60
+ samples = ex[audio_key][:num_samples]
61
+ all_samples[i, :len(samples)] = samples
62
+ return all_samples
63
+
64
+
65
+ def _example_to_note_sequence(
66
+ example: Mapping[str, Sequence[float]],
67
+ ns_feature_name: str,
68
+ note_onset_feature_name: str,
69
+ note_offset_feature_name: str,
70
+ note_frequency_feature_name: str,
71
+ note_confidence_feature_name: str,
72
+ num_seconds: float
73
+ ) -> music_pb2.NoteSequence:
74
+ """Extract NoteSequence from example."""
75
+ if ns_feature_name:
76
+ ns = example[ns_feature_name]
77
+
78
+ else:
79
+ onset_times = np.array(example[note_onset_feature_name])
80
+ pitches = librosa.hz_to_midi(
81
+ example[note_frequency_feature_name]).round().astype(int)
82
+ assert len(onset_times) == len(pitches)
83
+
84
+ if note_offset_feature_name or note_confidence_feature_name:
85
+ offset_times = (
86
+ example[note_offset_feature_name]
87
+ if note_offset_feature_name
88
+ else onset_times + note_sequences.DEFAULT_NOTE_DURATION
89
+ )
90
+ assert len(onset_times) == len(offset_times)
91
+
92
+ confidences = (np.array(example[note_confidence_feature_name])
93
+ if note_confidence_feature_name else None)
94
+ velocities = np.ceil(
95
+ note_seq.MAX_MIDI_VELOCITY * confidences if confidences is not None
96
+ else note_sequences.DEFAULT_VELOCITY * np.ones_like(onset_times)
97
+ ).astype(int)
98
+ assert len(onset_times) == len(velocities)
99
+
100
+ ns = note_sequences.note_arrays_to_note_sequence(
101
+ onset_times=onset_times, offset_times=offset_times,
102
+ pitches=pitches, velocities=velocities)
103
+
104
+ else:
105
+ ns = note_sequences.note_arrays_to_note_sequence(
106
+ onset_times=onset_times, pitches=pitches)
107
+
108
+ return sequences_lib.trim_note_sequence(ns, 0, num_seconds)
109
+
110
+
111
+ def _synthesize_example_notes(
112
+ examples: Sequence[Mapping[str, Sequence[float]]],
113
+ ns_feature_name: str,
114
+ note_onset_feature_name: str,
115
+ note_offset_feature_name: str,
116
+ note_frequency_feature_name: str,
117
+ note_confidence_feature_name: str,
118
+ sample_rate: float,
119
+ num_seconds: float,
120
+ ) -> np.ndarray:
121
+ """Synthesize example notes to audio.
122
+
123
+ Args:
124
+ examples: List of example dictionaries, containing either serialized
125
+ NoteSequence protos or note onset times and pitches.
126
+ ns_feature_name: Name of serialized NoteSequence feature.
127
+ note_onset_feature_name: Name of note onset times feature.
128
+ note_offset_feature_name: Name of note offset times feature.
129
+ note_frequency_feature_name: Name of note frequencies feature.
130
+ note_confidence_feature_name: Name of note confidences (velocities) feature.
131
+ sample_rate: Sample rate at which to synthesize.
132
+ num_seconds: Number of seconds to synthesize for each example.
133
+
134
+ Returns:
135
+ An n-by-num_samples numpy array of samples.
136
+ """
137
+ if (ns_feature_name is not None) == (note_onset_feature_name is not None):
138
+ raise ValueError(
139
+ 'must specify exactly one of NoteSequence feature and onset feature')
140
+
141
+ n = len(examples)
142
+ num_samples = round(num_seconds * sample_rate)
143
+
144
+ all_samples = np.zeros([n, num_samples])
145
+
146
+ for i, ex in enumerate(examples):
147
+ ns = _example_to_note_sequence(
148
+ ex,
149
+ ns_feature_name=ns_feature_name,
150
+ note_onset_feature_name=note_onset_feature_name,
151
+ note_offset_feature_name=note_offset_feature_name,
152
+ note_frequency_feature_name=note_frequency_feature_name,
153
+ note_confidence_feature_name=note_confidence_feature_name,
154
+ num_seconds=num_seconds)
155
+ fluidsynth = midi_synth.fluidsynth
156
+ samples = fluidsynth(ns, sample_rate=sample_rate)
157
+ if len(samples) > num_samples:
158
+ samples = samples[:num_samples]
159
+ all_samples[i, :len(samples)] = samples
160
+
161
+ return all_samples
162
+
163
+
164
+ def _examples_to_pianorolls(
165
+ targets: Sequence[Mapping[str, Sequence[float]]],
166
+ predictions: Sequence[Mapping[str, Sequence[float]]],
167
+ ns_feature_suffix: str,
168
+ note_onset_feature_suffix: str,
169
+ note_offset_feature_suffix: str,
170
+ note_frequency_feature_suffix: str,
171
+ note_confidence_feature_suffix: str,
172
+ track_specs: Optional[Sequence[note_sequences.TrackSpec]],
173
+ num_seconds: float,
174
+ frames_per_second: float
175
+ ) -> Tuple[np.ndarray, np.ndarray]:
176
+ """Generate pianoroll images from example notes.
177
+
178
+ Args:
179
+ targets: List of target dictionaries, containing either serialized
180
+ NoteSequence protos or note onset times and pitches.
181
+ predictions: List of prediction dictionaries, containing either serialized
182
+ NoteSequence protos or note onset times and pitches.
183
+ ns_feature_suffix: Suffix of serialized NoteSequence feature.
184
+ note_onset_feature_suffix: Suffix of note onset times feature.
185
+ note_offset_feature_suffix: Suffix of note offset times feature.
186
+ note_frequency_feature_suffix: Suffix of note frequencies feature.
187
+ note_confidence_feature_suffix: Suffix of note confidences (velocities)
188
+ feature.
189
+ track_specs: Optional list of TrackSpec objects to indicate a set of tracks
190
+ into which each NoteSequence should be split. Tracks will be stacked
191
+ vertically in the pianorolls
192
+ num_seconds: Number of seconds to show for each example.
193
+ frames_per_second: Number of pianoroll frames per second.
194
+
195
+ Returns:
196
+ onset_pianorolls: An n-by-num_pitches-by-num_frames-by-4 numpy array of
197
+ pianoroll images showing only onsets.
198
+ full_pianorolls: An n-by-num_pitches-by-num_frames-by-4 numpy array of
199
+ pianoroll images.
200
+ """
201
+ if (ns_feature_suffix is not None) == (note_onset_feature_suffix is not None):
202
+ raise ValueError(
203
+ 'must specify exactly one of NoteSequence feature and onset feature')
204
+
205
+ def ex_to_ns(example, prefix):
206
+ return _example_to_note_sequence(
207
+ example=example,
208
+ ns_feature_name=(prefix + ns_feature_suffix
209
+ if ns_feature_suffix else None),
210
+ note_onset_feature_name=(prefix + note_onset_feature_suffix
211
+ if note_onset_feature_suffix else None),
212
+ note_offset_feature_name=(prefix + note_offset_feature_suffix
213
+ if note_offset_feature_suffix else None),
214
+ note_frequency_feature_name=(
215
+ prefix + note_frequency_feature_suffix
216
+ if note_frequency_feature_suffix else None),
217
+ note_confidence_feature_name=(
218
+ prefix + note_confidence_feature_suffix
219
+ if note_confidence_feature_suffix else None),
220
+ num_seconds=num_seconds)
221
+
222
+ n = len(targets)
223
+ num_pitches = note_seq.MAX_MIDI_PITCH - note_seq.MIN_MIDI_PITCH + 1
224
+ num_frames = round(num_seconds * frames_per_second)
225
+ num_tracks = len(track_specs) if track_specs else 1
226
+ pianoroll_height = num_tracks * num_pitches + (num_tracks - 1)
227
+
228
+ onset_images = np.zeros([n, pianoroll_height, num_frames, 3])
229
+ full_images = np.zeros([n, pianoroll_height, num_frames, 3])
230
+
231
+ for i, (target, pred) in enumerate(zip(targets, predictions)):
232
+ target_ns, pred_ns = [
233
+ ex_to_ns(ex, prefix)
234
+ for (ex, prefix) in [(target, 'ref_'), (pred, 'est_')]
235
+ ]
236
+
237
+ # Show lines at frame boundaries. To ensure that these lines are drawn with
238
+ # the same downsampling and frame selection logic as the real NoteSequences,
239
+ # use this hack to draw the lines with a NoteSequence that contains notes
240
+ # across all pitches at all frame start times.
241
+ start_times_ns = note_seq.NoteSequence()
242
+ start_times_ns.CopyFrom(target_ns)
243
+ del start_times_ns.notes[:]
244
+ for start_time in pred['start_times']:
245
+ if start_time < target_ns.total_time:
246
+ for pitch in range(
247
+ note_seq.MIN_MIDI_PITCH, note_seq.MAX_MIDI_PITCH + 1):
248
+ start_times_ns.notes.add(
249
+ pitch=pitch,
250
+ velocity=100,
251
+ start_time=start_time,
252
+ end_time=start_time + (1 / frames_per_second))
253
+
254
+ start_time_roll = sequences_lib.sequence_to_pianoroll(
255
+ start_times_ns,
256
+ frames_per_second=frames_per_second,
257
+ min_pitch=note_seq.MIN_MIDI_PITCH,
258
+ max_pitch=note_seq.MAX_MIDI_PITCH,
259
+ onset_mode='length_ms')
260
+ num_start_time_frames = min(len(start_time_roll.onsets), num_frames)
261
+
262
+ if track_specs is not None:
263
+ target_tracks = [note_sequences.extract_track(target_ns,
264
+ spec.program, spec.is_drum)
265
+ for spec in track_specs]
266
+ pred_tracks = [note_sequences.extract_track(pred_ns,
267
+ spec.program, spec.is_drum)
268
+ for spec in track_specs]
269
+ else:
270
+ target_tracks = [target_ns]
271
+ pred_tracks = [pred_ns]
272
+
273
+ for j, (target_track, pred_track) in enumerate(zip(target_tracks[::-1],
274
+ pred_tracks[::-1])):
275
+ target_roll = sequences_lib.sequence_to_pianoroll(
276
+ target_track,
277
+ frames_per_second=frames_per_second,
278
+ min_pitch=note_seq.MIN_MIDI_PITCH,
279
+ max_pitch=note_seq.MAX_MIDI_PITCH,
280
+ onset_mode='length_ms')
281
+ pred_roll = sequences_lib.sequence_to_pianoroll(
282
+ pred_track,
283
+ frames_per_second=frames_per_second,
284
+ min_pitch=note_seq.MIN_MIDI_PITCH,
285
+ max_pitch=note_seq.MAX_MIDI_PITCH,
286
+ onset_mode='length_ms')
287
+
288
+ num_target_frames = min(len(target_roll.onsets), num_frames)
289
+ num_pred_frames = min(len(pred_roll.onsets), num_frames)
290
+
291
+ start_offset = j * (num_pitches + 1)
292
+ end_offset = (j + 1) * (num_pitches + 1) - 1
293
+
294
+ # Onsets
295
+ onset_images[
296
+ i, start_offset:end_offset, :num_start_time_frames, 0
297
+ ] = start_time_roll.onsets[:num_start_time_frames, :].T
298
+ onset_images[
299
+ i, start_offset:end_offset, :num_target_frames, 1
300
+ ] = target_roll.onsets[:num_target_frames, :].T
301
+ onset_images[
302
+ i, start_offset:end_offset, :num_pred_frames, 2
303
+ ] = pred_roll.onsets[:num_pred_frames, :].T
304
+
305
+ # Full notes
306
+ full_images[
307
+ i, start_offset:end_offset, :num_start_time_frames, 0
308
+ ] = start_time_roll.onsets[:num_start_time_frames, :].T
309
+ full_images[
310
+ i, start_offset:end_offset, :num_target_frames, 1
311
+ ] = target_roll.active[:num_target_frames, :].T
312
+ full_images[
313
+ i, start_offset:end_offset, :num_pred_frames, 2
314
+ ] = pred_roll.active[:num_pred_frames, :].T
315
+
316
+ # Add separator between tracks.
317
+ if j < num_tracks - 1:
318
+ onset_images[i, end_offset, :, 0] = 1
319
+ full_images[i, end_offset, :, 0] = 1
320
+
321
+ return onset_images[:, ::-1, :, :], full_images[:, ::-1, :, :]
322
+
323
+
324
+ def prettymidi_pianoroll(
325
+ track_pianorolls: Mapping[str, Sequence[Tuple[np.ndarray, np.ndarray]]],
326
+ fps: float,
327
+ num_seconds=_DEFAULT_AUDIO_SECONDS
328
+ ) -> Mapping[str, seqio.metrics.MetricValue]:
329
+ """Create summary from given pianorolls."""
330
+ max_len = int(num_seconds * fps)
331
+ summaries = {}
332
+ for inst_name, all_prs in track_pianorolls.items():
333
+
334
+ est_prs, ref_prs = zip(*all_prs)
335
+
336
+ bs = len(ref_prs)
337
+ pianoroll_image_batch = np.zeros(shape=(bs, 128, max_len, 3))
338
+ for i in range(bs):
339
+ ref_pr = ref_prs[i][:, :max_len]
340
+ est_pr = est_prs[i][:, :max_len]
341
+
342
+ pianoroll_image_batch[i, :, :est_pr.shape[1], 2] = est_pr
343
+ pianoroll_image_batch[i, :, :ref_pr.shape[1], 1] = ref_pr
344
+ if not inst_name:
345
+ inst_name = 'all instruments'
346
+
347
+ summaries[f'{inst_name} pretty_midi pianoroll'] = seqio.metrics.Image(
348
+ image=pianoroll_image_batch, max_outputs=bs)
349
+
350
+ return summaries
351
+
352
+
353
+ def audio_summaries(
354
+ targets: Sequence[Mapping[str, Sequence[float]]],
355
+ predictions: Sequence[Mapping[str, Sequence[float]]],
356
+ spectrogram_config: spectrograms.SpectrogramConfig,
357
+ num_seconds: float = _DEFAULT_AUDIO_SECONDS
358
+ ) -> Mapping[str, seqio.metrics.MetricValue]:
359
+ """Compute audio summaries for a list of examples.
360
+
361
+ Args:
362
+ targets: List of targets, unused as we pass the input audio tokens via
363
+ predictions.
364
+ predictions: List of predictions, including input audio tokens.
365
+ spectrogram_config: Spectrogram configuration.
366
+ num_seconds: Number of seconds of audio to include in the summaries.
367
+ Longer audio will be cropped (from the beginning), shorter audio will be
368
+ padded with silence (at the end).
369
+
370
+ Returns:
371
+ A dictionary mapping "audio" to the audio summaries.
372
+ """
373
+ del targets
374
+ samples = _extract_example_audio(
375
+ examples=predictions,
376
+ sample_rate=spectrogram_config.sample_rate,
377
+ num_seconds=num_seconds)
378
+ return {
379
+ 'audio': seqio.metrics.Audio(
380
+ audiodata=samples[:, :, np.newaxis],
381
+ sample_rate=spectrogram_config.sample_rate,
382
+ max_outputs=samples.shape[0])
383
+ }
384
+
385
+
386
+ def transcription_summaries(
387
+ targets: Sequence[Mapping[str, Sequence[float]]],
388
+ predictions: Sequence[Mapping[str, Sequence[float]]],
389
+ spectrogram_config: spectrograms.SpectrogramConfig,
390
+ ns_feature_suffix: Optional[str] = None,
391
+ note_onset_feature_suffix: Optional[str] = None,
392
+ note_offset_feature_suffix: Optional[str] = None,
393
+ note_frequency_feature_suffix: Optional[str] = None,
394
+ note_confidence_feature_suffix: Optional[str] = None,
395
+ track_specs: Optional[Sequence[note_sequences.TrackSpec]] = None,
396
+ num_seconds: float = _DEFAULT_AUDIO_SECONDS,
397
+ pianoroll_frames_per_second: float = _DEFAULT_PIANOROLL_FRAMES_PER_SECOND,
398
+ ) -> Mapping[str, seqio.metrics.MetricValue]:
399
+ """Compute note transcription summaries for multiple examples.
400
+
401
+ Args:
402
+ targets: List of targets containing ground truth.
403
+ predictions: List of predictions, including raw input audio.
404
+ spectrogram_config: The spectrogram configuration.
405
+ ns_feature_suffix: Suffix of serialized NoteSequence feature.
406
+ note_onset_feature_suffix: Suffix of note onset times feature.
407
+ note_offset_feature_suffix: Suffix of note offset times feature.
408
+ note_frequency_feature_suffix: Suffix of note frequencies feature.
409
+ note_confidence_feature_suffix: Suffix of note confidences (velocities)
410
+ feature.
411
+ track_specs: Optional list of TrackSpec objects to indicate a set of tracks
412
+ into which each NoteSequence should be split.
413
+ num_seconds: Number of seconds of audio to include in the summaries.
414
+ Longer audio will be cropped (from the beginning), shorter audio will be
415
+ padded with silence (at the end).
416
+ pianoroll_frames_per_second: Temporal resolution of pianoroll images.
417
+
418
+ Returns:
419
+ A dictionary of input, ground truth, and transcription summaries.
420
+ """
421
+ audio_samples = _extract_example_audio(
422
+ examples=predictions,
423
+ sample_rate=spectrogram_config.sample_rate,
424
+ num_seconds=num_seconds)
425
+
426
+ def synthesize(examples, prefix):
427
+ return _synthesize_example_notes(
428
+ examples=examples,
429
+ ns_feature_name=(prefix + ns_feature_suffix
430
+ if ns_feature_suffix else None),
431
+ note_onset_feature_name=(prefix + note_onset_feature_suffix
432
+ if note_onset_feature_suffix else None),
433
+ note_offset_feature_name=(prefix + note_offset_feature_suffix
434
+ if note_offset_feature_suffix else None),
435
+ note_frequency_feature_name=(
436
+ prefix + note_frequency_feature_suffix
437
+ if note_frequency_feature_suffix else None),
438
+ note_confidence_feature_name=(
439
+ prefix + note_confidence_feature_suffix
440
+ if note_confidence_feature_suffix else None),
441
+ sample_rate=spectrogram_config.sample_rate,
442
+ num_seconds=num_seconds)
443
+
444
+ synthesized_predictions = synthesize(predictions, 'est_')
445
+
446
+ onset_pianoroll_images, full_pianoroll_images = _examples_to_pianorolls(
447
+ targets=targets,
448
+ predictions=predictions,
449
+ ns_feature_suffix=ns_feature_suffix,
450
+ note_onset_feature_suffix=note_onset_feature_suffix,
451
+ note_offset_feature_suffix=note_offset_feature_suffix,
452
+ note_frequency_feature_suffix=note_frequency_feature_suffix,
453
+ note_confidence_feature_suffix=note_confidence_feature_suffix,
454
+ track_specs=track_specs,
455
+ num_seconds=num_seconds,
456
+ frames_per_second=pianoroll_frames_per_second)
457
+
458
+ return {
459
+ 'input_with_transcription': seqio.metrics.Audio(
460
+ audiodata=np.stack([audio_samples, synthesized_predictions], axis=2),
461
+ sample_rate=spectrogram_config.sample_rate,
462
+ max_outputs=audio_samples.shape[0]),
463
+
464
+ 'pianoroll': seqio.metrics.Image(
465
+ image=full_pianoroll_images,
466
+ max_outputs=full_pianoroll_images.shape[0]),
467
+
468
+ 'onset_pianoroll': seqio.metrics.Image(
469
+ image=onset_pianoroll_images,
470
+ max_outputs=onset_pianoroll_images.shape[0]),
471
+ }
mt3/tasks.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The MT3 Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Transcription task definitions."""
16
+
17
+ import functools
18
+ from typing import Optional, Sequence
19
+
20
+ from mt3 import datasets
21
+ from mt3 import event_codec
22
+ from mt3 import metrics
23
+ from mt3 import mixing
24
+ from mt3 import preprocessors
25
+ from mt3 import run_length_encoding
26
+ from mt3 import spectrograms
27
+ from mt3 import vocabularies
28
+
29
+ import note_seq
30
+ import numpy as np
31
+ import seqio
32
+ import t5
33
+ import tensorflow as tf
34
+
35
+ # Split audio frame sequences into this length before the cache placeholder.
36
+ MAX_NUM_CACHED_FRAMES = 2000
37
+
38
+ seqio.add_global_cache_dirs(['gs://mt3/data/cache_tasks/'])
39
+
40
+
41
+ def construct_task_name(
42
+ task_prefix: str,
43
+ spectrogram_config=spectrograms.SpectrogramConfig(),
44
+ vocab_config=vocabularies.VocabularyConfig(),
45
+ task_suffix: Optional[str] = None
46
+ ) -> str:
47
+ """Construct task name from prefix, config, and optional suffix."""
48
+ fields = [task_prefix]
49
+ if spectrogram_config.abbrev_str:
50
+ fields.append(spectrogram_config.abbrev_str)
51
+ if vocab_config.abbrev_str:
52
+ fields.append(vocab_config.abbrev_str)
53
+ if task_suffix:
54
+ fields.append(task_suffix)
55
+ return '_'.join(fields)
56
+
57
+
58
+ def trim_eos(tokens: Sequence[int]) -> np.ndarray:
59
+ """If EOS is present, remove it and everything after."""
60
+ tokens = np.array(tokens, np.int32)
61
+ if vocabularies.DECODED_EOS_ID in tokens:
62
+ tokens = tokens[:np.argmax(tokens == vocabularies.DECODED_EOS_ID)]
63
+ return tokens
64
+
65
+
66
+ def postprocess(tokens, example, is_target, codec):
67
+ """Transcription postprocessing function."""
68
+ tokens = trim_eos(tokens)
69
+
70
+ if is_target:
71
+ return {
72
+ 'unique_id': example['unique_id'][0],
73
+ 'ref_ns': (note_seq.NoteSequence.FromString(example['sequence'][0])
74
+ if example['sequence'][0] else None),
75
+ 'ref_tokens': tokens,
76
+ }
77
+
78
+ start_time = example['input_times'][0]
79
+ # Round down to nearest symbolic token step.
80
+ start_time -= start_time % (1 / codec.steps_per_second)
81
+
82
+ return {
83
+ 'unique_id': example['unique_id'][0],
84
+ 'raw_inputs': example['raw_inputs'],
85
+ 'est_tokens': tokens,
86
+ 'start_time': start_time
87
+ }
88
+
89
+
90
+ def add_transcription_task_to_registry(
91
+ dataset_config: datasets.DatasetConfig,
92
+ spectrogram_config: spectrograms.SpectrogramConfig,
93
+ vocab_config: vocabularies.VocabularyConfig,
94
+ tokenize_fn, # TODO(iansimon): add type signature
95
+ onsets_only: bool,
96
+ include_ties: bool,
97
+ skip_too_long: bool = False
98
+ ) -> None:
99
+ """Add note transcription task to seqio.TaskRegistry."""
100
+ codec = vocabularies.build_codec(vocab_config)
101
+ vocabulary = vocabularies.vocabulary_from_codec(codec)
102
+
103
+ output_features = {
104
+ 'targets': seqio.Feature(vocabulary=vocabulary),
105
+ 'inputs': seqio.ContinuousFeature(dtype=tf.float32, rank=2)
106
+ }
107
+
108
+ task_name = 'onsets' if onsets_only else 'notes'
109
+ if include_ties:
110
+ task_name += '_ties'
111
+ task_prefix = f'{dataset_config.name}_{task_name}'
112
+
113
+ train_task_name = construct_task_name(
114
+ task_prefix=task_prefix,
115
+ spectrogram_config=spectrogram_config,
116
+ vocab_config=vocab_config,
117
+ task_suffix='train')
118
+
119
+ mixture_task_names = []
120
+
121
+ tie_token = codec.encode_event(event_codec.Event('tie', 0))
122
+ track_specs = (dataset_config.track_specs
123
+ if dataset_config.track_specs else None)
124
+
125
+ # Add transcription training task.
126
+ seqio.TaskRegistry.add(
127
+ train_task_name,
128
+ source=seqio.TFExampleDataSource(
129
+ split_to_filepattern={
130
+ 'train': dataset_config.paths[dataset_config.train_split],
131
+ 'eval': dataset_config.paths[dataset_config.train_eval_split]
132
+ },
133
+ feature_description=dataset_config.features),
134
+ output_features=output_features,
135
+ preprocessors=[
136
+ functools.partial(
137
+ tokenize_fn,
138
+ spectrogram_config=spectrogram_config, codec=codec,
139
+ is_training_data=True, onsets_only=onsets_only,
140
+ include_ties=include_ties),
141
+ functools.partial(
142
+ t5.data.preprocessors.split_tokens,
143
+ max_tokens_per_segment=MAX_NUM_CACHED_FRAMES,
144
+ feature_key='inputs',
145
+ additional_feature_keys=[
146
+ 'input_event_start_indices', 'input_event_end_indices',
147
+ 'input_state_event_indices'
148
+ ],
149
+ passthrough_feature_keys=['targets', 'state_events']),
150
+ seqio.CacheDatasetPlaceholder(),
151
+ functools.partial(
152
+ t5.data.preprocessors.select_random_chunk,
153
+ feature_key='inputs',
154
+ additional_feature_keys=[
155
+ 'input_event_start_indices', 'input_event_end_indices',
156
+ 'input_state_event_indices'
157
+ ],
158
+ passthrough_feature_keys=['targets', 'state_events'],
159
+ uniform_random_start=True),
160
+ functools.partial(
161
+ run_length_encoding.extract_target_sequence_with_indices,
162
+ state_events_end_token=tie_token if include_ties else None),
163
+ functools.partial(preprocessors.map_midi_programs, codec=codec),
164
+ run_length_encoding.run_length_encode_shifts_fn(
165
+ codec,
166
+ feature_key='targets'),
167
+ functools.partial(
168
+ mixing.mix_transcription_examples,
169
+ codec=codec,
170
+ targets_feature_keys=['targets']),
171
+ run_length_encoding.remove_redundant_state_changes_fn(
172
+ feature_key='targets', codec=codec,
173
+ state_change_event_types=['velocity', 'program']),
174
+ functools.partial(
175
+ preprocessors.compute_spectrograms,
176
+ spectrogram_config=spectrogram_config),
177
+ functools.partial(preprocessors.handle_too_long, skip=skip_too_long),
178
+ functools.partial(
179
+ seqio.preprocessors.tokenize_and_append_eos,
180
+ copy_pretokenized=False)
181
+ ],
182
+ postprocess_fn=None,
183
+ metric_fns=[],
184
+ )
185
+
186
+ # Add transcription eval tasks.
187
+ for split in dataset_config.infer_eval_splits:
188
+ eval_task_name = construct_task_name(
189
+ task_prefix=task_prefix,
190
+ spectrogram_config=spectrogram_config,
191
+ vocab_config=vocab_config,
192
+ task_suffix=split.suffix)
193
+
194
+ if split.include_in_mixture:
195
+ mixture_task_names.append(eval_task_name)
196
+
197
+ seqio.TaskRegistry.add(
198
+ eval_task_name,
199
+ source=seqio.TFExampleDataSource(
200
+ split_to_filepattern={'eval': dataset_config.paths[split.name]},
201
+ feature_description=dataset_config.features),
202
+ output_features=output_features,
203
+ preprocessors=[
204
+ functools.partial(
205
+ tokenize_fn,
206
+ spectrogram_config=spectrogram_config, codec=codec,
207
+ is_training_data='train' in split.name, onsets_only=onsets_only,
208
+ include_ties=include_ties),
209
+ seqio.CacheDatasetPlaceholder(),
210
+ preprocessors.add_unique_id,
211
+ preprocessors.pad_notesequence_array,
212
+ functools.partial(
213
+ t5.data.preprocessors.split_tokens_to_inputs_length,
214
+ feature_key='inputs',
215
+ additional_feature_keys=['input_times', 'sequence'],
216
+ passthrough_feature_keys=['unique_id']),
217
+ # Add dummy targets as they are dropped during the above split to
218
+ # avoid memory blowups, but expected to be present by seqio; the
219
+ # evaluation metrics currently only use the target NoteSequence.
220
+ preprocessors.add_dummy_targets,
221
+ functools.partial(
222
+ preprocessors.compute_spectrograms,
223
+ spectrogram_config=spectrogram_config),
224
+ functools.partial(preprocessors.handle_too_long, skip=False),
225
+ functools.partial(
226
+ seqio.preprocessors.tokenize_and_append_eos,
227
+ copy_pretokenized=False)
228
+ ],
229
+ postprocess_fn=functools.partial(postprocess, codec=codec),
230
+ metric_fns=[
231
+ functools.partial(
232
+ metrics.transcription_metrics,
233
+ codec=codec,
234
+ spectrogram_config=spectrogram_config,
235
+ onsets_only=onsets_only,
236
+ use_ties=include_ties,
237
+ track_specs=track_specs)
238
+ ],
239
+ )
240
+
241
+ seqio.MixtureRegistry.add(
242
+ construct_task_name(
243
+ task_prefix=task_prefix, spectrogram_config=spectrogram_config,
244
+ vocab_config=vocab_config, task_suffix='eval'),
245
+ mixture_task_names,
246
+ default_rate=1)
247
+
248
+
249
+ # Just use default spectrogram config.
250
+ SPECTROGRAM_CONFIG = spectrograms.SpectrogramConfig()
251
+
252
+ # Create two vocabulary configs, one default and one with only on-off velocity.
253
+ VOCAB_CONFIG_FULL = vocabularies.VocabularyConfig()
254
+ VOCAB_CONFIG_NOVELOCITY = vocabularies.VocabularyConfig(num_velocity_bins=1)
255
+
256
+ # Transcribe MAESTRO v1.
257
+ add_transcription_task_to_registry(
258
+ dataset_config=datasets.MAESTROV1_CONFIG,
259
+ spectrogram_config=SPECTROGRAM_CONFIG,
260
+ vocab_config=VOCAB_CONFIG_FULL,
261
+ tokenize_fn=functools.partial(
262
+ preprocessors.tokenize_transcription_example,
263
+ audio_is_samples=False,
264
+ id_feature_key='id'),
265
+ onsets_only=False,
266
+ include_ties=False)
267
+
268
+ # Transcribe MAESTRO v3.
269
+ add_transcription_task_to_registry(
270
+ dataset_config=datasets.MAESTROV3_CONFIG,
271
+ spectrogram_config=SPECTROGRAM_CONFIG,
272
+ vocab_config=VOCAB_CONFIG_FULL,
273
+ tokenize_fn=functools.partial(
274
+ preprocessors.tokenize_transcription_example,
275
+ audio_is_samples=False,
276
+ id_feature_key='id'),
277
+ onsets_only=False,
278
+ include_ties=False)
279
+
280
+ # Transcribe MAESTRO v3 without velocities, with ties.
281
+ add_transcription_task_to_registry(
282
+ dataset_config=datasets.MAESTROV3_CONFIG,
283
+ spectrogram_config=SPECTROGRAM_CONFIG,
284
+ vocab_config=VOCAB_CONFIG_NOVELOCITY,
285
+ tokenize_fn=functools.partial(
286
+ preprocessors.tokenize_transcription_example,
287
+ audio_is_samples=False,
288
+ id_feature_key='id'),
289
+ onsets_only=False,
290
+ include_ties=True)
291
+
292
+ # Transcribe GuitarSet, with ties.
293
+ add_transcription_task_to_registry(
294
+ dataset_config=datasets.GUITARSET_CONFIG,
295
+ spectrogram_config=SPECTROGRAM_CONFIG,
296
+ vocab_config=VOCAB_CONFIG_NOVELOCITY,
297
+ tokenize_fn=preprocessors.tokenize_guitarset_example,
298
+ onsets_only=False,
299
+ include_ties=True)
300
+
301
+ # Transcribe URMP mixes, with ties.
302
+ add_transcription_task_to_registry(
303
+ dataset_config=datasets.URMP_CONFIG,
304
+ spectrogram_config=SPECTROGRAM_CONFIG,
305
+ vocab_config=VOCAB_CONFIG_NOVELOCITY,
306
+ tokenize_fn=functools.partial(
307
+ preprocessors.tokenize_example_with_program_lookup,
308
+ inst_name_to_program_fn=preprocessors.urmp_instrument_to_program,
309
+ id_feature_key='id'),
310
+ onsets_only=False,
311
+ include_ties=True)
312
+
313
+ # Transcribe MusicNet, with ties.
314
+ add_transcription_task_to_registry(
315
+ dataset_config=datasets.MUSICNET_CONFIG,
316
+ spectrogram_config=SPECTROGRAM_CONFIG,
317
+ vocab_config=VOCAB_CONFIG_NOVELOCITY,
318
+ tokenize_fn=functools.partial(
319
+ preprocessors.tokenize_transcription_example,
320
+ audio_is_samples=True,
321
+ id_feature_key='id'),
322
+ onsets_only=False,
323
+ include_ties=True)
324
+
325
+ # Transcribe MusicNetEM, with ties.
326
+ add_transcription_task_to_registry(
327
+ dataset_config=datasets.MUSICNET_EM_CONFIG,
328
+ spectrogram_config=SPECTROGRAM_CONFIG,
329
+ vocab_config=VOCAB_CONFIG_NOVELOCITY,
330
+ tokenize_fn=functools.partial(
331
+ preprocessors.tokenize_transcription_example,
332
+ audio_is_samples=True,
333
+ id_feature_key='id'),
334
+ onsets_only=False,
335
+ include_ties=True)
336
+
337
+ # Transcribe Cerberus4 (piano-guitar-bass-drums quartets), with ties.
338
+ add_transcription_task_to_registry(
339
+ dataset_config=datasets.CERBERUS4_CONFIG,
340
+ spectrogram_config=SPECTROGRAM_CONFIG,
341
+ vocab_config=VOCAB_CONFIG_NOVELOCITY,
342
+ tokenize_fn=functools.partial(
343
+ preprocessors.tokenize_slakh_example,
344
+ track_specs=datasets.CERBERUS4_CONFIG.track_specs,
345
+ ignore_pitch_bends=True),
346
+ onsets_only=False,
347
+ include_ties=True)
348
+
349
+ # Transcribe 10 random sub-mixes of each song from Slakh, with ties.
350
+ add_transcription_task_to_registry(
351
+ dataset_config=datasets.SLAKH_CONFIG,
352
+ spectrogram_config=SPECTROGRAM_CONFIG,
353
+ vocab_config=VOCAB_CONFIG_NOVELOCITY,
354
+ tokenize_fn=functools.partial(
355
+ preprocessors.tokenize_slakh_example,
356
+ track_specs=None,
357
+ ignore_pitch_bends=True),
358
+ onsets_only=False,
359
+ include_ties=True)
360
+
361
+
362
+ # Construct task names to include in transcription mixture.
363
+ MIXTURE_DATASET_NAMES = [
364
+ 'maestrov3', 'guitarset', 'urmp', 'musicnet_em', 'cerberus4', 'slakh'
365
+ ]
366
+ MIXTURE_TRAIN_TASK_NAMES = []
367
+ MIXTURE_EVAL_TASK_NAMES = []
368
+ MIXTURE_TEST_TASK_NAMES = []
369
+ for dataset_name in MIXTURE_DATASET_NAMES:
370
+ MIXTURE_TRAIN_TASK_NAMES.append(
371
+ construct_task_name(task_prefix=f'{dataset_name}_notes_ties',
372
+ spectrogram_config=SPECTROGRAM_CONFIG,
373
+ vocab_config=VOCAB_CONFIG_NOVELOCITY,
374
+ task_suffix='train'))
375
+ MIXTURE_EVAL_TASK_NAMES.append(
376
+ construct_task_name(task_prefix=f'{dataset_name}_notes_ties',
377
+ spectrogram_config=SPECTROGRAM_CONFIG,
378
+ vocab_config=VOCAB_CONFIG_NOVELOCITY,
379
+ task_suffix='validation'))
380
+ MIXING_TEMPERATURE = 10 / 3
381
+
382
+ # Add the mixture of all transcription tasks, with ties.
383
+ seqio.MixtureRegistry.add(
384
+ construct_task_name(
385
+ task_prefix='mega_notes_ties',
386
+ spectrogram_config=SPECTROGRAM_CONFIG,
387
+ vocab_config=VOCAB_CONFIG_NOVELOCITY,
388
+ task_suffix='train'),
389
+ MIXTURE_TRAIN_TASK_NAMES,
390
+ default_rate=functools.partial(
391
+ seqio.mixing_rate_num_examples,
392
+ temperature=MIXING_TEMPERATURE))
393
+ seqio.MixtureRegistry.add(
394
+ construct_task_name(
395
+ task_prefix='mega_notes_ties',
396
+ spectrogram_config=SPECTROGRAM_CONFIG,
397
+ vocab_config=VOCAB_CONFIG_NOVELOCITY,
398
+ task_suffix='eval'),
399
+ MIXTURE_EVAL_TASK_NAMES,
400
+ default_rate=functools.partial(
401
+ seqio.mixing_rate_num_examples,
402
+ temperature=MIXING_TEMPERATURE))
mt3/version.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The MT3 Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """MT3 version."""
16
+ __version__ = '0.0.1'
mt3/vocabularies.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The MT3 Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Model vocabulary."""
16
+
17
+ import dataclasses
18
+ import math
19
+
20
+ from typing import Callable, Optional, Sequence
21
+ from mt3 import event_codec
22
+
23
+ import note_seq
24
+ import seqio
25
+ import t5.data
26
+ import tensorflow as tf
27
+
28
+
29
+ DECODED_EOS_ID = -1
30
+ DECODED_INVALID_ID = -2
31
+
32
+ # defaults for vocabulary config
33
+ DEFAULT_STEPS_PER_SECOND = 100
34
+ DEFAULT_MAX_SHIFT_SECONDS = 10
35
+ DEFAULT_NUM_VELOCITY_BINS = 127
36
+
37
+
38
+ @dataclasses.dataclass
39
+ class VocabularyConfig:
40
+ """Vocabulary configuration parameters."""
41
+ steps_per_second: int = DEFAULT_STEPS_PER_SECOND
42
+ max_shift_seconds: int = DEFAULT_MAX_SHIFT_SECONDS
43
+ num_velocity_bins: int = DEFAULT_NUM_VELOCITY_BINS
44
+
45
+ @property
46
+ def abbrev_str(self):
47
+ s = ''
48
+ if self.steps_per_second != DEFAULT_STEPS_PER_SECOND:
49
+ s += 'ss%d' % self.steps_per_second
50
+ if self.max_shift_seconds != DEFAULT_MAX_SHIFT_SECONDS:
51
+ s += 'ms%d' % self.max_shift_seconds
52
+ if self.num_velocity_bins != DEFAULT_NUM_VELOCITY_BINS:
53
+ s += 'vb%d' % self.num_velocity_bins
54
+ return s
55
+
56
+
57
+ def num_velocity_bins_from_codec(codec: event_codec.Codec):
58
+ """Get number of velocity bins from event codec."""
59
+ lo, hi = codec.event_type_range('velocity')
60
+ return hi - lo
61
+
62
+
63
+ def velocity_to_bin(velocity, num_velocity_bins):
64
+ if velocity == 0:
65
+ return 0
66
+ else:
67
+ return math.ceil(num_velocity_bins * velocity / note_seq.MAX_MIDI_VELOCITY)
68
+
69
+
70
+ def bin_to_velocity(velocity_bin, num_velocity_bins):
71
+ if velocity_bin == 0:
72
+ return 0
73
+ else:
74
+ return int(note_seq.MAX_MIDI_VELOCITY * velocity_bin / num_velocity_bins)
75
+
76
+
77
+ def drop_programs(tokens, codec: event_codec.Codec):
78
+ """Drops program change events from a token sequence."""
79
+ min_program_id, max_program_id = codec.event_type_range('program')
80
+ return tokens[(tokens < min_program_id) | (tokens > max_program_id)]
81
+
82
+
83
+ def programs_to_midi_classes(tokens, codec):
84
+ """Modifies program events to be the first program in the MIDI class."""
85
+ min_program_id, max_program_id = codec.event_type_range('program')
86
+ is_program = (tokens >= min_program_id) & (tokens <= max_program_id)
87
+ return tf.where(
88
+ is_program,
89
+ min_program_id + 8 * ((tokens - min_program_id) // 8),
90
+ tokens)
91
+
92
+
93
+ @dataclasses.dataclass
94
+ class ProgramGranularity:
95
+ # both tokens_map_fn and program_map_fn should be idempotent
96
+ tokens_map_fn: Callable[[Sequence[int], event_codec.Codec], Sequence[int]]
97
+ program_map_fn: Callable[[int], int]
98
+
99
+
100
+ PROGRAM_GRANULARITIES = {
101
+ # "flat" granularity; drop program change tokens and set NoteSequence
102
+ # programs to zero
103
+ 'flat': ProgramGranularity(
104
+ tokens_map_fn=drop_programs,
105
+ program_map_fn=lambda program: 0),
106
+
107
+ # map each program to the first program in its MIDI class
108
+ 'midi_class': ProgramGranularity(
109
+ tokens_map_fn=programs_to_midi_classes,
110
+ program_map_fn=lambda program: 8 * (program // 8)),
111
+
112
+ # leave programs as is
113
+ 'full': ProgramGranularity(
114
+ tokens_map_fn=lambda tokens, codec: tokens,
115
+ program_map_fn=lambda program: program)
116
+ }
117
+
118
+
119
+ def build_codec(vocab_config: VocabularyConfig):
120
+ """Build event codec."""
121
+ event_ranges = [
122
+ event_codec.EventRange('pitch', note_seq.MIN_MIDI_PITCH,
123
+ note_seq.MAX_MIDI_PITCH),
124
+ # velocity bin 0 is used for note-off
125
+ event_codec.EventRange('velocity', 0, vocab_config.num_velocity_bins),
126
+ # used to indicate that a pitch is present at the beginning of a segment
127
+ # (only has an "off" event as when using ties all pitch events until the
128
+ # "tie" event belong to the tie section)
129
+ event_codec.EventRange('tie', 0, 0),
130
+ event_codec.EventRange('program', note_seq.MIN_MIDI_PROGRAM,
131
+ note_seq.MAX_MIDI_PROGRAM),
132
+ event_codec.EventRange('drum', note_seq.MIN_MIDI_PITCH,
133
+ note_seq.MAX_MIDI_PITCH),
134
+ ]
135
+
136
+ return event_codec.Codec(
137
+ max_shift_steps=(vocab_config.steps_per_second *
138
+ vocab_config.max_shift_seconds),
139
+ steps_per_second=vocab_config.steps_per_second,
140
+ event_ranges=event_ranges)
141
+
142
+
143
+ def vocabulary_from_codec(codec: event_codec.Codec) -> seqio.Vocabulary:
144
+ return GenericTokenVocabulary(
145
+ codec.num_classes, extra_ids=t5.data.DEFAULT_EXTRA_IDS)
146
+
147
+
148
+ class GenericTokenVocabulary(seqio.Vocabulary):
149
+ """Vocabulary with pass-through encoding of tokens."""
150
+
151
+ def __init__(self, regular_ids: int, extra_ids: int = 0):
152
+ # The special tokens: 0=PAD, 1=EOS, and 2=UNK
153
+ self._num_special_tokens = 3
154
+ self._num_regular_tokens = regular_ids
155
+ super().__init__(extra_ids=extra_ids)
156
+
157
+ @property
158
+ def eos_id(self) -> Optional[int]:
159
+ return 1
160
+
161
+ @property
162
+ def unk_id(self) -> Optional[int]:
163
+ return 2
164
+
165
+ @property
166
+ def _base_vocab_size(self) -> int:
167
+ """Number of ids.
168
+
169
+ Returns:
170
+ an integer, the vocabulary size
171
+ """
172
+ return self._num_special_tokens + self._num_regular_tokens
173
+
174
+ def _encode(self, token_ids: Sequence[int]) -> Sequence[int]:
175
+ """Encode a list of tokens ids as a list of integers.
176
+
177
+ To keep the first few ids for special tokens, increase ids by the number
178
+ of special tokens.
179
+
180
+ Args:
181
+ token_ids: array of token ids.
182
+
183
+ Returns:
184
+ a list of integers (not terminated by EOS)
185
+ """
186
+ encoded = []
187
+ for token_id in token_ids:
188
+ if not 0 <= token_id < self._num_regular_tokens:
189
+ raise ValueError(
190
+ f'token_id {token_id} does not fall within valid range of '
191
+ f'[0, {self._num_regular_tokens})')
192
+ encoded.append(token_id + self._num_special_tokens)
193
+
194
+ return encoded
195
+
196
+ def _decode(self, ids: Sequence[int]) -> Sequence[int]:
197
+ """Decode a list of integers to a list of token ids.
198
+
199
+ The special tokens of PAD and UNK as well as extra_ids will be
200
+ replaced with DECODED_INVALID_ID in the output. If EOS is present, it will
201
+ be the final token in the decoded output and will be represented by
202
+ DECODED_EOS_ID.
203
+
204
+ Args:
205
+ ids: a list of integers
206
+
207
+ Returns:
208
+ a list of token ids.
209
+ """
210
+ # convert all the extra ids to INVALID_ID
211
+ def _decode_id(encoded_id):
212
+ if encoded_id == self.eos_id:
213
+ return DECODED_EOS_ID
214
+ elif encoded_id < self._num_special_tokens:
215
+ return DECODED_INVALID_ID
216
+ elif encoded_id >= self._base_vocab_size:
217
+ return DECODED_INVALID_ID
218
+ else:
219
+ return encoded_id - self._num_special_tokens
220
+ ids = [_decode_id(int(i)) for i in ids]
221
+ return ids
222
+
223
+ def _encode_tf(self, token_ids: tf.Tensor) -> tf.Tensor:
224
+ """Encode a list of tokens to a tf.Tensor.
225
+
226
+ Args:
227
+ token_ids: array of audio token ids.
228
+
229
+ Returns:
230
+ a 1d tf.Tensor with dtype tf.int32
231
+ """
232
+ with tf.control_dependencies(
233
+ [tf.debugging.assert_less(
234
+ token_ids, tf.cast(self._num_regular_tokens, token_ids.dtype)),
235
+ tf.debugging.assert_greater_equal(
236
+ token_ids, tf.cast(0, token_ids.dtype))
237
+ ]):
238
+ tf_ids = token_ids + self._num_special_tokens
239
+ return tf_ids
240
+
241
+ def _decode_tf(self, ids: tf.Tensor) -> tf.Tensor:
242
+ """Decode in TensorFlow.
243
+
244
+ The special tokens of PAD and UNK as well as extra_ids will be
245
+ replaced with DECODED_INVALID_ID in the output. If EOS is present, it and
246
+ all following tokens in the decoded output and will be represented by
247
+ DECODED_EOS_ID.
248
+
249
+ Args:
250
+ ids: a 1d tf.Tensor with dtype tf.int32
251
+
252
+ Returns:
253
+ a 1d tf.Tensor with dtype tf.int32
254
+ """
255
+ # Create a mask that is true from the first EOS position onward.
256
+ # First, create an array that is True whenever there is an EOS, then cumsum
257
+ # that array so that every position after and including the first True is
258
+ # >1, then cast back to bool for the final mask.
259
+ eos_and_after = tf.cumsum(
260
+ tf.cast(tf.equal(ids, self.eos_id), tf.int32), exclusive=False, axis=-1)
261
+ eos_and_after = tf.cast(eos_and_after, tf.bool)
262
+
263
+ return tf.where(
264
+ eos_and_after,
265
+ DECODED_EOS_ID,
266
+ tf.where(
267
+ tf.logical_and(
268
+ tf.greater_equal(ids, self._num_special_tokens),
269
+ tf.less(ids, self._base_vocab_size)),
270
+ ids - self._num_special_tokens,
271
+ DECODED_INVALID_ID))
272
+
273
+ def __eq__(self, other):
274
+ their_extra_ids = other.extra_ids
275
+ their_num_regular_tokens = other._num_regular_tokens
276
+ return (self.extra_ids == their_extra_ids and
277
+ self._num_regular_tokens == their_num_regular_tokens)
278
+
279
+
280
+ def num_embeddings(vocabulary: GenericTokenVocabulary) -> int:
281
+ """Vocabulary size as a multiple of 128 for TPU efficiency."""
282
+ return 128 * math.ceil(vocabulary.vocab_size / 128)
mt3/vocabularies_test.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The MT3 Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tests for vocabularies."""
16
+
17
+ from absl.testing import absltest
18
+ from mt3 import vocabularies
19
+
20
+ import numpy as np
21
+ import tensorflow.compat.v2 as tf
22
+
23
+ tf.compat.v1.enable_eager_execution()
24
+
25
+
26
+ class VocabulariesTest(absltest.TestCase):
27
+
28
+ def test_velocity_quantization(self):
29
+ self.assertEqual(0, vocabularies.velocity_to_bin(0, num_velocity_bins=1))
30
+ self.assertEqual(0, vocabularies.velocity_to_bin(0, num_velocity_bins=127))
31
+ self.assertEqual(0, vocabularies.bin_to_velocity(0, num_velocity_bins=1))
32
+ self.assertEqual(0, vocabularies.bin_to_velocity(0, num_velocity_bins=127))
33
+
34
+ self.assertEqual(
35
+ 1,
36
+ vocabularies.velocity_to_bin(
37
+ vocabularies.bin_to_velocity(1, num_velocity_bins=1),
38
+ num_velocity_bins=1))
39
+
40
+ for velocity_bin in range(1, 128):
41
+ self.assertEqual(
42
+ velocity_bin,
43
+ vocabularies.velocity_to_bin(
44
+ vocabularies.bin_to_velocity(velocity_bin, num_velocity_bins=127),
45
+ num_velocity_bins=127))
46
+
47
+ def test_encode_decode(self):
48
+ vocab = vocabularies.GenericTokenVocabulary(32)
49
+ input_tokens = [1, 2, 3]
50
+ expected_encoded = [4, 5, 6]
51
+
52
+ # Encode
53
+ self.assertSequenceEqual(vocab.encode(input_tokens), expected_encoded)
54
+ np.testing.assert_array_equal(
55
+ vocab.encode_tf(tf.convert_to_tensor(input_tokens)).numpy(),
56
+ expected_encoded)
57
+
58
+ # Decode
59
+ self.assertSequenceEqual(vocab.decode(expected_encoded), input_tokens)
60
+ np.testing.assert_array_equal(
61
+ vocab.decode_tf(tf.convert_to_tensor(expected_encoded)).numpy(),
62
+ input_tokens)
63
+
64
+ def test_decode_invalid_ids(self):
65
+ vocab = vocabularies.GenericTokenVocabulary(32, extra_ids=4)
66
+ encoded = [0, 2, 3, 4, 34, 35]
67
+ expected_decoded = [-2, -2, 0, 1, 31, -2]
68
+ self.assertSequenceEqual(vocab.decode(encoded), expected_decoded)
69
+ np.testing.assert_array_equal(
70
+ vocab.decode_tf(tf.convert_to_tensor(encoded)).numpy(),
71
+ expected_decoded)
72
+
73
+ def test_decode_eos(self):
74
+ vocab = vocabularies.GenericTokenVocabulary(32)
75
+ encoded = [0, 2, 3, 4, 1, 0, 1, 0]
76
+ # Python decode function truncates everything after first EOS.
77
+ expected_decoded = [-2, -2, 0, 1, -1]
78
+ self.assertSequenceEqual(vocab.decode(encoded), expected_decoded)
79
+ # TF decode function preserves array length.
80
+ expected_decoded_tf = [-2, -2, 0, 1, -1, -1, -1, -1]
81
+ np.testing.assert_array_equal(
82
+ vocab.decode_tf(tf.convert_to_tensor(encoded)).numpy(),
83
+ expected_decoded_tf)
84
+
85
+ def test_encode_invalid_id(self):
86
+ vocab = vocabularies.GenericTokenVocabulary(32)
87
+ inputs = [0, 15, 31]
88
+ # No exception expected.
89
+ vocab.encode(inputs)
90
+ vocab.encode_tf(tf.convert_to_tensor(inputs))
91
+
92
+ inputs_too_low = [-1, 15, 31]
93
+ with self.assertRaises(ValueError):
94
+ vocab.encode(inputs_too_low)
95
+ with self.assertRaises(tf.errors.InvalidArgumentError):
96
+ vocab.encode_tf(tf.convert_to_tensor(inputs_too_low))
97
+
98
+ inputs_too_high = [0, 15, 32]
99
+ with self.assertRaises(ValueError):
100
+ vocab.encode(inputs_too_high)
101
+ with self.assertRaises(tf.errors.InvalidArgumentError):
102
+ vocab.encode_tf(tf.convert_to_tensor(inputs_too_high))
103
+
104
+ def test_encode_dtypes(self):
105
+ vocab = vocabularies.GenericTokenVocabulary(32)
106
+ inputs = [0, 15, 31]
107
+ encoded32 = vocab.encode_tf(tf.convert_to_tensor(inputs, tf.int32))
108
+ self.assertEqual(tf.int32, encoded32.dtype)
109
+ encoded64 = vocab.encode_tf(tf.convert_to_tensor(inputs, tf.int64))
110
+ self.assertEqual(tf.int64, encoded64.dtype)
111
+
112
+
113
+ if __name__ == '__main__':
114
+ absltest.main()
pytest.ini ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [pytest]
2
+ python_files = *_test.py
3
+ log_level = INFO
setup.cfg ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [aliases]
2
+ test=pytest
setup.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The MT3 Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Install mt3."""
16
+
17
+ import os
18
+ import sys
19
+ import setuptools
20
+
21
+ # To enable importing version.py directly, we add its path to sys.path.
22
+ version_path = os.path.join(os.path.dirname(__file__), 'mt3')
23
+ sys.path.append(version_path)
24
+ from version import __version__ # pylint: disable=g-import-not-at-top
25
+
26
+ setuptools.setup(
27
+ name='mt3',
28
+ version=__version__,
29
+ description='Multi-Task Multitrack Music Transcription',
30
+ author='Google Inc.',
31
+ author_email='no-reply@google.com',
32
+ url='http://github.com/magenta/mt3',
33
+ license='Apache 2.0',
34
+ packages=setuptools.find_packages(),
35
+ package_data={
36
+ '': ['*.gin'],
37
+ },
38
+ scripts=[],
39
+ install_requires=[
40
+ 'absl-py == 1.1.0',
41
+ 'ddsp == 3.4.4',
42
+ 'flax == 0.5.2',
43
+ 'gin-config == 0.5.0',
44
+ 'immutabledict == 2.2.1',
45
+ 'librosa == 0.9.2',
46
+ 'mir_eval == 0.7',
47
+ 'note_seq == 0.0.3',
48
+ 'numpy == 1.21.6',
49
+ 'pretty_midi == 0.2.9',
50
+ 'scikit-learn == 1.0.2',
51
+ 'scipy == 1.7.3',
52
+ 'seqio == 0.0.8',
53
+ 't5 == 0.9.3',
54
+ 'tensorflow',
55
+ 'tensorflow-datasets == 4.6.0',
56
+ ],
57
+ classifiers=[
58
+ 'Development Status :: 4 - Beta',
59
+ 'Intended Audience :: Developers',
60
+ 'Intended Audience :: Science/Research',
61
+ 'License :: OSI Approved :: Apache Software License',
62
+ 'Topic :: Scientific/Engineering :: Artificial Intelligence',
63
+ ],
64
+ tests_require=['pytest'],
65
+ setup_requires=['pytest-runner'],
66
+ keywords='music transcription machinelearning audio',
67
+ )
t5x/__init__.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The T5X Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Import API modules."""
16
+
17
+ import t5x.adafactor
18
+ import t5x.checkpoints
19
+ import t5x.decoding
20
+ import t5x.gin_utils
21
+ import t5x.losses
22
+ import t5x.models
23
+ import t5x.partitioning
24
+ import t5x.state_utils
25
+ import t5x.train_state
26
+ import t5x.trainer
27
+ import t5x.utils
28
+
29
+ # Version number.
30
+ from t5x.version import __version__
31
+
32
+ # TODO(adarob): Move clients to t5x.checkpointing and rename
33
+ # checkpoints.py to checkpointing.py
34
+ checkpointing = t5x.checkpoints
t5x/adafactor.py ADDED
@@ -0,0 +1,608 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The T5X Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Adafactor Optimizer.
16
+
17
+ Specialized Adafactor implementation for T5X with:
18
+ - custom factorization specification rules.
19
+ - support for stacked parameters from scanned layers and parameter fusions.
20
+
21
+ Why do we need custom factorization? In the Adafactor paper, scalar, vector and
22
+ matrix parameters are considered. This is sufficiently general because higher
23
+ dimensional parameters can be reshaped. In practice, there are situations where
24
+ higher dimensional parameters are desirable. For example, consider the
25
+ multi-headed attention. It has projection kernels. This is naturally
26
+ represented as 3-dimensional array [d_model, num_head, head_dim]. Keeping the
27
+ 3-dimensional structure can be beneficial for performance optimization, e.g., by
28
+ giving compilers additional degree of freedom to do layout optimization.
29
+
30
+ The default heuristic behavior for the second-moment estimator can lead to an
31
+ unexpected result because it assumes that the parameters are matrices (vectors
32
+ and scalars are not factored). The dimensions are sorted and the smaller
33
+ dimension is assigned to the row dim and the larger dim to the col dim (unless
34
+ the two largest dims have an equal size and then the original ordering of the
35
+ dimensions is used). Then `v_row` (i.e., the optimizer state for the row) is
36
+ obtained by removing the col dim. In other words, `rank(v_row) = rank(v) - 1`.
37
+ If the parameter is higher dimensional, v_row and v_col are higher dimensional.
38
+ Therefore, the outer product of v_row and v_col do not necessarily corresponds
39
+ to the row rank approximation that minimizes the generalized Kullback-Leibler
40
+ divergence (the original Adafactor formulation).
41
+
42
+ This Adafactor implementation generalized the default behavior such that we
43
+ obtain the correct second moment estimator even for higher dimensional
44
+ parameters.
45
+
46
+ """
47
+ import enum
48
+ import re
49
+ from typing import Any, Mapping, Optional, Sequence, Tuple, Union
50
+
51
+ from absl import logging
52
+ from flax import struct
53
+ from flax.core import freeze
54
+ from flax.core import FrozenDict
55
+ from flax.core import unfreeze
56
+ from flax.serialization import from_state_dict
57
+ from flax.serialization import to_state_dict
58
+ from flax.traverse_util import flatten_dict
59
+ from flax.traverse_util import unflatten_dict
60
+ import jax
61
+ import jax.numpy as jnp
62
+ import numpy as np
63
+ from t5x import utils
64
+ from t5x.optimizers import OptimizerDef
65
+ from t5x.optimizers import OptimizerState
66
+
67
+ Dtype = Any
68
+
69
+
70
+ class FactorDim(enum.Enum):
71
+ # Don't factorize this dimension.
72
+ NONE = None
73
+ # A batch-like dimension that we should not average over.
74
+ BATCH = 1
75
+ ROW = 2
76
+ COLUMN = 3
77
+
78
+
79
+ # Sentinel value signifying the legacy heuristic factorization rule.
80
+ class HeuristicRule(enum.Enum):
81
+ token = 1
82
+
83
+
84
+ HEURISTIC_RULE = HeuristicRule.token
85
+ FactorRule = Union[HeuristicRule, Tuple[FactorDim]]
86
+
87
+
88
+ def _restore(target, flat):
89
+ state_dict = unflatten_dict({tuple(k.split('/')): v for k, v in flat.items()})
90
+ if isinstance(target, FrozenDict):
91
+ return freeze(state_dict)
92
+ else:
93
+ return state_dict
94
+
95
+
96
+ def _insert(tpl, idx, x):
97
+ tmp = list(tpl)
98
+ tmp.insert(idx, x)
99
+ return tuple(tmp)
100
+
101
+
102
+ def standard_logical_factor_rules():
103
+ return freeze({
104
+ 'vocab': FactorDim.COLUMN,
105
+ 'embed': FactorDim.ROW,
106
+ 'mlp': FactorDim.COLUMN,
107
+ 'heads': FactorDim.COLUMN,
108
+ 'kv': FactorDim.COLUMN,
109
+ 'joined_kv': FactorDim.COLUMN,
110
+ 'relpos_buckets': FactorDim.NONE,
111
+ 'layers': FactorDim.BATCH, # used in scanned layers
112
+ 'stack': FactorDim.BATCH, # used in stacked params
113
+ # 'batch', 'length' should not occur in parameters
114
+ 'q_wi_fused': FactorDim.COLUMN,
115
+ 'o_wo_fused': FactorDim.COLUMN,
116
+ 'multiquery_heads': FactorDim.COLUMN,
117
+ 'kv_fused': FactorDim.COLUMN,
118
+ 'layer_norm_scale': FactorDim.NONE,
119
+ 'mlp_activations': FactorDim.COLUMN,
120
+ })
121
+
122
+
123
+ def factor_name_to_factordim(name):
124
+ if not isinstance(name, str):
125
+ return name
126
+ name = name.lower()
127
+ return {
128
+ 'row': FactorDim.ROW,
129
+ 'col': FactorDim.COLUMN,
130
+ 'column': FactorDim.COLUMN,
131
+ 'batch': FactorDim.BATCH,
132
+ 'none': FactorDim.NONE,
133
+ 'unfactorized': FactorDim.NONE
134
+ }[name]
135
+
136
+
137
+ class HParamMap:
138
+ """Maps parameter path names to hparams.
139
+
140
+ Names of parameters nested in a PyTree (e.g., an Optimizer) are formed by
141
+ joining the names along the path to the parameter leaf with '/'.
142
+ """
143
+
144
+ def __init__(self, rules):
145
+ self._rules = [(re.compile(r), p) for r, p in rules]
146
+
147
+ def __getitem__(self, key: str) -> Any:
148
+ for r, p in self._rules:
149
+ if r.search(key):
150
+ return p
151
+ raise KeyError(f'No factor rule found for parameter: {key}')
152
+
153
+ def __call__(self, params):
154
+ """Returns a copy of the params with mapped hparams in leaves."""
155
+ flat_state_dict = flatten_dict(to_state_dict(params))
156
+ flat_rules_dict = {k: self['/'.join(k)] for k in flat_state_dict.keys()}
157
+ return from_state_dict(params, unflatten_dict(flat_rules_dict))
158
+
159
+
160
+ @struct.dataclass
161
+ class _AdafactorHyperParams:
162
+ """Hparams for Adafactor optimizer."""
163
+ learning_rate: Optional[float]
164
+ factored: bool
165
+ multiply_by_parameter_scale: Union[bool, HParamMap]
166
+ beta1: Optional[float]
167
+ decay_rate: float
168
+ step_offset: int
169
+ clipping_threshold: Optional[float]
170
+ weight_decay_rate: Optional[float]
171
+ min_dim_size_to_factor: int
172
+ epsilon1: float
173
+ epsilon2: float
174
+ factor_map: Optional[HParamMap] = None
175
+ logical_factor_rules: Any = None
176
+ weight_decay_rate_lr_exponent: Optional[float] = None
177
+ global_norm_clip_threshold: Optional[float] = None
178
+ max_parameter_scale: Optional[float] = None
179
+ skip_nan_updates: Optional[bool] = False
180
+
181
+
182
+ @struct.dataclass
183
+ class _AdafactorParamState:
184
+ v_row: np.ndarray # used in normal factored version
185
+ v_col: np.ndarray
186
+ v: np.ndarray # only used without factoring
187
+ m: np.ndarray # only used with momentum
188
+
189
+
190
+ class Adafactor(OptimizerDef):
191
+ """Adafactor optimizer.
192
+
193
+ Adafactor is described in https://arxiv.org/abs/1804.04235.
194
+ """
195
+
196
+ def __init__(self,
197
+ learning_rate: Optional[float] = None,
198
+ factored: bool = True,
199
+ multiply_by_parameter_scale: Union[bool, HParamMap] = True,
200
+ beta1: Optional[float] = None,
201
+ decay_rate: float = 0.8,
202
+ step_offset: int = 0,
203
+ clipping_threshold: Optional[float] = 1.0,
204
+ weight_decay_rate: Optional[float] = None,
205
+ min_dim_size_to_factor: int = 128,
206
+ epsilon1: float = 1e-30,
207
+ epsilon2: float = 1e-3,
208
+ dtype_momentum: Dtype = jnp.float32,
209
+ factor_map: Optional[HParamMap] = None,
210
+ logical_factor_rules: Optional[Mapping[str, FactorDim]] = None,
211
+ weight_decay_rate_lr_exponent: Optional[float] = None,
212
+ global_norm_clip_threshold: Optional[float] = None,
213
+ max_parameter_scale: Optional[float] = None,
214
+ skip_nan_updates: Optional[bool] = False):
215
+ """Constructor for the Adafactor optimizer.
216
+
217
+
218
+ Args:
219
+ learning_rate: float: learning rate. NB: the natural scale for adafactor
220
+ LR is markedly different from Adam, one doesn't use the 1/sqrt(hidden)
221
+ correction for this optimizer with attention-based models.
222
+ factored: boolean: whether to use factored second-moment estimator for 2d
223
+ variables.
224
+ multiply_by_parameter_scale: boolean: if True, then scale provided
225
+ learning_rate by parameter norm. if False, provided learning_rate is
226
+ absolute step size.
227
+ beta1: an optional float value between 0 and 1, enables momentum and uses
228
+ extra memory if non-None! None by default.
229
+ decay_rate: float: controls second-moment exponential decay schedule.
230
+ step_offset: for finetuning, one may optionally set this to the starting
231
+ step-number of the finetuning phase to reset the second moment
232
+ accumulators after pretraining. Does not affect the momentum even if it
233
+ was used during pretraining.
234
+ clipping_threshold: an optional float >= 1, if None no update clipping.
235
+ weight_decay_rate: optional rate at which to decay weights.
236
+ min_dim_size_to_factor: only factor accumulator if two array dimensions
237
+ are at least this size.
238
+ epsilon1: Regularization constant for squared gradient.
239
+ epsilon2: Regularization constant for parameter scale.
240
+ dtype_momentum: dtype of momentum buffers.
241
+ factor_map: hparam-map from key path to manual factorization rules.
242
+ logical_factor_rules: factorization rules provided as a set of mappings
243
+ from logical axis name to ROW, COLUMN, BATCH, or NONE. Supercedes
244
+ factor_map if `set_param_axes` is called.
245
+ weight_decay_rate_lr_exponent: If present, weight decay rate is computed
246
+ as (learning_rate ** weight_decay_rate_lr_exponent). If
247
+ weight_decay_rate is also present, then multiply by it.
248
+ global_norm_clip_threshold: If set, will clip gradients by global norm
249
+ before Adafactor stats are applied.
250
+ max_parameter_scale: If set, clips the parameter scale to a maximum value,
251
+ which helps prevent parameters from growing without bound.
252
+ skip_nan_updates: If set, any parameter that would have been updated by a
253
+ NaN value after a applying gradients will be kept with the earlier
254
+ value it had.
255
+ """
256
+ if not factored and factor_map is not None:
257
+ raise ValueError('Adafactor factored is False but factorization rules '
258
+ 'have been provided.')
259
+ if not isinstance(multiply_by_parameter_scale, (bool, HParamMap)):
260
+ raise TypeError(
261
+ '`multiply_by_parameter_scale` must be either bool or `HParamMap` '
262
+ f'type. Got {type(multiply_by_parameter_scale)}')
263
+
264
+ if not isinstance(factor_map, (type(None), HParamMap)):
265
+ raise TypeError(
266
+ '`factor_map` must be either None or `HParamMap` type. Got '
267
+ f'{type(factor_map)}')
268
+
269
+ hyper_params = _AdafactorHyperParams(
270
+ learning_rate, factored, multiply_by_parameter_scale, beta1, decay_rate,
271
+ step_offset, clipping_threshold, weight_decay_rate,
272
+ min_dim_size_to_factor, epsilon1, epsilon2, factor_map,
273
+ logical_factor_rules, weight_decay_rate_lr_exponent,
274
+ global_norm_clip_threshold, max_parameter_scale, skip_nan_updates)
275
+ self.dtype_momentum = jax.dtypes.canonicalize_dtype(dtype_momentum)
276
+ super().__init__(hyper_params)
277
+
278
+ @staticmethod
279
+ def _decay_rate_pow(i: int, exponent: float = 0.8) -> float:
280
+ """Default Adafactor second-moment decay schedule."""
281
+ t = jnp.array(i, jnp.float32) + 1.0
282
+ return 1.0 - t**(-exponent)
283
+
284
+ @staticmethod
285
+ def _parse_rule(
286
+ rule: Optional[FactorRule],
287
+ shape: Sequence[int],
288
+ path: str,
289
+ fallback_to_heuristics=True
290
+ ) -> Tuple[Tuple[int, ...], Optional[Union[HeuristicRule, Tuple[Tuple[
291
+ int, ...], Tuple[int, ...]]]]]:
292
+ """Parses specification and return factored dims and dims for averaging.
293
+
294
+ Adafactor needs to know the two largest dimensions to factorize along.
295
+ Traditionally it used a heuristic, but we want finer control over these
296
+ factorization dimensions. Additionally, there are situations where
297
+ parameters are batched together for e.g. scanned layers and QKV fusion,
298
+ and we want to ensure that the scale updates and clipping thresholds are
299
+ calculated _within_ each array and not across the entire batched array.
300
+
301
+ Args:
302
+ rule: the rule is either None (default to heuristic behavior) or a tuple
303
+ of the same rank as the `param` array containing a FactorDim.ROW or
304
+ FactorDim.COLUMN to mark dimensions to factorize in two row and column
305
+ sets, and optionally dimensions marked FactorDim.BATCH to denote batched
306
+ dimensions that should not be averaged over. e.g. (BATCH, ROW, COLUMN,
307
+ COLUMN)
308
+ shape: shape of the variable
309
+ path: '/' joined parameter path.
310
+ fallback_to_heuristics: whether to fallback to heuristic factorization
311
+ rule. For most cases this should be set to `True`.
312
+
313
+ Returns:
314
+ tuple of: tuple of dimensions to average over, 2-tuple of dimensions to
315
+ factorize over.
316
+ """
317
+ param_ndim = len(shape)
318
+
319
+ if rule is None:
320
+ # No factorization.
321
+ return tuple(np.arange(param_ndim)), None
322
+
323
+ if rule is HEURISTIC_RULE:
324
+ if param_ndim > 2:
325
+ raise ValueError(
326
+ f'A parameter with rank strictly higher than 2 must have an '
327
+ f'explicit factorization rule: {path}, {shape}')
328
+ # Even if no explicit rule is provided for the param, we still want to
329
+ # average over all the dimensions for computing the RMS scale.
330
+ return tuple(np.arange(param_ndim)), HEURISTIC_RULE
331
+
332
+ if len(rule) != param_ndim:
333
+ raise ValueError(f'Factorization rule {rule} has incorrect rank '
334
+ f'for param of rank {param_ndim}: {path}, {shape}')
335
+
336
+ row_dims = tuple(idx for idx, d in enumerate(rule) if d == FactorDim.ROW)
337
+ col_dims = tuple(idx for idx, d in enumerate(rule) if d == FactorDim.COLUMN)
338
+ batched_dims = tuple(
339
+ idx for idx, d in enumerate(rule) if d == FactorDim.BATCH)
340
+ averaging_dims = tuple(np.delete(np.arange(param_ndim), batched_dims))
341
+ factor_dims = (row_dims, col_dims)
342
+ if factor_dims == ((), ()):
343
+ factor_dims = None
344
+
345
+ if fallback_to_heuristics and param_ndim <= 2 and not batched_dims:
346
+ logging.warning(
347
+ 'Since rank of parameter %s %d is less than or equal to 2, the '
348
+ 'factorization method falls back to heuristics and the provided '
349
+ 'factor rule %s is ignored.', path, param_ndim, rule)
350
+ return tuple(np.arange(param_ndim)), HEURISTIC_RULE
351
+
352
+ return averaging_dims, factor_dims
353
+
354
+ def _factored_dims(
355
+ self, shape: Sequence[int]) -> Optional[Tuple[Tuple[int], Tuple[int]]]:
356
+ """Whether to use a factored second moment estimator.
357
+
358
+ If there are not two dimensions of size >= min_dim_size_to_factor, then we
359
+ do not factor. If we do factor the accumulator, then this function returns a
360
+ tuple of the two largest axes to reduce over.
361
+
362
+ Args:
363
+ shape: a Shape
364
+
365
+ Returns:
366
+ None or a tuple of ints
367
+ """
368
+ if not self.hyper_params.factored or len(shape) < 2:
369
+ return None
370
+ sorted_dims = np.argsort(shape)
371
+ if shape[sorted_dims[-2]] < self.hyper_params.min_dim_size_to_factor:
372
+ return None
373
+ return (int(sorted_dims[-2]),), (int(sorted_dims[-1]),)
374
+
375
+ def init_param_state(self, param, path):
376
+ shape = param.shape
377
+ state = {k: jnp.zeros((1,)) for k in ['v_row', 'v_col', 'v', 'm']}
378
+ if self.hyper_params.factored:
379
+ factor_rule = (
380
+ self.hyper_params.factor_map[path]
381
+ if self.hyper_params.factor_map else HEURISTIC_RULE)
382
+ else:
383
+ factor_rule = None
384
+ _, factored_dims = self._parse_rule(factor_rule, param.shape, path)
385
+ if factored_dims is HEURISTIC_RULE:
386
+ factored_dims = self._factored_dims(shape)
387
+ if factored_dims is not None:
388
+ d1, d0 = factored_dims
389
+ vr_shape = np.delete(shape, d0)
390
+ vc_shape = np.delete(shape, d1)
391
+ state['v_row'] = jnp.zeros(vr_shape, dtype=jnp.float32)
392
+ state['v_col'] = jnp.zeros(vc_shape, dtype=jnp.float32)
393
+ else:
394
+ state['v'] = jnp.zeros(param.shape, dtype=jnp.float32)
395
+ if self.hyper_params.beta1 is not None:
396
+ state['m'] = jnp.zeros(param.shape, dtype=self.dtype_momentum)
397
+ return _AdafactorParamState(**state)
398
+
399
+ def init_state(self, params):
400
+ params_flat = utils.flatten_dict_string_keys(params)
401
+ param_states_flat = [
402
+ self.init_param_state(param, path)
403
+ for path, param in params_flat.items()
404
+ ]
405
+ param_states_flat = {
406
+ k: v for k, v in zip(params_flat.keys(), param_states_flat)
407
+ }
408
+ param_states = _restore(params, param_states_flat)
409
+ state = OptimizerState(jnp.asarray(0, dtype=jnp.int32), param_states)
410
+ return state
411
+
412
+ def apply_param_gradient(self, step, hyper_params, param, state, grad, path):
413
+ assert hyper_params.learning_rate is not None, 'no learning rate provided.'
414
+ learning_rate = hyper_params.learning_rate
415
+ beta1 = hyper_params.beta1
416
+ decay_rate = hyper_params.decay_rate
417
+ step_offset = hyper_params.step_offset
418
+ multiply_by_parameter_scale = hyper_params.multiply_by_parameter_scale
419
+ max_parameter_scale = hyper_params.max_parameter_scale
420
+ clipping_threshold = hyper_params.clipping_threshold
421
+ weight_decay_rate = hyper_params.weight_decay_rate
422
+ epsilon1 = hyper_params.epsilon1
423
+ epsilon2 = hyper_params.epsilon2
424
+ if hyper_params.weight_decay_rate_lr_exponent:
425
+ weight_decay_rate = (
426
+ (weight_decay_rate or 1.0) *
427
+ learning_rate**hyper_params.weight_decay_rate_lr_exponent)
428
+
429
+ if self.hyper_params.factored:
430
+ factor_rule = (
431
+ self.hyper_params.factor_map[path]
432
+ if self.hyper_params.factor_map else HEURISTIC_RULE)
433
+ else:
434
+ factor_rule = None
435
+ averaging_dims, factored_dims = self._parse_rule(factor_rule, param.shape,
436
+ path)
437
+
438
+ grad = grad.astype(jnp.float32)
439
+
440
+ updates = {k: jnp.zeros((1,)) for k in ['v_row', 'v_col', 'v', 'm']}
441
+ decay_rate = self._decay_rate_pow(step - step_offset, exponent=decay_rate)
442
+ update_scale = learning_rate
443
+
444
+ if isinstance(multiply_by_parameter_scale, HParamMap):
445
+ multiply_by_parameter_scale = multiply_by_parameter_scale[path]
446
+ if multiply_by_parameter_scale:
447
+ param_scale = jnp.sqrt(
448
+ jnp.mean(param * param, axis=averaging_dims, keepdims=True))
449
+ # Clip param_scale to a minimum value of epsilon2.
450
+ param_scale = jnp.maximum(param_scale, epsilon2)
451
+ # Clip param_scale to a maximum value, if specified.
452
+ if max_parameter_scale is not None:
453
+ param_scale = jnp.minimum(param_scale, max_parameter_scale)
454
+ update_scale *= param_scale
455
+ mixing_rate = 1.0 - decay_rate
456
+
457
+ grad_sqr = grad * grad + epsilon1
458
+ if factored_dims is HEURISTIC_RULE:
459
+ factored_dims = self._factored_dims(param.shape)
460
+ if factored_dims is not None:
461
+ d1, d0 = factored_dims
462
+ new_v_row = (
463
+ decay_rate * state.v_row + mixing_rate * jnp.mean(grad_sqr, axis=d0))
464
+ new_v_col = (
465
+ decay_rate * state.v_col + mixing_rate * jnp.mean(grad_sqr, axis=d1))
466
+ updates['v_row'] = new_v_row
467
+ updates['v_col'] = new_v_col
468
+ reduced_d1 = tuple(d - len([e for e in d0 if e < d]) for d in d1)
469
+
470
+ row_col_mean = jnp.mean(new_v_row, axis=reduced_d1, keepdims=True)
471
+ row_factor = (new_v_row / row_col_mean)**-0.5
472
+ col_factor = (new_v_col)**-0.5
473
+ y = (
474
+ grad * jnp.expand_dims(row_factor, axis=d0) *
475
+ jnp.expand_dims(col_factor, axis=d1))
476
+ else:
477
+ new_v = decay_rate * state.v + mixing_rate * grad_sqr
478
+ updates['v'] = new_v
479
+ y = grad * (new_v)**-0.5
480
+
481
+ if clipping_threshold is not None:
482
+ clipping_denom = (
483
+ jnp.maximum(
484
+ 1.0,
485
+ jnp.sqrt(jnp.mean(y * y, axis=averaging_dims, keepdims=True)) /
486
+ clipping_threshold))
487
+ y /= clipping_denom
488
+
489
+ subtrahend = update_scale * y
490
+ if beta1 is not None:
491
+ new_m = beta1 * state.m + (1.0 - beta1) * subtrahend
492
+ subtrahend = new_m
493
+ updates['m'] = new_m.astype(self.dtype_momentum)
494
+
495
+ if weight_decay_rate is not None:
496
+ new_param = (1.0 - weight_decay_rate) * param - subtrahend
497
+ else:
498
+ new_param = param - subtrahend
499
+
500
+ if hyper_params.skip_nan_updates:
501
+ updates['v_row'] = jnp.where(
502
+ jnp.isnan(updates['v_row']), state.v_row, updates['v_row'])
503
+ updates['v_col'] = jnp.where(
504
+ jnp.isnan(updates['v_col']), state.v_col, updates['v_col'])
505
+ updates['v'] = jnp.where(jnp.isnan(updates['v']), state.v, updates['v'])
506
+ updates['m'] = jnp.where(jnp.isnan(updates['m']), state.m, updates['m'])
507
+ new_param = jnp.where(jnp.isnan(new_param), param, new_param)
508
+ new_state = _AdafactorParamState(**updates)
509
+
510
+ return new_param.astype(param.dtype), new_state
511
+
512
+ def apply_gradient(self, hyper_params, params, state, grads):
513
+ """Applies a gradient for a set of parameters.
514
+
515
+ Args:
516
+ hyper_params: a named tuple of hyper parameters.
517
+ params: the parameters that should be updated.
518
+ state: a named tuple containing the state of the optimizer
519
+ grads: the gradient tensors for the parameters.
520
+
521
+ Returns:
522
+ A tuple containing the new parameters and the new optimizer state.
523
+ """
524
+ step = state.step
525
+ # We assume that params, param_states, and grads are all dict-like here.
526
+ params_flat_dict = utils.flatten_dict_string_keys(params)
527
+ params_paths = params_flat_dict.keys()
528
+ params_flat = params_flat_dict.values()
529
+ # extra paranoia to guarantee identical value ordering
530
+ states_flat = utils.flatten_dict_string_keys(state.param_states)
531
+ states_flat = [states_flat[k] for k in params_paths]
532
+ grads_flat = utils.flatten_dict_string_keys(grads)
533
+ grads_flat = [grads_flat[k] for k in params_paths]
534
+
535
+ if hyper_params.global_norm_clip_threshold:
536
+ # Paper: http://proceedings.mlr.press/v28/pascanu13.pdf
537
+ # TF: https://www.tensorflow.org/api_docs/python/tf/clip_by_global_norm
538
+ squared_l2_norms = [jnp.sum(jnp.square(g)) for g in grads_flat]
539
+ global_norm = jnp.sqrt(jnp.sum(jnp.array(squared_l2_norms)))
540
+ scale = hyper_params.global_norm_clip_threshold * jnp.minimum(
541
+ 1.0 / hyper_params.global_norm_clip_threshold, 1.0 / global_norm)
542
+ grads_flat = [g * scale for g in grads_flat]
543
+
544
+ out = [
545
+ self.apply_param_gradient(step, hyper_params, param, state, grad, path)
546
+ for param, state, grad, path in zip(params_flat, states_flat,
547
+ grads_flat, params_paths)
548
+ ]
549
+
550
+ new_params_flat, new_states_flat = list(zip(*out)) if out else ((), ())
551
+ new_params_flat = {k: v for k, v in zip(params_paths, new_params_flat)}
552
+ new_states_flat = {k: v for k, v in zip(params_paths, new_states_flat)}
553
+ new_params = _restore(params, new_params_flat)
554
+ new_param_states = _restore(params, new_states_flat)
555
+ new_state = OptimizerState(step + 1, new_param_states)
556
+
557
+ return new_params, new_state
558
+
559
+ def set_param_axes(self, param_logical_axes):
560
+ """Sets Adafactor factorization map from logical axis names tree."""
561
+ logical_factor_rules = self.hyper_params.logical_factor_rules
562
+ if logical_factor_rules is None:
563
+ return
564
+
565
+ # pylint:disable=invalid-name
566
+ NONE = FactorDim.NONE
567
+ COLUMN = FactorDim.COLUMN
568
+ ROW = FactorDim.ROW
569
+
570
+ # pylint:enable=invalid-name
571
+
572
+ def apply_rules(axes):
573
+ # Partially factorized params are marked as unfactorized, preserving
574
+ # only BATCH axis annotations. We also check for incompletely factorized
575
+ # params that have ROW, COLUMN but also accidental NONE dimensions and
576
+ # raise an error in that case.
577
+ axis_rules = tuple(logical_factor_rules[x] for x in axes)
578
+ axis_rules = tuple(factor_name_to_factordim(x) for x in axis_rules)
579
+ if ROW in axis_rules and COLUMN in axis_rules and NONE in axis_rules:
580
+ raise ValueError(f'Incomplete adafactor spec {axis_rules} for {axes}!')
581
+ if ROW not in axis_rules or COLUMN not in axis_rules:
582
+ axis_rules = tuple(
583
+ NONE if x in (ROW, COLUMN) else x for x in axis_rules)
584
+ return axis_rules
585
+
586
+ factor_map = jax.tree_map(apply_rules, param_logical_axes)
587
+ factor_map = utils.flatten_dict_string_keys(factor_map)
588
+
589
+ self.hyper_params = self.hyper_params.replace(factor_map=factor_map)
590
+
591
+ def derive_logical_axes(self, optimizer_state, param_logical_axes):
592
+ """Derives optimizer logical partitioning from model logical partitions."""
593
+ optimizer_logical_axes = jax.tree_map(lambda x: None,
594
+ optimizer_state.state_dict())
595
+ optimizer_logical_axes['target'] = param_logical_axes
596
+
597
+ def factor_rule(logical_axes, adafactor_leaf):
598
+ return dict(
599
+ v_row=None,
600
+ v_col=None,
601
+ v=logical_axes if adafactor_leaf['v'].shape != (1,) else None,
602
+ m=logical_axes if self.hyper_params.beta1 else None)
603
+
604
+ optimizer_logical_axes['state']['param_states'] = jax.tree_map(
605
+ factor_rule, unfreeze(param_logical_axes),
606
+ optimizer_state.state_dict()['state']['param_states'])
607
+
608
+ return optimizer_state.restore_state(unfreeze(optimizer_logical_axes))
t5x/adafactor_test.py ADDED
@@ -0,0 +1,527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The T5X Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tests for t5x.adafactor."""
16
+
17
+ import functools
18
+ import operator
19
+ from typing import Sequence
20
+
21
+ from absl.testing import absltest
22
+ from absl.testing import parameterized
23
+
24
+ import flax
25
+ from flax import optim # used for equivalence testing only
26
+ from flax import traverse_util
27
+ import jax
28
+ from jax import numpy as jnp
29
+ from jax import random
30
+ import numpy as np
31
+
32
+ from t5x import adafactor
33
+ from t5x import optimizers
34
+
35
+ OptimizerState = optimizers.OptimizerState
36
+
37
+ _AdafactorHyperParams = adafactor._AdafactorHyperParams
38
+ _AdafactorParamState = adafactor._AdafactorParamState
39
+
40
+ _BATCH = adafactor.FactorDim.BATCH
41
+ _ROW = adafactor.FactorDim.ROW
42
+ _COL = adafactor.FactorDim.COLUMN
43
+
44
+ # Testing helpers
45
+
46
+
47
+ def _assert_numpy_allclose(a, b, atol=None, rtol=None):
48
+ a, b = jnp.array(a), jnp.array(b)
49
+ a = a.astype(np.float32) if a.dtype == jnp.bfloat16 else a
50
+ b = b.astype(np.float32) if b.dtype == jnp.bfloat16 else b
51
+ kw = {}
52
+ if atol:
53
+ kw['atol'] = atol
54
+ if rtol:
55
+ kw['rtol'] = rtol
56
+ np.testing.assert_allclose(a, b, **kw)
57
+
58
+
59
+ def check_eq(xs, ys, atol=None, rtol=None):
60
+ xs_leaves, xs_tree = jax.tree_flatten(xs)
61
+ ys_leaves, ys_tree = jax.tree_flatten(ys)
62
+ assert xs_tree == ys_tree, f"Tree shapes don't match. \n{xs_tree}\n{ys_tree}"
63
+ assert jax.tree_util.tree_all(
64
+ jax.tree_multimap(lambda x, y: np.array(x).shape == np.array(y).shape,
65
+ xs_leaves, ys_leaves)), "Leaves' shapes don't match."
66
+ assert jax.tree_multimap(
67
+ functools.partial(_assert_numpy_allclose, atol=atol, rtol=rtol),
68
+ xs_leaves, ys_leaves)
69
+
70
+
71
+ def flattened_state_dict(x):
72
+ s = flax.serialization.to_state_dict(x)
73
+ return flax.traverse_util.flatten_dict(s, sep='/')
74
+
75
+
76
+ def tree_shape(x):
77
+ return jax.tree_map(jnp.shape, x)
78
+
79
+
80
+ def tree_equals(x, y):
81
+ return jax.tree_util.tree_all(jax.tree_multimap(operator.eq, x, y))
82
+
83
+
84
+ def _get_multi_adafactor(
85
+ learning_rate: float, step_offset: int,
86
+ adafactor_exclude_from_parameter_scale: Sequence[str]
87
+ ) -> optim.MultiOptimizer:
88
+ """Get adafactor with support for excluding some parameters from scaling."""
89
+
90
+ def _should_not_scale(path):
91
+ return any([s in path for s in adafactor_exclude_from_parameter_scale])
92
+
93
+ scaled_vars = traverse_util.ModelParamTraversal(
94
+ lambda path, _: not _should_not_scale(path))
95
+ unscaled_vars = traverse_util.ModelParamTraversal(
96
+ lambda path, _: _should_not_scale(path))
97
+ scaled_opt = optim.Adafactor(
98
+ learning_rate, decay_rate=0.8, step_offset=step_offset)
99
+ unscaled_opt = optim.Adafactor(
100
+ learning_rate,
101
+ decay_rate=0.8,
102
+ step_offset=step_offset,
103
+ multiply_by_parameter_scale=False)
104
+ return optim.MultiOptimizer((scaled_vars, scaled_opt),
105
+ (unscaled_vars, unscaled_opt))
106
+
107
+
108
+ # Inline test data
109
+
110
+ MODEL_SHAPE = {
111
+ 'decoder': {
112
+ 'decoder_norm': {'scale': [128]},
113
+ 'layers_0': {
114
+ 'encoder_decoder_attention': {
115
+ 'key': {'kernel': [128, 256]},
116
+ 'out': {'kernel': [256, 128]},
117
+ 'query': {'kernel': [128, 256]},
118
+ 'value': {'kernel': [128, 256]}},
119
+ 'mlp': {
120
+ 'wi': {'kernel': [128, 512]},
121
+ 'wo': {'kernel': [512, 128]}},
122
+ 'pre_cross_attention_layer_norm': {'scale': [128]},
123
+ 'pre_mlp_layer_norm': {'scale': [128]},
124
+ 'pre_self_attention_layer_norm': {'scale': [128]},
125
+ 'self_attention': {
126
+ 'key': {'kernel': [128, 256]},
127
+ 'out': {'kernel': [256, 128]},
128
+ 'query': {'kernel': [128, 256]},
129
+ 'value': {'kernel': [128, 256]}}},
130
+ 'layers_1': {
131
+ 'encoder_decoder_attention': {
132
+ 'key': {'kernel': [128, 128]},
133
+ 'out': {'kernel': [128, 128]},
134
+ 'query': {'kernel': [128, 128]},
135
+ 'value': {'kernel': [128, 128]}},
136
+ 'mlp': {
137
+ 'wi': {'kernel': [128, 512]},
138
+ 'wo': {'kernel': [512, 128]}},
139
+ 'pre_cross_attention_layer_norm': {'scale': [128]},
140
+ 'pre_mlp_layer_norm': {'scale': [128]},
141
+ 'pre_self_attention_layer_norm': {'scale': [128]},
142
+ 'self_attention': {
143
+ 'key': {'kernel': [128, 256]},
144
+ 'out': {'kernel': [256, 128]},
145
+ 'query': {'kernel': [128, 256]},
146
+ 'value': {'kernel': [128, 256]}}},
147
+ 'relpos_bias': {'rel_embedding': [2, 32]}},
148
+ 'encoder': {
149
+ 'encoder_norm': {'scale': [128]},
150
+ 'layers_0': {
151
+ 'attention': {
152
+ 'key': {'kernel': [128, 256]},
153
+ 'out': {'kernel': [256, 128]},
154
+ 'query': {'kernel': [128, 256]},
155
+ 'value': {'kernel': [128, 256]}},
156
+ 'mlp': {
157
+ 'wi': {'kernel': [128, 512]},
158
+ 'wo': {'kernel': [512, 128]}},
159
+ 'pre_attention_layer_norm': {'scale': [128]},
160
+ 'pre_mlp_layer_norm': {'scale': [128]}},
161
+ 'layers_1': {
162
+ 'attention': {
163
+ 'key': {'kernel': [128, 256]},
164
+ 'out': {'kernel': [256, 128]},
165
+ 'query': {'kernel': [128, 256]},
166
+ 'value': {'kernel': [128, 256]}},
167
+ 'mlp': {
168
+ 'wi': {'kernel': [128, 512]},
169
+ 'wo': {'kernel': [512, 128]}},
170
+ 'pre_attention_layer_norm': {'scale': [128]},
171
+ 'pre_mlp_layer_norm': {'scale': [128]}},
172
+ 'relpos_bias': {'rel_embedding': [2, 32]}},
173
+ 'token_embedder': {'embedding': [32128, 128]}} # pyformat: disable
174
+
175
+
176
+ class AdafactorTest(parameterized.TestCase):
177
+
178
+ # Classic Adafactor Behavior Tests
179
+
180
+ def test_2D_simple(self):
181
+ x = {'a': jnp.ones((24, 16))}
182
+ opt_def = adafactor.Adafactor(min_dim_size_to_factor=8)
183
+ optimizer = opt_def.create(x)
184
+ shapes = tree_shape(flattened_state_dict(optimizer.state.param_states))
185
+ ref = {'a/m': (1,), 'a/v': (1,), 'a/v_col': (24,), 'a/v_row': (16,)}
186
+ self.assertTrue(tree_equals(shapes, ref))
187
+
188
+ def test_2D_simple_nofactor(self):
189
+ x = {'a': jnp.ones((24, 16))}
190
+ opt_def = adafactor.Adafactor(min_dim_size_to_factor=32)
191
+ optimizer = opt_def.create(x)
192
+ shapes = tree_shape(flattened_state_dict(optimizer.state.param_states))
193
+ ref = {'a/m': (1,), 'a/v': (24, 16), 'a/v_col': (1,), 'a/v_row': (1,)}
194
+ self.assertTrue(tree_equals(shapes, ref))
195
+
196
+ def test_2D_simple_nofactor_momentum(self):
197
+ x = {'a': jnp.ones((24, 16))}
198
+ opt_def = adafactor.Adafactor(min_dim_size_to_factor=32, beta1=0.1)
199
+ optimizer = opt_def.create(x)
200
+ shapes = tree_shape(flattened_state_dict(optimizer.state.param_states))
201
+ ref = {'a/m': (24, 16), 'a/v': (24, 16), 'a/v_col': (1,), 'a/v_row': (1,)}
202
+ self.assertTrue(tree_equals(shapes, ref))
203
+
204
+ def test_3D_simple(self):
205
+ x = {'a': jnp.ones((24, 4, 16))}
206
+ factor_map = adafactor.HParamMap((('a', (_COL, _BATCH, _ROW)),))
207
+ opt_def = adafactor.Adafactor(
208
+ min_dim_size_to_factor=8, factor_map=factor_map)
209
+ optimizer = opt_def.create(x)
210
+ shapes = tree_shape(flattened_state_dict(optimizer.state.param_states))
211
+ ref = {'a/m': (1,), 'a/v': (1,), 'a/v_col': (24, 4), 'a/v_row': (4, 16)}
212
+ self.assertTrue(tree_equals(shapes, ref))
213
+
214
+ def test_init_state(self):
215
+ params = {'x': np.zeros((3, 2))}
216
+ optimizer_def = adafactor.Adafactor(
217
+ learning_rate=0.1, decay_rate=0.8, beta1=None, min_dim_size_to_factor=0)
218
+ state = optimizer_def.init_state(params)
219
+
220
+ expected_hyper_params = _AdafactorHyperParams(0.1, True, True, None, 0.8, 0,
221
+ 1.0, None, 0, 1e-30, 1e-3)
222
+ self.assertEqual(optimizer_def.hyper_params, expected_hyper_params)
223
+ expected_state = OptimizerState(
224
+ 0, {
225
+ 'x':
226
+ _AdafactorParamState(
227
+ np.zeros((2,)), np.zeros((3,)), np.zeros(
228
+ (1,)), np.zeros((1,)))
229
+ })
230
+ check_eq(state, expected_state)
231
+
232
+ # unfactorized
233
+ optimizer_def = adafactor.Adafactor(
234
+ learning_rate=0.1, decay_rate=0.8, beta1=0.0, min_dim_size_to_factor=32)
235
+ state = optimizer_def.init_state(params)
236
+
237
+ expected_hyper_params = _AdafactorHyperParams(0.1, True, True, 0.0, 0.8, 0,
238
+ 1.0, None, 32, 1e-30, 1e-3)
239
+ self.assertEqual(optimizer_def.hyper_params, expected_hyper_params)
240
+ expected_state = OptimizerState(
241
+ 0, {
242
+ 'x':
243
+ _AdafactorParamState(
244
+ np.zeros((1,)), np.zeros((1,)), np.zeros(
245
+ (3, 2)), np.zeros((3, 2)))
246
+ })
247
+ check_eq(state, expected_state)
248
+
249
+ def test_apply_gradient(self):
250
+ optimizer_def = adafactor.Adafactor(
251
+ learning_rate=0.1, decay_rate=0.8, min_dim_size_to_factor=0)
252
+ params = {'x': np.ones((3, 2), np.float32)}
253
+ state = OptimizerState(
254
+ 1, {
255
+ 'x':
256
+ _AdafactorParamState(
257
+ np.array([0.9, 0.9]), np.array([0.1, 0.1, 0.1]),
258
+ np.zeros((1,)), np.zeros((1,)))
259
+ })
260
+ grads = {'x': np.ones((3, 2), np.float32)}
261
+ new_params, new_state = optimizer_def.apply_gradient(
262
+ optimizer_def.hyper_params, params, state, grads)
263
+ expected_new_state = OptimizerState(
264
+ 2, {
265
+ 'x':
266
+ _AdafactorParamState(
267
+ np.array([0.9574349, 0.9574349]),
268
+ np.array([0.6169143, 0.6169143, 0.6169143]), np.zeros(
269
+ (1,)), np.zeros((1,)))
270
+ })
271
+ expected_new_params = {'x': 0.9 * np.ones((3, 2))}
272
+ check_eq(new_params, expected_new_params)
273
+ check_eq(new_state, expected_new_state, rtol=1e-6)
274
+
275
+ # unfactored w momentum
276
+ optimizer_def = adafactor.Adafactor(
277
+ learning_rate=0.1, beta1=0.0, decay_rate=0.8, min_dim_size_to_factor=32)
278
+ params = {'x': np.ones((3, 2), np.float32)}
279
+ state = OptimizerState(
280
+ 1, {
281
+ 'x':
282
+ _AdafactorParamState(
283
+ np.zeros(1,), np.zeros(1,), 0.5 * np.ones(
284
+ (3, 2)), np.zeros((3, 2)))
285
+ })
286
+ grads = {'x': np.ones((3, 2), np.float32)}
287
+ new_params, new_state = optimizer_def.apply_gradient(
288
+ optimizer_def.hyper_params, params, state, grads)
289
+ expected_new_params = {'x': 0.9 * np.ones((3, 2))}
290
+ check_eq(new_params, expected_new_params)
291
+ expected_new_state = OptimizerState(
292
+ 2, {
293
+ 'x':
294
+ _AdafactorParamState(
295
+ np.array([0.0]), np.array([0.0]), 0.787174 * np.ones(
296
+ (3, 2)), 0.1 * np.ones((3, 2)))
297
+ })
298
+ check_eq(new_state, expected_new_state, rtol=1e-6)
299
+
300
+ def test_apply_gradient_with_global_norm_clipping(self):
301
+ optimizer_def = adafactor.Adafactor(
302
+ learning_rate=0.1,
303
+ decay_rate=0.8,
304
+ min_dim_size_to_factor=0,
305
+ global_norm_clip_threshold=1.0)
306
+ params = {'x': np.ones((3, 2), np.float32)}
307
+ state = OptimizerState(
308
+ 1, {
309
+ 'x':
310
+ _AdafactorParamState(
311
+ np.array([0.9, 0.9]), np.array([0.1, 0.1, 0.1]),
312
+ np.zeros((1,)), np.zeros((1,)))
313
+ })
314
+ grads = {'x': np.ones((3, 2), np.float32)}
315
+ new_params, new_state = optimizer_def.apply_gradient(
316
+ optimizer_def.hyper_params, params, state, grads)
317
+ expected_new_state = OptimizerState(
318
+ 2, {
319
+ 'x':
320
+ _AdafactorParamState(
321
+ np.array([0.478811, 0.478811]),
322
+ np.array([0.13829, 0.13829, 0.13829]), np.zeros(
323
+ (1,)), np.zeros((1,)))
324
+ })
325
+ expected_new_params = {'x': 0.9 * np.ones((3, 2))}
326
+ check_eq(new_params, expected_new_params)
327
+ check_eq(new_state, expected_new_state, rtol=1e-6)
328
+
329
+ def test_factorizes(self):
330
+ params = {'x': np.zeros((64, 64))}
331
+ optimizer_def = adafactor.Adafactor(
332
+ learning_rate=0.1,
333
+ decay_rate=0.8,
334
+ beta1=None,
335
+ min_dim_size_to_factor=32)
336
+ state = optimizer_def.init_state(params)
337
+ self.assertEqual(state.param_states['x'].v.shape, (1,))
338
+ self.assertEqual(state.param_states['x'].m.shape, (1,))
339
+ self.assertEqual(state.param_states['x'].v_row.shape, (64,))
340
+ self.assertEqual(state.param_states['x'].v_col.shape, (64,))
341
+
342
+ params = {'x': np.zeros((31, 64))}
343
+ optimizer_def = adafactor.Adafactor(
344
+ learning_rate=0.1,
345
+ decay_rate=0.8,
346
+ beta1=None,
347
+ min_dim_size_to_factor=32)
348
+ state = optimizer_def.init_state(params)
349
+ self.assertEqual(state.param_states['x'].v.shape, (31, 64))
350
+ self.assertEqual(state.param_states['x'].m.shape, (1,))
351
+ self.assertEqual(state.param_states['x'].v_row.shape, (1,))
352
+ self.assertEqual(state.param_states['x'].v_col.shape, (1,))
353
+
354
+ # Manually specified factorization rules tests.
355
+
356
+ @parameterized.parameters(
357
+ {'rule': (_ROW, _COL)},
358
+ {'rule': (_COL, _ROW)},
359
+ )
360
+ def test_2D_ignore_specified_factor_rule(self, rule):
361
+ x = {'a': jnp.ones((24, 16))}
362
+ factor_map = adafactor.HParamMap((('a', rule),))
363
+ opt_def = adafactor.Adafactor(
364
+ min_dim_size_to_factor=8, factor_map=factor_map)
365
+ optimizer = opt_def.create(x)
366
+ shapes = tree_shape(flattened_state_dict(optimizer.state.param_states))
367
+ # Since param is 2D, the explicit factor rule should be ignored and falls
368
+ # back to heuristics where v_row corresponds to the smaller dim.
369
+ ref = {'a/m': (1,), 'a/v': (1,), 'a/v_col': (24,), 'a/v_row': (16,)}
370
+ self.assertTrue(tree_equals(shapes, ref))
371
+
372
+ def test_3D_simple_manual_rules(self):
373
+ x = {'a': jnp.ones((24, 4, 16))}
374
+
375
+ factor_map = adafactor.HParamMap((('a', (_COL, _BATCH, _ROW)),))
376
+ opt_def = adafactor.Adafactor(
377
+ min_dim_size_to_factor=8, factor_map=factor_map)
378
+ optimizer = opt_def.create(x)
379
+ shapes = tree_shape(flattened_state_dict(optimizer.state.param_states))
380
+ ref = {'a/m': (1,), 'a/v': (1,), 'a/v_col': (24, 4), 'a/v_row': (4, 16)}
381
+ self.assertTrue(tree_equals(shapes, ref))
382
+
383
+ factor_map = adafactor.HParamMap((('a', (_ROW, _BATCH, _COL)),))
384
+ opt_def = adafactor.Adafactor(
385
+ min_dim_size_to_factor=8, factor_map=factor_map)
386
+ optimizer = opt_def.create(x)
387
+ shapes = tree_shape(flattened_state_dict(optimizer.state.param_states))
388
+ ref = {'a/m': (1,), 'a/v': (1,), 'a/v_col': (4, 16), 'a/v_row': (24, 4)}
389
+ self.assertTrue(tree_equals(shapes, ref))
390
+
391
+ factor_map = adafactor.HParamMap((('a', (_COL, _ROW, _ROW)),))
392
+ opt_def = adafactor.Adafactor(
393
+ min_dim_size_to_factor=8, factor_map=factor_map)
394
+ optimizer = opt_def.create(x)
395
+ shapes = tree_shape(flattened_state_dict(optimizer.state.param_states))
396
+ ref = {'a/m': (1,), 'a/v': (1,), 'a/v_col': (24,), 'a/v_row': (4, 16)}
397
+ self.assertTrue(tree_equals(shapes, ref))
398
+
399
+ factor_map = adafactor.HParamMap((('a', (_COL, _COL, _ROW)),))
400
+ opt_def = adafactor.Adafactor(
401
+ min_dim_size_to_factor=8, factor_map=factor_map)
402
+ optimizer = opt_def.create(x)
403
+ shapes = tree_shape(flattened_state_dict(optimizer.state.param_states))
404
+ ref = {'a/m': (1,), 'a/v': (1,), 'a/v_col': (24, 4), 'a/v_row': (16,)}
405
+ self.assertTrue(tree_equals(shapes, ref))
406
+
407
+ def test_standard_factor_rules(self):
408
+ # one-off test to double-check that we're following the previous
409
+ # heuristic convention for rows/columns.
410
+ def test_standard_factor_rules():
411
+ token_embedding = (_COL, _ROW)
412
+ attn_qkv = (_ROW, _COL)
413
+ attn_out = (_COL, _ROW)
414
+ mlp_in = (_ROW, _COL)
415
+ mlp_out = (_COL, _ROW)
416
+ return ((r'_layer_norm/(bias|scale)',
417
+ None), (r'(encoder|decoder)_norm/(bias|scale)', None),
418
+ (r'(encoder_decoder_|self_|\b)attention/(query|key|value)/kernel',
419
+ attn_qkv), (r'(encoder_decoder_|self_|\b)attention/out/kernel',
420
+ attn_out), (r'mlp/DenseGeneral_\d+/bias', None),
421
+ (r'mlp/wi(_\d+)?/kernel', mlp_in), (r'mlp/wo/kernel', mlp_out),
422
+ (r'\brelpos_bias', None), (r'token_embedder', token_embedding),
423
+ (r'.*', adafactor.HEURISTIC_RULE))
424
+
425
+ # create fake model parameters
426
+ k = jax.random.PRNGKey(0)
427
+ params = jax.tree_map(
428
+ lambda shape: jax.random.uniform(k, shape),
429
+ MODEL_SHAPE,
430
+ is_leaf=lambda x: isinstance(x, list))
431
+ # make traditional adafactor state with heuristic
432
+ factor_map1 = adafactor.HParamMap(((r'.*', adafactor.HEURISTIC_RULE),))
433
+ optimizer_def1 = adafactor.Adafactor(
434
+ 0.1,
435
+ decay_rate=0.8,
436
+ step_offset=0,
437
+ multiply_by_parameter_scale=True,
438
+ factor_map=factor_map1)
439
+ optimizer1 = optimizer_def1.create(params)
440
+ # make traditional adafactor state with explicit rules
441
+ factor_map2 = adafactor.HParamMap(test_standard_factor_rules())
442
+ optimizer_def2 = adafactor.Adafactor(
443
+ 0.1,
444
+ decay_rate=0.8,
445
+ step_offset=0,
446
+ multiply_by_parameter_scale=True,
447
+ factor_map=factor_map2)
448
+ optimizer2 = optimizer_def2.create(params)
449
+ # are they the same?
450
+ check_eq(optimizer1.state.param_states, optimizer2.state.param_states)
451
+
452
+ @parameterized.parameters(
453
+ {'shape': (64, 64)},
454
+ {'shape': (64, 132)},
455
+ {'shape': (132, 64)},
456
+ {'shape': (132, 132)},
457
+ {'shape': (132, 140)},
458
+ {'shape': (140, 132)},
459
+ )
460
+ def test_no_factor_map_equivalence(self, shape):
461
+ k = random.PRNGKey(0)
462
+ k1, k2 = random.split(k)
463
+ p = {'a': random.uniform(k1, shape)}
464
+ g = {'a': random.uniform(k2, shape)}
465
+
466
+ orig_opt = optim.Adafactor(0.1).create(p)
467
+ new_opt = adafactor.Adafactor(0.1, factor_map=None).create(p)
468
+ check_eq(orig_opt.state_dict(), new_opt.state_dict())
469
+
470
+ orig_opt1 = orig_opt.apply_gradient(g)
471
+ new_opt1 = new_opt.apply_gradient(g)
472
+ check_eq(orig_opt1.state_dict(), new_opt1.state_dict())
473
+
474
+ @parameterized.parameters({
475
+ 'shape': (128, 128),
476
+ 'rule': (_ROW, _COL)
477
+ }, {
478
+ 'shape': (132, 128),
479
+ 'rule': (_COL, _ROW)
480
+ }, {
481
+ 'shape': (128, 132),
482
+ 'rule': (_ROW, _COL)
483
+ })
484
+ def test_simple_equivalence(self, shape, rule):
485
+ k = random.PRNGKey(0)
486
+ k1, k2 = random.split(k)
487
+ k3, k4 = random.split(k1)
488
+ k5, k6 = random.split(k2)
489
+
490
+ p = {'a': random.uniform(k3, shape), 'b': random.uniform(k4, shape)}
491
+ g = {'a': random.uniform(k5, shape), 'b': random.uniform(k6, shape)}
492
+
493
+ orig_opt = optim.Adafactor(0.1).create(p)
494
+ factor_map = adafactor.HParamMap(
495
+ rules=((('a'), rule), ('.*', adafactor.HEURISTIC_RULE)))
496
+ new_opt = adafactor.Adafactor(0.1, factor_map=factor_map).create(p)
497
+ check_eq(orig_opt.state_dict(), new_opt.state_dict())
498
+
499
+ orig_opt1 = orig_opt.apply_gradient(g)
500
+ new_opt1 = new_opt.apply_gradient(g)
501
+ check_eq(orig_opt1.state_dict(), new_opt1.state_dict())
502
+
503
+ @parameterized.parameters({'shape': (64, 64)}, {'shape': (132, 132)})
504
+ def test_multiply_by_parameter_scale_equivalence(self, shape):
505
+ # Use large parameter values to magnify the parameter scaling effect.
506
+ p = {'a': np.random.randn(*shape) * 100, 'b': np.random.randn(*shape) * 100}
507
+ g = {'a': np.random.randn(*shape), 'b': np.random.randn(*shape)}
508
+ orig_opt = _get_multi_adafactor(
509
+ 3.0, 0, adafactor_exclude_from_parameter_scale=('a',)).create(p)
510
+ scaling_map = adafactor.HParamMap([('a', False), ('.*', True)])
511
+ new_opt = adafactor.Adafactor(
512
+ 3.0, multiply_by_parameter_scale=scaling_map).create(p)
513
+ check_eq(orig_opt.state_dict(), new_opt.state_dict())
514
+
515
+ orig_opt1 = orig_opt.apply_gradient(g)
516
+ new_opt1 = new_opt.apply_gradient(g)
517
+ check_eq(orig_opt1.state_dict(), new_opt1.state_dict())
518
+
519
+ def test_3d_without_factor_map(self):
520
+ x = {'a': jnp.ones((24, 4, 16))}
521
+ opt_def = adafactor.Adafactor(factor_map=None)
522
+ with self.assertRaises(ValueError):
523
+ _ = opt_def.create(x)
524
+
525
+
526
+ if __name__ == '__main__':
527
+ absltest.main()
t5x/checkpoint_importer.py ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The T5X Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """T5 Checkpoint Importer."""
16
+
17
+ import asyncio
18
+ from concurrent.futures import thread
19
+ import re
20
+ from typing import Any, Callable, Mapping, MutableMapping, Optional, Union
21
+
22
+ from flax import traverse_util
23
+ import jax
24
+ from jax import numpy as jnp
25
+ import numpy as np
26
+ import orbax.checkpoint
27
+ import tensorflow as tf
28
+ import tensorstore as ts
29
+
30
+ # TODO(b/233659813): Cleanup clients depending on t5x.checkpoint_importer for
31
+ # LazyArray. Reconcile divergence in subclass implementation when possible.
32
+ LazyArray = orbax.checkpoint.lazy_array.LazyArray
33
+
34
+
35
+ # TODO(brianlester): The choice between using a `LazyTreadPoolArray` or a
36
+ # `LazyAwaitableArray` is dependent on if the user provided `get_fn` is blocking
37
+ # or async respectively, if we can detect which it is, we can automatically
38
+ # proxy to the correct subclass. We cannot detect of `get_fn` is a lambda that
39
+ # wraps an async call so this isn't possible yet. Add this dispatch once we are
40
+ # able to detect that, python3.8+ can detect async for partial'ed functions but
41
+ # not lambdas.
42
+ class LazyThreadPoolArray(LazyArray):
43
+ """Lazily and asynchronously loads an array when the `get_fn` blocks."""
44
+
45
+ # Uses a global threadpool to enable asynchronous loading.
46
+ executor = thread.ThreadPoolExecutor()
47
+
48
+ def get_async(self) -> asyncio.Future:
49
+ return asyncio.wrap_future(self.executor.submit(self.get))
50
+
51
+ def get(self) -> np.ndarray:
52
+ arr = self._get_fn()
53
+ if arr.dtype != self.dtype:
54
+ arr = arr.astype(self.dtype)
55
+ return arr
56
+
57
+
58
+ class LazyAwaitableArray(LazyArray):
59
+ """Lazily and asynchronously loads an array when the `get_fn` is async.
60
+
61
+ Note:
62
+ The synchronous load method `.get` requires the asyncio event loop and
63
+ calling `.run_until_complete`. This is not supported when the event loop is
64
+ already running (for example, from inside another async function).
65
+
66
+ Note:
67
+ Currently, this class has a few helper methods for creating a
68
+ LazyAwaitableArray when the input could be either an array, or a TensorStore
69
+ spec. Most people use async code when dealing with TensorStore so the
70
+ classmethods have been placed here. When someone eventually uses a blocking
71
+ function to read from TensorStore they can be moved to the LazyArray base
72
+ class.
73
+ """
74
+
75
+ def get_async(self) -> asyncio.Future:
76
+
77
+ async def _get_and_cast():
78
+ # Pytype has a false positive here, where it treats our _get_fn (_read_ts
79
+ # in this case) as having a return type of `np.ndarray` instead of
80
+ # wrapping it in an Awaitable. Related to this bug
81
+ # https://github.com/google/pytype/issues/527
82
+ arr = await self._get_fn() # pytype: disable=bad-return-type
83
+ if arr.dtype != self.dtype:
84
+ arr = arr.astype(self.dtype)
85
+ return arr
86
+
87
+ return asyncio.ensure_future(_get_and_cast())
88
+
89
+ def get(self) -> np.ndarray:
90
+ loop = asyncio.get_event_loop()
91
+ return loop.run_until_complete(self.get_async())
92
+
93
+ @classmethod
94
+ def from_tensor_store_spec(
95
+ cls,
96
+ ts_spec: ts.Spec,
97
+ get_fn: Callable[[], np.ndarray],
98
+ dtype: Optional[jnp.dtype] = None) -> 'LazyAwaitableArray':
99
+ """Create a LazyAwaitableArray based on a tensorstore.Spec."""
100
+ ts_spec = ts_spec.to_json()
101
+ shape = ts_spec['metadata']['shape']
102
+ if dtype is None:
103
+ dtype = jnp.dtype(ts_spec['dtype'])
104
+ else:
105
+ dtype = jnp.dtype(dtype)
106
+ # v2 T5X checkpoints use uint16 as the TensorStore datatype and then store
107
+ # the bfloat16 bytes as in in the 16 bytes uint16 has (no actual cast). When
108
+ # When reading the dtype from the TensorStore, if we keep the dtype of these
109
+ # v2 checkpoints as np.uint16 then the _get_fn (which has a possible cast to
110
+ # support the `restore_dtype` parameter for the checkpointer) will actually
111
+ # cast the bfloat16 values to uint16, generally resulting in an array of all
112
+ # zeros. This check avoid the actual cast to uint16 by replacing the dtype.
113
+ if dtype == np.uint16:
114
+ dtype = jnp.bfloat16
115
+ return cls(shape, dtype, get_fn)
116
+
117
+ @classmethod
118
+ def from_array(cls,
119
+ array: np.ndarray,
120
+ get_fn: Callable[[], np.ndarray],
121
+ dtype: Optional[jnp.dtype] = None) -> 'LazyAwaitableArray':
122
+ """Create a LazyAwaitableArray based on an array or python number."""
123
+ if dtype is None:
124
+ dtype = array.dtype
125
+ else:
126
+ dtype = jnp.dtype(dtype)
127
+ return cls(array.shape, dtype, get_fn)
128
+
129
+ @classmethod
130
+ def from_tensor_store_spec_or_array(
131
+ cls,
132
+ maybe_ts_spec: Union[ts.Spec, np.ndarray],
133
+ get_fn: Callable[[], np.ndarray],
134
+ dtype: Optional[jnp.dtype] = None) -> 'LazyAwaitableArray':
135
+ """Create a LazyAwaitableArray based on an array or a tensorstore.Spec."""
136
+ if isinstance(maybe_ts_spec, ts.Spec):
137
+ return cls.from_tensor_store_spec(maybe_ts_spec, get_fn, dtype=dtype)
138
+ return cls.from_array(maybe_ts_spec, get_fn, dtype=dtype)
139
+
140
+
141
+ class CheckpointTranslator:
142
+ """Utility class for defining mapping rules from one flatdict to another.
143
+
144
+ We assume a checkpoint is loaded as a dictionary with flattened keys of the
145
+ form: 'name0/name1/name2/.../nameN'
146
+
147
+ A rule is added with the 'add' decorator, which takes a regex matching rule
148
+ and wraps a conversion function, feeding it (opts, key, val, **regex_groups)
149
+ where opts is a dict containing apply-time keyword options for use by the
150
+ conversion functions.
151
+ """
152
+
153
+ def __init__(self):
154
+ self.rules = []
155
+
156
+ def add(self, pattern):
157
+ """Adds a new keyval conversion rule.
158
+
159
+ Args:
160
+ pattern: regex with capture groups for matching given sets of model
161
+ variables. We terminate all regexes with '$' to force complete matches.
162
+
163
+ Returns:
164
+ Translation function decorator for associating with the provided
165
+ pattern.
166
+ """
167
+
168
+ def register_translation_fn_decorator(fn):
169
+ # We force a complete match by adding end-of-string match.
170
+ self.rules.append((re.compile(pattern + '$'), fn))
171
+ return fn
172
+
173
+ return register_translation_fn_decorator
174
+
175
+ def apply(self, flatdict, **opts):
176
+ """Applies rules to a flattened dictionary.
177
+
178
+ Args:
179
+ flatdict: flat-key dictionary of variables.
180
+ **opts: additional config options for translation rules supplied at
181
+ application time.
182
+
183
+ Returns:
184
+ Checkpoint data with translated key/values in flat-key dict format.
185
+ """
186
+ new_dict = {}
187
+ unmatched = {}
188
+ for k, v in flatdict.items():
189
+ matched = False
190
+ for rule_pat, rule_fn in self.rules:
191
+ if rule_pat.match(k):
192
+ groups = rule_pat.match(k).groups()
193
+ new_k, new_v = rule_fn(opts, k, v, *groups)
194
+ if new_k is not None:
195
+ new_dict[new_k] = new_v
196
+ matched = True
197
+ break
198
+ if not matched:
199
+ unmatched[k] = v
200
+
201
+ # We force every key-value pair in checkpoint to have a rule associated with
202
+ # it.
203
+ if unmatched:
204
+ raise ValueError('Unmapped tensor keys exist: %s' % unmatched)
205
+
206
+ return new_dict
207
+
208
+
209
+ # Create a translation rule set for importing T5 & T5.1.1 model checkpoints.
210
+ # -----------------------------------------------------------------------------
211
+ t5_importer = CheckpointTranslator()
212
+
213
+ # Name mappings.
214
+ SLOT_MAP = {'_slot_vc': 'v_col', '_slot_vr': 'v_row', '_slot_v': 'v'}
215
+ TOWER_MAP = {'transformer': 'decoder'}
216
+
217
+
218
+ @t5_importer.add(r'global_step')
219
+ def global_step(opts, key, val):
220
+ del opts, key
221
+ return 'state/step', val.astype(np.int32).get() if isinstance(
222
+ val, LazyArray) else val
223
+
224
+
225
+ @t5_importer.add(r'shared/embedding(\w*)')
226
+ def shared_embeddings(opts, key, val, slot):
227
+ del opts, key
228
+ prefix = 'state/param_states' if slot else 'target'
229
+ suffix = '/' + SLOT_MAP[slot] if slot else ''
230
+ newkey = f'{prefix}/token_embedder/embedding{suffix}'
231
+ return newkey, val
232
+
233
+
234
+ @t5_importer.add(r'(encoder|decoder|transformer)/embedding(\w*)')
235
+ def separate_embeddings(opts, key, val, encdec, slot):
236
+ del opts, key
237
+ prefix = 'state/param_states' if slot else 'target'
238
+ suffix = '/' + SLOT_MAP[slot] if slot else ''
239
+ encdec = TOWER_MAP.get(encdec, encdec)
240
+ newkey = f'{prefix}/{encdec}/token_embedder/embedding{suffix}'
241
+ return newkey, val
242
+
243
+
244
+ # In the Mesh TensorFlow T5 code, relative_attention_bias always occurs in layer
245
+ # 0 because SelfAttention precedes other sublayers within the same block.
246
+ @t5_importer.add(
247
+ r'(encoder|decoder|transformer)/block_(\d+)/layer_000/SelfAttention/relative_attention_bias(\w*)'
248
+ )
249
+ def rel_embeddings(opts, key, val, encdec, blocknum, slot):
250
+ """Process relpos bias assuming that they are not shared across layers."""
251
+ del opts, key
252
+ prefix = 'state/param_states' if slot else 'target'
253
+ suffix = '/' + SLOT_MAP[slot] if slot else ''
254
+ blocknum = int(blocknum)
255
+ encdec = TOWER_MAP.get(encdec, encdec)
256
+ # At this point, we can't determine whether the relpos bias was shared across
257
+ # layers or not. We first assume that it was not shared. During post
258
+ # processing, we remove the layers_0 scope if it was shared.
259
+ newkey = f'{prefix}/{encdec}/layers_{blocknum}/relpos_bias/rel_embedding{suffix}'
260
+ return newkey, val
261
+
262
+
263
+ @t5_importer.add(
264
+ r'(encoder|decoder|transformer)/block_(\d+)/layer_\d+/(SelfAttention|EncDecAttention)/(q|k|v|o)(\w*)'
265
+ )
266
+ def attention_layers(opts, key, val, encdec, blocknum, attntype, qkvo, slot):
267
+ """Process attention layers."""
268
+ del opts, key
269
+ prefix = 'state/param_states' if slot else 'target'
270
+ suffix = '/' + SLOT_MAP[slot] if slot else ''
271
+ blocknum = int(blocknum)
272
+ encdec = TOWER_MAP.get(encdec, encdec)
273
+ matrix = {'q': 'query', 'k': 'key', 'v': 'value', 'o': 'out'}[qkvo]
274
+
275
+ if encdec == 'encoder':
276
+ attntype = 'attention'
277
+ else:
278
+ attntype = {
279
+ 'SelfAttention': 'self_attention',
280
+ 'EncDecAttention': 'encoder_decoder_attention'
281
+ }[attntype]
282
+ newkey = f'{prefix}/{encdec}/layers_{blocknum}/{attntype}/{matrix}/kernel{suffix}'
283
+ return newkey, val
284
+
285
+
286
+ @t5_importer.add(
287
+ r'(encoder|decoder|transformer)/block_(\d+)/layer_\d+/DenseReluDense/(wi|wo)(?:_(\d+))?/kernel(\w*)'
288
+ )
289
+ def mlpblock(opts, key, val, encdec, blocknum, io_name, io_num, slot):
290
+ """Process MLP blocks."""
291
+ del opts, key
292
+ prefix = 'state/param_states' if slot else 'target'
293
+ suffix = '/' + SLOT_MAP[slot] if slot else ''
294
+ blocknum = int(blocknum)
295
+ encdec = TOWER_MAP.get(encdec, encdec)
296
+ io_num = f'_{io_num}' if io_num else ''
297
+ newkey = f'{prefix}/{encdec}/layers_{blocknum}/mlp/{io_name}{io_num}/kernel{suffix}'
298
+ return newkey, val
299
+
300
+
301
+ @t5_importer.add(
302
+ r'(encoder|decoder|transformer)/block_(\d+)/layer_(\d+)/(?:layer|rms)_norm/scale(\w*)'
303
+ )
304
+ def layernorms(opts, key, val, encdec, blocknum, lyrnum, slot):
305
+ """Process layer norms assuming that they are pre-layernorms."""
306
+ del opts, key
307
+ prefix = 'state/param_states' if slot else 'target'
308
+ suffix = '/' + SLOT_MAP[slot] if slot else ''
309
+ lyrnum = int(lyrnum)
310
+
311
+ if encdec == 'transformer':
312
+ layernorm_type = ['pre_self_attention_layer_norm',
313
+ 'pre_mlp_layer_norm'][lyrnum]
314
+
315
+ elif encdec == 'encoder':
316
+ layernorm_type = ['pre_attention_layer_norm', 'pre_mlp_layer_norm'][lyrnum]
317
+ else: # decoder
318
+ layernorm_type = [
319
+ 'pre_self_attention_layer_norm', 'pre_cross_attention_layer_norm',
320
+ 'pre_mlp_layer_norm'
321
+ ][lyrnum]
322
+
323
+ encdec = TOWER_MAP.get(encdec, encdec)
324
+ newkey = f'{prefix}/{encdec}/layers_{int(blocknum)}/{layernorm_type}/scale{suffix}'
325
+ return newkey, val
326
+
327
+
328
+ @t5_importer.add(
329
+ r'(encoder|decoder|transformer)/(?:final_layer|rms)_norm/scale(\w*)')
330
+ def final_layernorms(opts, key, val, encdec, slot):
331
+ """Process final layer norms."""
332
+ del opts, key
333
+ prefix = 'state/param_states' if slot else 'target'
334
+ suffix = '/' + SLOT_MAP[slot] if slot else ''
335
+ norm = {
336
+ 'encoder': 'encoder_norm',
337
+ 'decoder': 'decoder_norm',
338
+ 'transformer': 'decoder_norm'
339
+ }[encdec]
340
+ encdec = TOWER_MAP.get(encdec, encdec)
341
+ newkey = f'{prefix}/{encdec}/{norm}/scale{suffix}'
342
+ return newkey, val
343
+
344
+
345
+ @t5_importer.add(r'(?:decoder|transformer)/logits/kernel(\w*)')
346
+ def final_logits(opts, key, val, slot):
347
+ del opts, key
348
+ prefix = 'state/param_states' if slot else 'target'
349
+ suffix = '/' + SLOT_MAP[slot] if slot else ''
350
+ newkey = f'{prefix}/decoder/logits_dense/kernel{suffix}'
351
+ return newkey, val
352
+
353
+
354
+ def _add_missing_param_states(t5_data):
355
+ """Add dummy slots that Flax Adafactor requires but TF does not."""
356
+ updates = {}
357
+ for k in t5_data:
358
+ if k.startswith('target'):
359
+ state_leaf = 'state/param_states' + k[len('target'):]
360
+ updates[state_leaf + '/m'] = np.zeros((1,), np.float32)
361
+ if state_leaf + '/v' in t5_data:
362
+ updates[state_leaf + '/v_row'] = np.zeros((1,), np.float32)
363
+ updates[state_leaf + '/v_col'] = np.zeros((1,), np.float32)
364
+ elif state_leaf + '/v_row' in t5_data:
365
+ updates[state_leaf + '/v'] = np.zeros((1,), np.float32)
366
+ t5_data.update(**updates)
367
+ return t5_data
368
+
369
+
370
+ def _maybe_correct_relpos_bias(t5_data):
371
+ """Correct the relpos_bias format if it is shared across layers."""
372
+ max_layer_ind = 0
373
+ for k, v in t5_data.items():
374
+ match = re.search(r'layers_(\d+)/relpos_bias', k)
375
+ if match:
376
+ layer_ind = int(match.groups()[0])
377
+ max_layer_ind = max(max_layer_ind, layer_ind)
378
+
379
+ modified_dict = {}
380
+ if max_layer_ind == 0:
381
+ # Relative position biases are shared across layers
382
+ for k, v in t5_data.items():
383
+ new_k = re.sub(r'layers_\d+/relpos_bias', 'relpos_bias', k)
384
+ modified_dict[new_k] = v
385
+ else:
386
+ # Relative position biases are unique in each layer. No more processing is
387
+ # necessary.
388
+ modified_dict = t5_data
389
+
390
+ return modified_dict
391
+
392
+
393
+ # Load checkpoint, translate, and update flax optimizer and model.
394
+ # -----------------------------------------------------------------------------
395
+ def load_tf_ckpt(path):
396
+ """Load a TF checkpoint as a flat dictionary of numpy arrays."""
397
+ ckpt_reader = tf.train.load_checkpoint(path)
398
+ ckpt_shape_map = ckpt_reader.get_variable_to_shape_map()
399
+ ckpt_dtype_map = ckpt_reader.get_variable_to_dtype_map()
400
+ datamap = { # pylint: disable=g-complex-comprehension
401
+ k: LazyThreadPoolArray(
402
+ s,
403
+ jnp.dtype(ckpt_dtype_map[k].as_numpy_dtype),
404
+ lambda x=k: ckpt_reader.get_tensor(x))
405
+ for k, s in ckpt_shape_map.items()
406
+ }
407
+ return datamap
408
+
409
+
410
+ def _update_state_dict(state_dict: Mapping[str, Any],
411
+ t5_data: MutableMapping[str, LazyArray],
412
+ strict: bool = True) -> Mapping[str, Any]:
413
+ """Update flax optimizer for T5 model.
414
+
415
+ Args:
416
+ state_dict: Optimizer to update with T5 parameters.
417
+ t5_data: T5 model parameters, typically loaded from a checkpoint.
418
+ strict: If True requires that optimizer and t5_data mappings contain the
419
+ same set of names (variables). If False, updating will succeed even if
420
+ t5_data contains variables not in the optimizer. If the optimizer has
421
+ variables not in t5_data, this function will still fail.
422
+
423
+ Returns:
424
+ Updated optimizer.
425
+ """
426
+ flat_state_dict = traverse_util.flatten_dict(state_dict, sep='/')
427
+
428
+ # Remove parameters from the checkpoint not found in the optimizer (this
429
+ # allows us to load checkpoints that contain more parameters than our current
430
+ # model).
431
+ if not strict:
432
+ for k in list(t5_data):
433
+ if k not in flat_state_dict:
434
+ t5_data.pop(k)
435
+
436
+ # Shape check.
437
+ for k, v in t5_data.items():
438
+ if flat_state_dict[k].shape != v.shape:
439
+ raise ValueError(
440
+ f'Variable {k} has shape {v.shape} != {flat_state_dict[k].shape}')
441
+ flat_state_dict = t5_data
442
+ state_dict = traverse_util.unflatten_dict(
443
+ {tuple(k.split('/')): v for k, v in flat_state_dict.items()})
444
+ return state_dict
445
+
446
+
447
+ def restore_from_t5_checkpoint(
448
+ state_dict: Mapping[str, Any],
449
+ path: str,
450
+ lazy_parameters: bool = False,
451
+ strict: bool = True,
452
+ translator: Optional[CheckpointTranslator] = None) -> Mapping[str, Any]:
453
+ """Load T5 checkpoint and update Adafactor optimizer and T5 model from it.
454
+
455
+ We require that the final translated checkpoint structure exactly matches
456
+ that of the Flax Adafactor + Transformer data, up to shape agreement of
457
+ the leaves.
458
+
459
+ Args:
460
+ state_dict: Flax Adafactor Optimizer for T5 transformer encoder-decoder.
461
+ path: a path to checkpoint file or directory.
462
+ lazy_parameters: whether to leave the parameters as LazyArrays to preserve
463
+ memory.
464
+ strict: If True requires that optimizer and t5_data mappings contain the
465
+ same set of names (variables). If False, updating will succeed even if
466
+ t5_data contains variables not in the optimizer. If the optimizer has
467
+ variables not in t5_data, this function will still fail.
468
+ translator: The mapping rules for conversion. If None, then default T5
469
+ conversion rules will be used.
470
+
471
+ Returns:
472
+ Adafactor optimizer updated with parameters and optimizer state from
473
+ T5 checkpoint.
474
+ """
475
+ if translator is None:
476
+ translator = t5_importer
477
+ ckpt_data = load_tf_ckpt(path)
478
+ t5_data = translator.apply(ckpt_data)
479
+ t5_data = _add_missing_param_states(t5_data)
480
+ t5_data = _maybe_correct_relpos_bias(t5_data)
481
+ state_dict = _update_state_dict(state_dict, t5_data, strict=strict)
482
+ if not lazy_parameters:
483
+ state_dict = jax.tree_map(
484
+ lambda x: x.get() if isinstance(x, LazyArray) else x, state_dict)
485
+ return state_dict
t5x/checkpoint_importer_test.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The T5X Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tests for t5x.checkpoint_importer."""
16
+
17
+ import json
18
+ import os
19
+
20
+ from absl import flags
21
+ from absl.testing import absltest
22
+ import jax
23
+ import numpy as np
24
+ from t5x import checkpoint_importer
25
+ import tensorflow as tf
26
+
27
+
28
+ class CheckpointImporterTest(absltest.TestCase):
29
+
30
+ def test_rel_embeddings_shared_layers(self):
31
+ # This represents a ckpt where the Mesh TensorFlow's
32
+ # transformer_layers.SelfAttention.relative_attention_type = "bias_shared",
33
+ # i.e., the same relative attention parameters are shared by all layers
34
+ # within the (en|de)coder.
35
+ ckpt_data = {
36
+ 'encoder/block_000/layer_000/SelfAttention/relative_attention_bias':
37
+ 1,
38
+ 'decoder/block_000/layer_000/SelfAttention/relative_attention_bias':
39
+ 2,
40
+ 'decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v':
41
+ 3,
42
+ }
43
+ t5_data = checkpoint_importer.t5_importer.apply(ckpt_data)
44
+ t5_data = checkpoint_importer._maybe_correct_relpos_bias(t5_data)
45
+ expected = {
46
+ 'target/encoder/relpos_bias/rel_embedding': 1,
47
+ 'target/decoder/relpos_bias/rel_embedding': 2,
48
+ 'state/param_states/decoder/relpos_bias/rel_embedding/v': 3,
49
+ }
50
+ self.assertEqual(t5_data, expected)
51
+
52
+ def test_rel_embeddings_per_layer(self):
53
+ # This represents a ckpt where the Mesh TensorFlow's
54
+ # transformer_layers.SelfAttention.relative_attention_type = "bias", i.e.,
55
+ # each layer has its own relative attention parameters.
56
+ ckpt_data = {
57
+ 'encoder/block_000/layer_000/SelfAttention/relative_attention_bias':
58
+ 1,
59
+ 'encoder/block_001/layer_000/SelfAttention/relative_attention_bias':
60
+ 2,
61
+ 'decoder/block_000/layer_000/SelfAttention/relative_attention_bias':
62
+ 3,
63
+ 'decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v':
64
+ 4,
65
+ 'decoder/block_011/layer_000/SelfAttention/relative_attention_bias':
66
+ 5
67
+ }
68
+ t5_data = checkpoint_importer.t5_importer.apply(ckpt_data)
69
+ t5_data = checkpoint_importer._maybe_correct_relpos_bias(t5_data)
70
+ expected = {
71
+ 'target/encoder/layers_0/relpos_bias/rel_embedding': 1,
72
+ 'target/encoder/layers_1/relpos_bias/rel_embedding': 2,
73
+ 'target/decoder/layers_0/relpos_bias/rel_embedding': 3,
74
+ 'state/param_states/decoder/layers_0/relpos_bias/rel_embedding/v': 4,
75
+ 'target/decoder/layers_11/relpos_bias/rel_embedding': 5,
76
+ }
77
+ self.assertEqual(t5_data, expected)
78
+
79
+
80
+ if __name__ == '__main__':
81
+ absltest.main()
t5x/checkpoint_utils.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The T5X Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Checkpoint helper functions for managing checkpoints.
16
+
17
+ Supports marking checkpoints as pinned to exclude them from the checkpointer
18
+ removal process.
19
+ """
20
+
21
+ import os
22
+
23
+ from absl import logging
24
+
25
+ from tensorflow.io import gfile
26
+
27
+ # PINNED file in the checkpoint directory indicates that the checkpoint should
28
+ # not be removed during the automatic pruning of old checkpoints.
29
+ _PINNED_CHECKPOINT_FILENAME = 'PINNED'
30
+
31
+
32
+ def pinned_checkpoint_filepath(ckpt_dir: str) -> str:
33
+ """Full path of the pinned checkpoint file."""
34
+ return os.path.join(ckpt_dir, _PINNED_CHECKPOINT_FILENAME)
35
+
36
+
37
+ def is_pinned_checkpoint(ckpt_dir: str) -> bool:
38
+ """Returns whether the checkpoint is pinned, and should NOT be removed."""
39
+ pinned_ckpt_file = pinned_checkpoint_filepath(ckpt_dir)
40
+ if gfile.exists(pinned_ckpt_file):
41
+ return True
42
+ return False
43
+
44
+
45
+ def pin_checkpoint(ckpt_dir: str, txt: str = '1') -> None:
46
+ """Pin a checkpoint so it does not get deleted by the normal pruning process.
47
+
48
+ Creates a PINNED file in the checkpoint directory to indicate the checkpoint
49
+ should be excluded from the deletion of old checkpoints.
50
+
51
+ Args:
52
+ ckpt_dir: The checkpoint step dir that is to be always kept.
53
+ txt: Text to be written into the checkpoints ALWAYS_KEEP me file.
54
+ """
55
+ pinned_ckpt_file = pinned_checkpoint_filepath(ckpt_dir)
56
+ with gfile.GFile(pinned_ckpt_file, 'w') as f:
57
+ logging.debug('Write %s file : %s.', pinned_ckpt_file, txt)
58
+ f.write(txt)
59
+
60
+
61
+ def unpin_checkpoint(ckpt_dir: str) -> None:
62
+ """Removes the pinned status of the checkpoint so it is open for deletion."""
63
+ if not is_pinned_checkpoint(ckpt_dir):
64
+ logging.debug('%s is not PINNED. Nothing to do here.', ckpt_dir)
65
+ return
66
+ try:
67
+ pinned_ckpt_file = pinned_checkpoint_filepath(ckpt_dir)
68
+ logging.debug('Remove %s file.', pinned_ckpt_file)
69
+ gfile.rmtree(pinned_ckpt_file)
70
+ except IOError:
71
+ logging.exception('Failed to unpin %s', ckpt_dir)
72
+
73
+
74
+ def remove_checkpoint_dir(ckpt_dir: str) -> None:
75
+ """Removes the checkpoint dir if it is not pinned."""
76
+ if not is_pinned_checkpoint(ckpt_dir):
77
+ logging.info('Deleting checkpoint: %s', ckpt_dir)
78
+ gfile.rmtree(ckpt_dir)
79
+ else:
80
+ logging.info('Keeping pinned checkpoint: %s', ckpt_dir)
81
+
82
+
83
+ def remove_dataset_checkpoint(ckpt_dir: str, train_ds_prefix: str) -> None:
84
+ """Removes dataset checkpoints if the checkpoint is not pinned."""
85
+ if not is_pinned_checkpoint(ckpt_dir):
86
+ train_ds_pattern = os.path.join(ckpt_dir, train_ds_prefix + '*')
87
+ logging.info('Deleting dataset checkpoint: %s', train_ds_pattern)
88
+ for file in gfile.glob(train_ds_pattern):
89
+ gfile.remove(file)
90
+ else:
91
+ logging.info('Keeping pinned checkpoint: %s', ckpt_dir)
t5x/checkpoint_utils_test.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The T5X Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tests for t5x.checkpoint_utils."""
16
+
17
+ import os
18
+ import traceback
19
+
20
+ from absl.testing import absltest
21
+ from t5x import checkpoint_utils
22
+ from tensorflow.io import gfile
23
+
24
+ TESTDATA = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata")
25
+
26
+
27
+ class CheckpointsUtilsTest(absltest.TestCase):
28
+
29
+ def setUp(self):
30
+ super().setUp()
31
+ self.checkpoints_dir = self.create_tempdir()
32
+ self.ckpt_dir_path = self.checkpoints_dir.full_path
33
+ self.pinned_ckpt_file = os.path.join(self.ckpt_dir_path, "PINNED")
34
+ self.checkpoints_dir.create_file("checkpoint")
35
+ # Create a `train_ds` file representing the dataset checkpoint.
36
+ train_ds_basename = "train_ds-00000-of-00001"
37
+ self.train_ds_file = os.path.join(self.ckpt_dir_path, train_ds_basename)
38
+ self.checkpoints_dir.create_file(train_ds_basename)
39
+
40
+ def test_always_keep_checkpoint_file(self):
41
+ self.assertEqual(
42
+ "/path/to/ckpt/dir/PINNED",
43
+ checkpoint_utils.pinned_checkpoint_filepath("/path/to/ckpt/dir"))
44
+
45
+ def test_is_pinned_checkpoint_false_by_default(self):
46
+ # Ensure regular checkpoint without PINNED file.
47
+ self.assertFalse(gfile.exists(os.path.join(self.ckpt_dir_path, "PINNED")))
48
+
49
+ # Validate checkpoints are not pinned by default.
50
+ self.assertFalse(checkpoint_utils.is_pinned_checkpoint(self.ckpt_dir_path))
51
+
52
+ def test_is_pinned_checkpoint(self):
53
+ # Ensure the checkpoint directory as pinned.
54
+ pinned_ckpt_testdata = os.path.join(TESTDATA, "pinned_ckpt_dir")
55
+ pinned_file = os.path.join(pinned_ckpt_testdata, "PINNED")
56
+ self.assertTrue(gfile.exists(pinned_file))
57
+
58
+ # Test and validate.
59
+ self.assertTrue(checkpoint_utils.is_pinned_checkpoint(pinned_ckpt_testdata))
60
+
61
+ def test_is_pinned_missing_ckpt(self):
62
+ self.assertFalse(
63
+ checkpoint_utils.is_pinned_checkpoint(
64
+ os.path.join(self.ckpt_dir_path, "ckpt_does_not_exist")))
65
+
66
+ def test_pin_checkpoint(self):
67
+ # Ensure directory isn't already pinned.
68
+ self.assertFalse(gfile.exists(self.pinned_ckpt_file))
69
+
70
+ # Test.
71
+ checkpoint_utils.pin_checkpoint(self.ckpt_dir_path)
72
+
73
+ # Validate.
74
+ self.assertTrue(gfile.exists(self.pinned_ckpt_file))
75
+ with open(self.pinned_ckpt_file) as f:
76
+ self.assertEqual("1", f.read())
77
+
78
+ def test_pin_checkpoint_txt(self):
79
+ checkpoint_utils.pin_checkpoint(self.ckpt_dir_path, "TEXT_IN_PINNED")
80
+ self.assertTrue(os.path.exists(os.path.join(self.ckpt_dir_path, "PINNED")))
81
+ with open(self.pinned_ckpt_file) as f:
82
+ self.assertEqual("TEXT_IN_PINNED", f.read())
83
+
84
+ def test_unpin_checkpoint(self):
85
+ # Mark the checkpoint directory as pinned.
86
+ self.checkpoints_dir.create_file("PINNED")
87
+ self.assertTrue(checkpoint_utils.is_pinned_checkpoint(self.ckpt_dir_path))
88
+
89
+ # Test.
90
+ checkpoint_utils.unpin_checkpoint(self.ckpt_dir_path)
91
+
92
+ # Validate the "PINNED" checkpoint file got removed.
93
+ self.assertFalse(gfile.exists(os.path.join(self.ckpt_dir_path, "PINNED")))
94
+
95
+ def test_unpin_checkpoint_does_not_exist(self):
96
+ missing_ckpt_path = os.path.join(self.ckpt_dir_path, "ckpt_does_not_exist")
97
+ self.assertFalse(gfile.exists(missing_ckpt_path))
98
+
99
+ # Test. Assert does not raise error.
100
+ try:
101
+ checkpoint_utils.unpin_checkpoint(missing_ckpt_path)
102
+ except IOError:
103
+ # TODO(b/172262005): Remove traceback.format_exc() from the error message.
104
+ self.fail("Unpin checkpoint failed with: %s" % traceback.format_exc())
105
+
106
+ def test_remove_checkpoint_dir(self):
107
+ # Ensure the checkpoint directory is setup.
108
+ assert gfile.exists(self.ckpt_dir_path)
109
+
110
+ # Test.
111
+ checkpoint_utils.remove_checkpoint_dir(self.ckpt_dir_path)
112
+
113
+ # Validate the checkpoint directory got removed.
114
+ self.assertFalse(gfile.exists(self.ckpt_dir_path))
115
+
116
+ def test_remove_checkpoint_dir_pinned(self):
117
+ # Mark the checkpoint directory as pinned so it does not get removed.
118
+ self.checkpoints_dir.create_file("PINNED")
119
+
120
+ # Test.
121
+ checkpoint_utils.remove_checkpoint_dir(self.ckpt_dir_path)
122
+
123
+ # Validate the checkpoint directory still exists.
124
+ self.assertTrue(gfile.exists(self.ckpt_dir_path))
125
+
126
+ def test_remove_dataset_checkpoint(self):
127
+ # Ensure the checkpoint directory is setup.
128
+ assert gfile.exists(self.ckpt_dir_path)
129
+
130
+ # Test.
131
+ checkpoint_utils.remove_dataset_checkpoint(self.ckpt_dir_path, "train_ds")
132
+
133
+ # Validate the checkpoint directory got removed.
134
+ self.assertFalse(gfile.exists(self.train_ds_file))
135
+ self.assertTrue(gfile.exists(self.ckpt_dir_path))
136
+
137
+ def test_remove_dataset_checkpoint_pinned(self):
138
+ # Mark the checkpoint directory as pinned so it does not get removed.
139
+ self.checkpoints_dir.create_file("PINNED")
140
+
141
+ # Test.
142
+ checkpoint_utils.remove_dataset_checkpoint(self.ckpt_dir_path, "train_ds")
143
+
144
+ # Validate the checkpoint directory still exists.
145
+ self.assertTrue(gfile.exists(self.train_ds_file))
146
+ self.assertTrue(gfile.exists(self.ckpt_dir_path))
147
+
148
+ if __name__ == "__main__":
149
+ absltest.main()